PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
The inclusion-exclusion principle
PROBLEM:
You’re given an array A, each of whose elements has at most four set bits.
Compute
Here, \vee denotes the bitwise OR operation.
EXPLANATION:
Let’s analyze what values \left\lceil \frac{x + y}{x \vee y} \right\rceil can take.
First, we invoke a well-known equation: x+y = (x\vee y) + (x\land y), where \land denotes bitwise AND.
That is, the sum of two numbers equals the sum of their bitwise OR and bitwise AND.
This is not hard to see: look at the binary representations of x and y, and observe which powers of 2 are being added once and which are being added twice.
With this,
Now, note that every bit set in x\land y will also be set in x\vee y, so x\land y \leq x\vee y.
This means 0 \leq \left\lceil \frac{x \land y}{x \vee y} \right\rceil \leq 1, and in particular the ceiling will be 1 if and only if the fraction is non-zero - which in turns means x\land y \neq 0, i.e. the bitwise AND of x and y is non-zero.
Now, putting this back into the original problem, what we want to compute is
The 1 is a constant and will be added \frac{N\cdot (N-1)}{2} times, so we can ignore it for now.
The other term, as noted above, is either 0 or 1 - and will be 1 if and only if A_i\land A_j \neq 0, meaning A_i and A_j share a common bit.
So, all we really need to do is compute the number of pairs of elements that share at least one common bit - this count plus \frac{N\cdot (N-1)}{2} will be the answer.
Here’s where we’ll use the fact that every element of A has at most four set bits.
Suppose we’re looking at an element A_i, with its set bits being b_1, b_2, \ldots, b_k.
Then, one way to count the number of j such that A_i\land A_j \neq 0 is as follows:
- First, for each b_i, add the number of elements that have b_i set.
- Now, if some number has two of the b_i set, it’ll have been added twice (but we want to add it only once).
So, for each pair of set bits, subtract the number of elements that have them both set. - But then if some element has three of the b_i set, we would’ve added it three times (in the first step) and subtracted it three times (in the second step), so it’s no longer being counted.
So, for each triple of set bits, add back in all elements that have these three bits set. - Once again, we run into the same problem with four bits: these elements will end up being counted twice, so they must be subtracted to correct for this, and so on.
This is really nothing but the inclusion-exclusion principle, which can be stated succinctly as follows:
- For each subset S of the b_i, if S has odd size then add the number of elements that have every bit in S; otherwise subtract this count.
There are 2^k subsets to go through - but the problem’s constraints guarantee that k \leq 4 so we perform \leq 2^4 = 16 checks per element which is pretty fast (numerically, it’s effectively a factor of \log N).
Now that we have a working solution, only implementation remains.
Let f[\text{mask}] denote the number of array elements that have all the bits in \text{mask} set (though they might have other bits too).
Then, we can do the following:
- For each i from 1 to N:
- Find the bits set in A_i.
- Go through all subsets of these bits. If the subset size is odd, add the corresponding f[\text{mask}], otherwise subtract it.
- Then, add 1 to f[\text{mask}] because this element needs to be counted in the future.
The value of \text{mask} can be large, but we’ll touch no more than 16\cdot N distinct values of them, so f can be safely stored in a map - alternately you can precompute all 16N masks and store them in a sorted array, and then use binary search to do the counting.
TIME COMPLEXITY:
\mathcal{O}(16N\log N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
void Solve()
{
int n; cin >> n;
vector <int> a(n);
for (auto &x : a) cin >> x;
// do PIE
// count pairs - pair where one common + pair where 2 common - pair where 3 common + pair where 4 common
vector <int> b;
for (int i = 0; i < n; i++){
assert(__builtin_popcount(a[i]) <= 4);
vector <int> bits;
for (int j = 0; j < 30; j++){
if (a[i] >> j & 1){
bits.push_back(j);
}
}
int m = bits.size();
for (int j = 1; j < (1 << m); j++){
int go = 0;
for (int k = 0; k < m; k++){
if (j >> k & 1){
go += 1 << bits[k];
}
}
b.push_back(go);
}
}
int count = n * (n - 1) / 2;
sort(b.begin(), b.end());
for (int i = 0; i < b.size(); i++){
int j = i;
while (j + 1 < b.size() && b[j + 1] == b[i]){
j += 1;
}
int v = (j - i + 1);
v = v * (v - 1) / 2;
if (__builtin_popcount(b[i]) & 1){
count -= v;
} else {
count += v;
}
i = j;
}
int ans = - count * 1 + 2 * n * (n - 1) / 2;
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
void solve(int test_case){
ll n; cin >> n;
vector<ll> a(n+5);
rep1(i,n) cin >> a[i];
ll one_cnt = 0;
map<ll,ll> mp;
rep1(i,n){
ll x = a[i];
vector<ll> bits;
while(x){
ll lsb = __lg(x);
x ^= 1<<lsb;
bits.pb(lsb);
}
x = a[i];
ll siz = sz(bits);
ll curr_cnt = 0;
rep(mask,1<<siz){
ll currv = 0;
rep(i,siz){
if(mask&(1<<i)){
currv |= 1<<bits[i];
}
}
ll coeff = 1;
if(setbits(mask)&1) coeff = -1;
curr_cnt += coeff*mp[currv];
mp[currv]++;
}
one_cnt += curr_cnt;
}
ll ans = one_cnt+(n*(n-1)/2-one_cnt)*2;
cout << ans << endl;
/*
rep1(x,(1<<10)-1){
rep1(y,(1<<10)-1){
assert(ceil2((x+y),(x|y)) <= 2);
}
}
*/
}
int main()
{
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
Editorialist's code (PyPy3)
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
from collections import defaultdict
freq = defaultdict(int)
ans = 0
for i, x in enumerate(a):
sub = x
ans -= i
while sub > 0:
if sub.bit_count() % 2 == 0: ans -= freq[sub]
else: ans += freq[sub]
freq[sub] += 1
sub = (sub-1) & x
print(ans + n*(n-1))