题解:P8339 [AHOI2022] 钥匙

· · 题解

我们以每种颜色的点为关键点建立虚树。注意到同种钥匙数量极少,考虑枚举配对 (key,box)

容易发现,一个配对能够产生贡献的条件是:从 key 沿简单路径走到 box,满足路径任意真前缀这种颜色钥匙数量 > 盒子数量,且整条路径上两数量相等。

因为建了虚树,暴力搜复杂度不会爆炸,直接分别以每个钥匙为虚树的根把树搜遍即可。复杂度 O(n)

这样我们得到所有合法配对。考虑这些配对能够对哪些起点和终点产生贡献。

keybox 没有祖先关系,显然起点终点需要分别在 keybox 的子树中。

keybox 祖先,终点还是在 box 子树里。设 z 为从 keybox 路径的第二个点,起点则在 z 子树的补里。

反过来是同理的。

上述子树问题容易用 dfs 序维护,下文称作 dfn。把两个点的 dfn 看作平面上两个维度,一个限制可以转化为一个或两个矩形。把询问的两点 dfn 看作点的两维,答案即为这个点被多少个矩形覆盖。扫描线维护即可。

复杂度 O((n+m)\log{n})

Code

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 500100;
const int M = 1001000;
int n,m,root,F[N],c[N];
int op[N];
vector <int> key[N],box[N];
struct edge{
    int nxt,to;
}e[N*2];
int cnt,head[N];
inline void add(int u,int v){
    F[v]=u;
    e[++cnt].nxt=head[u];
    e[cnt].to=v;
    head[u]=cnt;
    return;
}
int dfn[N],dt,fa[N][21],dep[N],L[N],R[N];
inline void dfs(int u,int f){
    dfn[u]=++dt;
    fa[u][0]=f;
    dep[u]=dep[f]+1;
    for(int i=1;i<=20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f) continue;
        dfs(v,u);
    }
    L[u]=dfn[u];
    R[u]=dt;
    return;
}
int lca(int x,int y){
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=20;i>=0;i--){
        if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
    }
    if(x==y) return x;
    for(int i=20;i>=0;i--){
        if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    }
    return fa[x][0];
}
int jump(int x,int y){
    for(int i=20;i>=0;i--){
        if(dep[fa[x][i]]>dep[y]) x=fa[x][i];
    }
    return x;
}
int st[N],top,b[N];
inline void insert(int x){
    if(!top) return st[++top]=x,void();
    int z=lca(x,st[top]);
    while(dep[st[top-1]]>dep[z]) add(st[top-1],st[top]),top--;
    if(dep[st[top]]>dep[z]) add(z,st[top]),top--;
    if(st[top]!=z) st[++top]=z;
    st[++top]=x;
    return;
}
bool cmp0(int a,int b){
    return dfn[a]<dfn[b];
}
bool tag[10][N];
inline void dfs2(int st,int u,int sk,int sb,int f,int col){
    if(op[u]==1){
        if(c[u]==col) sk++;
    }
    else if(op[u]==2){
        if(c[u]==col){
            sb++;
            if(sk==sb) tag[st][u]=1;
        }
    }else if(u==root) return;
    if(sb>=sk) return;
    if(F[u]!=f) dfs2(st,F[u],sk,sb,u,col);
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f) continue;
        dfs2(st,v,sk,sb,u,col);
    }
    return;
}
int tot;
struct segment{
    int l,r,val,y;
}sg[10*N];
bool cmp3(segment a,segment b){
    return a.y<b.y;
}
inline void Add(int X1,int X2,int Y1,int Y2){
    if(X1>X2 || Y1>Y2) return;
    sg[++tot]=(segment){X1,X2,1,Y1};
    sg[++tot]=(segment){X1,X2,-1,Y2+1};
    return;
}
struct Query{
    int s,t,id;
}q[M];
bool cmp4(Query a,Query b){
    return dfn[a.t]<dfn[b.t];
}
int tree[N];
int lowbit(int x){
    return x & (-x);
}
inline void updata(int x,int d){
    int u=x;
    while(u<=n+1){
        tree[u]+=d;
        u+=lowbit(u);
    }
    return;
}
int query(int x){
    int u=x,ans=0;
    while(u){
        ans+=tree[u];
        u-=lowbit(u);
    }
    return ans;
}
int Ans[M];
inline void clr(int u){
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        clr(v);
    }
    head[u]=0;
    F[u]=0;
    return;
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        cin>>op[i]>>c[i];
        if(op[i]==1) key[c[i]].push_back(i);
        else box[c[i]].push_back(i);
    }
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        add(u,v);
        add(v,u);
    }
    root=n+1;
    add(root,1);
    add(1,root);
    dfs(root,0);
    cnt=0;
    memset(head,0,sizeof(head));
    memset(e,0,sizeof(e));
    memset(F,0,sizeof(F));
    for(int i=1;i<=n;i++){
        int k=key[i].size()+box[i].size();
        for(int j=0;j<key[i].size();j++){
            b[j+1]=key[i][j];
        }
        for(int j=0;j<box[i].size();j++){
            for(int o=0;o<key[i].size();o++) tag[o][box[i][j]]=0;
            b[j+1+key[i].size()]=box[i][j];
        }
        sort(b+1,b+1+k,cmp0);
        insert(root);
        for(int j=1;j<=k;j++) insert(b[j]);
        while(top>1) add(st[top-1],st[top]),top--;
        for(int j=0;j<key[i].size();j++){
            dfs2(j,key[i][j],0,0,0,i);
        }
        for(int j=0;j<key[i].size();j++){
            for(int o=0;o<box[i].size();o++){
                int x=key[i][j],y=box[i][o];
                if(tag[j][y]){
                    if(L[x]<=dfn[y] && R[x]>=dfn[y]){
                        int z=jump(y,x);
                        Add(1,L[z]-1,L[y],R[y]);
                        Add(R[z]+1,n+1,L[y],R[y]);
                    }else if(L[y]<=dfn[x] && R[y]>=dfn[x]){
                        int z=jump(x,y);
                        Add(L[x],R[x],1,L[z]-1);
                        Add(L[x],R[x],R[z]+1,n+1);
                    }else{
                        Add(L[x],R[x],L[y],R[y]);
                    }
                }
            }
        }
        top=0;
        cnt=0;
        clr(root);
    }
    sort(sg+1,sg+1+tot,cmp3);
    sg[tot+1].y=n+2;
    for(int i=1;i<=m;i++){
        cin>>q[i].s>>q[i].t;
        q[i].id=i;
    }
    int now=1;
    sort(q+1,q+1+m,cmp4);
    while(dfn[q[now].t]<sg[1].y && now<=m) now++;
    for(int i=1;i<=tot;i++){
        updata(sg[i].l,sg[i].val);
        updata(sg[i].r+1,-sg[i].val);
        while(dfn[q[now].t]<sg[i+1].y && now<=m){
            Ans[q[now].id]=query(dfn[q[now].s]);
            now++;
        }
    }
    for(int i=1;i<=m;i++) cout<<Ans[i]<<'\n';
    return 0;
}