题解:P7531 [USACO21OPEN] Routing Schemes P

· · 题解

思路

很容易想到欧拉回路。

我们考虑增加一个超级源点 C,将 C 与所有起点之间连一条有向边,再增加一个超级汇点 D,将所有终点与 D 之间连一条有向边,DC 之间再连边,就变成了一个有向图欧拉回路计数问题。

于是考虑 BEST 定理。

BEST 定理

G 为有向欧拉图,k 为任意顶点,则 G 的不同欧拉回路个数

\operatorname{euc}(G)=t^{\mathrm{in}}(G,k)\prod\limits_{v\in V}(\deg(v)-1)!

BEST 定理表明,对于 \forall k,k'\in V_G,都有 t^{\mathrm{in}}(G,k)=t^{\mathrm{in}}(G,k')

证明

考虑图 G 的任意一棵内向树,对于每个节点 u,我们给以 u 为起点的所有不在内向树上的 \deg(u)-1出边一个顺序,称这个根向树及这个出边的排列顺序为一个组合。

于是我们只需要证明组合和欧拉回路一一对应。

考虑从根节点开始,每到达一个节点,若不在内向树上的出边都被走过了,就沿着根向树上的边走向其父亲,否则就按照出边的排列顺序走向下一个节点。

注意到这样只会经过每个节点至多一次,现在证明这样会经过且仅经过每个节点一次

不妨设到达节点 u 后无法移动,考虑分类讨论。

u 不是根节点,我们经过 u 时会经过其一条入边和一条出边,而无法移动说明只经过了 u 的一条入边,说明 \deg^{\mathrm{in}}(u)=\deg^{\mathrm{out}}(u)+1,与 G 为欧拉图矛盾。

这样我们就证明了这种方案一定会形成一个欧拉回路。

现在我们证明了一个组合对应一个欧拉回路,接下来考虑证明一个欧拉回路对应一个组合。

e_uu 最后访问的入边,下面证明所有 e_u 构成一棵内向树。

不妨设 e_u 构成的图中有环,首先根节点必然不会出现在环上。现在环上找出任意一个节点 u,容易发现 u 沿着环的方向可以再次回到 u。由于原图是欧拉图,\deg^{\mathrm{in}}(u)=\deg^{\mathrm{out}}(u),而 u 在环上回到 u 会导致 \deg^{\mathrm{in}}(u)=\deg^{\mathrm{out}}(u)+1,矛盾,故所有 e_u 构成树。

于是一个组合和一个欧拉回路一一对应

由上文我们知道

\operatorname{euc}(G)=t^{\mathrm{in}}(G,k)\prod\limits_{v\in V}(\deg(v)-1)!

在本题中即为

\operatorname{euc}(G)=t^{\mathrm{in}}(G,S_b)\deg(S_b)(\deg (S_b)-1)!(\deg(S_e)-1)!\prod\limits_{i=1}^{n}(\deg(i)-1)!

其中 S_b 是超级源点,S_e 是超级汇点。

但是仅仅套用 BEST 定理还不够。

而本题中由于我们建立了超级源点和超级汇点,与这两个点相连的边的决策带来的额外方案数不能被算入答案中,由排列的基本知识可以得到这样会令答案额外乘

\deg(S_b)!\deg(S_e)!

注意 e=(S_b,S_e) 不计入两点度数。

于是有

\begin{aligned}ans(G)&=\frac{t^{\mathrm{in}}(G,S_b)\deg(S_b)(\deg (S_b)-1)!(\deg(S_e)-1)!\prod\limits_{i=1}^{n}(\deg(i)-1)!}{\deg(S_b)!\deg(S_e)!}\\&=\frac{t^{\mathrm{in}}(G,S_b)\prod\limits_{i=1}^{n}(\deg(i)-1)!}{\deg(S_e)}\end{aligned}

显然,\deg(S_e) 即为给定的字符串中 S 的个数。

注意到题目描述中的一句话:每条边使用恰好一次,所以要删除孤立点后再计算。

代码

#include <bits/stdc++.h>
//#include <bits/extc++.h>
#define int long long
#define mod (int)(1e9+7)
#define __MULTITEST__
//#undef __MULTITEST__
using namespace std;
//using namespace __gnu_cxx;
//using namespace __gnu_pbds;
int a[105][105];
string s[105];
int ind[105],oud[105],iid[105];
int cntiid;
int quick_pow(int a,int b,int p)
{
    int ret=1;
    while(b)
    {
        if(b&1)
            ret=ret*a%p;
        a=a*a%p;
        b>>=1;
    }
    return ret;
}
void updeg(int u,int v)
{
    ind[v]++;
    oud[u]++;
}
void upcon(int u,int v)
{
    u=iid[u];
    v=iid[v];
    a[u][v]=(a[u][v]+mod-1)%mod;
    a[v][v]=(a[v][v]+1)%mod;
}
int det(int n)
{
    int ans=1,f=1;
    for(int i=1;i<=n;i++)
        for(int j=i+1;j<=n;j++)
        {
            while(a[i][i])
            {
                int tmp=a[j][i]/a[i][i];
                for(int k=i;k<=n;k++)
                    a[j][k]=(a[j][k]-a[i][k]*tmp%mod+mod)%mod;
                for(int k=1;k<=n;k++)
                    swap(a[i][k],a[j][k]);
                f=-f;
            }
            for(int k=1;k<=n;k++)
                swap(a[i][k],a[j][k]);
            f=-f; 
        }
    for(int i=1;i<=n;i++)
        if(a[i][i])
            ans=ans*a[i][i]%mod;
    return (ans*f%mod+mod)%mod; 
}
int n,k,ans;
int fac[305];
signed main()
{
    fac[0]=1;
    for(int i=1;i<=300;i++)
        fac[i]=fac[i-1]*i%mod; 
    #ifdef __MULTITEST__
    signed T;
    scanf("%d",&T);
    while(T--)
    {
    #endif
        memset(ind,0,sizeof ind);
        memset(oud,0,sizeof oud);
        memset(iid,0,sizeof iid);
        memset(a,0,sizeof a);
        cntiid=0;
        scanf("%lld%lld",&n,&k);
        string st;
        cin>>st;
        st='#'+st;
        for(int i=1;i<=n;i++)
        {
            if(st[i]=='S')
            {
                updeg(n+1,i);
                updeg(n+2,n+1);
            }
            if(st[i]=='R')
                updeg(i,n+2);
        }
        for(int i=1;i<=n;i++)
        {
            cin>>s[i];
            s[i]='#'+s[i];
            for(int j=1;j<=n;j++)
                if(s[i][j]=='1')
                    updeg(i,j);
        }
        for(int i=1;i<=n+2;i++)
            if(ind[i]!=oud[i])
            {
                printf("0\n");
                goto label;
            }
        for(int i=1;i<=n+2;i++)
            if(i!=n+1&&ind[i])
                iid[i]=++cntiid;
        for(int i=1;i<=n;i++)
        {
            if(st[i]=='S')
            {
                upcon(n+1,i);
                upcon(n+2,n+1);
            }
            if(st[i]=='R')
                upcon(i,n+2);
        }
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
                if(s[i][j]=='1')
                    upcon(i,j);
        ans=det(cntiid);
        for(int i=1;i<=n;i++)
            if(oud[i])
                ans=(ans*fac[oud[i]-1]%mod+mod)%mod;
        ans=ans*quick_pow(oud[n+1],mod-2,mod)%mod;
        printf("%lld\n",ans);
        label:
            ;
    #ifdef __MULTITEST__
    }
    #endif
    return 0;
}