P5138 题解

· · 题解

首先我们有一个重要的式子:

f_{n+m}=f_{n}f_{m+1}+f_{n-1}f_m

证明:

假设 $\forall m\in [0,k],m\in \mathbb{N}$ 结论成立,则 $m=k+1$ 时,有: $$ \begin{aligned} f_{n}f_{k+2}+f_{n-1}f_{k+1} &= f_nf_{k+1}+f_{n}f_{k}+f_{n-1}f_{k+1} \\ &= (f_{n-1}f_{k}+f_{n}f_{k+1})+(f_{n-2}f_{k}+f_{n-1}f_{k+1}) \\ &= f_{n+k}+f_{n+k-1} \\ &= f_{n+k+1} \end{aligned} $$

假设我们要计算 u 的子树内的一个节点 v 的增加量,则有:

f_{d+k}=f_{dep_v+k-dep_u}=f_{dep_v}f_{k-dep_u+1}+f_{dep_v-1}f_{k-dep_u}

注意到 k-dep_u 可能小于 0,我们有 f_{-n}=(-1)^{n-1}f_n,可以证明其仍然满足上面的式子(因为递推式 f_n=f_{n-1}+f_{n-2} 仍然成立)。

由于 f_{dep_v},f_{dep_{v-1}} 均为常数,我们可以线段树维护两个常数 \sum f_{dep_v},\sum{f_{dep_v-1}},两个增加量 \sum f_{k-dep_u+1}, \sum f_{k-dep_u} 以及节点的增加量之和。由于 k\le 10^{18},需要用矩阵快速幂计算 f_i

树链剖分后用线段树维护即可。时间复杂度 O(n\log ^2 n)

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

#define int long long

typedef long long ll;
typedef pair<ll, ll> pii;

const int N = 2e5+10, P = 1e9+7;

int n, m;
vector<int> e[N];

struct Matrix {
    int num[4][4];
    Matrix() {
        memset(num, 0, sizeof num);
    }
};

Matrix operator *(const Matrix &a, const Matrix &b) {
    Matrix ans;
    for (int k = 1; k <= 2; ++k) {
        for (int i = 1; i <= 2; ++i) 
            for (int j = 1; j <= 2; ++j)
                ans.num[i][j] = (ans.num[i][j] + (ll)a.num[i][k] * b.num[k][j]) % P;  
    }
    return ans;
}

Matrix power(Matrix &a, int b) {
    Matrix ans; for (int i = 1; i <= 2; ++i) ans.num[i][i] = 1;
    while (b > 0) {
        if (b & 1) ans = ans * a;
        a = a * a;
        b >>= 1;
    }
    return ans;
}

int calc(ll k) {
    int f = 1;
    if (k < 0) k = -k, f = (k % 2) ? 1 : -1;
    if (k == 0) return 0;
    Matrix a, trans; 
    a.num[1][1] = a.num[1][2] = 1; trans.num[1][1] = trans.num[1][2] = trans.num[2][1] = 1;
    Matrix ans = a;
    if (k >= 2) ans = ans * power(trans, k-2);
    return (f*ans.num[1][1]%P+P)%P;
}

int idx, id[N], dfn[N];
int dep[N], siz[N], top[N], son[N], f[N];

void get_son(int u, int fa) {
    siz[u] = 1, dep[u] = dep[fa]+1, f[u] = fa;
    for (auto v : e[u]) {
        if (v == fa) continue;
        get_son(v, u), siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
}

void get_top(int u, int t) {
    top[u] = t, id[u] = ++ idx, dfn[idx] = u;
    if (son[u]) get_top(son[u], t);
    for (auto v : e[u]) {
        if (v == f[u] || v == son[u]) continue;
        get_top(v, v);
    }
}

struct Node {
    int l, r, sum; pii val, tag;
} seg[N<<2];

void pushup(int u) {
    seg[u].val.first = (seg[u<<1].val.first + seg[u<<1|1].val.first) % P;
    seg[u].val.second = (seg[u<<1].val.second + seg[u<<1|1].val.second) % P;
    seg[u].sum = (seg[u<<1].sum + seg[u<<1|1].sum) % P;
}

void pushdown(Node &u, Node &now) {
    now.sum = (now.sum + (ll)now.val.first*u.tag.first%P + (ll)now.val.second*u.tag.second%P) % P;
    now.tag.first = ((ll)now.tag.first + u.tag.first) % P;
    now.tag.second = ((ll)now.tag.second + u.tag.second) % P;
}

void pushdown(int u) {
    pushdown(seg[u], seg[u<<1]), pushdown(seg[u], seg[u<<1|1]);
    seg[u].tag = make_pair(0, 0);
}

void build(int u, int l, int r) {
    seg[u].l = l, seg[u].r = r, seg[u].tag = make_pair(0, 0);
    if (l == r) seg[u].val.first = calc(dep[dfn[l]]), seg[u].val.second = calc(dep[dfn[l]]-1);
    else {
        int mid = l + r >> 1;
        build(u<<1, l, mid), build(u<<1|1, mid+1, r);
        pushup(u);
    }
}

int query(int u, int l, int r) {
    if (seg[u].l >= l && seg[u].r <= r) return seg[u].sum;
    pushdown(u);
    int mid = seg[u].l + seg[u].r >> 1, ans = 0;
    if (l <= mid) ans = ((ll)ans + query(u<<1, l, r)) % P;
    if (r > mid) ans = ((ll)ans + query(u<<1|1, l, r)) % P;
    return ans;
}

void modify(int u, int l, int r, ll val1, ll val2) {
    if (seg[u].l >= l && seg[u].r <= r) {
        Node tmp; tmp.tag = make_pair(val1, val2);
        pushdown(tmp, seg[u]);
        return ;
    }
    pushdown(u);
    int mid = seg[u].l + seg[u].r >> 1;
    if (l <= mid) modify(u<<1, l, r, val1, val2);
    if (r > mid) modify(u<<1|1, l, r, val1, val2);
    pushup(u); 
    return ;
}

int ask_path(int u, int v) {
    int sum = 0;
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        sum = ((ll)sum+query(1, id[top[u]], id[u])) % P;
        u = f[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    sum = ((ll)sum + query(1, id[u], id[v])) % P;
    return sum;
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0), cout.tie(0);

    cin >> n >> m;
    for (int i = 1; i < n; ++i) {
        int u, v; cin >> u >> v;
        e[u].push_back(v), e[v].push_back(u);
    }
    get_son(1, 1), get_top(1, 1);
    build(1, 1, n);

    char op; int x; ll y;
    for (int i = 1; i <= m; ++i) {
        cin >> op >> x >> y;
        if (op == 'U') modify(1, id[x], id[x]+siz[x]-1, calc(y-dep[x]+1), calc(y-dep[x]));
        else cout << ask_path(x, y) << '\n';
    }
    return 0;
}