PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: iceknight1093
Tester: apoorv_me
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Combinatorics, prefix sums
PROBLEM:
An array is called prefix-balanced if, for each of its prefixes, no two elements’ frequencies differ by more than 1.
You’re given a partially-filled array with elements between 1 and M.
Find the number of ways to fill in the array such that it’s prefix-balanced.
EXPLANATION:
The solution will continue from the easier version, so it’s recommended to read that editorial first if you haven’t (link).
To recap, an array containing K distinct elements is prefix-balanced if and only if when it is broken up into blocks of length K (with the last one being maybe of length \lt K), each such block contains distinct elements.
Looking back at the solution to the easy version, we did the following:
- Fix K, the size of S_A.
- Fix the elements of S_A.
- Fix a rearrangement of them into each block.
Let’s try to recreate the same idea here.
Fixing K remains the same, so let’s do that.
When trying to fix the elements of S_A however, we need to be a bit careful: some values already exist in A, so those need to be taken into consideration.
Specifically, if A already contains x distinct non-zero values, we can only choose another (K - x) elements to reach a size of K.
Further, this choice must be made from the (M-x) elements that aren’t already in A; for a total of \binom{M-x}{K-x} choices.
Next, we look at the blocks. These are no longer homogeneous, so we need to look at each of them separately.
Consider a block spanning indices i to \min(N, i+K-1).
For such a block,
- If it already contains duplicate non-zero elements, the condition is already a failure - and there’s no way at all to have |S_A| = K.
- Otherwise, suppose there are y distinct elements already present, and z zeros.
We can then choose z elements out of the K-y we have, and arrange them in any of z! orders, for a total of \binom{K-y}{z} ways.
(Note that if the block has size K, we’ll have K-y = z and so \binom{K-y}{z} = 1; however stating it in this fashion allows us to take care of the last block without explicitly having to special-case it).
Also, note that the starting points of the blocks of size K will be 1, K+1, 2K+1, \ldots
There will be exactly \left\lceil \frac{N}{K} \right\rceil such blocks, so across all K, the number of blocks we process will be
It’s well-known that this is \mathcal{O}(N\log N), so as long as we’re able to check each block fast enough, the blocks being non-homogeneous doesn’t matter!
From the above discussion, we’re now left with two things to do: check whether each block contains repeated non-zerp elements, and if it doesn’t, count the number of distinct non-zero elements present in it. Both need to be done quickly, ideally in constant time since there’s already a multiplier of N\log N.
Notice that the second part is actually quite trivial if we can achieve the first: after all, if the non-zero elements aren’t repeated, then the number of distinct non-zero elements simply equals the number of non-zero elements present in the range!
That can easily be computed in \mathcal{O}(1) time using prefix sums.
So, all we’re really left with is checking whether the non-zero elements in some range are distinct or not.
To do that, let’s precompute for each index i the position R_i, which is the smallest index \geq i such that the range [i, R_i] contains repeated elements.
This can be computed in \mathcal{O}(N+M) time as follows:
- Iterate i from N down to 1.
- If A_i = 0, R_i = R_{i+1}.
- Otherwise, R_i = \min(R_{i+1}, j), where j \gt i is the first position such that A_i = A_j.
This nearest next equal element can also be precomputed for each element, for \mathcal{O}(1) lookup.
The logic should be simple enough: either the repeat element is A_i itself, in which case we look for its nearest occurrence; or it’s something else in which case it’ll also be a repeat for the subarray starting at index i+1.
With this in hand, checking whether some subarray [i, j] contains repeat elements is trivial: just check if R_i \leq j or not!
TIME COMPLEXITY:
\mathcal{O}(N\log N + M) per testcase.
CODE:
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
} inp;
namespace mint_ns {
template<auto P>
struct Modular {
using value_type = decltype(P);
value_type value;
Modular(long long k = 0) : value(norm(k)) {}
friend Modular<P>& operator += ( Modular<P>& n, const Modular<P>& m) { n.value += m.value; if (n.value >= P) n.value -= P; return n; }
friend Modular<P> operator + (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r += m; }
friend Modular<P>& operator -= ( Modular<P>& n, const Modular<P>& m) { n.value -= m.value; if (n.value < 0) n.value += P; return n; }
friend Modular<P> operator - (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r -= m; }
friend Modular<P> operator - (const Modular<P>& n) { return Modular<P>(-n.value); }
friend Modular<P>& operator *= ( Modular<P>& n, const Modular<P>& m) { n.value = n.value * 1ll * m.value % P; return n; }
friend Modular<P> operator * (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r *= m; }
friend Modular<P>& operator /= ( Modular<P>& n, const Modular<P>& m) { return n *= m.inv(); }
friend Modular<P> operator / (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r /= m; }
Modular<P>& operator ++ ( ) { return *this += 1; }
Modular<P>& operator -- ( ) { return *this -= 1; }
Modular<P> operator ++ (int) { Modular<P> r = *this; *this += 1; return r; }
Modular<P> operator -- (int) { Modular<P> r = *this; *this -= 1; return r; }
friend bool operator == (const Modular<P>& n, const Modular<P>& m) { return n.value == m.value; }
friend bool operator != (const Modular<P>& n, const Modular<P>& m) { return n.value != m.value; }
explicit operator int() const { return value; }
explicit operator bool() const { return value; }
explicit operator long long() const { return value; }
constexpr static value_type mod() { return P; }
value_type norm(long long k) {
if (!(-P <= k && k < P)) k %= P;
if (k < 0) k += P;
return k;
}
Modular<P> inv() const {
value_type a = value, b = P, x = 0, y = 1;
while (a != 0) { value_type k = b / a; b -= k * a; x -= k * y; swap(a, b); swap(x, y); }
return Modular<P>(x);
}
friend void __print(Modular<P> x) {
cerr << x;
}
};
template<auto P> Modular<P> pow(Modular<P> m, long long p) {
Modular<P> r(1);
while (p) {
if (p & 1) r *= m;
m *= m;
p >>= 1;
}
return r;
}
template<auto P> ostream& operator << (ostream& o, const Modular<P>& m) { return o << m.value; }
template<auto P> istream& operator >> (istream& i, Modular<P>& m) { long long k; i >> k; m.value = m.norm(k); return i; }
template<auto P> string to_string(const Modular<P>& m) { return to_string(m.value); }
}
constexpr int mod = 998244353;
using mod_int = mint_ns::Modular<mod>;
using mi = mod_int;
constexpr int maxn = 1e6 + 3;
vector<mi> fct(maxn, 1), invf(maxn, 1);
void calc_fact() {
for(int i = 1 ; i < maxn ; i++) {
fct[i] = fct[i - 1] * i;
}
invf.back() = mi(1) / fct.back();
for(int i = maxn - 1 ; i ; i--)
invf[i - 1] = i * invf[i];
}
mi choose(int n, int r) { // choose r elements out of n elements
if(r > n) return mi(0);
assert(r <= n);
return fct[n] * invf[r] * invf[n - r];
}
mi place(int n, int r) { // x1 + x2 ---- xr = n and limit value of xi >= n
assert(r > 0);
return choose(n + r - 1, r - 1);
}
template<typename T>
struct BIT {
vector<T> tree; int N;
BIT(int N_ = 0) {
N = N_;
tree.resize(N + 1);
}
void update(int ind, T val) {
for(++ind ; ind <= N ; ind += ind & -ind)
tree[ind] += val;
}
T query(int ind) {
T sum = 0;
for(++ind ; ind >= 0 ; ind -= ind & -ind)
sum += tree[ind];
return sum;
}
T query(int L, int R) {
return query(R) - query(L - 1);
}
};
int32_t main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
const int NN = 5e5 + 1;
vector<vector<int>> F(NN);
for (int i = 1 ; i < NN ; ++i) {
for(int j = 0 ; j <= NN ; j += i)
F[j].push_back(i);
}
calc_fact();
int sumN = 0;
auto __solve_testcase = [&](int test) {
int N = inp.readInt(1, (int)5e5); inp.readSpace(); sumN += N;
int M = inp.readInt(1, (int)5e5); inp.readEoln();
vector<int> A = inp.readInts(N, 0, M); inp.readEoln();
int s = set<int>(A.begin(), A.end()).size();
if(count(A.begin(), A.end(), 0))
--s;
vector<mod_int> B(N + 1);
for(int x = max(1, s) ; x <= min(N, M) ; ++x)
B[x] = choose(M - s, x - s);
vector<int> P(N + 1);
for(int i = 0 ; i < N ; ++i) {
P[i + 1] = P[i] + (A[i] == 0);
}
map<int, int> ind;
int bad = 2 * N;
for(int i = N - 1 ; i >= 0 ; --i) {
if(A[i] && ind.find(A[i]) != ind.end())
bad = min(bad, ind[A[i]]);
ind[A[i]] = i;
for(auto &x: F[i]) {
if(x > N) break;
if(x + i > bad) {
B[x] = 0;
continue;
}
int numz = P[min(i + x, N)] - P[i];
int tot = min(i + x, N) - i;
B[x] *= choose(x - (tot - numz), numz) * fct[numz];
}
}
cout << accumulate(B.begin(), B.end(), mod_int(0)) << '\n';
};
int NumTest = 1;
NumTest = inp.readInt(1, (int)1e5); inp.readEoln();
for(int testno = 1; testno <= NumTest ; ++testno) {
__solve_testcase(testno);
}
inp.readEof();
return 0;
}
Editorialist's code (Python)
mod = 998244353
N = 500005
fac = [1]*N
for i in range(1, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(N-1)): inv[i] = inv[i+1] * (i+1) % mod
def C(n, r):
if n < 0 or n < r: return 0
return fac[n] * inv[r] % mod * inv[n-r] % mod
for _ in range(int(input())):
n, m = map(int, input().split())
a = list(map(int, input().split()))
mark = [0]*(m+1)
for x in a: mark[x] = 1
already_have = mark[1:].count(1)
pref = [0]*(n+1)
for i in range(n):
if a[i] > 0: pref[i+1] = 1
pref[i+1] += pref[i]
next_dup = [n+1]*(n+2)
last_seen = [n+1]*(m+1)
for i in reversed(range(n)):
if a[i] == 0: next_dup[i+1] = next_dup[i+2]
else:
next_dup[i+1] = min(next_dup[i+2], last_seen[a[i]])
last_seen[a[i]] = i+1
ans = 0
for i in range(max(1, already_have), n+1):
# C(m - already_have, i - already have) choices for the other elements
choices = C(m - already_have, i - already_have)
arrangements = 1
for L in range(1, n+1, i):
R = min(n, L+i-1)
# check if everything is distinct in [L, R]
if next_dup[L] > R:
# if it is, find count of things in [L, R]
in_range = pref[R] - pref[L-1]
zeros = R-L+1 - in_range
arrangements = arrangements * C(i - in_range, zeros) % mod * fac[zeros] % mod
else:
arrangements = 0
ans = (ans + choices * arrangements) % mod
print(ans)