Pairwise AND - Editorial

Problem Link:

practice

Authors and Editorialists: Shivam Sahni, Amey Kulkarni

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Bitwise AND, Number Theory

PROBLEM:

Given a number N, compute the sum of pairwise AND of all consecutive numbers from 0 to N (inclusive).

QUICK EXPLANATION:

Instead of going number by number, go bit by bit.

EXPLANATION:

Consider the bits at different positions. They change their values according to the following patterns-
Bit position β†’ Observation
1 β†’ 0, since the last bit alternates, the AND of any two consecutive will always be 0.

2 β†’ In every interval [4n, 4n + 3], there will be two consecutive numbers(4n + 2, 4n + 3) that have the 2^{nd} bit set to 1. Hence add 1 \times 2 to the sum for each such interval.

3 β†’ In every interval [8n, 8n + 7], there will be four consecutive numbers([8n + 4, 8n + 7]) that have the 3^{rd} bit set to 1. Hence add 3 \times 4 to the sum for each such interval.

4 β†’ In every interval [16n, 16n + 15], there will be eight consecutive numbers([16n + 8, 16n + 15]) that have the 4^{th} bit set to 1. Hence add 7 \times 8 to the sum for each such interval.

… and so on.

In general,
K β†’ In every interval [2^kn, 2^kn + 2^k - 1], there will be 2^{k-1} consecutive numbers([2^kn + 2^{k-1}, 2^kn + 2^k-1]) that have the k^{th} bit set to 1. Hence add 2^{k-1}-1 \times 2^{k-1} to the sum for each such interval.

When N isn’t a perfect power of 2

In this case, the number of pairs will not be 2^{k-1}-1 pairs, for the last interval. Instead, we will have to calculate how many such pairs there are. An important observation is that all pairs where the k^{th} bit is set occur at the end of the interval that we have chosen. Using this, you can find out how many pairs contribute in the last interval.

Time Complexity

O(log(N)), since we are iterating over the bits of N.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#include <iostream>
#define mod 1000000007
#define ll long long
#define vi vector<int>
#define vii vector<vector<int>>
#define fo(i,n) for(int i=0;i<n;i++)
#define pb(x) push_back(x)
#define ci(x) cin>>x
#define ci2(x,y) cin>>x>>y
#define ci3(x,y,z) cin>>x>>y>>z
#define co(x) cout<<x<<endl
#define co2(x,y) cout<<x<<' '<<y<<endl
#define co3(x,y,z) cout<<x<<' '<<y<<' '<<z<<endl

using namespace std;

ll max(ll a, ll b){
    if(a > b)
        return a;
    return b;
}

void test(){
    ll n;
    cin>>n;
    // We have n + 1 numbers
    ll k = 2; // Group of numbers starting with zero
    ll ans = 0;
    while(k <= 2 * (n + 1)){
        ll num_grps = (n + 1) / k;
        ll term1 = (((num_grps) * (k / 2 - 1)) % mod) * ((k / 2) % mod);
        term1 %= mod;
        ans += term1;
        ans %= mod;
        ll leftover = (n + 1) % k;
        ll term2;
        ll pairs = leftover - k / 2 - 1;
        if(pairs <= 0)
            term2 = 0;
        else
            term2 = (pairs % mod) * ((k / 2) % mod);
        term2 %= mod;
        ans += term2;
        ans %= mod;
        k *= 2;
    }
    cout<<ans<<endl;
}


int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    int t;
    cin>>t;
    while(t--){
        test();
    }
}
Editorialist's Solution
#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define mp make_pair
#define vi vector <int>
#define vvi vector <vi>
#define pii pair<int, int>
#define piii pair<int, pii>
#define pll pair<ll, ll>
#define sz(v) ((int)(v).size())
#define all(v) v.begin(), v.end()
#define MOD 1000000007
using namespace std;

int main(){

#ifndef ONLINE_JUDGE 
    freopen("input.txt", "r", stdin); 
    freopen("output.txt", "w", stdout); 
#endif
    
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    int t; cin>>t;
    while (t--){
        ll n; cin>>n;
        ++n;
        ll ans = 0;
        for (int i=0; i<60; i++){
            ll t1 = pow(2, i+1);
            ll t2 = t1/2;
            ans = (ans + ( ( ( (t2 - 1LL + MOD) % MOD) * (t2 % MOD) ) % MOD * ( ( n/t1 ) % MOD) ) % MOD) % MOD;
            ll t3 = ( n - ( t1 * ( n/t1 ) ) ) - t2;
            if (t3>0){
                ans = ( ans + ( ( (t3 - 1LL + MOD) % MOD) * (t2 % MOD) ) % MOD ) % MOD;
            }
        }
        cout<<ans<<endl;
    }

    return 0;
}
1 Like