平衡树 + 类欧几里得

· · 题解

这题用平衡树会更方便一些,不用离散化。直接写个 WBLT,然后写个结点回收,就也用不了多少空间,还跑得飞快。

具体讲一下思路,就是每个结点处存当前线段参数 (a,b,c,n) 和所求的和式 su。显然,对于线段

y=(ax+b)/c,~1\le x\le n

来说,所求和式等于

\begin{aligned} \sum_{i=1}^n(ai+b)\bmod c &= \sum_{i=1}^n\left((ai+b)-c\left\lfloor\dfrac{ai+b}{c}\right\rfloor\right) \\ &= \dfrac{a}{2}n(n+1)+bn-c\sum_{i=1}^n\left\lfloor\dfrac{ai+b}{c}\right\rfloor. \end{aligned}

最后一项可以通过类欧几里得算法实现。

WBLT 中,除了叶子结点外,不需要存储多余的信息,把所求和式穿上去就行。

每次覆盖一条新的线段的时候,直接分裂出相应的子树,然后替换成新的线段的参数,再合并,就完成了一次更新。

分裂的时候需要算一下两边的线段参数,其中,ac 是不变的,只要把右边线段的截距 b 算出来就可以了。然后重新更新一下两侧的和 su

#include <iostream>
#include <vector>
#define fast_io std::ios::sync_with_stdio(false), std::cin.tie(nullptr)

long long sum_floor(long long a, long long b, long long c, long long n) {
    long long n2 = n * (n + 1) / 2;
    if (a >= c || b >= c) {
        long long aa = a / c, bb = b / c;
        return sum_floor(a % c, b % c, c, n) + aa * n2 + bb * (n + 1);
    } else if (a) {
        long long m = (a * n + b) / c;
        return m * n - sum_floor(c, c - b - 1, a, m - 1);
    } else {
        return 0;
    }
}

class WBLT {
    struct Node {
        int ch[2], sz, a, b, c, n;
        long long su;
        Node() : ch{}, sz{}, a{}, b{}, c{}, n{}, su{} {}

        void update() {
            su = (long long)n * b + (long long)(n + 1) * n / 2 * a
                - (long long)c * (sum_floor(a, b, c, n) - (b / c));
        }
    };

    int id, rt, top;
    std::vector<int> pool;
    std::vector<Node> node;

#define getter(var) \
    auto var(int x) const { return node[x].var; } \
    auto& var(int x) { return node[x].var; }

    getter(sz)
    getter(a)
    getter(b)
    getter(c)
    getter(n)
    getter(su)

#undef getter

    auto& ch(int x, int i) { return node[x].ch[i]; }
    auto ch(int x, int i) const { return node[x].ch[i]; }

    void push_up(int x) {
        sz(x) = sz(ch(x, 0)) + sz(ch(x, 1));
        su(x) = su(ch(x, 0)) + su(ch(x, 1));
        n(x) = n(ch(x, 0)) + n(ch(x, 1));
    }

    void push_down(int x) {}

    int new_node() {
        int x = top ? pool[--top] : ++id;
        node[x] = Node();
        return x;
    }

    void del_node(int& x) {
        pool[top++] = x;
        x = 0;
    }

    int new_leaf(int _a, int _b, int _c, int _n) {
        int x = new_node();
        sz(x) = 1;
        a(x) = _a;
        b(x) = _b;
        c(x) = _c;
        n(x) = _n;
        node[x].update();
        return x;
    }

    int join(int x, int y) {
        int z = new_node();
        ch(z, 0) = x;
        ch(z, 1) = y;
        push_up(z);
        return z;
    }

    auto cut(int& x) {
        push_down(x);
        int y = ch(x, 0);
        int z = ch(x, 1);
        del_node(x);
        return std::pair(y, z);
    }

    bool too_heavy(int sx, int sy) {
        return sx * 2 > sy * 5;
    }

    int merge(int x, int y) {
        if (!x || !y) return x | y;
        if (too_heavy(sz(x), sz(y))) {
            auto [a, b] = cut(x);
            if (too_heavy(sz(b) + sz(y), sz(a))) {
                auto [c, d] = cut(b);
                return merge(merge(a, c), merge(d, y));
            } else {
                return merge(a, merge(b, y));
            }
        } else if (too_heavy(sz(y), sz(x))) {
            auto [a, b] = cut(y);
            if (too_heavy(sz(a) + sz(x), sz(b))) {
                auto [c, d] = cut(a);
                return merge(merge(x, c), merge(d, b));
            } else {
                return merge(merge(x, a), b);
            }
        } else {
            return join(x, y);
        }
    }

    std::pair<int, int> split(int x, int k) {
        if (!x) return { 0, 0 };
        if (!k) return { 0, x };
        if (k == n(x)) return { x, 0 };
        if (sz(x) == 1) {
            auto y = new_leaf(a(x), (b(x) + (long long)a(x) * k) % c(x), c(x), n(x) - k);
            n(x) = k;
            su(x) -= su(y);
            return { x, y };
        }
        auto [a, b] = cut(x);
        if (k <= n(a)) {
            auto [ll, rr] = split(a, k);
            return { ll, merge(rr, b) };
        } else {
            auto [ll, rr] = split(b, k - n(a));
            return { merge(a, ll), rr };
        }
    }

    void del_tree(int& x) {
        if (!x) return;
        del_tree(ch(x, 0));
        del_tree(ch(x, 1));
        del_node(x);
    }

public:
    WBLT(int n, int N) : id(0), rt(0), top(0), pool(n << 1), node(n << 1) {
        rt = new_leaf(0, 0, 1, N);
    }

    void modify(int l, int r, int a, int b) {
        int ll, rr;
        std::tie(rt, rr) = split(rt, r);
        std::tie(ll, rt) = split(rt, l - 1);
        del_tree(rt);
        rt = new_leaf(a, 0, b, r - l + 1);
        rt = merge(merge(ll, rt), rr);
    }

    long long find_sum(int l, int r) {
        int ll, rr;
        std::tie(rt, rr) = split(rt, r);
        std::tie(ll, rt) = split(rt, l - 1);
        auto res = su(rt);
        rt = merge(merge(ll, rt), rr);
        return res;
    }
};

int main() {
    fast_io;
    int n, q;
    std::cin >> n >> q;
    WBLT st(q << 1, n);
    for (; q; --q) {
        int op;
        std::cin >> op;
        if (op == 1) {
            int l, r, a, b;
            std::cin >> l >> r >> a >> b;
            st.modify(l, r, a, b);
        } else if (op == 2) {
            int l, r;
            std::cin >> l >> r;
            std::cout << st.find_sum(l, r) << '\n';
        }
    }
    return 0;
}