IAI - EDITORIAL

PROBLEM LINK

Practice

Contest: Division 1

Contest: Division 2

Setter: Vivek Chauhan

Tester: Michael Nematollahi

Editorialist: Taranpreet Singh

DIFFICULTY

Medium-Hard

PREREQUISITES

Prefix arrays, Convex hull trick, Segment Tree and a tactful eye of observation.

PROBLEM

Given an array A of length N and M intervals> For any subset of given M intervals such that their intersection [L, R] is non-empty, value of this subset is calculated as \sum_{i = L}^{R} A_i*(i-L+1). Find maximum value of any valid subset with non-empty intersection.

Let us denote value of intersection [L, R] by val(L, R).

EXPLANATION

In this problem, we shall solve the first subtask, and only then extend that solution to solve the second subtask.

Lemma: The intersection of any subset is dependent upon at most two intervals in the subset.
Proof: If the intersection is given by [L, R], we know, there exists an interval [x, y] with the left end at L and another same or different interval with the right end at R. This way, even if we ignore the remaining intervals present in the subset, the intersection remains the same. Hence, the intersection of the subset can be represented by a pair of intervals.

Using above, if we try every pair of intervals, resulting in total M^2 pairs of intervals and computing value in O(N) time, it is too slow to pass the first subtask yet, being of complexity O(N*M^2).

However, we can use precomputation to compute the value of a range in O(1) as explained below.

Let us compute two arrays S and T such that S_{i} = S_{i-1} + A_i and T_i = T_{i-1}+i*A_i.

Now, we can see that T_X = A_1+2*A_2+3*A_3 + \ldots X*A_X which is actually the value of range [1, X]. Now see that T_R-T_{L-1} = L*A_L+(L+1)*A_{L+1}+(L+2)*A_{L+2}+\ldots + R*A_R.
Also, val(L, R) = 1*A_L+2*A_{L+1}+\ldots+(R-L+1)*A_R. If we subtract above two, we get (L-1) times the sum of range [L, R].

This gives the following expression to compute the value of an interval.

val(L, R) = T_R-T_{L-1}-(L-1)*(S_R-S_{L-1}).

Using this expression, our solution becomes O(M^2) which passes the first subtask, but not second (what else did you expect? xD).

Rewriting val(L, R) = T_R -(L-1)*S_R-T_{L-1}+(L-1)*S_{L-1}. Notice that the for a given left end, -T_{L-1}+(L-1)*S_{L-1} remains constant. Let us denote f(L) = -T_{L-1}+(L-1)*S_{L-1} which can be precomputed using S and T array.

Now, We have to maximize val(L, R) = T_R-(L-1)*S_R+f(L).

If we fix the right end of the intersection, Both T_R and S_R remain constant and we just need to maximize -(L-1)*S_R+f(L). We can see, that for a fixed R, we just have different linear functions for each Left end position and we need to find the maximum value of those functions when evaluated at S_R.

For each position p, Let us store a function g_p(x) = -(p-1)*x+f(p).

Let us assume that interval [L, R] shall define the right end of the intersection. Clearly, we can choose only those intervals as the second interval which have right end \geq R. Hence, for all intervals with right end \geq R, we can insert the function corresponding to the left end of the interval which is \leq L and evaluate them at S_R. The maximum value out of these values is the maximum value we can obtain if we fix intersection at position R.

This way, we can try to fix each possible position R and take the maximum answer.

But this procedure is still slow, since for each R, we need to find all intervals with right end \geq R and evaluating functions associated with left ends individually.

Following requires an understanding of Convex Hull Trick, as explained here.

But we can easily speed it up. In place of individually computing functions, we can insert them into Dynamic Convex Hull and make a query to answer the maximum value.

Now, Evaluation takes O(log(M)) time, but inserting functions into Hull is still taking O(M) time.

Let us sort the intervals in non-increasing order of their right ends. Now we can see that the interval once inserted into convex hull do not need to be removed. Now, there are only O(M) insertions into the convex hull.

But there is still a catch. If there are two intervals [A, B] and [C, D] such that C < A, B \leq D, function at position C shall be inserted by the time we want to evaluate functions when the right end is fixed at position B. But, C cannot be the left end of the intersection, A > C. So, we need to evaluate only those functions, which lie within the current interval [L, R].

Let’s assume we have a data structure which supports the following operations.

  • Insert (P, M, C) which insert function H(x) = M*x+C at position P.
  • Query (L, R, X) returning the maximum value of functions in the range [L, R] when evaluated at point X.

We can just insert functions at their left end positions as soon as the right end of intersection is \leq R where R is the right end of the current interval. To calculate the maximum answer, we just make query (L, R, S_R). Suppose we get V. So, T_R+V is the maximum value of any intersection ending at R.

