题解:P16905 [CCO 2026] Tree Traversals

· · 题解

给定一棵 n 个点的树。

定义 f(K) 表示 n 的排列数量,满足:

\mathcal{O}(n^2q) 的暴力 dp

由第一条可以得到,满足条件的排列一定是原树从上到下一层一层组成的,可以通过 dep 的值分成最大深度 mxd 个块。设值为 d 的块内部点的集合为 bel_d

块与块之间的限制可以通过 dp 考虑,现在看块内的情况。结论:块内任意两点 u, v 必有距离 dis(u, v) \le K,否则无解。

::::info[无解结论的证明] 假设存在合法排列同一块的连续三个点 u, v, wdis(u, v) \le K, dis(v, w) \le K, dis(u, w) \gt K

因为 dis(i, j) = dep_i + dep_j - 2 \times dep_{lca(i, j)},有:

dep_u + dep_v - 2\times dep_{lca(u, v)} \le K \\ dep_v + dep_w - 2\times dep_{lca(v, w)} \le K \\ dep_u + dep_w - 2\times dep_{lca(u, w)} \gt K \\ \end{cases}

d = dep_u = dep_v = dep_w,整理式子得:

dep_{lca(u, v)} \ge \frac{2d - K} 2 \\ dep_{lca(v, w)} \ge \frac{2d - K} 2 \\ dep_{lca(u, w)} \lt \frac{2d - K} 2 \\ \end{cases}

由前两个不等式可知,u, v, w 在某个点 u_0 的子树中,且 dep_{u_0} \ge \frac{2d - K} 2。这说明 dep_{lca(u, w)} \ge dep_{u_0} \ge \frac{2d - K} 2,与第三个不等式矛盾,假设不成立。

其实画图感性理解一下就是,总有一对相邻点的路径会经过所有点的 lca,这条路径就是最远的路径。 ::::

f_u 表示深度小于等于 dep_u 的块中,值为 dep_u 的块以 u 为结尾的排列数量。

u 在层 d = dep_u,枚举下一层 d + 1 的点 v 作为下一块的开头,如果 dis(u, v) \le K 说明这两块可以拼接,是合法的;再枚举下一层的点 i \, (i \ne v) 作为下一块的结尾;而 vi 之间的点随意排都可以。设层 dep_u 的点数为 c_u,有转移:

f_i \gets f_u \times (c_{d + 1} - 2)!

初始根节点 f_1 = 1,答案是 \sum_{u \in bel_{mxd}} f_u

可以获得 2pts。

::::success[2pts Code]

#define LL long long 
const int Mod = 1e9 + 7, N = 5e5 + 5;
vector<int> e[N], bel[N];
int dep[N];
LL fac[N], f[N];
void dfs(int u, int fa){
    bel[dep[u] = dep[fa] + 1].emplace_back(u);
    for(int v : e[u]){
        if(v == fa) continue;
        dfs(v, u);
    }
}
namespace LCA{/*预处理O(nlogn),单次O(1)的dfs序求lca*/}
inline int getdis(int u, int v){
    return dep[u] + dep[v] - 2 * dep[LCA::lca(u, v)];
}
inline void amo(LL &x, LL y){
    x += y;
    if(x >= Mod) x -= Mod;
}
inline void solve(int n, int q, int K, int mxd){
    for(int d = 1; d <= mxd; ++d){
        for(int i = 0; i < bel[d].size(); ++i){
            for(int j = i + 1; j < bel[d].size(); ++j){
                if(getdis(bel[d][i], bel[d][j]) > K){
                    putchar('0'), putchar(' ');
                    return;
                }
            }
        }
    }
    for(int i = 1; i <= n; ++i) f[i] = 0;
    f[1] = 1;
    for(int d = 1; d < mxd; ++d){
        for(int u : bel[d]){
            for(int v : bel[d + 1]){
                if(getdis(u, v) <= K){
                    if(bel[d + 1].size() == 1) amo(f[v], f[u]);
                    else{
                        for(int i : bel[d + 1]){
                            if(i != v){
                                amo(f[i], f[u] * fac[bel[d + 1].size() - 2] % Mod);
                            }
                        }
                    }

                }
            }
        }
    }
    LL ans = 0;
    for(int u : bel[mxd]) amo(ans, f[u]);
    write(ans), putchar(' ');
}
inline void run(){
    int n, q, K, mxd = 0;
    read(n, q);
    for(int i = 1; i <= n; ++i){
        vector<int>().swap(e[i]);
        vector<int>().swap(bel[i]);
        fac[i] = fac[i - 1] * i % Mod;
    }
    for(int i = 1; i < n; ++i){
        int u, v;
        read(u, v);
        e[u].emplace_back(v);
        e[v].emplace_back(u);
    }
    dfs(1, 0);
    for(int i = 1; i <= n; ++i) mxd = max(mxd, dep[i]);
    LCA::solve(n);
    while(q--){
        read(K);
        solve(n, q, K, mxd);
    }
    putchar('\n');
}
int main(){
//  freopen("tree.in", "r", stdin);
//  freopen("tree.out", "w", stdout);
    int T;
    read(T);
    fac[0] = 1;
    while(T--) run();
    return 0;
}

