GET_THE_SEG - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Testers: Nishank Suresh, Takuki Kurokawa
Editorialist: Nishank Suresh

DIFFICULTY:

2961

PREREQUISITES:

Binary Search

PROBLEM:

There is a hidden array containing only the integers 1 and 2. You want to find a subarray of this array with sum K. To achieve this, you can ask queries for the sum of any subarray.

Achieve the goal using at most 40 queries.

EXPLANATION:

The small number of queries (and the fact that this is an interactive task) should immediately lead you to look at binary search. But what to binary search on?

I will use sum(L, R) to denote the sum of the subarray A[L\ldots R].

Suppose we fix the left endpoint of the subarray, L. Binary search allows us to find the largest index R such that sum(L, R) \leq K.
The input guarantees that a subarray with sum K exists, so if we repeat this for every index L, we will eventually find the answer. However, this requires \mathcal{O}(N\log N) queries, which is of course way too much.

Instead, let’s analyze some simpler cases a bit. First, fix L = 1 and find the appropriate R. Now, there are a couple of cases:

  • First, if sum(1, R) = K, we are done and can immediately report the answer.
  • Otherwise, the only possibility is sum(1, R) = K-1, and A_{R+1} = 2. (Do you see why?)
  • If A_1 = 1, it’s easy to see that sum(2, R+1) = K and we would be done. However, if A_1 = 2 we get no further information.

Note that the above results apply more generally to any index L such that sum(L, N) \geq K.
For any such index, if we find the largest R such that sum(L, R) \leq K, then as long as A_L = 1 either (L, R) or (L+1, R+1) is the answer.

This tells us that we need to find a 1 in the array: ideally, the leftmost 1.
Finding the leftmost 1 can be done easily with binary search: find the first position p such that sum(1, p) \neq 2p.

Now that we have our 1 at position p, if sum(p, N) \geq K we can solve the problem with another binary search, as described above. However, this need not always be the case: what happens when sum(p, N) \lt K?

The important thing to note in this case is that everything to the left of p is a 2. So,

  • If K and sum(p, N) have the same parity, simply extend the subarray [p, N] to the left by as much as needed to make the sums equal. This can be done without any queries at all, since only the values of K and sum(p, N) need to be known.
  • Otherwise, note that there is no choice for the answer subarray but to exclude the last 1 in the array: this is the only way to change the parity of the sum.

To deal with the second case, once again we binary search to find the position of the last 1 in the array: say this is q.
Then, sum(p, q-1) and K have the same parity, so we can extend p to the left by adding 2's till we get a sum of K.

Note that no matter how the problem is solved, we use at most two binary searches: one to find p, and then either one to find q, or one to find R (such that sum(p, R) \leq K).

This gives us 2 \cdot \log{10^5} \approx 34 queries, along with a constant number more for any checks needed inbetween. My implementation linked below uses 35 queries in the worst case, so the limit at 40 is slightly lenient.

TIME COMPLEXITY

\mathcal{O}(\log N) per test case.

CODE:

Setter's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef double db; 
typedef long long ll;
typedef unsigned long long ull;
const int N=100000;
const int LOGN=28;
const ll  TMD=0;
const ll  INF=2147483647;
int T;

int query(int L,int R)
{
	int tmp;
	printf("1 %d %d\n",L,R);
	fflush(stdout);
	scanf("%d",&tmp);
	return tmp;
}

void answer(int L,int R)
{
	int tmp;
	printf("2 %d %d\n",L,R);
	fflush(stdout);
	scanf("%d",&tmp);
}

