树链剖分 + 二分
通过树链剖分查找,判断一下路径上,最后一个黑点出现在哪一条链上,然后在链上进行二分 dfn 查找第一个黑点所在位置
#include <iostream> #include <cstdio> #include <vector> #include <algorithm> using namespace std; const int maxn = 1e5 + 10; int dep[maxn], siz[maxn], hson[maxn], fa[maxn]; int dfn[maxn], rnk[maxn], tr[maxn << 2], top[maxn]; vector<int>gra[maxn]; void dfs1(int now, int pre, int d) { dep[now] = d; hson[now] = -1; siz[now] = 1; fa[now] = pre; for(int i=0; i<gra[now].size(); i++) { int nex = gra[now][i]; if(nex == fa[now]) continue; dfs1(nex, now, d + 1); siz[now] += siz[nex]; if(hson[now] == -1 || siz[hson[now]] < siz[nex]) hson[now] = nex; } } int tp = 0; void dfs2(int now, int t) { tp++; top[now] = t; dfn[now] = tp; rnk[tp] = now; if(hson[now] != -1) { dfs2(hson[now], t); for(int i=0; i<gra[now].size(); i++) { int nex = gra[now][i]; if(nex == fa[now] || nex == hson[now]) continue; dfs2(nex, nex); } } } void init(int n, int rt = 1) { tp = 0; dfs1(rt, rt, 1); dfs2(rt, rt); for(int i=0; i<=n; i++) gra[i].clear(); } void update(int now, int l, int r, int x) { if(l == r) { tr[now] ^= 1; return; } int mid = l + r >> 1; if(x <= mid) update(now << 1, l, mid, x); else update(now << 1 | 1, mid + 1, r, x); tr[now] = tr[now << 1] + tr[now << 1 | 1]; } int query(int now, int l, int r, int L, int R) { if(L <= l && r <= R) return tr[now]; int mid = l + r >> 1; int ans = 0; if(L <= mid) ans += query(now << 1, l, mid, L, R); if(R > mid) ans += query(now << 1 | 1, mid + 1, r, L, R); return ans; } int solve(int v, int n) { int node = 0; while(top[v] != 1) { if(query(1, 1, n, dfn[top[v]], dfn[v])) node = v; v = fa[top[v]]; } if(query(1, 1, n, 1, dfn[v])) node = v; if(node == 0) return -1; int l = dfn[top[v]], r = dfn[v]; while(l < r) { int mid = l + r >> 1; if(query(1, 1, n, l, mid)) r = mid; else l = mid + 1; } return rnk[r]; } int main() { int n, m; scanf("%d%d", &n, &m); for(int i=1; i<n; i++) { int x, y; scanf("%d%d", &x, &y); gra[x].push_back(y); gra[y].push_back(x); } init(n, 1); while(m--) { int t; scanf("%d", &t); if(t == 0) { int i; scanf("%d", &i); update(1, 1, n, dfn[i]); } else { int v; scanf("%d", &v); printf("%d\n", solve(v, n)); } } return 0; }