题解 CF1083D 【The Fair Nut's getting crazy】
Owen_codeisking
2020-06-13 09:50:54
~~CF 上这题好像 3500,像极了洛谷的恶意评分~~
首先枚举交集 $[l,r]$,然后计算每个交集的答案。
定义
$\text{pre}_i$ 为数值和 $a_i$ 相同,最靠近 $i$ 且 $<i$ 的位置,若不存在为 $0$。
$\text{suc}_i$ 为数值和 $a_i$ 相同,最靠近 $i$ 且 $>i$ 的位置,若不存在为 $n+1$。
记
$$f(i,j)=\max(\text{pre}_i,...,\text{pre}_j)+1$$
$$g(i,j)=\min(\text{suc}_i,...,\text{suc}_j)-1$$
则
$$ans=\sum_{i=1}^n\sum_{j=i}^n [i\ge f(i,j)][g(i,j)\ge j](i-f(i,j))\times (g(i,j)-j)$$
考虑怎么算这个答案。
我们先枚举 $i$。这个 $f(i,j)$ 不降,$g(i,j)$ 不升,一定有一个最大的 $j$,使得 $[i,i],...,[i,j]$ 计入答案。我们发现对于 $i_1<i_2$,$j_1<j_2$,这样就可以双指针计算 $j$。
现在的问题就是对于每个 $i$,怎么快速计算
$$\sum_{k=i}^j (i-f(i,k))\times (g(i,k)-k)$$
现在把括号拆开
$$=\sum_{k=i}^j g(i,k)-i\times \sum_{k=i}^j k-\sum_{k=i}^j f(i,k)\times g(i,k)+\sum_{k=i}^j f(i,k)\times k$$
这些东西可以在线段树上搞。每次 $i$ 左移时,把所有值取 $\max/\min$ 的操作换成区间赋值。
现在你需要维护几种操作:
1. 对 $a/b$ 序列区间赋值
2. 求 $a/b$ 序列区间和
3. 求 $a/b$ 序列区间 $(a/b)_i\times i$ 和
4. 求 $a_i\times b_i$ 和
时间 $\mathcal{O}(n\log n)$
```cpp
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=100005;
const int inf=0x3f3f3f3f;
const int mod=1000000007;
int n,a[maxn],mp[maxn],lst[maxn],cnt[maxn],pre[maxn],suc[maxn],fp[maxn],fs[maxn],sta[maxn],top;
ll sum[maxn<<2][4],lazy[maxn<<2][2];
inline ll sum_2(ll x) { return x*(x+1)/2; }
inline ll S(int l,int r) { return sum_2(r)-sum_2(l-1); }
#define ls (rt<<1)
#define rs (rt<<1|1)
inline void pushup(int rt) { for(int i=0;i<4;i++) sum[rt][i]=sum[ls][i]+sum[rs][i]; }
inline void pushlazy(int rt,int l,int r,int z,int op)
{
sum[rt][2]=sum[rt][op^1]*z;
sum[rt][op]=(ll)(r-l+1)*z;
lazy[rt][op]=z;
if(op==0) sum[rt][3]=z*S(l,r);
}
inline void pushdown(int rt,int l,int r)
{
int mid=(l+r)>>1;
if(lazy[rt][0])
{
pushlazy(ls,l,mid,lazy[rt][0],0);
pushlazy(rs,mid+1,r,lazy[rt][0],0);
lazy[rt][0]=0;
}
if(lazy[rt][1])
{
pushlazy(ls,l,mid,lazy[rt][1],1);
pushlazy(rs,mid+1,r,lazy[rt][1],1);
lazy[rt][1]=0;
}
}
void update(int rt,int l,int r,int x,int y,int z,int op)
{
if(x<=l && r<=y) { pushlazy(rt,l,r,z,op); return; }
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(x<=mid) update(ls,l,mid,x,y,z,op);
if(y>mid) update(rs,mid+1,r,x,y,z,op);
pushup(rt);
}
ll query(int rt,int l,int r,int x,int y,int z)
{
if(x>y) return 0;
if(x<=l && r<=y) return sum[rt][z];
pushdown(rt,l,r);
int mid=(l+r)>>1; ll ans=0;
if(x<=mid) ans+=query(ls,l,mid,x,y,z);
if(y>mid) ans+=query(rs,mid+1,r,x,y,z);
return ans;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),mp[i]=a[i];
sort(mp+1,mp+n+1);
int tot=unique(mp+1,mp+n+1)-mp-1;
for(int i=1;i<=n;i++)
a[i]=lower_bound(mp+1,mp+tot+1,a[i])-mp,pre[i]=lst[a[i]]+1,lst[a[i]]=i;
for(int i=1;i<=tot;i++) lst[i]=n+1;
for(int i=n;i;i--) suc[i]=lst[a[i]]-1,lst[a[i]]=i;
for(int i=n;i;i--)
{
while(top && pre[i]>pre[sta[top]]) top--;
fp[i]=top?sta[top]-1:n,sta[++top]=i;
}
top=0;
for(int i=n;i;i--)
{
while(top && suc[i]<suc[sta[top]]) top--;
fs[i]=top?sta[top]-1:n,sta[++top]=i;
}
ll ans=0;
for(int i=n,j=n;i;i--)
{
update(1,1,n,i,fp[i],pre[i],0);
update(1,1,n,i,fs[i],suc[i],1);
cnt[a[i]]++;
while(cnt[a[i]]>1) cnt[a[j]]--,j--;
(ans+=query(1,1,n,i,j,3)+i*query(1,1,n,i,j,1)-i*S(i,j)-query(1,1,n,i,j,2))%=mod;
}
printf("%lld\n",(ans+mod)%mod);
return 0;
}
```