COVAXIN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Utkarsh Gupta
Tester: Istvan Nagy
Editorialist: Aman Dwivedi

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Binary Search

PROBLEM:

There are two types of vaccines available: Covaxin and Covishield.

A black marketeer has X coins and wants to buy as many vaccines as possible. Due to the black marketing concerns, the government has enforced the following policy:

  • i^{th} dose of Covaxin costs a + (i - 1)\cdot b coins for every i \geq 1.

  • i^{th} dose of Covishield costs c + (i - 1)\cdot d coins for every i \geq 1.

The values of the four parameters a, b, c and d, however, aren’t constant and might vary from query to query. In general the value of these four parameters for i^{th} query will be A_i, B_i, C_i and D_i respectively.
Let ans_i be the maximum total quantity of vaccines the black marketeer can buy corresponding to the i^{th} query. For each query, you have to find the value of ans_i.

You will be given integers A_1, B_1, C_1, D_1, P, Q, R, S, T and M which will define the queries to be followed.

For i \geq 1 and i \leq I - 1:

  • A_{i+1} = (A_i + ans_i\cdot T) \bmod M + P
  • B_{i+1} = (B_i + ans_i\cdot T) \bmod M + Q
  • C_{i+1} = (C_i + ans_i\cdot T) \bmod M + R
  • D_{i+1} = (D_i + ans_i\cdot T) \bmod M + S .

EXPLANATION:

Let us forget about the number of queries for some time and just focus on finding the number of vaccines we can buy when a, b, c, and d are given to us.

Suppose that we only need to buy one type of vaccine say Covaxin. Then doesn’t the problem looks easy now, We can simply binary search on the number of Covaxin we can buy if we have X coins.

How we can find that by using binary search?

Look into the cost given to us for the i^{th} dose of Covaxin:

a+(i-1)*b

This is the formula for finding the i^{th} term when numbers are in AP. Hence we know the first term of an AP and the common difference of AP. Therefore by using the binary search we can easily find the number of Covaxin we can buy. The condition it should follow should be:

\sum_{i=1}^{mid} Sum \le X

where Sum is the cost to purchase the first mid doses of Covaxin.

Let us say that we can C_1 doses of Covaxin with X coins available. Hence the total number of vaccines we have with us is C_1. Now, let us introduce our second type of vaccine Covidsheild into the problem.

As our goal is to maximize the total number of vaccines we can purchase. The problem looks like that we can sell some doses of the Covaxin vaccine and purchase some Covidsheild vaccine. It is optimal to sell expensive doses of Covaxin that we have and purchase cheap Covidsheild vaccines if possible.

Now the question that arises in our head?

  • How many Covaxin doses that we had purchased needs to be sold?
  • How many Covidsheild doses do we need to buy?

Again we can binary search on the number of Covaxin that we need to sell and simply see that after selling some doses of Covaxin and purchasing doses of Covaxin increases the total number of doses. If it does then we can sell more Covaxin and purchase more Covidsheild.

Finally, we will get the maximum number of doses of any vaccine that we can buy.

Now for other queries, we can simply update the values of a,b,c, and d and again perform the same method to find the number of doses that we can buy.

TIME COMPLEXITY:

O(log^2(C)) per query

where C \approx min (X/A, \sqrt{2 * X / B})

SOLUTIONS:

