PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: grayhathacker
Tester: jay_1048576
Editorialist: iceknight1093
DIFFICULTY:
1868
PREREQUISITES:
Binary search, prefix sums
PROBLEM:
For an array A of length N, the multiset S = \{\min(A_i, A_j, A_j) \mid 1 \leq i \lt j \lt k \leq N\} is called its triplet array.
You’re given an array A and Q queries.
For each query, given an integer K, output the K-th smallest element of the triplet array of A.
EXPLANATION:
First, let’s attempt to answer a single query reasonably quickly.
Let’s sort A, so that A_1 \leq A_2 \leq \ldots \leq A_N. This doesn’t change the triplet array.
Then, we have the following:
- There are \binom{N-1}{2} pairs whose minimum value is A_1 (pick two distinct indices out of (2, 3, 4, \ldots, N).
- There are \binom{N-2}{2} pairs whose minimum value is A_2 (pick two distinct indices out of (3, 4, \ldots, N).
\vdots - There are \binom{N-i}{2} pairs whose minimum value is A_i, for any i.
This already gives us an algorithm in \mathcal{O}(N) to solve for a single K: all we need to do is find the smallest i such that
\binom{N-1}{2} + \binom{N-2}{2} + \ldots + \binom{N-i}{2} \geq K
and we know that the answer is A_i.
To speed this up, we can use binary search and prefix sums.
In particular, let P_i = \binom{N-1}{2} + \binom{N-2}{2} + \ldots + \binom{N-i}{2} be the prefix sums of the counts we calculated above.
For a fixed K, we want to find the smallest i such that P_i \geq K.
Since P_i \lt P_{i+1} for all i, this can be done by just binary searching on the P array!
Now we’re able to answer a single query in \mathcal{O}(\log N) time with \mathcal{O}(N) precomputation.
Since the precomputation remains the same across all queries, this gives us a fairly simple solution in \mathcal{O}(N + Q\log N), which is enough to get AC.
TIME COMPLEXITY
\mathcal{O}(N + Q\log N) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define mod 1000000007
typedef set<string> ss;
typedef vector<int> vs;
typedef map<int, char> msi;
typedef pair<int, int> pa;
typedef long long int ll;
ll n, q, i, a[300005], cnt[300005], k;
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(0);
#ifndef ONLINE_JUDGE
freopen("inputf.in", "r", stdin);
// freopen("output.txt", "w", stdout);
#endif
int t;
cin >> t;
while (t--)
{
cin >> n >> q;
for (i = 0; i < n; i++)
cin >> a[i];
sort(a, a + n);
for (i = 0; i < n; i++)
{
cnt[i] = (n - i - 1) * (n - i - 2) / 2;
if (i > 0)
cnt[i] += cnt[i - 1];
}
while (q--)
{
cin >> k;
cout << a[lower_bound(cnt, cnt + n, k) - cnt] << "\n";
}
}
return 0;
}
Tester's code (C++)
/*...................................................................*
*............___..................___.....____...______......___....*
*.../|....../...\........./|...../...\...|.............|..../...\...*
*../.|...../.....\......./.|....|.....|..|.............|.../........*
*....|....|.......|...../..|....|.....|..|............/...|.........*
*....|....|.......|..../...|.....\___/...|___......../....|..___....*
*....|....|.......|.../....|...../...\.......\....../.....|./...\...*
*....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
*....|.....\...../.........|....|.....|.......|.../........\...../..*
*..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
*...................................................................*
*/
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1000000000000000000
#define MOD 1000000007
void solve(int tc)
{
int n,q;
cin >> n >> q;
int a[n];
for(int i=0;i<n;i++)
cin >> a[i];
sort(a,a+n);
int pre[n];
for(int i=0;i<n;i++)
pre[i]=(n-i-1)*(n-i-2)/2;
for(int i=1;i<n;i++)
pre[i]+=pre[i-1];
while(q--)
{
int k;
cin >> k;
cout << a[lower_bound(pre,pre+n,k)-pre] << '\n';
}
}
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int tc=1;
cin >> tc;
for(int ttc=1;ttc<=tc;ttc++)
solve(ttc);
return 0;
}
Editorialist's code (Python)
from bisect import bisect_right
def C2(x):
return x * (x-1) // 2
for _ in range(int(input())):
n, q = map(int, input().split())
a = sorted(list(map(int, input().split())))
counts = []
for i in range(n): counts.append(C2(n-1-i))
for i in range(1, n): counts[i] += counts[i-1]
for i in range(q):
k = int(input())
pos = bisect_right(counts, k-1)
print(a[pos])