P8560 约定(Promise)

· · 题解

update:修正了笔误

题目中的组合推导并不复杂,仿照有标号有根树计数的方法,设 n!k![x^ny^k]F 表示所有 n 个点的有标号有根二叉树权值 k 次方和,可以列出方程

F=x\text e^y+x\text e^{dy}F + \frac{1}{2}x\text e^{dy}F^2

如果不能直接推出这个方程,也可以先发现树的权值只与其节点数和叶子数量有关。先计量 n 个点、m 个叶子的二叉树个数,经过处理也可以得到一样的结果。

x 这一维使用拉格朗日反演,把原式先化为

x=\frac{F}{\text e^y+\text e^{dy}F+\text e^{dy}F^2/2}

得到

[x^ny^k]F=\frac 1n[x^{n-1}y^k]\left( \text e^y+\text e^{dy}x+\frac{\text e^{dy}x^2}{2}\right)^n =\frac 1n[y^k]\sum_{i=0}^n \binom ni \text e^{(n-i+di)y}[x^{n-1}]\left( x+\frac{x^2}{2}\right)^i

这个 \text e 上的指数看着很难受,设 m(d-1)\equiv n \pmod p,化为

\frac{(d-1)^k}{n}[y^k]\text e^{my}\sum_{i=0}^n \binom ni \text e^{iy}[x^{n-1}]\left( x+\frac{x^2}{2}\right)^i

现在设

g(z)=\sum_{i=0}^n \binom ni z^i[x^{n-1}] \left( x+\frac{x^2}{2}\right)^i=[x^{n-1}]\left( 1+zx+\frac{zx^2}{2}\right)^n

注意 g(z) 是一个超几何级数,其系数满足简单的一阶整式递推:

(2i-n)(2i+1-n)g_i = 2(n-i+1)(n-i)g_{i-1}

写为 ODE 就是

(4z^2-2z^3)g''+((6-4n)z+(-4+4n)z^2)g'+(n(n-1)+2n(1-n)z)g=0

现在我们想求 [z^k]\text e^{mz}g(\text e^z),可以应用 EI 提出的一套方法。在此之前需要先知道 G(z)=z^mg(z) 的 ODE。对 G 求二阶导可得

zG'=mG+z^{m+1}g' z^2G''=2mzG'-m(m+1)G+z^{m+2}g''

再把 g''g',g 表示就可以了:

z^2(4-2z)G''-((-6+4n+8m)z+4(1-n-m)z^2)G' n (n - 1) (2 z - 1))G =0

现在要求 [z^k]G(\text e^z),第一步是求 H(z)=G(z+1)\bmod z^{k+1} 满足的 ODE,但这又必须知道 H(z) 的前两项:

[z^0]H=G(1)=g(1)=[x^{n-1}]\left( 1+x+\frac{x^2}{2}\right)^n [z^1]H=H'(0)=mg'(1)+g(1)

这可以用 整式递推 算法以 \Theta(\sqrt n \log n) 的时间复杂度求解(由于模数固定,也可以分块打表)。在求解过程中可以对整式递推模板微调一下,让最后答案不用再乘 n!

然后就可以递推求 H(z) 的系数,以得到一个多项式 D(z) 使得

P_0(z)H(z)+P_1(z)H'(z)+P_2(z)H''(z)=D(z)

