BRTXORS - Editorial

PROBLEM LINK:

Practice
Div1
Div2
Div3

Setter: Manuj Nanthan
Tester: Aryan Choudhary
Editorialist: Ajit Sharma Kasturi

DIFFICULTY:

SIMPLE

PREREQUISITES:

Bitwise-xor

PROBLEM:

Given a number N, we need to find the number of distinct valuyes possible for i \oplus j where 1 \leq i, j \leq N.

EXPLANATION:

  • If N = 1, the possible xor values are 1 \oplus 1 = 0 which has 1 unique value.

  • If N = 2, the possible xor values are 1 \oplus 1 = 0, 1 \oplus 2 = 1, 2 \oplus 2 = 0 which has 2 unique values.

  • For the remaining cases, we consider N \geq 3.

  • Let us define x as the highest power for which 2^x \leq N.

  • Suppose 2^x \lt N. I claim that we can get all the numbers from 0 to 2^{x+1} -1. This can be achieved by the following way:

  • For the numbers num which have bit x set and are greater than 2^x, it can be formed by (2^x, 2^x \oplus num). For example, let N = 12, then we have x = 3. Number 12 (1100 in binary ) can be formed by (2^3(1000), 4(0100)).

  • Number 0 can be formed by (1, 1) since 1 \oplus 1 = 0. Number 1 can be formed by (2, 3) since 2 \oplus 3 = 1. For the remaining numbers 1 \lt num \leq 2^x, we can simply get them as (1, num \oplus 1). For example, 2 = 1 \oplus 3, 3=1 \oplus 2 and so on.

  • By the property of xor, num \oplus 1 is either num +1 or num -1, so if 1 \lt num \leq 2^x \lt N, 1 \leq num \oplus 1 \leq N. Hence these pairs of numbers will always be valid.

  • Now what happens if 2^x = N ? All of the above cases hold true except for the case of num = 2^x. We cannot get this from any xor pair (i, j) where 1 \leq i, j \leq 2^x. Since the only number with bit x set is 2^x, we must keep i = 2^x. Then for i \oplus j = 2^x, we must keep j=0, which we cannot do since j \geq 1. Hence, in this case, except 2^x, we can get xor pair for any number from 0 to 2^{x + 1} -1.

TIME COMPLEXITY:

O(\log N) for each testcase.

SOLUTION:

Editorialist's solution

#include <bits/stdc++.h>
#define ll long long int
using namespace std;

int main() {
	int tests;
	cin >> tests;
	while (tests--) {
	    ll n;
	    cin >> n;
	    
	    ll ans = 1;
	    int MOD = 1e9 + 7;
	    
	    // Special case
	    if (n == 2) {
	        cout << 2 << endl;
	        continue;
	    }
	    
	    while (ans < n) {
	        ans *= 2;
	    }
	    
	    if (ans == n) {
	        ans *= 2;
	        ans--;
	    }
	    
	    cout << ans % MOD << endl;
	}
	return 0;
}


Setter's solution

#from itertools import *
#from math import *
#from bisect import *
#from collections import *
#from random import *
#from decimal import *
#from heapq import *
#from itertools import *            # Things Change ....remember :)
import sys
input=sys.stdin.readline
def inp():
    return int(input())
def st():
    return input().rstrip('\n')
def lis():
    return list(map(int,input().split()))
def ma():
    return map(int,input().split())
t=inp()
p=10**9 + 7
while(t):
    t-=1
    n=inp()
    if(n<=2):
        print(n)
    else:
        x=bin(n)[2:]
        res=pow(2,len(x),p)
        if(x.count('1')==1):
            res-=1
        print(res%p)


Tester's solution
def main():
    mod=10**9+7
    for _ in range(int(input())):
        n=int(input())
        if n<=2:
            print(n)
        elif n&(n-1):
            print(pow(2,len(bin(n))-2,mod))
        else:
            print((2*n-1)%mod)

main()

Please comment below if you have any questions, alternate solutions, or suggestions. :slight_smile:

1 Like

This is my solution which I believe runs in O(1) for each testcase. Tell me if the complexity is wrong.

#include <bits/stdc++.h>
using namespace std;

#define long long ll
const int MOD = 1e9+7;

// To find the highest power of 2 less than or equal to the given number.
ll highestPow(ll n) {
    n |= n >> 1;
    n |= n >> 2;
    n |= n >> 4;
    n |= n >> 8;
    n |= n >> 16;
    n |= n >> 32;
    
    return n ^ (n >> 1);
}

