P7420 题解

· · 题解

无需多项式,直接转换一下题意然后 dp 即可。

首先可以发现拆分和字符串具体内容无关,只和 n 也就是 c(a,b) 有关。

第一反应是求元素总和等于 n 的上升序列的数量,但观察一下样例就知道这是假的。

为什么呢?看样例,n=8 却可以拆分成 \text{noi(1)noinonoin(2)oiionoinoinoionoi(4)} 总和为 7 的上升序列。观察发现,这时第二段多拿了一个 \text n,导致一个原本会产生贡献的 \text{noi} 被拆开而没有产生贡献。

具体地,一个长为 k 的上升序列元素总和是 [n-k+1,n] 都是合法,因为每次拆出一个新段时都可能拆开一个 b 串使贡献少 1

然后就很简单了,设 f_{i,j} 表示长为 i 元素总和为 j 的上升序列的个数,显然有 f_{i,j}=f_{i-1,j-i}+f_{i,j-i}

元素总和为 n 的上升序列长度是 O(\sqrt n) 级别的,因此 dp 的时间复杂度为 O(n\sqrt n)

注意会卡内存,滚动数组优化一下。

#include <iostream>
#include <cstdio>
#include <cstring>
#define bas 917
#define mod 899678209
using namespace std;
typedef long long ll;
ll n,m,n1,n2,h1[1000010],h2,f[2][200010],ans;
char str[1000010];
int main() {
    scanf("%s",str+1);
    n1=strlen(str+1);
    for(int i=1;i<=n1;++i) h1[i]=(h1[i-1]*bas+str[i]-'a'+1)%mod;
    scanf("%s",str+1);
    n2=strlen(str+1);
    ll pw=1;
    for(int i=1;i<=n2;++i) {
        h2=(h2*bas+str[i]-'a'+1)%mod;
        pw=pw*bas%mod;
    }
    for(int i=n2;i<=n1;++i)
        if(((h1[i]-h1[i-n2]*pw)%mod+mod)%mod==h2) {
            ++n;
            i+=n2-1;
        }
    ll l=1,r=n,mid;
    while(l<=r) {
        mid=l+r>>1;
        if((mid*(mid+1)>>1)<=n) {
            m=mid;
            l=mid+1;
        }
        else r=mid-1;
    }
    if((m*(m+1)>>1)<n) ++m;
    f[0][0]=1;
    for(int i=1;i<=m;++i) {
        ll x=i&1,y=x^1;
        for(int j=0;j<=n;++j)
            if(i>j) f[x][j]=0;
            else {
                f[x][j]=f[y][j-i]+f[x][j-i];
                if(f[x][j]>=mod) f[x][j]-=mod;
            }
        if(i==1) continue;
        for(int j=n-i+1;j<=n;++j) {
            ans+=f[x][j];
            if(ans>=mod) ans-=mod;
        }
    }
    printf("%lld",ans);
    return 0;
}