P11678 [USACO25JAN] Watering the Plants P

· · 题解

P11678 [USACO25JAN] Watering the Plants P

题意:

n 个植物和 n-1 个水管,第 i 个水管可以给 i,i+1 两个植物一起提供任意单位的水,每单位水的花费是 c_i。第 i 个植物至少需要 w_i 单位的水。

对于每个 1 \le i < n 求只考虑前 i 个水管,满足前 i+1 个植物的最小花费。

**思路:** 注意到这个问题很想网络流,但是至少 $w_i$ 单位的水不太好做,考虑对偶一下: 原来的限制为: $$ 1 \le i < n,x_i + x_{i+1} \ge w_i\\ \text{minimize} \sum_{i=1}^{n-1}c_ix_i $$ 对偶之后变成: $$ 1 \le i < n, y_i + y_{i+1} \le c_i\\ \text{maximize}\sum_{i=1}^n w_iy_i $$ 这样我们就可以考虑费用流模型,和种树那题很想,大致如下: ![](https://cdn.luogu.com.cn/upload/image_hosting/v2be87ck.png?x-oss-process=image/resize,m_lfit,h_170,w_225) 但是显然数据范围不支持我们真的去跑费用流。所以我们只能贪心或者 dp 凸优化。这里我们考虑 dp 凸优化。 设 $f_i(j)$ 表示前 $i$ 个,满足 $y_i = j$ 的最大的 $\sum w_iy_i$。我们发现随着 $j$ 的增大 $f_i(j)$ 是凸函数。 这相当于不断给最后一条边加 $1$ 的流量,由于是费用流,所以增广路不会增大,所以是凸的。 转移如下: $$ f_i(j) = j \times w_i + \max_{k \le c_i - j}f_{i-1}(k) $$ 假设我们维护了 $f_{i-1}$ 的凸函数,那么这相当于以下三件事: - 将凸函数范围改成 $[0,c_i]$。 - 反转凸函数。 - 整体取前缀 $\max$。 - 加上直线 $y=w_ix$。 我们可以用 FHQ_Treap 来维护,一个节点储存差分相等的区间的长度,同时还要储存差分是多少,以及差分乘上区间长度的值。 我们要支持全局加法,全局反转,打两个 tag 就行。 当然这里其实可以用官方解法那样用双端队列就行,考场上没想那么多。时间复杂度 $O(n \log n)$,可以通过。 ```cpp #include <iostream> #include <cstdio> #include <queue> #include <ctime> #include <random> #include <vector> #include <cstring> #include <algorithm> using namespace std; const int N = 5e5 + 5; const int A = 1e6 + 5; typedef long long ll; #define debug(x) cout << #x << "=" << x << endl int n; int w[N] = {0}, c[N] = {0}; struct LazyTag { int rev; ll add;//翻转后加上 add LazyTag (int _rev = 0, ll _add = 0ll) : rev(_rev), add(_add) {} }; LazyTag operator+(LazyTag x, LazyTag y) { return LazyTag((x.rev ^ y.rev), y.add + (y.rev ? -x.add : x.add)); } struct Node { int ch[2], pri, sz; ll d, p; ll sd, sp, sum;//d之和, p 之和, p * d之和 LazyTag tag; Node () { ch[0] = ch[1] = sz = 0, pri = -1; d = p = sd = sp = sum = 0ll; tag = LazyTag(); } Node (ll _d, ll _p) { sd = d = _d, sp = p = _p, sum = p * d; ch[0] = ch[1] = 0, sz = 1, pri = rand(); tag = LazyTag(); } } a[N]; #define ls(x) a[x].ch[0] #define rs(x) a[x].ch[1] void upd(Node &x, Node &y, Node &z) { x.sd = y.sd + z.sd + x.d; x.sp = y.sp + z.sp + x.p; x.sum = y.sum + z.sum + x.d * x.p; x.sz = y.sz + z.sz + 1; } void pushup(int x) { upd(a[x], a[ls(x)], a[rs(x)]); } void mdf(Node &x, LazyTag v) { if (v.rev) { x.d = -x.d; x.sd = -x.sd; x.sum = -x.sum; swap(x.ch[0], x.ch[1]); } x.d += v.add; x.sd += v.add * x.sz; x.sum += v.add * x.sp; x.tag = x.tag + v; } void pushtag(int x, LazyTag v) { mdf(a[x], v); } void pushdown(int x) { pushtag(ls(x), a[x].tag); pushtag(rs(x), a[x].tag); a[x].tag = LazyTag(); } void spt_sp(int x, ll v, int &L, int &R) { if (x == 0) { L = R = 0; return; } pushdown(x); if (a[ls(x)].sp + a[x].p <= v) L = x, spt_sp(rs(x), v - a[ls(x)].sp - a[x].p, rs(x), R); else R = x, spt_sp(ls(x), v, L, ls(x)); pushup(x); } void spt_d(int x, ll v, int &L, int &R) { if (x == 0) { L = R = 0; return; } pushdown(x); if (a[x].d > v) L = x, spt_d(rs(x), v, rs(x), R); else R = x, spt_d(ls(x), v, L, ls(x)); pushup(x); } void spt_sz(int x, int k, int &L, int &R) { if (x == 0) { L = R = 0; return; } pushdown(x); if (a[ls(x)].sz + 1 <= k) L = x, spt_sz(rs(x), k - a[ls(x)].sz - 1, rs(x), R); else R = x, spt_sz(ls(x), k, L, ls(x)); pushup(x); } int mrg(int L, int R) { if (L == 0 || R == 0) return L + R; if (a[L].pri > a[R].pri) { pushdown(L); rs(L) = mrg(rs(L), R); pushup(L); return L; } else { pushdown(R); ls(R) = mrg(L, ls(R)); pushup(R); return R; } } int rt;//当前的根 ll res;//初始截距 int tot = 0;//节点个数 void show(int x) { if (x == 0) return; pushdown(x); show(ls(x)); printf("(%lld, %lld) ", a[x].p, a[x].d); show(rs(x)); } void upd_lim(ll k) {//只保留 <= k 的部分 int L, M, R; spt_sp(rt, k, L, R); // debug(k); // show(L);cout << endl; show(R);cout << endl; if (!R) {//表示当前不够了 // cout << a[L].sp << endl; a[++tot] = Node(0ll, k - a[L].sp); rt = mrg(L, tot); // show(rt);cout << endl; return; } spt_sz(R, 1, M, R); a[M] = Node(a[M].d, k - a[L].sp); rt = mrg(L, M); } void upd_rev(ll v) {//整体翻转然后所有数加上 v res += a[rt].sum;//截距变化 // show(rt);cout << endl; // debug(rt); pushtag(rt, LazyTag(1, v)); // debug(v); // show(rt);cout << endl; } void upd_min() {//取前缀 max,只保留 d > 0 的所有部分 int L, R; // show(rt);cout << endl; spt_d(rt, 0ll, L, R); rt = L; } #undef ls #undef rs int main() { // freopen("a.in", "r", stdin); // freopen("a.out", "w", stdout); scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &w[i]); for (int i = 1; i < n; i++) scanf("%d", &c[i]); res = 0ll; rt = ++tot; a[rt] = Node(w[1], 1000000ll); for (int i = 2; i <= n; i++) { // if (i == 9) // cout << "\n\n\n\n"; upd_lim(c[i - 1]); upd_rev(w[i]); upd_min(); // show(rt); // printf("res: %lld\n", res); printf("%lld\n", res + a[rt].sum); // cout << "nxt \n\n" << endl; // if (i == 9) // break; } return 0; } ```