模式串匹配,即给定一个文本串 \(A\) 和一个模式串 \(B\),询问 \(B\) 在 \(A\) 中是否出现、出现的次数及每次出现的位置等。通常数据范围为 \(1\le|A|,|B|\le10^6\)。
显然,我们可以枚举 \(A\) 的下标 \(i\),对于每一个 \(i\),都尝试用 \(B\) 去匹配( \(n=|A|,m=|B|\)):
for (int i = 1; i <= n; i++) { bool flag = true; for (int j = 1; j <= m; j++) { if (a[i + j - 1] != b[j]) { flag = false; break; } } if (flag) { // 即 B 在 A 中出现过 } }
暴力的时间复杂度为 \(\operatorname{O}(nm)\),直接 和 cxr 一起 炸了。
我们会发现,暴力之所以慢,是因为第 \(i\) 位失配后 \(j\) 重置为 \(1\),会出现很多重复的匹配,而 \(\rm KMP\) 算法通过优化,使得我们不用再从头开始枚举 。
对于下面一组数据:
a | b | c | a | b | c | a | b | b |
---|---|---|---|---|---|---|---|---|
a | b | c | a | b | b |
当第 \(6\) 位失配时,我们不必将模式串一位一位往右移,而是直接将模式串右移 \(3\) 位:
a | b | c | a | b | c | a | b | b |
---|---|---|---|---|---|---|---|---|
a | b | c | a | b | b |
因为模式串中,第 \(1,2\) 位与第 \(4,5\) 位相同,第 \(4,5\) 位匹配且第 \(6\) 位失配时,就把第 \(1,2\) 位移过来,从第 \(6\) 位(原来的第 \(3\) 位)继续匹配。
\(\rm KMP\) 算法的核心就是找到模式串中像上面 \(1,2\) 与 \(4,5\) 相同的子串,我们用一个 \(nxt\) 数组,\(nxt_i\) 的意义为当第 \(i\) 位失配后要跳到哪一位,即模式串前缀 \(B[1\sim i]\) 中既是前缀又是后缀的子串(不能为自身)里长度最长的子串的长度,例如字符串 \(\text{ababcaababab}\), 前 \(10\) 位 \(\text{ababcaabab}\) 中,既是前缀又是后缀的有 \(\text{ab,abab}\),长度最长的是 \(\text{abab}\),长度为 \(4\),所以 \(nxt_{10}=4\)。这两种意义的 \(nxt\) 值相同,但对于不同的题目各有用处。
显然 \(nxt_0=nxt_1=0\),求 \(nxt\) 数组的的过程相当于自己和自己匹配。
for (int i = 2, j = 0; i <= m; i++) { while (j && b[i] != b[j + 1]) // 只要失配就不停往回跳(跳到 j=0 就直接从第一位开始匹配了) { j = nxt[j]; } if (b[i] == b[j + 1]) // 相同就可以往前 { j++; } nxt[i] = j; // 记录 i 失配后往哪跳 }
有了 \(nxt\) 数组后就可以直接与文本串匹配了:
for (int i = 1, j = 0; i <= n; i++) { while (j && a[i] != b[j + 1]) // 失配往回跳 { j = nxt[j]; } if (a[i] == b[j + 1]) { j++; } if (j == m) { // B 在 A 中出现了一次 } }
\(\rm KMP\) 算法的时间复杂度证明
每次执行 \(\operatorname{while}\) 循环时,\(j\) 的值都在不停减小,而在每层 \(\operatorname{for}\) 循环里 \(j\) 最多增加 \(1\),即 \(j\) 至多增加 \(n+m\) 次,因为 \(j\) 始终非负,所以减少的幅度不会超过增加的幅度,则减少的次数不会超过增加的次数,所以 \(j\) 最多变化 \(2(n+m)\) 次,\(\rm KMP\) 算法的时间复杂度在 \(\operatorname{O}(n)\) 级别。
求出模式串 \(B\) 在文本串 \(A\) 中所有出现的位置和 \(B\) 的每一个前缀的 \(nxt\) 值。
#include <iostream> #include <cstdio> #include <cstring> using namespace std; const int MAXN = 1e6 + 5; char a[MAXN], b[MAXN]; int nxt[MAXN]; int main() { scanf("%s%s", a + 1, b + 1); // 下标从1开始 int n = strlen(a + 1), m = strlen(b + 1); for (int i = 2, j = 0; i <= m; i++) { while (j && b[i] != b[j + 1]) { j = nxt[j]; } if (b[i] == b[j + 1]) { j++; } nxt[i] = j; } for (int i = 1, j = 0; i <= n; i++) { while (j && a[i] != b[j + 1]) { j = nxt[j]; } if (a[i] == b[j + 1]) { j++; } if (j == m) { printf("%d\n", i - m + 1); // 起点要减去 (m-1) } } for (int i = 1; i <= m; i++) { printf("%d ", nxt[i]); } return 0; }