[NOI Online #2 提高组]子序列问题

皎月半洒花

2020-04-25 16:11:51

Solution

我来说一个很不正常的解法…不正常在他特别麻烦…特别难调… 我的做法是先算出全部区间的贡献: $$\sum_{i=1}^n\sum_{j=i}^n(j-i)^2=\sum_{i=1}^n(1^2+2^2+3^2+\cdots +(n-i)^2)$$ 也就是 $$\sum_{i=1}^n\frac{(n-i)(n-i+1)(2\times (n-i)+1)}{6}$$ 然后考虑减掉那些不合法的。具体的,预处理出每个位置左边最近的那个相同颜色的下标 $pre_x$ 。那么 $x$ 和 $pre_x$ 会对左端点在 $1\sim pre_x$ ,右端点在 $x+1\sim n$ 的区间产生贡献。贡献怎么算呢? 考虑假设一个区间长为 $L$ 。那么第一组 $(pre_x,x)$ 出现时,会有 $$L^2\to (L-1)^2=L^2-2L+1=L^2-(2L-1)$$ 第二组出现时: $$(L-1)^2\to (L-2)^2=(L-1)^2-2(L-1)+1=(L-1)^2-(2L-3)$$ 以此类推,当一个区间存在 $t$ 个重复颜色时(即假设某种颜色的数量为 $c$,那么这种颜色的「重复颜色数」为 $c-1$),他需要减去 $(2\cdot t\cdot L-t^2)$ 的贡献。 考虑拆成两半做: 1、$2\cdot t\cdot L$ 需要枚举每个位置 $i$ ,设 $j=pre_i$ 。记 $p=\max\{(n-i+1),i\},q=\min\{i,(n-i+1)\}$ 。即 $p$ 是左右两边较长的那个区间,$q$ 是较短的那个。同时记当前区间长度为 $d$,即 $d=i-pre_i$ 。以下默认省略前面的系数 $2$ 。 那么需要再分三类讨论会被产生贡献的区间长度 $L$ ,以下在计算 $L$ 时,用 $d+\Delta$ 来代替: > (1)$d+1\leq L\leq q+d$ > > 对于每个这样的 $L$ ,会存在 $L-d$ 个区间产生合法贡献,所以这部分贡献就是 $$\sum_{L=d+1}^{q+d}L\cdot i=\sum_{i=1}^{q}(d+i)\cdot i$$ > 可以通过预处理 $\sum i$ ,$\sum i^2$ 快速计算。 > > (2) $q+d+1\leq L\leq p+d$ > > 对于每个这样的 $L$ ,由于不能全部取到,所以至多会有 $q$ 个。所以这部分贡献是: $$\sum_{L=q+d+1}^{p+d}L\cdot q=\sum_{i=q+1}^{p}(d+i)\cdot q$$ > 这部分比较好算。 > > (3) $p+d+1\leq L\leq n$ > > 对于每个这样的 $L$ ,发现最多只能取到 $n-L+1$ 次。所以这部分贡献是 $$\sum_{L=p+d+1}^{n}L\cdot (n-L+1)=\sum_{i=p+1}^{n-d}(n-d-i+1)\cdot (d+i)$$ > 这一部分同样可以通过预处理来快速计算。 综上,这一部分的复杂度是排序外线性。 2、$-t^2$ 设 $i$ 右边第一个和 $i$ 同颜色的元素为 $r_i$ 。 也就是现在把问题转化成了「{区间内重复出现的数字个数 $-1$ 的平方和」。考虑扫描线。一开始将所有的数都加进线段树。从左开始,每次都删掉一个最左边的元素 $i$。如果这个元素的颜色依旧出现在后面的序列中,那么可以知道对于所有右端点 $\geq r_i$ 的区间,都会少掉一个 $(i, r_i)$ 组成的 `pair`,也就是会少掉一个重复颜色的元素。所以就是后缀减 $-1$ and 询问后缀的平方和,线段树维护即可。 这一部分复杂度 $O(n\log n)$ 。 如何卡常: 1、不要用 `map` . 2、(mayaohua 在 uoj 群里的高论)发现中间,一段区间内部的平方的和本质上是不会爆 `long long` 的,所以可以减少取模次数。 Upd: 发现代码里最难懂的部分没有注明 `add(),dec(),addn(),decn()` 都是什么…我的锅。 ```cpp template <typename T> il void add(T &x, T y, ll mod = P){ x += y ; x = x >= mod ? x - mod : x ; } template <typename T> il void dec(T &x, T y, ll mod = P){ x -= y ; x = x < 0 ? x + mod : x ; } template <typename T> il T addn(T x, T y, ll mod = P){ x += y ; return (x = x > mod ? x - mod : x) ; } template <typename T> il T decn(T x, T y, ll mod = P){ x -= y ; return (x = x < 0 ? x + mod : x) ; } /* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */ int val[N] ; ll s[N * 3] ; ll t[N * 3] ; ll tg[N * 3] ; void build(int rt, int l, int r){ if (l == r){ s[rt] = val[l] ; t[rt] = s[rt] * s[rt] ; return ; } int mid = (l + r) >> 1 ; build(ls, l, mid) ; build(rs, mid + 1, r) ; s[rt] = s[ls] + s[rs] ; t[rt] = t[ls] + t[rs] ; } void _down(int rt, int l, int r){ if (tg[rt]){ ll p = tg[rt] * tg[rt] % P ; ll pr = r - ((l + r) >> 1) ; ll pl = ((l + r) >> 1) - l + 1 ; tg[ls] += tg[rt], tg[rs] += tg[rt] ; dec(t[ls], decn(2ll * s[ls] * tg[rt] % P, p * pl)) ; dec(t[rs], decn(2ll * s[rs] * tg[rt] % P, p * pr)) ; dec(s[ls], tg[rt] * pl % P) ; dec(s[rs], tg[rt] * pr % P) ; tg[rt] = 0 ; } } void update(int rt, int l, int r, int ul, int ur){ if (ul <= l && r <= ur){ dec(t[rt], decn(2ll * s[rt] % P, 1ll * (r - l + 1))) ; dec(s[rt], 1ll * (r - l + 1)) ; tg[rt] += 1 ; return ; } int mid = (l + r) >> 1 ; _down(rt, l, r) ; if (ul <= mid) update(ls, l, mid, ul, ur) ; if (ur > mid) update(rs, mid + 1, r, ul, ur) ; s[rt] = s[ls] + s[rs] ; t[rt] = t[ls] + t[rs] ; } ll query(int rt, int l, int r, int ul, int ur){ if (ul <= l && r <= ur) return t[rt] ; int mid = (l + r) >> 1 ; ll res = 0 ; _down(rt, l, r) ; if (ul <= mid) res += query(ls, l, mid, ul, ur) ; if (ur > mid) res += query(rs, mid + 1, r, ul, ur) ; return res ; } int n ; ll ans ; ll res ; ll sum1[N] ; ll sum2[N] ; int pos[N] ; int nxt[N] ; int buc[N] ; int tmp[N] ; int base[N] ; ll fuck[M][M] ; const ll Inv6 = 166666668ll ; int main(){ cin >> n ; int len ; for (int i = 1 ; i <= n ; ++ i) base[i] = tmp[i] = qr() ; sort(tmp + 1, tmp + n + 1) ; len = unique(tmp + 1, tmp + n + 1) - tmp - 1 ; for (int i = 1 ; i <= n ; ++ i) base[i] = lb(tmp + 1, tmp + len + 1, base[i]) - tmp ; for (int i = 1 ; i <= n ; ++ i){ if (buc[base[i]]) pos[i] = buc[base[i]] ; buc[base[i]] = i ; } for (ll i = 1 ; i <= n ; ++ i) sum1[i] = addn(sum1[i - 1], i) ; for (ll i = 1 ; i <= n ; ++ i) sum2[i] = addn(sum2[i - 1], i * i) ; for (ll i = 0 ; i <= n ; ++ i) add(ans, (i + 1) * i * (2ll * i + 1ll) % P) ; ans = ans * Inv6 % P ; ll q, maxx, minx, m, p, len1, len2, d ; for (int i = 1 ; i <= n ; ++ i){ if (!pos[i]) continue ; q = n - i + 1 ; maxx = q, minx = pos[i] ; p = i - pos[i], m = n - p ; d = decn(sum1[m], sum1[maxx]) ; if (minx > maxx) swap(minx, maxx) ; len2 = m - maxx, len1 = maxx - minx ; //part1 add(res, sum2[minx]) ; add(res, sum1[minx] * p) ; //part2 add(res, p * minx * len1) ; add(res, minx * decn(sum1[maxx], sum1[minx], P) % P) ; //part3 dec(res, 2ll * p * d % P) ; dec(res, p * p * len2 % P) ; add(res, 1ll * (n + 1) * d % P) ; dec(res, decn(sum2[m], sum2[maxx])) ; add(res, 1ll * (n + 1) * p * len2 % P) ; dec(ans, 2ll * res % P) ; res = 0 ; } if (n <= 1000){ for (int i = 1 ; i <= n ; ++ i){ fill(buc, buc + n + 1, 0) ; for (int j = i ; j <= n ; ++ j){ buc[base[j]] ++ ; fuck[i][j] = fuck[i][j - 1] + (buc[base[j]] > 1) ; add(ans, fuck[i][j] * fuck[i][j]) ; } } cout << ans << '\n' ; return 0 ; } fill(buc, buc + n + 1, 0) ; for (int i = n ; i >= 1 ; -- i) nxt[i] = buc[base[i]] ? buc[base[i]] : n + 1, buc[base[i]] = i ; fill(buc, buc + n + 1, 0) ; for (int i = 1 ; i <= n ; ++ i) buc[base[i]] ++, val[i] = val[i - 1] + (buc[base[i]] > 1) ; build (1, 1, n) ; for (int i = 1 ; i < n ; ++ i){ add(ans, query(1, 1, n, i, n) % P) ; if (nxt[i] <= n) update(1, 1, n, nxt[i], n) ; } cout << ans << '\n' ; return 0 ; } ```