题解:P14254 分割(divide)

· · 题解

感觉没到蓝,场上想的时间甚至比第一题短。

solution

:::info[引理 1] 子树在原树深度上的深度集合是一个连续区间。 :::success[证明] 设 u 的原树深度为 d_u,子树中原树最深的深度为 M_u。那么子树中出现过的原树深度集合:

S_u={d_u,d_u+1,\dots,M_u}

子树中任一节点的原树深度不小于 d_u(因为从根到该节点经过 u),不大于 M_u。对于区间中任意整数 x 满足 d_u\le x\le M_u,沿着从 u 到达某个深度为 M_u 的节点的路径上必然存在深度为 x 的节点,因此区间内每个整数都出现,集合为闭区间。 :::

:::info[引理 2] 所有被选节点的原树深度必须相等。 :::success[证明] 序列要求 1<d_{b_1}\le d_{b_2}\le\cdots\le d_{b_k}。由条件:

S_1=\bigcap_{i=2}^{k+1} S_i,

两边的最小元素(即区间左端点)相等。左端点分别是 d_{b_1}\max(d_{b_2},\dots,d_{b_k},1)。因此:

d_{b_1}=\max(d_{b_2},\dots,d_{b_k},1)

又因非降序 d_{b_1}\le d_{b_2}\le\cdots\le d_{b_k},可推出所有 d_{b_i} 必然相等。设共同深度为 D>1

因此我们可以按深度 D 独立地统计:只考虑原树中深度恰为 D 的节点,把所有合法序列的项都限制在该层。 :::

对深度为 D 的每个节点 u,令子树中原树最深的深度为 M_u

按上面的引理,每个被选节点对应的 S 都是区间 [D, M_u]

把深度为 D 的所有节点按 M_u 从小到大排序。设该层共有 m 个节点。

固定某个节点 u且令其被放在序列的第 1 位。

t=M_u。要使得:

S_1=[D,t]=\bigcap_{i=2}^{k+1} S_i,

必须满足:

  1. 对于 i=2,\dots,k,它们对应的 M_{b_i} 都不能小于 t,否则交集上界会小于 t。因此其他 k-1 个被选节点必须从 M\ge t 的节点中选。
  2. 根所在子树在去掉这些 k 条边后,仍然有一个深度不小于 t,即剩余部分的最大深度 R\ge t。当且仅当并非层内所有 M\ge t 的节点都被选掉时,根所在剩余部分才含有深度 \ge t。换句话说,若把层内所有 M\ge t 的节点全部包含在选集里,那么剩下的树里没有深度 \ge t 的节点,导致 R<t,因此该种选择非法。

把层内节点按 M 排序并分组。对于某个具体的 t

考虑把首位 b_1 选为该组中某个节点。其余 k-1 个位置必须从剩下的 G-1M\ge t 节点中 有序不重复地选出。方案数为:

P(G-1,k-1)=(G-1)(G-2)\cdots(G-k+1)

但若 k=G,则根所在部分不含深度 \ge t,不满足条件。

k=GP(G-1,k-1)=(G-1)!,此类全部被选掉的序列数是 a(G-1)!,需要剔除。

因此,固定 t对合法有序序列贡献:

\begin{cases} aP(G-1,k-1), & G\ne k\\ a\bigl(P(G-1,k-1)-(G-1)!\bigr)=0, & G=k \end{cases}

把该层上所有不同的 t 值的贡献累加,得到该深度 D 的总贡献。对所有 D\ge2 累加即为全树结果。

预处理阶乘逆元后如果用 sortO(n\log n) 的。换成基数排序可以做到 O(n)

:::success[code]

#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read(){
    int s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')s=s*10+ch-'0',ch=getchar();
    return s*w;
}
inline void out(int x){
    if(x==0){putchar('0');return;}
    int len=0,k1=x,c[10005];
    if(k1<0)k1=-k1,putchar('-');
    while(k1)c[len++]=k1%10+'0',k1/=10;
    while(len--)putchar(c[len]);
}
const int N=1e6+5,mod=998244353;
int fa[N],dep[N],cnt[N],maxd[N],pos[N],m[N],inv[N],fac[N];
int addmod(int a,int b){a+=b;if(a>=mod)a-=mod;return a;}
int submod(int a,int b){a-=b;if(a<0)a+=mod;return a;}
int mulmod(int a,int b){return (a*b)%mod;}
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=ans*a%mod;
        b>>=1,a=a*a%mod;
    }return ans;
}
void init(int n){
    inv[0]=inv[1]=1;fac[0]=fac[1]=1;
    for(int i=2;i<=n;i++)fac[i]=fac[i-1]*i%mod;
    inv[n]=qpow(fac[n],mod-2);
    for(int i=n-1;i>=2;i--)inv[i]=inv[i+1]*(i+1)%mod;
}
int c(int n,int m){
    if(n<m)return 0;
    return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
signed main(){
    // freopen("divide6.in","r",stdin);
    // freopen("divide.out","w",stdout);
    int n=read(),k=read(),maxn=1;dep[1]=1;
    for(int i=2;i<=n;i++)fa[i]=read();init(n+1);
    for(int i=2;i<=n;i++){
        dep[i]=dep[fa[i]]+1;
        maxn=max(maxn,dep[i]);
    }for(int i=1;i<=n;i++)maxd[i]=dep[i];
    for(int i=n;i>=2;i--)maxd[fa[i]]=max(maxd[fa[i]],maxd[i]);
    // for(int i=2;i<=n;i++)cout<<dep[i]<<" ";puts("");
    for(int i=2;i<=n;i++)cnt[dep[i]]++;int tot=0;
    for(int d=1;d<=maxn;d++)pos[d]=tot,tot+=cnt[d];
    vector<int>cur(pos,pos+maxn+1);int ans=0;
    for(int i=2;i<=n;i++)m[cur[dep[i]]++]=maxd[i];
    // cout<<maxn<<"\n";
    for(int d=2;d<=maxn;d++){
        int mm=cnt[d];
        if(mm<k)continue;
        int l=pos[d],r=l+mm;
        sort(m+l,m+r);int pre=0,idx=l;
        while(idx<r){
            int j=idx,val=m[idx];
            while(j<r&&m[j]==val)j++;
            int a=j-idx,lcnt=pre,g=mm-lcnt,b=g-a;
            if(k<g){
                int t1=submod(c(g-1,k-1),c(b,k-1));
                t1=mulmod(a,t1);int t2=0;
                if((k-1)==b&&a>=2)t2=a%mod;
                int add=addmod(t1%mod,t2);
                add=mulmod(add,fac[k-1]);ans=addmod(ans,add);
            }pre+=a;idx=j;
        }
    }cout<<ans;
    return 0;
}

:::