PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh_07
Tester: apoorv_me
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Dynamic programming
PROBLEM:
You’re given an array C of length N.
Consider the following process for a permutation P of length N:
- Let sum = C_1, cnt = 1.
- For each i = 2, 3, \ldots, N, if C_{P_i} \geq sum/cnt, add C_{P_i} to sum and 1 to cnt.
For each K from 1 to \sum C_i, find out whether there exists a permutation P for which the final value of sum is exactly K.
EXPLANATION:
We essentially want to find, for each possible subset sum of C, whether there’s a way to obtain it as the result of this process.
Let’s make a few observations about what’s going on. Suppose S is some non-empty subset of elements that we want to choose.
- The average of the chosen items is non-decreasing, because each time we choose an item iff it’s not less than the current average.
- Consider some element x that’s not in S.
To skip x, we need to have a strictly higher average than x when we reach it.
Coupled with the observation that the average is non-decreasing, it’s clearly best to first take every element of S, and only then try to skip x. - In particular, this means that only the largest missing element from S matters: if we’re able to skip it, we can definitely skip everything \leq it.
So, let’s fix a value y, and see what subset sums we can achieve where y is the largest missing element.
Since y is the largest missing element, certainly anything \gt y will be chosen.
Let there be m_1 such values, and B denote their sum.
Now, consider some subset sum K consisting of only values that are \leq y.
We want two things:
- K should be achievable as a subset sum at all, of course.
- Further, y should be skippable.
This means, if m_2 elements sum to K, our overall average = \frac{K+B}{m_1+m_2} should be strictly larger than y.
Since B, m_1, K are all fixed here, clearly our best option is to attempt to minimize m_2, i.e, minimize the number of elements that sum to K.
This requirement allows for a dynamic programming solution.
Sort the array C, so that C_i \leq C_{i+1}.
Let dp_{i, x} denote the smallest subset of the first i elements that sums up to x.
We have, depending on whether C_i is included or not:
Now, for each i such that C_i \neq C_{i+1}, do the following:
- Find m_1 and B, the number of elements \gt C_i and their sum.
- For each K from 0 to \sum C_i, let m_2 = dp_{i-1, K}.
Note that we use index i-1 to ensure that at least one copy of C_i is skipped. - If C_i \lt \frac{B+K}{m_1 + m_2}, it’s possible to get a final score of B+K, so set ans_{B+K} = 1.
At the end of this process, ans holds our answer.
The complexity of this is \mathcal{O}(N\cdot \sum C_i), which is fast enough since \sum C_i is bounded by 10^5 and N is at most 1000.
Note that the dynamic programming sketched above has \mathcal{O}(N\cdot \sum C_i) states, which might be slow; however, it’s quite easy to reduce the number of states by a factor of N since only the previous row needs to be stored (the time complexity remains the same, but using lots of memory often slows down code).
TIME COMPLEXITY:
\mathcal{O}(N\cdot \sum C_i) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e9
#define f first
#define s second
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int N = 1005;
const int M = 1e5 + 1;
int n, a[N], dp[M], sum;
void Solve()
{
cin >> n;
sum = 0;
for (int i = 1; i <= n; i++){
cin >> a[i];
sum += a[i];
}
for (int i = 1; i <= sum; i++) dp[i] = INF;
dp[0] = 0;
sort(a + 1, a + n + 1);
vector <bool> ans(sum + 1, 0);
ans[sum] = true;
int suf = sum;
//iterate on first untaken element
for (int i = 1; i < n; i++){
suf -= a[i];
for (int j = sum; j >= a[i-1]; j--){
dp[j] = min(dp[j], dp[j - a[i - 1]] + 1);
}
for (int j = 0; j <= sum; j++){
if (suf + j > sum) continue;
//suf + j is good if a[i] < (suf + j)/(dp[j] + n - i) or if dp[j] * a[i] < suf + j
if ((dp[j] + n - i) * a[i] < suf + j) ans[j + suf] = true;
}
}
for (int i = 1; i <= sum; i++){
cout << ans[i];
}
cout << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
// cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
int32_t main() {
ios_base::sync_with_stdio(0); cin.tie(0);
input_checker input;
int T = input.readInt(1, (int)1e4); input.readEoln();
int SS = 0;
while(T-- > 0) {
int N = input.readInt(1, 1000); input.readEoln();
vector<int> A = input.readInts(N, 1, (int)1e5); input.readEoln();
int S = accumulate(A.begin(), A.end(), 0);
SS += S;
vector<int> dp(S + 1, 1e6);
dp[0] = 0;
sort(A.begin(), A.end());
vector<int> suf(N + 1);
for(int i = N - 1 ; i >= 0 ; --i)
suf[i] = suf[i + 1] + A[i];
int s = 0;
vector<bool> good(S + 1);
good[S] = 1;
auto valid = [&](int a, int b, int c) {
return a > (int64_t)b * c;
};
for(int i = 0 ; i < N ; ++i) {
for(int j = s ; j >= 0 ; --j)
dp[j + A[i]] = min(dp[j + A[i]], dp[j] + 1);
s += A[i];
for(int x = 0 ; x <= s ; ++x) {
if(valid(x + suf[i + 1], N - i - 1 + dp[x], A[i])) {
good[x + suf[i + 1]] = 1;
}
}
}
vector<int> result;
for(int i = 1 ; i <= S ; ++i)
cout << good[i];
cout << '\n';
}
assert(SS <= (int)1e5);
input.readEof();
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
c = sorted(list(map(int, input().split())))
m, sm = max(c), sum(c)
lim = sm - m*c.count(m) + 1
ans = [0]*(sm + 1)
dp = [n+1]*(lim)
dp[0] = 0
larger = sum(c)
pref = 0
for i in range(n):
y = c[i]
larger -= y
if y == m: break
if y != c[i+1]:
for x in range(pref+1):
if dp[x] == n+1: continue
val = x + larger
sz = dp[x] + n - (i+1)
if y*sz < val: ans[val] = 1
for x in reversed(range(y, lim)):
dp[x] = min(dp[x], dp[x-y] + 1)
pref += y
ans[sm] = 1
print(*ans[1:], sep='')