题解:CF1923E Count Paths
Nuclear_Pasta
2024-03-15 22:02:02
~~教你玩爆CF2000的题~~
# 1.方法本质
这道题的解法不难,我们将所求路径分为两种情况:
1. 两点为祖先关系,如图:

2. 两点并非祖先关系,如图:

为了去重,我们按如下计数顺序处理:
1. 有祖先关系时,层数较小的节点处理此路线。
2. 无祖先关系的,后遍历的节点处理此路线。
# 2.整体处理
我们用 $a_i$ 处理 $i$ 号颜色的情况 $1$ 路径计数,$b_i$ 处理 $i$ 号颜色情况 $2$ 的方案数,接下来深搜遍历每个节点 $v$ 并按顺序作如下操作:
1. 预存 $va = a_{c_v},vb = b_{c_v}$;
2. 得到 $vb$ 条对应情况 $2$ 的路径;
3. 遍历子节点,因为父节点拦截了子树外的同色节点,所以遍历前 $b_{c_v} = 0$;
4. 我们得到了从子节点来的 $a_{c_v} - va$ 条对应情况 $1$ 的新路径。
5. 因为该节点会拦截子树内同色节点,并贡献该节点,最后改 $a_{c_v} = va + 1,b_{c_v} = vb + 1$。
# 3 选择原因
1. 速度:时间复杂度 $\operatorname{O}(n)$,并且常数小,本人代码可以在使用 `cin` 与 `cout` 且不关同步时仍然卡进 $1 \text{ s}$。
3. 空间:复杂度 $\operatorname{O}(n)$,且用的数组数量少,具体的,本人代码空间 $17700 \text{ KB}$。
2. 难度:思维和代码难度简单,支持普及组实力选手学习和使用。
# 4.代码
用 `scanf` 和 `printf` (时间 $240\text{ ms}$ 左右,空间 $17700\text{ KB}$ 左右):
```cpp
#include<cstdio>
#include<vector>
using namespace std;
int t,n,c[200009],u,v;
int cntc[200009],up[200009];
vector<int>e[200005];
long long ans;
void srh(int v,int fa){
int u = cntc[c[v]],s = up[c[v]];
ans += s;
for(int i = 0; i < e[v].size(); i ++){
if(e[v][i] == fa)
continue;
up[c[v]] = 0;
srh(e[v][i],v);
}
ans += cntc[c[v]] - u;
cntc[c[v]] = u + 1;
up[c[v]] = s + 1;
}
int main(){
scanf("%d",&t);
while(t--){
scanf("%d",&n);
for(int i = 1; i <= n; i ++){
scanf("%d",&c[i]);
cntc[c[i]] = up[c[i]] = 0;
e[i].clear();
}
ans = 0;
for(int i = 1; i < n; i ++){
scanf("%d %d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
srh(1,0);
printf("%lld\n",ans);
}
return 0;
}
```
用 `cin`,`cout` 代码(不关同步,空间不变,时间 $820\text{ ms}$ 左右):
```cpp
#include<cstdio>
#include<vector>
#include<iostream>
using namespace std;
int t,n,c[200009],u,v;
int cntc[200009],up[200009];
vector<int>e[200005];
long long ans;
void srh(int v,int fa){
int u = cntc[c[v]],s = up[c[v]];
ans += s;
for(int i = 0; i < e[v].size(); i ++){
if(e[v][i] == fa)
continue;
up[c[v]] = 0;
srh(e[v][i],v);
}
ans += cntc[c[v]] - u;
cntc[c[v]] = u + 1;
up[c[v]] = s + 1;
}
int main(){
cin >> t;
while(t--){
cin >> n;
for(int i = 1; i <= n; i ++){
cin >> c[i];
cntc[c[i]] = up[c[i]] = 0;
e[i].clear();
}
ans = 0;
for(int i = 1; i < n; i ++){
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
srh(1,0);
cout << ans << endl;
}
return 0;
}
```