MOUNTAIN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Testers: tabr, iceknight1093
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Greedy algorithms

PROBLEM:

There’s an N\times M grid, with A_{i, j} = i.

A mountain is a set of integers P, K, L_1, L_2, \ldots, L_K, saying that you pick the first L_i integers from row (P+i-1).

Answer Q queries of the following form:

  • Given S, find a mountain with sum S.

EXPLANATION:

This task can be solved greedily.

Suppose we want a sum of S.
Let’s go from the first row to the last, each time taking as many numbers as possible till we first exceed the sum.
As soon as we exceed it, we can throw out (at most) one number to attain the exact sum we want.

That is, initialize a variable \text{sum} = 0 and set P = 1 since we’re starting from the first row.
Then, for each i from 1 to N:

  • If \text{sum} + i\cdot M \lt S, take all M elements from this row and continue. In other words, set L_i = M.
  • Otherwise, let j be the smallest integer such that \text{sum} + i\cdot j \geq S. Take these j numbers into the sum, i.e, set L_i = j.
  • Now, if \text{sum} = S we’re done.
  • Otherwise, remove one element from the row (\text{sum} - S), i.e, decrement L_{\text{sum} - S} by one. Since we took elements in order from smallest to largest, it’s guaranteed that the value of \text{sum} - S is no larger than i, so this is always possible.

TIME COMPLEXITY

\mathcal{O}(N) or \mathcal{O}(N+M) per query.

CODE:

Setter's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=1000010;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647;
int n,m,q;
int p[N];
pair<ll,ll> qr[N];
vector<int> ans[N];

int main()
{
	scanf("%d%d%d",&n,&m,&q);
	for(int i=1;i<=q;i++)
	{
		ll  t,sum;
		int L=0,R=n+1,M,p;
		scanf("%lld",&t);
		while(L+1!=R)
		{
			M=(L+R)>>1;
			if((ll)m*(ll)M*(M+1)/2<t) L=M;
			else R=M;
		}
		p=R;sum=(ll)m*(ll)L*(L+1)/2;
		for(int j=1;j<=m;j++)
		{
			sum+=p;
			if(sum>=t)
			{
				printf("%d %d\n",1,p);
				for(int k=1;k<p;k++) printf("%d ",k==sum-t?m-1:m);
				printf("%d\n",j);
				break;
			}
		}
	}

	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int n = in.readInt(2, 30000);
    in.readSpace();
    int m = in.readInt(2, 30000);
    in.readSpace();
    int q = in.readInt(1, 10);
    in.readEoln();
    vector<long long> s(q);
    for (int i = 0; i < q; i++) {
        s[i] = in.readLong(1, m * 1LL * n * (n + 1) / 2);
        (i == q - 1 ? in.readEoln() : in.readSpace());
    }
    for (auto t : s) {
        long long now = 0;
        vector<int> l;
        for (int i = 0; i < n; i++) {
            if (now + (i + 1) <= t) {
                now += i + 1;
                l.emplace_back(1);
            } else {
                break;
            }
        }
        for (int i = (int) l.size() - 1; i >= 0; i--) {
            if (now + (i + 1) * 1LL * (m - 1) <= t) {
                l[i] += m - 1;
                now += (i + 1) * 1LL * (m - 1);
            }
            while (now + (i + 1) <= t && l[i] < m) {
                l[i]++;
                now += i + 1;
            }
        }
        cout << 1 << " " << l.size() << '\n';
        for (int i = 0; i < (int) l.size(); i++) {
            cout << l[i] << " \n"[i == (int) l.size() - 1];
        }
    }
    return 0;
}
Editorialist's code (Python)
n, m, q = map(int, input().split())
queries = list(map(int, input().split()))
for s in queries:
    cursum = 0
    row = 1
    while True:
        if cursum + m*row < s:
            cursum += m*row
            row += 1
        else:
            take = (s - cursum + row-1) // row
            cursum += row * take
            print(1, row)
            for i in range(1, row):
                if cursum - i == s: print(m-1, end = ' ')
                else: print(m, end = ' ')
            print(take)
            break
2 Likes

I did the exact same thing in java but it gave runtime error. After contest I verfied it with c++ and the solution passed.
My java solution: CodeChef: Practical coding for everyone
My c++ solution: CodeChef: Practical coding for everyone
Heres my solution if I have made a mistake please state the cause after verifying.

Your Java code reads s as an int (instead of a long), which is an issue.

1 Like

int n=readInt();
int m=readInt();
int q=readInt();
long sum=((long)n*((long)n+1L))/2L;

