题解:P11235 「KTSC 2024 R1」最大化平均值

· · 题解

你们韩国人原来喜欢出这种题吗。

模拟赛题,调了一下午都在挂,最后发现是预处理少了个特判,发怒。

我们称一个区间是好的当且仅当其满足题目中『封闭序列』的定义,第一步观察就是发现这个条件是比较苛刻的,好的区间数量应该不会太多。

事实上确实不多,好的区间只有 O(n) 个,下面我们来说明这件事情。不妨假设好区间 [l,r] 满足 a_l\ge a_r,因为限制相当于是要区间的两端比中间都小,所以对于一个固定的 r 合法的 l 相当于是 [1,r-1] 里所有的后缀最小值的位置。而如果我们从左往右扫序列并同时维护一个元素递增的单调栈,容易发现此时合法的 l 就对应了加入 a_r 时栈内被弹出的位置。

对于 a_l<a_r 的情况是对称的,改成从右往左扫即可找出每个 l 所对应的 r(注意这里一定要判不能取等,你猜我前面一下午因为啥没调出来),所以说好的区间总数量应该是不超过 2n 的,同时我们可以很方便地把好区间全都求出来。

根据前面的过程很难不发现,好的区间之间要么包含么不交,不会出现相交而不包含的情况。这种结构一种经典想法是按照区间的包含关系建树,也就是每个区间以包含它的区间中长度最小的那个为它的父亲,这样会形成一个森林,可以手动加个根变成树。

当然我们不能暴力建树,一个方便的做法是把这些区间按照长度降序排序,对每个位置维护当前长度最小的覆盖它的区间是哪个,这样对每个区间的端点做个单点查询就找到了它的父亲,然后相当于要做一个区间染色把覆盖的编号改成自己,拉个 ODT 或者线段树来即可。

由于每次询问的区间一定也是好区间,所以现在相当于是每次选定一棵子树,可以在这棵子树内选择任意多个两两不成祖先后代关系的区间删去,最大化剩下的平均值。

一个暴力的想法是直接在这棵树上 dp,可以设 f_{u,i} 表示 u 子树内一共保留了 i 个数时留下的和的最大值。转移比较简单,先考虑 u 不选的情况就是把 u 所有儿子的背包拉出来做个 (\max,+) 卷积,然后因为不选 u 时区间内有些数一定不会被删掉所以要再做个类似整体加和下标平移的操作。而选了 u 时区间内会只剩下两个数,有 a_l+a_r\to f_{u,2}。直接实现这个 dp 的复杂度同树上背包,是 O(n^2) 的。

考虑在这个 dp 的基础上继续优化。我们要做的其实是对每个 u 查询 \frac{f_{u,i}}{i} 的最小值,不妨把每个 f_{u,i} 看成坐标系上的点 (i,f_{u,i}),其实就是在若干个点里找到一个点使它和原点连线的斜率最大。容易发现点集内不在上凸壳上的点一定不可能是最优的点,因为考虑三个下凸的点 A,B,C,不妨设 x_A<x_B<x_C,分类讨论后发现无论什么情况 A,C 两个点都至少有一个比 B 要优,所以可以归纳出答案只可能在上凸壳上的点取到。

因此只需要保留上凸壳上的点,具体来说,对于每个 u 我们维护上凸壳上每条边对应的向量,把儿子拉出来做 (\max,+) 卷积相当于对儿子的向量集合归并,整体加和整体平移就相当于是在最前方加入一个向量(加入这个向量后为了保持凸性可能会删掉前缀的一些边,这是好处理的)。查询时,相当于找到凸壳上最靠左的点使得再走下一条边会使答案变劣,这个是有单调性的可以直接二分。

但是本题是放到树上的,不太能直接存凸壳上的每个向量。我们使用平衡树来维护凸包上的向量,同时维护向量的前缀和,这样把儿子的向量归并可以直接 dsu on tree 来维护,维护整体平移只需要不断把平移的向量和第一个向量比较直到不破坏上凸时插入,查询答案的时候直接在平衡树上二分即可。

