题解 CF1254E 【Send Tree to Charlie】
justin_cao · · 题解
一道跟CSP的某道题很像的题。
首先可以发现,最后构成的树的形态一样,当且仅当对于每个点而言它连出去的边的操作的相对顺序固定。
所以可以考虑把每个点连出去的边的相对顺序的方案数求出来之后,每个点的方案乘起来就是答案了。
考虑一条路径对这些边的相对顺序有什么影响:
- 起点连出去的这条边一定是这个点连出去的边中第一个操作的。
- 连向终点的这条边一定是这个点相连的边中最后一个被操作的。
- 对于中间的点,连向它的边操作之后一定马上接着它连出去的边。
于是可以发现,对于每个点连出去的边,一定构成了若干条链(如果不构成链显然不合法),这些链除了钦定的起点和终点都是可以任意排列的,于是这个点的方案就是可以任意移动的链的数量的阶乘。
那么对于每个点都计算出来方案乘起来即可。
实现有点恶心。
复杂度
#include<bits/stdc++.h>
#define maxn 500010
#define mod 1000000007
using namespace std;
typedef long long ll;
typedef pair<int,int>pii;
int read()
{
int x=0,f=1;
char ch=getchar();
while(ch-'0'<0||ch-'0'>9){if(ch=='-') f=-1;ch=getchar();}
while(ch-'0'>=0&&ch-'0'<=9){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n;
int deg[maxn];
int head[maxn],nxt[maxn*2],to[maxn*2],tot;
void add(int u,int v)
{
tot++;
nxt[tot]=head[u];
head[u]=tot;
to[tot]=v;
}
int dep[maxn],up[maxn][22],fa_id[maxn],fa[maxn];
void dfs(int x,int las)
{
for(int i=head[x];i;i=nxt[i])
{
if(to[i]==las) continue;
fa[to[i]]=x;
dep[to[i]]=dep[x]+1;
up[to[i]][0]=x;
fa_id[to[i]]=(i+1)/2;
dfs(to[i],x);
}
}
int LCA(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;i--)
if(dep[y]+(1<<i)<=dep[x]) x=up[x][i];
if(x==y) return x;
for(int i=20;i>=0;i--)
if(up[x][i]!=up[y][i]) x=up[x][i],y=up[y][i];
return up[x][0];
}
int get_dis(int x,int y)
{
return dep[x]+dep[y]-2*dep[LCA(x,y)];
}
int a[maxn];
int en[maxn],st[maxn];
vector<pii>E[maxn];
vector<int>V,e[maxn];
int get_path(int x,int lca,int ty)
{
if(x==lca) return 0;
while(fa[x]!=lca)
{
int y=fa[x];
if(!ty) E[y].push_back(mp(fa_id[x],fa_id[fa[x]]));
else E[y].push_back(mp(fa_id[fa[x]],fa_id[x]));
x=y;
}
return x;
}
int deg_in[maxn],deg_out[maxn],col[maxn];
bool check()
{
for(int i=0;i<V.size();i++)
if(deg[V[i]]>2) return true;
return false;
}
queue<int>q;
bool top_sort(int id)
{
for(int i=0;i<V.size();i++) deg[V[i]]=0;
for(int i=0;i<E[id].size();i++) deg[E[id][i].se]++;
for(int i=0;i<V.size();i++)
if(!deg[V[i]]) q.push(V[i]),col[V[i]]=V[i];
int cnt=0;
while(q.size())
{
int now=q.front();q.pop();cnt++;
for(int i=0;i<e[now].size();i++)
{
deg[e[now][i]]--;
if(!deg[e[now][i]]) col[e[now][i]]=col[now],q.push(e[now][i]);
}
}
return (cnt!=V.size());
}
int fac[maxn];
int solve(int id)
{
V.clear();
for(int i=head[id];i;i=nxt[i]) V.push_back((i+1)>>1);
for(int i=0;i<V.size();i++) deg_in[V[i]]=deg_out[V[i]]=deg[V[i]]=0;
for(int i=0;i<V.size();i++) e[V[i]].clear();
for(int i=0;i<E[id].size();i++)
{
deg_out[E[id][i].fi]++,deg_in[E[id][i].se]++;
deg[E[id][i].fi]++,deg[E[id][i].se]++;
e[E[id][i].fi].push_back(E[id][i].se);
}
if(check()) puts("0"),exit(0);
if(top_sort(id)) puts("0"),exit(0);
if(st[id]&°_in[st[id]]) puts("0"),exit(0);
if(en[id]&°_out[en[id]]) puts("0"),exit(0);
int cnt=0;
for(int i=0;i<V.size();i++)
if(!deg_in[V[i]]) cnt++;
if(st[id]&&en[id]&&col[st[id]]==col[en[id]])
{
if(cnt==1) return 1;
else puts("0"),exit(0);
}
return fac[cnt-(st[id]!=0)-(en[id]!=0)];
}
int ans=1;
int main()
{
n=read();
fac[0]=1;
for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
for(int i=1;i<n;i++)
{
int q=read(),w=read();
add(q,w);add(w,q);
}
dfs(1,0);
for(int j=1;j<=20;j++)
for(int i=1;i<=n;i++) up[i][j]=up[up[i][j-1]][j-1];
for(int i=1;i<=n;i++) a[i]=read();
ll sum=0;
for(int i=1;i<=n;i++)
if(a[i]) sum+=get_dis(i,a[i]);
if(sum>2*n-2) return puts("0"),0;
for(int i=1;i<=n;i++)
{
if(a[i])
{
if(i==a[i]) return puts("0"),0;
int lca=LCA(i,a[i]);
int u=get_path(a[i],lca,0);
int v=get_path(i,lca,1);
if(lca!=i&&lca!=a[i]) E[lca].push_back(mp(fa_id[u],fa_id[v]));
if(st[a[i]]||en[i]) return puts("0"),0;
if(lca==a[i]) st[a[i]]=fa_id[v];
else st[a[i]]=fa_id[a[i]];
if(lca==i) en[i]=fa_id[u];
else en[i]=fa_id[i];
}
}
for(int i=1;i<=n;i++) ans=1ll*ans*solve(i)%mod;
cout<<ans<<endl;
return 0;
}