Chocolate Day (CC003) - Editorial

Contest Name:

CODICTION 2020

PROBLEM LINK:

Problem
Contest

Author: Vishal Rochlani
Tester: Yash Sharma
Editorialist: Vishal Rochlani

DIFFICULTY:

EASY-MEDIUM.

PREREQUISITES:

Segment Tree

PROBLEM:

There are N jars having some chocolates in them. To win the game, one has to select the maximum number of consecutive jars such that the sum of count of chocolates in maximum and second maximum jar is less than or equal to k in that range.

QUICK EXPLANATION:

We can simply use two pointers which denotes the subarray, If the sum is less then or equal to K then simply we will increment the right pointer by one index to make our subarray as large as possible and if the sum of max and second max element in that range exceeds the value of K then we will increment the left pointer by one index

The value of max and second max element of subarray can be found by simply using Segment Tree as described below.

EXPLANATION:

Let’s simplify the problem slightly: We are given a array of size N, we have to find the maximum length of a subarray with the additional constraint that the sum of max and second max element of subarray should be less than or equal to a positive integer K.

This problem can be solved by solving the below two sub-problems:

  • Traversal/checking of all the subarrays to find the maximum subarray keeping in mind the given constraint.
  • Finding the max and second max element in the given range (subarray).

So, let’s one by one solve these virtually difficult problems :sweat_smile: which are really easy!!

How to Traverse in array ?

Let us define two pointers, s=0 and e=1 for start and end. Now as we loop through the array.
we will increment e till we reach a point where the sum of max and second max element in that range exceeds the value of K, at this point we will increment s and then again check for constraint by finding the sum of max and second max element in the new range and repeating this process till e < N.
We will store the maximum value of e-s+1 in a variable which will be the output of our solution.

How to find max and second max in the given range ?

We can create a segment tree which stores a pair (max value in the range,index of max element) . Let L and R be the left and right index of our query then we will query for max element in this range in the segment tree, as the segment tree contains pair as described above we will get the max element and the index of that element in the array let max element is represented by M and index of this element is represented by ind. To find the second max element from L to R we can simply do two more queries from L to ind-1 let the maximum in this range be A and from ind+1 to R let the maximum in this range be B then second maximum in range L to R will be equal to maximum of A and B considering all the corner cases.

So by combining these two we can easily solve this problem.

COMPLEXITY:

Time complexity: O(NlogN) per test.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
int getMid(int s, int e)
{
    return s + (e - s) / 2;
}
 
pair<int,int> MaxUtil(pair<int,int>* st, int ss, int se, int l,
            int r, int node)
{
    // If segment of this node is completely
    // part of given range, then return
    // the max of segment
    if (l <= ss && r >= se)
        return st[node];
 
    // If segment of this node does not
    // belong to given range
    if (se < l || ss > r)
        return make_pair(-1,-1);
 
    // If segment of this node is partially
    // the part of given range
    int mid = getMid(ss, se);
 
    pair<int,int> p = MaxUtil(st, ss, mid, l, r,
                       2 * node + 1);
    pair<int,int> q = MaxUtil(st, mid + 1, se, l,
                       r, 2 * node + 2);
 
     if(p.first > q.first)
     {
         return p;
     }
     else
        return q;
}
 
pair<int,int> getMax(pair<int,int>* st, int n, int l, int r)
{
    // Check for erroneous input values
    if (l < 0 || r > n - 1 || l > r)
    {
        //printf("Invalid Input");
        return make_pair(-1,-1);
    }
 
    return MaxUtil(st, 0, n - 1, l, r, 0);
}
 
pair<int,int> constructSTUtil(int arr[], int ss, int se,
                    pair<int,int> *st, int si)
{
    // If there is one element in array, store
    // it in current node of segment tree and return
    if (ss == se)
    {
        pair<int,int> temp = make_pair(arr[ss],ss);
        st[si] = temp;
        return temp;
    }
 
    // If there are more than one elements, then
    // recur for left and right subtrees and
    // store the max of values in this node
    int mid = getMid(ss, se);
 
    pair<int,int> p = constructSTUtil(arr, ss, mid, st,si * 2 + 1);
    pair<int,int> q = constructSTUtil(arr, mid + 1, se,st, si * 2 + 2);
 
    if(p.first > q.first)
    {
        st[si] = p;
    }
    else
        st[si] = q;
 
    return st[si];
}
 
/* Function to construct segment tree from given array.
   This function allocates memory for segment tree.*/
