题解 P8493 [IOI2022]数字电路
我们发现算方案数很麻烦,于是考虑计算节点
但这样显然是过不去的,考虑记
为了方便,记
通过观察我们发现后面几项在加起来的时候都能被消掉,于是有
那么一个叶子节点对概率的贡献就是自身的初始值除以所有在这个叶子节点到根的路径上的节点的儿子数的积。由于答案还要乘上
至于区间翻转,在线段树上维护两个值
时间复杂度:
Code:
#include <cstdio>
#include <algorithm>
#include <vector>
#include "circuit.h"
using namespace std ;
const int MAXN = 2e5 + 10 , mod = 1000002022 ;
typedef long long ll ;
typedef vector<int> vi ;
int n , m , a[MAXN] , fir[MAXN] , tot , id[MAXN] , f[MAXN] ;
ll mul[MAXN] , w[MAXN] , pre[MAXN] , suf[MAXN] ;
struct edge {int to , nxt ;} e[MAXN] ;
void add (int u , int v) {
e[++tot].to = v ; e[tot].nxt = fir[u] ; fir[u] = tot ;
}
void dfs1 (int x) {
int cnt = 0 ;
for (int i = fir[x] ; i ; i = e[i].nxt) cnt++ ;
if (!fir[x]) mul[x] = 1 ;
else mul[x] = cnt ;
for (int i = fir[x] ; i ; i = e[i].nxt)
dfs1 (e[i].to) , mul[x] = mul[x] * mul[e[i].to] % mod ;
}
void dfs2 (int x) {
if (!fir[x]) return ;
int cnt = 0 ;
for (int i = fir[x] ; i ; i = e[i].nxt)
id[e[i].to] = ++cnt , pre[cnt] = suf[cnt] = mul[e[i].to] ;
pre[0] = suf[cnt + 1] = 1 ;
for (int i = 2 ; i <= cnt ; i++) pre[i] = pre[i - 1] * pre[i] % mod ;
for (int i = cnt - 1 ; i ; i--) suf[i] = suf[i + 1] * suf[i] % mod ;
for (int i = fir[x] ; i ; i = e[i].nxt) {
int v = e[i].to ;
w[v] = w[x] * pre[id[v] - 1] % mod * suf[id[v] + 1] % mod ;
}
for (int i = fir[x] ; i ; i = e[i].nxt) dfs2 (e[i].to) ;
}
#define lc (o << 1)
#define rc (o << 1 | 1)
#define mid ((l + r) >> 1)
ll s0[MAXN << 2] , s1[MAXN << 2] ;
int lz[MAXN << 2] ;
void build (int o , int l , int r) {
if (l == r) {f[l] ? (s1[o] = w[l + n - 1]) : (s0[o] = w[l + n - 1]) ; return ;}
build (lc , l , mid) , build (rc , mid + 1 , r) ;
s0[o] = (s0[lc] + s0[rc]) % mod , s1[o] = (s1[lc] + s1[rc]) % mod ;
}
void pushdown (int o) {
if (!o || !lz[o]) return ;
swap (s0[lc] , s1[lc]) , swap (s0[rc] , s1[rc]) , lz[lc] ^= 1 , lz[rc] ^= 1 , lz[o] = 0 ;
}
void upd (int o , int l , int r , int x , int y) {
if (x <= l && r <= y) {lz[o] ^= 1 , swap (s0[o] , s1[o]) ; return ;}
pushdown (o) ;
if (x <= mid) upd (lc , l , mid , x , y) ;
if (mid < y) upd (rc , mid + 1 , r , x , y) ;
s0[o] = (s0[lc] + s0[rc]) % mod , s1[o] = (s1[lc] + s1[rc]) % mod ;
}
void init (int N , int M , vi p , vi a) {
n = N , m = M ;
for (int i = 1 ; i < n + m ; i++) add (p[i] , i) ;
for (int i = 0 ; i < m ; i++) f[i + 1] = a[i] ;
dfs1 (0) , w[0] = 1 , dfs2 (0) ;
build (1 , 1 , m) ;
}
int count_ways (int x , int y) {
upd (1 , 1 , m , x - n + 1 , y - n + 1) ;
return s1[1] ;
}