\(\operatorname{FFT}\)算法支持在\(O(n log n)\)时间内计算两个\(n\)度的多项式的乘法。也可以用来加速大整数乘法运算。
用一个多项式的各项系数来表达这个多项式(升幂顺序):
\[\large f(x) = a_0 + a_1x + a_2x^2 + \ldots + a_nx^n \iff f(x) = \{a_0,a_1, \ldots ,a_n \} \]则它的计算公式为:
\[\large f(x) = \sum _{i = 0}^{n} a_i * x^i \]时间复杂度:\(O(n^2)\)
把一个多项式看成一个坐标系中的函数图像,从图像上选取\(n +1\)个点,利用这\(n + 1\)个点来唯一地表示这个函数。
为什么\(n + 1\)个点可以唯一地表示\(f(n)\)
\[由小学数学可知,两点确定一条直线(一次函数),三点确定一条抛物线(二次函数)... 以此类推,n + 1个点可以确定一个n次函数,就是变量x的最高次幂为n的函数。 \]知道了正确性,那么我们可以设:
\[\large f(x_0) = y_0 = a_0 + a_1x_0 + a_2x_0^2 + \ldots +a_nx_0^n \]\[\large f(x_1) = y_1 = a_0+ a_1x_1 + a_2x_1^2 + \ldots + a_nx_1^n \]\[\large f(x_2)= y_2= a_0+a_1x_2+ a_2x_2^2+\ldots+ a_nx_2^n \]\[\ldots \]\[\large f(x_n) = y_n = a_0 + a_1x_n + a_2x_n^2+\dots + a_nx_n^n \]那么用点值表示法表示\(f(x)\)如下:
\[\large f(x) = a_0 + a_1x + a_2x^2 + a_3x^3 + \ldots +a_nx^n \iff f(x) = \{(x_0,y_0),(x_1,y_1),(x_2,y_2),(x_3,y_3),\ldots,(x_n,y_n)\} \]它的计算公式为:
\[\large f(i) = \sum _{j = 0}^{n - 1} a_j * x_i^j \]时间复杂度:\(O(n^2)\)
\(x^n = 1\)在复数意义下的解释\(n\)次负根。这样的解有\(n\)个。
设:\(\large \omega_n = e^{\frac{2 \pi i}{n}}\)
则:\(x^n\)的解集表示为:
\[\large \{\omega_n | k = 0,1,\ldots,n - 1\} \]称\(\omega_n\)是\(n\)次单位负根,其他负根均可以用单位负根的幂表示。
由于我数学太菜了所以不会证明。
根据欧拉公式可得:
\[\large \omega_n = e^{\frac{2 \pi i}{n}} = cos(\frac{2\pi}{n}) + i * sin(\frac{2\pi}{n}) \]柿子很好记,可是我不知道怎么证,可是记住就行了。
单位负根。对于任意正整数\(n\)和整数\(k\):
\[\large \omega ^n_n = 1 \]\[\large \omega^k_n = \omega^{2k}_{2n} \]\[\large \omega^{k + n}_{2n} = -\omega^k_{2n} \]\(\operatorname{FFT}\)的基本思想是分治。
先看\(\operatorname{DFT}\),它分治地来求当\(x = \omega^k_n\)的时候\(f(x)\)的值,分治思想体现在将多项式分为奇次项和偶次项处理。
令\(f(x)\)为一个\(n\)次的多项式。
令\(G(x)\)表示偶次项系数建立的新函数:
\[\large G(x) = a_0 + a_2x + a_4x^2 +\ldots+ a_{n -2}x^{\frac{n - 2}{2}} \]令\(H(x)\)表示奇次项系数建立的新函数:
\[\large H(x) = a_1 + a_3x + a_5x^2 +\ldots+a_{n - 1}x^{\lfloor \frac{n - 1}{2}\rfloor} \]那么原来的\(f(x)\)用新的两个函数表示为:
\[\large f(x) = G(x^2) + x * H(x ^ 2) \]利用单位负根的性质推柿子:
\[\large \operatorname{DFT}(f(\omega^k_n)) = \large\operatorname{DFT}(G((\omega^k_n)^2)) + \operatorname{DFT}(H((\omega^k_n)^2)) \]\[=\large\operatorname{DFT}(G(\omega^{2k}_n)) + \omega^k_n * \operatorname{DFT}(H(\omega^{2k}_n)) \]\[=\large\operatorname{DFT}(G(\omega^k_{\frac{n}{2}})) + \omega^k_n * \operatorname{DFT}(H(\omega^k_{\frac{n}{2}})) \]同理:
\[\large \operatorname{DFT}(f(\omega^{k + \frac{n}{2}}_n)) = \large\operatorname{DFT}(G((\omega^k_n)^2)) + \operatorname{DFT}(H((\omega^k_n)^2)) \]\[=\large\operatorname{DFT}(G(\omega^{2k}_n)) + \omega^k_n * \operatorname{DFT}(H(\omega^{2k}_n)) \]\[=\large\operatorname{DFT}(G(\omega^k_{\frac{n}{2}})) + \omega^k_n * \operatorname{DFT}(H(\omega^k_{\frac{n}{2}})) \]因此,我们求出\(\operatorname{DFT}(G(\omega^k_{\frac{n}{2}}))\)和\(DFT(H(\omega^k_{\frac{n}{2}}))\)后,就可以同时求出\(\operatorname{DFT}(f(\omega^k_n)\)和\(\operatorname{DFT}(f(\omega^{k + \frac{n}{2}}_n))\),于是对\(G\)和\(H\)分别递归\(\operatorname {DFT}\)即可。
注意:考虑到分治\(\operatorname {DFT}\)能处理的多项式长度只能是\(2^m(m \in N^*)\),否则在分治的时候左式和右式长度不等,右式就取不到系数了。所以要在第一次\(\operatorname{DFT}\)之前就把序列向上补成长度为\(2^m(m \in N^*)\),最高项次数为\(2^m - 1\)的多项式(高次系数补0)。
在代入值的时候,因为要代入\(n\)个不同值,所以我们代入\(\omega^0_n,\omega^1_n,\omega^2_n\ldots,\omega^{n - 1}_n(n = 2^m(m\in N^*))\)一共\(2^m\)个不同值。
把点值表示法转化为系数表示法。
考虑原本的多项式:
\[\large f(x) = a_0 + a_1x + a_2x^2+\ldots+a_{n - 1}x^{n - 1} = \sum^{n - 1}_{i = 0}a_ix^i \]考虑构造法。
我们已知\(y_i = f(\omega^i_n),i\in\{0,1,\ldots,n - 1\}\),要求\(\{a_0,a_1,\ldots,a_{n - 1}\}\)
构造多项式:
\[\large A(x) = \sum^{n - 1}_{i = 0}y_ix^i \]也就是把\(\{y_0,y_1,\ldots,y_{n - 1}\}\)当做多项式\(A\)的系数表示法。
设\(b_i = \omega^{-i}_n\),则多项式\(A\)在\(x = b_0,b_1,\ldots,b_{n-1}\)处的点值表示法为:\(\{A(b_0),A(b_1),\ldots,A(b_{n - 1})\}\)
对\(A(x)\)的定义式做一下变换,可以将\(A(b_k)\)表示为:
\[\large A(b_k) = \sum^{n-1}_{i=0}f(\omega^i_n)\omega^{-ik}_n \]\[\large =\sum^{n-1}_{i=0}\omega^{-ik}_{n}\sum^{n-1}_{j=0}a_j(\omega^i_n)^j \]\[\large =\sum^{n-1}_{i=0}\sum^{n-1}_{j=0}a_j\omega^{i(j-k)}_{n} \]\[\large =\sum^{n-1}_{j=0}a_j\sum^{n-1}_{i=0}(\omega^{j-k}_{n})^i \]记\(S( \omega^a_n)=\sum^{n-1}{i=0}(\omega^a_n)^i\)
当\(a=0(\bmod n)\)时,\(S(\omega^a_n) = n\)
当\(a\ne0(\bmod n)\)时,错位相减:
\[\large S(\omega^a_n)=\sum^{n-i}_{i=0}(\omega^a_n)^i \]\[\large\omega^a_nS(\omega^a_n)=\sum^n_{i=1}(\omega^a_n)^i \]\[\large S(\omega^a_n)=\dfrac{(\omega^a_n)^n-(\omega^a_n)^0}{\omega^a_n-1}=0 \]也就是:
\[\large S(\omega^a_n) = \begin{cases}n&a=0\\0&a\ne0\end{cases} \]代回原式得:
\[\large A(b_k)=\sum^{n-1}_{j=0}a_jS(\omega^{j-k}_n)=a_k\cdot n \]那么多项式\(A\)的点值表示法为:
\[\begin{aligned}\large \{(b_0,A(b_0),(b_1,A(b_1)),\ldots,(b_{n-1},A(b_{n-1})))\}\\=\large \{(b_0,a_0\cdot n),(b_1,a_1\cdot n),\ldots,(b_{n-1},a_{n-1}\cdot n)\}\end{aligned} \]总结:我们取单位根为其倒数,对\(\{y_0,y_1,\ldots,y_{n-1}\}\)跑一遍\(\operatorname{FFT}\),然后除以\(n\)即可得到\(f(x)\)的系数表示。
代码实现:luogu P3803[模板]多项式乘法(FFT)
#include <bits/stdc++.h> using namespace std; //#define ll long long #define ri register int const int maxn = 5e6 + 5; const double pi = acos(-1.0); int n,m,l,r[maxn],lim = 1; inline int read() { int s = 0,w = 1; char ch = getchar(); while (ch < '0' || ch > '9') {if (ch == '-') w = -1; ch = getchar();} while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0',ch = getchar(); return s * w; } struct cpx { double x,y; cpx(double a = 0,double b = 0) {x = a; y = b;} cpx operator + (const cpx &a) {return cpx(x + a.x,y + a.y);} cpx operator - (const cpx &a) {return cpx(x - a.x,y - a.y);} cpx operator * (const cpx &a) {return cpx(x * a.x - y * a.y,x * a.y + y * a.x);} }a[maxn],b[maxn]; inline void fft(cpx *f,int type) { for (ri i = 0;i < lim;i++) if (i < r[i]) swap(f[i],f[r[i]]); for (ri mid = 1;mid < lim;mid <<= 1) { cpx W(cos(pi / mid),type * sin(pi / mid)); int len = (mid << 1); for (ri j = 0;j < lim;j += len) { cpx pw(1,0); for (ri k = 0;k < mid;k++,pw = pw * W) { cpx x = f[j + k],y = pw * f[j + k + mid]; f[j + k] = x + y; f[j + k + mid] = x - y; } } } } int main() { n = read(); m = read(); for (ri i = 0;i <= n;i++) a[i].x = read(); for (ri i = 0;i <= m;i++) b[i].x = read(); while (lim <= n + m) {lim <<= 1; l++;} for (ri i = 0;i < lim;i++) r[i] = (r[i >> 1] >> 1) | (i & 1) << (l - 1); fft(a,1); fft(b,1); // for (int i = 0;i <= lim;i++) cout << a[i].x << " " << a[i].y << endl; for (ri i = 0;i <= lim;i++) a[i] = a[i] * b[i]; fft(a,-1); for (ri i = 0;i <= n + m;i++) printf("%d ",(int)(a[i].x / lim + 0.5)); return 0; }
本文大部分内容来自:oi-wiki