[2019.2.15]LuoguP3723/BZOJ4827 [Hnoi2017]礼物(数学/FFT)

2019-02-15 16:27:38


设第一个串为 $a$ ,第二个串为 $b$ 。

不妨让我们的所有操作对 $b$ 进行。

如果我们将 $b$ 逆时针转动 $j(0\le j<n)$ 个单位,那么与 $a_i$ 对应的装饰物就是 $b_{i+j}$ 了。

注:在本文中,数组下标从1开始。对于 $b$ ,当数组下标大于数组长度,我们将其对数组长度取余。

我们对 $b$ 进行了 $j(0\le j<n)$ 次逆时针旋转,增加了 $c(-m\le c\le m)$ 的亮度,设 $ans_x$ 表示 $j=x$ 时的答案,那么

$ans_j=\sum_{i=1}^n(a_i-b_{i+j}-c)^2$

$ans_j=\sum_{i=1}^n[a_i^2+b_{i+j}^2-2c(a_i-b_{i+j})-2a_ib_{i+j}+c^2]$

$ans_j=\sum_{i=1}^na_i^2+\sum_{i=1}^nb_{i+j}^2-2c(\sum_{i=1}^na_i-\sum_{i=1}^nb_{i+j})+ nc^2-2\sum_{i=1}^na_ib_{i+j}$

$ans_j=\sum_{i=1}^na_i^2+\sum_{i=1}^nb_i^2-2c(\sum_{i=1}^na_i-\sum_{i=1}^nb_i)+nc^2-2\sum_{i=1}^na_ib_{i+j}$

发现只有最后一项与 $j$ 有关,且该项与 $c$ 无关。

考虑求 $\sum_{i=1}^na_ib_{i+j}$ 。

我们翻转 $b$ 。则原式变为

$\sum_{i=1}^na_ib_{n+j-i+1}$

发现上式是卷积的形式。

这样看不明显,但是如果我们认为当 $i>n$ , $a_i=0$ ,那原式等价于

$\sum_{i=1}^{n+j}a_ib_{n+j-i+1}$

就很明显了。

所以我们设当 $j=x$ ,原式值为 $v_{n+j}$ ,那么我们可以用 $FFT$ 求出 $v$ ,由于 $v_{n+j}$ 越大, $ans_j$ 越小,当 $0\le j<n$ , $n\le n+j<2n$ ,所以原式的最大值为 $max_{i=n}^{2n-1}v_i$ 。

我们枚举 $c$ ,即可求得其他项的答案,。

时间复杂度 $O(nm\log n)$

code:

注:方便进行 $FFT$ ,代码中下标从0开始。

code:

#include<bits/stdc++.h>
#define VAL(x) (int)(fabs(x.re)/(N*1.0)+0.5)
using namespace std;
const int INF=2e9;
const double pi=acos(-1);
struct com{
    double re,in;
}f[300010],f1[300010],f2[300010],cpy[300010];
com operator+(const com &x,const com &y){
    return (com){x.re+y.re,x.in+y.in};
}
com operator-(const com &x,const com &y){
    return (com){x.re-y.re,x.in-y.in};
}
com operator*(const com &x,const com &y){
    return (com){x.re*y.re-x.in*y.in,x.re*y.in+x.in*y.re};
}
int n,m,N=1,a,b,a1,b1,a2,b2,mx;
long long ans=INF;
void FFT(com *f,int l,int len,int op){
    if(len<2)return;
    int nl=len>>1,tmp=l-1;
    for(int i=l;i<l+len;++i)cpy[i]=f[i];
    for(int i=l;i<l+nl;++i)f[i]=cpy[++tmp],f[i+nl]=cpy[++tmp];
    FFT(f,l,nl,op),FFT(f,l+nl,nl,op);
    com w=(com){cos(pi/(1.0*nl)),sin(pi/(1.0*nl)*op)},nw=(com){1,0},t;
    for(int i=l;i<l+nl;++i,nw=nw*w)t=nw*f[i+nl],cpy[i]=f[i]+t,cpy[i+nl]=f[i]-t;
    for(int i=l;i<l+len;++i)f[i]=cpy[i];
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)scanf("%d",&a),f1[i-1]=(com){a,0},a1+=a,a2+=a*a;//len(f1)=n
    for(int i=1;i<=n;++i)scanf("%d",&b),f2[n-i]=f2[n-i+n]=(com){b,0},b1+=b,b2+=b*b;//len(f2)=2n
    //len(f)=3n-1
    while(N<3*n-1)N<<=1;
    FFT(f1,0,N,1),FFT(f2,0,N,1);
    for(int i=0;i<N;++i)f[i]=f1[i]*f2[i];
    FFT(f,0,N,-1);
    //f[x]=sum(0<=i<=x)f1[i]*f2[x-i]
    for(int i=n-1;i<2*n-1;++i)mx=max(mx,VAL(f[i]));
    for(int i=-m;i<=m;++i)ans=min(ans,1ll*n*i*i+2*i*(a1-b1));
    printf("%lld",ans+a2+b2-2*mx);
    return 0;
}