「CEOI2021」Diversity 题解
ForgotMe
·
·
题解
发现这个总多样性很丑陋,我们用柿子来表示。
可以发现一种元素 i 产生的贡献为
\dfrac{n\times (n+1)}{2}-\sum_{j=1}^{cnt_i}\dfrac{f_{i,j}(f_{i,j}+1)}{2}
于是可以得到总多样性
$$
S=\sum_{i=1}^K(\dfrac{n\times (n+1)}{2}-\sum_{j=1}^{cnt_i}\dfrac{f_{i,j}(f_{i,j}+1)}{2})
$$
其中 $K$ 表示不同元素个数,$n$ 表示区间长度。
又注意到 $n-c_i=\sum_{i=1}^{cnt_i}f_{i,j}$,$c_i$ 为元素 $i$ 的出现次数。
于是拆拆拆,就可以得到:
$$
S=\dfrac{1}{2}(KN(N+1)-N(K-1)-\sum_{i=1}^K\sum_{j=1}^{cnt_i}f_{i,j}^2)
$$
于是任务为最大化 $\sum_{i=1}^K\sum_{j=1}^{cnt_i}f_{i,j}^2$。
然后就可以猜结论了。
## 结论 1
在一个最优解中,重排后的序列每一种数字一定都是连续的排在一起的,不会分开这里放一个那里放一个。
这个结论挺简单,证明略。
## 结论 2
知道了前面那个东西还是不太能做,考虑一下应该怎么摆放,显然分析到现在数字的大小并没有什么意义,而是在于每种数字的个数。
然后经过一顿猜,发现最优策略是把每种数字的个数从小到大排序,然后先第一种数字都塞头,再把第二种数字都塞尾,第三种数字都塞头,交错着塞。
感性理解一下,这样子分配会很均匀,所以不会劣。(如果要看证明的请去官方题解看那 3 页数学公式。。。)
猜到这个东西就可以 $\mathcal{O(nq\log n)}$ 了。64 pts 到手。
考虑如何维护这个东西,发现可以把每种数字的个数都塞进一个桶里,形式化的来说,设 $t_i$ 表示有 $t_i$ 种数字的个数都为 $i$。
惊奇的发现这个桶里有值的地方最多 $\mathcal{O(\sqrt{n})}$ 个,因为最坏情况为 $1+2+3+4+...+L\le n$,$L\le \sqrt{n}$。
如果我们知道桶里那些地方有值,该如何计算贡献?
假设左边已经塞了 $l$ 个数字右边塞了 $r$ 个数字,当前左边还要塞 $lput$ 种,右边还要塞 $rput$ 种,每种个数为 $len$ 的数字。
可以左边塞的这些数字造成的左边空隙的贡献之和为:
$$
\sum_{i=0}^{lput-1}(i\times len+l)^2
$$
这个东西很好算,拆开就是自然数 $1$ 次幂和,$2$ 次幂和。
其他的 $3$ 种情况可以仿照上面这种情况计算。
剩下的问题就是如何维护这个桶了,考虑莫队,开个 `set` 暴力维护桶里哪些地方有值,算贡献的时候直接把有值的地方拉出来就行了,但是这样子复杂度 $\mathcal{O}(q\sqrt{n}\log n)$,极其的不优美,而且由于洛谷的 120s 限制开不到 7s,无法通过。
考虑一下转移的时候 $\mathcal{O}(1)$ 转移,可以开一个操作栈记录下所有的转移(插入一个数字/删除一个数字),把数字拉出来的时候就看最后一次操作到底是插入还是删除(显然只有最后一次操作才有影响),但是拉出来的数字可能无序,需要排序,注意到拉出来的数字最多会有 $\mathcal{O}(\sqrt{n})$ 个,于是复杂度就只有 $\mathcal{O}(q\sqrt{n}+q\sqrt{n\log n})$,可以通过此题,如果你想 $\mathcal{O}(q\sqrt{n})$,请使用基数排序。。。
丑陋的代码
```cpp
#include <cstdio>
#include <iostream>
#include <stack>
#include <queue>
#include <set>
#include <algorithm>
#include <cmath>
using namespace std;
#define LL long long
inline int rd(){
char c;
int x=0,f=1;
for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())x=x*10+c-'0';
return x*f;
}
#define mp make_pair
#define pp pair<int,int>
int n,q,B,bl[300005],tong[300005],col[300005],len,c[2][300005],a[300005],K,oplen,num[1200005],Num[300005];
bool ok[300005];
pp opr[1200005];
LL ans[50005];
struct que{
int l,r,id;
}Q[50005];
bool cmp(que a,que b) {
return (bl[a.l]^bl[b.l])?bl[a.l]<bl[b.l]:((bl[a.l]&1)?a.r<b.r:a.r>b.r);
}
set<int>S;
set<int>::iterator it;
void add(int x){
tong[col[a[x]]]--;
if(!tong[col[a[x]]])opr[++oplen]=mp(-1,col[a[x]]);
col[a[x]]++;
if(col[a[x]]==1)K++;
tong[col[a[x]]]++;
if(tong[col[a[x]]]==1)opr[++oplen]=mp(1,col[a[x]]);
}
void dec(int x){
tong[col[a[x]]]--;
if(!tong[col[a[x]]])opr[++oplen]=mp(-1,col[a[x]]);
col[a[x]]--;
if(col[a[x]]==0)K--;
tong[col[a[x]]]++;
if(tong[col[a[x]]]==1)opr[++oplen]=mp(1,col[a[x]]);
}
LL Sum1(int n){
return 1ll*n*(n+1)/2;
}
LL Sum2(int n){
return 1ll*n*(n+1)*(2*n+1)/6;
}
//0 1 2
LL calc1(int N,int c,int len){
return 1ll*c*c*(N+1)+1ll*Sum2(N)*len*len+2ll*Sum1(N)*len*c;
}
LL calc2(int N,int c,int len){
return 1ll*c*c*(N+1)+1ll*Sum2(N)*len*len-2ll*Sum1(N)*len*c;
}
//i^2len^2+2*i*lenl
LL solve(int L,int R){
int nlen=0;
for(int i=oplen;i>=1;i--){
if(ok[opr[i].second])continue;
ok[opr[i].second]=1;
if(opr[i].first==1)Num[++nlen]=opr[i].second;
}
for(int i=1;i<=len;i++){
if(ok[num[i]])continue;
ok[num[i]]=1;
Num[++nlen]=num[i];
}
for(int i=oplen;i>=1;i--)ok[opr[i].second]=0;
for(int i=1;i<=len;i++)ok[num[i]]=0;
len=nlen;
for(int i=1;i<=len;i++)num[i]=Num[i];
sort(num+1,num+len+1);
for(int i=1;i<=len;i++){
c[0][i]=num[i];
c[1][i]=tong[num[i]];
}
//表示长度为 (*it) 的有这么多个
int l=0,r=0,tot=0,N=R-L+1;
LL Ans=0;
for(int i=1;i<=len;i++){
int lput=0,rput=0,Len=c[0][i];
if(tot%2==0){
lput=(c[1][i]+1)/2;
rput=c[1][i]/2;
}else{
rput=(c[1][i]+1)/2;
lput=c[1][i]/2;
}
Ans+=calc1(lput-1,l,Len);
Ans+=calc2(lput-1,N-l-Len,Len);
Ans+=calc1(rput-1,r,Len);
Ans+=calc2(rput-1,N-r-Len,Len);
l+=lput*Len,r+=rput*Len;
tot+=c[1][i];
}
return 1ll*(1ll*K*N*(N+1)-1ll*N*(K-1)-Ans)/2;
}
int main() {
scanf("%d %d",&n,&q);
B=sqrt(n*1.0);
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
for(int i=1;i<=n;i++)bl[i]=(i-1)/B+1;
for(int i=1;i<=q;i++)scanf("%d %d",&Q[i].l,&Q[i].r),Q[i].id=i;
sort(Q+1,Q+q+1,cmp);
int l=1,r=0;
for(int i=1;i<=q;i++){
int L=Q[i].l,R=Q[i].r;
oplen=0;
while(l<L)dec(l),l++;
while(L<l)l--,add(l);
while(r<R)r++,add(r);
while(R<r)dec(r),r--;
ans[Q[i].id]=solve(L,R);
}
for(int i=1;i<=q;i++)printf("%lld\n",ans[i]);
return 0;
}
```