Reverse Pairs

Can anyone please explain the solution using Binary Indexed Tree and Segment tree for this problem?

A bit seat-of-the-pants, but I think this is right - bit confused about the input format, though. Let me know if you want anything clarified. Thereā€™s some clues in the comments in solveOptimised.

// Simon St James (ssjgz) - 2019-08-28
//#define SUBMISSION
#define BRUTE_FORCE
#ifdef SUBMISSION
#undef BRUTE_FORCE
#define NDEBUG
#endif
#include <iostream>
#include <vector>
#include <algorithm>

#include <cassert>

#include <sys/time.h> // TODO - this is only for random testcase generation.  Remove it when you don't need new random testcases!

using namespace std;

template <typename T>
T read()
{
    T toRead;
    cin >> toRead;
    assert(cin);
    return toRead;
}

// Typical SegmentTree - you can find similar implementations all over the place :)
class SegmentTree
{
    public:
        SegmentTree() = default;
        SegmentTree(int maxId)
            : m_maxId{maxId},
            m_numElements{2 * maxId},
            m_elements(m_numElements + 1)
            {
            }

        void addValueAt(int pos, int value)
        {
            const auto n = m_numElements;
            auto elements = m_elements.data();
            pos = pos + 1; // Make 1-relative.
            while(pos <= n)
            {
                elements[pos] += value;
                assert(elements[pos] >= 0);
                pos += (pos & (pos * -1));
            }
        }

        // Find the number in the given range (inclusive) in O(log2(maxId)).
        int numInRange(int start, int end) const
        {
            start++; // Make 1-relative.  start and end are inclusive.
            end++;
            int sum = 0;
            auto elements = m_elements.data();
            while(end > 0)
            {
                sum += elements[end];
                end -= (end & (end*-1));
            }
            start--;
            while(start > 0)
            {
                sum -= elements[start];
                start -= (start & (start*-1));
            }
            return sum;
        }
    private:
        int m_maxId;
        int m_numElements;
        vector<int> m_elements;

};

int64_t solveBruteForce(const vector<int>& nums)
{
    int64_t result = 0;
    const int n = nums.size();

    for (int j = 0; j < n; j++)
    {
        for (int i = 0; i < j; i++)
        {
            if (nums[i] > 2 * nums[j])
                result++;
        }
    }

    return result;
}

int64_t solveOptimised(const vector<int>& nums)
{
    int64_t result = 0;
    const int n = nums.size();
    SegmentTree segmentTree(n + 1);

    struct ValueAndIndex
    {
        int value = -1;
        int index = -1;
    };
    vector<ValueAndIndex> numsAndIndicesDecreasing;
    for (int i = 0; i < n; i++)
    {
        numsAndIndicesDecreasing.push_back({nums[i], i});
    }
    sort(numsAndIndicesDecreasing.begin(), numsAndIndicesDecreasing.end(), [](const auto& lhs, const auto& rhs) 
            {
                if (lhs.value != rhs.value)
                    return lhs.value > rhs.value;
                return lhs.index < rhs.index;
            });

    struct TwiceValueAndIndex
    {
        int twiceValue = -1;
        int index = -1;
    };
    vector<TwiceValueAndIndex> twiceValuesAndIndicesIncreasing;
    for (int i = 0; i < n; i++)
    {
        twiceValuesAndIndicesIncreasing.push_back({2 * nums[i], i});
    }
    sort(twiceValuesAndIndicesIncreasing.begin(), twiceValuesAndIndicesIncreasing.end(), [](const auto& lhs, const auto& rhs) 
            {
                if (lhs.twiceValue != rhs.twiceValue)
                    return lhs.twiceValue < rhs.twiceValue;
                return lhs.index < rhs.index;
            });

    for (const auto valueAndIndex : numsAndIndicesDecreasing)
    {
        while (!twiceValuesAndIndicesIncreasing.empty() && twiceValuesAndIndicesIncreasing.back().twiceValue >= valueAndIndex.value)
        {
            // The set of things we've added to the SegmentTree so far is precisely the set of indices i such that 
            // nums[i] > twiceValuesAndIndicesIncreasing.back().value -
            // now we can find the contribution for j == twiceValuesAndIndicesIncreasing.back().index.
            result += segmentTree.numInRange(0, twiceValuesAndIndicesIncreasing.back().index - 1);
            // We're done with this index j, now.
            twiceValuesAndIndicesIncreasing.pop_back();
        }

        segmentTree.addValueAt(valueAndIndex.index, 1);
    }

    return result;
}

