题解 P5406 【[THUPC2019]找树】

· · 题解

首先,这不是一道最优化题,这是一道计数题,如果对每个 i,我们求出权值为 i 的生成树的数量,那么答案就是最大的生成树数量不为 0i

那么问题就转化成了:求权值为 i 的生成树数量。

首先矩阵树定理是少不了的,我们知道,矩阵树定理实际上求的是

\sum\limits_{T}\prod_{e\in T}w_{Te}

注意到这个 \prod,它不一定非得是数相乘,只要保证 <W,+,\times> 构成一个环,那么这个式子就成立。其中 W 是边权所在的全集。

所以我们可以对于每一条边构造一个集合幂级数 f_v=x^{v},乘法定义为输入的集合卷积(如果这一位为或,这一位就是集合或卷积,其它同理),这样求出的行列式的 x^v 项的系数就是权值为 v 的生成树个数。(集合幂级数形成了一个交换环)

具体实现我们注意到 FWT 是线性变换,所以我们一开始先把矩阵做好 FWT,求完行列式再 IFWT 回去,而且直接求生成树的数量太大,我们可以对大质数取模,复杂度是 O(n^32^w),信仰跑吧。关于集合卷积如何按位做不明白的可以看代码。

代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll const p=1e9+7,inv=500000004;
int n,m,w;
ll mat[75][75][4097],r[75][75],c[4097];
string str;
ll mod(ll x){return x>=p?x-p:x;}
ll pw(ll x,ll y)
{
    ll res=1;
    while(y)
    {
        if(y&1)res*=x,res%=p;
        x*=x,x%=p;
        y>>=1;
    }
    return res;
}
void fwt(ll *f,int n,int op)
{
    for(int len=2,b=0;len<=n;len<<=1,b++)
    {
        int q=(len>>1);
        if(str[b]=='|')
        {
            for(int i=0;i<n;i+=len)
                for(int j=i;j<i+q;j++)
                    if(op==1)f[j+q]=mod(f[j+q]+f[j]);
                    else f[j+q]=mod(f[j+q]+p-f[j]);
        }
        else if(str[b]=='&')
        {
            for(int i=0;i<n;i+=len)
                for(int j=i;j<i+q;j++)
                    if(op==1)f[j]=mod(f[j]+f[j+q]);
                    else f[j]=mod(f[j]+p-f[j+q]);
        }
        else
        {
            for(int i=0;i<n;i+=len)
                for(int j=i;j<i+q;j++)
                {
                    f[j]=mod(f[j]+f[j+q]);
                    f[j+q]=(f[j]-2*f[j+q]+2*p)%p;
                    if(op==-1)f[j]*=inv,f[j]%=p,f[j+q]*=inv,f[j+q]%=p;
                }
        }
    }
}
ll det()
{
    ll ans=1;
    for(int i=1;i<n;i++)
    {
        for(int j=i+1;j<n;j++)
            if((!r[i][i])&&r[j][i]){ans=p-ans,swap(r[i],r[j]);break;}
        if(!r[i][i])return 0;
        ll t=pw(r[i][i],p-2);
        for(int j=i+1;j<n;j++)
        {
            ll t2=t*r[j][i]%p;
            for(int k=i;k<n;k++)
                r[j][k]=mod(r[j][k]-t2*r[i][k]%p+p);
        }
    }
    for(int i=1;i<n;i++)ans*=r[i][i],ans%=p;
    return ans;
}
int main()
{
    int x,y,v;
    scanf("%d%d",&n,&m);
    cin>>str;
    w=str.size();
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d",&x,&y,&v);
        mat[x][y][v]--,mat[y][x][v]--;
        mat[x][x][v]++,mat[y][y][v]++;
    }
    for(int i=1;i<n;i++)
        for(int j=1;j<n;j++)
        {
            for(int k=0;k<(1<<w);k++)
                if(mat[i][j][k]<0)mat[i][j][k]+=p;
        }
    for(int i=1;i<n;i++)
        for(int j=1;j<n;j++)
            fwt(mat[i][j],(1<<w),1);
    for(int i=0;i<(1<<w);i++)
    {
        memset(r,0,sizeof(r));
        for(int j=1;j<=n;j++)
            for(int k=1;k<=n;k++)
                r[j][k]=mat[j][k][i];
        c[i]=det();
    }
    fwt(c,(1<<w),-1);
    for(int i=(1<<w)-1;i>=0;i--)
        if(c[i]){printf("%d",i);return 0;}
    puts("-1");
    return 0;
}