为了统一记号,下文设矩形的行数为 \(n(\le 5000)\),列数为 \(m(\le 200)\),更新次数为 \(U(\le 500)\),查询次数为 \(Q(\le 2\times 10^5)\)。
最暴力的想法是每一次查询时直接DP,时间复杂度为 \(\mathcal O(Qnm^2)\)。这显然过不去,考虑优化。
我们发现 \(Q\) 非常大,而 \(U\) 比较小,所以可以考虑在每一次更新时维护答案。具体地,就是每一次更新后维护一个 \(m^2\) 的二维数组,代表第一行的每一个点走到最后一行的每一个点的最小代价,然后回答就是 \(\mathcal O(1)\) 的。暴力维护这个数组的代价为 \(\mathcal O(nm^2)\),因此总时间复杂度就是 \(\mathcal O(Unm^2+Q)\)。但这还远远不够。
注意到上述做法的瓶颈依然在维护 \(m^2\) 的二维数组,我们考虑优化这个过程。我们发现这玩意儿可以通过矩阵优化:我们将每一层的转移写成矩阵的形式,那么这个二维数组实际上就是由每一层的转移矩阵相乘得到的,而一次修改就是修改一层转移矩阵,于是我们用线段树维护转移矩阵。矩阵乘法的复杂度为 \(\mathcal O(m^3)\),那么总复杂度就是 \(\mathcal O(nm^2+Um^3\log n+Q)\)。
观察这个复杂度,我们发现瓶颈依旧在 \(Um^3\log n\),于是我们考虑优化矩阵乘法的复杂度。考虑观察矩阵乘法转移式的性质 \(z_{i,j}\leftarrow \min\limits_{k}\{x_{i,k}+y_{k,j}\}\)。我们发现,随着 \(i,j\) 的分别增大,最终的决策点 \(k\) 也会增大;也就是说这个转移时有决策单调性的。于是一次矩阵乘法的时间复杂度可以被降到 \(\mathcal O(m^2)\)。现在的总时间复杂度为 \(\mathcal O(nm^2+Um^2\log n+Q)\),可以通过。
然后我们就发现了一个问题:这样做的空间是 \(\mathcal O(nm^2)\),会爆炸。解决方法很简单,对原序列进行分块,用线段树维护块与块之间的转移,块内的修改直接暴力重新计算DP。假设块长为 \(S\),那么空间复杂度为 \(\mathcal O\left(\dfrac{nm^2}{S}\right)\),时间复杂度为 \(\mathcal O\left(nm^2+Um^2\left(S+\log\dfrac{n}{S}\right)+Q\right)\),可以通过。
#include <bits/stdc++.h> using namespace std; template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); } template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); } static constexpr int inf = 0x3f3f3f3f; static constexpr int Maxn = 5005, Maxm = 205; static constexpr int B = 16, MaxN = 1050; int n, m, q; int wr[Maxn][Maxm], sr[Maxn][Maxm]; int wc[Maxn][Maxm]; int tr[MaxN][Maxm][Maxm]; #define ls (p << 1 | 0) #define rs (p << 1 | 1) void pushup(int p, int mid) { static int ta[Maxm][Maxm]; memset(ta, 0, sizeof(ta)); for (int i = 1; i <= m; ++i) for (int j = m; j >= 1; --j) { int l = 1, r = m, res = inf, x = 0; if (ta[i - 1][j]) max_eq(l, ta[i - 1][j]); if (ta[i][j + 1]) min_eq(r, ta[i][j + 1]); for (int k = l; k <= r; ++k) if (tr[ls][i][k] + tr[rs][k][j] + wc[mid][k] < res) res = tr[ls][i][k] + tr[rs][k][j] + wc[mid][k], x = k; ta[i][j] = x, tr[p][i][j] = res; } } // pushup void pushall(int p, int l, int r) { static int f[Maxn][Maxm]; for (int i = 1; i <= m; ++i) { for (int j = 1; j <= m; ++j) f[l][j] = abs(sr[l][i] - sr[l][j]); for (int j = l + 1; j <= r; ++j) { for (int k = 1; k <= m; ++k) f[j][k] = f[j - 1][k] + wc[j - 1][k]; for (int k = 2; k <= m; ++k) min_eq(f[j][k], f[j][k - 1] + wr[j][k - 1]); for (int k = m - 1; k >= 1; --k) min_eq(f[j][k], f[j][k + 1] + wr[j][k]); } for (int j = 1; j <= m; ++j) tr[p][i][j] = f[r][j]; } } // pushall void build(int p, int l, int r) { if (r - l + 1 <= B) return pushall(p, l, r); int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); pushup(p, mid); } // build void modify(int p, int l, int r, int x) { if (r - l + 1 <= B) return pushall(p, l, r); int mid = (l + r) >> 1; if (x <= mid) modify(ls, l, mid, x); else modify(rs, mid + 1, r, x); pushup(p, mid); } // modify int main(void) { scanf("%d%d", &n, &m); for (int i = 1; i <= n; ++i) for (int j = 1; j < m; ++j) { scanf("%d", &wr[i][j]); sr[i][j + 1] = sr[i][j] + wr[i][j]; } for (int i = 1; i < n; ++i) for (int j = 1; j <= m; ++j) scanf("%d", &wc[i][j]); build(1, 1, n); for (scanf("%d", &q); q--; ) { int op; scanf("%d", &op); if (op == 1) { int u, v, w; scanf("%d%d", &u, &v), ++u, ++v; scanf("%d", &w), wr[u][v] = w; for (int j = 1; j < m; ++j) sr[u][j + 1] = sr[u][j] + wr[u][j]; modify(1, 1, n, u); } else if (op == 2) { int u, v, w; scanf("%d%d", &u, &v), ++u, ++v; scanf("%d", &w), wc[u][v] = w; modify(1, 1, n, u); } else { int u, v; scanf("%d%d", &u, &v), ++u, ++v; printf("%d\n", tr[1][u][v]); } } exit(EXIT_SUCCESS); } // main