给定一棵树,要求计算,给节点染色,要求每个节点 \(c_k \neq c_{p_k} - 1\) ,统计方案数 \((mod\ \ 998\ 244\ 353)\)
容斥枚举破坏 \(i\) 个条件下的方案数,对于每个节点,都有出度种方法造成 \(1\) 个贡献,对于每个节点的生成函数即为
\[g(x) = 1 + c \cdot x \]其余节点染色 \((n-i)!\) 排列一下就好
启发式合并或分治NTT即可
#include <bits/stdc++.h> #define ll long long #define ull unsigned long long #define i64 long long #define poly std::vector<int> // dont visit a[m] when a.size() <= m // (a = fastpow(c,n-m+1,m+1)).resize(m+1); // i64 res = a[m] - b[m]; // (b = fastpow(d,n-m+1,m+1)).resize(m+1); constexpr int MOD = 998244353; namespace Poly { // remember to resize const int N = (1 << 21), g = 3; inline int power(int x, int p) { int res = 1; for (; p; p >>= 1, x = (ll)x * x % MOD) if (p & 1) res = (ll)res * x % MOD; return res; } inline int fix(const int x) { return x >= MOD ? x - MOD : x; } void dft(poly& A, int n) { static ull W[N << 1], *H[30], *las = W, mx = 0; for (; mx < n; mx++) { H[mx] = las; ull w = 1, wn = power(g, (MOD - 1) >> (mx + 1)); for(int i=0;i<1<<n;++i) *las++ = w, w = w * wn % MOD; } if (A.size() != (1 << n)) A.resize(1 << n); static ull a[N]; for (int i = 0, j = 0; i < (1 << n); ++i) { a[i] = A[j]; for (int k = 1 << (n - 1); (j ^= k) < k; k >>= 1); } for (int k = 0, d = 1; k < n; k++, d <<= 1) for (int i = 0; i < (1 << n); i += (d << 1)) { ull *l = a + i, *r = a + i + d, *w = H[k], t; for (int j = 0; j < d; j++, l ++, r++) { t = (*r) * (*w++) % MOD; *r = *l + MOD - t, *l += t; } } for(int i=0;i<1<<n;++i) A[i] = a[i] % MOD; } void idft(poly &a, int n) { a.resize(1 << n), reverse(a.begin() + 1, a.end()); dft(a, n); int inv = power(1 << n, MOD - 2); for(int i=0;i<1<<n;++i) a[i] = (ll)a[i] * inv % MOD; } poly FIX(poly a) { while (!a.empty() && !a.back()) a.pop_back(); return a; } // remember to resize poly mul(poly a, poly b, int t = 1) { if (t == 1 && a.size() + b.size() <= 24) { poly c(a.size() + b.size(), 0); for(int i=0;i<a.size();++i) for(int j=0;j<b.size();++j) c[i + j] = (c[i + j] + (ll)a[i] * b[j]) % MOD; return FIX(c); } int n = 1, aim = a.size() * t + b.size(); while ((1<<n) <= aim) n++; dft(a, n); dft(b, n); if (t == 1) for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * b[i] % MOD; else for(int i=0;i<1<<n;++i) a[i] = (ll) a[i] * a[i] % MOD * b[i] % MOD; idft(a, n); a.resize(aim); return FIX(a); } }; using namespace Poly; // remember to resize void norm(int&x) { if(x>=MOD) x -= MOD; if(x<0) x += MOD; } int main(int argc, char const *argv[]) { std::ios_base::sync_with_stdio(false); std::cin.tie(nullptr); std::cout.tie(nullptr); int n; std::cin >> n; std::vector<std::vector<int> > g(n, std::vector<int>()); for(int i=0;i<n-1;++i) { int u,v; std::cin >> u >> v; --u; --v; g[u].push_back(v); g[v].push_back(u); } auto dnc = [&](auto dnc,int l,int r) { if(r - l == 1) { return (poly) {1, (int) g[l].size() - (l != 0)}; } int mid = l + r >> 1; return mul(dnc(dnc,l,mid), dnc(dnc,mid,r)); }; int res = 0; poly ans = dnc(dnc,0,n); ans.resize(n+1); std::vector<int> fac(n+1); fac[0] = fac[1] = 1; for(int i=2;i<=n;++i) { fac[i] = 1ll * fac[i-1] * i % MOD; } for(int i=0;i<=n;++i) { int thiz = 1ll * fac[n - i] * ans[i] % MOD; norm( res += (i&1 ? MOD - thiz : thiz) ); } std::cout << res; return 0; }