题解 P6630 【[ZJOI2020] 传统艺能】

Thinking

2020-07-08 15:42:33

Solution

首先,每个结点都是独立的,所以不妨分开计算贡献。 按照[P5280](https://www.luogu.com.cn/problem/P5280)的套路,我们定义 $P_o$ 为 $o$ 上有标记的概率,$P_f$ 为 $o$ 及其祖先上有标记的概率,同时记 $o$ 控制的区间为 $[L,R]$,$o$ 的父亲的为 $[l,r]$(可以避免对左/右儿子的分类讨论),修改区间为 $[x,y]$。 考虑以下五种情况: 1. 修改不会进入 $o$ 的父亲。则 $1\le x\le y<l$ 或 $r<x\le y\le n$,此时 $P_o'=P_o,P_f'=P_f$。 2. 修改进入 $o$ 的父亲,并在 $o$ 上打了标记。则 $1\le x\le L,R\le y<r$ 或 $l<x\le L,R\le y\le n$,此时 $P_o'=P_f'=1$。 3. 修改标记了 $o$ 的祖先。则 $1\le x\le l,r\le y\le n$,此时 $P_o'=P_o,P_f'=1$。 4. 修改下推了 $o$ 的父亲的标记,但没有进入 $o$。则 $l\le y<L,1\le x\le y$ 或 $R<x\le r,x\le y\le n$,此时 $P_o'=P_f'=P_f$。 5. 修改进入了 $o$。则 $P_o'=P_f'=0$。 显然每次转移的系数都不变,因此可以用矩阵加速。同时,有一个可以大大减小常数的 trick:可以注意到 $3\times 3$ 矩阵中在变化的值只有 5 个,因此只需要计算这 5 个值就可以了,可以将乘法次数减小到 7 次,取模次数减小到 5 次。 code: ```cpp #include<cstdio> typedef long long ll; const int mod=998244353; const int BUF=1<<21; char rB[BUF],*rS,*rT; inline char gc(){return rS==rT&&(rT=(rS=rB)+fread(rB,1,BUF,stdin),rS==rT)?EOF:*rS++;} inline int rd(){ char c=gc(); while(c<48||c>57)c=gc(); int x=c&15; for(c=gc();c>=48&&c<=57;c=gc())x=x*10+(c&15); return x; } int n,k,div,ans; inline void inc(int&a,int b){a+=b;if(a>=mod)a-=mod;} void exgcd(int a,int b,int&x,int&y){ if(!b){x=1;y=0;} else{ exgcd(b,a%b,y,x); y-=a/b*x; } } inline int inv(int a){ int x,y; exgcd(a,mod,x,y); return x+((x>>31)&mod); } struct mat{ int a11,a12,a13,a22,a23; mat(int a,int b,int c,int d,int e):a11(a),a12(b),a13(c),a22(d),a23(e){} inline mat operator*(const mat&b)const{return mat((ll)a11*b.a11%mod,((ll)a11*b.a12+(ll)a12*b.a22)%mod,((ll)a11*b.a13+(ll)a12*b.a23+a13)%mod,(ll)a22*b.a22%mod,((ll)a22*b.a23+a23)%mod);} }; inline int fp(mat a){ mat s(1,0,0,1,0); for(int p=k;p;p>>=1){ if(p&1)s=s*a; a=a*a; } return s.a13; } void solve(int L,int R,int l,int r){ if(L<R){ int M=rd(); solve(L,M,L,R);solve(M+1,R,L,R); } if(!l)inc(ans,div); //o 为根节点 else{ int a=((ll)l*(l-1)+(ll)(n-r)*(n-r+1)>>1)%mod,b=((ll)L*(r-R)+(ll)(n-R+1)*(L-l))%mod,c=(ll)l*(n-r+1)%mod,d=((ll)(l+L-1)*(L-l)+(ll)(n*2-R-r+1)*(r-R)>>1)%mod; //a,b,c,d 对应情况 1,2,3,4 inc(ans,fp(mat((ll)(a+c)*div%mod,(ll)d*div%mod,(ll)b*div%mod,(ll)(a+d)*div%mod,(ll)(b+c)*div%mod))); } } int main(){ n=rd();k=rd();div=inv(((ll)n*(n+1)>>1)%mod); solve(1,n,0,0); printf("%d",ans); return 0; } ```