首先可以差分将限制转化为 \((a_{r_1}\oplus a_{r_2})+(r_1-l_1+1)\le k\)。
将 \(\texttt{SAM}\) 建出来后对于每个本质不同子串的 \(\text{endpos}\) 考虑。设点 \(x_1,x_2\) 分别对应原序列中 \(r_1,r_2\) 在 \(\texttt{parent tree}\) 上的位置,设 \(y=\operatorname{lca}(x_1,x_2)\) 那么点对 \(x_1,x_2\) 的贡献为 \(\sum\limits_{i=1}^{\operatorname{len}_y}[(a_{r_1}\oplus a_{r_2})+i\le k]\)。
注意到我们要统计所有的 \(\operatorname{endpos}\) 点对,可以考虑使用 \(\texttt{dsu on tree}\) 优化。于是你要维护一个数据结构,实现以下操作:
这个查询其实还有点阴间。我们记 \(c_k=\sum\limits_{y\in S}[x\oplus y\le k],g_k=\sum\limits_{y\in S}[x\oplus y\le k](x\oplus y)\),那么就有 \(A=\sum\limits_{x\oplus y\le k}\min\{k-(x\oplus y),d\}=d\cdot c_{k-d}+k\cdot (c_{k}-c_{k-d})-(g_{k}-g_{k-d})\)。于是我们只需考虑如何求出 \(c_k,g_k\)。
建出 \(\texttt{01trie}\)。查询 \(c_k\) 是基操,不用多说;而查询 \(g_k\) 时,只需要在 \(\texttt{trie}\) 的每个结点上拆位维护每一位 \(1\) 的个数即可。
总时间复杂度为 \(\mathcal O(n\log n\log ^2V)\),空间复杂度为 \(\mathcal O(n\log ^2V)\)。由于 \(\texttt{dsu on tree}\) 和 \(\texttt{01trie}\) 的常数都很小,就过了。
#include <bits/stdc++.h> using namespace std; static constexpr int mod = 998244353; inline int add(int x, int y) { return x += y - mod, x + (x >> 31 & mod); } inline int sub(int x, int y) { return x -= y, x + (x >> 31 & mod); } inline int mul(int x, int y) { return (int64_t)x * y % mod; } inline void add_eq(int &x, int y) { x += y - mod, x += (x >> 31 & mod); } inline void sub_eq(int &x, int y) { x -= y, x += (x >> 31 & mod); } inline void mul_eq(int &x, int y) { x = (int64_t)x * y % mod; } static constexpr int Maxn = 2e5 + 5, MaxS = 26; int n, en, head[Maxn], dn, ans; int wl, wr, w[Maxn]; char str[Maxn]; struct Edge { int to, nxt; } e[Maxn]; void add_edge(int u, int v) { e[++en] = (Edge){v, head[u]}, head[u] = en; } struct state { int ch[MaxS], link, len; } tr[Maxn]; int last, sn, edp[Maxn], iedp[Maxn]; void extend(int c) { int p = last, cur = last = ++sn, r; edp[tr[cur].len = tr[p].len + 1] = cur; iedp[cur] = tr[cur].len; for (; ~p && !tr[p].ch[c]; p = tr[p].link) tr[p].ch[c] = cur; if (p == -1) return ; int q = tr[p].ch[c]; if (tr[q].len == tr[p].len + 1) tr[cur].link = q; else { tr[r = ++sn].len = tr[p].len + 1; memcpy(tr[r].ch, tr[q].ch, MaxS << 2); for (; ~p && tr[p].ch[c] == q; p = tr[p].link) tr[p].ch[c] = r; tr[r].link = tr[q].link, tr[q].link = tr[cur].link = r; } } // extend namespace trie { static constexpr int LOG = 17; struct node { int ch[2], c, s[LOG]; node() = default; } tr[Maxn * LOG * 2]; int tn = 1; inline int newnode(void) { return tr[++tn] = node(), tn; } // trie::newnode void insert(int w) { int p = 1; tr[p].c++; for (int k = 0; k < LOG; ++k) tr[p].s[k] += (w >> k & 1); for (int i = LOG - 1; i >= 0; --i) { int dir = w >> i & 1; if (!tr[p].ch[dir]) tr[p].ch[dir] = newnode(); p = tr[p].ch[dir]; tr[p].c++; for (int k = 0; k < LOG; ++k) tr[p].s[k] += (w >> k & 1); } } // trie::insert pair<int, int> ask(int w, int r) { if (r < 0) return {0, 0}; int p = 1, c = 0, s = 0; for (int i = LOG - 1; i >= 0 && p; --i) { int dir = ((w ^ r) >> i & 1) ^ 1; if (r >> i & 1) { c += tr[tr[p].ch[dir]].c; for (int k = 0; k < LOG; ++k) { int cs = (w >> k & 1) ? tr[tr[p].ch[dir]].c - tr[tr[p].ch[dir]].s[k] : tr[tr[p].ch[dir]].s[k]; add_eq(s, ((int64_t)cs << k) % mod); } } p = tr[p].ch[dir ^ 1]; } c += tr[p].c; for (int k = 0; k < LOG; ++k) { int cs = (w >> k & 1) ? tr[p].c - tr[p].s[k] : tr[p].s[k]; add_eq(s, ((int64_t)cs << k) % mod); } return {c, s}; } // trie::ask } // namespace trie inline int ask(int w, int r, int len) { auto r1 = trie::ask(w, r - len), r2 = trie::ask(w, r); return add(mul(r1.first, len), sub(mul(r2.first - r1.first, r), sub(r2.second, r1.second))); } // ask inline int query(int w, int len) { return sub(ask(w, wr, len), ask(w, wl - 1, len)); } int sz[Maxn], hson[Maxn], dep[Maxn], dfn[Maxn], idfn[Maxn]; void sack_init(int u, int depth) { dep[u] = depth; sz[u] = 1, hson[u] = -1; idfn[dfn[u] = ++dn] = u; for (int i = head[u], v; i; i = e[i].nxt) { sack_init(v = e[i].to, depth + 1), sz[u] += sz[v]; if (hson[u] == -1 || sz[v] > sz[hson[u]]) hson[u] = v; } } // sack_init void sack(int u, bool keep) { for (int i = head[u], v; i; i = e[i].nxt) if ((v = e[i].to) != hson[u]) sack(v, false); if (hson[u] != -1) sack(hson[u], true); if (iedp[u] != 0) { add_eq(ans, query(w[iedp[u]], tr[u].len)); trie::insert(w[iedp[u]]); } for (int i = head[u], v; i; i = e[i].nxt) if ((v = e[i].to) != hson[u]) { for (int i = dfn[v], x; i < dfn[v] + sz[v]; ++i) if (iedp[x = idfn[i]] != 0) add_eq(ans, query(w[iedp[x]], tr[u].len)); for (int i = dfn[v], x; i < dfn[v] + sz[v]; ++i) if (iedp[x = idfn[i]] != 0) trie::insert(w[iedp[x]]); } if (!keep) trie::tr[trie::tn = 1] = trie::node(); } // sack int main(void) { scanf("%d%s", &n, str + 1); last = sn = 0, tr[0].link = -1; for (int i = 1; i <= n; ++i) extend(str[i] - 'a'); for (int i = 1; i <= sn; ++i) add_edge(tr[i].link, i); for (int i = 1; i <= n; ++i) scanf("%d", &w[i]); scanf("%d%d", &wl, &wr); dn = 0, sack_init(0, 0); sack(0, false); printf("%d\n", ans); exit(EXIT_SUCCESS); } // main