题解:P11235 [KTSC 2024 R1] 最大化平均值

· · 题解

题目给出了每次询问的 l,r 满足是封闭序列,考虑有什么作用。可以发现对于任意两个封闭序列,他们相交的长度是 \le 1 的。进一步的,不难发现,封闭序列最多只有 n 个,因为对于一个封闭序列 [l,r],它只会向一个方向进行拓展,且拓展的总量是 O(n) 的。

这启发我们对于封闭序列的包含关系进行建树。这个是好做的,可以用单调栈+线段树解决。

设有 tot封闭序列,考虑用 [1,tot] 依次给这 tot 个序列标号。

然后考虑 dp,设 f_{u,i} 表示编号为 u 的子树中,选了 i 个点(不包含 u 对应区间的两个端点),和最大是多少。

对于 f_{u,0},其值就是 0

对于 i>0 的转移,就是儿子先做 (\max,+) 卷积,然后再考虑 u 所有儿子对应的区间的两个端点的贡献。

然后考虑优化,发现我们要求的是 \frac{f_{u,i}}{i} 的最大值,而将 (i,f_{u,i}) 视作二维平面上的点后,我们发现,如果存在 3 个点(按 x 轴递增依次是 A,B,C),使得这 3 个点的坐标构成一个下凸壳,则 C 一定是不优于 A,B 中的一个的,这是容易证明的,考虑如果 A 不优于 C,那么 B 就优于 C 了。

所以 dp 数组中有用的值会构成一个上凸壳。

考虑维护其对应的斜率,只不过这里斜率是一个向量 (sum,cnt)

然后重定义一些运算,我们称向量 (sum,cnt) 小于等于 (sum',cnt') 当且仅当 \frac{sum}{cnt}\le \frac{sum'}{cnt'},加法是两维直接加。

然后 (\max,+) 卷积直接启发式合并+平衡树,这里用平衡树是为了等会儿算答案方便一点。

然后还要考虑 u 所有儿子对应的区间的两个端点的贡献,设其对应的向量是 p

p 加入当前维护的凸包的开头,然后把不满足上凸的点与 p 合并,一直到不能操作为止。

然后就是算答案,这一步是简单的,可以直接二分一个位置,然后看这个位置是否还能继续拓展,平衡树上是好操作的。

由于启发式合并了,所以时间复杂度 O(n\log^2 n)

