GCDXOR2HD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: hjroh0315
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Dynamic programming, binary exponentiation

PROBLEM:

You’re given N and D. Count the number of pairs of integers (X, Y) such that:

  1. 1 \leq X \leq Y \lt 2^N
  2. \gcd(X, Y) = X\oplus Y
  3. Y-X = D

N \leq 10^9 and D \leq 2000 in this version.

EXPLANATION:

From the easy version of the problem, we know exactly what the structure of the solution is: we start with (X, Y) = (0, D), and then we can choose any non-empty subset of bits among those not set in D and whose sum is 0 modulo D.

In the easy version, this was done with DP in \mathcal{O}(ND) time, which is of course too slow now that N can be as large as 10^9.
The key idea behind optimizing for larger N is periodicity.

We’re interested in the values 2^0\bmod D, 2^1\bmod D, 2^2\bmod D, \ldots
Since we’re working modulo D, and each value is obtained by multiplying the previous one by 2, the sequence of values will eventually end up periodic: specifically, the very first time a value repeats (which needs at most D+1 elements), the values will then begin repeating in a cycle over and over again.
Further, this cycle will have length no more than D.


Let’s find this cycle (which can be done in \mathcal{O}(D) by just bruteforcing it).
Suppose the remainders on this cycle are r_1, r_2, \ldots, r_k. Note that k \leq D.

Now, it’s pretty easy to compute the contribution of just this cycle: directly using the DP we had for the easy version gives a complexity of \mathcal{O}(kD) which is fast enough.
Let A be the array of length D we obtain by running the DP on this cycle.

We have several repetitions of this same cycle, so our next step is figuring out how to use A to compute the answer for that.
To that end, let’s define A^{(m)} to be the answer array for if the cycle is repeated m times.

Observe that, for any integers x, y, we have the following relation:

A^{(x+y)}_i = \sum_{j=0}^{D-1} A^{(x)}_j A^{(y)}_{i-j}

It’s easy to see why: if we have x+y repetitions of the cycle, we can split it into two parts with sizes x and y; and the remainders obtained from both parts simply get added up.

So, A^{(x+y)} is simply the convolution of A^{(x)} and A^{(y)}.
In fact, this property holds for any two arbitrary sets as well: the DP array corresponding to their union is simply the convolution of their respective DP arrays.

This is a very useful property, because it allows us to compute A^{(x)} using \mathcal{O}(\log x) convolutions by simply running the binary exponentiation algorithm.
Each convolution can be done in \mathcal{O}(D^2) naively which is allowed by the limits.


We are now able to find the cycle in \mathcal{O}(D) time, and compute the answer for x repetitions of this cycle in \mathcal{O}(D^2 \log x) time.

This allows us to put together a full solution, as follows:

  • Compute the DP naively for the first D powers.
    This is because small powers are special in that they might have to be skipped (if they’re set in D), but once we get to “large enough” powers (functionally, \gt \log_2 D is safe) we’ll be fine.
    This takes \mathcal{O}(D^2) time.
  • Now, starting from the (D+1)-th power, find the cycle and count the number of times the cycle fits in till N-1.
    Computing the contribution of this part takes \mathcal{O}(D^2 \log N) time.
  • Next, we’ll have a handful of bits remaining at the end which don’t fit into a full cycle.
    Since the cycle has length \leq D, there will also be no more than D values remaining, meaning we can just compute the answer for them naively in \mathcal{O}(D^2) too.
  • Finally, simply merge all three of the parts above to obtain the final DP array.
    This also takes \mathcal{O}(D^2) time.

Every convolution that took \mathcal{O}(D^2) can be optimized to \mathcal{O}(D\log D) by using a faster convolution algorithm, though that isn’t needed to get AC.

TIME COMPLEXITY:

\mathcal{O}(D^2 \log N) per testcase.

CODE:

Editorialist's code (PyPy3)
mod = 998244353
def merge(a, b):
    n = len(a)
    c = [0]*n
    for i in range(n):
        for j in range(n):
            c[(i+j)%n] += a[i]*b[j]
    for i in range(n): c[i] %= mod
    return c

def calc(d, L, R):
    dp = [0]*d
    dp[0] = 1
    for b in range(L, R+1):
        if b < 20 and (d & (1 << b)): continue
        
        val = pow(2, b, d)
        ndp = [0]*d
        for i in range(d):
            ndp[(i + val) % d] = (dp[i] + dp[(i + val) % d]) % mod
        dp = ndp[:]
    return dp

for _ in range(int(input())):
    n, d = map(int, input().split())
    
    if n <= max(3*d, d+10):
        ans = calc(d, 0, n-1)[0]
        print((ans + mod - 1) % mod)
        continue

    dp1 = calc(d, 0, d+5)
    
    L, R = d+6, d+7
    mark = [0]*d
    mark[pow(2, L, d)] = 1
    while mark[pow(2, R, d)] == 0:
        mark[pow(2, R, d)] = 1
        R += 1
    
    # [L, R) cycle
    siz = R - L
    dp2 = calc(d, L, R-1)
    reps = (n - (d+6)) // siz
    nxt = L + siz*reps
    res = [1] + [0]*(d-1)
    while reps > 0:
        if reps & 1: res = merge(res, dp2)
        dp2 = merge(dp2, dp2)
        reps //= 2
    dp3 = calc(d, nxt, n-1)
    ans = merge(dp1, merge(res, dp3))[0]
    print((ans + mod - 1) % mod)