An easier method of polynomial composition: No more transposition theorem is needed!
Petit_Souris · · 算法·理论
Just some random thoughts arised by a simple problem. But it seems that I'm reinventing wheels once more :(
If there's any mistake in the article, please don't hesitate to point it out in the comment.
Introduction
Yesterday night, my friend asked me a question on polynomials (his profile picture is pixelated in order to prevent privacy issues):
(Translation: Ciallo! Is there any solution for
And unfortunately I didn't figure out the solution at first. I even considered it as an unsolvable problem:
(Translation: It's actually unsolvable. I guess, isn't it?)
But I suddenly realized that in the latest algorithm for polynomial composition, we have done a similar process working on
So it might be solvable in
Why not give it a try?
Solution
We need to get coefficients for every possible
Do you remember how we handled the linear homogeneous recursions using polynomials? We transformed it into an evaluation of a certain coefficient, in the form of a fraction:
For the bivariate polynomial with
After this iteration, we can only conserve about half of the elements depending on the parity of
Consider the time complexity of the algorithm: after each iteration, although the degree of
We handled
Polynomial Composition
Now let us move on to polynomial composition. Suppose we have two polynomials
We need to assign a coefficient for each single
Thus, we are actually multiplying the whole expression with
The following step is just the same as we have done in the previous part. Moreover, we need to store every coefficient from
Suppose we are handling
Can we avoid operations on
For the first iteration, as we need coefficient of
As we have said in the previous part,
Thus, polynomial composition under modulo
Code
:::success[P10249 (n = 2e5) runtime = 11s on luogu]
#include <bits/stdc++.h>
using ll = long long;
using ld = long double;
using ull = unsigned long long;
using namespace std;
template <class T>
using Ve = vector<T>;
#define ALL(v) (v).begin(), (v).end()
#define pii pair<ll, ll>
#define rep(i, a, b) for(int i = (a); i <= (b); ++i)
#define per(i, a, b) for(int i = (a); i >= (b); --i)
#define pb push_back
bool Mbe;
ll read() {
ll x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
return x * f;
}
void write(ll x) {
if(x < 0) putchar('-'), x = -x;
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
const ll N = 2e5 + 9;
const ll Mod = 998244353, G = 3, iG = (Mod + 1) / 3;
struct poly {
Ve<int> a;
int size() const {return a.size();}
void resize(int n) {a.resize(n);}
int operator[] (int n) const {
assert(0 <= n && n < (int)a.size());
return a[n];
}
int &operator[] (int n) {
assert(0 <= n && n < (int)a.size());
return a[n];
}
};
int rev[N << 5], tw[N << 5];
int Init(int n) {
int p = 1, c = 0;
while(p <= n) p <<= 1, ++c;
rev[0] = 0;
rep(i, 1, p - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (c - 1));
return p;
}
int pw(int x, int p) {
int res = 1;
while(p) {
if(p & 1) res = 1ll * res * x % Mod;
x = 1ll * x * x % Mod, p >>= 1;
}
return res;
}
int Add(int x, int y) {
return ((x += y) >= Mod) ? (x - Mod) : x;
}
int Sub(int x, int y) {
return ((x -= y) < 0) ? (x + Mod) : x;
}
void NTT(poly &a, int n, int sgn) {
rep(i, 0, n - 1) {
if(i < rev[i]) swap(a[i], a[rev[i]]);
}
for(int i = 1; i < n; i <<= 1) {
int wn = pw((sgn == 1 ? G : iG), (Mod - 1) / (i << 1));
tw[0] = 1;
rep(j, 1, i - 1) tw[j] = 1ll * tw[j - 1] * wn % Mod;
for(int j = 0; j < n; j += (i << 1)) {
rep(k, 0, i - 1) {
int x = a[j + k], y = 1ll * a[j + k + i] * tw[k] % Mod;
a[j + k] = Add(x, y), a[j + k + i] = Sub(x, y);
}
}
}
if(~sgn) return ;
int iv = pw(n, Mod - 2);
rep(i, 0, n - 1) a[i] = 1ll * a[i] * iv % Mod;
return ;
}
poly operator * (const poly &a, const poly &b) {
ll n = a.size(), m = b.size();
ll p = Init(n + m - 1);
poly f = a, g = b; f.resize(p), g.resize(p);
NTT(f, p, 1), NTT(g, p, 1);
rep(i, 0, p - 1) f[i] = 1ll * f[i] * g[i] % Mod;
NTT(f, p, -1), f.resize(n + m - 1);
return f;
}
poly Inv(poly h, int n){
if(n == 1){
poly f;
f.resize(1), f[0] = pw(h[0], Mod - 2);
return f;
}
int n0 = (n + 1) >> 1;
poly h0 = h; h0.resize(n0);
poly f0 = Inv(h0, n0); f0.resize(n);
int len = Init(n << 1);
poly _f0 = f0; _f0.resize(len), h.resize(len);
NTT(_f0, len, 1), NTT(h, len, 1);
rep(i, 0, len - 1) _f0[i] = 1ll * _f0[i] * _f0[i] % Mod * h[i] % Mod;
NTT(_f0, len, -1), _f0.resize(n);
rep(i, 0, n - 1) f0[i] = (f0[i] * 2ll - _f0[i] + Mod) % Mod;
return f0;
}
Ve<poly> work(const poly &P, const Ve<poly> &Q, int n) {
int maxy = (int)Q.size() - 1;
if(!n) {
poly Q0; Q0.resize(maxy + 1);
rep(i, 0, maxy) Q0[i] = Q[i][0];
Q0 = Inv(Q0, maxy + 1);
int lenP = P.size(); poly R = P;
reverse(ALL(R.a));
R = R * Q0;
poly res; res.resize(maxy);
rep(i, 0, (int)R.size() - 1) {
int cy = i - (lenP - 1);
if(cy <= 0 && cy >= -maxy + 1) res[-cy] = R[i];
}
Ve<poly> ret; ret.resize(maxy);
rep(i, 0, maxy - 1) ret[i].resize(1), ret[i][0] = res[i];
return ret;
}
int n2 = (n << 1) | 1;
int len = Init(n2 * max(maxy + 1, maxy * 2) + n2 * (maxy + 1));
poly pos, neg; pos.resize(len), neg.resize(len);
rep(i, 0, maxy) {
rep(j, 0, (int)Q[i].size() - 1) {
int pwr = i * n2 + j;
pos[pwr] = Q[i][j];
if(j & 1) neg[pwr] = (Mod - Q[i][j]);
else neg[pwr] = Q[i][j];
}
}
NTT(neg, len, 1);
NTT(pos, len, 1);
rep(i, 0, len - 1) pos[i] = 1ll * pos[i] * neg[i] % Mod;
NTT(pos, len, -1);
Ve<poly> Q0; Q0.resize(maxy * 2 + 1);
rep(i, 0, maxy * 2) Q0[i].resize((n >> 1) + 1);
rep(i, 0, (int)pos.size() - 1) {
int cy = i / n2, cx = i - cy * n2;
if(cx & 1) continue;
if(cx > n || cy > maxy * 2) continue;
Q0[cy][cx >> 1] = pos[i];
}
Ve<poly> ret0 = work(P, Q0, n >> 1);
int len0 = ret0.size();
reverse(ALL(ret0));
Init(len - 1);
pos.a.clear(), pos.resize(len);
rep(i, 0, len0 - 1) {
rep(j, 0, (int)ret0[i].size() - 1) {
int pwr = i * n2 + j * 2;
pos[pwr] = ret0[i][j];
}
}
NTT(pos, len, 1);
rep(i, 0, len - 1) pos[i] = 1ll * pos[i] * neg[i] % Mod;
NTT(pos, len, -1);
Ve<poly> ret; ret.resize(maxy);
rep(i, 0, maxy - 1) ret[i].resize(n + 1);
rep(i, 0, (int)pos.size() - 1) {
int cy = i / n2, cx = i - cy * n2;
cy -= (len0 - 1);
if(cy <= 0 && cy >= -maxy + 1) {
if(cx <= n) ret[-cy][cx] = pos[i];
}
}
return ret;
}
int n, m;
poly f, g;
bool Med;
int main() {
cerr << fabs(&Med - &Mbe) / 1048576.0 << "MB\n";
n = read(), m = read();
f.resize(n + 1), g.resize(m + 1);
rep(i, 0, n) f[i] = read();
rep(i, 0, m) g[i] = Mod - read();
poly con; con.resize(1), con[0] = 1;
Ve<poly> h = work(f, {con, g}, n);
assert(h.size() == 1);
rep(i, 0, n) write(h[0][i]), putchar(' ');
putchar('\n');
cerr << "\n" << clock() * 1.0 / CLOCKS_PER_SEC * 1000 << "ms\n";
return 0;
}
:::
References:
-
Yasunori Kinoshita, Baitian Li: Power Series Composition in Near-Linear Time.
-
alpha1022: 多项式不存在了:多项式复合(逆)的 O(nlog^2n) 做法.