理一下思路,这道题我们首先的想法是二分答案一个 $check $ 没问题吧?
我们二分了一个值,考虑 \(check\) 的过程。
我们每次记录每个节点还没被并入的链长度,然后实行在树上进行一个 \(dfs\) 的过程。
然后就是考虑这个 \(dfs\) 的过程中我们每次遍历完子树,然后我们其实就将这个里面划分成了若干条链。
考虑我们当前节点是起的一个什么作用?
它可以充当链顶端以及作为两个节点的 \(LCA\) 然后进行合并。
考虑如果说一个链长度如果加上到当前节点的长度已经满足大于等于我们设置的 \(lim\) 那么就很显然的就可以将它们进行一个合并,然后增大我们的链次数。
题解中多数地方写到我们要考虑最大化合并次数,为什么?
因为你首先是要满足的是你能划分出来大于等于 \(m\) 条链,而不是说让这里面的链长有足够的大。
这样你才可以让你限制的这个 \(lim\) 尽量可能的大。
所以说这是为什么我们要进行最大化合并次数的原因。
那么我们考虑充当链顶端的方法很简单,你直接在遍历子节点的时候 \(check\) 一下就好了。
问题在于充当 \(lca\) 时怎么办
这个也好解决。
我们其实可以选一个大的和一个小的然后去 \(check\) 。
然后注意一件事,我们不能贸然的将排序后的拿去左端点右端点匹配。
我们要考虑更高级一点的东西。
我们是要尽量合并最多
所以我们控制的是第一个可以合并的更他进行合并
那么这个可以利用 \(lowerbound\) 进行一个实现
然后我们考虑合并完了之后很明显的我们要将这个给删除掉显然吧
因为不删除掉就会造成影响吧。
可以用 \(multiset\) 进行实现,但是看了大佬的做法发现可以用一个并查集代替这个过程。
然后就很妙这就是个很牛逼的思路,感觉可以用在一些有趣的地方。
#include <bits/stdc++.h> #define int long long using namespace std; namespace IO { int len = 0; char ibuf[(1 << 20) + 1], *iS, *iT, out[(1 << 25) + 1]; #define gh() \ (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), \ (iS == iT ? EOF : *iS++) : *iS++) inline int read() { char ch = gh(); int x = 0; char t = 0; while (ch < '0' || ch > '9') t |= ch == '-', ch = gh(); while (ch >= '0' && ch <= '9') x = x * 10 + (ch ^ 48), ch = gh(); return t ? -x : x; } inline void putc(char ch) { out[len++] = ch; } template <class T> inline void write(T x) { if (x < 0) putc('-'), x = -x; if (x > 9) write(x / 10); out[len++] = x % 10 + 48; } string getstr(void) { string s = ""; char c = gh(); while (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF) c = gh(); while (!(c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == EOF)) s.push_back(c), c = gh(); return s; } void putstr(string str, int begin = 0, int end = -1) { if (end == -1) end = str.size(); for (int i = begin; i < end; i++) putc(str[i]); return; } inline void flush() { fwrite(out, 1, len, stdout); len = 0; } } // namespace IO using IO::flush; using IO::getstr; using IO::putc; using IO::putstr; using IO::read; using IO::write; #define pr pair<int, int> const int N = 2e5; int n, m, lim, num, fa[N], ans[N], s[N], vis[N]; vector<pr> ver[N]; int getfa(int x) { if (fa[x] != x) fa[x] = getfa(fa[x]); return fa[x]; } void dfs(int x, int ff) { s[x] = 0; vector<int> q; q.clear(); // ok for (int i = 0; i < ver[x].size(); i++) { int to = ver[x][i].first, val = ver[x][i].second; if (to == ff) continue; dfs(to, x); if (s[to] + val >= lim) num++; else q.push_back(s[to] + val); } if (!q.size()) return; if (q.size() == 1) { s[x] = q[0]; return; } sort(q.begin(), q.end()); // ok for (int i = 0; i <= q.size() + 1; i++) vis[i] = 0, fa[i] = i; for (int i = 0; i < q.size() - 1; i++) { if (vis[i]) continue; if (i >= q.size() - 1 || i == -1) break; int t = lower_bound(q.begin() + i + 1, q.end(), lim - q[i]) - q.begin(); if (t >= q.size()) continue; t = getfa(t); if (t >= q.size()) continue; // printf("%lld %lld | %lld %lld \n", i, q[i], t, q[t]); int val = 0; val = q[t] + q[i]; if (val < lim) continue; num++; vis[t] = 1; vis[i] = 1; fa[t] = t + 1; } for (int i = q.size() - 1; i >= 0; i--) { if (!vis[i]) { s[x] = q[i]; return; } } } bool check(int l) { lim = l; num = 0; // printf("\nLen :%lld\n", l); dfs(1, 0); // printf("Num : %lld\n", num); return num >= m; } signed main() { int l = 1, r = 0, ans = 0; n = read(), m = read(); for (int i = 1; i < n; i++) { int u = read(), v = read(), w = read(); ver[u].push_back(make_pair(v, w)); ver[v].push_back(make_pair(u, w)); r += w; } while (l <= r) { int mid = (l + r) >> 1; if (check(mid)) { ans = mid; l = mid + 1; } else r = mid - 1; } printf("%lld ", ans); return 0; }