题解 P4690 【[Ynoi2016]镜中的昆虫】

· · 题解

相对 Ynoi 来说更偏结论和模板的 cdq 题。

前置知识:CDQ 分治、时间复杂度分析、卡常。

首先这玩意看起来不太好做,考虑弱化为单点修改。

单点修改区间颜色数有一个非常套路的转换,就是维护每一个点的颜色的前驱(即对每一个 i 维护最大的 p_i 使得 p_i<ia_{p_i}=a_i,如果不存在则为 0),然后每次统计满足 l\leq i\leq r, 1\leq p_i<li 的数量。

显然这玩意就是一个二维数点。

但是带上单点修改之后就没法简单地维护了怎么办呢?

显然,每次单点修改 ix,只会改变 i 的前驱、i 的原后继的前驱、i 的新后继的前驱。

修改一个点可以拆分为先删除再加入,然后在 t 时间删除一个点 (x,y) 可以通过加一个权值为 -1 的三维空间内的点 (t,x,y),然后再加入一个权值为 1 的点 (t,x',y') 来实现。

然后,因为询问只考虑到询问时间之前的修改的影响,所以这就转换为了一个三维偏序,就可以树套树/CDQ 维护了。

那么区间修改怎么做呢?

【区间修改的时候,p 数组最多发生多少次改变?】
那还用说吗,每次改变 O(n) 个位置,显然是 O(nm)
【好,你说是 O(nm),怎么卡满这个复杂度?】
嗯……
【试了 15 分钟】
为啥就死活卡不满呢……
【也许不是 O(nm)?】
也许吧……试着证一下吧。

每次推平一段区间的时候,所有 p 改变的位置一定是 一段极长连续同色子段的左端点,剩下的都将 维持为下标减 1 不变

推平之后,这些连续同色子段都会被 删除

如果我们删掉了 O(x) 个节点,花费的代价一定是 O(x) 的。

删除的过程中,先要执行 ODT 的 split,这一步增加的点的数量是 O(1) 的。

然后花费 O(x) 的代价删去 O(x) 个节点。

然后再把这一段加回来,增加的点的数量是 O(1) 的。

所以增加的点的数量是 O(m) 的。

一开始有 O(n) 个点,所以总点数是 O(n+m) 的。

显然只有 O(n+m) 个点,删除的总次数也只有 O(n+m)

每次删除对总修改次数的贡献为 O(1)

得到结论:

所有修改操作完成后,p 数组上单点修改的次数为 O(n+m)

好!那么这就变成了单点修改!

现在的问题就变成了如何维护出需要修改的点。

显然还是要拿一棵 ODT,但是区间修改的时候要是直接讨论需要修改的位置……看着……看着……就不想写了……

换一个思路,如果把所有需要修改的位置存下来,然后排除掉 p 数组直接修改 ODT,然后再对这些位置看要改成啥……诶!好像非常好写!

为了方便处理前驱,ODT 上面还要再额外对每一种颜色维护一个 set。

那么现在就只需要把所有修改的位置处理出来,然后使用单点修改的方法拆点转换为三维偏序,这题就做完啦~

最终时间复杂度 O((n+m)\log ^2 n),空间复杂度 O(n+m)

然后直接写好了交上去,就 TLE+MLE 了 QAQ

我们发现一个点的空间开销较大,需要 5 个 int,然后这个点数的 O(n) 带大常数,所以在 CDQ 分治的归并排序过程中不去维护点的临时数组 p',而是维护 点的下标 的临时数组 v',其中 v'_i 实际对应的点是 p_{v'_i}

这样临时数组的空间就优掉了很多,于是就不 MLE 了。

然后有两个主要的剪枝:

  1. 对于所有查询点维护出最大的 X、最大的 Y 和最大的 Z,然后将所有满足 x_0>Xy_0>Yz_0>Z修改点剔除(因为这些点不可能产生贡献);
  2. 树状数组清零的时候可以另写一个函数,清到一个位置为 0 的时候就没有必要往下走了。

加上这两个剪枝和一些常用的卡常技巧(如快读)再打开 O2 就可以爆过去啦~

代码:

