P14251 [集训队互测 2025] Everlasting Friends?

· · 题解

考虑 tp = 1。枚举连通块在 T_{\max} 的根 x,只保留 T 两端都在 x 子树内的边。发现 T 中每条边覆盖 T_{\max} 的一条祖孙链,且每条边至少被覆盖一次。

然后有一个比较深刻的观察是,把选连通块看成断一些边,一条边能断当且仅当它只被覆盖一次,感性理解就是如果被覆盖两次,那么断了这条边会导致那两个上端点不连通。

那么考虑 DP,f_u 表示以 u 为根的连通块数量,那么 f_u = \prod\limits_{v \in son_u} (f_v + [d_v = 1]),其中 d_v(u, v) 这条边被覆盖的次数。答案即为 f_x。时间复杂度 O(n^2)

考虑优化。固然可以 DDP 优化到 O(n \operatorname{polylog}(n)),但是有更简单的方法。

只做一次 DFS,递归到 u 时,找到它在 T 中的所有边 (u, v), u > v。设 u 往下的边为 u \to w,相当于是把 T_{\max}u \to v 路径上的边全部设成不可断开($$)。

从上往下设,那么考虑到一条原本可以断开的边 x \to y 时,f_yf_w 的贡献原本是 f_y + 1 现在变成 f_y,那么只需令 f_w \gets f_w \times \frac{f_y}{f_y + 1} 即可。除 0 的问题可以将每个数表示成 a \times 0^b 解决(集合还在追我()。时间复杂度 O(n (\log n + \log P)),其中 P = 998244353

考虑 tp = 2。固定连通块在 T_{\max} 上的根 xT_{\min} 上的根 y。可以归纳证明连通块 ST 中也是连通块。

并且可能的 S 最多有一个,因为设初始连通块为 x \to yT 中路径的点集,连通块不断拓展的过程中,若 y < z < xz 不在连通块且 zT 中与连通块中的点 w 相邻,那么 z 一定在 T_{\max}T_{\min} 中是 w 的祖先,从而一定要被加入连通块。

通过上述过程可以观察出结论:S 只可能是 T_{\max}x 子树与 T_{\min}y 子树的交。

于是问题转化成有多少对 (x, y),满足:

数连通块自然考虑点减边容斥。枚举 x,给每个 y 设一个权值 val_y(初始若 yT_{\max}x 子树内则 val_y = 2,否则 val_y = +\infty)。若一个点同时在两棵子树内,对 val_y2 的贡献,对于 T_{\max}x 子树内的一条边 (u, v),若 u, v 都在 y 子树内,对 val_y-1 的贡献,对于 T_{\min}y 子树内的一条边 (u, v) 同理。val_y 一定 \ge 2,若 val_y = 2 说明 y 合法。

那么考虑在 T_{\max} 上做线段树合并,对 val_y 的修改相当于若干个在 T_{\min} 上的链加,统计答案相当于查 xT_{\min} 上到根的路径的最小值和最小值个数。

时空复杂度均为 O(n \log^2 n),感受一下空间很难卡满,加上垃圾回收后可以通过。

:::info[代码]

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

const int maxn = 200100;
const int logn = 20;
const int maxm = 16000100;
const int inf = 0x3f3f3f3f;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
    ll res = 1;
    while (p) {
        if (p & 1) {
            res = res * b % mod;
        }
        b = b * b % mod;
        p >>= 1;
    }
    return res;
}

ll n, type;
int fa[maxn], p[maxn], pa[maxn];
vector<int> G[maxn], G1[maxn], G2[maxn];

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

namespace Sub1 {
    ll ans;

    struct node {
        ll x, y;
        node(ll _x = 0, ll _y = 0) : x(_x), y(_y) {}
    } f[maxn];

    inline node operator + (const node &a, const node &b) {
        if (a.y < b.y) {
            return a;
        } else if (a.y > b.y) {
            return b;
        } else {
            ll x = (a.x + b.x) % mod;
            if (x == 0) {
                return node(1, a.y + 1);
            } else {
                return node(x, a.y);
            }
        }
    }

    inline node operator * (const node &a, const node &b) {
        return node(a.x * b.x % mod, a.y + b.y);
    }

    inline node operator / (const node &a, const node &b) {
        return node(a.x * qpow(b.x, mod - 2) % mod, a.y - b.y);
    }

