P14363 Solution: (In)Persistent Trie

· · 题解

谨以此文章纪念我自学习 OI 以来第一次在正式赛场上获得一道紫题的 AC。

由于题目保证 t_1\neq t_2,显然可以直接忽略所有 s_1=s_2 的替换串对。同时,为确保完整的正确性,需要特判 |t_1|\neq|t_2| 的情况。以下均假设 s_1\neq s_2|t_1|=|t_2| 成立。字符串下标从 0 开始计算,使用 s[l,r] 表示字符串 s 中下标在区间 [l,r] 中的字符形成的子串。

对于一个替换字符串对 (s_1,s_2),我们可以确定一个非负整数对 (l,r),同时满足 s_1[l]\neq s_2[l],s_1[r]\neq s_2[r],s_1[0,l-1]=s_2[0,l-1],s_1[r+1,|s_1|-1]=s_2[r+1,|s_2|-1]。根据定义,该非负整数对显然是唯一的。类似的,对于查询字符串对 (t_1,t_2),可以使用类似的方法求得唯一非负整数对 (l,r)

由于替换操作只能进行一次,一个替换字符串对可以对一个查询字符串对进行合法操作,当且仅当其 (l,r) 内的内容是完全重合的,且在 (l,r) 之外,替换字符串对的位置没有超出查询字符串对,而且内容依旧匹配。

形式化的,使用上述方法得到 l_s,r_s,l_t,r_t,那么可以替换的条件为 s_1[l_s,r_s]=t_1[l_t,r_t],s_2[l_s,r_s]=t_2[l_t,r_t]s_1[0,l_s-1]t_1[0,l_t-1] 的后缀,s_1[r_s+1,|s_1|-1]t_1[r_t+1,|t_1|-1] 的前缀。

由此,我们可以考虑使用两个字典树来维护前后缀。

首先,对于每个 (s_1,s_2) 对,我们将其 [l,r] 区间内的字符串提取出来,用作分类标准;其次,将其 [0,l-1][r+1,|s_1|-1] 区间内的字符串提取出来,作为前缀和后缀。

对于分出的每一类字符串对,首先按照其前缀长度升序排序,之后按照顺序,将前缀逆向后插入前缀字典树。前缀字典树上的每一个节点对应的是一个后缀字典树上的根节点,若新创建前缀字典树的节点,那么其将继承该节点的父节点所对应的后缀字典树根节点。

得到前缀对应的根节点后,新创建一个与该根节点相同的节点,并将前缀对应的根节点指向这个新节点。从这个新节点开始,遍历后缀,每遍历到一个节点都要创建新节点。最终,对最后走到的节点,将其权值增加一。

做完上面的所有操作后,对于每个查询字符串对 (t_1,t_2) 找到 (l,r)逆向遍历前缀,找到在 (l,r) 内字符串所对应的类型中最后一个有效的根,并从这个根开始遍历后缀,在途中加上所有的权值。将这个加和作为答案输出。

中间对替换字符串对的分类需要使用哈希。假如你用的是 map<pair<string,string>,int> 的话,准备好被时间复杂度为 O(n)operator< 创思吧。

时间复杂度 O(n\log n+L_1+L_2|\Sigma|)

#include <algorithm>
#include <iostream>
#include <map>
#include <string>
#include <utility>
#include <vector>
using namespace std;
const int N = 5.2e6 + 10, base = 29, mod = 1e9 + 9, M = 2e5 + 10;
// int stt;
struct st
{
    int sn[26], v;
} tr1[N], tr2[N];
// int ed;
int idx1, idx2;
int n, q, lp, rp;
using ull = unsigned long long;
using pli = pair<ull, int>;
string x, y;
ull hsv1, hsv2;
using pll = pair<ull, ull>;
map<pll, int> rts;
pll tmp;
map<pll, vector<pair<string, string>>> mpt;
int main()
{
    freopen("replace.in", "r", stdin);
    freopen("replace.out", "w", stdout);
    cin.tie(0);
    cout.tie(0);
    ios::sync_with_stdio(0);
    // cerr<<(&stt-&ed)/4.0/1024.0/1024.0<<'\n';
    cin >> n >> q;
    for (int i = 1; i <= n; i++)
    {
        cin >> x >> y;
        lp = 0, rp = x.size() - 1;
        while (lp < x.size() and x[lp] == y[lp])
            lp++;
        while (~rp and x[rp] == y[rp])
            rp--;
        if (lp >= x.size())
            continue;
        hsv1 = hsv2 = 0;
        for (int j = lp; j <= rp; j++)
            hsv1 = (hsv1 * base + x[j] - 'a') % mod, hsv2 = (hsv2 * base + y[j] - 'a') % mod;
        tmp = {hsv1, hsv2};
        mpt[tmp].emplace_back(x.substr(0, lp), y.substr(rp + 1));
    }
    for (auto &i : mpt)
    {
        // cout<<i.first.first<<' '<<i.first.second<<'\n';
        sort(i.second.begin(), i.second.end(),
             [&](pair<string, string> &p1, pair<string, string> &p2) { return p1.first.size() < p2.first.size(); });
        int crt = ++idx1;
        rts[i.first] = crt;
        for (auto &j : i.second)
        {
            int tbuf = crt;
            reverse(j.first.begin(), j.first.end());
            for (auto &k : j.first)
            {
                if (!tr1[tbuf].sn[k - 'a'])
                {
                    tr1[tbuf].sn[k - 'a'] = ++idx1;
                    tr1[idx1].v = tr1[tbuf].v;
                }
                tbuf = tr1[tbuf].sn[k - 'a'];
            }
            // cout<<tbuf<<'\n';
            tr2[++idx2] = tr2[tr1[tbuf].v];
            tr1[tbuf].v = idx2;
            tbuf = idx2;
            for (auto &k : j.second)
            {
                tr2[++idx2] = tr2[tr2[tbuf].sn[k - 'a']];
                tr2[tbuf].sn[k - 'a'] = idx2;
                tbuf = idx2;
            }
            tr2[tbuf].v++;
        }
    }
    for (int i = 1, res; i <= q; i++)
    {
        cin >> x >> y;
        lp = 0, rp = x.size() - 1;
        while (lp < x.size() and x[lp] == y[lp])
            lp++;
        while (~rp and x[rp] == y[rp])
            rp--;
        hsv1 = hsv2 = 0;
        for (int j = lp; j <= rp; j++)
            hsv1 = (hsv1 * base + x[j] - 'a') % mod, hsv2 = (hsv2 * base + y[j] - 'a') % mod;
        tmp = {hsv1, hsv2};
        // cout<<hsv1<<' '<<hsv2<<'\n';
        int tbuf = rts[tmp];
        for (int j = lp - 1; ~j; j--)
        {
            if (!tr1[tbuf].sn[x[j] - 'a'])
                break;
            tbuf = tr1[tbuf].sn[x[j] - 'a'];
        }
        tbuf = tr1[tbuf].v;
        if (!tbuf)
        {
            cout << 0 << '\n';
            continue;
        }
        // cout<<"ok\n";
        res = tr2[tbuf].v;
        for (int j = rp + 1; j < x.size(); j++)
        {
            if (!tr2[tbuf].sn[x[j] - 'a'])
                break;
            tbuf = tr2[tbuf].sn[x[j] - 'a'];
            res += tr2[tbuf].v;
        }
        cout << res << '\n';
    }
}