二宫凉夏
takanashi_mifuru · · 题解
没懂为啥题解全是八个状态。
首先肯定要能做
然后我们考虑在这个条件下,我们把一个节点进进退退所涉及到的两个节点之间的边直接点亮,那么最后点亮的边一定是若干条链。
这是为啥呢,你考虑假设不是链,那就有地方度数在
然后考虑这是条链还不够,还得每一步点集唯一,那我已经构造出一组合法方案,只要踏出一步使得和这个方案不一样那就是非法方案。
什么情况下会出现一个非法状态呢,我们先把这个图的大概样式写出来,对于每条链他由两个部分组成,链头一回合有点一回合没有点,链中间始终有点。
然后我们发现如果链头即将消失,旁边有一个不在链上的点的话,这个点就可以跑过来使得这个方案不唯一。
这启发我们链中间不能和链头拼在一起。
链头是可以和链头拼在一起的,但是显然他们的状态受到限制。
而链中间也是可以和链中间拼在一起的。
还有把所有点亮的边都点亮之后,点也都被点亮了,一旦出现空点我就可以直接踩过去然后方案就不唯一了。
好的,那么有这个我们就可以来设计一个状态,设
然后你发现这个东西有个问题,那就是链头如果被上面的点接过去了,那他不就变成链中间了吗?
所以加一维,设
然后就做完啦!不过注意到因为有
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int P=998244353;
int n;
vector<int> ljb[200005];
int f[200005][2];
int g[200005];
int power(int x,int y=P-2){
if(y==0)return 1;
int tmp=power(x,y>>1);
if(y&1)return tmp*tmp%P*x%P;
return tmp*tmp%P;
}
void getf(int cur,int fa){
for(int i=0;i<ljb[cur].size();i++){
int v=ljb[cur][i];
if(v==fa)continue;
getf(v,cur);
}
int mul=1;
int cnt=0;
int num=0;
for(int i=0;i<ljb[cur].size();i++){//
int v=ljb[cur][i];
if(v==fa)continue;
cnt++;
if(!f[v][0]){
num++;
continue;
}
mul=mul*f[v][0]%P;
}
if(!num){
for(int i=0;i<ljb[cur].size();i++){
int v=ljb[cur][i];
if(v==fa)continue;
f[cur][0]+=f[v][1]*mul%P*power(f[v][0])%P*power(power(2,cnt-1))%P;
if(f[cur][0]>=P)f[cur][0]-=P;
}
}
if(num==1){
for(int i=0;i<ljb[cur].size();i++){
int v=ljb[cur][i];
if(v==fa)continue;
if(f[v][0])continue;
f[cur][0]+=f[v][1]*mul%P*power(power(2,cnt-1))%P;
if(f[cur][0]>=P)f[cur][0]-=P;
}
}
int all=1;
for(int i=0;i<ljb[cur].size();i++){
int v=ljb[cur][i];
if(v==fa)continue;
all=all*f[v][0]%P*power(2)%P;
}
all=all*2%P;
f[cur][1]=all;
mul=1;
num=0;
for(int i=0;i<ljb[cur].size();i++){
int v=ljb[cur][i];
if(v==fa)continue;
if(!g[v]){
num++;
continue;
}
mul=mul*g[v]%P;
}
if(!num){
for(int i=0;i<ljb[cur].size();i++){
int v=ljb[cur][i];
if(v==fa)continue;
f[cur][1]+=f[v][1]*mul%P*power(g[v])%P;
if(f[cur][1]>=P)f[cur][1]-=P;
}
}
if(num==1){
for(int i=0;i<ljb[cur].size();i++){
int v=ljb[cur][i];
if(v==fa)continue;
if(g[v])continue;
f[cur][1]+=f[v][1]*mul%P;
if(f[cur][1]>=P)f[cur][1]-=P;
}
}
if(num==0){
int sum=0;
for(int i=ljb[cur].size()-1;i>=0;i--){
int v1=ljb[cur][i];
if(v1==fa)continue;
int val1=(f[v1][1]*power(g[v1])%P);
g[cur]+=sum*val1;
g[cur]%=P;
sum+=(f[v1][1]*power(g[v1])%P)%P*power(2)%P*mul%P;
if(sum>=P)sum-=P;
}
}
if(num==1){
for(int i=0;i<ljb[cur].size();i++){
int v1=ljb[cur][i];
if(v1==fa)continue;
if(g[v1])continue;
for(int j=0;j<ljb[cur].size();j++){
int v2=ljb[cur][j];
if(v1==fa||v2==fa)continue;
if(v1==v2)continue;
g[cur]+=f[v1][1]*f[v2][1]%P*power(2)%P*mul%P*power(g[v2])%P;
if(g[cur]>=P)g[cur]-=P;
}
}
}
if(num==2){
for(int i=0;i<ljb[cur].size();i++){
int v1=ljb[cur][i];
if(v1==fa)continue;
if(g[v1])continue;
for(int j=i+1;j<ljb[cur].size();j++){
int v2=ljb[cur][j];
if(v1==fa||v2==fa)continue;
if(v1==v2)continue;
if(g[v2])continue;
g[cur]+=f[v1][1]*f[v2][1]%P*power(2)%P*mul%P;
if(g[cur]>=P)g[cur]-=P;
}
}
}
return;
}
signed main(){
scanf("%lld",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%lld%lld",&u,&v);
ljb[u].push_back(v);
ljb[v].push_back(u);
}
getf(1,0);
printf("%lld\n",(g[1]+f[1][0])%P);
return 0;
}