[ABC390G] Permutation Concatenation 题解

· · 题解

ABC390 录播

记录第一道自己做出来的 poly!

不难看出答案为:

\begin{aligned} \sum_{x=1}^nx\sum_{j=0}^{n-1}(n-j-1)!j!\sum_{i_1+\dots+i_6=j}\prod_{k=1}^6{a_k\choose i_k}10^{ki_k} \end{aligned}

其中 a_i 表示去掉 x 后十进制位数为 i 的数的数量。

意思是,我们枚举一个数 x,计算出它对答案的贡献,考虑在排列中放在它后面的数,统计在所有方案中它们位数的和:j 表示一共放了多少个数,而 i_1\dots i_k 表示每种长度的数的数量。

可以把 x 按照位数分成 6 组,每组后面的那一坨是一样的。

考虑一个 DP,设 f_{p,j} 表示考虑了 a_1\dots a_p 的数字,已经有 i_1+\dots+i_p=jj! 后面那个式子的总和。

转移时枚举 t=i_p

f_{p,j}=\sum_{t=0}^{a_p}f_{p-1,j-t}{a_p\choose t}10^{pt}

发现转移是一个卷积,直接 NTT 即可。

时间复杂度 \mathcal O(n\log^2n),其中第二个 \log10 为底。

#include<bits/stdc++.h>
#define rep(x,qwq,qaq) for(int x=(qwq);x<=(qaq);++x)
#define per(x,qwq,qaq) for(int x=(qwq);x>=(qaq);--x)
using namespace std;
#define m998 998244353
#define mod m998
template<typename Tp>
int qp(int x,Tp y) {
    assert(y>=0);
    x%=mod;
    int res=1;
    while(y) {
        if(y&1)res=1ll*res*x%mod;
        x=1ll*x*x%mod;
        y>>=1;
    }
    return res;
}
int inv(int x) {
    return qp(x,mod-2);
}

template <int MOD>
struct modint {
    int val;
    static int norm(const int& x) {
        return x < 0 ? x + MOD : x;
    }
    static constexpr int get_mod() {
        return MOD;
    }
    modint() : val(0) {}
    modint(const int& m) : val(norm(m)) {}
    modint(const long long& m) : val(norm(m % MOD)) {}
    modint operator-() const {
        return modint(norm(-val));
    }
    bool operator==(const modint& o) {
        return val == o.val;
    }
    bool operator<(const modint& o) {
        return val < o.val;
    }
    modint& operator+=(const modint& o) {
        return val = norm(val + o.val-MOD), *this;
    }
    modint& operator-=(const modint& o) {
        return val = norm(1ll * val - o.val), *this;
    }
    modint& operator*=(const modint& o) {
        return val = static_cast<int>(1ll * val * o.val % MOD), *this;
    }
    modint& operator/=(const modint& o) {
        return *this *= o.inv();
    }
    modint& operator^=(const modint& o) {
        return val ^= o.val, *this;
    }
    modint& operator>>=(const modint& o) {
        return val >>= o.val, *this;
    }
    modint& operator<<=(const modint& o) {
        return val <<= o.val, *this;
    }
    modint operator-(const modint& o) const {
        return modint(*this) -= o;
    }
    modint operator+(const modint& o) const {
        return modint(*this) += o;
    }
    modint operator*(const modint& o) const {
        return modint(*this) *= o;
    }
    modint operator/(const modint& o) const {
        return modint(*this) /= o;
    }
    modint operator^(const modint& o) const {
        return modint(*this) ^= o;
    }
    bool operator!=(const modint& o) {
        return val != o.val;
    }
    modint operator>>(const modint& o) const {
        return modint(*this) >>= o;
    }
    modint operator<<(const modint& o) const {
        return modint(*this) <<= o;
    }
    friend std::istream& operator>>(std::istream& is, modint& a) {
        long long v;
        return is >> v, a.val = norm(v % MOD), is;
    }
    friend std::ostream& operator<<(std::ostream& os, const modint& a) {
        return os << a.val;
    }
    friend std::string tostring(const modint& a) {
        return std::to_string(a.val);
    }
    template <typename T>
    friend modint qpow(const modint a, const T& b) {
        assert(b >= 0);
        modint x = a, res = 1;
        for (T p = b; p; x *= x, p >>= 1)
            if (p & 1) res *= x;
        return res;
    }
    modint inv() const {
        return qpow(*this,MOD-2);
    }
};
using M107 = modint<1000000007>;
using M998 = modint<998244353>;

using mint = M998;

struct Combinatorics {
#define Lim 2000000
    int fac[Lim+10],invfac[Lim+10];
    Combinatorics() {
        fac[0]=invfac[0]=1;
        rep(i,1,Lim)fac[i]=1ll*fac[i-1]*i%mod;
        invfac[Lim]=inv(fac[Lim]);
        per(i,Lim-1,1)invfac[i]=1ll*invfac[i+1]*(i+1)%mod;
    }
    mint C(int n,int m) {
        if(n<m||n<0||m<0)return 0;
        return 1ll*fac[n]*invfac[m]%mod*invfac[n-m]%mod;
    }
    int A(int n,int m) {
        if(n<m||n<0||m<0)return 0;
        return 1ll*fac[n]*invfac[n-m]%mod;
    }
} comb;
const mint g=3;
void NTT(vector<mint>&f,const int N,const int op) {
    vector<int>rev(N);
    int t=__lg(N)-1;
    rep(i,0,N-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<t);
    rep(i,0,N-1)if(i<rev[i])swap(f[i],f[rev[i]]);
    for(int n=2; n<=N; n<<=1) {
        mint w1=qpow(op==1?g:g.inv(),(m998-1)/n);
        for(int j=0; j<N; j+=n) {
            //[j,j+n/2)[j+n/2,j+n)
            mint wk=1;
            for(int i=j; i<j+n/2; ++i,wk*=w1) {
                mint f0=f[i],f1=wk*f[i+n/2];
                f[i]=f0+f1,f[i+n/2]=f0-f1;
            }
        }
    }
}
//------------------------------------------------------------------以上是模板
int n;
int a[10];
int lg(int x) {
    int res=0;
    do {
        ++res;
        x/=10;
    } while(x);
    return res;
}
vector<mint>f[10];
bool Med;
signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n;
    rep(i,1,n)++a[lg(i)];
    mint ans=0;
    int N=1;
    while(N<n)N<<=1;
    rep(i,1,6){
        f[i].resize(N);
        rep(j,0,a[i]){
            f[i][j]=comb.C(a[i],j)*qp(10,i*j);
        }
        NTT(f[i],N,1);
    }
    mint iv=(mint)1/N;
    rep(i,1,6) {
        --a[i];
        fill(f[i].begin(),f[i].end(),0);
        rep(k,0,a[i])f[i][k]=comb.C(a[i],k)*qp(10,i*k);
        NTT(f[i],N,1);
        vector<mint>F=f[1];
        rep(j,2,6)rep(p,0,N-1)F[p]*=f[j][p];
        NTT(F,N,-1);
        rep(p,0,N-1)F[p]*=iv;
        mint sum=0;
        rep(j,1,n)if(lg(j)==i)sum+=j;
        rep(j,0,n-1)ans+=F[j]*comb.fac[j]*sum*comb.fac[n-j-1];
        ++a[i];
        fill(f[i].begin(),f[i].end(),0);
        rep(k,0,a[i])f[i][k]=comb.C(a[i],k)*qp(10,i*k);
        NTT(f[i],N,1);
    }
    cout<<ans<<'\n';
    return 0;
}