【P5327 [ZJOI2019]语言】题解

· · 题解

P5327 [ZJOI2019]语言

线段树合并 + 点分治,总时间复杂度为 O(n\log^2n)

貌似题解区里没有点分治的做法,那就交一发点分治。

首先将链分别挂在两个端点的 vector 上。

选取了分治中心后,对中心对应的树求一遍主席树,让每个节点 u 上都挂有一颗装着前缀链上所有点的线段树,设线段树根为 rt1(u)

中心下挂着若干个子树,对不同子树的节点染不同色(col(u)),中心也染上不同的色。

然后对每一颗子树分别处理和统计答案,方法是对每个节点维护它通过分治中心所能到达的节点的集合(这仍然是一颗线段树),记为 rt2(u)

在某一颗子树内时,若某节点 u 上有跳出这颗子树的链 (u,v),也即 col(u)\neq col(v),说明 rt1(v) 这一条前缀链上的所有点(包括分治中心)都是 u 通过分治中心能到达的,所以让 rt2(u):=\text{merge}(rt2(u),rt1(v))

节点 u 上跳出这个子树的链 (u,v) 会被分治中心截断,不妨设这颗子树的根是 top,则需要将这条链修改为 (u,top),记住在 topvector 上也需要加入一条链 (top,u)

其次 u 的子节点通过分治中心所能到达的节点也一定是 u 能到的,故递归处理完子节点 w 后,让 rt2(u)=\text{merge}(rt2(u), rt2(w))

然后将 size(rt2(u)) 加入最终答案。

处理完分治中心的每颗子树后,还要记得求出分治中心所能到达的节点(注意到自身的贡献要减去,若有的话)并把贡献加入到最终答案。

点分治结束后输出答案时要将答案除 2

每一条链至多被遍历 O(\log n) 次,每条链被截断时至多在 top 上产生 1 条新链,每条链的截断次数为 O(\log n),所以对链遍历的总时间复杂度为 O(m\log^2 n),当然这个界远远达不到。

对于分治的每一层,若这一层的大小为 s,则线段树合并的时间复杂度为 O(s\log s),于是线段树合并的总时间复杂度为 O(n\log^2 n)

#include <stdio.h>
#include <algorithm>
#include <string.h>
#include <iostream>
#include <assert.h>
#include <vector>
using namespace std;

#define re register
#define sf scanf
#define pf printf
#define nl() putchar('\n')
#define ms(x, val) memset(x, val, sizeof(x))
#define ll long long
#define db double
#define ull unsigned long long
#define _for(i, a, b) for(re int i = (a); i < (b); ++i)
#define _rfor(i, a, b) for(re int i = (a); i <= (b); ++i)
#define _fev(p, u) for(re int p = head[u]; p; p = nex[p])
#define inf 0x7fffffff
#define maxn 100005
#define maxnn (maxn*120) 
#define mod 1000000007ll

template <class T>
void print(string name, T arr[], int n, int flag = 1){
    cout<<name<<":";
    _for(i, 0, n)cout<<" "<<arr[i+flag];
    cout<<endl;
}

int rdnt(){
    re int x = 0, sign = 1;
    re char c = getchar();
    while (c < '0' || c > '9') { if (c == '-') sign = -1; c = getchar(); }
    while (c >= '0' && c <= '9') x = (x<<3) + (x<<1) + (c ^ 48), c = getchar();
    return x * sign;
}

inline void uad(int &x, const int &y){ if ((x+=y)>=mod) x-=mod; }
template<class T>
inline void umx(T &x, const T &y){ if (y > x) x = y; }
template<class T>
inline void umi(T &x, const T &y){ if (y < x) x = y; }

ll ans = 0;

int seg_ncnt, rg, epw[maxn], one[maxn];
struct Seg{ int lc, rc, sz, hsh; } seg[maxnn];
#define lc(x) seg[x].lc
#define rc(x) seg[x].rc
#define sz(x) seg[x].sz
#define hsh(x) seg[x].hsh

int max_seg = 0;
int new_node(re int nn){
    assert(seg_ncnt+1 < maxnn);
    umx(max_seg, seg_ncnt);
    re int x = ++seg_ncnt;
    seg[x] = seg[nn];
    return x;
}

inline void ud(re int x, re int len){
    sz(x) = sz(lc(x)) + sz(rc(x));
    hsh(x) = (hsh(lc(x))+(ll)hsh(rc(x))*epw[len]%mod)%mod;
}

void adn(int x, int &y, int p, int tl, int tr){
    y = new_node(x);
    if (tl == tr){ sz(y) = 1; hsh(y) = 1; return; }
    int mi = (tl+tr)>>1;
    if (p <= mi) adn(lc(x), lc(y), p, tl, mi);
    else adn(rc(x), rc(y), p, mi+1, tr);
    ud(y, mi-tl+1); 
}

