题解 CF493E 【Vasya and Polynomial】

· · 题解

题目传送门

好的阅读体验

题意简述

给定正整数 t,a,b,求满足 P(t)=a,P(a)=b 的非负整系数多项式 P(x) 的个数。

题目分析

显然 P(x) \geq x,故必须满足 t \leq a \leq b ,否则无解。

P(x)=p_0+p_1x+p_2x^2+ \dots +p_nx^n

首先有一个很重要的性质:每个系数都小于等于 a,即 p_i \leq a

由上面的证明可以知道当存在 p_i=a 时只能有一项,可以分类讨论。

存在 p_i=a

此时多项式只有一项,故 i=n

  1. n=0,P(x)=a

a 代入得 P(a)=a ,故 P(x) 存在当且仅当 a=b

  1. n>1,P(x)=p_nx^n

t 代入得 P(t)=at^n=a ,得 t^n=1 ,故 t=1

a 代入得 P(a)=a^{n+1}=b

P(x) 存在当且仅当 t=1a^{n+1}=b

注意 a=1 时只有可能无解或无穷解:

#### 所有 $p_i<a

b 表示成 a 进制,就可以构造出一个 P(a)=b ,并且至多只有一个

以下是证明,不过简单来说,就是b 表示成 a 进制的形式是唯一的

P(a)=p_0+p_1a+p_2a^2+ \dots +p_na^n=b

假设存在另一个 P'(a)=b

P'(a)=p_0'+p_1'a+p_2'a^2+ \dots +p_n'a^n=b

两式相减

(p_1-p_1')a+(p_2-p_2')a^2+ \dots +(p_n-p_n')a^n=0

因为 P(x) \neq P'(x) ,存在 m 使 p_m-p_m' \neq 0 ,令 m 取最大值(也就是最高次数的不为0的项)。

(p_1-p_1')a+(p_2-p_2')a^2+ \dots +(p_m-p_m')a^m=0 |(p_1-p_1')a+(p_2-p_2')a^2+ \dots +(p_m-p_m')a^m|=|(p_m-p_m')a^m| (*)

又因为 0 \leq p_i,p_i' \leq a-1 ,故 |p_i-p_i'| < a

|(p_1-p_1')a|<a^2 |(p_1-p_1')a+(p_2-p_2')a^2<a^2+a^2(a-1)|=a^3 |(p_1-p_1')a+(p_2-p_2')a^2+(p_3-p_3')a^3<a^3+a^3(a-1)|=a^4

以此类推,故 |(p_1-p_1')a+(p_2-p_2')a^2+ \dots +(p_{m-1}-p_{m-1}')a^{m-1}|<a^m

|(p_m-p_m')a^m|>=a^m

(*) 矛盾。

上面那一堆东西简单来说,就是b 表示成 a 进制的形式是唯一的

求出 P(x) ,再将 t 代入检验一下 P(t)=a 就行了。

代码

#include<bits/stdc++.h>
using namespace std;
long long T,t,a,b;
long long cnt,tot;
long long ans[65];

bool check(long long x,long long y){//检验P(t)=a 
    long long tot=0;
    long long k=1;
    for(long long i=0;i<=cnt;i++){
        tot+=k*ans[i];
        k*=x;
        if(tot>y) return false;
    }
    return (tot==y);
}

long long cal(long long a,long long b){//计算a^k=b 
    long long cn=0;
    while(b%a==0){
        cn++;
        b/=a;
    }
    if(b!=1) return -1;
    return cn;
}

int main(){
    scanf("%lld%lld%lld",&t,&a,&b);
    if(t==1 && a==1 && b==1){//特判无数解 
        printf("inf");
        return 0;
    }

    if(a==b) tot++;//p_i=a,n=0
    if(t==1 && a>1){//p_i=a,n>0
        long long tmp=cal(a,b);
        if(tmp!=-1 && tmp>1) tot++;
    }

    if(a>1){//p_i<a
        cnt=-1;
        long long tmp=b;
        while(tmp) ans[++cnt]=tmp%a,tmp/=a;
        if(check(t,a)) tot++;
    }
    printf("%lld",tot);
}