RICY - Editorial

PROBLEM LINK:

Practice
Div-2 Contest
Div-1 Contest

Author: Shresth Walia
Tester: daanish
Editorialist: Abhishek Chopra

DIFFICULTY:

Easy-medium

PREREQUISITES:

Binary Search, Sorting

PROBLEM:

Given and list a of length n and another list b of length m. Evaluate the following
ans = \sum_{i = 1}^{m}\sum_{j = i}^{m} rangeMin(b_i, b_j).

where rangeMin(l, r) is returns minimum element in a from l to r.

QUICK EXPLANATION:

Consider the contribution answer by each element via it’s frequency of times it occurs as a minimum element in sub-array. Sort the element pair of [a_i,i] sort them and process them in order. Use binary search in b for each pair to calculate it’s contribution.

EXPLANATION:

The expression demands, us to calculate rangeMin(b_i, b_j) \forall i, j ; i \leq j .
The idea is to simplify the problem by calculating contribution of each element from smallest to largest. While processing the input in sorted order we will have track of indices of previously used smaller elements hence the range where our current value would span across if it has to be minimum.
We can track this easily by set in C++ or a similar data structure. Processing the elements in sorted order the set S will consist of indices of elements smaller than it. Using a simple binary search on S, we can find the maximal range [L, R] that our current element can contribute to. Since any range outside [L, R] will have different minimum element whose contribution has already been calculated.

Once we calculate the span we have to calculate the number of pairs [b_i, b_j] that lie inside the span of the interval [L, R] and b_i \le pos and b_j \ge pos where pos is the index of current element whose contribution we are calculating. We again use Binary search in b to find the number of such [b_i, b_j]. Let this be f.
We add f * curValue in the ans.

Please refer to editorial solution for implementation details.

Time Complexity:

We sort the [a_i, i] pair values and do binary search on set S and list b for every pair value.

O(n*logn + n*logm).

SOLUTIONS:

