题解 P4152 【[WC2014]时空穿梭】

· · 题解

考虑枚举所有直线来计算答案,一条直线可以由第一个点以及最后一个点和第一个点的差向量确定. 确定差向量 x=(x_1,x_2,\cdots ,x_n) 之后第一个点有 (m_1-x_1)(m_2-x_2)\cdots(m_n-x_n) 种取法,令 d=\gcd(x_1,x_2,\cdots,x_n),那么该差向量上的整点只有 d+1 个,即 (x_1/d,x_2/d,\cdots,x_n/d)0d 倍. 由于已经钦定了第一个点和最后一个点,所以这条直线的贡献就是

$$ \sum_{1\leq x_i\leq m_i}\binom{\gcd\{x_i\}-1}{c-2}\prod_{i=1}^n (m_i-x_i) $$ 根据套路反演,设 $$ g(n)=\sum_{d|n}\binom{d-1}{c-2}\mu(\frac{n}{d}) $$ 对所有 $c$ 求出 $g$,这部分可以做到 $O(cm\log\log m)$. 那么就有 $$ \begin{aligned} &\sum_{1\leq x_i\leq m_i}\binom{\gcd\{x_i\}-1}{c-2}\prod_{i=1}^n (m_i-x_i)\\ =&\sum_{1\leq x_i\leq m_i}\left(\prod_{i=1}^n(m_i-x_i)\right)\sum_{d\mid x_i}g(d)\\ =&\sum_{d=1}^{\min\{m_i\}}g(d)\sum_{1\leq x_i\leq m_i/d}\prod_{i=1}^n(m_i-x_id)\\ =&\sum_{d=1}^{\min\{m_i\}}g(d)\prod_{i=1}^n\left(\sum_{x_i=1}^{m_i/d}(m_i-x_id)\right) \end{aligned} $$ 那么直接做就是 $O(Tnm)$ 的,我也不知道能不能过. 由于 $n$ 很小所以考虑把复杂度往 $n$ 上倾斜一下,考虑整除分块,对于 $\lfloor\frac{m_i}{d}\rfloor$ 都相同的一段拿出来计算. 注意到后面那个乘积是一个关于 $d$ 的 $n$ 次多项式,那么假设已经算出系数 $t_i$,带进去就得到 $$ \begin{aligned} &\sum_{d=L}^{R}g(d)\sum_{i=0}^nt_id^i\\ =&\sum_{i=0}^nt_i\sum_{d=L}^{R}g(d)d^i \end{aligned} $$ 我们可以用 $O(cnm)$ 的复杂度预处理 $g(d)d^k$ 的前缀和之后 $O(n)$ 计算一段 $d$ 的和. 段数是 $O(n\sqrt{m})$ 的所以这部分复杂度是 $O(n^2\sqrt{m})$ 的. 接下来要解决的问题是多项式系数怎么算,如果每次都暴力计算的话复杂度是 $O(n^3\sqrt{m})$,不太好. 考虑这个多项式实际上是一堆一次多项式的积,而这些一次多项式的变化次数是 $O(n\sqrt{m})$ 的. 那么我们可以直接维护积的多项式,需要乘或者除以一个一次多项式,这可以在 $O(n)$ 时间内完成,那么总复杂度 $O(n^2\sqrt{m})$. 最终的复杂度为 $O(Tn^2\sqrt{m}+cnm+cm\log\log m)

