题解:P4384 [八省联考 2018] 制胡窜

· · 题解

提供一种码量较小,但是分讨细节较多的做法。

分析

首先,问题是将序列划分为三段,问有多少种划分方式使得三段中存在至少一段使得 s_{l,r} 在这段中出现过。

直接做比较麻烦,考虑容斥,计算有多少个不合法的划分方式。

这如当是说,s_{l,r} 的任意一次出现 s_{x,y} 都被两个划分线中的至少一个分开了。我们称将 i-1,i 划分开为“割在 i 上”。

首先,维护 s_{l,r} 出现的位置信息是经典的 SAM 问题。

考虑建出 parent 树之后,从叶子倍增找到一个节点使得这个节点表示的长度区间包含了 r-l+1

然后离线树上启发式合并就可以维护一个子树内的 endpos 集合了。

现在我们分讨一下 s_{l,r} 的每次出现的形态。

如果存在三个子串和 s_{l,r} 相等,且这三个串和不相交,则一定不能使得所有串都被分割开。继续观察发现,只要相邻两个子串的交长度至多为 1 则也不存在划分方式使得这三个串都被分开了。

所以这种情况下的答案就是 \binom{n-1}{2}

那么现在所有和 s_{l,r} 相等的子串的位置的并集只会形成至多两段连续段。

考虑如果是两段的话,那么分割线一定需要在两段分别的交集上。这种情况也是容易通过维护二分找到两段首尾的 endpos 求出的。对于两段中的一段,记录找到的最小的 endpos 为 first_pos,最大的 endpos 为 last_pos。则这一段中可以选择割的区间是 (\operatorname{last\_pos}-\operatorname{len}+1,\operatorname{first\_pos}]

只有一段的情况是最难求的。

不过我们通过对其的 border 的分析发现,出现位置的 endpos 构成一个等差数列!

在这个优雅的性质下,直觉告诉我们他是可以 O(1) 求的。

考虑分讨所有的子串的交集是否为空。

\Delta 为两个相邻 endpos 之间的差。同样记录找到的最小的 endpos 为 first_pos,最大的 endpos 为 last_pos

如果为空,考虑固定左边的分割点,计算有多少个合法的右端点,其从 first_pos 向前是一段相同元素出现 \Delta 次的公差为 -\Delta 的等差数列。(注意最开始的一段长度可能不是 \Delta 个,而是 (r-l+1)\mod\Delta 个)(再注意到等差数列会一直减少直到下一个元素 \le 0 为止,可以证明的是等差数列会在 first_pos-len+1 之前停下来)

计算左端点割在 first_pos 时右端点的割的方案数:

v=\lfloor\frac{\operatorname{len}-1+\Delta-1}{\Delta}\rfloor\times\Delta+\operatorname{first\_pos}-\operatorname{last\_pos}+\operatorname{len}-1

计算等差数列:

\frac{(v+v\mod \Delta)(\lfloor\frac{v}{\Delta}\rfloor+1)}{2}\Delta-((\Delta-(\operatorname{len}-1)\mod\Delta)\mod \Delta )\times v

如果不为空,则发现,如果左端点割在相交部分,右端点的取值是任意的。

剩下部分和上述类似的,固定左端点,右端点的取值个数同样是等差数列。

详见代码。

code

