TRIXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Utkarsh Gupta
Tester: Tejas Pandey
Editorialist: Nishank Suresh

DIFFICULTY:

2256

PREREQUISITES:

None

PROBLEM:

You are given N (6 \leq N \leq 1000) non-negative integers. In one move, you can choose three of them, say A, B, C, and replace these three with A\oplus B, A\oplus C, B\oplus C.

Find a sequence of at most 11111 moves to make all N numbers 0.

EXPLANATION:

In a lot of constructive tasks like this, it’s not immediately obvious how to proceed. Instead, start by making some observations and looking at simpler cases.

One useful observation is the following:

  • If we pick A = B = C, then all three get replaced with 0.

This is nice, since our aim is to make everything 0.

Now, let’s try looking at a special case: what if every integer was either 0 or 1?

Solution

One algorithm to solve this case is as follows:

  • First, if there are at least 3 ones, we can use our earlier observation to make all 3 of them zeros.
  • Repeatedly doing this leaves us with either 0, 1, or 2 ones.
    • If there are 0 ones, we’re done, since everything is 0.
    • If there is one 1, our only non-trivial move is to pick \{0, 0, 1\}, which gives us \{0, 1, 1\}, i.e, we are now at the case with 2 ones. All that remains is to solve that case.
    • If there are two 1's, we can use the fact that N \geq 6 to ensure that there are also at least two zeros.
      • First, pick \{0, 0, 1\}, which gives us \{0, 1, 1\}.
      • Now notice that we have three ones, so pick those three and we’re done.

Now that we have a solution to the simpler case of 0's and 1's, extending it to the general case isn’t too hard: simply apply this solution independently to each bit!

That is, first apply this solution to make the 0-th bit of all the numbers 0.
Then, make the first bit of everything 0
Then, make the second bit 0, and so on.

Since A_i \leq 10^9, this only needs to be done 30 times. What about the number of moves?

Answer

Let’s look at the number of moves for a single bit, then multiply this by 30.

In the worst case, we make N/3 moves of 3 ones, followed by 3 extra moves for when there’s exactly one 1 left.
This gives us a worst-case of 1000/3 + 3 moves, let’s say 340 is an upper bound.

Multiplying this by 30 gets us to 10200, which is well within the limit.

Solving for a single bit can easily be done in \mathcal{O}(N) time, giving us a \mathcal{O}(30\cdot N) solution overall.

TIME COMPLEXITY

\mathcal{O}(30\cdot N) per test case.

CODE:

