THREEBIKES - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Authors: Daanish Mahajan
Testers: Abhinav Sharma and Lavish Gupta
Editorialist: Nishank Suresh

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Sum-over-subsets DP

PROBLEM:

You are given an array M = [M_1, M_2, \ldots, M_N]. Count the number of ordered triples of distinct indices (i, j, k) such that (M_i \oplus M_j) \mathbin{\&} M_k = M_i \oplus (M_j \mathbin{\&} M_k).

QUICK EXPlANATION

  • Note that any triple that satisfies this condition must satisfy A_i \mathbin{\&} A_k = A_i, i.e, A_i must be a submask of A_k.
  • This is also a sufficient condition for the triple to be good.
  • Upon fixing a k, the number of possible choices of i can be found using sum-over-subsets DP.
  • j can be chosen arbitrarily once i and k are fixed, leaving it with N-2 choices.

EXPLANATION:

A common trick with problems dealing with bitwise operations is to treat each bit independently, so let’s do that here.
This means that each of M_i, M_j, M_k can be either 0 or 1.

  • Suppose M_i = 0.
    Then, (M_i \oplus M_j) \mathbin{\&} M_k = M_j \mathbin{\&} M_k and M_i \oplus (M_j \mathbin{\&} M_k) = M_j \mathbin{\&} M_k so both expressions are already equal regardless of choice of M_j and M_k.
  • Suppose M_i = 1.
    Then, if M_k = 0 we have (M_i \oplus M_j) \mathbin{\&} M_k = 0 and M_i \oplus (M_j \mathbin{\&} M_k) = 1 regardless of what M_j is, which means they will never be equal.
    Thus, we must have M_k = 1. In this case, it can be verified that both equations evaluate to 1 \oplus M_j, so they’re equal and once again, the value of M_j doesn’t matter.

This tells us that what M_j is doesn’t matter at all, while M_i can be 1 only if M_k is also 1.
Extending this condition to more bits tells us that a triple is good if and only if M_i is a submask of M_k.

Thus, we have the following algorithm to compute the number of triples:

  • First, fix which index is chosen as k.
  • Then, count the number of indices which can possibly be i - this is exactly the count of integers M_i such that M_i is a submask of M_k.
  • Finally, j can be freely chosen to be any of the remaining N-2 indices.

The first part is easy to do — simply iterate over every index of the array. The third part is also trivial, which leaves the second.

The second part essentially requires us to solve the following problem:
Let F_x denote the number of indices i such that M_i = x. We would like to compute

S_x = \sum_{y \subseteq x} F_y

where y\subseteq x means that y is a submask of x.

This is a classical problem, which can be solved in \mathcal{O}(B2^B) using sum-over-subsets DP, where B is the number of bits.
In this case, the bound M_i \leq 10^6 gives us B = 20, because 2^{20} > 10^6.
If you do not know what sum over subsets DP is, please go through this codeforces blog.

The final solution to the problem is then simply:

  • Compute the array S using SOS DP.
  • iterate over each index 1\leq k \leq N.
  • Add (S_{M_k} - 1)\cdot(N-2) to the answer. We subtract 1 from S_{M_k} because M_k is itself a submask of M_k, and we can’t choose i = k.

TIME COMPLEXITY:

\mathcal{O}(N + B\cdot 2^B) per test case, where B = 20 for this problem.

SOLUTIONS:

Setter's Solution (C++)
#include<bits/stdc++.h>
using namespace std;

const int BITS = 20;

