ALR20F - EDITORIAL

PROBLEM LINK:

Practice

Author: Ritesh Gupta
Tester: Shubham Gupta
Editorialist: Romit Kumar

DIFFICULTY:

Hard

PREREQUISITES:

2-pointer , Segment Tree

PROBLEM:

Given an array we are required to find minimum length subarry such that by changing atmost X elements in that subarray, we can have total sum of array be equal to K.

EXPLANATION:

Given a subarray A_l,A_{l+1},...,A_r, we first need a way to check if by changing intensities of atmost X elements from that subarray only, can a total array sum of K be achieved. Let the total sum of array be S and sum of subarray be s. Besides, let the sum of smallest X elements(if number of elements is less than X all the elements of subarray is considered) in the subarray be smls and the sum of biggest X(if number of elements is less than X all the elements of subarray is considered) elements be bigs. It can be followed that the minimum sum of the array that can be achieved by changing atmost X elements in our subarray will be from = S - bigs + min(x, len of subarray). Similarly, maximum sum of the array that can be achieved by changing atmost X elements in our subarray will be to = S - smls + min(x, len of subarray)*n. So if from \leq K \leq to, we can achieve total array sum of K by changing atmost X elements in our subarray.

Now we need to find a way to minimise the subarray length. This can be done by two-pointer techinque. We will have two poninters, i and j, both starting from 0. This represents that A_i,A_{i+1},...,A_j is our current subarray under consideration. We will shift the j pointer right until we find a subarray which satisfies our condition. Then, we will shift i pointer right until our condition is not violated. All this while maintaining minimum length of subarray achieved.

The only thing left now is to find a way such that we can efficiently compute smallest and largest X elements in our subarray. This can be done using segment tree. The segment tree will initially be empty. When an element becomes part of our current subarray(due to j pointer shifting right), we will insert that element in our segment tree and when an element is removed from current subarray(due to i pointer shifting right), we will remove that element from the segment tree. Note that in the segment tree, element will be inserted at the position which the element occupies in sorted form of array A. For example, if an element e is at index i1 in A but on sorting A, e occupies index i2. So when e becomes part of our current subarray, we will make an update for index i2 in the segment tree. This way instead of finding sum of smallest X elements in the segment tree, we can equivalently find sum of leftmost X element in the segment tree. Similarly, sum of largest X elements will be equal to sum of rightmost X elements. This is because the elements in the segment tree are always in sorted order.

The sum of leftmost or rightmost X elements can be easily computed by maintainig a segment tree where each node stores 2 values, number of elements and sum of elements in the range represented by the node. As an example, consider we want to compute the sum of leftmost X element in the segment tree. Let the left child of root have values c1(count) and s1(sum) and right child have value c2(count) and s2(sum). If c1 \geq X, we can recursively query for sum of leftmost X elements at root->left. If c1 < X, the answer will be sum of s1 and result of query for sum of leftmost X-c1 elements at root->right.

SOLUTIONS:

Setter's Solution

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.InputMismatchException;
import java.io.;
import java.util.
;

class sol {

long[][] tree;
long[] a;

// void solve(String output) throws IOException {
void solve() throws IOException {
    // PrintWriter out = new PrintWriter(output);
    PrintWriter out = new PrintWriter(System.out);
    StringBuilder sb = new StringBuilder("");
    int t = ni();
    while(t-->0)
    {
        int n = ni(), x = ni();
        long k = nl();
        a = nb(n);
        long s = 0;
        for(long i:a) s += i;

        if(s==k)
        {
            out.println(0); continue;
        }
        
        data[] tmp = new data[n];
        for(int i=0;i<n;i++) tmp[i] = new data(a[i],i);
        Arrays.sort(tmp);
        HashMap<Integer, Integer> h = new HashMap<>();
        for(int i=0;i<n;i++) h.put(tmp[i].id, i);
        
        tree = new long[4*n][2];
        update(0,0,n-1,h.get(0),a[0],1);
        int res = Integer.MAX_VALUE;
        int i=0,j=0;
        
        o:while(i<n && j<n && i<=j)
        {
            long smls = query1(0,x);
            long bigs = query2(0,x);
            long from  = s - bigs + Math.min(x, j-i+1);
            long to = s - smls + 1l*Math.min(x, j-i+1)*n;
            
            if(!(from<=k && k<=to)) 
            {
                j++;
                if(j>=n) break;
                update(0,0,n-1,h.get(j),a[j],1);
                continue;
            }
            
            res = Math.min(res, j-i+1);
            if(res == 1) break;
            
            while(i<j)
            {
                update(0,0,n-1,h.get(i),-a[i],-1);
                smls = query1(0,x);
                bigs = query2(0,x);
                i++;
                from  = s - bigs + Math.min(x, j-i+1);
                to = s - smls + 1l*Math.min(x, j-i+1)*n;
                if(!(from<=k && k<=to))
                {
                    j++;
                    if(j>=n) break;
                    update(0,0,n-1,h.get(j),a[j],1);
                    break;
                }
                
                res = Math.min(res, j-i+1);
                if(res == 1) break o;
            }
            
        }
        if(res == Integer.MAX_VALUE) res = -1;
        out.println(res);
    }
    out.flush();
}

long query1(int node, long x)
{
    if(x<=0) return 0;
    if(tree[node][0] <= x) return tree[node][1];
    
    return query1(2*node+1, x) + query1(2*node+2, x-tree[2*node+1][0]);
}

long query2(int node, long x)
{
    if(x<=0) return 0;
    if(tree[node][0] <= x) return tree[node][1];
    
    return query2(2*node+1, x-tree[2*node+2][0]) + query2(2*node+2, x);
}

void update(int node, int s, int e, int x, long y, int yy)
{
    /*update index at x with value y*/
    if(s==e)
    {
        /*write code when tree[node] represents element at index x*/
       tree[node][0] += yy;
       tree[node][1] += y;
       return; 
    }
    
    int mid = (s+e)/2;
    if(s<=x && x<=mid)//updation element is in left subchild
        update(2*node+1, s, mid, x, y, yy);
    else//updation element is in right subchild
        update(2*node+2, mid+1, e, x, y, yy);
    tree[node][0] = tree[2*node+1][0] + tree[2*node+2][0];
    tree[node][1] = tree[2*node+1][1] + tree[2*node+2][1];
}

class data implements Comparable<data>
{
    long e;
    int id;

