SNOI 2024 D1T1 题解

· · 题解

我现在发现我去年写的这东西太抽象了,所以我选择直接重写一遍,想看我发一堆无意义牢骚的可以看这里。

有一些 corner:关键点的 f 值为它本身,故如果存在一个 x\in[1,k] 没有在 f 中出现非法;f 值相同的点一定形如一个连通块。以上的这些东西都可以简单判断。

为方便叙述,我们令 g_x=a_{f_x},及距离 x 最近的关键点的标号。d(x,y)xy 的距离。

注意到一条满足 f_x\ne f_y 的树边 (x,y),不妨令 f_x<f_y,可以解释为:d(x,g_x)-d(y,g_y)\in[0,1]

我们关于上文“颜色的连通块”dp。设 dp_x 为满足了 x 所在联通块和该连通块的子树的所有限制,x 是关键点的方案数。初值为 f_x=1,答案是 \sum\limits_{v\le n}[f_v=f_{root}]dp_v

我们仍然关于连通块整体转移:现在存在一条满足 f_x\ne f_y 的树边 (x,y)yx 的子节点,我们要将 y 所在连通块的信息向 x 转移。另外设 w_i\sum\limits_{v\le n\land f_v=f_y\land d(v,y)=i}dp_v,根据以上的讨论,对于 x 所在连通块的一个点 o,会在这一步乘上 \begin{cases}w_{d(x,o)}+w_{d(x,o)+1}\quad f_x>f_y\\w_{d(x,o)}+w_{d(x,o)-1}\quad f_x<f_y\end{cases} 的系数。

根据这个做转移,时间复杂度 O(n^2)。不过貌似常数非常小。以下这份代码是我很久前写的,变量名与题解不同,请自己辨别。

#include<iostream>
#include<vector>
#define pb push_back
const int p = 998244353;
int n, fl, k, rs, a[3050], f[3050], ct[3050];
std::vector <int> e[3050], o[3050];
inline void gr(std::vector<int> &vc, int x, int fa, int d){
    (vc[d] += f[x]) %= p; for(auto v:e[x]) 
        if(v != fa && a[v] == a[x]) gr(vc, v, x, d + 1);
}
inline void dp(std::vector<int> &vc, int x, int fa, int k, int d){
    f[x] = 1ll * f[x] * (vc[d] + vc[d + (a[x] < k ? -1 : 1)]) % p;
    for(auto v:e[x]) if(v != fa && a[v] == a[x]) dp(vc, v, x, k, d + 1);
}
inline void dfs(int x, int fa){
    ct[a[x]] += a[x] != a[fa];
    if(ct[a[x]] >= 2) return void(fl = 1);
    for(auto v:e[x]) if(v != fa){
        dfs(v, x); if(a[v] != a[x]){
            std::vector<int>vc(n + 5);
            gr(vc, v, x, 1); dp(vc, x, v, a[v], 1);
        }
    }
}
int main(){
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);std::cout.tie(0);
    int T; std::cin >> T; while(T--){
        std::cin >> n >> k; rs = fl = 0;
        for(int i = 1; i <= n; i++)
            e[i].clear(), o[i].clear(), ct[i] = 0;
        for(int i = 1, x, y; i < n; i++)
            std::cin >> x >> y, e[x].pb(y), e[y].pb(x);
        for(int i = 1; i <= n; i++) 
            std::cin >> a[i], o[a[i]].pb(i);
        for(int i = 1; i <= n; i++) f[i] = 1;
        dfs(1, 0); for(int i = 1; i <= k; i++)
            if(!ct[i]) {fl = 1; break;}
        if(fl){std::cout << 0 << '\n'; continue;}
        for(auto v:o[a[1]]) (rs += f[v]) %= p;
        std::cout << rs << '\n';
    }
}