题解:P6962 [NEERC 2017] Knapsack Cryptosystem

· · 题解

做法

首先,如果 b_i 是任意的数,那么本题就是一个值域巨大的背包问题,显然不可做。这提示我们必须利用 b_i 的生成方式所带来的性质。发现 a_i 增长非常快,是 2^n 级别的,所以当 n 较大时,合法的 a_1 个数会很少。考虑对 n 大和 n 小时分别做。

n\leq 40 时,直接使用折半搜索即可做到 O(2^{\frac{n}{2}})。具体地,把数组分成两半,每次只要枚举 \frac{n}{2} 个数的方案。然后在第一半的和里枚举,使用 umap 查找第二段里面对应的方案即可。

n>40 时,可能的 a_1 至多 2^{64-n} 种。直接枚举 a_1,使用 exgcd 解不定方程 a_1\times r\equiv b_1 \pmod {2^{64}} 即可求出题目中的 r,然后即可求出 a_1 和不取模的 s(题目中保证了 s<2^{64},直接乘 r 的逆元即可)。注意求 r 的时候因为 a_1 可能与 2^{64} 不互质,因此可能有多个 a_1。对于有多个 a_1 的情况,直接枚举即可。

枚举的复杂度证明:设 b_1 最大的是 2 的幂的因子为 x,因为 r 是奇数,所以 a_1 一定是 x 的倍数。则 d=\gcd(a_1,2^{64})\geq x。而对于每个 a_1 只会枚举 d 次,则总的枚举 r 的个数不超过 \frac{m}{d}\times d,其中 ma_1 可能的最大值。

代码

#include<bits/stdc++.h>
#define int long long
#define ll long long
#define ull unsigned long long
#define pii pair<ll,ll>
#define fi first
#define se second
#define i128 __int128
#define ALL(x) x.begin(),x.end()
#define popcount(x) __builtin_popcountll(x)
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
using namespace std;
const int INF=1e18;
const int N=4000005;
const int MOD1=1e9+7,MOD2=998244353;
const int MOD=MOD1;
int n;
ull m;
ull a[N],b[N];
const i128 t=((i128)1<<64);
istream&operator>>(istream&is,__int128&n){string s;is>>s;n=0;for(char c:s){n=n*10+(c-'0');}return is;}
ostream&operator<<(ostream&os,__int128 n){if(n==0)return os<<"0";if(n<0){os<<"-";n=-n;}string s;while(n>0){s+='0'+n%10;n/=10;}reverse(s.begin(),s.end());return os<<s;}
void solve1(){
    int x=n/2;
    int len1=0,len2=0;
    unordered_map<ull,int> mp;
    for(int i=0;i<(1ll<<x);i++){
        ull sum=0;
        for(int j=1;j<=x;j++){
            if((i>>(j-1))&1){
                sum=(ull)(sum+b[j]);
            }
        }
        mp[sum]=i;
    }
    for(int i=0;i<(1ll<<(n-x));i++){
        ull sum=0;
        for(int j=1;j<=(n-x);j++){
            if((i>>(j-1))&1){
                sum=(ull)(sum+b[j+x]);
            }
        }
        if(mp.count((ull)((ull)m-sum))){
            int tmp=mp[(ull)((ull)m-sum)];
            for(int j=0;j<x;j++){
                cout<<((tmp>>j)&1);
            }
            for(int j=0;j<(n-x);j++){
                cout<<((i>>j)&1);
            }
            exit(0);
        }
    }
    exit(1);
}
i128 exgcd(i128 a,i128 b,i128 &x,i128 &y){
    if(!b){
        x=1,y=0;
        return a;
    }
    i128 d=exgcd(b,a%b,x,y);
    i128 t=x;
    x=y;
    y=t-(a/b)*y;
    return d;
}
int ans[70];
void solve2(){
    i128 mx=((i128)1<<(64-n+1));
    for(i128 i=1;i<=mx;i++){
        i128 r0=0,tmp=0;
        i128 d=exgcd(i,t,r0,tmp);
        r0=(ull)r0;
        i128 delta=t/d;
        r0%=delta;
        if(b[1]%d!=0)continue;
        for(i128 tmpr=r0;tmpr<t;tmpr+=delta){
            ull r=tmpr*(ull)(b[1]/d);
            assert((ull)((ull)i*r)==b[1]);
            if(r%2==0){
                continue;
            }
            i128 invr=0;
            exgcd(r,t,invr,tmp);
            invr=(ull)invr;
            assert((ull)((ull)invr*r)==1);
            a[1]=i;
            i128 s=a[1];
            bool flg=1;
            for(int j=2;j<=n;j++){
                a[j]=(ull)((ull)b[j]*(ull)invr);
                if(a[j]<=s){
                    flg=0;
                    break;
                }
                s+=a[j];
            }
            if(s>=t){
                continue;
            }
            if(flg){
                ull tmpm=(ull)((ull)m*invr);
                for(int j=n;j>=1;j--){
                    if(tmpm>=a[j]){
                        tmpm-=a[j];
                        ans[j]=1;
                    }
                }
                if(!tmpm){
                    for(int j=1;j<=n;j++){
                        cout<<ans[j];
                    }
                    exit(0);
                }
            }
        }
    }
    exit(1);
}
void solve_(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>b[i];
    }
    cin>>m;
    if(n<=44){
        solve1();
    }else{
        solve2();
    }
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int testcase,multitest=0;
    if(multitest)cin>>testcase;
    else testcase=1;
    while(testcase--){
        solve_();
    }
    return 0;
}