 BRTXORS - Editorial

Setter: Manuj Nanthan
Tester: Aryan Choudhary
Editorialist: Ajit Sharma Kasturi

SIMPLE

Bitwise-xor

PROBLEM:

Given a number N, we need to find the number of distinct valuyes possible for i \oplus j where 1 \leq i, j \leq N.

EXPLANATION:

• If N = 1, the possible xor values are 1 \oplus 1 = 0 which has 1 unique value.

• If N = 2, the possible xor values are 1 \oplus 1 = 0, 1 \oplus 2 = 1, 2 \oplus 2 = 0 which has 2 unique values.

• For the remaining cases, we consider N \geq 3.

• Let us define x as the highest power for which 2^x \leq N.

• Suppose 2^x \lt N. I claim that we can get all the numbers from 0 to 2^{x+1} -1. This can be achieved by the following way:

• For the numbers num which have bit x set and are greater than 2^x, it can be formed by (2^x, 2^x \oplus num). For example, let N = 12, then we have x = 3. Number 12 (1100 in binary ) can be formed by (2^3(1000), 4(0100)).

• Number 0 can be formed by (1, 1) since 1 \oplus 1 = 0. Number 1 can be formed by (2, 3) since 2 \oplus 3 = 1. For the remaining numbers 1 \lt num \leq 2^x, we can simply get them as (1, num \oplus 1). For example, 2 = 1 \oplus 3, 3=1 \oplus 2 and so on.

• By the property of xor, num \oplus 1 is either num +1 or num -1, so if 1 \lt num \leq 2^x \lt N, 1 \leq num \oplus 1 \leq N. Hence these pairs of numbers will always be valid.

• Now what happens if 2^x = N ? All of the above cases hold true except for the case of num = 2^x. We cannot get this from any xor pair (i, j) where 1 \leq i, j \leq 2^x. Since the only number with bit x set is 2^x, we must keep i = 2^x. Then for i \oplus j = 2^x, we must keep j=0, which we cannot do since j \geq 1. Hence, in this case, except 2^x, we can get xor pair for any number from 0 to 2^{x + 1} -1.

TIME COMPLEXITY:

O(\log N) for each testcase.

SOLUTION:

Editorialist's solution

#include <bits/stdc++.h>
#define ll long long int
using namespace std;

int main() {
int tests;
cin >> tests;
while (tests--) {
ll n;
cin >> n;

ll ans = 1;
int MOD = 1e9 + 7;

// Special case
if (n == 2) {
cout << 2 << endl;
continue;
}

while (ans < n) {
ans *= 2;
}

if (ans == n) {
ans *= 2;
ans--;
}

cout << ans % MOD << endl;
}
return 0;
}

Setter's solution

#from itertools import *
#from math import *
#from bisect import *
#from collections import *
#from random import *
#from decimal import *
#from heapq import *
#from itertools import *            # Things Change ....remember :)
import sys
def inp():
return int(input())
def st():
return input().rstrip('\n')
def lis():
return list(map(int,input().split()))
def ma():
return map(int,input().split())
t=inp()
p=10**9 + 7
while(t):
t-=1
n=inp()
if(n<=2):
print(n)
else:
x=bin(n)[2:]
res=pow(2,len(x),p)
if(x.count('1')==1):
res-=1
print(res%p)

Tester's solution
def main():
mod=10**9+7
for _ in range(int(input())):
n=int(input())
if n<=2:
print(n)
elif n&(n-1):
print(pow(2,len(bin(n))-2,mod))
else:
print((2*n-1)%mod)

main()

Please comment below if you have any questions, alternate solutions, or suggestions. 1 Like

This is my solution which I believe runs in O(1) for each testcase. Tell me if the complexity is wrong.

#include <bits/stdc++.h>
using namespace std;

#define long long ll
const int MOD = 1e9+7;

// To find the highest power of 2 less than or equal to the given number.
ll highestPow(ll n) {
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n |= n >> 32;

return n ^ (n >> 1);
}

void solve() {
ll n;
cin >> n;

if(n == 2) {
cout << 2 << '/n';
return;
}

if(n == highestPow(n)) {
n <<= 1;
cout << (n - 1) % MOD   << '/n';
}
else {
n = highestPow(n);
n <<= 1;
cout << n % MOD << '/n';
}
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int tt;
cin >> tt;
while(tt--) {
solve();
}
return 0;
}