#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define fir first
#define sec second
#define mk make_pair
using namespace std;
inline int read(){
    int x=0;bool f=0;char ch=getchar();
    while(ch<'0'||ch>'9')f^=(ch=='-'),ch=getchar();
    while('0'<=ch&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return f?-x:x;
}
bool st;
const int Maxn=3e5+5;
int n,a[Maxn];
ll sum[Maxn];
inline ll getsum(int l,int r){return sum[r]-sum[l-1];}
struct Node{
    int l,r;
}g[Maxn],gg[Maxn];
int stk[Maxn],len;
map<pii,int>id,vis;
int fa[Maxn],tot;
int si[Maxn],son[Maxn];
vector<int>G[Maxn];
struct seg{
    struct Tree{
        int tag,val;
    }t[Maxn<<2];
    void build(int x,int l,int r){
        t[x].val=t[x].tag=tot+1;if(l==r)return;
        int mid=l+r>>1;
        build(x<<1,l,mid);build(x<<1|1,mid+1,r);
    }
    inline void add_(int x,int p){
        t[x].val=min(t[x].val,p);
        t[x].tag=min(t[x].tag,p);
    }
    inline void spread(int x){
        add_(x<<1,t[x].tag);add_(x<<1|1,t[x].tag);
    }
    void change(int x,int l,int r,int L,int R,int p){
        if(L<=l&&r<=R)return void(add_(x,p));
        int mid=l+r>>1;spread(x);
        if(mid>=L)change(x<<1,l,mid,L,R,p);
        if(mid<R)change(x<<1|1,mid+1,r,L,R,p);
        t[x].val=min(t[x<<1].val,t[x<<1|1].val);
    }
    int query(int x,int l,int r,int L,int R){
        if(L<=l&&r<=R)return t[x].val;
        int mid=l+r>>1,res=n+1;spread(x);
        if(mid>=L)res=query(x<<1,l,mid,L,R);
        if(mid<R)res=min(res,query(x<<1|1,mid+1,r,L,R));
        return res;
    }
}T;
struct node{
    ll sum,cnt;
    inline node operator+(const node&b)const{
        return (node){sum+b.sum,cnt+b.cnt};
    }
    inline bool operator<(const node&b)const{
        return sum*b.cnt<b.sum*cnt;
    }
    inline bool operator<=(const node&b)const{
        return sum*b.cnt<=b.sum*cnt;
    }
}ans[Maxn];
#define rnd() abs((int)_rnd())
mt19937 _rnd(147744151);
struct treap{
    int rt[Maxn],cnt;
    struct tree{
        int ls,rs,fa,key;
        node val,sum;
    }t[Maxn<<5];
    inline void update(int x){
        if(!x)return;
        t[x].fa=0;
        t[t[x].ls].fa=x;t[t[x].rs].fa=x;
        t[x].sum=t[x].val+t[t[x].ls].sum+t[t[x].rs].sum;
    }
    void split2(int root,int&x,int&y,node val){
        if(!root)return void(x=y=0);
        if(t[root].val<val){
            x=root;split2(t[x].rs,t[x].rs,y,val);
        }else{
            y=root;split2(t[y].ls,x,t[y].ls,val);
        }
        update(root);
    }
    void split(int root,int&x,int&y,node val){
        if(!root)return void(x=y=0);
        if(t[root].val<=val){
            x=root;split(t[x].rs,t[x].rs,y,val);
        }else{
            y=root;split(t[y].ls,x,t[y].ls,val);
        }
        update(root);
    }
    void split_(int root,int&x,int&y,int p){
        if(!root)return void(x=y=0);
        if(root!=p){
            x=root;split_(t[x].rs,t[x].rs,y,p);
        }else{
            y=root;split_(t[y].ls,x,t[y].ls,p);
        }
        update(root);
    }
    int merge(int x,int y){
        if(!x||!y)return x^y;
        if(t[x].key>t[y].key){
            t[x].rs=merge(t[x].rs,y);
            update(x);
            return x;
        }
        t[y].ls=merge(x,t[y].ls);
        update(y);
        return y;
    }
    inline void insert(int&root,int now,node val){
        now=++cnt;
        t[now]={0,0,0,rnd(),val,val};
        int x,y;
        split(root,x,y,val);
        root=merge(merge(x,now),y);
    }
    inline void move(int x,int&root){
        if(!x)return;
        insert(root,x,t[x].val);
        move(t[x].ls,root);
        move(t[x].rs,root);
    }
    inline int getmx(int x){
        while(t[x].rs)x=t[x].rs;
//      printf("mn %d\n",x);
        return x;
    }
    inline void del(int&root,int p){
        int x,y;
        split_(root,x,y,p);
        root=x;
    }
    node calc(int x,node tp){
        if(!x)return tp;
        if(tp+t[t[x].rs].sum<t[x].val){
            tp=tp+t[x].val+t[t[x].rs].sum;
            return calc(t[x].ls,tp);
        }
        return calc(t[x].rs,tp);
    }
}fhq;

void dfs(int u){
    node tp={getsum(gg[u].l+1,gg[u].r-1),gg[u].r-gg[u].l-1};
    si[u]=1;son[u]=0;
    for(int y:G[u]){
        dfs(y);si[u]+=si[y];
        if(si[son[u]]<si[y])son[u]=y;
        tp.sum-=getsum(gg[y].l+1,gg[y].r-1);
        tp.cnt-=gg[y].r-gg[y].l-1;
    }if(u==tot+1)return;
    fhq.rt[u]=fhq.rt[son[u]];
    for(int y:G[u])if(y!=son[u]){
        fhq.move(fhq.rt[y],fhq.rt[u]);
    }
    int p=fhq.getmx(fhq.rt[u]);
    while(p&&tp<fhq.t[p].val){
        tp=tp+fhq.t[p].val;
        fhq.del(fhq.rt[u],p),p=fhq.getmx(fhq.rt[u]);
    }
    fhq.insert(fhq.rt[u],0,tp);
    tp={a[gg[u].l]+a[gg[u].r],2};
    ans[u]=fhq.calc(fhq.rt[u],tp);
}
bool en;

void initialize(vector<int>A){
    n=A.size();
    for(int i=1;i<=n;i++){
        a[i]=A[i-1];sum[i]=sum[i-1]+a[i];
        while(len&&a[stk[len]]>=a[i])len--;
        g[i].l=stk[len];
        stk[++len]=i;
    }
    len=0;
    for(int i=n;i;i--){
        while(len&&a[stk[len]]>=a[i])len--;
        g[i].r=stk[len];
        stk[++len]=i;
    }
    sort(g+1,g+1+n,[&](Node a,Node b){
        if(a.r==b.r)return a.l>b.l;
        return a.r<b.r;
    });

    for(int i=1;i<=n;i++)if(g[i].l&&g[i].r&&!id[mk(g[i].l,g[i].r)])id[mk(g[i].l,g[i].r)]=++tot;
    T.build(1,1,n);
    for(int i=n;i;i--)if(!vis[mk(g[i].l,g[i].r)]&&id[mk(g[i].l,g[i].r)]){
        vis[mk(g[i].l,g[i].r)]=1;
        int p=id[mk(g[i].l,g[i].r)];
        gg[p]=g[i];
        fa[p]=T.query(1,1,n,g[i].l+1,g[i].r-1);
        G[fa[p]].push_back(p);
        T.change(1,1,n,g[i].l+1,g[i].r-1,p);
    }
    dfs(tot+1);
}
array<ll,2>maximum_average(int l,int r){l++;r++;
    if(r-l+1==2){
        ll x=a[l]+a[r],y=2,tpp=__gcd(x,y);
        x/=tpp;y/=tpp;
        return {x,y};
    }
    int p=id[mk(l,r)];
    ll x=ans[p].sum,y=ans[p].cnt,tpp=__gcd(x,y);
    x/=tpp;y/=tpp;
    return {x,y};
}
/*
g++ seq.cpp -o seq -std=c++14 -O2
./ seq
*/