XORTREEH - Editorial




Author: Adarsh Kumar

Tester: Alexey Zayakin

Editorialist: Jakub Safin




Fourier transform, number theoretic transform (NTT), fast (modular) exponentiation and modular inverse


You’re given an array A of N non-negative integers and integers K,X.

Define the XOR of two numbers a\oplus b = c in base K this way: the d-th digit of c in base K is c_d=a_d+b_d modulo K.

Compute the probabilities p_i that the base K XOR of mex values of X randomly selected subsequences of A is equal to i, for all possible i \ge 0, modulo 330301441.


The mex won’t be too large, XORs won’t be too large either. Compute the probabilities for getting all possible mex-s of one subsequence. The probabilities for XOR of X subsequences can be computed using slow multidimensional number theoretic transform, where each dimension is a digit and array size is K instead of a power of 2, combined with fast exponentiation.


The problem statement mentions the answer should be a complicated sum \sum (i^2 p_i^3)^i over all p_i > 0 (obviously, the terms with p_i=0 don’t affect the sum since that only happens with i > 0). However, that’s not important. The sum is just a hash value that’s there to avoid having to print large numbers and we can compute it after finding all p_i. Let’s just mention that the i-th powers can be computed using fast exponentiation.

Since we need to compute the result modulo MOD, we need to work with fractions (all p_i will be rational numbers) as their equivalents modulo – dividing by Q corresponds to multiplying by its modular inverse. Since the given modulo is a prime, the inverse Q^{-1}=Q^{MOD-2} according to Fermat’s little theorem, which can be computed using the above mentioned fast exponentiation.

Which values of i give non-zero p_i? Obviously, the mex of an array of size N can’t be more than N, since the opposite would require all integers between 0 and N to be present in A. We can interpret numbers \le N in base K as numbers with at most D digits, or exactly D digits including leading zeroes, where D=\left\lceil \log_K N \right\rceil. Xor-ing D-digit numbers gives a D-digit number again, so the xor of X numbers is < K^D. Therefore, it’s sufficient to compute p_i only for i < K^D, which makes O(KN) numbers. That’s not too much.

From mex to probabilities

Let’s find the probabilities P_1(i) that the mex of a random subsequence of A will be equal to i, e.g. the probabilities p_i if X=1. As mentioned above, we can limit ourselves to i \le N.

We can compute just the number of subsequences S_1(i) that give mex equal to i and then normalise those values – divide them by \sum S_1(i), or rather multiply by its multiplicative inverse – to get P_1(i).

If the mex of some subsequence is i, then all elements A_j=i can’t be in the subsequence. Any of the elements A_j > i can be in there, but it doesn’t matter; if there are g such elements, that gives 2^g possibilities. Finally, for any 0 \le k < i, there must be at least one element A_j=k present in the subsequence. If there are s_k such elements for a given k, then there are 2^{s_k}-1 ways to choose them (any non-empty subset). We can express

S_1(i) = 2^g \prod_{k=0}^i \left(2^{s_k}-1\right) = \prod_{k=i+1}^{A_{max}} 2^{s_k} \prod_{k=0}^i \left(2^{s_k}-1\right)\,,

where A_{max} is the maximum element in A, since g is just the sum of s_k for k > i.

Using fast exponentiation, we can precompute all s_k, 2^{s_k}, their suffix products and prefix products of 2^{s_k}-1 (similarly to prefix sums) and compute S_1(i) and P_1(i) for all i \le N using the given formula in O(A_{max}+N\log N); it doesn’t even need to depend on A_{max} if we notice that since the mex can’t be greater than N, setting A_i := min(N+1,A_i) doesn’t affect the result.

This approach is fast with only O(K^D)=O(KN) time complexity.

Walsh-Hadamard transform

Look at the straightforward way to compute probabilities P_2(i) for X=K=2 from P_1(i):

P_2(i) = \sum_{j=0}^N P_1(j) P_1(i\oplus j)\,.

