题解 CF1149C【Tree Generator™】

· · 题解

UPD on 2021.3.24:修了一个 typo

安利个人 blog:https://www.cnblogs.com/ET2006/

Codeforces 题目传送门 & 洛谷题目传送门

首先考虑这个所谓的“括号树”与直径的本质是什么。考虑括号树上两点 x,y,我们不妨用一个“DFS”的过程来理解,在 DFS 过程中假设我们在第 l 个字符后访问 x,显然接下来会访问 x 的子树并回到 x,也就是说对应的括号序列是一个合法括号序列,也就是说它的左右括号相抵消了,紧接着我们会向上回溯到 \text{LCA}(x,y),对于在回溯的过程中访问的点 z,我们可能还会访问它的其它子树,不过由于最终回到了 z,所经过的括号串一定是一个合法括号序列,最终不能相抵消的部分一定是 dep_x-dep_{\text{LCA}(x,y)} 个右括号,换句话说,从 x\text{LCA}(x,y),其经过的路径进行左右括号抵消后一定是若干个右括号拼起来的字符串。同理,从 \text{LCA}(x,y)\to y 一定是若干个左括号拼起来的字符串,也就是说 x,y 之间的路径长度就是 [l,r] 进行左括号相抵消后剩余部分的长度,我们记该值为 f(l,r)。而显然 \forall 1\leq l\leq r\leq 2(n-1),区间 [l,r] 都对应一对 (x,y)。故答案即为 \max\limits_{1\leq l\leq r\leq 2(n-1)}f(l,r)

然后我就在那儿一直想怎样直接维护 f(l,r),心态爆炸……似乎 ycx 也卡在了这个地方?

根据 f(l,r) 的定义不难发现这玩意儿直接维护是不太容易的,因为合并两个区间时还需考虑左右括号相消的问题。如果我们能够将其变成类似于求和、取 \max 的东西是不是就比较好维护了呢?

我们假设 [l,r] 消完之后剩余 x 个右括号,y 个左括号。考虑套路地将 ( 看作 1) 看作 -1。对 [l,r] 进行一遍前缀和得到数组 s_i(或者说 s_i 表示区间 [l,i] 中左括号个数 - 右括号个数),那么显然 \min_{i=l}^rs_i=-x

看到这个 \min 能想到什么呢?

