题解:P11126 [ROIR 2024] 三等分的数组 (Day 2)

· · 题解

要想到上界 O(m^3) 的 dp 是简单的,重点在于复杂度分析。

dp_{i,j,k} 表示当前考虑到值为 i 的数字,当前数字剩余 j 个,上一个数字剩余 k 个的方案数。容易有转移:

dp_{i,j,k}\to dp_{i+1,a_{i+1}-x,j-x},3\mid k-x

滚动数组加枚举 j,倒序枚举 x 即可做到一个上界 O(m^3) 的 dp。

由于数的总个数和 m 同阶,也许我们认为的上界会很松。

仔细分析一下,时间复杂度的贡献应该是 \sum_{i=1}^{m} cnt_{i-1}\times cnt_{i} 的,有 (a+b)^2\ge a^2+b^2\ge 2ab,所以上界应是 \left(\sum_{i=0}^{m}cnt_i\right)^2=n^2

#include<bits/stdc++.h>
using namespace std;
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
    return x*f;
}
bool mbs;
#define ll long long
const int mod=1e9+7;
const int maxn=5e3+20;
ll sum[4],dp[2][maxn][maxn],n,a[maxn],m;
bool mbt;
int main(){
//  cerr<<(&mbs-&mbt)/1024.0/1024.0<<endl;
    n=read(),m=read();int x;
    for(int i=1;i<=n;i++) x=read(),a[x]++;
    dp[0][0][0]=1;
    for(int i=0;i<m;i++){
        int cur=(i&1);
        for(int j=0;j<=a[i];j++) for(int k=0;k<=a[i+1];k++) dp[cur^1][k][j]=0;
        for(int j=0;j<=a[i];j++){
            sum[0]=sum[1]=sum[2]=0;
            for(int x=(i==0?0:a[i-1]);x>=0;x--){
                sum[x%3]=(sum[x%3]+dp[cur][j][x])%mod;
                if(x<=min<ll>(j,a[i+1])) dp[cur^1][a[i+1]-x][j-x]=(dp[cur^1][a[i+1]-x][j-x]+sum[x%3])%mod;
            }
        }
    }
    ll ans=0;
    for(int i=0;i<=a[m-1];i++) for(int j=0;j<=a[m];j++) if(i%3==0&&j%3==0) ans=(ans+dp[m&1][j][i])%mod;
    printf("%lld\n",ans); 
    return 0;
}
/*
3 1
1 1 1
*/