题解 CF1098F 【Ж-function】

· · 题解

这真的是一道神题……

首先要求的这个东西\sum_{i=l}^rlcp(s[l,r],s[i,r])蕴含了一个不能超出r的限制,就相当于\sum_{i=l}^r\min(r-i+1,lcp(suf_l,suf_i)),并不是完全的后缀的LCP

不过我们可以考虑完全意义上的对后缀的LCP求和怎么做,就是把那个\min去掉算\sum_{i=l}^rlcp(suf_i,suf_l),我们知道两个后缀的LCP等于后缀树上两个节点的LCA的深度,那么相当于求\sum_{i=l}^rdeep(LCA(i,l)),就是这题,做法就是考虑l的每个祖先的贡献,我们对每个祖先求出子树里有多少个[l,r]的点就好了。同样的这题也需要这个思路。

(注意以下为方便描述我们直接称l这个后缀在后缀树上的节点为l

那么我们就枚举l在后缀树上的每个深度不超过r-l+1的祖先u(因为LCP显然不能超过区间长度),我们要统计有多少个x\in[l,r]使得lcp(s[x,r],s[l,r])>=deep_u,那么r-x+1>=deep[u](即x<=r+1-deep[u])必须要满足,那么我们就是相当于u的子树里有多少个x\in[l,r+1-deep[u]]

蛤,要找的x是和deep_u相关的?完了好像没法统计贡献了……

既然不好搞,那我们干脆直接把这个东西前缀和相减一下(因为我们统计的deep[u]都是<=r-l+1的,所以[l,r+1-deep[u]]一定不是空集),就分成了两个步骤:

①:求u子树里有多少个x<=l-1

②:求u子树里有多少个x<=r+1-deep[u]

上述两个都要对l的所有深度<=r-l+1的祖先求和,显然这也是某个点到根的一条链。其实我们干的事情就是统计这条直链与别的不在这条链上的点(可能在其细枝末节的子树里)产生的贡献,或者说,一条直链与一个点x的贡献会在x到根的这条链与这条直链的交上,那么我们可以把这两条链树链剖分,我们对一个询问[l,r]把这个询问存到l的深度<=r-l+1的链的所有重链上,并且在每条重链上记录与这条重链的交的长度LA(因为可能只和重链交一部分)。我们再把所有后缀节点x存到到根的所有重链上,记录与重链交的长度len。接下来我们将在每一条重链上统计所有后缀节点能够给所有询问的贡献。

先考虑①怎么做。我们其实要求这个:\sum_{(x,len)\in u,x<=l-1}\min(LA,len)(注意u在这里是重链的链顶),这个对应到二维平面上长这样:

就是每个(x,len)都对应了一条与纵轴平行的线段,每个询问相当于求一个左下角的矩形区域内线段被框柱部分的长度和。这个直接扫描线,对(x,len)x排序,询问按l-1排序,每个(x,len)在线段树上区间加然后区间查询就行了,这还是比较容易的。

那个②就比较恶心了,首先我们应该有这样的转化,对于存在重链u上的每个(x,len),相当于这个x在这条重链从上到下依次对应着x+deep_u,x+deep_u+1,……x+deep_u+len-1这些连续递增的权值(要知道重链也是直的),那么就相当于要求\sum_{(x,len)\in u}\max(0,\min(LA,len,r+1-x-deep_u))

蛤?你在逗我吗?这么复杂的东西能算啊……

其实我们把这个问题拆开分析一下就会发现它还是能算的。我们可以对询问的LA和后缀节点的len排序然后对这个扫描线,可以讨论LAlen的大小关系:

①:len<=LA,我们需要在扫到LA这个询问的时候对前面已经扫过(x,len)求答案,那么既然x+deep_u,x+deep_u+1,……x+deep_u+len-1这组连续的数的长度不会超过LA的限制,那么我们只需要让他们不超过r+1的限制就行了,这个做法就和上面那个①一样了,在线段树上区间加[x+deep_u,x+deep_u+len-1],在询问的时候查<=r+1的部分就行了。

②:len>LA,我们相当于把那个len的限制去掉变成了\sum_{(x,len)\in u,len>LA}\max(0,\min(LA,r+2-x-deep_u)),即考虑还未扫过的部分,首先我们希望r+2-x-deep_u>0,即-x-deep_u>=-r-1,我们把式子改一下:

\max(0,\min(LA-r-2,-x-deep_u)+r+2)

那么我们对于每个x,需要维护的仅仅是与-x-deep_u相关的,可以对这个进行大力讨论:

①:-x-deep_u\in[-r-1,LA-r-2],这个就特别好因为LA>0,显然-r-1<=LA-r-2所以这个区间不会是空集,我们直接取这部分的\sum-x-deep_u即可。

-x-deep_u>LA-r-2,这个肯定要和LA-r-2\min,所以找出有多少个在这个范围内的-x-deep_u然后乘个LA-r-2即可。

另外外面加的那个r+2找出有多少个-1-deep_u>=-r-1再乘上r+2就行了。

上面的这3部分都可以用树状数组维护,由于我们是维护的未扫过的部分,所以我们可以先把所有x在树状数组上都加一下,然后每扫到一个就减掉就行了。另外由于是负数下标还要加偏移量什么的……

哦对了有一个问题就是我们实际上用的后缀树把很多真正的后缀树上的节点都缩到了一条边上,所以我们在统计的时候所用的len其实并不是这条重链上交点的个数而是后缀树上的真正长度,并且链顶的管辖范围也要做符合实际的变化,反正后缀树上的问题都得切合实际啊……所以这题细节比较多……

哦另外我们最后要把在1点统计的答案减掉……因为1表示空串,而我们显然不能统计空串……

那么这题就做完了,由于每个询问和后缀节点被存到了O(\log n)条重链上,在每条重链上做的扫描线是O(n\log n)的,所以总复杂度均摊O((n+q)\log^2n)

上代码~

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define ll long long
#define offset 500000
#define K 1000000
#define N 400000
#define up(_o) data[_o] = data[ls(_o)] + data[rs(_o)]
#define ls(_o) (_o << 1)
#define rs(_o) ((_o << 1) | 1)
using namespace std;
namespace ywy {
    inline int get() {
        int n = 0;
        char c;
        while ((c = getchar()) || 23333) {
            if (c >= '0' && c <= '9')
                break;
            if (c == '-')
                goto s;
        }
        n = c - '0';
        while ((c = getchar()) || 23333) {
            if (c >= '0' && c <= '9')
                n = n * 10 + c - '0';
            else
                return (n);
        }
    s:
        while ((c = getchar()) || 23333) {
            if (c >= '0' && c <= '9')
                n = n * 10 - c + '0';
            else
                return (n);
        }
    }
    int sam[N][26], mxdfn[N], len[N], fa[N], top[N], zhongson[N], size[N], dfn[N], fan[N], gdfn = 1;
    typedef struct _b {
        int dest;
        int nxt;
    } bian;
    bian memchi[1000001];
    int gn = 1, heads[N], ance[N][20];
    inline void add(int s, int t) {
        memchi[gn].dest = t;
        memchi[gn].nxt = heads[s];
        heads[s] = gn;
        gn++;
    }
    ll anss[N];
    typedef struct _n {
        int x;
        int len;
        int lim;
        int pos;
        unsigned char isque;
        _n() { x = len = lim = pos = isque = 0; }
        friend bool operator<(const _n &a, const _n &b) {
            if (a.pos == b.pos)
                return (a.isque < b.isque);
            return (a.pos < b.pos);
        }
    } node;
    vector<node> leaf[N], quel[N], quer1[N];
    void dfs(int pt) {
        size[pt] = 1;
        top[pt] = pt;
        int mx = 0, best = 0;
        for (register int i = heads[pt]; i; i = memchi[i].nxt) {
            dfs(memchi[i].dest);
            size[pt] += size[memchi[i].dest];
            ance[memchi[i].dest][0] = pt;
            if (size[memchi[i].dest] > mx)
                mx = size[memchi[i].dest], best = memchi[i].dest;
        }
        zhongson[pt] = best;
    }
    void efs(int pt) {
        dfn[pt] = gdfn;
        fan[gdfn] = pt;
        gdfn++;
        if (zhongson[pt])
            top[zhongson[pt]] = top[pt], efs(zhongson[pt]);
        for (register int i = heads[pt]; i; i = memchi[i].nxt) {
            if (memchi[i].dest == zhongson[pt])
                continue;
            efs(memchi[i].dest);
        }
    }
    int gsam = 2, nd[N];
    inline int zhuanyi(int p, int x) {
        int me = gsam;
        gsam++;
        len[me] = len[p] + 1;
        while (p && !sam[p][x]) sam[p][x] = me, p = fa[p];
        if (!p) {
            fa[me] = 1;
            return (me);
        }
        int q = sam[p][x];
        if (len[q] == len[p] + 1) {
            fa[me] = q;
            return (me);
        }
        int nq = gsam;
        gsam++;
        len[nq] = len[p] + 1;
        fa[nq] = fa[q];
        fa[q] = fa[me] = nq;
        for (register int i = 0; i < 26; i++) sam[nq][i] = sam[q][i];
        while (p && sam[p][x] == q) sam[p][x] = nq, p = fa[p];
        return (me);
    }
    int adds[2000001];
    ll data[2000001];
    inline void down(int tree, int l, int r) {
        if (adds[tree]) {
            int mid = (l + r) >> 1;
            data[ls(tree)] += (ll)adds[tree] * (mid - l + 1);
            data[rs(tree)] += (ll)adds[tree] * (r - mid);
            adds[ls(tree)] += adds[tree];
            adds[rs(tree)] += adds[tree];
            adds[tree] = 0;
        }
    }
    void add(int rl, int rr, int l, int r, int tree) {
        if (rl == l && rr == r) {
            data[tree] += (r - l + 1);
            adds[tree]++;
            return;
        }
        int mid = (l + r) >> 1;
        down(tree, l, r);
        if (rl > mid)
            add(rl, rr, mid + 1, r, rs(tree));
        else {
            if (rr <= mid)
                add(rl, rr, l, mid, ls(tree));
            else {
                add(rl, mid, l, mid, ls(tree));
                add(mid + 1, rr, mid + 1, r, rs(tree));
            }
        }
        up(tree);
    }
    void sub(int rl, int rr, int l, int r, int tree) {
        if (rl == l && rr == r) {
            adds[tree]--;
            data[tree] -= (r - l + 1);
            return;
        }
        int mid = (l + r) >> 1;
        down(tree, l, r);
        if (rl > mid)
            sub(rl, rr, mid + 1, r, rs(tree));
        else {
            if (rr <= mid)
                sub(rl, rr, l, mid, ls(tree));
            else {
                sub(rl, mid, l, mid, ls(tree));
                sub(mid + 1, rr, mid + 1, r, rs(tree));
            }
        }
        up(tree);
    }
    ll query(int rl, int rr, int l, int r, int tree) {
        if (rl == l && rr == r)
            return (data[tree]);
        int mid = (l + r) >> 1;
        down(tree, l, r);
        if (rl > mid)
            return (query(rl, rr, mid + 1, r, rs(tree)));
        if (rr <= mid)
            return (query(rl, rr, l, mid, ls(tree)));
        return (query(rl, mid, l, mid, ls(tree)) + query(mid + 1, rr, mid + 1, r, rs(tree)));
    }
    node ints[N + 2];
    ll c1[K + 1], c2[K + 1];
    inline ll query(int l, int r, ll *c) {
        l += offset;
        r += offset;
        if (l > r)
            return (0);
        ll tot = 0;
        for (register int i = r; i > 0; i -= (i & -i)) tot += c[i];
        for (register int i = l - 1; i > 0; i -= (i & -i)) tot -= c[i];
        return (tot);
    }
    inline void cadd(int pos, ll x, ll *c) {
        pos += offset;
        for (register int i = pos; i <= K; i += (i & -i)) c[i] += x;
    }
    char str[222222];
    inline int getance(int pt, int mxlen) {  //找到第一个len>=mxlen的祖先
        for (register int i = 19; i >= 0 && len[pt] > mxlen; i--)
            if (len[ance[pt][i]] >= mxlen)
                pt = ance[pt][i];
        return (pt);
    }
    void print(ll num) {
        if (num < 0)
            putchar('-'), num = -num;
        if (num >= 10)
            print(num / 10);
        putchar(num % 10 + '0');
    }
    void ywymain() {
        scanf("%s", str + 1);
        int n = strlen(str + 1);
        len[0] = -1;
        int p = 1;
        for (register int i = n; i >= 1; i--) {
            p = zhuanyi(p, str[i] - 'a');
            nd[i] = p;
        }
        for (register int i = 2; i < gsam; i++) add(fa[i], i);
        dfs(1);
        efs(1);
        for (register int i = 1; i <= 19; i++) {
            for (register int j = 1; j < gsam; j++) ance[j][i] = ance[ance[j][i - 1]][i - 1];
        }
        for (register int i = 1; i <= n; i++) {
            int cur = top[nd[i]], ppl = len[nd[i]] - len[fa[cur]];
            while (cur) {
                node cjr;
                cjr.len = ppl;
                cjr.x = i;
                leaf[cur].push_back(cjr);
                ppl = len[fa[cur]] - len[fa[top[fa[cur]]]];
                cur = top[fa[cur]];
            }
        }
        int q = get();
        for (register int i = 1; i <= q; i++) {
            int l = get(), r = get();
            int cyh = getance(nd[l], r - l + 1);
            anss[i] = -(min(n, r + 1) - l + 1);
            int cur = top[cyh], ppl = r - l + 1 - len[fa[cur]];
            while (cur) {
                node cjr;
                cjr.len = ppl;
                cjr.x = i;
                cjr.lim = l - 1;
                cjr.isque = 1;
                quel[cur].push_back(cjr);
                cjr = _n();
                cjr.len = ppl;
                cjr.x = i;
                cjr.lim = r + 1;
                cjr.isque = 1;
                quer1[cur].push_back(cjr);
                ppl = len[fa[cur]] - len[fa[top[fa[cur]]]];
                cur = top[fa[cur]];
            }
        }
        for (register int u = 1; u < gsam; u++) {
            if (top[u] != u)
                continue;
            int ptr = 1;
            for (register int i = 0; i < quel[u].size(); i++) {
                node cjr = quel[u][i];
                cjr.pos = cjr.lim;
                ints[ptr] = cjr;
                ptr++;
            }
            for (register int i = 0; i < leaf[u].size(); i++) {
                node cjr = leaf[u][i];
                cjr.pos = cjr.x;
                ints[ptr] = cjr;
                ptr++;
            }
            ptr--;
            sort(ints + 1, ints + 1 + ptr);
            for (register int i = 1; i <= ptr; i++) {
                if (!ints[i].isque)
                    add(1, ints[i].len, 1, N, 1);
                else {
                    anss[ints[i].x] -= query(1, ints[i].len, 1, N, 1);
                }
            }
            for (register int i = 0; i < leaf[u].size(); i++) sub(1, leaf[u][i].len, 1, N, 1);
            ptr = 1;
            for (register int i = 0; i < quer1[u].size(); i++) {
                node cjr = quer1[u][i];
                cjr.pos = cjr.len;
                ints[ptr] = cjr;
                ptr++;
            }
            for (register int i = 0; i < leaf[u].size(); i++) {
                node cjr = leaf[u][i];
                cjr.pos = cjr.len;
                ints[ptr] = cjr;
                ptr++;
            }
            ptr--;
            sort(ints + 1, ints + 1 + ptr);
            for (register int i = 0; i < leaf[u].size(); i++) {
                cadd(-leaf[u][i].x - (len[fa[u]] + 1), 1, c1);
                cadd(-leaf[u][i].x - (len[fa[u]] + 1), -leaf[u][i].x - (len[fa[u]] + 1), c2);
            }
            for (register int i = 1; i <= ptr; i++) {
                if (!ints[i].isque) {
                    cadd(-ints[i].x - (len[fa[u]] + 1), -1, c1);
                    cadd(-ints[i].x - (len[fa[u]] + 1), ints[i].x + (len[fa[u]] + 1), c2);
                    add(ints[i].x + len[fa[u]] + 1, ints[i].x + len[fa[u]] + ints[i].len, 1, N, 1);
                } else {
                    anss[ints[i].x] += query(1, ints[i].lim, 1, N, 1);
                    anss[ints[i].x] += (ll)(ints[i].lim + 1) * query(-ints[i].lim, K - offset, c1);
                    anss[ints[i].x] +=
                        (ll)(ints[i].len - ints[i].lim - 1) * query(ints[i].len - ints[i].lim, K - offset, c1);
                    anss[ints[i].x] += query(-ints[i].lim, ints[i].len - ints[i].lim - 1, c2);
                }
            }
            for (register int i = 0; i < leaf[u].size(); i++) {
                sub(leaf[u][i].x + len[fa[u]] + 1, leaf[u][i].x + len[fa[u]] + leaf[u][i].len, 1, N, 1);
            }
        }
        for (register int i = 1; i <= q; i++) print(anss[i]), putchar('\n');
    }
}
int main() {
    ywy::ywymain();
    return (0);
}