PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: Prasant Kumar
Tester: Harris Leung
Editorialist: Nishank Suresh
DIFFICULTY:
TBD
PREREQUISITES:
Basic combinatorics, familiarity with binary representations
PROBLEM:
Find the sum of A_1 \land A_2 \land A_3 \ldots \land A_N across all arrays A of length N such that each A_i is an integer between 1 and M.
EXPLANATION:
Note that bitwise AND is independent on bits, that is, we can do the following:
- For each 0 \leq i \leq 30, find the number of arrays such that the i-th bit is set in all the elements of A, i.e, the bitwise AND of A has the i-th bit set. Let this number be ct_i.
- Then, add ct_i \cdot 2^i to the answer.
The second step is easy, so all we really need to do is find ct_i.
Our problem thus reduces to this:
- Given i, how many length N arrays of integers from 1 to N have the i-th bit set in all elements?
For this reduced problem, note that it doesn’t matter which integer is placed where: just that each position must contain an integer with the i-th bit set.
So, if we knew k_i, the number of integers from 1 to M with the i-th bit set, then ct_i = k_i^N, since each of the N positions has exactly k_i choices for what is placed there, and these choices are independent.
All that remains is to compute k_i given M and i. Some observation should tell you that this boils down to a simple formula.
Every 2^{i+1} consecutive integers have exactly 2^i integers with the i-th bit set. So, k_i can be computed by considering blocks of length 2^{i+1} from 0, and looking at the last \lt 2^{i+1} elements separately. This yields the formula
where \% is the modulo operator.
This gives us the final solution:
- For each i from 0 to 30, compute k_i using the formula above.
- Then, compute ct_i = k_i ^ N.
- Finally, add ct_i \cdot 2^i to the answer
TIME COMPLEXITY
\mathcal{O}(30N) or \mathcal{O}(30\log N) per test case, depending on whether binary exponentiation is used or not.
CODE:
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
n, m = map(int, input().split())
ans = 0
for bit in range(30):
ct = (m // (2 << bit)) * (1 << bit)
ct += max(0, (m % (2 << bit)) - (1 << bit) + 1)
ans += pow(ct, n, mod) << bit
print(ans % mod)