Editorial-TBDY

PROBLEM LINK:

Taste-Buddy

Author: aditya2024
Tester: aditya2024, chef_hamster, chef_oggy
Editorialist: aditya2024

Prerequisites:

  • knowledge of two-pointers and its applications.

PROBLEM:

Bob is a food lover. Every day he enjoys his lunch break in his office canteen where n different food items are kept in a row. After eating the i^{th} food item, Bob gets a tastiness of a_{i}. The lunch break lasts for k minutes and it takes him exactly t_{i} minutes to finish the i^{th} item.
Bob does his lunch in the following manner :

  • He chooses a particular item and starts eating it.
  • He can either finish this item or decide to eat half of it. After that, he immediately moves to the next item. (Note that once Bob starts, he cannot skip any food item. Formally, at first, he begins with the food item i, then moves to (i+1), then (i+2), and so on).
  • If he decides to eat half of the i^{th} food item, he will spend exactly {\frac{t_{i}}{2}} minutes(rounded above) eating and will also get the tastiness of that food item. For example, if the time required to finish is 10 minutes, he will spend 5 minutes and if the time required is 7 minutes, he will spend 4 minutes.
  • The total number of food items that are partially eaten should not exceed x. Bob stops until the lunch break is over or he finishes eating the n^{th} item. Note that if the time remaining is less than {\frac{t_{n}}{2}} minutes(rounded above), Bob will not get the tastiness from it, and hence, it will not be a part of partially eaten food items.

Bob’s overall tastiness of his meal is the sum of the individual tastiness he gets from each food item.
Being Bob’s buddy, help him know the maximum tastiness he can get, eating no more than x items partly.


EXPLANATION:

To solve this problem, we can use two-pointers. Two sets will be created, one for the food items which are completely consumed, and the other for the food items, partially consumed. The left and right pointers will move accordingly and update the current answer. The left end of the current segment will be denoted as ‘l’ and the right end as ‘r’. The set of food items with partial time will be referred to as ‘half’, while the set of food items with full time will be referred to as ‘full’.
Both these sets will store pairs consisting of the time required to finish the item and its number. To move the right end, we will first check if we have enough time to partially finish the item and add it to the ‘half’ set, updating the time and answer accordingly. If we don’t have enough time, we will have two options: to add the current food item in the full or the partial set. The option that results in less total time will be chosen, and the food item will be added to the appropriate set. If the total time exceeds the lunch break duration ‘k’, the right pointer cannot be moved further. The ultimate answer will be updated with the current value of the answer.
To move the left end, if the food item was taken as a full set element, it will be removed from the ‘full’ set, and the total time will be decreased accordingly. If the item was taken as a partial set element, it will be removed from the ‘half’ set, and the total time will be decreased by half of its time. Then, we will attempt to take an item from the ‘full’ set, (if possible) and accordingly updating the total time for the current segment.
The time complexity of the solution will be O(N.log(N)), which will satisfy the constraints.

SOLUTIONS:

Setter's solution (JAVA)

  import com.sun.jdi.IntegerValue;

import java.util.;
import java.lang.
;
import java.io.*;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;

