DVL - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Kasra Mazaheri

Tester: Arshia

Editorialist: Taranpreet Singh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Prefix Sums, Segment Tree and Lot of precomputation!

PROBLEM:

Given an array A of length N \leq 3000, find number of tuples (a, b, c, d, e, f, g) such that

  • 1 \leq a < b < c < d < e < f < g \leq N
  • A_g < A_e < A_f < A_d < A_b < A_c < A_a

NOT SO QUICK EXPLANATION

  • Let us try each d. Now, we have two almost similar subproblems, finding number of tuples (a, b, c) in subarray A[1, d-1] such that A_d < A_b < A_c < A_a and number of tuples (e, f, g) such that A_g < A_e < A_f < A_d. The final answer is product of answers to these subproblems, taking sum of product over all values of d.
  • For subproblem (a, b, c), we wish to compute C(p, x) as the number of triples (a, b, c) such that c = p and x = A_b. Since A_b < A_c < A_a and A_d < A_b, the number of triplets for position d is \sum_{p = 1}^{d-1} \sum_{y = x+1}^{mx}C(p, y) where mx is the maximum value present in array.
    Let’s define CC(d, x) as \sum_{p = 1}^{d-1} \sum_{y = x+1}^{mx}C(p, y). Since tuple (d, x) can take N^2 different values, we can precompute it using prefix sums on C.
  • Similarly we can define functions E(p, x) as number of triplets (e, f, g) such that x = A_f and p = e and EE(d, x) = \sum_{p = d+1}^{N} \sum_{y = 0}^{x-1} E(p, y) We can precompute this too.
  • Final answer is \sum_{d = 1}^{N} CC(d, A_d)*EE(d, A_d)

EXPLANATION

Woah! Scary 7-tuple. Seems hard to handle, so let’s break it first! But before that, since we only care about relative inequalities, we can compress the initial array A such that it contains at most N distinct values.

Let us count tuples with each value of d separately. For a fixed d, the number of 7-tuple is the product of the following two.

  • Number of tuples (a, b, c, d) such that A_d < A_b < A_c < A_a for a given d
  • Number of tuples (d, e, f, g) such that A_g < A_e < A_f < A_d for a given d

These are actually the same problems, we can reverse the array to translate the second subproblem into the first subproblem, so we’ll focus only on the first one. Second subproblem can be solved similarly.

Let us define function C(p, x) denote the number of triplets (a, b, c) such that A_b < A_c < A_a and a < b < c and p = c and x = A_b.

For position p = c, consider all positions b \leq c such that A_b < A_c, position b contribute (number of elements in range [1, b-1] greater than A_c) to C(p, A_b)

Now, Let’s compute CC(p, x) as defined in Not so quick explanation.

For a position p, CC(p, x) = CC(p-1, x) + \sum_{y= x+1}^{mx} C(p, y). Considering each x in increasing order, we can compute this table too in O(MX*N) time.

Similarly, we compute E(p, x) and EE(p, x) in almost same manner. The only change is the signs, if we do not reverse the arrays.

For the final answer, we need to take some of CC(d, A_d)*EE(d, A_d) over all values of d

[details=“Number of values in range [L, R] greater than X”]
Another Two-dimensional prefix array. :stuck_out_tongue:
freq[x][y] denote the number of values in prefix [1, x] with value y.
[/details]

I know it seems confusing, so refer to the editorialist solution with this approach.

Incidentally, the two sub-problems can also be solved using the segment tree instead of prefix array, as used by setter and tester which you may refer below.

This problem is worth a try after solving this problem.

TIME COMPLEXITY

Time complexity is O(N^2) or O(N^2*log(N)) per test case depending upon implementation.

SOLUTIONS:

