题解 P3714 【[BJOI2017]树的难题】

· · 题解

点分治

以下的同色联通块以及异色联通块中的色指的是与分治中心直接相连的边的颜色

考虑对于一个分治中心,每种颜色单独处理,也就是对分治中心下方的每颗子树,以其同颜色联通块最大深度为第一关键字,颜色为第二关键字 排序,然后再在同色的联通块中,以深度排序得到一个遍历序列

这样我们就可以使用数据结构来维护同色联通块以及异色联通块的产生的贡献了

使用单调队列

异色联通块中默认存在一条权值和长度都为0的路径,而同色联通块不行

对同色联通块和异色联通块各开一个单调队列,单调队列 中存下当一条路径 l 长度为 x 时权值最大的可以与 l 组成一条路径的路径权值最大值,每次在在DFS的时候更新答案

当遍历完一颗子树时,如果下一次遍历的子树颜色与这颗子树相同,那么更新同色联通块的权值,否则更新异色联通块的权值并 清空同色联通块的权值

注意同色联通块的单调队列里面的权值要减去分治中心直接相连的边的颜色的权值

时间复杂度:O(N\log N)

#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <functional>

#define DEBUG(args...) fprintf(stderr, args)

typedef long long LL;

#define FOR(i, l, r) for(int i = (l), i##_end = (r); i <= i##_end; ++i)
#define REP(i, l, r) for(int i = (l), i##_end = (r); i <  i##_end; ++i)
#define DFR(i, l, r) for(int i = (l), i##_end = (r); i >= i##_end; --i)
#define DRP(i, l, r) for(int i = (l), i##_end = (r); i >  i##_end; --i)

template<class T>T Min(const T &a, const T &b) {return a < b ? a : b;}
template<class T>T Max(const T &a, const T &b) {return a > b ? a : b;}
template<class T>bool Chkmin(T &a, const T &b) {return a > b ? a = b, 1 : 0;}
template<class T>bool Chkmax(T &a, const T &b) {return a < b ? a = b, 1 : 0;}

class fast_input {
private:
    static const int SIZE = 1 << 15 | 1;
    char buf[SIZE], *front, *back;

    void Next(char &c) {
        if(front == back) back = (front = buf) + fread(buf, 1, SIZE, stdin);
        c = front == back ? (char)EOF : *front++;
    }

public :
    template<class T>void operator () (T &x) {
        char c, f = 1;
        for(Next(c); !isdigit(c); Next(c)) if(c == '-') f = -1;
        for(x = 0; isdigit(c); Next(c)) x = x * 10 + c - '0';
        x *= f;
    }
    void operator () (char &c, char l = 'a', char r = 'z') {
        for(Next(c); c > r || c < l; Next(c)) ;
    }
}input;

struct triple {
    int c, d, x, md;
    explicit triple(int c = 0, int d = 0, int x = 0) : c(c), d(d), x(x) {}
    bool operator < (const triple &o) const { 
        return c != o.c ? c < o.c : d < o.d;
    }
    bool operator > (const triple &o) const {
        return md != o.md ? md < o.md : c != o.c ? c < o.c : d < o.d;
    }
};

const int SN = 200000 + 47;
const int SE = 400000 + 47;
const int INF = 2000100000;

int head[SN], nxt[SE], to[SE], col[SE];
int val[SN], n, left, right;
int size[SN], max_deep[SN];
int vis[SN];
triple a[SN];
int ans;

void Add(int, int, int);
void GetSize(int, int);
int GetDeep(int, int, int);

int __size, __root, __max_size;
void GetRoot(int, int);

void Solve(int);
void DFS(int, int, int, int, int);

int main() {

#ifdef Cai
    freopen("s.in", "r", stdin);
#endif

    int x, y, z, m;

    input(n), input(m), input(left), input(right);
    FOR(i, 1, m) input(val[i]);
    FOR(i, 2, n) input(x), input(y), input(z), Add(x, y, z);

    GetSize(1, -1);
    __size = n, __max_size = INF, GetRoot(1, -1);
    ans = -INF, Solve(__root);

    if(ans < -2000000000) return printf("%d\n", -INF), 0;

    printf("%d\n", ans);

    return 0;

}

void Add(int x, int y, int z) {
    static int _ = 0;
    nxt[++_] = head[x], head[x] = _, to[_] = y, col[_] = z;
    nxt[++_] = head[y], head[y] = _, to[_] = x, col[_] = z;
}

void GetSize(int x, int y) {
    size[x] = 1;
    for(int i = head[x]; i; i = nxt[i])
        if(to[i] != y && !vis[to[i]]) {
            GetSize(to[i], x);
            size[x] += size[to[i]];
        }
}

