题解:CF18E Flag 2

· · 题解

题解:CF18E Flag 2

UPD:增加了亿点点内容。

一道调了 5 个小时的 DP 题。

解题思路

我们由题面可推知,每一行的字母都是交替摆放的,所以可以去枚举每行所使用的两个字母分别是什么就可以了。用图片来解释一下,也就是,只能按照下图的方法放字母:

我首先写了一个记搜,结果:TLE。

这里的主要问题是记忆化不一定每个位置都能一次跑到最优,比如 CF 上的 #3,我最后一层记搜被跑了整整 5 遍!

有句话说得好,叫做:“任何记搜,都能写成 DP。”所以我将代码改成了 DP (这就是为什么 DP 代码里面会有 memory 这个数组名)

此时,我们的思路可以大致理为:

接下来是最重要的推 DP 转移方程。若我们设行数为 i,上一行所出现的两个数为:j , k(列数为 1 时也就当它有两个数就行),操作次数为 memory_{i , j , k}。则最后的转移方程为:

dp_{i , j , k} = \min(dp_{i , j1 , k1} , dp_{i - 1 , j2 , k2} + memory_{i , j1 , k1})

由此,写出代码不难。但是要注意,因为我们的处理顺序是从下到上,所以我们要逆序输出。

参考代码

