传送门
题面:给两个序列\(a,b\),将\(b\)中的所有元素按任意顺序插入\(a\)中,求形成的新的序列的最小逆序对数。
这题首先最好观察出这么个结论:如果把\(b_i\)插在\(p_i\)(即\(a_{i-1}\)和\(a_i\)之间)得到的逆序对最少,那么当\(b_i < b_j\)时,一定有\(p_i < p_j\).即这个最优插入位置是随着\(b_i\)增大而单调递增的。
知道这个结论后,如果我们能算出来所有的\(p_i\),那么逆序对只会在\(a\)序列本身和\(a,b\)之间产生,\(b\)本身是不会产生逆序对的。
求\(p_i\)的方法,除了\(O(n^2)\)暴力外有两种方法:
线段树。将\(a\)和\(b\)放在一块并从小到大排序。刚开始所有\(a_i\)都比任意一个\(b_i\)大,那么有\(p_i=i-1( i \in [1,n+1])\)。考虑遇到一个\(a_i\),那么接下来的\(b_j\)一定比这个\(a_i\)大,那么放在\(a_i\)后面就不会和\(a_i\)产生逆序对,而放在\(a_i\)前面反而会多产生一个逆序对,因此把\(a_i\)在原数组位置之后的\(p_j\)都减1,而之前的\(p_j\)都加1.
如果遇到\(b_i\),直接查询当前\(p_j\)的最小值即可。这些操作都可以用线段树实现。
但是\(a,b\)中可能有相同的元素,假设\(a_i=b_j\),那么将\(b_j\)放在\(a_i\)的前面和后面都不会产生逆序对,对比未用\(a_i\)更新数组的情况,会发现我们应该先将\(a_i\)之后的所有\(p_k\)减1,再查询\(b_j\),最后再把\(a_i\)之前的所有\(p_k\)加1.
这样线段树的时间复杂度就是\(O(m\log n)\).
分治。这是题解的做法。先把\(b_i\)从小到大排序,因为\(p_i\)随\(b_i\)递增,因此对于当前的一段\(b[L,R]\),先求\(b_{\frac{L+R}{2}}\)的最优位置,再递归到左右子区间求解。这样左右子区间扫描的长度之和就只有当前区间的长度。又因为递归了\(\log n\)层,因此时间复杂度也是\(O(n\log n)\)的。
我这里用的第一种写法,模板化,好写(不过分治也挺好写的)。
#include<bits/stdc++.h> using namespace std; #define enter puts("") #define space putchar(' ') #define Mem(a, x) memset(a, x, sizeof(a)) #define In inline #define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt) typedef long long ll; typedef double db; const int INF = 0x3f3f3f3f; const db eps = 1e-8; const int maxn = 1e6 + 5; In ll read() { ll ans = 0; char ch = getchar(), las = ' '; while(!isdigit(ch)) las = ch, ch = getchar(); while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar(); if(las == '-') ans = -ans; return ans; } In void write(ll x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar(x % 10 + '0'); } int n, m, N, cnt = 0; struct Node { int val, pos, flg; In bool operator < (const Node& oth)const { if(val ^ oth.val) return val < oth.val; if(flg ^ oth.flg) return flg < oth.flg; return pos < oth.pos; } }t[maxn * 3]; int c[maxn]; In int lowbit(int x) {return x & -x;} In void add(int pos) //树状数组也要考虑ai相同的情况…… { for(; pos <= N; pos += lowbit(pos)) c[pos]++; } In int query(int pos) { int ret = 0; for(; pos; pos -= lowbit(pos)) ret += c[pos]; return ret; } int l[maxn << 2], r[maxn << 2], Min[maxn << 2], lzy[maxn << 2]; In void build(int L, int R, int now) { l[now] = L, r[now] = R; lzy[now] = 0; if(L == R) {Min[now] = L - 1; return;} int mid = (L + R) >> 1; build(L, mid, now << 1), build(mid + 1, R, now << 1 | 1); Min[now] = min(Min[now << 1], Min[now << 1 | 1]); } In void pushdown(int now) { if(lzy[now]) { Min[now << 1] += lzy[now], lzy[now << 1] += lzy[now]; Min[now << 1 | 1] += lzy[now], lzy[now << 1 | 1] += lzy[now]; lzy[now] = 0; } } In void update(int L, int R, int now, int d) { if(l[now] == L && r[now] == R) { Min[now] += d, lzy[now] += d; return; } pushdown(now); int mid = (l[now] + r[now]) >> 1; if(R <= mid) update(L, R, now << 1, d); else if(L > mid) update(L, R, now << 1 | 1, d); else update(L, mid, now << 1, d), update(mid + 1, R, now << 1 | 1, d); Min[now] = min(Min[now << 1], Min[now << 1 | 1]); } int main() { int T = read(); while(T--) { n = read(), m = read(); N = n + 1, cnt = 0; fill(c, c + N + 1, 0); build(1, N, 1); for(int i = 1; i <= n; ++i) { int x = read(); t[++cnt] = (Node){x, i, -1}; t[++cnt] = (Node){x, i, 1}; } for(int i = 1; i <= m; ++i) { int x = read(); t[++cnt] = (Node){x, 0, 0}; } sort(t + 1, t + cnt + 1); ll ans = 0; for(int i = 1; i <= cnt; ++i) { if(t[i].flg == -1) { ans += query(N) - query(t[i].pos); add(t[i].pos); update(t[i].pos + 1, N, 1, -1); } else if(t[i].flg == 0) ans += Min[1]; else update(1, t[i].pos, 1, 1); } write(ans), enter; } return 0; }