题解:P14254 分割(divide)
感觉没到蓝,场上想的时间甚至比第一题短。
solution
:::info[引理 1]
子树在原树深度上的深度集合是一个连续区间。
:::success[证明]
设
子树中任一节点的原树深度不小于
:::info[引理 2]
所有被选节点的原树深度必须相等。
:::success[证明]
序列要求
两边的最小元素(即区间左端点)相等。左端点分别是
又因非降序
因此我们可以按深度
对深度为
按上面的引理,每个被选节点对应的
把深度为
固定某个节点
令
必须满足:
- 对于
i=2,\dots,k ,它们对应的M_{b_i} 都不能小于t ,否则交集上界会小于t 。因此其他k-1 个被选节点必须从M\ge t 的节点中选。 - 根所在子树在去掉这些
k 条边后,仍然有一个深度不小于t ,即剩余部分的最大深度R\ge t 。当且仅当并非层内所有M\ge t 的节点都被选掉时,根所在剩余部分才含有深度\ge t 。换句话说,若把层内所有M\ge t 的节点全部包含在选集里,那么剩下的树里没有深度\ge t 的节点,导致R<t ,因此该种选择非法。
把层内节点按
- 设
a 为该组中M=t 的节点数。 - 设
G 为层内满足M\ge t 的节点总数。
考虑把首位
但若
当
因此,固定
把该层上所有不同的
预处理阶乘逆元后如果用 sort 是
:::success[code]
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline void out(int x){
if(x==0){putchar('0');return;}
int len=0,k1=x,c[10005];
if(k1<0)k1=-k1,putchar('-');
while(k1)c[len++]=k1%10+'0',k1/=10;
while(len--)putchar(c[len]);
}
const int N=1e6+5,mod=998244353;
int fa[N],dep[N],cnt[N],maxd[N],pos[N],m[N],inv[N],fac[N];
int addmod(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
int submod(int a,int b){a-=b;if(a<0)a+=mod;return a;}
int mulmod(int a,int b){return (a*b)%mod;}
int qpow(int a,int b){
int ans=1;
while(b){
if(b&1)ans=ans*a%mod;
b>>=1,a=a*a%mod;
}return ans;
}
void init(int n){
inv[0]=inv[1]=1;fac[0]=fac[1]=1;
for(int i=2;i<=n;i++)fac[i]=fac[i-1]*i%mod;
inv[n]=qpow(fac[n],mod-2);
for(int i=n-1;i>=2;i--)inv[i]=inv[i+1]*(i+1)%mod;
}
int c(int n,int m){
if(n<m)return 0;
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
signed main(){
// freopen("divide6.in","r",stdin);
// freopen("divide.out","w",stdout);
int n=read(),k=read(),maxn=1;dep[1]=1;
for(int i=2;i<=n;i++)fa[i]=read();init(n+1);
for(int i=2;i<=n;i++){
dep[i]=dep[fa[i]]+1;
maxn=max(maxn,dep[i]);
}for(int i=1;i<=n;i++)maxd[i]=dep[i];
for(int i=n;i>=2;i--)maxd[fa[i]]=max(maxd[fa[i]],maxd[i]);
// for(int i=2;i<=n;i++)cout<<dep[i]<<" ";puts("");
for(int i=2;i<=n;i++)cnt[dep[i]]++;int tot=0;
for(int d=1;d<=maxn;d++)pos[d]=tot,tot+=cnt[d];
vector<int>cur(pos,pos+maxn+1);int ans=0;
for(int i=2;i<=n;i++)m[cur[dep[i]]++]=maxd[i];
// cout<<maxn<<"\n";
for(int d=2;d<=maxn;d++){
int mm=cnt[d];
if(mm<k)continue;
int l=pos[d],r=l+mm;
sort(m+l,m+r);int pre=0,idx=l;
while(idx<r){
int j=idx,val=m[idx];
while(j<r&&m[j]==val)j++;
int a=j-idx,lcnt=pre,g=mm-lcnt,b=g-a;
if(k<g){
int t1=submod(c(g-1,k-1),c(b,k-1));
t1=mulmod(a,t1);int t2=0;
if((k-1)==b&&a>=2)t2=a%mod;
int add=addmod(t1%mod,t2);
add=mulmod(add,fac[k-1]);ans=addmod(ans,add);
}pre+=a;idx=j;
}
}cout<<ans;
return 0;
}
:::