题解 [ARC066D] Contest with Drinks Hard

皎月半洒花

2020-02-23 16:33:57

Solution

很不懂啊,另一篇题解的内容完全和cdsfcesf的课件内容一模一样,就几乎是一个字也没改,不是很明白这个操作…… > 有 $n$ 个点,选择第 $i$ 个点有 $a_i$ 的代价。如果一个极大的、长度为 $L$ 的区间内的点都被选了,这个区间会带来$\frac{L^2+L}{2}$ 的收益。 > > > 形式化地讲,给定 $\{t_n\}$ ,求一组 $\{x_n\},\forall i \in[1, n], x_{i} \in[0,1]$ 使可以得到: $$\max \left\{\sum_{i=1}^{n} \sum_{j=i}^{n}\left(\prod_{k=i}^{j} x_{k}\right)-\sum_{i=1}^{n} x_{i} t_{i}\right\}$$ > > 询问共 $q$ 组,为对 $a_i$ 的单点修改。 询问彼此独立。 > $n, q ≤ 3 × 10^5$ 一道神题…本质上虽然不是很难,但是碍于个人对斜率优化这个东西掌握的实在十分差劲,所以几乎把所有能踩的坑全踩了一遍。 # $\rm Sol$ 首先考虑对于每个询问,主要思想是维护一个前缀的 $f_i$ 表示不选 $i$ ,前 $i-1$ 个元素的最优解;$g_i$ 则是用来维护一个后缀,表示不选 $i$,$i+1\sim n$ 中元素的最优解。同时,再维护一个 $h_i$ 表示选了 $i$ 的全局最优解。那么这样就可以如此回答询问: ```cpp while (m --){ scanf("%d%d", &p, &x) ; printf("%lld\n", max(f[p] + g[p], o[p] + base[p] - x)) ; } ``` 于是考虑这三个东西应该怎么转移。先考虑 $f$ : $$\begin{aligned} f_i&=\max\{f_j+\frac{(i-j)(i-j-1)}{2}-(s_i-s_j)\}\\ f_i&=\max\{f_i,f_{i-1}\} \end{aligned}$$ 枚举的是 $[j,i-1]$ 这段区间。 变形可得 $$ f_j+s_j+\frac{j^2+j}{2}=i\cdot j+f_i+s_i- \frac{i^2-i}{2} $$ 于是就发现,这是一个以 $f_j+s_j+\frac{j^2+j}{2}$ 为纵坐标,$i$ 为斜率的斜率优化。观察可知,横坐标 $j$ 单增,同时斜率 $i$ 也单增。所以对于这种类型的 $dp$ 只能用**单调栈**来维护。 同时,关于 $g$,只需要倒着做一遍 $f$ 的 $dp$ 即可。 ```cpp LL calc0(LL x, LL y){ return (x - y) * (x - y - 1) / 2ll ; } LL Y(LL x){ return f[x] + s[x] + (x * x + x) / 2ll ; } LL YY(LL x){ return g[x] + s[x] + (x * x + x) / 2ll ; } double slope1(int x, int y){ return 1.0 * (1.0 * Y(x) - 1.0 * Y(y)) / (1.0 * x - 1.0 * y) ; } double slope2(int x, int y){ return 1.0 * (1.0 * YY(x) - 1.0 * YY(y)) / (1.0 * x - 1.0 * y) ; } ...... stk[++ t] = 0 ; s[n + 1] = s[n] ; for (int i = 1 ; i <= n + 1 ; ++ i){ while (t > 1 && slope1(stk[t], stk[t - 1]) <= i) stk[t --] = 0 ; f[i] = max(f[i - 1], f[stk[t]] + calc0(i, stk[t]) - s[i - 1] + s[stk[t]]) ; while (t > 1 && slope1(stk[t], stk[t - 1]) <= slope1(i, stk[t])) stk[t --] = 0 ; stk[++ t] = i ; } reverse(base + 1, base + n + 1) ; for (int i = 1 ; i <= n + 1 ; ++ i) s[i] = s[i - 1] + base[i] ; while (t) stk[t --] = 0 ; t = 0 ; stk[++ t] = 0 ; for (int i = 1 ; i <= n + 1 ; ++ i){ while (t > 1 && slope2(stk[t], stk[t - 1]) <= i) stk[t --] = 0 ; g[i] = max(g[i - 1], g[stk[t]] + calc0(i, stk[t]) - s[i - 1] + s[stk[t]]) ; while (t > 1 && slope2(stk[t], stk[t - 1]) <= slope2(i, stk[t])) stk[t --] = 0 ; stk[++ t] = i ; } reverse(g + 1, g + n + 1) ; ``` 考虑关于于 $h$ 的转移,朴素的转移是 $O(n^3)$ 的,类似这样: $$ h_i=\max_{x\leq i\leq y}\{f_x+\frac{(y-x+1)(y-x+2)}{2}-s_y+s_x+g_y\} $$ 发现并不可以过。发现转移是一个区间的形式,于是考虑分治。但是这个分治有个 $bug$ ,就是不是很容易维护跨过 $mid$ 的信息。 但发现一个性质,对于 $[mid+1,r]$ 之间的所有 $h$ 值,左边的所有点仅仅是帮助其成为最优解,也就是左边对右边的点满足单调性。那么我们就可以用单调栈来维护上一个不选的(即维护 $x$)同时枚举最后选的(枚举 $y$),即用 $[l,mid]$ 的值去更新 $[mid+1,r]$;同理对于右半边我们也可以这么做。 具体的,以用 $[l,mid]$ 的值去更新 $[mid+1,r]$ 为例,在算左半边的决策集合的时候,不需要考虑 $g_y$ 的贡献,所以就是一个裸的斜率优化;之后转移就只需要弹栈转移即可。 发现在转移 $h$ 的时候,由于固定了其中一个端点,依然满足斜率优化,于是复杂度为 $\max\{n\log n, Q\}$。 以下是坑点: 1、注意如果按照上述方式来转移,两个转移常数有些微不同(但是这些不同很致命 2、注意因为一开始 $\{g_n\}$ 这个东西是按照 `reverse` 的前缀和来转移的,所以在分治的时候需要重新计算前缀和。 3、在分治的时候,斜率啊什么的变换频繁,所以如果想要推明白,可以在维护上凸壳的时候偷懒这么写: ``` #define tranl(b,a) g[a + 1] - s[a] + S[b] + F[b] + 1ll * ((a - b)*(a - b + 1)/ 2ll) #define tranr(b,a) f[a - 1] - s[a] + S[b] + G[b] + 1ll * ((b - a)*(b - a + 1)/ 2ll) ``` 然后每次比较就不需要斜率了,只需要 `while (top > 1 && tranl(stk[top], i) <= tranr(stk[top - 1], i)) -- top` ; 虽然显然本质上没什么不同,但这样或许会简单很多。 ```cpp const int N = 600010 ; typedef long long LL ; LL f[N] ; LL g[N] ; LL o[N] ; LL s[N] ; LL tmp[N] ; int t ; int n, m ; int stk[N] ; int base[N] ; LL calc(LL x, LL y){ return (x - y) * (x - y + 1) / 2ll ; } LL calc0(LL x, LL y){ return (x - y) * (x - y - 1) / 2ll ; } LL Y(LL x){ return f[x] + s[x] + (x * x + x) / 2ll ; } LL YY(LL x){ return g[x] + s[x] + (x * x + x) / 2ll ; } double slope1(int x, int y){ return 1.0 * (1.0 * Y(x) - 1.0 * Y(y)) / (1.0 * x - 1.0 * y) ; } double slope2(int x, int y){ return 1.0 * (1.0 * YY(x) - 1.0 * YY(y)) / (1.0 * x - 1.0 * y) ; } LL Y2(LL x){ return f[x] + s[x] + (x * x - x) / 2ll ; } LL YY2(LL x){ return g[x] + s[x] + (x * x - x) / 2ll ; } double Slope1(int x, int y){ return 1.0 * (1.0 * Y2(x) - 1.0 * Y2(y)) / (1.0 * x - 1.0 * y) ; } double Slope2(int x, int y){ return 1.0 * (1.0 * YY2(x) - 1.0 * YY2(y)) / (1.0 * x - 1.0 * y) ; } void solve(int l, int r){ if (l == r) return o[l] = max(o[l], f[l - 1] + g[r + 1] + 1 - base[l]), void() ; int mid = (l + r) >> 1 ; LL ans ; solve(l, mid) ; solve(mid + 1, r) ; s[l - 1] = s[r + 1] = 0 ; t = 0 ; for (int i = l ; i <= r ; ++ i) tmp[i] = -(1ll << 62), s[i] = s[i - 1] + base[i] ; for (int i = l - 1 ; i <= mid ; ++ i){ while (t > 1 && slope1(stk[t], stk[t - 1]) <= slope1(i, stk[t])) stk[t --] = 0 ; stk[++ t] = i ; } for (int i = mid + 1 ; i <= r ; ++ i){ while (t > 1 && Slope1(stk[t], stk[t - 1]) <= i) stk[t --] = 0 ; tmp[i] = f[stk[t]] + calc(i, stk[t]) - s[i] + s[stk[t]] + g[i + 1] ; } ans = -(1ll << 62) ; for (int i = r ; i >= mid + 1 ; -- i) ans = max(ans, tmp[i]), o[i] = max(o[i], ans) ; t = 0 ; for (int i = r ; i >= l ; -- i) s[i] = s[i + 1] + base[i] ; for (int i = r + 1 ; i > mid ; -- i){ while (t > 1 && slope2(stk[t], stk[t - 1]) >= slope2(i, stk[t])) stk[t --] = 0 ; stk[++ t] = i ; } for (int i = mid ; i >= l ; -- i){ while (t > 1 && Slope2(stk[t], stk[t - 1]) >= i) stk[t --] = 0 ; tmp[i] = g[stk[t]] + calc(stk[t], i) - s[i] + s[stk[t]] + f[i - 1] ; } ans = -(1ll << 62) ; for (int i = l ; i <= mid ; ++ i) ans = max(ans, tmp[i]), o[i] = max(o[i], ans) ; } int main(){ freopen("2.in", "r", stdin) ; freopen("1.out", "w", stdout) ; cin >> n ; t = 0 ; int p, x ; for (int i = 1 ; i <= n ; ++ i){ scanf("%d", &base[i]), o[i] = - (1ll << 62), s[i] = s[i - 1] + base[i] ; } stk[++ t] = 0 ; s[n + 1] = s[n] ; for (int i = 1 ; i <= n + 1 ; ++ i){ while (t > 1 && slope1(stk[t], stk[t - 1]) <= i) stk[t --] = 0 ; f[i] = max(f[i - 1], f[stk[t]] + calc0(i, stk[t]) - s[i - 1] + s[stk[t]]) ; while (t > 1 && slope1(stk[t], stk[t - 1]) <= slope1(i, stk[t])) stk[t --] = 0 ; stk[++ t] = i ; } reverse(base + 1, base + n + 1) ; for (int i = 1 ; i <= n + 1 ; ++ i) s[i] = s[i - 1] + base[i] ; while (t) stk[t --] = 0 ; t = 0 ; stk[++ t] = 0 ; for (int i = 1 ; i <= n + 1 ; ++ i){ while (t > 1 && slope2(stk[t], stk[t - 1]) <= i) stk[t --] = 0 ; g[i] = max(g[i - 1], g[stk[t]] + calc0(i, stk[t]) - s[i - 1] + s[stk[t]]) ; while (t > 1 && slope2(stk[t], stk[t - 1]) <= slope2(i, stk[t])) stk[t --] = 0 ; stk[++ t] = i ; } reverse(g + 1, g + n + 1) ; g[0] = g[n + 1], g[n + 1] = 0 ; reverse(base + 1, base + n + 1) ; //for (int i = 0 ; i <= n + 1 ; ++ i) cout << f[i] << " " ; puts("") ; //for (int i = 0 ; i <= n + 1 ; ++ i) cout << g[i] << " " ; puts("") ; solve(1, n) ; cin >> m ; //for (int i = 1 ; i <= n ; ++ i) cout << o[i] << " " ; puts("") ; while (m --){ scanf("%d%d", &p, &x) ; printf("%lld\n", max(f[p] + g[p], o[p] + base[p] - x)) ; } return 0 ; } ```