题解:P10787 [NOI2024] 树的定向

· · 题解

Solution

先把最容易定向的边处理了。发现如果 a_i \to b_i 的路径上只有 1 条边没有被定向,且其他边都是按照 a \to b 的顺序了,那么剩下的边是可以确定的。

处理完一波之后,只剩下了距离 \ge 2 的点。而你对着缩点后的结果做黑白染色,那么一定有解,且每条边一定可以 01 都行。

所以你就又确定了一条边,继续做就行。

这样是 O(nm) 的。我们只需要加速找到“恰有一条边”的路径。

为啥大家不会想到 KDT 呢。就是你把路径按照 (\min\{dfn_u,dfn_v\},\max\{dfn_u,dfn_v\}) 扔到平面上,你发现每次只有把两个矩形中的路径剩余长度减一。

那么使用 KDT 维护容易做到 O(n \sqrt n),轻松通过 84 分。

鉴于现在没有这样的代码,我在末尾放一个。

考虑实现报警器类似物(树上倍增)。对于长度为 2^0 的报警器,它只监视一条边(相当于最开始已经报警一起)。其他所有报警器都只看着另外两个报警器。如果一个报警器的子报警器一共报警了 3 次,那么它会报警一次;子报警器一共报警了 4 次,他还会再报警一次。

直接暴力实现报警复杂度显然是 O(n \log n)。感觉常数大的离谱,跑得和根号差不多。

