#loj2983 / P5206 WC2019 数树
shadowice1984 · · 题解
技不如人肝败吓疯
orzrqy
前置芝士:多项式exp
不会exp的出门左转你站膜板区自行学习
前置芝士:生成函数,指数型生成函数
如果对ogf和egf不是很熟悉的话可以去找本具体数学翻翻/度娘/翻rqy博客找题做来学习
前置芝士:矩阵树定理/行列式
如果对矩阵树定理和行列式不太了解的话可以自行度娘或者扒别人博客了解一下
前置芝士:容斥原理
请保证您对容斥原理有一定的认识不然可能无法理解本题解当中的一些式子
本题题解
我们将题目的意思重新描述一下就是这样的
定义一个树的pair
其中
那么子任务0就是给出
子任务1就是给出
子任务2就是求所有的
然后我们发现权值函数是
我们不妨改写一下权值函数
子任务1
直接按照题意模拟一下,数一数那些边出现了两次就行了
子任务2
我们发现权值函数里有个取交集,这引起了我们的极度不适
我们考虑把这个东西给容斥掉,怎么做呢?
我们定义
那么答案就是
但是我们发现
我们定义
你可以将
那么我们可以搞出来一个非常诡异的容斥式子
这个式子出人意料的简单,那么它是怎么推出来的呢?
其实我们将
那么我们将
整理一下就是
那么这个式子就十分好理解了,我们计数的对象是有n个点的树T2,并且T2含有S中的边,限制是这个树和T1的交集恰好是S,换句话讲限制就是有些边必须不能出现
根据传统的容斥套路,我们枚举一个限制的集合,打破这个集合里的所有限制,最后配上正负1的容斥系数就能get到正确的答案
那么上面的式子其实就是在枚举那些不能出现的边出现了,然后搞了一个容斥出来
这样我们的容斥式子就成立了
接下来就是喜闻乐见的交换求和号时间~
考虑每个
补上一些1我们就可以凑出来一个二项式定理出来
然后我们把
我们令
那么我们就得到了这样的一个式子
wow,看来这个式子出人意料的简洁,不过到此为止我们的算法还是指数级的,接下来我们需要搞点trick让我们的算法变成多项式级别的
我们考虑一下如何将
假设
证明的话我们采取矩阵树定理+手玩行列式的方法来证明它
当然我们需要先构造出一张图来,具体点来讲我们将第i个联通块当中的点缩成一个点i,然后我们在点i和点j之间连接
那么这样建出来图的基尔霍夫矩阵应该长这样
然后我们删去一行一列之后会变成这样
接下来我们将第
然后我们将第2列到第k-1列加到第1列上,会得到
接下来我们用第1行去减其他的行,会得到
这样矩阵就被我们削成了一个上三角阵,它的行列式是
这样我们就证明了我们的式子是正确的
重新把答案的式子写一下就是
然后我们把n和P分别配到
我们令
一个简单的实现方式是直接dp,令
直接实现的复杂度是
看起来我们需要胡乱优化一波
那么我们考虑每一个dp数组对应的生成函数,我们把点u的生成函数记做
那么我们合并
而我们最后求的答案其实就是
那么不妨设
那么我们考虑合并了一个子树的时候
看起来这个式子十分的简单,如果我们设
边界条件是
然后就可以愉悦的一遍dfs算出答案了~,最后我们需要输出
子任务3
好了我们发现subtask2的容斥可以原封不动的拿来用,这里我们我们发现
因此我们的答案就是
注意这里的T枚举的是n个点所有森林的边集
还是老套路我们把
此时我们把
此时我们设
一个大小为
一个森林的权值为里面所有树的权值之积
求所有森林的权值之和
这是一个十分传统的问题,告诉你一个无向图的权值为这个无向图中所有连通块的权值之积,求所有大小为
那么我们构造
那么我们会得到一个相当令人吃鲸的结论
对,这就是多项式exp和ln的组合意义,他们本质上是带标号的多重背包和对应的反演
至于这两个式子为什么是对的可以将exp暴力泰勒展开一下就会发现exp本来的定义加上指数型函数刚好凑出了一个多重集合的排列,从而分配好了标号,这里就不在泰勒展开了
那么我们现在要做的就是把多项式
做一个多项式exp然后取第n项,乘上
上代码~
#include<cstdio>
#include<algorithm>
#include<map>
using namespace std;const int N=262144+10;typedef long long ll;const ll mod=998244353;
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
int n;ll bas;int op;
namespace solve0
{
map <int,int> mp[N];int tot;
inline void mainsolve()
{
for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),mp[u][v]++,mp[v][u]++;
for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),mp[u][v]++,mp[v][u]++;
tot=n;for(int i=1;i<=n;i++)
for(map <int,int>:: iterator it=mp[i].begin();it!=mp[i].end();++it)
if(it->first<i&&it->second==2)tot--;
printf("%lld",po(bas,tot));
}
}
namespace solve1
{
int v[N<<1];int x[N<<1];int ct;int al[N];ll f[N];ll g[N];ll p;ll k;
inline void add(int u,int V){v[++ct]=V;x[ct]=al[u];al[u]=ct;}
inline void dfs(int u,int fa)
{
f[u]=1;g[u]=k;
for(int i=al[u];i;i=x[i])
if(v[i]!=fa)
{
dfs(v[i],u);ll ret=(f[v[i]]+g[v[i]])%mod;
g[u]=(g[u]*ret+g[v[i]]*f[u])%mod;
(f[u]*=ret)%=mod;
}
}
inline void mainsolve()
{
if(bas==1){printf("%lld",po(n,n-2));return;}
p=(po(bas,mod-2)+mod-1)%mod;k=n*po(p,mod-2)%mod;
for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),add(u,v),add(v,u);dfs(1,0);
printf("%lld",g[1]*po(p*bas%mod,n)%mod*po((ll)n*n%mod,mod-2)%mod);
}
}
namespace solve2
{
int rv[22][N];ll rt[2][22];ll tr[N];ll tr1[N];ll inv[N];ll fac[N];ll ifac[N];
ll f[N];ll g[N];
inline void pre()
{
for(int d=1;d<=18;d++)
for(int i=1;i<(1<<d);i++)rv[d][i]=(rv[d][i>>1]>>1)|((i&1)<<(d-1));
inv[1]=inv[0]=1;for(int i=2;i<262144;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
fac[0]=1;for(int i=1;i<262144;i++)fac[i]=fac[i-1]*i%mod;
ifac[0]=1;for(int i=1;i<262144;i++)ifac[i]=ifac[i-1]*inv[i]%mod;
for(int t=(mod-1)>>1,i=1;i<=20;t>>=1,i++)rt[0][i]=po(3,t);
for(int t=(mod-1)>>1,i=1;i<=20;t>>=1,i++)rt[1][i]=po(332748118,t);
}
# define md(x) (x=(x>=mod)?x-mod:x)
inline void ntt(ll* a,int len,int d,int o)
{
for(int i=1;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]);
for(int k=1,j=1;k<len;k<<=1,j++)
for(int s=0;s<len;s+=(k<<1))
for(int i=s,w=1;i<s+k;i++,w=w*rt[o][j]%mod)
{ll a1=a[i+k]*w%mod;a[i+k]=a[i]+mod-a1;a[i]+=a1;md(a[i+k]);md(a[i]);}
if(o){ll iv=po(len,mod-2);for(int i=0;i<len;i++)(a[i]*=iv)%=mod;}
}
inline void poly_inv(ll* a,ll* b,int len)
{
b[0]=po(a[0],mod-2);
for(int k=2,d=1;k<=len;k<<=1,d++)
{
for(int i=0;i<k;i++)tr[i]=a[i];for(int i=k;i<(k<<1);i++)tr[i]=0;
ntt(tr,k<<1,d+1,0);ntt(b,k<<1,d+1,0);
for(int i=0;i<(k<<1);i++)b[i]=b[i]*(2+(mod-tr[i])*b[i]%mod)%mod;
ntt(b,k<<1,d+1,1);for(int i=k;i<(k<<1);i++)b[i]=0;
}
}
inline void poly_der(ll* a,int len)
{for(int i=0;i+1<len;i++)a[i]=a[i+1]*(i+1)%mod;a[len-1]=0;}
inline void poly_int(ll* a,int len)
{for(int i=len-1;i>=1;i--)a[i]=a[i-1]*inv[i]%mod;a[0]=0;}
inline void poly_ln(ll* a,ll* b,int len,int d)
{
for(int i=0;i<(len<<1);i++)b[i]=0;poly_inv(a,b,len);
for(int i=0;i<(len<<1);i++)tr[i]=a[i];poly_der(tr,len);ntt(b,len<<1,d+1,0);
ntt(tr,len<<1,d+1,0);for(int i=0;i<(len<<1);i++)b[i]=b[i]*tr[i]%mod;
ntt(b,len<<1,d+1,1);for(int i=len;i<(len<<1);i++)b[i]=0;poly_int(b,len);
}
inline void poly_exp(ll* a,ll* b,int len)
{
b[0]=1;
for(int k=2,d=1;k<=len;k<<=1,d++)
{
poly_ln(b,tr1,k,d);for(int i=0;i<k;i++)tr1[i]=(mod-tr1[i]+a[i])%mod;(tr1[0]+=1)%=mod;
ntt(tr1,k<<1,d+1,0);ntt(b,k<<1,d+1,0);
for(int i=0;i<(k<<1);i++)b[i]=b[i]*tr1[i]%mod;ntt(b,k<<1,d+1,1);
for(int i=k;i<(k<<1);i++)b[i]=0;
}
}
inline void mainsolve()
{
if(bas==1){printf("%lld",po(n,n-2)*po(n,n-2)%mod);return;}
int LEN=1;for(;LEN<=n;LEN<<=1);pre();
ll p=(po(bas,mod-2)+mod-1)%mod;ll k=po(p,mod-2)*n%mod*n%mod;
for(int i=1;i<LEN;i++)f[i]=k*po(i,i)%mod*ifac[i]%mod;
poly_exp(f,g,LEN);ll ans=g[n]*fac[n]%mod;(ans*=po(p,n))%=mod;
(ans*=po(po(n,4),mod-2))%=mod;(ans*=po(bas,n))%=mod;printf("%lld",ans);
}
}
int main()
{
freopen("tree.in","r",stdin);freopen("tree.out","w",stdout);
scanf("%d%lld%d",&n,&bas,&op);
switch(op)
{
case 0:{solve0::mainsolve();break;}
case 1:{solve1::mainsolve();break;}
case 2:{solve2::mainsolve();break;}
}return 0;
}