题解 P6012 【[P5087] 数学 加强版】

· · 题解

对于弱化版,可以 0/1 背包,时间复杂度是 O(nk) 的。

对于每个数,可以选或者不选,那我们构造其生成函数为 1+a_ix,表示这个人选了就可以为总乘积产生 a_i 的贡献。

那么要求恰好 k 个,答案就是 [x^k]\prod_{i=1}^n (1+a_ix),用分治 fft 计算即可,这个题中需要使用拆系数 fft。

#include<queue>
#include<vector>
#include<math.h>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define SZ(x) ((int)x.size())
using namespace std;
typedef long long ll;
typedef vector<int> poly;
const int sz=17,N=1<<17|5;
const double Pi=acos(-1);

const int P=1e9+7;

int L,rev[N];
struct Cp {
    double x, y;
    Cp(double a = 0, double b = 0) : x(a), y(b) {}
    Cp operator + (Cp t) { return Cp(x + t.x, y + t.y); }
    Cp operator - (Cp t) { return Cp(x - t.x, y - t.y); }
    Cp operator * (Cp t) { return Cp(x * t.x - y * t.y, x * t.y + y * t.x); }
    Cp operator / (double t) { return Cp(x / t, y / t); }
    Cp operator ~ () { return Cp(x, -y); }
}w[N],_[N];

void init(){
    L=1<<sz;
    for(int i=0;i<L;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(sz-1));
    w[L/2]=Cp(1,0);
    for(int i=1;i<L/2;i++)w[i+L/2]=Cp(cos(2*Pi*i/L),sin(2*Pi*i/L));
    for(int i=L/2-1;i>=0;i--)w[i]=w[i<<1];
}

void dft(int n,Cp*a){
    for(int i=0,s=sz-__builtin_ctz(n);i<n;i++)_[rev[i]>>s]=a[i];
    for(int l=1;l<n;l<<=1)for(int i=0;i<n;i+=l+l)for(int j=0;j<l;j++)
    {Cp k=_[i+j+l]*w[j+l];_[i+j+l]=_[i+j]-k;_[i+j]=_[i+j]+k;}
    memcpy(a,_,sizeof(Cp)*n);
}

struct __mtt {
    int Mod;
    Cp a[N], b[N], c[N], d[N];
    __mtt(int _Md = 1000000007) : Mod(_Md) {}
    poly operator()(const poly& p,const poly& q) {
        int l;
        for(l=1; l<=SZ(p)+SZ(q)-2; l<<=1);
        memset(a,0,sizeof(Cp)*l);
        memset(b,0,sizeof(Cp)*l);
        for(int i=0; i<SZ(p); i++)
            a[i]=Cp(p[i]&32767,p[i]>>15);
        for(int i=0; i<SZ(q); i++)
            b[i]=Cp(q[i]&32767,q[i]>>15);
        dft(l,a);dft(l,b);
        for(int i=0; i!=l; i++){
            int j=(l-i)&(l-1);
            c[i]=(~a[i]+a[j])*Cp(0.5,0)*b[j];
            d[i]=(~a[i]-a[j])*Cp(0,0.5)*b[j];
        }
        dft(l,c);dft(l,d);
        auto Z=[&](double x){return(ll)(0.5+x)%Mod;};
        poly v(SZ(p)+SZ(q)-1);
        for(int i=0; i<SZ(v); i++)
            v[i]=(Z(c[i].x/l)+(Z(c[i].y/l)+Z(d[i].x/l))%Mod*32768+Z(d[i].y/l)%Mod*32768%Mod*32768)%Mod;
        return v;
    }
}mtt;

int n,k,x;
poly f[N];

poly solve(int l,int r)
{
    if(l==r)return f[l];
    return mtt(solve(l,(l+r)/2),solve((l+r)/2+1,r));
}

int main()
{
    init();
    scanf("%d%d",&n,&k);
    for(int i=1; i<=n; i++)
    {
        scanf("%d",&x);
        f[i]=(poly){1,x};
    }
    printf("%d",solve(1,n)[k]);
    return 0;
}