CF1878G 题解

· · 题解

题目链接

解题思路

我们考虑按二进制位进行统计。

首先我们使用倍增解决 LCA 的问题,在做 dfs 的时候顺便再统计一个 C_{i,j} 表示在 i 这个节点时往上最早的节点 k 使得其满足 a_k 二进制的第 j 位上为1。

然后对于每对询问,我们可以考虑将 x \rightarrow lcalca \rightarrow y 拉成一条链,并在上面建立一棵线段树。然后对于每一个二进制位进行遍历,找到从 x 开始的离其最近的在此时遍历的二进制位上为1的数字,将这个位置一直到 y 区间加一。y 同理。答案即为在遍历完后整棵线段树的区间最大值。

复杂度 O(n \log V^2)

代码

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=200010;
int T;
int n,a[N],q;
int head[N],cnt;
int f[N][35],c[N][35],dep[N];
ll ans;
struct Tree{
    int v,tag[2];
    //tag 0 : plus ; tag 1 cover
}tr[N<<2];
struct Edge{
    int to,nxt;
}e[N<<1];
void add(int u,int v){
    e[++cnt].to=v;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
void pushup(int node){
    tr[node].v=max(tr[node<<1].v,tr[node<<1|1].v);
}
void build(int l,int r,int node){
    if(l==r){
        tr[node].v=tr[node].tag[0]=0;
        tr[node].tag[1]=-1;
        return;
    }
    int mid=l+r>>1;
    build(l,mid,node<<1);
    build(mid+1,r,node<<1|1);
    pushup(node);
}
void pushdown(int node){
    if(~tr[node].tag[1]){
        tr[node<<1].tag[1]=tr[node].tag[1];
        tr[node<<1|1].tag[1]=tr[node].tag[1];
        tr[node<<1].v=tr[node].tag[1];
        tr[node<<1|1].v=tr[node].tag[1];
        tr[node<<1].tag[0]=tr[node<<1|1].tag[0]=0;
        tr[node].tag[1]=-1;
    }
    tr[node<<1].v+=tr[node].tag[0];tr[node<<1|1].v+=tr[node].tag[0];
    tr[node<<1].tag[0]+=tr[node].tag[0];tr[node<<1|1].tag[0]+=tr[node].tag[0];
    tr[node].tag[0]=0;
}
void Plus(int l,int r,int x,int y,int v,int node){
    if(x<=l && r<=y){
        tr[node].v+=v;
        tr[node].tag[0]+=v;
        return;
    }
    int mid=l+r>>1;
    pushdown(node);
    if(x<=mid) Plus(l,mid,x,y,v,node<<1);
    if(y>mid) Plus(mid+1,r,x,y,v,node<<1|1);
    pushup(node);
}
void Cover(int l,int r,int x,int y,int v,int node){
    if(x<=l && r<=y){
        tr[node].v=v;
        tr[node].tag[0]=0;
        tr[node].tag[1]=v;
        return;
    }
    int mid=l+r>>1;
    pushdown(node);
    if(x<=mid) Cover(l,mid,x,y,v,node<<1);
    if(y>mid) Cover(mid+1,r,x,y,v,node<<1|1);
    pushup(node);
}
void dfs(int x,int fa){
    f[x][0]=fa;
    dep[x]=dep[fa]+1;
    for(int j=0;j<=31;j++){
        if((a[x]>>j)&1) c[x][j]=x;
        else c[x][j]=c[fa][j];
    }
    for(int j=1;j<=31;j++)
        f[x][j]=f[f[x][j-1]][j-1];
    for(int i=head[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa) continue;
        dfs(y,x);
    }
}
int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(int j=31;j>=0;j--)
        if(dep[f[x][j]]>=dep[y])
            x=f[x][j];
    if(x==y) return x;
    for(int j=31;j>=0;j--)
        if(f[x][j]!=f[y][j])
            x=f[x][j],y=f[y][j];
    return f[x][0];
}
void work(int x,int y){
    int fa=lca(x,y);
    int len=dep[x]-dep[fa]+dep[y]-dep[fa]+1;
    Cover(1,len,1,len,0,1);
    int lx,ly,rx,ry;
    for(int i=0;i<=31;i++){
        lx=c[x][i];ly=c[y][i];
        if(dep[lx]<dep[fa] && dep[ly]>=dep[fa]){
            ry=y;
            for(int p=31;p>=0;p--)
                if(dep[c[f[ry][p]][i]]>=dep[fa]){
                    ry=f[ry][p];
                }
            Plus(1,len,len-(dep[y]-dep[ry]),len,1,1);
        }//如果此时x到lca的路径上没有,就去lca到y的路径上找
        else if(dep[lx]>=dep[fa]){  
            Plus(1,len,1+dep[x]-dep[lx],len,1,1);
        }
        if(dep[ly]<dep[fa] && dep[lx]>=dep[fa]){
            rx=x;
            for(int p=31;p>=0;p--)
                if(dep[c[f[rx][p]][i]]>=dep[fa])
                    rx=f[rx][p];
            Plus(1,len,1,1+dep[x]-dep[rx],1,1);
        }//同理
        else if(dep[ly]>=dep[fa]){          
            Plus(1,len,1,len-(dep[y]-dep[ly]),1,1);
        }
    }
    cout<<tr[1].v<<' ';
}
void solve(){
    cin>>n;
    build(1,n,1);
    cnt=0;
    for(int i=1;i<=n;i++){
        cin>>a[i];  
        head[i]=0;dep[i]=0;
    }
    int u,v;
    for(int i=1;i<n;i++){
        cin>>u>>v;
        add(u,v);add(v,u);
    }
    dfs(1,0);
    cin>>q;
    int x,y;
    while(q--){
        cin>>x>>y;
        ans=0;
        work(x,y);
    }
    cout<<endl;
}
int main(){
    cin>>T;
    while(T--)
        solve();
    return 0;
}