题解 P6668【[清华集训2016] 连通子树】

· · 题解

首先先考虑单组询问。这时应该怎么处理?

显然是一个背包问题,我们设 f_{i,a,b,c} 表示 i 的子树中三种颜色各选了 a,b,c 个时的方案数,然后简单背包转移即可。

但是,合并两个背包的复杂度是 (abc)^2 的,无法通过。

怎么办呢?事实上,注意到这种背包往其中添加一个数的操作是 O(abc) 的,因此我们尝试构建一种方法,其中全程仅涉及到添加数,不涉及到两个背包间的卷积。

我们考虑对树淀粉质。考虑当前的分治树。

我们钦定分治树的树根必须在背包里。这就意味着,如果儿子选了,父亲就一定要被选。进而我们可以构思出如下 DP 方式:

深搜整棵树。当尝试进入一个儿子时,我们首先将父亲的 DP 数组拷贝一份给儿子,然后进入儿子递归,递归完儿子后再把父子的背包合并即可。

明显此时背包仅仅涉及到添加一个数,因而复杂度是 O(abcn\log n)

然后考虑回答多组询问。因为题目保证每种颜色的元素数都不会太多,故首先显然我们可以对关心的三种颜色建虚树。

建完虚树后,我们可以在虚树上做淀粉质,并且套用上述合并过程。

考虑合并过程中所有需要的值。

首先我们需要知道某棵子树中选择一个包含根的非空连通块的方案数。这可以令 sub_x=\prod\limits_{y\text{ is a son of }x}(sub_y+1) 得出。但是子树不一定是以 1 为根的子树,故我们需要再换个根得到 ful_x 表示以 x 为根时的答案。结合 subful,我们可以关于虚树中每个点找出其所有不包含任何虚树上点的子树中任意选择的方案数,记作 lam_x。因为 sub 的递推式表明其满足可减性,故 lam 可以直接由 ful 除去所有虚树上与之邻接的点对应子树的方案数得到。

然后考虑转移,这时我们就发现转移时要知道链的方案数:假如某个儿子不选,则链的方案就是链顶强制选、链底强制不选的方案数;选,就是链顶链底都选的方案数。因为淀粉质时每条链从上从下被考虑都是有可能的,所以对于每条在原树中从上往下的链,我们需要知道链底选而链顶不选的方案数(记作 bot)、链底不选而链顶选的方案数(记作 top)、链顶链底都选的方案数(记作 whl)。

列出式子就会发现当一条链的链顶与另一条链的链底相同时,这两条链可以运用上述值加以合并,需要知道的辅助值是相同点周围的点任意选择或不选择的方案数。于是可以倍增处理,每次合并两条链。

但是,如果把合并两条链时所要进行的式子全都列出来,我们最终甚至可以发现链间具有可减性,进而直接一个前缀和就解决了。

需要注意的是,当三种颜色需要的数量都为 0 时,此时选出的连通块可能不包含任何虚树中的点,也即其来自于某点在实树上的子树中,或者在一条链的中央。

前者可以在换根 DP 时顺手预处理掉(因为虚树以 1 为根,所以实树上子树必然不可能是父亲方向的子树,故只需要对 sub 在子树上求个和即可)。后者需要额外记一个 unl 表示链顶链底都不选的方案数,可以和 bot,top,whl 一样转移,且一样满足可减性。

时间复杂度为做到 n\log n+\sum abc(a+b+c)\log(a+b+c)

虽然但是,本人的实现方式常数很大,花了好久才卡过 LG 的评测,并始终无法在 UOJ 上通过。这里介绍一种优化常数的方法:

考虑不在递归时 DP,而是在递归时求出 dfs 序。这样,如果一个儿子被选就等价于用 f_i 转移到 f_{i+1},不被选就等价于用 f_i 转移到 f_{lea_{i+1}},其中 lea_{i+1} 意味着 i+1 所在子树中 dfs 序最大点的编号加一,这步的意义是跳过整棵 i+1 子树不选。

这种 DP 方式减少了很多拷贝数组的操作,故常数比起之前要小。

又臭又长的代码:

