题解 P8260【[CTS2022] 燃烧的呐球】
FunnyCreatress · · 题解
算是比较套路的一个题了。
观察到这是一道 CTS 的完全图最小生成树题,考虑 Boruvka 算法。那么我们以一个
注意到这不是普通的单点修矩形查,而是在一棵树上,因此树剖+线段树的结构不够优美。换成全局平衡二叉树后,单次查询复杂度骤降两个
然后 lxl 就说这个题有 2log 做法。
发现我们刚刚的做法其实有一些浪费,因为我们的仅仅是要求不在同一个连通块内而没有严格的顺序限制,于是就可以优化一下。
考虑一个子问题:维护若干个元素的集合,每个元素包含颜色和权值两个信息,支持加入元素,合并集合,查询某个集合中颜色不为给定颜色的元素中最大的那一个。
这是一个傻逼题,我们显然只需要维护权值最大的元素和与这个元素颜色不同的元素中权值最大的那一个,答案必定在这两者中产生。于是可以用类似的思想优化。
再就是大力讨论时间了。
Case 1: 两维都无关,啥也不用说。
Case 2: 两维中恰好一维为祖先后代关系,这个直接无脑搜一遍即可,数多了也没关系因为显然不优。
Case 3: 当前算的点两维都是祖先,线并维护上面那个东西就行。
Case 4: 当前算的点一维祖先一维后代,按后代维 dfs,祖先维用线段树维护。
Case 5: 两维都是后代,这个就 dfs 一维,全局平衡二叉树维护另一维,支持加点,撤销,点到根的链查最小值,都是容易的。为什么另一篇题解里这个部分这么复杂啊
这样就把复杂度优化到
但是为什么 2log 打不过 3log 啊
#include<bits/stdc++.h>
#define fir first
#define sec second
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N=1e6+5,M=1e5+5,INF=1e9;
int read(){
int res=0;char c=getchar();;
while(c<48)c=getchar();
while(c>=48)res=res*10+c-48,c=getchar();
return res;
}
int n,m,fa[N],sz[N],sb[N],hs[N],rt,ch[N][2],ffa[N],p[N],cntn,id[N],cnt,tp,tp0,tp1;pii a[M],ed[M];vector<int> son[N],hc,vx[N],vy[N],wq;ll ans;
template<typename T> void minto(T &x,T y){if(y<x)x=y;}
void dfs(int x){
sz[x]=1,id[x]=++cnt;
for(int y:son[x])dfs(y),sz[x]+=sz[y];
for(int y:son[x])if(sz[y]>sz[hs[x]])hs[x]=y;
sb[x]=sz[x]-sz[hs[x]];
}
int divide(int l,int r,int S){
if(l==r)return hc[l];
int mid=l,ts=0;
while((ts+=sb[hc[mid]])<=(S>>1)&&mid<r)mid++;
if(l!=mid)ffa[ch[hc[mid]][0]=divide(l,mid-1,ts-sb[hc[mid]])]=hc[mid];
if(r!=mid)ffa[ch[hc[mid]][1]=divide(mid+1,r,S-ts)]=hc[mid];
return hc[mid];
}
int dfs2(int x){
for(int y=x;y;y=hs[y])
for(int z:son[y])if(z!=hs[y])ffa[dfs2(z)]=y;
hc.clear();
for(int y=x;y;y=hs[y])hc.push_back(y);
return divide(0,hc.size()-1,sz[x]);
}
int fnd(int x){return p[x]==x?x:p[x]=fnd(p[x]);}
bool check(){for(int i=2;i<=m;i++)if(fnd(i)!=fnd(1))return 0;return 1;}
struct result{
pii mn,mn2;
result(){mn=mn2=make_pair(INF,0);}
result(pii Mn,pii Mn2=make_pair(INF,0)){mn=Mn,mn2=Mn2;}
friend result operator*(result u,result v){
if(u.mn.fir<v.mn.fir)return result(u.mn,min(u.mn2,v.mn.sec==u.mn.sec?v.mn2:v.mn));
else return result(v.mn,min(v.mn2,u.mn.sec==v.mn.sec?u.mn2:u.mn));
}
pii getmn(int c){return mn.sec==c?mn2:mn;}
} m0,aw[N],s[N];pair<int,result> bckup0[M],bckup1[N+M*22];
struct node{int ls,rs;result re;} st[N+M*22];pair<int,node> bckup[N+M*22];
pii operator+(pii x,int y){return make_pair(x.first+y,x.second);}
void solve_c1(){
m0=result();
for(int i=1;i<=m;i++)m0=m0*result(make_pair(sz[a[i].fir]+sz[a[i].sec],fnd(i)));
for(int i=1;i<=m;i++)minto(ed[fnd(i)],m0.getmn(fnd(i))+(sz[a[i].fir]+sz[a[i].sec]));
}
void init(int nn){for(int i=0;i<=cntn;i++)st[i]=(node){0,0,result()};cntn=nn;}
void pushup(int x){st[x].re=st[st[x].ls].re*st[st[x].rs].re;}
void insrt(int l,int r,int x,int p,pii v,bool saving=0){
if(saving)bckup[++tp]=make_pair(x,st[x]);
if(l==r){st[x].re=st[x].re*result(v);return;}
int mid=l+r>>1;
if(p<=mid){if(!st[x].ls)st[x].ls=++cntn;insrt(l,mid,st[x].ls,p,v,saving);}
else{if(!st[x].rs)st[x].rs=++cntn;insrt(mid+1,r,st[x].rs,p,v,saving);}
pushup(x);
}
void merge(int l,int r,int x,int y){
if(l==r){st[x].re=st[x].re*st[y].re;return;}
int mid=l+r>>1;
if(st[y].ls)st[x].ls?merge(l,mid,st[x].ls,st[y].ls),0:st[x].ls=st[y].ls;
if(st[y].rs)st[x].rs?merge(mid+1,r,st[x].rs,st[y].rs),0:st[x].rs=st[y].rs;
pushup(x);
}
result query(int l,int r,int x,int L,int R){
if(!x||l>R||r<L)return result();
if(l>=L&&r<=R)return st[x].re;
int mid=l+r>>1;
return query(l,mid,st[x].ls,L,R)*query(mid+1,r,st[x].rs,L,R);
}
result solve_c2x(int x,result af){
result res;
for(int i:vx[x])af=af*result(make_pair(sz[a[i].fir]+sz[a[i].sec],fnd(i)));
for(int i:vx[x])minto(ed[fnd(i)],af.getmn(fnd(i))+(sz[a[i].sec]-sz[a[i].fir])),res=res*result(make_pair(sz[a[i].sec]-sz[a[i].fir],fnd(i)));
for(int y:son[x])res=res*solve_c2x(y,af);
for(int i:vx[x])minto(ed[fnd(i)],res.getmn(fnd(i))+(sz[a[i].fir]+sz[a[i].sec]));
return res;
}
result solve_c2y(int x,result af){
result res;
for(int i:vy[x])af=af*result(make_pair(sz[a[i].fir]+sz[a[i].sec],fnd(i)));
for(int i:vy[x])minto(ed[fnd(i)],af.getmn(fnd(i))+(sz[a[i].fir]-sz[a[i].sec])),res=res*result(make_pair(sz[a[i].fir]-sz[a[i].sec],fnd(i)));
for(int y:son[x])res=res*solve_c2y(y,af);
for(int i:vy[x])minto(ed[fnd(i)],res.getmn(fnd(i))+(sz[a[i].fir]+sz[a[i].sec]));
return res;
}
void solve_c3(int x){
for(int i:vx[x])insrt(1,n,x,id[a[i].sec],make_pair(-sz[a[i].fir]-sz[a[i].sec],fnd(i)));
for(int y:son[x])solve_c3(y),merge(1,n,x,y);
for(int i:vx[x])minto(ed[fnd(i)],query(1,n,x,id[a[i].sec],id[a[i].sec]+sz[a[i].sec]-1).getmn(fnd(i))+(sz[a[i].fir]+sz[a[i].sec]));
}
void retrace(int x){while(tp>x)st[bckup[tp].fir]=bckup[tp].sec,tp--;}
void solve_c4x(int x){
int t=tp;
for(int i:vx[x])insrt(1,n,1,id[a[i].sec],make_pair(sz[a[i].fir]-sz[a[i].sec],fnd(i)),1);
for(int i:vx[x])minto(ed[fnd(i)],query(1,n,1,id[a[i].sec],id[a[i].sec]+sz[a[i].sec]-1).getmn(fnd(i))+(sz[a[i].sec]-sz[a[i].fir]));
for(int y:son[x])solve_c4x(y);
retrace(t);
}
void solve_c4y(int x){
int t=tp;
for(int i:vy[x])insrt(1,n,1,id[a[i].fir],make_pair(sz[a[i].sec]-sz[a[i].fir],fnd(i)),1);
for(int i:vy[x])minto(ed[fnd(i)],query(1,n,1,id[a[i].fir],id[a[i].fir]+sz[a[i].fir]-1).getmn(fnd(i))+(sz[a[i].fir]-sz[a[i].sec]));
for(int y:son[x])solve_c4y(y);
retrace(t);
}
bool isroot(int x){return x&&ch[ffa[x]][0]!=x&&ch[ffa[x]][1]!=x;}
void updt(int x,pii v){
bckup0[++tp0]=make_pair(x,aw[x]),aw[x]=aw[x]*result(v);
for(int y=0;!isroot(y);y=x,x=ffa[x])bckup1[++tp1]=make_pair(x,s[x]),s[x]=s[x]*result(v);
}
result queryt(int x){
result res;
for(int y=-1;x;y=x,x=ffa[x])
if(ch[x][0]!=y)res=res*aw[x]*s[ch[x][0]];
return res;
}
void retrace0(int x){while(tp0>x)aw[bckup0[tp0].fir]=bckup0[tp0].sec,tp0--;}
void retrace1(int x){while(tp1>x)s[bckup1[tp1].fir]=bckup1[tp1].sec,tp1--;}
void solve_c5(int x){
int t0=tp0,t1=tp1;
for(int i:vx[x])updt(a[i].sec,make_pair(sz[a[i].fir]+sz[a[i].sec],fnd(i)));
for(int i:vx[x])minto(ed[fnd(i)],queryt(a[i].sec).getmn(fnd(i))+(-sz[a[i].fir]-sz[a[i].sec]));
for(int y:son[x])solve_c5(y);
retrace0(t0),retrace1(t1);
}
void solve(){
wq.clear();
for(int i=1;i<=m;i++)if(fnd(i)==i)ed[i]=make_pair(INF,0),wq.push_back(i);
solve_c1(),solve_c2x(1,result()),solve_c2y(1,result());
init(n),solve_c3(1);
init(1),solve_c4x(1);
init(1),solve_c4y(1);
solve_c5(1);
for(int i:wq){
if(fnd(ed[i].sec)==fnd(i))continue;
ans+=ed[i].fir,p[fnd(i)]=fnd(ed[i].sec);
}
}
int main(){
n=read(),m=read();
for(int i=2;i<=n;i++)fa[i]=read(),son[fa[i]].push_back(i);
for(int i=1,x,y;i<=m;i++)x=read(),y=read(),a[i]=make_pair(x,y),vx[x].push_back(i),vy[y].push_back(i);
dfs(1),rt=dfs2(1);
for(int i=1;i<=m;i++)p[i]=i;
while(!check())solve();
printf("%lld\n",ans);
return 0;
}