题解:P11235 「KTSC 2024 R1」最大化平均值
KingPowers · · 题解
你们韩国人原来喜欢出这种题吗。
模拟赛题,调了一下午都在挂,最后发现是预处理少了个特判,发怒。
我们称一个区间是好的当且仅当其满足题目中『封闭序列』的定义,第一步观察就是发现这个条件是比较苛刻的,好的区间数量应该不会太多。
事实上确实不多,好的区间只有
对于
根据前面的过程很难不发现,好的区间之间要么包含么不交,不会出现相交而不包含的情况。这种结构一种经典想法是按照区间的包含关系建树,也就是每个区间以包含它的区间中长度最小的那个为它的父亲,这样会形成一个森林,可以手动加个根变成树。
当然我们不能暴力建树,一个方便的做法是把这些区间按照长度降序排序,对每个位置维护当前长度最小的覆盖它的区间是哪个,这样对每个区间的端点做个单点查询就找到了它的父亲,然后相当于要做一个区间染色把覆盖的编号改成自己,拉个 ODT 或者线段树来即可。
由于每次询问的区间一定也是好区间,所以现在相当于是每次选定一棵子树,可以在这棵子树内选择任意多个两两不成祖先后代关系的区间删去,最大化剩下的平均值。
一个暴力的想法是直接在这棵树上 dp,可以设
考虑在这个 dp 的基础上继续优化。我们要做的其实是对每个
因此只需要保留上凸壳上的点,具体来说,对于每个
但是本题是放到树上的,不太能直接存凸壳上的每个向量。我们使用平衡树来维护凸包上的向量,同时维护向量的前缀和,这样把儿子的向量归并可以直接 dsu on tree 来维护,维护整体平移只需要不断把平移的向量和第一个向量比较直到不破坏上凸时插入,查询答案的时候直接在平衡树上二分即可。
这样最后的复杂度是
代码不算太难写,大概 5k 左右的码量。我也在下面的代码里加入了一些注释,可以帮助理解维护的过程,还是比较清晰的。
#include<bits/stdc++.h>
#define int long long
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define x0 guanzhuyongchutafei
#define y0 xiexiemiao
using namespace std;
const int N = 1e6 + 5;
int n, a[N], sum[N];
int fa[N], st[N], ed[N], sz[N], son[N];
vector<int> e[N];
int S(int l, int r){return sum[r] - sum[l - 1];}
struct node{ //注意这个存的是向量不是点
int x, y;
node(int x = 0, int y = 0):x(x), y(y){}
friend int cross(const node &a, const node &b){return a.x * b.y - a.y * b.x;}
friend bool operator<(const node &a, const node &b){return cross(a, b) < 0;}
friend node operator+(const node &a, const node &b){return node(a.x + b.x, a.y + b.y);}
friend node operator-(const node &a, const node &b){return node(a.x - b.x, a.y - b.y);}
};
mt19937 nrd(114514);
struct treap{ //平衡树维护凸壳上的向量
int tot, top, sz[N], key[N], ls[N], rs[N], st[N];
node val[N], sum[N];
int newnode(node w){
int p = top ? st[top--] : ++tot; //注意 dsu on tree 不加垃圾回收空间复杂度是会多 log 的
ls[p] = rs[p] = 0;
key[p] = nrd(); sz[p] = 1;
return val[p] = sum[p] = w, p;
}
void pushup(int now){
sz[now] = sz[ls[now]] + sz[rs[now]] + 1;
sum[now] = sum[ls[now]] + sum[rs[now]] + val[now];
}
void split(int now, int k, int &x, int &y){
if(!now) return x = y = 0, void();
if(sz[ls[now]] + 1 <= k) x = now, split(rs[now], k - sz[ls[now]] - 1, rs[now], y);
else y = now, split(ls[now], k, x, ls[now]);
pushup(now);
}
int merge(int x, int y){
if(!x || !y) return x | y;
if(key[x] < key[y]) return rs[x] = merge(rs[x], y), pushup(x), x;
return ls[y] = merge(x, ls[y]), pushup(y), y;
}
int getmn(int now){if(!now) return 0; return !ls[now] ? now : getmn(ls[now]);}
void popmn(int &now){int x, y; split(now, 1, x, y); now = y;}
int getrk(int now, node w){
if(!now) return 0;
if(val[now] < w) return sz[ls[now]] + 1 + getrk(rs[now], w);
return getrk(ls[now], w);
}
void insert(int &now, node w){
int k = getrk(now, w), x, y;
split(now, k, x, y);
now = merge(merge(x, newnode(w)), y);
}
void del(int now, vector<node> &vec){
st[++top] = now;
if(ls[now]) del(ls[now], vec);
vec.push_back(val[now]);
if(rs[now]) del(rs[now], vec);
}
void solve(int now, node w, node &ans){ //平衡树上二分查询答案
if(!now) return;
node tmp = sum[ls[now]];
if(cross(w + tmp, val[now]) >= 0) ans = w + tmp + val[now], solve(rs[now], w + tmp + val[now], ans);
else solve(ls[now], w, ans);
}
}fhq;
struct ODT{ //ODT 维护连续段,用来建树
struct range{
int l, r, v;
range(int l = 0, int r = 0, int v = 0):l(l), r(r), v(v){}
friend bool operator<(const range& a, const range &b){
return a.l < b.l;
}
};
set<range> st;
using sit = set<range>::iterator;
ODT(){st.insert(range(1, n, 0));}
sit split(int pos){
if(pos == n + 1) return st.end();
auto it = st.lower_bound(range(pos));
if(it != st.end() && it->l == pos) return it;
--it; auto[l, r, v] = *it; st.erase(it);
st.insert(range(l, pos - 1, v));
return st.insert(range(pos, r, v)).first;
}
void assign(int l, int r, int v){
auto itr = split(r + 1), itl = split(l);
st.erase(itl, itr); st.insert(range(l, r, v));
}
int query(int pos){
auto it = st.lower_bound(range(pos));
if(it != st.end() && it->l == pos) return it->v;
--it; return it->v;
}
}odt;
map<pair<int, int>, node> ans;
int rt[N];
void dfs(int now){
sz[now] = 1;
int x0 = ed[now] - st[now] + 1, y0 = S(st[now], ed[now]); //这个表示区间内不会被删的点的贡献形成的向量
for(int to : e[now]){
dfs(to); sz[now] += sz[to];
if(sz[to] > sz[son[now]]) son[now] = to;
x0 -= ed[to] - st[to] + 1, y0 -= S(st[to], ed[to]);
}
if(son[now]) rt[now] = rt[son[now]];
vector<node> rub;
for(int to : e[now]){
if(to == son[now]) continue;
rub.clear(); fhq.del(rt[to], rub);
for(auto w : rub) fhq.insert(rt[now], w); //先直接把儿子的向量归并过来
}
int p = fhq.getmn(rt[now]);
while(p){
if(cross(fhq.val[p], node(x0, y0)) < 0){ //在前方加入平移的向量,删去前缀加入后下凸的边
x0 += fhq.val[p].x, y0 += fhq.val[p].y;
fhq.popmn(rt[now]); p = fhq.getmn(rt[now]);
}
else break;
}
fhq.insert(rt[now], node(x0, y0));
node tmp;
fhq.solve(rt[now], node(2, a[st[now] - 1] + a[ed[now] + 1]), tmp); //(2,a[l]+a[r]) 一定是这个凸壳上的第一个点,其它的点根据我们维护的向量可以还原出来
ans[{st[now] - 1, ed[now] + 1}] = tmp;
}
array<int, 2> maximum_average(signed i, signed j){
int x = 0, y = 0; i++, j++;
if(j == i + 1) x = a[i] + a[j], y = 2;
else x = ans[{i, j}].y, y = ans[{i, j}].x;
int d = __gcd(x, y);
return {x / d, y / d};
}
void initialize(vector<signed> A){
n = A.size();
For(i, 1, n) a[i] = A[i - 1], sum[i] = sum[i - 1] + a[i];
vector<pair<int, int>> vec;
static int top, stk[N];
For(i, 1, n){
while(top && a[stk[top]] >= a[i]){
if(stk[top] < i - 1) vec.emplace_back(stk[top] + 1, i - 1);
top--;
}
stk[++top] = i;
}
top = 0;
Rof(i, n, 1){
while(top && a[stk[top]] >= a[i]){
if(stk[top] > i + 1 && a[i] != a[stk[top]]) vec.emplace_back(i + 1, stk[top] - 1);
top--;
}
stk[++top] = i;
}
sort(vec.begin(), vec.end(), [](pair<int, int> a, pair<int, int> b){
int x = a.second - a.first + 1, y = b.second - b.first + 1;
return x > y;
});
int np = vec.size();
For(i, 1, np) st[i] = vec[i - 1].first, ed[i] = vec[i - 1].second;
For(i, 1, np){
e[fa[i] = odt.query(ed[i])].push_back(i);
odt.assign(st[i], ed[i], i);
}
st[0] = 1; ed[0] = n; dfs(0);
}