题解:P2019 四平方和定理

· · 题解

引入

这是一道十分好的数论题目。适合初学者做做练练手

介绍

拉格朗日四平方和定理

这个定理表述的是任何正整数必定可以表示为四个整数的平方和。

证明:

由欧拉恒等式:

(a^2+b^2+c^2+d^2)(x^2+y^2+z^2+w^2)=(ax+by+cz+dw)^2+(ay-bx+cw-dz)^2+(az-bw-cx+dy)^2+(aw+bz-cy-dx)^2

这个恒等式可以使用初中的恒等变形证明。由于展开式我懒得写比较长,这里忽略证明。 那么如果两个数都能表示为四个平方数的和,那么它们的乘积也能表示为四个平方数的和。因此,只需证明所有素数都能表示为四个平方数的和。

首先对于 2,显然地,2=1^2+0^2+0^2+1^2

其次对于奇素数 p,通过模 p 的二次剩余理论容易得出存在整数 x,y,使得 x^2+y^2+1 \equiv 0 \pmod{p}。构造一个数 m,且 m<p,使得 mp=x^2+y^2+z^2+w^2,通过无穷递降法即证。

雅可比四平方和定理

r_4(n) 表示自然数 n 表示为四个整数平方和的有序表示方式的数量。定义:

\theta_3(q)=\sum_{n=-\infty}^{\infty}q^{n^2},|q|<1

根据雅可比三重积恒等式,有

\prod_{n=1}^{\infty}(1-q^{2n})(1+q^{2n-1}z)(1+q^{2n-1}z^{-1})=\sum_{n=-\infty}^{\infty}q^{n^2}z^n

z=1,有

\theta_3(q)=\prod_{n=1}^{\infty}(1-q^{2n})(1+q^{2n-1})^2

展开 \theta_3(q)^4

\theta_3(q)^4=\left(\prod_{n=1}^{\infty}(1-q^{2n})(1+q^{2n-1})^2\right)^4=\prod_{n=1}^{\infty}(1-q^{2n})^4(1+q^{2n-1})^8

雅可比证明了(这里我实力不够不知道咋证,有大佬告诉我吗?)

\theta_3(q)^4=1+8\sum_{k=1}^{\infty}\frac{kq^k}{1+(-q)^k}

展开右边的系数,可以得到:

1+8\sum_{k=1}^{\infty}\frac{kq^k}{1+(-q)^k}=1+8\sum_{k=1}^{\infty}kq^k\left(1-\left(\sum_{i=1}^{\infty}(-(-q)^k)^{i}\right)\right)

提取 q^n 的系数,得到:

r_4(n)=8\sum_{d \mid n,4 \nmid n} d

最终得到结论。

代码实现

由于 n \le 10^{18},通过 MillerRabin 和 PollardRho 完成找因数的过程。可以通过此题。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int MOD = 1e9 + 7;

namespace MillerRabin {
    ll mul(ll a, ll b, ll mod) {
        ll res = 0;
        while (b) {
            if (b & 1) res = (res + a) % mod;
            a = (a + a) % mod;
            b >>= 1;
        }
        return res;
    }
    ll pow(ll a, ll b, ll mod) {
        ll res = 1;
        while (b) {
            if (b & 1)
                res = mul(res, a, mod);
            a = mul(a, a, mod);
            b >>= 1;
        }
        return res;
    }
    bool is_prime(ll n) {
        if (n < 2)
            return false;
        if (n == 2 || n == 3)
            return true;
        if (n % 2 == 0)
            return false;
        ll d = n - 1;
        int s = 0;
        while (d % 2 == 0) {
            d /= 2;
            s++;
        }
        for (int a : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) {
            if (a >= n)
                continue;
            ll x = pow(a, d, n);
            if (x == 1 || x == n - 1)
                continue;
            bool ok = false;
            for (int i = 0; i < s - 1; i++) {
                x = mul(x, x, n);
                if (x == n - 1) {
                    ok = true;
                    break;
                }
            }
            if (!ok)
                return false;
        }
        return true;
    }
}

namespace PollardRho {
    using namespace MillerRabin;
    mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
    ll get_factor(ll n) {
        if (n % 2 == 0)
            return 2;
        if (n % 3 == 0)
            return 3;
        if (n % 5 == 0)
            return 5;
        while (true) {
            ll x = rng() % (n - 2) + 2;
            ll y = x;
            ll c = rng() % (n - 1) + 1;
            ll d = 1;
            auto f = [&](ll x) {
                return (mul(x, x, n) + c) % n;
            };
            while (d == 1) {
                x = f(x);
                y = f(f(y));
                d = __gcd(abs(x - y), n);
            }
            if (d != n && is_prime(d)) return d;
        }
    }
    void solve2(ll n, ll* f1, int& count) {
        if (n == 1) return;
        if (is_prime(n)) {
            f1[count++] = n;
            return;
        }
        ll d = get_factor(n);
        solve2(d, f1, count);
        solve2(n / d, f1, count);
    }
}

struct Factor {
    ll prime;
    int exponent;
};

void sort_f1(ll* arr, int n) {
    sort(arr, arr + n);
}

void group_f1(ll* f1, int count, Factor* isg, int& icnt) {
    if (count == 0) return;
    icnt = 0;
    isg[0].prime = f1[0];
    isg[0].exponent = 1;
    for (int i = 1; i < count; i++)
        if (f1[i] == isg[icnt].prime)
            isg[icnt].exponent++;
        else {
            icnt++;
            isg[icnt].prime = f1[i];
            isg[icnt].exponent = 1;
        }
    icnt++;
}

void solve1(const Factor* f1, int cnt, ll* f2, int& dcnt, ll current = 1, int index = 0) {
    if (index == cnt) {
        f2[dcnt++] = current;
        return;
    }
    ll p = f1[index].prime;
    int e = f1[index].exponent;
    for (int i = 0; i <= e; i++) {
        solve1(f1, cnt, f2, dcnt, current, index + 1);
        current *= p;
    }
}

int solve(ll n) {
    if (n == 0)
        return 1;
    ll f1[105];
    int cnt = 0;
    PollardRho::solve2(n, f1, cnt);
    sort_f1(f1, cnt);
    Factor isg[MAX_FACTORS];
    int icnt = 0;
    group_f1(f1, cnt, isg, icnt);
    ll* f2 = new ll[1000005];
    int dcnt = 0;
    solve1(isg, icnt, f2, dcnt);
    ll sum = 0;
    for (int i = 0; i < dcnt; i++)
        if (f2[i] % 4 != 0)
            sum += f2[i];
    delete[] f2;
    sum %= MOD;
    sum = sum * 8 % MOD;
    return sum;
}

int main() {
    int T;
    cin >> T;
    while (T--) {
        ll n;
        cin >> n;
        cout << solve(n) << '\n';
    }
    return 0;
}

Update:

第一次过审的题解不够详细,感谢 Pentiment 大佬提出的宝贵意见。现在已经修改。