【P5327 [ZJOI2019]语言】题解
P5327 [ZJOI2019]语言
线段树合并 + 点分治,总时间复杂度为
貌似题解区里没有点分治的做法,那就交一发点分治。
首先将链分别挂在两个端点的 vector 上。
选取了分治中心后,对中心对应的树求一遍主席树,让每个节点
中心下挂着若干个子树,对不同子树的节点染不同色(
然后对每一颗子树分别处理和统计答案,方法是对每个节点维护它通过分治中心所能到达的节点的集合(这仍然是一颗线段树),记为
在某一颗子树内时,若某节点
节点 vector 上也需要加入一条链
其次
然后将
处理完分治中心的每颗子树后,还要记得求出分治中心所能到达的节点(注意到自身的贡献要减去,若有的话)并把贡献加入到最终答案。
点分治结束后输出答案时要将答案除
每一条链至多被遍历
对于分治的每一层,若这一层的大小为
#include <stdio.h>
#include <algorithm>
#include <string.h>
#include <iostream>
#include <assert.h>
#include <vector>
using namespace std;
#define re register
#define sf scanf
#define pf printf
#define nl() putchar('\n')
#define ms(x, val) memset(x, val, sizeof(x))
#define ll long long
#define db double
#define ull unsigned long long
#define _for(i, a, b) for(re int i = (a); i < (b); ++i)
#define _rfor(i, a, b) for(re int i = (a); i <= (b); ++i)
#define _fev(p, u) for(re int p = head[u]; p; p = nex[p])
#define inf 0x7fffffff
#define maxn 100005
#define maxnn (maxn*120)
#define mod 1000000007ll
template <class T>
void print(string name, T arr[], int n, int flag = 1){
cout<<name<<":";
_for(i, 0, n)cout<<" "<<arr[i+flag];
cout<<endl;
}
int rdnt(){
re int x = 0, sign = 1;
re char c = getchar();
while (c < '0' || c > '9') { if (c == '-') sign = -1; c = getchar(); }
while (c >= '0' && c <= '9') x = (x<<3) + (x<<1) + (c ^ 48), c = getchar();
return x * sign;
}
inline void uad(int &x, const int &y){ if ((x+=y)>=mod) x-=mod; }
template<class T>
inline void umx(T &x, const T &y){ if (y > x) x = y; }
template<class T>
inline void umi(T &x, const T &y){ if (y < x) x = y; }
ll ans = 0;
int seg_ncnt, rg, epw[maxn], one[maxn];
struct Seg{ int lc, rc, sz, hsh; } seg[maxnn];
#define lc(x) seg[x].lc
#define rc(x) seg[x].rc
#define sz(x) seg[x].sz
#define hsh(x) seg[x].hsh
int max_seg = 0;
int new_node(re int nn){
assert(seg_ncnt+1 < maxnn);
umx(max_seg, seg_ncnt);
re int x = ++seg_ncnt;
seg[x] = seg[nn];
return x;
}
inline void ud(re int x, re int len){
sz(x) = sz(lc(x)) + sz(rc(x));
hsh(x) = (hsh(lc(x))+(ll)hsh(rc(x))*epw[len]%mod)%mod;
}
void adn(int x, int &y, int p, int tl, int tr){
y = new_node(x);
if (tl == tr){ sz(y) = 1; hsh(y) = 1; return; }
int mi = (tl+tr)>>1;
if (p <= mi) adn(lc(x), lc(y), p, tl, mi);
else adn(rc(x), rc(y), p, mi+1, tr);
ud(y, mi-tl+1);
}
void merge(int &z, int x, int y, int tl, int tr){
if (!x || !y){ z = x+y; return; }
if ((hsh(x) == hsh(y) && sz(x) == sz(y)) || sz(x) == tr-tl+1){ z = x; return; }
if (sz(y) == tr-tl+1){ z = y; return; }
assert(tl < tr);
int mi = (tl+tr)>>1;
z = new_node(0);
merge(lc(z), lc(x), lc(y), tl, mi);
merge(rc(z), rc(x), rc(y), mi+1, tr);
ud(z, mi-tl+1);
}
int ecnt = 1, icnt = 0,
head[maxn],
to[maxn*2],
nex[maxn*2],
col[maxn],
rt1[maxn],
rt2[maxn],
id[maxn],
son[maxn],
sz[maxn];
bool vis[maxn];
vector<int> adj[maxn];
void add_edge(re int u, re int v){
to[++ecnt] = v; nex[ecnt] = head[u]; head[u] = ecnt;
to[++ecnt] = u; nex[ecnt] = head[v]; head[v] = ecnt;
}
void get_son(int u, int fa){
sz[u] = 1; son[u] = 0; rt1[u] = rt2[u] = 0; id[u] = ++icnt;
_fev(p, u){
int v = to[p];
if (vis[v] || v == fa) continue;
get_son(v, u);
sz[u] += sz[v];
if (!son[u] || sz[v] > sz[son[u]]) son[u] = v;
}
}
int get_centroid(re int u){
if (!son[u]) return u;
re int S = sz[u], v;
while(sz[v = son[u]]*2 > S) u = v;
return u;
}
void get_cmt(int u, int c, int fa){
col[u] = c;
adn(rt1[fa], rt1[u], id[u], 1, rg);
_fev(p, u){
int v = to[p];
if (vis[v] || v == fa) continue;
get_cmt(v, c, u);
}
}
void get_ans(int u, int top, int fa){
static bool yes[maxn], flag; static int stk[maxn], tp;
tp = 0; flag = false;
for(auto &v : adj[u]){
if (col[v] != col[u]){
merge(rt2[u], rt2[u], rt1[v], 1, rg);
v = top;
flag = true;
}
if (!yes[v] && v != u) stk[tp++] = v, yes[v] = true;
}
adj[u].clear();
while(tp) adj[u].push_back(stk[--tp]), yes[stk[tp]] = false;
if (flag && top != u) adj[top].push_back(u);
_fev(p, u){
int v = to[p];
if (vis[v] || v == fa) continue;
get_ans(v, top, u);
merge(rt2[u], rt2[u], rt2[v], 1, rg);
}
ans += sz(rt2[u]);
}
void divide(int u){
if (!son[u]) return; icnt = 0;
u = get_centroid(u); get_son(u, 0);
assert(son[u]); assert(!vis[u]);
vis[u] = true;
//pf("u:%d\n", u);
rg = icnt; rt1[0] = rt2[0] = seg_ncnt = 0;
adn(0, rt1[u], id[u], 1, rg);
int ccnt = col[u] = 1;
_fev(p, u){
int v = to[p];
if (vis[v]) continue;
++ccnt;
get_cmt(v, ccnt, u);
}
_fev(p, u){
int v = to[p];
if (vis[v]) continue;
get_ans(v, v, u);
merge(rt2[u], rt2[u], rt2[v], 1, rg);
}
for(auto &v : adj[u]){
if (col[v] != col[u]){
merge(rt2[u], rt2[u], rt1[v], 1, rg);
}
}
if (sz(rt2[u])) ans += sz(rt2[u])-1;
//pf("ans:%lld\n", ans);
//nl();
_fev(p, u) if (!vis[to[p]]) divide(to[p]);
}
void init(re int n){
rt1[0] = rt2[0] = 0;
seg[0] = {0, 0, 0, 0};
epw[0] = one[0] = 1;
_rfor(i, 1, n) epw[i] = (epw[i-1]<<1)%mod, one[i] = (one[i-1]<<1|1)%mod;
}
int main(){
#ifndef ONLINE_JUDGE
freopen("sample.in", "r", stdin);
//freopen("sample.out", "w", stdout);
#endif
re int n = rdnt(), m = rdnt();
init(n);
_rfor(i, 1, n-1) add_edge(rdnt(), rdnt());
_rfor(i, 1, m){
re int s = rdnt(), t = rdnt();
adj[s].push_back(t);
adj[t].push_back(s);
}
get_son(1, 0);
divide(1);
assert((ans&1) == 0);
pf("%lld\n", ans/2);
//pf("%d\n", max_seg);
return 0;
}