考虑设 s_k=-x,我们不妨将区间 [l,r]k 处劈开,劈成两个子区间 [l,k],[k+1,r],显然 [l,k] 中左右相消后一定是 x 个右括号,[k+1,r] 中左右相消后一定是 y 个左括号。如果我们记 sum(l,r)[l,r] 中所有数字和。那么有 sum(l,k)=-x,sum(k+1,r)=y,故 sum(k+1,r)-sum(l,k)=x+y。而对于某个 k'\in[l,r),k'\neq k,由 \min_{i=l}^rs_i=-xs_k\le s'_ksum(k+1,r)-sum(l,k')=sum(l,r)-2sum(l,k)=sum(l,r)-2s_{k'}\le sum(l,r)-2s_k=x+y,故 f(l,r)=x+y=\max\limits_{k=l}^{r-1}\{sum(k+1,r)-sum(l,k)\}

于是最终答案即为 \max\limits_{1\leq l\leq r\leq 2(n-1)}\max\limits_{k=l}^{r-1}\{sum(k+1,r)-sum(l,k)\},也就是第一篇题解中所说的“选择相邻两段并做差的最大值”。

这个就可以用线段树维护了,每个节点 [l,r] 维护以下八个值:

以上八个标记都可 \mathcal O(1) pushup,具体见代码。

至于修改……这个就相当容易了罢,直接单点修改即可。

时间复杂度线对。

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;}
template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;}
typedef pair<int,int> pii;
typedef long long ll;
typedef unsigned int u32;
typedef unsigned long long u64;
namespace fastio{
    #define FILE_SIZE 1<<23
    char rbuf[FILE_SIZE],*p1=rbuf,*p2=rbuf,wbuf[FILE_SIZE],*p3=wbuf;
    inline char getc(){return p1==p2&&(p2=(p1=rbuf)+fread(rbuf,1,FILE_SIZE,stdin),p1==p2)?-1:*p1++;}
    inline void putc(char x){(*p3++=x);}
    template<typename T> void read(T &x){
        x=0;char c=getchar();T neg=0;
        while(!isdigit(c)) neg|=!(c^'-'),c=getchar();
        while(isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();
        if(neg) x=(~x)+1;
    }
    template<typename T> void recursive_print(T x){if(!x) return;recursive_print(x/10);putc(x%10^48);}
    template<typename T> void print(T x){if(!x) putc('0');if(x<0) putc('-'),x=~x+1;recursive_print(x);}
    void print_final(){fwrite(wbuf,1,p3-wbuf,stdout);}
}
const int MAXN=2e5;
int n,qu;char str[MAXN+5];
struct node{int l,r,sum,lmx,rmx,lmn,rmn,mx1,mx2,mx;} s[MAXN*4+5];
void pushup(int k){
    s[k].sum=s[k<<1].sum+s[k<<1|1].sum;
    s[k].lmx=max(s[k<<1].lmx,s[k<<1].sum+s[k<<1|1].lmx);
    s[k].lmn=min(s[k<<1].lmn,s[k<<1].sum+s[k<<1|1].lmn);
    s[k].rmx=max(s[k<<1|1].rmx,s[k<<1|1].sum+s[k<<1].rmx);
    s[k].rmn=min(s[k<<1|1].rmn,s[k<<1|1].sum+s[k<<1].rmn);
    s[k].mx1=max(s[k<<1].mx1,max(-s[k<<1].sum+s[k<<1|1].mx1,s[k<<1|1].lmx+s[k<<1].rmx*2-s[k<<1].sum));
    s[k].mx2=max(s[k<<1|1].mx2,max(s[k<<1|1].sum+s[k<<1].mx2,-s[k<<1].rmn+s[k<<1|1].sum-2*s[k<<1|1].lmn));
    s[k].mx=max(max(s[k<<1].mx,s[k<<1|1].mx),max(s[k<<1].mx2+s[k<<1|1].lmx,-s[k<<1].rmn+s[k<<1|1].mx1));
}
void build(int k=1,int l=1,int r=n){
    s[k].l=l;s[k].r=r;
    if(l==r){
        if(str[l]=='(') s[k].sum=1,s[k].lmx=1,s[k].rmx=1,s[k].lmn=0,s[k].rmn=0;
        else s[k].sum=-1,s[k].lmx=0,s[k].rmx=0,s[k].lmn=-1,s[k].rmn=-1;
        s[k].mx1=1;s[k].mx2=1;s[k].mx=1;return;
    } int mid=l+r>>1;build(k<<1,l,mid);build(k<<1|1,mid+1,r);
    pushup(k);
}
void modify(int k,int p,int v){
    if(s[k].l==s[k].r){
        s[k].lmx=s[k].rmx=max(v,0);
        s[k].lmn=s[k].rmn=min(v,0);
        s[k].sum=v;
        s[k].mx1=1;s[k].mx2=1;s[k].mx=1;return;
    } int mid=s[k].l+s[k].r>>1;
    if(p<=mid) modify(k<<1,p,v);
    else modify(k<<1|1,p,v);
    pushup(k);
}
int main(){
    scanf("%d%d%s",&n,&qu,str+1);n=(n-1)<<1;
    build(1,1,n);printf("%d\n",s[1].mx);
    while(qu--){
        int x,y;scanf("%d%d",&x,&y);
        if(str[x]!=str[y]){
            swap(str[x],str[y]);
            modify(1,x,(str[x]=='(')?1:-1);
            modify(1,y,(str[y]=='(')?1:-1);
        } printf("%d\n",s[1].mx);
    }
    return 0;
}