P3328 [SDOI2015]音质检测 题解

· · 题解

从某位不愿透露姓名的神仙处获得了一个做法。

一种线段树配合矩阵乘法的做法。

首先看到这个线性递推式子,显然可以直接矩阵快速幂没有什么问题,就把向量设为形如 F_i,F_{i-1},1 就可以了,我们叫这个向量是 V(i),然后我们设使得 i 加一的转移矩阵为 pl,减一的为 mi

然后看,最终我们计算贡献的式子是:F_{A_{i-1}+1}F_{A_{i+1}-1}

我们可以发现,如果按照我们写代码时,把向量放进矩阵里面,假设 V 是列向量,那么转置就是个行向量。用这两个对应的矩阵,列向量左乘行向量得到的结果就恰好是两个向量第一位之积。

那么这个乘积就可以描述成 (V(A_{i-1}+1)\times V(A_{i+1}-1)^T)[0][0] 的,这个东西有什么用呢?我们设 W(i)=V(A_{i-1}+1)\times V(A_{i+1}-1)^T

我们考虑对一个点 i,我们都维护一下 W(i)。这样有什么用呢?假设我想要让 A_{i-1}A_{i+1} 都加上 1,那么新的贡献一定可以通过 ((pl\times V(A_{i-1}+1))\times (pl\times V(A_{i+1}-1))^T)[0][0]=(pl\times W(i)\times pl^T)[0][0] 表示出来。

那此时使用线段树的思路就明了了,因为矩阵乘法对加法具有分配律,所以我们线段树维护区间 W 之和,然后修改的时候直接整体在前面或者后面乘上转移矩阵,就可以得出答案了!形式化地写,即:(以左右均乘加一矩阵为例)

pl\times (\sum_{i=l}^r W(i))\times pl^T=\sum_{i=l}^r pl\times W(i)\times pl^T

于是本题解决了,时间复杂度 O((n+Q)\log n)

另外不知道为什么这个题直接写普通线段树会 TLE,但如果写动态开点就能过,很迷惑。

代码

#include <bits/stdc++.h>
#define lint long long
using namespace std;
inline int read(){
    char c;int f=1,res=0;
    while(c=getchar(),!isdigit(c))if(c=='-')f*=-1;
    while(isdigit(c))res=res*10+c-'0',c=getchar();
    return res*f;
}
const int N=3e5+5;
const lint md=1e9+7;
#define add(x,y) x=(x+y)%md
struct mat{
    lint v[3][3];
    mat(lint x=0){
        memset(v,0,sizeof v);
        for(int i=0;i<3;++i)v[i][i]=x;
    }inline lint* operator[](int x)
        {return v[x];}
    inline mat operator*(mat x){
        mat r;
        for(int i=0;i<3;++i)
            for(int k=0;k<3;++k)
                for(int j=0;j<3;++j)
                    add(r[i][j],v[i][k]*x[k][j]);
        return r;
    }inline mat operator+(mat x){
        for(int i=0;i<3;++i)
            for(int j=0;j<3;++j)
                add(x[i][j],v[i][j]);
        return x;
    }inline mat T(){
        mat r;
        for(int i=0;i<3;++i)
            for(int j=0;j<3;++j)
                r[i][j]=v[j][i];
        return r;
    }inline mat qpow(lint x){
        mat r(1),t=*this;
        while(x){
            if(x&1)r=r*t;
            t=t*t;x>>=1;
        }return r;
    }
}pl,mi,tpl,tmi,f,tf;
int n,nq;lint a,b,v[N];
inline lint qpow(lint x,lint y=md-2){
    lint res=1;
    while(y){
        if(y&1)res=res*x%md;
        x=x*x%md;y>>=1;
    }return res;
}
inline void init(){
    pl[0][0]=pl[1][0]=pl[2][2]=1;
    pl[0][1]=a;pl[0][2]=b;
    f[0][0]=2;f[1][0]=f[2][0]=1;
    tpl=pl.T();tf=f.T();
    if(a){
        mi[0][1]=mi[2][2]=1;
        mi[1][0]=qpow(a);
        mi[1][1]=md-qpow(a);
        mi[1][2]=md-b*qpow(a)%md;
    }else{
        mi[0][1]=mi[1][1]=mi[2][2]=1;
        mi[1][2]=md-b;
    }tmi=mi.T();
}
mat s[N<<1],tl[N<<1],tr[N<<1];
int ch[N<<1][2],tot,rt;
inline void nupd(int x,mat lu,mat ru){
    tl[x]=lu*tl[x];tr[x]=tr[x]*ru;
    s[x]=lu*s[x]*ru;
}inline void pushdown(int x){
    nupd(ch[x][0],tl[x],tr[x]);
    nupd(ch[x][1],tl[x],tr[x]);
    tl[x]=tr[x]=mat(1);
}void build(int &x=rt,int l=2,int r=n-1){
    x=++tot;tl[x]=tr[x]=mat(1);
    if(l==r){
        s[x]=pl.qpow(v[l-1]-1)*f;
        s[x]=s[x]*tf*tpl.qpow(v[l+1]-3);
        return;
    }int mid=l+r>>1;
    build(ch[x][0],l,mid);
    build(ch[x][1],mid+1,r);
    s[x]=s[ch[x][0]]+s[ch[x][1]];
}
void upd(int lp,int rp,mat lu,mat ru,int x=rt,int l=2,int r=n-1){
    if(l>rp||r<lp)return;
    if(lp<=l&&r<=rp)return nupd(x,lu,ru);
    int mid=l+r>>1;pushdown(x);
    upd(lp,rp,lu,ru,ch[x][0],l,mid);
    upd(lp,rp,lu,ru,ch[x][1],mid+1,r);
    s[x]=s[ch[x][0]]+s[ch[x][1]];
}
mat query(int lp,int rp,int x=rt,int l=2,int r=n-1){
    if(l>rp||r<lp)return mat();
    if(lp<=l&&r<=rp)return s[x];
    int mid=l+r>>1;pushdown(x);
    mat res=query(lp,rp,ch[x][0],l,mid);
    res=res+query(lp,rp,ch[x][1],mid+1,r);
    return res;
}
int main(){
    n=read();nq=read();
    a=read();b=read();
    for(int i=1;i<=n;++i)v[i]=read();
    init();build();
    while(nq--){
        string op;cin>>op;
        int l=read(),r=read();
        if(op=="plus"){
            upd(l+1,r+1,pl,mat(1));
            upd(l-1,r-1,mat(1),tpl);
        }else if(op=="minus"){
            upd(l+1,r+1,mi,mat(1));
            upd(l-1,r-1,mat(1),tmi);
        }else if(op=="query")
            printf("%lld\n",query(l+1,r-1)[0][0]);
    }return 0;
}