CF1928F 题解

· · 题解

首先题目中的条件是骗人的。单独对于一个格子 (i,j) 考虑:

因此一个以 (i,j) 为左上角,边长为 d 的正方形合法,当且仅当:

可以发现 a,b 是独立的,所以若要计算方案数,可以单独计算出 a,b 中长度为 d 的合法子串数量再相乘。问题就变成了维护这个东西。

然后真正恶心的部分来了。来考虑 a 中一个长度为 x 的极长合法子串与 b 中一个长度为 y 的极长合法子串,他们能组成多少个合法正方形。不妨 x\le y,则有:

\begin{aligned} f(x,y)&=\sum^x_{i=1}(x-i+1)(y-i+1)\\ &=\sum^x_{i=1}i(y-x+i)\\ &=\sum^x_{i=1}i^2+i(y-x)\\ &=\frac{x(x+1)(2x+1)}6+(y-x)\times\frac{x(x+1)}2\\ &=\frac{x(x+1)(2x+1)}6-\frac{x^2(x+1)}2+\frac{x(x+1)y}2\\ \end{aligned}

a 中的极长连续段长度为 x_1,x_2,\cdots,x_pb 中的是 y_1,y_2,\cdots,y_q。再分别设 u_i=i,v_i=\frac{x(x+1)}2,w_i=\frac{i(i+1)(2i+1)}6-\frac{i^2(i+1)}2。则我们要求的就是:

\begin{aligned} &\sum_i\sum_jf(x_i,y_j)\\ =&\sum_i\sum_j[x_i\le y_j](w_{x_i}+v_{x_i}u_{y_j})+[x_i>y_j](w_{y_j}+v_{y_j}u_{x_i})\\ =&\sum_i\left(w_{x_i}\sum_j[x_i\le y_j]+v_{x_i}\sum_j[x_i\le y_j]u_{y_j}+\sum_j[x_i>y_j]w_{y_j}+u_{x_i}\sum_j[x_i>y_j]v_{y_j}\right)\\ =&\sum_i\left(w_{x_i}\sum_{y_j\ge x_i}1+v_{x_i}\sum_{y_j\ge x_i}u_{y_j}+\sum_{y_j<x_i}w_{y_j}+u_{x_i}\sum_{y_j<x_i}v_{y_j}\right)\\ \end{aligned}

由于每个区间加只会造成 O(1) 个区间的加入与删除,所以第一维求和可以直接暴力求和维护。第二维则是很多个前缀求和与后缀求和,可以简单树状数组解决。啊但是这时候改的如果是 b,又变成很麻烦的样子了。所以我们还得推改 b 时的贡献:

\begin{aligned} &\sum_j\sum_if(x_i,y_j)\\ =&\sum_j\sum_i[x_i\le y_j](w_{x_i}+v_{x_i}u_{y_j})+[x_i>y_j](w_{y_j}+v_{y_j}u_{x_i})\\ =&\sum_j\left(\sum_i[x_i\le y_j]w_{x_i}+u_{y_j}\sum_i[x_i\le y_j]v_{x_i}+w_{y_j}\sum_i[x_i>y_j]1+v_{y_j}\sum_i[x_i>y_j]u_{x_i}\right)\\ =&\sum_j\left(\sum_{x_i\le y_j}w_{x_i}+u_{y_j}\sum_{x_i\le y_j}v_{x_i}+w_{y_j}\sum_{x_i>y_j}1+v_{y_j}\sum_{x_i>y_j}u_{x_i}\right)\\ =&\sum_j\left(w_{y_j}\sum_{x_i\ge y_j}1+v_{y_j}\sum_{x_i\ge y_j}u_{x_i}+\sum_{x_i<y_j}w_{x_i}+u_{y_j}\sum_{x_i<y_j}v_{x_i}\right)\\ \end{aligned}

那这两部分实际上完全一样,可以节省些码量。维护区间可以用 set 暴力查找,时间复杂度为 O(n\log n)

AC 代码

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 3e5 + 10;

inline ll u(ll n) { return n; }
inline ll v(ll n) { return n * (n + 1) / 2; }
inline ll w(ll n) { return n * (n + 1) * (2 * n + 1) / 6 - n * n * (n + 1) / 2; }

struct bit {

    int n; vector<ll> c;

    bit(int n) : n(n), c(n + 1) {}

    inline void add(int k, ll x) {
        if (k <= 0) return ;
        for (int i = k; i <= n; i += i & -i) c[i] += x;
    }

    inline ll ask(int k) {
        if (k <= 0) return 0; if (k > n) k = n; ll res = 0;
        for (int i = k; i; i &= i - 1) res += c[i];
        return res;
    }

};

struct maintainer {

    int n; bit a, b, c, d;

    maintainer(int n) : n(n), a(n), b(n), c(n), d(n) {}

    inline 
    void insert(int k) {
        a.add(n - k + 1, 1), b.add(n - k + 1, u(k));
        c.add(k, v(k)), d.add(k, w(k));
    }

    inline 
    void erase(int k) {
        a.add(n - k + 1, -1), b.add(n - k + 1, -u(k));
        c.add(k, -v(k)), d.add(k, -w(k));
    }

    inline 
    ll query(int m) {
        return w(m) * a.ask(n - m + 1) + v(m) * b.ask(n - m + 1) 
        + u(m) * c.ask(m - 1) + d.ask(m - 1);
    }

};

struct node {

    maintainer &a, &b; ll ans;
    vector<int> d; set<int> s;

    inline void insert(int k) { a.insert(k), ans += b.query(k); }
    inline void erase(int k) { a.erase(k), ans -= b.query(k); }

    node(vector<int> t, maintainer &a, maintainer &b) : s(), d(t.size()), a(a), b(b), ans() {
        s.insert(0), s.insert(t.size());
        for (int i = 1; i < t.size(); i++) {
            d[i] = t[i] - t[i - 1];
            if (!d[i]) s.insert(i);
        }
        for (auto it = s.begin(); next(it) != s.end(); ++it) insert(*next(it) - *it);
    }

    inline 
    void change(int k, ll x) {
        if (!k || k == a.n) return ;
        if (!d[k] && d[k] + x) {
            auto it = s.find(k), pre = prev(it), nxt = next(it);
            erase(*it - *pre), erase(*nxt - *it), insert(*nxt - *pre);
            s.erase(it);
        }
        if (d[k] && !(d[k] + x)) {
            auto it = s.insert(k).first, pre = prev(it), nxt = next(it);
            erase(*nxt - *pre), insert(*it - *pre), insert(*nxt - *it);
        }
        d[k] += x;
    }

    inline void add(int l, int r, ll x) { change(l - 1, x), change(r, -x); }

};

int n, m, q;

int main() {
    scanf("%d%d%d", &n, &m, &q);
    vector<int> a(n), b(m);
    for (int &x : a) scanf("%d", &x);
    for (int &x : b) scanf("%d", &x);
    maintainer ta(n), tb(m);
    node pa(a, ta, tb), pb(b, tb, ta);
    printf("%lld\n", pa.ans + pb.ans);
    for (int t, l, r, x; q--;) {
        scanf("%d%d%d%d", &t, &l, &r, &x);
        t == 1 ? pa.add(l, r, x) : pb.add(l, r, x);
        printf("%lld\n", pa.ans + pb.ans);
    }
}