STRENGTH - Editorial

PROBLEM LINK:

Strong Array

Author: aditya2024
Tester: mughda273, aditya2024
Editorialist: aditya2024

DIFFICULTY:

EASY-MEDIUM

PROBLEM:

You are given an array A of N non-negative integers.
Let A(i, j) denote the subarray
{\{A[i], A[i+1], ......, A[j] \}},
1 \leqslant \mathcal{i} \leqslant \mathcal{j} \leqslant N.

  • The maxshift of any subarray A(i, j) is denoted by D(i, j) and is defined as:
    D(i, j) = \sum_{k = i}^{k = j} ( max - A[ k ] ), where max denotes the maximum element in A(i, j).

  • The minshift of any subarray A(i, j) is denoted by d(i, j) and is defined as: $\$
    d(i, j) = \sum_{k = i}^{k = j} ( A[ k ] - min ), where min denotes the minimum element in A(i, j).

  • We also define the strength of A(i, j) as:
    S(i, j) = D(i, j) + d(i, j) + \sum_{k = i}^{k = j} A[ k ]

You need to answer Q queries of the following types:

Type 1: 1 x y ( 1 \leqslant \mathcal{x} \leqslant N )
Change the element at position \mathcal{x} to \mathcal{y}. Note that this change will persist.

Type 2: 2 \ell \mathcal{r} ( 1 \leqslant \ell \leqslant \mathcal{r} \leqslant N )
Output S(\ell, \mathcal{r}).

Prerequisites:

  • knowledge of segment tree and its applications.

QUICK EXPLANATION:

The expression for the strength of the array can be reduced to :
( r - l + 1 ) * ( max - min ) + sum
All these three parameters can be found using a segment tree.

EXPLANATION:

First of all, let’s analyze the expression for strength. Let the array elements be {A[1], A[2], ....A[n]}. For any subarray {A[i], A[i+1], ...., A[j]}, the maxshift can be written as :

D(i, j) = max-A[i] + max-A[i+1] + max-A[i+2] ...... + max-A[j]

D(i, j) = max*(j-i+1) - \sum_{k = i}^{k = j} A[ k ]

Similarly, the expression of minshift can be written as :

d(i, j) = A[i] -min+A[i+1]-min +A[i+2] -min...... + A[j]-min

d(i, j) = \sum_{k = i}^{k = j} A[ k ] -min*(j-i+1)

Therefore,

S(i, j) = max*(j-i+1) -min*(j-i+1) + \sum_{k = i}^{k = j} A[ k ] .

S(i, j) = (max-min)*(j-i+1) + \sum_{k = i}^{k = j} A[ k ] .

So, for each thpe 2 query (which represents a subarray), we need three parameters - min, max and the sum of that subarray. We cannot iterate over the segment to get these values because the overall time complexity of the solution would be O(Q*N) , which will give TLE.
Instead, we can use segment trees to get these values and the overall time complexity will be reduced to O(Q*logN) which will satisfy the given constraints.

SOLUTIONS:

Setter's Solution
import java.util.*;
import java.lang.*;
import java.io.*;

public class Main
{
   static PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
   static FastReader sc = new FastReader();

   static long mod = (int)1e9+7;
   static long mod2 = 998244353;

   public static void main (String[] args) throws java.lang.Exception
   {
       int t = 1;//sc.nextInt();
       while(t-->0){
           solve();
       }
   }
   public static void solve() {
       int n = i();
       int q = i();
       int[] arr = new int[n];
       for(int i = 0;i<n;++i){
           arr[i] = i();
       }

       int[] treeMin = constructSTMin(arr, n);
       int[] treeMax = constructSTMax(arr, n);
       int[] treeSum = constructSTSum(arr, n);

       while(q-->0){
           int a = i(), b = i(), c = i();
           if(a==1){
               int oldval = arr[b-1];
               int newval = c;
               updateValueMax(arr, treeMax, 0, n-1, b-1, newval , 0);
               arr[b-1] = oldval;
               updateValueMin(arr, treeMin, 0, n-1, b-1, newval, 0);
               arr[b-1] = oldval;
               updateValueSum(arr, treeSum, b-1, newval);
               continue;
           }

           int qs = b-1, qe = c-1;
           long ans = qe-qs+1;
           long max = getMax(treeMax, n, qs, qe);
           long min = getMin(treeMin, n, qs, qe);
           long sum = getSum(treeSum, n, qs, qe);
           ans*=(max-min);
           ans+=sum;
           out.println(ans);
       }

       out.flush();
   }

