CF1007D Ants 题解

· · 题解

题意

有一棵包含 n 个顶点的树,树上住着 m 只蚂蚁,每只蚂蚁都有自己的颜色,第 i 只蚂蚁有两对最喜欢的顶点对:(a_i, b_i)(c_i, d_i)。 你需要判断是否可以用 m 种颜色为树的边染色,使得每只蚂蚁都能仅通过自己颜色的边,从其某一对最喜欢的顶点对中的一个顶点走到另一个顶点;如果可以,请输出每只蚂蚁应该选择哪一对顶点对。

n\leq 10^5,m\leq 10^4

思路

很好的一道树剖优化建图的题。

首先看到题目首先想到 2-SAT 直接建图求方案,但是稍加分析会发现边数高达 O(m^2),时间空间都会炸。

考虑使用数据结构来减少边的数量,由于每只蚂蚁给的都是树上的一条链,考虑使用树链剖分优化。

具体的,将一条链拆成 O(\log n) 条树链,放到线段树上,最后一条链会拆成 O(\log^2 n) 个线段树上的区间。对于每个区间,如果选择这个区间,那么他的所有祖先和他子树内的区间都不可以被选择。可以根据这个建出图来,在图上跑 2-SAT。

但是这样做会发现样例过不去,因为如果一个区间选择颜色 x,由于这个区间在他自己的子树内,通过上面的部分可以得出他不能选择 x 的结论,这显然是错的。于是我们连边的时候需要避开区间本身这个节点,但是这样会忽略在这个区间内但是颜色不是 x 的。我们可以选择将每个区间表示的节点上可能的颜色记录下来,最后用前后缀优化建图来解决这部分的影响。

总的时间空间复杂度均为 O(n\log^2 n)

Code

#include <bits/stdc++.h>
#define ret return
#define pb push_back
#define mid (l+r>>1)
using namespace std;
int read(){int s=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){s=s*10+ch-'0';ch=getchar();}ret s*f;}
namespace T{
    const int N=1e5+1;vector<int> v[N];
    int dep[N],siz[N],son[N],f[N],dfn[N],tot,top[N];
    void dfs(int u,int fa){
        dep[u]=dep[f[u]=fa]+1,siz[u]=1;
        for(auto i:v[u]){
            if(i==fa)continue;
            dfs(i,u);siz[u]+=siz[i];
            if(siz[son[u]]<siz[i])son[u]=i;
        }
    }
    void init(int u,int now){
        top[u]=now,dfn[u]=++tot;
        if(son[u])init(son[u],now);
        for(auto i:v[u])if(!top[i])init(i,i);
    }
}
const int N=4e6+10;
namespace G{
    int vis[N],dfn[N],low[N],bl[N],tot,cnt;stack<int> stk;vector<int> v[N];
    void tarjan(int u){
        int d;vis[u]=1,low[u]=dfn[u]=++tot;stk.push(u);
        for(auto i:v[u]){
            if(!dfn[i]){tarjan(i);low[u]=min(low[u],low[i]);}
            else if(vis[i])low[u]=min(low[u],dfn[i]);
        }
        if(dfn[u]==low[u]){++cnt;do{d=stk.top();stk.pop();bl[d]=cnt;vis[d]=0;}while(u!=d);}
    }
}
namespace S{
    int ls[N],rs[N],f[N],tot,rup,rdown;
    vector<int> v[N],p[N],n[N];
    void build(int l,int r,int &q,int op){
        q=++tot;if(l==r)ret;
        build(l,mid,ls[q],op);f[ls[q]]=q;
        build(mid+1,r,rs[q],op);f[rs[q]]=q;
        if(op){G::v[ls[q]].pb(q);G::v[rs[q]].pb(q);}
        else{G::v[q].pb(ls[q]);G::v[q].pb(rs[q]);}
    }
    void init(int n,int m){tot=m*2+1;build(1,n,rup,1);build(1,n,rdown,0);}
    void change(int l,int r,int q,int x,int y,int in,int out,int op){
        if(x>y)ret;
        if(l==x&&r==y){
            if(op){if(ls[q])G::v[in].pb(ls[q]);if(rs[q])G::v[in].pb(rs[q]);G::v[q].pb(out);}
            else{if(f[q])G::v[in].pb(f[q]);G::v[q].pb(out);v[q].pb(out);}ret;
        }
        if(y<=mid)ret change(l,mid,ls[q],x,y,in,out,op);
        if(x>mid)ret change(mid+1,r,rs[q],x,y,in,out,op);
        change(l,mid,ls[q],x,mid,in,out,op);
        change(mid+1,r,rs[q],mid+1,y,in,out,op);
    }
    void add(int x,int y,int in,int out,int n){
        while(T::top[x]!=T::top[y]){
            if(T::dep[T::top[x]]<T::dep[T::top[y]])swap(x,y);
            change(1,n,rup,T::dfn[T::top[x]],T::dfn[x],in,out,0);
            change(1,n,rdown,T::dfn[T::top[x]],T::dfn[x],in,out,1);x=T::f[T::top[x]];
        }
        if(T::dep[x]>T::dep[y])swap(x,y);
        change(1,n,rup,T::dfn[x]+1,T::dfn[y],in,out,0);
        change(1,n,rdown,T::dfn[x]+1,T::dfn[y],in,out,1);
    }
    void update(int m){
        int cnt=tot;
        for(int i=m*2+2;i<=cnt;i++){
            int len=v[i].size();
            p[i].resize(len);
            n[i].resize(len);
            for(int j=0;j<len;j++){
                p[i][j]=++tot;
                if(j!=0)G::v[p[i][j]].pb(p[i][j-1]);
                G::v[p[i][j]].pb(v[i][j]);
            }
            for(int j=len-1;j>=0;j--){
                n[i][j]=++tot;
                if(j!=len-1)G::v[n[i][j]].pb(n[i][j+1]);
                G::v[n[i][j]].pb(v[i][j]);
            }
            for(int j=0;j<len;j++){
                if(j!=0)G::v[v[i][j]^1].pb(p[i][j-1]);
                if(j!=len-1)G::v[v[i][j]^1].pb(n[i][j+1]);
            }
        }
    }
}
void solve(){
    int n=read(),x,y;
    for(int i=1;i<n;i++){int x=read(),y=read();T::v[x].pb(y);T::v[y].pb(x);}
    T::dfs(1,0);T::init(1,1);int m=read();S::init(n,m);
    for(int i=1;i<=m;i++){
        x=read(),y=read();S::add(x,y,i*2,i*2+1,n);
        x=read(),y=read();S::add(x,y,i*2+1,i*2,n);
    }
    S::update(m);
    for(int i=2;i<=m*2+1;i++)if(!G::dfn[i])G::tarjan(i);
    for(int i=1;i<=m;i++)if(G::bl[i*2]==G::bl[i*2+1]){puts("NO");ret;}
    puts("YES");for(int i=1;i<=m;i++)printf("%d\n",G::bl[i*2]<G::bl[i*2+1]?1:2);
}
signed main(){solve();ret 0;}