#include<bits/stdc++.h>
#define ffor(i,a,b) for(int i=(a);i<=(b);i++)
#define roff(i,a,b) for(int i=(a);i>=(b);i--)
using namespace std;
const int MAXN=5e5+10;
int cid,n,m,tot,dsu[MAXN],u[MAXN],v[MAXN],pu[MAXN],pv[MAXN],ans[MAXN],fa[MAXN][20],dep[MAXN],tr[2][MAXN];
void update(int pos,const int op,const int v) {while(pos<=n) tr[op][pos]+=v,pos+=pos&-pos;return ;}
int query(int pos,const int op) {int ans=0;while(pos) ans+=tr[op][pos],pos-=pos&-pos;return ans;}
vector<int> G[MAXN];
set<int> st;
int dfn[MAXN],sze[MAXN],fid[MAXN];
inline void dfs(const int u,const int f) {
    fa[u][0]=f,dep[u]=dep[f]+1,sze[u]=1,dfn[u]=++tot;
    ffor(i,1,19) fa[u][i]=fa[fa[u][i-1]][i-1];
    for(auto v:G[u]) if(v!=f) dfs(v,u),sze[u]+=sze[v];
    return ;
}
inline int find(const int k) {return (dsu[k]==k)?k:(dsu[k]=find(dsu[k]));}
struct KDT {
    int a,b,l,ls,rs,del,len;
    int x,y,xmn,xmx,ymn,ymx;
}t[MAXN];
int del[MAXN],mn[MAXN];
inline void push_up(const int u,const int op=0){
    if(mn[t[u].ls]>n) t[u].ls=0;
    if(mn[t[u].rs]>n) t[u].rs=0;
    mn[u]=min(t[u].len,min(mn[t[u].ls],mn[t[u].rs]))-t[u].del;
    if(op) {
        t[u].xmn=t[u].xmx=t[u].x,t[u].ymn=t[u].ymx=t[u].y;
        if(t[u].ls) {
            int l=t[u].ls;
            t[u].xmn=min(t[u].xmn,t[l].xmn),t[u].xmx=max(t[u].xmx,t[l].xmx);
            t[u].ymn=min(t[u].ymn,t[l].ymn),t[u].ymx=max(t[u].ymx,t[l].ymx);
        }
        if(t[u].rs) {
            int r=t[u].rs;
            t[u].xmn=min(t[u].xmn,t[r].xmn),t[u].xmx=max(t[u].xmx,t[r].xmx);
            t[u].ymn=min(t[u].ymn,t[r].ymn),t[u].ymx=max(t[u].ymx,t[r].ymx);
        }
    }
    return ;
}
inline void modify(const int u,const int x,const int y,const int X,const int Y) {
    if(!u||x>t[u].xmx||y>t[u].ymx||X<t[u].xmn||Y<t[u].ymn) return ;
    if(x<=t[u].xmn&&t[u].xmx<=X&&y<=t[u].ymn&&t[u].ymx<=Y) return t[u].del++,mn[u]--,void();
    if(x<=t[u].x&&t[u].x<=X&&y<=t[u].y&&t[u].y<=Y) t[u].len--;
    modify(t[u].ls,x,y,X,Y),modify(t[u].rs,x,y,X,Y);
    return push_up(u),void(); 
}
set<int> ke;
inline int lca(int u,int v) {
    if(dep[u]<dep[v]) swap(u,v);
    roff(i,19,0) if((dep[u]-dep[v])&(1<<i)) u=fa[u][i];
    if(u==v) return u;
    roff(i,19,0) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
    return fa[u][0];
}
inline void check(const int U,const int V,const int l) {
    int al=dep[U]+dep[V]-2*dep[l],cnt=query(dfn[U],0)+query(dfn[V],1)-query(dfn[l],0)-query(dfn[l],1);
    if(cnt+1==al) {
        int s=find(U),t=find(V);
        if(dep[s]<dep[t]) ans[fid[t]]=(t==v[fid[t]]),ke.insert(fid[t]);
        else ans[fid[s]]=(s!=v[fid[s]]),ke.insert(fid[s]);
    }
    return ;
}
inline void solve(const int u,int odel) {
    if(!u||mn[u]-odel>1) return ;
    odel+=t[u].del;
    if(t[u].len-odel==1) check(t[u].a,t[u].b,t[u].l),t[u].len=INT_MAX;
    else if(t[u].len-odel==0) t[u].len=INT_MAX;
    solve(t[u].ls,odel),solve(t[u].rs,odel),push_up(u);
    return ;
}
inline int build(vector<int> vc,const int op) {
    if(vc.empty()) return 0;
    if(op==0) sort(vc.begin(),vc.end(),[&](int u,int v){return t[u].x<t[v].x;});
    else sort(vc.begin(),vc.end(),[&](int u,int v){return t[u].y<t[v].y;});
    int s=vc.size();
    int rt=(s+1)/2,u=vc[rt-1];
    vector<int> L,R;
    ffor(i,1,rt-1) L.emplace_back(vc[i-1]);
    ffor(i,rt+1,vc.size()) R.emplace_back(vc[i-1]);
    t[u].ls=build(L,op^1),t[u].rs=build(R,op^1);
    return push_up(u,1),u;
}
int main() {
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    mn[0]=INT_MAX;
    cin>>cid>>n>>m;
    ffor(i,1,n) dsu[i]=i;
    ffor(i,1,n-1) cin>>u[i]>>v[i],G[u[i]].emplace_back(v[i]),G[v[i]].emplace_back(u[i]),ans[i]=-1,st.insert(i);
    dfs(1,0);
    ffor(i,1,n-1) if(fa[u[i]][0]==v[i]) fid[u[i]]=i;
    else fid[v[i]]=i;
    ffor(i,1,m) {
        int u=i;
        cin>>t[u].a>>t[u].b,t[u].l=lca(t[u].a,t[u].b),t[u].x=min(dfn[t[u].a],dfn[t[u].b]),t[u].y=max(dfn[t[u].a],dfn[t[u].b]);
        t[u].len=dep[t[u].a]+dep[t[u].b]-2*dep[t[u].l];
    }
    vector<int> vc;
    ffor(i,1,m) vc.push_back(i);
    int rt=build(vc,0),cnt=n-1;
    while(cnt) {
        while(mn[rt]<=1) {
            solve(rt,0);
            for(auto id:ke) {
                --cnt;
                int uid=u[id],vid=v[id];
                if(fa[uid][0]==vid) swap(uid,vid);
                dsu[find(vid)]=find(uid);
                if(dfn[vid]!=1) modify(rt,1,dfn[vid],dfn[vid]-1,dfn[vid]+sze[vid]-1);
                if(dfn[vid]+sze[vid]-1!=n) modify(rt,dfn[vid],dfn[vid]+sze[vid],dfn[vid]+sze[vid]-1,n);

                int frm=u[id],to=v[id];
                if(ans[id]==1) swap(frm,to);
                if(fa[frm][0]==to) update(dfn[frm],0,1),update(dfn[frm]+sze[frm],0,-1);
                else update(dfn[to],1,1),update(dfn[to]+sze[to],1,-1);
            }
            ke.clear();
        }
        if(!cnt) break ;
        while(!st.empty()&&ans[*st.begin()]!=-1) st.erase(st.begin());
        int id=*st.begin();
        ans[id]=0,--cnt;
        int uid=u[id],vid=v[id];
        if(fa[uid][0]==vid) swap(uid,vid);
        dsu[find(vid)]=find(uid);
        if(dfn[vid]!=1) modify(rt,1,dfn[vid],dfn[vid]-1,dfn[vid]+sze[vid]-1);
        if(dfn[vid]+sze[vid]-1!=n) modify(rt,dfn[vid],dfn[vid]+sze[vid],dfn[vid]+sze[vid]-1,n);
        int frm=u[id],to=v[id];
        if(ans[id]==1) swap(frm,to);

        if(fa[frm][0]==to) update(dfn[frm],0,1),update(dfn[frm]+sze[frm],0,-1);
        else update(dfn[to],1,1),update(dfn[to]+sze[to],1,-1);
    }
    ffor(i,1,n-1) cout<<ans[i];
    return 0;
}
#include<bits/stdc++.h>
#define ffor(i,a,b) for(int i=(a);i<=(b);i++)
#define roff(i,a,b) for(int i=(a);i>=(b);i--)
using namespace std;
const int MAXN=5e5+10;
int cid,n,m,tot,dsu[MAXN],u[MAXN],v[MAXN],pu[MAXN],pv[MAXN],ans[MAXN],fa[MAXN][20],dep[MAXN],tr[2][MAXN];
void update(int pos,const int op,const int v) {while(pos<=n) tr[op][pos]+=v,pos+=pos&-pos;return ;}
int query(int pos,const int op) {int ans=0;while(pos) ans+=tr[op][pos],pos-=pos&-pos;return ans;}
vector<int> G[MAXN],nxt[MAXN][20];
int nd[MAXN],bj[MAXN][20];
set<int> st;
int dfn[MAXN],sze[MAXN],fid[MAXN];
inline void dfs(const int u,const int f) {
    fa[u][0]=f,dep[u]=dep[f]+1,sze[u]=1,dfn[u]=++tot;
    ffor(i,1,19) fa[u][i]=fa[fa[u][i-1]][i-1];
    for(auto v:G[u]) if(v!=f) dfs(v,u),sze[u]+=sze[v];
    return ;
}
inline int find(const int k) {return (dsu[k]==k)?k:(dsu[k]=find(dsu[k]));}
int A[MAXN],B[MAXN]; 
set<int> ke;
inline int lca(int u,int v) {
    if(dep[u]<dep[v]) swap(u,v);
    roff(i,19,0) if((dep[u]-dep[v])&(1<<i)) u=fa[u][i];
    if(u==v) return u;
    roff(i,19,0) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
    return fa[u][0];
}
inline void check(const int U,const int V,const int l) {
    int al=dep[U]+dep[V]-2*dep[l],cnt=query(dfn[U],0)+query(dfn[V],1)-query(dfn[l],0)-query(dfn[l],1);
    int mz=query(dfn[U],0)+query(dfn[U],1)+query(dfn[V],0)+query(dfn[V],1)-2*query(dfn[l],0)-2*query(dfn[l],1);
    if(cnt+1==al&&mz==cnt) {
        int s=find(U),t=find(V);

        if(dep[s]<dep[t]) {if(ans[fid[t]]==-1) ans[fid[t]]=(t==v[fid[t]]),ke.insert(fid[t]);}
        else {if(ans[fid[s]]==-1) ans[fid[s]]=(s!=v[fid[s]]),ke.insert(fid[s]);}
    }
    return ;
}
int cnt;
void solve(int id) {
    int uid=u[id],vid=v[id];
    if(fa[uid][0]==vid) swap(uid,vid);
    dsu[find(vid)]=find(uid);

    int frm=u[id],to=v[id];
    if(ans[id]==1) swap(frm,to);
    if(fa[frm][0]==to) update(dfn[frm],0,1),update(dfn[frm]+sze[frm],0,-1);
    else update(dfn[to],1,1),update(dfn[to]+sze[to],1,-1);
    return ;
}
void alarm(int u,int j) {
    for(auto id:nxt[u][j]) {
        if(id<0) {
            int v=-id;
            --nd[v];
            if(nd[v]==1) check(A[v],B[v],lca(A[v],B[v])); 
        }
        else {
            int nu=id/20,nj=id%20;
            ++bj[nu][nj];
            if(bj[nu][nj]>=3) alarm(nu,nj);
        }
    }
    return ;
}
int main() {
    ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    cin>>cid>>n>>m;
    ffor(i,1,n) dsu[i]=i;
    ffor(i,1,n-1) cin>>u[i]>>v[i],G[u[i]].emplace_back(v[i]),G[v[i]].emplace_back(u[i]),ans[i]=-1,st.insert(i);
    dfs(1,0);
    ffor(i,1,n-1) if(fa[u[i]][0]==v[i]) fid[u[i]]=i;
    else fid[v[i]]=i;
    ffor(i,1,m) cin>>A[i]>>B[i];
    ffor(i,1,n) ffor(j,1,19) if(fa[i][j]) {
        nxt[i][j-1].push_back(i*20+j);
        nxt[fa[i][j-1]][j-1].push_back(i*20+j);
    }
    ffor(i,1,m) {
        int u=A[i],v=B[i];
        if(dep[u]<dep[v]) swap(u,v);
        roff(j,19,0) if((dep[u]-dep[v])&(1<<j)) nxt[u][j].push_back(-i),nd[i]+=2,u=fa[u][j];
        if(u!=v) {
            roff(j,19,0) if(fa[u][j]!=fa[v][j]) {
                nxt[u][j].push_back(-i);
                nxt[v][j].push_back(-i);
                u=fa[u][j],v=fa[v][j],nd[i]+=4;
            }
            nxt[u][0].push_back(-i),nxt[v][0].push_back(-i),nd[i]+=4;
        }
    }
    cnt=n-1;
    ffor(i,2,n) {
        alarm(i,0);
        while(!ke.empty()) {
            auto mzx=ke;
            ke.clear();
            vector<int> pos;
            for(auto id:mzx) {
                --cnt;
                if(fa[v[id]][0]==u[id]) pos.push_back(v[id]);
                else pos.push_back(u[id]);
                solve(id);
            }
            for(auto id:pos) alarm(id,0);
        }
    }
    while(cnt) {
        while(!ke.empty()) {
            auto mzx=ke;
            ke.clear();
            vector<int> pos;
            for(auto id:mzx) {
                --cnt;
                if(fa[v[id]][0]==u[id]) pos.push_back(v[id]);
                else pos.push_back(u[id]);
                solve(id);
            }
            for(auto id:pos) alarm(id,0);
        }
        if(!cnt) break ;
        while(!st.empty()&&ans[*st.begin()]!=-1) st.erase(st.begin());
        int id=*st.begin();
        ans[id]=0,--cnt,solve(id);
        if(fa[v[id]][0]==u[id]) alarm(v[id],0);
        else alarm(u[id],0);
    }
    ffor(i,1,n-1) cout<<ans[i];
    return 0;
}