题解【AGC058D】Yet Another ABC String

· · 题解

显然要考虑容斥。一般的容斥是枚举不合法的字符串位置,但这题不合法的字符串可能会重叠,比较难以计算。一种好的容斥方式是容斥形如 ABCABCABCA... 这样的连续段,只要每个极长连续段长度都不超过 2 就行了。

这样会产生一个新的问题:我们并不能保证连续段是极长的。我们需要找到一个容斥系数,使得长度 \le 2 的连续段的所有划分方式容斥系数之和为 1,否则为 0。设 F(x) 为容斥系数的生成函数,应该有:

\begin{aligned} \sum_{i\ge 0}F(x)^i&=1+x+x^2\\ \frac{1}{1-F(x)}&=1+x+x^2\\ F(x)&=1-\frac{1}{1+x+x^2} \end{aligned}

通过暴力多项式求逆或者手算,可以得到:

[x^k]F(x)=\begin{cases} 0 & k=0 \operatorname{or} k\equiv 2\pmod 3\\ 1 & k\equiv 1\pmod 3\\ -1 & k\equiv 0\pmod 3 \end{cases}

我们只关心 \not= 0 的项。计算答案可以利用三元生成函数:

\begin{aligned} G&=(-3\sum_{i\ge 1}(abc)^i)+(\sum_{i\ge 0}(abc)^i(a+b+c))\\ ans&=[a^{A}b^{B}c^{C}]\sum_{i\ge 0}G^i=[a^{A}b^{B}c^{C}]\frac{1}{1-G} \end{aligned}

其中 -3 的系数是因为连续段长度为 3 的倍数时,有三种方案。整理一下:

\begin{aligned} G&= a+b+c+\sum_{i\ge 1}(abc)^i(a+b+c-3)\\ &= a+b+c+(a+b+c-3)\frac{abc}{1-abc}\\ &=\frac{(1-abc)(a+b+c)+abc(a+b+c-3)}{1-abc}\\ &=\frac{a+b+c-3abc}{1-abc}\\ \frac{1}{1-G}&=\frac{1-abc}{1-a-b-c+2abc}=(1-abc)(\sum_{i\ge 0}(a+b+c-2abc)^i) \end{aligned}

那么我们只需要枚举右边选了几个 -2abc 就可以 O(1) 求出答案。时间复杂度 O(n)

#include<bits/stdc++.h>
#define For(i,a,b) for(int i=(a);i<=(b);++i)
#define Rof(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
const int Maxn=3e6,Mod=998244353;

inline int Pow(int x,int y)
{
    int res=1;
    while(y)
    {
        if(y&1) res=1ll*res*x%Mod;
        x=1ll*x*x%Mod,y>>=1;
    }
    return res;
}

int A,B,C,n,fac[Maxn+5],inv[Maxn+5],pw[Maxn+5];
inline int F(int a,int b,int c)
{
    int m=min(a,min(b,c)),res=0; if(m<0) return 0;
    For(i,0,m)
    {
        int k=1ll*pw[i]*(i&1?Mod-1:1)%Mod;
        res=(res+1ll*k*fac[a+b+c-i-i]%Mod*inv[i]%Mod*
             inv[a-i]%Mod*inv[b-i]%Mod*inv[c-i])%Mod;
    } return res;
}

int main()
{
    cin>>A>>B>>C; n=A+B+C,fac[0]=inv[0]=pw[0]=1;
    For(i,1,n) pw[i]=2*pw[i-1]%Mod;
    For(i,1,n) fac[i]=1ll*fac[i-1]*i%Mod;
    inv[n]=Pow(fac[n],Mod-2);
    Rof(i,n-1,1) inv[i]=1ll*inv[i+1]*(i+1)%Mod;
    cout<<(F(A,B,C)-F(A-1,B-1,C-1)+Mod)%Mod<<endl;
    return 0;
}