题解 CF1254E 【Send Tree to Charlie】

· · 题解

一道跟CSP的某道题很像的题。

首先可以发现,最后构成的树的形态一样,当且仅当对于每个点而言它连出去的边的操作的相对顺序固定。

所以可以考虑把每个点连出去的边的相对顺序的方案数求出来之后,每个点的方案乘起来就是答案了。

考虑一条路径对这些边的相对顺序有什么影响:

于是可以发现,对于每个点连出去的边,一定构成了若干条链(如果不构成链显然不合法),这些链除了钦定的起点和终点都是可以任意排列的,于是这个点的方案就是可以任意移动的链的数量的阶乘。

那么对于每个点都计算出来方案乘起来即可。

实现有点恶心。

复杂度O(n\log n)

#include<bits/stdc++.h>
#define maxn 500010
#define mod 1000000007
using namespace std;
typedef long long ll;
typedef pair<int,int>pii;
int read()
{
    int x=0,f=1;
    char ch=getchar();
    while(ch-'0'<0||ch-'0'>9){if(ch=='-') f=-1;ch=getchar();}
    while(ch-'0'>=0&&ch-'0'<=9){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int n;
int deg[maxn];
int head[maxn],nxt[maxn*2],to[maxn*2],tot;
void add(int u,int v)
{
    tot++;
    nxt[tot]=head[u];
    head[u]=tot;
    to[tot]=v;
}
int dep[maxn],up[maxn][22],fa_id[maxn],fa[maxn];
void dfs(int x,int las)
{
    for(int i=head[x];i;i=nxt[i])
    {
        if(to[i]==las)  continue;
        fa[to[i]]=x;
        dep[to[i]]=dep[x]+1;
        up[to[i]][0]=x;
        fa_id[to[i]]=(i+1)/2;
        dfs(to[i],x);
    }
}
int LCA(int x,int y)
{
    if(dep[x]<dep[y])  swap(x,y);
    for(int i=20;i>=0;i--)
      if(dep[y]+(1<<i)<=dep[x])  x=up[x][i];
    if(x==y)  return x;
    for(int i=20;i>=0;i--)
      if(up[x][i]!=up[y][i])  x=up[x][i],y=up[y][i];
    return up[x][0];
}
int get_dis(int x,int y)
{
    return dep[x]+dep[y]-2*dep[LCA(x,y)];
}
int a[maxn];
int en[maxn],st[maxn];
vector<pii>E[maxn];
vector<int>V,e[maxn];
int get_path(int x,int lca,int ty)
{
    if(x==lca)  return 0;
    while(fa[x]!=lca)
    {
        int y=fa[x];
        if(!ty)  E[y].push_back(mp(fa_id[x],fa_id[fa[x]]));
        else     E[y].push_back(mp(fa_id[fa[x]],fa_id[x]));
        x=y;
    }
    return x;
}
int deg_in[maxn],deg_out[maxn],col[maxn];
bool check()
{
    for(int i=0;i<V.size();i++)
      if(deg[V[i]]>2)  return true;
    return false;
}
queue<int>q;
bool top_sort(int id)
{
    for(int i=0;i<V.size();i++)  deg[V[i]]=0;
    for(int i=0;i<E[id].size();i++)  deg[E[id][i].se]++;
    for(int i=0;i<V.size();i++)
      if(!deg[V[i]])  q.push(V[i]),col[V[i]]=V[i];
    int cnt=0;
    while(q.size())
    {
        int now=q.front();q.pop();cnt++;
        for(int i=0;i<e[now].size();i++)
        {
            deg[e[now][i]]--;
            if(!deg[e[now][i]])  col[e[now][i]]=col[now],q.push(e[now][i]);
        }
    }
    return (cnt!=V.size());
}
int fac[maxn];
int solve(int id)
{
    V.clear();
    for(int i=head[id];i;i=nxt[i])  V.push_back((i+1)>>1);
    for(int i=0;i<V.size();i++)  deg_in[V[i]]=deg_out[V[i]]=deg[V[i]]=0;
    for(int i=0;i<V.size();i++)  e[V[i]].clear();
    for(int i=0;i<E[id].size();i++)
    {
        deg_out[E[id][i].fi]++,deg_in[E[id][i].se]++;
        deg[E[id][i].fi]++,deg[E[id][i].se]++;
        e[E[id][i].fi].push_back(E[id][i].se);
    }
    if(check())  puts("0"),exit(0);
    if(top_sort(id))  puts("0"),exit(0);
    if(st[id]&&deg_in[st[id]])  puts("0"),exit(0);
    if(en[id]&&deg_out[en[id]])  puts("0"),exit(0);
    int cnt=0;
    for(int i=0;i<V.size();i++)
      if(!deg_in[V[i]])  cnt++;
    if(st[id]&&en[id]&&col[st[id]]==col[en[id]])
    {
        if(cnt==1)  return 1;
        else        puts("0"),exit(0);
    }
    return fac[cnt-(st[id]!=0)-(en[id]!=0)];
}
int ans=1;
int main()
{
    n=read();
    fac[0]=1;
    for(int i=1;i<=n;i++)  fac[i]=1ll*fac[i-1]*i%mod;
    for(int i=1;i<n;i++)  
    {
        int q=read(),w=read();
        add(q,w);add(w,q);
    }
    dfs(1,0);
    for(int j=1;j<=20;j++)
      for(int i=1;i<=n;i++)  up[i][j]=up[up[i][j-1]][j-1];
    for(int i=1;i<=n;i++)  a[i]=read();
    ll sum=0;
    for(int i=1;i<=n;i++)
      if(a[i])  sum+=get_dis(i,a[i]);
    if(sum>2*n-2)  return puts("0"),0;
    for(int i=1;i<=n;i++)
    {
        if(a[i])
        {
            if(i==a[i])   return puts("0"),0;
            int lca=LCA(i,a[i]);
            int u=get_path(a[i],lca,0);
            int v=get_path(i,lca,1);
            if(lca!=i&&lca!=a[i])  E[lca].push_back(mp(fa_id[u],fa_id[v]));
            if(st[a[i]]||en[i])  return puts("0"),0;
            if(lca==a[i])  st[a[i]]=fa_id[v];
            else           st[a[i]]=fa_id[a[i]];
            if(lca==i)     en[i]=fa_id[u];
            else           en[i]=fa_id[i];
        }
    }
    for(int i=1;i<=n;i++)  ans=1ll*ans*solve(i)%mod;
    cout<<ans<<endl;
    return 0;
}