「CEOI2021」Diversity 题解

· · 题解

发现这个总多样性很丑陋,我们用柿子来表示。

可以发现一种元素 i 产生的贡献为

\dfrac{n\times (n+1)}{2}-\sum_{j=1}^{cnt_i}\dfrac{f_{i,j}(f_{i,j}+1)}{2} 于是可以得到总多样性 $$ S=\sum_{i=1}^K(\dfrac{n\times (n+1)}{2}-\sum_{j=1}^{cnt_i}\dfrac{f_{i,j}(f_{i,j}+1)}{2}) $$ 其中 $K$ 表示不同元素个数,$n$ 表示区间长度。 又注意到 $n-c_i=\sum_{i=1}^{cnt_i}f_{i,j}$,$c_i$ 为元素 $i$ 的出现次数。 于是拆拆拆,就可以得到: $$ S=\dfrac{1}{2}(KN(N+1)-N(K-1)-\sum_{i=1}^K\sum_{j=1}^{cnt_i}f_{i,j}^2) $$ 于是任务为最大化 $\sum_{i=1}^K\sum_{j=1}^{cnt_i}f_{i,j}^2$。 然后就可以猜结论了。 ## 结论 1 在一个最优解中,重排后的序列每一种数字一定都是连续的排在一起的,不会分开这里放一个那里放一个。 这个结论挺简单,证明略。 ## 结论 2 知道了前面那个东西还是不太能做,考虑一下应该怎么摆放,显然分析到现在数字的大小并没有什么意义,而是在于每种数字的个数。 然后经过一顿猜,发现最优策略是把每种数字的个数从小到大排序,然后先第一种数字都塞头,再把第二种数字都塞尾,第三种数字都塞头,交错着塞。 感性理解一下,这样子分配会很均匀,所以不会劣。(如果要看证明的请去官方题解看那 3 页数学公式。。。) 猜到这个东西就可以 $\mathcal{O(nq\log n)}$ 了。64 pts 到手。 考虑如何维护这个东西,发现可以把每种数字的个数都塞进一个桶里,形式化的来说,设 $t_i$ 表示有 $t_i$ 种数字的个数都为 $i$。 惊奇的发现这个桶里有值的地方最多 $\mathcal{O(\sqrt{n})}$ 个,因为最坏情况为 $1+2+3+4+...+L\le n$,$L\le \sqrt{n}$。 如果我们知道桶里那些地方有值,该如何计算贡献? 假设左边已经塞了 $l$ 个数字右边塞了 $r$ 个数字,当前左边还要塞 $lput$ 种,右边还要塞 $rput$ 种,每种个数为 $len$ 的数字。 可以左边塞的这些数字造成的左边空隙的贡献之和为: $$ \sum_{i=0}^{lput-1}(i\times len+l)^2 $$ 这个东西很好算,拆开就是自然数 $1$ 次幂和,$2$ 次幂和。 其他的 $3$ 种情况可以仿照上面这种情况计算。 剩下的问题就是如何维护这个桶了,考虑莫队,开个 `set` 暴力维护桶里哪些地方有值,算贡献的时候直接把有值的地方拉出来就行了,但是这样子复杂度 $\mathcal{O}(q\sqrt{n}\log n)$,极其的不优美,而且由于洛谷的 120s 限制开不到 7s,无法通过。 考虑一下转移的时候 $\mathcal{O}(1)$ 转移,可以开一个操作栈记录下所有的转移(插入一个数字/删除一个数字),把数字拉出来的时候就看最后一次操作到底是插入还是删除(显然只有最后一次操作才有影响),但是拉出来的数字可能无序,需要排序,注意到拉出来的数字最多会有 $\mathcal{O}(\sqrt{n})$ 个,于是复杂度就只有 $\mathcal{O}(q\sqrt{n}+q\sqrt{n\log n})$,可以通过此题,如果你想 $\mathcal{O}(q\sqrt{n})$,请使用基数排序。。。 丑陋的代码 ```cpp #include <cstdio> #include <iostream> #include <stack> #include <queue> #include <set> #include <algorithm> #include <cmath> using namespace std; #define LL long long inline int rd(){ char c; int x=0,f=1; for(;!isdigit(c);c=getchar())if(c=='-')f=-1; for(;isdigit(c);c=getchar())x=x*10+c-'0'; return x*f; } #define mp make_pair #define pp pair<int,int> int n,q,B,bl[300005],tong[300005],col[300005],len,c[2][300005],a[300005],K,oplen,num[1200005],Num[300005]; bool ok[300005]; pp opr[1200005]; LL ans[50005]; struct que{ int l,r,id; }Q[50005]; bool cmp(que a,que b) { return (bl[a.l]^bl[b.l])?bl[a.l]<bl[b.l]:((bl[a.l]&1)?a.r<b.r:a.r>b.r); } set<int>S; set<int>::iterator it; void add(int x){ tong[col[a[x]]]--; if(!tong[col[a[x]]])opr[++oplen]=mp(-1,col[a[x]]); col[a[x]]++; if(col[a[x]]==1)K++; tong[col[a[x]]]++; if(tong[col[a[x]]]==1)opr[++oplen]=mp(1,col[a[x]]); } void dec(int x){ tong[col[a[x]]]--; if(!tong[col[a[x]]])opr[++oplen]=mp(-1,col[a[x]]); col[a[x]]--; if(col[a[x]]==0)K--; tong[col[a[x]]]++; if(tong[col[a[x]]]==1)opr[++oplen]=mp(1,col[a[x]]); } LL Sum1(int n){ return 1ll*n*(n+1)/2; } LL Sum2(int n){ return 1ll*n*(n+1)*(2*n+1)/6; } //0 1 2 LL calc1(int N,int c,int len){ return 1ll*c*c*(N+1)+1ll*Sum2(N)*len*len+2ll*Sum1(N)*len*c; } LL calc2(int N,int c,int len){ return 1ll*c*c*(N+1)+1ll*Sum2(N)*len*len-2ll*Sum1(N)*len*c; } //i^2len^2+2*i*lenl LL solve(int L,int R){ int nlen=0; for(int i=oplen;i>=1;i--){ if(ok[opr[i].second])continue; ok[opr[i].second]=1; if(opr[i].first==1)Num[++nlen]=opr[i].second; } for(int i=1;i<=len;i++){ if(ok[num[i]])continue; ok[num[i]]=1; Num[++nlen]=num[i]; } for(int i=oplen;i>=1;i--)ok[opr[i].second]=0; for(int i=1;i<=len;i++)ok[num[i]]=0; len=nlen; for(int i=1;i<=len;i++)num[i]=Num[i]; sort(num+1,num+len+1); for(int i=1;i<=len;i++){ c[0][i]=num[i]; c[1][i]=tong[num[i]]; } //表示长度为 (*it) 的有这么多个 int l=0,r=0,tot=0,N=R-L+1; LL Ans=0; for(int i=1;i<=len;i++){ int lput=0,rput=0,Len=c[0][i]; if(tot%2==0){ lput=(c[1][i]+1)/2; rput=c[1][i]/2; }else{ rput=(c[1][i]+1)/2; lput=c[1][i]/2; } Ans+=calc1(lput-1,l,Len); Ans+=calc2(lput-1,N-l-Len,Len); Ans+=calc1(rput-1,r,Len); Ans+=calc2(rput-1,N-r-Len,Len); l+=lput*Len,r+=rput*Len; tot+=c[1][i]; } return 1ll*(1ll*K*N*(N+1)-1ll*N*(K-1)-Ans)/2; } int main() { scanf("%d %d",&n,&q); B=sqrt(n*1.0); for(int i=1;i<=n;i++)scanf("%d",&a[i]); for(int i=1;i<=n;i++)bl[i]=(i-1)/B+1; for(int i=1;i<=q;i++)scanf("%d %d",&Q[i].l,&Q[i].r),Q[i].id=i; sort(Q+1,Q+q+1,cmp); int l=1,r=0; for(int i=1;i<=q;i++){ int L=Q[i].l,R=Q[i].r; oplen=0; while(l<L)dec(l),l++; while(L<l)l--,add(l); while(r<R)r++,add(r); while(R<r)dec(r),r--; ans[Q[i].id]=solve(L,R); } for(int i=1;i<=q;i++)printf("%lld\n",ans[i]); return 0; } ```