PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: iceknight1093
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Dynamic programming
PROBLEM:
For a binary array S, define f(S) to be the number of distinct arrays that can be reached by performing the following operation several times:
- Choose an index i such that S_i = S_{i+2}, and flip S_{i+1}.
You’re given an array A satisfying A_i \in \{-1, 0, 1\}.
Compute the sum of f(S) across all distinct binary strings S that can be formed by replacing the -1’s in A by 0/1.
Here, N \le 100.
EXPLANATION:
Consider a fixed binary array S.
Let’s perform a couple of transformations on it to make computing f(S) easier, because the current operation is a bit awkward to deal with.
In particular, the first thing we do is take the difference array of S.
That is, let B be an array of length N-1, where B_i = S_i \oplus S_{i+1} for each 1 \le i \lt N.
So, B_i = 1 if and only if S_i \ne S_{i+1}.
Under this lens, note that the operation in B becomes: choose an index i such that B_i = B_{i+1}, and flip both B_i and B_{i+1}.
This is because, if S_i = S_{i+2}, then S_{i+1} differs from both S_i and S_{i+2} in the same way; both before and after flipping it.
It’s still a bit unclear how to figure out reachable states from just B, however.
So, we’ll look for another transformation.
Observe that each operation on B affects one even index and one odd index.
Consider what happens when we flip the values at all odd indices of B.
That is, consider a new array C such that:
- If i is odd, C_i = B_i \oplus 1
- If i is even, C_i = B_i
In this new array C, it can be seen that an operation on S corresponds to the following:
- Choose an index i such that C_i \ne C_{i+1}.
- Then, swap the values at these indices.
This is because we had 00\to 11 and 11\to 00 in B, but after swapping only odd-index elements this turns into 01\to 10 and 10\to 01 in C.
The operation of “swap adjacent different elements” is quite powerful, because it allows us to freely rearrange C as we like!
Further, observe that the transformations S \to B and B \to C are bijections (well, technically you need to also store the information of the first element of S for S\to B to be a bijection; but that’s fine to do because our operation cannot change S_1 anyway.)
Thus, any sequence of swap operations in C also corresponds to some sequence of operations on S.
Further, since these are bijections, every different string we can transform C into also corresponds to a different string S can transform into.
Thus, f(S) simply equals the number of distinct strings we can transform C into using the operation of swapping adjacent different elements.
This, now, is easy to answer: if there are K ones in C, then it can turn into any of
possible binary arrays, since we can freely rearrange these K ones among the existing N-1 positions.
Observe now that f(S) depends purely on the number of ones in the transformed array C, where C_i = S_i \oplus S_{i+1} \oplus (i\bmod 2).
So, to sum up f(S) over all binary completions S of the given array A, we only need to know the distribution of ones across all possible transformed arrays.
That is, if f_x denotes the number of completions such that the transformed array has x ones, the answer will then be
Our goal is now to compute the values c_x for all x.
That’s fairly easy to do using dynamic programming.
Let’s define dp(i, x, p) to be the number of ways to fill in all -1’s at positions 1, 2, \ldots, i such that:
- S_i = p, and
- The array C has x ones so far.
Computing dp(i, x, p) is easy - fix the value of S_i (either 0 or 1, with both being allowed if A_i = -1), then use that and the value at the previous position (try both possibilities) to fix C_{i-1}, and perform the transition from the appropriate previous state.
This gives a solution in \mathcal{O}(N^2) which is more than fast enough for the constraints.
TIME COMPLEXITY:
\mathcal{O}(N^2) per testcase.
CODE:
Editorialist's code (PyPy3)
mod = 998244353
fac = [1] + list(range(1, 105))
for i in range(1, 105): fac[i] = (fac[i-1] * i) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * pow(fac[r] * fac[n-r], mod-2, mod) % mod
for _ in range(int(input())):
n = int(input())
a = list(map(int, input().split()))
dp = [ [0, 0] for _ in range(n+1)]
dp[0][0] = 1 if a[0] != 1 else 0
dp[0][1] = 1 if a[0] != 0 else 0
for i in range(1, n):
ndp = [ [0, 0] for _ in range(n+1)]
vals = [0, 1]
if a[i] != -1: vals.remove(1 - a[i])
for v in vals:
for p in [0, 1]:
cur = v ^ p
if i%2 == 1: cur ^= 1
for k in range(n):
nk = k + cur
ndp[nk][v] = (ndp[nk][v] + dp[k][p]) % mod
dp, ndp = ndp, dp
ans = 0
for i in range(n+1): ans += sum(dp[i]) * C(n-1, i)
print(ans % mod)