We can calculate this maximum value by fixing all positions as of the right end of the intersection and take the maximum.

The data structure mentioned above is a segment tree having convex hull at each node. The convex Hull at each node stores all functions in the range represented by the node. Now, whenever we find an interval [A, B] with $B \geq R, we can insert at position A the function associated with position A. We can see that function is inserted at a maximum log(N) nodes. To find the maximum value of any function which lies within the range [L, R] when evaluated at S_R, we make a query to segment tree which returns the maximum value when all functions in the range [L, R] are evaluated at S_R.

This solves our problem fast. How fast, we can see below.

TIME COMPLEXITY

Following points affect the time complexity of our solution.

  • Sorting the intervals take O(M*log(M)) time.
  • Construction of the segment tree takes O(N) time.
  • Inserting a function at any position in segment tree needs to insert functions in log(N) convex hulls, each of which takes O(log(M)) time, leading to O(log(N)*log(M)) time for insertion.
  • Similarly, the query to segment tree also takes O(log(N)*log(M)) time.
  • We make a total of M insertions and M queries.

So, the final time complexity becomes O(N+M*log(M)*log(N)) per test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef long double ld;
const int N = 100005;
#define M 30005
ll inf = 1e16;
ll mod = 1e9 + 7;

char en = '\n';
ll power(ll x, ll n, ll mod) {
  ll res = 1;
  x %= mod;
  while (n) {
	if (n & 1)
	  res = (res * x) % mod;
	x = (x * x) % mod;
	n >>= 1;
  }
  return res;
}

bool Q;
struct Line {
  mutable ll k, m, p;
  bool operator<(const Line& o) const {
	return Q ? p < o.p : k < o.k;
  }
};
struct LineContainer : multiset<Line> {
  ll div(ll a, ll b){
	return a / b - ((a ^ b) < 0 && a % b);
  }
  bool isect(iterator x, iterator y) {
	if (y == end()) { x->p = inf; return false; }
	if (x->k == y->k) x->p = x->m > y->m ? inf : -inf;
	else x->p = div(y->m - x->m, x->k - y->k);
	return x->p >= y->p;
  }
  void add(ll k, ll m) {
	auto z = insert({k, m, 0}), y = z++, x = y;
	while (isect(y, z)) z = erase(z);
	if (x != begin() && isect(--x, y)) isect(x, y = erase(y));
	while ((y = x) != begin() && (--x)->p >= y->p)
	  isect(x, erase(y));
  }
  ll query(ll x) {
	//   cout<<endl<<jk<<endl;
	if(empty())
	{
	  return -inf;
	}
	Q = 1; auto l = *lower_bound({0,0,x}); Q = 0;
	//     cout<<l.k<<" h "<<l.m<<endl;
	return l.k * x + l.m;
  }
};
LineContainer tr[4*N];
ll query1(ll l,ll r,ll a,ll b,ll x,ll si=1)
{
  if(l>b or r<a)
	return -inf;

  if(l>=a && r<=b)
  {
	return tr[si].query(x);
  }
  ll mid=(l+r)>>1;
  ll left1=query1(l,mid,a,b,x,si<<1);
  ll right1=query1(mid+1,r,a,b,x,(si<<1)+1);

  return max(left1,right1);
}

void update(ll l,ll r,ll i,ll k,ll m,ll si=1)
{
  if(l>i or r<i)
	return;

  if(l==r)
  {
	tr[si].add(k,m);
	return;
  }

  tr[si].add(k,m);
  ll mid=(l+r)>>1;
  update(l,mid,i,k,m,(si<<1));
  update(mid+1,r,i,k,m,(si<<1)+1);
}
ll pre1[N],pre2[N];

int32_t main() {
  ios_base::sync_with_stdio(false);
  cin.tie(NULL);

  ll t;
  cin >> t;
  while (t--) {
	ll n,m;
	cin>>n>>m;

	for(ll i=0;i<=4*n+5;i++)
	{
	  tr[i].clear();
	}

	ll arr[n+5];
	for(ll i=1;i<=n;i++)
	  cin>>arr[i];
	ll interval[m+5][2];
	vector<ll> events[n+5];
	for(ll i=1;i<=m;i++)
	{
	  cin>>interval[i][0]>>interval[i][1];
	  events[interval[i][1]].push_back(interval[i][0]);
	}

	memset(pre1,0,sizeof(pre1));
	memset(pre2,0,sizeof(pre2));

	for(ll i=1;i<=n;i++)
	{
	  pre1[i]=pre1[i-1]+arr[i]*i;
	  pre2[i]=pre2[i-1]+arr[i];
	}

	ll res=-inf;
	for(ll i=n;i>=1;i--)
	{
	  for(ll x:events[i])
	  {
	    update(1,n,x,-(x-1),pre2[x-1]*(x-1)-pre1[x-1]);
	  }
	  for(ll x:events[i])
	  {
	    res=max(res,pre1[i]+query1(1,n,x,i,pre2[i]));
	  }
	}

	cout<<res<<en;

  }

  return 0;
}
Tester's Solution
#include<bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

