【题解】CF2029I Variance Challenge

· · 题解

性质:函数 f(x)=\sum_{i=1}^n(a_i-x)^2x=\bar{a} 时取到最小值,且最小值为方差乘 n

因为总共进行 m 次操作,每次操作最多令 na_ik,所以平均数可能的取值有 n\times m 个。

枚举每一种可能的取值,设为 x_0,建立费用流模型。即令源点、汇点都与每个点各连一条流量为 1、费用为 0 的边;对于 1n 每个点,i 号点都向 i+1 号点连 m 条边,流量均为 1,第 j 条边的费用为 (a_i+j\times k-x_0)^2-(a_i+(j-1)\times k-x_0)^2,跑最小费用流。

可以使用最小子段和进行模拟,由于边的编号越大,正边的费用越大,反边的费用越小,所以正边的编号越小越好,而反边的编号则越大越好,因此最优的正边与反边一定相邻,用一个数组记录正边最小编号即可。

代码如下,具体看注释:

#include<bits/stdc++.h>
using namespace std;
const __int128 inf=1e25;
int n,m,k,a[5010],p[5010];
long long sum,b[5010];
__int128 ans[5010];
long long calc(int i,int j){
    return (b[i]+j*k)*(b[i]+j*k)-(b[i]+(j-1)*k)*(b[i]+(j-1)*k);
}
void solve(long long x){
    __int128 flow=0;
    for(int i=1;i<=n;i++)
    {
        b[i]=a[i]-x;
        flow+=b[i]*b[i];
        p[i]=0;
        //p[i] 表示第 i 个位置已选的编号最小的正边的编号 
    }
    for(int j=1;j<=m;j++)
    {
        int l,r,lst,op;
        //l、r 分别记录选取子段的左、右端点,lst 记录当前最小后缀的左端点,op 记录所取子段是正边还是反边 
        __int128 rm=inf,mi=inf;
        //rm 记录最小后缀和,mi 记录最小子段和 
        for(int i=1;i<=n;i++)
        {
            __int128 ts=calc(i,p[i]+1);
            rm=min(rm+ts,ts);
            if(rm==ts)lst=i;
            if(rm<mi)
            {
                mi=rm;
                l=lst;
                r=i;
                op=1;
            }
        }
        rm=inf;
        for(int i=1;i<=n;i++)
        {
            __int128 ts=-calc(i,p[i]);
            if(p[i]==0)ts=inf;
            //如果当前位置没取过正边,则没有反边可取 
            rm=min(rm+ts,ts);
            if(rm==ts)lst=i;
            if(rm<mi)
            {
                mi=rm;
                l=lst;
                r=i;
                op=-1;
            }
        }
        for(int i=l;i<=r;i++)
            p[i]+=op;
        //更新正边最小编号数组 
        flow+=mi;
        ans[j]=min(ans[j],flow);
    }
}
int main(){
    int t;
    scanf("%d",&t);
    while(t--)
    {
        scanf("%d%d%d",&n,&m,&k);
        k*=n;
        sum=0;
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&a[i]);
            sum+=a[i];
            a[i]*=n;
        }
        //为了避免平均值是小数,先将代入函数的所有值乘 n,则最后求出的值是方差的 n^3 倍,也就是答案的 n 倍 
        for(int j=1;j<=m;j++)
            ans[j]=inf;
        for(int i=1;i<=n*m;i++)
            solve(sum+k/n*i);
        //枚举所有可能的平均值,找到最小答案 
        for(int j=1;j<=m;j++)
            printf("%lld ",(long long)(ans[j]/n));
        printf("\n");
    }
    return 0;
}