题解 P7439 【入门级的排列计数】

· · 题解

本文同步发表于我的博客:https://www.alpha1022.me/articles/lg-7439.htm。

首先把给定的多项式转成牛顿级数,即转写成

F(x) = \sum\limits_{i=0}^{k-1} a_i \binom xi

然后套路的拆开项,考虑计算

\sum\limits_{\pi} \binom{{\rm cyc}_{\pi}}{k}

让我们考虑一个错排,显然它是由非自环的循环置换为基本单位构成的,即

{\rm e}^{-x-\ln(1-x)}

然后考虑从中选 k 个。不难发现这只需要对一个循环置换附加一个因子 (1+y) 即可做到:

{\rm e}^{(1+y)(-x-\ln(1-x))}

首先求考虑求 [x^n] {\rm e}^{y(-x-\ln(1-x))},然后使用二项式展开并卷积处理。

\frac{f^2}2 = -x-\ln(1-x),显然其存在复合逆,设复合逆为 g

[x^n]{\rm e}^{yf^2/2} = \frac1n[x^{n-1}] \frac{\partial{\rm e}^{yx^2/2}}{\partial x} \left(\frac xg\right)^n

我们知道

\frac{\partial{\rm e}^{yx^2/2}}{\partial x} = xy {\rm e}^{yx^2/2} = \sum\limits_{i\ge 0} \frac{x^{2i+1}y^{i+1}}{2^ii!}

因此

\frac1n[x^{n-1} y^k] \frac{\partial{\rm e}^{yx^2/2}}{\partial x} \left(\frac xg\right)^n = \frac1{n2^{k-1} (k-1)!} [x^{n-2k}] \left(\frac xg\right)^n

考虑求 g。我们知道

\frac{f^2}2 = -x-\ln(1-x)

所以

\frac{x^2}2 = -g-\ln(1-g)

不幸的是如果我们直接对这个式子牛顿迭代,会出现一些意料之外的问题。
所以不妨将其改写成

\sqrt{-2g-2\ln(1-g)}-x=0

然后对其牛顿迭代。

时间复杂度 O(n \log^2 n)(视 n,k 同阶)。

