题目链接
题意:
给定一棵\(N\)个节点的树,节点编号从\(1\)到\(N\),每个节点都有一个整数权值。
现在,我们要进行\(M\)次询问,格式为\(u\) \(v\),对于每个询问你需要回答从\(u\)到\(v\)的路径上(包括两端点)共有多少种不同的点权值。
思路:
树上莫队,预处理得到树的欧拉序列。
\(dfs\)序列:是指将一棵树被\(dfs\)遍历时所经过的节点顺序,回溯时不再记录。
欧拉序列:\(dfs\)遍历,第一次遇到该点时记录一次时间戳\(first\),回溯时再记录一次时间戳\(last\),由此得到一序列为欧拉序列,其中\(first\)和\(last\)分别记录该点在序列中的第一次遍历和第二次遍历的时间戳。
令\(first[u]\)<\(first[v]\),即先遍历\(u\),对于\(u\)到\(v\)的路径分两种情况:
1.当\(u\)与\(v\)的最近公共祖先为\(u\)时,则从\(u\)到\(v\)的路径上的点即为区间\([first[u], first[v]]\)上的点,且区间内每个点只出现一次。
2.当\(u\)与\(v\)不在一颗子树上时,即它们的\(LCA\)为另一点时,则从\(u\)到\(v\)的路径上的点即为区间\([last[u], first[v]]\)上的只出现过一次的点再加上点\(LCA(u,v)\),(如果\(dfs\)从\(u\)到\(v\)时经过其他点,则在区间内必刚好出现两次,对此可以设置\(used\)数组,不断令\(used\)^\(1\)判断其是否出现两次)。特别注意:当处理区间\([l,r]\)内的点时,需映射到欧拉序列中,得到该点序号再处理,而对于\(LCA(u,v)\),无需映射,因为求得的\(LCA\)本身就是点的序号。
对于每次询问,得到区间左右端点,然后普通莫队离线处理即可。
code:
#include <iostream> #include <cstdio> #include <string> #include <cstring> #include <algorithm> #include <queue> #include <vector> #include <deque> #include <cmath> #include <ctime> #include <map> #include <set> // #include <unordered_map> #define fi first #define se second #define pb push_back // #define endl "\n" #define debug(x) cout << #x << ":" << x << endl; #define bug cout << "********" << endl; #define all(x) x.begin(), x.end() #define lowbit(x) x & -x #define fin(x) freopen(x, "r", stdin) #define fout(x) freopen(x, "w", stdout) #define ull unsigned long long #define ll long long const double eps = 1e-15; const int inf = 0x3f3f3f3f; // const ll INF = 0x3f3f3f3f3f3f3f3f; const double pi = acos(-1.0); const int mod = 998244353; const int maxn = 1e6 + 10; using namespace std; int our[maxn], n, m, s[maxn], p[maxn], ans[maxn]; int dp[maxn][32], dep[maxn], first[maxn], last[maxn]; int vis[maxn], tot, block, ret, used[maxn]; vector<int> v[maxn]; struct node{ int l, r, lca, id; bool operator<(const node &a)const{ return (l/block == a.l/block) ? ((l/block) & 1 ? r < a.r : r > a.r) : l < a.l; } }cnt[maxn]; void dfs(int u, int fa){ dep[u] = dep[fa] + 1, dp[u][0] = fa; our[++ tot] = u; first[u] = tot; for(int i = 1; (1 << i) <= dep[fa]; i ++)dp[u][i] = dp[dp[u][i - 1]][i - 1]; for(auto i : v[u]){ if(i == fa)continue; dfs(i, u); } our[++ tot] = u; last[u] = tot; } int lca(int a, int b){ if(dep[a] < dep[b])swap(a, b); int h = dep[a] - dep[b]; for(int i = 25; i >= 0; i --){ if((h >> i) & 1)a = dp[a][i]; } if(a == b)return a; for(int i = 25; i >= 0; i --){ if(dp[a][i] != dp[b][i])a = dp[a][i], b = dp[b][i]; } return dp[a][0]; } void work(int x){ x = our[x]; if(used[x])vis[s[x]] --, ret -= !vis[s[x]]; else ret += !vis[s[x]], vis[s[x]] ++; used[x] ^= 1; } int main(){ scanf("%d%d", &n, &m); for(int i = 1; i <= n; i ++)scanf("%d", &s[i]), p[i] = s[i]; for(int i = 1, a, b; i < n; i ++){ scanf("%d%d", &a, &b); v[a].pb(b), v[b].pb(a); } block = sqrt(tot); sort(p + 1, p + n + 1); int q = unique(p + 1, p + n + 1) - p - 1; for(int i = 1; i <= n; i ++)s[i] = lower_bound(p + 1, p + q + 1, s[i]) - p; dfs(1, 0); for(int i = 1; i <= m; i ++){ int u, v; scanf("%d%d", &u, &v); int LCA = lca(u, v); if(first[u] > first[v])swap(u, v); if(LCA != u)cnt[i] = {last[u], first[v], LCA, i}; else cnt[i] = {first[u], first[v], 0, i}; } sort(cnt + 1, cnt + m + 1); int l = 1, r = 0; for(int i = 1; i <= m; i ++){ while(r < cnt[i].r)work(++ r); while(r > cnt[i].r)work(r --); while(l < cnt[i].l)work(l ++); while(l > cnt[i].l)work(-- l); if(cnt[i].lca)work(cnt[i].lca); ans[cnt[i].id] = ret; if(cnt[i].lca)work(cnt[i].lca); } for(int i = 1; i <= m; i ++)printf("%d\n", ans[i]); return 0; }