题解:P7607 [THUPC 2021] 赌徒问题

· · 题解

想了很久自己为什么跑这么慢,结果发现是完全背包多写了一个 n,没有人类了。

考虑枚举 s,然后每次要做一个完全背包,可以获得一个 O(nm^3) 的做法。

看起来不太好优化,但是考虑到这里可能有很多数被重复加入删除,可以想到对于所有 (i,j) 连边,边权是从 s=is=j 需要修改的物品个数,然后跑一个最小生成树,记边权和为 w,那么后面部分复杂度可以做到 O(nmw)。发现 m=2000 的时候 w=8006,可以轻松通过。

最后说一下怎么求 MST。发现只保留 (i,ki)(k\in \mathbb Z) 的边也是可以的,边权可以 O(m) 或者 O(\frac{m}w) 计算,那求 MST 的复杂度就是 O(\frac{m^2\log n}w)

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=1e9+7;
int n,m,k,ans,f[12][2005],fd[2005],tot;
bitset<2005> vis,e[2005];
int find(int x){
    return fd[x]?fd[x]=find(fd[x]):x;
}
struct edge{
    int u,v,w;
    bool operator <(edge &x) const{
        return w<x.w;
    }
}ed[50005];
vector<int> g[2005];
inline void add(int x){
    if(vis[x]) return ;
    vis[x]=1;
    for(int i=0;i<n;i++)
        for(int k=0;k+x<=m;k++)
            f[i+1][k+x]=(f[i+1][k+x]+f[i][k])%mod;
}
inline void del(int x){
    if(!vis[x]) return ;
    vis[x]=0;
    for(int i=n-1;~i;i--)
        for(int k=0;k+x<=m;k++)
            f[i+1][k+x]=(f[i+1][k+x]-f[i][k])%mod;
}
void dfs(int x,int fa){
    for(int i=1;i<=m;i++)
        if(1ll*k*x%i==0)
            add(i);
    ans=(ans+f[n][x])%mod;
    for(int v:g[x])
        if(v^fa){
            int g[12][2005];
            bitset<2005> nw=vis;
            memcpy(g,f,sizeof(g));
            for(int i=1;i<=m;i++)
                if(1ll*k*x%i==0&&1ll*k*v%i)
                    del(i);
            dfs(v,x);
            memcpy(f,g,sizeof(f)),vis=nw;
        }
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0);
    cin>>n>>m>>k;
//  n=10,m=2000,k=440000000;
    vis=0;
    for(int i=1;i<=m;i++)
        for(int j=1;j<=m;j++)
            e[i][j]=(1ll*i*k%j==0);
    for(int i=1;i<=m;i++)
        for(int j=i+i;j<=m;j+=i)
            ed[++tot]={i,j,(e[i]^e[j]).count()};
    sort(ed+1,ed+tot+1);
    int all=0;
    for(int i=1;i<=tot;i++){
        int u=ed[i].u,v=ed[i].v;
        if(find(u)==find(v)) continue;
        all+=ed[i].w,g[u].push_back(v),g[v].push_back(u);
        fd[find(u)]=find(v);
    }
    cerr<<all<<'\n';
    f[0][0]=1;
    dfs(1,0);
    cout<<ans<<'\n';
    return 0;
}