题解 P4338 【[ZJOI2018]历史】

· · 题解

题意

给出一棵树,给定每一个点的access次数,计算轻重链切换次数的最大值,带修改.

题解

先考虑不带修改怎么做

可以发现一个点u,只有u子树里的点进行access才会影响u的答案,并且每个点都是独立的,可以分开计算

假设u的子树发生了两次access,那么当且仅当这两次access的点来自u的两个不同的儿子的子树,答案才会+1

A_0=a_u,A_i=[u的第i个儿子子树access次数之和],m=|Son_u|

既然当且仅当两次access来自不同子树会使得答案+1

要使得答案最大,就是尽量让所有相邻发生的access都来自不同子树

转化一下也就是有[0,m]种颜色,每种颜色有A_i个,最大化相邻的颜色不同次数

t=\sum_{i=0}^mA_i,h=\max_{i=0}^m\{A_i\}

那么答案就是\min\{t-1,2\times(t-h)\}

把同类型的数挪到一边就是当2\times h\ge t+1时,答案是2(t-h),否则是t-1

这里举个例子说明一下答案为啥是那个

第一种情况

颜色A\times3,B\times4,C\times5

一组可行解是CACBCACBCBAB

答案是11=12-1

第二种情况

颜色A\times2,B\times3,C\times7

一组可行解是CACBCACBCBCC

答案是10=2\times(12-7)

至于你说严格的证明我可以跟你说我最多只能通过举例子理解了吗,当然你可以问数学好的dalao

这样我们就可以在O(n)的时间内dfs一下这棵树就可以得到答案了(ps:记得开long\ long)

考虑待修改怎么做

我们可以根据这个[2\times h\ge t+1]的"分界线"去想一想怎么维护答案

f_i表示i的子树access的总次数,如果2f_i\ge f_{fa_i}+1那么连实边(i,fa_i),其他的都是虚边

考虑到如果把i子树里的点j的权值加上w,他只会影响j到根路径上的点的答案和虚实边关系

因为2(f_i+w)\ge(f_{fa_i}+w)+1,所以实边还是实边,并且答案不会变化(\Delta=2[(f_{fa_i}+w)-(f_i+w)]-2(f_{fa_i}-f_i)=0)

所以我们只需要找到路径上的虚边进行修改就好了

并且可以知道一条路径上虚边的数量\le\log\sum a(证明和树剖的证明一样,有一个"重儿子")

所以我们可以找到路径上所有虚边,然后暴力修改就好了

至于怎么找到路径上的虚边,可以用树剖+在线段树/树状数组上二分找到

当然还可以写LCT,因为这里的的操作其实就和access差不多,只是不一定会把所有边都变成实边,所以特判一下就好了,然后ans+=这个点更新后的答案-更新前的答案就好了

#include<bits/stdc++.h>
#define fp(i,a,b) for(register int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(register int i=a,I=b-1;i>I;--i)
#define go(u) for(register int i=fi[u],v=e[i].to;i;v=e[i=e[i].nx].to)
#define file(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
template<class T>inline bool cmax(T&a,const T&b){return a<b?a=b,1:0;}
template<class T>inline bool cmin(T&a,const T&b){return a>b?a=b,1:0;}
using namespace std;
char ss[1<<17],*A=ss,*B=ss;
inline char gc(){return A==B&&(B=(A=ss)+fread(ss,1,1<<17,stdin),A==B)?-1:*A++;}
template<class T>inline void sd(T&x){
    char c;T y=1;while(c=gc(),(c<48||57<c)&&c!=-1)if(c==45)y=-1;x=c-48;
    while(c=gc(),47<c&&c<58)x=x*10+c-48;x*=y;
}
char sr[1<<21],z[20];int C=-1,Z;
inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
template<class T>inline void we(T x){
    if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
    while(z[++Z]=x%10+48,x/=10);
    while(sr[++C]=z[Z],--Z);sr[++C]='\n';
}
const int N=4e5+5;
typedef int arr[N];
typedef long long ll;
struct eg{int nx,to;}e[N*2];
int n,m,ce,fi[N];ll ans;
struct LCT{
    int fa[N],ch[N][2];ll s[N],val[N],vs[N];
    #define lc(u) (ch[u][0])
    #define rc(u) (ch[u][1])
    inline bool gf(int u){return rc(fa[u])==u;}
    inline bool ir(int u){return lc(fa[u])^u&&rc(fa[u])^u;}
    inline void up(int u){s[u]=s[lc(u)]+s[rc(u)]+val[u]+vs[u];}
    inline void rot(int u){
        int p=fa[u],k=gf(u);
        if(!ir(p))ch[fa[p]][gf(p)]=u;
        if(ch[u][!k])fa[ch[u][!k]]=p;
        ch[p][k]=ch[u][!k],ch[u][!k]=p;
        fa[u]=fa[p],fa[p]=u,up(p);
    }
    void splay(int u){
        for(int f=fa[u];!ir(u);rot(u),f=fa[u])
            if(!ir(f))rot(gf(f)==gf(u)?f:u);
        up(u);
    }
    inline ll calc(int u,ll t,ll h){return rc(u)?(t-h)*2:(val[u]*2>t?(t-val[u])*2:t-1);}
    inline void mdy(int u,int w){
        splay(u);int v;
        ll t=s[u]-s[lc(u)],h=s[rc(u)];
        ans-=calc(u,t,h);s[u]+=w,val[u]+=w,t+=w;
        if(h*2<t+1)vs[u]+=h,rc(u)=0;
        ans+=calc(u,t,h);up(u);
        //access
        for(u=fa[v=u];u;u=fa[v=u]){
            splay(u);t=s[u]-s[lc(u)],h=s[rc(u)];
            ans-=calc(u,t,h);s[u]+=w,vs[u]+=w,t+=w;
            if(h*2<t+1)vs[u]+=h,rc(u)=0,h=0;
            if(s[v]*2>t)vs[u]-=s[v],rc(u)=v,h=s[v];
            ans+=calc(u,t,h);up(u);
        }
    }
    void dfs(int u){
        s[u]=val[u];int p=0;ll mx=val[u];
        go(u)if(v^fa[u]){
            fa[v]=u,dfs(v),s[u]+=s[v];
            if(s[v]>mx)mx=s[p=v];
        }
        ans+=min(s[u]-1,(s[u]-mx)*2);
        if(mx*2>=s[u]+1)rc(u)=p;
        vs[u]=s[u]-val[u]-s[rc(u)];
    }
}t;
inline void add(int u,int v){e[++ce]={fi[u],v},fi[u]=ce;}
int main(){
    #ifndef ONLINE_JUDGE
        file("s");
    #endif
    sd(n),sd(m);int u,v;
    fp(i,1,n)sd(t.val[i]);
    fp(i,2,n)sd(u),sd(v),add(u,v),add(v,u);
    t.dfs(1);we(ans);
    while(m--){
        sd(u),sd(v);
        t.mdy(u,v);
        we(ans);
    }
return Ot(),0;
}

那个calc(u,t,h)或许这么写会好理解一些?

inline ll calc(int u,ll t,ll h){
    if(rc(u))return (t-h)*2;
    else if(val[u]>=t+1)return (t-val[u])*2;
    else return t-1
}