    void dfs(int u) {
        f[u] = node(1, 0);
        for (int v : G1[u]) {
            dfs(v);
            int w = p[v];
            for (int x = find(w); x != v; x = find(x)) {
                f[v] = f[v] / (f[x] + node(1, 0)) * f[x];
                fa[x] = pa[x];
            }
            f[u] = f[u] * (f[v] + node(1, 0));
        }
        ans = (ans + (f[u].y ? 0 : f[u].x)) % mod;
    }

    void solve() {
        for (int i = 1; i <= n; ++i) {
            fa[i] = i;
        }
        dfs(n);
        printf("%lld\n", ans);
    }
}

int st1[logn][maxn], st2[logn][maxn], dfn1[maxn], dfn2[maxn], tim;

inline int get1(int i, int j) {
    return dfn1[i] < dfn1[j] ? i : j;
}

inline int get2(int i, int j) {
    return dfn2[i] < dfn2[j] ? i : j;
}

inline int qlca1(int x, int y) {
    if (x == y) {
        return x;
    }
    x = dfn1[x];
    y = dfn1[y];
    if (x > y) {
        swap(x, y);
    }
    ++x;
    int k = __lg(y - x + 1);
    return get1(st1[k][x], st1[k][y - (1 << k) + 1]);
}

inline int qlca2(int x, int y) {
    if (x == y) {
        return x;
    }
    x = dfn2[x];
    y = dfn2[y];
    if (x > y) {
        swap(x, y);
    }
    ++x;
    int k = __lg(y - x + 1);
    return get2(st2[k][x], st2[k][y - (1 << k) + 1]);
}

void dfs(int u, int t) {
    dfn1[u] = ++tim;
    st1[0][tim] = t;
    for (int v : G1[u]) {
        dfs(v, u);
    }
}

int sz[maxn], son[maxn], dep[maxn], top[maxn];

int dfs2(int u, int f, int d) {
    fa[u] = f;
    sz[u] = 1;
    dep[u] = d;
    int mx = -1;
    for (int v : G2[u]) {
        sz[u] += dfs2(v, u, d + 1);
        if (sz[v] > mx) {
            son[u] = v;
            mx = sz[v];
        }
    }
    return sz[u];
}

void dfs3(int u, int tp) {
    top[u] = tp;
    dfn2[u] = ++tim;
    st2[0][tim] = fa[u];
    if (!son[u]) {
        return;
    }
    dfs3(son[u], tp);
    for (int v : G2[u]) {
        if (!dfn2[v]) {
            dfs3(v, v);
        }
    }
}

inline pii operator + (const pii &a, const pii &b) {
    if (a.fst < b.fst) {
        return a;
    } else if (a.fst > b.fst) {
        return b;
    } else {
        return mkp(a.fst, a.scd + b.scd);
    }
}

namespace SGT {
    int ls[maxm], rs[maxm], tag[maxm], nt, stk[maxm], top;
    pii a[maxm];

    inline void init() {
        for (int i = 0; i < maxm; ++i) {
            a[i] = pii(inf, 0);
        }
    }

    inline void pushup(int x) {
        a[x] = a[ls[x]] + a[rs[x]];
        a[x].fst += tag[x];
    }

    inline void pushtag(int x, int y) {
        if (!x) {
            return;
        }
        a[x].fst += y;
        tag[x] += y;
    }

    inline void delnode(int x) {
        a[x] = pii(inf, 0);
        ls[x] = rs[x] = tag[x] = 0;
        if (top + 1 < maxm) {
            stk[++top] = x;
        }
    }

    inline int newnode() {
        assert(nt + 1 < maxm);
        return top ? stk[top--] : (++nt);
    }

    void update(int &rt, int l, int r, int ql, int qr, int x) {
        if (!rt) {
            rt = newnode();
        }
        if (ql <= l && r <= qr) {
            pushtag(rt, x);
            return;
        }
        int mid = (l + r) >> 1;
        if (ql <= mid) {
            update(ls[rt], l, mid, ql, qr, x);
        }
        if (qr > mid) {
            update(rs[rt], mid + 1, r, ql, qr, x);
        }
        pushup(rt);
    }