void solve() {
    ll n;
    cin >> n;

    if(n == 2) {
        cout << 2 << '/n';
        return;
    }

    if(n == highestPow(n)) {
        n <<= 1;
        cout << (n - 1) % MOD   << '/n';
    }
    else {
        n = highestPow(n);
        n <<= 1;
        cout << n % MOD << '/n';
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int tt;
    cin >> tt;
    while(tt--) {
        solve();
    }
    return 0;
}

Also, thanks editorialist for a clean solution. Haven’t seen a clean solution in a while.

Hii bro thanks for the solution !!!
Can you tell plz tell me where i am going wrong
void solve(){

ll n;
cin >> n;
set<int> st;

if (n == 1){
    cout << 1 << '\n'; return;
}
if (n == 2){
    cout << 2 << '\n'; return;
}

vector<int> x;
while (n != 0){
    x.pb(1);
    n = n/2;
}
ll ans = 0;
for (ll i = 0 ; i< (ll) x.size() ; i++){
    ans = (ans + power(2ll , i))%MOD;
}
cout << (ans+1)%MOD << '\n';

}

… i tried few test cases with your code they are showing same ans but it is showing wa

I did the same thing in my solution (C language). I used induction to prove that for N\geq3, the number is the 2’s power higher than N except when N is a power of 2 when the value is 2N-1. My solution is here, which is also O(\log N) I believe (It’s done using the same logic as the above editorial). Here is the code for your reference -

#include <stdio.h>
#define PRIME 1000000007


int main(void) {
	// your code goes here
	int t, n, num, val;
	scanf("%d", &t);
	while(t--){
	    val=1;
	   scanf("%d", &n);
	   if(n==1) num=1;
	   else if(n==2) num=2;
	   else{
	       while(n>val){
	           val*=2;
	       }
	       if(val==n){
	           num=2*val-1;
	       }
	       else num=val;
	   }
	   printf("%d\n", num%(PRIME));  	   
	}
	return 0;
}

I got time limit exceeded (TLE) error in two tasks, I don’t know why.

Hi, you used int but in the problem statement it’s mentioned that 1 <= N <= 10^12, which cannot fit in int data type that is why you are getting TLE. if you change int to long long and type specifier to %lld then you won’t get TLE.

1 Like

Can someone please tell me why my below 3 solutions are getting WA?

Solution 1
solution 2
solution 3

I did a little bit of “cheating” here
What I did is first I wrote a brute force solution and got the values for the first 10 numbers and searched the series in OEIS and I got this

Then I tried to write a solution based on the above OEIS formula but eventually got WA
Can someone please help me if my approach is wrong or am I calculating the mod incorrectly or if I am completely strayed off the correct path?

I believe you are not handling the case where N is a power of 2. Your code is giving output for every N which is a power of 2 as N itself.
For example, for N = 4, your output is 4, which is wrong. It should be 7.
Other cases are fine.

1 Like

Can you also share the implementation of the power() function you used in your submission?

ll power(ll base , ll n){
ll ans = 1;
while (n!= 0){
if (n&1){
ans = (ansbase)%MOD;
n = n-1;
}
else{
base = (base
base)%MOD;
n = n >> 1;
}
}
return ans;
}
here it is bro … kindly tell if there is a mistake

Well thanks it worked for me

so whenever the value of n is a power of 2, your code is giving the output as N it self.
eg - N = 4, output = 4, N = 8, output = 8,
but it should give the output as the (next power of 2 - 1).
eg: N = 4, output = nextPowerof2 - 1 = 8 - 1 = 7 (since after 4, 8 is the power of 2), similarly for
N = 8, output = nextPowerof2 - 1 = 16 - 1 = 15 (since after 8, 16 is the power of 2).
I hope this helps. let me know if you don’t understand something.

My solution

I am Still getting WA

That’s because you did not take care of edge case. For N = 2, ans is not nextPowerOf2 - 1, it is 2 it self. You need to handle that case separately.

1 Like

Yes, as mentioned by @ankitksh81, you are not handling the case when N is a power of 2, correctly.

Thank you very much it worked

Can you please see my code and tell me what is wrong here? Python code inside.

https://www.codechef.com/viewsolution/55050431

We need to print the answer modulo 10^9 + 7. When I added this to your code, it got AC.

Solution link:
https://www.codechef.com/viewsolution/55105444

1 Like

Ohhh, I missed that part; need to learn to description better :slight_smile: . Thank you for your reply.

I tried this problem several times but couldn’t figure out where I am getting it wrong. Here is the code:
#include
#include
typedef long long ll;
using namespace std;
int main(){
ll T, N, i;
ll mod=pow(10, 9)+7;
cin >> T;
while(T–){
cin>> N;
if(N==2){
cout<<2<<endl;
}
else if(N & (N-1)==0){
cout<<(2*N-1)%mod<<endl;
}
else{
ll bit=log2(N);
ll power=pow(2, bit+1);
cout<<power% mod<<endl;
}
}
}