ADIVITYA11 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh_07
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

You’re given an array C of length N.
Consider the following process for a permutation P of length N:

  • Let sum = C_1, cnt = 1.
  • For each i = 2, 3, \ldots, N, if C_{P_i} \geq sum/cnt, add C_{P_i} to sum and 1 to cnt.

For each K from 1 to \sum C_i, find out whether there exists a permutation P for which the final value of sum is exactly K.

EXPLANATION:

We essentially want to find, for each possible subset sum of C, whether there’s a way to obtain it as the result of this process.
Let’s make a few observations about what’s going on. Suppose S is some non-empty subset of elements that we want to choose.

  1. The average of the chosen items is non-decreasing, because each time we choose an item iff it’s not less than the current average.
  2. Consider some element x that’s not in S.
    To skip x, we need to have a strictly higher average than x when we reach it.
    Coupled with the observation that the average is non-decreasing, it’s clearly best to first take every element of S, and only then try to skip x.
  3. In particular, this means that only the largest missing element from S matters: if we’re able to skip it, we can definitely skip everything \leq it.

So, let’s fix a value y, and see what subset sums we can achieve where y is the largest missing element.
Since y is the largest missing element, certainly anything \gt y will be chosen.
Let there be m_1 such values, and B denote their sum.

Now, consider some subset sum K consisting of only values that are \leq y.
We want two things:

  • K should be achievable as a subset sum at all, of course.
  • Further, y should be skippable.
    This means, if m_2 elements sum to K, our overall average = \frac{K+B}{m_1+m_2} should be strictly larger than y.

Since B, m_1, K are all fixed here, clearly our best option is to attempt to minimize m_2, i.e, minimize the number of elements that sum to K.
This requirement allows for a dynamic programming solution.
Sort the array C, so that C_i \leq C_{i+1}.
Let dp_{i, x} denote the smallest subset of the first i elements that sums up to x.
We have, depending on whether C_i is included or not:

dp_{i, x} = \min(dp_{i-1, x}, dp_{i-1, x-C_i} + 1)

Now, for each i such that C_i \neq C_{i+1}, do the following:

  • Find m_1 and B, the number of elements \gt C_i and their sum.
  • For each K from 0 to \sum C_i, let m_2 = dp_{i-1, K}.
    Note that we use index i-1 to ensure that at least one copy of C_i is skipped.
  • If C_i \lt \frac{B+K}{m_1 + m_2}, it’s possible to get a final score of B+K, so set ans_{B+K} = 1.

At the end of this process, ans holds our answer.
The complexity of this is \mathcal{O}(N\cdot \sum C_i), which is fast enough since \sum C_i is bounded by 10^5 and N is at most 1000.
Note that the dynamic programming sketched above has \mathcal{O}(N\cdot \sum C_i) states, which might be slow; however, it’s quite easy to reduce the number of states by a factor of N since only the previous row needs to be stored (the time complexity remains the same, but using lots of memory often slows down code).

TIME COMPLEXITY:

\mathcal{O}(N\cdot \sum C_i) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e9
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
const int N = 1005;
const int M = 1e5 + 1;
int n, a[N], dp[M], sum;

void Solve() 
{
    cin >> n;
    sum = 0;
    for (int i = 1; i <= n; i++){
        cin >> a[i];
        sum += a[i];
    }
    
    for (int i = 1; i <= sum; i++) dp[i] = INF;
    dp[0] = 0;
    
    sort(a + 1, a + n + 1);
    vector <bool> ans(sum + 1, 0);
    ans[sum] = true;
    int suf = sum;
    
    //iterate on first untaken element
    for (int i = 1; i < n; i++){
        suf -= a[i];
        for (int j = sum; j >= a[i-1]; j--){
            dp[j] = min(dp[j], dp[j - a[i - 1]] + 1);
        }
        
        for (int j = 0; j <= sum; j++){
            if (suf + j > sum) continue;
            
            //suf + j is good if a[i] < (suf + j)/(dp[j] + n - i) or if dp[j] * a[i] < suf + j
            if ((dp[j] + n - i) * a[i] < suf + j) ans[j + suf] = true;
        }
    }
    
    for (int i = 1; i <= sum; i++){
        cout << ans[i];
    }
    
    cout << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
   // cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;

    int T = input.readInt(1, (int)1e4);     input.readEoln();
    int SS = 0;
    while(T-- > 0) {
        int N = input.readInt(1, 1000); input.readEoln();
        vector<int> A = input.readInts(N, 1, (int)1e5); input.readEoln();
        int S = accumulate(A.begin(), A.end(), 0);
        SS += S;
        vector<int> dp(S + 1, 1e6);
        dp[0] = 0;
        sort(A.begin(), A.end());
        vector<int> suf(N + 1);
        for(int i = N - 1 ; i >= 0 ; --i)
            suf[i] = suf[i + 1] + A[i];

        int s = 0;
        vector<bool> good(S + 1);
        good[S] = 1;
        auto valid = [&](int a, int b, int c) {
            return a > (int64_t)b * c;
        };
        for(int i = 0 ; i < N ; ++i) {
            for(int j = s ; j >= 0 ; --j)
                dp[j + A[i]] = min(dp[j + A[i]], dp[j] + 1);
            s += A[i];
            for(int x = 0 ; x <= s ; ++x) {
                if(valid(x + suf[i + 1], N - i - 1 + dp[x], A[i])) {
                    good[x + suf[i + 1]] = 1;
                }
            }
        }
        vector<int> result;
        for(int i = 1 ; i <= S ; ++i)
            cout << good[i];
        cout << '\n';
    }
    assert(SS <= (int)1e5);

    input.readEof();

    return 0;
}

Editorialist's code (Python)
for _ in range(int(input())):
    n = int(input())
    c = sorted(list(map(int, input().split())))
    
    m, sm = max(c), sum(c)
    lim = sm - m*c.count(m) + 1
    ans = [0]*(sm + 1)
    dp = [n+1]*(lim)
    dp[0] = 0
    
    larger = sum(c)
    pref = 0
    for i in range(n):
        y = c[i]
        larger -= y
        if y == m: break
        if y != c[i+1]:
            for x in range(pref+1):
                if dp[x] == n+1: continue
                val = x + larger
                sz = dp[x] + n - (i+1)
                if y*sz < val: ans[val] = 1
        for x in reversed(range(y, lim)):
            dp[x] = min(dp[x], dp[x-y] + 1)
        pref += y
    ans[sm] = 1
    print(*ans[1:], sep='')