C.solution

· · 题解

题解

一句话总结:质因数分解后建模,树上启发式合并。

先考虑不在树上该如何解决。

假设现在数字分为两组 S,T(看成是两颗子树中的数)。

我们对其中的每个数字质因数分解,并且每个质因子只保留一位,具体地,假设一个数为

c_i=p_1^{a_1}\times p_2^{a_2}\times \dots \times p_n^{a_n}

那么我们把它变成

c_i'=p_1\times p_2\times \dots \times p_n

对这个 c_i 我们向 c_i'所有因数连边,可以发现,因为

2\times 3\times 5\times 7\times 11\times 13\times 17=510510>500000

所以一个数在变形后最多只有 6 个质因子,那么它的因数个数最大就为

\binom{6}{1}+\binom{6}{2}+\dots+\binom{6}{6}=63

所以一个数最多产生 63 条边。

处理完所有数后发现被连了边的数的因子中没有质数的平方。我们遍历这些数。对于一个数 x,设 T 中向 x 连边的数的集合为 QS 中向 x 连边的数的集合为 P

如果 x奇数个质因子就会产生贡献:

\sum_{x\in Q}\sum_{y\in P}x\times y

如果 x偶数个质因子就会产生贡献:

-\sum_{x\in Q}\sum_{y\in P}x\times y

这两个求和实际上是可以写成这样的:

\sum_{x\in Q}x\sum_{y\in P}y

那么求个和乘起来即可。

为什么这样做呢?我们考虑两个数 a,b,先对它们做最开始的处理变成 a',b'。不难发现,如果 \gcd(a,b)>1 那么 \gcd(a',b')>1,所以我们现在只需要判断 a',b' 的关系,按照上面的式子,在

p_1,p_2\dots p_n

的地方 a',b' 都会被统计进答案,这明显算多了,但是在

p_1p_2,p_1p_3\dots p_np_{n-1}

的地方 a',b' 又产生了负贡献,到这里已经可以发现这个其实就是容斥了。

会算重的质因子 a',b' 一定都有,那么这些质因子组合出来的数同样也是 a',b' 的因子,所以按照这样算完刚好满足容斥的式子,也即每一对数刚好被贡献了一次。时间复杂度 \Theta (n) 有一个 63 的常数。

放在树上只需要加一个启发式合并,在加一个数的时候先计算答案再把这个数要连的边加上即可,多一个 \log

总复杂度 \Theta (n\log{n}) 有一个 63 的常数,不过因为质因子分解几乎卡不满,所以会稍微快一点点,勉强通过。

但是这题如果直接按因数建虚树也能过,只是常数会大,为了尽量卡这种做法我就把时间调的比较小了QAQ。

不过有人写 \Theta(n) 的虚树加上卡常依然能过,只能说我这道题出的真的不好,考虑的太少了QAQ。

Code

#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int M=2e5+10,N=5e5+10,mod=998244353;
int n,a[M];
vector<int>G[M];
int p[N],isp[N],fr[N],mu[N];
vector<int>r1[N],rc[N];
LL pb[N],pc[N][20];
LL b[M],c[M];
int sz[M],son[M];
bool pd[N];
void dfs(int k,int m,int last,int sum,vector<int>&ps,vector<int>&st){
    if(m-k>(int)st.size()-last)return;
    if(k==m){
        ps.push_back(sum);
        return;
    }
    for(int i=last;i<(int)st.size();i++){
        dfs(k+1,m,i+1,sum*st[i],ps,st);
    }
}
void init(){
    isp[1]=mu[1]=1;
    for(int i=2;i<N;i++){
        if(!isp[i])p[++p[0]]=i,fr[i]=i,mu[i]=-1;
        for(int j=1;j<=p[0]&&i*p[j]<N;j++){
            isp[i*p[j]]=true;
            fr[i*p[j]]=p[j];
            if(i%p[j]==0)break;
            mu[i*p[j]]=-mu[i];
        }
        int s=i;
        if(pd[s]){
            while(s>1){
                int np=fr[s];
                while(fr[s]==np)s/=np;
                r1[i].push_back(np);
            }
            for(int j=1;j<=(int)r1[i].size();j++)dfs(0,j,0,1,rc[i],r1[i]);
        }
    }
}
void init(int u,int f){
    sz[u]=1;
    for(int v:G[u]){
        if(v==f)continue;
        init(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]])son[u]=v;
    }
}
void dfs_count(int u,int f,int fr){
    for(int v:rc[a[u]])b[fr]-=mu[v]*pb[v]%mod*a[u]%mod,b[fr]%=mod;
    for(int v:G[u]){
        if(v==f)continue;
        dfs_count(v,u,fr);
    }
}
void dfs_add(int u,int f){
    for(int v:rc[a[u]])pb[v]+=a[u];
    for(int v:G[u]){
        if(v==f)continue;
        dfs_add(v,u);
    }
}
void dfs_del(int u,int f){
    for(int v:rc[a[u]])pb[v]-=a[u];
    for(int v:G[u]){
        if(v==f)continue;
        dfs_del(v,u);
    }
}
void dfs(int u,int f){
    for(int v:G[u]){
        if(v==f||v==son[u])continue;
        dfs(v,u);
        dfs_del(v,u);
    }
    if(son[u]){
        dfs(son[u],u);
        for(int v:rc[a[u]]){
            b[u]-=mu[v]*pb[v]%mod*a[u]%mod;
            b[u]%=mod;
            pb[v]+=a[u];
        }
        for(int v:G[u]){
            if(v==f||v==son[u])continue;
            dfs_count(v,u,u);
            dfs_add(v,u);
        }
    }else for(int v:rc[a[u]])pb[v]+=a[u];
}
int read() {
    int x = 0, w = 1;
    char ch = 0;
    while (ch < '0' || ch > '9') {
        if (ch == '-') w = -1;       
        ch = getchar();              
    }
    while (ch >= '0' && ch <= '9') {  
        x = x * 10 + (ch - '0'); 
        ch = getchar();  
    }
    return x * w;  
}
void write(LL x) {
    if (x < 0) {  
        x = -x;
        putchar('-');
    }
    if (x > 9) write(x / 10); 
    putchar(x % 10 + '0');  
}
signed main(){
    cin>>n;
    for(int i=1;i<=n;i++)a[i]=read(),pd[a[i]]=true;
    init();
    for(int i=1;i<n;i++){
        int u=read(),v=read();
        G[u].push_back(v),G[v].push_back(u);
    }
    init(1,0);
    dfs(1,0);
    for(int i=1;i<=n;i++)write((b[i]+mod)%mod),putchar('\n');
    return 0;
}