题解:AT_arc209_e [ARC209E] I hate ABC

· · 题解

Solution

考虑到 N-K\le 100,我们尝试将 K 转化为 N-K。将价值改写为删除尽量少的项使得原串不含 \tt ABC 子序列,则相当于问价值等于 N-K 的串有多少个,这就完成了 K\gets N-K 的转化(下文认为 K\gets N-K)。

考虑如何判断一个串是否不含任何 \tt ABC 子序列。进行一些感受后可以发现这相当于能够将原串划分为三段,第一段不含 \tt A,第二段不含 \tt B,且第三段不含 \tt C。据此我们可以写出计算一个序列的价值的代码:

int A=0,B=0,C=0;
for(int i=1;i<=N;i++){
    if(s[i]=='A')A=A+1;
    if(s[i]=='B')B=min(B+1,A);
    if(s[i]=='C')C=min(C+1,B);
}

其中变量 A,B,C 分别表示目前划分了 1,2,3 段的代价。

尝试计算价值 \ge K 的序列数量,一步差分即可求得原答案。将原代码改写为

int A=0,B=0,C=0;
for(int i=1;i<=N;i++){
    if(s[i]=='A')A=min(A+1,K);
    if(s[i]=='B')B=min(B+1,A);
    if(s[i]=='C')C=min(C+1,B);
}

若程序结束后有 A=B=C=K 则说明序列 s 的价值 \ge K

考察三元组 (A,B,C) 的变化,不难发现它会改变恰好 3K 次。考察在相邻两次变化中间出现的字符,不难发现这些字符均不会使 (A,B,C) 发生变化(废话)。设判定过程中有 1,2,3 种字符不使三元组发生变化的三元组分别出现了 p,q,r 次,则一个三元组序列对答案的贡献就是

[x^{N-3K}]\frac{1}{(1-x)^p(1-2x)^q(1-3x)^r}

不难发现我们有 p,q\le 3Kr=1(这是因为当且仅当 A=B=C=K 时三种字符均不会使三元组发生变化)。我们将原序列通分为

[x^{N-3K}]\frac{f(x)}{(1-x)^{3K}(1-2x)^{3K}(1-3x)}

则只要求出了 f(x) 的和我们就能获得答案的表达式,这个显然可以 O(K^4) 计算。

但是这样之后求答案还是有些困难。考虑分式分解,将原式转化为

\frac{f(x)}{(1-x)^{3K}(1-2x)^{3K}(1-3x)}=\frac{f_1(x)}{(1-x)^{3K}}+\frac{f_2(x)}{(1-2x)^{3K}}+\frac{f_3(x)}{1-3x}

则只要我们求出了 f_1(x),f_2(x),f_3(x) 就容易在 O(K) 时间内回答单组询问。接下来有两种求 f_1(x),f_2(x),f_3(x) 的方法:

  1. 高斯消元。将上述等式乘以 (1-x)^{3K}(1-2x)^{3K}(1-3x) 后对照系数可以获得 O(K) 个方程,高消即可求解。这样做是 O(K^4) 的,但因为常数原因可能过不了。
  2. 考虑 CRT。根据 CRT,我们容易将限制转化为如下三个方程:

    \begin{cases}f_1(x)(1-2x)^{3K}(1-3x)\equiv f(x)\pmod{(1-x)^{3K}}\\f_2(x)(1-x)^{3K}(1-3x)\equiv f(x)\pmod{(1-2x)^{3K}}\\f_3(x)(1-x)^{3K}(1-2x)^{3K}\equiv f(x)\pmod{(1-3x)}\end{cases}

    则关键在于求出形如 (1-wx)^n 的多项式在 \bmod (1-sx)^m 意义下的逆元。一种方法是 exgcd 算,简单分析一下可以发现一次求逆的复杂度是 O(K^2) 的;另一种方法是换元:设 y=1-sx,则相当于计算 (ky+b)^{-n}\bmod y^m,这是容易计算的,算完后将 y 换回去即可,一次求逆的复杂度仍然是 O(K^2) 的。这部分的复杂度即为 O(K^3)

