网上看了一圈,看到几个都是用数组实现的
我用树结构重写了一遍
#ifndef SEGMENTTREE_H #define SEGMENTTREE_H #include <vector> template<typename T> class SegmentTree { public: SegmentTree(std::vector<T> &a) { int N = a.size(); this->a = a; this->root = new _Node(); this->_build(0, N-1, this->root); }; virtual ~SegmentTree() { delete this->root; }; void updateElement(T value, int index) { this->_updateSegment(value, index, index, this->root); } void updateSegment(T c, int l, int r) { this->_updateSegment(c, l, r, this->root); }; T querySegment(int l, int r) { return this->_querySegment(l,r, this->root); }; private: struct _Node { _Node* rchild; _Node* lchild; int l, r, lazy; T sum; _Node() { this->rchild = nullptr; this->lchild = nullptr; this->l = -1; this->r = -1; this->lazy = -1; this->sum = 0; } ~_Node() { delete this->rchild; delete this->lchild; } int mid() { return (l+r) >> 1; }; void _pushUp() { this->sum = this->rchild->sum + this->lchild->sum; }; void _pushDown(int length) { // lazy load if(this->lazy) { // push down the lazy value into children this->lchild->lazy += this->lazy; this->rchild->lazy += this->lazy; this->lchild->sum += this->lazy * (length-(length>>1)); this->rchild->sum += this->lazy * (length>>1); this->lazy = 0; } }; }; _Node *root; std::vector<T> a; void _build(int l, int r, _Node *rt) { rt->l = l; rt->r = r; rt->lazy = 0; rt->rchild = new _Node(); rt->lchild = new _Node(); if(l == r) { rt->sum = this->a[l]; return ; } int m = rt->mid(); this->_build(l, m, rt->lchild); this->_build(m+1, r, rt->rchild); rt->_pushUp(); }; void _updateSegment(T value, int l, int r, _Node *rt) { if(rt->l == l && rt->r == r) { rt->lazy += value; rt->sum += value * (r-l+1); return; } if(rt->l == rt->r) return ; rt->_pushDown(rt->r - rt->l + 1); int m = rt->mid(); if(r <= m) this->_updateSegment(value, l, r, rt->lchild); else if(l > m) this->_updateSegment(value, l, r, rt->rchild); else { this->_updateSegment(value, l, m, rt->lchild); this->_updateSegment(value, m+1, r, rt->rchild); } rt->_pushUp(); }; T _querySegment(int l, int r, _Node *rt) { if(l == rt->l && r == rt->r) { return rt->sum; } rt->_pushDown(rt->r - rt->l + 1); T ans = 0; int m = rt->mid(); if(r <= m) ans += this->_querySegment(l, r, rt->lchild); else if(l > m) ans += this->_querySegment(l, r, rt->rchild); else { ans += this->_querySegment(l, m, rt->lchild); ans += this->_querySegment(m+1, r, rt->rchild); } return ans; } }; #endif // SEGMENTTREE_H