题解:P7334 [JRKSJ R1] 吊打

· · 题解

考虑到 a^ka^{k \bmod 998244352}998244353 是同余的(费马小定理)。所以就可以把 a^{2^k} \bmod 998244353 转化为 a^{2^k \bmod 998244352} \bmod 998244353 就可以快速幂了。

这时候再进行线段树维护开根和平方的个数,注意到每次平方后开根可以抵消掉一次平方,于是每次在标记下传(pushdown)和制造标记时都考虑一下有没有抵消的情况,剩下的因为每个数最多开 \log \log V 次根就变成 1 了,所以开根的时间复杂度是有保障的。注意运算顺序是先开根再平方。然后就是线段树的区间加单点求值模板了。

时间复杂度 O((n+m) \log n+n \log \log V)

AC Code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N=200009;
const ll MOD=998244353;
struct Segment{
    ll l;
    ll r;
    ll toAddA;
    ll toAddB;
}tr[N*4];
ll a[N],n,m;
ll qpow(ll a,ll b,ll p){
    ll ans=1;
    while(b){
        if(b&1) (ans*=a)%=p;
        (a*=a)%=p;
        b>>=1;
    }
    return ans;
}
void build(ll u,ll l,ll r){
    tr[u]=(Segment){l,r,0,0};
    if(l==r) return;
    ll mid=(l+r)/2;
    build(u*2,l,mid);
    build(u*2+1,mid+1,r);
}
void pushdown(ll u){
    ll A=tr[u].toAddA,B=tr[u].toAddB;
    tr[u].toAddA=0;tr[u].toAddB=0;
    if(!A&&!B) return;
    if(tr[u*2].toAddB){
        if(tr[u*2].toAddB>A) tr[u*2].toAddB-=A;
        else tr[u*2].toAddA+=A-tr[u*2].toAddB,tr[u*2].toAddB=0;
    }
    else tr[u*2].toAddA+=A;
    if(tr[u*2+1].toAddB){
        if(tr[u*2+1].toAddB>A) tr[u*2+1].toAddB-=A;
        else tr[u*2+1].toAddA+=A-tr[u*2+1].toAddB,tr[u*2+1].toAddB=0;
    }
    else tr[u*2+1].toAddA+=A;
    tr[u*2].toAddB+=B;
    tr[u*2+1].toAddB+=B;
}
void add(ll u,ll l,ll r,ll op){
    if(tr[u].r<l||r<tr[u].l) return;
    if(l<=tr[u].l&&tr[u].r<=r){
        if(op==1){
            if(tr[u].toAddB) tr[u].toAddB--;
            else tr[u].toAddA++;
        }
        else tr[u].toAddB++;
        return;
    }
    pushdown(u);
    add(u*2,l,r,op);
    add(u*2+1,l,r,op);
}
ll query(ll u,ll x,ll op){
    if(tr[u].l==tr[u].r){
        if(op==1) return tr[u].toAddA;
        else return tr[u].toAddB;
    }
    pushdown(u);
    if(x<=tr[u*2].r) return query(u*2,x,op);
    else return query(u*2+1,x,op);
}
int main(){
    cin>>n>>m;
    for(ll i=1;i<=n;i++) cin>>a[i];
    build(1,1,n); 
    for(ll i=1;i<=m;i++){
        ll op,l,r;
        cin>>op>>l>>r;
        add(1,l,r,op);
    }
    ll ans=0;
    for(ll i=1;i<=n;i++){
        ll x=query(1,i,1);
        for(ll j=1;j<=x&&a[i]!=1;j++) a[i]=sqrt(a[i]);
        x=query(1,i,2); 
        a[i]=qpow(a[i],qpow(2,x,MOD-1),MOD);
        (ans+=a[i])%=MOD;
    }
    cout<<ans<<endl;
    return 0;
}