P6839 [BJOI2016]打字机 题解

· · 题解

这是一道三合一的题目,我们将数据点 1~6,7~8,9~10 分别处理。

数据点 1~6

一个经典的 AC 自动机上 dp。(BJOI 传统艺能)

考虑到打错字符这一扰动的特殊性我们并不能直接正向做决策,但是可以倒序决策,这样就能把扰动的“最坏”与正常决策的“最好”一起转移。

先将串建成 AC 自动机,然后对于每个串在结束节点添加价值,最后每个节点的价值 val 定义为其在 fail 树上到根的价值和。

dp_{i,j,k} 表示在填后 i 个字符,起始在自动机上的节点 j,在之前已经打错了 k 次时,能够在后面获得的最大价值。

转移的时候,分类讨论是否打错,如果打错则取最坏情况,否则取最好情况,在此基础上再选择最坏情况。因此有转移:

dp_{i,j,k} = \min \{ \max_{x \in to_j} \{dp_{i-1,x,k}\}, \min_{x \in to_j}\{ dp_{i-1,x,k+1}\}\}+val_j

直接做,复杂度为 O(mk|\Sigma|\sum|S|)

数据点 7~8

没有扰动了,因此可以简化成一个更简单的 dp。

先同样建出自动机,并得出节点的价值。

dp_{i,j} 为填了 i 个字符,当前在自动机上的节点 j 时已经获得的最大价值。有转移:

dp_{i,j} = \max_{j \in to_x} \{ dp_{i-1,x}\}+val_j

注意到这个转移式子是简洁的 (max,+) 递推,可以使用类似 【NOI2020】美食家 的方法,用倍增或者矩阵的形式加速 dp。

数据点 9~10

略显玄学的部分。

观察到扰动最多 5 次,\sum|S| 只有 50,而 m 高达 10^9。这说明在两次扰动之间会隔有很长的一段。对于这一段,其决策不受扰动影响。

如果这一段的长度是数倍于自动机节点个数的,则可以预见到,最优的决策一定会经过一个环,在环上不断地走,获得环上的价值。

在一次扰动之后,可能会有多种情况,比如走上另一个环,或者走一条路径等,但实际上也都是一个类似之前决策的过程。

如果走上的路径长度是自动机节点个数的数倍,那必然不如走一个环有意义。所以走路径只会是最后几步的决策。

而对于不停走一个环的情况,在开始循环后在第几次循环扰动都是等价的。

因此可以声明:存在一种符合题意的最优情况,使得所有扰动都集中在末尾,且末尾这段的长度是 O(k\sum|S|) 级别。

那么对于前面的部分视作没有扰动,按数据点 7~8 的方法处理,最后选择一个末尾长度(理论上需要是 k\sum|S| 的数倍以上,我选择了 2000)按数据点 1~6 的方法处理,然后拼起来计算答案。