next();
long[] arr=new long[q];
for(int i=0;i<q;i++)
arr[i]=readLong();
for(int i=0;i<q;i++)
{
long m2=(long)m;

long sum1=(long)sum;
sum1=sum1*m2;
if(arr[i]>sum1)
out.println(-1);
else{
    if(sum>arr[i])
    {
       long n1=(long)n;
        long total=arr[i];
       StringBuilder s=new StringBuilder();
       long count=0L;
       while(n1>0)
       {
                if(total>=n1)
                {
                    s.append(1+" ");
                    total-=n1;
                    count++;
                }
                
                if(total==0)
                break;
                n1--;
       }
       StringBuilder s1=new StringBuilder();
       s1.append((n1)+" "+count);
       out.println(s1);
       out.println(s);
    }
    else if((arr[i]%sum)==0)
    {
        long n1=arr[i]/sum;
        //StringBuilder s1=new StringBuilder();
        StringBuilder s=new StringBuilder();
        for(int j=0;j<n;j++)
        s.append(n1+" ");
        out.println(1+" "+n);
        out.println(s);
    }
    else
    {
        long n3=arr[i]/sum;
          long n1=(long)n;
        long total=arr[i]-(sum*n3);
       StringBuilder s=new StringBuilder();
       long[] a=new long[n];
       Arrays.fill(a,n3);
       long count=0;
       while(n1>0L)
       {
                if(total>=n1)
                {
                    a[(int)n1-1]+=1L;
                    total-=n1;
                    count++;
                }
                
                if(total==0)
                break;
                n1--;
       }
       for(int j=0;j<n;j++)
       s.append(a[j]+" ");
       out.println(1+" "+n);
       out.println(s);
    }
}

Can anyone tells where this code fails?

this code will fail if the input is more than the maximum sum that we could calculate.i.e. if 1 value in queries is 100 then it should print ‘-1’ but it is not printing that.

The constraint 1\leq S_i \leq N\cdot M \cdot \frac{N+1}{2} guarantees that this case will never happen.

Why isn’t this code working?

#include<bits/stdc++.h>
		 
typedef long long int ll;
#define pb push_back
#define pf push_front
#define ss second
#define ff first
#define cY cout<<"YES\n"
#define cN cout<<"NO\n"
#define cy cout<<"Yes\n"
#define cn cout<<"No\n"
#define cyy cout<<"yes\n"
#define cnn cout<<"no\n"
using namespace std;
#define farr(i,n) for(int i=0; i<n; i++)
#define fo(i, start, end) for(int i=start; i<end; i++)
#define all(x) (x).begin(), (x).end()
typedef vector<int> vint;
typedef vector<long long int> vll;
typedef set<int> sint;
typedef set<long long int> sll;
typedef pair<int,int> pint;
typedef pair<ll,ll> pll;
#define minv(a) *min_element(all(a))
#define maxv(a) *max_element(all(a))

ll sifdiv(ll a, ll b){
	if(a%b==0){
		return a/b;
	}
	else{
		return a/b +1;
	}
}
#define MOD 1000000007
#define PI acos(-1)

void soln(vll &presum, vint &a, int n, int m, ll s){
    for(int i=n; i>0; i--){
        if(s>presum[i-1]){
            a.pb((s-presum[i-1]+i-1)/i);
            s-=i*((s-presum[i-1]+i-1)/i);
        }
    }
    
}
 
 
 
int main(){
	#ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    freopen("error.txt", "w", stderr);
    #endif
	ios::sync_with_stdio(false);
	cin.tie(NULL); cout.tie(NULL);
	int tc, n, m;
	cin>>n>>m;
	vll presum(n+1,0);
	farr(i, n+1){
	    presum[i]=m*(i+1)*(i)/2;
	}
	cin>>tc;
	while(tc--){
	    vint a;
	    ll s;
	    cin>>s;
		soln(presum, a, n, m, s);
		cout<<1<<" "<<a.size()<<"\n";
		for(int i=a.size()-1; i>=0;i--){
		    cout<<a[i]<<" ";
		}
		cout<<"\n";
	}
	return 0;
}

Kindly Let me know my error here. I have been tryin this for a day. I tried for various inputs and it works perfectly fine, outputwise. but the submission is not getting accepted even for one subtask.

int main() {
    int m,n,q;
    cin>>n>>m>>q;
    for(int i=1;i<=q;i++){
        int arr[n][m],p,k,s;
        cin>>s;
        if(s>m*((n*(n+1)/2)) || s<0) {cout<<"-1"<<endl; continue;}
        for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) arr[i-1][j-1]=i;
        int count=s/((n*(n+1)/2));
        int count2=s%((n*(n+1)/2));
        int kn[n],last[n];
        for(int i=0;i<n;i++) {kn[i]=count; last[i]=i+1;}
        for(int i=0;i<n;i++){
        	int j=0;
        	while(count2!=0 && kn[j]<m && j<n){
        		if(last[j]<=count2){count2=count2-last[j];
        		kn[j]++;} j++;}
        }
        
        if(count2!=0){
        	for(int i=1;i<n;i++) {count2+=i; kn[i-1]--;}
        	for(int i=n-1;i>=0;i--){if(last[i]<=count2) {count2-=last[i]; kn[i]++;}}
        }
        //cout<<count2<<' '<<endl;;	

        for(int i=0;i<n;i++) if(kn[i]!=0){p=i+1;k=n-i; break;}
        for(int i=n-1;i>=0;i--) {if(kn[i]==0) k--; else break;}
       	if(p!=0 && k!=0) cout<<p<<' '<<k<<endl;
       	else cout<<"-1";
       	for(int i=p-1;i<p+k-1;i++) cout<<kn[i]<<' ';
       	cout<<endl;

     
    }
	return 0;
}

I have a different greedy strategy , i will describe it below(the proof of correctness is left for readers as an exercise)
First take all the cells i.e to say start with the sum n*(n+1)/2 * m , now we have to exclude the extra part : suppose we want to remove x from this to get the sum s(for the current query) , then my claim is :
First remove as many n’s (elements from the last row) as you can , now move onto n-1 and again remove as many as you can, and so on … and you do this until you have removed as many as you wanted or you have reached 0. Any possible x can be achieved this way. There is a classical problem this goes like this : you have all the denominations from 1,2,…n and you want to make a sum x out of it with minimum possible denominations, then the greedy strategy to choose the maximum value until we can and then move onto the next maximum is indeed correct ! The process that i described in my greedy solution to the original question is exactly the same.