题解:P6666 [清华集训2016] 数据交互

· · 题解

题目链接:[清华集训2016] 数据交互

省流:[清华集训2016] 数据交互 = 仓鼠找 sugar + 【模板】"动态 DP"&动态树分治

把取消操作看作加上原路径的相反数,这样就把操作全部转化为了添加路径。

对于路径 \left(u,v\right),若路径 \left(x,y\right) 与其相交,当且仅当:

  1. \text{LCA}\left(x,y\right)\in \left(u,v\right)
  2. \text{LCA}\left(u,v\right)\in \left(x,y\right)

val1_i=\sum\limits_{\text{LCA}\left(u_j,v_j\right)=i}w_jval2_i=\sum\limits_{i\in\left(u_j,v_j\right)}w_j

f_{i} 表示以 i 为链头的链中相交路径的权值最大值,ans_{i} 表示 \text{LCA}\left(u,v\right)=i 的路径 \left(u,v\right) 中相交路径的权值最大值。

max1_u 表示 u 的儿子中 f 的最大值,max2_u 表示 u 的儿子中的 f 的次大值。

易得转移如下:

f_{u}=max1_u+val1_u ans_{u} = max1_u+max2_u+val2_u

在每次增加路径后重新计算一遍,\max\limits_{i=1}^n ans_i 即为答案。这样我们就得到了一个 O\left(nm\right) 的算法。

该算法是修改一些信息后进行 \text{DP},考虑 \text{DDP}

套路地,对该树进行重链剖分,令 u 的重儿子为 w_u

f 的含义同上,ans 表示重链上所有上述含义下 ans 的最大值,max1_u 表示 u轻儿子f 的最大值,max2_u 表示 u轻儿子 中的 f 的次大值。

则易得转移:

f_{u}=\max\left(max1_u,f_{w_u}\right)+val1_u ans_{u} = \max\left(\max\left(max1_u+max2_u,max1_u+f_{w_u}\right)+val2_u,ans_{w_u}\right)

定义广义矩阵乘法 C=A\times B,其结果满足:C_{i,j}=\max\limits_{k}\left(A_{i,k}+B_{k,j}\right)

上述转移式写成矩阵乘法的形式有:

val1_u & -\infin & val1_u+max1_u\\ val2_u+max1_u & 0 & val2_u+max1_u+max2_u\\ -\infin & -\infin & 0 \end{bmatrix}\times \begin{bmatrix} f_{w_u}\\ ans_{w_u}\\ 0 \end{bmatrix}=\begin{bmatrix} f_u\\ ans_u\\ 0 \end{bmatrix}

每次加入一条新边时,会单点修改 val1,修改一个路径上的 val2 以及该路径上的重链的链头的父亲的 max1max2

比较难以处理的是区间加 val2

考虑矩阵乘法中区间加 val2 的影响:

a & -\infin & b\\ c & 0 & d\\ -\infin & -\infin & 0 \end{bmatrix}\times \begin{bmatrix} e & -\infin & f\\ g & 0 & h\\ -\infin & -\infin & 0 \end{bmatrix}=\begin{bmatrix} a+e & -\infin & \max\left(a+f,b\right)\\ \max\left(c+e,g\right) & 0 & \max\left(c+f,d,h\right)\\ -\infin & -\infin & 0 \end{bmatrix}

观察此式,不难发现区间加 val2 后相乘的结果就是加在了原答案矩阵的 val2 上,直接打个 \text{lazytag} 维护即可。

维护整个路径写起来很麻烦,差分成 4 段就好写了很多。\left(+\left(u\right),+\left(v\right),-\left(\text{LCA}\left(u,v\right)\right),-\left(fa_{\text{LCA}\left(u,v\right)}\right)\right)

时间复杂度为 O\left(m\log^2 n\right),但是如果写朴素的矩阵乘法,光是相乘的常数就有 27,完全过不了。按照上式,每个矩阵里面有用的元素只有 a,b,c,d 四个,因此矩阵里面只要维护 4 个元素即可,且相乘后的结果也是固定的,这样常数就可以缩小很多。

下面贴上丑陋的 \text{6.67KB} 的代码:

#include <bits/stdc++.h>
#define int long long
using namespace std;

