P13495

· · 题解

题意描述

有一个 n\times m01 矩阵 a,其中有 k 个位置已经填好了数,问有多少种方案把矩阵剩下的位置填满,使得每行每列都有奇数个 1,对 998244353 取模。

解法

这里介绍一个时间复杂度严格线性的做法。

考虑将 n 行的限制看做二分图的 n 个左部点,将 m 列的限制看做二分图的 m 个右部点,然后每个位置 (i,j) 看做第 i 个左部点和第 j 个右部点之间连了一条边权为 a_{i,j} 的无向边。要求变为每个点周围都连有奇数条边权为 1 的边。

考虑把已经固定好边权的 k 条边的边权先贡献给两边的点,然后直接把这些边删掉。

考虑剩下的图中的一个连通块,如果这个连通块内左部点要求连有奇数条边权为 1 的边的点数和右部点要求连有奇数条边权为 1 的边的点数的奇偶性不同,那么显然无解,方案数为 0

否则,考虑任选一个生成树,不在生成树上的边的边权可以任意填。然后考虑生成树的一个叶子节点,不难发现这个叶子节点向外连的边的边权已经可以由这个叶子节点的限制唯一确定,可以把这个叶子结点和这个叶子节点向外连的边一起删除。最后剩下两个点的时候两个点的的限制肯定是相同的,所以这种情况的方案数就是 2^{|E|-|V|+1}。其中 |E| 表示边数,|V| 表示点数。

最终的方案数就是所有连通块的方案数的乘积,如果不是 0,那就是 2^{\sum|E|-\sum|V|+cnt}=2^{nm-k-n-m+cnt}cnt 表示连通块的个数)。

接下来考虑如何求出每个点属于哪个连通块。

考虑使用若干个双向链表维护所有的右部点,每个双向链表维护一个连通块内的所有右部点,所有双向链表之间也直接用双向链表串起来。

接下来按顺序加入左部点,每加入一个左部点就按顺序扫描所有链表,一旦在某个链表内找到一个和当前左部点有边相连的右部点就将这个链表标记为与当前左部点连通,然后立刻跳到下一个链表,最后把所有标记过的链表合并为一个链表。不难证明这样做的复杂度是 O(n+m+k)

这样做求出所有右部点连通块,如法炮制也可以求出所有左部点连通块。最后只需枚举一个右部点连通块,这个右部点连通块要么是个孤立点,不连向任何左部点,要么刚好连向一个左部点连通块。最后,找到所有左部点连通块中的孤立点就可以还原出所有连通块。

这样做的复杂度是 O(n+m+k),规避了并查集以及 \log 数据结构。但这样就是严格线性了吗?

仔细观察我们的复杂度 O(n+m+k),不难发现这个复杂度甚至不是多项式级的,又怎么会是严格线性的呢?

考虑当 nm 至少有一个超过 k 的情况下如何保持严格线性的复杂度。

如果 nm 均超过 k,那么不难发现一定只有 1 个连通块,当 nm 奇偶性相同时有解,奇偶性不同时无解。

如果 nm 其中之一超过 k,不妨假设 m 超过 k。此时除了左部点可能存在的孤立点以外一定只有一个连通块,所以只需找出左部点存在的所有孤立点即可,这个容易使用散列表做到线性。

最后,算上快速幂,我们的时间复杂度是 O(\log n+\log m+k),是严格线性的。

参考代码

以下是一份时间复杂度严格线性的参考代码。

