传送门
有几档暴力不会写,巨丢人
\(m=2\) 的话两个人之间的距离会覆盖整棵树上所有可能的路径,所以就是求所有树上路径长度的总和
成链且 \(m\) 为奇数的话,集中点肯定是中位数那个点
考场上想偏了,只会用这个性质求一些给定的人应该集中在哪个点
但实际上可以枚举中位数这个点,求出一共有多少匹配的方案
然后正解
点并不好考虑,所以考虑边
如果一条边两边人数不等那数量较少的那些人肯定都得经过这条边
令 \(s\) 为这条边一边的人数
于是一条边的贡献为 \(\sum\limits_{i=1}^{m-1}\binom{s}{i}\binom{n-s}{m-i}min(i, m-i)\)
直接求复杂度会炸,于是转化一下,先考虑弄掉那个min
等价于 \(\sum\limits_{i=1}^{\frac{m-1}{2}}\binom{s}{i}\binom{n-s}{m-i}i + \binom{n-s}{i}\binom{s}{m-i}i\)
令 \(k=\frac{m-1}{2}\)
观察这个式子 \(\sum\limits_{i=1}^{k}\binom{s}{i}\binom{n-s}{m-i}i\),试着把外面的 \(i\) 去掉
于是令 \(G(s)=\sum\limits_{i=1}^{k}\binom{s-1}{i-1}\binom{n-s}{m-i}\),则原式等于 \(s*G(s)\)
考虑组合意义,即为在 \(n-1\) 个物品里选 \(m-1\) 个,要求前 \(s-1\) 中最多能选 \(k-1\) 个的方案数
推到 \(G(s+1)\) 的话会少了前 \(s-1\) 中选了 \(k-1\) 个,且第 \(s+1\) 个也被选了的情况
于是可以 \(O(n)\) 递推
与这个类似的另一部分实际上就是 \(G(n-s)\),不用另算了
#include <bits/stdc++.h> using namespace std; #define INF 0x3f3f3f3f #define N 1000010 #define ll long long // #define int long long char buf[1<<21], *p1=buf, *p2=buf; #define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf, 1, 1<<21, stdin)), p1==p2?EOF:*p1++) inline int read() { int ans=0, f=1; char c=getchar(); while (!isdigit(c)) {if (c=='-') f=-f; c=getchar();} while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48), c=getchar(); return ans*f; } int n, m; int head[N], size; const ll mod=1000000007; struct edge{int to, next;}e[N<<1]; inline void add(int s, int t) {e[++size].to=t; e[size].next=head[s]; head[s]=size;} namespace force{ int vis, dp[30], siz[30], minn; ll ans; void dfs1(int u) { siz[u]=(vis&(1<<(u-1)))?1:0; for (int i=head[u],v; ~i; i=e[i].next) { v = e[i].to; dfs1(v); siz[u]+=siz[v]; dp[u]+=dp[v]+siz[v]; } } void dfs2(int u, int sum) { // cout<<"dfs2: "<<u<<' '<<sum<<endl; minn=min(minn, sum+dp[u]); for (int i=head[u],v; ~i; i=e[i].next) { v = e[i].to; dfs2(v, sum+dp[u]-dp[v]-siz[v]+(m-siz[v])); } } void solve() { memset(head, -1, sizeof(head)); for (int i=2; i<=n; ++i) add(read(), i); int lim=1<<n; for (int s=1,s2,cnt; s<lim; ++s) { s2=s; cnt=0; do {++cnt; s2&=s2-1;} while (s2) ; if (cnt!=m) goto jump; vis=s; // cout<<"s: "<<bitset<5>(s)<<endl; memset(dp, 0, sizeof(dp)); // memset(siz, 0, sizeof(siz)); dfs1(1); minn=INF; dfs2(1, 0); // cout<<"siz: "; for (int i=1; i<=n; ++i) cout<<siz[i]<<' '; cout<<endl; // cout<<"dp: "; for (int i=1; i<=n; ++i) cout<<dp[i]<<' '; cout<<endl; // cout<<"minn: "<<minn<<endl; ans=(ans+minn)%mod; jump: ; } printf("%lld\n", ans); exit(0); } } namespace task1{ int siz[N]; ll fac[N], inv[N], ans; inline ll C(int n, int k) {return n<k?0ll:fac[n]*inv[n-k]%mod*inv[k]%mod;} void dfs(int u) { siz[u]=1; for (int i=head[u],v; ~i; i=e[i].next) { v = e[i].to; dfs(v); siz[u]+=siz[v]; } } void solve() { memset(head, -1, sizeof(head)); for (int i=2; i<=n; ++i) add(read(), i); dfs(1); fac[0]=fac[1]=1; inv[0]=inv[1]=1; for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod; for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod; for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod; for (int i=1; i<n; ++i) { for (int j=1; j<m; ++j) { int s=siz[e[i].to]; ans=(ans+C(s, j)*C(n-s, m-j)%mod*min(j, m-j)%mod)%mod; } } printf("%lld\n", ans); exit(0); } } namespace task{ int siz[N]; ll fac[N], inv[N], ans, G[N], H[N]; inline ll C(int n, int k) {return n<k?0ll:fac[n]*inv[n-k]%mod*inv[k]%mod;} inline ll H2(int s) { int k=(m-1)/2; ll ans=0; for (int i=1; i<=k; ++i) ans=(ans+C(n-s-1, i-1)*C(s, m-i)%mod)%mod; return ans; } inline ll G2(int s) { int k=(m-1)/2; ll ans=0; for (int j=1; j<=k; ++j) ans=(ans+C(s-1, j-1)*C(n-s, m-j)%mod)%mod; return ans; } void dfs(int u) { siz[u]=1; for (int i=head[u],v; ~i; i=e[i].next) { v = e[i].to; dfs(v); siz[u]+=siz[v]; } } void solve() { memset(head, -1, sizeof(head)); for (int i=2; i<=n; ++i) add(read(), i); dfs(1); fac[0]=fac[1]=1; inv[0]=inv[1]=1; for (int i=2; i<=n; ++i) fac[i]=fac[i-1]*i%mod; for (int i=2; i<=n; ++i) inv[i]=(mod-mod/i)*inv[mod%i]%mod; for (int i=2; i<=n; ++i) inv[i]=inv[i-1]*inv[i]%mod; G[1]=G2(1); int k=(m-1)/2; for (int i=1; i<n; ++i) G[i+1]=(G[i]-C(i-1, k-1)*C(n-i-1, m-k-1)%mod)%mod; #if 0 cout<<"G: "; for (int i=1; i<=n; ++i) cout<<(G[i]+mod)%mod<<' '; cout<<endl; for (int i=1; i<=n; ++i) { ll t=0; for (int j=1; j<=k; ++j) t=(t+C(i-1, j-1)*C(n-i, m-j)%mod)%mod; cout<<t<<' '; } cout<<endl; #endif #if 1 // H[n-1]=H2(n-1); // for (int i=n-1; i; --i) H[i-1]=(H[i]-C(n-i-1, k-1)*(i-1, m-k-1)%mod)%mod; // for (int i=1; i<=n; ++i) H[i+1]=(H[i]+C(n-i-1, k-1)*(i-1, m-k-1)%mod)%mod; // cout<<"H2: "; for (int i=1; i<=n; ++i) cout<<H2(i)<<' '; cout<<endl; // cout<<"H: "; for (int i=1; i<=n; ++i) cout<<H[i]<<' '; cout<<endl; #endif for (int i=1; i<n; ++i) { int s=siz[e[i].to]; // cout<<"H: "<<H2(s)<<endl; ans=(ans+s*G[s]%mod+(n-s)*G[n-s]%mod+((m&1)?0:(C(s, m/2)*C(n-s, m/2)%mod*(m/2)%mod)))%mod; } printf("%lld\n", (ans%mod+mod)%mod); exit(0); } } signed main() { freopen("meeting.in", "r", stdin); freopen("meeting.out", "w", stdout); n=read(); m=read(); // force::solve(); task::solve(); return 0; }