PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: wuhudsm
Testers: tabr, iceknight1093
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Greedy algorithms
PROBLEM:
There’s an N\times M grid, with A_{i, j} = i.
A mountain is a set of integers P, K, L_1, L_2, \ldots, L_K, saying that you pick the first L_i integers from row (P+i-1).
Answer Q queries of the following form:
- Given S, find a mountain with sum S.
EXPLANATION:
This task can be solved greedily.
Suppose we want a sum of S.
Let’s go from the first row to the last, each time taking as many numbers as possible till we first exceed the sum.
As soon as we exceed it, we can throw out (at most) one number to attain the exact sum we want.
That is, initialize a variable \text{sum} = 0 and set P = 1 since we’re starting from the first row.
Then, for each i from 1 to N:
- If \text{sum} + i\cdot M \lt S, take all M elements from this row and continue. In other words, set L_i = M.
- Otherwise, let j be the smallest integer such that \text{sum} + i\cdot j \geq S. Take these j numbers into the sum, i.e, set L_i = j.
- Now, if \text{sum} = S we’re done.
- Otherwise, remove one element from the row (\text{sum} - S), i.e, decrement L_{\text{sum} - S} by one. Since we took elements in order from smallest to largest, it’s guaranteed that the value of \text{sum} - S is no larger than i, so this is always possible.
TIME COMPLEXITY
\mathcal{O}(N) or \mathcal{O}(N+M) per query.
CODE:
Setter's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll TMD=0;
const ll INF=2147483647;
int n,m,q;
int p[N];
pair<ll,ll> qr[N];
vector<int> ans[N];
int main()
{
scanf("%d%d%d",&n,&m,&q);
for(int i=1;i<=q;i++)
{
ll t,sum;
int L=0,R=n+1,M,p;
scanf("%lld",&t);
while(L+1!=R)
{
M=(L+R)>>1;
if((ll)m*(ll)M*(M+1)/2<t) L=M;
else R=M;
}
p=R;sum=(ll)m*(ll)L*(L+1)/2;
for(int j=1;j<=m;j++)
{
sum+=p;
if(sum>=t)
{
printf("%d %d\n",1,p);
for(int k=1;k<p;k++) printf("%d ",k==sum-t?m-1:m);
printf("%d\n",j);
break;
}
}
}
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
// cerr << res << endl;
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
input_checker in;
int n = in.readInt(2, 30000);
in.readSpace();
int m = in.readInt(2, 30000);
in.readSpace();
int q = in.readInt(1, 10);
in.readEoln();
vector<long long> s(q);
for (int i = 0; i < q; i++) {
s[i] = in.readLong(1, m * 1LL * n * (n + 1) / 2);
(i == q - 1 ? in.readEoln() : in.readSpace());
}
for (auto t : s) {
long long now = 0;
vector<int> l;
for (int i = 0; i < n; i++) {
if (now + (i + 1) <= t) {
now += i + 1;
l.emplace_back(1);
} else {
break;
}
}
for (int i = (int) l.size() - 1; i >= 0; i--) {
if (now + (i + 1) * 1LL * (m - 1) <= t) {
l[i] += m - 1;
now += (i + 1) * 1LL * (m - 1);
}
while (now + (i + 1) <= t && l[i] < m) {
l[i]++;
now += i + 1;
}
}
cout << 1 << " " << l.size() << '\n';
for (int i = 0; i < (int) l.size(); i++) {
cout << l[i] << " \n"[i == (int) l.size() - 1];
}
}
return 0;
}
Editorialist's code (Python)
n, m, q = map(int, input().split())
queries = list(map(int, input().split()))
for s in queries:
cursum = 0
row = 1
while True:
if cursum + m*row < s:
cursum += m*row
row += 1
else:
take = (s - cursum + row-1) // row
cursum += row * take
print(1, row)
for i in range(1, row):
if cursum - i == s: print(m-1, end = ' ')
else: print(m, end = ' ')
print(take)
break