题解:P1763 埃及分数

· · 题解

先放上代码,然后再讲解:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll deep,s[11],ans[11],flag,a,b;
ll gcd(ll x,ll y){
    if(y==0)return x;
    else return gcd(y,x%y);
}
void dfs(ll a,ll b,int x){
    if(x>deep)return;
    if(a==1&&b>s[x-1]){
        s[x]=b;
        if(!flag||s[x]<ans[x])memcpy(ans,s,sizeof(ll)*(deep+1));
        flag=1;
        return;
    }
    ll l=max(b/a+1,s[x-1]+1);
    ll r=(deep-x+1)*b/a;
    if(flag&&r>=ans[deep])r=ans[deep]-1;
    for(ll i=l;i<r;i++){
        s[x]=i;
        ll gcdd=gcd(a*i-b,b*i);
        dfs((a*i-b)/gcdd,b*i/gcdd,x+1);
    }
}
int main(){
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
    cin>>a>>b;
    ll g=gcd(a,b);
    a/=g,b/=g;
    for(deep=1;deep<=10;deep++){
        dfs(a,b,1);
        if(flag){
            for(int i=1;i<=deep;i++)cout<<ans[i]<<' ';
            return 0;
        }
    }
}

代码中有一个全局变量 deep,用来规定搜索深度的上限。将其设为全局变量是为了方便跨函数调用(设成函数参数也可以)。

接下来就是 dfs 函数来深搜,其中有一个变量 x1 表示搜索深度。函数中如果当前面的搜索深度大于规定搜索深度 deep,则直接返回,表示在当前规定搜索深度下,此状态一定不能到达搜索目标状态。

在主函数中,从小到大枚举深度界限 deep,每次让深度界限增加 1,并调用 dfs 函数,在新的深度界限下再次搜索。当 dfs 搜索到目标时后不再继续增加深度界限,此时 deep 的值就能搜到的答案的最小深度。

但是,众所周知单单就这样写代码 hack 会 TLE(即上面刚刚的代码不能 AC!

所以我们还需要优化:

首先很容易发现暴力枚举每个位置的分母还是太慢了。我们可以发现,当搜到最后两个分母时,我们没有必要枚举,而是可以直接把最后两项分母列方程解出来,这样会比枚举快很多。那现在我们设最后两项分母为 xy,其中 x<y ,它们的总和是 \frac{a}{b}, 即:\frac{a}{b}=\frac{1}{x}+\frac{1}{y}=\frac{x+y}{xy}

因为 ab 是互质的,等式要是有整数解,则存在一个整数 k 满足:\frac{x+y}{xy}=\frac{ak}{bk},则有:

\\ xy=bk \end{matrix}\right.

y 消掉以后,有:x^2-akx+bk=0

这个方程要是有解,则判别式为:\Delta =a^2k^2-4bk≥0 满足这个条件的话,则:k≥\left \lceil \frac{4b}{a^2} \right \rceil

因此我们可以枚举 k,找到某个 k 使得 \sqrt{\Delta} 为整数,此时两个解分别为:

\\ y=\frac{ak+\sqrt{\Delta}}{2} \end{matrix}\right.

要求这两个数也必须是整数,且 y 要小于等于求出过的最后一项的分母,或者小于 10^7。如果 y 大于 10^7 可直接 break

好了,这就是优化思路,最后我放上求最后两项分母的代码:

if(x1==deep-1){
    ll minK=ceil(sqrt(4*b/(a*a)));
        for(ll k=minK;;k++){
            ll delta=a*a*k*k-4*b*k;
            ll t=sqrt(delta),gd=-1;
            if(t*t==delta)gd=t;
            else if((t-1)*(t-1)==delta)gd=t-1;
            else if((t+1)*(t+1)==delta)gd=t+1;
            ll x=(a*k-gd)/2;
            ll y=(a*k+gd)/2;
            if(y>1e7||(flag&&y>=ans[deep]))break;
            if(gd<=0||(a*k-gd)%2!=0)continue;
            s[deep-1]=x;
            s[deep]=y;
            memcpy(ans,s,sizeof(ll)*(deep+1));
            flag=1;
            break;
        }
        return;
    }