人赢的跳棋 题解

· · 题解

使用点分治,朴素地求出分治区域内所有点的信息,例如点 u 的信息为 (u_1,u_2) 表示从 u 走到 mid 可以求的所有 win 值的最大值midu 方向的那条和 mid 相邻边的边权三元组 e

一条路径可以表示为 u\to mid\to v,其中 uv 必须属于 mid 的不同子树。路径权值为 \max(u_1,v_1,win(u_2,v_2))

把所有三元组放在一起跑三维偏序,至此整题得到 O(n\log^3 n) 的做法。

考虑在三维偏序使用归并排序的过程。按照 e_2 排序会使得相同子树的点被排到相邻位置,形成一个区间。对于同一区间内的元素不会构成路径;此时相当于归并若干个区间。按照区间长度分治归并即可。对于时间复杂度,可以注意到分治时两层长度一定减半,否则已经递归至底层;类似全局平衡二叉树的分析,整个分治树的高度为 O(\log n),总时间复杂度加上 BIT 为 O(n\log^2 n)

#include<bits/stdc++.h>
#define endl '\n'
using namespace std;
using ll = long long;
const int N = 3e5 + 5;
int n;
ll ans;
int a[N], b[N];
vector<pair<int, int>> e[N];
int sz[N], mx[N];
bool vis[N], vis2[N];
int ttot, tt[N];
void getrt(int u, int fa){
    tt[++ttot] = u;
    sz[u] = 1;
    mx[u] = 0;
    for (auto& [v, id] : e[u])
        if (!vis[v] && v != fa){
            getrt(v, u);
            sz[u] += sz[v];
            mx[u] = max(sz[v], mx[u]);
        }
}
int f[N];
int tot, tot2, tot3, tr[N], ed[N];
pair<int, int> A[N], B[N];
int res[N];
pair<int, int> tmp[N];
int win(int x, int y){
    if (x < y && b[x] < b[y]) return a[x];
    else if (x > y && b[x] > b[y]) return a[y];
    else return 0;
}
void dfs2(int u, int fae){
    tr[++tot] = f[u];
    for (auto& [v, id] : e[u])
        if (vis2[v] && id != fae){
            f[v] = max(f[u], win(fae, id));
            dfs2(v, id);
        }
}
struct tree2{
    int a[N];
    void add(int x, int k){
        x++;
        for (; x <= n + 1; x += x & -x)
            a[x] += k;
    }
    int ask(int x){
        x++;
        int res = 0;
        for (; x; x -= x & -x)
            res += a[x];
        return res;
    }
}T1, T2;
void dfs3(int l, int r){
    if (l == r) return;
    pair<int, int> mn(1e9, 0);
    for (int i = l; i < r; i++)
        mn = min(make_pair(abs(2 * ed[i] - ed[l - 1] - ed[r]), i), mn);
    int mid = mn.second;
//  int mid = (l + r) >> 1;
    dfs3(l, mid);
    dfs3(mid + 1, r);
    {//x2 < y2 B贡献 
        int j = ed[r];
        for (int i = ed[mid]; i > ed[l - 1]; i--){
            for (; j > ed[mid] && A[j] > B[i]; j--)
                T1.add(A[j].second, 1);
            ans += 1ll * T1.ask(B[i].second) * B[i].second;
        }
        for (int k = j + 1; k <= ed[r]; k++)
            T1.add(A[k].second, -1);
    }
    {//x2 < y2 A贡献 
        int j = ed[l - 1] + 1;
        for (int i = ed[mid] + 1; i <= ed[r]; i++){
            for (; j <= ed[mid] && B[j] < A[i]; j++)
                T1.add(B[j].second, 1);
            ans += 1ll * T1.ask(A[i].second) * A[i].second;
        }
        for (int k = ed[l - 1] + 1; k < j; k++)
            T1.add(B[k].second, -1);
    }
    {//x2 > y2 Al
        int j = ed[mid] + 1;
        for (int i = ed[l - 1] + 1; i <= ed[mid]; i++){
            for (; j <= ed[r] && A[j] < A[i]; j++)
                T1.add(A[j].second, 1);
            ans += 1ll * T1.ask(A[i].second) * A[i].second;
        }
        for (int k = ed[mid] + 1; k < j; k++)
            T1.add(A[k].second, -1);
    }
    {//x2 > y2 Ar
        int j = ed[mid];
        for (int i = ed[r]; i > ed[mid]; i--){
            for (; j > ed[l - 1] && A[j] > A[i]; j--)
                T1.add(A[j].second, 1);
            ans += 1ll * T1.ask(A[i].second) * A[i].second;
        }
        for (int k = j + 1; k <= ed[mid]; k++)
            T1.add(A[k].second, -1);
    }
    merge(A + ed[l - 1] + 1, A + ed[mid] + 1, A + ed[mid] + 1, A + ed[r] + 1, tmp + 1);
    copy(tmp + 1, tmp + 1 + ed[r] - ed[l - 1], A + ed[l - 1] + 1);
    merge(B + ed[l - 1] + 1, B + ed[mid] + 1, B + ed[mid] + 1, B + ed[r] + 1, tmp + 1);
    copy(tmp + 1, tmp + 1 + ed[r] - ed[l - 1], B + ed[l - 1] + 1);
}
void dfs1(int u){
    ttot = 0;
    getrt(u, 0);
    int S = sz[u];
    for (int i = 1; i <= ttot; i++){
        int v = tt[i];
        if (max(mx[v], S - sz[v]) * 2 <= S)
            u = v;
    }
    vis[u] = 1;
    if (S == 1){
        vis2[u] = 1;
        return;
    }
    vector<pair<int, int>> son;
    for (auto [v, id] : e[u])
        if (!vis[v]){
            son.emplace_back(v, id);
            S = sz[v];
            dfs1(v);
        }
    int len = son.size();
    tot2 = tot3 = 0;
    for (auto [v, id] : son){
        f[v] = 0;
        tot = 0;
        dfs2(v, id);
        for (int i = 1; i <= tot; i++){
            int x = tr[i];
            ans += x;
            A[tot2 + i] = {b[id], x};
            B[tot2 + i] = {b[id], max(x, a[id])};
        }
        tot2 += tot;
        ed[++tot3] = tot2;
    }
    dfs3(1, tot3);
    vis2[u] = 1;
}
char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
inline int read() {
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)) x=x*10+(ch^48),ch=getchar();
    return x*f;
}
int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    n = read();
    for (int i = 1; i < n; i++){
        int u = read(), v = read();
        a[i] = read(), b[i] = read();
        e[u].emplace_back(v, i);
        e[v].emplace_back(u, i);
    }
    dfs1(1);
    cout << ans;
    return 0;
}