Friday, December 18, 2015

Count of Smaller Numbers After Self [LeetCode]

Problem Description
You are given an integer array nums and you have to return a new counts array. The counts array has the property where counts[i] is the number of smaller elements to the right of nums[i].
Example:
Given nums = [5, 2, 6, 1]

To the right of 5 there are 2 smaller elements (2 and 1).
To the right of 2 there is only 1 smaller element (1).
To the right of 6 there is 1 smaller element (1).
To the right of 1 there is 0 smaller element.
Return the array [2, 1, 1, 0].
Solution
In this post, I'm going to write a little bit about Binary Indexed Tree and its application to solve the above problem. The problem itself is not really hard. We can solve it using many ways, including Binary Search Tree, Segment Tree, Sorting, or language specific way such as using lower_bound in C++, TreeSet (or SortedSet) in Java with method lower  (see some at the end) .
Once we know how to use Binary Indexed Tree or shortly BIT, we can solve many other problems, especially in programming contests since BIT is very easy to implement.

I suggest that you spend some time to read this article from Topcoder: Binary Indexed Tree.
Basically, in this problem, we use BIT to count the number of integers that are less than a specific number.
Suppose that a number N = A1B > 0 in binary representation, where B contains all 0 . The array tree is a BIT where tree[N] count the number of integers that are from A0B and A1B - 1 .
So if we call f[N] is the number of integers that are less than N, how we calculate its value?
Yes, you are correct, f[N] = tree[N] + f[A0B] (where A0B is in binary representation).
We also know that A0B = N & (N-1) using bit manipulation. (NOTE: on the Topcoder, they use A0B= N - (N & -N) .  
Having this in mind, to solve the problem we run from the back of the array, try each element. At the position i , we can simply calculate f[nums[i]] and put it into the result. However, we need to update the BIT here, because we have found another integer. So the natural question is which element we need to update in the BIT? Obviously, we need to update tree[N+1] by increasing its value by 1. But we do not stop there. Let N+1 = C1D where D has all 0 . As you can see, let  g[N+1] = C1D + 1D , we need to update g[N+1] also. And in turn, we need to update g[g[N+1]],so on...

Let's see the following Java code for implementation.
public class Solution {
    
    /*
    In this solution, we use a binary indexed tree (BIT)
    Our assumption is that all elements in nums are positive
    */
    
    static int MAX = 11000; //we set max value that can be store in the tree
    int[] tree = new int[MAX];
    
    public List<Integer> countSmaller(int[] nums) {
        Integer[] result = new Integer[nums.length];
        
        //make all elements in the array posive while maintaining their order
        makePositive(nums);
    
        for(int i=nums.length-1; i>=0; i--){
            result[i] = get(nums[i]);
            add(nums[i]+1, 1);
        }
        return Arrays.asList(result);
    }
    
    public void makePositive(int[] nums){
        int min = MAX;
        for(int i=0; i<nums.length; i++)    
            min = Math.min(min, nums[i]);
        if(min < 0){
            min = -min+1;
            for(int i=0; i<nums.length; i++)
                nums[i] += min;
        }
    }
    
    public void add(int idx, int val){
        while(idx<MAX){
            tree[idx] += val;
            idx += (idx & (-idx));
        }
    }
    
    public int get(int idx){
        int result = 0;
        while(idx>0){
            result += tree[idx];
            idx &= (idx-1);
        }
        return result;
    }
}
Appendix A: Binary search Tree solution (Java) - Credited to  yavinci
public class Solution {
    class Node {
        Node left, right;
        int val, sum, dup = 1;
        public Node(int v, int s) {
            val = v;
            sum = s;
        }
    }
    public List<Integer> countSmaller(int[] nums) {
        Integer[] ans = new Integer[nums.length];
        Node root = null;
        for (int i = nums.length - 1; i >= 0; i--) {
            root = insert(nums[i], root, ans, i, 0);
        }
        return Arrays.asList(ans);
    }
    private Node insert(int num, Node node, Integer[] ans, int i, int preSum) {
        if (node == null) {
            node = new Node(num, 0);
            ans[i] = preSum;
        } else if (node.val == num) {
            node.dup++;
            ans[i] = preSum + node.sum;
        } else if (node.val > num) {
            node.sum++;
            node.left = insert(num, node.left, ans, i, preSum);
        } else {
            node.right = insert(num, node.right, ans, i, preSum + node.dup + node.sum);
        }
        return node;
    }
}
Appendix B: Segment Tree Solution (Javascript) - Credited to opmiss.
/**
 * @param {number[]} nums
 * @return {number[]}
 */
var countSmaller = function(nums) {
    if (nums.length<1) return []; 
    var SegmentTreeNode = function(s, e){
        this.start = s;
        this.end = e; 
        this.left = null; 
        this.right = null; 
        this.count = 0; 
    }; 
    var max = nums[0]; 
    var min = nums[0]; 
    nums.forEach(function(num){
        max = (max<num)?num:max; 
        min = (min>num)?num:min; 
    }); 
    var root = new SegmentTreeNode(min, max);
    var insert = function(node, num){
        ++node.count; 
        if (node.start===node.end){
            return 0; 
        }
        if (node.left===null){
            var mid = (node.start+node.end)>>1; 
            node.left = new SegmentTreeNode(node.start, mid); 
            node.right = new SegmentTreeNode(mid+1, node.end); 
        }
        if (num>node.left.end){
            var res=node.left.count+insert(node.right, num);
            return res; 
        }
        return insert(node.left, num); 
    }; 

    var res = []; 
    while (nums.length>0){
       res.unshift(insert(root, nums.pop()));  
    }
    return res; 
};
Appendix C: Merge sort (Java) - Credited to  lzyfriday.
int[] count;
public List<Integer> countSmaller(int[] nums) {
    List<Integer> res = new ArrayList<Integer>();     

    count = new int[nums.length];
    int[] indexes = new int[nums.length];
    for(int i = 0; i < nums.length; i++){
        indexes[i] = i;
    }
    mergesort(nums, indexes, 0, nums.length - 1);
    for(int i = 0; i < count.length; i++){
        res.add(count[i]);
    }
    return res;
}
private void mergesort(int[] nums, int[] indexes, int start, int end){
    if(end <= start){
        return;
    }
    int mid = (start + end) / 2;
    mergesort(nums, indexes, start, mid);
    mergesort(nums, indexes, mid + 1, end);

    merge(nums, indexes, start, end);
}
private void merge(int[] nums, int[] indexes, int start, int end){
    int mid = (start + end) / 2;
    int left_index = start;
    int right_index = mid+1;
    int rightcount = 0;     
    int[] new_indexes = new int[end - start + 1];

    int sort_index = 0;
    while(left_index <= mid && right_index <= end){
        if(nums[indexes[right_index]] < nums[indexes[left_index]]){
            new_indexes[sort_index] = indexes[right_index];
            rightcount++;
            right_index++;
        }else{
            new_indexes[sort_index] = indexes[left_index];
            count[indexes[left_index]] += rightcount;
            left_index++;
        }
        sort_index++;
    }
    while(left_index <= mid){
        new_indexes[sort_index] = indexes[left_index];
        count[indexes[left_index]] += rightcount;
        left_index++;
        sort_index++;
    }
    while(right_index <= end){
        new_indexes[sort_index++] = indexes[right_index++];
    }
    for(int i = start; i <= end; i++){
        indexes[i] = new_indexes[i - start];
    }
}
Appendix D: Merge sort (Python) - Credited to StefanPochmann
def countSmaller(self, nums):
    def sort(enum):
        half = len(enum) / 2
        if half:
            left, right = sort(enum[:half]), sort(enum[half:])
            for i in range(len(enum))[::-1]:
                if not right or left and left[-1][1] > right[-1][1]:
                    smaller[left[-1][0]] += len(right)
                    enum[i] = left.pop()
                else:
                    enum[i] = right.pop()
        return enum
    smaller = [0] * len(nums)
    sort(list(enumerate(nums)))
    return smaller

1 comment: