题解:CF2223B Zhily and Barknights

· · 题解

你是一名退役的 OIer,有一天你打开了 CF,突然发现今天有之前非常少见的 CF 下午场,于是你开始追忆了。

你切掉了前三题,来到了 D,发现题目的意思就是给你两个长度为 n 数组 ab,求对于 b 的所有重排顺序,中长度为 n 的数组 c_i = a_i \times b_i 的逆序对期望个数。

你开始回想你当初是怎么做这种题的,看到这种统计逆元题什么都不要想,直接撕贡献,先考虑答案的分子,也就是在所有排列里有多少个乘积后的逆序对。

就是这个地方,撕开,转化成对于 a_ia_j 有多少个排列能对应这一对 a 产生逆序对。

考虑这个东西会对答案产生什么贡献,接着撕贡献,撕成对于一组 a_ia_j 有多少个 b_kb_l 能满足 a_i \times b_k > a_j \times b_l,接着剩下的数直接排列组合算一下,剩下的数总共是 (n - 2)! 种方法,也就是说对于逆序对,能对答案产生 (n-2)! 的贡献,注意到分母是一个 n! 直接约掉变成 n \times (n - 1),其实现在这题基本做完了。

但是你发现但是这样做是 O(n ^ 3) 的,过不去。

你开始考虑优化,这个对于一组 a_ia_j 有多少个 b_kb_l 能满足 a_i \times b_k > a_j \times b_l 的东西一看就很能二分的样子,考虑将所有的 b_kb_l 存出来,接着考虑构造单调性。

将原来的式子拆一下就可以得到:

\frac{a_i}{a_j} > \frac{b_l}{b_k}

你开始考虑将每组 b_kb_l 排序之后二分,考虑该如何排序。

直接按分数来排序非常容易丢精度,考虑对于 \frac{b_j}{b_i}\frac{b_l}{b_k} 这两个分数比较大小。

就当比赛快要结束时,你突然发现这个式子可以拆开,如果 \frac{b_j}{b_i}\frac{b_l}{b_k} 大,当且仅当:

b_j \times b_k > b_i \times b_l

直接按这个排序就做完了,你飞快的码完了代码,提交通过了。

我常常追忆过去。

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<ll,ll>
#define dgd priority_queue<int>
#define xgd priority_queue<int,vector<int>,greater<int> >
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,a,b) for(int i=(a);i>=(b);i--)
#define nl cout << "\n"
const int INF=1e9;
const ll Inf=4e18;
const int abc=26;
const double pi=3.1415926;
int read(){
    int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-48;c=getchar();}
    return x*f;
}
ll n,ans,mod=998244353,a[2010],b[2010];
vector<pii > g;
bool cmp(pii a,pii b){
    return a.first*b.second<a.second*b.first;
}
ll qpow(ll a,ll b){
    ll res=1;
    while(b){
        if(b&1) res=(res*a)%mod;
        a=(a*a)%mod;
        b>>=1;
    }
    return res;
}
void solve(){
    cin >> n;
    for(int i=1;i<=n;i++) cin >> a[i];
    for(int i=1;i<=n;i++) cin >> b[i];
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            if(i==j) continue;
            g.push_back({b[i],b[j]});
        }
    }
    sort(g.begin(),g.end(),cmp);
    for(int i=1;i<=n;i++){
        for(int j=i+1;j<=n;j++){
            ll x=a[i],y=a[j];
            ll l=0,r=n*(n-1);
            while(l<r){
                ll mid=(l+r)>>1;
                pii p=g[mid];
                if(p.first*x<=p.second*y) l=mid+1;
                else r=mid;
            }
            ans+=(n*(n-1)-l);
        }
    }
    ans%=mod;
    cout << (ans*qpow(n*(n-1),mod-2))%mod;
    ans=0;
    g.clear();
    cout << "\n";
    return;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    int t=1;
    cin >> t;
    while(t--) solve();
    return 0;
}