   static int getMid(int s, int e)
   {
       return s + (e - s) / 2;
   }

   static int getMax(int[] st, int n, int l, int r)
   {
       if (l < 0 || r > n - 1 || l > r) {
           return -1;
       }
       return MaxUtil(st, 0, n - 1, l, r, 0);
   }
   static int getMin(int[] st, int n, int l, int r)
   {
       if (l < 0 || r > n - 1 || l > r) {
           return -1;
       }
       return MinUtil(st, 0, n - 1, l, r, 0);
   }
   static int getSum(int[] st, int n, int l, int r)
   {
       if (l < 0 || r > n - 1 || l > r) {
           return -1;
       }
       return SumUtil(st, 0, n - 1, l, r, 0);
   }

   static int MaxUtil(int[] st, int ss, int se, int l, int r, int node)
   {
       if (l <= ss && r >= se)
           return st[node];

       if (se < l || ss > r)
           return Integer.MIN_VALUE;

       int mid = getMid(ss, se);

       return Math.max(MaxUtil(st, ss, mid, l, r, 2 * node + 1), MaxUtil(st, mid + 1, se, l, r, 2 * node + 2));
   }
   static int MinUtil(int[] st, int ss, int se, int l, int r, int node)
   {
       if (l <= ss && r >= se)
           return st[node];

       if (se < l || ss > r)
           return Integer.MAX_VALUE;

       int mid = getMid(ss, se);

       return Math.min(MinUtil(st, ss, mid, l, r, 2 * node + 1), MinUtil(st, mid + 1, se, l, r, 2 * node + 2));
   }
   static int SumUtil(int[] st, int ss, int se, int l, int r, int node)
   {
       if (l <= ss && r >= se)
           return st[node];

       if (se < l || ss > r)
           return 0;

       int mid = getMid(ss, se);

       return SumUtil(st, ss, mid, l, r, 2 * node + 1) + SumUtil(st, mid + 1, se, l, r, 2 * node + 2);
   }
   static void updateValueMax(int[] arr, int[] st, int ss, int se, int index, int value, int node)
   {
       if (index < ss || index > se) {
           return;
       }

       if (ss == se) {
           arr[index] = value;
           st[node] = value;
       }
       else {
           int mid = getMid(ss, se);
           if (index <= mid) updateValueMax(arr, st, ss, mid, index, value, 2 * node + 1);
           else
               updateValueMax(arr, st, mid + 1, se, index, value, 2 * node + 2);

           st[node] = Math.max(st[2 * node + 1], st[2 * node + 2]);
       }
   }
   static void updateValueMin(int[] arr, int[] st, int ss, int se, int index, int value, int node)
   {
       if (index < ss || index > se) {
           return;
       }

       if (ss == se) {
           arr[index] = value;
           st[node] = value;
       }
       else {
           int mid = getMid(ss, se);
           if (index <= mid) updateValueMin(arr, st, ss, mid, index, value, 2 * node + 1);
           else
               updateValueMin(arr, st, mid + 1, se, index, value, 2 * node + 2);

           st[node] = Math.min(st[2 * node + 1], st[2 * node + 2]);
       }
   }
   static void updateValueSum(int[] arr, int[] st, int i, int new_val)
   {
       int n = arr.length;
       if (i < 0 || i > n - 1) {
           return;
       }

       int diff = new_val - arr[i];

       arr[i] = new_val;

       updateValueUtilSum(0, n - 1, i, diff, 0, st);
   }
   static void updateValueUtilSum(int ss, int se, int i, int diff, int si, int[] st)
   {
       if (i < ss || i > se)
           return;

       st[si] = st[si] + diff;
       if (se != ss) {
           int mid = getMid(ss, se);
           updateValueUtilSum(ss, mid, i, diff, 2 * si + 1, st);
           updateValueUtilSum(mid + 1, se, i, diff, 2 * si + 2, st);
       }
   }
   static int constructSTMaxUtil(int[] arr, int ss, int se, int[] st, int si)
   {
       if (ss == se) {
           return st[si] = arr[ss];
       }

       int mid = getMid(ss, se);

       return st[si] = Math.max(constructSTMaxUtil(arr, ss, mid, st, si * 2 + 1), constructSTMaxUtil(arr, mid + 1, se, st, si * 2 + 2));
   }
   static int constructSTMinUtil(int[] arr, int ss, int se, int[] st, int si)
   {
       if (ss == se) {
           return st[si] = arr[ss];
       }

       int mid = getMid(ss, se);

       return st[si] = Math.min(constructSTMinUtil(arr, ss, mid, st, si * 2 + 1), constructSTMinUtil(arr, mid + 1, se, st, si * 2 + 2));
   }
   static int constructSTSumUtil(int[] arr, int ss, int se, int[] st, int si)
   {
       if (ss == se) {
           return st[si] = arr[ss];
       }

       int mid = getMid(ss, se);

       return st[si] = constructSTSumUtil(arr, ss, mid, st, si * 2 + 1) + constructSTSumUtil(arr, mid + 1, se, st, si * 2 + 2);
   }

