题解:P10787 [NOI2024] 树的定向

· · 题解

首先要想出本题的关键结论,如果每条路径 a_i,b_i 之间都有至少 2 条边没被定向,那么一定有解。

证明如下,考虑特殊性质 A,可以简单构造出一组解:

一般情况中,有一些边已经定向了。

考虑缩边,对于一条已经定向的边,考虑经过它的所有路径,不同向的路径可以直接删除,仅保留同向的路径,然后这条边就没用了,可以直接把两个端点合并。

于是,如果每条路径 a_i,b_i 之间都有至少 2 条边没被定向,那么我们可以通过缩边变为特殊性质 A,所以一定有解。

而如果存在一条路径上只有 1 条边没定向,那这条边的方向可以唯一确定。

因此我们得到了一个 \operatorname{poly}(nm) 的做法,按编号从小到大枚举每一条边的方向,判断有没有解。

注意到在特殊性质 A 中,二分图染色后黑白点并没有本质区别,所以我们给某一条边随意定向后依旧有解。

那么我们可以省掉枚举方向的步骤,如果存在一条的路径中只有 1 条边没定向,那么直接定向,否则,我们直接选择编号最小的边定向为 u\rightarrow v

现在瓶颈在于,如何支持删边和快速找到只有 1 条边没定向的路径。

注意到这是一个类似拓扑排序的过程,只不过加入队列的条件变为度数 \le 1

我们需要将每条边向经过它的所有路径连边,可以倍增优化建图做到 O(n\log n)

我们找的一条路径后,还需要判断它是否已经不可达了,如果还可达的情况下,我们需要找到那条还没定向的边

这里有个很聪明的做法,考虑用树上带权并查集维护,删边后就把这个点和父亲合并,这样找到没定向的边是容易的。

判断是否已经不可达,相当于求路径上是否有某个方向的边,因为路径上最多只有 1 条边没定向,所以路径已经合并为 O(1) 个集合,带权并查集维护每个点到祖先的边权和即可。

时空复杂度 O((n+m)\log n),常数较大。

参考代码:

#include<bits/stdc++.h>
using namespace std;
const int N=5e5+5,M=N*21;
char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<23,stdin),p1==p2)?EOF:*p1++)
template<class rd>
void read(rd &x){
    char c=getchar();
    for(;c<48||c>57;c=getchar());
    for(x=0;c>47&&c<58;c=getchar())x=(x<<1)+(x<<3)+(c^48);
}
int tid,n,m,u[N],v[N],fa[N][19],id[N][19],dp[N],tot;
int a[N],b[N],c[N],ans[N],f[N],d[N],lg[N];
int sz[M],cur[M],hd[M],to[3*M],nx[3*M],num;
vector<int>e[N];
queue<int>q;
void adde(int x,int y){
    nx[++num]=hd[x],hd[x]=num,to[num]=y;
}
void dfs(int x,int y){
    fa[x][0]=y,dp[x]=dp[y]+1,id[x][0]=++tot;
    sz[tot]=cur[tot]=1;
    for(int i=1;i<19;++i){
        fa[x][i]=fa[fa[x][i-1]][i-1];
        if(!fa[x][i])break;
        id[x][i]=++tot,sz[tot]=cur[tot]=1<<i;
        adde(id[x][i-1],tot);
        adde(id[fa[x][i-1]][i-1],tot);
    }
    for(auto v:e[x])if(v!=y)dfs(v,x);
}
int lca(int x,int y){
    if(dp[x]<dp[y])swap(x,y);
    for(int i=lg[dp[x]-dp[y]];~i;--i)if(dp[fa[x][i]]>=dp[y])x=fa[x][i];
    if(x==y)return x;
    for(int i=lg[dp[x]];~i;--i)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
void link(int x,int y,int i){
    for(int j=lg[dp[x]-dp[y]];~j;--j)if(dp[fa[x][j]]>=dp[y])
        adde(id[x][j],i),x=fa[x][j];
}
int gf(int x){
    if(f[x]==x)return x;
    int y=gf(f[x]);
    d[x]+=d[f[x]];
    f[x]=y;
    return y;
}
inline bool chk(int x,int y,bool fg){
    return fg?d[x]-d[y]<dp[x]-dp[y]:d[x]>d[y];
}
int find(int i){
    int x=gf(a[i]),y=gf(b[i]),z=gf(c[i]);
    if(x==y)return -1;
    if(x==z){
        if(chk(b[i],y,0)||chk(a[i],c[i],1))return -1;
        int o=fa[y][0];
        if(gf(o)!=z||chk(o,c[i],0))return -1;
        ans[y]=d[y]=1,f[y]=o;
        return y;
    }
    if(y==z){
        if(chk(b[i],c[i],0)||chk(a[i],x,1))return -1;
        int o=fa[x][0];
        if(gf(o)!=z||chk(o,c[i],1))return -1;
        ans[x]=d[x]=0,f[x]=o;
        return x;
    }
    return -1;
}
void upd(int x){
    int dt=sz[x]-cur[x];
    if(x<=m&&cur[x]==1)q.push(x);
    for(int i=hd[x],v;v=to[i],i;i=nx[i]){
        cur[v]-=dt;
        if(cur[v]<2)upd(v);
    }
    sz[x]=cur[x];
}
int main(){
    read(tid),read(n),read(m),tot=m,lg[0]=-1;
    for(int i=1;i<=n;++i)lg[i]=lg[i>>1]+1;
    for(int i=1;i<n;++i){
        read(u[i]),read(v[i]);
        e[u[i]].emplace_back(v[i]);
        e[v[i]].emplace_back(u[i]);
    }
    dfs(1,0);
    for(int i=1;i<=m;++i){
        read(a[i]),read(b[i]);
        c[i]=lca(a[i],b[i]);
        link(a[i],c[i],i);
        link(b[i],c[i],i);
        sz[i]=cur[i]=dp[a[i]]+dp[b[i]]-2*dp[c[i]];
        if(sz[i]==1)q.push(i);
    }
    for(int i=1;i<=n;++i)ans[i]=-1,f[i]=i,d[i]=0;
    for(int i=1;i<n;++i){
        while(!q.empty()){
            int x=q.front();
            q.pop();
            int o=find(x);
            if(~o)--cur[id[o][0]],upd(id[o][0]);
        }
        int x=u[i],y=v[i];
        if(fa[y][0]==x)swap(x,y);
        if(~ans[x])continue;
        ans[x]=d[x]=(x==u[i]),f[x]=y;
        --cur[id[x][0]],upd(id[x][0]);
    }
    for(int i=1;i<n;++i){
        if(fa[u[i]][0]==v[i])putchar(48+(ans[u[i]]^1));
        else putchar(48+ans[v[i]]);
    }
    putchar(10);
    return 0;
}