Setter's Solution
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#include<bits/stdc++.h>
using namespace std;
#define hackcyborg shresth_walia
#define all(a) a.begin(), a.end()
#define ll long long 
#define ld long double
#define pb push_back
#define mod 1000000007
#define IO ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
#define int long long
#define ordered_set tree<int,null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define sz(x) (int)(x).size()
typedef pair<int, int> pii;
typedef vector<int> vi;
ll bp(ll a,ll b)
{
    ll res=1;
    while(b>0)
    { 
        if(b&1)
        res=(a*res)%mod;
        a=(a*a)%mod;
        b/=2;
    }
    return res;
}
FILE *fp;
ofstream outfile;
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        // char g = getc(fp);
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
            
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
            assert(l<=x && x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        // char g=getc(fp);
        assert(g != -1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
const int maxt = 1e5;
const int maxn = 1e5;
const int maxm = 1e5;
const int maxv = 1e9;
 
main()
{
   IO
   int t = readIntLn(1, maxt);
   // cin>>t;
   while(t--)
   { 
    int n = readIntSp(1, maxn),m = readIntLn(1, maxm);
     // cin>>n>>m;
   	 vector<pii> a(n);
   	 for(int x=0;x<n-1;x++)
   	 	{//cin>>a[x].first; 
        a[x].first = readIntSp(1, maxv);
   	 	a[x].second=x+1;}
      a[n - 1].first = readIntLn(1, maxv);
      a[n - 1].second=n;
 
   	 sort(all(a));
   	 vector<ll> b(m);
   	 set<ll> k;
     int pv = 0;
   	 for(int x=0;x<m-1;x++)
   	    {//cin>>b[x];
          b[x] = readIntSp(1, n);
          assert(b[x] > pv);
          pv = b[x];
   	    k.insert(b[x]);}
      b[m - 1] = readIntLn(1, n);
          assert(b[m - 1] > pv);
        k.insert(b[m - 1]);
 
   	set<ll> au;
   	au.insert(0);
   	au.insert(n+1);
   	ll ans=0;
   	for(int x=0;x<n;x++)
   	{
   		auto it=au.upper_bound(a[x].second);
   		int r=*it;
   		--it;
   		int l=*it;
   		ll i1=upper_bound(all(b),a[x].second-1)-upper_bound(all(b),l);
   		ll i2=upper_bound(all(b),r-1)-upper_bound(all(b),a[x].second);
   		ans+=(a[x].first*(i1*i2));
   		if(k.count(a[x].second))
   		ans+=(a[x].first*(i1+i2+1ll));
   		au.insert(a[x].second);
   	}
   	cout<<ans<<"\n";
   }
   assert(getchar()==-1);
}
Tester's Solution
#pragma GCC optimize("Ofast")
#pragma GCC target("avx,avx2,fma")
#include<bits/stdc++.h>
using namespace std;
#define hackcyborg shresth_walia
#define all(a) a.begin(), a.end()
#define ll long long 
#define ld long double
#define pb push_back
#define mod 1000000007
#define IO ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
#define int long long
#define ordered_set tree<int,null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define sz(x) (int)(x).size()
typedef pair<int, int> pii;
typedef vector<int> vi;
ll bp(ll a,ll b)
{
    ll res=1;
    while(b>0)
    { 
        if(b&1)
        res=(a*res)%mod;
        a=(a*a)%mod;
        b/=2;
    }
    return res;
}
FILE *fp;
ofstream outfile;
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        // char g = getc(fp);
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
            
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
            assert(l<=x && x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        // char g=getc(fp);
        assert(g != -1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
const int maxt = 1e5;
const int maxn = 1e5;
const int maxtn = 1e6;
const int maxm = 1e5;
const int maxv = 1e9;
 
main()
{
   IO
   int t = readIntLn(1, maxt);
   // cin>>t;
   int tn = 0;
   while(t--)
   { 
    int n = readIntSp(1, maxn),m = readIntLn(1, maxm);
    tn += n;
     // cin>>n>>m;
     vector<pii> a(n);
     for(int x=0;x<n-1;x++)
      {//cin>>a[x].first; 
        a[x].first = readIntSp(1, maxv);
      a[x].second=x+1;}
      a[n - 1].first = readIntLn(1, maxv);
      a[n - 1].second=n;
 
     sort(all(a));
     vector<ll> b(m);
     set<ll> k;
     int pv = 0;
     for(int x=0;x<m-1;x++)
        {//cin>>b[x];
          b[x] = readIntSp(1, n);
          assert(b[x] > pv);
          pv = b[x];
        k.insert(b[x]);}
      b[m - 1] = readIntLn(1, n);
          assert(b[m - 1] > pv);
        k.insert(b[m - 1]);
 
    set<ll> au;
    au.insert(0);
    au.insert(n+1);
    ll ans=0;
    for(int x=0;x<n;x++)
    {
      auto it=au.upper_bound(a[x].second);
      int r=*it;
      --it;
      int l=*it;
      ll i1=upper_bound(all(b),a[x].second-1)-upper_bound(all(b),l);
      ll i2=upper_bound(all(b),r-1)-upper_bound(all(b),a[x].second);
      ans+=(a[x].first*(i1*i2));
      if(k.count(a[x].second))
      ans+=(a[x].first*(i1+i2+1ll));
      au.insert(a[x].second);
    }
    cout<<ans<<"\n";
   }
   assert(tn <= maxtn);
   assert(getchar()==-1);
}  
Editorialist's Solution
#pragma GCC optimize("O3")
#pragma GCC target("sse4")
#include "bits/stdc++.h"
//#include <ext/pb_ds/assoc_container.hpp>
//#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
//using namespace __gnu_pbds;
#define int long long int
#define SYNC std::ios_base::sync_with_stdio(0);cin.tie(NULL);cout.tie(NULL);
#define FRE freopen("input.txt","r",stdin);freopen("output.txt","w",stdout);
typedef long double ld;
typedef pair<int,int> ii;
typedef pair<int,ii> iii;
typedef vector<int> vi;
typedef vector<ii>   vii;
//typedef tree<int, null_type, less<int>, rb_tree_tag,
//             tree_order_statistics_node_update>
//    ost;
#define rep(i,l,r)   for (int i = (l); i < (r); i++)
#define here cout << " Hey!!\n";
#define  pb push_back
#define  F  first
#define  S  second
#define all(v) (v).begin(),(v).end()
#define sz(a) (int)((a).size())
#define sq(x) ((x)*(x))

const int MOD = 1e9+7;
const int MOD1 = 998244353;
const int N = 2e5+5;
const int INF = 1000111000111000111LL;
const ld EPS = 1e-12;
const ld PI = 3.141592653589793116;

void solve() {
    int n, m; cin >> n >> m;
    vector<int> a(n), b(m);
    for (int &x : a)    cin >> x;
    for (int &x : b) {
        cin >> x;
        x--; // 0 based indexing
    }
    vector<pair<int, int> > v;
    for (int i = 0; i < n; i++) {
        v.push_back({a[i], i});
    }
    sort(v.begin(), v.end()); // Calculating answer in sorted order
    set<int> st;
    st.insert(-1); st.insert(n); // The extreme indices, assuming them to -INF
    int ans = 0;
    // Funtion to find number of [b_i, b_j] between [L, R] in  b
    auto get_range = [&] (int L, int R) -> int{
        return upper_bound(b.begin(), b.end(), R) - lower_bound(b.begin(), b.end(), L);
    };
    for (int i = 0; i < n; i++) {
        int pos = v[i].S;
        int cur = v[i].F;
        int L = *(--st.lower_bound(pos)) + 1;
        int R = *(st.lower_bound(pos)) - 1;
        // The number of times cur will be added in answer 
        ans += get_range(L, pos) * get_range(pos, R) * cur;
        st.insert(pos);
    }
    cout << ans << '\n';
}

int32_t main()
{
    SYNC
    int T; cin >> T;
    while (T--) {
        solve();
    }
    return 0;
}
1 Like

Interesting problem. The implementation was hard but the idea was pretty interesting.

1 Like

You can do better and solve it in O(nlogm) (no need of sorting A). You can use a stack to maintain the valid segment for each element.

There is an easier way to solve this problem.
It can be explained using an example
Consider 1-based indexing

A[]  {10, 3, 7, 9, 1, 19, 2}
B[]  {1, 4, 6}

Let MIN(L, R) denote the minimum in the range L to R of array A
We compute an array C by taking MIN(B[i], B[i + 1])

C[] = {MIN(1, 4), MIN(4, 6)}
C[] = {3, 1}

Idea behind forming this array is: MIN(1, 6) = min(MIN(1, 4) ,MIN(4, 6))
Now we compute sum of minimum elements of all subarrays in C
(this can be done by finding in how many subarrays a given element is minimum. )
(Sum of minimum elements of all subarrays - GeeksforGeeks)
Finally we add the single element ranges MIN(1, 1), MIN(4, 4), MIN(6, 6) to the sum, which gives us the answer

11 Likes

shouldn’t it be O(n + n*logm)?O(n) for stack traversing

Yeah, but I just didn’t write the O(n) term since it’s always going to be smaller than O(nlogm).

1 Like

It can also be done in O(n):

        long long ans = 0;
		for (int i = 1; i <= n; i++) {
			int ql = pref[i] - pref[left[i]];
			int qr = pref[right[i] - 1] - pref[i - 1];
			ans += (long long)a[i] * ql * qr;
		}
		cout << ans << "\n";

pref[i] = count of j s.t b_j <= i
a[i] is min. in (left[i], right[i])

3 Likes

wow nice!!

Yeah, using stack was a better approach indeed. I will add it as an alternative solution thanks!

Can you explain the function get_range in your code? I understand what it is doing. But I’m unable understand how it is doing.
And also these two lines. How are you choosing L and R?

int L = *(–st.lower_bound(pos)) + 1;
int R = *(st.lower_bound(pos)) - 1;

Mainly what I’m asking for is that how you are using Binary Search to get these values?

implementaion of stack method in O(n).
#include <bits/stdc++.h>
using namespace std;
#define ll long long
int main() {
// your code goes here
int t; cin>>t; while(t–){
int n,m; cin>>n>>m;
ll a[n+1],b[n+1];
ll cum[n+1];
a[0]=0;
for(int i=1;i<=n;i++){
cin>>a[i];
b[i]=0;
cum[i]=0;
}

for(int i=1;i<=m;i++){
cin>>b[0];
b[b[0]]=1;
}
b[0]=0;
cum[0]=0;
for(int i=1;i<=n;i++){
cum[i]=cum[i-1]+b[i];
}
int psm[n+1],nsm[n+1];
psm[0]=0;nsm[0]=0;
stack s;
for(int i=1;i<=n;i++){
while(!s.empty() && a[i] < a[s.top()]){
nsm[s.top()]=i;
s.pop();
}
s.push(i);
}
while(!s.empty()){ nsm[s.top()]=n+1; s.pop();}
for(int i=n;i>=1;i–){
while(!s.empty() && a[i] < a[s.top()]){
psm[s.top()]=i;
s.pop();
}
s.push(i);
}
while(!s.empty()){ psm[s.top()]=0; s.pop();}
// for(int i=0;i<=n;i++)cout<<b[i]<<" “; cout<<endl;
// for(int i=0;i<=n;i++)cout<<cum[i]<<” “; cout<<endl;
ll ans = 0;
for(int i=1;i<=n;i++){
ll lb = psm[i]+1;
ll ub = nsm[i]-1;
ll left,right;
right = cum[ub]-cum[i-1];
left = cum[i]-cum[lb-1];
ll add = rightlefta[i];
// cout<<add<<” ";
ans+=add;

	}
	//cout<<endl;
	cout<<ans<<endl;
}
return 0;

}

Very intresting problem . Had fun solving this one.

–st.lower_bound(pos) will return the maximum index just less than pos that has already been visited (in other words whose contribution has been considered) and (+1) will just give me the range of the interval I can use cause any less than that will have a different minimum. I hope you will now be able to figure the same for R as well.

1 Like

Please help me in solving I don’t get your solution. I have tried something easy and got Time Limit exceeded. I have tried a lot to optimize to O(n) from O(n*n) but I was not able to do it.

That’s my Solution
#include <bits/stdc++.h>
#include
using namespace std;

int minEle(int a, int b, int arr[])
{
return(*min_element(arr+a-1, arr+b-1));
}

int main() {
// your code goes here
int T;
cin>>T;
for(int i=0; i<T; i++)
{
int n, m;
cin>>n>>m;
int a[n], b[m];
for(int i=0; i<n; i++)
cin>>a[i];
for(int i=0; i<m; i++)
cin>>b[i];

    int sum = 0;
    for(int i=0; i<m; i++)
    {
        for(int j=i; j<m; j++)
        {
            // cout<<minEle(b[i], b[j], a)<<endl;
            sum += minEle(b[i], b[j], a);
        }
    }
    cout<<sum<<endl;
    
}
return 0;

}