Setter's code (C++)
//Utkarsh.25dec
#include <bits/stdc++.h>
#define ll long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;
ll power(ll a,ll b) {ll res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;}
ll modInverse(ll a){return power(a,mod-2);}
const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
void solve()
{
    int n=readInt(6,1000,'\n');
    ll A[n+1]={0};
    for(int i=1;i<=n;i++)
    {
        if(i==n)
            A[i]=readInt(0,1000000000,'\n');
        else
            A[i]=readInt(0,1000000000,' ');
    }
    vector <tuple<ll,ll,ll>> opers;
    for(int bit=0;bit<=30;bit++)
    {
        set <int> indices;
        for(int i=1;i<=n;i++)
        {
            if((A[i]&(1<<bit))!=0)
                indices.insert(i);
        }
        while(indices.size()>=3)
        {
            int p=(*indices.begin());
            indices.erase(p);
            int q=(*indices.begin());
            indices.erase(q);
            int r=(*indices.begin());
            indices.erase(r);
            ll P=A[p], Q=A[q], R=A[r];
            opers.pb(make_tuple(P, Q, R));
            A[p]=(P^Q);
            A[q]=(Q^R);
            A[r]=(R^P);
        }
        if(indices.size()==1)
        {
            int p=(*indices.begin());
            vector <int> fun;
            fun.pb(p);
            for(int i=1;i<=n;i++)
            {
                if(fun.size()==3)
                    break;
                if(i==p)
                    continue;
                fun.pb(i);
            }
            for(auto it:fun)
                indices.insert(it);
            p=fun[0];
            int q=fun[1];
            int r=fun[2];
            ll P=A[p], Q=A[q], R=A[r];
            opers.pb(make_tuple(P, Q, R));
            A[p]=(P^Q);
            A[q]=(Q^R);
            A[r]=(R^P);
            indices.erase(q);
        }
        if(indices.size()==2)
        {
            int p=(*indices.begin());
            indices.erase(p);
            int q=(*indices.begin());
            vector <int> fun;
            fun.pb(p);
            for(int i=1;i<=n;i++)
            {
                if(fun.size()==3)
                    break;
                if(i==p)
                    continue;
                if(i==q)
                    continue;
                fun.pb(i);
            }
            p=fun[0];
            q=fun[1];
            int r=fun[2];
            ll P=A[p], Q=A[q], R=A[r];
            opers.pb(make_tuple(P, Q, R));
            A[p]=(P^Q);
            A[q]=(Q^R);
            A[r]=(R^P);
            indices.insert(p);
            indices.insert(r);
        }
        if(indices.size()==3)
        {
            int p=(*indices.begin());
            indices.erase(p);
            int q=(*indices.begin());
            indices.erase(q);
            int r=(*indices.begin());
            indices.erase(r);
            ll P=A[p], Q=A[q], R=A[r];
            opers.pb(make_tuple(P, Q, R));
            A[p]=(P^Q);
            A[q]=(Q^R);
            A[r]=(R^P);
        }
    }
    cout<<opers.size()<<'\n';
    for(auto it:opers)
        cout<<get<0>(it)<<' '<<get<1>(it)<<' '<<get<2>(it)<<'\n';
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readInt(1,10,'\n');
    while(T--)
        solve();
    assert(getchar()==-1);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

void modify(int a, int b, int c, vector<int> &v) {
    int v1 = (v[a]^v[b]), v2 = (v[b]^v[c]), v3 = (v[c]^v[a]);
    v[a] = v1;
    v[b] = v2;
    v[c] = v3;
}

int main() {
	int t;
	cin >> t;
	while(t--) {
	    int n;
	    cin >> n;
	    vector<int> a(n + 1);
	    for(int i = 1; i <= n; i++) cin >> a[i];
	    vector<tuple<int, int, int>> ops;
	    for(int i = 0; i < 31; i++) {
	        vector<int> set; int cnt = 0;
	        for(int j = 1; j <= n; j++) {
	            cnt += ((a[j]&(1LL<<i)) > 0);
	            if(((a[j]&(1LL<<i)) > 0))
	                set.push_back(j);
	        }
	       for(int j = 0; j < cnt - cnt%3; j+=3) {
	           ops.push_back({a[set[j]], a[set[j + 1]], a[set[j + 2]]}); 
	           modify(set[j], set[j + 1], set[j + 2], a);
	       }
	       if(cnt%3 == 1) {
	           int z1 = 0, z2 = 0;
	           for(int j = 1; j <= 6; j++)
	            if((a[j]&(1LL<<i)) == 0)
	                z1 = z2, z2 = j;
	           ops.push_back({a[set[cnt - 1]], a[z1], a[z2]});
	           modify(set[cnt - 1], z1, z2, a);
	           for(int j = 1; j <= 6; j++)
	            if((a[j]&(1LL<<i)) == 0)
	                z1 = z2, z2 = j;
	           ops.push_back({a[set[cnt - 1]], a[z1], a[z2]});
	           modify(set[cnt - 1], z1, z2, a);
	       } else if(cnt%3 == 2) {
	           int z1 = 0, z2 = 0;
	           for(int j = 1; j <= 6; j++)
	            if((a[j]&(1LL<<i)) == 0)
	                z1 = z2, z2 = j;
	           ops.push_back({a[set[cnt - 1]], a[z1], a[z2]});
	           modify(set[cnt - 1], z1, z2, a);
	       }
	       set.clear(); cnt = 0;
	       for(int j = 1; j <= n; j++) {
	            cnt += ((a[j]&(1LL<<i)) > 0);
	            if(((a[j]&(1LL<<i)) > 0))
	                set.push_back(j);
	        }
	       for(int j = 0; j < cnt - cnt%3; j+=3) {
	           ops.push_back({a[set[j]], a[set[j + 1]], a[set[j + 2]]}); 
	           modify(set[j], set[j + 1], set[j + 2], a);
	       }
	    }
	    cout << ops.size() << "\n";
	    for(int i = 0; i < ops.size(); i++) cout << get<0>(ops[i]) << " " << get<1>(ops[i]) << " " << get<2>(ops[i]) << "\n";
	}
	return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
	n = int(input())
	a = list(map(int, input().split()))
	ops = []
	for bit in range(30):
		ind = []
		for i in range(n):
			if (a[i] >> bit) & 1:
				ind.append(i)
		while len(ind) >= 3:
			x, y, z = ind.pop(), ind.pop(), ind.pop()
			A, B, C = a[x] ^ a[y], a[x] ^ a[z], a[y] ^ a[z]
			ops.append([a[x], a[y], a[z]])
			a[x], a[y], a[z] = A, B, C
		if len(ind) == 0:
			continue
		while len(ind) < 3:
			x, y = 0, 0
			while x in ind: x += 1
			y = x+1
			while y in ind: y += 1
			z = ind[-1]
			A, B, C = a[x] ^ a[y], a[x] ^ a[z], a[y] ^ a[z]
			ops.append([a[x], a[y], a[z]])
			a[x], a[y], a[z] = A, B, C
			ind.append(y)
		x, y, z = ind[0], ind[1], ind[2]
		ops.append([a[x], a[y], a[z]])
		A, B, C = a[x] ^ a[y], a[x] ^ a[z], a[y] ^ a[z]
		a[x], a[y], a[z] = A, B, C
	print(len(ops))
	for x, y, z in ops:
		print(x, y, z)
3 Likes

This problem can also be solved in at most n + 2 operations (see submission), using the following results:

  • (x, x, y) \rightarrow (x \oplus y, x \oplus y, 0), which can be rewritten as (z, z, 0) for some value of z
  • (x, 0, 0) \rightarrow (x, x, 0)
  • (x, x, x) \rightarrow (0, 0, 0)

So if we have two equal elements, then we can make the rest of the elements 0 using operations of the first type. Then we can apply an operation of the second type with one of the two initial elements to create an extra copy, and finally apply an operation of the third type to make them all 0. This would take at most n operations.

Now to obtain two equal elements, consider the following 2 operations on arbitrary elements x, y, z, w.

  • (x, y, z) \rightarrow (x \oplus y, y \oplus z, z \oplus x)
  • (y \oplus z, z \oplus x, w) \rightarrow (x \oplus y, z \oplus x \oplus w, w \oplus y \oplus z)

In doing so we have two occurrences of (x \oplus y) and can solve the rest of the problem as described above.

9 Likes

wow brilliant

Such a brilliant explanation
Hats off

In my submission in D language I used a priority queue to solve this problem. Basically on each iteration fetch the 3 largest elements (let’s label them as a, b, c). If the largest element a is 0, then we are done and everything has been already successfully converted. So we do the xor transformation and put 3 transformed elements back into the priority queue.

Now the tricky case is when the top bit of a is also set in b, but not in c (which is similar to the case with 2 ones from the editorial). In this situation I fetch one more element from the priority queue to replace b (it won’t have the top bit of a set in it) and put b back into the priority queue.

The same submission converted to C++

You have a very nice solution with a proven low upper bound for the number of operations. That’s a “branchless” SIMD approach with all 30 lanes being processed simultaneously and takes constant time for any N. It also produces fewer lines than the intended solution from the editorial.

But my priority queue based solution produces fewer lines than yours when processing random input. I wonder if this is always the case? Are there even more optimal solutions?

Great solution I just was in shock when analize it

Very nice problem. Thanks!

On applying operation of 1st type we will get 1 zero , but in operation of type-2 we need 2 zero , so how we will get that extra zero?

It’s easier to look at an example. If the array initially was [x_1, x_1, a, b, c, d, ... ], then we can use the pair of x_1 values to zero everything else:

  • On the first step we do (x_1, x_1, a) → (x_1 ⊕ a, x_1 ⊕ a, 0) operation and the array becomes [x_2, x_2, 0, b, c, d, ... ], where x_2 = x_1 ⊕ a
  • On the second step we do (x_2, x_2, b) → (x_2 ⊕ b, x_2 ⊕ b, 0) operation and the array becomes [x_3, x_3, 0, 0, c, d, ... ], where x_3 = x_2 ⊕ b
  • On the last step the array becomes [x_k, x_k, 0, 0, 0, 0, ... ]

but in operation of type-2 we need 2 zero , so how we will get that extra zero?

See the example above. On the last step everything is already zero except for the first two elements. As the minimal array size is 6, we have at least 4 zeroes at our disposal. The finishing steps:

  • After applying the operation (x_k, 0, 0) → (x_k, x_k, 0), our array becomes [x_k, x_k, x_k, 0, 0, 0, ... ]
  • And finally applying (x_k, x_k, x_k) → (0, 0, 0) converts the array into [0, 0, 0, 0, 0, 0, ... ]