INC0XOR - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Alex
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

3949

PREREQUISITES:

XOR

PROBLEM:

Given an array of N non-negative integers A_1, A_2, \ldots, A_N. In one operation, you can increase any one of the elements by 1.

Find the minimum number of operations needed to make the XOR of all elements (A_1 \oplus A_2 \oplus \ldots \oplus A_N) equal to 0. Here \oplus denotes the bitwise XOR operation.

EXPLANATION:

Let X be the xor of all elements. If X is 0, the answer is obviously 0. Otherwise, we need to apply the operations for some number of times.

We first try to solve the problem in the case where there exists at least one element in A that is 0. Assume the array A is sorted. We have this following observation:

  • After we perform the operations, the array should still be sorted. That’s because there is no reason for A_i < A_j to end up being B_i > B_j after the operations, we can instead transform A_i to B_j and A_j to B_i instead.

Noting this, there is a greedy strategy to solve the subproblem above:

  • Let the highest bit of X be P. Among all elements of A such that its P-th bit is zero, we take one that has the largest last P bits, then add in just enough so that the P-th bit of this element is 1. Repeatedly do this until X becomes 0.

This greedy strategy works because:

  • There is no reason for us to make a bit greater than P appears in X.
  • Therefore, we only consider the last P bits of every value. Our goal now is to transform some value into 2^P.
  • Because of the previous observation, we take the largest value that is less than 2^P, and transform it to 2^P.
  • We know that such an element always exist, because A contains 0.
  • Additionally, we know that after this step, the last P - 1 bits of the transformed element is 0, so we continue to have a 0 element in A going into the next steps.

Therefore we can solve this subproblem in O(N \log \max(A)).

In the general case however, we don’t have a 0 element in A. Let’s make one artificially! Observe that there must be some largest bit X such that some 0 bit get transformed into a 1 bit in some value. Loop over this value X, then we only need to consider the last X bits of elements in A. We simply do the same preprocess step we have done in our greedy strategy:

  • Among all elements of X such that its X-th bit is zero, we take one or two that has the largest last P bits, then add in just enough so that the X-th bit of these elements become 1.

After this preprocessing step, these two elements have the last X - 1 bits equal to 0, so we return to our subproblem above, which we knows how to solve.

This algorithm runs in O(N \log^2 \max(A)), which is not fast enough. We have one final observation:

  • Suppose we loop X from \log \max(A) + 1 to 0, at each time we take away two elements in A that has the X-th bit equals to zero and has the largest last X bits and put this in an array B. Notice that these elements in B are the only ones being chosen by some step in both the preprocessing part and the greedy algorithm, so we only need to solve the problem on B.
    Since B's size is 2 \cdot (\log \max(A) + 1), the solution is optimized to O(N \log \max(A) + \log^3 \max(A)).

TIME COMPLEXITY:

Time complexity is O(N \log \max(A) + \log^3 \max(A)).

SOLUTION:

Setter's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
#include <array>
using namespace std;

#ifdef LOCAL
	#define eprintf(...) {fprintf(stderr, __VA_ARGS__);fflush(stderr);}
#else
	#define eprintf(...) 42
#endif

using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
ll myRand(ll B) {
	return (ull)rng() % B;
}

#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second

clock_t startTime;
double getCurrentTime() {
	return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}

const ll INF = (1LL << 62);
const int K = 61;
const int N = (int)1e6 + 7;
ll a[N];
ll aa[N];
bool tk[N];
int n;

ll solve(int f, ll X) {
	for (int i = 0; i < n; i++)
		a[i] = aa[i];
	ll ans = 0;
	pli b[2];
	for (int k = f; k >= 0; k--) {
		int z = (X >> k) & 1;
		if (z == 0 && k == f) z = 2;
		if (z == 0) continue;
		b[0] = b[1] = mp(-1LL, -1);
		for (int i = 0; i < n; i++) {
			if ((a[i] >> k) & 1) continue;
			pli t = mp(a[i] & ((1LL << k) - 1), i);
			for (int j = 0; j < z; j++)
				if (b[j] < t)
					swap(b[j], t);
		}
		if (b[z - 1].second == -1) return INF;
		for (int j = 0; j < z; j++) {
			int v = b[j].second;
			ll x = b[j].first;
			ll y = 1LL << k;
			ans += y - x;
			X ^= x ^ y;
			a[v] += y - x;
		}
	}
	assert(X == 0);
	return ans;
}

void solve() {
	scanf("%d", &n);
	ll X = 0;
	for (int i = 0; i < n; i++) {
		tk[i] = false;
		scanf("%lld", &a[i]);
		X ^= a[i];
	}
	if (X == 0) {
		printf("0\n");
		return;
	}
	pli b[2];
	for (int k = K - 1; k >= 0; k--) {
		b[0] = b[1] = mp(-1LL, -1);
		for (int i = 0; i < n; i++) {
			if (tk[i]) continue;
			if ((a[i] >> k) & 1) continue;
			pli t = mp(a[i] & ((1LL << k) - 1), i);
			for (int j = 0; j < 2; j++)
				if (b[j] < t)
					swap(b[j], t);
		}
		for (int j = 0; j < 2; j++)
			if (b[j].second != -1)
				tk[b[j].second] = 1;
	}
	int sz = 0;
	for (int i = 0; i < n; i++) if (tk[i])
		aa[sz++] = a[i];
	n = sz;
	ll ans = INF;
	for (int k = K - 1; k >= 0; k--) {
		ans = min(ans, solve(k, X));
		if ((X >> k) & 1) break;
	}
	printf("%lld\n", ans);
}

