P8442(LMOI R1 F) 详解

· · 题解

题目链接

今天比赛快结束时,听说有人阿克了这场,就进来看了看 F,这不是典题吗,,

首先这个模数是个合数,可以分解为 24000001\times 30000001,分别设为 p_1p_2。那么只要分别对 p_1p_2 求解后用 crt 合并即可。

可以发现给出的 n 除了 13 都可以整除 p_1-1p_2-1,这些情况下 \omega_n 在模 p_1p_2 下都存在,直接用

\frac 1n\sum_{i=0}^{n-1}(\omega_n^i+\omega_n^{-i})^m \ \ \texttt{(1)}

计算,时间复杂度是 \Theta(n \log p) 的。

还有另外一种做法是进行二项式展开,单位根反演即得

\sum_{k=0}^m \binom mk \frac 1n\sum_{i=0}^{n-1}\omega_n^{(2k-m)i}=\sum_{k=0}^m\binom mk[n|(2k-m)] \ \ \texttt{(2)}

用这种方法计算一组数据的时间复杂度是 \Theta(p/n),需要 \Theta(p) 预处理阶乘。

综合一下这两种算法,解方程 n \log p = p/n 可以得到 n=1000 时为临界数据,小于 1000 的用方法 \texttt{(1)},大于 1000 的用方法 \texttt{(2)} 即可。

对于 n=13 这个特别数据可以发现,在模 p=30000001 意义下有:

(x+x^{12})^p \equiv x^5+x^8 \pmod{(x^n-1)} (x^5+x^8)^p\equiv 2+x^2+x^{11}\equiv (x+x^{12})^2 \pmod{(x^n-1)}

故直接将 mp^2-1 取模即可。特别地,对于 p=24000001 时循环节长度只有 p-1

最后就是常数优化,主要有五点:

  1. 对于题目中限定的 n \neq 13 的数据,m 为奇数时答案为零。

  2. 由单位根的性质,若 m 为偶数,在利用方法 \texttt{(1)}i 只用取到一半,之后取二倍即可。

  3. 对于 n=1000 的情况,若使用方法 \texttt{(1)},可以对快速幂做优化,做预处理。这里我选用 8 进制快速幂。

  4. 在使用方法 \texttt{(2)} 时,可以发现答案也是对称的,也可以只算一半。

  5. 使用无符号整数运算,效率会明显提升。

参考代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define ll unsigned long long
#define int unsigned int
#define N 30000003
#define P 720000054000001
#define i128 __int128_t
using namespace std;

const int pr[2] = {24000001,30000001};
const int g[2] = {23,14};
int ifac[N],sw[503][9][8];
int p,inv1k;

inline int power(int a,int t,int m){
    int res = 1;
    while(t){
        if(t&1) res = (ll)res*a%m;
        a = (ll)a*a%m;
        t >>= 1;
    }
    return res;
}

void init(int id){
    inv1k = power(1000,p-2,p);
    ifac[0] = 1,ifac[p-1] = p-1;
    for(int i=p-2;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
    int r = power(g[id],(p-1)/1000,p),w = 1,iw = 1;
    int ir = power(r,p-2,p);
    for(int i=0;i<500;++i){ 
        sw[i][0][0] = 1,sw[i][0][1] = w+iw;
        for(int k=2;k<8;++k) sw[i][0][k] = (ll)sw[i][0][k-1]*sw[i][0][1]%p;
        for(int j=1;j<9;++j){
            sw[i][j][0] = 1,sw[i][j][1] =  (ll)sw[i][j-1][7]*sw[i][j-1][1]%p;
            for(int k=2;k<8;++k) sw[i][j][k] = (ll)sw[i][j][k-1]*sw[i][j][1]%p;
        }
        w = (ll)w*r%p,iw = (ll)iw*ir%p;
    }
}

inline void multiply(const int *f,const int *g,int n,int *r){
    static int h[28];
    memset(h,0,sizeof(h));
    for(int i=0;i<n;++i)
    for(int j=0;j<n;++j)
        h[i+j] = (h[i+j]+(ll)f[i]*g[j])%p;
    for(int i=0;i<n;++i) r[i] = (h[i]+h[i+n])%p;
}

inline int solve(int n,ll m,int rt){
    int pw,t;
    ll res = 0;
    if(n==13){
        static int f[28],g[28];
        memset(f,0,sizeof(f));
        memset(g,0,sizeof(g));
        f[1] = f[n-1] = g[0] = 1;
        while(1){
            if(m&1) multiply(f,g,n,g);
            m >>= 1;
            if(m==0) break;
            multiply(f,f,n,f);
        }
        return g[0];
    }
    if(m&1) return 0;
    if(n==1000){
        for(int i=0;i<500;++i){
            pw = 1,t = m;
            for(int j=0;t;++j){
                if(t&7) pw = (ll)pw*sw[i][j][t&7]%p;
                t >>= 3;
            }
            res += pw;
        }
        return (res<<1)%p*inv1k%p;
    }
    if(n<=100){
        int r = power(rt,(p-1)/n,p);
        int ir = power(r,p-2,p),w = 1,iw = 1;
        n >>= 1;
        for(int i=0;i<n;++i){
            res += power(w+iw,m,p);
            w = (ll)w*r%p,iw = (ll)iw*ir%p;
        }
        return (res<<1)%p*power(n<<1,p-2,p)%p;
    }
    int st = (m>>1)%(n>>1),len = n>>1;
    for(int k=st;(k<<1)<m;k+=len) res += (ll)ifac[k]*ifac[m-k]%p;
    return ((res<<1)+(ll)ifac[m>>1]*ifac[m>>1])%p*power(ifac[m],p-2,p)%p;
}

inline ll crt(int x,int y){
    return ((i128)x*pr[1]*power(pr[1],pr[0]-2,pr[0])+(i128)y*pr[0]*power(pr[0],pr[1]-2,pr[1]))%P;
}

struct query{
    int n,m0,m1,x,y;
    ll m13;
    inline query(int _n=0,int _m0=0,int _m1=0,int _x=0,int _y=0):n(_n),m0(_m0),m1(_m1),x(_x),y(_y){}
}a[2503];

char str[10003];
int m0,m1,x,y,l,cnt;
ll m13;

signed main(){
    int n;
    p = pr[0];
    init(0);
    while(scanf("%d%s",&n,str)==2){
        l = strlen(str);
        m1 = m0 = m13 = 0;
        for(int i=0;i<l;++i) m0 = (m0*10+str[i]-'0')%(pr[0]-1);
        if(n==13){
            for(int i=0;i<l;++i) 
                m13 = (m13*10+str[i]-'0')%((ll)pr[1]*pr[1]-1);
        }else{
            for(int i=0;i<l;++i)
                m1 = (m1*10+str[i]-'0')%(pr[1]-1);
        }
        x = solve(n,m0,g[0]);
        a[++cnt] = (query(n,m0,m1,x));
        a[cnt].m13 = m13;
    }
    p = pr[1];
    init(1);
    for(int i=1;i<=cnt;++i){
        x = a[i].x;
        y = solve(a[i].n,a[i].n==13?a[i].m13:a[i].m1,g[1]);
        printf("%lld\n",crt(x,y));
    }
    return 0;
} 

后记:我没看比赛奖评,但这个出题人咋跑那么快啊?更神必的是代码量达到了惊人的 29.49\text {KB},不知道是牛逼科技还是打表优化常数。