# SUBARRAYREM - Editorial

Author: Jeevan Jyot Singh
Testers: Tejas Pandey, Hriday
Editorialist: Nishank Suresh

TBD

None

# PROBLEM:

You have a boolean array A. In one move, you can choose any subarray of A with length at least 2, add its bitwise XOR to your score, and delete this subarray from A.

Find the maximum possible score you can achieve.

# EXPLANATION:

Since the array is boolean, each move increases our score by either zero or one.
Note that making a move that increases our score by zero is pointless, so we really want to count the maximum number of 1 moves that we can make.
Let’s call a subarray with xor 1 a good subarray.

Further, it’s better to use shorter subarrays if possible, since that gives us more freedom in the future. The shortest possible good subarrays we can choose are [0, 1] and [1, 0], so let’s keep choosing these as long as its possible to do so.

Suppose we can’t choose any more subarrays of this kind. Then, every element of A must be the same, i.e, A consists of all 0's or all 1's.

• In the first case, all 0's, nothing more can be done, since any remaining subarray has xor 0.
• In the second case, we still have good subarrays: anything with odd length is good, i.e, [1, 1, 1], [1, 1, 1, 1, 1], \ldots
• Suppose the length of A is now K. Then, the best we can do is \lfloor \frac{K}{3} \rfloor subarrays, each of the form [1, 1, 1]. So, add this value to the answer.

This gives us the final solution:

• While the array contains both 0's and 1's, remove one 0 and one 1 from it, and increase the answer by 1
• When the array contains only a single type of character, if it’s 1 and there are K of them, add \lfloor \frac{K}{3} \rfloor to the answer.

This can be done in \mathcal{O}(1) by knowing the counts of 0's and 1's, although simulation in \mathcal{O}(N) will still pass.

# TIME COMPLEXITY

\mathcal{O}(N) per test case.

# CODE:

Setter's code (C++, formula)
#ifdef WTSH
#include <wtsh.h>
#else
#include <bits/stdc++.h>
using namespace std;
#define dbg(...)
#endif

#define int long long
#define endl "\n"
#define sz(w) (int)(w.size())
using pii = pair<int, int>;

// -------------------- Input Checker Start --------------------

long long readInt(long long l, long long r, char endd)
{
long long x = 0;
int cnt = 0, fi = -1;
bool is_neg = false;
while(true)
{
char g = getchar();
if(g == '-')
{
assert(fi == -1);
is_neg = true;
continue;
}
if('0' <= g && g <= '9')
{
x *= 10;
x += g - '0';
if(cnt == 0)
fi = g - '0';
cnt++;
assert(fi != 0 || cnt == 1);
assert(fi != 0 || is_neg == false);
assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
}
else if(g == endd)
{
if(is_neg)
x = -x;
if(!(l <= x && x <= r))
{
cerr << "L: " << l << ", R: " << r << ", Value Found: " << x << '\n';
assert(false);
}
return x;
}
else
{
assert(false);
}
}
}

string readString(int l, int r, char endd)
{
string ret = "";
int cnt = 0;
while(true)
{
char g = getchar();
assert(g != -1);
if(g == endd)
break;
cnt++;
ret += g;
}
assert(l <= cnt && cnt <= r);
return ret;
}

long long readIntSp(long long l, long long r) { return readInt(l, r, ' '); }
long long readIntLn(long long l, long long r) { return readInt(l, r, '\n'); }
string readStringSp(int l, int r) { return readString(l, r, ' '); }
string readStringLn(int l, int r) { return readString(l, r, '\n'); }
void readEOF() { assert(getchar() == EOF); }

vector<int> readVectorInt(int n, long long l, long long r)
{
vector<int> a(n);
for(int i = 0; i < n - 1; i++)
a[i] = readIntSp(l, r);
a[n - 1] = readIntLn(l, r);
return a;
}

// -------------------- Input Checker End --------------------

int sumN = 0;

void solve()
{
int n = readIntLn(1, 1e5);
sumN += n;
vector<int> a = readVectorInt(n, 0, 1);
int cnt0 = count(a.begin(), a.end(), 0);
int cnt1 = count(a.begin(), a.end(), 1);
int take = min(cnt0, cnt1);
int ans = take;
cnt0 -= take, cnt1 -= take;
ans += cnt1 / 3;
cout << ans << endl;
}

int32_t main()
{
ios::sync_with_stdio(0);
cin.tie(0);
int T = readIntLn(1, 1e5);
for(int tc = 1; tc <= T; tc++)
{
// cout << "Case #" << tc << ": ";
solve();
}
assert(sumN <= 2e5);
return 0;
}

Editorialist's code (Python, simulation)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
cur = []
ans = 0
for x in a:
if len(cur) == 0 or x == cur[-1]: cur.append(x)
else:
ans += 1
cur.pop()
if len(cur) > 0 and cur[0] == 1:
ans += len(cur)//3
print(ans)

2 Likes

Q :My code is giving wrong answer ?