int GetDeep(int x, int y, int d) {
    int max_d = d;
    for(int i = head[x]; i; i = nxt[i])
        if(to[i] != y && !vis[to[i]]) 
            Chkmax(max_d, GetDeep(to[i], x, d + 1));
    return max_d;
}

void GetRoot(int x, int y) {
    int max_size = __size - size[x];
    for(int i = head[x]; i; i = nxt[i])
        if(to[i] != y && !vis[to[i]]) {
            GetRoot(to[i], x);
            Chkmax(max_size, size[to[i]]);
        }
    if(Chkmin(__max_size, max_size)) __root = x;
}

// q0 for all
// q1 for same col
int q0[SN], f0, b0, t0[SN], v0[SN];
int q1[SN], f1, b1, t1[SN], v1[SN];
int c[SN], cc, b[SN], cb, d[SN], cd;

void Solve(int x) {
    int cnt = 0;
    GetSize(x, -1), vis[x] = 1;
    for(int i = head[x]; i; i = nxt[i])
        if(!vis[to[i]]) 
            a[++cnt] = triple(col[i], GetDeep(to[i], x, 1), to[i]);
    std::sort(a + 1, a + cnt + 1);
    DFR(i, cnt, 1)
        if(a[i].c == a[i + 1].c) a[i].md = a[i + 1].md;
        else a[i].md = a[i].d;
    std::sort(a + 1, a + cnt + 1, std::greater<triple>());
    int root_all = 0, root_col = 0;
    FOR(i, 1, a[1].md) q0[i] = q1[i] = c[i] = d[i] = v0[i] = v1[i] = -INF;
    FOR(i, left, Min(a[1].md, right)) v1[i] = 0;
    FOR(i, 1, cnt) {
        cb = 0, DFS(a[i].x, x, 1, val[a[i].c], a[i].c); // ans : updated
        //if(ans == -4) {DEBUG("%d %d\n", x, i); throw ;}
        if(a[i].c == a[i + 1].c) {
            cc = cb;
            FOR(j, 1, cb) Chkmax(c[j], b[j]);
            f0 = 0, b0 = -1, c[0] = val[a[i].c];
            DFR(j, Min(cc, right - 1), left - 1) {
                while(f0 <= b0 && q0[b0] <= c[j]) --b0;
                q0[++b0] = c[j], t0[b0] = j;
            }
            FOR(j, 1, a[i + 1].d) {
                if(f0 > b0) v0[j] = -INF;
                else v0[j] = q0[f0] - val[a[i].c]; // val[a[i].c] will be calc twice
                while(f0 <= b0 && t0[f0] + j + 1 > right) ++f0;
                if(left - j - 1 >= 0 && left - j - 1 <= cc) {
                    while(f0 <= b0 && q0[b0] <= c[left - j - 1]) --b0;
                    q0[++b0] = c[left - j - 1], t0[b0] = left - j - 1;
                }
            }
        }
        else {
            cd = cb;
            FOR(j, 1, cd) Chkmax(d[j], Max(b[j], c[j]));
            f1 = 0, b1 = -1, d[0] = 0;
            DFR(j, Min(cd, right - 1), left - 1) {
                while(f1 <= b1 && q1[b1] <= d[j]) --b1;
                q1[++b1] = d[j], t1[b1] = j;
            }
            FOR(j, 1, a[i + 1].md) {
                if(f1 > b1) v1[j] = -INF;
                else v1[j] = q1[f1];
                while(f1 <= b1 && t1[f1] + j + 1 > right) ++f1;
                if(left - j - 1 >= 0 && left - j - 1 <= cd) {
                    while(f1 <= b1 && q1[b1] <= d[left - j - 1]) --b1;
                    q1[++b1] = d[left - j - 1], t1[b1] = left - j - 1;
                }
            }
            cc = 0;
            FOR(j, 1, a[i + 1].md) c[j] = v0[j] = -INF;
            FOR(j, a[i].md + 1, a[i + 1].md) d[j] = -INF;
        }
    }
    for(int i = head[x]; i; i = nxt[i])
        if(!vis[to[i]]) {
            __size = size[to[i]], __max_size = INF, GetRoot(to[i], x);
            Solve(__root);
        }
}

void DFS(int x, int y, int deep, int v, int lc) {
    if(deep <= cb) Chkmax(b[deep], v);
    else b[cb = deep] = v;
    Chkmax(ans, v + Max(v0[deep], v1[deep]));
    for(int i = head[x]; i; i = nxt[i])
        if(to[i] != y && !vis[to[i]])
            if(col[i] == lc)
                DFS(to[i], x, deep + 1, v, lc);
            else
                DFS(to[i], x, deep + 1, v + val[col[i]], col[i]);
}

/*
g++ -o s s.cpp -O2; for((i = 1; i <= 10; ++i)) do cp journey$i.in s.in; ./s > s.out; diff journey$i.out s.out -w > s.res; echo $i : $?; done
*/