CF1922F 题解

· · 题解

感觉大部分(也许是所有)题解都没说到点上,于是我来发一篇。

思路

结论:存在一种最优解使得区间之间没有相交且不包含的关系。

证明:考虑将连续相同的数缩成一个点,那么操作一次相当于将这个区间缩成一个点,如果能证明这样操作答案不变,就证明了结论(不会有区间与一个长度为 1 的区间相交且不包含)。考虑归纳证明,根据最优解的次数从小到大归纳。如果之后存在某次操作的区间包含了当前区间的一部分,那么这个的答案相当于将所有都操作了的数列加上一个点(因为根据归纳证明,若数都相同则可以缩成一个)。而加上一个点一定不如不加,所以这样操作不优。

因此操作的区间形成了树形结构,可以使用区间 dp。设 dp_{i,j,k,0/1} 表示将区间 [l,r] 全变成不是/是 k 的数需要的操作次数。首先判断区间是否已经满足了。否则根据当前区间是否操作可以得到转移:

dp_{i,j,k,0}=\min(\min_{l=i}^{j-1} dp_{i,l,k,0}+dp_{l+1,j,k,0},\min_{l=1}^x dp_{i,j,l,0}+1)

这个转移方程的意思是,如果当前区间不操作,那么它一定需要分成两半,并分别操作。否则,当前区间会变成一个不为 k 的数 l,在此之前,区间需要满足全都不为 l,需要的次数为 dp_{i,j,l,0}

dp_{i,j,k,1}=\min(\min_{l=i}^{j-1} dp_{i,l,k,1}+dp_{l+1,j,k,1},dp_{i,j,k,0}+1)

这个转移方程和上面的类似,如果当前区间不操作,则它需要分成两半分别操作。否则,当前区间操作的前提是区间所有数都不为 k,次数即为 dp_{i,j,k,0}

在第一个转移方程中存在互相转移的情况。但是可以发现只有 dp 值最小的才会更新别人,而它不会被别人更新,所以直接先求出来第一部分的 dp 值,再求出最小值,用它加一更新别人。

这样暴力做的复杂度是 \mathcal O(n^3x) 已经可以通过,不过这个做法还可以优化。

发现对于一个左端点,使用特定次数操作使得满足条件的右端点存在单调性,所以考虑交换下标和值。设 dp_{i,j,k,0/1} 表示左端点为 i、操作小于等于 j 次能使区间全变成不是/是 k 的数的最大的右端点。转移和上面类似。这样做的好处是,状态和转移中都有关于答案大小的枚举,而答案大小并不很大。考虑将序列中出现次数最小的数作为最后相等的数,它的出现次数 \le \dfrac{n}{x},那么将它把序列划分的次数{}+1 段都操作一次,次数 \le \dfrac{n}{x}+1。所以答案的上界就是这个数。复杂度为 \mathcal O(nx+\dfrac{n^3}{x})

代码

暴力的代码:

#include<bits/stdc++.h>
using namespace std;
int n,m;
int a[110];
int dp[110][110][110];
int ans[110][110][110];
int main()
{
    int t; cin>>t; while(t--)
    {
        cin>>n>>m;
        for(int i=1; i<=n; ++i) cin>>a[i];
        for(int len=1; len<=n; ++len)
        {
            for(int i=1; i<=n-len+1; ++i)
            {
                int j=i+len-1;
                for(int k=1; k<=m; ++k)
                {
                    bool flag=1;
                    for(int l=i; l<=j; ++l)
                    {
                        if(a[l]==k) { flag=0; break; }
                    }
                    if(flag) { dp[i][j][k]=0; continue; }
                    dp[i][j][k]=1e9;
                    for(int l=i; l<j; ++l) dp[i][j][k]=min(dp[i][j][k],dp[i][l][k]+dp[l+1][j][k]);
                }
                for(int k=1; k<=m; ++k)
                {
                    bool flag=1;
                    for(int l=i; l<=j; ++l)
                    {
                        if(a[l]!=k) { flag=0; break; }
                    }
                    if(flag) { ans[i][j][k]=0; continue; }
                    ans[i][j][k]=1e9;
                    for(int l=i; l<j; ++l) ans[i][j][k]=min(ans[i][j][k],ans[i][l][k]+ans[l+1][j][k]);
                }
                int in=1e9;
                for(int k=1; k<=m; ++k) in=min(in,dp[i][j][k]);
                for(int k=1; k<=m; ++k) dp[i][j][k]=min(dp[i][j][k],in+1),ans[i][j][k]=min(ans[i][j][k],dp[i][j][k]+1);
            }
        }
        int aans=1e9;
        for(int i=1; i<=m; ++i) aans=min(aans,ans[1][n][i]);
        cout<<aans<<'\n';
    }
    return 0;
}

优化后的代码:

#include<bits/stdc++.h>
using namespace std;
int n,m;
int a[110];
int dp[110][110][110];
int ans[110][110][110];
int lst[110][110],llst[110][110];
int main()
{
    int t; cin>>t; while(t--)
    {
        cin>>n>>m;
        for(int i=1; i<=n; ++i) cin>>a[i];
        memset(dp,0,sizeof(dp));
        memset(ans,0,sizeof(ans));
        int lim=n/m+1,aans=1e9;
        for(int i=1; i<=m; ++i) lst[n+1][i]=llst[n+1][i]=n;
        for(int i=n; i>=1; --i)
        {
            for(int j=1; j<=m; ++j)
            {
                if(a[i]==j) lst[i][j]=i-1,llst[i][j]=llst[i+1][j];
                else lst[i][j]=lst[i+1][j],llst[i][j]=i-1;
            }
            for(int j=0; j<=lim; ++j)
            {
                for(int k=1; k<=m; ++k)
                {
                    if(j==0) { dp[i][j][k]=lst[i][k],ans[i][j][k]=llst[i][k]; continue; }
                    for(int l=0; l<j; ++l)
                    {
                        if(dp[i][l][k]==n) dp[i][j][k]=n;
                        else dp[i][j][k]=max(dp[i][j][k],dp[dp[i][l][k]+1][j-l][k]);
                        if(ans[i][l][k]==n) ans[i][j][k]=n;
                        else ans[i][j][k]=max(ans[i][j][k],ans[ans[i][l][k]+1][j-l][k]);
                    }
                }
                if(j!=0)
                {
                    int ax=0;
                    for(int k=1; k<=m; ++k) ax=max(ax,dp[i][j-1][k]);
                    for(int k=1; k<=m; ++k)
                    {
                        dp[i][j][k]=max(dp[i][j][k],ax),ans[i][j][k]=max(ans[i][j][k],dp[i][j-1][k]);
                        dp[i][j][k]=lst[dp[i][j][k]+1][k];
                        ans[i][j][k]=llst[ans[i][j][k]+1][k];
                    }
                }
                if(i==1)
                {
                    for(int k=1; k<=m; ++k) if(ans[i][j][k]==n) aans=min(aans,j);
                }
            }
        }
        cout<<aans<<'\n';
    }
    return 0;
}