void merge(int &z, int x, int y, int tl, int tr){
    if (!x || !y){ z = x+y; return; }
    if ((hsh(x) == hsh(y) && sz(x) == sz(y)) || sz(x) == tr-tl+1){ z = x; return; }
    if (sz(y) == tr-tl+1){ z = y; return; }
    assert(tl < tr);
    int mi = (tl+tr)>>1;
    z = new_node(0);
    merge(lc(z), lc(x), lc(y), tl, mi);
    merge(rc(z), rc(x), rc(y), mi+1, tr);
    ud(z, mi-tl+1); 
}

int ecnt = 1, icnt = 0,
    head[maxn],
    to[maxn*2],
    nex[maxn*2],
    col[maxn],
    rt1[maxn],
    rt2[maxn],
    id[maxn],
    son[maxn],
    sz[maxn];
bool vis[maxn];
vector<int> adj[maxn];

void add_edge(re int u, re int v){
    to[++ecnt] = v; nex[ecnt] = head[u]; head[u] = ecnt;
    to[++ecnt] = u; nex[ecnt] = head[v]; head[v] = ecnt;
}

void get_son(int u, int fa){
    sz[u] = 1; son[u] = 0; rt1[u] = rt2[u] = 0; id[u] = ++icnt;
    _fev(p, u){
        int v = to[p];
        if (vis[v] || v == fa) continue;
        get_son(v, u);
        sz[u] += sz[v];
        if (!son[u] || sz[v] > sz[son[u]]) son[u] = v;
    }
}

int get_centroid(re int u){
    if (!son[u]) return u;
    re int S = sz[u], v;
    while(sz[v = son[u]]*2 > S) u = v;
    return u;
}

void get_cmt(int u, int c, int fa){
    col[u] = c;
    adn(rt1[fa], rt1[u], id[u], 1, rg);
    _fev(p, u){
        int v = to[p];
        if (vis[v] || v == fa) continue;
        get_cmt(v, c, u);
    }
}

void get_ans(int u, int top, int fa){
    static bool yes[maxn], flag; static int stk[maxn], tp;
    tp = 0; flag = false;
    for(auto &v : adj[u]){
        if (col[v] != col[u]){
            merge(rt2[u], rt2[u], rt1[v], 1, rg);
            v = top;
            flag = true;
        }
        if (!yes[v] && v != u) stk[tp++] = v, yes[v] = true;
    }
    adj[u].clear();
    while(tp) adj[u].push_back(stk[--tp]), yes[stk[tp]] = false;
    if (flag && top != u) adj[top].push_back(u);

    _fev(p, u){
        int v = to[p];
        if (vis[v] || v == fa) continue;
        get_ans(v, top, u);
        merge(rt2[u], rt2[u], rt2[v], 1, rg);
    }
    ans += sz(rt2[u]);
}

void divide(int u){
    if (!son[u]) return; icnt = 0;
    u = get_centroid(u); get_son(u, 0);
    assert(son[u]); assert(!vis[u]);
    vis[u] = true;

    //pf("u:%d\n", u);

    rg = icnt; rt1[0] = rt2[0] = seg_ncnt = 0;
    adn(0, rt1[u], id[u], 1, rg);
    int ccnt = col[u] = 1;
    _fev(p, u){
        int v = to[p];
        if (vis[v]) continue;
        ++ccnt;
        get_cmt(v, ccnt, u);
    }

    _fev(p, u){
        int v = to[p];
        if (vis[v]) continue;
        get_ans(v, v, u);
        merge(rt2[u], rt2[u], rt2[v], 1, rg);
    }
    for(auto &v : adj[u]){
        if (col[v] != col[u]){
            merge(rt2[u], rt2[u], rt1[v], 1, rg);
        }
    }
    if (sz(rt2[u])) ans += sz(rt2[u])-1;
    //pf("ans:%lld\n", ans);

    //nl();
    _fev(p, u) if (!vis[to[p]]) divide(to[p]);
}

void init(re int n){
    rt1[0] = rt2[0] = 0;
    seg[0] = {0, 0, 0, 0};
    epw[0] = one[0] = 1;
    _rfor(i, 1, n) epw[i] = (epw[i-1]<<1)%mod, one[i] = (one[i-1]<<1|1)%mod;
}

int main(){
    #ifndef ONLINE_JUDGE
    freopen("sample.in", "r", stdin);
    //freopen("sample.out", "w", stdout);
    #endif

    re int n = rdnt(), m = rdnt();
    init(n);
    _rfor(i, 1, n-1) add_edge(rdnt(), rdnt());
    _rfor(i, 1, m){
        re int s = rdnt(), t = rdnt();
        adj[s].push_back(t);
        adj[t].push_back(s);
    }

    get_son(1, 0);
    divide(1);

    assert((ans&1) == 0);
    pf("%lld\n", ans/2);
    //pf("%d\n", max_seg);

    return 0;
}