PMXXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: biggestotaku
Tester: pols_agyi_pols
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Linearity of expectation, combinatorics

PROBLEM:

For an array B, define f(B) to be the XOR of its prefix maximums.
An index i contains a prefix maximum iff B_i \ge B_j for all j \le i.

You’re given an array A.
Find the expected value of f(B) across all distinct rearrangements B of A.

EXPLANATION:

Let p_k denote the probability that the k-th bit is set in the value of f(B).
By linearity of expectation, the answer we’re looking for is exactly \displaystyle\sum_{k\ge 0} p_k 2^k
So, it suffices to be able to compute p_k for a fixed k.

To do this, we can simply compute the total number of configurations that do have the k-th bit set, and divide this by the total number of configurations.

Now, the k-th bit will be set if and only if an odd number of prefix maximums have it set.
To count the number of such configurations, let’s try building the array by placing elements in descending order; so that later placed elements do not affect existing prefix maximums.


Suppose the distinct elements present in A are
x_1 \gt x_2 \gt \ldots \gt x_m
with their respective counts in A being
c_1, c_2, \ldots, c_m

We start with an empty array.
We also maintain two values w_0 and w_1, where w_0 denotes the number of arrangements so far that don’t have the k-th bit set in the XOR, and w_1 denotes the number of arrangements that do.

First, we place all c_1 copies of x_1.
Note that these c_1 elements will always be prefix maximums, no matter what we do with the remaining elements.
So,

  • If x_1 has the k-th bit set, then the current XOR will either surely have it set, or surely have it unset, depending on whether c_1 is odd or even respectively.
    This corresponds to the states w_1 = 1, w_0 = 0 or w_1 = 0, w_0 = 1 respectively.
  • If x_1 has the k-th bit unset, the XOR will definitely not have the k-th bit set.
    This corresponds to w_1 = 0, w_0 = 1.

Next, let’s try placing the copies of x_2.
Here, note that there are now c_1 + 1 “gaps” between the existing copies of x_1, including before the first copy and after the last one; and we can choose to place as many copies of x_2 as we want into each gap (as long as the total number of copies across all gaps equals c_2.)
Each such method of distribution leads to a different array.

However, the only copies of x_2 that will be prefix maximums are those placed into the first gap, i.e. before the first copy of x_1.
So, we only really care about (the parity of) the number of elements in this gap.

Let’s fix y to be the number of elements that go into the first gap, so that c_2 - y copies have to be distributed among the remaining c_1 gaps.
The number of valid distributions equals \binom{c_1 + c_2 - y - 1}{c_1-1} by stars-and-bars.

Then,

  • If y is even or x_2 doesn’t have bit k set, the parity of bit k is unaffected.
  • If y is odd and x_2 has bit k set, the parity of the number of occurrences will flip.

So, if t_{even} and t_{odd} denote the number of arrangements that have an even/odd number of instances of x_2 before the first instance of x_1,

  • If x_2 doesn’t have bit k set, both w_0 and w_1 will be multiplied by (t_{even} + t_{odd}).
  • If x_2 does have bit k set,
    • w_0 changes to w_0\cdot t_{even} + w_1 \cdot t_{odd}
    • w_1 changes to w_1\cdot t_{even} + w_0 \cdot t_{odd}

In general, when we’re placing the c_i copies of x_i,

  • There are already S_i = c_1 + c_2 + \ldots + c_{i-1} existing larger elements.
    So, there are (S_i+1) gaps to place the copies of x_i.
    However, only the copies placed in the first gap can affect the parity of k in the prefix-max XOR.
  • So, compute t_{even} and t_{odd} to be the number of configurations with an even/odd number of occurrences in the first gap.
    This can be done by iterating through all choices of counts in the first gap (which has c_i+1 options) and then applying stars and bars.
  • Once t_{even} and t_{odd} are known, w_0 and w_1 can be updated appropriately as seen above.

After processing all elements, the number of configurations we’re looking for is exactly w_1.
So, the probability equals w_1 divided by the total number of configurations, which is

\frac{N!}{c_1! \cdot c_2! \cdot\ldots \cdot c_m!}

Note that with k fixed, the overall complexity is \mathcal{O}(N) because c_1 + c_2 + \ldots + c_m = N and we do \mathcal{O}(c_i) work when processing x_i.

Repeating this for every bit will give an additional logarithmic factor to the complexity, which is fine for the given constraints.

TIME COMPLEXITY:

\mathcal{O}(N \log N + N\log(\max A_i)) per testcase.

CODE:

Editorialist's code (PyPy3)
mod = 998244353
N = 10**6 + 10
fac = list(range(N))
fac[0] = 1
for i in range(1, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(inv[-1], mod-2, mod)
for i in reversed(range(N-1)): inv[i] = inv[i+1] * (i+1) % mod

def C(n, r):
    if n < r or r < 0: return 0
    return (fac[n] * inv[r] % mod * inv[n-r]) % mod

for _ in range(int(input())):
    n = int(input())
    a = list(map(int, input().split()))
    
    a.sort()
    val, ct = [], []
    cur = 0
    for i in range(n):
        if i == 0 or a[i] != a[i-1]:
            if cur > 0: ct.append(cur)
            cur = 0
            val.append(a[i])
        cur += 1
    ct.append(cur)
    val = val[::-1]
    ct = ct[::-1]
    
    sz = len(val)
    ans = 0
    for bit in reversed(range(30)):
        dp0, dp1 = 0, 0
        if ct[0]%2 == 1 and val[0] & (1 << bit): dp1 = 1
        else: dp0 = 1
        
        tot = ct[0]
        for i in range(1, sz):
            even, odd = 0, 0
            
            for x in range(ct[i]+1):
                # tot variables summing up to ct[i] - x
                ways = C(tot + ct[i] - x - 1, tot - 1)
                if x%2 == 0: even += ways
                else: odd += ways
            
            if val[i] & (1 << bit):
                dp0, dp1 = (dp0 * even + dp1 * odd) % mod, (dp0 * odd + dp1 * even) % mod
            else:
                dp0 = dp0 * (even + odd) % mod
                dp1 = dp1 * (even + odd) % mod
            tot += ct[i]
        
        ans += dp1 * (1 << bit) % mod
    
    for i in range(sz):
        ans = (ans * fac[ct[i]]) % mod
    print(ans * inv[n] % mod)