题解:AT_abc400_f [ABC400F] Happy Birthday! 3

· · 题解

在此处感谢翻译以及 @int_stl 大佬成功让我读错题面。

这题很显然是区间 dp,于是考虑转移。首先对于 c_l=c_{l+1} 以及 c_r=c_{r-1} 等情况进行特判,分别用 f_{l+1,r}f_{l,r-1} 来转移,当 c_l=c_r 时便考虑先将 l\sim r 都染成 c_l 的颜色,然后再对颜色不合法的区间进行转移。

然后是正常操作,也就是枚举断点 k,用 f_{l,k}+f_{k+1,r} 来转移;当然需要注意当 c_k=c_r 时我们可以进行特别的操作,也就是先将 l\sim r 染成 c_k 的颜色,然后对区间 [l,k][k+1,r-1] 进行操作,也就是 f_{l,k}+f_{k+1,r-1}+r-k,注意到转移式不应是 f_{l,k}+f_{k+1,r-1}+r-l+1+x_{c_k},因为 x_{c_k}k-l+1 已经被 f_{l,k} 算进去了。

Code

#include <bits/stdc++.h>
using namespace std;
#define int long long
int f[805][805], c[805], x[805];

signed main() {
    int n;
    cin >> n;
    memset(f, 0x3f, sizeof(f));
    for (int i = 1; i <= n; i++)
        cin >> c[i];
    for (int i = 1; i <= n; i++)
        cin >> x[i];
    for (int i = n + 1; i <= 2 * n; i++)
        c[i] = c[i - n];
    for (int i = 1; i <= 2 * n; i++)
        f[i][i] = 1 + x[c[i]];
    for (int len = 2; len <= 2 * n; len++) {
        for (int l = 1; l + len - 1 <= 2 * n; l++) {
            int r = l + len - 1;
            if (c[l] == c[l + 1])
                f[l][r] = min(f[l][r], f[l + 1][r] + 1);
            else
                f[l][r] = min(f[l][r], f[l + 1][r] + 1 + x[c[l]]);
            if (c[r - 1] == c[r])
                f[l][r] = min(f[l][r], f[l][r - 1] + 1);
            else
                f[l][r] = min(f[l][r], f[l][r - 1] + 1 + x[c[r]]);
            if (c[l] == c[r]) {
                int ll = l, rr = r;
                while (ll + 1 <= r && c[ll + 1] == c[l])
                    ll++;
                while (rr - 1 >= l && c[rr - 1] == c[r])
                    rr--;
                f[l][r] = min(f[l][r], f[ll + 1][rr - 1] + r - l + 1 + x[c[l]]);
            }
            for (int k = l; k <= r; k++)
                f[l][r] = min(f[l][r], f[l][k] + f[k + 1][r]);
            for (int k = l; k <= r; k++) {
                if (c[k] == c[r])
                    f[l][r] = min(f[l][r], f[l][k] + f[k + 1][r - 1] + r - k);
            }
        }
    }
    int ans = f[0][0];
    for (int i = 1; i + n - 1 <= 2 * n; i++)
        ans = min(ans, f[i][i + n - 1]);
    cout << ans;
}

也许我的方法中有一些多余的转移。