因为题目保证了每条链上的温度是递减的,所以可以考虑线段树合并的做法。
具体的,我们先将所有询问离线,其中对于每个询问\((x,l,r)\)我们找到最大的满足以下条件的祖先\(y\):
\(\bullet\) \(t_y \leq r\)
这个部分可以倍增完成,然后我们对于每个询问在合并到\(y\)的时候处理即可。
对于刚开始不在区间内的直接特判。
关于线段树合并的部分,对于每个节点\(x\)的线段树,我们先在\(t_x\)的位置设置初值\(1\)代表这棵线段树第\(t_x\)的位置有\(1\)个数,然后不断将每个位置的节点个数向上合并。
#include <bits/stdc++.h> using namespace std; const int MAXN = 1e5 + 5; struct SegmentTreeNode { int l, r; int val; }tree[MAXN * 50]; int idx; class SegmentTree { #define ls(x) tree[x].l #define rs(x) tree[x].r private: int left, right; // 维护的值域 inline void pushup(int p) { tree[p].val = tree[ls(p)].val + tree[rs(p)].val; } inline void insert(int i, int l, int r, int &p) { if (!p) p = ++idx; if (l == r) { ++tree[p].val; return ; } int mid = (l + r) >> 1; if (i <= mid) insert(i, l, mid, ls(p)); else insert(i, mid + 1, r, rs(p)); pushup(p); } inline int merge(int l, int r, int x, int y) { if (!x || !y) return x + y; if (l == r) { tree[x].val += tree[y].val; return x; } int mid = (l + r) >> 1; ls(x) = merge(l, mid, ls(x), ls(y)); rs(x) = merge(mid + 1, r, rs(x), rs(y)); pushup(x); return x; } inline int query(int ql, int qr, int l, int r, int p) { if (!p) return 0; if (ql <= l && r <= qr) { return tree[p].val; } int mid = (l + r) >> 1, res = 0; if (ql <= mid) res += query(ql, qr, l, mid, ls(p)); if (qr > mid) res += query(ql, qr, mid + 1, r, rs(p)); return res; } public: int root; inline SegmentTree() { root = 0; } inline SegmentTree(int l, int r) { root = 0; left = l, right = r; } inline void setRange(int l, int r) { left = l, right = r; } inline void insert(int i) { insert(i, left, right, root); } inline void merge(SegmentTree x) { merge(left, right, root, x.root); } inline int query(int l, int r) { return query(l, r, left, right, root); } }; SegmentTree od[MAXN]; vector<int> e[MAXN]; void addedge(int u, int v) { e[u].push_back(v); e[v].push_back(u); } struct Query { int idx, l, r; }; vector<Query> Q[MAXN]; int w[MAXN], ans[MAXN]; // w存储温度, ans为每个询问的答案 int Log2[MAXN]; void init() { for (int i = 1; i < MAXN; ++i) { Log2[i] = Log2[i - 1] + (1 << Log2[i - 1] == i); } } int dp[MAXN][65], depth[MAXN]; void dfs2(int u, int fa) { dp[u][0] = fa; depth[u] = depth[fa] + 1; for (int i = 1; i <= Log2[depth[u]]; ++i) { int s = dp[u][i - 1]; dp[u][i] = dp[s][i - 1]; } for (int v: e[u]) { if (v == fa) continue; dfs2(v, u); } } int lca(int x, int r) { for (int i = Log2[depth[x]] - 1; i >= 0; --i) { if (dp[x][i] && w[dp[x][i]] <= r) { x = dp[x][i]; } } return x; } void dfs(int u, int fa) { for (int v: e[u]) { if (v == fa) continue; dfs(v, u); od[u].merge(od[v]); } for (auto [idx, l, r]: Q[u]) { ans[idx] = od[u].query(l, r); } } int main(int argc, char *argv[]) { ios::sync_with_stdio(false), cin.tie(0), cout.tie(0); init(); int n; cin >> n; for (int i = 1; i < n; ++i) { int u, v; cin >> u >> v; addedge(u, v); } for (int i = 1; i <= n; ++i) { cin >> w[i]; od[i].setRange(1, 1e9); od[i].insert(w[i]); } dfs2(1, 0); int q; cin >> q; for (int i = 1; i <= q; ++i) { int x, l, r; cin >> x >> l >> r; if (l <= w[x] && w[x] <= r) { Q[lca(x, r)].push_back({i, l, r}); } else { ans[i] = 0; } } dfs(1, 0); for (int i = 1; i <= q; ++i) { cout << ans[i] << '\n'; } system("pause"); return 0; }