P14251 [集训队互测 2025] Everlasting Friends?
EuphoricStar · · 题解
考虑
然后有一个比较深刻的观察是,把选连通块看成断一些边,一条边能断当且仅当它只被覆盖一次,感性理解就是如果被覆盖两次,那么断了这条边会导致那两个上端点不连通。
那么考虑 DP,
考虑优化。固然可以 DDP 优化到
只做一次 DFS,递归到
从上往下设,那么考虑到一条原本可以断开的边 集合还在追我()。时间复杂度
考虑
并且可能的
通过上述过程可以观察出结论:
于是问题转化成有多少对
数连通块自然考虑点减边容斥。枚举
那么考虑在
时空复杂度均为
:::info[代码]
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;
const int maxn = 200100;
const int logn = 20;
const int maxm = 16000100;
const int inf = 0x3f3f3f3f;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, type;
int fa[maxn], p[maxn], pa[maxn];
vector<int> G[maxn], G1[maxn], G2[maxn];
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
namespace Sub1 {
ll ans;
struct node {
ll x, y;
node(ll _x = 0, ll _y = 0) : x(_x), y(_y) {}
} f[maxn];
inline node operator + (const node &a, const node &b) {
if (a.y < b.y) {
return a;
} else if (a.y > b.y) {
return b;
} else {
ll x = (a.x + b.x) % mod;
if (x == 0) {
return node(1, a.y + 1);
} else {
return node(x, a.y);
}
}
}
inline node operator * (const node &a, const node &b) {
return node(a.x * b.x % mod, a.y + b.y);
}
inline node operator / (const node &a, const node &b) {
return node(a.x * qpow(b.x, mod - 2) % mod, a.y - b.y);
}
void dfs(int u) {
f[u] = node(1, 0);
for (int v : G1[u]) {
dfs(v);
int w = p[v];
for (int x = find(w); x != v; x = find(x)) {
f[v] = f[v] / (f[x] + node(1, 0)) * f[x];
fa[x] = pa[x];
}
f[u] = f[u] * (f[v] + node(1, 0));
}
ans = (ans + (f[u].y ? 0 : f[u].x)) % mod;
}
void solve() {
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
dfs(n);
printf("%lld\n", ans);
}
}
int st1[logn][maxn], st2[logn][maxn], dfn1[maxn], dfn2[maxn], tim;
inline int get1(int i, int j) {
return dfn1[i] < dfn1[j] ? i : j;
}
inline int get2(int i, int j) {
return dfn2[i] < dfn2[j] ? i : j;
}
inline int qlca1(int x, int y) {
if (x == y) {
return x;
}
x = dfn1[x];
y = dfn1[y];
if (x > y) {
swap(x, y);
}
++x;
int k = __lg(y - x + 1);
return get1(st1[k][x], st1[k][y - (1 << k) + 1]);
}
inline int qlca2(int x, int y) {
if (x == y) {
return x;
}
x = dfn2[x];
y = dfn2[y];
if (x > y) {
swap(x, y);
}
++x;
int k = __lg(y - x + 1);
return get2(st2[k][x], st2[k][y - (1 << k) + 1]);
}
void dfs(int u, int t) {
dfn1[u] = ++tim;
st1[0][tim] = t;
for (int v : G1[u]) {
dfs(v, u);
}
}
int sz[maxn], son[maxn], dep[maxn], top[maxn];
int dfs2(int u, int f, int d) {
fa[u] = f;
sz[u] = 1;
dep[u] = d;
int mx = -1;
for (int v : G2[u]) {
sz[u] += dfs2(v, u, d + 1);
if (sz[v] > mx) {
son[u] = v;
mx = sz[v];
}
}
return sz[u];
}
void dfs3(int u, int tp) {
top[u] = tp;
dfn2[u] = ++tim;
st2[0][tim] = fa[u];
if (!son[u]) {
return;
}
dfs3(son[u], tp);
for (int v : G2[u]) {
if (!dfn2[v]) {
dfs3(v, v);
}
}
}
inline pii operator + (const pii &a, const pii &b) {
if (a.fst < b.fst) {
return a;
} else if (a.fst > b.fst) {
return b;
} else {
return mkp(a.fst, a.scd + b.scd);
}
}
namespace SGT {
int ls[maxm], rs[maxm], tag[maxm], nt, stk[maxm], top;
pii a[maxm];
inline void init() {
for (int i = 0; i < maxm; ++i) {
a[i] = pii(inf, 0);
}
}
inline void pushup(int x) {
a[x] = a[ls[x]] + a[rs[x]];
a[x].fst += tag[x];
}
inline void pushtag(int x, int y) {
if (!x) {
return;
}
a[x].fst += y;
tag[x] += y;
}
inline void delnode(int x) {
a[x] = pii(inf, 0);
ls[x] = rs[x] = tag[x] = 0;
if (top + 1 < maxm) {
stk[++top] = x;
}
}
inline int newnode() {
assert(nt + 1 < maxm);
return top ? stk[top--] : (++nt);
}
void update(int &rt, int l, int r, int ql, int qr, int x) {
if (!rt) {
rt = newnode();
}
if (ql <= l && r <= qr) {
pushtag(rt, x);
return;
}
int mid = (l + r) >> 1;
if (ql <= mid) {
update(ls[rt], l, mid, ql, qr, x);
}
if (qr > mid) {
update(rs[rt], mid + 1, r, ql, qr, x);
}
pushup(rt);
}
void modify(int &rt, int l, int r, int x) {
if (!rt) {
rt = newnode();
}
if (l == r) {
a[rt].fst -= inf;
a[rt].scd = 1;
return;
}
int mid = (l + r) >> 1;
(x <= mid) ? modify(ls[rt], l, mid, x) : modify(rs[rt], mid + 1, r, x);
pushup(rt);
}
int merge(int u, int v, int l, int r) {
if (!u || !v) {
return u | v;
}
tag[u] += tag[v];
if (l == r) {
bool fl = (a[u].fst > 1e9) && (a[v].fst > 1e9);
if (a[u].fst >= inf) {
a[u].fst -= inf;
}
if (a[v].fst >= inf) {
a[v].fst -= inf;
}
a[u].fst = a[u].fst + a[v].fst + (fl ? inf : 0);
a[u].scd |= a[v].scd;
delnode(v);
return u;
}
int mid = (l + r) >> 1;
ls[u] = merge(ls[u], ls[v], l, mid);
rs[u] = merge(rs[u], rs[v], mid + 1, r);
pushup(u);
delnode(v);
return u;
}
pii query(int rt, int l, int r, int ql, int qr) {
if (!rt) {
return pii(inf, 0);
}
if (ql <= l && r <= qr) {
return a[rt];
}
int mid = (l + r) >> 1;
pii res(inf, 0);
if (ql <= mid) {
res = res + query(ls[rt], l, mid, ql, qr);
}
if (qr > mid) {
res = res + query(rs[rt], mid + 1, r, ql, qr);
}
res.fst += tag[rt];
return res;
}
}
vector<int> vc[maxn];
ll ans;
int rt[maxn];
inline void update(int &rt, int x, int y) {
while (x) {
SGT::update(rt, 1, n, dfn2[top[x]], dfn2[x], y);
x = fa[top[x]];
}
}
inline pii query(int rt, int x) {
pii res(inf, 0);
while (x) {
res = res + SGT::query(rt, 1, n, dfn2[top[x]], dfn2[x]);
x = fa[top[x]];
}
return res;
}
void dfs4(int u) {
for (int v : G1[u]) {
dfs4(v);
rt[u] = SGT::merge(rt[u], rt[v], 1, n);
}
SGT::modify(rt[u], 1, n, dfn2[u]);
update(rt[u], u, 2);
for (int v : G1[u]) {
int w = qlca2(u, v);
update(rt[u], w, -1);
}
for (int v : vc[u]) {
update(rt[u], v, -1);
}
pii p = query(rt[u], u);
if (p.fst == 2) {
ans += p.scd;
}
}
void solve() {
scanf("%lld%lld", &type, &n);
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
for (int i = 1; i <= n; ++i) {
for (int j : G[i]) {
if (j < i && find(i) != find(j)) {
int k = find(j);
p[k] = j;
fa[k] = i;
pa[k] = i;
G1[i].pb(k);
}
}
}
for (int i = 1; i <= n; ++i) {
fa[i] = i;
}
for (int i = n; i; --i) {
for (int j : G[i]) {
if (j > i && find(j) != find(i)) {
int k = find(j);
fa[k] = i;
G2[i].pb(k);
}
}
}
if (type == 1) {
Sub1::solve();
return;
}
dfs(n, 0);
tim = 0;
dfs2(1, 0, 1);
dfs3(1, 1);
for (int j = 1; (1 << j) <= n; ++j) {
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
st1[j][i] = get1(st1[j - 1][i], st1[j - 1][i + (1 << (j - 1))]);
st2[j][i] = get2(st2[j - 1][i], st2[j - 1][i + (1 << (j - 1))]);
}
}
for (int i = 1; i <= n; ++i) {
for (int j : G2[i]) {
vc[qlca1(i, j)].pb(i);
}
}
SGT::init();
dfs4(n);
printf("%lld\n", ans % mod);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}
:::