Setter's Solution
// ItnoE
#include<bits/stdc++.h>
using namespace std;
const int N = 3003, Mod = 987654319;
struct Task
{
	#define lc (id << 1)
	#define rc (id << 1 ^ 1)
	#define md ((l + r >> 1))
	int n;
	vector < int > A, C, S, Lz;
	inline void Shift(int id)
	{
	    Lz[lc] += Lz[id];
	    Lz[rc] += Lz[id];
	    S[lc] += Lz[id] * C[lc];
	    S[rc] += Lz[id] * C[rc];
	    Lz[id] = 0;
	}
	void Plus(int i, int id, int l, int r)
	{
	    C[id] ++;
	    if (r - l < 2)
	        return ;
	    Shift(id);
	    if (i < md)
	        Plus(i, lc, l, md);
	    else
	        Plus(i, rc, md, r);
	}
	void Add(int le, int ri, int id, int l, int r)
	{
	    if (ri <= l || r <= le)
	        return ;
	    if (le <= l && r <= ri)
	    {
	        S[id] += C[id];
	        Lz[id] ++;
	        return ;
	    }
	    Shift(id);
	    Add(le, ri, lc, l, md);
	    Add(le, ri, rc, md, r);
	    S[id] = S[lc] + S[rc];
	}
	int Get(int le, int ri, int id, int l, int r)
	{
	    if (ri <= l || r <= le)
	        return (0);
	    if (le <= l && r <= ri)
	        return (S[id]);
	    Shift(id);
	    return (Get(le, ri, lc, l, md) + Get(le, ri, rc, md, r));
	}
	inline int Solve()
	{
	    int tot = 0;
	    n = (int)A.size();
	    vector < int > U = A;
	    C = vector < int > (n * 4 , 0);
	    S = vector < int > (n * 4 , 0);
	    Lz = vector < int > (n * 4 , 0);
	    sort(U.begin(), U.end());
	    for (int i = 0; i < n; i ++)
	    {
	        A[i] = (int)(lower_bound(U.begin(), U.end(), A[i]) - U.begin());
	        tot = (tot + Get(A[i] + 1, n, 1, 0, n)) % Mod;
	        Add(0, A[i], 1, 0, n); Plus(A[i], 1, 0, n);
	    }
	    return (tot);
	}
};
int n, A[N];
Task P, S;
int main()
{
	scanf("%d", &n);
	for (int i = 1; i <= n; i ++)
	    scanf("%d", &A[i]);
	int tot = 0;
	for (int d = 1; d <= n; d ++)
	{
	    P.A.clear();
	    for (int j = 1; j < d; j ++)
	        if (A[j] > A[d])
	            P.A.push_back(-A[j]);
	    reverse(P.A.begin(), P.A.end());
	    S.A.clear();
	    for (int j = d + 1; j <= n; j ++)
	        if (A[j] < A[d])
	            S.A.push_back(A[j]);
	    tot = (tot + 1LL * P.Solve() * S.Solve()) % Mod;
	}
	return !printf("%d\n", tot);
}
Tester's Solution
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <exception>
#include <fstream>
#include <functional>
#include <iomanip>
#include <ios>
#include <iosfwd>
#include <iostream>
#include <istream>
#include <iterator>
#include <limits>
#include <list>
#include <locale>
#include <map>
#include <memory>
#include <new>
#include <numeric>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <stdexcept>
#include <streambuf>
#include <string>
#include <typeinfo>
#include <utility>
#include <valarray>
#include <vector>
#if __cplusplus >= 201103L
#include <array>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <forward_list>
#include <future>
#include <initializer_list>
#include <mutex>
#include <random>
#include <ratio>
#include <regex>
#include <scoped_allocator>
#include <system_error>
#include <thread>
#include <tuple>
#include <typeindex>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#endif

int gcd(int a, int b) {return b == 0 ? a : gcd(b, a % b);}

#define ll int
#define pb push_back
#define ld long double
#define mp make_pair
#define F first
#define S second
#define pii pair<ll,ll>

using namespace :: std;

const ll maxn=3100;
const ll inf=1e9+1;
const ll mod=987654319;

ll seg[maxn*4];
ll lazy[maxn*4];
ll on[maxn*4];


