题解 P5277 【【模板】多项式开根(加强版)】

ikka

2019-03-30 16:43:31

Solution

**Update** 两处公式挂了换成了图片。 --- ## 前置知识 - 多项式开根 - Cipolla算法 ### 多项式开根 模板在[P5205 【模板】多项式开根](https://www.luogu.org/problemnew/show/P5277) 。 这里不多做讨论。 ### Cipolla算法 求解$x^2\equiv n \pmod p$的算法,由于`ntt`的模是奇素数所以这里只讨论$p$为奇素数的情况。 当方程有解时称$n$在模$p$意义下是二次剩余,否则称$n$在模$p$意义下是非二次剩余。 我们定义**勒让德符号**: ![勒让德符号](https://i.loli.net/2019/03/30/5c9f2e12389c9.png) 有一个**欧拉判别准则**: $$ (\frac{n}{p})\equiv n^{\frac{p-1}{2}}\pmod p $$ ~~口胡~~证明: 1. 当$p|n$时,显然有$0\equiv n^{\frac{p-1}{2}}\pmod p$ 2. 假设$x^2\equiv n \pmod p$,若$x^{p-1}\equiv 1\pmod p$时,根据费马小定理,假设成立,$n$在模$p$意义下是二次剩余。 若$x^{p-1}\equiv -1\pmod p$时,根据费马小定理,假设不成立,$n$在模$p$意义下是二次非剩余。 接下来让我们证明三个结论。 **结论一:$n^2 \equiv (p-n)^2 \pmod p$** 证明展开式子中的$(p-n)^2$即可。 **结论二:有$\frac{p-1}{2}$个数是模$p$意义下的二次非剩余。** 根据结论一,可以得出$0$到$p-1$的平方中有$\frac{p+1}{2}$个不同的数,剩下的数即为模$p$意义下的二次非剩余。 **结论三:$(a+b)^p\equiv a^p+b^p \pmod p$** 二项式定理展开,由于$p$是素数,所以除了首尾两项其他项中的组合数中都有质因子$p$,故同余式成立。 #### Cipolla算法流程 第一步,判断原方程是否有解(利用欧拉判别准则)。 第二步,随机找到一个$a$使得$\omega\equiv(a^2-n)\pmod p$在模$p$意义下是二次非剩余。(根据结论二,期望随机次数为$2$) 第三步,找到一个解$x\equiv(a+\sqrt{\omega})^{\frac{p+1}{2}} \pmod p$。但是$\sqrt{\omega}$显然不存在,我们可以把$\sqrt{\omega}$当做虚数$i$来看,就把问题解决了。 证明: ![Cipolla证明](https://i.loli.net/2019/03/30/5c9f2e1259b12.png) ## Code ``` cpp #include <cstdio> #include <cstring> #include <ctime> #include <cstdlib> #include <algorithm> const int maxn = 400010; const int mod = 998244353; const int g = 3; const int invg = 332748118; const int inv2 = 499122177; int inline pls(int a, int b) { int m = a + b; return m < mod ? m : m - mod; } int inline dec(int a, int b) { int m = a - b; return m + ((m >> 31) & mod); } int inline mul(int a, int b) { return 1ll * a * b % mod; } int inline pow(int a, int b) { int ans = 1; while (b) { if (b & 1) ans = mul(ans, a); a = mul(a, a); b >>= 1; } return ans; } struct cp { int r, i; cp(int x = 0, int y = 0) : r(x), i(y) {} }; cp inline mul(cp a, cp b, int w) { return cp(pls(mul(a.r, b.r), mul(w, mul(a.i, b.i))), pls(mul(a.r, b.i), mul(a.i, b.r))); } int inline pow(cp a, int b, int w) { cp ans(1, 0); while (b) { if (b & 1) ans = mul(ans, a, w); a = mul(a, a, w); b >>= 1; } return ans.r; } int inline cipolla(int x) { srand(time(0)); if (pow(x, (mod - 1) >> 1) == mod - 1) return -1; while (true) { int a = mul(rand(), rand()); int w = dec(mul(a, a), x); if (pow(w, (mod - 1) >> 1) == mod - 1) { return pow(cp(a, 1), (mod + 1) >> 1, w); } } } int a[maxn], b[maxn], r[maxn]; void inline ntt(int *a, int n, int f) { for (int i = 0; i < n; ++i) if (i < r[i]) std::swap(a[i], a[r[i]]); for (int i = 1; i < n; i <<= 1) { int wn = pow(f ? g : invg, (mod - 1) / (i << 1)); for (int *j = a; j < a + n; j += i << 1) { int w = 1; for (int k = 0; k < i; ++k, w = mul(w, wn)) { int x = *(j + k), y = mul(w, *(i + j + k)); *(j + k) = pls(x, y), *(i + j + k) = dec(x, y); } } } if (!f) { int rv = pow(n, mod - 2); for (int *i = a; i < a + n; ++i) *i = mul(*i, rv); } } void inline inv(int *a, int *b, int n) { b[0] = pow(a[0], mod - 2); static int A[maxn], B[maxn], len, lim; for (len = 1; len < n << 1; len <<= 1) { lim = len << 1; memcpy(A, a, len << 2); memcpy(B, b, len << 2); for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0); ntt(A, lim, 1); ntt(B, lim, 1); for (int i = 0; i < lim; ++i) b[i] = dec((B[i] << 1) % mod, mul(A[i], mul(B[i], B[i]))); ntt(b, lim, 0); memset(b + len, 0, len << 2); } memset(A, 0, len << 2); memset(B, 0, len << 2); memset(b + n, 0, (len - n) << 2); } void inline sqrt(int *a, int *b, int n) { int sr = cipolla(a[0]); b[0] = std::min(sr, mod - sr); static int A[maxn], B[maxn], len, lim; for (len = 1; len < n << 1; len <<= 1) { lim = len << 1; memcpy(A, a, len << 2); inv(b, B, len); for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0); ntt(A, lim, 1); ntt(B, lim, 1); for (int i = 0; i < lim; ++i) A[i] = mul(A[i], B[i]); ntt(A, lim, 0); for (int i = 0; i < len; ++i) b[i] = mul(inv2, pls(A[i], b[i])); memset(b + len, 0, len << 2); } memset(A, 0, len << 2); memset(B, 0, len << 2); memset(b + n, 0, (len - n) << 2); } int main() { int n; scanf("%d", &n); for (int *i = a; i < a + n; ++i) scanf("%d", i); sqrt(a, b, n); for (int *i = b; i < b + n; ++i) printf("%d%c", *i, " \n"[i == b + n - 1]); return 0; } ``` ~~抄我的代码会TLE233~~