#include<bits/stdc++.h>
#define int long long
#define rep(i,a,b) for(register int i=(a);i<=(b);++i)
#define per(i,a,b) for(register int i=(a);i>=(b);--i)
#define edge(i,u) for(int i=head[u];i;i=e[i].next)
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define sed second
#define Max(a,b) (a=max(a,b))
#define Min(a,b) (a=min(a,b))
using namespace std;
const int N=2e5+10,M=1e6+10,inf=1e9,mod=1e9+7;
bool MS;int used;
namespace SAM
{
    vector<int>edge[N];
    int f[N][20];
    int pos[N];
    struct st
    {
        int fsp,len,fail;
        int next[10];
    }st[N<<1];
    int res=0,last;
    void init()
    {
        st[0].fail=-1;
    }
    int Insert(int c)
    {
        int u=++res;
        int p=last;
        st[u].len=st[p].len+1;
        st[u].fsp=st[u].len;
        pos[u]=st[u].len;
        while(p!=-1&&!st[p].next[c])
        {
            st[p].next[c]=u;
            p=st[p].fail;
        }
        if(p==-1)
        st[u].fail=0;
        else
        {
            int v=st[p].next[c];
            if(st[v].len==st[p].len+1)
            {
                st[u].fail=v;
            }
            else
            {
                int copy=++res;
                st[copy]=st[v];
                st[copy].len=st[p].len+1;
                while(p!=-1&&st[p].next[c]==v)
                {
                    st[p].next[c]=copy;
                    p=st[p].fail;
                }
                st[u].fail=st[v].fail=copy;
            }
        }
        last=u;
        return u;
    }
    void dfs(int u)
    {
        rep(i,1,19)
        f[u][i]=f[f[u][i-1]][i-1];
        for(auto v:edge[u])
        {
            f[v][0]=u;
            dfs(v);
        }
    }
}using namespace SAM;
int n,m;
int val[N];
string s;
void build()
{
    init();
    rep(i,1,n)
    val[i]=Insert(s[i]-'0');
    rep(i,1,res)
    edge[st[i].fail].pb(i);
    dfs(0);
}
set<int>vl[N];
vector<pii>q[N];
int ans[N<<1];
void getans(int u)
{
    if(pos[u])
    vl[u].insert(pos[u]);
    for(auto v:edge[u])
    {
        getans(v);
        if(vl[v].size()>vl[u].size())
        std::swap(vl[u],vl[v]);
        for(auto s:vl[v])
        vl[u].insert(s);
    }
    for(auto s:q[u])
    {
        int len=s.fst;
        int from=s.sed;
        int fps=*vl[u].begin();
        int lps=*vl[u].rbegin();
        if(vl[u].size()==1)
        {
            ans[from]=(n-len)*(n-len-1)/2;
            continue;
        }//比较特殊的情况
        auto g=vl[u].lower_bound(fps+len-1);
        if(g==vl[u].end()||*g>lps-len+1)
        {
            auto r=vl[u].upper_bound(lps-len+1);
            auto l=vl[u].lower_bound(fps+len-1);
            l--;
            if(*l<=*r-len)
            {
                ans[from]-=(fps-(*l-len)-1)*(*r-(lps-len)-1);
            }
            else
            {
                int delta=(lps-fps)/(vl[u].size()-1);
                len--;
                int dl=fps-len+1;
                int dr=fps;
                int tl=lps-len+1;
                int tr=lps;
                if(tl<=dr)
                {
                    ans[from]-=(fps-len)*(dr-tl+1)+(n-tl-1+n-dr-1)*(dr-tl+1)/2;
                    int val=(dr-tl+1)+delta;
                    ans[from]-=(val+((len-(dr-tl+1))/delta-1)*delta+val)*((len-(dr-tl+1))/delta)/2*delta+(len-(dr-tl+1))%delta*((len-(dr-tl+1))/delta*delta+val);
                }
                else
                {
                    int val=(len+delta-1)/delta*delta+fps-lps+len;
                    ans[from]-=(val+val%delta)*(val/delta+1)/2*delta-(delta-len%delta)%delta*val;
                }
            }
        }
    }
}
bool MT;
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>m;
    cin>>s;
    s=" "+s;
    build();
    rep(w,1,m)
    {
        int l,r;
        cin>>l>>r;
        int len=r-l+1;
        int u=val[r];
        per(i,19,0)
        {
            if(st[f[u][i]].len>=(r-l+1))
            u=f[u][i];
        }
        ans[w]=(n-1)*(n-2)/2;
        q[u].pb(mp(len,w));
    }
    getans(0);
    rep(i,1,m)
    cout<<ans[i]<<'\n';
    cerr<<"Memory:"<<(&MS-&MT)/1048576.0<<"MB Time:"<<clock()/1000.0<<"s\n";
}