int read()
{
    int f = 1;
    char c = getchar();
    while (!isdigit(c))
    {
        if (c == '-') f = -1;
        c = getchar();
    }
    int x = 0;
    while (isdigit(c))
    {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x * f;
}

int buf[15];

void write(int x)
{
    int p = 0;
    if (x < 0)
    {
        putchar('-');
        x = -x;
    }
    if (x == 0) putchar('0');
    else
    {
        while (x)
        {
            buf[++p] = x % 10;
            x /= 10;
        }
        for (int i = p; i >= 1; i--)
            putchar('0' + buf[i]);
    }
}

int n, m, vistime, fa[100005][25], tag[400005], sz[100005], w[100005], dep[100005], top[100005], dfn[100005], rdfn[100005], End[100005], f[100005], ans[100005], max1[100005], max2[100005];

vector < int > G[100005];

struct node
{
    int u, v, w;
} a[100005];

multiset < int > s[100005], res;

struct Matrix
{
    int A, B, C, D;
    Matrix friend operator * (const Matrix &A, const Matrix &B)
    {
        Matrix Ans;
        Ans.A = A.A + B.A;
        Ans.B = max(A.A + B.B, A.B);
        Ans.C = max(A.C + B.A, B.C);
        Ans.D = max({A.C + B.B, A.D, B.D});
        return Ans;
    }
} tree[400005];

void dfs1(int u, int fath)
{
    sz[u] = 1;
    fa[u][0] = fath;
    dep[u] = dep[fath] + 1;
    for (int i = 0; i < G[u].size(); i++)
    {
        int v = G[u][i];
        if (v == fath) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[w[u]]) w[u] = v;
    }
}

void dfs2(int u, int Top)
{
    top[u] = Top;
    dfn[u] = ++vistime;
    rdfn[vistime] = u;
    if (w[u]) dfs2(w[u], Top);
    else End[Top] = u;
    for (int i = 0; i < G[u].size(); i++)
    {
        int v = G[u][i];
        if (v == w[u] || v == fa[u][0]) continue;
        s[u].insert(0);
        dfs2(v, v);
    }
}

int LCA(int u, int v)
{
    if (dep[u] < dep[v]) swap(u, v);
    int t = dep[u] - dep[v];
    for (int i = 0; i <= 20; i++)
        if (t & (1 << i)) u = fa[u][i];
    if (u == v) return u;
    for (int i = 20; i >= 0; i--)
        if (fa[u][i] != fa[v][i])
        {
            u = fa[u][i];
            v = fa[v][i];
        }
    return fa[u][0];
}

void build(int rt, int l, int r)
{
    if (l == r)
    {
        tree[rt].A = 0;
        tree[rt].B = 0;
        tree[rt].C = 0;
        tree[rt].D = 0;
    }
    else
    {
        int mid = (l + r) / 2;
        build(rt * 2, l, mid);
        build(rt * 2 + 1, mid + 1, r);
        tree[rt] = tree[rt * 2] * tree[rt * 2 + 1];
    }
}

void update1(int rt, int l, int r, int t, int x)
{
    if (l == r) 
    {
        tree[rt].A += x;
        tree[rt].B += x;
    }
    else
    {
        int mid = (l + r) / 2;
        if (tag[rt] != 0)
        {
            tree[rt * 2].C += tag[rt];
            tree[rt * 2].D += tag[rt];
            tag[rt * 2] += tag[rt];
            tree[rt * 2 + 1].C += tag[rt];
            tree[rt * 2 + 1].D += tag[rt];
            tag[rt * 2 + 1] += tag[rt];
            tag[rt] = 0;
        }
        if (t <= mid) update1(rt * 2, l, mid, t, x);
        else update1(rt * 2 + 1, mid + 1, r, t, x);
        tree[rt] = tree[rt * 2] * tree[rt * 2 + 1];
    }
}

void update2(int rt, int l, int r, int tl, int tr, int x)
{
    if (tl <= l && r <= tr)
    {
        tree[rt].C += x;
        tree[rt].D += x;
        tag[rt] += x;
    }
    else
    {
        int mid = (l + r) / 2;
        if (tag[rt] != 0)
        {
            tree[rt * 2].C += tag[rt];
            tree[rt * 2].D += tag[rt];
            tag[rt * 2] += tag[rt];
            tree[rt * 2 + 1].C += tag[rt];
            tree[rt * 2 + 1].D += tag[rt];
            tag[rt * 2 + 1] += tag[rt];
            tag[rt] = 0;
        }
        if (tl <= mid) update2(rt * 2, l, mid, tl, tr, x);
        if (tr > mid) update2(rt * 2 + 1, mid + 1, r, tl, tr, x);
        tree[rt] = tree[rt * 2] * tree[rt * 2 + 1];
    }
}

