[NOI Online #2 提高组]子序列问题
皎月半洒花
2020-04-25 16:11:51
我来说一个很不正常的解法…不正常在他特别麻烦…特别难调…
我的做法是先算出全部区间的贡献:
$$\sum_{i=1}^n\sum_{j=i}^n(j-i)^2=\sum_{i=1}^n(1^2+2^2+3^2+\cdots +(n-i)^2)$$
也就是
$$\sum_{i=1}^n\frac{(n-i)(n-i+1)(2\times (n-i)+1)}{6}$$
然后考虑减掉那些不合法的。具体的,预处理出每个位置左边最近的那个相同颜色的下标 $pre_x$ 。那么 $x$ 和 $pre_x$ 会对左端点在 $1\sim pre_x$ ,右端点在 $x+1\sim n$ 的区间产生贡献。贡献怎么算呢?
考虑假设一个区间长为 $L$ 。那么第一组 $(pre_x,x)$ 出现时,会有
$$L^2\to (L-1)^2=L^2-2L+1=L^2-(2L-1)$$
第二组出现时:
$$(L-1)^2\to (L-2)^2=(L-1)^2-2(L-1)+1=(L-1)^2-(2L-3)$$
以此类推,当一个区间存在 $t$ 个重复颜色时(即假设某种颜色的数量为 $c$,那么这种颜色的「重复颜色数」为 $c-1$),他需要减去 $(2\cdot t\cdot L-t^2)$ 的贡献。
考虑拆成两半做:
1、$2\cdot t\cdot L$
需要枚举每个位置 $i$ ,设 $j=pre_i$ 。记 $p=\max\{(n-i+1),i\},q=\min\{i,(n-i+1)\}$ 。即 $p$ 是左右两边较长的那个区间,$q$ 是较短的那个。同时记当前区间长度为 $d$,即 $d=i-pre_i$ 。以下默认省略前面的系数 $2$ 。
那么需要再分三类讨论会被产生贡献的区间长度 $L$ ,以下在计算 $L$ 时,用 $d+\Delta$ 来代替:
> (1)$d+1\leq L\leq q+d$
>
> 对于每个这样的 $L$ ,会存在 $L-d$ 个区间产生合法贡献,所以这部分贡献就是
$$\sum_{L=d+1}^{q+d}L\cdot i=\sum_{i=1}^{q}(d+i)\cdot i$$
> 可以通过预处理 $\sum i$ ,$\sum i^2$ 快速计算。
>
> (2) $q+d+1\leq L\leq p+d$
>
> 对于每个这样的 $L$ ,由于不能全部取到,所以至多会有 $q$ 个。所以这部分贡献是:
$$\sum_{L=q+d+1}^{p+d}L\cdot q=\sum_{i=q+1}^{p}(d+i)\cdot q$$
> 这部分比较好算。
>
> (3) $p+d+1\leq L\leq n$
>
> 对于每个这样的 $L$ ,发现最多只能取到 $n-L+1$ 次。所以这部分贡献是
$$\sum_{L=p+d+1}^{n}L\cdot (n-L+1)=\sum_{i=p+1}^{n-d}(n-d-i+1)\cdot (d+i)$$
> 这一部分同样可以通过预处理来快速计算。
综上,这一部分的复杂度是排序外线性。
2、$-t^2$
设 $i$ 右边第一个和 $i$ 同颜色的元素为 $r_i$ 。
也就是现在把问题转化成了「{区间内重复出现的数字个数 $-1$ 的平方和」。考虑扫描线。一开始将所有的数都加进线段树。从左开始,每次都删掉一个最左边的元素 $i$。如果这个元素的颜色依旧出现在后面的序列中,那么可以知道对于所有右端点 $\geq r_i$ 的区间,都会少掉一个 $(i, r_i)$ 组成的 `pair`,也就是会少掉一个重复颜色的元素。所以就是后缀减 $-1$ and 询问后缀的平方和,线段树维护即可。
这一部分复杂度 $O(n\log n)$ 。
如何卡常:
1、不要用 `map` .
2、(mayaohua 在 uoj 群里的高论)发现中间,一段区间内部的平方的和本质上是不会爆 `long long` 的,所以可以减少取模次数。
Upd: 发现代码里最难懂的部分没有注明 `add(),dec(),addn(),decn()` 都是什么…我的锅。
```cpp
template <typename T>
il void add(T &x, T y, ll mod = P){
x += y ; x = x >= mod ? x - mod : x ;
}
template <typename T>
il void dec(T &x, T y, ll mod = P){
x -= y ; x = x < 0 ? x + mod : x ;
}
template <typename T>
il T addn(T x, T y, ll mod = P){
x += y ; return (x = x > mod ? x - mod : x) ;
}
template <typename T>
il T decn(T x, T y, ll mod = P){
x -= y ; return (x = x < 0 ? x + mod : x) ;
}
/* - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - */
int val[N] ;
ll s[N * 3] ;
ll t[N * 3] ;
ll tg[N * 3] ;
void build(int rt, int l, int r){
if (l == r){
s[rt] = val[l] ;
t[rt] = s[rt] * s[rt] ;
return ;
}
int mid = (l + r) >> 1 ;
build(ls, l, mid) ;
build(rs, mid + 1, r) ;
s[rt] = s[ls] + s[rs] ;
t[rt] = t[ls] + t[rs] ;
}
void _down(int rt, int l, int r){
if (tg[rt]){
ll p = tg[rt] * tg[rt] % P ;
ll pr = r - ((l + r) >> 1) ;
ll pl = ((l + r) >> 1) - l + 1 ;
tg[ls] += tg[rt], tg[rs] += tg[rt] ;
dec(t[ls], decn(2ll * s[ls] * tg[rt] % P, p * pl)) ;
dec(t[rs], decn(2ll * s[rs] * tg[rt] % P, p * pr)) ;
dec(s[ls], tg[rt] * pl % P) ; dec(s[rs], tg[rt] * pr % P) ; tg[rt] = 0 ;
}
}
void update(int rt, int l, int r, int ul, int ur){
if (ul <= l && r <= ur){
dec(t[rt], decn(2ll * s[rt] % P, 1ll * (r - l + 1))) ;
dec(s[rt], 1ll * (r - l + 1)) ; tg[rt] += 1 ; return ;
}
int mid = (l + r) >> 1 ; _down(rt, l, r) ;
if (ul <= mid) update(ls, l, mid, ul, ur) ;
if (ur > mid) update(rs, mid + 1, r, ul, ur) ;
s[rt] = s[ls] + s[rs] ; t[rt] = t[ls] + t[rs] ;
}
ll query(int rt, int l, int r, int ul, int ur){
if (ul <= l && r <= ur) return t[rt] ;
int mid = (l + r) >> 1 ; ll res = 0 ; _down(rt, l, r) ;
if (ul <= mid) res += query(ls, l, mid, ul, ur) ;
if (ur > mid) res += query(rs, mid + 1, r, ul, ur) ;
return res ;
}
int n ;
ll ans ;
ll res ;
ll sum1[N] ;
ll sum2[N] ;
int pos[N] ;
int nxt[N] ;
int buc[N] ;
int tmp[N] ;
int base[N] ;
ll fuck[M][M] ;
const ll Inv6 = 166666668ll ;
int main(){
cin >> n ; int len ;
for (int i = 1 ; i <= n ; ++ i)
base[i] = tmp[i] = qr() ;
sort(tmp + 1, tmp + n + 1) ;
len = unique(tmp + 1, tmp + n + 1) - tmp - 1 ;
for (int i = 1 ; i <= n ; ++ i)
base[i] = lb(tmp + 1, tmp + len + 1, base[i]) - tmp ;
for (int i = 1 ; i <= n ; ++ i){
if (buc[base[i]]) pos[i] = buc[base[i]] ; buc[base[i]] = i ;
}
for (ll i = 1 ; i <= n ; ++ i)
sum1[i] = addn(sum1[i - 1], i) ;
for (ll i = 1 ; i <= n ; ++ i)
sum2[i] = addn(sum2[i - 1], i * i) ;
for (ll i = 0 ; i <= n ; ++ i)
add(ans, (i + 1) * i * (2ll * i + 1ll) % P) ;
ans = ans * Inv6 % P ;
ll q, maxx, minx, m, p, len1, len2, d ;
for (int i = 1 ; i <= n ; ++ i){
if (!pos[i])
continue ;
q = n - i + 1 ;
maxx = q, minx = pos[i] ;
p = i - pos[i], m = n - p ;
d = decn(sum1[m], sum1[maxx]) ;
if (minx > maxx) swap(minx, maxx) ;
len2 = m - maxx, len1 = maxx - minx ;
//part1
add(res, sum2[minx]) ;
add(res, sum1[minx] * p) ;
//part2
add(res, p * minx * len1) ;
add(res, minx * decn(sum1[maxx], sum1[minx], P) % P) ;
//part3
dec(res, 2ll * p * d % P) ;
dec(res, p * p * len2 % P) ;
add(res, 1ll * (n + 1) * d % P) ;
dec(res, decn(sum2[m], sum2[maxx])) ;
add(res, 1ll * (n + 1) * p * len2 % P) ;
dec(ans, 2ll * res % P) ; res = 0 ;
}
if (n <= 1000){
for (int i = 1 ; i <= n ; ++ i){
fill(buc, buc + n + 1, 0) ;
for (int j = i ; j <= n ; ++ j){
buc[base[j]] ++ ;
fuck[i][j] = fuck[i][j - 1] + (buc[base[j]] > 1) ;
add(ans, fuck[i][j] * fuck[i][j]) ;
}
}
cout << ans << '\n' ; return 0 ;
}
fill(buc, buc + n + 1, 0) ;
for (int i = n ; i >= 1 ; -- i)
nxt[i] = buc[base[i]] ? buc[base[i]] : n + 1, buc[base[i]] = i ;
fill(buc, buc + n + 1, 0) ;
for (int i = 1 ; i <= n ; ++ i)
buc[base[i]] ++, val[i] = val[i - 1] + (buc[base[i]] > 1) ;
build (1, 1, n) ;
for (int i = 1 ; i < n ; ++ i){
add(ans, query(1, 1, n, i, n) % P) ;
if (nxt[i] <= n) update(1, 1, n, nxt[i], n) ;
}
cout << ans << '\n' ; return 0 ;
}
```