题解:P14523 【MX-S11-T4】Ice Drop

· · 题解

为了方便,下文中 p 一律表示质数,定义当 l > rf(a[l,r]) = 0

首先我们考虑如何判断一个序列是否是好的。

我们容易发现,其实我们不需要判断所有数字,只要检查所有满足 x = p^{k} 形式的正整数(k 为正整数),如果 x 的倍数在序列 a 中都是连续的一段,那么序列 a 是好的。

既然我们可以把质因子拆开,那么我们先考虑长度为 m 的序列 bb_{1}=p^{k_{1}},b_{m}=p^{k_{m}},b_{2}=b_{3}=\dots=b_{m-1}=1k 由正整数构成且 k_{1} \leq k_{m}) 的替换过程中,\operatorname{LCM}(b_{1},b_{2},\dots,b_{m}) = \operatorname{LCM}(b^{\prime}_{1},b^{\prime}_{2},\dots,b^{\prime}_{m}) 这一限制条件的意义。注意到完成替换后,b^{\prime}_{i} 一定可以被表示为 p^{k_{i}} 的形式。由于 \operatorname{LCM}(b_{1},b_{2},\dots,b_{m}) = p^{k_{m}},所以 \forall 1 \leq i < mk_{i} \leq k_{m},而又要满足上一段中提到的性质,因此 \forall 1 \leq i \leq m-1 都有 k_{i} \leq k_{i+1}

思考一下如何计算方案数。替换方案数其实就是计算可能的差分序列数量,现在我们就把这个问题转换成了经典问题:有 b_{m}-b_{1} 个无标号小球要放入 m-1 个有标号盒子,由插板法易知方案数为 \binom{b_{m}-b_{1}+m-2}{m-2}

取消掉对 b_{1},b_{m} 的限制,你会发现我们只需要在原来的基础上,对 \operatorname{lcm}(b_{1},b_{m}) 每一个质因子算一遍,最后将方案数相乘即可得到 f(b) 的值。

现在我们可以回到原问题了。\sum\limits_{x=l}^{r} \sum\limits_{y=x}^{r} f(a[x,y]) 这样的格式不难让人想到利用扫描线与历史和线段树(区间乘、单点赋值与历史和)维护。我们对 r 扫描,对于线段树上每个点(令其表示的区间范围为 [L,R])我们维护这样的一个矩阵(记录本层答案与历史和):

\begin{bmatrix} \sum\limits_{i=L}^{R} f(a[i,r]) & \sum\limits_{i=L}^{R}\sum\limits_{j=1}^{r} f(a[i,j]) \\ 0 & 0 \end{bmatrix}

并维护懒惰标记矩阵方便区间修改。

k 的操作可以表示为乘上矩阵 \begin{bmatrix} k & 0 \\ 0 & 1 \end{bmatrix},而 r 加一可以表示为乘上矩阵 \begin{bmatrix} 1 & 1 \\ 0 & 1 \end{bmatrix}。这样我们就完成了维护。

每扫到一个 r,我们都需要把无论如何赋值得到的序列都不好情况去掉,也就是一段前缀乘 0;记 r 前面离 r 最近的非 1 值所在的位置为 c,每一次我们需要重新计算前面的贡献,容易用区间乘做到(将上一次乘的方案数撤销,在乘上这一轮的方案数),当 a_{r} > 1 我们还需要暴力重新计算 c \sim r-1 中每个点这一轮的贡献。此外,我们还需要把 r 所在的位置赋值为 1

总的时间复杂度为 \mathcal{O(k^{3} n \log n)},其中 k 为矩阵的大小,即 2。常数略大。

