[ZJOI2019] Minimax搜索

· · 题解

一个 O(n\log n) 的做法,单纯的线段树合并,思维代码都很简单。

考虑先计算出最后的值 x,若 S 包含了 x,则 w(S)=1。我们考虑 S 不包含 x 的情况。

考虑将 x 到 1 号节点的路径剖出来,这个路径将整个树分成了若干个小树。对应的,一个集合 S 就被分到了若干颗树上。

若最后 1 号节点的值不同,那么 x 的值需要在某一点的时候被替换掉。若该层为 \max,则需要满足其对应的小数中有一个大于 x,反之小于 x,我们称这个要求为“小树的要求”。设 S=\cup T_i,若 w(T_i) 表示满足对应小树的要求所需要的最少的能量,那么明显 w(S)=\min w(T_i)

考虑一颗小树,我们计算一个序列 a_i,表示 w(S)=iS 的数量。设该小树的根是 rt,父亲是 u,同时父亲是 \min 类型。当然父亲是 \max 同理。

此时,对于 rt 来说,S 又被分到了若干个子树中,设他们为 T_i。由于 rt 的类型一定是 \max,所以要求所有子树最后的值都小于 x,所以 w(S)=\max w(T_i)

对于 rt 的儿子 u,由于 u 的类型一定是 \min,所以只要求某一个儿子的值小于 x,所以 w(S)=\min w(T_i)

通过找规律发现,最后我们的计算贡献一定是 \max ,\min 交替的。

由于每一个儿子至多只有两种贡献,一颗树有值的 a_i 的数量和该树的叶子数相等。

那么直接线段树合并就行了。

代码简洁易懂,应该一看就能明白:

#include<bits/stdc++.h>
using namespace std;
#define N 200005
#define p 998244353
#define ll long long
int n,L,R,x,dep[N],dad[N];
basic_string<int> G[N];
int ls[N<<5],rs[N<<5],tot;ll tr[N<<5],tag[N<<5];
int dfs(int u,int fa){
    dep[u]=dep[dad[u]=fa]+1;
    if(G[u].size()==1&&u!=1)return u;
    int ans;
    if(dep[u]&1){ans=0;for(int v:G[u])if(v!=fa)ans=max(ans,dfs(v,u));}
    else{ans=n;for(int v:G[u])if(v!=fa)ans=min(ans,dfs(v,u));}
    return ans;
}
inline void push(int u){
    if(tag[u]!=1)
        tag[ls[u]]=tag[ls[u]]*tag[u]%p,
        tag[rs[u]]=tag[rs[u]]*tag[u]%p,
        tr[ls[u]]=tr[ls[u]]*tag[u]%p,
        tr[rs[u]]=tr[rs[u]]*tag[u]%p,
        tag[u]=1;
}
inline void pull(int u){
    tr[u]=(tr[ls[u]]+tr[rs[u]])%p;
}
void upd(int &u,int l,int r,int x,int y){
    if(!u)u=++tot,tag[u]=1;
    if(l==r)return tr[u]=y,void();
    int m=(l+r)>>1;
    if(m>=x)upd(ls[u],l,m,x,y);
    else upd(rs[u],m+1,r,x,y);
    pull(u);
}
int merge1(int u,int v,int l,int r,ll mn1,ll mn2){//max 合并
    if(!u||!v){
        if(!u)return tr[v]=tr[v]*mn1%p,tag[v]=tag[v]*mn1%p,v;
        else return tr[u]=tr[u]*mn2%p,tag[u]=tag[u]*mn2%p,u;
    }
    if(l==r)return tr[u]=(tr[u]*mn2+tr[v]*mn1+tr[u]*tr[v])%p,u;
    push(u);push(v);
    int m=(l+r)>>1;
    rs[u]=merge1(rs[u],rs[v],m+1,r,(mn1+tr[ls[u]])%p,(mn2+tr[ls[v]])%p);
    ls[u]=merge1(ls[u],ls[v],l,m,mn1,mn2);
    return pull(u),u;
}
int merge2(int u,int v,int l,int r,ll mx1,ll mx2){//min 合并
    if(!u||!v){
        if(!u)return tr[v]=tr[v]*mx1%p,tag[v]=tag[v]*mx1%p,v;
        else return tr[u]=tr[u]*mx2%p,tag[u]=tag[u]*mx2%p,u;
    }
    if(l==r)return tr[u]=(tr[u]*mx2+tr[v]*(mx1+tr[u]))%p,u;
    push(u);push(v);
    int m=(l+r)>>1;
    ls[u]=merge2(ls[u],ls[v],l,m,(mx1+tr[rs[u]])%p,(mx2+tr[rs[v]])%p);
    rs[u]=merge2(rs[u],rs[v],m+1,r,mx1,mx2);
    return pull(u),u;
}
void output(int u,int l,int r,int L,int R){
    if(l==r)return printf("%lld ",(tr[u]-(l==n)+p)%p),void();
    int m=(l+r)>>1;push(u);
    if(m>=L)output(ls[u],l,m,L,R);
    if(m<R)output(rs[u],m+1,r,L,R);
}
int dfs(int u,int fa,int opt){
    int root=0;
    if(G[u].size()==1){
        if(!opt){//需要小于 x
            if(u<x)upd(root,0,n,0,2);
            else upd(root,0,n,n,1),upd(root,0,n,u-(x-1),1);
        }
        else{//需要大于 x
            if(u>x)upd(root,0,n,0,2);
            else upd(root,0,n,n,1),upd(root,0,n,(x+1)-u,1);
        }
    }else{
        if((dep[u]&1)^opt){//max 类型
            upd(root,0,n,0,1);
            for(int v:G[u])if(v!=fa)root=merge1(root,dfs(v,u,opt),0,n,0,0);
        }
        else{
            upd(root,0,n,n,1);//min 类型
            for(int v:G[u])if(v!=fa)root=merge2(root,dfs(v,u,opt),0,n,0,0);
        }    
    }
    return root;
}
int main(){
    scanf("%d%d%d",&n,&L,&R);
    for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),G[u]+=v,G[v]+=u;
    x=dfs(1,0);int rt=0,rt2;
    upd(rt,0,n,1,1);upd(rt,0,n,n,1);
    for(int i=x,rt2;dad[i];i=dad[i]){
        for(int v:G[dad[i]])if(v!=i&&v!=dad[dad[i]])
            rt2=dfs(v,dad[i],dep[dad[i]]&1),rt=merge2(rt,rt2,0,n,0,0);
    }
    output(rt,0,n,L,R);
}