题解 AT3956 【[AGC023E] Inversions】
题意
给出一个长度为
题解
搬运一下官方题解
令
首先考虑
我们枚举
现在来考虑
代码
想清楚了就非常好写。
#include<bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define pc putchar
#define chkmx(a,b) (a)=max((a),(b))
#define chkmn(a,b) (a)=min((a),(b))
#define fi first
#define se second
using namespace std;
template<class T>
void read(T&x){x=0;char c=getchar();bool f=0;for(;!isdigit(c);c=getchar())f^=c=='-';for(;isdigit(c);c=getchar())x=x*10+c-'0';if(f)x=-x;}
template<class T,class ...ARK>void read(T&x,ARK&...ark){read(x);read(ark...);}
template<class T>void write(T x){if(x<0)pc('-'),x=-x;if(x>=10)write(x/10);pc(x%10+'0');}
template<class T,class ...ARK>void write(T x,ARK...ark){write(x);pc(' ');write(ark...);}
template<class ...ARK>void writeln(ARK...ark){write(ark...);pc('\n');}
typedef long long ll;
#define lowbit(x) ((x)&-(x))
const int mod=1e9+7;
struct mint{
int x;
mint(int o=0){x=o;}
mint&operator+=(mint a){return(x+=a.x)%=mod,*this;}
mint&operator-=(mint a){return(x+=mod-a.x)%=mod,*this;}
mint&operator*=(mint a){return(x=1ll*x*a.x%mod),*this;}
mint&operator^=(int b){mint a=*this;x=1;while(b)(b&1)&&(*this*=a,1),a*=a,b>>=1;return*this;}
mint&operator/=(mint a){return*this*=(a^=mod-2);}
friend mint operator+(mint a,mint b){return a+=b;}
friend mint operator-(mint a,mint b){return a-=b;}
friend mint operator*(mint a,mint b){return a*=b;}
friend mint operator/(mint a,mint b){return a/=b;}
};
const int N=2e5+100;
int n,a[N],cnt[N],x[N],lst[N];
mint D[N],S;
template<class T>struct BIT{
T tree[N];
void clear(){memset(tree,0,sizeof tree);}
void add(int x,T val){for(;x<=n;x+=lowbit(x))tree[x]+=val;}
T qry(int x){T res=0;for(;x;x-=lowbit(x))res+=tree[x];return res;}
T qry(int l,int r){return qry(r)-((l==0)?0:qry(l-1));}
};
BIT<int>ct;
BIT<mint>sm;
signed main(){
read(n);
for(int i=1;i<=n;i++)read(a[i]),cnt[a[i]]++;
for(int i=n;i;i--)cnt[i]+=cnt[i+1];
for(int i=1;i<=n;i++)cnt[i]-=n-i,cnt[i]==0&&(puts("0"),exit(0),1);
D[0]=1,x[0]=0;S=1;
for(int i=1;i<=n;i++){
D[i]=D[i-1]/cnt[i];S*=cnt[i];
if(cnt[i]==1)x[i]=x[i-1]+1;
else x[i]=x[i-1],D[i]*=cnt[i]-1;
//writeln(i,cnt[i],D[i].x,x[i]);
}
for(int i=1;i<=n;i++)
if(x[i]==x[i-1])lst[i]=lst[i-1];
else lst[i]=i;
mint ans=0;
//考虑 i<j,Ai<=Aj 的情况
//贡献为 S/2 D[Aj]/D[Ai]
for(int j=1;j<=n;j++){
ans+=S/2*D[a[j]]*sm.qry(lst[a[j]],a[j]);
sm.add(a[j],1/D[a[j]]);
}
//再考虑 i<j,Ai>Aj 的情况
//正难则反考虑不满足就 S - S/2 D[Ai]/D[Aj]
sm.clear();
for(int i=n;i;i--){
ans+=S*ct.qry(1,a[i]-1)-S/2*D[a[i]]*sm.qry(lst[a[i]],a[i]-1);
sm.add(a[i],1/D[a[i]]);ct.add(a[i],1);
}
write(ans.x);
}