int main()
{
	scanf("%d",&T);
	while(T--)
	{
		int n,k,L,R,M,pos1,pos2;
		scanf("%d%d",&n,&k);
		L=0;R=n+1;
		while(L+1!=R)
		{
			M=(L+R)>>1;
			if(query(1,M)==2*M) L=M;
			else R=M;
		}
		if(R==n+1) answer(1,k/2);
		else
		{
			pos1=R;
			if(query(pos1,n)>=k)
			{
				L=pos1;R=n+1;
				while(L+1!=R)
				{
					M=(L+R)>>1;
					if(query(pos1,M)<=k) L=M;
					else R=M;
				}
				if(query(pos1,L)==k) answer(pos1,L);
				else answer(pos1+1,L+1);
			}
			else
			{
				if((query(pos1,n)&1)^(k&1))
				{
					L=0;R=n+1;
					while(L+1!=R)
					{
						M=(L+R)>>1;
						if(query(M,n)==2*(n-M+1)) R=M;
						else L=M;
					}
					pos2=L;
					answer(pos1-(k-query(pos1,pos2-1))/2,pos2-1);
				}
				else answer(pos1-(k-query(pos1,n))/2,n);
			}
		}
	}
	

	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

int main() {
    auto Ask = [&](int c, int l, int r) {
        cout << c << " " << l + 1 << " " << r + 1 << endl;
        cin >> l;
        return l;
    };
    int tt;
    cin >> tt;
    while (tt--) {
        int n, k;
        cin >> n >> k;
        if (Ask(1, 0, n - 1) == 2 * n) {
            assert(Ask(2, 0, k / 2 - 1) == 1);
            continue;
        }
        int l;
        {
            int low = -1;
            int high = n - 1;
            while (high - low > 1) {
                int mid = (high + low) >> 1;
                if (Ask(1, 0, mid) < 2 * (mid + 1)) {
                    high = mid;
                } else {
                    low = mid;
                }
            }
            l = high;
        }
        int s = Ask(1, l, n - 1);
        if (s >= k) {
            int low = l - 1;
            int high = n - 1;
            while (high - low > 1) {
                int mid = (high + low) >> 1;
                if (Ask(1, l, mid) >= k) {
                    high = mid;
                } else {
                    low = mid;
                }
            }
            if (Ask(1, l, high) == k + 1) {
                l++;
            }
            assert(Ask(2, l, high) == 1);
        } else if (s % 2 == k % 2) {
            assert(Ask(2, l - (k - s) / 2, n - 1) == 1);
        } else {
            int r;
            {
                int low = 0;
                int high = n;
                while (high - low > 1) {
                    int mid = (high + low) >> 1;
                    if (Ask(1, mid, n - 1) < 2 * (n - mid)) {
                        low = mid;
                    } else {
                        high = mid;
                    }
                }
                r = low;
            }
            if (l == r) {
                if (k / 2 <= l) {
                    assert(Ask(2, 0, k / 2 - 1));
                } else {
                    assert(Ask(2, n - k / 2, n - 1));
                }
            } else {
                s = Ask(1, l, r - 1);
                assert(Ask(2, l - (k - s) / 2, r - 1) == 1);
            }
        }
    }
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	auto ask = [&] (int L, int R) {
		cout << 1 << ' ' << L << ' ' << R << endl;
		int res; cin >> res;
		assert(res != -1);
		return res;
	};
	auto ans = [&] (int L, int R) {
		cout << 2 << ' ' << L << ' ' << R << endl;
		int res; cin >> res;
		assert(res == 1);
	};

	int t; cin >> t;
	while (t--) {
		int n, k; cin >> n >> k;

		// find a 1
		int first1;
		{
			int lo = 1, hi = n;
			while (lo < hi) {
				int mid = (lo + hi)/2;
				int sum = ask(1, mid);
				if (sum == 2*mid) lo = mid+1;
				else hi = mid;
			}
			first1 = lo;
		}
		int suf = ask(first1, n);
		
		// answer immediately if possible
		if (suf >= k) {
			int lo = first1, hi = n, val = 1;
			while (lo < hi) {
				int mid = (lo + hi + 1)/2;
				int sum = ask(first1, mid);
				if (sum <= k) {
				    lo = mid;
				    val = sum;
				}
				else hi = mid - 1;
			}
			if (val == k) ans(first1, lo);
			else ans(first1+1, lo+1);
			continue;
		}
		if (suf%2 == k%2) {
			int reqd = (k - suf)/2;
			ans(first1 - reqd, n);
			continue;
		}
		suf += 2 * (first1 - 1);

		// find last 1
		int last1;
		{
			int lo = first1, hi = n;
			while (lo < hi) {
				int mid = (lo + hi + 1)/2;
				int sum = ask(mid, n);
				if (sum == 2*(n - mid + 1)) hi = mid-1;
				else lo = mid;
			}
			last1 = lo;
		}

		suf -= 2 * (n - last1) + 1;
		int reqd = (suf - k)/2;
		ans(1 + reqd, last1 - 1);
	}
}
2 Likes