public class Main
{
static PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
static FastReader sc = new FastReader();

static long mod = (int)1e9+7;
static long mod2 = 998244353;
static class Pair implements Comparable<Pair>{
    int a, b;
    Pair(int a, int b){
        this.a=a;
        this.b=b;
    }
    public int compareTo(Pair o){
        return this.b-o.b;
    }
}
public static void main (String[] args) throws java.lang.Exception {
    int t = 1;//sc.nextInt();
    while (t-- > 0) {
        solve();
    }
}
public static void solve() {
    int n = i();
    long k = l();
    int limit = i();
    
    int[] a = ia(n);
    int[] s = ia(n);
    PriorityQueue<Integer> min = new PriorityQueue(new Comparator<Integer>(){
        public int compare(Integer b, Integer c) {
            return s[b]-s[c];
        }
    });
    PriorityQueue<Integer> max = new PriorityQueue(new Comparator<Integer>(){
        public int compare(Integer b, Integer c) {
            return s[c]-s[b];
        }
    });
    HashSet<Integer> set = new HashSet();
    long res = 0;
    long score = 0, time = 0;
    int l = 0, r = -1;
    while (r<n) {
        if (k>=0) {
            res = Math.max(res,score);
            r++;
            if (r==n) break;
            score += a[r];
            while (min.size()>0 && min.peek()<l) min.poll();
            if (min.size()==0 || r-l+1<=limit) {
                min.add(r);
                k -= s[r]/2 + s[r]%2;
                set.add(r);
            } else if (s[r]>s[min.peek()]) {
                min.add(r);
                k -= s[r]/2 + s[r]%2;
                int ind = min.poll();
                max.add(ind);
                int x = s[ind];
                set.add(r);
                set.remove(ind);
                k -= x/2;
            } else {
                max.add(r);
                k -= s[r];
            }
        } else {
            score -= a[l];
            boolean ad = false;
            while (min.size()>0 && min.peek()<l) min.poll();
            while (max.size()>0 && max.peek()<=l) max.poll();
            if (set.contains(l)) {
                k += s[l]/2+s[l]%2;
                set.remove(l);
                if (max.size()>0) {
                    int ind = max.poll();
                    min.add(ind);
                    set.add(ind);
                    int x = s[ind];
                    k += x;
                    k -= x/2+x%2;
                }
            } else {
                k += s[l];
            }
            l++;
        }
    }

    out.println(res);
    out.flush();
}

static int i() {
    return sc.nextInt();
}
static String s() {
    return sc.next();
}
static long l() {
    return sc.nextLong();
}
static int[] ia(int n){
    int[] arr= new int[n];
    for(int i = 0;i<n;++i){
        arr[i] = i();
    }
    return arr;
}

static class FastReader {
    BufferedReader br;
    StringTokenizer st;

    public FastReader()
    {
        br = new BufferedReader(
                new InputStreamReader(System.in));
    }

    String next()
    {
        while (st == null || !st.hasMoreElements()) {
            try {
                st = new StringTokenizer(br.readLine());
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
        return st.nextToken();
    }

    int nextInt() { return Integer.parseInt(next()); }

    long nextLong() { return Long.parseLong(next()); }

    double nextDouble()
    {
        return Double.parseDouble(next());
    }

    String nextLine()
    {
        String str = "";
        try {
            str = br.readLine();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return str;
    }
}

}

Tester's solution (C++)

                          //The more you know, the more you don't know//
#include  

#define loop(i,m,n) for(ll i=m;i<n;i++)
#define rev(i,m,n) for(ll i=m;i>=n;i–)
#define ll long long
#define in(n) cin>>n
#define out(n) cout<<n<<“\n”
#define o(n) cout<<n<<" "
#define nl cout<<“\n”
#define yes cout<<“YES”<<“\n”
#define no cout<<“NO”<<“\n”
#define pb push_back
#define ppb pop_back
#define size(x) x.size()
#define all(x) x.begin(),x.end()
#define rall(x) x.rbegin(),x.rend()
#define MAX(x) *max_element(all(x))
#define MIN(x) min_element(all(x))
#define SUM(x) accumulate(all(x), 0LL)
#define SORT(x) is_sorted(all(x))
#define lcm(x,y) (x
y)/__gcd(x,y)
#define dgt(n) floor(log10(n)+1)
#define point cout<<fixed<<setprecision(10)
#define debug(x) cout<<#x<<" “<<x<<”\n";
#define print(v) for(auto x : v) o(x); nl;
#define ub(v,val) upper_bound(all(v),val)-v.begin()
#define lb(v,val) lower_bound(all(v),val)-v.begin()
const long long M = 1e9 + 7;
const long long inf = LONG_LONG_MAX;

using namespace std;

//--------------------[…Never Used Functions…]---------------------//
ll binpow(ll a, ll b){ ll res=1; while(b>0){ if(b&1) res=resa; a=aa; b>>=1; } return res; }
ll powmod(ll a, ll b, ll m) {ll ans = 1; while (b) {if (b & 1) {ans = (ans * a) % m;} a = (a * a) % m; b = b >> 1;} return ans % m;}
vector div(ll n){ vector v; for(int i=1; i<=sqrt(n); i++){ if (n%i == 0){ v.pb(i); if(n/i!=i) v.pb(n/i); } } return v; }
vector primediv(ll n){ vector v; for(ll i=2;i<=sqrt(n);i++){ while(n%i==0){ n=n/i; v.pb(i); } } if(n>1) v.pb(n); return v; }
ll modINV(ll n,ll mod){ return powmod(n,mod-2,mod); }
ll ncrmod(ll n,ll r,ll p){ if(r==0) return 1; unsigned long long fac[n + 1]={1}; for(int i=1;i<=n;i++) fac[i]=(fac[i-1]*i)%p; return (((fac[n]*modINV(fac[r],p))%p)*modINV(fac[n-r],p))%p; }
double logg(ll n,ll k){ return log2(n)/log2(k); }
//---------------------------------------------------------------------//

/------------------------[…DSU…]------------------------//
ll parent[100007]; ll Size[100007];
void make(ll v){ parent[v]=v; Size[v]=1; }
ll find(ll v){ if(v==parent[v]) return v; return parent[v] = find(parent[v]); }
void Union(ll a,ll b){ a = find(a); b = find(b); if(a!=b){ if(Size[a] < Size[b]) swap(a,b); parent[b] = a; Size[a] += Size[b]; } }
//----------------------------------------------------------
/

void drizzle()
{
ll n,k,x; in(n>>k>>x);
vector v(n); loop(i,0,n) in(v[i]);
vector t(n); loop(i,0,n) in(t[i]);

vector<ll> used(n);
set<pair<ll,ll>> s;
ll left=0,time=0,cnt=0,res=0,ans=-1;
loop(right,0,n)
{
    // auto it = s.begin()
    if(cnt<x)
    {
        ll temp=(t[right]+1)/2;
        used[right]=temp;
        s.insert({t[right]-temp,right});
        cnt++;
    }
    else
    {
        ll temp=t[right]-(t[right]+1)/2;
        if(s.begin()->first<=temp)
        {
            used[s.begin()->second]=t[s.begin()->second];
            time+=s.begin()->first;
            s.erase(s.begin());
            used[right]=(t[right]+1)/2;
            s.insert({temp,right});
        }
        else
        {
            used[right]=t[right];
        }
    }
    res+=v[right];
    // debug(used[right])
    time+=used[right];

    // debug(time)
    while(time>k)
    {
        time-=used[left];
        s.erase({t[left]-used[left],left});
        if(used[left]!=t[left])     cnt--;
        res-=v[left];
        left++;
    }
    // debug(time)

    // debug(right)
    // debug(left)
    // debug(res)
    // debug(ans)
    ans=max(ans,res);
    // debug(ans)
    // nl;
}

out(ans);

return;

}

int main()
{
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);

int testcase=1;     
// cin>>testcase;

while(testcase--)
{
    drizzle();
}

return 0;

}

​ ​
Tester's solution (PYTHON)

import heapq
import sys

input = sys.stdin.readline
output = sys.stdout.write

class Pair:
def init(self, a, b):
self.a = a
self.b = b

def __lt__(self, other):
    return self.b < other.b

def solve():
[n,k,limit] = ia(3)

a = ia(n)
s = ia(n)
min_heap = []
max_heap = []
set_ = set()
res = score = time = l = 0
r = -1

while r < n:
    if k >= 0:
        res = max(res, score)
        r += 1

