不知道多久才能做完多项式全家桶 qaq
直接上链接(
快速傅里叶变换(FFT)详解 - 自为风月马前卒
总的来说就是先 DFT 从系数表示法到点值表示法,再 IDFT 从点值表示法到系数表示法。
简单说一下不太理解的,在 DFT 中 \(\omega_n^k = -\omega_n^{k+\frac{n}{2}}\),其实就是在单位圆上旋转了 \(180°\)。
感性理解 DFT 和 IDFT 是逆运算,所以 \(\omega_n\) 就是相反数。
#include <bits/stdc++.h> #define ll long long #define db double #define gc getchar #define pc putchar using namespace std; namespace IO { template <typename T> void read(T &x) { x = 0; bool f = 0; char c = gc(); while(!isdigit(c)) f |= c == '-', c = gc(); while(isdigit(c)) x = x * 10 + c - '0', c = gc(); if(f) x = -x; } template <typename T> void write(T x) { if(x < 0) pc('-'), x = -x; if(x > 9) write(x / 10); pc('0' + x % 10); } } using namespace IO; struct Complex { db x, y; Complex(db _x = 0.0, db _y = 0.0) { x = _x, y = _y; } Complex operator + (const Complex b) const { return Complex(x + b.x, y + b.y); } Complex operator - (const Complex b) const { return Complex(x - b.x, y - b.y); } Complex operator * (const Complex b) const { return Complex(x * b.x - y * b.y, x * b.y + y * b.x); } }; const int N = 1e6 + 5; const db pi = acos(-1.0); int n, m; Complex f[N << 2], g[N << 2]; int len = 1, bit, rev[N << 2]; void FFT(Complex a[], int type) { for(int i = 0; i < len; i++) if(i < rev[i]) swap(a[i], a[rev[i]]); for(int mid = 1; mid < len; mid <<= 1) { Complex wn(cos(pi / mid), type * sin(pi / mid)); for(int i = 0; i < len; i += (mid << 1)) { Complex w(1, 0); for(int j = 0; j < mid; j++, w = w * wn) { Complex x = a[i + j], y = w * a[i + mid + j]; a[i + j] = x + y; a[i + mid + j] = x - y; } } } if(type == -1) for(int i = 0; i < len; i++) a[i].x /= len; return; } int main() { read(n), read(m); for(int i = 0; i <= n; i++) read(f[i].x); for(int i = 0; i <= m; i++) read(g[i].x); while(len <= n + m) len <<= 1, bit++; for(int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); FFT(f, 1); FFT(g, 1); for(int i = 0; i < len; i++) f[i] = f[i] * g[i]; FFT(f, -1); for(int i = 0; i <= n + m; i++) printf("%lld ", (ll)(f[i].x + 0.5)); pc('\n'); return 0; } // A.S.
继续上链接(
快速数论变换(NTT)小结 - 自为风月马前卒
就是把 FFT 中的 \(\omega_n\) 改为了 \(g^{\frac{p-1}{n}}\),其中 \(g\) 为模数的原根。
\[\omega_n \equiv g^{\frac{p-1}{n}}\ (\bmod p) \]证明就算了
然后在 IDFT 时就感性理解地取一下逆元就行了(
#include <bits/stdc++.h> #define ll long long #define db double #define gc getchar #define pc putchar #define swap(a, b) a ^= b ^= a ^= b using namespace std; namespace IO { template <typename T> void read(T &x) { x = 0; bool f = 0; char c = gc(); while(!isdigit(c)) f |= c == '-', c = gc(); while(isdigit(c)) x = x * 10 + c - '0', c = gc(); if(f) x = -x; } template <typename T> void write(T x) { if(x < 0) pc('-'), x = -x; if(x > 9) write(x / 10); pc('0' + x % 10); } } using namespace IO; const int N = 1e6 + 5; const int p = 998244353; const int G = 3; const int Gi = 332748118; int n, m, len = 1, bit; ll f[N << 2], g[N << 2]; int rev[N << 2]; ll qpow(ll a, int b) { ll res = 1; while(b) { if(b & 1) res = res * a % p; a = a * a % p, b >>= 1; } return res % p; } void NTT(ll a[], int type) { for(int i = 0; i < len; i++) if(i < rev[i]) swap(a[i], a[rev[i]]); for(int mid = 1; mid < len; mid <<= 1) { ll wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1)); for(int i = 0; i < len; i += (mid << 1)) { ll w = 1; for(int j = 0; j < mid; j++, w = w * wn % p) { ll x = a[i + j], y = w * a[i + mid + j] % p; a[i + j] = (x + y) % p; a[i + mid + j] = (x - y + p) % p; } } } ll inv = qpow(len, p - 2); if(type == -1) for(int i = 0; i < len; i++) f[i] = f[i] * inv % p; return; } int main() { read(n), read(m); for(int i = 0; i <= n; i++) read(f[i]); for(int i = 0; i <= m; i++) read(g[i]); while(len <= n + m) len <<= 1, bit++; for(int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); NTT(f, 1); NTT(g, 1); for(int i = 0; i < len; i++) f[i] = f[i] * g[i] % p; NTT(f, -1); for(int i = 0; i <= n + m; i++) write(f[i]), pc(' '); pc('\n'); return 0; } // A.S.
对于一个多项式 \(F(x)\) 求 一个多项式 \(G(x)\),满足 \(F(x)*G(x)\equiv 1\ (\bmod x^n)\),系数对 \(998244353\) 取模。
就是多项式的逆元。
推一下式子
设
\[A(x)*B(x) \equiv 1\ (\bmod x^n)\\ A(x)*C(x)\equiv 1\ (\bmod x^{\frac{n}{2}}) \]那么
\[A(x)*(B(x)-C(x))\equiv 0\ (\bmod x^{\frac{n}{2}}) \\ B(x)-C(x)\equiv 0\ (\bmod x^{\frac{n}{2}}) \]我们要把模数改为 \(x^n\),只需要平方一下
\[[B(x)-C(x)]^2\equiv 0\ (\bmod x^n) \\ B^2(x)-2B(x)*C(x)+C^2(x)\equiv 0\ (\bmod x^n)\\ \]将等式左边乘上 \(A(x)\) 不影响等号
因为
\[A(x)*B(x)\equiv 1\ (\bmod x^n) \]所以可以将 \(A(x)*B(x)\) 都去掉
\[B(x)-B(x)*C(x)+A(x)*C^2(x)\equiv 0\ (\bmod x^n) \\ B(x)\equiv B(x)*C(x)-A(x)*C^2(x)\ (\bmod x^n) \]我们只需要求出 \(C(x)\),而 \(C(x)\) 与 \(B(x)\) 的形式是一样的,只是模数不一样,所以可以递归求解,当然也可以递推
复杂度 \(O(n\log n)\)
这里写的递推,但是常数好像不是很优秀(
\(bas\) 是当前多项式的项数,\(len\) 是当前多项式乘起来后的项数,也就是 \(2\times bas\)。
#include <bits/stdc++.h> #define ll long long #define db double #define gc getchar #define pc putchar #define swap(a, b) a ^= b ^= a ^= b using namespace std; namespace IO { template <typename T> void read(T &x) { x = 0; bool f = 0; char c = gc(); while(!isdigit(c)) f |= c == '-', c = gc(); while(isdigit(c)) x = x * 10 + c - '0', c = gc(); if(f) x = -x; } template <typename T> void write(T x) { if(x < 0) pc('-'), x = -x; if(x > 9) write(x / 10); pc('0' + x % 10); } } using namespace IO; const int N = 1e5 + 5; const int p = 998244353; const int G = 3; const int Gi = 332748118; int n; ll a[N << 2], b[2][N << 2]; int rev[N << 2]; ll qpow(ll a, int b) { ll res = 1; while(b) { if(b & 1) res = res * a % p; a = a * a % p, b >>= 1; } return res; } void calcrev(int len, int bit) { for(int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); } ll add(ll x, ll y) { return (x + y >= p) ? (x + y - p) : (x + y); } ll sub(ll x, ll y) { return add(x, p - y); } void NTT(ll a[], int len, int type) { for(int i = 0; i < len; i++) if(i < rev[i]) swap(a[i], a[rev[i]]); for(int mid = 1; mid < len; mid <<= 1) { ll wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1)); for(int i = 0; i < len; i += (mid << 1)) { ll w = 1; for(int j = 0; j < mid; j++, w = w * wn % p) { ll x = a[i + j], y = w * a[i + mid + j] % p; a[i + j] = add(x, y); a[i + mid + j] = sub(x, y); } } } ll inv = qpow(len, p - 2); if(type == -1) for(int i = 0; i < len; i++) a[i] = a[i] * inv % p; return; } ll X[N << 2], Y[N << 2]; void mul(ll x[], ll y[], int len) { memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y)); for(int i = 0; i < (len >> 1); i++) X[i] = x[i], Y[i] = y[i]; NTT(X, len, 1); NTT(Y, len, 1); for(int i = 0; i < len; i++) X[i] = X[i] * Y[i] % p; NTT(X, len, -1); for(int i = 0; i < len; i++) x[i] = X[i]; return; } void solve() { int k = 0, bas = 1, bit = 1, len = 2; b[k][0] = qpow(a[0], p - 2); while(bas < (n << 1)) { calcrev(len, bit); k ^= 1; for(int i = 0; i < bas; i++) b[k][i] = add(b[k ^ 1][i], b[k ^ 1][i]); mul(b[k ^ 1], b[k ^ 1], len); mul(b[k ^ 1], a, len); for(int i = 0; i < bas; i++) b[k][i] = sub(b[k][i], b[k ^ 1][i]); bas <<= 1, len <<= 1, bit++; } for(int i = 0; i < n; i++) write(b[k][i]), pc(' '); pc('\n'); return; } int main() { read(n); for(int i = 0; i < n; i++) read(a[i]), a[i] %= p; solve(); return 0; } // A.S.
给定一个多项式 \(A(x)\),求一个多项式 \(B(x)\),满足 \(B(x)\equiv \ln A(x)\ (\bmod x^n)\)
设 \(f(x) = \ln x\)
\(\ln\) 不好处理,但是对 \(\ln\) 求导后就很好算了,\(f'(x)=\dfrac{1}{x}\)
所以将同余号两边同时求导,\(B'(x)\equiv f'(A(x))*A'(x)\) (复合函数求导)
因为 \(f'(A(x))=\dfrac{1}{A(x)}\)
所以 \(B'(x)\equiv \dfrac{A'(x)}{A(x)}\)
有了 \(B'(x)\) 后再积分求出 \(B(x)\)
求导公式:\((x^a)'=ax^{a-1}\)
积分公式:\(\int x^adx=\dfrac{1}{a+1}x^{a+1}\)
积分就是求导的逆运算,你会发现把 \(\dfrac{1}{a+1}x^{a+1}\) 求导后就是 \(x^a\)
#include <bits/stdc++.h> #define ll long long #define db double #define gc getchar #define pc putchar #define swap(a, b) a ^= b ^= a ^= b using namespace std; namespace IO { template <typename T> void read(T &x) { x = 0; bool f = 0; char c = gc(); while(!isdigit(c)) f |= c == '-', c = gc(); while(isdigit(c)) x = x * 10 + c - '0', c = gc(); if(f) x = -x; } template <typename T> void write(T x) { if(x < 0) pc('-'), x = -x; if(x > 9) write(x / 10); pc('0' + x % 10); } } using namespace IO; const int N = 1e5 + 5; const int p = 998244353; const int G = 3; const int Gi = 332748118; int n; ll f[N << 2], g[N << 2]; ll a[N << 2], b[N << 2]; ll qpow(ll a, int b) { ll res = 1; while(b) { if(b & 1) res = res * a % p; a = a * a % p, b >>= 1; } return res; } ll add(ll x) {return x >= p ? x - p : x; } ll sub(ll x) {return x < 0 ? x + p : x; } void Copy(ll *x, ll *y, int len) {for(int i = 0; i < len; i++) x[i] = y[i]; } int rev[N << 2]; void calcrev(int len, int bit) { for(int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); } void NTT(ll *a, int len, int type) //快速数论变换 { for(int i = 0; i < len; i++) if(i < rev[i]) swap(a[i], a[rev[i]]); for(int mid = 1; mid < len; mid <<= 1) { ll wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1)); for(int i = 0; i < len; i += (mid << 1)) { ll w = 1; for(int j = 0; j < mid; j++, w = w * wn % p) { ll x = a[i + j], y = w * a[i + mid + j] % p; a[i + j] = add(x + y); a[i + mid + j] = sub(x - y); } } } ll leni = qpow(len, p - 2); if(type == -1) for(int i = 0; i < len; i++) a[i] = a[i] * leni % p; return; } ll X[N << 2], Y[N << 2]; void mul(ll *x, ll *y, int len) //多项式乘法 { memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y)); Copy(X, x, len >> 1), Copy(Y, y, len >> 1); NTT(X, len, 1); NTT(Y, len, 1); for(int i = 0; i < len; i++) X[i] = X[i] * Y[i] % p; NTT(X, len, -1); Copy(x, X, len); } ll inv[2][N << 2]; void Inv(ll *x, ll *y, int n) //多项式求逆 { int bas = 1, len = 2, bit = 1, k = 0; inv[k][0] = qpow(x[0], p - 2); while(bas < (n << 1)) { calcrev(len, bit); k ^= 1; for(int i = 0; i < bas; i++) inv[k][i] = add(inv[k ^ 1][i] + inv[k ^ 1][i]); mul(inv[k ^ 1], inv[k ^ 1], len); mul(inv[k ^ 1], x, len); for(int i = 0; i < bas; i++) inv[k][i] = sub(inv[k][i] - inv[k ^ 1][i]); bas <<= 1, len <<= 1, bit++; } Copy(y, inv[k], n); } void Differential(ll *x, ll *y, int n) //求导 { for(int i = 1; i < n; i++) y[i - 1] = i * x[i] % p; y[n - 1] = 0; } void Integral(ll *x, ll *y, int n) //积分 { for(int i = 1; i < n; i++) y[i] = x[i - 1] * qpow(i, p - 2) % p; y[0] = 0; } int calclen(int n) { int len = 1; while(len <= (n << 1)) len <<= 1; return len; } void Ln(ll *x, ll *y, int n) { Differential(x, a, n); Inv(x, b, n); mul(a, b, calclen(n)); Integral(a, y, n); } int main() { read(n); for(int i = 0; i < n; i++) read(f[i]); Ln(f, g, n); for(int i = 0; i < n; i++) write(g[i]), pc(' '); pc('\n'); return 0; } // A.S.