代码:

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define add(a,b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a,b) (a < b ? a - b + mod : a - b)
#define ls (p << 1)
#define rs (ls | 1)
using namespace std;
const int N = 1e5;
const int mod = 998244353;
int n,k,m;
inline int fpow(int a,int b)
{
    int ret = 1;
    for(;b;b >>= 1)
        (b & 1) && (ret = (long long)ret * a % mod),a = (long long)a * a % mod;
    return ret;
}
namespace Poly
{
    const int LG = 17;
    const int N = 1 << LG + 1;
    const int G = 3;
    int lg2[N + 5];
    int rev[N + 5],fac[N + 5],ifac[N + 5],inv[N + 5];
    int rt[N + 5];
    inline void init()
    {
        for(register int i = 2;i <= N;++i)
            lg2[i] = lg2[i >> 1] + 1;
        rt[0] = 1,rt[1 << LG] = fpow(G,(mod - 1) >> LG + 2);
        for(register int i = LG;i;--i)
            rt[1 << i - 1] = (long long)rt[1 << i] * rt[1 << i] % mod;
        for(register int i = 1;i < N;++i)
            rt[i] = (long long)rt[i & i - 1] * rt[i & -i] % mod;
        fac[0] = 1;
        for(register int i = 1;i <= N;++i)
            fac[i] = (long long)fac[i - 1] * i % mod;
        ifac[N] = fpow(fac[N],mod - 2);
        for(register int i = N;i;--i)
            ifac[i - 1] = (long long)ifac[i] * i % mod;
        for(register int i = 1;i <= N;++i)
            inv[i] = (long long)ifac[i] * fac[i - 1] % mod;
    }
    struct poly
    {
        vector<int> a;
        inline poly(int x = 0)
        {
            x && (a.push_back(x),1);
        }
        inline poly(const vector<int> &o)
        {
            a = o,shrink();
        }
        inline poly(const poly &o)
        {
            a = o.a,shrink();
        }
        inline void shrink()
        {
            for(;!a.empty() && !a.back();a.pop_back());
        }
        inline int size() const
        {
            return a.size();
        }
        inline void resize(int x)
        {
            a.resize(x);
        }
        inline int operator[](int x) const
        {
            if(x < 0 || x >= size())
                return 0;
            return a[x];
        }
        inline void clear()
        {
            vector<int>().swap(a);
        }
        inline poly rever() const
        {
            return poly(vector<int>(a.rbegin(),a.rend()));
        }
        inline void dif()
        {
            int n = size();
            for(register int i = 0,len = n >> 1;len;++i,len >>= 1)
                for(register int j = 0,*w = rt;j < n;j += len << 1,++w)
                    for(register int k = j,R;k < j + len;++k)
                        R = (long long)*w * a[k + len] % mod,
                        a[k + len] = dec(a[k],R),
                        a[k] = add(a[k],R);
        }
        inline void dit()
        {
            int n = size();
            for(register int i = 0,len = 1;len < n;++i,len <<= 1)
                for(register int j = 0,*w = rt;j < n;j += len << 1,++w)
                    for(register int k = j,R;k < j + len;++k)
                        R = add(a[k],a[k + len]),
                        a[k + len] = (long long)(a[k] - a[k + len] + mod) * *w % mod,
                        a[k] = R;
            reverse(a.begin() + 1,a.end());
            for(register int i = 0;i < n;++i)
                a[i] = (long long)a[i] * inv[n] % mod;
        }
        inline void ntt(int type = 1)
        {
            type == 1 ? dif() : dit();
        }
        friend inline poly operator+(const poly &a,const poly &b)
        {
            vector<int> ret(max(a.size(),b.size()));
            for(register int i = 0;i < ret.size();++i)
                ret[i] = add(a[i],b[i]);
            return poly(ret);
        }
        friend inline poly operator-(const poly &a,const poly &b)
        {
            vector<int> ret(max(a.size(),b.size()));
            for(register int i = 0;i < ret.size();++i)
                ret[i] = dec(a[i],b[i]);
            return poly(ret);
        }
        friend inline poly operator*(poly a,poly b)
        {
            if(a.a.empty() || b.a.empty())
                return poly();
            if(a.size() < 40 || b.size() < 40)
            {
                if(a.size() > b.size())
                    swap(a,b);
                poly ret;
                ret.resize(a.size() + b.size() - 1);
                for(register int i = 0;i < ret.size();++i)
                    for(register int j = 0;j <= i && j < a.size();++j)
                        ret.a[i] = (ret[i] + (long long)a[j] * b[i - j]) % mod;
                ret.shrink();
                return ret;
            }
            int lim = 1,tot = a.size() + b.size() - 1;
            for(;lim < tot;lim <<= 1);
            a.resize(lim),b.resize(lim);
            a.ntt(),b.ntt();
            for(register int i = 0;i < lim;++i)
                a.a[i] = (long long)a[i] * b[i] % mod;
            a.ntt(-1),a.shrink();
            return a;
        }
        poly &operator+=(const poly &o)
        {
            resize(max(size(),o.size()));
            for(register int i = 0;i < o.size();++i)
                a[i] = add(a[i],o[i]);
            return *this;
        }
        poly &operator-=(const poly &o)
        {
            resize(max(size(),o.size()));
            for(register int i = 0;i < o.size();++i)
                a[i] = dec(a[i],o[i]);
            return *this;
        }
        poly &operator*=(poly o)
        {
            return (*this) = (*this) * o;
        }
        poly deriv() const
        {
            if(a.empty())
                return poly();
            vector<int> ret(size() - 1);
            for(register int i = 0;i < size() - 1;++i)
                ret[i] = (long long)(i + 1) * a[i + 1] % mod;
            return poly(ret);
        }
        poly integ() const
        {
            if(a.empty())
                return poly();
            vector<int> ret(size() + 1);
            for(register int i = 0;i < size();++i)
                ret[i + 1] = (long long)a[i] * inv[i + 1] % mod;
            return poly(ret);
        }
        inline poly modxn(int n) const
        {
            if(a.empty())
                return poly();
            n = min(n,size());
            return poly(vector<int>(a.begin(),a.begin() + n));
        }
        inline poly inver(int m) const
        {
            poly ret(fpow(a[0],mod - 2)),f,g;
            for(register int k = 1;k < m;)
            {
                k <<= 1,f.resize(k),g.resize(k);
                for(register int i = 0;i < k;++i)
                    f.a[i] = (*this)[i],g.a[i] = ret[i];
                f.ntt(),g.ntt();
                for(register int i = 0;i < k;++i)
                    f.a[i] = (long long)f[i] * g[i] % mod;
                f.ntt(-1);
                for(register int i = 0;i < (k >> 1);++i)
                    f.a[i] = 0;
                f.ntt();
                for(register int i = 0;i < k;++i)
                    f.a[i] = (long long)f[i] * g[i] % mod;
                f.ntt(-1);
                ret.resize(k);
                for(register int i = (k >> 1);i < k;++i)
                    ret.a[i] = dec(0,f[i]);
            }
            return ret.modxn(m);
        }
        inline pair<poly,poly> div(poly o) const
        {
            if(size() < o.size())
                return make_pair(poly(),*this);
            poly f,g;
            f = (rever().modxn(size() - o.size() + 1) * o.rever().inver(size() - o.size() + 1)).modxn(size() - o.size() + 1).rever();
            g = (modxn(o.size() - 1) - o.modxn(o.size() - 1) * f.modxn(o.size() - 1)).modxn(o.size() - 1);
            return make_pair(f,g);
        }
        inline poly log(int m) const
        {
            return (deriv() * inver(m)).integ().modxn(m);
        }
        inline poly exp(int m) const
        {
            poly ret(1),iv,it,d = deriv(),itd,itd0,t1;
            if(m < 70)
            {
                ret.resize(m);
                for(register int i = 1;i < m;++i)
                {
                    for(register int j = 1;j <= i;++j)
                        ret.a[i] = (ret[i] + (long long)j * operator[](j) % mod * ret[i - j]) % mod;
                    ret.a[i] = (long long)ret[i] * inv[i] % mod;
                }
                return ret;
            }
            for(register int k = 1;k < m;)
            {
                k <<= 1;
                it.resize(k >> 1);
                for(register int i = 0;i < (k >> 1);++i)
                    it.a[i] = ret[i];
                itd = it.deriv(),itd.resize(k >> 1);
                iv = ret.inver(k >> 1),iv.resize(k >> 1);
                it.ntt(),itd.ntt(),iv.ntt();
                for(register int i = 0;i < (k >> 1);++i)
                    it.a[i] = (long long)it[i] * iv[i] % mod,
                    itd.a[i] = (long long)itd[i] * iv[i] % mod;
                it.ntt(-1),itd.ntt(-1),it.a[0] = dec(it[0],1);
                for(register int i = 0;i < k - 1;++i)
                    itd.a[i % (k >> 1)] = dec(itd[i % (k >> 1)],d[i]);
                itd0.resize((k >> 1) - 1);
                for(register int i = 0;i < (k >> 1) - 1;++i)
                    itd0.a[i] = d[i];
                itd0 = (itd0 * it).modxn((k >> 1) - 1);
                t1.resize(k - 1);
                for(register int i = (k >> 1) - 1;i < k - 1;++i)
                    t1.a[i] = itd[(i + (k >> 1)) % (k >> 1)];
                for(register int i = k >> 1;i < k - 1;++i)
                    t1.a[i] = dec(t1[i],itd0[i - (k >> 1)]);
                t1 = t1.integ();
                for(register int i = 0;i < (k >> 1);++i)
                    t1.a[i] = t1[i + (k >> 1)];
                for(register int i = (k >> 1);i < k;++i)
                    t1.a[i] = 0;
                t1.resize(k >> 1),t1 = (t1 * ret).modxn(k >> 1),t1.resize(k);
                for(register int i = (k >> 1);i < k;++i)
                    t1.a[i] = t1[i - (k >> 1)];
                for(register int i = 0;i < (k >> 1);++i)
                    t1.a[i] = 0;
                ret -= t1;
            }
            return ret.modxn(m);
        }
        inline poly sqrt(int m) const
        {
            poly ret(1),f,g;
            for(register int k = 1;k < m;)
            {
                k <<= 1;
                f = ret,f.resize(k >> 1);
                f.ntt();
                for(register int i = 0;i < (k >> 1);++i)
                    f.a[i] = (long long)f[i] * f[i] % mod;
                f.ntt(-1);
                for(register int i = 0;i < k;++i)
                    f.a[i % (k >> 1)] = dec(f[i % (k >> 1)],(*this)[i]);
                g = (2 * ret).inver(k >> 1),f = (f * g).modxn(k >> 1),f.resize(k);
                for(register int i = (k >> 1);i < k;++i)
                    f.a[i] = f[i - (k >> 1)];
                for(register int i = 0;i < (k >> 1);++i)
                    f.a[i] = 0;
                ret -= f;
            }
            return ret.modxn(m);
        }
        inline poly pow(int m,int k1,int k2 = -1) const
        {
            if(a.empty())
                return poly();
            if(k2 == -1)
                k2 = k1;
            int t = 0;
            for(;t < size() && !a[t];++t);
            if((long long)t * k1 >= m)
                return poly();
            poly ret;
            ret.resize(m);
            int u = fpow(a[t],mod - 2),v = fpow(a[t],k2);
            for(register int i = 0;i < m - t * k1;++i)
                ret.a[i] = (long long)operator[](i + t) * u % mod;
            ret = ret.log(m - t * k1);
            for(register int i = 0;i < ret.size();++i)
                ret.a[i] = (long long)ret[i] * k1 % mod;
            ret = ret.exp(m - t * k1),t *= k1,ret.resize(m);
            for(register int i = m - 1;i >= t;--i)
                ret.a[i] = (long long)ret[i - t] * v % mod;
            for(register int i = 0;i < t;++i)
                ret.a[i] = 0;
            return ret;
        }
    };
}
using Poly::init;
using Poly::poly;
inline int C(int n,int m)
{
    return n < m ? 0 : (long long)Poly::fac[n] * Poly::ifac[m] % mod * Poly::ifac[n - m] % mod;
}
poly f,seg[(N << 2) + 5],g;
poly t1,t2;
inline poly comp_xplusa(const poly &f,int a)
{
    int n = f.size();
    poly t1,t2,ret;
    t1.resize(n),t2.resize(n);
    for(register int i = 0,pw = 1;i < n;++i,pw = (long long)pw * a % mod)
        t1.a[n - 1 - i] = (long long)Poly::fac[i] * f[i] % mod,
        t2.a[i] = (long long)Poly::ifac[i] * pw % mod;
    t1 *= t2;
    ret.resize(n);
    for(register int i = 0;i < n;++i)
        ret.a[i] = (long long)Poly::ifac[i] * t1[n - 1 - i] % mod;
    return ret;
}
void build(int n,int p)
{
    if(n == 1)
    {
        seg[p].resize(2),seg[p].a[1] = 1;
        return ;
    }
    int mid = n + 1 >> 1;
    build(mid,ls),build(n - mid,rs);
    seg[p] = seg[ls] * comp_xplusa(seg[rs],dec(0,mid));
}
poly solve(const poly &f,int n,int p)
{
    if(n == 1)
        return f;
    int mid = n + 1 >> 1;
    poly ret;
    ret.resize(n);
    pair<poly,poly> res = f.div(seg[ls]);
    poly t1 = solve(res.second,mid,ls);
    poly t2 = solve(comp_xplusa(res.first,mid),n - mid,rs);
    for(register int i = 0;i < mid;++i)
        ret.a[i] = t1[i];
    for(register int i = mid;i < n;++i)
        ret.a[i] = t2[i - mid];
    return ret;
}
inline poly calc(int m)
{
    poly ret,t,t1,t2,t3;
    ret.resize(2),ret.a[1] = 1;
    for(register int k = 2;k < m;)
    {
        k <<= 1;
        t = 0 - 2 * ret - 2 * (1 - ret).log(k + 1),t.resize(k + 1);
        for(register int i = 0;i < k - 1;++i)
            t.a[i] = t[i + 2];
        t.resize(k - 1),t = t.sqrt(k - 1),t.resize(k);
        for(register int i = k - 1;i;--i)
            t.a[i] = t[i - 1];
        t.a[0] = 0,t1 = t,t1.resize(k),t1.a[1] = dec(t1[1],1);
        for(register int i = 0;i < (k >> 1);++i)
            t1.a[i] = t1[i + (k >> 1)];
        for(register int i = (k >> 1);i < k;++i)
            t1.a[i] = 0;
        t2 = ((1 - ret) * t).modxn((k >> 1) + 1),t2.resize((k >> 1) + 1);
        for(register int i = 0;i < (k >> 1);++i)
            t2.a[i] = t2[i + 1];
        t3.resize(k >> 1);
        for(register int i = 0;i < (k >> 1);++i)
            t3.a[i] = ret[i + 1];
        t2.resize(k >> 1),t2 = (t2.inver(k >> 1) * t3).modxn(k >> 1);
        t1 = (t1 * t2.inver(k >> 1)).modxn(k >> 1),t1.resize(k);
        for(register int i = (k >> 1);i < k;++i)
            t1.a[i] = t1[i - (k >> 1)];
        for(register int i = 0;i < (k >> 1);++i)
            t1.a[i] = 0;
        ret -= t1;
    }
    return ret.modxn(m);
}
int ans;
int main()
{
    init();
    scanf("%d%d",&n,&k),f.resize(k);
    for(register int i = 0;i < k;++i)
        scanf("%d",&f.a[i]);
    build(k,1),f = solve(f,k,1);
    for(register int i = 0;i < k;++i)
        f.a[i] = (long long)f[i] * Poly::fac[i] % mod;
    g = calc(n + 1),g.resize(n + 1);
    for(register int i = 0;i < n;++i)
        g.a[i] = g[i + 1];
    g.resize(n),g = g.pow(n,mod - n,mod - n - 1);
    m = (n >> 1) + 1,t1.resize(m),t2.resize(m);
    for(register int i = 0,pw = 2;i < m;++i,pw = (long long)pw * Poly::inv[2] % mod)
        i && (t1.a[m - i - 1] = (long long)Poly::fac[n - 1] * pw % mod * i % mod * g[n - 2 * i] % mod),
        t2.a[i] = Poly::ifac[i];
    t1 = t1 * t2;
    for(register int i = 0;i < min(k,m);++i)
        ans = (ans + (long long)f[i] * Poly::ifac[i] % mod * t1[m - 1 - i]) % mod;
    printf("%d\n",ans);
}