题解:P6707 [COCI 2010/2011 #7] UPIT

· · 题解

看到这种区间操作加上插入元素的题,很明显线段树维护不了,于是就可以想到 Splay 维护区间操作。

如果你不知道怎么用 Splay 维护区间操作,请先完成文艺平衡树。

那么接下来的问题就是如何设计懒标记了。这道题有两个区间操作,因此需要两个懒标记:区间推平标记和等差数列标记。对于一个区间,钦定懒标记表示的顺序为先进行区间推平再加上一个等差数列。

对于等差数列的标记,为了可并性和可拆分性,可以对一个区间 [l,r] 打上一个 (v,c) 的标记表示把这个区间从 a_1,a_2,\cdots,a_n 变换为

a_1+v,a_2+v+c,\cdots,a_k+v+(k-1)c,\cdots

即加上一个首项为 v,公差为 c 的等差数列。

于是两个标记 (v_1,c_1)(v_2,c_2) 可以合并为 (v_1+v_2,c_1+c_2)。一个 [l,r] 的懒标记 (v,c) 可以拆分为 [l,mid] 的标记 (v,c)[mid+1,r] 的标记 (v+c(mid-l+1),c)

:::success[代码]

#include <cstdint>
#include <iostream>
#define int int64_t
#define endl '\n'

using namespace std;

constexpr int N = 2e5 + 10;
int n, m;
// tagm=[<是否存在区间推平 tag>, <tag 的值>]
pair<bool, int> tagm[N];
// taga=[<是否存在等差数列 tag>, [<等差数列起始项>, <等差数列公差>]]
pair<bool, pair<int, int>> taga[N];
int rt, idx, val[N], siz[N], sum[N], fa[N], ch[N][2];

void clear(int u) { val[u] = 0, siz[u] = 0, tagm[u] = {false, 0}, taga[u] = {false, {0, 0}}; }
int where(int u) { return ch[fa[u]][1] == u; }

void pushup(int u) {
    siz[u] = siz[ch[u][0]] + 1 + siz[ch[u][1]];
    sum[u] = sum[ch[u][0]] + val[u] + sum[ch[u][1]];
}

// 添加区间推平 tag
void addtagm(int u, int x) {
    val[u] = x;
    sum[u] = x * siz[u];
    // 更新的推平操作,那么等差数列 tag 就没有影响了
    taga[u].first = false;
    tagm[u] = {true, x};
}

// 添加等差数列 tag
void addtaga(int u, int vl, int vc) {
    val[u] += vl + vc * siz[ch[u][0]];
    // sum 要累加整个等差数列的贡献
    sum[u] += vl * siz[u] + (siz[u] - 1) * siz[u] / 2 * vc;
    if (!taga[u].first) taga[u].second = {0, 0};
    taga[u].second.first += vl, taga[u].second.second += vc;
    taga[u].first = true;
}

void pushdown(int u) {
    if (tagm[u].first) {
        if (ch[u][0]) addtagm(ch[u][0], tagm[u].second);
        if (ch[u][1]) addtagm(ch[u][1], tagm[u].second);
        tagm[u].first = false;
    }
    if (taga[u].first) {
        if (ch[u][0]) addtaga(ch[u][0], taga[u].second.first, taga[u].second.second);
        if (ch[u][1]) addtaga(ch[u][1], taga[u].second.first + (siz[ch[u][0]] + 1) * taga[u].second.second, taga[u].second.second);
        taga[u].first = false;
    }
}

void rotate(int x) {
    int y = fa[x], z = fa[fa[x]], k = where(x);
    fa[ch[z][where(y)] = x] = z;
    fa[ch[y][k] = ch[x][k ^ 1]] = y;
    fa[ch[x][k ^ 1] = y] = x;
    pushup(y), pushup(x), clear(0);
}

void splay(int& rt, int u) {
    int w = fa[rt];
    for (int v; (v = fa[u]) != w; rotate(u))
        if (fa[v] != w) rotate(where(u) == where(v) ? v : u);
    rt = u;
}

void loc(int& rt, int k) {
    int u = rt;
    while (true) {
        pushdown(u);
        if (k <= siz[ch[u][0]]) u = ch[u][0];
        else if (k > siz[ch[u][0]] + 1)
            k -= siz[ch[u][0]] + 1, u = ch[u][1];
        else break;
    }
    splay(rt, u);
}

void updm(int l, int r, int val) {
    loc(rt, l), loc(ch[rt][1], r - l + 2);
    int u = ch[ch[rt][1]][0];
    addtagm(u, val), pushdown(u);
    pushup(ch[rt][1]), pushup(rt);
}

void upda(int l, int r, int vl, int vc) {
    loc(rt, l), loc(ch[rt][1], r - l + 2);
    int u = ch[ch[rt][1]][0];
    addtaga(u, vl, vc), pushdown(u);
    pushup(ch[rt][1]), pushup(rt);
}

int query(int l, int r) {
    loc(rt, l), loc(ch[rt][1], r - l + 2);
    int u = ch[ch[rt][1]][0];
    int res = sum[u];
    return res;
}

void insert(int k, int x) {
    int u = rt, v = 0, p = 0;
    while (u) {
        pushdown(u);
        v = u;
        if (k <= siz[ch[u][0]] + 1) u = ch[u][0], p = 0;
        else k -= siz[ch[u][0]] + 1, u = ch[u][1], p = 1;
    }
    u = ++idx;
    fa[ch[v][p] = u] = v;
    val[u] = x;
    pushup(u), pushup(v);
    splay(rt, u);
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    fa[rt = 1] = 0;
    for (int i = 1; i <= n; i++) {
        fa[ch[i][1] = i + 1] = i;
        cin >> val[i + 1];
    }
    fa[ch[n + 1][1] = n + 2] = n + 1;
    val[n + 2] = sum[n + 2] = 0;
    splay(rt, n + 2);
    idx = n + 2;
    while (m --> 0) {
        int opt; cin >> opt;
        if (opt == 1) {
            int l, r, x;
            cin >> l >> r >> x;
            updm(l, r, x);
        } else if (opt == 2) {
            int l, r, x;
            cin >> l >> r >> x;
            upda(l, r, x, x);
        } else if (opt == 3) {
            int p, x;
            cin >> p >> x;
            insert(p + 1, x);
        } else {
            int l, r;
            cin >> l >> r;
            cout << query(l, r) << endl;
        }
    }
    return 0;
}

:::

当然如果你是分块的忠实爱好者,看到插入操作当然可以用块状链表,跑了一下发现跑得比 Splay 还快,现在单 \log 大不过根号了(或许只是我常数大)。不过还是要上面的懒标记。

:::success[块状链表代码]

#include <cmath>
#include <cstdint>
#include <iostream>
#define int int64_t
#define endl '\n'

using namespace std;

constexpr int N = 2e5 + 10, M = 1010;
int n, m, a[N], blocklen, mxlen;

struct Block {
    Block *pre, *nex;
    int len, val[M], sum;
    pair<bool, int> tagm;
    pair<bool, pair<int, int>> taga;

    Block() : pre(nullptr), nex(nullptr), len(0), sum(0), tagm(false, 0), taga(false, {0, 0}) {}
};

void addtagm(Block* b, int x) {
    b->sum = x * b->len;
    b->tagm = {true, x};
    b->taga.first = false;
}

void addtaga(Block* b, int vl, int vc) {
    b->sum += vl * b->len + b->len * (b->len - 1) / 2 * vc;
    if (!b->taga.first) b->taga.second = {0, 0};
    b->taga.second.first += vl, b->taga.second.second += vc;
    b->taga.first = true;
}

void pushdown(Block* b) {
    b->sum = 0;
    for (int i = 1; i <= b->len; i++)
        b->sum += (b->val[i] = (b->tagm.first ? b->tagm.second : b->val[i]) + (b->taga.first ? b->taga.second.first + b->taga.second.second * (i - 1) : 0));
    b->tagm.first = b->taga.first = false;
}

Block *head;

void updm(int l, int r, int x) {
    int pos = 0;
    Block *cur = head;
    while (pos + cur->len < l)
        pos += cur->len, cur = cur->nex;
    if (r <= pos + cur->len) {
        pushdown(cur);
        for (int i = l; i <= r; i++) {
            cur->sum -= cur->val[i - pos];
            cur->val[i - pos] = x;
            cur->sum += cur->val[i - pos];
        }
    } else {
        pushdown(cur);
        for (int i = l - pos; i <= cur->len; i++) {
            cur->sum -= cur->val[i];
            cur->val[i] = x;
            cur->sum += cur->val[i];
        }
        Block* ruc = cur->nex;
        pos += cur->len;
        while (r > pos + ruc->len) {
            addtagm(ruc, x);
            pos += ruc->len, ruc = ruc->nex;
        }
        pushdown(ruc);
        for (int i = 1; i <= r - pos; i++) {
            ruc->sum -= ruc->val[i];
            ruc->val[i] = x;
            ruc->sum += ruc->val[i];
        }
    }
}

void upda(int l, int r, int vl, int vc) {
    int pos = 0;
    Block *cur = head;
    while (pos + cur->len < l)
        pos += cur->len, cur = cur->nex;
    if (r <= pos + cur->len) {
        pushdown(cur);
        for (int i = l; i <= r; i++) {
            cur->sum -= cur->val[i - pos];
            cur->val[i - pos] += vl + (i - l) * vc;
            cur->sum += cur->val[i - pos];
        }
    } else {
        pushdown(cur);
        for (int i = l - pos; i <= cur->len; i++) {
            cur->sum -= cur->val[i];
            cur->val[i] += vl + (i + pos - l) * vc;
            cur->sum += cur->val[i];
        }
        Block* ruc = cur->nex;
        pos += cur->len;
        while (r > pos + ruc->len) {
            addtaga(ruc, vl + (pos - l + 1) * vc, vc);
            pos += ruc->len, ruc = ruc->nex;
        }
        pushdown(ruc);
        for (int i = 1; i <= r - pos; i++) {
            ruc->sum -= ruc->val[i];
            ruc->val[i] += vl + (i + pos - l) * vc;
            ruc->sum += ruc->val[i];
        }
    }
}

int query(int l, int r) {
    int pos = 0, res = 0;
    Block *cur = head;
    while (pos + cur->len < l)
        pos += cur->len, cur = cur->nex;
    if (r <= pos + cur->len) {
        pushdown(cur);
        for (int i = l; i <= r; i++)
            res += cur->val[i - pos];
    } else {
        pushdown(cur);
        for (int i = l - pos; i <= cur->len; i++)
            res += cur->val[i];
        int tmp = res;
        Block* ruc = cur->nex;
        pos += cur->len;
        while (r > pos + ruc->len) {
            res += ruc->sum;
            pos += ruc->len, ruc = ruc->nex;
        }
        tmp = res;
        pushdown(ruc);
        for (int i = 1; i <= r - pos; i++)
            res += ruc->val[i];
    }
    return res;
}

void insert(int k, int x) {
    int pos = 0;
    Block *cur = head;
    while (!(pos < k && k <= pos + cur->len + 1))
        pos += cur->len, cur = cur->nex;
    pushdown(cur);
    for (int i = cur->len; i >= k - pos; i--)
        cur->val[i + 1] = cur->val[i];
    cur->val[k - pos] = x, cur->sum += x, cur->len++;
    return;
    if (cur->len == mxlen) {
        Block *temp = new Block;
        cur->len = mxlen / 2;
        temp->len = mxlen - cur->len;
        for (int i = cur->len + 1; i <= mxlen; i++)
            temp->val[i - cur->len] = cur->val[i],
            cur->sum -= cur->val[i], temp->sum += cur->val[i];
        if (cur->nex != nullptr)
            temp->nex = cur->nex, cur->nex->pre = temp;
        cur->nex = temp, temp->pre = cur;
    }
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    blocklen = sqrt(1.5 * n), mxlen = 2 * blocklen;
    Block *cur = head = new Block;
    for (int i = 1; i <= n; i++) {
        if (cur->len == blocklen) {
            cur->nex = new Block;
            cur->nex->pre = cur;
            cur = cur->nex;
        }
        cur->val[++cur->len] = a[i];
        cur->sum += a[i];
    }
    while (m --> 0) {
        int opt; cin >> opt;
        if (opt == 1) {
            int l, r, x;
            cin >> l >> r >> x;
            updm(l, r, x);
        } else if (opt == 2) {
            int l, r, x;
            cin >> l >> r >> x;
            upda(l, r, x, x);
        } else if (opt == 3) {
            int p, x;
            cin >> p >> x;
            insert(p, x);
        } else {
            int l, r;
            cin >> l >> r;
            cout << query(l, r) << endl;
        }
    }
    return 0;
}

:::