另外据说官方题解是利用 a_i=1 性质的一个所谓“最小比率环”的做法,但是作者实在找不到相关资源,只能不了了之了。

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int N=505,N2=205;
char s[N][N];
int val[N],son[N][26],fail[N],ndc=1;
int n,m,k,que[N],h,t;
inline void ins_AC(int id,int v)
{
    int i,len=strlen(s[id]+1),now=1;
    for(i=1;i<=len;i++)
    {
        int c=s[id][i]-'a';
        if(!son[now][c]) son[now][c]=++ndc;
        now=son[now][c];
    }
    val[now]+=v;
}
void build_AC()
{
    for(int i=0;i<26;i++)
    {
        if(son[1][i]) fail[son[1][i]]=1,que[t++]=son[1][i];
        else son[1][i]=1;
    }
    while(h<t)
    {
        int now=que[h++];
        for(int i=0;i<26;i++)
        {
            if(son[now][i]) fail[son[now][i]]=son[fail[now]][i],que[t++]=son[now][i];
            else son[now][i]=son[fail[now]][i];
        }
    }
    for(int i=0;i<t;i++)
    {
        int now=que[i];
        val[now]+=val[fail[now]];
    }
}
int dp[N*5][N][6];
void solve1()
{
    int i,j,tk,x;
    for(i=1;i<=ndc;i++)
        for(j=0;j<=k;j++)
            dp[0][i][j]=val[i];
    for(i=1;i<=m;i++)
    {
        for(j=1;j<=ndc;j++)
        {
            for(tk=0;tk<=k;tk++)
            {
                int mx=0,mn=1e9;
                for(x=0;x<26;x++) mx=max(mx,dp[i-1][son[j][x]][tk]);
                if(tk<k) {for(x=0;x<26;x++) mn=min(mn,dp[i-1][son[j][x]][tk+1]);}
                dp[i][j][tk]=min(mx,mn)+val[j];
            }
        }
    }
}
long long dp2[35][N2][N2],tdp[N2],tdp2[N2];
void solve2()
{
    int i,j,tk,x;
    for(i=1;i<=ndc;i++) 
    {
        for(j=1;j<=ndc;j++) dp2[0][i][j]=-1;
        for(j=0;j<26;j++) dp2[0][i][son[i][j]]=val[i]+val[son[i][j]];
    }
    for(i=1;(1<<i)<=m;i++)
    {
        for(j=1;j<=ndc;j++)
        {
            for(tk=1;tk<=ndc;tk++)
            {
                long long mx=-1;
                for(x=1;x<=ndc;x++) if(dp2[i-1][j][x]!=-1&&dp2[i-1][x][tk]!=-1) mx=max(mx,dp2[i-1][j][x]+dp2[i-1][x][tk]-val[x]);
                dp2[i][j][tk]=mx;
            }
        }
    }
}
int main()
{
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i<=n;i++)
    {
        int v;
        scanf("%s%d",s[i]+1,&v);
        ins_AC(i,v);
    }
    build_AC();
    if(m<=500)
    {
        solve1();
        printf("%d",dp[m][1][0]);
    }
    else if(k==0)
    {
        solve2();
        int i,j,tk;
        for(i=1;i<=ndc;i++) tdp[i]=tdp2[i]=-1;
        tdp[1]=0;
        for(i=0;(1<<i)<=m;i++)
        {
            if(!((1<<i)&m)) continue;
            for(j=1;j<=ndc;j++)
            {
                if(tdp[j]==-1) continue;
                for(tk=1;tk<=ndc;tk++)
                {
                    if(dp2[i][j][tk]==-1) continue;
                    tdp2[tk]=max(tdp2[tk],tdp[j]+dp2[i][j][tk]-val[j]);
                }
            }
            for(j=1;j<=ndc;j++) tdp[j]=tdp2[j],tdp2[j]=-1;
        }
        long long ans=0;
        for(i=1;i<=ndc;i++) ans=max(ans,tdp[i]);
        printf("%lld",ans);
    }
    else
    {
        if(m<=2500)
        {
            solve1();
            printf("%d",dp[m][1][0]);
            return 0;
        }
        solve2();
        int i,j,tk;
        for(i=1;i<=ndc;i++) tdp[i]=tdp2[i]=-1;
        tdp[1]=0;
        for(i=0;(1<<i)<=m-2000;i++)
        {
            if(!((1<<i)&(m-2000))) continue;
            for(j=1;j<=ndc;j++)
            {
                if(tdp[j]==-1) continue;
                for(tk=1;tk<=ndc;tk++)
                {
                    if(dp2[i][j][tk]==-1) continue;
                    tdp2[tk]=max(tdp2[tk],tdp[j]+dp2[i][j][tk]-val[j]);
                }
            }
            for(j=1;j<=ndc;j++) tdp[j]=tdp2[j],tdp2[j]=-1;
        }
        m=2000;
        solve1();
        long long ans=0;
        for(i=1;i<=ndc;i++) if(tdp[i]!=-1&&dp[m][i][0]!=-1) ans=max(ans,tdp[i]+dp[m][i][0]-val[i]);
        printf("%lld",ans);
    }
    return 0;
}