「LAOI-12」Calculate 题解

· · 题解

首先不难发现将 a 重排后不影响答案,所以我们先从小到大排列。

对于这个权值和有一个贪心策略:这个序列必然是一小一大的,具体证明的话可以任意交换,交换后一定更劣。

那么我们便可以对每个 a_i,a_j 组合进行计数,由于序列已排序,那么便是在 i 左边有 x 个数,j 左边有 y 个数,那么先假设左右选的数相同,方案数即为 \sum\limits_{i=0}^{\min(a,b)}\binom{a}{i}\binom{b}{i}=\binom{a+b}{\min(a,b)},上述式子可以用范德蒙德卷积证明,然后中间再选偶数个数乘起来。显然左右选的个数可以不对等,右边可以多一个,那方案便是 \sum\limits_{i=0}^{\min(a,b-1)}\binom{a}{i}\binom{b}{i+1}=\binom{a+b}{a+1},然后中间再选奇数个数乘起来,故时间复杂度 O(n^2+n\log p)

#include<bits/stdc++.h>
#include<ext/rope>
using namespace __gnu_cxx;
#define mp make_pair
#define pb push_back
#define dbg puts("-------------qaqaqaqaqaqaqaqaqaq------------")
#define pl (p<<1)
#define pr ((p<<1)|1)
#define pii pair<int,int>
#define int long long
#define mod 998244353
#define eps 1e-9
#define ent putchar('\n')
#define sp putchar(' ')
using namespace std;
inline int read(){
    char c=getchar(),f=0;int t=0;
    for(;c<'0'||c>'9';c=getchar()) if(!(c^45)) f=1;
    for(;c>='0'&&c<='9';c=getchar()) t=(t<<1)+(t<<3)+(c^48);
    return f?-t:t;
}
inline void write(int x){
    static int t[25];int tp=0;
    if(x==0) return(void)(putchar('0'));else if(x<0) putchar('-'),x=-x;
    while(x) t[tp++]=x%10,x/=10;
    while(tp--) putchar(t[tp]+48);
}
int a[5009],p2[10009];
int f[10009],inv[10009];
int qpow(int x,int y){
    int res=1;
    while(y){
        if(y&1) res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void init(){
    f[0]=inv[0]=1;
    p2[0]=1;
    for(int i=1;i<=10000;i++){
        p2[i]=p2[i-1]*2%mod;
        f[i]=f[i-1]*i%mod;
        inv[i]=qpow(f[i],mod-2);
    }
}
int calc(int n,int m){
    if(n==m||m==0) return 1;
    if(n<m) return 0;
    return f[n]*inv[n-m]%mod*inv[m]%mod;
}
int get(int n){return p2[n-1];}
signed main(){
    init();
    for(int task=0;task<=0;task++){
        char awa[30];
        //sprintf(awa,"calc%d.in",task);
        //freopen(awa,"r",stdin);
        //sprintf(awa,"calc%d.out",task);
        //freopen(awa,"w",stdout);
        int n=read();
        for(int i=1;i<=n;i++){
            a[i]=read();
        }
        int ans=0;
        sort(a+1,a+n+1);
        for(int i=1;i<=n;i++){
            for(int j=i+1;j<=n;j++){
                int las=ans;
                int mid=j-i-1,res=(a[j]-a[i])*(a[j]-a[i])%mod;
                int lsum=i-1,rsum=n-j;
                if(mid==0){
                    ans=(ans+res*(calc(lsum+rsum,min(lsum,rsum)))%mod)%mod;
                    //write(i),sp,write(j),sp,write(ans-las),dbg; 
                    continue;
                }
                int temp=calc(lsum+rsum,min(lsum,rsum))*get(mid)%mod;
                //if(rsum!=0) temp=(temp+get(mid)*calc(lsum+rsum,lsum+1)%mod)%mod;
                if(lsum!=0) temp=(temp+get(mid)*calc(lsum+rsum,rsum+1)%mod)%mod;
                //res=res*temp%mod;
                ans=(ans+(temp)*res%mod)%mod;
                //write(i),sp,write(j),sp,write(temp+1),sp,write(ans-las),dbg;
            }
        }
        write(ans),ent;
    }
    return 0;
}