[线段树维护哈希] CF213E

· · 题解

前言:线段树维护哈希好题,想通一点就容易做了,但是不看题解切掉还是挺困难的。

考虑如何匹配,有枚举值、匹配下标和枚举下标、匹配值两种方案,由于要求合法的 x 的数量,因此枚举值看起来更合理点(或者两种都试试)。

考虑枚举 x,等价于在 B 上枚举区间 [x,x+n-1]。考虑下标如何匹配,我们令 A_i 表示 ia 中的下标,同理有 B_i,那么只需比较 B[x,x+n-1] 离散化后与 A 是否完全相同即可。

考虑哈希,维护 B[x,x+n-1] 离散化后的哈希值 HASH。暴力处理 B[1,n] 的哈希值,那么接下来考虑相邻区间的转化即可,即删去第一个元素并在后面增加一个元素。

补充概念:字符串哈希等价于 base 进制,即 HASH(S[1,n])=\sum S_i\times base^{n-i},我们称 base^{n-i} 项为 S_i 的权重。

定义 query(l,r) 表示区间内元素的权重之和,rk(k) 表示 k 离散化后的值。前者开一棵权值线段树维护,后者使用树状数组维护。

考虑删去区间第一个元素 k。首先将 HASH-rk(k)\times base^{n-1},除此之外 HASH-query(k+1,m),这是因为删去 k 后区间内 >k 的元素离散化后的值都会 -1,即少一个权重。同时维护线段树。

考虑在区间后面接一个元素 k。同上,先将 HASH\times base+rk(k),然后 HASH+query(k+1,m)。同时维护线段树。

复杂度 O(n\log n)

代码

#include<bits/stdc++.h>
using namespace std;

#define int long long
const int N = 3e5 + 5, base = 1331, mod = 1e9 + 7;
int n, m, a[N], b[N], t[N], bas[N];
int A[N], B[N], S[N], hsh, hsh2, ans;

inline void upd(int a, int k) {for(; a < N; a += a & -a) t[a] += k;}
inline int rk(int a) {int res = 0; for(; a > 0; a -= a & -a) res += t[a]; return res;}
namespace Sg_Tree{
    #define lt (u << 1)
    #define rt (u << 1 | 1)
    #define mid (l + r >> 1)
    int s[N << 2], tag[N << 2];

    inline void psup(int u) {s[u] = (s[lt] + s[rt]) % mod;}
    inline void psdw(int u) {
        s[lt] = s[lt] * tag[u] % mod, s[rt] = s[rt] * tag[u] % mod;
        tag[lt] = tag[lt] * tag[u] % mod, tag[rt] = tag[rt] * tag[u] % mod;
        tag[u] = 1;
    }
    inline void build(int u, int l, int r){
        tag[u] = 1;
        if(l == r) {s[u] = S[l]; return ;}
        build(lt, l, mid), build(rt, mid + 1, r);
        psup(u);
    }
    inline void upd1(int u, int l, int r, int k, int op){
        if(l == r) {s[u] = tag[u] = op; return ;}
        psdw(u);
        if(k <= mid) upd1(lt, l, mid, k, op);
        else upd1(rt, mid + 1, r, k, op);
        psup(u);
    }
    inline void upd2() {s[1] = s[1] * base % mod; tag[1] = tag[1] * base % mod;}
    inline int query(int u, int l, int r, int ll, int rr){
        if(ll > rr) return 0;
        if(ll <= l && r <= rr) return s[u];
        psdw(u);
        int res = 0;
        if(ll <= mid) res = query(lt, l, mid, ll, rr);
        if(rr > mid) res = (res + query(rt, mid + 1, r, ll, rr)) % mod;
        psup(u);
        return res;
    }
}
using namespace Sg_Tree;
signed main(){
    cin >> n >> m;
    for(int i = 0; i <= m; ++i) bas[i] = !i ? 1 : bas[i - 1] * base % mod;
    for(int i = 1; i <= n; ++i) scanf("%lld", &a[i]), A[a[i]] = i;
    for(int i = 1; i <= m; ++i) scanf("%lld", &b[i]), B[b[i]] = i;
    if(n > m) {cout << 0; return 0;}

    for(int i = 1; i <= n; ++i) 
        hsh = (hsh * base % mod + A[i]) % mod, upd(B[i], 1);
    for(int i = 1; i <= n; ++i) 
        hsh2 = (hsh2 * base % mod + rk(B[i])) % mod;

    S[B[n]] = 1;
    for(int i = n - 1; i >= 1; --i) S[B[i]] = S[B[i + 1]] * base % mod; 
    build(1, 1, m);

    for(int i = 1; i <= m - n + 1; ++i){
        if(hsh == hsh2) ++ans;
        if(i + n <= m){
            int R = i + n;
            //删首
            hsh2 = (((hsh2 - bas[n - 1] * rk(B[i]) % mod)) % mod + mod) % mod; 
            hsh2 = ((hsh2 - query(1, 1, m, B[i] + 1, m)) % mod + mod) % mod;
            upd1(1, 1, m, B[i], 0), upd(B[i], -1);
            //加尾
            upd2(), hsh2 = (hsh2 * base + rk(B[R]) + 1)% mod;
            hsh2 = (hsh2 + query(1, 1, m, B[R] + 1, m)) % mod;
            upd1(1, 1, m, B[R], 1), upd(B[R], 1);
        }
    }
    cout << ans;
    return 0;
}