P7537 [COCI2016-2017#4] Rima

· · 题解

很有趣的一道题(?),模拟赛的时候没写出来,老师讲的时候是真的感觉茅塞顿开。

仔细读题,想想最终的序列能长成什么样子。不难发现,最终的序列上下相邻的两个字符串的长度差距一定为 01

由于没有重复的字符串,所以一定不会出现一个型如 Sc+SS 的子序列(此处 S 代表字符串,c 代表字符)。也就是说,最终的字符串长度序列不会在中间出现峰值(只可能在序列两侧出现)。

这是字符串序列最终可能出现的一种情况,我们只需要找出最多个字符串,使其形成的序列能呈现出这样的性质即可。

考虑相邻的两个字符串。将它们分别反向插入一棵 trie 树中并标记出结尾所在的点,一定会呈现出以下两种情况的任意一种。(红点、蓝点分别表示两个字符串的第一个字符,也可以理解为反向后的最后一个字符)

其中,第一种情况对应长度相同、首字符不同的情况;第二种情况对应长度相差 1,一个字符串为另一个字符串的后缀的的情况。

结合上面推出来的两个结论,就可以解决这题了。首先将所有的字符串反向插入 trie 树,然后在这棵 trie 树上做一次树形 dp 即可。

我们需要找出一棵被标记节点最多的子树,这棵子树需要满足:

  1. 根节点只有 1 \sim 2 个子节点。
  2. 除根节点外,其他所有节点均一定被标记。
  3. 除根节点外,每个节点的儿子中最多只有一个节点不是叶子。

接下来是最不重要的代码。(已经尽量减少压行)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define rep(i, a, b) for(int i=(a), i##up=(b); i<=i##up; ++i)
#define ref(i, a) for(int i=1, i##up=(a); i<=i##up; ++i)
#define rer(i, a, b) for(int i=(a), i##dn=(b); i>=i##dn; --i)

inline int read() {
    int t=0, f=1; char c;
    while(!isdigit(c=getchar())) f=c^45;
    while(isdigit(c)) t=(t<<1)+(t<<3)+(c^48), c=getchar();
    return f? t: -t;
}

const int N=5e5+5, M=3e6+5;

int n;
char s[M];

int son[M][26], ptot, cnt[M];
inline void insert(char *s) {
    int u=0, d;
    rer(i, strlen(s+1), 1) {
        d=s[i]-'a';
        if(!son[u][d]) son[u][d]=++ptot;
        u=son[u][d];
    }
    cnt[u]++;
}

int dp[M], ans;
void dfs(int u) {
    int v, siz=0, mx1=0, mx2=0;
    rep(d, 0, 25) if(v=son[u][d]) {
        dfs(v), siz+=cnt[v];
        if(mx1<dp[v]) mx2=mx1, mx1=dp[v];
        else if(mx2<dp[v]) mx2=dp[v];
    }
    if(cnt[u]) dp[u]=mx1+max(siz, 1);
    ans=max(ans, mx1+mx2+cnt[u]+max(siz-2, 0));
}

int main() {
    n=read();
    ref(i, n) scanf("%s", s+1), insert(s);
    dfs(0);
    printf("%d", ans);
}

听说有人想看压行版本?(并没有人想看)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define rep(i, a, b) for(int i=(a), i##up=(b); i<=i##up; ++i)
#define ref(i, a) for(int i=1, i##up=(a); i<=i##up; ++i)
#define rer(i, a, b) for(int i=(a), i##dn=(b); i>=i##dn; --i)

inline int read() {
    int t=0, f=1; char c;
    while(!isdigit(c=getchar())) f=c^45;
    while(isdigit(c)) t=(t<<1)+(t<<3)+(c^48), c=getchar();
    return f? t: -t;
}

const int N=5e5+5, M=3e6+5;

int n;
char s[M];

int son[M][26], ptot, cnt[M];
inline void insert(char *s) {
    int u=0, d;
    rer(i, strlen(s+1), 1) son[u][d=s[i]-'a']||(son[u][d]=++ptot), u=son[u][d];
    cnt[u]++;
}
int dp[M], ans;
void dfs(int u) {
    int v, siz=0, mx1=0, mx2=0;
    rep(d, 0, 25) if(v=son[u][d]) dfs(v), siz+=cnt[v], mx1<dp[v]&&(mx2=mx1, mx1=dp[v], 1) || mx2<dp[v]&&(mx2=dp[v]);
    cnt[u]&&(dp[u]=mx1+max(siz, 1)), ans=max(ans, mx1+mx2+cnt[u]+max(siz-2, 0));
}

int main() {
    ref(i, n=read()) scanf("%s", s+1), insert(s);
    dfs(0), printf("%d", ans);
}