Luogu P12389 COmPoUNdS

· · 题解

官方题解是不是有点幽默了。

考虑这个区间加的操作:\forall i\in [l, r], a_i\leftarrow (a_i + c)\bmod k
此时会发现 \forall i\in (l, r], (a_i - a_{i - 1})\bmod k 的值依然不会变。
即模意义下差分值依然只有 l, r + 1 会有变化。

此时再考虑这个判相等,其实可以根据差分拆成两个条件:

于是可以对条件 1 上一个树状数组(维护不带取模,最后 check 再取模)。
对于条件 2 上一个单点改区间查 hash 值的线段树即可(是不是树状数组还更好写来着)。

时间复杂度 \mathcal{O}(n + q\log n)

#include<bits/stdc++.h>

using u64 = unsigned long long;
constexpr u64 p = 13331, mod = 998244853;

constexpr int maxn = 1e6 + 10;
int n, K, q;

struct info_ {
    u64 sum, pw;
    inline info_(u64 sum_ = 0, u64 pw_ = 1) : sum(sum_), pw(pw_) {}
    inline info_ operator + (const info_ &oth) const {
        return info_((sum + pw * oth.sum) % mod, pw * oth.pw % mod);
    }
} tr[maxn * 4];

int a[maxn];
inline void build(int k = 1, int l = 1, int r = n) {
    if (l == r) {
        tr[k] = info_((a[l] - a[l - 1] + K) % K, p);
        return ;
    }
    int mid = l + r >> 1;
    build(k << 1, l, mid), build(k << 1 | 1, mid + 1, r);
    tr[k] = tr[k << 1] + tr[k << 1 | 1];
}
inline void update(int x, int y, int k = 1, int l = 1, int r = n) {
    if (l == r) {
        tr[k].sum = (tr[k].sum + y) % K;
        return ;
    }
    int mid = l + r >> 1;
    if (x <= mid) update(x, y, k << 1, l, mid);
    else update(x, y, k << 1 | 1, mid + 1, r);
    tr[k] = tr[k << 1] + tr[k << 1 | 1];
}
inline info_ query(int x, int y, int k = 1, int l = 1, int r = n) {
    if (x <= l && r <= y) return tr[k];
    int mid = l + r >> 1;
    if (y <= mid) return query(x, y, k << 1, l, mid);
    if (mid < x) return query(x, y, k << 1 | 1, mid + 1, r);
    return query(x, y, k << 1, l, mid) + query(x, y, k << 1 | 1, mid + 1, r);
}

u64 sum[maxn];
inline void add(int x, int y) {
    for (; x <= n; x += x & -x) sum[x] += y;
}
inline u64 qry(int x) {
    u64 y = 0;
    for (; x >= 1; x -= x & -x) y += sum[x];
    return y;
}

int main() {
    scanf("%d%d%d", &n, &K, &q);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);

    build();
    for (int i = 1; i <= n; i++) {
        sum[i] += (u64)a[i] - a[i - 1];
        if (i + (i & -i) <= n) {
            sum[i + (i & -i)] += sum[i];
        }
    }

    for (int i = 1; i <= q; i++) {
        int op;
        scanf("%d", &op);

        if (op == 1) {
            int l, r, c;
            scanf("%d%d%d", &l, &r, &c);

            update(l, c), add(l, c);
            if (r < n) update(r + 1, K - c), add(r + 1, -c);
        } else {
            int l1, r1, l2, r2;
            scanf("%d%d%d%d", &l1, &r1, &l2, &r2);

            if (qry(l1) % K != qry(l2) % K) {
                puts("No");
            } else if (l1 < r1 && query(l1 + 1, r1).sum != query(l2 + 1, r2).sum) {
                puts("No");
            } else {
                puts("Yes");
            }
        }
    }

    return 0;
}