Link!
给出\(n\)个模式串(长度\(<= 5\))
定义\(chainword\)为满足下三个条件的字符串和一对划分
要求\(chainword\)的个数。
\(n <= 8,m <= 10^9\)
我先是想到了一个不太可做的\(dp\),记录后四位字符然后每次转移一个模式串,合法的后四位字符并不多,便考虑矩阵优化,但是问题在于,由于转移的是模式串,\(chainword\)的长度每次不是增加\(1\),不能用矩阵转移。所以我们要考虑设计一个可以每次转移一个字符,状态数不多,且能保证最后答案是一整个一整个模式串转移来的状态。
题解给出了一个巧妙的状态,设\(dp_{i,u,v}\)表示\(chainword\)长度为\(i\),第一个划分走到了\(u\),第二个走到了\(v\),其中\(u,v\)是模式串字典树上的节点。每次枚举字符转移,最终答案是\(dp_{m,rt,rt}\),\(rt\)是字典树根节点。\(u,v\)到根形成的字符串,一定有一个是另一个的后缀,这样的\(u,v\)是不多的。关于个数的计算,我们考虑对反串建字典树,\(u,v\)就是树上的一对祖先和后代,及每个节点的深度之和。不失一般性地,我们假设\(u <= v\)可以算出这样的\(u,v\)只有\(161\)对。求出转移矩阵后快速幂求解即可。
#include<bits/stdc++.h> #define ll long long #define N #define mod 998244353 #define rep(i,a,n) for (int i=a;i<=n;i++) #define per(i,a,n) for (int i=n;i>=a;i--) #define inf 0x3f3f3f3f #define pb push_back #define mp make_pair #define pii pair<int,int> #define fi first #define se second #define lowbit(i) ((i)&(-i)) #define VI vector<int> #define all(x) x.begin(),x.end() #define SZ(x) ((int)x.size()) #define end qwq using namespace std; int n,m; namespace Trie{ int tr[405][26],rt = 0,end[405],cnt; void insert(char *s){ int cur = rt; for(int i = 1;s[i];++i){ if(!tr[cur][s[i]-'a']) tr[cur][s[i]-'a'] = ++cnt; cur = tr[cur][s[i]-'a']; } end[cur] = 1; } } using namespace Trie; struct matrix{ int a[205][205]; matrix(){memset(a,0,sizeof a);} int* operator[](int i){return a[i];} matrix operator*(matrix lhs) const{ matrix res; rep(i,0,200) rep(j,0,200) rep(k,0,200){ res[i][k] = (res[i][k] + 1ll*a[i][j]*lhs[j][k])%mod; } return res; } }base; queue<pii> Q; map<pii,int> id; int tot; int get(pii x){ if(x.fi > x.se) swap(x.fi,x.se); if(id.count(x) > 0) return id[x]; else{ id[x] = tot; Q.push(x); return tot++; } } matrix qpow(matrix a,int b){ matrix res; rep(i,0,200) res[i][i] = 1; while(b){ if(b&1) res = res*a; a = a*a; b >>= 1; } return res; } int main(){ //freopen(".in","r",stdin); //freopen(".out","w",stdout); scanf("%d%d",&n,&m); rep(i,1,n){ char s[105]; scanf("%s",s+1); insert(s); } // printf("%d\n",cnt); get(mp(0,0)); while(!Q.empty()){ pii u = Q.front(); Q.pop(); int x = u.fi,y = u.se,cid = get(mp(x,y)); // printf("%d %d %d\n",x,y,cid); rep(i,0,25){ int tx = tr[x][i],ty = tr[y][i]; if(!tx || !ty) continue; base[cid][get(mp(tx,ty))]++; if(end[tx]) base[cid][get(mp(0,ty))]++; if(end[ty]) base[cid][get(mp(0,tx))]++; if(end[tx] && end[ty]) base[cid][get(mp(0,0))]++; } } base = qpow(base,m); printf("%d\n",base[0][0]); return 0; }