题解:P11363 [NOIP2024] 树的遍历
一个关键性质是,一个点连接的所有边在最终的树上会形成一条链,且链与链之间的结构不互相影响。注意这并不意味着你可以把每个点的合法链方案乘起来,因为一个方案是否合法与你选定的起点有关。博主因为这个问题虚空调试 1h 并最终直接导致爆炸。
对于
对于
对于
不难发现,对于一棵这样的树,他会在初始时被算到
对于第二部分的计算,我们可以设
-
f_{i,0}=\prod f_{j,0}\times (deg_i-1)! - 对于有一个边的情况,考察
(i,j) 是否是关键边:- 若不是,则
f_{i,1}\gets f_{j,1}\times \prod_{v\neq j}f_{v,0}\times (deg_i-2)! - 否则,由于我们要求两个关键边之间没有其他边,则选定的关键边只能是
(i,j) ,有f_{i,1}\gets \prod f_{j,0}\times (deg_i-2)! 。
- 若不是,则
- 对于有两个边的情况,分下面几种情况讨论:
- 在某个儿子内就已经完成了合并:
f_{i,2}\gets f_{j,2}\times \prod_{v\neq j}f_{v,0}\times (deg_i-1)! 。 - 两个不同的子树内各选了一个:此时另设
g_{0/1/2} 表示选了多少个的方案数,根据每条边是否是关键边以及选不选讨论,做背包合并即可,有f_{i,2}\gets g_2 。 - 选了某个关键边
(i,j) ,另一个在j 的子树内:此时有f_{i,2}\gets f_{j,1}\times \prod_{v\neq j}f_{v,0}\times (deg_i-1)! 。
- 在某个儿子内就已经完成了合并:
通过预处理
#include<bits/stdc++.h>
#define rep(i,j,k) for(int i=j;i<=k;i++)
#define repp(i,j,k) for(int i=j;i>=k;i--)
#define pii pair<int,int>
#define mp make_pair
#define fir first
#define sec second
#define ls(x) (x<<1)
#define rs(x) ((x<<1)|1)
#define lowbit(i) (i&-i)
#define int long long
#define qingbai 666
using namespace std;
typedef long long ll;
const int N=1e5+5,inf=(ll)1e18+7,mo=1e9+7;
void read(int &p){
int w=1,x=0;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')w=-1;
ch=getchar();
}
while(isdigit(ch)){
x=(x<<1)+(x<<3)+ch-'0';
ch=getchar();
}
p=w*x;
}
int T;
int n,m;
vector<pii>e[N];
int deg[N],jc[N],qj[N],f[N][3],g[3];
int quick_power(int base,int x){
int res=1;
while(x){
if(x&1)res*=base,res%=mo;
base*=base,base%=mo;
x>>=1;
}
return res;
}
bool imp[N];
void dfs(int x,int p){
int prod=1,sum2=0;
for(auto j:e[x])
if(j.fir!=p)dfs(j.fir,x),prod*=f[j.fir][0],prod%=mo;
rep(i,0,2)
f[x][i]=g[i]=0;
g[0]=1;
f[x][0]=prod*jc[deg[x]-1]%mo;
for(auto j:e[x]){
if(j.fir==p)continue;
int inv0=quick_power(f[j.fir][0],mo-2);
if(deg[x]>=2){
if(imp[j.sec])f[x][1]++,f[x][1]%=mo;
else f[x][1]+=f[j.fir][1]*inv0%mo,f[x][1]%=mo;
}
repp(k,2,0){
g[k]=g[k]*f[j.fir][0]%mo;
if(k){
if(imp[j.sec])g[k]+=g[k-1]*f[j.fir][0]%mo;
else g[k]+=g[k-1]*f[j.fir][1]%mo;
}
g[k]%=mo;
}
if(imp[j.sec])f[x][2]+=f[j.fir][1]*inv0%mo,f[x][2]%=mo;
sum2+=f[j.fir][2]*inv0%mo,sum2%=mo;
}
f[x][2]*=prod*jc[deg[x]-1]%mo,f[x][2]%=mo;
if(deg[x]>=2){
f[x][1]*=prod*jc[deg[x]-2]%mo,f[x][1]%=mo;
f[x][2]+=g[2]*jc[deg[x]-2]%mo;
}
if(e[x].size()>=2||x==1)f[x][2]+=sum2*prod%mo*jc[deg[x]-1],f[x][2]%=mo;
}
void solve(){
read(n),read(m);
rep(i,1,n)
deg[i]=0,e[i].clear(),imp[i]=0;
rep(i,1,n-1){
int x,y;
read(x),read(y);
e[x].push_back(mp(y,i)),e[y].push_back(mp(x,i));
deg[x]++,deg[y]++;
}
rep(i,1,m){
int x;
read(x),imp[x]=1;
}
int ans=1;
rep(i,1,n)
ans*=jc[deg[i]-1],ans%=mo;
ans*=m,ans%=mo;
dfs(1,0);
ans=ans+mo-f[1][2],ans%=mo;
printf("%lld\n",ans);
}
int cid;
signed main(){
jc[0]=qj[0]=1;
rep(i,1,100000)
jc[i]=jc[i-1]*i%mo,qj[i]=quick_power(jc[i],mo-2);
read(cid),read(T);
while(T--)
solve();
return 0;
}