CF1928F 题解

Register_int

2024-02-12 15:17:30

Solution

首先题目中的条件是骗人的。单独对于一个格子 $(i,j)$ 考虑: - $a_i+b_j\not=a_{i+1}+b_j$,也即 $a_i\not=a_{i+1}$。 - $a_i+b_j\not=a_i+b_{j+1}$,也即 $b_i\not=b_{i+1}$。 因此一个以 $(i,j)$ 为左上角,边长为 $d$ 的正方形合法,当且仅当: - 对于 $x=i\sim i+d-2,y=j\sim j+d-2$,有 $a_x\not=a_{x+1}$ 且 $b_y\not=b_{y+1}$。 可以发现 $a,b$ 是独立的,所以若要计算方案数,可以单独计算出 $a,b$ 中长度为 $d$ 的合法子串数量再相乘。问题就变成了维护这个东西。 然后真正恶心的部分来了。来考虑 $a$ 中一个长度为 $x$ 的极长合法子串与 $b$ 中一个长度为 $y$ 的极长合法子串,他们能组成多少个合法正方形。不妨 $x\le y$,则有: $$ \begin{aligned} f(x,y)&=\sum^x_{i=1}(x-i+1)(y-i+1)\\ &=\sum^x_{i=1}i(y-x+i)\\ &=\sum^x_{i=1}i^2+i(y-x)\\ &=\frac{x(x+1)(2x+1)}6+(y-x)\times\frac{x(x+1)}2\\ &=\frac{x(x+1)(2x+1)}6-\frac{x^2(x+1)}2+\frac{x(x+1)y}2\\ \end{aligned} $$ 设 $a$ 中的极长连续段长度为 $x_1,x_2,\cdots,x_p$,$b$ 中的是 $y_1,y_2,\cdots,y_q$。再分别设 $u_i=i,v_i=\frac{x(x+1)}2,w_i=\frac{i(i+1)(2i+1)}6-\frac{i^2(i+1)}2$。则我们要求的就是: $$ \begin{aligned} &\sum_i\sum_jf(x_i,y_j)\\ =&\sum_i\sum_j[x_i\le y_j](w_{x_i}+v_{x_i}u_{y_j})+[x_i>y_j](w_{y_j}+v_{y_j}u_{x_i})\\ =&\sum_i\left(w_{x_i}\sum_j[x_i\le y_j]+v_{x_i}\sum_j[x_i\le y_j]u_{y_j}+\sum_j[x_i>y_j]w_{y_j}+u_{x_i}\sum_j[x_i>y_j]v_{y_j}\right)\\ =&\sum_i\left(w_{x_i}\sum_{y_j\ge x_i}1+v_{x_i}\sum_{y_j\ge x_i}u_{y_j}+\sum_{y_j<x_i}w_{y_j}+u_{x_i}\sum_{y_j<x_i}v_{y_j}\right)\\ \end{aligned} $$ 由于每个区间加只会造成 $O(1)$ 个区间的加入与删除,所以第一维求和可以直接暴力求和维护。第二维则是很多个前缀求和与后缀求和,可以简单树状数组解决。啊但是这时候改的如果是 $b$,又变成很麻烦的样子了。所以我们还得推改 $b$ 时的贡献: $$ \begin{aligned} &\sum_j\sum_if(x_i,y_j)\\ =&\sum_j\sum_i[x_i\le y_j](w_{x_i}+v_{x_i}u_{y_j})+[x_i>y_j](w_{y_j}+v_{y_j}u_{x_i})\\ =&\sum_j\left(\sum_i[x_i\le y_j]w_{x_i}+u_{y_j}\sum_i[x_i\le y_j]v_{x_i}+w_{y_j}\sum_i[x_i>y_j]1+v_{y_j}\sum_i[x_i>y_j]u_{x_i}\right)\\ =&\sum_j\left(\sum_{x_i\le y_j}w_{x_i}+u_{y_j}\sum_{x_i\le y_j}v_{x_i}+w_{y_j}\sum_{x_i>y_j}1+v_{y_j}\sum_{x_i>y_j}u_{x_i}\right)\\ =&\sum_j\left(w_{y_j}\sum_{x_i\ge y_j}1+v_{y_j}\sum_{x_i\ge y_j}u_{x_i}+\sum_{x_i<y_j}w_{x_i}+u_{y_j}\sum_{x_i<y_j}v_{x_i}\right)\\ \end{aligned} $$ 那这两部分实际上完全一样,可以节省些码量。维护区间可以用 set 暴力查找,时间复杂度为 $O(n\log n)$。 # AC 代码 ```cpp #include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAXN = 3e5 + 10; inline ll u(ll n) { return n; } inline ll v(ll n) { return n * (n + 1) / 2; } inline ll w(ll n) { return n * (n + 1) * (2 * n + 1) / 6 - n * n * (n + 1) / 2; } struct bit { int n; vector<ll> c; bit(int n) : n(n), c(n + 1) {} inline void add(int k, ll x) { if (k <= 0) return ; for (int i = k; i <= n; i += i & -i) c[i] += x; } inline ll ask(int k) { if (k <= 0) return 0; if (k > n) k = n; ll res = 0; for (int i = k; i; i &= i - 1) res += c[i]; return res; } }; struct maintainer { int n; bit a, b, c, d; maintainer(int n) : n(n), a(n), b(n), c(n), d(n) {} inline void insert(int k) { a.add(n - k + 1, 1), b.add(n - k + 1, u(k)); c.add(k, v(k)), d.add(k, w(k)); } inline void erase(int k) { a.add(n - k + 1, -1), b.add(n - k + 1, -u(k)); c.add(k, -v(k)), d.add(k, -w(k)); } inline ll query(int m) { return w(m) * a.ask(n - m + 1) + v(m) * b.ask(n - m + 1) + u(m) * c.ask(m - 1) + d.ask(m - 1); } }; struct node { maintainer &a, &b; ll ans; vector<int> d; set<int> s; inline void insert(int k) { a.insert(k), ans += b.query(k); } inline void erase(int k) { a.erase(k), ans -= b.query(k); } node(vector<int> t, maintainer &a, maintainer &b) : s(), d(t.size()), a(a), b(b), ans() { s.insert(0), s.insert(t.size()); for (int i = 1; i < t.size(); i++) { d[i] = t[i] - t[i - 1]; if (!d[i]) s.insert(i); } for (auto it = s.begin(); next(it) != s.end(); ++it) insert(*next(it) - *it); } inline void change(int k, ll x) { if (!k || k == a.n) return ; if (!d[k] && d[k] + x) { auto it = s.find(k), pre = prev(it), nxt = next(it); erase(*it - *pre), erase(*nxt - *it), insert(*nxt - *pre); s.erase(it); } if (d[k] && !(d[k] + x)) { auto it = s.insert(k).first, pre = prev(it), nxt = next(it); erase(*nxt - *pre), insert(*it - *pre), insert(*nxt - *it); } d[k] += x; } inline void add(int l, int r, ll x) { change(l - 1, x), change(r, -x); } }; int n, m, q; int main() { scanf("%d%d%d", &n, &m, &q); vector<int> a(n), b(m); for (int &x : a) scanf("%d", &x); for (int &x : b) scanf("%d", &x); maintainer ta(n), tb(m); node pa(a, ta, tb), pb(b, tb, ta); printf("%lld\n", pa.ans + pb.ans); for (int t, l, r, x; q--;) { scanf("%d%d%d%d", &t, &l, &r, &x); t == 1 ? pa.add(l, r, x) : pb.add(l, r, x); printf("%lld\n", pa.ans + pb.ans); } } ```