 # MAXBITSUM Editorial

Setter: Yash Goyal
Testers: Abhinav Sharma[inov_360 | CodeChef User Profile for Abhinav sharma | CodeChef], Nishank Suresh [iceknight1093 | CodeChef User Profile for Nishank Suresh | CodeChef]
Editorialist: Pratiyush Mishra

Easy-Medium

None

# PROBLEM:

Given two arrays A and B each of length N.

Find the value \sum_{i=1}^N\sum_{j=1}^N \max(A_i\oplus B_j,A_i\& B_j).
Here \oplus and \&amp; denote the bitwise XOR operation and bitwise AND operation respectively.

# Quick Explanation:

First of all let us see that pattern in max(x \oplus y, x \& y).
If the MSB of x and y are i and j respectively, then ,

• if i \neq j then x \oplus y > x \& y.
• if i=j, then x \oplus y < x \& y. This is because maximum of both will be the one with highest bit ON.

Now, for array b store the following for each i, sum of all values of b that has i as its MSB. Now, iterate through a[i] and add contribution according to its MSB to the sum.

# Explanation

We will define arrays to store different values as follows:
msb[i] → number of elements in a having their most significant bit as i.
bits[i] → number of elements in a having their i_{th} bit on.

bit\_table[i][j] → number of elements in a having their most significant bit as i, with j_{th} bit on.

We will calculate values for these arrays using simple bit manipulation and then proceed to calculate the sum as follows.

We will loop through the elements of B and see for each element as B_i, having it most significant bit as k.

• If all bits of B_i is unset then we would simply take the sum of array A and add it to our answer, since A_j \oplus B_i = A_j, 0 \leq j < n

• Otherwise we would loop through the bits of B_i and for each bit say j

• if j_{th} bit is set, then we add the following to our final answer:
answer +=(1< < j) \times bit_table[k][j]
• if j_{th} bit is not set, then we add the following:

# TIME COMPLEXITY:

O(N) for each test case.

# SOLUTION:

1 Like

Thanks for so fast and nice editorial.

#define int long long int

void solve(){
int n;
std::cin >> n;
std::vector <int> A(n, 0), B(n, 0), bit(22, 0);
std::vector <std::vector<int>> v(22);
for(auto &i:A)
std::cin >> i;
for(auto &i:B)
std::cin >> i;
for(int i=0; i<n; i++){
int k = B[i], j = 0;
while(k){
bit[j] += (k%2);
k /= 2;
j += 1;
}
v[j].push_back(B[i]);
}

int ans = 0;
for(int i=0; i<n; i++){
int k = A[i], j = 0, m = 0;
while(k){
j += 1;
k /= 2;
}
//std::cout << j << " ";
for(auto l:v[j]){
//std::cout << l << " ";
ans += (A[i]&l);
ans -= (A[i]^l);
}
//std::cout << "\n";
k = A[i];
for(int x = 0; x<22; x++){
if(k%2){
ans += (n-bit[x])*(1<<m);
}
else{
ans += bit[x]*(1<<m);
}
k /= 2;
m += 1;
}
}
std::cout << ans << "\n";
}

signed main() {
std::ios::sync_with_stdio(false);
std::cin.tie(0);
int t = 0;
std::cin >> t;
while(t--){
solve();
}
}



Your solution is actually incorrect, consider this testcase:

2
1 3
4 0


why are we adding this term?

First of all we can see that xor sum will be contributed by pairs with different most significant bits.
Lets say we are considering all the numbers in a with their msb as k then the xor will be contributed by the elements of this group and the remaining elements of b which have msb different from k .
That number is nothing but n - cnt(k) where cnt(i) is the number of elements in b with msb as i.

So when the bit is set for a position j , we will count how many numbers are having a bit 0 at that position and each of them will contribute a 1 << j to the ans .

And that count of zeroes is nothing but remaining elements - (numbers among them having bit set at jth position) which equals :

count of zeroes = (n - cnt[k]) - (bits[j] - bit_table[k][j]))

So the net contribution becomes (1<< j) times the count of zeroes

Actually I modified a bit, and now answer is 16. Thanks though, can you check again? It’s correct approach now.