题解 P7279 【光棱碎片】

· · 题解

本文同步发表于我的博客:https://www.alpha1022.me/articles/lg-7279.htm。

首先将限制转化为 \le R
建出反串的后缀树,执行静态链分治(即树上启发式合并)。
设当前结点包含字符串的最长长度为 \rm len,则考虑子树内的后缀 i,j 的贡献。
显然是 \min(R - (a_i \oplus a_j),{\rm len})

考虑分类讨论。
a_i \oplus a_j \le R - {\rm len},需要这部分的数对的个数。
R - {\rm len} < a_i \oplus a_j \le R,需要这部分的数对的 a_i \oplus a_j 之和以及个数。

则问题变为维护一个多重集 S,支持对于 i \in S,查询满足 i \oplus v \le Ri 的个数以及 i \oplus v 的和。
可以考虑使用 Trie 来维护,查询时在 Trie 上走 v \oplus R 的路径,当某个时刻 R 的当前位为 1,则意味着若走向另一个儿子,异或值的这一位会变为 0,故将答案加上另一个儿子子树内的贡献即可。
当然,还要加上 v \oplus R 本身的贡献。
而同时要查询和,可以套路地维护每一位 1 的个数来支持查询。

时间复杂度 O(n \log n \log^2 a_i)
因为这是一棵后缀树,可以相信良心出题人没有卡树剖,所以是可以过的。

当然,在 SA 的 height 数组上并查集启发式合并或所谓「启发式分裂」并使用可持久化 Trie 也是可以的。
实际上这两种做法都相当于在 height 数组的笛卡尔树上做静态链分治,而 height 数组的笛卡尔树和后缀树某种意义上是等价的。

代码:

