PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Setter: Yahor Dubovik
Tester: Harris Leung
Editorialist: Trung Dang
DIFFICULTY:
2502
PREREQUISITES:
XOR
PROBLEM:
You are given an array A, consisting of N distinct integers.
Calculate number of pairs (i,j) (1 \le i < j \le N), such that 2 \cdot (A_i \oplus A_j) = A_i + A_j, where \oplus denotes bitwise XOR.
EXPLANATION:
Generally it’s better to think about bitwise operations, so let’s try to transform the condition into something only consists of bitwise operations. We know that A + B = (A \oplus B) + 2 \cdot (A \land B), where \land is bitwise AND (semantically, A \oplus B is the carry-less addition, while A \land B finds the positions which induces carry in addition). Therefore, the condition is transformed to A \oplus B = 2 \cdot (A \land B). There is another way to think of this condition: for any i, bit i of A \land B must be bit i + 1 of A \oplus B.
Let’s see how does this condition helps us. We know for sure that the first bit of A \oplus B is 0 (because A \oplus B is even). Since we also know the first bit of A (because it is fixed), we then know the first bit of B, which means we can infer the first bit of A \land B. The condition then comes in: we then can infer the second bit of A \oplus B, and then the second bit of B, then the second bit of A \land B, the third bit of A \oplus B, etc. Simply put, we can directly construct one and only one possible B from any fixed A.
Therefore, our problem becomes super simple: for each element in the array, construct the only other possible corresponding value, then check if that value is also an element in the array. If it is, we increase the answer by 1.
TIME COMPLEXITY:
Time complexity is O(N \log N).
SOLUTION:
Setter's Solution
#ifdef DEBUG
#define _GLIBCXX_DEBUG
#endif
//#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
typedef long double ld;
typedef long long ll;
int n;
const int maxN = 1e6 + 10;
ll a[maxN];
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
// freopen("input.txt", "r", stdin);
cin >> n;
for (int i = 0; i < n; i++) {
cin >> a[i];
}
sort(a, a + n);
int ans = 0;
for (int i = 0; i < n; i++) {
ll x = a[i];
ll y = 0;
int bit = 0;
while (bit < 60) {
int bt = (int)((a[i] >> bit) & 1);
if (bt == 0) {
bit++;
continue;
}
y |= (1LL << bit);
if (!(a[i] & (1LL << (bit + 1)))) {
y |= (1LL << (bit + 1));
}
bit += 2;
}
assert((x ^ y) * 2 == (x + y));
int pos = lower_bound(a, a + n, y) - a;
if (pos < n && a[pos] == y) {
ans++;
}
}
assert(ans % 2 == 0);
ans /= 2;
cout << ans << '\n';
return 0;
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const int N=2e5+1;
const int iu=30;
map<ll,int>mp;
int n;
int main(){
ios::sync_with_stdio(false);cin.tie(0);
cin >> n;
ll ans=0;
for(int i=1; i<=n ;i++){
ll x;cin >> x;mp[x]++;
ll y=x;
for(int j=0; j<60 ;j++){
if((x>>j)&1){
y^=(1LL<<(j+1));
j++;
}
}
ans+=mp[y];
}
cout << ans << '\n';
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
int n; cin >> n;
set<long long> se;
int ans = 0;
while (n--) {
long long u; cin >> u;
long long oth = 0;
for (int i = 0; i < 60; i++) {
if (u >> i & 1) {
oth ^= (1LL << i); i++;
oth ^= (((u >> i & 1) ^ 1) << i);
}
}
ans += se.count(oth);
se.insert(u);
}
cout << ans;
}