int main(){
  ios_base::sync_with_stdio(false);
  cin.tie(NULL);
  cout.tie(NULL);

  int t; cin>>t; while(t--){

    int n; cin>>n;

    vector<int> A(n);
    for(int i=0; i<n; ++i)
      cin>>A[i];

    //make freq arr from input
    vector<int> F(1<<BITS,0);
    for(int i:A) F[i]++;

    vector<vector<int>> dp(1<<BITS,vector<int>(BITS+1));
    for(int t=0; t<(1<<BITS); t++){
      dp[t][0]=F[t];
    }

    for(int t=0; t<(1<<BITS); t++){
      for(int i=1; i<=BITS; i++){
        if(t&(1<<(i-1))){
          dp[t][i] = dp[t][i-1]+dp[t^(1<<(i-1))][i-1];
        }else{
          dp[t][i] = dp[t][i-1];
        }
      }
    }

    long long ans = 0;
    for(int i=0; i<n; i++){
      ans += (1ll*(dp[A[i]][BITS]-1)*(n-2));
    }

    cout<<ans<<'\n';
  }

  return 0;
}
Tester's Solution (C++)
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
 
int sum_len = 0;
int max_n = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;

const ll MX=200000;
ll fac[MX], ifac[MX];

const ll mod = 1e9+7;

ll po(ll x, ll n ){ 
    ll ans=1;
     while(n>0){
        if(n&1) ans=(ans*x)%mod;
        x=(x*x)%mod;
        n/=2;
     }
     return ans;
}

void solve()
{   
  int n;
  n = readIntLn(3, 1e5);
  sum_len+=n;
  max_n = max(max_n, n);
  
  int dp[(1<<20)] = {0};
  vector<int> v(n);
  int x;
  for(int i=0; i<n-1; i++){
    x = readIntSp(0, 1e6);
    v[i]=x;
    dp[x]++;
  }

  x = readIntLn(0, 1e6);
  v[n-1]=x;
  dp[x]++;

  for(int j=0; j<20; j++){
    for(int i=0; i<(1<<20); i++){
      if(!((i>>j)&1)) dp[(i^(1<<j))]+= dp[i];
    }
  }

  ll ans = 0;

  for(int i=0; i<n; i++){
    ans += (dp[v[i]]-1);
  }

  ans*=(n-2);
  cout<<ans<<'\n';

}
 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
    t = readIntLn(1,5);
        
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
    
    assert(getchar() == -1);
    assert(sum_len <= 1e5);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths : " << sum_len << '\n';
    cerr<<"Maximum length : " << max_n << '\n';
    // cerr<<"Total operations : " << total_ops << '\n';
    //cerr<<"Answered yes : " << yess << '\n';
    //cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(0); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n), freq(1<<20);
        for (int &x : a) {
            cin >> x;
            ++freq[x];
        }
        vector<int> subct(1<<20);
        for (int i = 0; i < 20; ++i) {
            for (int mask = 0; mask < 1<<20; ++mask) {
                if (i == 0) {
                    subct[mask] = freq[mask];
                }
                if (mask & (1<<i)) {
                    subct[mask] += subct[mask ^ (1<<i)];
                }
            }
        }
        ll ans = 0;
        for (int i = 0; i < n; ++i) {
            // Fix i to be the 3rd element of the triple
            // Then, the first element can be anything which is a submask of a[i] (except a[i] itself), which is subct[a[i]] - 1
            ll ways = subct[a[i]] - 1;
            // Second element can be any among the remaining
            ways *= n-2;
            ans += ways;
        }
        cout << ans << '\n';
    }
}
4 Likes

Constraints were too tight, long longs TLEd, then int overflowed ; (

The constraints were not especially tight — my code ran in 0.14s and the tester’s took 0.10s. Every accepted submission in the contest also ran in less than 0.2s.

The main issue with your code is the implementation. SOS DP can be implemented using 2^B memory, while you’ve used B \cdot 2^B.
One consequence of this is that you ordered your 2D dp table wrong. When you have two dimensions and one is much smaller than the other, it is often better to have the smaller dimension first, especially when using vectors.
Part of the reason why is because small vectors are rather slow, and keeping the smaller dimension second means you end up with lots of small vectors, rather than a few large vectors. There’s also some stuff to do with cache misses, which I’m not qualified enough to speak about.

If you have a multiple-dimensional dp and think it seems to be running slower than it should, try switching around the order of states to see if that speeds it up.

Here’s your first TLE submission with just the states flipped, which now runs in 0.33s, well within the limit.

11 Likes

Didn’t knew about that, thanks for informing