P6630 [ZJOI2020] 传统艺能 题解

· · 题解

同步于我的 Blog

2023.6.10 UPD: 3 * 3 的式子原来有笔误,感谢 灰鹤在此 和 llzer 的指出。

有一棵广义线段树,每个节点有一个 m 值。一开始 tag 数组均为 0,Bob 会执行 k 次操作,每次操作等概率随机选择区间 [l, r] 并执行 MODIFY(root,1,n,l,r);。 最后所有 Node 中满足 tag[Node]=1 的期望数量。

n \le 2\times 10^5

  看着题解想锤人的题。。。

  很多题解都不讲清状态的设计,以及为什么只要一个 3\times3 的矩阵,为什么状态可以这么设计,于是自行研究了一个 4\times 4 的矩阵的做法才明白原因。于是这个题解讲一个比较好理解的方式。

  首先,考虑根据期望的线性性计算答案,计算每个节点最后有 tag 的概率,比较显然的事情就是 tag 只有两种:祖先的 tag(是否存在一个祖先的 tag = 1) 和自己的 tag。接着我们可以只记录这两种 tag,并将每次操作的节点分类:

  然后按照编号 ABCDE 计算方案数。

关于方案数的计算,我们可以发现,单个节点与操作区间 没有交集/被包含/没有被包含但是有交集 的概率是比较好计算的。 那么 A,B 可以从父亲节点继承过来,C 就是自己的没有交集的概率-父亲没有交集的概率,D 就是自己被包含的概率-父亲被包含的概率,最后一个,E =1-A-B-C-D 即可。

  之后魔幻的事情就来了,大部分题解都直接说维护一个 3\times3 的矩阵然后搞搞搞搞搞,甚至有的没有讲清状态的定义,令人一脸蒙 B,看了半年(尤其是式子比较复杂,还不如自己推一推,于是我就推了大半天才明白)。

  最朴素直接的想法应该是设 f_0, f_1, f_2, f_3 表示 祖先和自己 tag 都为 0 、只有自己的 tag = 1、存在一个祖先 tag = 1、自己 tag = 1 并且存在一个祖先 tag = 1 的概率,转移的话可以根据定义以及变化列出以下式子:

\left\{ \begin{aligned} f'_0 &= Af_0 + Cf_0 + E(f_0 + f_1 + f_2 + f_3)\\ f'_1 &= Af_1 + C(f_1+f_2+f_3)+D(f_0+f_1+f_2+f_3)\\ f'_2 &= Af_2 + B(f_0 + f_2)\\ f'_3 &= Af_3 + B(f_1 + f_3) \end{aligned} \right.

  这个可以直接使用 4\times 4 的矩阵维护,最后的答案就是 f_1+f_3

  但是为什么很多题解都拿的 3\times 3 的矩阵呢?

  我们可以注意到,最后的答案就是 f_1+f_3,然后再自己观察一下两个的转移的式子,就可以发现可以将两个式子合并到一起,于是可以直接维护一个 3 \times 3 的矩阵了。

  具体地,将 f_1 状态设为 自己的 tag=1,祖先的 tag 任意的概率,其余不变,然后式子大概长这样:

\left\{ \begin{aligned} f'_0 &= Af_0 + Cf_0 + E(f_0 + f_1 + f_2)\\ f'_1 &= Af_1 + Bf_1 + C(f_1+f_2)+D(f_0+f_1+f_2)\\ f'_2 &= Af_2 + B(f_0 + f_2)\\ \end{aligned} \right.

  然后可以使用 3\times3 的矩阵维护了,也可以根据 f_0 + f_1+f_2 = 1 ,在矩阵中维护两个状态 + 一个单位元。

  因为 4\times4 的可以直接过,于是 3\times 3 代码没写,下面的是 4\times 4 的代码,所以如果 3\times3 的式子有问题可以与我联系。

  代码(gitee)

 #include <bits/stdc++.h>

#define eb emplace_back
#define ep emplace
#define fi first
#define se second
#define in read<int>()
#define lin read<ll>()
#define rep(i, x, y) for(int i = (x); i <= (y); i++)
#define per(i, x, y) for(int i = (x); i >= (y); i--)

using namespace std;

using ll = long long;
using db = double;
using pii = pair < int, int >;
using vec = vector < int >;
using veg = vector < pii >;

template < typename T > T read() {
    T x = 0; bool f = 0; char ch = getchar();
    while(!isdigit(ch)) f |= ch == '-', ch = getchar();
    while(isdigit(ch)) x = x * 10 + (ch ^ 48), ch = getchar();
    return f ? -x : x;
}

template < typename T > void chkmax(T &x, const T &y) { x = x > y ? x : y; }
template < typename T > void chkmin(T &x, const T &y) { x = x < y ? x : y; }

const int N = 2e5 + 10;
const int mod = 998244353;
const int inv2 = (mod + 1) >> 1;

int n, K, iv, ans;

ll qp(ll x, int t) { ll res = 1; for(; t; t >>= 1, x = x * x % mod) if(t & 1) res = res * x % mod; return res; }

struct node {
    int a, b, c; // a : 不交, b : 被包含, c : 不被包含但是有交
    node(int l, int r) {
        a = (1ll * l * (l - 1) / 2 % mod * iv % mod + 1ll * (n - r + 1) * (n - r) / 2 % mod * iv % mod) % mod;
        b = 1ll * l * (n - r + 1) % mod * iv % mod;
        c = (mod * 2 - a - b + 1) % mod;
    } node(int _a, int _b, int _c) : a(_a), b(_b), c(_c) { }
    node() { a = b = c = 0; }
};

struct mat {
    int a[4][4];
    mat() { memset(a, 0, sizeof a); }
    int* operator [](int x) { return a[x]; }
    friend mat operator * (mat a, mat b) {
        mat c;
        rep(i, 0, 3) rep(j, 0, 3) if(a[i][j])
            rep(k, 0, 3) c[i][k] = (c[i][k] + 1ll * a[i][j] * b[j][k] % mod) % mod;
        return c;
    }
} t;

void solve(int l, int r, const node &lst) {
    t = mat(); node c(l, r);
    ll A = lst.a, B = lst.b, C = (c.a - lst.a + mod) % mod, D = (c.b - lst.b + mod) % mod, E = (1ll - A - B - C - D + mod * 4ll) % mod;

    t[0][0] = (A + C + E) % mod; t[0][1] = D;                 t[0][2] = B;
    t[1][0] = E;                 t[1][1] = (A + C + D) % mod;                              t[1][3] = B;
    t[2][0] = E;                 t[2][1] = (C + D) % mod;     t[2][2] = (A + B) % mod; 
    t[3][0] = E;                 t[3][1] = (C + D) % mod;     t[3][2] = 0;                 t[3][3] = (A + B) % mod;

    mat res = t; for(int v = K - 1; v; v >>= 1, t = t * t) if(v & 1) res = res * t;
    ans = (ans + res[0][1]) % mod; ans = (ans + res[0][3]) % mod;
    if(l == r) return; int mid = in; solve(l, mid, c); solve(mid + 1, r, c);
}

int main() {
#ifndef ONLINE_JUDGE
    freopen("1.in", "r", stdin);
#endif
    n = in, K = in; iv = qp(1ll * n * (n + 1) / 2 % mod, mod - 2);
    solve(1, n, node(0, 0, 1)); printf("%d\n", ans); return 0;
}