OPMXOR - Editorial

PROBLEM LINK:

Practice

Author: Shril

Tester: Janvi

Editorialist: Shril

DIFFICULTY:

Easy

PREREQUISITES:

Pattern identification and knowledge of XOR

PROBLEM:

Given N and K, create a array with the given formula:

  • A[1] = K

  • A[i] = A[i-1]+ 2^{i-1} (for 2 \le i \le N+1)

Find the XOR of all numbers of this array and print it modulo 10^9+7.

QUICK EXPLANATION:

It can be found that for K=2, the XOR of all numbers comes out to be 2^{N+2}-2. And for K=3, the XOR comes to be 2^{N+2}-2 for odd numbers and 2^{N+2}-1 for even numbers.

EXPLANATION:

A pattern can be found if we find XOR for smaller values of N and then use mathematical induction to generalize the formula.

For case of K=2, the array turns out to be like this:

  • [2, 4, 8, 16, 32, 64, 128, 256, 512]

And the corresponding XOR comes out to be:

  • [2, 6, 14, 30, 62, 126, 254, 510]

As we can see that all the values are just 2 less than the powers of 2.

Another thing to notice that the XOR are basically prefix sums of the array itself, so we can directly apply the Arithmetic Geometric progression formula to find the solution which comes out to be 2^{N+2}-2.

For case of K=3, the array turns out to be like this:

  • [3, 5, 9, 17, 33, 65, 129, 257, 513]

And the corresponding XOR comes out to be:

  • [3, 6, 15, 30, 63, 126, 255, 510]

As we can see that for even indices [0,2,4,\cdots] the XOR turns out to be 1 less than the power of 2.

This is because the XOR turns to be a binary number with all bits set. This happens because we are starting with 3. For odd indices [1,3,5,\cdots] the XOR is same as that of the case of K=2 and so answer is 2 less than power of 2.

So,

when N is odd, answer is 2^{N+2}-2

when N is even, answer is 2^{N+2}-1

So basically it turns to be a O(1) solution.

As for the printing the modulo, we can take modulo at each step so that the number doesn’t become too large and so the complexity becomes O(log(N)) using the fast modular exponentiation.

SOLUTIONS:

Setter's Solution

NUM=10**9+7

for _ in range(int(input())):
    N,K=list(map(int,input().split()))
    if(K==2):
        val=pow(2,N+2,NUM)
        print((val-2)%NUM)

    if(K==3):
        if(N&1):
            val=pow(2,N+2,NUM)
            print((val-2)%NUM)
        else:
            val=pow(2,N+2,NUM)
            print((val-1)%NUM)