    void modify(int &rt, int l, int r, int x) {
        if (!rt) {
            rt = newnode();
        }
        if (l == r) {
            a[rt].fst -= inf;
            a[rt].scd = 1;
            return;
        }
        int mid = (l + r) >> 1;
        (x <= mid) ? modify(ls[rt], l, mid, x) : modify(rs[rt], mid + 1, r, x);
        pushup(rt);
    }

    int merge(int u, int v, int l, int r) {
        if (!u || !v) {
            return u | v;
        }
        tag[u] += tag[v];
        if (l == r) {
            bool fl = (a[u].fst > 1e9) && (a[v].fst > 1e9);
            if (a[u].fst >= inf) {
                a[u].fst -= inf;
            }
            if (a[v].fst >= inf) {
                a[v].fst -= inf;
            }
            a[u].fst = a[u].fst + a[v].fst + (fl ? inf : 0);
            a[u].scd |= a[v].scd;
            delnode(v);
            return u;
        }
        int mid = (l + r) >> 1;
        ls[u] = merge(ls[u], ls[v], l, mid);
        rs[u] = merge(rs[u], rs[v], mid + 1, r);
        pushup(u);
        delnode(v);
        return u;
    }

    pii query(int rt, int l, int r, int ql, int qr) {
        if (!rt) {
            return pii(inf, 0);
        }
        if (ql <= l && r <= qr) {
            return a[rt];
        }
        int mid = (l + r) >> 1;
        pii res(inf, 0);
        if (ql <= mid) {
            res = res + query(ls[rt], l, mid, ql, qr);
        }
        if (qr > mid) {
            res = res + query(rs[rt], mid + 1, r, ql, qr);
        }
        res.fst += tag[rt];
        return res;
    }
}

vector<int> vc[maxn];
ll ans;
int rt[maxn];

inline void update(int &rt, int x, int y) {
    while (x) {
        SGT::update(rt, 1, n, dfn2[top[x]], dfn2[x], y);
        x = fa[top[x]];
    }
}

inline pii query(int rt, int x) {
    pii res(inf, 0);
    while (x) {
        res = res + SGT::query(rt, 1, n, dfn2[top[x]], dfn2[x]);
        x = fa[top[x]];
    }
    return res;
}

void dfs4(int u) {
    for (int v : G1[u]) {
        dfs4(v);
        rt[u] = SGT::merge(rt[u], rt[v], 1, n);
    }
    SGT::modify(rt[u], 1, n, dfn2[u]);
    update(rt[u], u, 2);
    for (int v : G1[u]) {
        int w = qlca2(u, v);
        update(rt[u], w, -1);
    }
    for (int v : vc[u]) {
        update(rt[u], v, -1);
    }
    pii p = query(rt[u], u);
    if (p.fst == 2) {
        ans += p.scd;
    }
}

void solve() {
    scanf("%lld%lld", &type, &n);
    for (int i = 1; i <= n; ++i) {
        fa[i] = i;
    }
    for (int i = 1, u, v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        G[u].pb(v);
        G[v].pb(u);
    }
    for (int i = 1; i <= n; ++i) {
        for (int j : G[i]) {
            if (j < i && find(i) != find(j)) {
                int k = find(j);
                p[k] = j;
                fa[k] = i;
                pa[k] = i;
                G1[i].pb(k);
            }
        }
    }
    for (int i = 1; i <= n; ++i) {
        fa[i] = i;
    }
    for (int i = n; i; --i) {
        for (int j : G[i]) {
            if (j > i && find(j) != find(i)) {
                int k = find(j);
                fa[k] = i;
                G2[i].pb(k);
            }
        }
    }
    if (type == 1) {
        Sub1::solve();
        return;
    }
    dfs(n, 0);
    tim = 0;
    dfs2(1, 0, 1);
    dfs3(1, 1);
    for (int j = 1; (1 << j) <= n; ++j) {
        for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
            st1[j][i] = get1(st1[j - 1][i], st1[j - 1][i + (1 << (j - 1))]);
            st2[j][i] = get2(st2[j - 1][i], st2[j - 1][i + (1 << (j - 1))]);
        }
    }
    for (int i = 1; i <= n; ++i) {
        for (int j : G2[i]) {
            vc[qlca1(i, j)].pb(i);
        }
    }
    SGT::init();
    dfs4(n);
    printf("%lld\n", ans % mod);
}

int main() {
    int T = 1;
    // scanf("%d", &T);
    while (T--) {
        solve();
    }
    return 0;
}

:::