[Math×Girl] 平均律 - Official题解
Naszt
·
·
题解
平均律
观前提醒
感谢 @飞雨烟雁 大佬提供了一种线性的做法。
思路分析
找突破口
由于近似分数一定是既约分数,下文「分数」指的都是既约分数。
我们记一个分数 x 的 前驱/后继 为 x^-/x^+,
定义为分母比 x 的分母小,数值小于/大于 x 的最接近 x 的分数。
枚举分母为 n 的分数 x,答案就是近似分数为 x 的区间和:
ans=\sum\max(\min(x^+-\delta,x+\delta)-\max(x^-+\delta,x-\delta),0)
计算前驱/后继
※ 这一小节是 \log 做法。
有:
\begin{aligned}
x^+&=\lfloor x\rfloor+(x-\lfloor x\rfloor)^+\\
x^+&=1/(1/x)^-\\
x^-&=\lfloor x\rfloor+(x-\lfloor x\rfloor)^-\\
x^-&=1/(1/x)^+
\end{aligned}
若 x=\frac1n,则 x^-=\frac01,x^+=\frac1{n-1}
由此便可以递归求解。
其中的 (1)(3) 式是因为每一段整数都是等价的。
而以上方法的本质其实就是 简单连分数,更本质一点就是 辗转相除法。
时间复杂度也自然是 $O(\log n)$,总时间复杂度是 $O(n\log n)$。
### 线性方法
对于分数 $\frac dn$,设其前驱和后继为 $\frac xy<\frac dn<\frac zw$,则:
$$
\begin{aligned}
\frac dn-\frac xy&=\frac{dy-nx}{ny}=\frac 1{ny}\\
\frac zw-\frac dn&=\frac{zn-dw}{nw}=\frac 1{nw}
\end{aligned}
$$
根据 $dy\equiv 1\pmod n,dw\equiv -1\pmod n$ 即可求出 $y,w$。
但是你不能线性递推的求逆元,因为 $n$ 不一定是质数,
你可以根据逆元的积性和线性筛预处理。
这样的时间复杂度就是 $O(n)$ 的。
## 代码实现
### 出题人代码
```cpp
#include<bits/stdc++.h>
typedef unsigned long long i8;
const i8 MOD=998244353,MX=10000005;
i8 Inv[MX],Invp[MX],da,db,v;
long double I=1,V;
std::vector<i8>Prime;
bool vis[MX];
const __int128 II=1;
void exgcd(i8 a,i8 b,i8&x,i8&y) {
if(!b)x=1,y=0;
else exgcd(b,a%b,y,x),y-=a/b*x;
}
void sieve(const i8 n){
Inv[1]=1;Prime={};
for(i8 i=2;i<=n;i++){
if(!vis[i]){
Prime.push_back(i);
if(n%i==0)Inv[i]=0;
else{i8 x,y;exgcd(i,n,x,y);Inv[i]=(x+n)%n;}
}
for(i8 p:Prime){
if(i*p>n)break;
vis[i*p]=1;Inv[i*p]=Inv[i]*Inv[p]%n;
if(i%p==0)break;
}
}
}
struct frac{
#define il inline __attribute__((__always_inline__))
i8 a,b;char f;//frac{a}{b}+f\delta
il frac(i8 A,i8 B,char F):a(A),b(B),f(F){}
il friend i8 model(frac x){return (x.a*Invp[x.b]+(x.f==1?v:MOD-v));}
il friend bool operator<(frac x,frac y){
// return I*x.a/x.b+x.f*I*da/db<I*y.a/y.b+y.f*I*da/db;//精度会炸
return (y.f-x.f)*II*da*y.b*x.b>(II*x.a*y.b-II*y.a*x.b)*db;
}
};
void slove(){
i8 n,g,ans=0;
std::cin>>n>>da>>db;
g=std::gcd(da,db),da/=g,db/=g;
i8 x,y;exgcd(db%MOD,MOD,x,y);v=da%MOD*(x+MOD)%MOD;
if(n==1){std::cout<<(2*da<db?2*v%MOD:1)<<"\n";return;}
if((__int128)2*n*da>=db){std::cout<<"0\n";return;}
sieve(n);Invp[1]=1;
for(i8 i=2;i<=n;i++)Invp[i]=Invp[MOD%i]*(MOD-MOD/i)%MOD;
for(i8 d=1;d<n;d++){
if(!Inv[d])continue;
frac _x=std::max(frac((d*Inv[d]-1)/n,Inv[d],+1),frac(d,n,-1));
frac x_=std::min(frac((d*(n-Inv[d])+1)/n,n-Inv[d],-1),frac(d,n,+1));
if(_x<x_)ans=(ans+model(x_)+MOD-model(_x)%MOD)%MOD;
}
std::cout<<ans<<"\n";
}
int main(){
std::ios::sync_with_stdio(0);
std::cin.tie(0),std::cout.tie(0);
i8 T;std::cin>>T;
while(T--)slove();
return 0;
}
```
### 验题人代码
```cpp
#include <iostream>
#include <cstdio>
#include <cmath>
#define ll long long
#define lll __int128
using namespace std;
const int Mx = 1e7 + 5, Mod = 998244353;
int n;
ll a, b;
int inv_mod(ll a){ // a < Mod
int res = 1, b = Mod - 2;
while(b){
if(b & 1) res = res * a % Mod;
a = a * a % Mod, b >>= 1;
}
return res;
}
bool coprime[Mx], vis[Mx];
int prime[Mx], tot;
int inv[Mx], invp[Mx];
void sieve(){
tot = 0, inv[1] = coprime[1] = 1;
for(int i = 2; i < n; ++i){
if(!vis[i]) prime[++tot] = i, coprime[i] = (n % i > 0);
for(int j = 1; j <= tot && prime[j] * i < n; ++j){
vis[i * prime[j]] = true;
coprime[i * prime[j]] = coprime[i] & coprime[prime[j]];
if(i % prime[j] == 0) break;
}
}
for(int i = 2; i < n; ++i){
if(coprime[i]) inv[i] = 1ll * inv[i - 1] * i % n;
else inv[i] = inv[i - 1];
}
int back = 1, temp = inv[n - 1];
for(int i = n - 1; i > 1; --i) if(coprime[i]){
inv[i] = 1ll * inv[i - 1] * back % n;
back = 1ll * back * i % n;
}
if(temp == n - 1) for(int i = 2; i < n; ++i) inv[i] = n - inv[i];
}
int solve(){
int delta = a % Mod * inv_mod(b % Mod) % Mod;
if(n == 1) return 2 * a < b ? 2 * delta % Mod : 1;
sieve();
invp[1] = 1;
for(int i = 2; i <= n; i++) invp[i] = -invp[Mod % i] * (ll)(Mod / i) % Mod;
ll Ans = 0;
for(int d = 1; d < n; ++d) if(coprime[d]){
int y = inv[d], w = n - inv[d];
int x = 1ll * d * y / n, z = (1 + 1ll * d * w) / n;
if((lll)b * z * y <= (lll)2 * a * w * y + (lll)b * w * x) continue;
if(b / 2 / a / n >= y) Ans -= 1ll * d * invp[n] % Mod - delta;
else Ans -= 1ll * x * invp[y] % Mod + delta;
if(b / 2 / a / n >= w) Ans += 1ll * d * invp[n] % Mod + delta;
else Ans += 1ll * z * invp[w] % Mod - delta;
}
Ans %= Mod;
return Ans < 0 ? Ans + Mod : Ans;
}
int main(){
int T;
scanf("%d", &T);
while(T--){
scanf("%d%lld%lld", &n, &a, &b);
printf("%d\n", solve());
}
return 0;
}
```