题解 P6477 【[NOI Online #2 提高组]子序列问题】

· · 题解

这题在考场上线段树pushdown l写成r,直接挂成5分。 论对拍的重要性

这里给出一种不用维护平方和的线段树做法。

首先看到 \left[1,10^9\right]的数据范围需要离散化。

然后考虑对每个i(1\le i \le n)求出 \sum\limits_{l=1}^if(l,i)^2,然后相加即为答案。

我们用lst_i表示a_i上次出现的位置(如果之前没出现过则为0),用ans_i表示 \sum\limits_{l=1}^if(l,i)^2

所以我们只要在O(\log(n))的时间内实现ans的转移即可。

l \le lst_i时,a_i 不会对f(l,i)产生贡献,所以 f(l,i) = f(l,i-1)

lst_i < l < i时,a_i会对f(l,i)产生1的贡献,所以f(l,i) = f(l,i-1)+1f(l,i)^2-f(l,i-1)^2 = 2f(l,i-1)+1

l=i时,f(l,i) = 1

所以可以得到ans_i = ans_{i-1} + \sum\limits_{l=lst_i}^{i-1}\left(f(l,i)^2- f(l,i-1)^2\right) + 1= ans_{i-1} + \sum\limits_{l=lst_i}^{i-1}\left(f(l,i-1) *2+1\right) + 1

\sum\limits_{l=lst_i}^{i-1}\left(f(l,i-1) *2+1\right) 用线段树维护序列和即可,并不需要维护平方和。

最后的答案为\sum\limits_{i = 1}^n ans_i

代码不卡常不加O2在洛谷也能AC,不加O2 9.29s,加O2 5.38s。

#include<bits/stdc++.h>
using namespace std;
const int mod = 1e9 + 7;
const int N = 1e6 + 7;
int n;
int a[N],b[N];
int ans[N];
int totans = 0;
int cnt = 0;
int lst[N];
void discrete(){
    sort(b+1,b+n+1);
    cnt = unique(b+1,b+n+1) - (b+1);
    for(int i = 1;i<=n;i++){
        a[i] = lower_bound(b+1,b+cnt+1,a[i]) - b;
    }
}
struct Node{
    int lazy,c,l,r;
}tr[N*4] = {0};
void bt(int nw,int l,int r){
    tr[nw].l = l;
    tr[nw].r = r;
    if(l==r) return;
    int mid = (l+r)/2;
    bt(nw*2,l,mid);
    bt(nw*2+1,mid+1,r);
}
void upd(int nw){
    if(tr[nw].lazy){
        tr[nw*2].lazy += tr[nw].lazy;
        tr[nw*2+1].lazy += tr[nw].lazy;
        tr[nw*2].c =  (1ll * tr[nw*2].c + 1ll * tr[nw].lazy * (tr[nw*2].r - tr[nw*2].l + 1)) % mod;
        tr[nw*2+1].c = (1ll * tr[nw*2+1].c + 1ll * tr[nw].lazy * (tr[nw*2+1].r - tr[nw*2+1].l + 1)) % mod;
        tr[nw].lazy = 0;
    }
}
void chg(int nw,int l,int r,int k){
    if(l>r) return;
    if((l == tr[nw].l)&&(r == tr[nw].r)){
        tr[nw].c = (tr[nw].c + (r-l+1) * k)%mod;
        tr[nw].lazy += k;
        return;
    }
    upd(nw);
    int mid = (tr[nw].r + tr[nw].l)/2;
    if(r<=mid) chg(nw*2,l,r,k);
    else if(l>mid) chg(nw*2+1,l,r,k);
    else chg(nw*2,l,mid,k),chg(nw*2+1,mid+1,r,k);
    tr[nw].c = (tr[nw*2].c + tr[nw*2+1].c) % mod; 
}
int qry(int nw,int l,int r){
    if(l>r) return 0;
    if((l == tr[nw].l)&&(r == tr[nw].r)){
        return tr[nw].c;
    }
    upd(nw);
    int mid = (tr[nw].r + tr[nw].l)/2;
    if(r<=mid) return qry(nw*2,l,r);
    else if(l>mid) return qry(nw*2+1,l,r);
    else return (qry(nw*2,l,mid)+qry(nw*2+1,mid+1,r)) % mod;
}
int main(){
    scanf("%d",&n);
    for(int i = 1;i<=n;i++){
        scanf("%d",&a[i]);
        b[i] = a[i];
    }
    discrete();
    bt(1,1,n);
    for(int i = 1;i<=n;i++){
        ans[i] = ans[i-1];
        ans[i] += (qry(1,lst[a[i]]+1,i) * 2 + (i-lst[a[i]])) % mod;
        chg(1,lst[a[i]]+1,i,1);
        ans[i] %= mod;
        lst[a[i]] = i;
        totans = (totans + ans[i]) % mod;
    }
    printf("%d",totans);
    return 0;
}