#include<bits/stdc++.h>
using namespace std;
const int N=100100;
const int LG=17;
typedef long long ll;
constexpr int mod=1e9+7;
int ksm(int x,int y=mod-2){int z=1;for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*z*x%mod;return z;}
void ADJ(int&x){if(x>=mod)x-=mod;}
int RED(int x){return x>=mod?x-mod:x;}
namespace FIFO{
char buf[1<<23],*p1=buf,*p2=buf;
#ifndef Troverld
#define gc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#else
#define gc() getchar()
#endif
template<typename T>void read(T&x){
    x=0;
    char c=gc();
    while(c>'9'||c<'0')c=gc();
    while(c>='0'&&c<='9')x=(x<<3)+(x<<1)+(c^48),c=gc();
}
template<typename T>void print(T x){if(x<=9)putchar('0'+x);else print(x/10),putchar('0'+x%10);}
}using namespace FIFO;
int n,m,col[N];
vector<int>cc[N];
namespace TRE{
struct Chain{
    int bot,top,whl,unl;
    Chain(int B,int T,int W,int U){bot=B,top=T,whl=W,unl=U;}
    Chain(){bot=top=unl=0,whl=1;}
    Chain REV(){return Chain(top,bot,whl,unl);}
//  void pr(){printf("[%d %d %d %d]",bot,top,whl,unl);}
}chn[N];
Chain merge(Chain u,Chain v,int val,int bon){
    return Chain(
    (1ll*u.whl*v.bot%mod*val+u.bot)%mod,
    (1ll*u.top*v.whl%mod*val+v.top)%mod,
    1ll*u.whl*v.whl%mod*val%mod,
    (1ll*u.top*v.bot%mod*val+u.unl+v.unl+bon)%mod
    );
}
Chain split(Chain u,Chain v,int val,int bon){
    int inv=ksm(1ll*v.whl*val%mod);
    int whl=1ll*u.whl*inv%mod;
    int bot=RED(u.bot+mod-1ll*whl*v.bot%mod*val%mod);
    int top=1ll*(u.top+mod-v.top)*inv%mod;
    int unl=RED(u.unl+mod-(1ll*top*v.bot%mod*val+v.unl+bon)%mod);
    return Chain(bot,top,whl,unl);
}
vector<int>v[N];
void aev(int x,int y){v[x].push_back(y),v[y].push_back(x);}
int sub[N],ful[N],anc[N][LG],dep[N],dfn[N],tot,sum[N],ult[N],isb[N],iut[N];
int pas(int x,int y){if(dep[x]>dep[y])swap(x,y);return 1ll*sub[x]*isb[y]%mod;}
void dfs1(int x,int fa){
    sub[x]=1,dep[x]=dep[fa]+1,dfn[x]=++tot;
    for(auto y:v[x])if(y!=fa)
        dfs1(y,x),sub[x]=1ll*sub[x]*(sub[y]+1)%mod,ADJ(sum[x]+=RED(sum[y]+sub[y]));
    isb[x]=ksm(sub[x]+1);
}
void dfs2(int x,int fa){
    for(auto y:v[x])if(y!=fa)
        ult[y]=1ll*ful[x]*isb[y]%mod,iut[y]=ksm(ult[y]+1),
        ful[y]=1ll*(ult[y]+1)*sub[y]%mod,dfs2(y,x);
}
void dfs3(int x,int fa){
    anc[x][0]=fa;
    for(int i=1;i<LG;i++)anc[x][i]=anc[anc[x][i-1]][i-1];
    for(auto y:v[x])if(y!=fa)
        chn[y]=merge(Chain(1,1,1,0),chn[x],pas(x,y),RED(sum[x]+mod-RED(sum[y]+sub[y]))),
        dfs3(y,x);
}
int JUM(int x,int y){for(int i=LG-1;i>=0;i--)if(dep[x]<dep[y]-(1<<i))y=anc[y][i];return y;}
Chain Path(int x,int y){
    int z=JUM(x,y);
    return split(chn[y],chn[x],pas(x,z),RED(sum[x]+mod-RED(sum[z]+sub[z])));
}
int LCA(int x,int y){
    if(dep[x]>dep[y])swap(x,y);
    for(int i=LG-1;i>=0;i--)if(dep[x]<=(dep[y]-(1<<i)))y=anc[y][i];
    if(x==y)return x;
    for(int i=LG-1;i>=0;i--)if(anc[x][i]!=anc[y][i])x=anc[x][i],y=anc[y][i];
    return anc[x][0];
}
}using namespace TRE;
namespace IMG{
int stk[N],tp,arr[N],head[N],lam[N],cnt;
struct node{int to,next,all,par;}edge[N<<1];
bool vis[N];
int f[N][6][6][6],RT,SZ,sz[N],msz[N];
int a,b,c,A,B,C,res;
void aeh(int x,int y){
    if(dep[x]>dep[y])swap(x,y);
    Chain z=Path(x,y);
//  printf("%d %d\n",x,y);
    if(!A&&!B&&!C)ADJ(res+=z.unl);
    edge[cnt].next=head[x],edge[cnt].to=y;
    edge[cnt].all=z.whl,edge[cnt].par=z.top,head[x]=cnt++;
    int zz=JUM(x,y);
    lam[x]=1ll*lam[x]*isb[zz]%mod;
    if(!A&&!B&&!C)ADJ(res+=mod-RED(sum[zz]+sub[zz]));
    edge[cnt].next=head[y],edge[cnt].to=x;
    edge[cnt].all=z.whl,edge[cnt].par=z.bot,head[y]=cnt++;
    lam[y]=1ll*lam[y]*iut[y]%mod;
}
void ins(int x){
    if(!tp){stk[++tp]=x,SZ++;return;}
    int lca=LCA(stk[tp],x);
    while(tp>=2&&dep[stk[tp-1]]>=dep[lca])aeh(stk[tp-1],stk[tp]),tp--;
    if(dep[stk[tp]]>dep[lca])aeh(lca,stk[tp--]);
    if(stk[tp]!=lca)stk[++tp]=lca,SZ++;
    if(stk[tp]!=x)stk[++tp]=x,SZ++;
}
void fin(){while(tp>=2)aeh(stk[tp-1],stk[tp]),tp--;tp=0;}
#define iter for(int i=head[x],y;i!=-1;i=edge[i].next)if(!vis[y=edge[i].to])
void getroot(int x,int fa){
    sz[x]=1,msz[x]=0;
    iter if(y!=fa)getroot(y,x),sz[x]+=sz[y],msz[x]=max(msz[x],sz[y]);
    msz[x]=max(msz[x],SZ-sz[x]);
    if(msz[x]<msz[RT])RT=x;
}
void getsz(int x,int fa){sz[x]=1;iter if(y!=fa)getsz(y,x),sz[x]+=sz[y];}
#define trit for(int _a=0;_a<=A;_a++)for(int _b=0;_b<=B;_b++)for(int _c=0;_c<=C;_c++)
#define trin _a][_b][_c
int num,rec[N],lea[N],ALL[N],PAR[N],LAM[N];
void getdfn(int x,int fa){
    int id=++num;
//  printf("%d ",x);
    rec[id]=col[x],LAM[id]=lam[x];
    for(int i=head[x],y;i!=-1;i=edge[i].next){
        if((y=edge[i].to)==fa)continue;
        if(vis[edge[i].to])LAM[id]=1ll*LAM[id]*edge[i].par%mod;
        else PAR[num+1]=edge[i].par,ALL[num+1]=edge[i].all,getdfn(y,x);
    }
    lea[id]=num;
}
void calc(int x){
    getdfn(x,0);
    for(int i=1;i<=num;i++)trit f[i][trin]=0;
    f[1][rec[1]==a][rec[1]==b][rec[1]==c]=LAM[1];
    for(int i=2;i<=num;i++){
        trit{
            int _A=_a+(rec[i]==a),_B=_b+(rec[i]==b),_C=_c+(rec[i]==c);
            if(_A<=A&&_B<=B&&_C<=C)
                f[i][_A][_B][_C]=
                    (1ll*f[i-1][trin]*ALL[i]%mod*LAM[i]+f[i][_A][_B][_C])%mod;
        }
        trit f[lea[i]][trin]=(1ll*f[i-1][trin]*PAR[i]+f[lea[i]][trin])%mod;
    }
    ADJ(res+=f[num][A][B][C]),num=0;
}
void solve(int x){
    calc(x),getsz(x,0),vis[x]=true;
    iter SZ=sz[y],RT=0,getroot(y,0),solve(RT);
}
void nil(int x,int fa){ADJ(res+=sum[x]);iter if(y!=fa)nil(y,x);}
void reset(int x,int fa){
    for(int i=head[x],y;i!=-1;i=edge[i].next)if((y=edge[i].to)!=fa)reset(y,x);
    lam[x]=ful[x],vis[x]=false,head[x]=-1;
}
void query(){
    read(a),read(A),read(b),read(B),read(c),read(C);
    if(A>cc[a].size()||B>cc[b].size()||C>cc[c].size()){putchar('0'),putchar('\n');return;}
    for(auto x:cc[a])arr[++arr[0]]=x;
    for(auto x:cc[b])arr[++arr[0]]=x;
    for(auto x:cc[c])arr[++arr[0]]=x;
    sort(arr+1,arr+arr[0]+1,[](int x,int y){return dfn[x]<dfn[y];});
    if(arr[1]!=1)ins(1);for(int i=1;i<=arr[0];i++)ins(arr[i]);fin();
    if(!A&&!B&&!C)nil(1,0);
    getroot(1,0),solve(RT);
    print(res),putchar('\n');
    RT=SZ=res=cnt=0;for(int i=1;i<=arr[0];i++)arr[i]=0;arr[0]=0,reset(1,0);
}
}using namespace IMG;
int main(){
    read(n),read(m),memset(head,-1,sizeof(head));
    for(int i=1;i<=n;i++)read(col[i]),cc[col[i]].push_back(i);
    for(int i=1,x,y;i<n;i++)read(x),read(y),aev(x,y);
    dfs1(1,0),ful[1]=sub[1],dfs2(1,0),dfs3(1,0);
    msz[0]=n;
    for(int i=1;i<=n;i++)lam[i]=ful[i];
    for(int i=1;i<=m;i++)query();
    return 0;
}