Article #003: Bottom-Up Segment Tree

Published: 12/26/2022

In this article, we will explore how to implement a basic segment tree data structure by building and using it in a bottom-up fashion. This implementation is surprisingly a relatively unpopular approach compared to the top-down implementation even though it is more optimized and simpler. To learn how to implement a top-down implementation of a segment tree, I highly recommend reading the hackerearth article!

DISCLAIMER: I recommend implementing segment tree data structures with a top-down fashion if your situation requires range updates, since it is much easier to implement and is much more optimized for such tasks!

Segment Tree Introduction

A segment tree is a data structure which optimizes your program for performing binary and associative operations like summation and maximum over a range.

"Prefix-Sum arrays can calculate the sum over a range in O(1) runtime which is very fast. How does a segment tree help?"
To update values in a prefix sum array takes O(n) runtime because you have to change every element after the updated element. This can get slow if your arrays get larger.
A segment tree provides a perfect middle ground by allowing point updates and range sums both in O(logn) runtime. You can even update elements over a range in O(logn) runtime. runtime rather than O(nlogn) by using a lazy propagation tree.

Segment Tree Explanation

Let's start with the given input array: [4, 3, 9, 8, 5, 1, 2, 1].
(This is a list of numbers that we have to create range queries and static updates on.)

The segment tree is a data structure which contains maximums of specific ranges which can be used to find the same of any range within the given array. The segment tree data structure is given double the amount of elements as the input array. Since our input array has 8 elements, our segment tree array should have 16 elements. Rather than using an array for our segment tree data structure, you can create Node and Tree classes and build a more intuitive method of performing the segment tree data structure, but I am going for runtime efficiency here with the bottom-up segment tree, so I will stick to the 1-Dimensional array method.

Since we are implementing this segment tree from bottom-up, we will initialize our input array in the last n elements. So since our input array is 8 elements, the 8 elements go in the last 8 indices of the segment tree.

At this point in time, our segtree array should look like: [0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 9, 8, 5, 1, 2, 1]

Do you notice any pattern with the indices in the segment tree visualization above? The parent node of a node always has half the index. For example, the element 4 at index 8 is a child of the element 4 at index 4 and same goes for the element 3 at index 9. Parent node index = Current node index >> 1. (Right bit shifting by 1 is a faster way of doing floor division by 2).
Knowing this helps us traverse up and down the segment tree data structure with just arrays now.

To construct the segment tree we first have to update at every point from the given array within the segment tree. In other words, update the last n indices in the segment tree. The update() function, starts at an index and performs the associative operation at every parent at that node until it hits the root node. Here is a quick implementation of our update() function in this scenario:


    def update(segtree, index):
      value = segtree[index]
      while index > 1:
        segtree[index] = max(segtree[index], value)
        index >>= 1
 
After performing the update function for every of the last n indices, the segment tree has been constructed. To update the segment tree in the future, all that needs to be done is to change the value of the segment tree at that index and perform the update() function at that index again. The update() function runs in O(logn) runtime which is faster than the prefix array update method.

Here comes the tricky part though. How can we perform range queries across ALL ranges. Getting the max in some ranges like indices 1-4 seem obvious, because we have a node that represents that range (element 9 at index 2). But what about a range like indices 3-6?
To query the operation over a range, we are going to need 2 pointers indicating the ends of the range (left pointer and right pointer). At first, we set these left and right pointers to the ends of the range. In our example, the left pointer is at index 2+n = 2+8 = 10 and the right pointer is at index 5+n = 5+8 = 13. I subtracted 1 from the range ends because our segtree is 0-indexed and added n because our range starts at the input elements so we move up the segment tree while performing operations to find the answer. With these left and right pointers, you will move up the tree until the left and right pointers are at the same point. Once, the left and right pointers reach the same node, we have our answer. However, there is one slight twist, when we have a left pointer on a right child or a right pointer on a left child, we keep that node saved separately and move the pointer inwards to close the range even further. At the end, with these saved node values, we have to compare them with the answer and perform operations between them and the final result is our actual answer. This explanation can be a little tough to comprehend at first, so I will post a simple implementation of the query() function for our example:


    def query(segtree, left, right, max_value):
      if left == right:
        return max(max_value, segtree[left])
      if left & 1:
        max_value = max(max_value, segtree[left])
        left += 1
      if left == right:
        return max(max_value, segtree[left])
      if right & 1:
        max_value = max(max_value, segtree[right])
        right -= 1
      return query(segtree, left, right, max_value)
 


Segment Tree Implementation

The following dynamic range sum query implementation is written by BooleanCube in C++:

#include <iostream>
#include <vector>
#include <stdint.h>
using namespace std;

void constructTree(vector<uint64_t>* nums, uint64_t a) {
    for(uint64_t i=a; i<=2*a; i++) {
        uint64_t b = i;
        while(b>1) {
            b>>=1;
            nums->at(b) += nums->at(i);
        }
    }
}

void updateNode(vector<uint64_t>* nums, uint64_t k, uint64_t w) {
    uint64_t v = nums->at(k);
    uint64_t b = k;
    nums->at(k) = w;
    while(b>1) {
        b>>=1;
        nums->at(b) -= v;
        nums->at(b) += w;
    }
}

uint64_t getSumRange(vector<uint64_t>* nums, uint64_t k, uint64_t w, uint64_t sum) {
    if(k == w) return nums->at(k) + sum;
    if(k%2 == 1) { sum += nums->at(k); k++; }
    if(k == w) return nums->at(k) + sum;
    if(w%2 == 0) { sum += nums->at(w); w--; }
    return getSumRange(nums, k>>1, w>>1, sum);
}

int main() {
    uint64_t a, b; cin >> a >> b;
    vector<uint64_t>* nums = new vector<uint64_t>(2*a+1, 0);
    for(uint64_t i=0; i> nums->at(a+i);
    constructTree(nums, a);
    for(uint64_t i=0; i> q >> k >> w;
        if(q == 1) updateNode(nums, a+k-1, w);
        else cout << getSumRange(nums, a+k-1, a+w-1, 0) << endl;
    }
}
 


Written by BooleanCube :]