等式左边就是 G 的 ODE 中直接把 z 换为 z+1 得到的,多项式平移可以暴力处理。右边多出的这个 D(z) 就是因为截断了 z^k 之后的项产生的扰动,它只有常数个非零项。(因此直接对着 z^k 附近提取左侧系数,即可得到 D(z)

最后求出 \mathcal G(z)=H(z-1) 的系数,这样 [z^k]\mathcal G(\text e^z) 就是答案,线性筛求幂即可。注意递推过程中 D(z-1) 的系数可以直接提取(几个二项式系数的线性组合),不需要展开。

另外这里可以在 \Theta(1) 的时间内求出 \mathcal G(z) 最高次项,然后倒着推回去,会有不错的常数优化。不过需要注意一点:若 n+m<kH(z)k 次项为零,这里倒着推会出错,当然也没必要做前面那么复杂,直接线性筛后暴力 \Theta(n+m) 计算即可。

总时间复杂度 \mathcal O(k+\sqrt n \log n)。std 几乎不可读,不建议直接贺(

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define N 262147
#define M 5000005
#define ll long long
#define reg register
#define p 998244353
using namespace std;

struct Z{
    int v;
    inline Z(const int _v=0):v(_v){}
};

inline Z operator + (const Z& lhs,const Z& rhs){ return lhs.v+rhs.v<p ? lhs.v+rhs.v : lhs.v+rhs.v-p; }
inline Z operator - (const Z& lhs,const Z& rhs){ return lhs.v<rhs.v ? lhs.v-rhs.v+p : lhs.v-rhs.v; }
inline Z operator - (const Z& x){ return x.v?p-x:0; }
inline Z operator * (const Z& lhs,const Z& rhs){ return (ll)lhs.v*rhs.v%p; }
inline Z& operator += (Z& lhs,const Z& rhs){ lhs.v = lhs.v+rhs.v<p ? lhs.v+rhs.v : lhs.v+rhs.v-p; return lhs; }
inline Z& operator -= (Z& lhs,const Z& rhs){ lhs.v = lhs.v<rhs.v ? lhs.v-rhs.v+p : lhs.v-rhs.v; return lhs; }
inline Z& operator *= (Z& lhs,const Z& rhs){ lhs.v = (ll)lhs.v*rhs.v%p; return lhs; }
inline bool operator ! (const Z& x){ return x.v==0; }

struct poly{
    Z a[8];
    int t;
    inline Z operator [] (const int& x) const{ return a[x]; }
    inline Z& operator [] (const int& x){ return a[x]; }

    inline Z eval(const int& x){
        Z res = a[t];
        for(reg int i=t-1;~i;--i) res = a[i]+res*x;
        return res;
    }
}P[8];

struct ode{
    poly b[8];
    int ord,deg;
    inline poly operator [] (const int& x) const{ return b[x]; }
    inline poly& operator [] (const int& x) { return b[x]; }

    inline void update(){
        for(int i=0;i<8;++i) b[i].t = deg;
    }
};

inline Z check1(const Z* f,const ode& G,int n){
    Z res = 0,rfac;
    for(int j=0;j<=min(n,G.deg);++j){
        rfac = 1;
        for(int i=0;i<=G.ord;++i){
            res += G[i][j]*rfac*f[n-j+i];
            rfac *= (n-j+1+i);
        }
    }
    return res;
}

inline Z power(Z a,int t){
    Z res = 1;
    while(t){
        if(t&1) res *= a;
        a *= a;
        t >>= 1;
    }
    return res;
}

Z fpw[M];
int pr[348515];
bool vis[M];

void sieve(int n,int k){
    fpw[1] = 1;
    int cnt = 0;
    for(int i=2;i<=n;++i){
        if(!vis[i]){
            vis[i] = true;
            pr[++cnt] = i;
            fpw[i] = power(i,k);
        }
        for(int j=1;j<=cnt&&i*pr[j]<=n;++j){
            fpw[i*pr[j]] = fpw[i]*fpw[pr[j]];
            vis[i*pr[j]] = true;
            if(i%pr[j]==0) break;
        }
    }
}

int ms,deg;

struct matrix{
    Z a[2][2];
    inline matrix(){ memset(a,0,sizeof(a)); }

    inline matrix operator * (const matrix& b) const{
        matrix res;
        res.a[0][0] = a[0][0]*b.a[0][0]+a[0][1]*b.a[1][0]; 
        res.a[1][0] = a[1][0]*b.a[0][0]+a[1][1]*b.a[1][0];
        res.a[0][1] = a[0][0]*b.a[0][1]+a[0][1]*b.a[1][1];
        res.a[1][1] = a[1][0]*b.a[0][1]+a[1][1]*b.a[1][1];
        return res;    
    }
}I;

inline matrix getmat(int x){
    matrix res = matrix();
    Z p0 = P[0].eval(x+ms);
    for(reg int i=0;i!=ms-1;++i) res.a[i+1][i] = p0;
    for(reg int i=0;i!=ms;++i) res.a[i][ms-1] = -P[ms-i].eval(x+ms);
    return res;
}

Z fac[N],ifac[N],rt[N],facm[N],inv[M];
int rev[N];
int siz;

inline int getlen(int n){ return 1<<(32-__builtin_clz(n)); }

void init(int n,int k){
    int lim = 1;
    while(lim<=n) lim <<= 1,++siz;
    for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
    Z w = power(3,(p-1)>>siz);
    inv[1] = fac[0] = fac[1] = ifac[0] = ifac[1] = rt[lim>>1] = 1;
    for(int i=lim>>1|1;i!=lim;++i) rt[i] = rt[i-1]*w;
    for(int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
    for(int i=2;i<=n;++i) fac[i] = fac[i-1]*i;
    ifac[n] = power(fac[n],p-2);
    for(int i=n-1;i;--i) ifac[i] = ifac[i+1]*(i+1);
    for(int i=2;i<=k;++i) inv[i] = inv[p%i]*(p-p/i);
    I.a[0][0] = I.a[1][1] = 1;
}

inline void dft(Z *f,int lim){
    static unsigned long long a[N];
    reg int x,shift = siz-__builtin_ctz(lim);
    for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i].v;
    for(reg int mid=1;mid!=lim;mid<<=1)
    for(reg int j=0;j!=lim;j+=(mid<<1))
    for(reg int k=0;k!=mid;++k){
        x = a[j|k|mid]*rt[mid|k].v%p;
        a[j|k|mid] = a[j|k]+p-x;
        a[j|k] += x;
    }
    for(reg int i=0;i!=lim;++i) f[i] = a[i]%p;
}

inline void idft(Z *f,int lim){
    reverse(f+1,f+lim);
    dft(f,lim);
    reg int x = p-((p-1)>>__builtin_ctz(lim));
    for(reg int i=0;i!=lim;++i) f[i] *= x;
}

inline void lagrange(const matrix* F1,int n,Z m,matrix* R1){
    static Z pre[N],suf[N],f1[N],f2[N],g[N],inv_[N],ifcm[N],mul;
    int k = n<<1|1,lim = getlen(n<<1);
    facm[0] = 1;
    for(reg int i=0;i<=n;++i){
        facm[0] *= m-n+i;
        ifcm[i] = ifac[i]*ifac[n-i];
    }
    pre[0] = suf[k+1] = 1;
    for(reg int i=1;i<=k;++i) pre[i] = pre[i-1]*(m-n+i-1);
    for(reg int i=k;i;--i) suf[i] = suf[i+1]*(m-n+i-1);
    mul = power(pre[k],p-2);
    for(reg int i=1;i<=k;++i) inv_[i] = mul*pre[i-1]*suf[i+1];
    for(reg int i=1;i<=n;++i) facm[i] = facm[i-1]*(m+i)*inv_[i];
    for(reg int i=0;i!=k;++i) g[i] = inv_[i+1];
    memset(g+k,0,(lim-k+1)<<2);
    dft(g,lim);
    for(reg int i=0;i!=ms;++i)
    for(reg int j=0;j!=ms;++j){
        for(reg int t=0;t<=n;++t) f1[t] = ifcm[t]*((n-t)&1?-F1[t].a[i][j]:F1[t].a[i][j]);
        memset(f1+n+1,0,(lim-n)<<2);
        dft(f1,lim);
        for(reg int t=0;t!=lim;++t) f1[t] *= g[t];
        idft(f1,lim);
        for(reg int t=0;t<=n;++t) R1[t].a[i][j] = f1[t+n]*facm[t];
    }
}

inline matrix ff(int d,int x){
    matrix res = getmat(x);
    for(reg int i=1;i!=d;++i) res = res*getmat(x+i);
    return res;
}

inline Z gg(int d,int x){
    Z res = P[0].eval(x);
    for(reg int i=1;i!=d;++i) res *= P[0].eval(x+i);
    return res;
}

int kk;

matrix magic(int s,int t){ 
    static Z invs = power(s,p-2);
    static matrix f[N],fd[N];
    int st[30],top = 0,x = s,d = 1,kd;
    while(x){
        st[++top] = x;
        x >>= 1;
    }
    for(reg int i=0;i<=kk;++i){
        x = i*s;
        f[i] = getmat(x);
    }
    --top;
    while(top--){
        kd = kk*d;
        lagrange(f,kd,kd+1,f+kd+1);
        f[kd<<1|1] = matrix();
        lagrange(f,kd<<1,d*invs,fd);
        for(reg int i=0;i<=(kd<<1);++i) f[i] = f[i]*fd[i];
        d <<= 1;
        if(!(st[top+1]&1)) continue;
        kd = kk*(d+1);
        for(reg int i=kk*d+1;i<=kd;++i){
            x = i*s;
            f[i] = ff(d,x);
        }
        for(reg int i=0;i<=kd;++i){
            x = i*s;
            f[i] = f[i]*getmat(x+d);
        }
        ++d;
    }
    matrix r1 = I;
    for(reg int i=0;i<=t;++i) r1 = r1*f[i];
    return r1;
}

Z P_recursive(const Z *a,int n){
    int tn = n-ms+1,s;
    s = ceil(sqrt(tn*1.0/kk))+1;
    matrix mul = magic(s,(tn-s)/s);
    Z res = 0;
    for(reg int i=(tn/s)*s;i!=tn;++i) mul = mul*getmat(i);
    for(int i=0;i!=ms;++i) res += a[i]*mul.a[i][ms-1];
    return res;
}

inline Z binom(int n,int m){
    if(n<m) return Z(0);
    return fac[n]*ifac[m]*ifac[n-m];
}

Z prepare(int k,int n){ 
    static Z a[N];
    deg = kk = 1;
    ms = 2;
    P[0][1] = 1;
    P[1][0] = p-(k+1),P[1][1] = 1;
    P[2][0] = p-(k+1),P[2][1] = inv[2];
    a[0] = 1,a[1] = k;
    P[0].t = P[1].t = P[2].t = 1;
    if(n<=1000){
        for(int i=2;i<=n;++i){
            Z res = P[1].eval(i)*a[i-1]+P[2].eval(i)*a[i-2];
            a[i] = -res*power(P[0].eval(i),p-2);
        }
        return a[n]*fac[n];
    }
    return P_recursive(a,n);
}

ode G,H;

void poly_shift(){
    for(int i=0;i<=G.ord;++i)
    for(int k=0;k<=G.deg;++k)
    for(int j=k;j<=G.deg;++j)
        H[i][k] += G[i][j]*binom(j,k);
}

int n,k,d,len,lim;
Z g[M],h[M],pre[M],suf[M];
Z ans,m,r1,r2,r3;

inline Z check2(const int& n){
    Z res = H[0][0]*h[n]+H[1][0]*h[n+1]*(n+1);
    res += H[0][1]*h[n-1]+(H[1][1]*h[n]+H[2][1]*h[n+1]*(n+1))*n;
    res += (H[1][2]*h[n-1]+H[2][2]*h[n]*n)*(n-1);
    return res+H[2][3]*h[n-1]*(n-1)*(n-2);
}

int main(){ 
    scanf("%d%d%d",&n,&k,&d);
    m = power(d-1,p-2)*n;
    init(131075,max(1000,k)+3); 
    if(n+m.v<k){
        Z pw2 = power(2,p-n);
        for(int i=0;i<=n;++i){
            g[i] = fac[n-1]*binom(n,i)*binom(i,n-i-1)*pw2;
            pw2 += pw2;
        }
        sieve(n+m.v,k);
        for(int i=0;i<=n;++i) ans += fpw[m.v+i]*g[i];
        ans *= power(d-1,k); 
        printf("%d\n",ans.v);
        return 0;
    }
    Z _n = n,tmp;
    G[0][0] = p-m*(m+1)*4+m*(6-4*_n)+_n*(1+p-_n);
    G[0][1] = m*(m+1)*2+m*(4*_n-4)+2*_n*(_n-1);
    G[1][1] = 4*_n+8*m-6,G[1][2] = 4*(1+p-_n-m);
    G[2][2] = p-4,G[2][3] = 2;
    G.ord = H.ord = 2,G.deg = H.deg = 3;
    poly_shift();
    G.update(),H.update();
    h[0] = prepare(n,n-1);
    if(k==0){
        printf("%d",h[0].v);
        return 0;
    }
    h[1] = (h[0]-prepare(n-1,n-1))*n+h[0]*m;
    Z invh0 = power(H[2][0],p-2);
    for(int i=0;i<=min(k-2,2);++i) h[i+2] = -check1(h,H,i)*invh0*inv[i+1]*inv[i+2];
    for(int i=3;i<=k-2;++i) h[i+2] = -check2(i)*invh0*inv[i+1]*inv[i+2];
    r1 = check1(h,H,k-1),r2 = check1(h,H,k),r3 = check1(h,H,k+1);
    g[k] = h[k],g[k-1] = h[k-1]-h[k]*k;
    pre[0] = suf[k+1] = 1;
    for(int j=1;j<=k;++j) pre[j] = G[0][1]+(G[1][2]+(j-2)*2)*(j-1);
    for(int j=k;j;--j) suf[j] = suf[j+1]*pre[j];
    for(int j=1;j<=k;++j) pre[j] *= pre[j-1]; 
    Z Inv = power(pre[k],p-2),c1 = r1,c2 = k*r2,c3 = inv[2]*(k+1)*k*r3,falfac = 1;
    if(Inv.v==0){
        return 1;
    }
    for(int j=k-1;j>1;--j){
        Z tmp1 = (G[0][0]+j*(G[1][1]-(j-1)*4))*g[j];
        Z tmp2 = (k-j)&1?(c1-c2+c3):(c2-c1-c3);
        g[j-1] = (tmp2*falfac-tmp1)*Inv*pre[j-1]*suf[j+1];
        c1 *= inv[k-j],c2 *= inv[k-j+1],c3 *= inv[k-j+2];
        falfac *= j;
    }
    sieve(k,k);
    for(int i=1;i<=k;++i) ans += fpw[i]*g[i];
    ans *= power(d-1,k);
    printf("%d\n",ans.v);
    return 0;   
}