题解 P6177 【Count on a tree II/【模板】树分块】
Count on a Tree II
推荐到 Luogu 博客或者到 这里 阅读
以前写的用可持久化分块维护。。既难写而且常数还大。。最后在 Luogu 还是开了读优挂才跑过去的。。感觉题解区的 Bitset 维护比我这个不知道高到哪里去了。。
考虑随机在这个树上打
然后我们先预处理出关键点间两两的答案,做法是从每个关键点跑一次 DFS 。
然后考虑一次询问,对于一次询问我们需要查询这两个点到关键点的路径上有没有什么颜色是不存在于关键点的路径上的。也就是说我们需要维护根到每个节点的某个颜色的数量。直接主席树维护得带
还需要用 ST 表
代码写的很丑。。
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
using namespace std;
#define MAXN 40006
#define B 206
#define swap( a , b ) ( (a) ^= (b) , (b) ^= (a) , (a) ^= (b) )
int n , m , blo , sz , bl;
int c[MAXN] , C[MAXN];
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn;
void ade( int u , int v ) {
to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}
int ps[B] , im[MAXN] , pre[B][B] , cid , ee;
int w[MAXN] , an;
void dfs( int u , int fa ) {
if( !w[c[u]] ) ++ an;
++ w[c[u]];
if( im[u] ) pre[cid][im[u]] = an;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
dfs( v , u );
}
-- w[c[u]];
if( !w[c[u]] ) -- an;
}
int kk[MAXN << 2][B] , idx = 0 , rt[MAXN];
void build( ) {
idx = rt[0] = 1;
for( int i = 1 ; ( i - 1 ) * bl < sz ; ++ i ) kk[1][i] = i + 1 , ++ idx;
}
void add( int rt , int old , int x ) {
memcpy( kk[rt] , kk[old] , sizeof kk[old] );
++ idx;
memcpy( kk[idx] , kk[kk[rt][( x - 1 ) / bl + 1]] , sizeof kk[1] );
kk[rt][( x - 1 ) / bl + 1] = idx;
++ kk[idx][( x - 1 ) % bl + 1];
}
int que( int rt , int x ) {
int t = kk[rt][( x - 1 ) / bl + 1];
return kk[t][( x - 1 ) % bl + 1];
}
int dfn[MAXN << 1] , dep[MAXN << 1] , pos[MAXN << 1] , en = 0 , d[MAXN] , par[MAXN];
int dfs1( int u , int fa ) {
dfn[++en] = u , pos[u] = en , dep[en] = d[u] , par[u] = fa;
rt[u] = ++ idx;
add( rt[u] , rt[fa] , c[u] );
int re = 0;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
d[v] = d[u] + 1;
re = max( re , dfs1( v , u ) );
dfn[++ en] = u , dep[en] = d[u];
}
if( re - d[u] >= blo && d[u] % blo == 0 ) ps[++ ee] = u , im[u] = ee;
return re ? re : d[u];
}
int lg[MAXN << 1];
int st[MAXN << 1][17];
void ST() {
for( int i = 1 ; i <= en ; ++ i ) lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for( int i = 1 ; i <= en ; ++ i ) st[i][0] = i;
for( int i = 1 ; ( 1 << i ) <= en ; ++ i )
for (int j = 1; j + (1 << i) - 1 <= en; j++)
st[j][i] = dep[st[j][i - 1]] < dep[st[j + (1 << (i - 1))][i - 1]] ? st[j][i - 1] : st[j + (1 << (i - 1))][i - 1];
}
int lca( int l , int r ) {
l = pos[l] , r = pos[r];
if( l > r ) swap( l , r );
int k = lg[r - l + 1] - 1;
return dep[st[l][k]] <= dep[st[r-(1<<k)+1][k]] ? dfn[st[l][k]] : dfn[st[r-(1<<k)+1][k]];
}
int col[MAXN] , cn , l;
int jump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
if( u == l ) return -1;
return jump( par[u] );
}
int rejjump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
return rejjump( par[u] );
}
int oc[MAXN];
int main() {
// freopen("10.in","r",stdin);
// freopen("ot","w",stdout);
cin >> n >> m;
blo = sqrt( n );
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&c[i]) , C[i] = c[i];
sort( C + 1 , C + 1 + n ); sz = unique( C + 1 , C + 1 + n ) - C - 1;
bl = sqrt( sz );
for( int i = 1 ; i <= n ; ++ i ) c[i] = lower_bound( C + 1 , C + 1 + sz , c[i] ) - C;
for( int i = 1 , u , v ; i < n ; ++ i ) {
scanf("%d%d",&u,&v);
ade( u , v ) , ade( v , u );
}
build( );
dfs1( 1 , 0 );
// cout << que( rt[3] , 3 ) << endl;
ST( );
for( int i = 1 ; i <= blo ; ++ i ) {
int p = ps[i]; cid = i;
dfs( p , p );
}
int u , v , psu , psv , flg = 0 , re , last = 0 , t , pr;
while( m-- ) {
scanf("%d%d",&u,&v);
u ^= last;
// cout << u << ' ' << v << endl;
l = lca( u , v );
cn = re = 0;
psu = jump( u ) , psv = jump( v );
if( psu == psv ) {
oc[c[l]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) oc[col[i]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
oc[c[l]] = 0;
} else if( ~psu && ~psv ) {
re = pre[psu][psv];
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psu]] , col[i] ) + que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) * 2 == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
} else {
if( ~psu ) swap( u , v ) , swap( psu , psv );
pr = l; t = cn;
while( !im[pr] ) {
pr = par[pr];
if( !oc[c[pr]] ) {
oc[c[pr]] = 1;
if (que(rt[ps[psv]], c[pr]) - que(rt[par[l]], c[pr]) == 0)
--re;
col[++ cn] = c[pr];
}
}
re += pre[im[pr]][psv];
for( int i = t + 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
cn = t;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
}
printf("%d\n",last = re);
}
}