It’s very similar to convolution of two arrays, the only difference is that we’re using \oplus instead of +. The convolution of 2 arrays can be computed using fast Fourier transform by computing the FFT of both arrays extended to size 2^k \ge 2N, multiplying their corresponding elements and computing the inverse FFT of the resulting array; for this xor-convolution, it’s very similar, but we’re using something called fast Walsh-Hadamard transform instead. You can read about it here.

For general X, the fast way to compute all P_X(i) is to compute the Walsh-Hadamard transform WH\lbrack P_1\rbrack(\nu), take B(\nu)=WH\lbrack P_1\rbrack^X(\nu) (using fast exponentiation) and compute P_X(i) as the inverse Walsh-Hadamard transform P_X=WH^{-1}\lbrack B\rbrack. However, this only works for K=2, where the conventional xor is defined.

The following is actually a generalised version of WHT for arbitrary K \ge 2.

A better approach: number theoretic transform

This approach uses the specific value of MOD. If we compute small factors of MOD-1, we can see that all numbers from 2 to 10 – all possible K – divide it! That means we can use the number theoretic transform, which allows us to treat the base-K xor as what it actually is: summation modulo K.

On the other hand, we’re going to need the multidimensional version. We can look at an index i as a vector of D digits (i_1,\dots,i_D); the xor of 2 vectors is actually just their vector sum and then taking the remainder mod K in each digit. The formula for P_2(i) then becomes \sum_j P_1(j_1,\dots,j_D) P_1((i_1-j_1)\%K,\dots,(i_D-j_D)\%K), which is just multidimensional convolution with indices modulo K.

So how do we do convolution with indices modulo K? We don’t need to do anything – turns out convolution using DFT or NTT is already done with indices modulo array size! That’s why we need the trick with padding the array with zeroes at the end to at least twice the size (it’s a power of 2 just so that it’d run fast): when we’re doing C_{i+j} += A_i B_j, we’re only adding a non-zero number if i+j < 2N, so taking it modulo 2N does nothing.

The reason why this happens is apparent if we look at how DFT or NTT works. For an array of size N, we choose a number w such that w^k \neq 1 for 0 < k < N and w^N=1 and compute F\lbrack A\rbrack(j) = \sum w^{jk} A_k for each 0 \le j < N. The inverse transformation uses 1/w instead of w. DFT uses w=e^{2\pi i / N}; for NTT, it’s a so-called primitive root – a number for which the required conditions hold modulo MOD. There’s no easy way to pick a primitive root, but since we have a fixed modulo and array sizes N (N=K here) are small, we can compute them e.g. by bruteforcing locally for all possible values of K and hardwire them into the code.

Anyway, since w^N=1, there’s no difference in what we’re computing if we take some index modulo N. We can run DFT or NTT to get F\lbrack P_1\rbrack without increasing its size to a power of 2, then take F\lbrack P_X\rbrack(j)=F\lbrack P_1\rbrack(j)^X for each j and finally compute P_X using an inverse transform and it gives us the probabilities we need.

There’s a small drawback here: we’re transforming arrays of size K, so there’s no way to use the “butterfly scheme” of classic FFT. However, we don’t need that. In order to stay in integers and get the required precision, we’re going to use multidimensional NTT modulo MOD by simply computing the required sums directly. Small array sizes help a lot, since this bruteforce approach runs in O(K) per element per dimension.

We have O(K^D)=O(KN) elements (the former bound is tighter and we need to work with fixed D anyway, so let’s use that) in P_1 and P_X, so each NTT (direct and reverse) runs in O(K^{D+1}D). Between them, there are O(K^D) fast exponentiations in O(\log MOD). The total time complexity is therefore O(K^D(KD+\log{MOD})); memory complexity O(K^D).


Setter’s solution

Tester’s solution

Editorialist’s solution


Exactly how I did it.

1 Like