又是向杨大佬学习的一天
求树上所有叶子节点距离的平方和
就是求上面这个式子
求上面这个式子
----------------------------------------------------------------------------------------------------------
这两个题挺相似的,如果都考虑朴素做法的话,都是两层for枚举两个节点,复杂度已经到n^2了
所以考虑树上点分治,一种logn的复杂度解决树上静态询问的问题的数据结构(我瞎说的)
动态的话就要点分树了(学长跟我说的)
这个题我们可以先处理子树间的贡献,这样的话就不需要常规分治算法的合并的这个步骤了
点分治是要每次找一个新的重心
每次以新重心为根节点处理一部分子树
每次找子树的重心最多找logN次,那么需要一个O(N) 的计算贡献的方法
如果每次新来一个叶子节点,我需要和每个已经处理好的一部分叶子节点算贡献的话,如果枚举处理好的那部分
复杂度又是O(N^2)的
假设当前重心是root
因为我们选择先处理子树间的贡献,这样可以不用最后的合并操作
设已经处理好的一部分子树有叶子节点pre个,距离的平方和为sum1,距离和为sum2
我们展开一下(a[i]+a[j])^2这个式子(新来的是a[j]这个叶子节点)
a[i]^2 + 2*a[i]*a[j] + a[j]^2(下称第一项第二项第三项)
因为a[i]_1 a[i]_2....a[i]_cnt都要算一遍
第一项的贡献就是Sum1,第二项的贡献就是2*sum2*a[j] (这里O(n)枚举就行),第三项的贡献是a[j]*a[j]*pre(同第二项枚举计算就好)
// Problem: Gene Tree // Contest: NowCoder // URL: https://ac.nowcoder.com/acm/contest/15644/B // Memory Limit: 524288 MB // Time Limit: 2000 ms // // Powered by CP Editor (https://cpeditor.org) #include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> PII; typedef unsigned long long ull; const int inf = 0x3f3f3f3f; const long long INF = 1e18; const int maxn = 2e5 + 7; const ll mod = 1e9 + 7; #define pb push_back #define debug(x) cout << #x << ":" << x << endl; #define mst(x, a) memset(x, a, sizeof(x)) #define rep(i, a, b) for (int i = (a); i <= (b); ++i) #define dep(i, a, b) for (int i = (a); i >= (b); --i) inline ll read() { ll x = 0; bool f = 0; char ch = getchar(); while (ch < '0' || '9' < ch) f |= ch == '-', ch = getchar(); while ('0' <= ch && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return f ? -x : x; } void out(ll x) { int stackk[20]; if (x < 0) { putchar('-'); x = -x; } if (!x) { putchar('0'); return; } int top = 0; while (x) stackk[++top] = x % 10, x /= 10; while (top) putchar(stackk[top--] + '0'); } ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans = ans * a % mod; a = a * a % mod; b >>= 1; } return ans; } #define int ll int n,cnt,head[maxn]; struct node{ int u,v,w,next; }e[maxn]; void add(int u,int v,int w) { e[cnt].u=u,e[cnt].v=v,e[cnt].w=w; e[cnt].next = head[u],head[u]=cnt++; } int d[maxn]; int vis[maxn];// the root is visit int maxson[maxn]; int siz[maxn],Smer,Mx,root; int ans; void getroot(int u,int p) { siz[u] = 1,maxson[u]=0; for(int i = head[u];~i;i=e[i].next) { int v = e[i].v; if(v==p||vis[v]) continue; getroot(v,u); siz[u]+=siz[v]; if(maxson[u]<siz[v]) maxson[u] = siz[v]; } maxson[u] = max(maxson[u],Smer - siz[u]); if(maxson[u]<Mx) Mx = maxson[u],root = u; } int temp[maxn]; void solve(int u,int p,int len) { if(d[u]==1) { temp[++cnt] = len; return ; } for(int i = head[u];~i;i=e[i].next) { int v = e[i].v; if(vis[v]||v==p) continue; solve(v,u,len+e[i].w); } } void Divide(int tr) { //solve(tr,0,0); vis[tr]=1; int sum1=0,sum2=0; int pre=0; for(int i = head[tr];~i;i=e[i].next) { int v = e[i].v; if(vis[v]) continue; cnt=0; solve(v,0,e[i].w); for(int j=1;j<=cnt; j++) ans+=(sum1+2ll*temp[j]*sum2+temp[j]*temp[j]*pre); for(int j=1;j<=cnt; j++) sum1+=(temp[j]*temp[j]),sum2+=temp[j],pre++; Smer = siz[v];root=0; Mx=inf;getroot(v,0); Divide(root); } } #define int int int main() { // ios::sync_with_stdio(false); mst(head,-1); n = read(); for(int i=1 ;i<n ;i++) { ll u,v,w; u = read(),v=read(),w=read(); add(u,v,w),add(v,u,w); d[u]++,d[v]++; } Mx = inf,Smer = n; getroot(1,0),Divide(root); out(ans); return 0; } /* */View Code
跟上面一样
step1:先处理子树之间的贡献
step2:写一个式子,如果新来了一个点,如何计算和之前所有点的贡献
step3:AC
计算贡献的方法,首先,我们对当前处理的所有子树的点,按照权值大小排序
假设一共有m个节点,新来的是二号节点,那么二号节点最为最小值影响到的是[3,m]的节点
val = (dis[2][root] + dis[k][root])*minn -->这个式子可以用距离和去维护,然后维护一个前缀值就能算区间的了
But
这里我们处理子树之间的时候,因为不同于上一个题是叶子节点,这个题是所有节点,
所以会处理了一部分相同子树内部的节点,这部分贡献是要减去的
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> PII; typedef unsigned long long ull; const int inf = 0x3f3f3f3f; const int maxn = 2e5 + 7; const ll mod = 998244353 ; #define pb push_back #define debug(x) cout << #x << ":" << x << endl; #define mst(x, a) memset(x, a, sizeof(x)) #define rep(i, a, b) for (int i = (a); i <= (b); ++i) #define dep(i, a, b) for (int i = (a); i >= (b); --i) inline ll read() { ll x = 0; bool f = 0; char ch = getchar(); while (ch < '0' || '9' < ch) f |= ch == '-', ch = getchar(); while ('0' <= ch && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); return f ? -x : x; } void out(ll x) { int stackk[20]; if (x < 0) { putchar('-'); x = -x; } if (!x) { putchar('0'); return; } int top = 0; while (x) stackk[++top] = x % 10, x /= 10; while (top) putchar(stackk[top--] + '0'); } ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans = ans * a % mod; a = a * a % mod; b >>= 1; } return ans; } int n,cnt,head[maxn*2]; struct node{ int u,v,w,next; }e[maxn*2]; void add(int u,int v,int w) { e[cnt].u=u,e[cnt].v=v,e[cnt].w=w; e[cnt].next = head[u],head[u]=cnt++; } int vis[maxn];// the root is visit int maxson[maxn]; int siz[maxn],Smer,Mx,root; int ans; void getroot(int u,int p) { siz[u] = 1,maxson[u]=0; for(int i = head[u];~i;i=e[i].next) { int v = e[i].v; if(v==p||vis[v]) continue; getroot(v,u); siz[u]+=siz[v]; if(maxson[u]<siz[v]) maxson[u] = siz[v]; } maxson[u] = max(maxson[u],Smer - siz[u]); if(maxson[u]<Mx) Mx = maxson[u],root = u; } int val[maxn],sum[maxn]; struct Node{ int id,dis,val; }no[maxn]; int tot; bool cmp(Node x ,Node y ){return x.val<y.val;} void get_dis(int u,int p,int len) { no[++tot] = {u,len,val[u]}; for(int i=head[u];~i;i=e[i].next) { int v = e[i].v; if(v==p||vis[v]) continue; get_dis(v,u,len+e[i].w); } } ll solve(int u,int p,int len) { tot=0; get_dis(u,p,len); sort(no+1,no+1+tot,cmp); ll temp=0; for(int i=1 ;i<=tot; i++) sum[i] = (sum[i-1] + no[i].dis)%mod; for(int i=1 ;i<=tot ; i++) { temp=((temp+((tot-i+1)%mod*(no[i].val*no[i].dis)%mod)%mod+(no[i].val*(sum[tot] - sum[i])))%mod+mod)%mod; } return temp*2%mod; } void Divide(int tr) { ans+=solve(tr,0,0); ans%=mod; vis[tr]=1; for(int i = head[tr];~i;i=e[i].next) { int v = e[i].v; if(vis[v]) continue; ans=((ans - solve(v,0,e[i].w))%mod + mod )%mod; Smer = siz[v];root=0; Mx=inf;getroot(v,0); Divide(v); } } int main() { // ios::sync_with_stdio(false); mst(head,-1); n = read(); rep(i,1,n) val[i] = read(); for(int i=1 ;i<n ;i++) { ll u,v,w; u = read(),v=read(),w=1; add(u,v,w),add(v,u,w); } Mx = inf,Smer = n; getroot(1,0),Divide(root); out(ans); return 0; } /* */View Code