XOREQN - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Daanish Mahajan and Surya Prakash
Tester: Samarth Gupta
Editorialist: Taranpreet Singh

DIFFICULTY

Easy

PREREQUISITES

Bitwise operations

PROBLEM

You are given an array A of N non-negative integers, where N is odd. Find the minimum non-negative integer x that satisfies the equation

(A_1 + x) \oplus (A_2 + x) \oplus \dots \oplus (A_N + x) = 0

where \oplus denotes the bitwise XOR operation. If no such x exists, print -1.

QUICK EXPLANATION

  • Proceeding from the lowest bit to the highest bit, we try to decide whether the current bit b should be set in x or not.
  • When processing current bit b, all lower bits are already processed, so we know for each position whether b^{th} bit will be on or not. If there are an odd number of elements with b^{th} bit on, we have to set b^{th} bit on in x.
  • If, after processing all bits, the expression (A_1 + x) \oplus (A_2 + x) \oplus \dots \oplus (A_N + x) = 0 does not hold true, no other x exists.

EXPLANATION

In most problems on bitwise operations, it is a good approach to try to handle each bit separately and combine the answer.

Simpler problem

Consider a simpler problem, you are given an array A of N elements where N is odd, you need to make bitwise XOR zero. In order to do that,

Let C = \displaystyle \oplus_{i = 1}^N A_i be the bitwise XOR of all elements of A. Let’s assume that C has the lowest bit set, which can happen only when there is an odd number of elements in A having the last bit set. We need to flip this bit.

Since N is odd and the number of elements with the last bit set is also odd, the number of elements with the last bit set after flipping b^{th} bit would be even. That would make XOR zero.

We can even see that we decide for each bit independently whether an operation is required or not.

Original Problem

The only difference between simpler and the given problem is that In given problem, we have addition operatiion instead of \oplus.

XOR and addition are quite similar in nature. XOR is nothing but addition mod 2 on each bit. The only difference between Add and XOR comes from the fact that in addition, there can be carry forward from lower bits to higher bits.

Let’s assume we have processed b-1 bits already, and got x_0 which makes lower b-1 bits of \oplus_{i = 1}^N (A_i+x_0) zero. We want to find x such that lower b bits of \oplus_{i = 1}^N (A_i+x) are zeros.

Since we have already processed the first b-1 bits, it wouldn’t be wise to choose x such that lower b-1 bits are affected. So x_0 and x differ only by some multiple of$2^b$.

Let’s compute C = \oplus_{i = 1}^N (A_i+x_0). Two cases may arise

  • If b^{th} bit of C is also 0, we need to do nothing, hence x = x_0.
  • Otherwise we would set x = x_0 + 2^b, as by adding 2^b to each element of A, b^{th} bit of all elements in A are flipped.

Hence, since carry forward only move from least significant to most significant bit, we can repeat the above process from smallest bit to largest bit and update x if needed.

Why won’t this overflow?

Processing first 62 bits is sufficient. That way, X never exceed 2^{62}, so A_i+X cannot exceed 2^{62}+10^{18} \lt 2^{63}, which fits inside long range.

TIME COMPLEXITY

The time complexity is O(N*log(max(A_i))) per test case.

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;

void solve(){
    int n;
    cin >> n;
    vector<long long> A(n);
    for(int i = 0; i < n; i++){
	    cin >> A[i];
    }
    vector<long long> B = A;
    int iter = 0;
    long long answer = 0, mul = 1;
    while(true){
	    int cnt0 = 0;
	    int cnt1 = 0;
	    int cntx = 0;
	    long long mxv = 0;
	    for(int i = 0; i < n; i++){
		    if(A[i]&1)cnt1++;
		    if(A[i] == 0)cnt0++;
		    if(A[i] == A[0])cntx++;
		    mxv = max(mxv, A[i]);
	    }
	    if(cnt0 == n)break;
	    if(cntx == n){
		    cout << -1 << '\n';
		    return;
	    }
	    int r = cnt1&1;
	    if(r == 1){
		    if(mxv == 1){
			    cout << -1 << '\n';
			    return;
		    }
		    answer += mul;
	    }
	    for(int i = 0; i < n; i++){
		    A[i] = (A[i] + r) >> 1;
	    }
	    mul <<= 1;
	    iter++;
    }
    long long val = 0;
    for(long long &x : B){
	    val = val^(answer + x);
    }
    assert(val == 0);
    cout << answer << '\n';
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int t;
    cin >> t;
    for(int i = 0; i < t; i++){
	    solve();
    }

    return 0;
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;

long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}
 
void readEOF(){
    assert(getchar()==EOF);
}

