CF1928F 题解
Register_int
2024-02-12 15:17:30
首先题目中的条件是骗人的。单独对于一个格子 $(i,j)$ 考虑:
- $a_i+b_j\not=a_{i+1}+b_j$,也即 $a_i\not=a_{i+1}$。
- $a_i+b_j\not=a_i+b_{j+1}$,也即 $b_i\not=b_{i+1}$。
因此一个以 $(i,j)$ 为左上角,边长为 $d$ 的正方形合法,当且仅当:
- 对于 $x=i\sim i+d-2,y=j\sim j+d-2$,有 $a_x\not=a_{x+1}$ 且 $b_y\not=b_{y+1}$。
可以发现 $a,b$ 是独立的,所以若要计算方案数,可以单独计算出 $a,b$ 中长度为 $d$ 的合法子串数量再相乘。问题就变成了维护这个东西。
然后真正恶心的部分来了。来考虑 $a$ 中一个长度为 $x$ 的极长合法子串与 $b$ 中一个长度为 $y$ 的极长合法子串,他们能组成多少个合法正方形。不妨 $x\le y$,则有:
$$
\begin{aligned}
f(x,y)&=\sum^x_{i=1}(x-i+1)(y-i+1)\\
&=\sum^x_{i=1}i(y-x+i)\\
&=\sum^x_{i=1}i^2+i(y-x)\\
&=\frac{x(x+1)(2x+1)}6+(y-x)\times\frac{x(x+1)}2\\
&=\frac{x(x+1)(2x+1)}6-\frac{x^2(x+1)}2+\frac{x(x+1)y}2\\
\end{aligned}
$$
设 $a$ 中的极长连续段长度为 $x_1,x_2,\cdots,x_p$,$b$ 中的是 $y_1,y_2,\cdots,y_q$。再分别设 $u_i=i,v_i=\frac{x(x+1)}2,w_i=\frac{i(i+1)(2i+1)}6-\frac{i^2(i+1)}2$。则我们要求的就是:
$$
\begin{aligned}
&\sum_i\sum_jf(x_i,y_j)\\
=&\sum_i\sum_j[x_i\le y_j](w_{x_i}+v_{x_i}u_{y_j})+[x_i>y_j](w_{y_j}+v_{y_j}u_{x_i})\\
=&\sum_i\left(w_{x_i}\sum_j[x_i\le y_j]+v_{x_i}\sum_j[x_i\le y_j]u_{y_j}+\sum_j[x_i>y_j]w_{y_j}+u_{x_i}\sum_j[x_i>y_j]v_{y_j}\right)\\
=&\sum_i\left(w_{x_i}\sum_{y_j\ge x_i}1+v_{x_i}\sum_{y_j\ge x_i}u_{y_j}+\sum_{y_j<x_i}w_{y_j}+u_{x_i}\sum_{y_j<x_i}v_{y_j}\right)\\
\end{aligned}
$$
由于每个区间加只会造成 $O(1)$ 个区间的加入与删除,所以第一维求和可以直接暴力求和维护。第二维则是很多个前缀求和与后缀求和,可以简单树状数组解决。啊但是这时候改的如果是 $b$,又变成很麻烦的样子了。所以我们还得推改 $b$ 时的贡献:
$$
\begin{aligned}
&\sum_j\sum_if(x_i,y_j)\\
=&\sum_j\sum_i[x_i\le y_j](w_{x_i}+v_{x_i}u_{y_j})+[x_i>y_j](w_{y_j}+v_{y_j}u_{x_i})\\
=&\sum_j\left(\sum_i[x_i\le y_j]w_{x_i}+u_{y_j}\sum_i[x_i\le y_j]v_{x_i}+w_{y_j}\sum_i[x_i>y_j]1+v_{y_j}\sum_i[x_i>y_j]u_{x_i}\right)\\
=&\sum_j\left(\sum_{x_i\le y_j}w_{x_i}+u_{y_j}\sum_{x_i\le y_j}v_{x_i}+w_{y_j}\sum_{x_i>y_j}1+v_{y_j}\sum_{x_i>y_j}u_{x_i}\right)\\
=&\sum_j\left(w_{y_j}\sum_{x_i\ge y_j}1+v_{y_j}\sum_{x_i\ge y_j}u_{x_i}+\sum_{x_i<y_j}w_{x_i}+u_{y_j}\sum_{x_i<y_j}v_{x_i}\right)\\
\end{aligned}
$$
那这两部分实际上完全一样,可以节省些码量。维护区间可以用 set 暴力查找,时间复杂度为 $O(n\log n)$。
# AC 代码
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 3e5 + 10;
inline ll u(ll n) { return n; }
inline ll v(ll n) { return n * (n + 1) / 2; }
inline ll w(ll n) { return n * (n + 1) * (2 * n + 1) / 6 - n * n * (n + 1) / 2; }
struct bit {
int n; vector<ll> c;
bit(int n) : n(n), c(n + 1) {}
inline void add(int k, ll x) {
if (k <= 0) return ;
for (int i = k; i <= n; i += i & -i) c[i] += x;
}
inline ll ask(int k) {
if (k <= 0) return 0; if (k > n) k = n; ll res = 0;
for (int i = k; i; i &= i - 1) res += c[i];
return res;
}
};
struct maintainer {
int n; bit a, b, c, d;
maintainer(int n) : n(n), a(n), b(n), c(n), d(n) {}
inline
void insert(int k) {
a.add(n - k + 1, 1), b.add(n - k + 1, u(k));
c.add(k, v(k)), d.add(k, w(k));
}
inline
void erase(int k) {
a.add(n - k + 1, -1), b.add(n - k + 1, -u(k));
c.add(k, -v(k)), d.add(k, -w(k));
}
inline
ll query(int m) {
return w(m) * a.ask(n - m + 1) + v(m) * b.ask(n - m + 1)
+ u(m) * c.ask(m - 1) + d.ask(m - 1);
}
};
struct node {
maintainer &a, &b; ll ans;
vector<int> d; set<int> s;
inline void insert(int k) { a.insert(k), ans += b.query(k); }
inline void erase(int k) { a.erase(k), ans -= b.query(k); }
node(vector<int> t, maintainer &a, maintainer &b) : s(), d(t.size()), a(a), b(b), ans() {
s.insert(0), s.insert(t.size());
for (int i = 1; i < t.size(); i++) {
d[i] = t[i] - t[i - 1];
if (!d[i]) s.insert(i);
}
for (auto it = s.begin(); next(it) != s.end(); ++it) insert(*next(it) - *it);
}
inline
void change(int k, ll x) {
if (!k || k == a.n) return ;
if (!d[k] && d[k] + x) {
auto it = s.find(k), pre = prev(it), nxt = next(it);
erase(*it - *pre), erase(*nxt - *it), insert(*nxt - *pre);
s.erase(it);
}
if (d[k] && !(d[k] + x)) {
auto it = s.insert(k).first, pre = prev(it), nxt = next(it);
erase(*nxt - *pre), insert(*it - *pre), insert(*nxt - *it);
}
d[k] += x;
}
inline void add(int l, int r, ll x) { change(l - 1, x), change(r, -x); }
};
int n, m, q;
int main() {
scanf("%d%d%d", &n, &m, &q);
vector<int> a(n), b(m);
for (int &x : a) scanf("%d", &x);
for (int &x : b) scanf("%d", &x);
maintainer ta(n), tb(m);
node pa(a, ta, tb), pb(b, tb, ta);
printf("%lld\n", pa.ans + pb.ans);
for (int t, l, r, x; q--;) {
scanf("%d%d%d%d", &t, &l, &r, &x);
t == 1 ? pa.add(l, r, x) : pb.add(l, r, x);
printf("%lld\n", pa.ans + pb.ans);
}
}
```