int main(int argc, char* argv[])
{
    ios::sync_with_stdio(false);
    if (argc == 2 && string(argv[1]) == "--test")
    {
        // Generate random testcase.
        struct timeval time;
        gettimeofday(&time,NULL);
        srand((time.tv_sec * 1000) + (time.tv_usec / 1000));

        const int N = rand() % 100 + 1;
        const int maxA = rand() % 1000 + 1;

        cout << N << endl;
        for (int i = 0; i < N; i++)
        {
            cout << ((rand() % maxA)) << " " << endl;
        }

        return 0;
    }

    const int N = read<int>();

    vector<int> nums(N);
    for (auto& x : nums)
    {
        x = read<int>();
    }


#ifdef BRUTE_FORCE
    const auto solutionBruteForce = solveBruteForce(nums);
    cout << "solutionBruteForce: " << solutionBruteForce << endl;
    const auto solutionOptimised = solveOptimised(nums);
    cout << "solutionOptimised: " << solutionOptimised << endl;

    assert(solutionOptimised == solutionBruteForce);
#else
    const auto solutionOptimised = solveOptimised();
    cout << solutionOptimised << endl;
#endif

    assert(cin);
}
3 Likes

Okay, I have one solution using Merge Sort Tree(A type of Segment Tree where each node stores a vector instead of a number). Not very sure if this is the best approach but AFAIK it should work.
First sort the original array, letā€™s for sake of ease call it ar_sorted, and suppose that our original array was ar. Now, create another array letā€™s say ar_index, ar_index maps each element in ar_sorted to its index in ar.
Letā€™s say for example-

ar = [1, 3, 2, 3, 1]
ar_sorted = [1, 1, 2, 3, 3]
ar_index = [0, 4, 2, 1, 3]

We can create both ar_sorted and ar_index in $O(N * log(N))$.
Next, create a Merge Sort Tree for ar_index, we will use it later.
Now, letā€™s try to find a solution to our original problem. First, we will traverse every element in ar_sorted and try to find the last element such that it satisfies the condition given in the question. Assume that we are currently at an index i and processing the element ar_sorted[i]and letā€™s say element at index j, i.e. ar_sorted[j] is the last element that satisfies $ar_sorted[i] > 2 x ar_sorted[j] $. Then we are sure that-
All elements at an index < j will satisfy the above condition and trivially, by our definition of j, no element at an index above j satisfies the condition.
The last task is to find all the elements in ar_index in the range [0, j] that is greater than the element in ar_index at the index i. To put it in easier words, we will use the Merge Sort Tree to find all the indexes which are in [0, j]that are greater than the index of the element ar_sorted[j].
I havenā€™t written code but to me, it looks like this approach should get it done.

1 Like

I think we can also tackle this using Merge Sort Tree?

1 Like

Iā€™m betting thereā€™s a whole bunch of ways of skinning this particular cat :slight_smile:

1 Like

Actually I was asking you to read my approach, that I posted above and tell me what you think(sorry if thatā€™s odd :stuck_out_tongue:) because I havenā€™t written code but to me the Algorithm is correct. Itā€™s already 1:30 AM here and I have to go to sleep, so, I canā€™t write until I wake up in the morning. :sweat_smile:

Oh, I see - will have a ponder tomorrow morning, UK-time :slight_smile:

1 Like

Did you submit your code? If so, did you get AC?
My approach is correct, it gives the correct output for each test case but it Times Out(TLE), IDK why, The time complexity of Merge Sort Tree is O(N * log(N) * log(N)), which should be acceptable as per the constraints(N <= 50k) but it times out.

Would you mind helping me out, what can I Optimise?

#include <bits/stdc++.h>
using namespace std;

vector<vector<int>> seg_tree;

void merge(vector<int> &l, vector<int> &r, vector<int> &res) {
    int l_ptr = 0, r_ptr = 0;
    int l_size = l.size(), r_size = r.size();
    
    while(l_ptr < l_size && r_ptr < r_size) {
        if(l[l_ptr] <= r[r_ptr])
            res.push_back(l[l_ptr++]);
        else
            res.push_back(r[r_ptr++]);
    }
    
    while(l_ptr < l_size)
        res.push_back(l[l_ptr++]);
    while(r_ptr < r_size)
        res.push_back(r[r_ptr++]);
}

void buildTree(int idx, int l, int r, vector<int> ar) {
    if(l == r) {
        seg_tree[idx].push_back(ar[l]);
        return;
    }
    
    int mid = l + ((r - l) >> 1);
    buildTree(idx << 1, l, mid, ar);
    buildTree(idx << 1 | 1, mid + 1, r, ar);
    merge(seg_tree[idx << 1], seg_tree[idx << 1 | 1], seg_tree[idx]);
}

int query(int idx, int start, int end, int l, int r, int num) {
    if(start > r || end < l)
        return 0;
    
    if(l <= start && end <= r)
        return seg_tree[idx].end() - upper_bound(seg_tree[idx].begin(), seg_tree[idx].end(), num);
        
    int mid = start + ((end - start) >> 1);
    return query(idx << 1, start, mid, l, r, num) + query(idx << 1 | 1, mid + 1, end, l, r, num);
}