void solve()
{

ll n;
cin>>n;
vectora(n,0);
for(int i=0;i<n;++i)
cin>>a[i];
int c=0;
int l=0;
int r=1;
int xe=a[0];
while(r<n){
xe=xe^a[r];
if(xe==1 && l!=r){
c+=1;
if(r+1<n)
xe=a[r+1];
else
xe=0;
l=r+1;
}
r++;
}
cout<<c<<endl;
return ;
}strong text

1 Like

all the accepted codes are giving wrong answers for input 0001111
plz review

1 Like

yes i have the same doubt like when N==1 we cannot select any L and R as L<R (L != R) so for case
when given array is 0 1 0 1 1 ans should be 1 but the solution which got accepted have ans as 2

Following is my accepted code: Tough I had this solution in contest I was always getting WA due to above confusion

public static void main (String[] args) throws java.lang.Exception
{
// your code goes here
Scanner in = new Scanner(System.in);
int T = in.nextInt();
while(T-- > 0) {
int N = in.nextInt();
int[] arr = new int[N];
int cnt0 = 0;
int cnt1 = 0;
for(int i = 0; i < N; i++) {
arr[i] = in.nextInt();
if(arr[i] == 0) {
cnt0++;
} else {
cnt1++;
}
}

int ans = 0;
if(cnt0 >= cnt1) {
ans = cnt1;
} else {
ans += cnt0;
cnt1 -= cnt0;
ans += cnt1/3;
//  cnt1 %=3;
//  if(cnt1 == 1 && N != 1) {
//      ans--;
//  }
}
System.out.println(ans);
}
}


Hi, my solution here with Rust uses this same algorithm, but I get a timeout for one test case (by a small margin). How can this be further optimized?

fn main() {
let mut buff = String::new();
let t = buff.trim().parse::<u32>().unwrap();
for _ in 0..t {
buff.clear();
buff.clear();
let mut a = buff
.trim()
.chars()
.filter(|x| !x.is_whitespace())
.map(|x| if x == '0' { false } else { true })
.collect::<Vec<bool>>();
let mut score = 0;
'out: loop {
if a.iter().all(|x| *x == false) {
break;
} else if a.iter().all(|x| *x == true) {
if a.len() > 2 {
score += a.len() as u32 / 3;
break;
} else {
break;
}
} else {
//println!("{:?}", a);
let mut i = 0;
loop {
if i > a.len() - 2 {
break;
}
//println!("x{i}");
if a[i] != a[i + 1] {
a.remove(i);
a.remove(i);
score += 1;
//println!("{:?}", a);
continue 'out;
}
i += 1;
}
}
}
println!("{}", score);
}
}


I guess I needn’t have removed any elements. . I think that made it O(n²) instead of linear?

anyone please tell the test case at which my solution is failing

void solve(){

int n;
cin>>n;

int a[n+1];

for(int i=0 ; i<n ; i++){
cin>>a[i];
}

int ans = 0;
int carry = 1;
vectormark(n+1 , -1);

for(int i=0 ; i<n ; i++){

if((i+1)<n){

if(a[i]==1 && a[i+1]==0 && mark[i]==-1 && mark[i+1]==-1){
ans++;
mark[i] = 0;
mark[i+1] = 0;
}

else if(a[i]==1 && a[i+1]==1 && mark[i]==-1){
carry++;
mark[i] = 0;


}

else if(a[i]==0 && a[i+1]==1 && mark[i]==-1 && mark[i+1]==-1){
mark[i+1] = 0;
mark[i] = 0;
ans++;
}


if(carry%2!=0 && carry>1){
carry = 1;
mark[i] = 0;
ans++;
}

// cout<<carry<<nl;

}

cout<<ans<<nl;

}

No it would give correct answer which is 3
as first it would take 0 0 (0,1)1 1 1 [0,1] and add 1 to score and remove that subarray from that bigger array.
now the array will become 0 0 1 1 1 and same process will be executed

Yes, directly removing an element from the middle of an array is generally going to require \mathcal{O}(N) time, which makes your solution \mathcal{O}(N^2).

It’s definitely possible to simulate the process in \mathcal{O}(N) or similar, though you need to be a bit smart about it. One way of doing it can be seen in my Python code linked in the editorial.

Count the number of zeroes = z, number of ones = o in the input string

If (z>=o) answer = o.
If (z < o) answer = z +((o-z)/3)

If z>=o, It’s always possible to map all ones to zeroes, by choosing a subarray (01) or (10) and deleting it, and continuing the process repeatedly.

If z < o , then all zeroes can be mapped to ones , the left out ones can be grouped in triplets and score can be found.

My solution: CodeChef

Simple python code I submitted during contest

for _ in range(int(input())):
n = int(input())
arr = list(map(int, input().split()))
cnt0 = arr.count(0)
cnt1 = arr.count(1)
if(cnt1 <= n//2):
print(cnt1)

else:
ans = cnt0
ans += (cnt1-cnt0)//3
print(ans)