SUBARRAYREM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Jeevan Jyot Singh
Testers: Tejas Pandey, Hriday
Editorialist: Nishank Suresh

DIFFICULTY:

TBD

PREREQUISITES:

None

PROBLEM:

You have a boolean array A. In one move, you can choose any subarray of A with length at least 2, add its bitwise XOR to your score, and delete this subarray from A.

Find the maximum possible score you can achieve.

EXPLANATION:

Since the array is boolean, each move increases our score by either zero or one.
Note that making a move that increases our score by zero is pointless, so we really want to count the maximum number of 1 moves that we can make.
Let’s call a subarray with xor 1 a good subarray.

Further, it’s better to use shorter subarrays if possible, since that gives us more freedom in the future. The shortest possible good subarrays we can choose are [0, 1] and [1, 0], so let’s keep choosing these as long as its possible to do so.

Suppose we can’t choose any more subarrays of this kind. Then, every element of A must be the same, i.e, A consists of all 0's or all 1's.

  • In the first case, all 0's, nothing more can be done, since any remaining subarray has xor 0.
  • In the second case, we still have good subarrays: anything with odd length is good, i.e, [1, 1, 1], [1, 1, 1, 1, 1], \ldots
    • Suppose the length of A is now K. Then, the best we can do is \lfloor \frac{K}{3} \rfloor subarrays, each of the form [1, 1, 1]. So, add this value to the answer.

This gives us the final solution:

  • While the array contains both 0's and 1's, remove one 0 and one 1 from it, and increase the answer by 1
  • When the array contains only a single type of character, if it’s 1 and there are K of them, add \lfloor \frac{K}{3} \rfloor to the answer.

This can be done in \mathcal{O}(1) by knowing the counts of 0's and 1's, although simulation in \mathcal{O}(N) will still pass.

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++, formula)
#ifdef WTSH
    #include <wtsh.h>
#else
    #include <bits/stdc++.h>
    using namespace std;
    #define dbg(...)
#endif