#include <iostream>
#include <cmath>
#include <cstring>
#include <cstdio>
#include <set>
#include <vector>
#include <tr1/unordered_map>
using namespace std;
using namespace std::tr1;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int qread() {
    register char c = getchar();
    register int x = 0, f = 1;
    while (c < '0' || c > '9') {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return x * f;
}

inline int Abs(const int& x) {return (x > 0 ? x : -x);}
inline int Max(const int& x, const int& y) {return (x > y ? x : y);}
inline int Min(const int& x, const int& y) {return (x < y ? x : y);}

struct Node {
    int l, r, v;
    bool operator < (const Node& b) const {return l < b.l;}
    Node() {}
    Node(int l, int r, int v) : l(l), r(r), v(v) {}
};
struct Point {
    int _time, x, y, val;
    bool type;
    Point() {}
    Point(int _time, int x, int y, int val, bool type) : _time(_time), x(x), y(y), val(val), type(type) {}
};
int n, prv[100005], a[100005], m, mptop, lst[100005], ans[100005], c[100005], idx[1200005];
vector <Point> vc;
unordered_map <int, int> mp;
struct ChthollyTree {
    typedef set<Node>::iterator SNI;
    set <Node> tr, col[200005];
    inline SNI Insert(int l, int r, int v) {
        col[v].insert(Node(l, r, v));
        return tr.insert(Node(l, r, v)).first;
    }
    inline void Delete(int l, int r, int v) {
        col[v].erase(Node(l, r, v));
        tr.erase(Node(l, r, v));
    }
    inline int Prev(int pos) {
        SNI it = tr.upper_bound(Node(pos, 0, 0));
        it--;
        if (it->l < pos) return pos - 1;
        else {
            SNI rm = col[it->v].lower_bound(Node(pos, 0, 0));
            if (rm != col[it->v].begin()) {
                rm--;
                return rm->r;
            }
            return 0;
        }
    }
    inline SNI Split(int pos) {
        SNI it = tr.lower_bound(Node(pos, 0, 0));
        if (it != tr.end() && it->l == pos) return it;
        --it;
        int l = it->l, r = it->r;
        long long v = it->v;
        Delete(l, r, v);
        Insert(l, pos - 1, v);
        return Insert(pos, r, v);
    }
    inline void Assign(int l, int r, int v, int _time) {
        //printf("%d:\n", _time);
        SNI itr = Split(r + 1), itl = Split(l);
        vector <int> mpnt;
        for (SNI it = itl;it != itr;it++) {
            if (it != itl) mpnt.push_back(it->l);
            SNI nxt = col[it->v].upper_bound(*it);
            if (nxt != col[it->v].end()) mpnt.push_back(nxt->l);
            col[it->v].erase(*it);
        }
        tr.erase(itl, itr);
        tr.insert(Node(l, r, v));
        col[v].insert(Node(l, r, v));
        mpnt.push_back(l);
        SNI nxt = col[v].upper_bound(Node(l, r, v));
        if (nxt != col[v].end()) mpnt.push_back(nxt->l);
        register int siz = mpnt.size();
        for (register int i = 0;i < siz;i++) {
            if (mpnt[i] > n || mpnt[i] < 1) continue;
            vc.push_back(Point(_time, mpnt[i], prv[mpnt[i]], -1, 0));
            prv[mpnt[i]] = Prev(mpnt[i]);
            //printf("%d %d\n", mpnt[i], prv[mpnt[i]]);
            vc.push_back(Point(_time, mpnt[i], prv[mpnt[i]], 1, 0));
        }
        //puts("End");
    }
};
ChthollyTree tr;

inline int Lowbit(int x) {
    return x & -x;
}

inline void Update(int i, int x) {
    for (register int j = i;j <= n;j += Lowbit(j)) c[j] += x;
}

inline void Remove(int i) {
    for (register int j = i;j <= n;j += Lowbit(j)) {
        if (!c[j]) break;
        else c[j] = 0;
    }
}

inline int Query(int i) {
    register int ans = 0;
    for (register int j = i;j >= 1;j -= Lowbit(j)) ans += c[j];
    return ans;
}

inline void Read() {
    n = qread(); m = qread();
    for (register int i = 1;i <= n;i++) {
        a[i] = qread();
        if (!mp[a[i]]) {
            mptop++;
            mp[a[i]] = mptop;
        }
        a[i] = mp[a[i]];
    }
}

inline void Prefix() {
    vc.push_back(Point(-1, 0, 0, 0, 0));
    for (register int i = 1;i <= n;i++) {
        prv[i] = lst[a[i]];
        vc.push_back(Point(0, i, prv[i], 1, 0));
        lst[a[i]] = i;
        tr.tr.insert(Node(i, i, a[i]));
        tr.col[a[i]].insert(Node(i, i, a[i]));
    }
}

inline void Cdq(int l, int r) {
    //printf("cdq(%d,%d){\n", l, r);
    if (l == r) return;
    register int mid = l + r >> 1;
    Cdq(l, mid);
    //puts("}");
    Cdq(mid + 1, r);
    //puts("}");
    vector <int> tmp;
    int i = l, j = mid + 1;
    while (i <= mid && j <= r) {
        if (vc[idx[i]].x <= vc[idx[j]].x) {
            if (vc[idx[i]].type == 0) Update(vc[idx[i]].y, vc[idx[i]].val);
            tmp.push_back(idx[i]);
            i++;
        } else {
            if (vc[idx[j]].type == 1) {
                ans[vc[idx[j]]._time] += vc[idx[j]].val * Query(vc[idx[j]].y);
                //printf("%d %d %d %d\n", vc[j]._time, vc[j].x, vc[j].y, ans[vc[j]._time]);
            }
            tmp.push_back(idx[j]);
            j++;
        }
    }
    while (j <= r) {
        if (vc[idx[j]].type == 1) {
            ans[vc[idx[j]]._time] += vc[idx[j]].val * Query(vc[idx[j]].y);
            //printf("%d %d %d %d\n", vc[j]._time, vc[j].x, vc[j].y, ans[vc[j]._time]);
        }
        tmp.push_back(idx[j]);
        j++;
    }
    for (register int k = l;k < i;k++) {
        if (vc[idx[k]].type == 0) Remove(vc[idx[k]].y);
    }
    while (i <= mid) {
        tmp.push_back(idx[i]);
        i++;
    }
    for (i = l;i <= r;i++) idx[i] = tmp[i - l];
}

inline void Solve() {
    for (register int i = 1;i <= m;i++) {
        register int opt = qread();
        if (opt == 1) {
            register int l = qread(), r = qread(), v = qread();
            if (!mp[v]) {
                mptop++;
                mp[v] = mptop;
            }
            v = mp[v];
            tr.Assign(l, r, v, i);
        } else {
            register int l = qread(), r = qread();
            vc.push_back(Point(i, r, l - 1, 1, 1));
            vc.push_back(Point(i, l - 1, l - 1, -1, 1));
        }
    }
    register int pcnt = vc.size() - 1;
    for (register int i = 1;i <= pcnt;i++) vc[i].y++;
    // for (register int i = 1;i <= pcnt;i++) {
    //  printf("%d %d %d %d %d\n", vc[i].type, vc[i]._time, vc[i].x, vc[i].y, vc[i].val);
    // }
    register int maxx = 0, maxy = 0, maxz = 0;
    vector <int> ridx;
    for (register int i = 1;i <= pcnt;i++) {
        if (vc[i].type) {
            maxx = Max(maxx, vc[i]._time);
            maxy = Max(maxy, vc[i].x);
            maxz = Max(maxz, vc[i].y);
        }
    }
    for (register int i = 1;i <= pcnt;i++) {
        if (vc[i].type || (vc[i]._time <= maxx && vc[i].x <= maxy && vc[i].y <= maxz)) ridx.push_back(i);
    }
    pcnt = ridx.size();
    for (register int i = 1;i <= pcnt;i++) idx[i] = ridx[i - 1];
    Cdq(1, pcnt);
    for (register int i = 1;i <= m;i++) {
        if (ans[i]) printf("%d\n", ans[i]);
    }
}

int main() {
    Read();
    Prefix();
    Solve();
    return 0;
}