C.solution
题解
一句话总结:质因数分解后建模,树上启发式合并。
先考虑不在树上该如何解决。
假设现在数字分为两组
我们对其中的每个数字质因数分解,并且每个质因子只保留一位,具体地,假设一个数为
那么我们把它变成
对这个
所以一个数在变形后最多只有
所以一个数最多产生
处理完所有数后发现被连了边的数的因子中没有质数的平方。我们遍历这些数。对于一个数
如果
如果
这两个求和实际上是可以写成这样的:
那么求个和乘起来即可。
为什么这样做呢?我们考虑两个数
的地方
的地方
会算重的质因子
放在树上只需要加一个启发式合并,在加一个数的时候先计算答案再把这个数要连的边加上即可,多一个
总复杂度
但是这题如果直接按因数建虚树也能过,只是常数会大,为了尽量卡这种做法我就把时间调的比较小了QAQ。
不过有人写
Code
#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int M=2e5+10,N=5e5+10,mod=998244353;
int n,a[M];
vector<int>G[M];
int p[N],isp[N],fr[N],mu[N];
vector<int>r1[N],rc[N];
LL pb[N],pc[N][20];
LL b[M],c[M];
int sz[M],son[M];
bool pd[N];
void dfs(int k,int m,int last,int sum,vector<int>&ps,vector<int>&st){
if(m-k>(int)st.size()-last)return;
if(k==m){
ps.push_back(sum);
return;
}
for(int i=last;i<(int)st.size();i++){
dfs(k+1,m,i+1,sum*st[i],ps,st);
}
}
void init(){
isp[1]=mu[1]=1;
for(int i=2;i<N;i++){
if(!isp[i])p[++p[0]]=i,fr[i]=i,mu[i]=-1;
for(int j=1;j<=p[0]&&i*p[j]<N;j++){
isp[i*p[j]]=true;
fr[i*p[j]]=p[j];
if(i%p[j]==0)break;
mu[i*p[j]]=-mu[i];
}
int s=i;
if(pd[s]){
while(s>1){
int np=fr[s];
while(fr[s]==np)s/=np;
r1[i].push_back(np);
}
for(int j=1;j<=(int)r1[i].size();j++)dfs(0,j,0,1,rc[i],r1[i]);
}
}
}
void init(int u,int f){
sz[u]=1;
for(int v:G[u]){
if(v==f)continue;
init(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
void dfs_count(int u,int f,int fr){
for(int v:rc[a[u]])b[fr]-=mu[v]*pb[v]%mod*a[u]%mod,b[fr]%=mod;
for(int v:G[u]){
if(v==f)continue;
dfs_count(v,u,fr);
}
}
void dfs_add(int u,int f){
for(int v:rc[a[u]])pb[v]+=a[u];
for(int v:G[u]){
if(v==f)continue;
dfs_add(v,u);
}
}
void dfs_del(int u,int f){
for(int v:rc[a[u]])pb[v]-=a[u];
for(int v:G[u]){
if(v==f)continue;
dfs_del(v,u);
}
}
void dfs(int u,int f){
for(int v:G[u]){
if(v==f||v==son[u])continue;
dfs(v,u);
dfs_del(v,u);
}
if(son[u]){
dfs(son[u],u);
for(int v:rc[a[u]]){
b[u]-=mu[v]*pb[v]%mod*a[u]%mod;
b[u]%=mod;
pb[v]+=a[u];
}
for(int v:G[u]){
if(v==f||v==son[u])continue;
dfs_count(v,u,u);
dfs_add(v,u);
}
}else for(int v:rc[a[u]])pb[v]+=a[u];
}
int read() {
int x = 0, w = 1;
char ch = 0;
while (ch < '0' || ch > '9') {
if (ch == '-') w = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
x = x * 10 + (ch - '0');
ch = getchar();
}
return x * w;
}
void write(LL x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
signed main(){
cin>>n;
for(int i=1;i<=n;i++)a[i]=read(),pd[a[i]]=true;
init();
for(int i=1;i<n;i++){
int u=read(),v=read();
G[u].push_back(v),G[v].push_back(u);
}
init(1,0);
dfs(1,0);
for(int i=1;i<=n;i++)write((b[i]+mod)%mod),putchar('\n');
return 0;
}