题解 P2371 【[国家集训队]墨墨的等式】

· · 题解

这道题果然是国家集训队的神仙题,思维独特。

  首先若方程 \sum_{i=1}^{n}a_ix_i=k 有非负整数解,那么显然对于每一个 a_i 方程 \sum_{i=1}^{n}a_ix_i=k 都必有非负整数解。于是若取 Min=\min(a_i) ,那么对于任意 j \in [0,min) ,若对于自然数数 k \sum_{i=1}^{n}a_ix_i=k (k \equiv j (mod \ Min)) 有解,则对于一切自然数 B>k(B \equiv j (mod \ Min)) ,方程 \sum_{i=1}^{n}a_ix_i=k 都必有解。因此,我们只需对每一个 j 求出对应的 k 值。(我代码中实际求的是 k/Min ,取 Min 是为了让状态数尽可能小)

  但是,这个值怎么求呢?根据上文,若 \sum_{i=1}^{n}a_ix_i=k ,则我们可以尝试将 k 分别加上每一个 a_i 得到新的合法数值。这其实相当于对于每一个 i \in [1,n] ,从 k (k+a_i)mod \ Min 连一条长度为 a_i 的边。同时,我们要使 k 尽可能小,所以要在原图中跑最短路求解。

  代码:

1、dijkstra+堆(这个如果dijkstra写得不好在洛谷上会tle)

#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<algorithm>
#include<queue>
#include<vector>
#include<map>
#define ll long long
#define ull unsigned long long
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
#define lowbit(x) (x& -x)
#define mod 1000000000
#define inf 0x3f3f3f3f
#define eps 1e-18
#define maxn 5010
inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;}
inline ll power(ll a,ll b){ll ans=1; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;}
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;}
struct data{
    int id,dis;
    friend bool operator < (data a,data b){
        return a.dis>b.dis;
    }
};
std::priority_queue<data>heap;
struct edge{
    int to,nxt,d;
}e[6000010];
int fir[500010],mark[500010],dist[500010];
int a[20];
int n,m,tot=0;
ll l,r;
void add(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;}
void dij(int S)
{
    memset(dist,0x3f,sizeof(dist));
    memset(mark,0,sizeof(mark));
    data now; now.id=S; now.dis=0; dist[S]=0; heap.push(now);
    while(!heap.empty()){
        now=heap.top(); heap.pop();
        if(mark[now.id])continue;
        mark[now.id]=1;
        for(int i=fir[now.id];~i;i=e[i].nxt)
            if(dist[e[i].to]>dist[now.id]+e[i].d){
                data tmp; tmp.id=e[i].to; tmp.dis=dist[now.id]+e[i].d;
                heap.push(tmp); dist[e[i].to]=tmp.dis;
            }
    }
}
int main()
{
    memset(fir,255,sizeof(fir));
    n=read(); l=read(); r=read();
    int mn=inf;
    for(int i=1;i<=n;i++)
        a[i]=read(),mn=min(mn,a[i]);
    for(int i=0;i<mn;i++)
        for(int j=1;j<=n;j++){
            int nxt=i+a[j];
            add(i,nxt%mn,nxt/mn);
        }
    dij(0);
//  for(int i=0;i<=n;i++)printf("%d %d\n",i,dist[i]);
    ll ans=0;
    for(int i=0;i<mn;i++)
        if(1ll*dist[i]*mn+i<=r){
            ll L=max(dist[i]*mn+i,l),R=r;
            if(L/mn<R/mn){
                ans+=R/mn-L/mn-1;
                if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans;
                if(L<=R/mn*mn+i&&R/mn*mn+i<=R)++ans;
            }
            else if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans;
        }
    printf("%lld\n",ans);
}

2、spfa(这个就很清真了,跑得飞快)

#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<algorithm>
#include<queue>
#include<vector>
#include<map>
#define ll long long
#define ull unsigned long long
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
#define lowbit(x) (x& -x)
#define mod 1000000000
#define inf 0x3f3f3f3f
#define eps 1e-18
#define maxn 5010
inline ll read(){ll tmp=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())tmp=(tmp<<3)+(tmp<<1)+c-'0'; return tmp*f;}
inline ll power(ll a,ll b){ll ans=1; for(;b;b>>=1){if(b&1)ans=ans*a%mod; a=a*a%mod;} return ans;}
inline ll gcd(ll a,ll b){return b?gcd(b,a%b):a;}
inline void swap(int &a,int &b){int tmp=a; a=b; b=tmp;}
using namespace std;
struct edge{
    int to,nxt,d;
}e[6000010];
int fir[500010],inq[500010],dist[500010];
int q[10000010];
int a[20];
int n,m,tot=0;
ll l,r;
void add(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;}
void spfa(int S)
{
    memset(dist,0x3f,sizeof(dist));
    memset(inq,0,sizeof(inq));
    int h=1,t=1; q[1]=S; inq[S]=1; dist[S]=0;
    while(h<=t){
        for(int i=fir[q[h]];~i;i=e[i].nxt)
            if(dist[q[h]]+e[i].d<dist[e[i].to]){
                dist[e[i].to]=dist[q[h]]+e[i].d;
                if(!inq[e[i].to]){
                    q[++t]=e[i].to; inq[e[i].to]=1;
                }
            }
        inq[q[h++]]=0;
    }
}
int main()
{
    memset(fir,255,sizeof(fir));
    n=read(); l=read(); r=read();
    int mn=inf;
    for(int i=1;i<=n;i++)
        a[i]=read(),mn=min(mn,a[i]);
    for(int i=0;i<mn;i++)
        for(int j=1;j<=n;j++){
            int nxt=i+a[j];
            add(i,nxt%mn,nxt/mn);
        }
    spfa(0);
    ll ans=0;
    for(int i=0;i<mn;i++)
        if(1ll*dist[i]*mn+i<=r){
            ll L=max(dist[i]*mn+i,l),R=r;
            if(L/mn<R/mn){
                ans+=R/mn-L/mn-1;
                if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans;
                if(L<=R/mn*mn+i&&R/mn*mn+i<=R)++ans;
            }
            else if(L<=L/mn*mn+i&&L/mn*mn+i<=R)++ans;
        }
    printf("%lld\n",ans);
}