Rd795 题解滞销,帮帮我(误
不难想到换根 dp,然后我们可以写出如下式子:
\[f_u=\left(\dbinom{\mathrm{sz}_u}{k}-\sum_{v}\dbinom{\mathrm{sz}_v}{k}\right)\mathrm{sz}_u+\sum_{v}f_v \]具体含义就是:考虑这 \(k\) 个点,什么时候 \(\operatorname{lca}\) 为 \(u\)。
减法原理,考虑什么时候不为 \(u\),当且仅当这 \(k\) 个点来自于同一颗子树。减去\(\sum\limits_{v}\dbinom{\mathrm{sz}_v}{k}\) 即可。
最后把子树答案合并上去即可。
观察这个式子,可以拆出「只和 \(v\) 有关」、「只和 \(u\) 有关」、「和 \(u,v\) 都有关」三种。为了好换根,我们需要让和 \(v\) 有关的都比较好统计。于是考虑重定义 \(f_u\) 和 \(g_u\)。\(f_u\) 表示所有儿子的答案,\(g_u\) 表示 \(\sum\limits_{v}\dbinom{\mathrm{sz}_v}{k}\)。
此时子树 \(u\) 的答案应该是:
\[\mathrm{ans}_u=\left(\dbinom{\mathrm{sz}_u}{k}-g_u\right)\mathrm{sz}_u+f_u \]此时换根只需要修改 \(g\) 和 \(f\) 即可,减减加加上面提到的组合数,相信大家都会。
没什么注意事项,没有细节,很好写。
// Problem: F. K-Set Tree // From: Codeforces - CodeCraft-22 and Codeforces Round #795 (Div. 2) // URL: https://codeforces.com/contest/1691/problem/F // Time: 2022-05-31 22:35 // Author: lingfunny #include <bits/stdc++.h> #define LL long long using namespace std; const int mxn = 2e5+10, mod = 1e9+7; int jc[mxn], ijc[mxn], n, f[mxn], g[mxn], k, sz[mxn], ans; // f: 子树 ans 和 // g: 子树 (sz, k) 和 // ans[u]: (C(S-1, k-1) + C(S-1, k) - g[u]) * S + f[u] vector <int> G[mxn]; inline int mul() { return 1; } template <typename... A> int mul(int x, A... a) { return (LL)x*mul(a...)%mod; } inline int qpow(int x, int y) { int res = 1; for(; y; x = mul(x, x), y >>= 1) if(y&1) res = mul(res, x); return res; } inline int C(int n, int m) { if(m > n) return 0; return mul(jc[n], ijc[m], ijc[n-m]); } void add(int &x) { x = x; } template <typename... A> void add(int &x, int y, A... a) { x += (x + y) >= mod ? y - mod : y; add(x, a...); } inline int calc(int S, int F, int G) { return (mul((C(S, k) - G + mod)%mod, S) + F) % mod; } void dfs(int u, int fa) { sz[u] = 1; for(int v: G[u]) if(v != fa) { dfs(v, u), sz[u] += sz[v]; add(f[u], calc(sz[v], f[v], g[v])); add(g[u], C(sz[v], k)); } } void dfs2(int u, int fa) { if(fa) { int F = f[fa], G = g[fa]; add(F, mod - calc(sz[u], f[u], g[u])), add(G, mod - C(sz[u], k)); add(f[u], calc(n - sz[u], F, G)); add(g[u], C(n - sz[u], k)); } add(ans, calc(n, f[u], g[u])); for(int v: G[u]) if(v != fa) dfs2(v, u); } signed main() { scanf("%d%d", &n, &k); jc[0] = 1; for(int i = 1; i <= n; ++i) jc[i] = mul(jc[i-1], i); ijc[n] = qpow(jc[n], mod-2); for(int i = n - 1; ~i; --i) ijc[i] = mul(ijc[i+1], i+1); for(int i = 1, u, v; i < n; ++i) scanf("%d%d", &u, &v), G[u].emplace_back(v), G[v].emplace_back(u); dfs(1, 0); dfs2(1, 0); printf("%d\n", ans); return 0; }