传送门
题意:
给出一个长度为 \(n\) 的数组,有 \(q\) 次查询,每次查询给出一个区间 \([l,r]\) ,求这段区间里面所有子区间的异或和的总和。
题解:
不难想到,要按位考虑贡献,对于第 \(i\) 位的贡献是 \(2^i\) 乘上区间\(1\)的个数为奇数的子区间的数量。
考虑利用线段树维护一个区间中包含奇数个\(1\)的子区间数量。
如何区间合并?
设左区间范围是 \([l,mid]\) ,右区间范围是 \([mid+1,r]\) ,\(ans\)表示区间中包含奇数个\(1\)的子区间数量。
那么 \(ans=left.ans+right.ans+\) 以\(mid\)为右端点包含奇数个\(1\)的区间数量 \(\times\) 以\(mid+1\)为左端点包含偶数个\(1\)的区间数量 \(+\) 以\(mid\)为右端点包含偶数个\(1\)的区间数量 \(\times\) 以\(mid+1\)为左端点包含奇数个\(1\)的区间数量。
那么维护三个东西即可:区间中包含奇数个\(1\)的子区间数量 ,以\(mid\)为右端点包含奇数个\(1\)的区间数量,以\(mid+1\)为左端点包含奇数个\(1\)的区间数量。
代码:
#pragma GCC diagnostic error "-std=c++11" #include <algorithm> #include <cmath> #include <cstdio> #include <cstring> #include <ctime> #include <iostream> #include <map> #include <queue> #include <set> #include <stack> #define iss ios::sync_with_stdio(false) using namespace std; typedef unsigned long long ull; typedef long long ll; typedef pair<int, int> pii; const int mod = 1e9 + 7; const int MAXN = 2e5 + 5; const int inf = 0x3f3f3f3f; int a[MAXN], base[22]; struct Node { int l, r, sum[22]; ll ans[22], lsum[22], rsum[22]; } node[MAXN << 2]; Node combine(Node x,Node y) { Node k; k.l = x.l; k.r = y.r; int len1 = x.r - x.l + 1; int len2 = y.r - y.l + 1; for (int i = 0; i <= 20;i++) { k.sum[i] = x.sum[i] + y.sum[i]; k.ans[i] = (x.ans[i] + y.ans[i] + x.rsum[i] * (len2 - y.lsum[i])%mod + (len1 - x.rsum[i]) * y.lsum[i]%mod)%mod; if(x.sum[i]&1) k.lsum[i] = x.lsum[i] + (len2 - y.lsum[i]); else k.lsum[i] = x.lsum[i] + y.lsum[i]; if(y.sum[i]&1) k.rsum[i] = y.rsum[i] + (len1 - x.rsum[i]); else k.rsum[i] = y.rsum[i] + x.rsum[i]; } return k; } void build(int l, int r, int num) { node[num].l = l; node[num].r = r; if (l == r) { for (int i = 20; i >= 0; i--) { if((a[l]>>i)&1){ node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]=1; } else { node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]= 0; } } return; } int mid = (l + r) >> 1; build(l, mid, num << 1); build(mid + 1, r, num << 1|1); node[num] = combine(node[num << 1], node[num << 1 | 1]); } Node query(int l,int r,int num) { if(node[num].l>=l&&node[num].r<=r) { return node[num]; } int mid = (l + r) >> 1; if(r<=mid) return query(l, r, num << 1); else if(l>mid) return query(l, r, num << 1 | 1); else { Node tmp1 = query(l, r, num << 1); Node tmp2 = query(l, r, num << 1 | 1); Node tmp = combine(tmp1, tmp2); return tmp; } } int main() { base[0] = 1; for (int i = 1; i <= 20; i++) base[i] = base[i - 1] * 2; int t; scanf("%d", &t); while (t--) { int n, q; scanf("%d%d", &n, &q); for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); } build(1, n, 1); while ((q--)) { int l, r; scanf("%d%d", &l, &r); Node ans = query(l, r, 1); ll sum = 0; for (int i = 0; i <= 20;i++) { sum = (sum + base[i] * ans.ans[i]%mod)%mod; } printf("%lld\n", sum); } } }