P11648 【MX-X8-T7】「TAOI-3」2236 A.D. 题解

· · 题解

写一个详细点的题解。

考虑暴力怎么做,对于每个点 u,我们尝试求出 \text{lca} (x,y) =u 的点对 (x,y) 的贡献,每个点的答案就是子树内贡献的和。暴力做法就是每次把点 u 两个子树合并,计算 xy 分别位于这两部分时的贡献。

这里如果你直接枚举两部分子树内的点然后贡献,按照树上背包的方法分析就是 O(n^2) 的,如果你每次合并做 FWT 就是 O(nk2^k) 的。

你注意这两种合并方式中,前者的复杂度与合并的两部分大小有关,后者只与 k 有关,这启发我们结合一下两部分然后平衡一下复杂度。设一个阈值 B,每个节点维护两部分,一部分是暴力合并的部分,另一部分是 FWT 合并的部分,记前者为散点,后者为整块,每次合并时二者对应合并,如果散点数量超过阈值了就清空并全部塞到整块里面。

算答案怎么办?散点可以暴力,这部分的复杂度是 O(nB) 的,因为一个散点最多和 O(B) 个点被暴力考虑,否则就会被塞到整块里。整块和整块可以 FWT,这只会操作 O(\frac{n}{B}) 次,也是对的。

你发现似乎只有散点和整块的贡献不好处理,而这需要对每个 x 算一个类似 a_x=\sum_{y} c_{x|y} 的东西,注意到 c_{x|y} 可以拆开算,变成 a_x=c_x\sum_{y} c_{(x|y) -x},然后后者可以在一个类似 FWT 的过程中算掉,考虑在 FWT 中你每次对于某一位为 0 和 1 其他位一样的数作变换,设这两个数为 xy,令 x\leftarrow (x+wy)y\leftarrow (x+y),其中 w 是这一位的贡献,这部分复杂度是 O(nB) 的。

如果你 T 了可以考虑块长开大点,因为常数有点大。


#include<bits/stdc++.h>

#define ll long long
#define pi pair<int,int>
#define vi vector<int>
#define cpy(x,y,s) memcpy(x,y,sizeof(x[0])*(s))
#define mset(x,v,s) memset(x,v,sizeof(x[0])*(s))
#define all(x) begin(x),end(x)
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define ary array
#define eb emplace_back
#define IL inline
#define For(i,j,k) for(int i=(j);i<=(k);i++)
#define Fol(i,k,j) for(int i=(k);i>=(j);i--)

using namespace std;

#define B 1400
#define N 500005
#define K 33005
#define mod 998244353

int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0' || ch>'9')f=(ch=='-'?-1:f),ch=getchar();
    while(ch>='0' && ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return x*f;
}
void write(int x){
    if(x<0)x=-x,putchar('-');
    if(x/10)write(x/10);
    putchar(x%10+'0');
}

void debug(auto &&...x){
    ((cerr<<x<<' '),...);
    cerr<<'\n';
}

IL int pls(int x,int y){return (x+y>=mod?x+y-mod:x+y);}
IL int sub(int x,int y){return (x-y<0?x-y+mod:x-y);}
IL void Add(int &x,int y){x=pls(x,y);}
IL void Dec(int &x,int y){x=sub(x,y);}
IL int mul(int x,int y){return x*1ll*y%mod;}
IL int qp(int x,int y=mod-2){int ans=1;while(y){if(y&1)ans=mul(ans,x);x=mul(x,x);y>>=1;}return ans;}

int id[N],w[N],c[N],a[N];
basic_string<int> e[N];
int vis[B][16],f[B][K],_f[B][K],coef[B][K],ans[N],cnt,tmp[K];
vi S[N];
int n,k;

void FWT(int *g,int opt){
    for(int len=2,mid=1;len<=(1<<k);len<<=1,mid<<=1){
        for(int j=0;j<(1<<k);j+=len){
            for(int k=j;k<j+mid;k++){
                if(!opt)Add(g[k+mid],g[k]);
                else Dec(g[k+mid],g[k]);
            }
        }
    }
}