    public data(long e, int id) {
        this.e = e;
        this.id = id;
    }

    @Override
    public int compareTo(data o) {
        return (int)(this.e - o.e);
    }
}
public static void main(String[] args) throws IOException {
    // int files = 15;
    
    // for(int i=0;i<=files; i++)
    // {
    //     String ii = i+"";
    //     if(ii.length() == 1) ii = "0"+ii;
    //     String input = "input/input" + ii + ".txt";
    //     String output = "output/output" + ii + ".txt";
    //     is = new FileInputStream(input);
    //     new sol().solve(output);
    //     System.out.println("Done "+i);
    // }
    new sol().solve();
}    

private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;    
// static InputStream is;
static InputStream is = System.in;

private int readByte() {
    if (lenbuf == -1) {
        throw new InputMismatchException();
    }
    if (ptrbuf >= lenbuf) {
        ptrbuf = 0;
        try {
            lenbuf = is.read(inbuf);
        } catch (IOException e) {
            throw new InputMismatchException();
        }
        if (lenbuf <= 0) {
            return -1;
        }
    }
    return inbuf[ptrbuf++];
}

private boolean isSpaceChar(int c) {
    return !(c >= 33 && c <= 126);
}

private int skip() {
    int b;
    while ((b = readByte()) != -1 && isSpaceChar(b));
    return b;
}

private double nd() {
    return Double.parseDouble(ns());
}

private char nc() {
    return (char) skip();
}

private String ns() {
    int b = skip();
    StringBuilder sb = new StringBuilder();
    while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != ' ')
        sb.appendCodePoint(b);
        b = readByte();
    }
    return sb.toString();
}

private char[] ns(int n) {
    char[] buf = new char[n];
    int b = skip(), p = 0;
    while (p < n && !(isSpaceChar(b))) {
        buf[p++] = (char) b;
        b = readByte();
    }
    return n == p ? buf : Arrays.copyOf(buf, p);
}

private char[][] nm(int n, int m) {
    char[][] map = new char[n][];
    for (int i = 0; i < n; i++) {
        map[i] = ns(m);
    }
    return map;
}

private int[] na(int n) {
    int[] a = new int[n];
    for (int i = 0; i < n; i++) {
        a[i] = ni();
    }
    return a;
}

private int[] na1(int n) {
    int[] a = new int[n + 1];
    for (int i = 1; i < n + 1; i++) {
        a[i] = ni();
    }
    return a;
}

private long[] nb(int n) {
    long[] a = new long[n];
    for (int i = 0; i < n; i++) {
        a[i] = nl();
    }
    return a;
}

private long[] nb1(int n) {
    long[] a = new long[n + 1];
    for (int i = 1; i < n + 1; i++) {
        a[i] = nl();
    }
    return a;
}

private int ni() {
    int num = 0, b;
    boolean minus = false;
    while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
    if (b == '-') {
        minus = true;
        b = readByte();
    }
    
    while (true) {
        if (b >= '0' && b <= '9') {
            num = num * 10 + (b - '0');
        } else {
            return minus ? -num : num;
        }
        b = readByte();
    }
}

private long nl() {
    long num = 0;
    int b;
    boolean minus = false;
    while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
    if (b == '-') {
        minus = true;
        b = readByte();
    }
    
    while (true) {
        if (b >= '0' && b <= '9') {
            num = num * 10 + (b - '0');
        } else {
            return minus ? -num : num;
        }
        b = readByte();
    }
}

}

1 Like