P14164 [ICPC 2022 Nanjing R] 命题作文

· · 题解

根据经典结论,在第 i 次操作后,我们这样计算答案:

前两种容易用并查集或 set 维护。考虑计算第三种贡献。

我们把随机权值扔掉,直接对每条树边维护一个初始为空的 01 串,然后第 i 次操作给被 (u_i, v_i) 覆盖的树边的串加入 1,没被覆盖的树边的串加入 0

然后我们把所有串按照字典序排序,那么每个时刻权值的等价类形成一个区间。我们只需要求出相邻两个串的 LCP,表示它们在这个时刻之前属于同一个等价类。这个思想跟后缀数组用 height 求本质不同子串是一样的。

所以倒着扫描线,并查集维护等价类即可。

01 串排序和求 LCP 的操作可以主席树维护区间哈希值解决。

n, m 同阶,时间复杂度 O(n \log^2 n),不过跑得不算慢。

:::info[代码]

// Problem: P14164 [ICPC 2022 Nanjing R] 命题作文
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P14164
// Memory Limit: 512 MB
// Time Limit: 1500 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

const int maxn = 250100;

int n, m, fa[maxn], p[maxn], sz[maxn];
ll ans[maxn], res;
mt19937_64 rnd(chrono::steady_clock::now().time_since_epoch().count());
ull f[maxn], val[maxn];
vector<int> vc[maxn];

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

inline void merge(int x, int y) {
    x = find(x);
    y = find(y);
    if (x == y) {
        return;
    }
    res -= 1LL * sz[x] * (sz[x] - 1) / 2;
    res -= 1LL * sz[y] * (sz[y] - 1) / 2;
    fa[x] = y;
    sz[y] += sz[x];
    res += 1LL * sz[y] * (sz[y] - 1) / 2;
}

int rt[maxn], ls[maxn * 50], rs[maxn * 50], nt;
ull hs[maxn * 50];

int update(int rt, int l, int r, int x, ull y) {
    int u = ++nt;
    ls[u] = ls[rt];
    rs[u] = rs[rt];
    hs[u] = hs[rt] ^ y;
    if (l == r) {
        return u;
    }
    int mid = (l + r) >> 1;
    if (x <= mid) {
        ls[u] = update(ls[u], l, mid, x, y);
    } else {
        rs[u] = update(rs[u], mid + 1, r, x, y);
    }
    return u;
}

bool cmp(int u, int v, int l, int r) {
    if (!u) {
        return 1;
    }
    if (!v) {
        return 0;
    }
    if (l == r) {
        return hs[u] < hs[v];
    }
    int mid = (l + r) >> 1;
    return hs[ls[u]] != hs[ls[v]] ? cmp(ls[u], ls[v], l, mid) : cmp(rs[u], rs[v], mid + 1, r);
}

int lcp(int u, int v, int l, int r) {
    if (hs[u] == hs[v]) {
        return r - l + 1;
    }
    if (l == r) {
        return 0;
    }
    int mid = (l + r) >> 1;
    return hs[ls[u]] != hs[ls[v]] ? lcp(ls[u], ls[v], l, mid) : lcp(rs[u], rs[v], mid + 1, r) + mid - l + 1;
}

void solve() {
    for (int i = 0; i <= nt; ++i) {
        ls[i] = rs[i] = hs[i] = 0;
    }
    nt = 0;
    scanf("%d%d", &n, &m);
    if (n == 1) {
        while (m--) {
            scanf("%*d%*d");
            puts("0");
        }
        return;
    }
    for (int i = 1; i <= n; ++i) {
        f[i] = 0;
        fa[i] = i;
        vector<int>().swap(vc[i]);
    }
    set<int> S;
    ll cnt = n - 1;
    for (int i = 1, l, r; i <= m; ++i) {
        scanf("%d%d", &l, &r);
        if (l > r) {
            swap(l, r);
        }
        val[i] = rnd();
        for (auto it = S.lower_bound(l); it != S.end() && (*it) < r;) {
            it = S.erase(it);
        }
        for (int j = find(l); j < r; j = find(j)) {
            S.insert(j);
            fa[j] = j + 1;
            --cnt;
        }
        ans[i] = cnt * (n - 1 + i - cnt) + (int)S.size();
        vc[l].pb(i);
        vc[r].pb(i);
    }
    for (int i = 1; i < n; ++i) {
        rt[i] = rt[i - 1];
        for (int j : vc[i]) {
            rt[i] = update(rt[i], 1, m, j, val[j]);
        }
        p[i] = i;
    }
    stable_sort(p + 1, p + n, [&](const int &i, const int &j) {
        return cmp(rt[i], rt[j], 1, m);
    });
    for (int i = 1; i <= m; ++i) {
        vector<int>().swap(vc[i]);
    }
    res = 0;
    for (int i = 1; i < n; ++i) {
        fa[i] = i;
        sz[i] = 1;
    }
    for (int i = 1; i <= n - 2; ++i) {
        int u = p[i], v = p[i + 1];
        int k = lcp(rt[u], rt[v], 1, m);
        if (k) {
            vc[k].pb(i);
        }
    }
    for (int i = m; i; --i) {
        for (int j : vc[i]) {
            merge(j, j + 1);
        }
        ans[i] += res;
    }
    for (int i = 1; i <= m; ++i) {
        printf("%lld\n", ans[i]);
    }
}

int main() {
    int T = 1;
    scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}

:::