给定一棵 \(n\) 个点的树,每次从还活着的节点中随机选出一个点,把计数器加上其所在的树的大小并把这个点以及与之相连的边删除,求整棵树都被干掉时计数器上数字的期望。
\(n\leq 10^5\)
思路不难非常想,调代码过于恶心
考虑设 \(a_i\) 表示 \(i\) 的贡献的期望,也就是选 \(i\) 时所在树的大小,那么有:
\[a_i=\sum_j \frac1{\text{dist}(i,j)+1} \]所以答案就变成了
\[ans=E(\sum_{i=1}^n a_i)=\sum_{i=1}^nE(a_i)=\sum_{i=1}^n\sum_j\frac1{\text{dist}(i,j)+1} \]显然可以直接 \(O(n^2)\) 暴力求解,考虑优化
发现问题即是在求树上点对信息,考虑淀粉质维护
对于当前分治中心的每棵子树,设生成函数 \(F(x)=\sum cnt_ix^i\) 其中 \(cnt_i\) 表示到分治中心距离 \(i\) 的点的数量,卷一卷就可以了,时间复杂度 \(O(nlog^2n)\)
人调傻了...
#include<bits/stdc++.h> #define int long long using namespace std; const int N = 4e5+5; const int mod = 1e9+7; const double pi = acos(-1.0); char buf[1<<23],*p1=buf,*p2=buf; #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++) int read(){ int s=0,w=1; char ch=getchar(); while(!isdigit(ch)){if(ch=='-')w=-1;ch=getchar();} while(isdigit(ch))s=s*10+ch-'0',ch=getchar(); return s*w; } int ksm(int a,int b){ int res=1; for(;b;b>>=1,a=a*a%mod) if(b&1)res=res*a%mod; return res; } struct FFT{ struct cplx{ double x,y; cplx(double x=0,double y=0):x(x),y(y){} cplx operator+(const cplx&rhs)const{return cplx(x+rhs.x,y+rhs.y);} cplx operator-(const cplx&rhs)const{return cplx(x-rhs.x,y-rhs.y);} cplx operator*(const cplx&rhs)const{return cplx(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);} }a[N<<1]; int rev[N<<1],n; void init(int len){ for(n=1;n<len;n<<=1); for(int i=0;i<n;++i)rev[i]=(rev[i>>1]>>1)|((i&1)?(n>>1):0); } void work(cplx *a,int op){ for(int i=0;i<n;++i) if(i<rev[i])swap(a[i],a[rev[i]]); for(int mid=1;mid<n;mid<<=1){ cplx wn(cos(pi/mid),op*sin(pi/mid)); for(int r=mid<<1,j=0;j<n;j+=r){ cplx w(1,0),x,y; for(int k=0;k<mid;++k,w=w*wn) x=a[j+k],y=w*a[j+mid+k], a[j+k]=x+y,a[j+mid+k]=x-y; } } } void calc(int* b){ for(int i=0;i<n;++i)a[i]=cplx(b[i],0); work(a,1); for(int i=0;i<n;++i)a[i]=a[i]*a[i]; work(a,-1); for(int i=0;i<n;++i)b[i]=(int)(a[i].x/n+0.5); } }fft; int siz[N],mxp[N],t1[N<<1],t2[N<<1],ans[N<<1]; vector<int>Edge[N]; int n,rt,mx1,mx2,as; bool vis[N]; void getrt(int u,int fath,int all){ siz[u]=1,mxp[u]=0; for(auto v:Edge[u])if(v!=fath&&!vis[v]) getrt(v,u,all),siz[u]+=siz[v],mxp[u]=max(mxp[u],siz[v]); mxp[u]=max(mxp[u],all-siz[u]); if(mxp[u]<mxp[rt])rt=u; } void dfs(int u,int fath,int dep){ t2[dep]++,mx2=max(mx2,dep); for(auto v:Edge[u])if(v!=fath&&!vis[v])dfs(v,u,dep+1); } void calc(int u){ t1[0]=1,mx1=0; for(auto v:Edge[u])if(!vis[v]){ mx2=0,dfs(v,u,1),mx1=max(mx1,mx2); // printf(" v:%lld\n bef:\n",v); for(int i=0;i<=mx2;++i)t1[i]+=t2[i]/*,cout<<t2[i]<<" "*/;//puts(""); fft.init((mx2+1)<<1),fft.calc(t2); // printf(" now:\n"); for(int i=0;i<=(mx2<<1)+1;++i)ans[i]-=t2[i]/*,cout<<t2[i]<<" "*/;//puts(""); memset(t2,0,(mx2*2+5)<<3); } // printf(" u:%lld\n bef:\n",u); // for(int i=0;i<=mx1;++i)cout<<t1[i]<<" ";puts(""); fft.init((mx1+1)<<1),fft.calc(t1); for(int i=0;i<=(mx1<<1);++i)ans[i]+=t1[i]/*,cout<<t1[i]<<" "*/;//puts(""); memset(t1,0,(mx1*2+5)<<3); } void divide(int u){ vis[u]=1,calc(u); for(auto v:Edge[u])if(!vis[v]) mxp[rt=0]=n,getrt(v,0,siz[v]),divide(rt); } signed main(){ n=read(); for(int i=1,u,v;i<n;++i) u=read(),v=read(), Edge[u].push_back(v), Edge[v].push_back(u); mxp[rt=0]=n,getrt(1,0,n),divide(rt); // printf("ans:\n"); for(int i=0;i<n;++i)(as+=ans[i]%mod*ksm(i+1,mod-2)%mod)%=mod,/*cout<<ans[i]<<" "*/;//puts(""); //cout<<as<<endl; for(int i=1;i<=n;++i)(as*=i)%=mod; printf("%lld\n",as); return 0; }