SMOLLAST - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: himanio
Preparation: iceknight1093
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

Alice and Bob play a game on an array A.
The following is repeated N-1 times: Alice sorts the array, then gives Bob two elements from it. Bob must choose to insert either their sum or difference into the array; the original two elements are deleted.

Find a sequence of moves for Bob that minimizes the last element.

EXPLANATION:

First, let’s try to find a lower bound on the answer.

Note that no matter what Bob’s moves are, the final element will be of the form

|\pm A_1 \pm A_2 \pm A_3 \pm \ldots \pm A_N|

where we choose either + or - for each \pm.

Let S_1 denote the subset of elements for which we choose +, and S_2 denote the subset for which - is chosen.
Then, the value we obtain is |sum(S_1) - sum(S_2)|.

This gives us a lower bound: let D be the smallest integer such that there exists a partition of A into subsets S_1 and S_2, such that |sum(S_1) - sum(S_2)| = D.
Then the final element will definitely be \geq D.


As it turns out, it’s always possible to make the final element D.

Let’s find a partition of A into S_1 and S_2 such that |sum(S_1) - sum(S_2)| = D.
This is pretty much just the subset-sum problem, and can be solved in \mathcal{O}(N^2\cdot \max(A)) with dynamic programming.
In particular, let’s also have sum(S_1) \geq sum(S_2) (so we’re assigning + to the elements of S_1 and - to the rest).

Then, if the two elements chosen by Alice belong to the same subset, insert their sum into A.
Otherwise, insert their difference.
Note that you’ll have to assign a subset to the newly created element, this isn’t too hard:

  • If inserting the sum, assign it the same subset as the two elements it’s combining.
  • If inserting the difference, choose the subset of the larger element.

Each step can be simulated in \mathcal{O}(N\log N) time by just sorting the remaining array and erasing elements in linear time, for \mathcal{O}(N^2 \log N) overall after S_1 and S_2 are found.

TIME COMPLEXITY:

\mathcal{O}(N^2\cdot \max(A)) per testcase.

CODE:

Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

constexpr int S = 2e6 + 1;
short int DP[S + 1];

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    // input_checker input;
    int sum_n = 0;
    // int T = input.readInt(1, 2000);  input.readEoln();
    int T;  cin >> T;
    while(T-- > 0) {
        // int N = input.readInt(1, 2000); input.readEoln();
        int N;  cin >> N;
        sum_n += N;
        // auto A = input.readInts(N, 1, 2000);    input.readEoln();
        vector<int> A(N);
        for(auto &a: A)
            cin >> a;
        sort(A.begin(), A.end());
        int S = accumulate(A.begin(), A.end(), 0) / 2;
        for(int i = 0 ; i <= S ; ++i)
            DP[i] = -1;
        DP[0] = 0;
        vector<vector<bool>> Same(N, vector<bool>(N));
        vector<int> select(N);

        int s = 0;
        for(int i = 0 ; i < N ; ++i) {
            s += A[i];
            s = min(s, S);
            int y = A[i];
            for(int x = s ; x >= y ; --x) {
                if(DP[x] == -1 && DP[x - y] != -1)
                    DP[x] = i + 1;
            }
        }
        dbg(s);
        while(DP[s] == -1)
            --s;
        dbg(s);
        for(int i = N - 1 ; i >= 0 ; --i) {
            if(DP[s] == i + 1) {
                dbg(s, DP[s]);
                select[i] = 1;
                s -= A[i];
            }
        }
        dbg(select);

        for(int i = 0 ; i < N ; ++i)
            for(int j = 0 ; j < N ; ++j)
                Same[i][j] = select[i] == select[j];

        dbg(select, Same);

        vector<int> rep(N), inv(N), val(N);
        vector<vector<int>> to_upd(N);
        for(int i = 0 ; i < N ; ++i)
            to_upd[i] = {i + 1};

        auto update = [&](vector<vector<int>> &to_upd) {
            val.resize(to_upd.size());
            for(int i = 0 ; i < (int)to_upd.size() ; ++i) {
                val[i] = 0;
                for(auto &v: to_upd[i])
                    val[i] += (v < 0 ? -1 : 1) * A[abs(v) - 1];
            }
        };

        update(to_upd);

        for(int i = 1 ; i < N ; ++i) {
            // int u = input.readInt(1, N - i); input.readSpace();
            // int v = input.readInt(u + 1, N - i + 1);    input.readEoln();
            int u, v;   cin >> u >> v;
            --u, --v;
            vector<int> ord(to_upd.size());
            iota(ord.begin(), ord.end(), 0);
            sort(ord.begin(), ord.end(), [&](auto &i, auto &j) {
                return val[i] < val[j];
            });
            // dbg(val, ord, u, v);
            u = ord[u], v = ord[v];
            // dbg(u, v);
            vector<vector<int>> nupd;
            for(int i = 0 ; i < (int)to_upd.size() ; ++i) {
                if(u != i && v != i)
                    nupd.push_back(to_upd[i]);
            }
            nupd.push_back(to_upd[v]); {
                // dbg(nupd, to_upd[u], to_upd[v]);
                for(auto &x: to_upd[u])
                    nupd.back().push_back(Same[to_upd[u].front() - 1][to_upd[v].front() - 1] ? x : -x);
            }
            update(nupd);
            to_upd = nupd;
            // dbg(val, to_upd);
            if(val.back() < 0) {
                for(auto &x: to_upd.back())
                    x = -x;
                val.back() = -val.back();
                sort(to_upd.back().rbegin(), to_upd.back().rend());
            }
            cout << val.back() << '\n';
        }
    }

    assert(sum_n <= 2000);
    // input.readEof();

    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

using bs = bitset<750*750 + 10>;

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n+1);
        vector<bs> dp(n+1);
        dp[0][0] = 1;
        for (int i = 1; i <= n; ++i) {
            cin >> a[i];
            dp[i] = dp[i-1];
            dp[i] |= (dp[i-1] << a[i]);
        }
        vector<array<int, 2>> rem(n-1);
        for (int i = 0; i < n-1; ++i)
            cin >> rem[i][0] >> rem[i][1];

        int tot = accumulate(begin(a), end(a), 0);
        int subsm = 0;
        for (int i = 0; i <= tot/2; ++i) {
            if (dp[n][i] and abs(tot - 2*i) < abs(tot - 2*subsm)) subsm = i;
        }
        
        vector<int> mark(n+1);
        {
            int u = n, x = subsm;
            while (u >= 1) {
                if (a[u] <= x and dp[u-1][x-a[u]]) {
                    mark[u] = 1;
                    x -= a[u];
                }
                --u;
            }
            
        }
        vector<array<int, 2>> vals;
        for (int i = 1; i <= n; ++i) vals.push_back({a[i], mark[i]});
        for (int i = 0; i < n-1; ++i) {
            int p = rem[i][0], q = rem[i][1];

            sort(begin(vals), end(vals));
            --p, --q;
            auto [x, t1] = vals[p];
            auto [y, t2] = vals[q];
            vals.erase(vals.begin() + q);
            vals.erase(vals.begin() + p);
            if (t1 != t2) {
                cout << abs(x-y) << '\n';
                if (x > y) vals.push_back({abs(x-y), t1});
                else vals.push_back({abs(x-y), t2});
            }
            else {
                cout << x+y << '\n';
                vals.push_back({x+y, t1});
            }
        }
    }
}