XOR_PERM - EDITORIAL

PROBLEM LINK:

Frodo and the mines of Moria

Author: chef_hamster
Tester: ultimate_st
Editorialist: chef_hamster

PROBLEM:

You are given an array A of N integers with even number of bits in their binary representation. You are also given following operation:

Choose any element of the array and take its XOR with any permutation of its binary representation. You can apply this on that element as many times as you want in a single operation.

For example: Suppose you have [8,2,9] and you chose 8. Then 8 can be written as 1000 in its binary representation. You can take its XOR with 1000, 0100, 0010, and 0001. Let’s take 1000 \oplus 0001. Now we have 9 which is 1001. We can again repeat this process by taking its XOR with any permutation of 1001.

You have to print the maximum possible sum of the array after applying this on at most K elements.


Prerequisites:

  • Knowledge of properties of XOR
  • Basic Mathematics
  • Patience

Hint:

Hint 1

Try finding largest possible number for some 2, 4, 6, 8 and 10 bit numbers and observe some pattern.

Hint 2

For some integers, we are able to set all the bits while for some, few bits are always unset. Try finding the relation between those integers and the number of bits in that integer.

Hint 3

For integers with number of bits equal to n, if \frac{n}{2} is even then we can set all the bits in that integer.

Hint 4

For integers with number of bits n, if \frac{n}{2} is odd and number of set bits are equal to \frac{n}{2} then again we can set all the bits in that integer otherwise we can only set n-2 bits.


Explanation:

First of all, let’s derive some important points.

Claim 1: For a binary number of length n ( n \ge 2 ), if number of set bits in the number are less than n, then we can always reach to a binary number with number of set bits equal to 2.

Proof

Let B = 11111...0000 be a binary number with s 1s and (n-s) 0s. Then if we XOR it with B^* = 011111...000, a cyclic permutation of B, then we will always be left with 100..100.. which only have 2 set bits and rest unset bits.

Claim 2: After performing above operation, we are always left with even number of set bits in our integer.

Proof

Let B be a binary number with s number of set bits . If we take its XOR with a permutation of B and X bits are overlapping in both. Then the resultant of the XOR will have 2s-2x number of bits set which is, clearly, an even number.


Now, lets move on to the proof of the solution:

Claim 3: If we have a binary number B with 2n bits ( n \ge 1 ) and n is even, then we can set all the bits in the number using the above operation.

Proof

We have to set n bits of the binary number, then only we can set all its 2n ( by taking the XOR with the reverse permutation. ) . So let’s prove that we can set all the n bits. Now, using claim 1, we know we can always have 2^a bits set in the binary number. Suppose we take maximum possible a such that a \le 2^n. Now, to set n bits, we have to take XOR of

  • B = 11..._{\color{red}{2^a times}}00..._{\color{red}{2n-2^a times}} with
  • B^* = 00...._{\color{red}{\frac{n}{2} times}}11..._{\color{red}{2^a times}}00..._{\color{red}{\frac{3n}{2} -2^a times}}.

By doing so, first \frac{n}{2} bits from B are unhindered, next 2^a - \frac{n}{2} are flipped to 0. Similarly, first \frac{n}{2} bits of B^* of are 1, next 2^a-\frac{n}{2} bits are flipped to 0 and remaining \frac{n}{2} bits are unhindered. This gives us a total of 2^a + 2^a - 2(2^a - \frac{n}{2}) set bits which is equal to n.

Also, the above procedure is possible as the number of bits required to do the above manipulations are less than 2*n which can be clearly seen in B^*.

Claim 4: If we have a binary number B with 2n bits ( n \ge 1 and n is odd), and the number of set bits are not equal to n, then we can set only first n-2 bits in the number using the above operation.

Proof

Since n is odd, using claim 2, we can never set odd number of bits. This leaves us with only one choice, try to set n-1. Since n-1 is even, we can set 2(n-1) bits using claim 3, which will give us n-2 set bits.


Since we have proved that the claims made above are correct, its just a matter of implementation now. To get the maximum sum, we will calculate the maximum increase we can get for each element and then take k maximum values from them.

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;


ll solve(ll a){
    ll ct_bit=0,ct_set=0;
    while(a){
        ct_bit++;
        ct_set+=a&1;
        a>>=1;
    }
    if(((ct_bit)/2)%2){
        if(ct_bit==2*ct_set)return (1<<(ct_bit))-1;
        return (1<<(ct_bit))-4;
    }
    return (1<<(ct_bit))-1;
}

int main(){
    int test;
    cin>>test;
    while(test--){
        int n,k;
        cin>>n>>k;
        ll arr[n];
        for(int i=0;i<n;i++) cin>>arr[i];
        ll sum = accumulate(arr,arr+n,ll(0));
        ll val[n];
        for(int i=0;i<n;i++){
            val[i] = solve(arr[i])-arr[i];            
        }
        sort(val,val+n,greater<ll>());
        for(int i=0;i<k;i++){
            if(val[i]<0)break;
            sum+=val[i];
        }
        cout<<sum<<endl;
    }
    return 0;
}
Tester's Solution
import java.util.*;

public class Main
{
    static long xorPerm(long n){
        
        int bits=countBits(n);
        int set=countSetBits(n);
        long max=(1<<bits);
        int n1=bits/2;
        if(n1%2==0){
            return max-1;
        }
        else{
            if(2*set==bits)
                return max-1;
            return max-4;
        }
    }
    static int countSetBits(long n)
    {
        int count = 0;
        while (n > 0) {
            count += n & 1;
            n >>= 1;
        }
        return count;
    }
    static int countBits(long n)
    {
        int count = 0;
        while (n > 0) {
            count++;
            n >>= 1;
        }
        return count;
    }
    
    
	public static void main (String[] args) throws java.lang.Exception
	{
		// your code goes here
        Scanner sc = new Scanner(System.in);
		int t=sc.nextInt();
		while(t-->0){
		    int n=sc.nextInt();
		    int k=sc.nextInt();
		    long arr[]=new long[n];
		    long sum=0;
		    for(int i=0;i<n;i++){
		        arr[i]=sc.nextLong();
		        sum+=arr[i];
		    }
		    long diff[]=new long[n];
		    for(int i=0;i<n;i++){
		        diff[i]=xorPerm(arr[i])-arr[i];
		    }
		    Arrays.sort(diff);
		    for(int i=0;i<k;i++){
		       if(diff[n-1-i]<0)
		       break;
		       sum+=diff[n-1-i];
		    }
		    System.out.println(sum);
		}
	}
}
3 Likes