int main() {
    // your code goes here
    int t;
    t = readIntLn(1, 1e5);
    int sum = 0;
    while(t--){
        int n = readIntLn(1, 1e6);
        sum += n;
        assert(sum <= 1e6);
        vector<long long> vec(n);
        for(int i = 0; i < n ; i++){
            if(i == n - 1){
                vec[i] = readIntLn(0, 1e18);
            }
            else{
                vec[i] = readIntSp(0, 1e18);
            }
        }
        long long x = 0;
        for(int bit = 0; bit < 61; bit++){
            int cnt = 0;
            for(int i = 0; i < n ; i++){
                if((vec[i] >> bit)&1)
                    cnt++;
            } 
            if(cnt%2 == 0)
                continue;
            x += (1ll << bit);
            for(int i = 0; i < n ; i++)
                vec[i] += (1ll << bit);
        }
        long long ok = 0;
        for(int i = 0; i < n ; i++)
            ok ^= vec[i];
        if(ok == 0)
            cout << x << '\n';
        else
            cout << -1 << '\n';
    }
    readEOF();
    return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class Main{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), B = 62;
        long[] A = new long[N];
        for(int i = 0; i< N; i++)A[i] = nl();
        long X = 0;
        for(int b = 0; b< B; b++){
            long C = f(A, X);
            if(((C>>b)&1) == 1)X |= 1L<<b;
        }
        if(f(A, X) == 0)pn(X);
        else pn(-1);
    }
    long f(long[] A, long X){
        long C = 0;
        for(long x:A)C ^= x+X;
        return C;
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new Main().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

Anyone please help me with a basic doubt
correct : Solution: 53881617 | CodeChef
partial correct : Solution: 53881559 | CodeChef
The only difference in the above codes is at line no.8
When i use long long it give correct but in int it gives wrong . please tell me why so . The max value of x and y can be 2 and 60 respectively.

1 Like

What is the significance of saying ‘find the minimum non negative integer’, when there’s only single integer that satisfies the equation???

1 Like
#include <iostream>
using namespace std;
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int t;
    cin >> t;
    while (t--)
    {
        long long int n,i,k,j;
        cin>>n;
        long long int temp;
        long long int carry[n];
        long long int binary[n][64];
        for(i=0;i<64;i++)
        {
            for(j=0;j<n;j++)
            {
                binary[j][i]=0;
            }
        }
        for(i=0;i<n;i++)
        {
            carry[i]=0;
            cin>>temp;
            k=0;
            while(temp)
            {
                binary[i][k]=temp%2;
                k++;
                temp=temp/2;
            }
        }
        long long int count=0,flag=0;
        long long int add=1,answer=0;
        for(i=0;i<64;i++)
        {

            count=0;
            for(j=0;j<n;j++)
            {
                binary[j][i]=binary[j][i]+carry[j];
                if(binary[j][i]==2)
                {
                    binary[j][i]=0;
                    carry[j]=1;
                }
                else
                {
                    carry[j]=0;
                }
                count=count+binary[j][i];
            }
            if(count%2==0)
            {
                add=add*2;
                continue;
            }
            else
            {
                if(i>62)
                {
                    flag=1;
                    break;
                }
                answer=answer+add;
                for(j=0;j<n;j++)
                {
                    if(binary[j][i]+1==2)
                    {
                        carry[j]++;
                    }
                }
            }
            add=add*2;
        }
        if(flag==0)
        {
            cout<<answer<<endl;
        }
        else
        {
            cout<<-1<<endl;
        }
    }
    return 0;
}

Why is this code giving TLE?
Complexity is O(n642)

what is the meaning of this statement?

@arpit92_8
Suppose the count of number of last bit set is 5 (odd) and total number of element be 11 (odd) . so when you flip(1->0 , 0->1) all the digits , the count of number of last bit set would be always even(11-5=6) .

1 Like

This question really made me ponder that i need to learn bit manipulation more… This one here really gave me a hard time during the contest and i still wasn’t able to solve it :expressionless:

Easy implementation

const ll MAX = 4e18;
void solve(){
    int n; cin >> n;
    ll arr[n];
    for(int i=0;i<n;i++) cin >> arr[i];
    ll temp = arr[0];
    
    ll p = 1;
    while(p < MAX){
        int cnt = 0;
        for(int i=0;i<n;i++){
            if(p&arr[i]){
                cnt++;
            }
        }
        if(cnt%2) 
            for(int i=0;i<n;i++) arr[i] += p;
                
        p <<= 1LL;
    }
    ll XOR = 0;
    for(int i=0;i<n;i++){
        XOR ^= arr[i];
    }
    ll ans = arr[0] - temp;
    if(XOR == 0){
        cout << ans << '\n';
    }else{
        cout << -1 << '\n';
    }
}

Thanks for replying and helping can you tell me how the complexity is coming to this?

For each n we are calculating its binary digits
which take log(n) time . Try solving it , you will have a clear idea .

I’m not sure this applies for the codechef evaluator nor did I take a very close look at your code but going through columns first and then rows can significantly affect your runtime(i.e. binary[j][i])

I have a similar solution, please let me know if you find the reason.