int calc(int u,int v){
    int res=0;
    if(id[u] && id[v]){
        For(i,0,(1<<k)-1)tmp[i]=mul(_f[id[u]][i],_f[id[v]][i]);
        FWT(tmp,1);
        For(i,0,(1<<k)-1)Add(res,mul(tmp[i],w[i]));
    }
    if(id[v])for(auto val:S[u])Add(res,mul(coef[id[v]][val],w[val]));
    if(id[u])for(auto val:S[v])Add(res,mul(coef[id[u]][val],w[val]));
    for(auto val:S[u])for(auto val2:S[v])Add(res,w[val|val2]);
    return res;
}

void rebuild(int u,int v){
    if(!id[u])id[u]=++cnt;
    for(auto val:S[u])f[id[u]][val]++;S[u].clear();
    if(id[v])For(i,0,(1<<k)-1)f[id[u]][i]+=f[id[v]][i];
    For(i,0,(1<<k)-1)_f[id[u]][i]=f[id[u]][i],coef[id[u]][i]=f[id[u]][i];
    FWT(_f[id[u]],0);
    for(int len=2,mid=1;len<=(1<<k);len<<=1,mid<<=1){
        int z=__lg(len);
        for(int j=0;j<(1<<k);j+=len){
            for(int k=j;k<j+mid;k++){
                int x=coef[id[u]][k],y=coef[id[u]][k+mid];
                coef[id[u]][k]=pls(x,mul(y,c[z]));
                coef[id[u]][k+mid]=pls(x,y);
            }
        }
    }
    For(i,1,k)vis[id[u]][i]=0;
}

void opr(int u,int t){
    if(vis[id[u]][t])return;vis[id[u]][t]=1;
    For(S,0,(1<<k)-1)if(!(S&(1<<(t-1))))Add(f[id[u]][(S|(1<<(t-1)))],f[id[u]][S]),f[id[u]][S]=0;
    For(i,0,(1<<k)-1)_f[id[u]][i]=f[id[u]][i],coef[id[u]][i]=f[id[u]][i];
    FWT(_f[id[u]],0);
    for(int len=2,mid=1;len<=(1<<k);len<<=1,mid<<=1){
        int z=__lg(len);
        for(int j=0;j<(1<<k);j+=len){
            for(int k=j;k<j+mid;k++){
                int x=coef[id[u]][k],y=coef[id[u]][k+mid];
                coef[id[u]][k]=pls(x,mul(y,c[z]));
                coef[id[u]][k+mid]=pls(x,y);
            }
        }
    }
}

void dfs(int x,int fa){
    S[x].pb((1<<(a[x]-1)));ans[x]=c[a[x]];
    for(auto v:e[x]){
        if(v==fa)continue;
        dfs(v,x);
        //operate
        if(id[v])opr(v,a[x]);
        for(auto &val:S[v])val|=(1<<(a[x]-1));
        //calculate
        Add(ans[x],calc(x,v));Add(ans[x],ans[v]);
        //merge
        for(auto val:S[v])S[x].pb(val);vector<int>().swap(S[v]);
        if(id[v])swap(id[x],id[v]);
        if(id[v] || S[x].size()>B)rebuild(x,v);
    }
}

int main(){
    #ifdef EAST_CLOUD
    freopen("a.in","r",stdin);
    freopen("a.out","w",stdout);
    #endif

    n=read(),k=read();
    For(i,1,k)c[i]=read();
    For(i,1,n)a[i]=read();
    For(i,1,n-1){
        int u=read(),v=read();
        e[u]+=v;e[v]+=u;
    }
    For(S,0,(1<<k)-1){
        w[S]=1;
        For(j,1,k)if(S&(1<<(j-1)))w[S]=mul(w[S],c[j]);
    }
    dfs(1,0);
    For(i,1,n)write(ans[i]),putchar(' ');
    return 0;
}