题解 P6012 【[P5087] 数学 加强版】
对于弱化版,可以 0/1 背包,时间复杂度是
对于每个数,可以选或者不选,那我们构造其生成函数为
那么要求恰好
#include<queue>
#include<vector>
#include<math.h>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define SZ(x) ((int)x.size())
using namespace std;
typedef long long ll;
typedef vector<int> poly;
const int sz=17,N=1<<17|5;
const double Pi=acos(-1);
const int P=1e9+7;
int L,rev[N];
struct Cp {
double x, y;
Cp(double a = 0, double b = 0) : x(a), y(b) {}
Cp operator + (Cp t) { return Cp(x + t.x, y + t.y); }
Cp operator - (Cp t) { return Cp(x - t.x, y - t.y); }
Cp operator * (Cp t) { return Cp(x * t.x - y * t.y, x * t.y + y * t.x); }
Cp operator / (double t) { return Cp(x / t, y / t); }
Cp operator ~ () { return Cp(x, -y); }
}w[N],_[N];
void init(){
L=1<<sz;
for(int i=0;i<L;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(sz-1));
w[L/2]=Cp(1,0);
for(int i=1;i<L/2;i++)w[i+L/2]=Cp(cos(2*Pi*i/L),sin(2*Pi*i/L));
for(int i=L/2-1;i>=0;i--)w[i]=w[i<<1];
}
void dft(int n,Cp*a){
for(int i=0,s=sz-__builtin_ctz(n);i<n;i++)_[rev[i]>>s]=a[i];
for(int l=1;l<n;l<<=1)for(int i=0;i<n;i+=l+l)for(int j=0;j<l;j++)
{Cp k=_[i+j+l]*w[j+l];_[i+j+l]=_[i+j]-k;_[i+j]=_[i+j]+k;}
memcpy(a,_,sizeof(Cp)*n);
}
struct __mtt {
int Mod;
Cp a[N], b[N], c[N], d[N];
__mtt(int _Md = 1000000007) : Mod(_Md) {}
poly operator()(const poly& p,const poly& q) {
int l;
for(l=1; l<=SZ(p)+SZ(q)-2; l<<=1);
memset(a,0,sizeof(Cp)*l);
memset(b,0,sizeof(Cp)*l);
for(int i=0; i<SZ(p); i++)
a[i]=Cp(p[i]&32767,p[i]>>15);
for(int i=0; i<SZ(q); i++)
b[i]=Cp(q[i]&32767,q[i]>>15);
dft(l,a);dft(l,b);
for(int i=0; i!=l; i++){
int j=(l-i)&(l-1);
c[i]=(~a[i]+a[j])*Cp(0.5,0)*b[j];
d[i]=(~a[i]-a[j])*Cp(0,0.5)*b[j];
}
dft(l,c);dft(l,d);
auto Z=[&](double x){return(ll)(0.5+x)%Mod;};
poly v(SZ(p)+SZ(q)-1);
for(int i=0; i<SZ(v); i++)
v[i]=(Z(c[i].x/l)+(Z(c[i].y/l)+Z(d[i].x/l))%Mod*32768+Z(d[i].y/l)%Mod*32768%Mod*32768)%Mod;
return v;
}
}mtt;
int n,k,x;
poly f[N];
poly solve(int l,int r)
{
if(l==r)return f[l];
return mtt(solve(l,(l+r)/2),solve((l+r)/2+1,r));
}
int main()
{
init();
scanf("%d%d",&n,&k);
for(int i=1; i<=n; i++)
{
scanf("%d",&x);
f[i]=(poly){1,x};
}
printf("%d",solve(1,n)[k]);
return 0;
}