AWESUM_OR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: shauryabhalla0
Tester: abhidot
Editorialist: iceknight1093

DIFFICULTY:

2152

PREREQUISITES:

Combinatorics or dynamic programming

PROBLEM:

Given an integer N, find the number of triplets X, Y, Z such that

  • 0 \lt X, Y, Z \lt N
  • X+Y+Z = X\mid Y\mid Z = N

EXPLANATION:

For any integers X and Y, it’s not hard to see that X\mid Y \leq X+Y, with equality holding if and only if X and Y don’t share any bits in their binary representations.

This extends to three integers as well: X\mid Y\mid Z \leq X+Y+Z, and equality holds if and only if X, Y, Z all have distinct bits set.

Now, we know that X\mid Y\mid Z = N. This means that

  • If a bit b is not set in N, then it can’t be set in any of X, Y, Z
  • If a bit b is set in N, then it should be set in at least one of X, Y, Z.
    Our earlier discussion further tells us that it can’t be set in two or more of them, so it must be set in exactly one of them.

So, each set bit in N must be distributed to one of X, Y, or Z; while ensuring that each of them gets at least one set bit; our objective is to count the number of ways to do this.

There are several ways to do this, here are a couple.

Direct math

The number of ways can be calculated directly, using the inclusion-exclusion principle.

First, let’s ignore the “each value gets at least one bit” condition, i.e, allow zeros.
Counting the number of ways here is simple: each set bit in N has exactly 3 choices, for whether it’s given to X, Y, or Z.
So, if N has K set bits, the number of ways is 3^K.

Now, let’s remove the cases when some of the values are 0.
If X = 0 in the end, all the bits were distributed to Y and Z, i.e, 2 options for each bit for 2^K in total.
By symmetry, this is also the number of ways in which Y = 0, or Z = 0.
So, the total number of ways here is 3\cdot 2^K.

Finally, we need to add in the number of ways where two or more of X, Y, Z are zero, since those would have been removed multiple times.
It’s easy to see that there are only three ways here: (N, 0, 0), (0, N, 0), (0, 0, N).

So, the answer is simply

3^K - 3\cdot 2^K + 3
Dynamic programming

Suppose N has K set bits.
As our observations showed, only this number K matters; which bits were set doesn’t matter at all.

Our aim is to find out the number of ways to split K bits into three non-empty subsets.
Since K is quite small (N \lt 2^{60}, so K \leq 60), this can be done using dynamic programming.

Let dp_i be the number of ways to split i bits into three non-empty subsets.

Then, if we fix the size of the first subset j (1 \leq j \lt i), we have:

  • \binom{i}{j} ways to choose which bits go into this subset.
  • 2^{i-j}-2 ways to distribute the remaining bits into two non-empty subsets.

So, we have

dp_i = \sum_{j=1}^{i-1} \binom{i}{j} \left ( 2^{i-j} - 2 \right )

which can easily be precomputed for all K \leq 60 in \mathcal{O}(K^2) or \mathcal{O}(K^3), after which answering queries is easy.

TIME COMPLEXITY

\mathcal{O}(\log N) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define int long long   

const int mod=1e9+7;
vector<int> pcom(61, 0);

int binexp(int a, int b, int mod){
    assert(b>=0);
    a=a%mod;
    int ans = 1;
    while(b){
        if(b&1){
            ans=ans*a%mod;
        }
        a=a*a%mod;
        b/=2;
    }
    return ans;
}

void solve(){
    int n;
    cin>>n;

    int x = __builtin_popcountll(n);
    cout<<pcom[x]*6%mod<<'\n';
    // cout<<((binexp(3, x, mod)-3*binexp(2, x, mod)%mod+mod)%mod+3)%mod<<'\n';
}