代码实现起来有亿点点困难,我花了 3h+

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
template<class T>
void in(T &x){
    char c=getchar();T f=1; x=0;
    while(c<'0'||c>'9'){
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=x*10+c-'0';
        c=getchar();
    }
    x*=f;
}
const int N=500010,M=500000,Mod=1000000007;
int p[N],f[N],cnt,l,r; bool isp[N];
int n,q,a[N],pre[N],lstapp[N],cls[N];
ll ksm(ll a,ll b){
    ll res=1;
    while(b){
        if(b&1) res=res*a%Mod;
        a=a*a%Mod; b>>=1;
    }
    return res;
}
ll fac[N<<1],inv[N<<1],ifac[N<<1];
struct Queries{
    int l,r,h,id;
    bool operator<(const Queries &a){return h<a.h;}
}qs[N];
ll ans[N];
void getprime(){
    memset(isp,1,sizeof(isp));
    isp[0]=isp[1]=0;
    for(int i=2;i<=M;i++){
        if(isp[i]){
            p[++cnt]=i;
            f[i]=i;
        }
        for(int j=1;j<=cnt&&i*p[j]<=M;j++){
            isp[i*p[j]]=0;
            f[i*p[j]]=p[j];
            if(i%p[j]==0) break;
        }
    }
}
void getfac(){
    fac[0]=inv[0]=ifac[0]=fac[1]=inv[1]=ifac[1]=1;
    for(int i=2;i<=(M<<1);i++){
        fac[i]=fac[i-1]*i%Mod;
        inv[i]=(Mod-Mod/i)*inv[Mod%i]%Mod;
        ifac[i]=ifac[i-1]*inv[i]%Mod;
    }
}
ll C(int n,int m){
    if(n<0&&m==0) return 1;
    if(n<0||m<0||n<m) return 0;
    return fac[n]*ifac[m]%Mod*ifac[n-m]%Mod;
}
struct Disc{
    int val[30],cnt[30],len;
    Disc(){
        memset(val,0,sizeof(val));
        memset(cnt,0,sizeof(cnt));
        len=0;
    }
    ll calc(int x){
        ll res=1;
        for(int i=1;i<=len;i++)
            if(cnt[i]) (res*=C(cnt[i]+x-1,cnt[i]))%=Mod;
        return res;
    }
}dis[N];
Disc operator+(const Disc &a,const Disc &b){
    Disc res;
    int j=1;
    for(int i=1;i<=a.len;i++){
        while(j<=b.len&&b.val[j]<a.val[i]){
            res.val[++res.len]=b.val[j];
            res.cnt[res.len]=b.cnt[j];
            j++;
        }
        if(b.val[j]==a.val[i]){
            res.val[++res.len]=b.val[j];
            res.cnt[res.len]=b.cnt[j]+a.cnt[i]-min(a.cnt[i],b.cnt[j])*2;
            j++;
        }else{
            res.val[++res.len]=a.val[i];
            res.cnt[res.len]=a.cnt[i];
        }
    }
    while(j<=b.len){
        res.val[++res.len]=b.val[j];
        res.cnt[res.len]=b.cnt[j];
        j++;
    }
    return res;
}
struct Matrix{
    ll a[2][2];
    ll* operator[](bool x){return a[x];}
    const ll* operator[](bool x)const{return a[x];}
    void operator*=(const Matrix &b){
        Matrix res; res.zero();
        res[0][0]=(a[0][0]*b[0][0]+a[0][1]*b[1][0])%Mod;
        res[0][1]=(a[0][0]*b[0][1]+a[0][1]*b[1][1])%Mod;
        res[1][0]=(a[1][0]*b[0][0]+a[1][1]*b[1][0])%Mod;
        res[1][1]=(a[1][0]*b[0][1]+a[1][1]*b[1][1])%Mod;
        a[0][0]=res[0][0]; a[0][1]=res[0][1]; a[1][0]=res[1][0]; a[1][1]=res[1][1];
    }
    void zero(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=1;}
    void one(){a[0][0]=a[1][1]=1; a[0][1]=a[1][0]=0;}
};
Matrix operator+(const Matrix &a,const Matrix &b){
    return (Matrix){{{a[0][0]+b[0][0],a[0][1]+b[0][1]},{a[1][0]+b[1][0],a[1][1]+b[1][1]}}};
}
Matrix operator*(const Matrix &a,const Matrix &b){
    Matrix res; res.zero();
    res[0][0]=(a[0][0]*b[0][0]+a[0][1]*b[1][0])%Mod;
    res[0][1]=(a[0][0]*b[0][1]+a[0][1]*b[1][1])%Mod;
    res[1][0]=(a[1][0]*b[0][0]+a[1][1]*b[1][0])%Mod;
    res[1][1]=(a[1][0]*b[0][1]+a[1][1]*b[1][1])%Mod;
    return res;
}
struct SegTree{
    Matrix a[N<<2],lz[N<<2];
    #define ls (u<<1)
    #define rs (u<<1|1)
    void pushup(int u){a[u]=a[ls]+a[rs];}
    void pushdown(int u){
        a[ls]*=lz[u];
        a[rs]*=lz[u];
        lz[ls]*=lz[u];
        lz[rs]*=lz[u];
        lz[u].one();
    }
    void update(int l,int r,int tl,int tr,int u,const Matrix &v){
        if(tl<=l&&r<=tr){
            a[u]*=v;
            lz[u]*=v;
            return;
        }
        pushdown(u);
        int m=l+r>>1;
        if(tl<=m) update(l,m,tl,tr,ls,v);
        if(m+1<=tr) update(m+1,r,tl,tr,rs,v);
        pushup(u);
    }
    void updpoint(int l,int r,int t,int u,ll v){
        if(l==r){
            ll prev=a[u][0][1];
            a[u]={{{v,prev},{0,0}}};
            return;
        }
        pushdown(u);
        int m=l+r>>1;
        if(t<=m) updpoint(l,m,t,ls,v);
        else updpoint(m+1,r,t,rs,v);
        pushup(u);
    }
    void updmul(int tl,int tr,ll mul){if(tl<=tr)update(1,n,tl,tr,1,(Matrix){{{mul,0},{0,1}}});}
    void updval(int t,ll val){updpoint(1,n,t,1,val);}
    void updtime(){update(1,n,1,n,1,(Matrix){{{1,1},{0,1}}});}
    ll queryval(int l,int r,int tl,int tr,int u){
        if(tl<=l&&r<=tr) return a[u][0][1];
        pushdown(u);
        int m=l+r>>1; ll res=0;
        if(tl<=m) res=queryval(l,m,tl,tr,ls);
        if(m+1<=tr) (res+=queryval(m+1,r,tl,tr,rs))%=Mod;
        return res;
    }
    ll query(int tl,int tr){return tl<=tr?queryval(1,n,tl,tr,1):0;}
}seg;
int main(){
    getprime();
    getfac();
    in(n); in(q);
    int pv=0;
    for(int i=1;i<=n;i++){
        in(a[i]); int x=a[i],mul=1,lst=0;
        pre[i]=pv;
        while(x!=1){
            if(f[x]==lst){
                dis[i].cnt[dis[i].len]++;
                mul*=f[x];
            }else{
                dis[i].val[++dis[i].len]=f[x];
                dis[i].cnt[dis[i].len]++;
                mul=f[x]; lst=f[x];
            }
            if(lstapp[mul]&&lstapp[mul]!=pv){
                cls[i]=max(cls[i],lstapp[mul]);
            }
            lstapp[mul]=i;
            x/=f[x];
        }
        if(a[i]>1) pv=i;
    }
    for(int i=1;i<=q;i++){
        in(l); in(r);
        qs[i]=(Queries){l,r,r,i};
    }
    sort(qs+1,qs+q+1);
    int now=0;
    ll lstval=1;
    for(int i=1;i<=q;i++){
        while(now<qs[i].h){
            now++;
            if(a[now]>1){
                seg.updmul(1,cls[now],0);
                for(int j=pre[now]+1;j<=now;j++){
                    ll val=dis[now].calc(now-j+1);
                    seg.updval(j,val);
                }
                Disc tmp=dis[now]+dis[pre[now]];
                seg.updmul(1,pre[now],ksm(lstval,Mod-2));
                lstval=tmp.calc(now-pre[now]);
                seg.updmul(1,pre[now],lstval);
                lstval=1;
            }else{
                seg.updval(now,1);
                seg.updmul(1,pre[now],ksm(lstval,Mod-2));
                lstval=dis[pre[now]].calc(now-pre[now]+1);
                seg.updmul(1,pre[now],lstval);
            }
            seg.updtime();
        }
        ans[qs[i].id]=seg.query(qs[i].l,qs[i].r);
    }
    for(int i=1;i<=q;i++) printf("%lld\n",ans[i]);
}