AC
自动机AC
自动机是一种用于解决多模式串以及一主串匹配的字符串算法。
问题通常是给出若干个模式串 S
以及主串 T
,询问若干个模式串分别在主串中的某些信息。
AC
自动机构建在 Trie 的结构基础上,结合了 KMP
算法的失配指针思想。
在进行多模式串匹配前,只有两个步骤需要去实现:
将所有模式串扔进一棵 Trie
树中。
对于 Trie
上的所有节点构建失配指针。
AC
自动机的时间复杂度约为 \(O(n+m)\)。
AC
自动机其实就是在 Trie
树上建几条边然后就没了,真的没了。
有关 Trie
树的知识,请出门右转,点这里。
一开始按照 Trie
树的基本构建方法搭建即可。
请注意,Trie
树节点的含义十分重要:
它表示的是某个模式串的前缀,也就是一个状态。
而 Trie
树的边就是状态的转移。
一般 Trie
树的每个节点都代表一个或多个字符串。
建树的代码如下:
const int MAXN = 500005; int nxt[MAXN][26], cnt; // nxt[i][c] 表示 i 号点所连、存储字符为 c + 'a' 的点的编号 void init() // 初始化 { memset(nxt, 0, sizeof(nxt)); cnt = 1; } void insert(const string &s) // 插入字符串 { int cur = 1; for (auto c : s) { // 尽可能重用之前的路径,如果做不到则新建节点 if (!nxt[cur][c - 'a']) nxt[cur][c - 'a'] = ++cnt; cur = nxt[cur][c - 'a']; // 继续向下 } }
好了,到了最重要的一点了,如何构建 Fail
指针?
什么是 Fail
指针呢?
如果一个 Trie
树上的节点 \(u\) 的 Fail
指针指向节点 \(v\),那么这就表示根节点到节点 \(v\) 的字符串是根节点到节点 \(u\) 的字符串的一个后缀。
注意,根节点的所有非空子节点的 Fail
指针都必须指向根节点。
如果看不懂可参考下面这张图。
例如求根节点 \(0\) 的左子树上的那个 \(c\) 节点的 fail
指针,观察可得,根节点到根节点的右子树上的那个 \(c\) 节点组成的字符串(bc
)是根节点到根节点的左子树上的那个 \(c\) 节点组成的字符串(abc
)的一个后缀,所以 \(fail_{左边的 \ c}=右边的 \ c \ 的编号\)。
再思考如何在程序上构建 Fail
指针。
对于一个 Trie
树上的节点 \(u\),设它的父节点为 \(f\),两个节点通过字符 \(c\) 连接,也就是说 \(trie_{f,c}=u\)。
那么求 Fail
指针有两个情况,如下:
如果 \(trie_{f,c}\) 不是空节点,那么就将节点 \(u\) 的 Fail
指针指向 \(trie_{fail_{f},c}\)(肯定满足 Fail
指针的性质)。
如果 \(trie_{f,c}\) 是空节点,那么我们令 \(son_f=son_{fail_{f}}\),即 \(tr_{f,i} = tr_{fail_{f},i}\)。这样做令 AC
自动机的实现相当于不断拓展字符串的后缀,尝试匹配最后一个字符,如果最后一个字符并不存在,那么我们跳转到下一个可能出现该字符的位置,直到结束为止。
这里的 get_fail
函数将 Trie
树上所有节点按照 BFS
的顺序入队,最后依次求 Fail
指针。
首先我们单独处理根节点,将根节点 \(0\) 的所有非空的子节点入队。
然后每次取出队首处理 Fail
指针,即遍历 \(26\) 个字符依次判断(根据题目判断)。
\(fail_{u}\) 就表示节点 \(u\) 的 Fail
指针指向的节点。
代码如下:
void get_fail() { queue<int> q; for(int i = 0; i < 26; ++i) if(tr[0][i]) { fail[tr[0][i]] = 0; q.push(tr[0][i]); } while(!q.empty()) { int now = q.front(); q.pop(); for(int i = 0; i < 26; ++i) { int v = tr[now][i]; if(v) { fail[v] = tr[fail[now]][i]; q.push(v); } else { tr[now][i] = tr[fail[now]][i]; } } } }
然后就没了。。。
不,肯定还没结束。
给定文本串和若干个模式串,求出有多少个不同的模式串在文本串中出现。
对若干个模式串构建好 AC
自动机后,对文本串的每一个前缀跳一遍 Fail
指针就行了。
因为一个字符串的每一个前缀的所有后缀就是这个字符串的所有子串。
一开始记录每个节点对应多少个完整的模式串就行了。
#include<bits/stdc++.h> using namespace std; #define _ (int)2e6 + 5 int n; int tot; int tr[_][27]; int fail[_]; int tag[_]; int num[_]; char c[_]; void insert(char *c) { int len = strlen(c); int u = 0; for(int i = 0; i < len; ++i) { int v = c[i] - 'a'; if(!tr[u][v]) tr[u][v] = ++tot; u = tr[u][v]; } num[u]++; } void get_fail() { queue<int> q; for(int i = 0; i < 26; ++i) if(tr[0][i]) { fail[tr[0][i]] = 0; q.push(tr[0][i]); } while(!q.empty()) { int now = q.front(); q.pop(); for(int i = 0; i < 26; ++i) { int v = tr[now][i]; if(v) { fail[v] = tr[fail[now]][i]; q.push(v); } else { tr[now][i] = tr[fail[now]][i]; } } } } int query(char *s) { int len = strlen(s); int res = 0; int u = 0; for(int i = 0; i < len; ++i) { int v = s[i] - 'a'; u = tr[u][v]; for(int j = u; j && !tag[j]; j = fail[j]) { res += num[j]; tag[j] = 1; } } return res; } signed main() { // freopen("P3808_2.in", "r", stdin); // freopen("2.out", "w", stdout); scanf("%d", &n); // printf("%d\n", n); for(int i = 1; i <= n; ++i) { scanf("%s", c); // printf("%s\n", c); insert(c); } get_fail(); scanf("%s", c); // printf("%s\n", c); printf("%d\n", query(c)); return 0; }
给出若干个模式串和一个文本串,求某个模式串在文本串中出现的最大次数和该模式串,且保证不存在两个相同的模式串。
我们考虑如何查询最大出现次数。
记 \(num_u\),为以 \(u\) 为结尾的那个唯一的字符串读入时的编号。
最后在统计答案时用一个 \(vis\) 数组存储出现的次数,取最大值。
统计答案的方法上面说过了。
然后遍历 \(vis\) 数组,当 \(vis_i\) 与最大值相同时,就输出第 \(i\) 个模式串。
多测记得清空。(别问为什么,血的教训)
#include<bits/stdc++.h> using namespace std; #define _ (int)5e5 + 5 int n; int tot; int tr[_][27]; int fail[_]; int tag[_]; int num[_]; int vis[_]; char c[_][151]; void insert(char *c, int id) { int len = strlen(c); int u = 0; for(int i = 0; i < len; ++i) { int v = c[i] - 'a'; if(!tr[u][v]) tr[u][v] = ++tot; u = tr[u][v]; } num[u] = id; } void get_fail() { queue<int> q; for(int i = 0; i < 26; ++i) if(tr[0][i]) { fail[tr[0][i]] = 0; q.push(tr[0][i]); } while(!q.empty()) { int now = q.front(); q.pop(); for(int i = 0; i < 26; ++i) { int v = tr[now][i]; if(v) { fail[v] = tr[fail[now]][i]; q.push(v); } else { tr[now][i] = tr[fail[now]][i]; } } } } int query(char *s) { int len = strlen(s); int res = 0; int u = 0; for(int i = 0; i < len; ++i) { int v = s[i] - 'a'; u = tr[u][v]; for(int j = u; j; j = fail[j]) { if(!num[j]) continue; vis[num[j]]++; } } for(int i = 1; i <= n; ++i) res = max(res, vis[i]); return res; } void init() { tot = 0; memset(tag, 0, sizeof tag); memset(vis, 0, sizeof vis); memset(tr, 0, sizeof tr); memset(fail, 0, sizeof fail); memset(num, 0, sizeof num); memset(vis, 0, sizeof vis); } signed main() { while(scanf("%d", &n) && n) { init(); for(int i = 1; i <= n; ++i) { scanf("%s", c[i]); insert(c[i], i); } get_fail(); scanf("%s", c[n + 1]); int ans = query(c[n + 1]); printf("%d\n", ans); for(int i = 1; i <= n; ++i) if(vis[i] == ans) printf("%s\n", c[i]); continue; } return 0; }
给你一个文本串 S
和若干个模式串,请你分别求出每个模式串在 S
中出现的次数。
我们可以建出 AC
自动机后把文本串在上面跑一遍,每到达一个节点就把树上这个节点到根路径上的节点计数器 \(+1\)。
然后建一棵 Fail
树,即连一条有向边 \(fail_{i} \to i\)。
那么,一个模式串,在文本串中出现的次数就是它结束的节点子树的权值和,没了。
#include<bits/stdc++.h> using namespace std; #define _ (int)2e6 + 5 int n; int cnt; int tr[_][27]; int fail[_]; int tag[_]; int num[_]; int vis[_]; int siz[_]; char c[_][250]; int tot, head[_], to[_ << 1], nxt[_ << 1]; void add(int u, int v) { to[++tot] = v; nxt[tot] = head[u]; head[u] = tot; } void insert(char *c, int id) { int len = strlen(c); int u = 0; for(int i = 0; i < len; ++i) { int v = c[i] - 'a'; if(!tr[u][v]) tr[u][v] = ++cnt; u = tr[u][v]; } num[id] = u; } void get_fail() { queue<int> q; for(int i = 0; i < 26; ++i) if(tr[0][i]) { fail[tr[0][i]] = 0; q.push(tr[0][i]); } while(!q.empty()) { int now = q.front(); q.pop(); for(int i = 0; i < 26; ++i) { int v = tr[now][i]; if(v) { fail[v] = tr[fail[now]][i]; q.push(v); } else { tr[now][i] = tr[fail[now]][i]; } } } } void dfs(int u) { for(int i = head[u]; i; i = nxt[i]) { int v = to[i]; dfs(v); siz[u] += siz[v]; } } void query(char *s) { int u = 0; int len = strlen(s); for (int i = 0; i < len; ++i) { int v = s[i] - 'a'; u = tr[u][v]; ++siz[u]; } for(int i = 1; i <= cnt; ++i) add(fail[i], i); dfs(0); for(int i = 1; i <= n; ++i) printf("%d\n", siz[num[i]]); } signed main() { scanf("%d", &n); for(int i = 1; i <= n; ++i) { scanf("%s", c[i]); insert(c[i], i); } get_fail(); scanf("%s", c[n + 1]); query(c[n + 1]); return 0; }
如果主串由许多模式串组成,请你求出知道每个模式串分别在主串中出现了多少次。
请自行到洛谷上看主串的定义。
这题也要建 Fail
树,具体看上面。
首先,定义一个节点的权值为该节点属于的字符串个数。
那么,一个节点表示的字符串,在整个字典树中出现的次数就是子树的权值和,没了。
#include<bits/stdc++.h> using namespace std; #define _ (int)2e6 + 5 int n; int cnt; int tr[_][27]; int fail[_]; int tag[_]; int num[_]; int vis[_]; int siz[_]; char c[_]; int tot, head[_], to[_ << 1], nxt[_ << 1]; void add(int u, int v) { to[++tot] = v; nxt[tot] = head[u]; head[u] = tot; } void insert(char *c, int id) { int len = strlen(c); int u = 0; for(int i = 0; i < len; ++i) { int v = c[i] - 'a'; if(!tr[u][v]) tr[u][v] = ++cnt; u = tr[u][v]; siz[u]++; } num[id] = u; } void get_fail() { queue<int> q; for(int i = 0; i < 26; ++i) if(tr[0][i]) { fail[tr[0][i]] = 0; q.push(tr[0][i]); } while(!q.empty()) { int now = q.front(); q.pop(); for(int i = 0; i < 26; ++i) { int v = tr[now][i]; if(v) { fail[v] = tr[fail[now]][i]; q.push(v); } else { tr[now][i] = tr[fail[now]][i]; } } } } void dfs(int u) { for(int i = head[u]; i; i = nxt[i]) { int v = to[i]; dfs(v); siz[u] += siz[v]; } } void query() { for(int i = 1; i <= cnt; ++i) add(fail[i], i); dfs(0); for(int i = 1; i <= n; ++i) printf("%d\n", siz[num[i]]); } signed main() { scanf("%d", &n); for(int i = 1; i <= n; ++i) { scanf("%s", c); insert(c, i); } get_fail(); query(); return 0; }
要求对于每一个模式串,求出其最长的前缀 \(p\),满足 \(p\) 是文本串的子串。
我们可以先找出文本串的所有子串结束的节点,标记为 \(1\)。
然后对于每一个模式串,判断这个模式串的前缀结束的节点是否被标记为 \(1\),最后取长度的最大值即可。
#include <bits/stdc++.h> using namespace std; #define MAXN (int) 1e7 + 7 #define MAXM (int) 1e5 + 7 #define MAXT (int) 100 + 7 int n, m; char kkk[MAXN]; char c[MAXM][MAXT]; int cnt; int tr[MAXN][4]; int tag[MAXN]; int fail[MAXN]; int change(char c) { if(c == 'E') return 0; if(c == 'S') return 1; if(c == 'W') return 2; if(c == 'N') return 3; } void insert(char *s) { int t = 0; int len = strlen(s); for(int i = 0; i < len; ++i) { int b = change(s[i]); if(!tr[t][b]) tr[t][b] = ++cnt; t = tr[t][b]; } } void get_fail() { queue<int> q; for(int i = 0; i < 4; ++i) { if(tr[0][i]) { q.push(tr[0][i]); } } while(!q.empty()) { int now = q.front(); q.pop(); for(int i = 0; i < 4; ++i) { int v = tr[now][i]; if(v) { fail[v] = tr[fail[now]][i]; q.push(v); } else { tr[now][i] = tr[fail[now]][i]; } } } } void Find(char *s) { int t = 0; int len = strlen(s); for(int i = 0; i < len; ++i) { int v = change(s[i]); t = tr[t][v]; for(int j = t; j && !tag[j]; j = fail[j]) { tag[j] = 1; } } } int query(char *s) { int res = 0; int t = 0; int len = strlen(s); for(int i = 0; i < len; ++i) { int v = change(s[i]); t = tr[t][v]; if(tag[t]) res = max(res, i + 1); } return res; } signed main() { scanf("%d%d", &n, &m); scanf("%s", kkk); for(int i = 1; i <= m; ++i) { scanf("%s", c[i]); insert(c[i]); } get_fail(); Find(kkk); for(int i = 1; i <= m; ++i) printf("%d\n", query(c[i])); return 0; }