#include <cstdio>
#include <vector>
#include <cstring>
#include <utility>
#include <algorithm>
using namespace std;
const int N = 1e5;
const int LG = 16;
const int mod = 998244353;
int n,L,R,w[N + 5];
char s[N + 5];
int ans;
namespace TRIE
{
    struct node
    {
        int ch[2];
        int cnt,sum[LG + 5];
    } tr[(N << 5) + 5];
    int tot = 1;
    inline void insert(int v,int d)
    {
        int p = 1;
        tr[p].cnt += d;
        for(register int k = 0;k <= LG;++k)
            tr[p].sum[k] += ((v >> k) & 1) * d;
        for(register int i = LG;~i;--i)
        {
            !(tr[p].ch[(v >> i) & 1]) && (tr[p].ch[(v >> i) & 1] = ++tot),
            p = tr[p].ch[(v >> i) & 1];
            tr[p].cnt += d;
            for(register int k = 0;k <= LG;++k)
                tr[p].sum[k] += ((v >> k) & 1) * d;
        }
    }
    inline pair<int,int> query(int v,int r)
    {
        if(r < 0)
            return make_pair(0,0);
        int p = 1;
        int cnt = 0,sum = 0;
        for(register int i = LG;~i && p;--i)
        {
            if((r >> i) & 1)
            {
                cnt += tr[tr[p].ch[(((v ^ r) >> i) & 1) ^ 1]].cnt;
                for(register int k = 0;k <= LG;++k)
                    if((v >> k) & 1)
                        sum = (sum + (tr[tr[p].ch[(((v ^ r) >> i) & 1) ^ 1]].cnt - tr[tr[p].ch[(((v ^ r) >> i) & 1) ^ 1]].sum[k]) * (1LL << k)) % mod;
                    else
                        sum = (sum + tr[tr[p].ch[(((v ^ r) >> i) & 1) ^ 1]].sum[k] * (1LL << k)) % mod;
            }
            p = tr[p].ch[((v ^ r) >> i) & 1];
        }
        cnt += tr[p].cnt;
        for(register int k = 0;k <= LG;++k)
            if((v >> k) & 1)
                sum = (sum + (tr[p].cnt - tr[p].sum[k]) * (1LL << k)) % mod;
            else
                sum = (sum + tr[p].sum[k] * (1LL << k)) % mod;
        return make_pair(cnt,sum);
    }
}
namespace SAM
{
    struct node
    {
        int ch[26];
        int fa,len;
    } sam[(N << 1) + 5];
    int las = 1,tot = 1;
    int c[N + 5],a[(N << 1) + 5];
    int sz[(N << 1) + 5],son[(N << 1) + 5];
    vector<int> edp[(N << 1) + 5];
    inline void insert(int x,int pos)
    {
        int cur = las,p = ++tot;
        sam[p].len = sam[cur].len + 1;
        for(;cur && !sam[cur].ch[x];cur = sam[cur].fa)
            sam[cur].ch[x] = p;
        if(!cur)
            sam[p].fa = 1;
        else
        {
            int q = sam[cur].ch[x];
            if(sam[cur].len + 1 == sam[q].len)
                sam[p].fa = q;
            else
            {
                int nxt = ++tot;
                sam[nxt] = sam[q],sam[nxt].len = sam[cur].len + 1,sam[p].fa = sam[q].fa = nxt;
                for(;cur && sam[cur].ch[x] == q;cur = sam[cur].fa)
                    sam[cur].ch[x] = nxt;
            }
        }
        ++sz[las = p],edp[las].push_back(pos);
    }
    int to[(N << 1) + 5],pre[(N << 1) + 5],first[(N << 1) + 5];
    inline void add(int u,int v)
    {
        static int tot = 0;
        to[++tot] = v,pre[tot] = first[u],first[u] = tot;
    }
    inline void build()
    {
        for(register int i = 1;i <= tot;++i)
            ++c[sam[i].len],i > 1 && (add(sam[i].fa,i),1);
        for(register int i = 1;i <= n;++i)
            c[i] += c[i - 1];
        for(register int i = tot;i > 1;--i)
            a[c[sam[i].len]--] = i;
        for(register int i = tot;i > 1;--i)
        {
            sz[sam[a[i]].fa] += sz[a[i]];
            if(!son[sam[a[i]].fa] || sz[a[i]] > sz[son[sam[a[i]].fa]])
                son[sam[a[i]].fa] = a[i];
        }
    }
    void dfs(int p)
    {
        int len = sam[p].len;
        pair<int,int> res1,res2;
        for(register int i = first[p];i;i = pre[i])
            if(to[i] ^ son[p])
            {
                dfs(to[i]);
                for(register vector<int>::iterator it = edp[to[i]].begin();it != edp[to[i]].end();++it)
                    TRIE::insert(w[*it],-1);
            }
        if(son[p])
            dfs(son[p]);
        for(register vector<int>::iterator it = edp[p].begin();it != edp[p].end();++it)
            res1 = TRIE::query(w[*it],R - len),
            res2 = TRIE::query(w[*it],R),
            ans = (ans + (long long)res1.first * len) % mod,
            ans = (ans + (long long)(res2.first - res1.first) * R % mod - (res2.second - res1.second + mod) % mod + mod) % mod,
            res1 = TRIE::query(w[*it],L - 1 - len),
            res2 = TRIE::query(w[*it],L - 1),
            ans = (ans - (long long)res1.first * len % mod + mod) % mod,
            ans = (ans - (long long)(res2.first - res1.first) * (L - 1) % mod + (res2.second - res1.second + mod) % mod + mod) % mod;
        for(register vector<int>::iterator it = edp[p].begin();it != edp[p].end();++it)
            TRIE::insert(w[*it],1);
        if(son[p])
        {
            edp[p].swap(edp[son[p]]);
            for(register vector<int>::iterator it = edp[son[p]].begin();it != edp[son[p]].end();++it)
                edp[p].push_back(*it);
            vector<int>().swap(edp[son[p]]);
        }
        for(register int i = first[p];i;i = pre[i])
            if(to[i] ^ son[p])
            {
                for(register vector<int>::iterator it = edp[to[i]].begin();it != edp[to[i]].end();++it)
                    res1 = TRIE::query(w[*it],R - len),
                    res2 = TRIE::query(w[*it],R),
                    ans = (ans + (long long)res1.first * len) % mod,
                    ans = (ans + (long long)(res2.first - res1.first) * R % mod - (res2.second - res1.second + mod) % mod + mod) % mod,
                    res1 = TRIE::query(w[*it],L - 1 - len),
                    res2 = TRIE::query(w[*it],L - 1),
                    ans = (ans - (long long)res1.first * len % mod + mod) % mod,
                    ans = (ans - (long long)(res2.first - res1.first) * (L - 1) % mod + (res2.second - res1.second + mod) % mod + mod) % mod;
                for(register vector<int>::iterator it = edp[to[i]].begin();it != edp[to[i]].end();++it)
                    edp[p].push_back(*it),TRIE::insert(w[*it],1);
                vector<int>().swap(edp[to[i]]);
            }
    }
}
int main()
{
    scanf("%d%s",&n,s + 1);
    for(register int i = 1;i <= n;++i)
        scanf("%d",w + i);
    scanf("%d%d",&L,&R);
    for(register int i = 1;i <= n;++i)
        SAM::insert(s[i] - 'a',i);
    SAM::build(),SAM::dfs(1);
    printf("%d\n",ans);
}