非常好 D2T2
Otomachi_Una_
·
·
题解
显然我们只需要考虑每个叶子节点 u 新图上距离为 1 的点 S_u。原题即要求每个叶子 u,v,S_u,S_v 的交非空。
根据经典的“点减边”容斥,我们只需要算出“钦定经过某个点的方案树和”减去“钦定经过某个边的方案树”即可。
考虑怎么算经过 $u$ 的方案数。我们只需要关心叶子。我们以 $u$ 为根,考虑容斥,钦定若干叶子不能和 $u$ 联通。dp状态是容易的,假设 $f_i$ 表示当前已经钦定 $i$ 个叶子不能和 $u$ 联通,那么考虑 $u$ 和 $v$ 合并。假设 $siz,lef$ 表示当前的子树大小,子树叶子个数,转移:
$$
tf_{i+j}\leftarrow f_i\times \binom{lef_v}{j}\times 2^{(siz_u-i)\times(siz_v-j)}
$$
$tf$ 是转移后的 $f$ 数组。
这样子单个 $u$ 复杂度已经是 $\mathcal O(n^2)$ 了,总复杂度是 $\mathcal O(n^3)$。考虑优化。
注意到树上背包卷积复杂度只有 $\mathcal O(n^2)$,考虑只容斥子树内元素。对于子树外元素钦定必须经过,这样子是可以直接算答案的,由于不需要卷积子树外的元素,复杂度只有 $\mathcal O(n^2)$。
```cpp
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define MP make_pair
mt19937 rnd(time(0));
const int MAXN=3e3+5;
const int MOD=998244353;
void add(ll &x,ll y){x=(x+y%MOD+MOD)%MOD;}
ll ksm(ll a,int b){ll r=1;while(b){if(b&1)r=r*a%MOD;a=a*a%MOD,b>>=1;}return r;}
int n;
ll fac[MAXN],inf[MAXN],ans;
vector<int> edg[MAXN];
ll f[MAXN],g[MAXN];
int lef[MAXN],siz[MAXN],rt=1,sl=0;
ll C(int x,int y){return fac[y]*inf[x]%MOD*inf[y-x]%MOD;}
void dfs(int u,int fa){
if(edg[u].size()==1){
lef[u]=siz[u]=1;
return;
}
for(int v:edg[u]) if(v!=fa) dfs(v,u);
memset(f,0,sizeof(f));
siz[u]=f[0]=1;
int in=0;
for(int v:edg[u]) if(v!=fa){
memset(g,0,sizeof(g));
for(int i=0;i<=lef[u];i++) for(int j=0;j<=lef[v];j++)
add(g[i+j],f[i]*(j&1?MOD-1:1)%MOD*C(j,lef[v])%MOD*ksm(2,(siz[u]-i)*(siz[v]-j))%MOD);
swap(f,g);
siz[u]+=siz[v];lef[u]+=lef[v];in+=siz[v]*(siz[v]-1)/2;
}
int os=n-siz[u],of=sl-lef[u];in+=os*(os-1)/2;
for(int i=0;i<=lef[u];i++) add(ans,f[i]*ksm(2,in)%MOD*ksm(2,(os-of)*(siz[u]-i))%MOD*ksm(ksm(2,siz[u]-i)-1,of)%MOD);
if(fa){
in=os*(os-1)/2+siz[u]*(siz[u]-1)/2;
for(int i=0;i<=lef[u];i++) add(ans,(i&1?1:MOD-1)*C(i,lef[u])%MOD*ksm(2,in)%MOD*ksm(2,(siz[u]-i)*(os-of))%MOD*ksm(ksm(2,siz[u]-i)-1,of)%MOD);
}
}
int main(){
ios::sync_with_stdio(false);
fac[0]=inf[0]=1;
for(int i=1;i<MAXN;i++) inf[i]=ksm(fac[i]=fac[i-1]*i%MOD,MOD-2);
cin>>n;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
edg[u].push_back(v);
edg[v].push_back(u);
}
if(n==2) return cout<<1<<endl,0;
while(edg[rt].size()==1) rt++;
for(int i=1;i<=n;i++) if(edg[i].size()==1) sl++;
dfs(rt,0);
cout<<ans<<endl;
return 0;
}
```