题解:CF1923E Count Paths

· · 题解

教你玩爆CF2000的题

1.方法本质

这道题的解法不难,我们将所求路径分为两种情况:

  1. 两点为祖先关系,如图:

  2. 两点并非祖先关系,如图: 为了去重,我们按如下计数顺序处理:

  3. 有祖先关系时,层数较小的节点处理此路线。

  4. 无祖先关系的,后遍历的节点处理此路线。

2.整体处理

我们用 a_i 处理 i 号颜色的情况 1 路径计数,b_i 处理 i 号颜色情况 2 的方案数,接下来深搜遍历每个节点 v 并按顺序作如下操作:

  1. 预存 va = a_{c_v},vb = b_{c_v}
  2. 得到 vb 条对应情况 2 的路径;
  3. 遍历子节点,因为父节点拦截了子树外的同色节点,所以遍历前 b_{c_v} = 0
  4. 我们得到了从子节点来的 a_{c_v} - va 条对应情况 1 的新路径。
  5. 因为该节点会拦截子树内同色节点,并贡献该节点,最后改 a_{c_v} = va + 1,b_{c_v} = vb + 1

3 选择原因

  1. 速度:时间复杂度 \operatorname{O}(n),并且常数小,本人代码可以在使用 cincout 且不关同步时仍然卡进 1 \text{ s}
  2. 空间:复杂度 \operatorname{O}(n),且用的数组数量少,具体的,本人代码空间 17700 \text{ KB}
  3. 难度:思维和代码难度简单,支持普及组实力选手学习和使用。

4.代码

scanfprintf (时间 240\text{ ms} 左右,空间 17700\text{ KB} 左右):

#include<cstdio>
#include<vector>
using namespace std;
int t,n,c[200009],u,v;
int cntc[200009],up[200009];
vector<int>e[200005];
long long ans;
void srh(int v,int fa){
    int u = cntc[c[v]],s = up[c[v]];
    ans += s;
    for(int i = 0; i < e[v].size(); i ++){
        if(e[v][i] == fa)
            continue;
        up[c[v]] = 0;
        srh(e[v][i],v);
    }
    ans += cntc[c[v]] - u;
    cntc[c[v]] = u + 1;
    up[c[v]] = s + 1;
}
int main(){
    scanf("%d",&t);
    while(t--){
        scanf("%d",&n);
        for(int i = 1; i <= n; i ++){
            scanf("%d",&c[i]);
            cntc[c[i]] = up[c[i]] = 0;
            e[i].clear();
        }
        ans = 0;
        for(int i = 1; i < n; i ++){
            scanf("%d %d",&u,&v);
            e[u].push_back(v);
            e[v].push_back(u);
        }
        srh(1,0);
        printf("%lld\n",ans);
    }
    return 0;
}

cincout 代码(不关同步,空间不变,时间 820\text{ ms} 左右):

#include<cstdio>
#include<vector>
#include<iostream>
using namespace std;
int t,n,c[200009],u,v;
int cntc[200009],up[200009];
vector<int>e[200005];
long long ans;
void srh(int v,int fa){
    int u = cntc[c[v]],s = up[c[v]];
    ans += s;
    for(int i = 0; i < e[v].size(); i ++){
        if(e[v][i] == fa)
            continue;
        up[c[v]] = 0;
        srh(e[v][i],v);
    }
    ans += cntc[c[v]] - u;
    cntc[c[v]] = u + 1;
    up[c[v]] = s + 1;
}
int main(){
    cin >> t;
    while(t--){
        cin >> n;
        for(int i = 1; i <= n; i ++){
            cin >> c[i];
            cntc[c[i]] = up[c[i]] = 0;
            e[i].clear();
        }
        ans = 0;
        for(int i = 1; i < n; i ++){
            cin >> u >> v;
            e[u].push_back(v);
            e[v].push_back(u);
        }
        srh(1,0);
        cout << ans << endl;
    }
    return 0;
}