『GROI-R1』 古朴而优雅

· · 题解

赛时所有结论都推出来了,代码没调出来(

结论 1

考虑一棵树 (V,E),它的遍历方案数为:

f(V,E)=\prod_u \text{deg}_u!

其中 \text{deg}_u 表示结点 u 的子节点个数。

证明:

对于每个点的所有子节点 S_u=\{v\mid u\to v\},其中 |S_u|=\text{deg}_u,存在 \text{deg}_u! 种排列方式。

考虑每一次遍历都优先选择当前结点 u 在集合 S_u 中 第一个子结点。则每种排列都对应唯一一种遍历顺序。根据乘法原理,结论得证。

结论 2

对于任意图,遍历过程中,所有经过的边构成了一棵树。

证明:

证明其是一棵树,可分解为:

  1. 边集 E' 是连通的

  2. 点集 V 的大小为边集大小加一。

对于 1,结论显然成立。

对于 2,考虑到初始时,点集大小为 1V'=\{1\}),边集为空集。每次多遍历一个点,点集与边集同时增加一个元素,故差值不变为 1,证毕。

结论 3

对于一棵基环树,有且仅有一条边没有被遍历到。形式化地,设基环树为 (V,E),并且遍历中经过的所有边为 E',则有 |E\backslash E'|=1

证明:

由结论 2,有 |E'|=n-1。又基环树有 n 条边,得证。

结论 4

没有被遍历到的这条边一定在环上。

证明:

考虑反证。若这条边不在环上,由于基环树除环边外,剩余的每条边都是割边,故若这条边没有遍历到,E' 不连通,矛盾,得证。

结论 5

若在原树上增加边 (x,y),则环为 x\operatorname{lca}(x,y) 路径与 y\operatorname{lca}(x,y) 路径的交加上边 (x,y)

证明:

显然的,若基环树上一条边被去掉,剩余部分仍然连通,则这条边在环上。

考虑 x\operatorname{lca}(x,y) 这条路径上任意一边。若割掉 (u,v),其中 uv 的父亲,则 u 仍然可以通过 u\to \operatorname{lca}(x,y)\to y\to x\to v,原图仍然连通。y\operatorname{lca}(x,y) 同理。

结论 6

若在原树上增加边 (x,y),则没有被遍历到的边属于 \operatorname{lca}(x,y) 在环上的两边。

证明:

由于我们假定 1 为根,故遍历过程中一定会进入环。由于环上深度最低的点是 \operatorname{lca}(x,y),所以进入环后,有两条路径选择:顺时针走与逆时针走。由深搜“走到死路才回头”的性质,结论得证。

结论 7

若在原树上增加边 (x,y),则答案为将没有遍历到的边各个删去,得到的树的答案之和。形式化地,

\text{ans}=\sum f(V,E')

其中

f(V,E)=\prod_u \text{deg}_u!

证明:

没有遍历到的边对答案没有贡献,删去后答案不变。证毕。

结论 8

在原树上删去一条边再添加一条边使得仍然连通,则最多只有 4 个结点的 \text{deg} 发生了变化,且这四个结点为删去边与添加边的端点。

具体地,若删去一条边 (u,v),则 \text{deg}_u\gets \text{deg}_u-1\text{deg}_v\gets \text{deg}_v-1;若添加边 (x,y),则 \text{deg}_x\gets \text{deg}_x+1\text{deg}_y\gets \text{deg}_y+1

证明:

删去一条边 (u,v)uv 的父亲)后,显然的 \text{deg}_u\gets \text{deg}_u-1(少了一条通向 v 的边)。而对于 v,手玩可以发现,其一个子结点成为了它的父亲结点,所以 \text{deg}_v\gets \text{deg}_v-1

加上一条边同理。

总结

利用结论 8,考虑到我们只关心 \text{deg} 的变化,预处理出阶乘及其逆元,可以在 \Theta(q\log n) 的复杂度内完成问题。

#include <bits/stdc++.h>
#define int long long

using namespace std;

const int maxn = 5e5 + 1;
const int mod = 1e9 + 7;
const int maxd = 20;

typedef pair<int, int> pii;

struct edge {
    int to, next;
};

int n, q;
int head[maxn];
edge e[maxn << 1];
int cnt;

void add_edge(int u, int v) {
    e[++cnt].to = v;
    e[cnt].next = head[u];
    head[u] = cnt;
}

int fac[maxn];
int inv[maxn];

int power(int base, int freq, int mod) {
    int ans = 1, tmp = base;
    while (freq) {
        if (freq & 1) ans = ans * tmp % mod;
        freq >>= 1;
        tmp = tmp * tmp % mod;
    }
    return ans;
}

void init() {
    fac[0] = 1;
    for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % mod;
    inv[maxn - 1] = power(fac[maxn - 1], mod - 2, mod);
    for (int i = maxn - 2; ~i; i--) inv[i] = inv[i + 1] * (i + 1) % mod;
}

int fa[maxn][maxd];
int dep[maxn];

void get_dep(int u, int depth, int fat) {
    dep[u] = depth;
    for (int i = head[u]; i; i = e[i].next) {
        if (e[i].to != fat) {
            get_dep(e[i].to, depth + 1, u);
        }
    }
} 

void get_fa(int cur, int fat) {
    fa[cur][0] = fat;
    for (int i = 1; i <= log2(dep[cur]) + 1; i++) {
        fa[cur][i] = fa[fa[cur][i - 1]][i - 1];
    }
    for (int i = head[cur]; i; i = e[i].next) {
        if (e[i].to != fat) {
            get_fa(e[i].to, cur);
        }
    }
}

int lca(int u, int v) {
    if (dep[u] < dep[v]) swap(u, v);
    while (dep[u] > dep[v]) {
        u = fa[u][(int) log2(dep[u] - dep[v])];
    }
    if (u == v) return u;
    for (int i = log2(dep[u]); i >= 0; i--) {
        if (fa[u][i] != fa[v][i]) {
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}

int dist(int u, int v) {
    return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}

int ful = 1;
int deg[maxn];

void dfs(int u, int fa, int &mul) {
    int tot = 0;
    for (int i = head[u]; i; i = e[i].next) {
        int v = e[i].to;
        if (v == fa) continue;
        tot++;
        dfs(v, u, mul);
    }
    mul = mul * fac[tot] % mod;
    deg[u] = tot;
}

int getans(int u, int v, int x, int y) {
    int cur = ful * inv[deg[u]] % mod * inv[deg[v]] % mod * inv[deg[x]] % mod * inv[deg[y]] % mod;
    deg[u]--, deg[v]--, deg[x]++, deg[y]++;
    cur = cur * fac[deg[u]] % mod * fac[deg[v]] % mod * fac[deg[x]] % mod * fac[deg[y]] % mod;
    deg[u]++, deg[v]++, deg[x]--, deg[y]--;
    return cur;
}

int firstson(int u, int v) {
    int d = dep[v] - dep[u] - 1;
    for (int i = log2(d + 1); ~i; i--) {
        if (d >> i & 1) v = fa[v][i];
    }
    return v;
}

pii getson(int rt, int u, int v) {
    if (dep[u] > dep[v]) swap(u, v);
    if (u == rt) {
        return make_pair(v, firstson(u, v));
    }
    return make_pair(firstson(rt, u), firstson(rt, v));
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    init();
    cin >> n >> q;
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        add_edge(u, v);
        add_edge(v, u);
    }
    get_dep(1, 0, 0);
    get_fa(1, 0);
    dfs(1, 0, ful);
    while (q--) {
        int x, y;
        cin >> x >> y;
        int z = lca(x, y);
        if (dist(x, y) == 1) {
            cout << ful << endl;
            continue;
        }
        pii ts = getson(z, x, y);
        int xt = ts.first, yt = ts.second;
        int ans1 = getans(z, xt, x, y), ans2 = getans(z, yt, x, y);
        cout << (ans1 + ans2) % mod << endl;
    }
    return 0;
}