LEXMINBIN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: munch_01
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Prefix sums

PROBLEM:

You’re given a binary string S.
In one move, you can choose two adjacent unequal characters and delete them from S.
Find the lexicographically minimum final string.

EXPLANATION:

Observe that if we have a string containing an equal number of zeros and ones, we can always delete all of its characters - simply repeatedly keep deleting some pair of adjacent characters (which will always exist).

Let’s try to build the answer greedily from left to right.
Let A denote the answer string.

If S contains an equal number of zeros and ones, as noted above we can delete them all which is clearly optimal, and A will be empty.
Otherwise, it’s impossible to delete everything, and A will definitely have length \geq 1.
Since our objective is to make A be as small (lexicographically) as possible, it would be best if the first character of A is a 0.

For the first character of A to be 0, there should exist an index i such that:

  • S_i = 0
  • Among the characters S_1, S_2, \ldots, S_{i-1}, there are an equal number of zeros and ones.

Essentially, there should exist a prefix that can be completely deleted, followed by a 0.

We now have two details left to figure out:

  1. There can exist multiple such indices (for example, consider S = 010101\ldots, where every odd index satisfies the property) - so we need to decide on which of them to choose.
  2. We also need to decide on what to do when no valid index exists.

The first problem can be solved greedily: it’s always optimal to choose the leftmost valid index.
The second is in fact quite similar to the first: we cannot choose a 0 at all, so we must choose a 1.
But, any index we choose to be the first 1 can only do so if we delete a prefix containing an equal number of zeros and ones; and once again it’s optimal to choose the leftmost valid index (which in this case is just the first index).

Proof of greedy's correctness

Observe that any final string that we obtain, is done so by choosing some indices i_1, i_2, \ldots, i_k such that, between i_j and i_{j+1}, there are an equal number of zeros and ones (allowing for everything inbetween to be deleted).

Consider an optimal solution. Suppose i_1 isn’t the leftmost valid index (meaning there exists an index i \lt i_1 such that S_i = S_{i_1}, with an equal number of zeros and ones before index i.

Observe that the sequence (i, i_2, i_3, \ldots, i_k) is also a valid sequence, because:

  • Before index i and after index i_2, there is no issue: things remain the same.
  • There will be an equal number of zeros and ones from index i to index i_1 - 1.
  • There are an equal number of zeros and ones from index i_1+1 to index i_2 - 1.
  • So, there’s an equal number of zeros and ones from index i+1 to index i_2 - 1 (combine the above two sets, exclude index i, and include index i_1 - but we presumed that S_i = S_{i_1}).
    Thus, replacing i with i_1 results in a solution that’s still valid.

So, any optimal solution can be converted to one where the leftmost valid index is chosen.
Now repeat this argument for the remaining part.


The above discussion either eliminates the entire string (so we’re done), or gives us the first element of A (and its corresponding index in S).
We’re now left with a suffix of S, simply repeat the procedure for this suffix.

The process takes at most N steps (since at worst, we’ll have A = S), so as long as each step can be done quickly enough the algorithm will be fast.
This is not too hard:

  • We need to check whether a suffix has an equal number of zeros and ones - this can be precomputed.
  • Then, for a suffix, we want to find the first instance of a 0 that occurs after deleting an equal number of zeros and ones.
    For this, we use a small trick: replace the 0's with -1's, so that “equal number of zeros and ones” turns into “subarray with sum 0”.
    • Dealing with subarray sums is nice because we can now think in terms of prefix sums.
    • Specifically, if you’re at the suffix starting at index x, you’re then looking for the smallest index i \geq x such that S_i = 0 and P_i = P_x - 1 (where P is the prefix sum array).
      This can be found by, for example, maintaining a list of indices corresponding to each prefix sum and binary searching on it.
      Alternately, looking at minimum (or maximum, depending on implementations) prefix sums on each suffix allow for an \mathcal{O}(N) solution - see the author’s code for this.

TIME COMPLEXITY:

\mathcal{O}(N) or \mathcal{O}(N\log N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

string fast(string s){
    int n = s.length();
    vector <int> p(n + 1, 0);
    for (int i = 1; i <= n; i++){
        if (s[i - 1] == '0') p[i] = +1;
        else p[i] = -1;
        p[i] += p[i - 1];
    }
    
    string ans = "";
    if (p[n] >= 0){
        for (int i = 1; i <= p[n]; i++){
            ans += "0";
        }
        return ans;
    }
    
    vector <int> mx(n + 1, 0);
    
    mx[n] = p[n];
    for (int i = n - 1; i >= 0; i--){
        mx[i] = max(mx[i + 1], p[i]);
    }
    
    int f = 1;
    while (f <= n){
        if (p[n] == p[f - 1]) break;
        if (mx[f] - p[f - 1] <= 0){
            f++;
            ans += "1";
            continue;
        }
        
        int sum = 0;
        while (sum <= 0){
            if (s[f - 1] == '0') sum++;
            else sum--;
            f++;
        }
        
        ans += "0";
    }
    return ans;
}

void Solve() 
{
    int n; cin >> n;
    string s; cin >> s;
    string a = fast(s);
    if (a.size() == 0){
        cout << "EMPTY\n";
    } else {
        cout << a << "\n";
    }
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...)
#endif

#ifndef LOCAL
struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};
#else

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
      string X; cin >> X;
      return X;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res;  cin >> res;
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
    }

    void readEoln() {
    }

    void readEof() {
    }
};
#endif

int32_t main() {
  ios_base::sync_with_stdio(0);
  cin.tie(0);

  input_checker inp;
  int T = inp.readInt(1, (int)1e4), NN = 0; inp.readEoln();
  while(T-- > 0) {
    int N = inp.readInt(1, (int)5e5); inp.readEoln();
    NN += N;
    string S = inp.readString(N, N, "01");  inp.readEoln();
    vector<int> pf(N + 1);
    for(int i = 0 ; i < N ; ++i) {
      pf[i + 1] = pf[i] + (S[i] == '0' ? 1 : -1);
    }
    vector<int> nx(N + 1, N + 1), stk;
    for(int i = N ; i >= 0 ; --i) {
      while(!stk.empty() && pf[stk.back()] <= pf[i])
        stk.pop_back();
      if(!stk.empty())
        nx[i] = stk.back();
      stk.push_back(i);
    }

    int i = 0;
    string res;
    dbg(nx);
    while(i < N) {
      if(pf[i] == pf[N])  break;
      dbg(i, res);
      if(nx[i] == N + 1) {
        res.push_back('1'); ++i;
      } else {
        res.push_back('0');  i = nx[i]; // 0 -1 0 
      }
    }
    cout << (res.empty() ? "EMPTY" : res) << '\n';
  }
  assert(NN <= (int)5e5);
  inp.readEof();
  
  return 0;
}
Editorialist's code (Python)
from collections import defaultdict
from bisect import bisect_left
for _ in range(int(input())):
    n = int(input())
    s = input()
    val = 0
    pos = defaultdict(lambda: defaultdict(list))
    for i, c in enumerate(s):
        val += 1 if c == '0' else -1
        pos[c][val].append(i)
    
    ans = []
    ptr, cur = 0, 0
    while cur != val:
        # try to get a 0 -> cur increases by 1
        if cur+1 in pos['0'] and pos['0'][cur+1][-1] >= ptr:
            ans.append('0')
            cur += 1
        else:
            ans.append('1')
            cur -= 1
        what = bisect_left(pos[ans[-1]][cur], ptr)
        ptr = pos[ans[-1]][cur][what] + 1
        
    if len(ans) == 0: print('empty')
    else: print(*ans, sep = '')