pair<int,int>* constructST(int arr[], int n)
{
    // Height of segment tree
    int x = (int)(ceil(log2(n)));
 
    // Maximum size of segment tree
    int max_size = 2 * (int)pow(2, x) - 1;
 
    // Allocate memory
    pair<int,int> *st = new pair<int,int>[max_size];
 
    // Fill the allocated memory st
    pair<int,int> p = constructSTUtil(arr, 0, n - 1, st, 0);
 
    // Return the constructed segment tree
    return st;
}
 
int main()
{
    int i,j,n,m,a,b,t,k;
    cin>>t;
 
    while(t--)
    {
        cin>>n>>k;
        int a[n];
 
        for(i=0;i<n;i++)
        {
            cin>>a[i];
        }
 
        pair<int,int> *st = constructST(a,n);
 
        int l = 0,r = 1;
        int ans = 0;
        while(r < n)
        {
            pair<int,int> p = getMax(st,n,l,r);
            int ind = p.second;
            pair<int,int> x = getMax(st,n,l,ind-1);
            pair<int,int> y = getMax(st,n,ind+1,r);
 
            int val1,val2;
            val1 = p.first;
            if(x.first > y.first)
                val2 = x.first;
            else
                val2 = y.first;
 
            if(val1+val2 <= k)
            {
                ans = max(r-l+1,ans);
                r++;
            }
            else
            {
                if(r-l == 1)
                {
                    l++;
                    r++;
                }
                else
                {
                    l++;
                }
            }
        }
        cout<<ans<<"\n";
    }
    return 0;
}

ALTERNATE SOLUTION:

We can also solve this problem using Priority Queue. If you want detailed explanation for this approach please let us know.

Alternate Solution by Tester
    #include <bits/stdc++.h>
    #define ll long long
    #define dl double
    #define rep(i,n)  for(int i = 0; i < n; i++)
    #define all(cont) cont.begin(), cont.end()
    #define rall(cont) cont.rbegin(), cont.rend()
    #define FOREACH(it, l) for (auto it = l.begin(); it != l.end(); it++)
    #define IN(A, B, C) assert( B <= A && A <= C)
    #define unique(a)       sort((a).begin(), a.end()), (a).erase(unique((a).begin(), (a).end()),(a).end())
    #define fastio          ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0)
    #define Assert(x)       {if(!(x)){cerr<<"Assertion failed at line "<<__LINE__<<": "<<#x<<" = "<<(x)<<"\n";exit(1);}}
    #define MP make_pair
    #define PB push_back
    #define INF (int)1e9
    #define mod 1000000007
    using namespace std;
    int main()
    {
        ll i,j,n,m,t,k,en,st,ans;
     
        cin>>t;
        while(t--)
        {
        cin>>n>>k;
        ll a[n];
     
        for(i=0;i<n;i++)
        {
            cin>>a[i];
        }
     
        priority_queue<pair<ll,ll> >pq;
        pair<ll,ll> m1,m2,temp;
     
        pq.push(make_pair(a[0],0));
        pq.push(make_pair(a[1],1));
        en = 1;
        st = 0;
        ans = 0;
     
        while(1)
        {
            if(en == n-1)
                break;
     
            while(pq.top().second < st)
            {
                pq.pop();
            }
            m1 = pq.top();
            pq.pop();
     
            while(pq.top().second < st)
            {
                pq.pop();
            }
     
            m2 = pq.top();
            pq.pop();
     
            if(m1.first + m2.first <= k)
            {
                //cout<<m1.first<<" "<<m2.first<<"\n";
                //cout<<st<<" "<<en<<"\n";
                ans = max(ans,(en-st+1));
                pq.push(m1);
                pq.push(m2);
                pq.push(make_pair(a[en+1],en+1));
                en = en+1;
            }
            else
            {
                if(m1.second < m2.second)
                {
                    st = m1.second+1;
                    pq.push(m2);
                    pq.push(make_pair(a[en+1],en+1));
                    en = en+1;
                }
                else
                {
                    st = m2.second+1;
                    pq.push(m1);
                    pq.push(make_pair(a[en+1],en+1));
                    en = en+1;
                }
            }
        }
        cout<<ans<<"\n";
        }
        return 0;
    }

Feel free to share your approach. If you have any queries, they are always welcome.

6 Likes

Can anyone help me to understand why am i getting tle for my segment tree solution
solution link
I have just done standard segment tree implementation