图腾 题解
Eraine
·
2024-08-13 18:33:03
·
题解
来源:国家队选拔 2008
编号:P4528
tag:容斥,数据结构
难度:黑
这题显然不好统计分别贡献,那就让我们想到容斥做。将原式中的每一项都拆成相对好处理的部分方便后续计数。
先搬运其他题解的式子.jpg
1324-1243-1432
=(1x2x-1423)-(12xx-1234)-(14xx-1423)
=1x2x+1234-(12xx+14xx)
=1x2x+1234-(1xxx-13xx)
=1x2x+1234-1xxx+13xx
出现 x 的位置表示待确定排名的位置。
相信大家都会 1234 和 1xxx 的方法吧。前者跑 3 遍树状数组,后者也用树状数组。求出 R_i 表示在 a_i 右边且比 a_i 大的元素个数。R_i\choose 3 即为 a_i 作为 1 的贡献。
接下来难点在于 1x2x。我们考虑如下做法:枚举 2 的位置,对于右边的 x 显然很好处理,关键是左边的 1x。要求左边的 x 要比当前枚举的 2 大,但我们发现左边两个元素相对于当前枚举的 2 的大小关系互异,不太方便处理。我们能处理的要不是只有一个存在大小关系,要不是两个大小关系相同且与两个元素互相的大小关系相同。这里考虑固定 1 的大小关系,保证 1<x,1<2。冗余情况为 1<x<2,即 1234,直接减去即可。
问题变成求 a_u\lt a_v,a_u\lt a_w,u\lt v\lt w 。考虑枚举 w ,设当前 u 的权值为满足 a_u\lt a_v,u\lt v 求出当前满足 a_u<a_v,a_u<a_w 的 u 的权值和即可。那么如何求出 u 的权值呢?考虑权值线段树维护。每加入一个元素时,当作为 u 时做单点修改打标记,当作为 v 时要对值域在 [1,a_v) 的所有已标记元素做 +1 操作(这显然可以通过 u 的标记保存在所有祖先上进行区间操作),当作为 w 时统计 [1,a_w) 的权值和即可。线段树具体操作详见代码。
还有一种情况 13xx。我们充分发扬人类智慧,发现 13xx 的逆排列即为 1x2x,所以相当于将 a 转换为其逆排列再做一遍就好了。
```cpp
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e5+5;
const ll mod=16777216;
int n,a[N],b[N];
struct bitTree{
ll sub[N];
#define lowbit(i) (i&(-i))
void clear(){
for(int i=1;i<=n;i++){
sub[i]=0;
}
}
void add(int x,ll k){
for(int i=x;i<=n;i=i+lowbit(i)){
sub[i]=sub[i]+k;
}
}
ll getsum(int x){
ll res=0;
for(int i=x;i;i=i-lowbit(i)){
res=res+sub[i];
}
return res;
}
}bit;
struct node{
int ch[2];
ll sz,tag,sum;
}tr[N<<2];
int rt,idx;
struct segTree{
#define lc tr[i].ch[0]
#define rc tr[i].ch[1]
#define mid (l+r)/2
void build(int &i,int l,int r){
i=++idx;
if(l==r){
return;
}
build(lc,l,mid);
build(rc,mid+1,r);
}
void pushup(int i){
tr[i].sum=tr[lc].sum+tr[rc].sum;
}
void lazy(int i,ll val){
tr[i].tag=tr[i].tag+val;
tr[i].sum=tr[i].sum+val*tr[i].sz;
}
void pushdown(int i){
lazy(lc,tr[i].tag);
lazy(rc,tr[i].tag);
tr[i].tag=0;
}
void modify(int i,int l,int r,int x){
++tr[i].sz;
if(l==r){
return;
}
pushdown(i);
if(x<=mid){
modify(lc,l,mid,x);
}else{
modify(rc,mid+1,r,x);
}
pushup(i);
}
void update(int i,int l,int r,int L,int R){
if(L<=l&&R>=r){
lazy(i,1);
return;
}
pushdown(i);
if(L<=mid){
update(lc,l,mid,L,R);
}
if(R>mid){
update(rc,mid+1,r,L,R);
}
pushup(i);
}
ll query(int i,int l,int r,int L,int R){
if(L<=l&&R>=r){
return tr[i].sum;
}
pushdown(i);
if(R<=mid){
return query(lc,l,mid,L,R);
}else if(L>mid){
return query(rc,mid+1,r,L,R);
}else{
return query(lc,l,mid,L,R)+query(rc,mid+1,r,L,R);
}
}
#undef lc
#undef rc
#undef mid
}sgt;
ll f[N];
ll solve1234(){
for(int i=1;i<=n;i++){
f[i]=1;
}
for(int dep=2;dep<=4;dep++){
bit.clear();
for(int i=1;i<=n;i++){
bit.add(a[i],f[i]);
f[i]=bit.getsum(a[i]-1);
}
}
ll res=0;
for(int i=1;i<=n;i++){
res=(res+f[i])%mod;
}
return res;
}
ll sum1234;
ll R[N];
void solveR(){
bit.clear();
for(int i=n;i;i--){
R[i]=(n-i)-bit.getsum(a[i]);
bit.add(a[i],1);
}
}
ll solve1x2x(){
solveR();
sgt.build(rt,1,n);
ll res=0;
for(int i=1;i<=n;i++){
if(1<=a[i]-1){
res=(res+sgt.query(rt,1,n,1,a[i]-1)*R[i])%mod;
sgt.update(rt,1,n,1,a[i]-1);
}
sgt.modify(rt,1,n,a[i]);
}
return (res-sum1234+mod)%mod;
}
ll solve1xxx(){
solveR();
ll res=0;
for(int i=1;i<=n;i++){
res=(res+(R[i])*(R[i]-1)*(R[i]-2)/6)%mod;
}
return res;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
sum1234=solve1234();
ll res=sum1234;
res=(res+mod-solve1xxx())%mod;
res=(res+solve1x2x())%mod;
for(int i=1;i<=n;i++){
b[i]=a[i];
}
for(int i=1;i<=n;i++){
a[b[i]]=i;
}
res=(res+solve1x2x())%mod;
printf("%lld\n",res);
return 0;
}
```
若有疑问或错误请指出,虚心接受您的意见。