以下代码中含有大量卡常技术,可能不太能看, 轻松最优解.

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
using namespace std;
const int mod=10007,mod2=mod*mod;
const int N=100005;
struct FJ
{
    int w,id;
}fj[N];
int finalans[1005],C[20][N],prime[N],p[N],prime_cnt,T,n,m[20],inv[N],pres[20],pw[20][N],s[20][N];
int qmod1(const int &x){return x>=mod?x-mod:x;}
int qmod2(const int &x){return x+(x>>31&mod);}
void prework(int n)
{
    C[0][0]=1;
    for(int i=1;i<=n;i++)
    {
        C[0][i]=1;
        for(int j=1;j<=i&&j<=18;j++)
            C[j][i]=qmod1(C[j][i-1]+C[j-1][i-1]);
    }
    for(int j=0;j<=11;j++)pw[j][1]=1;
    for(int i=2;i<=n;i++)
    {
        if(!p[i])
        {
            prime[++prime_cnt]=i;
            pw[0][i]=1;for(int j=1;j<=11;j++)pw[j][i]=1ll*pw[j-1][i]*i%mod;
        }
        for(int j=1;j<=prime_cnt&&i*prime[j]<=n;j++)
        {
            int x=i*prime[j];p[x]=1;
            for(int k=0;k<=11;k++)pw[k][x]=pw[k][i]*pw[k][prime[j]]%mod;
            if(i%prime[j]==0)break;
        }
    }
    for(int d=0;d<=18;d++)
    {
        for(int j=n;j>=1;j--)C[d][j]=C[d][j-1];C[d][0]=0;
        for(int i=1;i<=prime_cnt;i++)
            for(int j=n/prime[i];j>=1;j--)
                C[d][j*prime[i]]=qmod2(C[d][j*prime[i]]-C[d][j]);
    }
 //   for(int i=1;i<=10;i++)cout<<C[3][i]<<" ";puts("");
    inv[1]=1;
    for(int i=2;i<mod;i++)inv[i]=(mod-mod/i)*inv[mod%i]%mod;
}
int tmp[20];
struct Poly
{
    int len,a[20];
    Poly(){len=0;a[0]=1;}
    void clear(){len=0;a[0]=1;}
    void mul(int u,int v)
    {
        ++len;a[len]=0;
        for(int i=len;i>=1;i--)a[i]=(u*a[i-1]+v*a[i])%mod;
        a[0]=v*a[0]%mod;
    }
    void div(int u,int v)
    {
        for(int i=len;i>=1;i--)
        {
            tmp[i-1]=a[i]*inv[u]%mod;
            a[i-1]=(a[i-1]-tmp[i-1]*v)%mod;
        }
        --len;
        for(int i=0;i<=len;i++)a[i]=tmp[i];
    }
    int get(int i)
    {
        return a[i];
    }
    void print()
    {
        cout<<"deg:"<<len<<endl;
        for(int i=0;i<=len;i++)cout<<a[i]<<" ";puts("");
    }
}G;
struct mypair{int a,b;}w[20];
struct Q
{
    int n,c,m[20],id;
}q[1005];
int maxn[20];
int cmp(const FJ &a,const FJ &b){return a.w<b.w;}
int cmp1(const Q &a,const Q &b){return a.c<b.c;}
int main()
{
    prework(100000);
    scanf("%d",&T);
    for(int i=1;i<=T;i++)
    {
        scanf("%d%d",&q[i].n,&q[i].c);q[i].id=i;maxn[q[i].c-2]=max(maxn[q[i].c-2],q[i].n);
        for(int j=1;j<=q[i].n;j++)scanf("%d",&q[i].m[j]);
    }
    sort(q+1,q+T+1,cmp1);
    for(int nowT=1;nowT<=T;nowT++)
    {
        n=q[nowT].n;int c=q[nowT].c-2;int M=100000,tn=0;
        for(int i=1;i<=n;i++)m[i]=q[nowT].m[i],M=min(M,m[i]);
        if(c!=q[nowT-1].c-2)
        {
            for(int k=0;k<=maxn[c];k++)
            {
                for(int i=1;i<=100000;i++)s[k][i]=C[c][i]*pw[k][i]%mod;
                for(int i=1;i<=100000;i++)s[k][i]=qmod1(s[k][i]+s[k][i-1]);
            }
        }
        int sqm=sqrt(M);
        for(int i=1;i<=n;i++)
            for(int d=sqm+1,lt;d<=M;d=lt+1)
            {
                lt=min(M,m[i]/(m[i]/d));
                fj[++tn]=(FJ){lt,i};
            }
        sort(fj+1,fj+tn+1,cmp);
 //       for(int i=1;i<=tn;i++)cout<<fj[i].w<<" "<<fj[i].id<<endl;
        fj[0].w=sqm;long long ans=0;
        for(int d=1;d<=sqm;d++)
        {
            int prod=1;
            for(int i=1;i<=n;i++)
            {
                int f=m[i]/d,w;
                w=(1ll*f*m[i]-(1ll*f*(f+1)>>1)*d)%mod;
                prod=prod*w%mod;
            }
            ans+=prod*C[c][d];
 //           cout<<ans<<endl;
        }
        for(int i=1;i<=n;i++)
        {
            int f=m[i]/(sqm+1);int c1=1ll*f*m[i]%mod,c2=(1ll*f*(f+1)>>1)%mod;
            w[i]=(mypair){c2?mod-c2:0,c1};
//            cout<<i<<" "<<c1<<" "<<c2<<endl;
        }   
        G.clear();
        for(int i=1;i<=n;i++)G.mul(w[i].a,w[i].b);
        ans%=mod;
        for(int id=1;id<=tn;)
        {
            int d=fj[id].w,pred=fj[id-1].w;

//            G.print();
            for(int j=0;j<=n;j++)
            {
 //               cout<<G.get(j)<<" "<<s[j][d]-s[j][pred]<<endl;
                ans+=G.get(j)*(s[j][d]-s[j][pred]);
            }
//            cout<<(ans%mod+mod)%mod<<endl;            
            while(id<=tn&&fj[id].w==d)
            {
                int i=fj[id].id;
                G.div(w[i].a,w[i].b);
                int f=m[i]/(d+1),c1=1ll*f*m[i]%mod,c2=(1ll*f*(f+1)>>1)%mod;
                w[i]=(mypair){c2?mod-c2:0,c1};
                G.mul(w[i].a,w[i].b);
                ++id;
            }
//            cout<<ans<<endl;
        }

        ans%=mod;ans=(ans+mod)%mod;
        finalans[q[nowT].id]=ans;
    }
    for(int i=1;i<=T;i++)printf("%d\n",finalans[i]);
}