STRANGE_BST - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Testers: Nishank Suresh, Takuki Kurokawa
Editorialist: Nishank Suresh

DIFFICULTY:

3010

PREREQUISITES:

Sorting, (optional) Range-min data structures

PROBLEM:

There are N items, the i-th of which has weight W_i and value G_i. Upon arranging these items into a binary search tree, the key of the i-th element K_i equals its value, plus the sum of weights of all its descendants (excluding itself).

Across all possible arrangements into a BST, find the minimum possible value of \max_{i=1}^N K_i.

EXPLANATION:

Let’s determine what the optimal choice of root is. Since the weights are distinct, choosing the root will immediately partition the vertices into two sets based on their weights, and we can try to solve the problem on these reduced sets.

Computing its key, we see that

K_{root} = G_{root} + \sum_{\substack{i=1 \\ i \neq root}}^N W_i = (G_{root} - W_{root}) + \sum_{i = 1}^N W_i

Note that \sum_{i=1}^N W_i is a constant. In other words, the key value of the root is minimized by simply choosing whichever element u has the lowest value of G_u - W_u.

It turns out this choice is optimal! This means that we can simply find such a u, make it the root, then find the sets of vertices on its left and right and solve for them recursively.

Proof

Consider some element x, and let p be its parent. Our claim simply boils down to the fact that G_p - W_p \leq G_x - W_x.

So, suppose this weren’t the case, and G_p - W_p \gt G_x - W_x.
Without loss of generality, let x be the left child of p (the proof for when it is the right child is similar).
Essentially, we have something like the image below:


Here, S, L, R all represent subtrees and not single nodes.

Let’s rotate the BST so that x is now the parent of p. In order to keep the BST structure, we will end up with the following:

Now, note that:

  • The key values of any vertices not shown in the diagram don’t change at all.
  • The key values of any vertices lying in the subtrees S, L, R don’t change at all
  • The new key value of x is strictly less than the old key value of p, since G_x - W_x \lt G_p - W_p.
  • The new key value of p is strictly less than the old key value of p since it’s now the root of a smaller subtree.

So, the maximum K_i in the new tree is no larger than the maximum K_i in the old tree, which is exactly what we wanted.

It’s easy to see that only a finite number of rotations of this form can be made, so eventually we will end up with a tree of the form described above, proving optimality of our algorithm.

However, directly implementing this gives us a solution in \mathcal{O}(N^2) — we use \mathcal{O}(L) time to solve for a set of size L and then recurse to its children, so if the tree degenerates to a line this comes out to be \mathcal{O}(N) + \mathcal{O}(N-1) + \mathcal{O}(N-2) + \ldots = \mathcal{O}(N^2).

There are a couple of ways to optimize the solution, but they all depend on a simple observation: suppose we sort the elements in ascending order of W_i. Then, any set we solve for is a subarray of this sorted array.

Proof

This should be pretty easy to see, and can be proved inductively.
We start off with the entire range [1, N]. Once we choose a root r, we recurse into the weights that are \lt W_r and \gt W_r. These are exactly the ranges [1, r-1] and [r+1, N].
Now we apply the same process to these smaller ranges, and so on.

So, let’s sort the elements by W_i.

Now, to solve for a range [L, R], we need to:

  • Find L \leq u \leq R such that G_u - W_u is minimum in this range
  • Compute the key value of u
  • Recurse into [L, u-1] and [u+1, R].

The first and second steps are what consume the most time. However, note that they simply require range-min and range-sum respectively.

So, build some range-min data structure on the G_i - W_i values, and also the prefix sum array of W, say P. Then,

  • The first step can be done in \mathcal{O}(1) or \mathcal{O}(\log N) depending on the structure used (sparse table/segment tree), by simply finding rangemin(L, R).
  • The second step can be done in \mathcal{O}(1) by computing rangesum(L, R)

The runtime is thus \mathcal{O}(N) or \mathcal{O}(N\log N), since we perform an \mathcal{O}(1)/\mathcal{O}(\log N) operation at every vertex of an N-sized tree.

It is also possible to solve the problem without any data structures at all. Note that the tree we build after sorting is nothing but the Cartesian tree of the G_i - W_i, and building such a tree can be done in linear time — methods to do so are described in the link.

TIME COMPLEXITY

\mathcal{O}(N\log N) per test case.

CODE:

Setter's code (C++)
#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=100010;
const int LOGN=28;
int n;
int lg2[N];
ll  S[N];
int mn[N][LOGN],pos[N][LOGN];

struct nod
{
	int w,g;
	
	friend bool operator<(nod x,nod y)
	{
		return x.w<y.w;	
	}
}nd[N];

void init()
{
	for(int i=1;i<=n;i++) S[i]=S[i-1]+nd[i].w;
	for(int i=1;i<=n;i++) lg2[i]=(int)log2(i);
	for(int i=1;i<=n;i++) mn[i][0]=nd[i].g-nd[i].w,pos[i][0]=i;
	for(int i=1;i<LOGN;i++)
	{
		for(int j=1;j<=n;j++)
		{
			int p=j+(1<<(i-1));
			if(p>n) mn[j][i]=mn[j][i-1],pos[j][i]=pos[j][i-1];
			else    
			{
				if(mn[j][i-1]<mn[p][i-1]) mn[j][i]=mn[j][i-1],pos[j][i]=pos[j][i-1];
				else					  mn[j][i]=mn[p][i-1],pos[j][i]=pos[p][i-1];
			}
		}
	}
}

