题解【P3311 [SDOI2014] 数数】

· · 题解

原题

前置知识:AC 自动机,数位 DP。

降级版相似题:[JSOI2007]文本生成器

[JSOI2007]文本生成器 题解:

这是一道 AC 自动机上 DP 的题目。AC 自动机上 DP 的基本思路是设 f[i][j] 表示当前串长 i、在 j 号节点。那么本题就可以采用这一思路。设 f[i][j] 表示当前串长 i、在 trie 上 j 号节点,且当前串不包含任何一个单词的方案数。由于在 trie 结构中从某号节点转移到它的子节点比子节点从父节点转移更加便捷,因此考虑用 f[i][j] 更新 f[i+1][trie[j][k]],k\in[0,26)

for(int b=0;b<26;b++)
    if(!trie[j][b])(f[i+1][0]+=f[i][j])%=mod;
    else if(!cnt[trie[j][b]])(f[i+1][trie[j][b]]+=f[i][j])%=mod;

上面的转移:当该子节点为空时,说明这个子节点对应的字母跟当前串不可能构成单词,因此一定可以接在后面,而后开启“新纪元”(回到 trie 根)。当该子节点不为空时,如果在 AC 自动机上跳着统计完发现不会统计到单词,则可以放,否则放上这个字母就会出现单词而不合法。代码中的 cnt 为根据 fail 树预处理出的每个点会统计到的答案。

#include <bits/stdc++.h>
using namespace std;
const int mod=1e4+7;
int n,m,tot,trie[6005][26],fail[6005],f[105][6005],cnt[6005];
string s;
vector<int>G[6005];
queue<int>Q;
void insert(){
    int rt=0;
    for(int i=0;i<s.length();i++){
        int b=s[i]-'A';
        if(!trie[rt][b])trie[rt][b]=++tot;
        rt=trie[rt][b];
    }
    cnt[rt]++;
}
void get_fail(){
    for(int i=0;i<26;i++)if(trie[0][i])Q.push(trie[0][i]),G[0].push_back(trie[0][i]);
    while(!Q.empty()){
        int x=Q.front();Q.pop();
        for(int i=0;i<26;i++){
            if(trie[x][i]){
                fail[trie[x][i]]=trie[fail[x]][i];
                G[fail[trie[x][i]]].push_back(trie[x][i]);
                Q.push(trie[x][i]);
            }
            else trie[x][i]=trie[fail[x]][i];
        }
    }
}
void sumup(int x){
    for(int i=0;i<G[x].size();i++)cnt[G[x][i]]+=cnt[x],sumup(G[x][i]);
}
int main(){
    cin>>n>>m;
    for(int i=1;i<=n;i++)cin>>s,insert();
    get_fail(),sumup(0);
    f[0][0]=1;
    for(int i=0;i<m;i++){
        for(int j=0;j<=tot;j++){
            int k=0;
            for(int b=0;b<26;b++)
                if(!trie[j][b])k++;
                else if(!cnt[trie[j][b]])(f[i+1][trie[j][b]]+=f[i][j])%=mod;
            (f[i+1][0]+=k*f[i][j]%mod)%=mod;
        }
    }
    int ans=0,pro=1;
    for(int i=1;i<=m;i++)pro=pro*26%mod;
    for(int i=0;i<=tot;i++)ans=(ans+f[m][i])%mod;
    cout<<((pro-ans)%mod+mod)%mod;
}

[SDOI2014]数数 题解:

这道题可以沿用上面 AC 自动机上 DP 的思想,且大同小异。小异在这题有了数位 DP“卡上界”的限制,但其实本质还是上面粗体字的基本思路。

f_{u,o,i,j}(u,o\in\{0,1\}) 为状态,u,o,i,j 的含义见下表:

u o i j
是(1)否(0)串的 1\sim i 位置都是前导 0 是(1)否(0)卡上界 当前串长 在 trie 上 j 号节点

于是:

//t为给定的大数字(字符串),下标从0到m(注意m与原题中不是同一个)
for(int i=0;i<m;i++){
    for(int u=0;u<=1;u++){
        for(int o=0;o<=1;o++){
            for(int j=0;j<=tot;j++){
                for(int b=0;b<=(o?t[i]-'0':9);b++){
                    if(u&&!b)f[1][0][i+1][0]=1;//如果还是都是前导0,则一定不卡上界,而且不转移
                    else if(!trie[j][b])(f[0][o&(b==(o?t[i]-'0':9))][i+1][0]+=f[u][o][i][j])%=mod;
                    else if(!cnt[trie[j][b]])(f[0][o&(b==(o?t[i]-'0':9))][i+1][trie[j][b]]+=f[u][o][i][j])%=mod;
                }
            }
        }
    }
}

答案即为 \sum_{u,o\in\{0,1\},0\le j\le tot}f[u][o][m][j]tot 表示 trie 中节点最大编号)。

#include <bits/stdc++.h>
using namespace std;
const int mod=1e9+7;
int n,m,tot,trie[1505][10],fail[1505],f[2][2][1205][1505],cnt[1505],f[1205],p[1205];
string t,s;
vector<int>f[1505];
queue<int>Q;
void insert(){
    int rt=0,l=s.length();
    for(int i=0;i<l;i++){
        int b=s[i]-'0';
        if(!trie[rt][b])trie[rt][b]=++tot;
        rt=trie[rt][b];
    }
    cnt[rt]++;
}
void get_fail(){
    for(int i=0;i<10;i++)if(trie[0][i])Q.push(trie[0][i]),f[0].push_back(trie[0][i]);
    while(!Q.empty()){
        int x=Q.front();Q.pop();
        for(int i=0;i<10;i++){
            if(trie[x][i]){
                fail[trie[x][i]]=trie[fail[x]][i];
                f[fail[trie[x][i]]].push_back(trie[x][i]);
                Q.push(trie[x][i]);
            }
            else trie[x][i]=trie[fail[x]][i];
        }
    }
}
void sumdn(int x){
    for(int i=0;i<f[x].size();i++)cnt[f[x][i]]+=cnt[x],sumdn(f[x][i]);
}
int main(){
    cin>>t>>n;m=t.length();
    for(int i=1;i<=n;i++)cin>>s,insert();
    get_fail(),sumdn(0);
    f[1][1][0][0]=1;
    for(int i=0;i<m;i++){
        for(int u=0;u<=1;u++){
            for(int o=0;o<=1;o++){
                for(int j=0;j<=tot;j++){
                    for(int b=0;b<=(o?t[i]-'0':9);b++){
                        if(u&&!b)f[1][0][i+1][0]=1;
                        else if(!trie[j][b])(f[0][o&(b==(o?t[i]-'0':9))][i+1][0]+=f[u][o][i][j])%=mod;
                        else if(!cnt[trie[j][b]])(f[0][o&(b==(o?t[i]-'0':9))][i+1][trie[j][b]]+=f[u][o][i][j])%=mod;
                    }
                }
            }
        }
    }
    int ans=0;
    for(int i=0;i<=tot;i++)ans=(((((((ans+=f[0][0][m][i])%=mod)+=f[1][0][m][i])%=mod)+=f[0][1][m][i])%=mod)+=f[1][1][m][i])%=mod;
    cout<<ans-1;
}