基数+计数排序
不考虑暴力了,直接搞上正解。
我们设 \(sa[i],rk[i]\) 分别表示第 \(i\) 名的子串初始点在哪,以及以 \(i\) 开头的子串的排名。
我们考虑倍增的做法。先将长度为 \(1\) 的子串排序求出。
然后每次倍增长度,设长度为 \(w\),然后我们对于每个长度为 \(w\) 的子串以当前位置 \(x\) 的长度为 \(w/2\) 的子串的排名为第一关键字,以 \(x+w\) 位置的长度为 \(w/2\) 子串的排名为第二关键字,然后搞一波基数排序即可。
最后 \(w>n\) 时当场终止。
我们发现只是中途在长度刚好大于等于 \(w\) 时终止,也可以求出长度为 \(w\) 的子串的排序。
不难证明,这个复杂度是 \(O(n\log n)\) 的。
给出一个参考代码。
#include<bits/stdc++.h> #define ll long long #define db double #define filein(a) freopen(#a".in","r",stdin) #define fileot(a) freopen(#a".out","w",stdout) #define sky fflush(stdout); #define gc getchar #define pc putchar namespace IO{ inline bool blank(char c){ return c==' ' or c=='\n' or c=='\t' or c=='\r' or c==EOF; } inline void gs(char *s){ char ch=gc(); while(blank(ch) ) {ch=gc();} while(!blank(ch) ) {*s++=ch;ch=gc();} *s=0; } inline void gs(std::string &s){ char ch=gc();s+='#'; while(blank(ch) ) {ch=gc();} while(!blank(ch) ) {s+=ch;ch=gc();} } inline void ps(char *s){ while(*s!=0) pc(*s++); } inline void ps(const std::string &s){ for(auto it:s) if(it!='#') pc(it); } template<class T> inline void read(T &s){ s=0;char ch=gc();bool f=0; while(ch<'0'||'9'<ch) {if(ch=='-') f=1;ch=gc();} while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=gc();} if(ch=='.'){ db p=0.1;ch=gc(); while('0'<=ch&&ch<='9') {s=s+p*(ch^48);ch=gc();} } s=f?-s:s; } template<class T,class ...A> inline void read(T &s,A &...a){ read(s);read(a...); } }; using IO::read; using IO::gs; using IO::ps; const int S=1e6+3; int sa[S],rk[S<<1],lark[S]; int main(){ //filein(a);fileot(a); char *c=new char[S]; gs(c+1); int n=strlen(c+1); int *id=new int[n+3]; int m=std::max(n,300); int *cnt=new int[S+3]; for(int i=0;i<=m;++i) cnt[i]=0; for(int i=1;i<=n;++i) ++cnt[rk[i]=c[i] ]; for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1]; for(int i=n;i>=1;--i) sa[cnt[rk[i] ]--]=i; delete []cnt; cnt=NULL; for(int w=1;w<n;w<<=1){ cnt=new int[m+3]; for(int i=0;i<=m;++i) cnt[i]=0; for(int i=1;i<=n;++i) id[i]=sa[i]; for(int i=1;i<=n;++i) ++cnt[rk[id[i]+w] ]; for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1]; for(int i=n;i>=1;--i) sa[cnt[rk[id[i]+w] ]--]=id[i]; /*----cutline----*/ for(int i=0;i<=m;++i) cnt[i]=0; for(int i=1;i<=n;++i) id[i]=sa[i]; for(int i=1;i<=n;++i) ++cnt[rk[id[i] ] ]; for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1]; for(int i=n;i>=1;--i) sa[cnt[rk[id[i] ] ]--]=id[i]; /*----cutline----*/ for(int i=1;i<=n;++i) lark[i]=rk[i]; int p=0; for(int i=1;i<=n;++i){ if( i!=1 and lark[sa[i] ]==lark[sa[i-1] ] and lark[sa[i]+w]==lark[sa[i-1]+w] ){ rk[sa[i] ]=p; }else{ rk[sa[i] ]=++p; } } m=p; delete []cnt; cnt=NULL; } for(int i=1;i<=n;++i) printf("%d ",sa[i]); pc('\n'); delete []c; c=NULL; delete []cnt; cnt=NULL; return 0; }
我们交一下上面的代码,发现常数巨大,所以我们考虑卡一下常数。
我们上一回求的 \(sa\) 可以直接用。发现后面的位置不存在长度为 \(w\) 的子串,直接倒着打入前几名。
然后其他的直接枚举 \(sa[i]\), 这个是枚举的排名为 \(i\) 的第二关键字的位置,然后 \(sa[i]-w\) 就是当前位置,直接接着打入排名即可。
这个是常数大的次要原因。
每次给 \(rk\) 上值的时候记录一下有几个不同的,我们需要值域就为那么大。
有稍微的常数优化。
我们可以用数组存储一下,防止反复算相同的东西,具体可以看看代码。
如果每个子串排名互不相同就可以退了。
优化后的代码如下。
#include<bits/stdc++.h> #define ll long long #define db double #define filein(a) freopen(#a".in","r",stdin) #define fileot(a) freopen(#a".out","w",stdout) #define sky fflush(stdout); #define Better_IO 1 namespace IO{ inline bool blank(const char &c){ return c==' ' or c=='\n' or c=='\t' or c=='\r' or c==EOF; } #if Better_IO==true char buf[(1<<20)+3],*p1(buf),*p2(buf); char buf2[(1<<20)+3],*p3(buf2); const int lim=1<<20; inline char gc(){ if(p1==p2) p2=(p1=buf)+fread(buf,1,lim,stdin); return p1==p2?EOF:*p1++; } #define pc putchar #else #define gc getchar #define pc putchar #endif inline void gs(char *s){ char ch=gc(); while(blank(ch) ) {ch=gc();} while(!blank(ch) ) {*s++=ch;ch=gc();} *s=0; } inline void gs(std::string &s){ char ch=gc();s+='#'; while(blank(ch) ) {ch=gc();} while(!blank(ch) ) {s+=ch;ch=gc();} } inline void ps(char *s){ while(*s!=0) pc(*s++); } inline void ps(const std::string &s){ for(auto it:s) if(it!='#') pc(it); } template<class T> inline void read(T &s){ s=0;char ch=gc();bool f=0; while(ch<'0'||'9'<ch) {if(ch=='-') f=1;ch=gc();} while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=gc();} if(ch=='.'){ db p=0.1;ch=gc(); while('0'<=ch&&ch<='9') {s=s+p*(ch^48);ch=gc();} } s=f?-s:s; } template<class T,class ...A> inline void read(T &s,A &...a){ read(s);read(a...); } }; using IO::read; using IO::gs; using IO::ps; const int S=1e6+3; int sa[S],rk[S<<1],kc[S],lark[S]; inline bool cmps(int a,int b,int c){ return b!=0 and lark[a]==lark[b] and lark[a+c]==lark[b+c]; } int main(){ //filein(a);fileot(a); char *c=new char[S]; gs(c+1); int n=strlen(c+1); int *id=new int[n+3]; int m=std::max(n,128); int *cnt=new int[S+3]; for(int i=0;i<=m;++i) cnt[i]=0; for(int i=1;i<=n;++i) ++cnt[rk[i]=c[i] ]; for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1]; for(int i=n;i>=1;--i) sa[cnt[rk[i] ]--]=i; delete []cnt; cnt=NULL; for(int w=1;w<n;w<<=1){ cnt=new int[m+3]; int p=0; for(int i=n;i>n-w;--i) id[++p]=i; for(int i=1;i<=n;++i){ if(sa[i]>w) id[++p]=sa[i]-w; } /*----cutline----*/ for(int i=0;i<=m;++i) cnt[i]=0; for(int i=1;i<=n;++i) ++cnt[kc[i]=rk[id[i] ] ]; for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1]; for(int i=n;i>=1;--i) sa[cnt[kc[i] ]--]=id[i]; /*----cutline----*/ for(int i=1;i<=n;++i) lark[i]=rk[i]; m=0; for(int i=1;i<=n;++i){ rk[sa[i] ]=cmps(sa[i],sa[i-1],w)?m:++m; } delete []cnt; cnt=NULL; if(m==n) break; } for(int i=1;i<=n;++i) printf("%d ",sa[i]); pc('\n'); delete []c; c=NULL; delete []cnt; cnt=NULL; return 0; }