#define F first
#define S second

const int MAXN = 1e5 + 10;
const int SQ = 400;
const ll INF = 1e18;

int n, m, a[MAXN], sec[MAXN];
ll p[MAXN], p2[MAXN], ms[MAXN], b[MAXN];
pii seg[MAXN];
bool on[MAXN];

//p2[r]-p2[l]-l*p[r]+l*p[l]

bool cmp(int u, int v){
	if (seg[u].S ^ seg[v].S)
		return seg[u].S > seg[v].S;
	return seg[u].F < seg[v].F;
}

ll meet(int u, int v){
	return (b[v]-b[u] + (ms[u]-ms[v]-1))/ (ms[u]-ms[v]);
}

vector<int> vec[SQ];
vector<ll> best[SQ];
void turnOn(int id){
	on[id] = true;
	ms[id] = seg[id].F;
	b[id] = -p2[seg[id].F] + 1ll*seg[id].F*p[seg[id].F];

	int block = id/SQ;
	int l = block*SQ;
	int r = min(m, l + SQ);

	vec[block].clear();
	best[block].clear();
	for (int i = l; i < r; i++){
		if (!on[i]) continue;
		if (vec[block].size() && ms[vec[block].back()] == ms[i]) continue;
		while (vec[block].size() > 1 && meet(i, vec[block].back()) <= meet(vec[block].back(), vec[block][vec[block].size()-2]))
			vec[block].pop_back(), best[block].pop_back();

		if (vec[block].empty())
			best[block].push_back(-INF);
		else
			best[block].push_back(meet(i, vec[block].back()));
		vec[block].push_back(i);
	}
}

ll eval(int v, ll x){return ms[v]*x+b[v];}