inline ll ok(ll a){
	if(a>=mod)return a-mod;
	return a;
}
void pushh(ll node,ll v){
	seg[node]+=(on[node]*v)%mod;
	seg[node]%=mod;
	lazy[node]+=v;
}
void shift(ll node){
	pushh(2*node,lazy[node]);
	pushh(2*node+1,lazy[node]);
	lazy[node]=0;
}
ll findSum(ll l,ll r,ll L,ll R,ll node){
	if(l<=L && R<=r){
		return seg[node];
	}
	if(r<=L || R<=l)return 0;
	ll mid=(L+R)/2;
	shift(node);
	return ok(findSum(l,r,L,mid,node*2)+findSum(l,r,mid,R,2*node+1));
}
void tern_on(ll x,ll L,ll R,ll node){
	if(L+1==R){
		on[node]=1;
		return ;
	}
	ll mid=(L+R)/2;
	shift(node);
	if(x<mid){
		tern_on(x,L,mid,2*node);
	}else{
		tern_on(x,mid,R,2*node+1);
	}
	on[node]=on[2*node+1]+on[2*node];
}
void add(ll l,ll r,ll L,ll R,ll node){
	if(l<=L && R<=r){
		seg[node]+=on[node];
		lazy[node]++;
		return;
	}
	if(r<=L || R<=l)return ;
	ll mid=(L+R)/2;
	shift(node);
	add(l,r,L,mid,node*2);
	add(l,r,mid,R,2*node+1);
	seg[node]=ok(seg[node*2]+seg[2*node+1]);
}


