BAL01 - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Srikkanth R
Tester: Tejas Pandey and Utkarsh Gupta
Editorialist: Taranpreet Singh

DIFFICULTY

Cakewalk

PREREQUISITES

None

PROBLEM

Given a string of length N containing only characters '0', '1' and '?', replace each '?' with either '0' or '1' so as to minimize the difference between the number of occurrences of '0' and '1'.

QUICK EXPLANATION

  • First compute the number of occurrences of '0' and '1' and '?'. Let’s denote these by c_0, c_1 and c_?
  • If c_? \leq |c_0-c_1|, then minimum difference we can achieve is |c_0-c_1| - c_?.
  • Otherwise, We can achieve minimum difference N \bmod 2.
  • To construct this string after computing c_0 and c_1, we can iterate from left to right, and every time we get an occurrence of '?', we replace it with '0' if c_0 \leq c_1, and with '1' otherwise. We also update c_0 or c_1 accordingly.

EXPLANATION

Let’s try to compute the minimum difference possible first. Let’s compute the number of occurrences of '0' and '1' and '?'. Let’s denote these by c_0, c_1 and c_?.

Now, replacing any '?' with '0' or '1' is same as decreasing c_? and incrementing c_0 or c_1. We need to do this exactly c_? times such that at the end, |c_0-c_1| is minimized.

It is easy to see that it is optimal to increment the smaller of c_0 and c_1 at each operation because only that way we can reduce the gap between c_0 and c_1. When c_0 = c_1, then we can increase any.

Difference obtained

Assuming c_? \leq |c_1-c_0|. Then the final difference would be |c_1 - c_0| - c_?, as all operations would go toward reducing the original gap. Otherwise, we would spend |c_0-c_1| operations to reduce the initial gap, and then every two operations would alternatingly increment c_0 and c_1. This way, the final difference would be N \bmod 2.

Replacing '?' in the given string

After computing c_0 and c_1, we can iterate over the string and at each occurrence of '?', we replace it with less frequent character among '0' and '1' and update c_0 or c_1 accordingly.

TIME COMPLEXITY

The time complexity is O(N) per test case.

SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
typedef vector<int> vint;
typedef vector<vector<int>> mat;
#define LL long long
LL seed = chrono::steady_clock::now().time_since_epoch().count();
mt19937_64 rng(seed);
#define rand(l, r) uniform_int_distribution<LL>(l, r)(rng)
clock_t start = clock();

#define getchar getchar_unlocked

long long readInt(char endd) {
    long long ret = 0;
    char c = getchar();
    while (c != endd) {
        ret = (ret * 10) + c - '0';
        c = getchar();
    }
    return ret;
}

long long readInt(long long L, long long R, char endd) {
    long long ret = readInt(endd);
    assert(ret >= L && ret <= R);
    return ret;
}
long long readIntSp(long long L, long long R) {
    return readInt(L, R, ' ');
}
long long readIntLn(long long L, long long R) {
    return readInt(L, R, '\n');
}

string readString(int l, int r) {
    string ret = "";
    char c = getchar();
    while (c == '0' || c == '?' || c == '1') {
        ret += c;
        c = getchar();
    }
    assert((int)ret.size() >= l && (int)ret.size() <= r);
    return ret;
}

const int TMAX = 1'00'000;
const int SUM_N = 1'00'000;

int sum_n = 0;
void solve() {
    string s = readString(1, SUM_N);
    sum_n += s.size();
    int z = 0, o = 0;
    for (auto c : s) {
        if (c == '0') z++;
        else if (c == '1') o++;
    }
    for (auto &c : s) if (c == '?') {
        if (o < z) {
            c = '1';
            ++o;
        } else {
            c = '0';
            ++z;
        }
    }
    cout << s << '\n';
}

int main() {
    int T = readIntLn(1, TMAX);
    while (T--) {
        solve();
    } 
    assert(sum_n <= SUM_N);
    // assert(getchar() == EOF);
    cerr << fixed << setprecision(10);
    cerr << (clock() - start) / ((long double)CLOCKS_PER_SEC) << " secs\n";
    return 0;
}
Tester's Solution 1
#include <bits/stdc++.h>
using namespace std;