ll get(int l, int r, ll x){
	ll ret = -INF;
	while (l < r && l%SQ){
		if (on[l])
			ret = max(ret, eval(l, x));
		l++;
	}

	while (l/SQ ^ r/SQ){
		int block = l/SQ;
		if (vec[block].size()){
			int pos = upper_bound(best[block].begin(), best[block].end(), x) - best[block].begin() - 1;
			ret = max(ret, eval(vec[block][pos], x));
		}
	
		l += SQ;
	}

	while (l<r){
		if (on[l])
			ret = max(ret, eval(l, x));
		l++;
	}
	return ret;
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	int te;	cin >> te;
	while (te--){
		cin >> n >> m;
		for (int i = 0; i < n; i++) cin >> a[i], p[i+1] = p[i] + a[i], p2[i+1] = p2[i] + 1ll*(i+1)*a[i];
		for (int i = 0; i < SQ; i++) vec[i].clear(), best[i].clear();
		for (int i = 0; i < m; i++) cin >> seg[i].F >> seg[i].S, seg[i].F--;
		sort(seg, seg + m);

		iota(sec, sec + m, 0);
		sort(sec, sec + m, cmp);
		ll ans = -1e18;
		memset(on, 0, sizeof(on));
		for (int i = 0; i < m; i++){
			int q = sec[i];
			turnOn(q);
			ans = max(ans, get(lower_bound(seg, seg + m, make_pair(seg[q].F, -1))-seg, 
						lower_bound(seg, seg + m, make_pair(seg[q].S, -1))-seg, 
							-p[seg[q].S])+p2[seg[q].S]);
		}
		cout << ans << "\n";
	}
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class IAI{
	//SOLUTION BEGIN
	long IINF = (long)1e16;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), m = ni();
	    long[] a = new long[1+n];
	    long[] s1 = new long[1+n], s2 = new long[1+n];
	    long[] f = new long[1+n];
	    for(int i = 1; i<= n; i++){
	        a[i] = nl();
	        s1[i] = s1[i-1]+a[i];
	        s2[i] = s2[i-1]+i*a[i];
	        f[i] = -s2[i-1]+(i-1)*s1[i-1];
	    }
	    int[][] w = new int[m][];
	    for(int i = 0; i< m; i++)w[i] = new int[]{ni(), ni()};
	    Arrays.sort(w, (int[] i1, int[] i2) -> Integer.compare(i1[1], i2[1]));
	    SegTree t = new SegTree(1+n);
	    long ans = (long)-IINF;
	    for(int i = n, ptr = m-1; i>= 1; i--){
	        int pt = ptr;
	        while(pt>=0 && w[pt][1]==i){
	            t.add(w[pt][0], -(w[pt][0]-1), f[w[pt][0]]);
	            pt--;
	        }
	        while(ptr>=0 && w[ptr][1]==i){
	            long x = s2[i];
	            long y = t.query(w[ptr][0], i, s1[i]);
	            ans = Math.max(ans, x+y);
	            ptr--;
	        }
	    }
	    pn(ans);
	}
	class SegTree{
	    int m = 1;
	    CHT[] t;
	    public SegTree(int n){
	        while(m<n)m<<=1;
	        t = new CHT[m<<1];
	        for(int i = 0; i< m<<1; i++)t[i] = new CHT(1);
	    }
	    void add(int p, long s, long c){
	        t[p+=m].add(s, c);
	        for(p>>=1; p>0; p>>=1)t[p].add(s, c);
	    }
	    long query(int l, int r, int ll, int rr, int i, long x){
	        if(l==ll && r==rr)return t[i].size()==0?-IINF:t[i].query(x);
	        int mid = (ll+rr)/2;
	        if(r<=mid)return query(l,r,ll,mid,i<<1,x);
	        else if(l>mid)return query(l,r,mid+1,rr,i<<1|1,x);
	        else return Math.max(query(l,mid,ll,mid,i<<1,x), query(mid+1,r,mid+1,rr,i<<1|1,x));
	    }
	    long query(int l, int r, long x){
	        return query(l,r,0,m-1,1,x);
	    }
	}
	class CHT {
	    //http://codeforces.com/contest/932/submission/35323630
	    public final int MIN = -1;
	    public TreeSet<CHT.Line> hull;
	    int type;
	    boolean query = false;
	    Comparator<CHT.Line> comp = new Comparator<CHT.Line>() {
	        public int compare(CHT.Line a, CHT.Line b) {
	            if (!query) return type * Long.compare(a.m, b.m);
	            if (a.left == b.left)
	                return Long.compare(a.m, b.m);
	            return Double.compare(a.left, b.left);
	        }
	    };
	    public CHT(final int type) {
	        this.type = type;
	        hull = new TreeSet<>(comp);
	    }
	    public void add(long m, long b) {
	        add(new CHT.Line(m, b));
	    }
	    public void add(CHT.Line a) {
	        CHT.Line[] LR = {hull.lower(a), hull.ceiling(a)};
	        for (int i = 0; i < 2; i++)
	            if (LR[i] != null && LR[i].m == a.m) {
	                if (type == 1 && LR[i].b >= a.b)
	                    return;
	                if (type == -1 && LR[i].b <= a.b)
	                    return;
	                remove(LR[i]);
	            }
	        hull.add(a);
	        CHT.Line L = hull.lower(a), R = hull.higher(a);
	        if (L != null && R != null && a.inter(R) <= R.left) {
	            hull.remove(a);
	            return;
	        }
	        CHT.Line LL = (L != null) ? hull.lower(L) : null;
	        CHT.Line RR = (R != null) ? hull.higher(R) : null;
	        if (L != null) a.left = a.inter(L);
	        if (R != null) R.left = a.inter(R);
	        while (LL != null && L.left >= a.inter(LL)) {
	            remove(L);
	            a.left = a.inter(L = LL);
	            LL = hull.lower(L);
	        }
	        while (RR != null && R.inter(RR) <= a.inter(RR)) {
	            remove(R);
	            RR.left = a.inter(R = RR);
	            RR = hull.higher(R);
	        }
	    }   
	    public long query(long x) {
	        CHT.Line temp = new CHT.Line(0, 0, 0);
	        temp.left = x;
	        query = true;
	        //Debug.print(x, temp, hull.floor(temp));
	        long ans = hull.floor(temp).eval(x);
	        query = false;
	        return ans;
	    }
	    private void remove(CHT.Line x) {
	        hull.remove(x);
	    }
	    public int size() {
	        return hull.size();
	    }
	    public class Line {
	        long m;
	        long b;
	        double left = Long.MIN_VALUE;
	        public Line(long m, long x, long y) {
	            this.m = m;
	            this.b = -m * x + y;
	        }
	        public Line(long m, long b) {
	            this.m = m;
	            this.b = b;
	        }
	        public long eval(long x) {
	            return m * x + b;
	        }
	        public double inter(CHT.Line x) {
	            return (double) (x.b - this.b) / (double) (this.m - x.m);
	        }
	    }
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new IAI().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to Share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

1 Like