void comp_vec(vector<ll> &a){
	ll n=a.size();
	vector<ll> co;
	co.resize(n);
	for(ll i=0;i<n;i++){
		co[i]=a[i];
	}
	sort(co.begin(),co.end());
	for(ll i=0;i<n;i++){
		a[i]=lower_bound(co.begin(),co.end(),a[i])-co.begin();
	}
}
vector<ll> ww[maxn];
ll findAns(vector<ll> a){
	ll n=a.size();
	comp_vec(a);
	ll ans=0;
	for(ll i=0;i<n;i++){
		ww[i].clear();
	}
	fill(seg,seg+4*n,0);
	fill(lazy,lazy+4*n,0);
	fill(on,on+4*n,0);

	for(ll i=0;i<n;i++){
		ww[a[i]].pb(i);
	}
	for(ll i=0;i<n;i++){
		for(auto r:ww[i]){
			ans+=findSum(r,n,0,n,1);
			if(ans>=mod)ans-=mod;
			add(0,r,0,n,1);
		}
		for(auto r:ww[i]){
			tern_on(r,0,n,1);
		}
	}
	return ans;
}
ll a[maxn];
int main(){
	ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	ll n;
	cin>>n;
	for(ll i=0;i<n;i++){
		cin>>a[i];
	}

	long long ans=0;
	for(ll i=3;i<n-3;i++){
		vector<ll> vecl,vecr;
		for(ll j=0;j<i;j++){
			if(a[j]>a[i]){
				vecl.pb(a[j]);
			}
		}
		for(ll j=i+1;j<n;j++){
			if(a[i]>a[j]){
				vecr.pb(a[j]);
			}
		}
		reverse(vecr.begin(),vecr.end());
		for(ll j=0;j<vecr.size();j++){
			vecr[j]=inf-vecr[j];
		}
		ans+=(1LL*findAns(vecl)*findAns(vecr))%mod;
		ans%=mod;
	}
	cout<<ans;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class DVL{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni();
	    long[] aa=  new long[n];
	    TreeSet<Long> set = new TreeSet<>();TreeMap<Long, Integer> mp = new TreeMap<>();
	    for(int i = 0; i< n; i++){aa[i] = nl();set.add(aa[i]);}
	    int cnt = 0;
	    for(long l:set)mp.put(l, cnt++);
	    int[] A = new int[1+n];
	    int[][] freq = new int[cnt][1+n], less = new int[cnt][1+n], more = new int[cnt][1+n];
	    for(int i = 1; i<= n; i++){
	        A[i] = mp.get(aa[i-1]);
	        freq[A[i]][i]++;
	        for(int j = 0; j< cnt; j++){
	            if(A[i] < j)less[j][i]++;
	            else if(A[i] > j)more[j][i]++;
	        }
	        for(int j = 0; j< cnt; j++){
	            freq[j][i] += freq[j][i-1];
	            less[j][i] += less[j][i-1];
	            more[j][i] += more[j][i-1];
	        }
	    }
	    
	    long[][] C = new long[cnt][1+n];
	    for(int c = 1; c <= n; c++)
	        for(int b = 1; b < c; b++)
	            if(A[b] < A[c]){
	                C[A[b]][c] += more[A[c]][b];
	                if(C[A[b]][c] >= mod)C[A[b]][c] -= mod;
	            }
	    long[][] E = new long[cnt][1+n];
	    for(int e = n; e >= 1; e--)
	        for(int f = n; f > e; f--)
	            if(A[f] > A[e]){
	                E[A[f]][e] += less[A[e]][n]-less[A[e]][f];
	                if(E[A[f]][e] >= mod)E[A[f]][e] -= mod;
	            }
	    
	    long[][] Cmore = new long[cnt][1+n], Eless = new long[cnt][2+n];
	    for(int i = 1; i<= n; i++){
	        for(int j = 0; j< cnt; j++)Cmore[j][i] = Cmore[j][i-1];
	        long sum = 0;
	        for(int j = cnt-1; j>= 0; j--){
	            Cmore[j][i] = (Cmore[j][i]+sum)%mod;
	            sum = (sum+C[j][i])%mod;
	        }
	    }
	    for(int i = n; i>= 1; i--){
	        for(int j = 0; j< cnt; j++)Eless[j][i] = Eless[j][i+1];
	        long sum = 0;
	        for(int j = 0; j< cnt; j++){
	            Eless[j][i] = (Eless[j][i]+sum)%mod;
	            sum = (sum+E[j][i])%mod;
	        }
	    }
	    long ans = 0;
	    for(int i = 1; i<= n; i++)
	        ans = (ans+(Cmore[A[i]][i]*Eless[A[i]][i])%mod)%mod;
	    pn(ans);
	}
	long mod = 987654319;
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = false;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new DVL().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

4 Likes

It is not must do. It’s just a problem worth trying in my problem and had slightly similar inequalities.

It’s upon you to do it or not.

2 Likes

Personally, I feel like the challenge here is the implementation. :smiley:

I will get back to this thread after I implement, I guess.

Great editorial, btw!

Sorry, my mistake.
Shall correct it by evening.

What i mean was CC(p, x) = CC(p-1, x)+ \sum_{y = x+1}^{mx}C(p, y)

@taran_1407 during the contest I coded this solution that is purely dp but it kept getting wa. I cant’t think why the solution fails. Here is my solution from the contest. CodeChef: Practical coding for everyone.

i is the index of array, j is the number of elements selected till now,
If the ith index is selected then, f ==0 selects b or e, f==1 selects c and f and f==2 selects a, d and g.

Can you explain your DP approach?

dp[i][j][f][p] , i is the current index which will either be selected or is skipped in call( solve (i + 1, j, f, p)). If it’s selected then we look for an index k which can be selected. Index k is selected depending upon what we need to select. If a or d or g has to be selected it needs to be only less than the ith index value which is represented by f=2. If b or e has to be selected it needs to be less than i as well as the previous index p that was selected before i. If c or e has to be selected it need to be greater than i but less than p represented by f=1.

@roll_no_1 I have updated it. Thanks for pointing out.

Please let me know if anything’s still unclear.

I actually found @spellstaker’s submission to be very helpful in understanding the solution after reading the definitions in the quick explanation section: CodeChef: Practical coding for everyone

3 Likes

Updated.
Check now.

1 Like

Nice, a shout out from @roll_no_1; I’m practically famous now.

1 Like

Your implementation was really clean and easy to understand , only your code helped me to understand the solution of this problem.

1 Like