        if r == n:
            break

        score += a[r]

        while len(min_heap) > 0 and min_heap[0] < l:
            heapq.heappop(min_heap)

        if len(min_heap) == 0 or r - l + 1 <= limit:
            heapq.heappush(min_heap, r)
            k -= s[r] // 2 + s[r] % 2
            set_.add(r)

        elif s[r] > s[min_heap[0]]:
            heapq.heappush(min_heap, r)
            k -= s[r] // 2 + s[r] % 2
            ind = heapq.heappop(min_heap)
            heapq.heappush(max_heap, -ind)
            x = s[ind]
            set_.add(r)
            set_.remove(ind)
            k -= x // 2

        else:
            heapq.heappush(max_heap, -r)
            k -= s[r]

    else:
        score -= a[l]
        ad = False

        while len(min_heap) > 0 and min_heap[0] < l:
            heapq.heappop(min_heap)

        while len(max_heap) > 0 and -max_heap[0] <= l:
            heapq.heappop(max_heap)

        if l in set_:
            k += s[l] // 2 + s[l] % 2
            set_.remove(l)

            if len(max_heap) > 0:
                ind = -heapq.heappop(max_heap)
                heapq.heappush(min_heap, ind)
                set_.add(ind)
                x = s[ind]
                k += x
                k -= x // 2 + x % 2

        else:
            k += s[l]

        l += 1

output(f"{res}\n")

def i():
return int(input())

def s():
return input().strip()

def l():
return int(input())

def ia(n):
return list(map(int, input().split()))[:n]

if name == “main”:
t = 1 # i()
for _ in range(t):
solve()