   static int[] constructSTMax(int[] arr, int n)
   {
       int x = (int)Math.ceil(Math.log(n) / Math.log(2));

       int max_size = 2 * (int)Math.pow(2, x) - 1;

       int[] st = new int[max_size];

       constructSTMaxUtil(arr, 0, n - 1, st, 0);

       return st;
   }
   static int[] constructSTMin(int[] arr, int n)
   {
       int x = (int)Math.ceil(Math.log(n) / Math.log(2));

       int max_size = 2 * (int)Math.pow(2, x) - 1;

       int[] st = new int[max_size];

       constructSTMinUtil(arr, 0, n - 1, st, 0);

       return st;
   }
   static int[] constructSTSum(int[] arr, int n)
   {
       int x = (int)Math.ceil(Math.log(n) / Math.log(2));

       int max_size = 2 * (int)Math.pow(2, x) - 1;

       int[] st = new int[max_size];

       constructSTSumUtil(arr, 0, n - 1, st, 0);

       return st;
   }
   static int i() {
       return sc.nextInt();
   }
   static String s() {
       return sc.next();
   }
   static long l() {
       return sc.nextLong();
   }
   static class FastReader {
       BufferedReader br;
       StringTokenizer st;

       public FastReader()
       {
           br = new BufferedReader(
                   new InputStreamReader(System.in));
       }

       String next()
       {
           while (st == null || !st.hasMoreElements()) {
               try {
                   st = new StringTokenizer(br.readLine());
               }
               catch (IOException e) {
                   e.printStackTrace();
               }
           }
           return st.nextToken();
       }

       int nextInt() { return Integer.parseInt(next()); }

       long nextLong() { return Long.parseLong(next()); }

       double nextDouble()
       {
           return Double.parseDouble(next());
       }

