题解:P11363 [NOIP2024] 树的遍历

· · 题解

一个关键性质是,一个点连接的所有边在最终的树上会形成一条链,且链与链之间的结构不互相影响。注意这并不意味着你可以把每个点的合法链方案乘起来,因为一个方案是否合法与你选定的起点有关。博主因为这个问题虚空调试 1h 并最终直接导致爆炸。

对于 k=1 的情况,我们发现链的端点一定是来源方向的边,其他边任意。方案数就显然了:\prod_{i=1}^n(deg_i-1)!

对于 k=2 的情况,假设两个关键边为 x,y,先分别计算以每个边为起点的树的数量和(其实就是 k=1 的答案乘 2),然后我们需要考虑哪些树是算重的。由于我们要求链的起点是来源方向的边,因此 x,y 之间的所有点的链的两端都固定了;其它点由于两个关键边对应点上的同一条边,答案仍然不变。于是算重的树的个数为 \prod_{i\in S}(deg_i-2)!\prod_{i\notin S}(deg_i-1)!,其中 Sx,y 之间的点的集合。

对于 k 更大的情况,我们发现,若一棵树被边集 S 里的关键边同时算重,则 S 内的边一定形成在同一条链内。否则一定会存在一个点,要求其至少三条邻边都在链的端点。这是不合法的。基于这个性质,我们只需要枚举任意两个关键边,使得它们之间没有关键边,然后计算同时被这两个关键边算到的方案数,从总答案中减去这些方案即可。

不难发现,对于一棵这样的树,他会在初始时被算到 |S| 次,然后被减去 |S|-1 次,故这样的容斥是正确的。

对于第二部分的计算,我们可以设 f_{i,0/1/2} 表示 i 子树内无选定的关键边;有一个选定的关键边;有两个选定的关键边的方案数。设 ji 的儿子转移如下:

通过预处理 \prod f_{j,0} 可以做到 O(\log n),瓶颈在求逆元。总复杂度 O(n\log n)

#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=j;i<=k;i++)
#define repp(i,j,k) for(int i=j;i>=k;i--)
#define pii pair<int,int>
#define mp make_pair
#define fir first
#define sec second
#define ls(x) (x<<1)
#define rs(x) ((x<<1)|1)
#define lowbit(i) (i&-i)
#define int long long
#define qingbai 666
using namespace std;
typedef long long ll;
const int N=1e5+5,inf=(ll)1e18+7,mo=1e9+7;
void read(int &p){
    int w=1,x=0;
    char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-')w=-1;
        ch=getchar();
    }
    while(isdigit(ch)){
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    p=w*x;
}
int T;
int n,m;
vector<pii>e[N];
int deg[N],jc[N],qj[N],f[N][3],g[3];
int quick_power(int base,int x){
    int res=1;
    while(x){
        if(x&1)res*=base,res%=mo;
        base*=base,base%=mo;
        x>>=1;
    }
    return res;
}
bool imp[N];
void dfs(int x,int p){
    int prod=1,sum2=0;
    for(auto j:e[x])
        if(j.fir!=p)dfs(j.fir,x),prod*=f[j.fir][0],prod%=mo;
    rep(i,0,2)
        f[x][i]=g[i]=0;
    g[0]=1;
    f[x][0]=prod*jc[deg[x]-1]%mo;
    for(auto j:e[x]){
        if(j.fir==p)continue;
        int inv0=quick_power(f[j.fir][0],mo-2);
        if(deg[x]>=2){
            if(imp[j.sec])f[x][1]++,f[x][1]%=mo;
            else f[x][1]+=f[j.fir][1]*inv0%mo,f[x][1]%=mo;
        }
        repp(k,2,0){
            g[k]=g[k]*f[j.fir][0]%mo;
            if(k){
                if(imp[j.sec])g[k]+=g[k-1]*f[j.fir][0]%mo;
                else g[k]+=g[k-1]*f[j.fir][1]%mo;
            }
            g[k]%=mo;
        }
        if(imp[j.sec])f[x][2]+=f[j.fir][1]*inv0%mo,f[x][2]%=mo;
        sum2+=f[j.fir][2]*inv0%mo,sum2%=mo;
    }
    f[x][2]*=prod*jc[deg[x]-1]%mo,f[x][2]%=mo;
    if(deg[x]>=2){
        f[x][1]*=prod*jc[deg[x]-2]%mo,f[x][1]%=mo;
        f[x][2]+=g[2]*jc[deg[x]-2]%mo;
    }
    if(e[x].size()>=2||x==1)f[x][2]+=sum2*prod%mo*jc[deg[x]-1],f[x][2]%=mo;
}
void solve(){
    read(n),read(m);
    rep(i,1,n)
        deg[i]=0,e[i].clear(),imp[i]=0;
    rep(i,1,n-1){
        int x,y;
        read(x),read(y);
        e[x].push_back(mp(y,i)),e[y].push_back(mp(x,i));
        deg[x]++,deg[y]++;
    }
    rep(i,1,m){
        int x;
        read(x),imp[x]=1;
    }
    int ans=1;
    rep(i,1,n)
        ans*=jc[deg[i]-1],ans%=mo;
    ans*=m,ans%=mo;
    dfs(1,0);
    ans=ans+mo-f[1][2],ans%=mo;
    printf("%lld\n",ans);
}
int cid;
signed main(){
    jc[0]=qj[0]=1;
    rep(i,1,100000)
        jc[i]=jc[i-1]*i%mo,qj[i]=quick_power(jc[i],mo-2);
    read(cid),read(T);
    while(T--)
        solve();
    return 0;
}