void update3(int rt, int l, int r, int t, int x1, int x2)
{
    if (l == r)
    {
        tree[rt].B += x1;
        tree[rt].C += x1;
        tree[rt].D += (x1 + x2);
    }
    else
    {
        int mid = (l + r) / 2;
        if (tag[rt] != 0)
        {
            tree[rt * 2].C += tag[rt];
            tree[rt * 2].D += tag[rt];
            tag[rt * 2] += tag[rt];
            tree[rt * 2 + 1].C += tag[rt];
            tree[rt * 2 + 1].D += tag[rt];
            tag[rt * 2 + 1] += tag[rt];
            tag[rt] = 0;
        }
        if (t <= mid) update3(rt * 2, l, mid, t, x1, x2);
        else update3(rt * 2 + 1, mid + 1, r, t, x1, x2);
        tree[rt] = tree[rt * 2] * tree[rt * 2 + 1];
    }
}

Matrix query(int rt, int l, int r, int tl, int tr)
{
    if (tl <= l && r <= tr) return tree[rt];
    int mid = (l + r) / 2;
    if (tag[rt] != 0)
    {
        tree[rt * 2].C += tag[rt];
        tree[rt * 2].D += tag[rt];
        tag[rt * 2] += tag[rt];
        tree[rt * 2 + 1].C += tag[rt];
        tree[rt * 2 + 1].D += tag[rt];
        tag[rt * 2 + 1] += tag[rt];
        tag[rt] = 0;
    }
    if (tr <= mid) return query(rt * 2, l, mid, tl, tr);
    if (tl > mid) return query(rt * 2 + 1, mid + 1, r, tl, tr);
    Matrix X = query(rt * 2, l, mid, tl, tr), Y = query(rt * 2 + 1, mid + 1, r, tl, tr);
    return X * Y;
}

void solve1(int u, int w)
{
    update2(1, 1, n, dfn[top[u]], dfn[u], w);
    while (1)
    {
        u = top[u];
        Matrix Ans = query(1, 1, n, dfn[u], dfn[End[u]]);
        int F = max(Ans.A, Ans.B), ANS = max(Ans.C, Ans.D);
        auto it = res.lower_bound(ans[u]);
        res.erase(it);
        res.insert(ANS);
        ans[u] = ANS;
        if (u == 1) break;
        it = s[fa[u][0]].lower_bound(f[u]);
        s[fa[u][0]].erase(it);
        s[fa[u][0]].insert(F);
        f[u] = F;
        u = fa[u][0];
        int MAX1 = *s[u].rbegin(), MAX2 = *++s[u].rbegin();
        update3(1, 1, n, dfn[u], MAX1 - max1[u], MAX2 - max2[u]);
        max1[u] = MAX1;
        max2[u] = MAX2;
        update2(1, 1, n, dfn[top[u]], dfn[u], w);
    }
}

void solve2(int u, int w)
{
    update1(1, 1, n, dfn[u], w);
    while (1)
    {
        u = top[u];
        Matrix Ans = query(1, 1, n, dfn[u], dfn[End[u]]);
        int F = max(Ans.A, Ans.B), ANS = max(Ans.C, Ans.D);
        auto it = res.lower_bound(ans[u]);
        res.erase(it);
        res.insert(ANS);
        ans[u] = ANS;
        if (u == 1) break;
        it = s[fa[u][0]].lower_bound(f[u]);
        s[fa[u][0]].erase(it);
        s[fa[u][0]].insert(F);
        f[u] = F;
        u = fa[u][0];
        int MAX1 = *s[u].rbegin(), MAX2 = *++s[u].rbegin();
        update3(1, 1, n, dfn[u], MAX1 - max1[u], MAX2 - max2[u]);
        max1[u] = MAX1;
        max2[u] = MAX2;
    }
}

signed main()
{
    n = read(), m = read();
    for (int i = 1; i < n; i++)
    {
        int u, v;
        u = read(), v = read();
        G[u].push_back(v);
        G[v].push_back(u);
    }
    for (int i = 1; i <= n; i++)
    {
        s[i].insert(0);
        s[i].insert(0);
    }
    dfs1(1, 0);
    dfs2(1, 1);
    for (int i = 1; i <= n; i++)
        if (top[i] == i) res.insert(0);
    for (int i = 1; i <= 20; i++)
        for (int j = 1; j <= n; j++) 
            fa[j][i] = fa[fa[j][i - 1]][i - 1];
    build(1, 1, n);
    for (int i = 1; i <= m; i++)
    {
        char c;
        c = getchar();
        if (c == '+') a[i].u = read(), a[i].v = read(), a[i].w = read();
        else
        {
            int t;
            t = read();
            a[i].u = a[t].u;
            a[i].v = a[t].v;
            a[i].w = -a[t].w;
        }
        int lca = LCA(a[i].u, a[i].v);
        solve1(a[i].u, a[i].w);
        solve1(a[i].v, a[i].w);
        solve1(lca, -a[i].w);
        if (lca != 1) solve1(fa[lca][0], -a[i].w);
        solve2(lca, a[i].w);
        write(*res.rbegin());
        putchar('\n');
    }
    return 0;
}