PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Authors: Daanish Mahajan
Testers: Abhinav Sharma and Lavish Gupta
Editorialist: Nishank Suresh
DIFFICULTY:
Easy-Medium
PREREQUISITES:
Sum-over-subsets DP
PROBLEM:
You are given an array M = [M_1, M_2, \ldots, M_N]. Count the number of ordered triples of distinct indices (i, j, k) such that (M_i \oplus M_j) \mathbin{\&} M_k = M_i \oplus (M_j \mathbin{\&} M_k).
QUICK EXPlANATION
- Note that any triple that satisfies this condition must satisfy A_i \mathbin{\&} A_k = A_i, i.e, A_i must be a submask of A_k.
- This is also a sufficient condition for the triple to be good.
- Upon fixing a k, the number of possible choices of i can be found using sum-over-subsets DP.
- j can be chosen arbitrarily once i and k are fixed, leaving it with N-2 choices.
EXPLANATION:
A common trick with problems dealing with bitwise operations is to treat each bit independently, so let’s do that here.
This means that each of M_i, M_j, M_k can be either 0 or 1.
- Suppose M_i = 0.
Then, (M_i \oplus M_j) \mathbin{\&} M_k = M_j \mathbin{\&} M_k and M_i \oplus (M_j \mathbin{\&} M_k) = M_j \mathbin{\&} M_k so both expressions are already equal regardless of choice of M_j and M_k. - Suppose M_i = 1.
Then, if M_k = 0 we have (M_i \oplus M_j) \mathbin{\&} M_k = 0 and M_i \oplus (M_j \mathbin{\&} M_k) = 1 regardless of what M_j is, which means they will never be equal.
Thus, we must have M_k = 1. In this case, it can be verified that both equations evaluate to 1 \oplus M_j, so they’re equal and once again, the value of M_j doesn’t matter.
This tells us that what M_j is doesn’t matter at all, while M_i can be 1 only if M_k is also 1.
Extending this condition to more bits tells us that a triple is good if and only if M_i is a submask of M_k.
Thus, we have the following algorithm to compute the number of triples:
- First, fix which index is chosen as k.
- Then, count the number of indices which can possibly be i - this is exactly the count of integers M_i such that M_i is a submask of M_k.
- Finally, j can be freely chosen to be any of the remaining N-2 indices.
The first part is easy to do — simply iterate over every index of the array. The third part is also trivial, which leaves the second.
The second part essentially requires us to solve the following problem:
Let F_x denote the number of indices i such that M_i = x. We would like to compute
where y\subseteq x means that y is a submask of x.
This is a classical problem, which can be solved in \mathcal{O}(B2^B) using sum-over-subsets DP, where B is the number of bits.
In this case, the bound M_i \leq 10^6 gives us B = 20, because 2^{20} > 10^6.
If you do not know what sum over subsets DP is, please go through this codeforces blog.
The final solution to the problem is then simply:
- Compute the array S using SOS DP.
- iterate over each index 1\leq k \leq N.
- Add (S_{M_k} - 1)\cdot(N-2) to the answer. We subtract 1 from S_{M_k} because M_k is itself a submask of M_k, and we can’t choose i = k.
TIME COMPLEXITY:
\mathcal{O}(N + B\cdot 2^B) per test case, where B = 20 for this problem.
SOLUTIONS:
Setter's Solution (C++)
#include<bits/stdc++.h>
using namespace std;
const int BITS = 20;
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int t; cin>>t; while(t--){
int n; cin>>n;
vector<int> A(n);
for(int i=0; i<n; ++i)
cin>>A[i];
//make freq arr from input
vector<int> F(1<<BITS,0);
for(int i:A) F[i]++;
vector<vector<int>> dp(1<<BITS,vector<int>(BITS+1));
for(int t=0; t<(1<<BITS); t++){
dp[t][0]=F[t];
}
for(int t=0; t<(1<<BITS); t++){
for(int i=1; i<=BITS; i++){
if(t&(1<<(i-1))){
dp[t][i] = dp[t][i-1]+dp[t^(1<<(i-1))][i-1];
}else{
dp[t][i] = dp[t][i-1];
}
}
}
long long ans = 0;
for(int i=0; i<n; i++){
ans += (1ll*(dp[A[i]][BITS]-1)*(n-2));
}
cout<<ans<<'\n';
}
return 0;
}
Tester's Solution (C++)
#include <bits/stdc++.h>
using namespace std;
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
const ll MX=200000;
ll fac[MX], ifac[MX];
const ll mod = 1e9+7;
ll po(ll x, ll n ){
ll ans=1;
while(n>0){
if(n&1) ans=(ans*x)%mod;
x=(x*x)%mod;
n/=2;
}
return ans;
}
void solve()
{
int n;
n = readIntLn(3, 1e5);
sum_len+=n;
max_n = max(max_n, n);
int dp[(1<<20)] = {0};
vector<int> v(n);
int x;
for(int i=0; i<n-1; i++){
x = readIntSp(0, 1e6);
v[i]=x;
dp[x]++;
}
x = readIntLn(0, 1e6);
v[n-1]=x;
dp[x]++;
for(int j=0; j<20; j++){
for(int i=0; i<(1<<20); i++){
if(!((i>>j)&1)) dp[(i^(1<<j))]+= dp[i];
}
}
ll ans = 0;
for(int i=0; i<n; i++){
ans += (dp[v[i]]-1);
}
ans*=(n-2);
cout<<ans<<'\n';
}
signed main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;
int t = 1;
t = readIntLn(1,5);
for(int i=1;i<=t;i++)
{
solve();
}
assert(getchar() == -1);
assert(sum_len <= 1e5);
cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
cerr<<"Sum of lengths : " << sum_len << '\n';
cerr<<"Maximum length : " << max_n << '\n';
// cerr<<"Total operations : " << total_ops << '\n';
//cerr<<"Answered yes : " << yess << '\n';
//cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int main()
{
ios::sync_with_stdio(0); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<int> a(n), freq(1<<20);
for (int &x : a) {
cin >> x;
++freq[x];
}
vector<int> subct(1<<20);
for (int i = 0; i < 20; ++i) {
for (int mask = 0; mask < 1<<20; ++mask) {
if (i == 0) {
subct[mask] = freq[mask];
}
if (mask & (1<<i)) {
subct[mask] += subct[mask ^ (1<<i)];
}
}
}
ll ans = 0;
for (int i = 0; i < n; ++i) {
// Fix i to be the 3rd element of the triple
// Then, the first element can be anything which is a submask of a[i] (except a[i] itself), which is subct[a[i]] - 1
ll ways = subct[a[i]] - 1;
// Second element can be any among the remaining
ways *= n-2;
ans += ways;
}
cout << ans << '\n';
}
}