       String nextLine()
       {
           String str = "";
           try {
               str = br.readLine();
           }
           catch (IOException e) {
               e.printStackTrace();
           }
           return str;
       }
   }
}
Tester's Solution
#include <bits/stdc++.h>
// #include<ext/pb_ds/assoc_container.hpp>
// #include<ext/pb_ds/tree_policy.hpp>
using namespace std;
// using namespace __gnu_pbds;
#define ll long long
#define int long long
const long long M = 1e9 + 7;
const long long SZ = 100000;
const long long inf = 1e18;
#define all(a) (a).begin(), (a).end()
#define pb push_back
#define vi vector<int>
#define fastio() ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL)
#ifndef ONLINE_JUDGE
#define debug(x) cerr << #x <<" "; cerr<<x; cerr << endl;
#else
#define debug(x)
#endif
#define de(m) cout << #m << " " << m << endl;
#define x(m) cout << (m) << " ";
#define vll vector<ll>
#define nl cout << "\n";
#define fr(i, a, b) for (ll i = a; i < b; i++)
//*-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------*//
ll gcd(ll a, ll b) {if (a == 0)return b; return gcd(b % a, a);}
bool isprime(ll n) {if (n <= 1)return false; for (ll i = 2; i <= sqrt(n); i++)if (n % i == 0)return false; return true;}
ll binex(ll a, ll b, ll m) {if (b == 0) {return 1;} ll ans = 1; while (b) {if (b & 1) {ans = (ans * a) % m;} a = (a * a) % m; b = b >> 1;} return ans % m;}
ll phi(ll n) {ll result = n; for (ll i = 2; i * i <= n; i++) {if (n % i == 0) {while (n % i == 0)n /= i; result -= result / i;}} if (n > 1)result -= result / n; return result;}
void all_divisor(ll n, vll& v) {/*returns unsorted vector*/for (ll i = 1; i <= sqrt(n); i++) {if (n % i == 0) {v.pb(i); if (n / i != i) {v.pb(n / i);}}}}
void prime_divisor_map(ll n, map<ll, ll>& m, ll& cnt) {for (ll i = 2; i * i <= n; i++) {if (n % i == 0) {while (n % i == 0) {n = n / i;  m[i]++; cnt++;}}} if (n > 1) {m[n]++; cnt++;}}
void prime_divisor(ll n, vll& v) {for (ll i = 2; i * i <= n; i++) {if (n % i == 0) {v.pb(i); while (n % i == 0) {n = n / i;}}} if (n > 1)v.pb(n);}
void prime_generator(ll n, vll& v) {ll arr[n + 1]; fr(i, 0, n + 1)arr[i] = 1; arr[0] = 0; arr[1] = 0; for (ll i = 2; i * i <= n; i++) {if (arr[i] == 1)for (ll j = i * i; j <= n; j += i) {arr[j] = 0;}} for (ll i = 0; i <= n; i++) {if (arr[i] == 1) {v.pb(i);}}}
ll ex_euclidean(ll a, ll b, ll& x, ll& y) {if (b == 0) {x = 1; y = 0; return a;} ll x1, y1; ll d = ex_euclidean(b, a % b, x1, y1); x = y1; y = x1 - (y1 * (a / b)); return d;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
ll getRandomNumber(ll l, ll r) {return uniform_int_distribution<ll>(l, r)(rng);}
//*--------------------------------------------------------------------------JAY SHREE RAM--------------------------------------------------------------------------------*//
int getMid(int s, int e)
{
   return s + (e - s) / 2;
}
int sumutil(vll &st, int ss, int se, int l, int r, int node)
{
   if (l <= ss && r >= se)
       return st[node];

   if (se < l || ss > r)
       return 0;

   int mid = getMid(ss, se);

   return sumutil(st, ss, mid, l, r, 2 * node + 1) + sumutil(st, mid + 1, se, l, r, 2 * node + 2);
}

int getsum(vll &st, int n, int l, int r)
{
   if (l < 0 || r > n - 1 || l > r) {
       return -1;
   }
   return sumutil(st, 0, n - 1, l, r, 0);
}
int maxutil(vll &st, int ss, int se, int l, int r, int node)
{
   if (l <= ss && r >= se)
   { return st[node];}

   if (se < l || ss > r)
   {return INT_MIN;}

   int mid = getMid(ss, se);

   return max(maxutil(st, ss, mid, l, r, 2 * node + 1), maxutil(st, mid + 1, se, l, r, 2 * node + 2));
}
int minutil(vll &st, int ss, int se, int l, int r, int node)
{
   if (l <= ss && r >= se)
       return st[node];

   if (se < l || ss > r)
       return INT_MAX;

   int mid = getMid(ss, se);

   return min(minutil(st, ss, mid, l, r, 2 * node + 1), minutil(st, mid + 1, se, l, r, 2 * node + 2));
}

void updatevalueutilsum(ll ss, ll se, ll i, ll diff, ll si, vll &st)
{
   if (i < ss || i > se)
       return;

   st[si] = st[si] + diff;
   if (se != ss) {
       int mid = getMid(ss, se);
       updatevalueutilsum(ss, mid, i, diff, 2 * si + 1, st);
       updatevalueutilsum(mid + 1, se, i, diff, 2 * si + 2, st);
   }
}
void updatevaluesum(vll &arr, vll &st, ll i, ll new_val)
{
   ll n = arr.size();
   if (i < 0 || i > n - 1) {
       return;
   }

   ll diff = new_val - arr[i];

   arr[i] = new_val;

   updatevalueutilsum(0, n - 1, i, diff, 0, st);
}
void updatevaluemax(vll &arr, vll &st, int ss, int se, int index, int value, int node)
{
   if (index < ss || index > se) {
       return;
   }

   if (ss == se) {
       arr[index] = value;
       st[node] = value;
   }
   else {
       int mid = getMid(ss, se);
       if (index <= mid) updatevaluemax(arr, st, ss, mid, index, value, 2 * node + 1);
       else
           updatevaluemax(arr, st, mid + 1, se, index, value, 2 * node + 2);

       st[node] = max(st[2 * node + 1], st[2 * node + 2]);
   }
}
void updatevaluemin(vll &arr, vll &st, int ss, int se, int index, int value, int node)
{
   if (index < ss || index > se) {
       return;
   }

   if (ss == se) {
       arr[index] = value;
       st[node] = value;
   }
   else {
       int mid = getMid(ss, se);
       if (index <= mid) updatevaluemin(arr, st, ss, mid, index, value, 2 * node + 1);
       else
           updatevaluemin(arr, st, mid + 1, se, index, value, 2 * node + 2);

       st[node] = min(st[2 * node + 1], st[2 * node + 2]);
   }
}

int constructstmaxutil(vll &arr, ll ss, int se, vll &st, int si)
{
   if (ss == se) {
       return st[si] = arr[ss];
   }

   ll mid = getMid(ss, se);

   return st[si] = max(constructstmaxutil(arr, ss, mid, st, si * 2 + 1), constructstmaxutil(arr, mid + 1, se, st, si * 2 + 2));
}
int constructstminutil(vll &arr, int ss, int se, vll &st, int si)
{
   if (ss == se) {
       return st[si] = arr[ss];
   }

   int mid = getMid(ss, se);

   return st[si] = min(constructstminutil(arr, ss, mid, st, si * 2 + 1), constructstminutil(arr, mid + 1, se, st, si * 2 + 2));
}
int constructstsumutil(vll &arr, int ss, int se, vll &st, int si)
{
   if (ss == se) {
       return st[si] = arr[ss];
   }

   int mid = getMid(ss, se);

   return st[si] = constructstsumutil(arr, ss, mid, st, si * 2 + 1) + constructstsumutil(arr, mid + 1, se, st, si * 2 + 2);
}
vll constructstmax(vll &arr, int n)
{
   int x = (int)ceil(log(n) / log(2));

   int max_size = 2 * (int)pow(2, x) - 1;

   vll st(max_size);

   constructstmaxutil(arr, 0, n - 1, st, 0);

   return st;
}
vll constructstmin(vll &arr, int n)
{
   int x = (int)ceil(log(n) / log(2));

   int max_size = 2 * (int)pow(2, x) - 1;

   vll st(max_size);

   constructstminutil(arr, 0, n - 1, st, 0);

   return st;
}
vll constructstsum(vll &arr, int n)
{
   int x = (int)ceil(log(n) / log(2));

   int max_size = 2 * (int)pow(2, x) - 1;

   vll st(max_size);

   constructstsumutil(arr, 0, n - 1, st, 0);

   return st;
}
int getmax(vll &st, int n, int l, int r)
{
   if (l < 0 || r > n - 1 || l > r) {
       return -1;
   }

   return maxutil(st, 0, n - 1, l, r, 0);
}
int getmin(vll &st, int n, int l, int r) {
   if (l < 0 || r > n - 1 || l > r) {
       return -1;
   }


   return minutil(st, 0, n - 1, l, r, 0);
}



void Champion_Patel()
{

   ll n, q;
   cin >> n >> q;

   vll arr(n);
   fr(i, 0, n)cin >> arr[i];

   vll treemin = constructstmin(arr, n);


   vll treemax = constructstmax(arr, n);

   vll treesum = constructstsum(arr, n);


   while (q--) {
       ll a, b, c;
       cin >> a >> b >> c;
       if (a == 1) {
           ll oldval = arr[b - 1];
           ll newval = c;
           updatevaluemax(arr, treemax, 0, n - 1, b - 1, newval, 0);
           arr[b - 1] = oldval;
           updatevaluemin(arr, treemin, 0, n - 1, b - 1, newval, 0);
           arr[b - 1] = oldval;
           updatevaluesum(arr, treesum, b - 1, newval);
           continue;
       }

       ll qs = b - 1, qe = c - 1;
       ll ans = qe - qs + 1;

       ll max = getmax(treemax, n, qs, qe);
       ll min = getmin(treemin, n, qs, qe);
       ll sum = getsum(treesum, n, qs, qe);

       ans *= (max - min);
       ans += sum;
       cout << ans << endl;



   }




}

int32_t main() {

   fastio();

   int T;

   T = 1;

   //cin >> T;

   while (T--)
   {
       Champion_Patel();
   }

}
1 Like