# FAULTY_TREE - Editorial

Author: youknow_who
Tester: yash_daga
Editorialist: iceknight1093

TBD

# PREREQUISITES:

The process of Huffman coding

# PROBLEM:

Alice and Bob play a game, independently, on an array of N elements. At each step, a player will take two elements x and y from the remaining set, delete them, and add x+y to the set.
They also add x+y to their score.

Alice has a well-defined strategy: (except for the very first move) she will always use the newly combined element and combine it again.
Bob, meanwhile, can choose his moves as he likes.

Bob wins if his score is strictly less than Alice’s.
To achieve this, Bob can change elements of the array as he likes, before any moves are made.
Find the minimum number of changes so that Bob, with optimal play, can guarantee a win over Alice. Also find one such array.

# EXPLANATION:

The first step to solving this problem is analyzing Alice’s and Bob’s best strategies, and seeing when exactly Bob can beat Alice.

Alice's strategy

Let A_1 \leq A_2 \leq \ldots \leq A_N.

Alice’s best strategy is then as follows:

• Merge A_1 and A_2
• Merge A_1+A_2 and A_3
• Merge A_1+A_2+A_3 and A_4
\vdots

Alice’s best score is thus the sum of all prefix sums of (sorted) A, minus A_1.

Bob's strategy

Bob is free to make moves as he likes, so the optimal strategy is to always pick the two smallest remaining elements, Huffman-style. A proof of this optimality can be found here.

Now that we know their strategies, it’s also not hard to see when exactly Bob can beat Alice: there should be a move when Alice doesn’t choose the smallest two elements.
That is, in the sorted array A, there should exist an i such that A_1 + A_2 + \ldots + A_i is not among the two smallest elements; with the other elements being [A_{i+1}, A_{i+2}, \ldots, A_N].

In particular, notice that this can only happen when A_1 + \ldots + A_i \gt A_{i+2}.

So, Bob’s aim is to reach an array where this is the case.

This leads us to a solution that deals with a few cases:

• If N \leq 3, Bob can never win.
• First, if this condition is already satisfied, nothing more needs to be done — Bob already wins.
• Otherwise, we need to see what we can modify.

First, let’s check if one modification is enough.
If i is fixed, we want to see if A_1 + \ldots + A_i \gt A_{i+2}.
The best we can do is bring A_{i+2} down to A_{i+1} (since sorted order needs to be maintained), so if A_1 + \ldots + A_i \gt A_{i+1}, we’re done: change A_{i+2} to A_{i+1}.

If this is not satisfied for any 1 \leq i \lt N-1, it’s always possible to achieve our aim in 2 moves by setting A_2 = A_3 = A_4, at which point A_1 + A_2 \gt A_4.

Note that we sorted A to see which changes need to be made, but the actual changes need to be made on the original positions of A; so make sure to keep track of indices as well.

# TIME COMPLEXITY

\mathcal{O}(N \log N) per test case.

# CODE:

Setter's code (C++)
#include "bits/stdc++.h"
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

#define all(x)      x.begin(), x.end()
#define pb          push_back
#define sz(x)       (int)(x.size())
#define ll          long long
#define fi          first
#define se          second
#define lbd         lower_bound
#define ubd         upper_bound

template <typename T>
using ordered_set = tree<T, null_type,
less<T>, rb_tree_tag,
tree_order_statistics_node_update>;

const int MOD = 1e9 + 7;
const double eps = 1e-10;
const long long INF = 1e12;
const int N = 2e5 + 10;