int getmnpos(int L,int R)
{
	int t=lg2[R-L+1];
	return mn[L][t]<mn[R-(1<<t)+1][t]?pos[L][t]:pos[R-(1<<t)+1][t];
} 

ll cal(int L,int R)
{
	int x=getmnpos(L,R);
	ll  ans=nd[x].g+S[x-1]-S[L-1]+S[R]-S[x];
	if(x!=L) ans=max(ans,cal(L,x-1));
	if(x!=R) ans=max(ans,cal(x+1,R));
	return ans;
}

int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++) scanf("%d",&nd[i].w);
	for(int i=1;i<=n;i++) scanf("%d",&nd[i].g);
	sort(nd+1,nd+n+1);
	init();
	printf("%lld\n",cal(1,n));

	return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

struct sparse {
    using T = pair<long long, int>;
    int n;
    int h;
    vector<vector<T>> table;

    T op(T x, T y) {
        return min(x, y);
    }

    sparse(const vector<T> &v) {
        n = (int) v.size();
        h = 32 - __builtin_clz(n);
        table.resize(h);
        table[0] = v;
        for (int j = 1; j < h; j++) {
            table[j].resize(n - (1 << j) + 1);
            for (int i = 0; i <= n - (1 << j); i++) {
                table[j][i] = op(table[j - 1][i], table[j - 1][i + (1 << (j - 1))]);
            }
        }
    }

    T get(int l, int r) {
        assert(0 <= l && l < r && r <= n);
        int k = 31 - __builtin_clz(r - l);
        return op(table[k][l], table[k][r - (1 << k)]);
    }
};

int main() {
    input_checker in;
    int n = in.readInt(1, 100000);
    in.readEoln();
    vector<long long> w(n), g(n);
    for (int i = 0; i < n; i++) {
        w[i] = in.readInt(1, 1e9);
        (i == n - 1 ? in.readEoln() : in.readSpace());
    }
    for (int i = 0; i < n; i++) {
        g[i] = in.readInt(1, 1e9);
        (i == n - 1 ? in.readEoln() : in.readSpace());
        g[i] -= w[i];
    }
    in.readEof();
    vector<pair<long long, long long>> wg(n);
    for (int i = 0; i < n; i++) {
        wg[i] = make_pair(w[i], g[i]);
    }
    sort(wg.begin(), wg.end());
    for (int i = 0; i < n; i++) {
        tie(w[i], g[i]) = wg[i];
    }
    vector<long long> pref(n + 1);
    for (int i = 0; i < n; i++) {
        pref[i + 1] = pref[i] + w[i];
    }
    vector<pair<long long, int>> a(n);
    for (int i = 0; i < n; i++) {
        a[i] = make_pair(g[i], i);
    }
    for (int i = 1; i < n; i++) {
        assert(w[i - 1] != w[i]);
    }
    sparse s(a);
    long long ans = 0;
    function<void(int, int)> Dfs = [&](int l, int r) {
        if (l >= r) {
            return;
        }
        int m = s.get(l, r).second;
        Dfs(l, m);
        Dfs(m + 1, r);
        ans = max(ans, g[m] + pref[r] - pref[l]);
    };
    Dfs(0, n);
    cout << ans << endl;
    return 0;
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

/**
 * Sparse Table
 * Source: kactl
 * Description: Given an idempotent function f that can be evaluated in O(T) and a static array V,
 *              finds f(V[L], V[L+1], ..., v[R-1]) in O(T) using O(nlogn) memory
 * Time: O(Tnlogn) precomputation, O(1) query
 * Note: Ranges are half-open, i.e, [L, R)
 */

template<class T>
struct SparseTable {
	T f(T a, T b) {return min(a, b);}
	vector<vector<T>> jmp;
	SparseTable(const vector<T>& V) : jmp(1, V) {
		for (int pw = 1, k = 1; pw * 2 <= (int)V.size(); pw *= 2, ++k) {
			jmp.emplace_back(V.size() - pw * 2 + 1);
			for (int j = 0; j < (int)jmp[k].size(); ++j)
				jmp[k][j] = f(jmp[k - 1][j], jmp[k - 1][j + pw]);
		}
	}
	T query(int a, int b) {
		assert(a < b); // or return unit if a == b
		int dep = 31 - __builtin_clz(b - a);
		return f(jmp[dep][a], jmp[dep][b - (1 << dep)]);
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int n; cin >> n;
	vector<array<int, 2>> v(n);
	for (int i = 0; i < n; ++i) cin >> v[i][0];
	for (int i = 0; i < n; ++i) cin >> v[i][1];
	sort(begin(v), end(v));
	vector<array<int, 2>> vals(n);
	for (int i = 0; i < n; ++i) vals[i] = {v[i][1] - v[i][0], i};
	SparseTable ST(vals);
	vector<ll> pref(n);
	for (int i = 0; i < n; ++i) {
		pref[i] = v[i][0];
		if (i) pref[i] += pref[i-1];
	}
	ll ans = 0;
	auto solve = [&] (const auto &self, int L, int R) {
		if (L > R) return;
		auto mn = ST.query(L, R+1);
		int root = mn[1];
		ll sum = pref[R];
		if (L) sum -= pref[L-1];
		auto [w, g] = v[root];
		ans = max(ans, sum + g - w);
		self(self, L, root-1);
		self(self, root+1, R);
	};
	solve(solve, 0, n-1);
	cout << ans << '\n';
}