**Update**
两处公式挂了换成了图片。
---
## 前置知识
- 多项式开根
- Cipolla算法
### 多项式开根
模板在[P5205 【模板】多项式开根](https://www.luogu.org/problemnew/show/P5277) 。
这里不多做讨论。
### Cipolla算法
求解$x^2\equiv n \pmod p$的算法,由于`ntt`的模是奇素数所以这里只讨论$p$为奇素数的情况。
当方程有解时称$n$在模$p$意义下是二次剩余,否则称$n$在模$p$意义下是非二次剩余。
我们定义**勒让德符号**:
![勒让德符号](https://i.loli.net/2019/03/30/5c9f2e12389c9.png)
有一个**欧拉判别准则**:
$$
(\frac{n}{p})\equiv n^{\frac{p-1}{2}}\pmod p
$$
~~口胡~~证明:
1. 当$p|n$时,显然有$0\equiv n^{\frac{p-1}{2}}\pmod p$
2. 假设$x^2\equiv n \pmod p$,若$x^{p-1}\equiv 1\pmod p$时,根据费马小定理,假设成立,$n$在模$p$意义下是二次剩余。
若$x^{p-1}\equiv -1\pmod p$时,根据费马小定理,假设不成立,$n$在模$p$意义下是二次非剩余。
接下来让我们证明三个结论。
**结论一:$n^2 \equiv (p-n)^2 \pmod p$**
证明展开式子中的$(p-n)^2$即可。
**结论二:有$\frac{p-1}{2}$个数是模$p$意义下的二次非剩余。**
根据结论一,可以得出$0$到$p-1$的平方中有$\frac{p+1}{2}$个不同的数,剩下的数即为模$p$意义下的二次非剩余。
**结论三:$(a+b)^p\equiv a^p+b^p \pmod p$**
二项式定理展开,由于$p$是素数,所以除了首尾两项其他项中的组合数中都有质因子$p$,故同余式成立。
#### Cipolla算法流程
第一步,判断原方程是否有解(利用欧拉判别准则)。
第二步,随机找到一个$a$使得$\omega\equiv(a^2-n)\pmod p$在模$p$意义下是二次非剩余。(根据结论二,期望随机次数为$2$)
第三步,找到一个解$x\equiv(a+\sqrt{\omega})^{\frac{p+1}{2}} \pmod p$。但是$\sqrt{\omega}$显然不存在,我们可以把$\sqrt{\omega}$当做虚数$i$来看,就把问题解决了。
证明:
![Cipolla证明](https://i.loli.net/2019/03/30/5c9f2e1259b12.png)
## Code
``` cpp
#include <cstdio>
#include <cstring>
#include <ctime>
#include <cstdlib>
#include <algorithm>
const int maxn = 400010;
const int mod = 998244353;
const int g = 3;
const int invg = 332748118;
const int inv2 = 499122177;
int inline pls(int a, int b) { int m = a + b; return m < mod ? m : m - mod; }
int inline dec(int a, int b) { int m = a - b; return m + ((m >> 31) & mod); }
int inline mul(int a, int b) { return 1ll * a * b % mod; }
int inline pow(int a, int b) {
int ans = 1;
while (b) {
if (b & 1) ans = mul(ans, a);
a = mul(a, a);
b >>= 1;
}
return ans;
}
struct cp {
int r, i;
cp(int x = 0, int y = 0) : r(x), i(y) {}
};
cp inline mul(cp a, cp b, int w) {
return cp(pls(mul(a.r, b.r), mul(w, mul(a.i, b.i))), pls(mul(a.r, b.i), mul(a.i, b.r)));
}
int inline pow(cp a, int b, int w) {
cp ans(1, 0);
while (b) {
if (b & 1) ans = mul(ans, a, w);
a = mul(a, a, w);
b >>= 1;
}
return ans.r;
}
int inline cipolla(int x) {
srand(time(0));
if (pow(x, (mod - 1) >> 1) == mod - 1) return -1;
while (true) {
int a = mul(rand(), rand());
int w = dec(mul(a, a), x);
if (pow(w, (mod - 1) >> 1) == mod - 1) {
return pow(cp(a, 1), (mod + 1) >> 1, w);
}
}
}
int a[maxn], b[maxn], r[maxn];
void inline ntt(int *a, int n, int f) {
for (int i = 0; i < n; ++i) if (i < r[i]) std::swap(a[i], a[r[i]]);
for (int i = 1; i < n; i <<= 1) {
int wn = pow(f ? g : invg, (mod - 1) / (i << 1));
for (int *j = a; j < a + n; j += i << 1) {
int w = 1;
for (int k = 0; k < i; ++k, w = mul(w, wn)) {
int x = *(j + k), y = mul(w, *(i + j + k));
*(j + k) = pls(x, y), *(i + j + k) = dec(x, y);
}
}
}
if (!f) {
int rv = pow(n, mod - 2);
for (int *i = a; i < a + n; ++i) *i = mul(*i, rv);
}
}
void inline inv(int *a, int *b, int n) {
b[0] = pow(a[0], mod - 2);
static int A[maxn], B[maxn], len, lim;
for (len = 1; len < n << 1; len <<= 1) {
lim = len << 1;
memcpy(A, a, len << 2); memcpy(B, b, len << 2);
for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0);
ntt(A, lim, 1); ntt(B, lim, 1);
for (int i = 0; i < lim; ++i) b[i] = dec((B[i] << 1) % mod, mul(A[i], mul(B[i], B[i])));
ntt(b, lim, 0);
memset(b + len, 0, len << 2);
}
memset(A, 0, len << 2); memset(B, 0, len << 2);
memset(b + n, 0, (len - n) << 2);
}
void inline sqrt(int *a, int *b, int n) {
int sr = cipolla(a[0]);
b[0] = std::min(sr, mod - sr);
static int A[maxn], B[maxn], len, lim;
for (len = 1; len < n << 1; len <<= 1) {
lim = len << 1;
memcpy(A, a, len << 2);
inv(b, B, len);
for (int i = 1; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) ? len : 0);
ntt(A, lim, 1); ntt(B, lim, 1);
for (int i = 0; i < lim; ++i) A[i] = mul(A[i], B[i]);
ntt(A, lim, 0);
for (int i = 0; i < len; ++i) b[i] = mul(inv2, pls(A[i], b[i]));
memset(b + len, 0, len << 2);
}
memset(A, 0, len << 2); memset(B, 0, len << 2);
memset(b + n, 0, (len - n) << 2);
}
int main() {
int n;
scanf("%d", &n);
for (int *i = a; i < a + n; ++i) scanf("%d", i);
sqrt(a, b, n);
for (int *i = b; i < b + n; ++i) printf("%d%c", *i, " \n"[i == b + n - 1]);
return 0;
}
```
~~抄我的代码会TLE233~~