P11363 [NOIP2024] 树的遍历 题解
_fairytale_ · · 题解
前言
来点魔怔点分治做法。
考场上最后五分钟胡出
观察性质
首先
接下来我们容斥,用
我们先尝试研究,新图的 dfs 树有什么样的性质。
性质 1:原树中交于同一点的边在 dfs 树中必然是一条连续的链。
这是因为,如果 dfs 过程中从当前蓝边进入了其他子树,那么不可能通过它们回到蓝边,所以只能把其他蓝边作为当前蓝边在 dfs 树上的儿子继续遍历。
我们记这种链为“特殊链”。
这个性质告诉我们:一条边可以作为根生成某棵 dfs 树的充要条件是这条边是所有包含它的特殊链的端点;在交于同一点的边中,最多有两条可以同时作为根生成同一棵 dfs 树。
性质 2:若两条边可以作为根生成同一棵 dfs 树,那么原树中这两条边之间的所有边都可以作为根生成这棵 dfs 树。
这是因为,当蓝边作为根时,那么必然从蓝边进入
因此,我们可以总结出,如果一个边集中的边可以同时作为根生成某棵 dfs 树,那么它们在原树中的虚树必须是一条链。因为假如不是,虚树中必定存在一个
\mathcal O(nk^2) 的做法
我们由此得到了一个
我们枚举一对关键边,满足它们之间没有其它关键边,计算同时以这一对边为根的 dfs 树的数量,然后减掉。这样做的正确性在于,考虑一棵 dfs 树,它能被
如何计算同时以某对边为根的 dfs 树的数量?设在这对边之间的点(不包括两个相距最远的点)构成的点集为
优化
我觉得 DP 写起来很麻烦啊!!!注意到我们要计算的问题是树上满足条件的路径的贡献之和,不难想到点分治。
考虑把当前重心
设
- 若
u 为端点,则(u,v) 需要为关键边,另一条关键边需要出现在v 的子树中,可以通过一次 dfs 求出。 - 若
u 不为端点,则两条关键边分属不同的子树,我们只需计算f_v 表示u 到v 子树中所有满足以出现过的第一条关键边结束的路径中点的\dfrac{1}{deg-1} 之和,同时维护f 的前缀和pre ,每次让sum+=f[v]*pre,pre+=f[v]即可。
时间复杂度
#include<bits/stdc++.h>
bool Mst;
#define ll long long
#define pii pair<int,int>
#define fi first
#define se second
#define pb push_back
#define rep(x,qwq,qaq) for(int x=(qwq);x<=(qaq);++x)
#define per(x,qwq,qaq) for(int x=(qwq);x>=(qaq);--x)
using namespace std;
#define m107 1000000007
template<class _T>
void ckmax(_T &x,_T y) {
x=max(x,y);
}
int ri() {
int x;
cin>>x;
return x;
}
#define inf 0x3f3f3f3f
#define mod m107
template<int MOD>
struct modint{
int x;
modint(){x=0;}
template<typename T>
int norm(T y){return (y%MOD+MOD)%MOD;}
template<typename T>
modint(T y){x=norm(y);}
friend modint operator +(modint a,modint b){return a.x+b.x;}
friend modint operator -(modint a,modint b){return a.x-b.x;}
friend modint operator *(modint a,modint b){return 1ll*a.x*b.x;}
modint& operator +=(modint b){return x=norm(x+b.x),*this;}
modint& operator -=(modint b){return x=norm(x-b.x),*this;}
modint& operator *=(modint b){return x=norm(1ll*x*b.x),*this;}
friend istream& operator >>(istream&is,modint &x){
ll v;
return is>>v,x.x=norm(v),is;
}
friend ostream& operator <<(ostream&os,modint &x){
return os<<x.x,os;
}
};
using mint=modint<m107>;
mint inv(int v) {
auto exgcd=[&](auto &exgcd,int a,int b,int &x,int &y)->int {
if(b==0)return x=1,y=0,a;
int g=exgcd(exgcd,b,a%b,y,x);
y-=a/b*x;
return g;
};
int x,y,g=exgcd(exgcd,v,mod,x,y);
assert(g==1);
return x;
}
mint fac[101000],iv[101000];
void Solve_() {
int n,k;
cin>>n>>k;
vector<int>deg(n+1);
vector<pii>G(n+n+2);
vector<int>bg(n+1),ed(n+1);
{
vector<vector<pii>>g(n+1);
for(int i=1,u,v; i<=n-1; ++i) {
cin>>u>>v;
g[u].pb({v,i}),g[v].pb({u,i});
++deg[u],++deg[v];
}
int cnt=0;
rep(i,1,n) {
bg[i]=cnt+1,ed[i]=bg[i]+deg[i]-1;
for(pii t:g[i])G[++cnt]=t;
}
}
vector<int>e(n+1);
rep(i,1,k)e[ri()]=1;
mint P=1,ans=k;
rep(i,1,n)P*=fac[deg[i]-1];
ans*=P;
vector<int>siz(n+1),vis(n+1);
int rt,maxp;
auto getrt=[&](auto &getrt,int u,int in,int s,int &maxp)->void {
siz[u]=1;
int nw=0;
rep(j,bg[u],ed[u]) {
int v=G[j].fi,i=G[j].se;
if(i==in||vis[v])continue;
getrt(getrt,v,i,s,maxp);
siz[u]+=siz[v];
ckmax(nw,siz[v]);
}
ckmax(nw,s-siz[u]);
if(nw<maxp)rt=u,maxp=nw;
};
auto calc=[&](int u,int in)->mint {
mint res=0;
auto dfs=[&](auto &dfs,int u,int in,mint prd)->void{
if(e[in])return void(res+=prd);
prd*=iv[deg[u]-1];
rep(j,bg[u],ed[u]) {
int v=G[j].fi,i=G[j].se;
if(i==in||vis[v])continue;
dfs(dfs,v,i,prd);
}
};
dfs(dfs,u,in,1);
return res;
};
auto calc2=[&](int u)->mint {
mint res=0;
auto dfs=[&](auto &dfs,int u,int in,mint prd)->void{
if(e[in])return void(res+=prd);
prd*=iv[deg[u]-1];
rep(j,bg[u],ed[u]) {
int v=G[j].fi,i=G[j].se;
if(i==in||vis[v])continue;
dfs(dfs,v,i,prd);
}
};
dfs(dfs,u,0,1);
return res;
};
auto solve=[&](auto &solve,int u)->void {
getrt(getrt,u,0,siz[u],maxp=inf);
vis[u=rt]=1;
mint sum=0,pre=0;
rep(j,bg[u],ed[u]) {
int v=G[j].fi,i=G[j].se;
if(vis[v])continue;
mint f=calc(v,i);
sum+=pre*f*iv[deg[u]-1];
pre+=f;
if(e[i])sum+=calc2(v);
}
ans-=P*sum;
rep(j,bg[u],ed[u]) {
int v=G[j].fi;
if(vis[v])continue;
solve(solve,v);
}
};
siz[1]=n;
solve(solve,1);
cout<<ans<<'\n';
}
bool Med;
signed main() {
fac[0]=1;
rep(i,1,100000)fac[i]=fac[i-1]*i,iv[i]=inv(i);
cerr<<(&Mst-&Med)/1024.0/1024.0<<" MB\n";
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
int c,Testcases=1;
cin>>c>>Testcases;
while(Testcases--)Solve_();
return 0;
}