非常好 D2T2

· · 题解

显然我们只需要考虑每个叶子节点 u 新图上距离为 1 的点 S_u。原题即要求每个叶子 u,vS_u,S_v 的交非空。

根据经典的“点减边”容斥,我们只需要算出“钦定经过某个点的方案树和”减去“钦定经过某个边的方案树”即可。 考虑怎么算经过 $u$ 的方案数。我们只需要关心叶子。我们以 $u$ 为根,考虑容斥,钦定若干叶子不能和 $u$ 联通。dp状态是容易的,假设 $f_i$ 表示当前已经钦定 $i$ 个叶子不能和 $u$ 联通,那么考虑 $u$ 和 $v$ 合并。假设 $siz,lef$ 表示当前的子树大小,子树叶子个数,转移: $$ tf_{i+j}\leftarrow f_i\times \binom{lef_v}{j}\times 2^{(siz_u-i)\times(siz_v-j)} $$ $tf$ 是转移后的 $f$ 数组。 这样子单个 $u$ 复杂度已经是 $\mathcal O(n^2)$ 了,总复杂度是 $\mathcal O(n^3)$。考虑优化。 注意到树上背包卷积复杂度只有 $\mathcal O(n^2)$,考虑只容斥子树内元素。对于子树外元素钦定必须经过,这样子是可以直接算答案的,由于不需要卷积子树外的元素,复杂度只有 $\mathcal O(n^2)$。 ```cpp #include<bits/stdc++.h> using namespace std; #define ll long long #define MP make_pair mt19937 rnd(time(0)); const int MAXN=3e3+5; const int MOD=998244353; void add(ll &x,ll y){x=(x+y%MOD+MOD)%MOD;} ll ksm(ll a,int b){ll r=1;while(b){if(b&1)r=r*a%MOD;a=a*a%MOD,b>>=1;}return r;} int n; ll fac[MAXN],inf[MAXN],ans; vector<int> edg[MAXN]; ll f[MAXN],g[MAXN]; int lef[MAXN],siz[MAXN],rt=1,sl=0; ll C(int x,int y){return fac[y]*inf[x]%MOD*inf[y-x]%MOD;} void dfs(int u,int fa){ if(edg[u].size()==1){ lef[u]=siz[u]=1; return; } for(int v:edg[u]) if(v!=fa) dfs(v,u); memset(f,0,sizeof(f)); siz[u]=f[0]=1; int in=0; for(int v:edg[u]) if(v!=fa){ memset(g,0,sizeof(g)); for(int i=0;i<=lef[u];i++) for(int j=0;j<=lef[v];j++) add(g[i+j],f[i]*(j&1?MOD-1:1)%MOD*C(j,lef[v])%MOD*ksm(2,(siz[u]-i)*(siz[v]-j))%MOD); swap(f,g); siz[u]+=siz[v];lef[u]+=lef[v];in+=siz[v]*(siz[v]-1)/2; } int os=n-siz[u],of=sl-lef[u];in+=os*(os-1)/2; for(int i=0;i<=lef[u];i++) add(ans,f[i]*ksm(2,in)%MOD*ksm(2,(os-of)*(siz[u]-i))%MOD*ksm(ksm(2,siz[u]-i)-1,of)%MOD); if(fa){ in=os*(os-1)/2+siz[u]*(siz[u]-1)/2; for(int i=0;i<=lef[u];i++) add(ans,(i&1?1:MOD-1)*C(i,lef[u])%MOD*ksm(2,in)%MOD*ksm(2,(siz[u]-i)*(os-of))%MOD*ksm(ksm(2,siz[u]-i)-1,of)%MOD); } } int main(){ ios::sync_with_stdio(false); fac[0]=inf[0]=1; for(int i=1;i<MAXN;i++) inf[i]=ksm(fac[i]=fac[i-1]*i%MOD,MOD-2); cin>>n; for(int i=1;i<n;i++){ int u,v;cin>>u>>v; edg[u].push_back(v); edg[v].push_back(u); } if(n==2) return cout<<1<<endl,0; while(edg[rt].size()==1) rt++; for(int i=1;i<=n;i++) if(edg[i].size()==1) sl++; dfs(rt,0); cout<<ans<<endl; return 0; } ```