int main()
{
	startTime = clock();
//	freopen("input.txt", "r", stdin);
//	freopen("output.txt", "w", stdout);

	int t;
	scanf("%d", &t);
	while(t--) solve();

	return 0;
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const int N=1e6+1;
const int iu=60;
int n;
ll a[N];
bool ban[N];
priority_queue<pair<ll,int>,vector<pair<ll,int> >,greater<pair<ll,int> > >pq[iu+2];
ll king[66][111];
ll w;
int m;
ll cost=0;
void sink(int bit){
	//cout << "sink " << bit << endl;
	for(int j=1; j<=m ;j++){
		int x=king[bit][j];
		if(x==0) break;
		if(!ban[x]){
			//cout << "get " << x << endl;
			ban[x]=true;
			ll del=(1LL<<bit)-(a[x]&((1LL<<bit)-1));
			w^=a[x];
			w^=a[x]+del;
			cost+=del;
			return;
		}
	}
	//cout << "Get " << 0 << endl;
	w^=(1LL<<bit);
	cost+=(1LL<<bit);
	
}
void solve(){
	cin >> n;m=min(100,n);
	ll y=0;
	for(int i=1; i<=n ;i++){
		cin >> a[i];
		for(int j=0; j<=iu ;j++){
			if(((a[i]>>j)&1)==0){
				ll cur=a[i]&((1LL<<j)-1);
				pq[j].push({cur,i});
				if(pq[j].size()>m) pq[j].pop();
			}
		}
		y^=a[i];
	}
	for(int i=0; i<=iu ;i++){
		for(int j=1; j<=m ;j++) king[i][j]=0;
		int z=pq[i].size();
		for(int j=z; j>=1 ;j--){
			king[i][j]=pq[i].top().se;
			pq[i].pop();
			z++;
		}
	}
	
	if(y==0){
		cout << "0\n";
		return;
	}
	ll ans=9e18;
	bool last=false;
	for(int i=iu; i>=0  && !last ;i--){
		//cout << "Round " << i << endl;
		for(int j=1; j<=n ;j++) ban[j]=false;
		w=y;cost=0;
		bool ok=true;
		if((1LL<<i)<=y){
			last=true;
			if(king[i][1]==0) ok=false;
			sink(i);
		}
		else{
			if(king[i][2]==0) ok=false;
			sink(i);sink(i);
		}
		if(!ok) continue;
		
		//cout << "thonk " << w << ' ' << cost << endl;
		for(int j=i-1; j>=0 ;j--){
			if(w&(1LL<<j)) sink(j);
		}
		//cout << "Found!!!!!!! " << cost << endl;
		ans=min(ans,cost);
		if(last) break;
	}
	cout << ans << '\n';
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int t;cin >> t;while(t--) solve();
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    auto take_last = [](long long x, int m) { return x & ((1LL << (m + 1)) - 1); };
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        long long sx = 0;
        vector<long long> a(n), can;
        vector<bool> used(n);
        for (int i = 0; i < n; i++) {
            cin >> a[i];
            sx ^= a[i];
        }
        if (sx == 0) {
            cout << "0\n";
            continue;
        }
        long long ans = numeric_limits<long long>::max();
        for (int mx = 60; mx >= 0; mx--) {
            priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> pq;
            for (int i = 0; i < n; i++) {
                if (used[i]) {
                    continue;
                }
                if (!(a[i] >> mx & 1)) {
                    pq.push({take_last(a[i], mx), i});
                    while (pq.size() > 2) {
                        pq.pop();
                    }
                }
            }
            for (; !pq.empty(); pq.pop()) {
                int ind = pq.top().second;
                used[ind] = true;
                can.push_back(a[ind]);
            }
        }
        a = can; n = a.size();
        for (int mx = 60; mx >= 0; mx--) {
            long long cur_ans = 0, suf = take_last(sx, mx), pre = suf ^ sx;
            if (pre > 0) {
                break;
            }
            vector<long long> b = a;
            bool ok = false;
            for (int bit = mx; bit >= 0; bit--) {
                if (bit != mx && !(suf >> bit & 1)) {
                    continue;
                }
                sort(b.begin(), b.end(), [&](long long u, long long v) {
                    return take_last(u, bit) > take_last(v, bit);
                });
                for (int i = 0; i < n; i++) {
                    if (!(b[i] >> bit & 1) && (!ok || (suf >> bit & 1))) {
                        ok = true;
                        long long to = (1LL << bit);
                        cur_ans += to - take_last(b[i], bit);
                        suf ^= take_last(b[i], bit) ^ to;
                        b[i] ^= take_last(b[i], bit) ^ to;
                    }
                }
            }
            if (ok && suf == 0) {
                ans = min(ans, cur_ans);
            }
        }
        cout << ans << '\n';
    }
}
1 Like