#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#define int long long
#define mod 998244353
using namespace std;
int n,a[1005],mx,mn=2e9;
int f[1005][2];
inline int pw(int a,int p){
if (p==0)return 1;
int t=pw(a,p/2);
t=t*t%mod;
if (p%2==1)t=t*a%mod;
return t;
}
inline int inv(int a){
return pw(a,mod-2);
}
inline int dfs(int now,int flag){
if (a[now]==mn)return 0;
if (f[now][flag]!=-1)return f[now][flag];
int tot=0,s=0;
for (int i=1;i<=n;i++){
if (i==now)continue;
if (a[i]<a[now])tot++;
else if (a[i]==a[now]&&flag==0)tot++;
}
int _inv=inv(tot);
for (int i=1;i<=n;i++){
if (i==now)continue;
if (a[i]<a[now])s+=_inv*(dfs(i,0)+abs(now-i)),s%=mod;
else if (a[i]==a[now]&&flag==0)s+=_inv*(dfs(i,1)+abs(now-i)),s%=mod;
}
return f[now][flag]=s;
}
signed main(){
cin>>n;
for (int i=1;i<=n;i++)cin>>a[i],mx=max(mx,a[i]),mn=min(mn,a[i]);
memset(f,-1,sizeof(f));
for (int i=1;i<=n;i++){
if (a[i]==mx){
cout<<dfs(i,0)<<endl;
return 0;
}
}
return 0;
}
然后考虑满分dp,先按高度排序,然后相同高度一起转移。
有两种转移情况:相同高度与不同高度。
设 g_{i} 表示由不同高度转移,f_{i} 表示两种转移均可。
第二种转移情况则由 $g_{i}$ 之和 与 当前点到相同高度点的距离和求出(使用容斥原理算出:用当前点到所有点的距离和减去当前点到之前点的距离和)。
$f_{i}$ 直接用两种转移乘上相应概率求出。
对于距离和的求法,可以参考[这题](https://www.luogu.com.cn/problem/P2345)。
详见代码:
```cpp
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
#define mod 998244353
using namespace std;
int n,g[500005],f[500005],dis[500005];
int ans,sum;
struct node{
int pos,h;
}a[500005];
bool cmp(node a,node b){
return a.h<b.h;
}
inline int pw(register int a,register int p){
if (p==0)return 1;
register int t=pw(a,p/2);
t=t*t%mod;
if (p%2==1)t=t*a%mod;
return t;
}
inline int inv(register int a){
return pw(a,mod-2);
}
inline void getmod(register int &a){
a=(a%mod+mod)%mod;
return;
}
struct BIT{
int tree[1000005];
inline void mem(){
memset(tree,0,sizeof(tree));
return;
}
inline int lowbit(register int &x){
return x&(-x);
}
inline int query(register int pos){
int ans=0;
while(pos)ans+=tree[pos],getmod(ans),pos-=lowbit(pos);
return ans;
}
inline void add(register int pos,register int val){
while(pos<=n)tree[pos]+=val,getmod(tree[pos]),pos+=lowbit(pos);
return;
}
}bit1,bit2;
inline void add(int pos){
bit1.add(pos,1);
bit2.add(pos,pos);
return;
}
inline void del(int pos){
bit1.add(pos,-1);
bit2.add(pos,-pos);
}
inline int ask(int pos){
register int ans=bit1.query(pos-1)*pos-bit2.query(pos-1);
getmod(ans);
ans+=(bit2.query(n)-bit2.query(pos))-(bit1.query(n)-bit1.query(pos))*pos;
getmod(ans);
return ans;
}
signed main(){
cin>>n;
for (int i=1;i<=n;i++)scanf("%lld",&a[i].h),a[i].pos=i;
sort(a+1,a+1+n,cmp);
bit1.mem();
bit2.mem();
for (int i=1;i<=n;){
register int j=i,nowsum=0;
while(a[j].h==a[i].h)j++;
for (register int k=i;k<j;k++){
dis[k]=ask(a[k].pos);
getmod(dis[k]);
g[k]=sum+dis[k];
getmod(g[k]);
g[k]*=inv(i-1);
getmod(g[k]);
nowsum+=g[k];
getmod(nowsum);
}
for (register int k=i;k<j;k++)add(a[k].pos);
for (register int k=i;k<j;k++){
register int now=nowsum+ask(a[k].pos)-dis[k]-g[k];
now*=inv(j-i-1);
getmod(now);
f[k]=(inv(j-2)*(i-1)%mod)*g[k]%mod+(inv(j-2)*(j-i-1)%mod)*now%mod;
getmod(f[k]);
}
for (register int k=i;k<j;k++){
sum+=f[k];
getmod(sum);
}
i=j;
}
cout<<f[n]<<endl;
return 0;
}
```
考场上JK用严格线性的算法吊打了std...