题解 P6668【[清华集训2016] 连通子树】
xtx1092515503 · · 题解
首先先考虑单组询问。这时应该怎么处理?
显然是一个背包问题,我们设
但是,合并两个背包的复杂度是
怎么办呢?事实上,注意到这种背包往其中添加一个数的操作是
我们考虑对树淀粉质。考虑当前的分治树。
我们钦定分治树的树根必须在背包里。这就意味着,如果儿子选了,父亲就一定要被选。进而我们可以构思出如下 DP 方式:
深搜整棵树。当尝试进入一个儿子时,我们首先将父亲的 DP 数组拷贝一份给儿子,然后进入儿子递归,递归完儿子后再把父子的背包合并即可。
明显此时背包仅仅涉及到添加一个数,因而复杂度是
然后考虑回答多组询问。因为题目保证每种颜色的元素数都不会太多,故首先显然我们可以对关心的三种颜色建虚树。
建完虚树后,我们可以在虚树上做淀粉质,并且套用上述合并过程。
考虑合并过程中所有需要的值。
首先我们需要知道某棵子树中选择一个包含根的非空连通块的方案数。这可以令
然后考虑转移,这时我们就发现转移时要知道链的方案数:假如某个儿子不选,则链的方案就是链顶强制选、链底强制不选的方案数;选,就是链顶链底都选的方案数。因为淀粉质时每条链从上从下被考虑都是有可能的,所以对于每条在原树中从上往下的链,我们需要知道链底选而链顶不选的方案数(记作
列出式子就会发现当一条链的链顶与另一条链的链底相同时,这两条链可以运用上述值加以合并,需要知道的辅助值是相同点周围的点任意选择或不选择的方案数。于是可以倍增处理,每次合并两条链。
但是,如果把合并两条链时所要进行的式子全都列出来,我们最终甚至可以发现链间具有可减性,进而直接一个前缀和就解决了。
需要注意的是,当三种颜色需要的数量都为
前者可以在换根 DP 时顺手预处理掉(因为虚树以
时间复杂度为做到
虽然但是,本人的实现方式常数很大,花了好久才卡过 LG 的评测,并始终无法在 UOJ 上通过。这里介绍一种优化常数的方法:
考虑不在递归时 DP,而是在递归时求出 dfs 序。这样,如果一个儿子被选就等价于用
这种 DP 方式减少了很多拷贝数组的操作,故常数比起之前要小。
又臭又长的代码:
#include<bits/stdc++.h>
using namespace std;
const int N=100100;
const int LG=17;
typedef long long ll;
constexpr int mod=1e9+7;
int ksm(int x,int y=mod-2){int z=1;for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*z*x%mod;return z;}
void ADJ(int&x){if(x>=mod)x-=mod;}
int RED(int x){return x>=mod?x-mod:x;}
namespace FIFO{
char buf[1<<23],*p1=buf,*p2=buf;
#ifndef Troverld
#define gc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#else
#define gc() getchar()
#endif
template<typename T>void read(T&x){
x=0;
char c=gc();
while(c>'9'||c<'0')c=gc();
while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=gc();
}
template<typename T>void print(T x){if(x<=9)putchar('0'+x);else print(x/10),putchar('0'+x%10);}
}using namespace FIFO;
int n,m,col[N];
vector<int>cc[N];
namespace TRE{
struct Chain{
int bot,top,whl,unl;
Chain(int B,int T,int W,int U){bot=B,top=T,whl=W,unl=U;}
Chain(){bot=top=unl=0,whl=1;}
Chain REV(){return Chain(top,bot,whl,unl);}
// void pr(){printf("[%d %d %d %d]",bot,top,whl,unl);}
}chn[N];
Chain merge(Chain u,Chain v,int val,int bon){
return Chain(
(1ll*u.whl*v.bot%mod*val+u.bot)%mod,
(1ll*u.top*v.whl%mod*val+v.top)%mod,
1ll*u.whl*v.whl%mod*val%mod,
(1ll*u.top*v.bot%mod*val+u.unl+v.unl+bon)%mod
);
}
Chain split(Chain u,Chain v,int val,int bon){
int inv=ksm(1ll*v.whl*val%mod);
int whl=1ll*u.whl*inv%mod;
int bot=RED(u.bot+mod-1ll*whl*v.bot%mod*val%mod);
int top=1ll*(u.top+mod-v.top)*inv%mod;
int unl=RED(u.unl+mod-(1ll*top*v.bot%mod*val+v.unl+bon)%mod);
return Chain(bot,top,whl,unl);
}
vector<int>v[N];
void aev(int x,int y){v[x].push_back(y),v[y].push_back(x);}
int sub[N],ful[N],anc[N][LG],dep[N],dfn[N],tot,sum[N],ult[N],isb[N],iut[N];
int pas(int x,int y){if(dep[x]>dep[y])swap(x,y);return 1ll*sub[x]*isb[y]%mod;}
void dfs1(int x,int fa){
sub[x]=1,dep[x]=dep[fa]+1,dfn[x]=++tot;
for(auto y:v[x])if(y!=fa)
dfs1(y,x),sub[x]=1ll*sub[x]*(sub[y]+1)%mod,ADJ(sum[x]+=RED(sum[y]+sub[y]));
isb[x]=ksm(sub[x]+1);
}
void dfs2(int x,int fa){
for(auto y:v[x])if(y!=fa)
ult[y]=1ll*ful[x]*isb[y]%mod,iut[y]=ksm(ult[y]+1),
ful[y]=1ll*(ult[y]+1)*sub[y]%mod,dfs2(y,x);
}
void dfs3(int x,int fa){
anc[x][0]=fa;
for(int i=1;i<LG;i++)anc[x][i]=anc[anc[x][i-1]][i-1];
for(auto y:v[x])if(y!=fa)
chn[y]=merge(Chain(1,1,1,0),chn[x],pas(x,y),RED(sum[x]+mod-RED(sum[y]+sub[y]))),
dfs3(y,x);
}
int JUM(int x,int y){for(int i=LG-1;i>=0;i--)if(dep[x]<dep[y]-(1<<i))y=anc[y][i];return y;}
Chain Path(int x,int y){
int z=JUM(x,y);
return split(chn[y],chn[x],pas(x,z),RED(sum[x]+mod-RED(sum[z]+sub[z])));
}
int LCA(int x,int y){
if(dep[x]>dep[y])swap(x,y);
for(int i=LG-1;i>=0;i--)if(dep[x]<=(dep[y]-(1<<i)))y=anc[y][i];
if(x==y)return x;
for(int i=LG-1;i>=0;i--)if(anc[x][i]!=anc[y][i])x=anc[x][i],y=anc[y][i];
return anc[x][0];
}
}using namespace TRE;
namespace IMG{
int stk[N],tp,arr[N],head[N],lam[N],cnt;
struct node{int to,next,all,par;}edge[N<<1];
bool vis[N];
int f[N][6][6][6],RT,SZ,sz[N],msz[N];
int a,b,c,A,B,C,res;
void aeh(int x,int y){
if(dep[x]>dep[y])swap(x,y);
Chain z=Path(x,y);
// printf("%d %d\n",x,y);
if(!A&&!B&&!C)ADJ(res+=z.unl);
edge[cnt].next=head[x],edge[cnt].to=y;
edge[cnt].all=z.whl,edge[cnt].par=z.top,head[x]=cnt++;
int zz=JUM(x,y);
lam[x]=1ll*lam[x]*isb[zz]%mod;
if(!A&&!B&&!C)ADJ(res+=mod-RED(sum[zz]+sub[zz]));
edge[cnt].next=head[y],edge[cnt].to=x;
edge[cnt].all=z.whl,edge[cnt].par=z.bot,head[y]=cnt++;
lam[y]=1ll*lam[y]*iut[y]%mod;
}
void ins(int x){
if(!tp){stk[++tp]=x,SZ++;return;}
int lca=LCA(stk[tp],x);
while(tp>=2&&dep[stk[tp-1]]>=dep[lca])aeh(stk[tp-1],stk[tp]),tp--;
if(dep[stk[tp]]>dep[lca])aeh(lca,stk[tp--]);
if(stk[tp]!=lca)stk[++tp]=lca,SZ++;
if(stk[tp]!=x)stk[++tp]=x,SZ++;
}
void fin(){while(tp>=2)aeh(stk[tp-1],stk[tp]),tp--;tp=0;}
#define iter for(int i=head[x],y;i!=-1;i=edge[i].next)if(!vis[y=edge[i].to])
void getroot(int x,int fa){
sz[x]=1,msz[x]=0;
iter if(y!=fa)getroot(y,x),sz[x]+=sz[y],msz[x]=max(msz[x],sz[y]);
msz[x]=max(msz[x],SZ-sz[x]);
if(msz[x]<msz[RT])RT=x;
}
void getsz(int x,int fa){sz[x]=1;iter if(y!=fa)getsz(y,x),sz[x]+=sz[y];}
#define trit for(int _a=0;_a<=A;_a++)for(int _b=0;_b<=B;_b++)for(int _c=0;_c<=C;_c++)
#define trin _a][_b][_c
int num,rec[N],lea[N],ALL[N],PAR[N],LAM[N];
void getdfn(int x,int fa){
int id=++num;
// printf("%d ",x);
rec[id]=col[x],LAM[id]=lam[x];
for(int i=head[x],y;i!=-1;i=edge[i].next){
if((y=edge[i].to)==fa)continue;
if(vis[edge[i].to])LAM[id]=1ll*LAM[id]*edge[i].par%mod;
else PAR[num+1]=edge[i].par,ALL[num+1]=edge[i].all,getdfn(y,x);
}
lea[id]=num;
}
void calc(int x){
getdfn(x,0);
for(int i=1;i<=num;i++)trit f[i][trin]=0;
f[1][rec[1]==a][rec[1]==b][rec[1]==c]=LAM[1];
for(int i=2;i<=num;i++){
trit{
int _A=_a+(rec[i]==a),_B=_b+(rec[i]==b),_C=_c+(rec[i]==c);
if(_A<=A&&_B<=B&&_C<=C)
f[i][_A][_B][_C]=
(1ll*f[i-1][trin]*ALL[i]%mod*LAM[i]+f[i][_A][_B][_C])%mod;
}
trit f[lea[i]][trin]=(1ll*f[i-1][trin]*PAR[i]+f[lea[i]][trin])%mod;
}
ADJ(res+=f[num][A][B][C]),num=0;
}
void solve(int x){
calc(x),getsz(x,0),vis[x]=true;
iter SZ=sz[y],RT=0,getroot(y,0),solve(RT);
}
void nil(int x,int fa){ADJ(res+=sum[x]);iter if(y!=fa)nil(y,x);}
void reset(int x,int fa){
for(int i=head[x],y;i!=-1;i=edge[i].next)if((y=edge[i].to)!=fa)reset(y,x);
lam[x]=ful[x],vis[x]=false,head[x]=-1;
}
void query(){
read(a),read(A),read(b),read(B),read(c),read(C);
if(A>cc[a].size()||B>cc[b].size()||C>cc[c].size()){putchar('0'),putchar('\n');return;}
for(auto x:cc[a])arr[++arr[0]]=x;
for(auto x:cc[b])arr[++arr[0]]=x;
for(auto x:cc[c])arr[++arr[0]]=x;
sort(arr+1,arr+arr[0]+1,[](int x,int y){return dfn[x]<dfn[y];});
if(arr[1]!=1)ins(1);for(int i=1;i<=arr[0];i++)ins(arr[i]);fin();
if(!A&&!B&&!C)nil(1,0);
getroot(1,0),solve(RT);
print(res),putchar('\n');
RT=SZ=res=cnt=0;for(int i=1;i<=arr[0];i++)arr[i]=0;arr[0]=0,reset(1,0);
}
}using namespace IMG;
int main(){
read(n),read(m),memset(head,-1,sizeof(head));
for(int i=1;i<=n;i++)read(col[i]),cc[col[i]].push_back(i);
for(int i=1,x,y;i<n;i++)read(x),read(y),aev(x,y);
dfs1(1,0),ful[1]=sub[1],dfs2(1,0),dfs3(1,0);
msz[0]=n;
for(int i=1;i<=n;i++)lam[i]=ful[i];
for(int i=1;i<=m;i++)query();
return 0;
}