综上,我们可以在 O(K^4+TK) 的复杂度内解决原问题。

Code

bool Mst;
#include<bits/stdc++.h>
using namespace std;
using ui=unsigned int;
using ll=long long;
using ull=unsigned long long;
using i128=__int128;
using u128=__uint128_t;
using pii=pair<int,int>;
#define fi first
#define se second
constexpr int N=1e6+105,K=105,mod=998244353;
inline ll add(ll x,ll y){return (x+=y)>=mod&&(x-=mod),x;}
inline ll Add(ll &x,ll y){return x=add(x,y);}
inline ll sub(ll x,ll y){return (x-=y)<0&&(x+=mod),x;}
inline ll Sub(ll &x,ll y){return x=sub(x,y);}
inline ll qpow(ll a,ll b){
    ll res=1;
    for(;b;b>>=1,a=a*a%mod)
        if(b&1)res=res*a%mod;
    return res;
}
using poly=vector<ll>;
const poly w1=poly{1,mod-1},w2=poly{1,mod-2},w3=poly{1,mod-3};
inline ostream& operator <<(ostream &ouf,const poly &f){
    ouf<<'{';
    if(f.size()){
        ouf<<f[0];
        for(int i=1;i<(int)f.size();i++)ouf<<", "<<f[i];
    }
    ouf<<'}';
    return ouf;
}
inline void shrink(poly &f){
    while(f.size()&&!f.back())
        f.pop_back();
}
inline poly operator +(const poly &f,const poly &g){
    poly h(max(f.size(),g.size()));
    for(int i=0;i<(int)f.size();i++)Add(h[i],f[i]);
    for(int i=0;i<(int)g.size();i++)Add(h[i],g[i]);
    return shrink(h),h;
}
inline poly operator -(const poly &f,const poly &g){
    poly h(max(f.size(),g.size()));
    for(int i=0;i<(int)f.size();i++)Add(h[i],f[i]);
    for(int i=0;i<(int)g.size();i++)Sub(h[i],g[i]);
    return shrink(h),h;
}
inline poly operator *(const poly &f,const poly &g){
    if(!f.size()||!g.size())return poly{};
    poly h(f.size()+g.size()-1);
    for(int i=0;i<(int)f.size();i++)
        for(int j=0;j<(int)g.size();j++)
            Add(h[i+j],f[i]*g[j]%mod);
    return shrink(h),h;
}
inline pair<poly,poly> div(const poly &f,const poly &g){
    if(f.size()<g.size())return make_pair(poly{},f);
    int lf=f.size(),lg=g.size();
    poly p(f.size()-g.size()+1),q(f);
    ll inv=qpow(g.back(),mod-2);
    for(int i=lf-1;i>=lg-1;i--){
        if(!q[i])continue;
        ll coef=q[i]*inv%mod;p[i-lg+1]=coef;
        for(int j=0;j<lg;j++)
            Sub(q[i-lg+j+1],coef*g[j]%mod);
    }
    return shrink(p),shrink(q),make_pair(p,q);
}
inline poly operator /(const poly &f,const poly &g){
    return div(f,g).fi;
}
inline poly operator %(const poly &f,const poly &g){
    return div(f,g).se;
}
inline void operator +=(poly &f,const poly &g){f=f+g;}
inline void operator -=(poly &f,const poly &g){f=f-g;}
inline void operator *=(poly &f,const poly &g){f=f*g;}
inline void operator /=(poly &f,const poly &g){f=f/g;}
inline void operator %=(poly &f,const poly &g){f=f%g;}
inline void exgcd(const poly &a,const poly &b,poly &x,poly &y){
    if(!b.size())
        x=poly{qpow(a[0],mod-2)},y=poly{};
    else{
        pair<poly,poly> o=div(a,b);
        exgcd(b,o.se,y,x),y-=o.fi*x;
    }
}
inline poly inv(const poly &p,const poly &q){
    poly x,y;
    exgcd(p,q,x,y);
    return x;
}
inline void decomp(const poly &f,const poly &p,const poly &q,const poly &r,poly &fp,poly &fq,poly &fr){
    fp=f%p*inv(q,p)%p*inv(r,p)%p;
    fq=f%q*inv(p,q)%q*inv(r,q)%q;
    fr=f%r*inv(p,r)%r*inv(q,r)%r;
}
ll fac[N],ifac[N],pw2[N],pw3[N],ipw2[N],ipw3[N];
inline void init(int n){
    fac[0]=1;
    for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;
    ifac[n]=qpow(fac[n],mod-2);
    for(int i=n;i>=1;i--)ifac[i-1]=ifac[i]*i%mod;
    pw2[0]=1;
    for(int i=1;i<=n;i++)pw2[i]=pw2[i-1]*2%mod;
    ipw2[n]=qpow(pw2[n],mod-2);
    for(int i=n;i>=1;i--)ipw2[i-1]=ipw2[i]*2%mod;
    pw3[0]=1;
    for(int i=1;i<=n;i++)pw3[i]=pw3[i-1]*3%mod;
    ipw3[n]=qpow(pw3[n],mod-2);
    for(int i=n;i>=1;i--)ipw3[i-1]=ipw3[i]*3%mod;
}
inline ll binom(int n,int m){
    if(m<0||m>n)return 0;
    return fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
poly f[K][K],g[K][K];
poly f1[K],f2[K],f3[K];
inline void Init(int n){
    f[0][0]=poly{1};
    poly p1{1},p2{1},p3=w3,q1=w1*w1*w1,q2=w2*w2*w2,q=q1*q2,qw1=q/w1,qw2=q/w2;
    for(int i=0;i<=n;i++,p1*=q1,p2*=q2){
        for(int j=0;j<=i;j++)
            for(int k=0;k<=j;k++)
                g[j][k]=f[j][k];
        for(int j=0;j<=i;j++)
            for(int k=0;k<=i;k++){
                int c=3-(j<i)-(k<j);
                if(c==1)g[j][k]/=w1;
                if(c==2)g[j][k]/=w2;
                if(j<i)g[j+1][k]+=g[j][k];
                if(k<j)g[j][k+1]+=g[j][k];
            }
        decomp(g[i][i],p1,p2,p3,f1[i],f2[i],f3[i]);
        for(int j=0;j<=i;j++)
            for(int k=0;k<=j;k++)
                g[j][k]=f[j][k]*q;
        for(int j=0;j<=i;j++)
            for(int k=0;k<=i;k++){
                int c=3-1-(j<i)-(k<j);
                if(c==1)g[j][k]/=w1;
                if(c==2)g[j][k]/=w2;
                f[j][k]=g[j][k];
                if(j<i)g[j+1][k]+=g[j][k];
                if(k<j)g[j][k+1]+=g[j][k];
            }
    }
}
inline ll qry(int n,int k){
    n-=k*3;
    if(n<0)return 0;
    ll ans=0;int k3=k*3;
    for(int i=0;i<=n&&i<(int)f1[k].size();i++)Add(ans,f1[k][i]*binom(n-i+k3-1,n-i)%mod);
    for(int i=0;i<=n&&i<(int)f2[k].size();i++)Add(ans,f2[k][i]*binom(n-i+k3-1,n-i)%mod*pw2[n-i]%mod);
    for(int i=0;i<=n&&i<(int)f3[k].size();i++)Add(ans,f3[k][i]*pw3[n-i]%mod);
    return ans;
}
bool Med;
int main(){
    cerr<<abs(&Mst-&Med)/1048576.0<<endl;
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    init(1000101);
    Init(101);
    int _Test;cin>>_Test;
    while(_Test--){
        int n,k;cin>>n>>k,k=n-k;
        cout<<sub(qry(n,k),qry(n,k+1))<<'\n';
    }
    return 0;
}