题解:P10707 永恒(Eternity)

· · 题解

切的最快的一个,感觉比 A 简单。

题目看起来就很线性基。我们先统计有多少种典型基能凑出 m。由于统计的是典型基的数量,因此不需要去重,且一定是全选。考虑对 m 的二进制位逐位 dp,每次进行的操作是选择是否插入一个最高位为 i 的典型基,并确定之前插入的典型基第 i 位的值。

具体来说,假设 m 的第 i 位是 1,那么要么插入一个最高位为 i 的典型基,要么将之前插入的典型基中选奇数个,将他们的第 i 位赋为 1。设 f_{i,j} 表示 dp 到第 i 位,已经插入了 j 个典型基的方案数,则转移方程就是 f_{i,j} = f_{i+1,j-1} + 2^{j-1} f_{i+1,j}。如果 m 的第 i 位是 0,则只能在之前的典型基中选择偶数个,将他们的第 i 位赋值为 1,转移方程是 f_{i,j} = 2^{j-1} f_{i+1,j}。注意特判 j=0 的情况。

现在我们已经确定了典型基,还要做的就是根据典型基确定原序列。假设典型基的大小为 S,一个简单的想法就是,在这个典型基能凑出的 2^S 个数中可重复的选出 n 个,即 \binom{2^S + n - 1}{n} 种方案,但是这样是不行的。因为选出的 n 个数构成的典型基可能只是这个典型基的一个子基。那么只需要去掉这些序列就行了。设 g_i 表示大小为 i 的典型基能构成多少种序列,则枚举 j,表示选择一个大小为 j 的子基,然后减去他能构成的序列即可。

现在问题就是求出大小为 n 的子集有多少个大小为 m 的子基。设这个答案为 h_{n,m},不难发现,你从小往大考虑,每次考虑第 i 小的典型基是否需要被删除,如果不被删除,则他可以异或上前面任意多个被删除的基。也就是说,转移方程为 h_{i,j} = h_{i-1,j} + 2^{i-j} h_{i-1,j-1}。而 g_i = \binom{2^S + n - 1}{n} - \sum\limits_{j=0}^{i-1} h_{i,j} g_j。最后输出 \sum\limits_{i=0}^{60} f_{0,i} g_i 即可,复杂度 O(60n),瓶颈在于算 \binom{2^S + n - 1}{n}

Code:

#include<bits/stdc++.h>
#define ll long long
#define pn putchar('\n')
#define mclear(a) memset(a,0,sizeof a)
#define fls() fflush(stdout)
#define maxn 100005
#define int ll
#define mod 998244353
using namespace std;
int re()
{
    int x=0;
    bool t=1;
    char ch=getchar();
    while(ch>'9'||ch<'0')
        t=ch=='-'?0:t,ch=getchar();
    while(ch>='0'&&ch<='9')
        x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return t?x:-x;
}
int n,m,ans;
int f[70][70],g[70],h[70][70],mi[70];
int ksm(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=ret*x%mod;
        x=x*x%mod,y>>=1;
    }
    return ret;
}
namespace ZH
{
    int jc[maxn],inv[maxn];
    void zh_init()
    {
        jc[0]=1;
        for(int i=1;i<=n;i++)
            jc[i]=jc[i-1]*i%mod;
        inv[n]=ksm(jc[n],mod-2);
        for(int i=n;i;i--)
            inv[i-1]=inv[i]*i%mod;
    }
    int A(int x,int y)
    {
        if(x<y)return 0;
        return jc[x]*inv[x-y]%mod;
    }
    int C(int x,int y)
    {
        return A(x,y)*inv[y]%mod;
    }
}
using namespace ZH;
int fu(int x)
{
    return x&1?-1:1;
}
signed main()
{
    n=re(),m=re();
    zh_init();
    mi[0]=1;
    for(int i=1;i<=60;i++)
        mi[i]=mi[i-1]*2%mod;
    f[60][0]=1;
    for(int i=59;~i;i--)
    {
        if(m>>i&1)
        {
            for(int j=1;j<=60;j++)
                f[i][j]=(f[i+1][j-1]+mi[j-1]*f[i+1][j])%mod;
        }
        else
        {
            f[i][0]=f[i+1][0];
            for(int j=1;j<=60;j++)
                f[i][j]=mi[j-1]*f[i+1][j]%mod;
        }
    }
    for(int i=0;i<=60;i++)
    {
        int x=1;
        for(int j=1;j<=n;j++)
            x=(mi[i]+n-j)%mod*x%mod;
        g[i]=x*inv[n]%mod;
    }
    h[0][0]=1;
    for(int i=1;i<=60;i++)
    {
        h[i][0]=1;
        for(int j=1;j<=i;j++)
            h[i][j]=(h[i-1][j-1]+mi[j]*h[i-1][j])%mod;
    }
    for(int i=0;i<=60;i++)
    {
        for(int j=1;j<=i;j++)
            (g[i]-=h[i][j]*g[i-j])%=mod;
        (ans+=g[i]*f[0][i])%=mod;
    }
    if(ans<0)
        ans+=mod;
    printf("%lld",ans);
    return 0;
}