PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: raysh07
Tester: iceknight1093
Editorialist: iceknight1093
DIFFICULTY:
Medium
PREREQUISITES:
NTT
PROBLEM:
For a binary array S, define f(S) to be the number of distinct arrays that can be reached by performing the following operation several times:
- Choose an index i such that S_i = S_{i+2}, and flip S_{i+1}.
You’re given an array A satisfying A_i \in \{-1, 0, 1\}.
Compute the sum of f(S) across all distinct binary strings S that can be formed by replacing the -1’s in A by 0/1.
Here, N and \sum N are both \le 2\cdot 10^5.
EXPLANATION:
The easy version tells us that the base solution idea is as follows:
- To compute f(S) for a fixed S,
- Define a new array C such that C_i = S_i \oplus S_{i+1} \oplus (i\bmod 2)
- Let K denote the number of ones in C.
- Then f(S) = \binom{N-1}{K}
- To sum up across all completions of A we then ran a quadratic DP to compute the distribution of strings hitting each count-of-ones in C.
Of course, the slow part is the DP, and it needs to be optimized.
To optimize the solution, note that C_i = S_i \oplus S_{i+1} \oplus (i\bmod 2) is a “local” quantity, in the sense that it only depends on a couple of surrounding values.
This allows us to break up the array into several ‘blocks’, and solve for each block separately.
That is, suppose i and j are two indices such that:
- A_i and A_j are fixed values, i.e. neither are -1
- A_k = -1 for all i \lt k \lt j.
Then, observe that every choice of fixing the values at indices [i+1, j-1] will affect only the C_k values for k \in [i, j-1]. Nothing else will be affected.
In particular this means the count of ones at these positions will also be independent of all other similar ‘blocks’, and so we can simply compute its distribution for now and merge all the distributions later.
So, what exactly is the distribution for a block of the form [A_i, -1, -1, \ldots, -1, A_j]?
To answer this, we need to understand what constraints are placed on the array C.
Recall that we defined the array B as B_i = A_i\oplus A_{i+1}, and C_i = B_i \oplus (i\bmod 2).
In particular, via the array B, we see that
Since C_i = B_i \oplus (i\bmod 2) \iff B_i = C_i \oplus (i\bmod 2), we obtain
Note that this then fixes the value of C_i\oplus\ldots\oplus C_{j-1} to
In particular, this tells us whether the number of ones in C, between indices i and j-1, can be even or odd - only one of the parities is valid (depending on i and j.)
However, once the parity of ones is known, any number of them (and in any distribution) can be attained in this range.
That is, if we let p = A_i\oplus A_j\oplus ((i\bmod 2)\oplus\ldots\oplus ((j-1)\bmod 2)), then the counts of ones in this range can be p, p+2, p+4, \ldots
Further, if we want x ones, there are \binom{j-i}{x} ways to choose their positions; and all such choices are valid.
For a block of length L, we can compute its distribution in \mathcal{O}(L) time (assuming binomial coefficients are found in constant time.)
We now need to merge the results of different blocks.
To understand how to do this, let’s assume we have only two blocks.
Since the blocks are independent, having x ones in one block and y ones in the other gives us a total of x+y ones; and the number of ways is obtained by simply multiplying the respective counts.
This is just a convolution of the values corresponding to the two blocks, assuming we represent them as polynomials (where the coefficient of x_i is the number of ways of having i ones in the C array.)
In general, we have several blocks, and we want to convolve them all to obtain the overall answer.
This can be done in \mathcal{O}(N\log ^2 N) time using a couple of different methods:
- Repeatedly take the two lowest-degree remaining polynomials and multiply them; or
- Run a divide-and-conquer on the polynomials: split the polynomials into two sets of half each, recursively find the answer for each set, and the multiply the results.
Either method can be shown to run in \mathcal{O}(N\log ^2 N) time, because the sum of degrees of all polynomials we have is \le N.
There are a couple more details to take care of.
First, note that we only considered blocks of the form [A_i, -1, -1, \ldots, -1, A_j], i.e. with both ends having fixed values.
However, there are (upto) two blocks that aren’t of this form, namely a prefix and a suffix of the array.
These are not too hard to deal with, however: since we’re constrained on only one side instead of both, we simply don’t have the parity condition at all and there are just \binom{L}{x} ways to have x ones for every 0 \le x \le L.
This gives us another couple of polynomials to work with, which can just be thrown on to the pile before multiplying everything.
Finally, the case of A having all -1’s is also special and needs to be handled separately since it doesn’t fall into any of the previous cases.
TIME COMPLEXITY:
\mathcal{O}(N \log^2 N) per testcase.
CODE:
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
// https://judge.yosupo.jp/submission/69895
namespace ntt {
template <class T, class F = multiplies<T>>
T power(T a, long long n, F op = multiplies<T>(), T e = {1}) {
// assert(n >= 0);
T res = e;
while (n) {
if (n & 1) res = op(res, a);
if (n >>= 1) a = op(a, a);
}
return res;
}
constexpr int mod = int(1e9) + 7;
constexpr int nttmod = 998'244'353;
template <std::uint32_t P>
struct ModInt32 {
public:
using i32 = std::int32_t;
using u32 = std::uint32_t;
using i64 = std::int64_t;
using u64 = std::uint64_t;
using m32 = ModInt32;
using internal_value_type = u32;
private:
u32 v;
static constexpr u32 get_r() {
u32 iv = P;
for (u32 i = 0; i != 4; ++i) iv *= 2U - P * iv;
return -iv;
}
static constexpr u32 r = get_r(), r2 = -u64(P) % P;
static_assert((P & 1) == 1);
static_assert(-r * P == 1);
static_assert(P < (1 << 30));
static constexpr u32 pow_mod(u32 x, u64 y) {
u32 res = 1;
for (; y != 0; y >>= 1, x = u64(x) * x % P)
if (y & 1) res = u64(res) * x % P;
return res;
}
static constexpr u32 reduce(u64 x) {
return (x + u64(u32(x) * r) * P) >> 32;
}
static constexpr u32 norm(u32 x) { return x - (P & -(x >= P)); }
public:
static constexpr u32 get_pr() {
u32 tmp[32] = {}, cnt = 0;
const u64 phi = P - 1;
u64 m = phi;
for (u64 i = 2; i * i <= m; ++i)
if (m % i == 0) {
tmp[cnt++] = i;
while (m % i == 0) m /= i;
}
if (m != 1) tmp[cnt++] = m;
for (u64 res = 2; res != P; ++res) {
bool flag = true;
for (u32 i = 0; i != cnt && flag; ++i)
flag &= pow_mod(res, phi / tmp[i]) != 1;
if (flag) return res;
}
return 0;
}
constexpr ModInt32() : v(0){};
~ModInt32() = default;
constexpr ModInt32(u32 _v) : v(reduce(u64(_v) * r2)) {}
constexpr ModInt32(i32 _v) : v(reduce(u64(_v % P + P) * r2)) {}
constexpr ModInt32(u64 _v) : v(reduce((_v % P) * r2)) {}
constexpr ModInt32(i64 _v) : v(reduce(u64(_v % P + P) * r2)) {}
constexpr ModInt32(const m32& rhs) : v(rhs.v) {}
constexpr u32 get() const { return norm(reduce(v)); }
explicit constexpr operator u32() const { return get(); }
explicit constexpr operator i32() const { return i32(get()); }
constexpr m32& operator=(const m32& rhs) { return v = rhs.v, *this; }
constexpr m32 operator-() const {
m32 res;
return res.v = (P << 1 & -(v != 0)) - v, res;
}
constexpr m32 inv() const { return pow(P - 2); }
constexpr m32& operator+=(const m32& rhs) {
return v += rhs.v - (P << 1), v += P << 1 & -(v >> 31), *this;
}
constexpr m32& operator-=(const m32& rhs) {
return v -= rhs.v, v += P << 1 & -(v >> 31), *this;
}
constexpr m32& operator*=(const m32& rhs) {
return v = reduce(u64(v) * rhs.v), *this;
}
constexpr m32& operator/=(const m32& rhs) {
return this->operator*=(rhs.inv());
}
friend m32 operator+(const m32& lhs, const m32& rhs) {
return m32(lhs) += rhs;
}
friend m32 operator-(const m32& lhs, const m32& rhs) {
return m32(lhs) -= rhs;
}
friend m32 operator*(const m32& lhs, const m32& rhs) {
return m32(lhs) *= rhs;
}
friend m32 operator/(const m32& lhs, const m32& rhs) {
return m32(lhs) /= rhs;
}
friend bool operator==(const m32& lhs, const m32& rhs) {
return norm(lhs.v) == norm(rhs.v);
}
friend bool operator!=(const m32& lhs, const m32& rhs) {
return norm(lhs.v) != norm(rhs.v);
}
friend std::istream& operator>>(std::istream& is, m32& rhs) {
return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
}
friend std::ostream& operator<<(std::ostream& os, const m32& rhs) {
return os << rhs.get();
}
constexpr m32 pow(i64 y) const {
// assumes P is a prime
i64 rem = y % (P - 1);
if (y > 0 && rem == 0)
y = P - 1;
else
y = rem;
m32 res(1), x(*this);
for (; y != 0; y >>= 1, x *= x)
if (y & 1) res *= x;
return res;
}
};
using mint = ModInt32<nttmod>;
void ntt(vector<mint>& a, bool inverse) {
static array<mint, 30> dw{}, idw{};
if (dw[0] == 0) {
mint root = 2;
while (power(root, (nttmod - 1) / 2) == 1) root += 1;
for (int i = 0; i < 30; ++i)
dw[i] = -power(root, (nttmod - 1) >> (i + 2)),
idw[i] = 1 / dw[i];
}
int n = (int)a.size();
assert((n & (n - 1)) == 0);
if (not inverse) {
for (int m = n; m >>= 1;) {
mint w = 1;
for (int s = 0, k = 0; s < n; s += 2 * m) {
for (int i = s, j = s + m; i < s + m; ++i, ++j) {
auto x = a[i], y = a[j] * w;
a[i] = x + y;
a[j] = x - y;
}
w *= dw[__builtin_ctz(++k)];
}
}
} else {
for (int m = 1; m < n; m *= 2) {
mint w = 1;
for (int s = 0, k = 0; s < n; s += 2 * m) {
for (int i = s, j = s + m; i < s + m; ++i, ++j) {
auto x = a[i], y = a[j];
a[i] = x + y;
a[j] = (x - y) * w;
}
w *= idw[__builtin_ctz(++k)];
}
}
auto inv = 1 / mint(n);
for (auto&& e : a) e *= inv;
}
}
vector<mint> operator*(vector<mint> l, vector<mint> r) {
if (l.empty() or r.empty()) return {};
int n = (int)l.size(), m = (int)r.size(),
sz = 1 << __lg(2 * (n + m - 1) - 1);
if (min(n, m) < 30) {
vector<mint> res(n + m - 1);
for (int i = 0; i < n; ++i)
for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
return {begin(res), end(res)};
}
bool eq = l == r;
l.resize(sz), ntt(l, false);
if (eq)
r = l;
else
r.resize(sz), ntt(r, false);
for (int i = 0; i < sz; ++i) l[i] *= r[i];
ntt(l, true), l.resize(n + m - 1);
return l;
}
vector<mint>& operator*=(vector<mint>& l, vector<mint> r) {
if (l.empty() or r.empty()) {
l.clear();
return l;
}
int n = (int)l.size(), m = (int)r.size(),
sz = 1 << __lg(2 * (n + m - 1) - 1);
if (min(n, m) < 30) {
vector<mint> res(n + m - 1);
for (int i = 0; i < n; ++i)
for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
l = {begin(res), end(res)};
return l;
}
bool eq = l == r;
l.resize(sz), ntt(l, false);
if (eq)
r = l;
else
r.resize(sz), ntt(r, false);
for (int i = 0; i < sz; ++i) l[i] *= r[i];
ntt(l, true), l.resize(n + m - 1);
return l;
}
} // namespace ntt
using ntt::mint;
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
vector<mint> fact(200005);
fact[0] = 1;
for (int i = 1; i < 200005; ++i) fact[i] = fact[i-1] * i;
auto C = [&] (int n, int r) {
if (n < r or r < 0) return mint(0);
return fact[n] / (fact[r] * fact[n-r]);
};
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector a(n, 0);
for (int &x : a) cin >> x;
vector<vector<mint>> polys;
for (int i = 0; i+1 < n; ++i) {
if (a[i] == -1) continue;
int j = i+1;
while (j < n and a[j] == -1) ++j;
int len = j - i;
vector<mint> poly(len+1);
if (j < n) {
int par = a[i] ^ a[j] ^ ((i/2)%2 != (j/2)%2);
for (j = par; j <= len; j += 2) poly[j] = C(len, j);
}
else {
--len;
for (j = 0; j <= len; ++j) poly[j] = C(len, j);
}
polys.push_back(poly);
}
if (a[0] == -1) {
int i = 1;
while (i < n and a[i] == -1) ++i;
vector<mint> poly(i+1);
if (i < n) {
for (int j = 0; j <= i; ++j) poly[j] = C(i, j);
}
else {
--i;
for (int j = 0; j <= i; ++j) poly[j] = 2*C(i, j);
}
polys.push_back(poly);
}
auto rec = [&] (const auto &self, int L, int R) -> vector<mint> {
if (L == R) return polys[L];
int mid = (L+R)/2;
auto left = self(self, L, mid);
auto right = self(self, mid+1, R);
return left*right;
};
auto res = rec(rec, 0, size(polys)-1);
mint ans = 0;
for (int i = 0; i < size(res); ++i) {
ans += res[i] * C(n-1, i);
}
cout << ans << '\n';
}
}