题解 P5431 【【模板】乘法逆元2】

2019-05-31 17:32:13


简单套路。

如果直接做,则复杂度为 $O(n\log p)$ ,不可能通过。

因此我们考虑优化。

可以先回忆阶乘的逆元是怎么算的。

我们计算阶乘的逆元的时候,是先求一个前缀积,然后求出最后一个前缀积的逆元后,一个个乘回去就可以得到每个阶乘的逆元了。

这个可以也这样处理。我们求出前缀积后,直接一个个乘回去就可以了。注意我们处理出的是每个前缀积的逆元,因此每个逆元再乘上上一个前缀积就可以得到这一个数字的逆元了。

用代码来说就是:

$$prd[i]=prd[i-1]*a[i]$$

$$inv[n]=inv(prd[n])$$

$$inv[k]=inv[k+1]*a[k]$$

$$realinv[k]=inv[k]*prd[k-1]$$

复杂度 $O(n+\log p)$ 。

注意这道题实际上并不需要将逆元存起来。我们只需要将数组翻转,然后在倒着乘回去的时候顺便计算答案就可以了。你也可以将前缀积改成后缀积。

代码:

#include<cctype>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<iostream>
#define Rep(i,a,b) for(register int i=(a);i<=(b);++i)
#define Repe(i,a,b) for(register int i=(a);i>=(b);--i)
#define rep(i,a,b) for(register int i=(a);i<(b);++i)
#define pb push_back
#define mp make_pair
typedef unsigned long long uint64;
typedef unsigned int uint32;
typedef long long ll;
using namespace std;

namespace IO
{
    const uint32 Buffsize=1<<15,Output=1<<23;
    static char Ch[Buffsize],*S=Ch,*T=Ch;
    inline char getc()
    {
        return((S==T)&&(T=(S=Ch)+fread(Ch,1,Buffsize,stdin),S==T)?0:*S++);
    }
    static char Out[Output],*nowps=Out;

    inline void flush(){fwrite(Out,1,nowps-Out,stdout);nowps=Out;}

    template<typename T>inline void read(T&x)
    {
        x=0;static char ch;T f=1;
        for(ch=getc();!isdigit(ch);ch=getc())if(ch=='-')f=-1;
        for(;isdigit(ch);ch=getc())x=x*10+(ch^48);
        x*=f;
    }

    template<typename T>inline void write(T x,char ch='\n')
    {
        if(!x)*nowps++='0';
        if(x<0)*nowps++='-',x=-x;
        static uint32 sta[111],tp;
        for(tp=0;x;x/=10)sta[++tp]=x%10;
        for(;tp;*nowps++=sta[tp--]^48);
        *nowps++=ch;
    }
}
using IO::read;
using IO::write;
using IO::getc;
using IO::flush;

inline void Chkmin(int&u,int v){u>v?u=v:0;}

inline void Chkmax(int&u,int v){u<v?u=v:0;}

inline void file()
{
#ifndef ONLINE_JUDGE
    freopen("water.in","r",stdin);
    freopen("water.out","w",stdout);
#endif
}

const int MAXN=5e6+7;

static int n,mod,k,a[MAXN];

inline void init()
{
    read(n),read(mod),read(k);
    Rep(i,1,n)read(a[i]);
}

static int sfd[MAXN];

inline int power(int u,int v)
{
    register int sm=1;
    for(;v;v>>=1,u=(ll)u*u%mod)if(v&1)
        sm=(ll)sm*u%mod;
    return sm;
}

inline void solve()
{
    sfd[n+1]=1;
    Repe(i,n,1)sfd[i]=(ll)sfd[i+1]*a[i]%mod;
    register int iv=power(sfd[1],mod-2),pwk=1,ans=0;
    Rep(i,1,n)
    {
        pwk=(ll)pwk*k%mod;
        ans=(ans+(ll)pwk*iv%mod*sfd[i+1])%mod;
        iv=(ll)iv*a[i]%mod;
    }
    cout<<ans<<endl;
}

int main()
{
    file();
    init();
    solve();
    return 0;
}