题目链接:学霸大帅哥zyh
dsu on tree,用线段树维护答案,每次修改的时候只会改变一个数,所以只会变一条链,非常好写。复杂度 \(O(n\log^2 n)\)。题是好题,就是数据太水了,\(O(n^2\log n)\) 的假做法都能过。
#include <bits/stdc++.h> using namespace std; #define pb push_back #define endl '\n' using ll = long long; using pii = pair<int, int>; struct node { int id, cnt; bool operator<(const node &other) const { return cnt > other.cnt; } }; const int maxn = 1e5 + 5; multiset<node> s; vector<pii> g[maxn]; int tot[maxn], h[maxn]; int sz[maxn], now[maxn]; pii son[maxn]; int ans[maxn], vis[maxn]; namespace sgt { #define ls p << 1 #define rs p << 1 | 1 int t[maxn << 2], id[maxn << 2]; void push_up(int p) { t[p] = max(t[ls], t[rs]); if (t[ls] != t[rs]) { id[p] = t[ls] > t[rs] ? id[ls] : id[rs]; } else { id[p] = t[p] > 0 ? max(id[ls], id[rs]) : -1; } } void build(int p, int l, int r) { if (l == r) { id[p] = l; return; } int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); push_up(p); } void modify(int p, int x, int l, int r, int k) { if (l == r) { assert(l == x); t[p] = k; return; } int mid = (l + r) >> 1; if (x <= mid) modify(ls, x, l, mid, k); if (x > mid) modify(rs, x, mid + 1, r, k); push_up(p); } int query() { return id[1]; } }; // namespace sgt void dfs0(int u, int f) { sz[u] = 1; for (auto &&p : g[u]) { auto &v = p.first; if (v == f) continue; dfs0(v, u); sz[u] += sz[v]; if (sz[v] > sz[son[u].first]) son[u] = p; } } int S; void work(int u, int f, int val) { now[h[u]] += val; int t = min(now[h[u]], tot[h[u]] - now[h[u]]); sgt::modify(1, h[u], 1, 100000, t); for (auto &&p : g[u]) { auto &v = p.first; if (v == f || v == S) continue; work(v, u, val); } } map<pii, int> mp; void dfs(int u, int f, int op) { for (auto &&p : g[u]) { auto &v = p.first; if (v == f || v == son[u].first) continue; dfs(v, u, 0); } if (son[u].first) { dfs(son[u].first, u, 1); S = son[u].first; } work(u, f, 1); S = 0; ans[mp[{f, u}]] = sgt::query(); if (!op) { work(u, f, -1); } } void solve() { int n; cin >> n; for (int i = 1; i <= n; ++i) { cin >> h[i]; tot[h[i]]++; } for (int i = 1, u, v; i < n; ++i) { cin >> u >> v; g[u].pb({v, i}), g[v].pb({u, i}); mp[{v, u}] = i, mp[{u, v}] = i; } dfs0(1, 0); sgt::build(1, 1, 100000); dfs(1, 0, 1); for (int i = 1; i < n; ++i) { cout << ans[i] << endl; } } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int T = 1; // cin >> T; while (T--) { solve(); } }