signed main(){

    ios::sync_with_stdio(false);
    cin.tie(0);  cout.tie(0);

    for(int a=3; a<=60; a++){
        int sum = 0;
        for(int b = a-1; b>0; b--){
            for(int c = b-1; c>0; c--){
                sum = (sum + binexp(2, b-c-1, mod)*binexp(3, c-1, mod)%mod)%mod;
            }
        }
        pcom[a] = sum;
    }

    int tt;
    cin>>tt;

    while(tt--) solve();
}     
Tester's code (C++)
// Problem: SUM OR
// Contest: CodeChef - STR84TST
// Memory Limit: 256 MB
// Time Limit: 1000 ms
// Author: abhidot

// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp>
#define int long long
#define ll long long
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);
#define pb push_back
#define mod 1000000007
#define mod2 998244353
#define lld long double
#define pii pair<int, int>
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define uniq(v) (v).erase(unique(all(v)),(v).end())
#define rep(i,x,y) for(int i=x; i<y; i++)
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define V vector
#define setbits(x) __builtin_popcountll(x)
#define w(x)  int x; cin>>x; while(x--)
using namespace std;
using namespace __gnu_pbds; 
template <typename num_t> using ordered_set = tree<num_t, null_type, less<num_t>, rb_tree_tag, tree_order_statistics_node_update>;
const long long N=200005, INF=2000000000000000000, inf = 2e9+5;
 
int power(int a, int b, int p){
	if(a==0)
	return 0;
	int res=1;
	a%=p;
	while(b>0)
	{
		if(b&1)
		res=(res*a)%p;
		b>>=1;
		a=(a*a)%p;
	}
	return res;
}
 
 
void print(bool n){
    if(n){
        cout<<"YES\n";
    }else{
        cout<<"NO\n";
    }
}

int f[100], in[100]; 

int ncr(int n, int r){
	return f[n]*in[r]%mod*in[n-r]%mod;
}
 
int32_t main()
{
    IOS;
    f[0]=1, in[0]=1;
    for(int i=1;i<100;i++){
    	f[i]=i*f[i-1]%mod;
    	in[i]=power(f[i], mod-2, mod);
    }
		int ans[61]={0};
		for(int x=1;x<=60;x++){
			for(int y=1;x+y<=60;y++){
				for(int z=1;x+y+z<=60;z++){
					int s = x+y+z;
					if(s>60) continue;
					ans[s]+=(ncr(s, x)*ncr(s-x, y)%mod);
					ans[s]%=mod;
				}
			}
		}
		
		w(T){
			int n;
			cin>>n;
			cout<<ans[setbits(n)]<<"\n";
		}
}
Editorialist's code (Python)
mod = 10**9 + 7
for _ in range(int(input())):
    n = int(input())
    bits = bin(n).count('1')
    ans = pow(3, bits, mod) - 3*pow(2, bits, mod) + 3
    print(ans % mod)
2 Likes

Can we think like that out of k bits first we assign each a,b,c 1 bit so ways are k*(k-1)*(k-2) and now we are remaining with k-3 bits so each has three choice so 3^(k-3) so final ans will be

(k*(k-1)*(k-2)) + (3^(k-3)) ,but it gave WA why?

1 Like

Suppose you distribute the bits and X = 3 in the end.

Did X receive the 1 from the first step and the 2 from the second one, or did it receive the 2 from the first step and the 1 from the second one?
Your formulation in fact counts both of them, so it’s overcounting by quite a lot.

3 Likes

Is there a way of subtracting the overcounting done by this formula ?

1 Like

Not that I can see, at least not easily.

First off, the actual formula given there is logically wrong, it should be multiplication and not addition.
Second, the example I gave in my above comment obviously generalizes even further: if X has 10 set bits, you’re going to count it 10 different times.

The only way I can see to get rid of this is to actually apply the inclusion-exclusion principle, so you subtract the number of ways where X gets 2 bits, add the number of ways it gets 3, and so on.
The problem is this gets super messy since you need to deal with Y and Z as well; and the inclusion-exclusion itself is going to be slow doing it this way since your formulas get worse and worse each step.

Finally, if you’re using inc-exc anyway, you might as well use the method I detailed in the editorial for a simple one-line answer.

2 Likes

Got it Sir.
Thanks for such a wonderful explanation

