题解 P8493 [IOI2022]数字电路

· · 题解

我们发现算方案数很麻烦,于是考虑计算节点 0 的值为 1 的概率。设 f_{i,j} 表示以 i 为根的子树,i 的位置参数为 j 时,i 的值为 1 的概率。

但这样显然是过不去的,考虑记 c_i 表示 i 的儿子数,g_x=\dfrac{1}{c_x}\sum\limits_{i=1}^{c_x}f_{x,i},我们发现我们只关心 g_x 的值,考虑化简。

为了方便,记 x 的儿子依次为 a_1,a_2,a_3,\dots,a_{c_x},记 p_1=\sum\limits_{i=1}^{c_x} g_{a_i}p_2=\sum\limits_{i<j}g_{a_i}\times g_{a_j},以下以此类推。那么根据容斥原理,有:

f_{x,1}=p_1-p_2+p_3-p_4+p_5\dots f_{x,2}=p_2-2p_3+3p_4-4p_5\dots f_{x,3}=p_3-3p_4\dots+6p_5 f_{x,4}=p_4-4p_5\dots

通过观察我们发现后面几项在加起来的时候都能被消掉,于是有 g_x=\dfrac{1}{c_x}p_1=\sum\limits_{i=1}^{c_x}g_{a_i}

那么一个叶子节点对概率的贡献就是自身的初始值除以所有在这个叶子节点到根的路径上的节点的儿子数的积。由于答案还要乘上 \prod\limits_{i=0}^{n+m-1}c_i,所以一个节点的贡献就是所有不在这个叶子节点到根的路径上的节点的儿子数的积乘上自身初始权值。这个可以在 dfs 时顺便求出。

至于区间翻转,在线段树上维护两个值 s_0s_1 表示区间内权值为 0 的贡献和以及权值为 1 的贡献和即可。

时间复杂度:\Theta(n+m+q\log m)

Code:

#include <cstdio>
#include <algorithm>
#include <vector>
#include "circuit.h"
using namespace std ;
const int MAXN = 2e5 + 10 , mod = 1000002022 ;
typedef long long ll ;
typedef vector<int> vi ;
int n , m , a[MAXN] , fir[MAXN] , tot , id[MAXN] , f[MAXN] ;
ll mul[MAXN] , w[MAXN] , pre[MAXN] , suf[MAXN] ;
struct edge {int to , nxt ;} e[MAXN] ;
void add (int u , int v) {
    e[++tot].to = v ; e[tot].nxt = fir[u] ; fir[u] = tot ;
}
void dfs1 (int x) {
    int cnt = 0 ; 
    for (int i = fir[x] ; i ; i = e[i].nxt) cnt++ ;
    if (!fir[x]) mul[x] = 1 ;
    else mul[x] = cnt ;
    for (int i = fir[x] ; i ; i = e[i].nxt)
        dfs1 (e[i].to) , mul[x] = mul[x] * mul[e[i].to] % mod ;
}
void dfs2 (int x) {
    if (!fir[x]) return ;
    int cnt = 0 ; 
    for (int i = fir[x] ; i ; i = e[i].nxt)
        id[e[i].to] = ++cnt , pre[cnt] = suf[cnt] = mul[e[i].to] ;
    pre[0] = suf[cnt + 1] = 1 ;
    for (int i = 2 ; i <= cnt ; i++) pre[i] = pre[i - 1] * pre[i] % mod ;
    for (int i = cnt - 1 ; i ; i--) suf[i] = suf[i + 1] * suf[i] % mod ;
    for (int i = fir[x] ; i ; i = e[i].nxt) {
        int v = e[i].to ;
        w[v] = w[x] * pre[id[v] - 1] % mod * suf[id[v] + 1] % mod ;
    }
    for (int i = fir[x] ; i ; i = e[i].nxt) dfs2 (e[i].to) ;
}
#define lc (o << 1)
#define rc (o << 1 | 1)
#define mid ((l + r) >> 1)
ll s0[MAXN << 2] , s1[MAXN << 2] ;
int lz[MAXN << 2] ;
void build (int o , int l , int r) {
    if (l == r) {f[l] ? (s1[o] = w[l + n - 1]) : (s0[o] = w[l + n - 1]) ; return ;}
    build (lc , l , mid) , build (rc , mid + 1 , r) ;
    s0[o] = (s0[lc] + s0[rc]) % mod , s1[o] = (s1[lc] + s1[rc]) % mod ;
}
void pushdown (int o) {
    if (!o || !lz[o]) return ;
    swap (s0[lc] , s1[lc]) , swap (s0[rc] , s1[rc]) , lz[lc] ^= 1 , lz[rc] ^= 1 , lz[o] = 0 ;
}
void upd (int o , int l , int r , int x , int y) {
    if (x <= l && r <= y) {lz[o] ^= 1 , swap (s0[o] , s1[o]) ; return ;}
    pushdown (o) ;
    if (x <= mid) upd (lc , l , mid , x , y) ;
    if (mid < y) upd (rc , mid + 1 , r , x , y) ;
    s0[o] = (s0[lc] + s0[rc]) % mod , s1[o] = (s1[lc] + s1[rc]) % mod ;
}
void init (int N , int M , vi p , vi a) {
    n = N , m = M ;
    for (int i = 1 ; i < n + m ; i++) add (p[i] , i) ;
    for (int i = 0 ; i < m ; i++) f[i + 1] = a[i] ;
    dfs1 (0) , w[0] = 1 , dfs2 (0) ;
    build (1 , 1 , m) ;
}
int count_ways (int x , int y) {
    upd (1 , 1 , m , x - n + 1 , y - n + 1) ;
    return s1[1] ;
}