Also, thanks editorialist for a clean solution. Haven’t seen a clean solution in a while.

Hii bro thanks for the solution !!!
Can you tell plz tell me where i am going wrong
void solve(){

ll n;
cin >> n;
set<int> st;

if (n == 1){
cout << 1 << '\n'; return;
}
if (n == 2){
cout << 2 << '\n'; return;
}

vector<int> x;
while (n != 0){
x.pb(1);
n = n/2;
}
ll ans = 0;
for (ll i = 0 ; i< (ll) x.size() ; i++){
ans = (ans + power(2ll , i))%MOD;
}
cout << (ans+1)%MOD << '\n';

}

… i tried few test cases with your code they are showing same ans but it is showing wa

I did the same thing in my solution (C language). I used induction to prove that for N\geq3, the number is the 2’s power higher than N except when N is a power of 2 when the value is 2N-1. My solution is here, which is also O(\log N) I believe (It’s done using the same logic as the above editorial). Here is the code for your reference -

#include <stdio.h>
#define PRIME 1000000007

int main(void) {
int t, n, num, val;
scanf("%d", &t);
while(t--){
val=1;
scanf("%d", &n);
if(n==1) num=1;
else if(n==2) num=2;
else{
while(n>val){
val*=2;
}
if(val==n){
num=2*val-1;
}
else num=val;
}
printf("%d\n", num%(PRIME));
}
return 0;
}

I got time limit exceeded (TLE) error in two tasks, I don’t know why.

Hi, you used int but in the problem statement it’s mentioned that 1 <= N <= 10^12, which cannot fit in int data type that is why you are getting TLE. if you change int to long long and type specifier to %lld then you won’t get TLE.

1 Like

Can someone please tell me why my below 3 solutions are getting WA?

I did a little bit of “cheating” here
What I did is first I wrote a brute force solution and got the values for the first 10 numbers and searched the series in OEIS and I got this

Then I tried to write a solution based on the above OEIS formula but eventually got WA
Can someone please help me if my approach is wrong or am I calculating the mod incorrectly or if I am completely strayed off the correct path?

I believe you are not handling the case where N is a power of 2. Your code is giving output for every N which is a power of 2 as N itself.
For example, for N = 4, your output is 4, which is wrong. It should be 7.
Other cases are fine.

1 Like

Can you also share the implementation of the power() function you used in your submission?

ll power(ll base , ll n){
ll ans = 1;
while (n!= 0){
if (n&1){
ans = (ansbase)%MOD;
n = n-1;
}
else{
base = (base
base)%MOD;
n = n >> 1;
}
}
return ans;
}
here it is bro … kindly tell if there is a mistake

Well thanks it worked for me

so whenever the value of n is a power of 2, your code is giving the output as N it self.
eg - N = 4, output = 4, N = 8, output = 8,
but it should give the output as the (next power of 2 - 1).
eg: N = 4, output = nextPowerof2 - 1 = 8 - 1 = 7 (since after 4, 8 is the power of 2), similarly for
N = 8, output = nextPowerof2 - 1 = 16 - 1 = 15 (since after 8, 16 is the power of 2).
I hope this helps. let me know if you don’t understand something.

My solution

I am Still getting WA

That’s because you did not take care of edge case. For N = 2, ans is not nextPowerOf2 - 1, it is 2 it self. You need to handle that case separately.

1 Like

Yes, as mentioned by @ankitksh81, you are not handling the case when N is a power of 2, correctly.

Thank you very much it worked

Can you please see my code and tell me what is wrong here? Python code inside.

https://www.codechef.com/viewsolution/55050431

We need to print the answer modulo 10^9 + 7. When I added this to your code, it got AC.

https://www.codechef.com/viewsolution/55105444

1 Like

Ohhh, I missed that part; need to learn to description better . Thank you for your reply.

I tried this problem several times but couldn’t figure out where I am getting it wrong. Here is the code:
#include
#include
typedef long long ll;
using namespace std;
int main(){
ll T, N, i;
ll mod=pow(10, 9)+7;
cin >> T;
while(T–){
cin>> N;
if(N==2){
cout<<2<<endl;
}
else if(N & (N-1)==0){
cout<<(2*N-1)%mod<<endl;
}
else{
ll bit=log2(N);
ll power=pow(2, bit+1);
cout<<power% mod<<endl;
}
}
}