Let's start with the given input array: [4, 3, 9, 8, 5, 1, 2, 1]
.
(This is a list of numbers that we have to create range queries and static updates on.)
The segment tree is a data structure which contains maximums of specific ranges which can be used to find the same of any range within the given array. The segment
tree data structure is given double the amount of elements as the input array. Since our input array has 8 elements, our segment tree array should have 16 elements.
Rather than using an array for our segment tree data structure, you can create Node
and Tree
classes and build a more intuitive method of
performing the segment tree data structure, but I am going for runtime efficiency here with the bottom-up segment tree, so I will stick to the 1-Dimensional array
method.
Since we are implementing this segment tree from bottom-up, we will initialize our input array in the last n elements. So since our input array is 8 elements, the 8
elements go in the last 8 indices of the segment tree.
At this point in time, our segtree array should look like: [0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 9, 8, 5, 1, 2, 1]
Do you notice any pattern with the indices in the segment tree visualization above? The parent node of a node always has half the index. For example, the element 4 at
index 8 is a child of the element 4 at index 4 and same goes for the element 3 at index 9. Parent node index = Current node index >> 1
. (Right bit shifting
by 1 is a faster way of doing floor division by 2).
Knowing this helps us traverse up and down the segment tree data structure with just arrays now.
To construct the segment tree we first have to update at every point from the given array within the segment tree. In other words, update the last n indices in the
segment tree. The update()
function, starts at an index and performs the associative operation at every parent at that node until it hits the root node.
Here is a quick implementation of our update()
function in this scenario:
def update(segtree, index):
value = segtree[index]
while index > 1:
segtree[index] = max(segtree[index], value)
index >>= 1
After performing the update function for every of the last n indices, the segment tree has been constructed. To update the segment tree in the future, all that needs
to be done is to change the value of the segment tree at that index and perform the update()
function at that index again. The update()
function
runs in O(logn) runtime which is faster than the prefix array update method.2+n = 2+8 = 10
and the right pointer is at index
5+n = 5+8 = 13
. I subtracted 1 from the range ends because our segtree is 0-indexed and added n because our range starts at the input elements so we move up
the segment tree while performing operations to find the answer. With these left and right pointers, you will move up the tree until the left and right pointers are at
the same point. Once, the left and right pointers reach the same node, we have our answer. However, there is one slight twist, when we have a left pointer on a right child
or a right pointer on a left child, we keep that node saved separately and move the pointer inwards to close the range even further. At the end, with these saved node values,
we have to compare them with the answer and perform operations between them and the final result is our actual answer. This explanation can be a little tough to comprehend
at first, so I will post a simple implementation of the query()
function for our example:
def query(segtree, left, right, max_value):
if left == right:
return max(max_value, segtree[left])
if left & 1:
max_value = max(max_value, segtree[left])
left += 1
if left == right:
return max(max_value, segtree[left])
if right & 1:
max_value = max(max_value, segtree[right])
right -= 1
return query(segtree, left, right, max_value)
#include <iostream>
#include <vector>
#include <stdint.h>
using namespace std;
void constructTree(vector<uint64_t>* nums, uint64_t a) {
for(uint64_t i=a; i<=2*a; i++) {
uint64_t b = i;
while(b>1) {
b>>=1;
nums->at(b) += nums->at(i);
}
}
}
void updateNode(vector<uint64_t>* nums, uint64_t k, uint64_t w) {
uint64_t v = nums->at(k);
uint64_t b = k;
nums->at(k) = w;
while(b>1) {
b>>=1;
nums->at(b) -= v;
nums->at(b) += w;
}
}
uint64_t getSumRange(vector<uint64_t>* nums, uint64_t k, uint64_t w, uint64_t sum) {
if(k == w) return nums->at(k) + sum;
if(k%2 == 1) { sum += nums->at(k); k++; }
if(k == w) return nums->at(k) + sum;
if(w%2 == 0) { sum += nums->at(w); w--; }
return getSumRange(nums, k>>1, w>>1, sum);
}
int main() {
uint64_t a, b; cin >> a >> b;
vector<uint64_t>* nums = new vector<uint64_t>(2*a+1, 0);
for(uint64_t i=0; i> nums->at(a+i);
constructTree(nums, a);
for(uint64_t i=0; i> q >> k >> w;
if(q == 1) updateNode(nums, a+k-1, w);
else cout << getSumRange(nums, a+k-1, a+w-1, 0) << endl;
}
}
Written by BooleanCube :]