MAXBITSUM Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Yash Goyal
Testers: Abhinav Sharma[inov_360 | CodeChef User Profile for Abhinav sharma | CodeChef], Nishank Suresh [iceknight1093 | CodeChef User Profile for Nishank Suresh | CodeChef]
Editorialist: Pratiyush Mishra

DIFFICULTY:

Easy-Medium

PREREQUISITES:

None

PROBLEM:

Given two arrays A and B each of length N.

Find the value \sum_{i=1}^N\sum_{j=1}^N \max(A_i\oplus B_j,A_i\& B_j).
Here \oplus and \& denote the bitwise XOR operation and bitwise AND operation respectively.

Quick Explanation:

First of all let us see that pattern in max(x \oplus y, x \& y).
If the MSB of x and y are i and j respectively, then ,

  • if i \neq j then x \oplus y > x \& y.
  • if i=j, then x \oplus y < x \& y. This is because maximum of both will be the one with highest bit ON.

Now, for array b store the following for each i, sum of all values of b that has i as its MSB. Now, iterate through a[i] and add contribution according to its MSB to the sum.

Explanation

We will define arrays to store different values as follows:
msb[i] → number of elements in a having their most significant bit as i.
bits[i] → number of elements in a having their i_{th} bit on.

bit\_table[i][j] → number of elements in a having their most significant bit as i, with j_{th} bit on.

We will calculate values for these arrays using simple bit manipulation and then proceed to calculate the sum as follows.

We will loop through the elements of B and see for each element as B_i, having it most significant bit as k.

  • If all bits of B_i is unset then we would simply take the sum of array A and add it to our answer, since A_j \oplus B_i = A_j, 0 \leq j < n

  • Otherwise we would loop through the bits of B_i and for each bit say j

    • if j_{th} bit is set, then we add the following to our final answer:
    answer +=(1< < j) \times bit_table[k][j]
answer += (1<<j)*(n-msb[k]-(bits[j]-bit\_table[k][j]))
  • if j_{th} bit is not set, then we add the following:
answer += (1<<j) \times (bits[j]-bit\_table[k][j])

TIME COMPLEXITY:

O(N) for each test case.

SOLUTION:

Setter’s Solution
Tester1’s Solution
Tester2’s Solution

1 Like

Thanks for so fast and nice editorial.

#define int long long int

void solve(){
    int n;
    std::cin >> n;
    std::vector <int> A(n, 0), B(n, 0), bit(22, 0);
    std::vector <std::vector<int>> v(22);
    for(auto &i:A)
        std::cin >> i;
    for(auto &i:B)
        std::cin >> i;
    for(int i=0; i<n; i++){
        int k = B[i], j = 0;
        while(k){
            bit[j] += (k%2);
            k /= 2;
            j += 1;
        }
        v[j].push_back(B[i]);
    }
    
    int ans = 0;
    for(int i=0; i<n; i++){
        int k = A[i], j = 0, m = 0;
        while(k){
            j += 1;
            k /= 2;
        }
        //std::cout << j << " ";
        for(auto l:v[j]){
            //std::cout << l << " ";
            ans += (A[i]&l);
            ans -= (A[i]^l);
        }
        //std::cout << "\n";
        k = A[i];
        for(int x = 0; x<22; x++){
            if(k%2){
                ans += (n-bit[x])*(1<<m);
            }
            else{
                ans += bit[x]*(1<<m);
            }
            k /= 2;
            m += 1;
        }
    }
    std::cout << ans << "\n";
}
     
signed main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
    int t = 0;
    std::cin >> t;
    while(t--){
        solve();
    }
}

Similar correct approach but giving TLE in some test cases, can someone please help?

Your solution is actually incorrect, consider this testcase:

2
1 3
4 0

correct answer is 16.

why are we adding this term?

answer+=(1<<j)∗(n−msb[k]−(bits[j]−bit_table[k][j]))

First of all we can see that xor sum will be contributed by pairs with different most significant bits.
Lets say we are considering all the numbers in a with their msb as k then the xor will be contributed by the elements of this group and the remaining elements of b which have msb different from k .
That number is nothing but n - cnt(k) where cnt(i) is the number of elements in b with msb as i.

So when the bit is set for a position j , we will count how many numbers are having a bit 0 at that position and each of them will contribute a 1 << j to the ans .

And that count of zeroes is nothing but remaining elements - (numbers among them having bit set at jth position) which equals :

count of zeroes = (n - cnt[k]) - (bits[j] - bit_table[k][j]))

So the net contribution becomes (1<< j) times the count of zeroes

Actually I modified a bit, and now answer is 16. Thanks though, can you check again? It’s correct approach now.