题解 P5050 【【模板】多项式多点求值】
officeyutong
2019-01-14 20:14:48
![](https://cdn.luogu.com.cn/upload/pic/48589.png)
一定要及时清空存储多项式数组中不相关的项。
请忽略大量调试代码。
```cpp
#pragma GCC optimize("Ofast")
#include <assert.h>
#include <algorithm>
#include <bitset>
#include <cmath>
#include <cstring>
#include <ctime>
#include <iostream>
#include <vector>
using int_t = int;
using std::cin;
using std::cout;
using std::endl;
#ifdef DEBUG
#define debug(x) cout << #x << " = " << x << endl;
#else
#define debug(x)
#endif
const int_t mod = 998244353;
const int_t g = 3;
const int_t LARGE = 1 << 20;
int_t power(int_t base, int_t index);
void transform(int_t *A, int_t size, int_t arg);
constexpr int_t upper2n(int_t x);
void poly_mul(int_t *A, int_t *B, int_t size);
void poly_inv(int_t *A, int_t *inv, int_t n);
void poly_div(int_t *A, int_t n, int_t *B, int_t m, int_t *Q, int_t *R);
void poly_dc_mul(int_t *base, int_t n);
void poly_evaluation(int_t *A, const int_t *seq, int_t *result, int_t left,
int_t right);
template <class T>
std::ostream &operator<<(std::ostream &os, std::vector<T> &x) {
os << "{ ";
for (int_t v : x) os << v << " ";
os << "} ";
return os;
}
inline int_t bitRev(int_t x, int_t size2) {
int_t res = 0;
for (int_t i = 1; i < size2; i++) {
res |= (x & 1);
res <<= 1;
x >>= 1;
}
return res | (x & 1);
};
int bitRevs[20][LARGE + 1];
int main() {
#ifdef TIME
auto begin = clock();
#endif
for (int_t i = 1; (1 << i) < LARGE; i++) {
for (int_t j = 0; j < LARGE; j++) {
bitRevs[i][j] = bitRev(j, i);
}
}
// {
// static int_t A[LARGE], B[LARGE];
// int_t n, m;
// cin >> n >> m;
// for (int_t i = 0; i <= n; i++) cin >> A[i];
// for (int_t i = 0; i <= m; i++) cin >> B[i];
// int_t size = upper2n(n + m + 1);
// transform(A, size);
// for (int_t i = 0; i < size; i++) cout << A[i] << " ";
// cout << endl;
// transform(B, size);
// for (int_t i = 0; i < size; i++) cout << B[i] << " ";
// cout << endl;
// for (int_t i = 0; i < size; i++) A[i] = 1ll * A[i] * B[i] % mod;
// transform<-1>(A, size);
// for (int_t i = 0; i <= n + m; i++)
// cout << 1ll * A[i] * power(size, -1) % mod << " ";
// return 0;
// }
int n, m;
static int_t A[LARGE], seq[LARGE], result[LARGE];
scanf("%d%d", &n, &m);
for (int_t i = 0; i <= n; i++) scanf("%d", &A[i]);
for (int_t i = 0; i < m; i++) scanf("%d", &seq[i]);
poly_evaluation(A, seq, result, 0, std::max(n + 1, m));
for (int_t i = 0; i < m; i++) {
printf("%d\n", (int)result[i]);
}
#ifdef TIME
auto end = clock();
std::cerr << (end - begin) / (1.0 * CLOCKS_PER_SEC) << endl;
#endif
return 0;
}
//计算n-1次多项式A代入x的值
int_t sub(int_t *A, int_t n, int_t x) {
int_t result = 0;
for (int_t i = 0; i < n; i++) {
result = ((int64_t)result * x + A[n - i - 1]) % mod;
}
return result;
}
//[left,right]闭区间
//下标从0开始
void poly_evaluation(int_t *A, const int_t *seq, int_t *result, int_t left,
int_t right) {
if (right - left <= 2000) {
for (int_t i = left; i <= right; i++) {
result[i] = sub(A + left, right - left + 1, seq[i]);
}
return;
}
if (left == right) {
result[left] = A[left];
return;
}
static int_t P0[LARGE], P1[LARGE], Q[LARGE], R0[LARGE], R1[LARGE];
int_t mid = (right + left) / 2;
int_t size = upper2n(right - left + 2);
for (int_t i = 0; i < size; i++) {
P0[i] = P1[i] = 0;
}
int_t size1 = upper2n((mid - left + 1) * 2 + 1),
size2 = upper2n((right - mid) * 2 + 1);
for (int_t i = 0; i < std::max(size1, size2); i += 2) {
P0[i] = P1[i] = 1;
}
for (int_t i = 0; i < mid - left + 1; i++) {
P0[i * 2] = -seq[i + left];
P0[i * 2 + 1] = 1;
}
for (int_t i = 0; i < right - mid; i++) {
P1[i * 2] = -seq[i + mid + 1];
P1[i * 2 + 1] = 1;
}
#ifdef DEBUG
cout << "processing at interval " << left << "," << right << endl;
cout << "mid = " << mid << endl;
// cout << "size0=" << size0 << endl;
cout << "size1=" << size1 << ",size2=" << size2 << endl;
// cout << "size=" << size << endl;
cout << "A ";
for (int_t i = left; i <= right; i++) cout << A[i] << " ";
cout << endl;
cout << "P0 ";
for (int_t i = 0; i < size1; i++) cout << P0[i] << " ";
cout << endl;
cout << "P1 ";
for (int_t i = 0; i < size2; i++) cout << P1[i] << " ";
cout << endl;
#endif
poly_dc_mul(P0, size1);
poly_dc_mul(P1, size2);
poly_div(A + left, right - left, P0, mid - left + 1, Q, R0);
poly_div(A + left, right - left, P1, right - mid, Q, R1);
for (int_t i = mid - left + 1; i < size1; i++) R0[i] = 0;
for (int_t i = right - mid; i < size2; i++) R1[i] = 0;
#ifdef DEBUG
cout << "P0 ";
for (int_t i = 0; i <= mid - left + 1; i++) cout << P0[i] << " ";
cout << endl;
cout << "P1 ";
for (int_t i = 0; i <= right - mid; i++) cout << P1[i] << " ";
cout << endl;
cout << "R0 ";
for (int_t i = 0; i < size1; i++) cout << R0[i] << " ";
cout << endl;
cout << "R1 ";
for (int_t i = 0; i < size2; i++) cout << R1[i] << " ";
cout << endl;
#endif
for (int_t i = 0; i < mid - left + 1; i++) {
A[i + left] = R0[i];
}
for (int_t i = 0; i < right - mid; i++) {
A[i + mid + 1] = R1[i];
}
#ifdef DEBUG
cout << "A modfied ";
for (int_t i = left; i <= right; i++) cout << A[i] << " ";
cout << endl;
#endif
#ifdef DEBUG
cout << endl << endl;
#endif
poly_evaluation(A, seq, result, left, mid);
poly_evaluation(A, seq, result, mid + 1, right);
}
//计算分治乘法,n是项数,必须是2的幂次
//初始时base中存放着若干个次数为1的多项式,每个占了两个位置
inline void poly_dc_mul(int_t *base, int_t n) {
#ifdef DEBUG
// cout << "calc " << endl;
// for (int_t i = 0; i < n; i++)
// cout << base[i] << " ";
// cout << "="
// << " ";
assert((1 << (int_t)log2(n)) == n);
#endif
//枚举块大小
for (int_t i = 4; i <= n; i *= 2) {
//每块前后两部分相乘
for (int_t j = 0; j < n; j += i) {
static int_t A[LARGE], B[LARGE];
std::fill(A, A + i, 0);
std::fill(B, B + i, 0);
std::copy(base + j, base + j + i / 2, A);
std::copy(base + j + i / 2, base + j + i, B);
poly_mul(A, B, i);
std::copy(A, A + i, base + j);
}
}
#ifdef DEBUG
// for (int_t i = 0; i < n; i++)
// cout << base[i] << " ";
// cout << endl
// << endl;
#endif
}
//计算n次多项式A除以m次多项式B的商和余数。
//需要确保高次项干净。
inline void poly_div(int_t *A, int_t n, int_t *B, int_t m, int_t *Q, int_t *R) {
const int_t size = upper2n(n + m + 1);
// #ifdef DEBUG
// cout<<"moding "
// #endif
static int_t Ax[LARGE], Bx[LARGE], Qx[LARGE], Binv[LARGE];
for (int_t i = 0; i <= n; i++) Ax[i] = A[n - i];
for (int_t i = 0; i <= m; i++) Bx[i] = B[m - i];
const int_t len = size - (n - m + 1);
memset(&Ax[n - m + 1], 0, sizeof(int_t) * len);
memset(&Bx[n - m + 1], 0, sizeof(int_t) * len);
poly_inv(Bx, Binv, n - m + 1);
memcpy(Bx, Binv, sizeof(int_t) * (n - m + 1));
memset(&Ax[n - m + 1], 0, sizeof(int_t) * len);
poly_mul(Ax, Bx, size);
for (int_t i = 0; i <= n - m; i++) {
Qx[i] = Q[i] = Ax[n - m - i];
}
memset(&Qx[n - m + 1], 0, sizeof(int_t) * len);
poly_mul(Qx, B, size);
for (int_t i = 0; i <= m - 1; i++)
R[i] = ((int64_t)A[i] - Qx[i] + 2 * mod) % mod;
}
inline void poly_mul(int_t *A, int_t *Bx, int_t size) {
static int_t B[LARGE + 1];
memcpy(B, Bx, sizeof(int_t) * size);
transform(A, size, 1);
transform(B, size, 1);
for (int_t i = 0; i < size; i++) A[i] = (int64_t)A[i] * B[i] % mod;
transform(A, size, -1);
const int_t inv = power(size, -1);
for (int_t i = 0; i < size; i++) A[i] = (int64_t)A[i] * inv % mod;
}
void poly_inv(int_t *A, int_t *inv, int_t n) {
static int_t Ax[LARGE];
if (n == 1) {
inv[0] = power(A[0], -1);
// for (int_t i = 1; i < upper2n(3); i++) inv[i] = 0;
return;
}
poly_inv(A, inv, n / 2 + n % 2);
// C(x)<-2B(x)-A(x)B(x)^2
int_t size = upper2n(3 * n + 1);
for (int_t i = 0; i < size; i++) {
if (i < n)
Ax[i] = A[i];
else
Ax[i] = 0;
}
for (int_t i = n; i < size; i++) inv[i] = 0;
// for(int_t i=)
transform(Ax, size, 1);
transform(inv, size, 1);
for (int_t i = 0; i < size; i++)
inv[i] = ((int64_t)2 * inv[i] -
(int64_t)Ax[i] * inv[i] % mod * inv[i] % mod + 2 * mod) %
mod;
transform(inv, size, -1);
const int_t size_inv = power(size, -1);
for (int_t i = 0; i < size; i++) {
if (i < n)
inv[i] = (int64_t)inv[i] * size_inv % mod;
else
inv[i] = 0;
}
}
inline constexpr int_t upper2n(int_t x) {
int_t res = 1;
while (res < x) res *= 2;
return res;
}
inline void transform(int_t *A, int_t size, int_t arg) {
const int_t size2 = log2(size);
for (int_t i = 0; i < size; i++) {
int_t x = bitRevs[size2][i];
if (x > i) std::swap(A[x], A[i]);
}
for (int_t i = 2; i <= size; i *= 2) {
int_t mr = power(g, arg * (mod - 1) / i);
int_t *p1 = A;
int_t *p2 = p1 + i / 2;
int_t counter = 0;
int_t curr = 1;
while (p2 < A + size) {
int_t u = *p1, t = *p2 * (int64_t)curr % mod;
*p1 = (u + t) % mod;
*p2 = ((int64_t)u - t + (int64_t)2 * mod) % mod;
curr = (int64_t)curr * mr % mod;
counter += 1;
p1++;
p2++;
if (counter == (i >> 1)) {
counter = 0;
p1 += (i >> 1);
p2 += (i >> 1);
curr = 1;
}
}
}
}
int_t power(int_t base, int_t index) {
const int_t phi = mod - 1;
base = (base % mod + mod) % mod;
index = (index % phi + phi) % phi;
int_t result = 1;
while (index) {
if (index & 1) result = (int64_t)result * base % mod;
base = (int64_t)base * base % mod;
index >>= 1;
}
return result;
}
```
------------
为什么洛谷还不支持多行Latex啊...
AxMath导出都没法用了..