PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: munch_01
Preparer: jay_1048576
Tester: yash_daga
Editorialist: iceknight1093
DIFFICULTY:
2127
PREREQUISITES:
None
PROBLEM:
Given N and K, define the function f(x) = (x\bmod K) \times ((N-x)\bmod K).
Find any 0 \leq x \leq N that maximizes f(x).
EXPLANATION:
N can be quite large, so of course trying every x is out of the question.
Instead, let’s attempt to reduce the number of x we need to check.
One way to build intuition for this is to look at the case when K \gt N, in which case x\bmod K = x and (N-x)\bmod K = (N-x) (since they’re both less than K already).
This results in f(x) = x\cdot (N-x), and it’s not hard to see that this function is maximized at x = \frac{N}{2} (or rather, the nearest integer to \frac{N}{2}).
For example, a quick way of seeing it is to observe that when x \lt N-x, we have x\cdot (N-x) \leq (x+1)\cdot (N-x-1), so it’s always better to bring x closer to N-x.
In fact, this idea applies to the case when K \leq N as well!
That is, the choice of x that maximizes f(x) will be such that x\bmod K and (N-x) \bmod K have a difference of at most 1.
We just need to figure out when this is possible at all, i.e, which x satisfy this condition.
It turns out there are only two cases:
- The first is x = \frac{N\bmod K}{2}
This is the direct generalization of what we had for the N\lt K case, where here we instead start at 0 and N\bmod K and keep moving them closer towards each other. - The second is to just move in the other direction! That is, the point x = \frac{(N\bmod K) + K}{2}.
This is what we get by starting at N\bmod K and K (which is equivalent to 0 modulo K) and moving them closer to each other.
Note that to have this option, you do need to verify that x \leq N - it might not be the case if N is too small.
Now that you have (at most) two options for x, simply evaluate f(x) at both of them and choose the best one.
TIME COMPLEXITY
\mathcal{O}(1) per testcase.
CODE
Tester's code (C++)
// Input Checker
// Input verification
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readIntVec(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
else readEoln();
}
return v;
}
auto readLongVec(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
else readEoln();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
input_checker inp;
int t = inp.readInt(1, 1e5); inp.readEoln();
while (t--) {
int n = inp.readInt(0, 1e9); inp.readSpace();
int b = inp.readInt(1, 1e9); inp.readEoln();
int ans=((n%b)/2);
if(n>=b && n%b!=b-1)
ans+=((b+1)/2);
cout<<ans<<"\n";
}
inp.readEof();
}
Editorialist's code (Python)
def f(x, n, k):
return (x%k) * ((n-x)%k)
for _ in range(int(input())):
n, k = map(int, input().split())
ans = (n%k)//2
what = (k+n%k)//2
if what <= n and f(what, n, k) > f(ans, n, k): ans = what
print(ans)