题解 P5050 【【模板】多项式多点求值】

officeyutong

2019-01-14 20:14:48

Solution

![](https://cdn.luogu.com.cn/upload/pic/48589.png) 一定要及时清空存储多项式数组中不相关的项。 请忽略大量调试代码。 ```cpp #pragma GCC optimize("Ofast") #include <assert.h> #include <algorithm> #include <bitset> #include <cmath> #include <cstring> #include <ctime> #include <iostream> #include <vector> using int_t = int; using std::cin; using std::cout; using std::endl; #ifdef DEBUG #define debug(x) cout << #x << " = " << x << endl; #else #define debug(x) #endif const int_t mod = 998244353; const int_t g = 3; const int_t LARGE = 1 << 20; int_t power(int_t base, int_t index); void transform(int_t *A, int_t size, int_t arg); constexpr int_t upper2n(int_t x); void poly_mul(int_t *A, int_t *B, int_t size); void poly_inv(int_t *A, int_t *inv, int_t n); void poly_div(int_t *A, int_t n, int_t *B, int_t m, int_t *Q, int_t *R); void poly_dc_mul(int_t *base, int_t n); void poly_evaluation(int_t *A, const int_t *seq, int_t *result, int_t left, int_t right); template <class T> std::ostream &operator<<(std::ostream &os, std::vector<T> &x) { os << "{ "; for (int_t v : x) os << v << " "; os << "} "; return os; } inline int_t bitRev(int_t x, int_t size2) { int_t res = 0; for (int_t i = 1; i < size2; i++) { res |= (x & 1); res <<= 1; x >>= 1; } return res | (x & 1); }; int bitRevs[20][LARGE + 1]; int main() { #ifdef TIME auto begin = clock(); #endif for (int_t i = 1; (1 << i) < LARGE; i++) { for (int_t j = 0; j < LARGE; j++) { bitRevs[i][j] = bitRev(j, i); } } // { // static int_t A[LARGE], B[LARGE]; // int_t n, m; // cin >> n >> m; // for (int_t i = 0; i <= n; i++) cin >> A[i]; // for (int_t i = 0; i <= m; i++) cin >> B[i]; // int_t size = upper2n(n + m + 1); // transform(A, size); // for (int_t i = 0; i < size; i++) cout << A[i] << " "; // cout << endl; // transform(B, size); // for (int_t i = 0; i < size; i++) cout << B[i] << " "; // cout << endl; // for (int_t i = 0; i < size; i++) A[i] = 1ll * A[i] * B[i] % mod; // transform<-1>(A, size); // for (int_t i = 0; i <= n + m; i++) // cout << 1ll * A[i] * power(size, -1) % mod << " "; // return 0; // } int n, m; static int_t A[LARGE], seq[LARGE], result[LARGE]; scanf("%d%d", &n, &m); for (int_t i = 0; i <= n; i++) scanf("%d", &A[i]); for (int_t i = 0; i < m; i++) scanf("%d", &seq[i]); poly_evaluation(A, seq, result, 0, std::max(n + 1, m)); for (int_t i = 0; i < m; i++) { printf("%d\n", (int)result[i]); } #ifdef TIME auto end = clock(); std::cerr << (end - begin) / (1.0 * CLOCKS_PER_SEC) << endl; #endif return 0; } //计算n-1次多项式A代入x的值 int_t sub(int_t *A, int_t n, int_t x) { int_t result = 0; for (int_t i = 0; i < n; i++) { result = ((int64_t)result * x + A[n - i - 1]) % mod; } return result; } //[left,right]闭区间 //下标从0开始 void poly_evaluation(int_t *A, const int_t *seq, int_t *result, int_t left, int_t right) { if (right - left <= 2000) { for (int_t i = left; i <= right; i++) { result[i] = sub(A + left, right - left + 1, seq[i]); } return; } if (left == right) { result[left] = A[left]; return; } static int_t P0[LARGE], P1[LARGE], Q[LARGE], R0[LARGE], R1[LARGE]; int_t mid = (right + left) / 2; int_t size = upper2n(right - left + 2); for (int_t i = 0; i < size; i++) { P0[i] = P1[i] = 0; } int_t size1 = upper2n((mid - left + 1) * 2 + 1), size2 = upper2n((right - mid) * 2 + 1); for (int_t i = 0; i < std::max(size1, size2); i += 2) { P0[i] = P1[i] = 1; } for (int_t i = 0; i < mid - left + 1; i++) { P0[i * 2] = -seq[i + left]; P0[i * 2 + 1] = 1; } for (int_t i = 0; i < right - mid; i++) { P1[i * 2] = -seq[i + mid + 1]; P1[i * 2 + 1] = 1; } #ifdef DEBUG cout << "processing at interval " << left << "," << right << endl; cout << "mid = " << mid << endl; // cout << "size0=" << size0 << endl; cout << "size1=" << size1 << ",size2=" << size2 << endl; // cout << "size=" << size << endl; cout << "A "; for (int_t i = left; i <= right; i++) cout << A[i] << " "; cout << endl; cout << "P0 "; for (int_t i = 0; i < size1; i++) cout << P0[i] << " "; cout << endl; cout << "P1 "; for (int_t i = 0; i < size2; i++) cout << P1[i] << " "; cout << endl; #endif poly_dc_mul(P0, size1); poly_dc_mul(P1, size2); poly_div(A + left, right - left, P0, mid - left + 1, Q, R0); poly_div(A + left, right - left, P1, right - mid, Q, R1); for (int_t i = mid - left + 1; i < size1; i++) R0[i] = 0; for (int_t i = right - mid; i < size2; i++) R1[i] = 0; #ifdef DEBUG cout << "P0 "; for (int_t i = 0; i <= mid - left + 1; i++) cout << P0[i] << " "; cout << endl; cout << "P1 "; for (int_t i = 0; i <= right - mid; i++) cout << P1[i] << " "; cout << endl; cout << "R0 "; for (int_t i = 0; i < size1; i++) cout << R0[i] << " "; cout << endl; cout << "R1 "; for (int_t i = 0; i < size2; i++) cout << R1[i] << " "; cout << endl; #endif for (int_t i = 0; i < mid - left + 1; i++) { A[i + left] = R0[i]; } for (int_t i = 0; i < right - mid; i++) { A[i + mid + 1] = R1[i]; } #ifdef DEBUG cout << "A modfied "; for (int_t i = left; i <= right; i++) cout << A[i] << " "; cout << endl; #endif #ifdef DEBUG cout << endl << endl; #endif poly_evaluation(A, seq, result, left, mid); poly_evaluation(A, seq, result, mid + 1, right); } //计算分治乘法,n是项数,必须是2的幂次 //初始时base中存放着若干个次数为1的多项式,每个占了两个位置 inline void poly_dc_mul(int_t *base, int_t n) { #ifdef DEBUG // cout << "calc " << endl; // for (int_t i = 0; i < n; i++) // cout << base[i] << " "; // cout << "=" // << " "; assert((1 << (int_t)log2(n)) == n); #endif //枚举块大小 for (int_t i = 4; i <= n; i *= 2) { //每块前后两部分相乘 for (int_t j = 0; j < n; j += i) { static int_t A[LARGE], B[LARGE]; std::fill(A, A + i, 0); std::fill(B, B + i, 0); std::copy(base + j, base + j + i / 2, A); std::copy(base + j + i / 2, base + j + i, B); poly_mul(A, B, i); std::copy(A, A + i, base + j); } } #ifdef DEBUG // for (int_t i = 0; i < n; i++) // cout << base[i] << " "; // cout << endl // << endl; #endif } //计算n次多项式A除以m次多项式B的商和余数。 //需要确保高次项干净。 inline void poly_div(int_t *A, int_t n, int_t *B, int_t m, int_t *Q, int_t *R) { const int_t size = upper2n(n + m + 1); // #ifdef DEBUG // cout<<"moding " // #endif static int_t Ax[LARGE], Bx[LARGE], Qx[LARGE], Binv[LARGE]; for (int_t i = 0; i <= n; i++) Ax[i] = A[n - i]; for (int_t i = 0; i <= m; i++) Bx[i] = B[m - i]; const int_t len = size - (n - m + 1); memset(&Ax[n - m + 1], 0, sizeof(int_t) * len); memset(&Bx[n - m + 1], 0, sizeof(int_t) * len); poly_inv(Bx, Binv, n - m + 1); memcpy(Bx, Binv, sizeof(int_t) * (n - m + 1)); memset(&Ax[n - m + 1], 0, sizeof(int_t) * len); poly_mul(Ax, Bx, size); for (int_t i = 0; i <= n - m; i++) { Qx[i] = Q[i] = Ax[n - m - i]; } memset(&Qx[n - m + 1], 0, sizeof(int_t) * len); poly_mul(Qx, B, size); for (int_t i = 0; i <= m - 1; i++) R[i] = ((int64_t)A[i] - Qx[i] + 2 * mod) % mod; } inline void poly_mul(int_t *A, int_t *Bx, int_t size) { static int_t B[LARGE + 1]; memcpy(B, Bx, sizeof(int_t) * size); transform(A, size, 1); transform(B, size, 1); for (int_t i = 0; i < size; i++) A[i] = (int64_t)A[i] * B[i] % mod; transform(A, size, -1); const int_t inv = power(size, -1); for (int_t i = 0; i < size; i++) A[i] = (int64_t)A[i] * inv % mod; } void poly_inv(int_t *A, int_t *inv, int_t n) { static int_t Ax[LARGE]; if (n == 1) { inv[0] = power(A[0], -1); // for (int_t i = 1; i < upper2n(3); i++) inv[i] = 0; return; } poly_inv(A, inv, n / 2 + n % 2); // C(x)<-2B(x)-A(x)B(x)^2 int_t size = upper2n(3 * n + 1); for (int_t i = 0; i < size; i++) { if (i < n) Ax[i] = A[i]; else Ax[i] = 0; } for (int_t i = n; i < size; i++) inv[i] = 0; // for(int_t i=) transform(Ax, size, 1); transform(inv, size, 1); for (int_t i = 0; i < size; i++) inv[i] = ((int64_t)2 * inv[i] - (int64_t)Ax[i] * inv[i] % mod * inv[i] % mod + 2 * mod) % mod; transform(inv, size, -1); const int_t size_inv = power(size, -1); for (int_t i = 0; i < size; i++) { if (i < n) inv[i] = (int64_t)inv[i] * size_inv % mod; else inv[i] = 0; } } inline constexpr int_t upper2n(int_t x) { int_t res = 1; while (res < x) res *= 2; return res; } inline void transform(int_t *A, int_t size, int_t arg) { const int_t size2 = log2(size); for (int_t i = 0; i < size; i++) { int_t x = bitRevs[size2][i]; if (x > i) std::swap(A[x], A[i]); } for (int_t i = 2; i <= size; i *= 2) { int_t mr = power(g, arg * (mod - 1) / i); int_t *p1 = A; int_t *p2 = p1 + i / 2; int_t counter = 0; int_t curr = 1; while (p2 < A + size) { int_t u = *p1, t = *p2 * (int64_t)curr % mod; *p1 = (u + t) % mod; *p2 = ((int64_t)u - t + (int64_t)2 * mod) % mod; curr = (int64_t)curr * mr % mod; counter += 1; p1++; p2++; if (counter == (i >> 1)) { counter = 0; p1 += (i >> 1); p2 += (i >> 1); curr = 1; } } } } int_t power(int_t base, int_t index) { const int_t phi = mod - 1; base = (base % mod + mod) % mod; index = (index % phi + phi) % phi; int_t result = 1; while (index) { if (index & 1) result = (int64_t)result * base % mod; base = (int64_t)base * base % mod; index >>= 1; } return result; } ``` ------------ 为什么洛谷还不支持多行Latex啊... AxMath导出都没法用了..