::::

通过同层 dp 值相同的性质,优化为 \mathcal{O}(nq)

无解部分是最好优化的。找到该层 d 所有点的最近公共祖先 L = lca_{u \in bel_d},最长距离即 mxl_d = d + d - 2dep_L。只要 \forall mxl_d \le K 就有解。

上述暴力 dp 有一种转移方程和某些变量无关的感觉,但无从下手(可以打印 dp 值发现值相同的规律)。不妨考虑 dis(u, v) \le Ku, v 的关系,到底什么时候 u 才能转移到 v

在层 d 转移该层末尾点 v,现在枚举层 d - 1 的点 u,只要 dis(u, v) \le Ku 作为层 d - 1 的结尾就可以和层 d 的任意点作为开头拼接!即:

f_v = \sum_{dis(u, v) \le K} f_u \times (c_d - 1) \times (c_d - 2)! = \sum_{dis(u, v) \le K} f_u \times (c_d - 1)!

dp_d 表示层 d 每个点的 f 值,cnt_d 表示满足 dis(u, v) \le Ku 的个数,有:

dp_d = dp_{d - 1} \times (c_d - 1)! \times cnt_d

通过这个递推式我们完全可以放弃动态规划,由 ans = \sum_{u \in bel_{mxd}} f_u,直接表示出答案:

ans = c_{mxd} \times \prod_{d = 2} ^{mxd}cnt_d \cdot (c_d - 1)!

我们化简出了如此优美简洁的式子!其中,c_d 可以 \mathcal{O}(n) 求出,总时间复杂度 \mathcal{O}(nq),得分 17pts。

::::success[17pts Code]

#define LL long long 
const int Mod = 1e9 + 7, N = 5e5 + 5;
vector<int> e[N], bel[N];
int dep[N];
LL fac[N], f[N];
void dfs(int u, int fa){
    bel[dep[u] = dep[fa] + 1].emplace_back(u);
    for(int v : e[u]){
        if(v == fa) continue;
        dfs(v, u);
    }
}
namespace LCA{/*预处理O(nlogn),单次O(1)的dfs序求lca*/}
inline int getdis(int u, int v){
    return dep[u] + dep[v] - 2 * dep[LCA::lca(u, v)];
}
inline void solve(int n, int q, int K, int mxd){
    for(int i = 1; i <= mxd; ++i){
        int lca = bel[i][0];
        for(int j : bel[i]) lca = LCA::lca(lca, j);
        if(i + i - 2 * dep[lca] > K){
            putchar('0'), putchar(' ');
            return;
        }
    }
    LL ans = 1;
    for(int i = 2; i <= mxd; ++i){
        int cnt = 0;
        for(int j : bel[i - 1]){
            cnt += (getdis(bel[i][0], j) <= K);
        }
        ans = ans * cnt % Mod * fac[bel[i].size() - 1] % Mod;
    }
    ans = ans * bel[mxd].size() % Mod;
    write(ans), putchar(' ');
}
inline void run(){
    int n, q, K, mxd = 0;
    read(n, q);
    for(int i = 1; i <= n; ++i){
        vector<int>().swap(e[i]);
        vector<int>().swap(bel[i]);
        fac[i] = fac[i - 1] * i % Mod;
    }
    for(int i = 1; i < n; ++i){
        int u, v;
        read(u, v);
        e[u].emplace_back(v);
        e[v].emplace_back(u);
    }
    dfs(1, 0);
    for(int i = 1; i <= n; ++i) mxd = max(mxd, dep[i]);
    LCA::solve(n);
    while(q--){
        read(K);
        solve(n, q, K, mxd);
//      for(int i = 1; i <= mxd; ++i){
//          cerr << "depth = " << i << ": ";
//          for(int j : bel[i]) cerr << f[j] << ' ';
//          cerr << '\n';
//      }
    }
    putchar('\n');
}
int main(){
//  freopen("tree.in", "r", stdin);
//  freopen("tree.out", "w", stdout);
    int T;
    read(T);
    fac[0] = 1;
    while(T--) run();
    return 0;
}

::::

离线查询+双指针优化至 \mathcal{O}(n)\mathcal{O}(n \log n)

先把和 cnt_d 无关的阶乘提出来。把计算 cnt_d 需要的相邻层 dis(u, v) 存在一起,从小到大排序;把查询的 K 从小到大排序。通过双指针,找到在满足查询 i 的最大 dis 指针 j。移动指针的过程中维护实时 cnt,当 cnt 变化时,在继承了上次答案的 ans 中除以掉原来的 cnt 再乘新的 cnt,提前预处理逆元即可。别忘了把阶乘乘回去,最后把无解答案 0 强制赋值给对应询问。

