题解 P7468 【[NOI Online 2021 提高组] 愤怒的小N】

· · 题解

另一种推式子方法。

题目大意

给定 nk-1 次多项式 F(x)=\sum_{i=0}^{k-1}a_ix^i ,求 \sum_{i=0}^{n-1} F(i)[popcnt(i)~mod~2=1] 的值,答案取模 10^9+7

n\le 2^{500000},1\le k\le 500,0\le a_i<10^9+7

解题思路

m=n-1

p_i=(-1)^{popcnt(i)} ,则 Ans=\dfrac 1 2(\sum_{i=0}^m F(i)+\sum_{i=0}^m p_iF(i))

左侧式子可以拉格朗日插值或斯特林数 O(k^2) 解决。考虑如何计算右侧式子。

考虑将 2i2i+1 进行配对。这两个数二进制中 1 的个数的奇偶性相反,即 p_{2i}=-p_{2i+1} ,且 p_{i}=p_{2i} 。那么有

\sum_{i=0}^m p_iF(i)=\sum_{i=0}^{\frac m 2} p_i(F(2i)-F(2i+1))-p_{m+1}F(m+1)[m\text{为偶数}]

发现一个奇妙的性质, F_1(x)=F(2x)-F(2x+1) 的次数正好是 k-2 !将 F_1(x) 作为新的 F(x) ,就可以带入式子继续递归下去,而不超过 k 次,这个多项式就会变成 0

根据二项式定理展开,我们可以预处理 2 的次幂,阶乘和阶乘逆元,在 O(k^3) 的时间复杂度内算出所有这些多项式。而每次 m 也就是 n-1 舍弃最末若干位的值,在 m 为偶数时,将一个值代入多项式计算并累加进答案即可。

需要特别注意, \log_2 n\le k 时,需要在前面补 0k 位,不然会出现问题。

时间复杂度 O(n+k^3) 。可以用多项式科技优化新的多项式的计算,使时间复杂度达到 O(n+k^2\log n)

参考代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=500010;
const int mod=1000000007;
#define int long long
char n[maxn];
int ksm(int a,int k)
{
    int x=a,ret=1;
    while(k)
    {
        if(k&1)ret=ret*x%mod;
        x=x*x%mod;k>>=1;
    }
    return ret;
}
int l,k,a[maxn],b[maxn],pww[maxn],frac[maxn],invf[maxn],val[maxn],pre[maxn],x[maxn],y[maxn],ans;
int F(int x){int ret=0,nww=1;for(int i=0;i<k;i++)ret=(ret+a[i]*nww)%mod,nww=nww*x%mod;return ret;}
int C(int n,int m){return frac[n]*invf[m]%mod*invf[n-m]%mod;}
signed main()
{
    scanf("%s%lld",n+1,&k);l=strlen(n+1);
    if(l<=k+1){for(int i=l;i>=1;i--)n[k+1-l+i]=n[i];for(int i=k+1-l;i>=0;i--)n[i]='0';l=k+1;}
    //printf("%s\n",n+1);
    for(int i=l;i>=1;i--)if(n[i]=='0')n[i]='1';else{n[i]='0';break;}
    for(int i=0;i<k;i++)scanf("%lld",a+i);
    for(int i=1;i<=l;i++)val[i]=(val[i-1]*2+n[i]-'0')%mod,pre[i]=pre[i-1]+n[i]-'0';
    pww[0]=frac[0]=invf[0]=1;
    for(int i=1;i<=k;i++)pww[i]=pww[i-1]*2%mod,frac[i]=frac[i-1]*i%mod,invf[i]=ksm(frac[i],mod-2);
    for(int i=0;i<=k+1;i++)
    {
        x[i]=i;
        if(i)y[i]=(y[i-1]+F(i))%mod;
        else y[i]=F(i);
    }
    if(val[l]<=k+1)ans=y[val[l]];
    else
    {
        for(int i=0;i<=k+1;i++)
        {
            int prod=y[i],divi=1;
            for(int j=0;j<=k+1;j++)if(i!=j)prod=prod*(val[l]-x[j]+mod)%mod,divi=divi*(x[i]-x[j]+mod)%mod;
            ans=(ans+prod*ksm(divi,mod-2)%mod)%mod;
        }
    }
    //printf("%lld\n",ans);
    for(int i=k-1,j=l;i>=0&&j>=1;i--,j--)
    {
        if(n[j]=='0')
        {
            if(pre[j]&1)ans=(ans+F(val[j]+1))%mod;
            else ans=(ans+mod-F(val[j]+1))%mod;
        }
        for(int t=0;t<=k;t++)b[t]=0;
        for(int k=0;k<=i;k++)for(int t=0;t<k;t++)
        {
            b[t]=(b[t]+mod-a[k]*pww[t]%mod*C(k,t)%mod)%mod;
        }
        for(int t=0;t<=k;t++)a[t]=b[t];
        //for(int t=0;t<=k;t++)printf("%d ",a[t]);printf("\n");
    }
    printf("%lld\n",ans*ksm(2,mod-2)%mod);
    return 0;
}