#include<bits/stdc++.h>
#pragma GCC optimize(2)//卡常
using namespace std;
int n , m , dp[505][30][30] , dp2[505][30][30] , dp3[505][30][30] , memory[505][30][30] , ansn , minj , mink , p , q , tmp1 , tmp2 , cost , tmp;
char a[505][505] , ans[505][505];
int main()
{
    ios::sync_with_stdio(0);
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> m;
    for(int i = 1 ; i <= n ; i++)
    {
        for(int j = 1 ; j <= m ; j++)
        {
            cin >> a[i][j];
        }
    }
    for(int i = 0 ; i <= n ; i++)
    {
        for(int j = 0 ; j <= 'z' ; j++)
        {
            for(int k = 0 ; k <= 'z' ; k++)
            {
                if(!i)//卡常
                {
                    dp[i][j][k] = 0;
                }
                else
                {
                    dp[i][j][k] = 1000000007;
                }
            }
        }
    }
    for(int i = 1 ; i <= n ; i++)
    {
        for(int j1 = 0 ; j1 <= 25 ; j1++)
        {   
            for(int k1 = 0 ; k1 <= 25 ; k1++)
            {
                if(j1 != k1)//卡常
                {
                    for(int j2 = 0 ; j2 <= 25 ; j2++)
                    {
                        if(j1 != j2)//卡常
                        {
                            for(int k2 = 0 ; k2 <= 25 ; k2++)
                            {
                                if(k1 != k2 && j2 != k2)
                                {
                                    if(memory[i][j2][k2] != 0)
                                    {
                                        cost = memory[i][j2][k2];
                                    }
                                    else
                                    {
                                        tmp = 0;
                                        for(int l = 1 ; l <= m ; l++)
                                        {
                                            if(l % 2)
                                            {
                                                tmp += (a[i][l] != j2 + 'a');
                                            }
                                            else
                                            {
                                                tmp += (a[i][l] != k2 + 'a');
                                            }
                                        }
                                        memory[i][j2][k2] = tmp;
                                        cost = tmp;
                                    }
                                    if(dp[i - 1][j1][k1] + cost < dp[i][j2][k2])
                                    {
                                        dp[i][j2][k2] = dp[i - 1][j1][k1] + cost;
                                        dp2[i][j2][k2] = j1;
                                        dp3[i][j2][k2] = k1;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    ansn = 1000000007;
    for(int i = 0 ; i <= 25 ; i++)
    {
        for(int j = 0 ; j <= 25 ; j++)
        {
            if(dp[n][i][j] < ansn)
            {
                ansn = dp[n][i][j];
                minj = i;
                mink = j;
            }
        }
    }
    cout << ansn << '\n';
    p = minj;
    q = mink;
    for(int j = n ; j >= 1 ; j--)
    {
        for(int i = 1 ; i <= m ; i++)
        {
            if(i % 2)
            {
                ans[j][i] = p + 'a';
            }
            else
            {
                ans[j][i] = q + 'a';
            }
        }
        tmp1 = dp2[j][p][q];
        tmp2 = dp3[j][p][q];
        p = tmp1;
        q = tmp2;
    }
    for(int i = 1 ; i <= n ; i++)
    {
        for(int j = 1 ; j <= m ; j++)
        {
            cout << ans[i][j];
        }
        cout << "\n";
    }
    return 0;
}

卡常 1.9\text{s} 过。

优化介绍

优化后的参考代码:

#include<bits/stdc++.h>
#pragma GCC optimize(2)//卡常
using namespace std;
int n , m , dp[505][30][30][5] , memory[505][30][30] , ansn , minj , mink , p , q , tmp1 , tmp2 , cost , tmp;
char a[505][505] , ans[505][505];
int main()
{
    ios::sync_with_stdio(0);
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> m;
    for(int i = 1 ; i <= n ; i++)
    {
        for(int j = 1 ; j <= m ; j++)
        {
            cin >> a[i][j];
        }
    }
    for(int i = 0 ; i <= n ; i++)
    {
        for(int j = 0 ; j <= 'z' ; j++)
        {
            for(int k = 0 ; k <= 'z' ; k++)
            {
                if(!i)//卡常
                {
                    dp[i][j][k][1] = 0;
                }
                else
                {
                    dp[i][j][k][1] = 1000000007;
                }
            }
        }
    }
    for(int i = 1 ; i <= n ; i++)
    {
        for(int j1 = 0 ; j1 <= 25 ; j1++)
        {   
            for(int k1 = 0 ; k1 <= 25 ; k1++)
            {
                if(j1 != k1)//卡常
                {
                    for(int j2 = 0 ; j2 <= 25 ; j2++)
                    {
                        if(j1 != j2)//卡常
                        {
                            for(int k2 = 0 ; k2 <= 25 ; k2++)
                            {
                                if(k1 != k2 && j2 != k2)
                                {
                                    if(memory[i][j2][k2] != 0)
                                    {
                                        cost = memory[i][j2][k2];
                                    }
                                    else
                                    {
                                        tmp = 0;
                                        for(int l = 1 ; l <= m ; l++)
                                        {
                                            if(l & 1) 
                                            {
                                                tmp += (a[i][l] != j2 + 'a');
                                            }
                                            else
                                            {
                                                tmp += (a[i][l] != k2 + 'a');
                                            }
                                        }
                                        memory[i][j2][k2] = tmp;
                                        cost = tmp;
                                    }
                                    if(dp[i - 1][j1][k1][1] + cost < dp[i][j2][k2][1])
                                    {
                                        dp[i][j2][k2][1] = dp[i - 1][j1][k1][1] + cost;
                                        dp[i][j2][k2][2] = j1;
                                        dp[i][j2][k2][3] = k1;
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    ansn = 1000000007;
    for(int i = 0 ; i <= 25 ; i++)
    {
        for(int j = 0 ; j <= 25 ; j++)
        {
            if(dp[n][i][j][1] < ansn)
            {
                ansn = dp[n][i][j][1];
                minj = i;
                mink = j;
            }
        }
    }
    cout << ansn << '\n';
    p = minj;
    q = mink;
    for(int j = n ; j >= 1 ; j--)
    {
        for(int i = 1 ; i <= m ; i++)
        {
            if(i & 1)//卡常 
            {
                ans[j][i] = p + 'a';
            }
            else
            {
                ans[j][i] = q + 'a';
            }
        }
        tmp1 = dp[j][p][q][2];
        tmp2 = dp[j][p][q][3];
        p = tmp1;
        q = tmp2;
    }
    for(int i = 1 ; i <= n ; i++)
    {
        for(int j = 1 ; j <= m ; j++)
        {
            cout << ans[i][j];
        }
        cout << "\n";
    }
    return 0;
}

优化后用时减少了 400\text{ms},最后用时 1.5\text{s}