best explained in youtube . It uses combinatorics like we do in 12 class.

#include "bits/stdc++.h"
using namespace std;

#define int        long long int
#define now(x)     cout<<#x<<" : "<<x<<endl;

int M = 1e9 + 7;

int power(int a,int b){
    int ans = 1;
    while(b){
        if(b&1) ans = (ans * a) %M;
        a = (a*a) % M;
        b >>= 1;
    }
    return ans;
}

void solve(){
    int n;cin>>n;

    int bt = __builtin_popcountll(n);

    if(bt < 3) {
        cout<<0<<endl;
        return;
    }

cout<<(( (power(3,bt)%M - (3*(power(2,bt)%M - 1 + M))%M + M)%M + M)%M)<<endl;
}

signed main() {
    ios_base::sync_with_stdio(false); cin.tie(NULL);
    int t = 1;
    cin >> t;  
    while (t--) solve();
    return 0;
} 

image

In cout why are you using mod , when you already know that “power” function is returning long long number (but value less than mod(1e9+7)) so we can easily multiply 3 into this . also for others there is no need to use modulo m , we can instead use it at once in last …
Please correct me if i am missing something . Thanks

I mean why we can’t simply wright this line as
cout<< ( power(3,bt) - 3*power(2,bt) +3 + M) % M << endl;

there many be any overflow after multiplying these two 3*power which have to be get in range

IceKnight Sir , EXPLANATION section of codechef problem is not so much beginner friendly. Unable to grasp concept. Video solution of codechef problem are poor, because they are not trying to develop the concept of problem but actually they are focusing on explaining all unrelatable code so that they are able to understand the code written by SETTLER & EDITORIAL EXPLANATION Section.
All worth of buying premium had ruin because of these small reason. But on my side I WILL DO IMPROVE my coding with learning concept building part of problem.

If you have a doubt about some specific part, you can always ask. That’s the point of this comment section, after all.

Sorry to hear that, but I have nothing to do with the video editorials.
If you have constructive criticism on how they can be improved, feel free to inform the team at help@codechef.com

My approach using dynamic programming and inclusion exclusion

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int mod = 1e9+7;

int tripleSum(vector<int> &digits){
    int n = digits.size();
    int dp[n][2];
    if(digits[0]&1){
        dp[0][0] = 3;
        dp[0][1] = 1;
    } else {
        dp[0][0] = 1;
        dp[0][1] = 0; 
    }
    for(int i=1;i<n;i++){
        if(digits[i]&1){
            dp[i][0] = (1ll*3*dp[i-1][0])%mod;
            dp[i][1] = (1ll*3*(dp[i-1][1] + dp[i-1][0]))%mod;
        } else {
            dp[i][0] = dp[i-1][0];
            dp[i][1] = 0;
        }
    }
    return (dp[n-1][0])%mod;
}

int doubleSum(vector<int> &digits){
    int n = digits.size();
    int dp[n][2];
    if(digits[0]&1){
        dp[0][0] = 2;
        dp[0][1] = 0;
    } else {
        dp[0][0] = 1;
        dp[0][1] = 0; 
    }
    for(int i=1;i<n;i++){
        if(digits[i]&1){
            dp[i][0] = (2*dp[i-1][0])%mod;
            dp[i][1] = dp[i-1][1];
        } else {
            dp[i][0] = dp[i-1][0];
            dp[i][1] = 0;
        }
    }
    return (dp[n-1][0])%mod;
}

signed main() {
	
	int t;
	cin>>t;
	while(t--){
	    int n;
	    cin>>n;
	    
	    vector<int> digits;
	    while(n>0){
	        digits.push_back(n&1);
	        n>>=1;
	    }
	    
	    reverse(digits.begin(), digits.end());
	    int ans = tripleSum(digits);
	    int k = (1ll*3*doubleSum(digits))%mod;
	    ans = (ans-k+mod)%mod;
	    ans = (ans+3)%mod;
	    cout<<ans<<endl;
	}
	
	return 0;
}