#define int long long
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
    long long x = 0;
    int cnt = 0, fi = -1;
    bool is_neg = false;
    while(true)
    {
        char g = getchar();
        if(g == '-')
        {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if('0' <= g && g <= '9')
        {
            x *= 10;
            x += g - '0';
            if(cnt == 0)
                fi = g - '0';
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if(g == endd)
        {
            if(is_neg)
                x = -x;
            if(!(l <= x && x <= r))
            {
                cerr << "L: " << l << ", R: " << r << ", Value Found: " << x << '\n';
                assert(false);
            }
            return x;
        }
        else
        {
            assert(false);
        }
    }
}

string readString(int l, int r, char endd)
{
    string ret = "";
    int cnt = 0;
    while(true)
    {
        char g = getchar();
        assert(g != -1);
        if(g == endd)
            break;
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
void readEOF() { assert(getchar() == EOF); }

vector<int> readVectorInt(int n, long long l, long long r)
{
    vector<int> a(n);
    for(int i = 0; i < n - 1; i++)
        a[i] = readIntSp(l, r);
    a[n - 1] = readIntLn(l, r);
    return a;
}

// -------------------- Input Checker End --------------------

int sumN = 0;

void solve()
{
    int n = readIntLn(1, 1e5);
    sumN += n;
    vector<int> a = readVectorInt(n, 0, 1);
    int cnt0 = count(a.begin(), a.end(), 0);
    int cnt1 = count(a.begin(), a.end(), 1);
    int take = min(cnt0, cnt1);
    int ans = take;
    cnt0 -= take, cnt1 -= take;
    ans += cnt1 / 3;
    cout << ans << endl;
}

int32_t main()
{
    ios::sync_with_stdio(0); 
    cin.tie(0);
    int T = readIntLn(1, 1e5);
    for(int tc = 1; tc <= T; tc++)
    {
        // cout << "Case #" << tc << ": ";
        solve();
    }
    assert(sumN <= 2e5);
    readEOF();
    return 0;
}
Editorialist's code (Python, simulation)
for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    cur = []
    ans = 0
    for x in a:
        if len(cur) == 0 or x == cur[-1]: cur.append(x)
        else:
            ans += 1
            cur.pop()
    if len(cur) > 0 and cur[0] == 1:
        ans += len(cur)//3
    print(ans)
2 Likes

Q :My code is giving wrong answer ?

void solve()
{

ll n;
cin>>n;
vectora(n,0);
for(int i=0;i<n;++i)
cin>>a[i];
int c=0;
int l=0;
int r=1;
int xe=a[0];
while(r<n){
xe=xe^a[r];
if(xe==1 && l!=r){
c+=1;
if(r+1<n)
xe=a[r+1];
else
xe=0;
l=r+1;
}
r++;
}
cout<<c<<endl;
return ;
}strong text

1 Like

all the accepted codes are giving wrong answers for input 0001111
plz review

1 Like

yes i have the same doubt like when N==1 we cannot select any L and R as L<R (L != R) so for case
when given array is 0 1 0 1 1 ans should be 1 but the solution which got accepted have ans as 2

Following is my accepted code: Tough I had this solution in contest I was always getting WA due to above confusion

public static void main (String[] args) throws java.lang.Exception
	{
		// your code goes here
		Scanner in = new Scanner(System.in);
		int T = in.nextInt();
		while(T-- > 0) {
		    int N = in.nextInt();
		    int[] arr = new int[N];
		    int cnt0 = 0;
		    int cnt1 = 0;
		    for(int i = 0; i < N; i++) {
		        arr[i] = in.nextInt();
		        if(arr[i] == 0) {
		            cnt0++;
		        } else {
		            cnt1++;
		        }
		    }
		    
		    int ans = 0;
		    if(cnt0 >= cnt1) {
		        ans = cnt1;
		    } else {
		        ans += cnt0;
		        cnt1 -= cnt0;
		        ans += cnt1/3;
		      //  cnt1 %=3;
		      //  if(cnt1 == 1 && N != 1) {
		      //      ans--;
		      //  }
		    }
		    System.out.println(ans);
		}
	}

Hi, my solution here with Rust uses this same algorithm, but I get a timeout for one test case (by a small margin). How can this be further optimized?

fn main() {
    let mut buff = String::new();
    std::io::stdin().read_line(&mut buff).unwrap();
    let t = buff.trim().parse::<u32>().unwrap();
    for _ in 0..t {
        buff.clear();
        std::io::stdin().read_line(&mut buff).unwrap();
        buff.clear();
        std::io::stdin().read_line(&mut buff).unwrap();
        let mut a = buff
            .trim()
            .chars()
            .filter(|x| !x.is_whitespace())
            .map(|x| if x == '0' { false } else { true })
            .collect::<Vec<bool>>();
        let mut score = 0;
        'out: loop {
            if a.iter().all(|x| *x == false) {
                break;
            } else if a.iter().all(|x| *x == true) {
                if a.len() > 2 {
                    score += a.len() as u32 / 3;
                    break;
                } else {
                    break;
                }
            } else {
                //println!("{:?}", a);
                let mut i = 0;
                loop {
                    if i > a.len() - 2 {
                        break;
                    }
                    //println!("x{i}");
                    if a[i] != a[i + 1] {
                        a.remove(i);
                        a.remove(i);
                        score += 1;
                        //println!("{:?}", a);
                        continue 'out;
                    }
                    i += 1;
                }
            }
        }
        println!("{}", score);
    }
}

I guess I needn’t have removed any elements. :person_facepalming:. I think that made it O(n²) instead of linear?

anyone please tell the test case at which my solution is failing

void solve(){

int n;
cin>>n;

int a[n+1];

for(int i=0 ; i<n ; i++){
cin>>a[i];
}

int ans = 0;
int carry = 1;
vectormark(n+1 , -1);

for(int i=0 ; i<n ; i++){

if((i+1)<n){

if(a[i]==1 && a[i+1]==0 && mark[i]==-1 && mark[i+1]==-1){
ans++;
mark[i] = 0;
mark[i+1] = 0;
}

else if(a[i]==1 && a[i+1]==1 && mark[i]==-1){
carry++;
mark[i] = 0;

}

else if(a[i]==0 && a[i+1]==1 && mark[i]==-1 && mark[i+1]==-1){
    mark[i+1] = 0;
    mark[i] = 0;
    ans++;
}

if(carry%2!=0 && carry>1){
carry = 1;
mark[i] = 0;
ans++;
}

// cout<<carry<<nl;

}

cout<<ans<<nl;

}

No it would give correct answer which is 3
as first it would take 0 0 (0,1)1 1 1 [0,1] and add 1 to score and remove that subarray from that bigger array.
now the array will become 0 0 1 1 1 and same process will be executed

Yes, directly removing an element from the middle of an array is generally going to require \mathcal{O}(N) time, which makes your solution \mathcal{O}(N^2).

It’s definitely possible to simulate the process in \mathcal{O}(N) or similar, though you need to be a bit smart about it. One way of doing it can be seen in my Python code linked in the editorial.

Count the number of zeroes = z, number of ones = o in the input string

If (z>=o) answer = o.
If (z < o) answer = z +((o-z)/3)

If z>=o, It’s always possible to map all ones to zeroes, by choosing a subarray (01) or (10) and deleting it, and continuing the process repeatedly.

If z < o , then all zeroes can be mapped to ones , the left out ones can be grouped in triplets and score can be found.

My solution: CodeChef

Simple python code I submitted during contest

for _ in range(int(input())):
    n = int(input())
    arr = list(map(int, input().split()))
    cnt0 = arr.count(0)
    cnt1 = arr.count(1)
    if(cnt1 <= n//2):
        print(cnt1)
        
    else:
        ans = cnt0
        ans += (cnt1-cnt0)//3
        print(ans)