题解 P4776 【[NOI2018]多边形】
状压dp,需要一定代码能力
对于每个子树,枚举其左边k个叶子和右边k个叶子的联通情况,其余的叶子节点不是路径端点状态。每个叶子节点共有四种情况。
- 不是路径的端点。(记为0)
- 是子树根所在路径的端点。(记为1,合法状态中该节点个数小于等于2)
- 不与任何点相连,单独组成一种路径。(记为-1)
- 是非子树根所在路径路径的端点。(每条路径记为一种颜色)
预处理出状态两两之间的转移情况。只考虑左右子树叶子节点之间连边情况即可。之后直接进行树形dp即可。统计答案时,枚举最左k个叶子和最右k个叶子的连边情况即可。(当叶子总数小于等于2k时,由于叶子之间连边会被重复计算,只能将链缩为一个点后直接暴力)。具体详见代码。
k=3,合法状态数共有2665种。左右状态叶子个数均大于等于6时,转移共有201536种。可以通过。
#include<cstdio>
#include<algorithm>
#include<queue>
#include<cmath>
#include<iostream>
#include<map>
#include<cstring>
#define ll long long
#define maxn 1005
#define p 998244353
#define re(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
inline int read(){char c=getchar();int f=1;int ans = 0;while(c>'9'||c<'0'){if(c=='-')f=-f;c=getchar();}while(c<='9'&&c>='0'){ans =ans*10+c-'0';c=getchar();}return ans*f;}
//_____________________________________________________________________________________________________
int k,used[maxn],_n,n,m,_a[maxn][maxn],__a[maxn][maxn],leaf[maxn],_st[maxn];
ll ans;
inline int dfs1(int x)//缩链
{
int cnt = 0,y;
re(i,1,_n)
if(_a[x][i])
cnt++;
if(cnt==1)
{
re(i,1,_n)
if(_a[x][i])
return dfs1(i);
}
used[x] = 1;
if(!cnt)
{
_st[++_st[0]] = ++n;
leaf[n]=1;
return n;
}
int _x = ++n;
re(i,1,_n)
if(_a[x][i])
{
int y = dfs1(i);
leaf[_x] += leaf[y] ;
if(used[i])
__a[_x][y] = 1;
else
{
leaf[++n] = leaf[y];
__a[_x][n] = __a[n][y] = 1;
}
}
return _x;
}
inline void build()
{
re(i,1,_n)
if(_a[1][i])
{
int y = dfs1(i);
leaf[1] += leaf[y];
if(used[i])
__a[1][y] = 1;
else
{
leaf[++n]=leaf[y];
__a[1][n] = __a[n][y] = 1;
}
}
}
int __dp[1<<21][25];
inline void solve1()//暴力
{
ans = 0;
re(i,1,_st[0])
re(j,1,k)
__a[ _st[i] ][ _st[ i+j>_st[0] ? i+j-_st[0] : i+j ] ] = 1;
re(i,1,n)
re(j,1,n)
if(__a[i][j])
__a[j][i] = 1;
int S = (1<<n)-1;__dp[1][1] = 1;
for(int s=1;s<=S;s+=2)
re(x,1,n)
if(__dp[s][x]&&((s>>x-1)&1))
re(y,1,n)
if(__a[x][y]&&(((s>>y-1)&1)==0))
(__dp[ s|(1<<y-1) ][y] += __dp[s][x])%=p;
re(i,1,n)
if(__a[1][i])
(ans+=__dp[S][i])%=p;
printf("%d\n",ans*(p+1)/2%p);
}
//______________________________________________
int match[40];
struct node
{
int len,root;
int s[13];
inline void clear()
{
re(i,1,13)
s[i] = 0;
len = root = 0;
}
void pri()
{
printf("%d|",root);
re(i,1,len)
printf("%d ",s[i]);
printf("\n________________________\n");
};
inline bool operator <(node x)const
{
if(root!=x.root ) return root<x.root ;
if(len!=x.len) return len<x.len ;
re(i,1,len)
if(s[i]!=x.s [i])
return s[i] < x.s [i];
return false;
}
inline void sort(int sz)
{
int cnt = 1;
re(i,2,sz)
match[i] = 0;
re(i,1,len)
if( s[i]>1 )
{
if( match[s[i]] )
s[i] = match[s[i]];
else
s[i] = match[s[i]] = ++cnt;
}
}
}node1,a[maxn*40];
map<node,int> mp;
int tot;
void dfs2(int i,int cnt)
{
if(i==node1.len +1)
{
re(j,1,node1.len )
if(node1.s [j])
{
a[++tot] = node1;
mp[node1] = tot;
break;
}
return ;
}
dfs2(i+1,cnt);
if(!node1.s [i])
{
node1.s[i] = cnt+1;
re(j,i+1,node1.len )
if(!node1.s [j])
{
node1.s [j] = cnt+1;
dfs2(i+1,cnt+1);
node1.s [j] = 0;
}
node1.s [i] = 0;
}
}
struct turn
{
int x,y,to;
turn(int x=0,int y=0,int to=0)
:x(x),y(y),to(to){};
}f[7][7][210000];
int tot_f[7][7];
#define pr pair<int,int>
int st1[maxn*200],_st2;
pr st2[maxn*200];
inline void check(node x,int cnt)
{
re(i,k+1,x.len-k)
if(x.s [i])
return;
node1.clear();
if(x.len>2*k)
{
node1.root = x.root ;
node1.len = 2*k;
re(i,1,k)
node1.s [i] = x.s[i];
re(i,x.len -k+1,x.len )
node1.s [i- x.len + 2*k] = x.s[i] ;
}
else node1 = x;
node1.sort(cnt);
if( mp[node1]!=0 )
st1[++st1[0]] = mp[node1];
}
int vis[20];
inline void dfs3(int i,int lim,node x,int cnt)
{
if(i==lim+1)
{
check(x,cnt+1);
return ;
}
dfs3(i+1,lim,x,cnt+1);
if(x.s [i]>0)
re(j,lim+1,min(x.len ,i+k))
{
int n_col = x.s [j]==1 || x.s[i] == 1 ? 1 : cnt+1;
if(x.s [j]>0&&(x.s [j]!=x.s [i]||x.s[j]==1 ))
{
node1 = x;
re(t,1,x.len )
if( x.s [t] == x.s [i] || x.s [t]==x.s [j] )
node1.s[t] = n_col;
node1.s [i] = node1.s [j] = 0;
dfs3(i+1,lim,node1,cnt+1);
}
if(x.s [j]==-1)
{
node1 = x;
node1.s [j] = node1.s [i];
node1.s [i] = 0;
dfs3(i+1,lim,node1,cnt+1);
}
}
if(x.s [i]==-1)
re(j,lim+1,min(x.len,i+k))
if(x.s [j])
{
node1 = x;
if(x.s [j]==-1)
node1.s [i] = node1.s [j] = cnt+1;
else
node1.s [i] = node1.s [j],
node1.s [j] = 0;
dfs3(i+1,lim,node1,cnt+1);
re(t,j+1,min(x.len,i+k))
if(x.s[t]==-1 ||(x.s[t]>0&&(x.s[t]!=x.s[j] || x.s[t]==1)))
{
node1 = x;
node1.s [i] = 0;
if( x.s[j]==-1 && x.s[t] ==-1 )
node1.s [j] = node1.s [t] = cnt+1 ;
if( x.s[j]==-1 && x.s[t]!=-1)
node1.s [j] = node1.s [t] , node1.s [t] = 0;
if( x.s [j]!=-1&& x.s[t]==-1)
node1.s [t] = node1.s [j] , node1.s [j] = 0;
if( x.s [j]!=-1&&x.s[t]!=-1)
{
int n_col = x.s [j]==1 || x.s[t] == 1 ? 1 : cnt+1;
re(_t,1,x.len)
if(x.s[_t]==x.s[j]|| x.s[_t] == x.s[t] )
node1.s [_t] = n_col ;
node1.s [j] = node1.s [t] = 0;
}
dfs3(i+1,lim,node1,cnt+1);
}
}
}
inline void Turn(node l,node r)
{
if(l.root + r.root > 2)
return;
node1.clear();
re(i,1,l.len )
node1.s[i] = l.s [i];
re(i,1,r.len )
node1.s [ i+l.len ] = r.s [i] <=1 ? r.s [i] : r.s [i] + k ;
node1.root = l.root + r.root ;
node1.len = l.len + r.len ;
dfs3( max(1,l.len -k+1),l.len , node1 ,2*k+1 );
}
inline int _Turn(int x)
{
if(a[x].root == 0)
return 0;
if(a[x].root == 1)
return x;
node1 = a[x];
node1.root = 0;
re(i,1,node1.len )
if(node1.s [i]==1)
node1.s [i] = k+2;
node1.sort(k+2);
return mp[node1];
}
pr _f[maxn*40];int _tot_f,cnt_5;
ll dp[maxn][maxn*3],_dp[maxn*3];
inline ll count(int i,node x,int col)
{
if(x.len !=2*k || x.root !=2)
return 0;
if(i==k+1)
{
re(t,1,2*k)
if(x.s[t])
return 0;
return 1;
}
int cnt = 0;
if(!x.s[i])
cnt = count(i+1,x,col);
if(x.s[i]>0)
re(j,k+i,2*k)
if(x.s[j])
{
node1 = x;
if( x.s[j]>0 &&( x.s [j]==1 || x.s[i] !=x .s[j]))
{
int n_col = x.s [i]==1 || x.s [j] ==1 ? 1 : col +1;
re(t,1,2*k)
if( x.s [t] == x.s[i] || x.s[t]== x.s[j] )
node1.s [t] = n_col;
node1.s [i] = node1.s [j] = 0;
cnt += count(i+1,node1,col+1);
}
if( x.s[j]==-1 )
{
node1.s [j] = node1.s [i];
node1.s [i] = 0;
cnt += count(i+1,node1,col+1);
}
}
if(x.s[i]==-1)
{
re(j,k+i,2*k)
if(x.s[j])
{
re(t,j+1,2*k)
{
node1 =x;
if( x.s[t]==-1 ||(x.s[t]>0 && (x.s[t]==1|| x.s[t]!=x.s[j] )))
{
node1.s [i] = 0;
if( x.s[j]==-1 && x.s[t] ==-1 )
node1.s [j] = node1.s [t] = col+1;
if( x.s[j]==-1 && x.s[t]!=-1)
node1.s [j] = node1.s [t] , node1.s [t] = 0;
if( x.s [j]!=-1&& x.s[t]==-1)
node1.s [t] = node1.s [j] , node1.s [j] = 0;
if( x.s [j]!=-1&&x.s[t]!=-1)
{
int n_col = x.s [j]==1 || x.s[t] == 1 ? 1 : cnt+1;
re(_t,1,x.len)
if(x.s[_t]==x.s[j]||x.s[_t]==x.s[t])
node1.s [_t] = n_col ;
node1.s [j] = node1.s [t] = 0;
}
cnt += count( i+1,node1,col+1 );
}
}
}
}
return cnt;
}
void dfs4(int x)
{
int bo = 0,cnt_leaf=0;
re(y,1,n)
if(__a[x][y])
{
dfs4(y);
if(!bo)
{
memcpy(dp[x],dp[y],sizeof(dp[x]));
bo = 1;
cnt_leaf = leaf[y];
}
else
{
memcpy(_dp,dp[x],sizeof(_dp));
memset(dp[x],0,sizeof(dp[x]));
re(t,1,tot_f[ cnt_leaf ][ leaf[y] ])
{
turn _ = f[ cnt_leaf ][ leaf[y] ][t] ;
( dp[x][ _.to ] += _dp[ _.x ] * dp[y][ _.y ])%=p;
}
cnt_leaf = min( cnt_leaf + leaf[y] , 2*k);
}
}
if(!bo)
{
dp[x][1] = dp[x][2] = 1;
return ;
}
if(x!=1)
{
memcpy(_dp,dp[x],sizeof(_dp));
memset(dp[x],0,sizeof(dp[x]));
re(t,1,_tot_f)
( dp[x][ _f[t].second ] += _dp[ _f[t].first ] )%=p ;
}
else
re(t,1,tot)
{
( ans += count(1,a[t],k+1) * dp[x][t] )%=p;
}
}
int main()
{
//freopen("0.in","r",stdin);
_n = read(),k=read();
re(i,2,_n)
_a[read()][i] = 1;
n = 1;
build();
re(i,1,n)
leaf[i] = min( leaf[i],2*k );
if(n<=21)
{
solve1();
return 0;
}
re(i,1,2*k)
{
node1.len = i;
re(s,0,(1<<i)-1)
{
re(w,1,i)
node1.s [w] = - ((s>>w-1)&1);
node1.root = 0;
dfs2(1,1);
re(j,1,i)
if(node1.s [j]!=-1)
{
node1.root = 1;
node1.s [j] = 1;
dfs2(1,1);
node1.root = 2;
re(k,j+1,i)
if(node1.s [k]!=-1)
{
node1.s[k] = 1;
dfs2(1,1);
node1.s[k] = 0;
}
node1.s[j]=0;
}
}
}
node1.clear();
node1.root = 2;
node1.len = k*2;
a[++tot] = node1;
mp[node1] = tot;
re(i,1,tot)
re(j,1,tot)
{
st1[0] = 0;
Turn(a[i],a[j]);
if(st1[0])
re(t,1,st1[0])
f[ a[i].len ][ a[j].len ][ ++tot_f[a[i].len ][a[j].len ]] = turn(i,j,st1[t]);
}
re(i,1,tot)
{
int to = _Turn(i);
if(to)
_f[++_tot_f] = make_pair( i,to );
}
dfs4(1);
printf("%lld\n",ans);
return 0;
}