这样最后的复杂度是 O(n\log^2 n+q),因为我们可以预处理每个区间的答案查询时直接查表。

代码不算太难写,大概 5k 左右的码量。我也在下面的代码里加入了一些注释,可以帮助理解维护的过程,还是比较清晰的。

#include<bits/stdc++.h>
#define int long long
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define x0 guanzhuyongchutafei
#define y0 xiexiemiao
using namespace std;
const int N = 1e6 + 5;
int n, a[N], sum[N]; 
int fa[N], st[N], ed[N], sz[N], son[N];
vector<int> e[N];
int S(int l, int r){return sum[r] - sum[l - 1];}
struct node{  //注意这个存的是向量不是点
    int x, y;
    node(int x = 0, int y = 0):x(x), y(y){}
    friend int cross(const node &a, const node &b){return a.x * b.y - a.y * b.x;}
    friend bool operator<(const node &a, const node &b){return cross(a, b) < 0;}
    friend node operator+(const node &a, const node &b){return node(a.x + b.x, a.y + b.y);}
    friend node operator-(const node &a, const node &b){return node(a.x - b.x, a.y - b.y);}
};
mt19937 nrd(114514);
struct treap{  //平衡树维护凸壳上的向量
    int tot, top, sz[N], key[N], ls[N], rs[N], st[N];
    node val[N], sum[N];
    int newnode(node w){
        int p = top ? st[top--] : ++tot;  //注意 dsu on tree 不加垃圾回收空间复杂度是会多 log 的
        ls[p] = rs[p] = 0;
        key[p] = nrd(); sz[p] = 1;
        return val[p] = sum[p] = w, p;
    }
    void pushup(int now){
        sz[now] = sz[ls[now]] + sz[rs[now]] + 1;
        sum[now] = sum[ls[now]] + sum[rs[now]] + val[now];
    }
    void split(int now, int k, int &x, int &y){
        if(!now) return x = y = 0, void();
        if(sz[ls[now]] + 1 <= k) x = now, split(rs[now], k - sz[ls[now]] - 1, rs[now], y);
        else y = now, split(ls[now], k, x, ls[now]);
        pushup(now);
    }
    int merge(int x, int y){
        if(!x || !y) return x | y;
        if(key[x] < key[y]) return rs[x] = merge(rs[x], y), pushup(x), x;
        return ls[y] = merge(x, ls[y]), pushup(y), y;
    }
    int getmn(int now){if(!now) return 0; return !ls[now] ? now : getmn(ls[now]);}
    void popmn(int &now){int x, y; split(now, 1, x, y); now = y;}
    int getrk(int now, node w){
        if(!now) return 0;
        if(val[now] < w) return sz[ls[now]] + 1 + getrk(rs[now], w);
        return getrk(ls[now], w);
    }
    void insert(int &now, node w){
        int k = getrk(now, w), x, y;
        split(now, k, x, y);
        now = merge(merge(x, newnode(w)), y);
    }
    void del(int now, vector<node> &vec){
        st[++top] = now;
        if(ls[now]) del(ls[now], vec);
        vec.push_back(val[now]);
        if(rs[now]) del(rs[now], vec);
    }
    void solve(int now, node w, node &ans){  //平衡树上二分查询答案
        if(!now) return;
        node tmp = sum[ls[now]];
        if(cross(w + tmp, val[now]) >= 0) ans = w + tmp + val[now], solve(rs[now], w + tmp + val[now], ans);
        else solve(ls[now], w, ans);
    }
}fhq;
struct ODT{  //ODT 维护连续段,用来建树
    struct range{
        int l, r, v;
        range(int l = 0, int r = 0, int v = 0):l(l), r(r), v(v){}
        friend bool operator<(const range& a, const range &b){
            return a.l < b.l;
        }
    };
    set<range> st;
    using sit = set<range>::iterator;
    ODT(){st.insert(range(1, n, 0));}
    sit split(int pos){
        if(pos == n + 1) return st.end();
        auto it = st.lower_bound(range(pos));
        if(it != st.end() && it->l == pos) return it;
        --it; auto[l, r, v] = *it; st.erase(it);
        st.insert(range(l, pos - 1, v));
        return st.insert(range(pos, r, v)).first;
    }
    void assign(int l, int r, int v){
        auto itr = split(r + 1), itl = split(l);
        st.erase(itl, itr); st.insert(range(l, r, v));
    }
    int query(int pos){
        auto it = st.lower_bound(range(pos));
        if(it != st.end() && it->l == pos) return it->v;
        --it; return it->v;
    }
}odt;
map<pair<int, int>, node> ans;
int rt[N];
void dfs(int now){
    sz[now] = 1;
    int x0 = ed[now] - st[now] + 1, y0 = S(st[now], ed[now]);  //这个表示区间内不会被删的点的贡献形成的向量
    for(int to : e[now]){
        dfs(to); sz[now] += sz[to];
        if(sz[to] > sz[son[now]]) son[now] = to;
        x0 -= ed[to] - st[to] + 1, y0 -= S(st[to], ed[to]);
    }
    if(son[now]) rt[now] = rt[son[now]];
    vector<node> rub;
    for(int to : e[now]){
        if(to == son[now]) continue;
        rub.clear(); fhq.del(rt[to], rub);
        for(auto w : rub) fhq.insert(rt[now], w);  //先直接把儿子的向量归并过来
    }
    int p = fhq.getmn(rt[now]);
    while(p){
        if(cross(fhq.val[p], node(x0, y0)) < 0){  //在前方加入平移的向量,删去前缀加入后下凸的边
            x0 += fhq.val[p].x, y0 += fhq.val[p].y;
            fhq.popmn(rt[now]); p = fhq.getmn(rt[now]);
        }
        else break;
    }
    fhq.insert(rt[now], node(x0, y0));
    node tmp;
    fhq.solve(rt[now], node(2, a[st[now] - 1] + a[ed[now] + 1]), tmp);  //(2,a[l]+a[r]) 一定是这个凸壳上的第一个点,其它的点根据我们维护的向量可以还原出来
    ans[{st[now] - 1, ed[now] + 1}] = tmp;
}
array<int, 2> maximum_average(signed i, signed j){
    int x = 0, y = 0; i++, j++;
    if(j == i + 1) x = a[i] + a[j], y = 2;
    else x = ans[{i, j}].y, y = ans[{i, j}].x;
    int d = __gcd(x, y);
    return {x / d, y / d};
}
void initialize(vector<signed> A){
    n = A.size();
    For(i, 1, n) a[i] = A[i - 1], sum[i] = sum[i - 1] + a[i];
    vector<pair<int, int>> vec;
    static int top, stk[N];
    For(i, 1, n){
        while(top && a[stk[top]] >= a[i]){
            if(stk[top] < i - 1) vec.emplace_back(stk[top] + 1, i - 1);
            top--;
        }
        stk[++top] = i;
    }
    top = 0;
    Rof(i, n, 1){
        while(top && a[stk[top]] >= a[i]){
            if(stk[top] > i + 1 && a[i] != a[stk[top]]) vec.emplace_back(i + 1, stk[top] - 1);
            top--;
        }
        stk[++top] = i;
    }
    sort(vec.begin(), vec.end(), [](pair<int, int> a, pair<int, int> b){
        int x = a.second - a.first + 1, y = b.second - b.first + 1;
        return x > y;
    });
    int np = vec.size();
    For(i, 1, np) st[i] = vec[i - 1].first, ed[i] = vec[i - 1].second;
    For(i, 1, np){
        e[fa[i] = odt.query(ed[i])].push_back(i);
        odt.assign(st[i], ed[i], i);
    }
    st[0] = 1; ed[0] = n; dfs(0);
}