Author's
//Utkarsh.25dec
#include <bits/stdc++.h>
#include <chrono>
#include <random>
#define ll long long int
#define ull unsigned long long int
#define pb push_back
#define mp make_pair
#define rep(i,n) for(ll i=0;i<n;i++)
#define loop(i,a,b) for(ll i=a;i<=b;i++)
#define vi vector <int>
#define vs vector <string>
#define vc vector <char>
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
#define max3(a,b,c) max(max(a,b),c)
#define min3(a,b,c) min(min(a,b),c)
#define deb(x) cerr<<#x<<' '<<'='<<' '<<x<<'\n'
using namespace std;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
typedef vector<vector<ll>> matrix;
const int N=500023;
bool vis[N];
vector <int> adj[N];
void solve()
{
    ll qu;
    cin>>qu;
    ll x,a,b,c,d;
    cin>>x>>a>>b>>c>>d;
    ll p,q,rr,s,t,m;
    cin>>p>>q>>rr>>s>>t>>m;
    ll mod=m;
    t%=mod;
    while(qu--)
    {
        // cout<<a<<' '<<b<<' '<<c<<' '<<d<<'\n';
        ll l=0,r=x;
        ll ans=0;
        while(l<=r)
        {
            ll mid=(l+r)/2;
            ll Q1=(mid-a)/b+1;
            if(mid<a)
                Q1=0;
            ll Q2=(mid-c)/d+1;
            if(mid<c)
                Q2=0;
            if(max(Q1,Q2)>2e9)
            {
                r=mid-1;
                continue;
            }
            if((Q1-1)>((ll)5e18)/b)
            {
                r=mid-1;
                continue;
            }
            if((Q2-1)>((ll)5e18)/d)
            {
                r=mid-1;
                continue;
            }
            // cout<<mid<<' '<<Q1<<' '<<Q2<<'\n';
            ll temp=(2*a+(Q1-1)*b);
            if(Q1>0 && Q1>((ll)5e18)/temp)
            {
                r=mid-1;
                continue;
            }
            // cout<<mid<<' '<<Q1<<' '<<Q2<<'\n';
            ll sumA=(Q1*temp)/2;
            if(sumA>x)
            {
                r=mid-1;
                continue;
            }
            // cout<<mid<<' '<<Q1<<' '<<Q2<<' '<<temp<<'\n';
            temp=(2*c+(Q2-1)*d);
            if(Q2>0 && Q2>((ll)5e18)/temp)
            {
                r=mid-1;
                continue;
            }
            ll sumB=(Q2*temp)/2;
            if(sumB>x)
            {
                r=mid-1;
                continue;
            }
            if((sumA+sumB)>x)
            {
                if((a+(Q1-1)*b)==mid && (c+(Q2-1)*d)==mid)
                {
                    if((sumA+sumB-mid)<=x)
                    {
                        ans=max(ans,Q1+Q2-1);
                    }
                }
                r=mid-1;
                continue;
            }
            else
            {
                ans=max(ans,Q1+Q2);
                l=mid+1;
                continue;
            }
        }
        cout<<ans<<'\n';
        ll tmp=ans%mod;
        tmp*=t;
        tmp%=mod;
        a=(a+tmp)%mod+p;
        b=(b+tmp)%mod+q;
        c=(c+tmp)%mod+rr;
        d=(d+tmp)%mod+s;
    }
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int T=1;
    // cin>>T;
    int t=0;
    while(t++<T)
    {
        //cout<<"Case #"<<t<<":"<<' ';
        solve();
        //cout<<'\n';
    }
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester
#include <iostream>
#include <cassert>
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include <random>

#ifdef HOME
#include <windows.h>
#endif

#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define forn(i, n) for (int i = 0; i < (int)(n); ++i)
#define for1(i, n) for (int i = 1; i <= (int)(n); ++i)
#define ford(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define fore(i, a, b) for (int i = (int)(a); i <= (int)(b); ++i)

template<class T> bool umin(T& a, T b) { return a > b ? (a = b, true) : false; }
template<class T> bool umax(T& a, T b) { return a < b ? (a = b, true) : false; }

using namespace std;

long long readInt(long long l, long long r, char endd) {
	long long x = 0;
	int cnt = 0;
	int fi = -1;
	bool is_neg = false;
	while (true) {
		char g = getchar();
		if (g == '-') {
			assert(fi == -1);
			is_neg = true;
			continue;
		}
		if ('0' <= g && g <= '9') {
			x *= 10;
			x += g - '0';
			if (cnt == 0) {
				fi = g - '0';
			}
			cnt++;
			assert(fi != 0 || cnt == 1);
			assert(fi != 0 || is_neg == false);

			assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
		}
		else if (g == endd) {
			assert(cnt > 0);
			if (is_neg) {
				x = -x;
			}
			assert(l <= x && x <= r);
			return x;
		}
		else {
			assert(false);
		}
	}
}

string readString(int l, int r, char endd) {
	string ret = "";
	int cnt = 0;
	while (true) {
		char g = getchar();
		assert(g != -1);
		if (g == endd) {
			break;
		}
		cnt++;
		ret += g;
	}
	assert(l <= cnt && cnt <= r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l, r, ' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l, r, '\n');
}
string readStringLn(int l, int r) {
	return readString(l, r, '\n');
}
string readStringSp(int l, int r) {
	return readString(l, r, ' ');
}

uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t m)
{
	assert(a < m&& b < m && 0 < m && m < (1ull << 63));
	long double x = a;
	uint64_t c(x * b / m);
	int64_t r = int64_t(a * b - c * m) % int64_t(m);
	return r < 0 ? r + m : r;
}

bool isMulOverflow(uint64_t a, uint64_t b)
{
	uint64_t x = a * b;
	if (a != 0 && x / a != b) {
		return true;
	}
	return false;
}

bool isSmaller(uint64_t X, uint64_t A, uint64_t B, uint64_t mul, uint64_t& res)
{//A*mul+B*mul*(mul-1)/2<=X
	//check A*mul > X ?
	if (isMulOverflow(A, mul))
		return false;
	A *= mul;
	if (A > X)
		return false;
	if (isMulOverflow(mul-1, mul))
		return false;
	uint64_t tmp = mul * (mul - 1);
	tmp /= 2;

	if (isMulOverflow(B, tmp))
		return false;

	B *= tmp;
	res = A + B;
	return res <= X;
}

uint64_t solveSimple(uint64_t X, uint64_t A, uint64_t B)
{
	//X = B * k * (k-1)/2 => 2*X/B=k*(k-1) => k^2 -k -val = 0 => k = 1 + sqrt(1+4*val)/2 ~ 1/2 + sqrt(val)
	uint64_t low = 0, hi = min<uint64_t>(X / A , sqrt(2 * X / B));
	uint64_t tmp = 0;
	while (isSmaller(X, A, B, hi, tmp))
		++hi;
	uint64_t res = 0;

	while (low < hi)
	{
		uint64_t m = (low + hi + 1) / 2;
		uint64_t tmp = A * m + B * m * (m - 1) / 2;
		if (tmp <= X)
		{
			low = m;
		}
		else
		{
			hi = m - 1;
		}
	}
	return low;
}

uint64_t solve(uint64_t X, uint64_t A, uint64_t B, uint64_t C, uint64_t D)
{
	uint64_t res = 0;

	uint64_t TA = solveSimple(X, A, B);
	uint64_t TB = solveSimple(X, C, D);

	uint64_t low = 0;
	uint64_t hi = TA;

	const auto f = [&](uint64_t mid) {
		uint64_t t1 = A * mid + B * mid * (mid - 1) / 2;
		uint64_t t2 = solveSimple(X - t1, C, D);
		uint64_t t1r = mid + t2;

		uint64_t t3 = t1 + A + B * mid;
		uint64_t t4 = 0;
		if (t3 <= X)
			t4 = solveSimple(X - t3, C, D);
		uint64_t t3r = t3 <= X ? t4 + mid + 1 : 0;
		if (t1r == t3r)
		{
			uint64_t v1 = t1 + (C * t2 + D * t2 * (t2 - 1) / 2);
			uint64_t v2 = t3 + (C * t4 + D * t4 * (t4 - 1) / 2);
			if (v1 < v2)
				--t3r;
			else
				--t1r;
		}
		return make_pair(t1r, t3r);
	};

	while (low + 1 < hi)
	{
		uint64_t mid = (low + hi) / 2;
		auto res = f(mid);

		if (res.first > res.second)
		{
			hi = mid;
		}
		else
		{
			low = mid + 1;
		}
	}
	auto tmp = f(low);
	return max<uint64_t>(tmp.first, tmp.second);
}

int main(int argc, char** argv)
{
#ifdef HOME
	if (IsDebuggerPresent())
	{
		freopen("../in.txt", "rb", stdin);
		freopen("../out.txt", "wb", stdout);
	}
#endif
	uint32_t I = readIntLn(1, 500'000);
	uint64_t X = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t A = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t B = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t C = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t D = readIntLn(1, 1'000'000'000'000'000'000ull);

	uint64_t P = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t Q = readIntSp(1, 1'000'000'000'000'000'000ull);
	
	uint64_t R = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t S = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t T = readIntSp(1, 1'000'000'000'000'000'000ull);
	uint64_t M = readIntLn(1, 1'000'000'000'000'000'000ull);
	uint64_t ans = 0;

	uint64_t newA = A, newB = B, newC = C, newD = D;
	T %= M;
	forn(tc, I)
	{
		//cerr << A << " " << B << " " << C << " " << D << endl;
		ans = solve(X, A, B, C, D);
		//cerr << ans << " " << A << " " << B << " " << C << " " << D << endl;
		printf("%d\n", ans);
		//cerr << ans;
		ans %= M;
		//cerr << ' ' << ans;
		
// 		__uint128_t tmp = ans;
// 		__uint128_t tmpT = T;
// 		__uint128_t tmpM = M;
// 		
// 		fprintf(stderr, "T: %016llu %016llu\n", (uint64_t)(T >> 64), (uint64_t)T);
// 		fprintf(stderr, "tmp: %016llu %016llu\n", (uint64_t)(tmp >> 64), (uint64_t)tmp);
// 		tmp = (tmp * T);
// 
// 		fprintf(stderr, " %016llu %016llu\n", (uint64_t)(tmp >> 64), (uint64_t)tmp);
// // 		tmp *= T;
//  		tmp %= M;
// 		fprintf(stderr, " %016llu %016llu\n", (uint64_t)(tmp >> 64), (uint64_t)tmp);
// 		ans = tmp;

 		ans *= T;
// 		cerr << ' ' << ans;
 		ans %= M;
// 		cerr << ' ' << ans;
//		mul_mod(ans, T, M);
		
		A += ans;
		A %= M;
		A += P;
		B += ans;
		B %= M;
		B += Q;
		C += ans;
		C %= M;
		C += R;
		D += ans;
		D %= M;
		D += S;
	}
	assert(getchar() == -1);
	return 0;
}

Author’s and Tester’s solutions are the same :frowning: