题解:P5138 fibonacci

· · 题解

提供一种不使用 Fib_{n+m} 公式的做法

先写出 Fib 矩阵 M=\begin{pmatrix} 1 & 1 \\ 1 & 0 \end{pmatrix}

每个点维护两个矩阵 A,B,其中 A 固定,B 会因为修改操作改变。

记一个节点的深度为 dep,根的深度为 1,则令这个点的 A=M^{dep}

对于子树 u 的加操作,转化为子树 u 的每个点的 B 加上 M^{k-dep_u+1}

不难发现,一个点真正的 Fib 矩阵就是 A\times B,Fib 矩阵右下角即为这个点的点权(即滑稽果个数)。

使用重链剖分和 DFS 序,将树上问题转为序列问题。

序列问题是:每个下标有两个矩阵 A,B,需要支持 B 的区间加、区间查询 A\times B 的和。

用线段树维护,线段树每个节点维护 Sa=\sum ASab=\sum A\times B,再维护一个懒标记 K

对线段树上一个节点进行加 B,有 Sab\leftarrow Sab+Sa\times BK\leftarrow K+B

下传标记时,设线段树上当前点的父节点为 fa,有 K\leftarrow K+K_{fa}Sab\leftarrow Sab+Sa\times K_{fa}

时间复杂度 O(n\log^2 n),八倍常数,完全不卡常。

注意 k 要开 long long

#include<bits/stdc++.h>
using namespace std;
constexpr int mod=1e9+7;
inline void upd(int &x,int y){
    x=(x+y)%mod;
}
struct Mat{
    int a00,a01,a10,a11;
    Mat operator +(Mat b){
        return {(a00+b.a00)%mod,(a01+b.a01)%mod,
                (a10+b.a10)%mod,(a11+b.a11)%mod};
    }
    Mat operator *(Mat b){
        Mat c={0,0,0,0};
        c.a00  =  1ll*a00*b.a00%mod;
        upd(c.a00,1ll*a01*b.a10%mod);
        c.a01  =  1ll*a00*b.a01%mod;
        upd(c.a01,1ll*a01*b.a11%mod);
        c.a10  =  1ll*a10*b.a00%mod;
        upd(c.a10,1ll*a11*b.a10%mod);
        c.a11  =  1ll*a10*b.a01%mod;
        upd(c.a11,1ll*a11*b.a11%mod);
        return c;
    }
    bool ck(){
        return a00||a11||a10||a01;
    }
    void set0(){
        a00=a01=a10=a11=0;
    }
    void set1(){
        a00=a11=1,a10=a01=0;
    }
};
int n,m;
vector<int> V[100005];
Mat A[100005],_A[100005];
int f[100005],sz[100005],dep[100005];
int son[100005],top[100005],dfn[100005],out[100005];
int tot;
Mat M={1,1,1,0};
Mat invM[100005];
#define ll long long
Mat ksm(Mat a,ll b){
    Mat res;res.set1();
    while(b){
        if(b&1) res=res*a;
        a=a*a;
        b/=2;
    }
    return res;
}
struct SGT{
    struct node{
        Mat Sa,Sab,K;
    }t[400005];
    void pushup(int p){
        t[p].Sab=t[p*2].Sab+t[p*2+1].Sab;
    }
    void mul(int p,Mat X){
        t[p].K=t[p].K+X;
        t[p].Sab=t[p].Sab+(t[p].Sa*X);
    }
    void pushdown(int p){
        if(t[p].K.ck()){
            mul(p*2,t[p].K);
            mul(p*2+1,t[p].K);
            t[p].K.set0();
        }
    }
    void build(int p,int l,int r,Mat A[]){
        if(l==r){
            t[p].Sa=A[l];
            t[p].Sab=M*invM[dep[l]]; 
            return;
        }
        int mid=(l+r)/2;
        build(p*2,l,mid,A);
        build(p*2+1,mid+1,r,A);
        t[p].Sa=t[p*2].Sa+t[p*2+1].Sa;
        pushup(p);
    }
    void modify(int p,int l,int r,int L,int R,Mat X){
        if(L<=l&&r<=R){
            mul(p,X);
            return;
        }
        pushdown(p);
        int mid=(l+r)/2;
        if(mid>=L) modify(p*2,l,mid,L,R,X);
        if(mid<R) modify(p*2+1,mid+1,r,L,R,X);
        pushup(p);
    }
    void query(int p,int l,int r,int L,int R,int &res){
        if(L<=l&&r<=R){
            upd(res,t[p].Sab.a11);
            return;
        }
        pushdown(p);
        int mid=(l+r)/2;
        if(mid>=L) query(p*2,l,mid,L,R,res);
        if(mid<R) query(p*2+1,mid+1,r,L,R,res);
    }
    void modify(int l,int r,ll k){
        if(k<=0){
            Mat T=invM[-k];
            modify(1,1,n,l,r,T);
        }else{
            Mat T=ksm(M,k);
//          cout<<k<<"\n";
//          cout<<T.a00<<" "<<T.a01<<"\n"<<T.a10<<" "<<T.a11<<"\n"; 
            modify(1,1,n,l,r,T);
        }

    }
    int query(int l,int r){
//      cout<<l<<" "<<r<<"\n"; 
        int res=0;
        query(1,1,n,l,r,res);
//      cout<<res<<"\n";
        return res;
    }
}sgt;
void dfs1(int u,int fa){
    if(fa) A[u]=A[fa]*M;
    else A[u]=M;
    dep[u]=dep[fa]+1;
    sz[u]=1;
    f[u]=fa;
    for(int v:V[u]){
        if(v==fa) continue;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(!son[u]||sz[v]>sz[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int tp){
    dfn[u]=++tot;
    _A[tot]=A[u];
    top[u]=tp;
    if(son[u]){
        dfs2(son[u],tp);
    }
    for(int v:V[u]){
        if(v!=f[u]&&v!=son[u]) dfs2(v,v);
    }
    out[u]=tot;
}
int Qry(int u,int v){
    int res=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        upd(res,sgt.query(dfn[top[u]],dfn[u]));
        u=f[top[u]];
    }
    upd(res,sgt.query(min(dfn[u],dfn[v]),
                      max(dfn[u],dfn[v])));
    return res;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        V[u].push_back(v);
        V[v].push_back(u);
    }
    dfs1(1,0);
    dfs2(1,1);
    sgt.build(1,1,n,_A);
    invM[0]={1,0,0,1};
    for(int i=1;i<=n;i++){
        invM[i]={invM[i-1].a01,invM[i-1].a11,
                 invM[i-1].a11,(invM[i-1].a01-invM[i-1].a11+mod)%mod};
    }
    for(int i=1;i<=m;i++){
        char op;
        cin>>op;
        if(op=='U'){
            int x;
            ll k;
            cin>>x>>k;
            sgt.modify(dfn[x],out[x],k-dep[x]+1);
        }else{
            int u,v;
            cin>>u>>v; 
            cout<<Qry(u,v)<<"\n";
        }
    }
}