/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}


/*
------------------------Main code starts here----------------------------------
*/

const int MAX_T = 100000;
const int MAX_N = 100000;
const int MAX_SUM_LEN = 100000;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)

int sum_len=0;

void solve()
{
    string s=readStringLn(1, MAX_N);
    int n=s.length();
    sum_len+=n;
    assert(sum_len <= MAX_SUM_LEN);
    int c[2] = {0, 0};
    for(int i = 0; i < n; i++) if(s[i] != '?') c[s[i] - '0']++;
    for(int i = 0; i < n; i++) if(s[i] == '?') s[i] = c[0] <= c[1]?'0':'1', (c[0] <= c[1]?c[0]++:c[1]++);
    cout << s << "\n";
}

signed main()
{
    fast;
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif


    int t = readIntLn(1, MAX_T);

    for(int i=1;i<=t;i++)
    {
        solve();
    }

    assert(getchar() == -1);
}
Tester's Solution 2
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 100000;
const int MAX_N = 100000;
const int MAX_SUM_LEN = 100000;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)

int sumN=0;

void solve()
{
    string s=readString(1,MAX_N,'\n');
    int n=s.length();
    sumN+=n;
    assert(sumN<=MAX_SUM_LEN);
    int cnt0=0,cnt1=0;
    for(int i=0;i<n;i++)
    {
        if(s[i]=='0')
            cnt0++;
        else if(s[i]=='1')
            cnt1++;
    }
    for(int i=0;i<n;i++)
    {
        if(s[i]=='?')
        {
            if(cnt0<=cnt1)
            {
                cnt0++;
                s[i]='0';
            }
            else
            {
                cnt1++;
                s[i]='1';
            }
        }
    }
    cout<<s<<'\n';
}

signed main()
{
    fast;
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    
    
    int t = readInt(1,MAX_T,'\n');
    
    for(int i=1;i<=t;i++)
    {    
        solve();
    }
    
    assert(getchar() == -1);
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class BAL01{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni();
        char[] ch = n().toCharArray();
        int c0 = 0, c1 = 0;
        for(int i = 0; i< ch.length; i++){
            switch(ch[i]){
                case '0':c0++;break;
                case '1':c1++;break;
            }
        }
        for(int i = 0; i< ch.length; i++){
            if(ch[i] != '?')continue;
            if(c0 <= c1){
                ch[i] = '0';
                c0++;
            }else{
                ch[i] = '1';
                c1++;
            }
        }
        StringBuilder s = new StringBuilder();
        for(char c:ch)s.append(c);
        pn(s.toString());
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new BAL01().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

Hi,
I am not sure why my code is giving wrong answer.
Can you please provide the testcases for the question, so that I can check where it might be failing, Thanks in advance.

/* package codechef; // don't place package name! */

import java.util.*;
import java.lang.*;
import java.io.*;

/* Name of the class has to be "Main" only if the class is public. */
class Codechef
{
	public static void main (String[] args) throws java.lang.Exception
	{
		Scanner sc = new Scanner(System.in);
		int T = sc.nextInt();
		for(int i=0;i<T;i++){
			int len = sc.nextInt();
			sc.nextLine();
			String s = sc.nextLine();
			Object result = new Codechef().balance(len,s);
			System.out.println(result);
		}
	}
	public Object balance(int len,String input){
		int zeros = 0;
		int ones = 0;
		List<Integer> qCount = new ArrayList<>();
		for(int i=0;i<len;i++){
			switch(input.charAt(i)){
				case '1':
					ones++;
					break;
				case '0':
					zeros++;
					break;
				default:
					qCount.add(i);
					break;
			}
		}
		int qCnt = qCount.size();
		char[] c = input.toCharArray();
		if(ones==zeros){
			for(int j=0;j<qCnt;j++){
				c[qCount.get(j)]='1';
				c[qCount.get(qCnt-1-j)]='0';
			}
		}else{
			char d = ones>zeros?'0':'1';
			int diff = Math.abs(ones-zeros);
			for(int j=0;j<qCnt;j++){
				c[qCount.get(j)] = d;
			}
		}
		return new String(c);
	}
}