题解 P3784 【[SDOI2017]遗忘的集合】

· · 题解

a_i=0/1表示元素i是否在集合S中,元素i的生成函数,为\displaystyle\left(\frac{1}{1-x^i}\right)^{a_i},那么f的生成函数\displaystyle F(x)=\prod_{i\geq 1}\left(\frac{1}{1-x^i}\right)^{a_i},现在我们知道了F(x),要求a_i

两边同时取负对数,得到\displaystyle-\ln F(x)=\sum_{i\geq 1}a_i\ln(1-x^i),对右边的\ln(1-x^i)泰勒展开,得到\displaystyle-\ln F(x)=\sum_{i\geq 1}a_i\sum_{j\geq 1}-\frac{x^{ij}}{j},枚举ij的值,得到\displaystyle\ln F(x)=\sum_{T\geq 1}x^T\sum_{i\mid T}a_i\times\frac{i}{T}。我们对F求一个多项式的\ln,那么就可以知道右边x^T的系数,即对于所有T知道了\frac{1}{T}\sum_{i\mid T}a_i\times i,设G(x)=\ln F(x),即有[x^T]G(x)=\frac{1}{T}\sum_{i\mid T}a_i\times i,那么只需要莫比乌斯反演即可求出a_i。最后只需要对于每个a_i\neq 0输出i即为答案。由生成函数的构造过程即可证明方案唯一。

复杂度O(n\log n),多项式求\ln

#include<bits/stdc++.h>

#define For(i,_beg,_end) for(int i=(_beg),i##end=(_end);i<=i##end;++i)
#define Rep(i,_beg,_end) for(int i=(_beg),i##end=(_end);i>=i##end;--i)

template<typename T>T Max(const T &x,const T &y){return x<y?y:x;}
template<typename T>T Min(const T &x,const T &y){return x<y?x:y;}
template<typename T>int chkmax(T &x,const T &y){return x<y?(x=y,1):0;}
template<typename T>int chkmin(T &x,const T &y){return x>y?(x=y,1):0;}
template<typename T>void read(T &x){
    T f=1;char ch=getchar();
    for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
    for(x=0;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
    x*=f;
}

typedef long long LL;
const int maxn=1<<19|1,N=1<<19;
const double pi=acos(-1);
struct E{
    double x,y;
    E(){}
    E(double a,double b){x=a;y=b;}
    E operator=(const int &a){x=a;y=0;return *this;}
    E conj(){return E(x,-y);}
}A[maxn],B[maxn],dfta[maxn],dftb[maxn],dftc[maxn],o[maxn];
E operator+(const E &a,const E &b){return E(a.x+b.x,a.y+b.y);}
E operator-(const E &a,const E &b){return E(a.x-b.x,a.y-b.y);}
E operator*(const E &a,const E &b){return E(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int n,p,s[maxn],a[maxn],f[maxn],g[maxn],r[maxn],tmp[maxn];
int prime[maxn],cnt,mu[maxn],ans;
bool ok[maxn];

int power(int,int);
void Init(int);
void fft(E*,int);
void mul(int*,int,int*,int,int*,int);
void ln(int*,int,int*);
void inverse(int*,int,int*);

int main(){
    read(n);read(p);
    Init(n);f[0]=1;
    For(i,1,n) read(f[i]);
    ln(f,n,g);
    For(i,1,n) g[i]=1LL*g[i]*i%p;
    For(i,1,n) for(int j=i;j<=n;j+=i) a[j]+=mu[i]*g[j/i];
    For(i,1,n) if(a[i]) ans++;
    printf("%d\n",ans);
    For(i,1,n) if(a[i]) printf("%d ",i);
    putchar(10);
    return 0;
}

void ln(int *x,int n,int *res){
    int inv[maxn],d[maxn];
    For(i,1,n) d[i-1]=1LL*i*x[i]%p;
    memset(inv,0,sizeof inv);
    inverse(x,n+1,inv);
    mul(d,n,inv,n,res,n);
    Rep(i,n,1) res[i]=1LL*res[i-1]*power(i,p-2)%p;
}
void inverse(int *x,int y,int *res){
    if(y==1)return res[0]=power(x[0],p-2),void();
    inverse(x,(y+1)>>1,res);
    mul(x,y-1,res,y-1,tmp,y-1);
    tmp[0]=2-tmp[0]+p;
    For(i,1,y-1) tmp[i]=p-tmp[i];
    mul(tmp,y-1,res,y-1,res,y-1);
}
void fft(E *a,int n){
    For(i,0,n-1) if(i<r[i]) std::swap(a[i],a[r[i]]);
    for(int i=1;i<n;i<<=1)
        for(int j=0,tmp=i<<1;j<n;j+=tmp)
            For(k,0,i-1){
                E x=a[j+k],y=o[N/i*k]*a[j+k+i];
                a[j+k]=x+y;a[j+k+i]=x-y;
            }
}
void mul(int* a,int n,int *b,int m,int *res,int k){
    int M=sqrt(p),nn=n,mm=m,L=-1;
    For(i,0,n) A[i].x=a[i]/M,A[i].y=a[i]%M;
    For(i,0,m) B[i].x=b[i]/M,B[i].y=b[i]%M;
    for(m+=n,n=1;n<=m;n<<=1) L++;
    For(i,0,n-1) r[i]=(r[i>>1]>>1)|((i&1)<<L);
    For(i,nn+1,n-1) A[i]=0;
    For(i,mm+1,n-1) B[i]=0;
    fft(A,n);fft(B,n);
    For(i,0,n-1){
        int j=(n-i)&(n-1);
        E ai=(A[i]+A[j].conj())*E(0.5,0);
        E bi=(A[i]-A[j].conj())*E(0,-0.5);
        E ci=(B[i]+B[j].conj())*E(0.5,0);
        E di=(B[i]-B[j].conj())*E(0,-0.5);
        dfta[i]=ai*ci;
        dftb[i]=ai*di+bi*ci;
        dftc[i]=bi*di;
    }
    For(i,0,n) A[i]=dfta[i]+dftb[i]*E(0,1);
    fft(A,n);fft(dftc,n);
    For(i,0,k){
        int j=(n-i)&(n-1);
        int ai=(LL)(A[j].x/n+0.5)%p;
        int bi=(LL)(A[j].y/n+0.5)%p;
        int ci=(LL)(dftc[j].x/n+0.5)%p;
        res[i]=(1LL*M*M*ai+1LL*M*bi+ci)%p;
    }
    For(i,k+1,n-1) res[i]=0;
}
int power(int x,int y){
    int res=1;
    for(;y;y>>=1,x=1LL*x*x%p) if(y&1) res=1LL*res*x%p;
    return res;
}
void Init(int n){
    mu[1]=1;
    For(i,2,n){
        if(!ok[i]) prime[++cnt]=i,mu[i]=-1;
        for(int j=1;j<=cnt&&i*prime[j]<=n;j++){
            ok[i*prime[j]]=true;
            if(i%prime[j]) mu[i*prime[j]]=-mu[i];
            else mu[i*prime[j]]=0;
        }
    }
    For(i,0,N) o[i]=E(cos(pi*i/N),sin(pi*i/N));
}