P3346 [ZJOI2015]诸神眷顾的幻想乡

· · 题解

\text{Solution}

话说这种通过对树的形态特殊限制最近做了还有一道,以保证暴力做法复杂度。

链怎么做?是不是从上至下以及从下到上类似于 trie 的形式建广义 SAM 求不同子串。或者你思考现在只有一段序列,不同子串的求法,用 SAM 就好,那么正反都要,是不是我正向跑一次,反向跑一次,但这时候就要用到广义 SAM 了,然后这种正向依赖前一个字符的序列做法实际上就是链上依赖 trie 树的上一个字符。?

抽象的,我们只是把每一个前缀插入到 SAM 里面了,那么获取这些前缀的方法,我们可以类似于 trie ,但实际上 dfs 保证依赖关系就好了。

类比到树上:看图下,摸了几遍发现,假如我们要保证正序和倒序的字符串都能被找到,那么我们必须要从叶子节点出发,仿照链的做法即可。

注意下广义 SAM 的 pre 以及空间问题。

\text{Code}

#include <bits/stdc++.h>

#define N (int)(2e6+5)
#define ll long long

using namespace std;

struct edge {
    int nex,to;
}e[N];

int head[N],cnt;
int du[N],col[N],fa[N],son[10][N],len[N];
int n,m,tot=1;

int rd() {
    int f=1,sum=0; char ch=getchar();
    while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)) {sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
    return sum*f;
}

void add_edge(int x,int y) {
    e[++cnt].nex=head[x]; e[cnt].to=y;
    head[x]=cnt;
}

int ins(int c,int las) {
    if(son[c][las]) {
        int pre=las,y=son[c][las];
        if(len[pre]+1==len[y]) return y;
        int x=++tot; len[x]=len[pre]+1;
        fa[x]=fa[y]; fa[y]=x;
        for(int i=0;i<10;i++) son[i][x]=son[i][y];
        for(;pre&&son[c][pre]==y;pre=fa[pre]) son[c][pre]=x;
        return x;
    }
    int pre=las,x=++tot; len[x]=len[pre]+1;
    for(;pre&&!son[c][pre];pre=fa[pre]) son[c][pre]=x;
    int y=son[c][pre];
    if(!pre) fa[x]=1;
    else if(len[pre]+1==len[y]) fa[x]=y;
    else {
        int p=++tot; len[p]=len[pre]+1;
        fa[p]=fa[y]; fa[x]=fa[y]=p;
        for(int i=0;i<10;i++) son[i][p]=son[i][y];
        for(;pre&&son[c][pre]==y;pre=fa[pre]) son[c][pre]=p;
    }
    return x;
}

void dfs(int x,int fa,int las) {
    las=ins(col[x],las);
    for(int i=head[x];i;i=e[i].nex) {
        int y=e[i].to;
        if(y==fa) continue;
        dfs(y,x,las);
    }
}

int main() {
    int x,y;
    n=rd(); m=rd();
    for(int i=1;i<=n;i++) col[i]=rd();
    for(int i=1;i<n;i++) {
        x=rd(); y=rd();
        add_edge(x,y); add_edge(y,x);
        ++du[x]; ++du[y];
    }   
    for(int i=1;i<=n;i++) if(du[i]==1) dfs(i,0,1);
    ll ans=0;
    for(int i=2;i<=tot;i++) ans+=len[i]-len[fa[i]];
    printf("%lld",ans);
    return 0;
}