给出一棵大小为 \(n\) 的树,\(q\) 次询问,每次给出一大小为 \(m\) 的点集,判断是否存在一条链覆盖这些点,注意这条链可以经过其他点。\(n,\sum m \leq 2\times 10^5\) ,\(q \leq 10^5\)。
由于 \(q\) 次询问的 \(\sum m \le 2 \times 10^5\) ,那么我们可以考虑对于每次询问使用虚树来解决,将树浓缩为所有关键点以及这些点的 \(LCA\),然后通过 \(DFS\) 来判断每个点的儿子个数来判断是否是一条链即可。
具体的判断方法为:
#include <bits/stdc++.h> using namespace std; #define rep(i, b, s) for(int i = (b); i <= (s); ++ i) #define dec(i, b, s) for(int i = (b); i >= (s); -- i) using ll = long long; #ifdef LOCAL #include <debugger> #else #define debug(...) 42 #endif template <typename T> void chkmax(T &x, T y) { x = max(x, y); } template <typename T> void chkmin(T &x, T y) { x = min(x, y); } constexpr int N = 2E5 + 10; vector<int> son[N]; vector<int> g[N]; int tp, a[N], use[N]; bool ans, has_root, f; int dfn[N], depth[N], sz[N], hson[N], top[N], parent[N]; void dfs1(int u, int fa, int d) { depth[u] = d; sz[u] = 1; parent[u] = fa; for (int &v: son[u]) if (v != fa) { dfs1(v, u, d + 1); sz[u] += sz[v]; if (hson[u] == -1 || sz[hson[u]] < sz[v]) hson[u] = v; } } void dfs2(int u, int id) { top[u] = id; dfn[u] = ++ tp; if (hson[u] == 0) return ; dfs2(hson[u], id); for (int &v: son[u]) if (v != parent[u] && v != hson[u]) { dfs2(v, v); } } int lca(int u, int v) { while (top[u] != top[v]) { if (depth[top[u]] > depth[top[v]]) { u = parent[top[u]]; } else { v = parent[v]; } } return depth[u] < depth[v] ? u : v; } void dfs3(int u, int fa) { int now = 0; for (int &v: g[u]) { if (v != fa) dfs3(v, u); if (!f || v != 1) { ++ now; } } if (!(f && u == 1) && now > 2) { ans = true; } g[u].clear(); } void solve() { int n; cin >> n; for (int i = 1; i < n; i ++ ) { int u, v; cin >> u >> v; son[u].emplace_back(v); son[v].emplace_back(u); } dfs1(1, 0, 1); dfs2(1, 1); int q; cin >> q; while (q -- ) { int k; cin >> k; has_root = false; for (int i = 1; i <= k; i ++ ) { cin >> a[i]; if (a[i] == 1) has_root = true; } sort(a + 1, a + 1 + k, [&] (int x, int y){ return dfn[x] < dfn[y]; }); auto add = [&] (int u, int v) { g[u].emplace_back(v); g[v].emplace_back(u); }; vector<int> stk{1}; for (int i = 1; i <= k; i ++ ) { if (a[i] != 1) { int p = lca(a[i], stk.back()); if (p != stk.back()) { while (dfn[p] < dfn[stk[(int)stk.size() - 2]]) { add(stk.back(), stk[(int)stk.size() - 2]); stk.pop_back(); } add(p, stk.back()), stk.pop_back(); if (dfn[p] > dfn[stk.back()]) stk.emplace_back(p); } stk.emplace_back(a[i]); } } while (stk.size() > 1) { if (stk.back() != 1) { add(stk.back(), stk[(int)stk.size() - 2]); stk.pop_back(); } } f = (!has_root && (int)g[1].size() == 1); // cout << (f ? "TRUE" : "FALSE") << "\n"; // 表示关键点里面没有根,并且根(一定存在)在虚树中只连了一个点 dfs3(1, 0); if (ans) { cout << "NO\n"; } else { cout << "YES\n"; } ans = false; } } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int T = 1; // cin >> T; while(T -- ) solve(); return 0; }
在树上判断三个点 \((x,y,z)\) 是否在一条脸上的方法是,判断 \(\{dis(x,y), dis(y,z), dis(x,z)\}\) 中两条短边的和是否等于最长边。
那么我们可以维护两个点 \((x,y)\) 为链的两端点,然后依次枚举剩余的点,每次更新 \((x,y)\) 即可。
#include <bits/stdc++.h> using namespace std; using ll = long long; #ifdef LOCAL #include <debugger> #else #define debug(...) 42 #endif template <typename T> void chkmax(T &x, T y) { x = max(x, y); } template <typename T> void chkmin(T &x, T y) { x = min(x, y); } const int N = 2e5 + 10; vector<int> son[N]; ll ans[N]; int d[N]; int depth[N]; int fa[N][20]; int q[N]; void bfs(int root) { memset(depth, 0x3f, sizeof depth); depth[0] = 0, depth[root] = 1; int hh = 0, tt = 0; q[0] = root; while(hh <= tt) { int u = q[hh ++ ]; int num = son[u].size(); for(int i = 0; i < num; i ++ ) { int v = son[u][i]; if(depth[v] > depth[u] + 1) { // father[v] = u; depth[v] = depth[u] + 1; fa[v][0] = u; for(int k = 1; k < 20; k ++ ) { fa[v][k] = fa[fa[v][k - 1]][k - 1]; } q[++ tt] = v; } } } } int lca(int a, int b) { if(depth[a] < depth[b]) swap(a, b); for(int k = 19; k >= 0; k -- ) { if(depth[fa[a][k]] >= depth[b]) { a = fa[a][k]; } } if(a == b) return a; for(int k = 19; k >= 0; k -- ) { if(fa[a][k] != fa[b][k]) { a = fa[a][k]; b = fa[b][k]; } } return fa[a][0]; } void solve() { int n; cin >> n; vector<int> a(n); for(int &x: a) cin >> x; //mp[x] ++; if(n == 1) { cout << "YES\n"; return ; } int x = a[0], y = a[1]; auto dis = [&] (int X, int Y) { return depth[X] + depth[Y] - 2 * depth[lca(X, Y)]; }; for(int i = 2; i < n; i ++ ) { int z = a[i]; vector<int> l(3); l[0] = dis(x, y), l[1] = dis(y, z), l[2] = dis(x, z); auto r = l; sort(l.begin(), l.end()); if(l[0] + l[1] != l[2]) { cout << "NO\n"; return ; } if(r[1] == l.back()) { x = z; } else if(r[2] == l.back()) { y = z; } } cout << "YES\n"; } int main() { cin.tie(nullptr)->sync_with_stdio(false); int n; cin >> n; for(int i = 1; i < n; i ++ ) { int u, v; cin >> u >> v; son[u].push_back(v); son[v].push_back(u); } bfs(1); int Q; cin >> Q; while(Q -- ) solve(); return 0; }