#include<ctime>
#include<cstdio>
#include<random>
#include<cstdlib>
using namespace std;
const long long mod=998244353,p=1999993;
long long n,m,k,x[1100000],y[1100000],z[1100000],val[500000],pre1[300000],nxt1[300000],ed1[300000],to1[300000],pre2[300000],nxt2[300000],ed2[300000],h1[2000000],h2[2000000],ht[2000000],hn[2000000],hv[2000000],tt[500000],wh,ans,i,j,st,t1,t2,t3;
long long power(long long a,long long b)
{
    long long k=a,ans=1;
    while(b)
    {
        if(b&1)
        ans=ans*k%mod;
        k=k*k%mod;
        b=b>>1;
    }
    return ans;
}
void in(long long x,long long y)
{
    long long k=(x*12345678+y*12)%p;
    while(h1[k]!=0)
    k=(k+1)%p;
    h1[k]=x;
    h2[k]=y;
}
bool find(long long x,long long y)
{
    long long k=(x*12345678+y*12)%p;
    while(h1[k]!=0)
    {
        if(h1[k]==x&&h2[k]==y)
        return true;
        k=(k+1)%p;
    }
    return false;
}
long long in2(long long x,long long v)
{
    long long k=(x*114514)%p;
    while(ht[k]!=0)
    {
        if(ht[k]==x)
        {
            hv[k]=v^hv[k];
            hn[k]++;
            if(hn[k]==n||hn[k]==m)
            if(hv[k]==0)
            return 1;
            else
            return -1;
            return 0;
        }
        k=(k+1)%p;
    }
    ht[k]=x;
    hv[k]=v^1;
    hn[k]=1;
    if(hn[k]==n||hn[k]==m)
    if(hv[k]==0)
    return 1;
    else
    return -1;
    return 0;
}
main()
{
    scanf("%lld%lld%lld",&n,&m,&k);
    for(i=1;i<=k;i++)
    {
        scanf("%lld%lld%lld",&x[i],&y[i],&z[i]);
        in(x[i],y[i]);
    }
    if(n%2!=m%2)
    {
        printf("0");
        return 0;
    }
    if(k<n&&k<m)
    {
        ans=power(2,n*m-n-m-k+1);
        printf("%lld",ans);
        return 0;
    }
    if(k<n)
    {
        ans=n*m-n-m-k+1;
        for(i=1;i<=k;i++)
        {
            t1=in2(x[i],z[i]);
            if(t1==-1)
            {
                printf("0");
                return 0;
            }
            else if(t1==1)
            ans++;
        }
        ans=power(2,ans);
        printf("%lld",ans);
        return 0;
    }
    if(k<m)
    {
        ans=n*m-n-m-k+1;
        for(i=1;i<=k;i++)
        {
            t1=in2(y[i],z[i]);
            if(t1==-1)
            {
                printf("0");
                return 0;
            }
            else if(t1==1)
            ans++;
        }
        ans=power(2,ans);
        printf("%lld",ans);
        return 0;
    }
    for(i=1;i<=n+m;i++)
    val[i]=1;
    for(i=1;i<=k;i++)
    if(z[i]==1)
    {
        val[x[i]]=!val[x[i]];
        val[n+y[i]]=!val[n+y[i]];
    }
    for(i=0;i<m;i++)
    nxt1[i]=i+1;
    for(i=1;i<=m;i++)
    pre1[i]=i-1;
    pre1[0]=m;
    for(i=1;i<=m;i++)
    ed1[i]=i;
    for(i=1;i<=n;i++)
    {
        wh=0;
        st=1;
        j=1;
        while(j!=0)
        {
            if(find(i,j))
            {
                if(ed1[j]==j)
                st=nxt1[j];
                j=nxt1[j];
                continue;
            }
            if(wh==0)
            wh=st;
            else
            {
                if(nxt1[ed1[wh]]!=st)
                {
                    t1=nxt1[ed1[wh]];
                    t2=pre1[st];
                    t3=nxt1[ed1[st]];
                    pre1[st]=ed1[wh];
                    nxt1[ed1[wh]]=st;
                    pre1[t1]=ed1[st];
                    nxt1[ed1[st]]=t1;
                    nxt1[t2]=t3;
                    pre1[t3]=t2;
                }
                ed1[ed1[wh]]=0;
                ed1[wh]=ed1[st];
            }
            j=nxt1[ed1[st]];
            st=j;
        }
    }
    for(i=0;i<n;i++)
    nxt2[i]=i+1;
    for(i=1;i<=n;i++)
    pre2[i]=i-1;
    pre2[0]=n;
    for(i=1;i<=n;i++)
    ed2[i]=i;
    for(i=1;i<=m;i++)
    {
        wh=0;
        st=1;
        j=1;
        while(j!=0)
        {
            if(find(j,i))
            {
                if(ed2[j]==j)
                st=nxt2[j];
                j=nxt2[j];
                continue;
            }
            if(wh==0)
            wh=st;
            else
            {
                if(nxt2[ed2[wh]]!=st)
                {
                    t1=nxt2[ed2[wh]];
                    t2=pre2[st];
                    t3=nxt2[ed2[st]];
                    pre2[st]=ed2[wh];
                    nxt2[ed2[wh]]=st;
                    pre2[t1]=ed2[st];
                    nxt2[ed2[st]]=t1;
                    nxt2[t2]=t3;
                    pre2[t3]=t2;
                }
                ed2[ed2[wh]]=0;
                ed2[wh]=ed2[st];
            }
            j=nxt2[ed2[st]];
            st=j;
        }
    }
    st=1;
    while(st!=0)
    {
        t1=1;
        j=1;
        while(j!=0)
        {
            if(find(st,j))
            {
                if(ed1[j]==j)
                t1=nxt1[j];
                j=nxt1[j];
                continue;
            }
            break;
        }
        if(j!=0)
        to1[t1]=st;
        st=nxt2[ed2[st]];
    }
    st=1;
    i=1;
    while(i!=0)
    {
        if(val[i+n]==1)
        if(to1[st]==0)
        tt[st+n]=!tt[st+n];
        else
        tt[to1[st]]=!tt[to1[st]];
        if(ed1[i]==i)
        st=nxt1[i];
        i=nxt1[i];
    }
    st=1;
    i=1;
    while(i!=0)
    {
        if(val[i]==1)
        tt[st]=!tt[st];
        if(ed2[i]==i)
        st=nxt2[i];
        i=nxt2[i];
    }
    for(i=1;i<=n+m;i++)
    if(tt[i]!=0)
    {
        printf("0");
        return 0;
    }
    st=1;
    i=1;
    while(i!=0)
    {
        if(to1[st]==0)
        tt[st+n]=1;
        else
        tt[to1[st]]=1;
        if(ed1[i]==i)
        st=nxt1[i];
        i=nxt1[i];
    }
    st=1;
    i=1;
    while(i!=0)
    {
        tt[st]=1;
        if(ed2[i]==i)
        st=nxt2[i];
        i=nxt2[i];
    }
    ans=n*m-k-n-m;
    for(i=1;i<=n+m;i++)
    if(tt[i]==1)
    ans++;
    ans=power(2,ans);
    printf("%lld",ans);
}

这份代码实际上跑得并不是很快(参考提交记录),主要原因是常数太大,以及这题并没有把 nm 的上限开到 10^9

另外,如果不追求严格线性,代码可以写得简便很多,而且速度和上面这份代码差不多。(参考另一个提交记录,这份代码的时间复杂度是 O(k+(n+m)\log(n+m)) 的,因为并查集只加了路径压缩)