void solve() {
int n;
cin >> n;
vector<pair<ll, int>> v(n);
for (int i = 0; i < n; i++) {
cin >> v[i].fi;
v[i].se = i;
}

sort(all(v));

ll pre = 0;
for (int i = 0; i + 1 < n; i++) {
if (pre > v[i + 1].fi) {
cout << "YES\n";
vector<ll> ans(n);
for (int j = 0; j < n; j++) {
ans[v[j].se] = v[j].fi;
}
for (int j = 0; j < n; j++) {
cout << ans[j] << ' ';
}
return;
}
pre += v[i].fi;
}

if (n <= 3) {
cout << "NO";
return;
}

cout << "YES\n";
if (v.fi + v.fi > v.fi) {
v.fi = v.fi;
vector<ll> ans(n);
for (int j = 0; j < n; j++) {
ans[v[j].se] = v[j].fi;
}
for (int j = 0; j < n; j++) {
cout << ans[j] << ' ';
}
return;
}
if (v.fi + v.fi > v.fi) {
v.fi = v.fi;
vector<ll> ans(n);
for (int j = 0; j < n; j++) {
ans[v[j].se] = v[j].fi;
}
for (int j = 0; j < n; j++) {
cout << ans[j] << ' ';
}
return;
}

pre = 0;
int mx = 0;
for (int i = 0; i + 1 < n; i++) {
if (pre > v[i].fi) {
v[i + 1].fi = v[i].fi;
vector<ll> ans(n);
for (int j = 0; j < n; j++) {
ans[v[j].se] = v[j].fi;
}
for (int j = 0; j < n; j++) {
cout << ans[j] << ' ';
}
return;
}

if (pre + v[mx + 1].fi - v[mx].fi > v[i + 1].fi) {
v[mx].fi = v[mx + 1].fi;
vector<ll> ans(n);
for (int j = 0; j < n; j++) {
ans[v[j].se] = v[j].fi;
}
for (int j = 0; j < n; j++) {
cout << ans[j] << ' ';
}
return;
}
if (v[i + 1].fi - v[i].fi > v[mx + 1].fi - v[mx].fi) {
mx = i;
}
pre += v[i].fi;
}

v[n - 2].fi = v[n - 1].fi;
v[n - 3].fi = v[n - 1].fi;
vector<ll> ans(n);
for (int j = 0; j < n; j++) {
ans[v[j].se] = v[j].fi;
}
for (int j = 0; j < n; j++) {
cout << ans[j] << ' ';
}
}

int main() {
ios::sync_with_stdio(false);
cin.tie(0);

int tt = 1;
cin >> tt;
while (tt--) {
solve();
cout << '\n';
}
return 0;
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long

void solve(int tc)
{
int n;
cin >> n;
int a[n];
pair<int,int> b[n];
for(int i=0;i<n;i++)
{
cin >> a[i];
b[i].first=a[i];
b[i].second=i;
}
if(n<=3)
{
cout << "NO\n";
return;
}
sort(b,b+n);
int pre[n];
pre=b.first;
for(int i=1;i<n;i++)
pre[i]=pre[i-1]+b[i].first;
bool ok = false;
for(int i=2;i<n;i++)
if(pre[i-2]>b[i].first)
ok=true;
if(!ok)
{
for(int i=2;i<n;i++)
{
if(pre[i-2]>b[i-1].first)
{
a[b[i].second]=b[i-1].first;
ok=true;
break;
}
}
}
if(!ok)
{
for(int i=3;i<n;i++)
{
if(pre[i-3]+b[i-1].first>b[i].first)
{
a[b[i-2].second]=b[i-1].first;
ok=true;
break;
}
}
}
if(!ok)
{
a[b.second]=b.first;
a[b.second]=b.first;
}
cout << "YES\n";
for(int i=0;i<n;i++)
cout << a[i] << " ";
cout << '\n';
}

int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int tc=1;
cin >> tc;
for(int ttc=1;ttc<=tc;ttc++)
solve(ttc);
return 0;
}

Editorialist's code (Python)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
indices = list(range(n))
indices.sort(key = lambda x: a[x])

def solve():
pref = 0
for i in range(1, n):
if pref > a[indices[i]]: return True
pref += a[indices[i-1]]

pref = a[indices]
for i in range(2, n):
x, y, z = indices[i-2], indices[i-1], indices[i]
if pref + a[y] - a[x] > a[z]:
a[x] = a[y]
return True
if pref > a[y]:
a[z] = a[y]
return True
pref += a[y]

if n <= 3: return False

a[indices] = a[indices] = a[indices]
return True
if solve(): print('Yes\n', *a)
else: print('No')