求答案的过程是 \mathcal{O}(n) 的,逆元也可以 \mathcal{O}(n) 预处理,因此总时间复杂度取决于你求 LCA 和排序的方法,我用的 dfs 序求 LCA,预处理是 \mathcal{O}(n \log n) 的。

::::success[Accepted Code]

#include<bits/stdc++.h>
using namespace std;
template<typename T> inline void read(T &x){
    T s = 0; int st = 1; char c = getchar();
    while(c < '0' || c > '9') (c == '-') && (st = -1), c = getchar();
    while(c >= '0' && c <= '9') s = (s << 3) + (s << 1) + (c ^ 48), c = getchar();
    x = s * st;
}
template<typename T, typename... Args> inline void read(T &x, Args &...args){
    read(x), read(args...);
}
template<typename T> inline void write(T x){
    if(x < 0) putchar('-'), x = -x;
    if(x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
#define LL long long 
#define PII pair<int, int> 
const int Mod = 1e9 + 7, N = 5e5 + 5;
vector<int> e[N], bel[N];
int dep[N], cnt[N];
LL fac[N], inv[N], ans[N];
PII req[N], dis[N];
void dfs(int u, int fa){
    bel[dep[u] = dep[fa] + 1].emplace_back(u);
    for(int v : e[u]){
        if(v == fa) continue;
        dfs(v, u);
    }
}
namespace LCA{
    int times;
    int dfn[N], mn[20][N], lg[N];
    inline int MN(int x, int y){
        if(dfn[x] < dfn[y]) return x;
        return y;
    }
    void init(int u, int fa){
        mn[0][dfn[u] = ++times] = fa;
        for(int v : e[u]){
            if(v == fa) continue;
            init(v, u);
        }
    }
    inline void solve(int n){
        times = 0;
        init(1, 0);
        lg[1] = 0;
        for(int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
        for(int i = 1; i <= 19; ++i){
            for(int j = 1; j + (1 << i) - 1 <= n; ++j){
                mn[i][j] = MN(mn[i - 1][j], mn[i - 1][j + (1 << i - 1)]);
            }
        }
    }
    inline int lca(int u, int v){
        if(u == v) return u;
        if(dfn[u] > dfn[v]) swap(u, v);
        int l = dfn[u] + 1, r = dfn[v], len = lg[r - l + 1];
        return MN(mn[len][l], mn[len][r - (1 << len) + 1]);
    }
}
inline int getdis(int u, int v){
    return dep[u] + dep[v] - 2 * dep[LCA::lca(u, v)];
}
inline void run(){
    int n, q, K, mxd = 0, mxl = 0, cntd = 0;
    read(n, q);
    for(int i = 1; i <= n; ++i){
        vector<int>().swap(e[i]);
        vector<int>().swap(bel[i]);
        fac[i] = fac[i - 1] * i % Mod;
    }
    inv[1] = 1;
    for(int i = 2; i <= n; ++i){
        int u, v;
        read(u, v);
        e[u].emplace_back(v);
        e[v].emplace_back(u);
        inv[i] = (Mod - Mod / i) * inv[Mod % i] % Mod;
    }
    dfs(1, 0);
    for(int i = 1; i <= n; ++i) mxd = max(mxd, dep[i]);
    LCA::solve(n);
    for(int i = 1; i <= q; ++i){
        read(req[i].first);
        req[i].second = i;
        ans[i] = 1;
    }
    sort(req + 1, req + 1 + q);
    LL tmp = 1;
    for(int i = 1; i <= mxd; ++i){
        int lca = bel[i][0];
        for(int j : bel[i]) lca = LCA::lca(lca, j);
        mxl = max(mxl, i + i - 2 * dep[lca]);
    }
    for(int i = 2; i <= mxd; ++i){
        for(int j : bel[i - 1]){
            dis[++cntd] = {getdis(bel[i][0], j), i};
        }
        cnt[i] = 0;
        tmp = tmp * fac[bel[i].size() - 1] % Mod;
    }
    tmp = tmp * bel[mxd].size() % Mod;
    sort(dis + 1, dis + 1 + cntd);
    for(int t = 1, j = 1, i; t <= q; ++t){
        i = req[t].second;
        if(t > 1) ans[i] = ans[req[t - 1].second];
        while(j <= cntd && dis[j].first <= req[t].first){
            if(++cnt[dis[j].second] > 1){
                ans[i] = ans[i] * inv[cnt[dis[j].second] - 1] % Mod * cnt[dis[j].second] % Mod; 
            }
            ++j;
        }
    }
    for(int t = 1; t <= q; ++t) if(mxl > req[t].first) ans[req[t].second] = 0;
    for(int i = 1; i <= q; ++i) write(ans[i] * tmp % Mod), putchar(' ');
    putchar('\n');
}
int main(){
//  freopen("tree.in", "r", stdin);
//  freopen("tree.out", "w", stdout);
    int T;
    read(T);
    fac[0] = 1;
    while(T--) run();
    return 0;
}

::::