题解:P6962 [NEERC 2017] Knapsack Cryptosystem
做法
首先,如果
当
当
枚举的复杂度证明:设
代码
#include<bits/stdc++.h>
#define int long long
#define ll long long
#define ull unsigned long long
#define pii pair<ll,ll>
#define fi first
#define se second
#define i128 __int128
#define ALL(x) x.begin(),x.end()
#define popcount(x) __builtin_popcountll(x)
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
using namespace std;
const int INF=1e18;
const int N=4000005;
const int MOD1=1e9+7,MOD2=998244353;
const int MOD=MOD1;
int n;
ull m;
ull a[N],b[N];
const i128 t=((i128)1<<64);
istream&operator>>(istream&is,__int128&n){string s;is>>s;n=0;for(char c:s){n=n*10+(c-'0');}return is;}
ostream&operator<<(ostream&os,__int128 n){if(n==0)return os<<"0";if(n<0){os<<"-";n=-n;}string s;while(n>0){s+='0'+n%10;n/=10;}reverse(s.begin(),s.end());return os<<s;}
void solve1(){
int x=n/2;
int len1=0,len2=0;
unordered_map<ull,int> mp;
for(int i=0;i<(1ll<<x);i++){
ull sum=0;
for(int j=1;j<=x;j++){
if((i>>(j-1))&1){
sum=(ull)(sum+b[j]);
}
}
mp[sum]=i;
}
for(int i=0;i<(1ll<<(n-x));i++){
ull sum=0;
for(int j=1;j<=(n-x);j++){
if((i>>(j-1))&1){
sum=(ull)(sum+b[j+x]);
}
}
if(mp.count((ull)((ull)m-sum))){
int tmp=mp[(ull)((ull)m-sum)];
for(int j=0;j<x;j++){
cout<<((tmp>>j)&1);
}
for(int j=0;j<(n-x);j++){
cout<<((i>>j)&1);
}
exit(0);
}
}
exit(1);
}
i128 exgcd(i128 a,i128 b,i128 &x,i128 &y){
if(!b){
x=1,y=0;
return a;
}
i128 d=exgcd(b,a%b,x,y);
i128 t=x;
x=y;
y=t-(a/b)*y;
return d;
}
int ans[70];
void solve2(){
i128 mx=((i128)1<<(64-n+1));
for(i128 i=1;i<=mx;i++){
i128 r0=0,tmp=0;
i128 d=exgcd(i,t,r0,tmp);
r0=(ull)r0;
i128 delta=t/d;
r0%=delta;
if(b[1]%d!=0)continue;
for(i128 tmpr=r0;tmpr<t;tmpr+=delta){
ull r=tmpr*(ull)(b[1]/d);
assert((ull)((ull)i*r)==b[1]);
if(r%2==0){
continue;
}
i128 invr=0;
exgcd(r,t,invr,tmp);
invr=(ull)invr;
assert((ull)((ull)invr*r)==1);
a[1]=i;
i128 s=a[1];
bool flg=1;
for(int j=2;j<=n;j++){
a[j]=(ull)((ull)b[j]*(ull)invr);
if(a[j]<=s){
flg=0;
break;
}
s+=a[j];
}
if(s>=t){
continue;
}
if(flg){
ull tmpm=(ull)((ull)m*invr);
for(int j=n;j>=1;j--){
if(tmpm>=a[j]){
tmpm-=a[j];
ans[j]=1;
}
}
if(!tmpm){
for(int j=1;j<=n;j++){
cout<<ans[j];
}
exit(0);
}
}
}
}
exit(1);
}
void solve_(){
cin>>n;
for(int i=1;i<=n;i++){
cin>>b[i];
}
cin>>m;
if(n<=44){
solve1();
}else{
solve2();
}
}
signed main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int testcase,multitest=0;
if(multitest)cin>>testcase;
else testcase=1;
while(testcase--){
solve_();
}
return 0;
}