虑$u$对答案的贡献(指以$u$为第一次战争)——
注意到$v$崛起时有贡献,当且仅当上一次崛起在$u$与$v$不同的儿子中(将$u$自身也看作一棵子树)
换言之,问题可以抽象为有$A_{i}$个$i$,将这$\sum_{i=1}^{k}A_{i}$个数任意排列后最大交替次数
记$x=\max_{1\le i\le k}A_{i},s=\sum_{i=1}^{k}A_{i}$,则最大交替次数为$\begin{cases}s-1&2x\le s\\2(s-x)&2x>s\end{cases}$(具体可以归纳证明)
代入原问题,有$x=\max(a_{u},S_{son})$且$s=S_{u}$,进而即可$o(n)$求出答案
进一步的,称$son$为$u$的重儿子当且仅当$2S_{son}>S_{u}$(显然一个点至多有一个重儿子)
由于$w_{i}\ge 0$,即修改仅会使得$x_{i}$到根节点上某些非重儿子变为重儿子(代替原来的重儿子,若存在)
不难发现,这与access的过程是类似地(但并不一定将所有边均变为实边),以此实现即可
关于复杂度,LCT中splay和虚边的复杂度是独立的,并分别考虑:
前者仍是$o(n\log n+m\log n)$,后者根据权值和翻倍的性质为$o(m\log V)$(其中$V$为值域)
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 400005 4 #define ll long long 5 vector<int>e[N]; 6 int n,m,x,y,fa[N],ch[N][2]; 7 ll ans,a[N],sz[N],sum_vir[N],sum[N]; 8 int which(int k){ 9 return ch[fa[k]][1]==k; 10 } 11 bool check(int k){ 12 return ch[fa[k]][which(k)]!=k; 13 } 14 void up(int k){ 15 sz[k]=a[k]+sum[ch[k][1]]+sum_vir[k]; 16 sum[k]=sz[k]+sum[ch[k][0]]; 17 } 18 void rotate(int k){ 19 int f=fa[k],g=fa[f],p=which(k); 20 fa[k]=g; 21 if (!check(f))ch[g][which(f)]=k; 22 fa[ch[k][p^1]]=f,ch[f][p]=ch[k][p^1]; 23 fa[f]=k,ch[k][p^1]=f,up(f),up(k); 24 } 25 void splay(int k){ 26 for(int i=fa[k];!check(k);i=fa[k]){ 27 if (!check(i)){ 28 if (which(i)^which(k))rotate(k); 29 else rotate(i); 30 } 31 rotate(k); 32 } 33 } 34 ll get_ans(int k){ 35 if ((a[k]<<1)>sz[k])return (sz[k]-a[k]<<1); 36 if (!ch[k][1])return sz[k]-1; 37 return (sz[k]-sum[ch[k][1]]<<1); 38 } 39 void dfs(int k,int f){ 40 fa[k]=f; 41 for(int i=0;i<e[k].size();i++) 42 if (e[k][i]!=f)dfs(e[k][i],k),sz[k]+=sz[e[k][i]]; 43 sum_vir[k]=sz[k],sz[k]+=a[k],sum[k]=sz[k]; 44 for(int i=0;i<e[k].size();i++) 45 if ((e[k][i]!=f)&&((sum[e[k][i]]<<1)>sz[k])){ 46 ch[k][1]=e[k][i]; 47 sum_vir[k]-=sum[e[k][i]]; 48 } 49 ans+=get_ans(k); 50 } 51 void update(int k,int x){ 52 for(int i=k;i;i=fa[i])splay(i),ans-=get_ans(i); 53 a[k]+=x,up(k); 54 if ((sum[ch[k][1]]<<1)<=sz[k]){ 55 sum_vir[k]+=sum[ch[k][1]]; 56 ch[k][1]=0; 57 } 58 ans+=get_ans(k); 59 for(int i=k;fa[i];i=fa[i]){ 60 sum_vir[fa[i]]+=x,up(fa[i]); 61 if ((sum[ch[fa[i]][1]]<<1)<=sz[fa[i]]){ 62 sum_vir[fa[i]]+=sum[ch[fa[i]][1]]; 63 ch[fa[i]][1]=0; 64 if ((sum[i]<<1)>sz[fa[i]]){ 65 sum_vir[fa[i]]+=sum[ch[fa[i]][1]]-sum[i]; 66 ch[fa[i]][1]=i; 67 } 68 } 69 ans+=get_ans(fa[i]); 70 } 71 } 72 int main(){ 73 scanf("%d%d",&n,&m); 74 for(int i=1;i<=n;i++)scanf("%lld",&a[i]); 75 for(int i=1;i<n;i++){ 76 scanf("%d%d",&x,&y); 77 e[x].push_back(y),e[y].push_back(x); 78 } 79 dfs(1,0),printf("%lld\n",ans); 80 for(int i=1;i<=m;i++){ 81 scanf("%d%d",&x,&y); 82 update(x,y),printf("%lld\n",ans); 83 } 84 return 0; 85 }View Code