class Solution {
public:
    int reversePairs(vector<int>& nums) {
        int size = nums.size();
        
        if(size == 0)
            return 0;
        
        seg_tree = vector<vector<int>>(4 * size + 1, vector<int>());
        
        vector<int> nums_copy = nums, nums_index(size);
        vector<bool> picked(size, false);
        
        sort(nums_copy.begin(), nums_copy.end());
        
        for(int i = 0; i < size; ++i) {
            int num = nums[i];
            int start = 0, end = size - 1;
            
            while(start < end) {
                int mid = start + (end - start) / 2;
                if(nums_copy[mid] > num)
                    end = mid - 1;
                else if(nums_copy[mid] < num || (nums_copy[mid] == num && picked[mid]))
                    start = mid + 1;
                else
                    end = mid;
            }
            
            int idx = start + (end - start) / 2;
            picked[idx] = true;
            nums_index[idx] = i;
        }
        
        buildTree(1, 0, size - 1, nums_index);
        
        int res = 0;
        for(int i = 0; i < size; ++i) {
            int start = 0, end = size - 1;
            while(start < end) {
                int mid = start + (end - start + 1) / 2;
                if(nums_copy[i] > 2LL * nums_copy[mid])
                    start = mid;
                else
                    end = mid - 1;
            }
            
            int idx = start;
            if(nums_copy[i] <= 2LL * nums_copy[idx])
                continue;
                
            res += query(1, 0, size - 1, 0, idx, nums_index[i]);
        }
        
        return res;
    }
};

No, I donā€™t have a Leetcode account :slight_smile:

I actually canā€™t even figure out what the time constraints even are XD I just tried your solution with a randomly-generated 100ā€™000 size array (exceeding the constraints, I know ā€¦) and it took about 16 seconds, which is very, very long for a O(N * log(N) * log(N)) solution, so something else is going on there.

valgrind wasnā€™t much help - will investigate a bit further.

Edit:

It appears to be building the tree thatā€™s taking all the time.

1 Like

Thank you for taking so much pain for helping out. I canā€™t use Valgrind at the moment, so, thank you.

1 Like

Aha - Iā€™ve found your ā€œdeliberate mistakeā€ :wink:

void buildTree(int idx, int l, int r, vector<int> ar /* Ouch! */) {
1 Like

Got it. Thank you. So, silly of me. Should have been obvious. I think declaring the array global would help out, itā€™s the repeated copying of ar for each function call that is causing the Time Out? Thank you so much man, youā€™re awesome as always.

1 Like

Yes - taking it by reference-to-const instead decreased the runtime for the testcase I was using by a factor of 114 :slight_smile:

Edit:

This is why my pre-Submission regimen always consists of:

a) Pass at least a thousand small random testcases (i.e. clever optimised result matches naive brute-force result); and
b) Pass at least a couple of random testcases maxing out the constraints in acceptable time - this wonā€™t necessarily flush out ā€œworst-caseā€ behaviour, but is a good ā€œsmoke-testā€ :slight_smile:

1 Like

Yeah thatā€™s why I said there is so much you can offer to the community despite not being, to quote you a ā€˜whipper-snapperā€™. Most Competitive Programmers wonā€™t bother doing that but you do as you probably carried such practices being a developer(am I right)? Thank you, I will try to make it a habit(atleast during practice and long contests).

1 Like

Got AC finally and the only thing I changed is that I made the arg ar a reference to a constant vector.

1 Like

Thanks a lot, @anon62928982 @ssjgz for providing these solutions. :slightly_smiling_face:

2 Likes

Okay so after reading the editorial solution and the above provided code.I managed to get an AC for the problem. Here is my solution.
Edit: we can also use C++ PBDS in place of BIT.

struct BIT{
    vector nums;
    BIT(int size) : nums(size){}
    int lowbit(int x) { return x & -x; }
    void update(int x, int cnt){
        while (x < nums.size()) {
            nums[x] += cnt;
            x += lowbit(x);
        }
    }
    
    int query(int x) {
        int sum = 0;
        while(x > 0) {
            sum += nums[x];
            x -= lowbit(x);
        }
        return sum;
    }
};

class Solution {
public:
    int reversePairs(vector& nums) {
        // Copy array
        vector copy;
        for(auto x:nums) {
            copy.push_back(x);
            copy.push_back(2*(long long)x+1);
        }
        sort(copy.begin(),copy.end());
        int ans=0;
        BIT bit(copy.size() + 10);
        for(int i = nums.size()-1; i >= 0; --i){
            int r1=upper_bound(copy.begin(),copy.end(),(long long)nums[i])-copy.begin()+1;
            int r2=upper_bound(copy.begin(),copy.end(),2*(long long)nums[i])-copy.begin()+1;
            ans += bit.query(r1-1);
            bit.update(r2, 1);
        }
        return ans;

    }
};
1 Like

Yup. PBDS would work since you have two parameters to sort on the basis of, viz - index and values but PBDS are harder to code, I think. BIT is the way to go for problems of these type. I knew how to solve it using BIT, it is not very different from counting inversions in an array. I wanted to try something different, so, Merge Sort Tree.