在任意代数结构上的多项式乘法 学习笔记

· · 算法·理论

cnblogs

前言

Stop learning useless algorithms, go and solve some problems, learn how to use binary search.

以下内容大多是作者看完《如何在任意代数结构上做多项式乘法》[^1] 后口胡的,所以可能和原文章不太一样。如果错了或者有更好的做法请告诉我。

分圆多项式

定义为 \Phi_n(x) = \prod_{1 \le k \lt n,\gcd(k, n)=1}(x - \omega_n^k).

也可以感性理解为 x=\omega_nx^n-1=0,约掉一些“显然”不为 0 的因式后剩下的素多项式。

分圆多项式都是整系数素多项式,且 \Phi_n(x) 最高次数为 \varphi(n)

结论:多项式 f(x) 代入 x=\omega_n 后做运算得到的结果(用 \omega_n 表示,最后再把 \omega_n 换成 x)等于先做运算再 \bmod \Phi_n(x) 得到的结果。感性理解就是 \Phi_n(\omega_n)=0

n=p^mp 为素数,m \ge 1)时,\Phi_n(x)=\sum\limits_{i=0}^{\varphi(n)}x^i

算法原理

要求:三个群 (A,+_A),(B,+_B),(C,+_C),乘法运算 \cdot:A\times B \rightarrow C 具有分配律 (a_1 +_A a_2) \cdot (b_1 +_B b_2) = a_1 \cdot b_1 +_C a_1 \cdot b_2 +_C a_2 \cdot b_1 +_C a_2 \cdot b_2

此时必有 \forall b \in B, e_A \times b = e_c\forall a \in A,a \times e_b = e_c,其中 e_A,e_B,e_C 分别为 A,B,C 中的单位元。于是将 AB 的高位填对应的单位元即可。

证明:

$e_B$ 是类似的。
### Part 1. 解决除法 IDFT 最后要除以长度,而 $C$ 中没有定义自然数乘的逆。 一个解决方法是,分别做长为 $2$ 的幂的 DFT 和长为 $3$ 的幂的 DFT,这样每个元素的 $2^{c_2}$ 倍和 $3^{c_3}$ 倍都已知($c_2$ 和 $c_3$ 取决于长度),类似辗转相除做即可。 ### Part 2. 解决单位根 这是一个很神仙的做法。 考虑把**一部分** $x$ 代入 $\omega_m$ 满足 $\varphi(m) \gt \deg A(x)+\deg B(x)$,然后将 $m$ 拆成 $m=pq$。取 $p=q=\sqrt m$ 能保证最优复杂度,读者自(wo)证(bu)不(hui)难(zheng)。具体实现可以参考代码。 具体地, $$ \begin{aligned} A(x) & = \sum\limits_{j=0}^{q-1}(\sum\limits_{i=0}^{\varphi(p)-1}a_{iq+j}x^{iq})x^j \\ & = \sum\limits_{j=0}^{q-1}(\sum\limits_{i=0}^{\varphi(p)-1}a_{iq+j}\omega_{pq}^{iq})x^j \\ & = \sum\limits_{j=0}^{q-1}(\sum\limits_{i=0}^{\varphi(p)-1}a_{iq+j}\omega_{p}^{i})x^j \end{aligned} $$ 然后将内层带 $\omega_p$ 的东西看成系数对外层做 DFT。实现时可以做成指针套数组的形式。这个部分可能不太好理解,可以看代码。 做完 DFT 要进行内层元素相乘,可以递归。 最后对分圆多项式取模即可。实现时可以暴力将高位减到低位。 ## 应用 好像没啥用... 目前想到的就是做 $c_k = \prod_{i+j=k}a_i^{b_j}$ 之类的卷积? ## 实现 给出一份大常数的实现。期待有大佬能优化。 题目是 lgP3803. ```cpp #define DEBUG 0 #include <iostream> #include <algorithm> #include <cmath> #define UP(i,s,e) for(auto i=s; i<e; ++i) #define DOWN(i,e,s) for(auto i=e; i-->s;) using std::cin; using std::cout; namespace Poly{ // }{{{ template<int BASE, typename T> void change(T* arr, int len){ int *rev = new int[len]; rev[0] = 0; UP(i, 1, len){ rev[i] = rev[i/BASE]/BASE; rev[i] += i%BASE*(len/BASE); } UP(i, 0, len) if(rev[i] > i) std::swap(arr[i], arr[rev[i]]); delete[] rev; } template<int BASE, class A> void fft(A **a, int len, int siz, bool idft){ // siz == len(a[0]) static A *tmp[BASE]; UP(i, 0, BASE){ tmp[i] = new A[siz]; //UP(j, siz, siz*BASE){ // tmp[i][j].unit(); //} } change<BASE>(a, len); int wn = siz/BASE; for(int h=BASE; h<=len; h*=BASE){ for(int st=0; st<len; st+=h){ int w=0; UP(i, st, st+h/BASE){ UP(j, 0, BASE) std::swap(a[i+h/BASE*j], tmp[j]); UP(j, 0, BASE){ auto &now = a[i+h/BASE*j]; std::copy(tmp[0], tmp[0]+siz, now); UP(k, 1, BASE){ UP(l, 0, siz){ int idx = l-(idft?-1:1)*(w+siz/BASE*j)*k; idx %= siz; idx = idx < 0 ? idx + siz : idx; now[l] += tmp[k][idx]; } } } w += wn; } } wn /= BASE; } UP(i, 0, BASE) delete[] tmp[i]; //delete[] tmp; } // mod Phi_len(x) // len = BASE**n template<int BASE, class A, class B, class C> int polymul_base(A *a, B *b, C *ret, int len #if DEBUG , int test=0 #endif ){ UP(i, 0, len/BASE*(BASE-1)) ret[i].unit(); if(len < 100 #if DEBUG && !test #endif ){ int phi_len = len / BASE * (BASE-1); UP(i, 0, len) UP(j, 0, len){ if((i+j)%len >= phi_len) UP(k, 1, BASE){ ret[(i+j)%len-len/BASE*k] += (a[i]*b[j]).inv(); } else { ret[(i+j)%len] += a[i]*b[j]; } } return 1; } int tim = std::round(std::log(len)/std::log(BASE)); int p = std::round(std::pow(BASE, tim/2+1)); int q = std::round(std::pow(BASE, (tim-1)/2)); A **aa = new A*[BASE*q]; B **bb = new B*[BASE*q]; C **cc = new C*[BASE*q]; UP(i, 0, BASE*q){ aa[i] = new A[p]; bb[i] = new B[p]; cc[i] = new C[p]; } UP(i, 0, q*BASE) UP(j, 0, p){ aa[i][j].unit(); bb[i][j].unit();// cc[i][j].unit(); } UP(i, 0, q*BASE) UP(j, p/BASE*(BASE-1), p){ cc[i][j].unit(); } UP(i, 0, q){ UP(j, 0, p){ if(j*q+i >= len){ break; //aa[i][j].unit(); bb[i][j].unit(); } else { aa[i][j] = a[j*q+i]; bb[i][j] = b[j*q+i]; } } //UP(j, p/BASE*(BASE-1), p){ aa[i][j].unit(); bb[i][j].unit(); } } //UP(i, q, BASE*q){ //UP(j, 0, p){ aa[i][j].unit(); bb[i][j].unit(); } //} fft<BASE>(aa, BASE*q, p, false); fft<BASE>(bb, BASE*q, p, false); int scale; UP(i, 0, BASE*q){ scale = polymul_base<BASE>(aa[i], bb[i], cc[i], p #if DEBUG , test ? test-1 : 0 #endif ); } UP(i, 0, BASE*q){ delete[] aa[i]; delete[] bb[i]; } delete[] aa; delete[] bb; fft<BASE>(cc, BASE*q, p, true); int pq = p*q; int phi_pq = pq/BASE*(BASE-1); UP(i, 0, BASE*q) UP(j, 0, p){ int pl = (i+j*q)%pq; if(pl >= phi_pq) UP(k, 1, BASE) ret[(pl-pq/BASE*k)%len] += cc[i][j].inv(); else ret[pl%len] += cc[i][j]; } UP(i, 0, BASE*q) delete[] cc[i]; delete[] cc; return scale * BASE * q; } template<class A, class B, class C> void polymul(A *a, B *b, C *ret, int len){ bool swapped = false; C *tmp = new C[len*2]; int l2 = std::round(std::pow(2, std::ceil(std::log(len*2) / std::log(2)))); int l3 = std::round(std::pow(3, std::ceil(std::log(len*3/2) / std::log(3)))); int tim2 = polymul_base<2>(a, b, tmp, l2); int tim3 = polymul_base<3>(a, b, ret, l3); while(tim3 != 1){ if(tim2 > tim3){ int scale = tim2 / tim3; UP(i, 0, len) tmp[i] += ret[i].inv() * scale; tim2 %= tim3; } std::swap(tim2, tim3); std::swap(ret, tmp); swapped ^= 1; } if(swapped){ std::swap(ret, tmp); std::copy(tmp, tmp+len, ret); } delete[] tmp; } } // {}}} namespace m{ // }{{{ constexpr int N = 5e6+2; struct u32{ unsigned val; u32(){} u32(unsigned v):val(v){} void unit(){val = 0;} u32 inv(){ return -val; } u32 &operator+=(u32 b){ val += b.val; return *this; } u32 &operator*=(u32 b){ val *= b.val; return *this; } u32 operator*(u32 b){ return b *= *this;} u32 operator*(unsigned x){ return val*x; } } ia[N], ib[N], ic[N]; int in, im; void work(){ cin >> in >> im; UP(i, 0, in+1){ cin >> ia[i].val; } UP(i, 0, im+1){ cin >> ib[i].val; } Poly::polymul(ia, ib, ic, in+im+1); UP(i, 0, in+im+1){ cout << ic[i].val << ' '; } } } // {}}} int main(){cin.tie(0)->sync_with_stdio(0); m::work(); return 0;} ``` [^1]: <https://www.cnblogs.com/whx1003/p/16214952.html>