题解 CF1172E 【Nauuo and ODT】
ouuan
2019-06-08 13:56:06
对每种颜色分别考虑不含该颜色的简单路径条数。
令当前处理的颜色为 $c$,把颜色为 $c$ 的视为白色,不是 $c$ 的视为黑色,那么不含 $c$ 的路径条数就是每个黑联通块的大小的平方和,修改就是当颜色是 $c$ $\leftrightarrow$ 颜色不是 $c$ 时翻转一个点的颜色。所以,问题转化成了黑白两色的树,单点翻转颜色,维护黑联通块大小的平方和。这个转化后的问题可以用很多数据结构做,比如:~~lxl:top tree 随便维护~~。这篇题解里使用 Link/cut Tree.
对每个点维护子树大小,儿子大小平方和,在 link/cut 的时候更新答案即可。有一个~~大家熟知的~~ trick,就是每个黑点向父亲连边,这样真正的联通块就是 Link/cut Tree 里的联通块删掉根。
具体看图吧,图讲的挺清楚的:
![](https://i.loli.net/2019/06/08/5cfb4d7d77b8678330.png)
![](https://i.loli.net/2019/06/08/5cfb4d7cda7be15979.png)
![](https://i.loli.net/2019/06/08/5cfb4d7d71d2055170.png)
LCT 的部分就是这样,计算答案的时候先初始化一个全是黑点的树,离线处理每个颜色,处理完一个颜色反着改回去。
```cpp
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 400010;
struct Node
{
int fa, ch[2], siz, sizi;
ll siz2i;
ll siz2() { return (ll) siz * siz; }
} t[N];
bool nroot(int x);
void rotate(int x);
void Splay(int x);
void access(int x);
int findroot(int x);
void link(int x);
void cut(int x);
void pushup(int x);
void add(int u, int v);
void dfs(int u);
int head[N], nxt[N << 1], to[N << 1], cnt;
int n, m, c[N], f[N];
ll ans, delta[N];
bool bw[N];
vector<int> mod[N][2];
int main()
{
int i, j, u, v;
ll last;
scanf("%d%d", &n, &m);
for (i = 1; i <= n; ++i)
{
scanf("%d", c + i);
mod[c[i]][0].push_back(i);
mod[c[i]][1].push_back(0);
}
for (i = 1; i <= n + 1; ++i) t[i].siz = 1;
for (i = 1; i < n; ++i)
{
scanf("%d%d", &u, &v);
add(u, v);
add(v, u);
}
for (i = 1; i <= m; ++i)
{
scanf("%d%d", &u, &v);
mod[c[u]][0].push_back(u);
mod[c[u]][1].push_back(i);
c[u] = v;
mod[v][0].push_back(u);
mod[v][1].push_back(i);
}
f[1] = n + 1;
dfs(1);
for (i = 1; i <= n; ++i) link(i);
for (i = 1; i <= n; ++i)
{
if (!mod[i][0].size())
{
delta[0] += (ll)n * n;
continue;
}
if (mod[i][1][0])
{
delta[0] += (ll)n * n;
last = (ll)n * n;
} else
last = 0;
for (j = 0; j < mod[i][0].size(); ++j)
{
u = mod[i][0][j];
if (bw[u] ^= 1)
cut(u);
else
link(u);
if (j == mod[i][0].size() - 1 || mod[i][1][j + 1] != mod[i][1][j])
{
delta[mod[i][1][j]] += ans - last;
last = ans;
}
}
for (j = mod[i][0].size() - 1; ~j; --j)
{
u = mod[i][0][j];
if (bw[u] ^= 1)
cut(u);
else
link(u);
}
}
ans = (ll) n * n * n;
for (i = 0; i <= m; ++i)
{
ans -= delta[i];
printf("%I64d ", ans);
}
return 0;
}
bool nroot(int x) { return x == t[t[x].fa].ch[0] || x == t[t[x].fa].ch[1]; }
void rotate(int x)
{
int y = t[x].fa;
int z = t[y].fa;
int k = x == t[y].ch[1];
if (nroot(y)) t[z].ch[y == t[z].ch[1]] = x;
t[x].fa = z;
t[y].ch[k] = t[x].ch[k ^ 1];
t[t[x].ch[k ^ 1]].fa = y;
t[x].ch[k ^ 1] = y;
t[y].fa = x;
pushup(y);
pushup(x);
}
void Splay(int x)
{
while (nroot(x))
{
int y = t[x].fa;
int z = t[y].fa;
if (nroot(y)) (x == t[y].ch[1]) ^ (y == t[z].ch[1]) ? rotate(x) : rotate(y);
rotate(x);
}
}
void access(int x)
{
for (int y = 0; x; x = t[y = x].fa)
{
Splay(x);
t[x].sizi += t[t[x].ch[1]].siz;
t[x].sizi -= t[y].siz;
t[x].siz2i += t[t[x].ch[1]].siz2();
t[x].siz2i -= t[y].siz2();
t[x].ch[1] = y;
pushup(x);
}
}
int findroot(int x)
{
access(x);
Splay(x);
while (t[x].ch[0]) x = t[x].ch[0];
Splay(x);
return x;
}
void link(int x)
{
int y = f[x];
Splay(x);
ans -= t[x].siz2i + t[t[x].ch[1]].siz2();
int z = findroot(y);
access(x);
Splay(z);
ans -= t[t[z].ch[1]].siz2();
t[x].fa = y;
Splay(y);
t[y].sizi += t[x].siz;
t[y].siz2i += t[x].siz2();
pushup(y);
access(x);
Splay(z);
ans += t[t[z].ch[1]].siz2();
}
void cut(int x)
{
int y = f[x];
access(x);
ans += t[x].siz2i;
int z = findroot(y);
access(x);
Splay(z);
ans -= t[t[z].ch[1]].siz2();
Splay(x);
t[x].ch[0] = t[t[x].ch[0]].fa = 0;
pushup(x);
Splay(z);
ans += t[t[z].ch[1]].siz2();
}
void pushup(int x)
{
t[x].siz = t[t[x].ch[0]].siz + t[t[x].ch[1]].siz + t[x].sizi + 1;
}
void add(int u, int v)
{
nxt[++cnt] = head[u];
head[u] = cnt;
to[cnt] = v;
}
void dfs(int u)
{
int i, v;
for (i = head[u]; i; i = nxt[i])
{
v = to[i];
if (v != f[u])
{
f[v] = u;
dfs(v);
}
}
}
```