TOOFAR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics, prefix sums

PROBLEM:

An array is called prefix-balanced if, for each of its prefixes, no two elements’ frequencies differ by more than 1.
You’re given a partially-filled array with elements between 1 and M.
Find the number of ways to fill in the array such that it’s prefix-balanced.

EXPLANATION:

The solution will continue from the easier version, so it’s recommended to read that editorial first if you haven’t (link).

To recap, an array containing K distinct elements is prefix-balanced if and only if when it is broken up into blocks of length K (with the last one being maybe of length \lt K), each such block contains distinct elements.

Looking back at the solution to the easy version, we did the following:

  • Fix K, the size of S_A.
  • Fix the elements of S_A.
  • Fix a rearrangement of them into each block.

Let’s try to recreate the same idea here.
Fixing K remains the same, so let’s do that.

When trying to fix the elements of S_A however, we need to be a bit careful: some values already exist in A, so those need to be taken into consideration.
Specifically, if A already contains x distinct non-zero values, we can only choose another (K - x) elements to reach a size of K.
Further, this choice must be made from the (M-x) elements that aren’t already in A; for a total of \binom{M-x}{K-x} choices.

Next, we look at the blocks. These are no longer homogeneous, so we need to look at each of them separately.
Consider a block spanning indices i to \min(N, i+K-1).
For such a block,

  • If it already contains duplicate non-zero elements, the condition is already a failure - and there’s no way at all to have |S_A| = K.
  • Otherwise, suppose there are y distinct elements already present, and z zeros.
    We can then choose z elements out of the K-y we have, and arrange them in any of z! orders, for a total of \binom{K-y}{z} ways.
    (Note that if the block has size K, we’ll have K-y = z and so \binom{K-y}{z} = 1; however stating it in this fashion allows us to take care of the last block without explicitly having to special-case it).

Also, note that the starting points of the blocks of size K will be 1, K+1, 2K+1, \ldots
There will be exactly \left\lceil \frac{N}{K} \right\rceil such blocks, so across all K, the number of blocks we process will be

\left\lceil \frac{N}{1} \right\rceil + \left\lceil \frac{N}{2} \right\rceil + \left\lceil \frac{N}{3} \right\rceil + \ldots + \left\lceil \frac{N}{N} \right\rceil

It’s well-known that this is \mathcal{O}(N\log N), so as long as we’re able to check each block fast enough, the blocks being non-homogeneous doesn’t matter!


From the above discussion, we’re now left with two things to do: check whether each block contains repeated non-zerp elements, and if it doesn’t, count the number of distinct non-zero elements present in it. Both need to be done quickly, ideally in constant time since there’s already a multiplier of N\log N.

Notice that the second part is actually quite trivial if we can achieve the first: after all, if the non-zero elements aren’t repeated, then the number of distinct non-zero elements simply equals the number of non-zero elements present in the range!
That can easily be computed in \mathcal{O}(1) time using prefix sums.

So, all we’re really left with is checking whether the non-zero elements in some range are distinct or not.
To do that, let’s precompute for each index i the position R_i, which is the smallest index \geq i such that the range [i, R_i] contains repeated elements.
This can be computed in \mathcal{O}(N+M) time as follows:

  • Iterate i from N down to 1.
  • If A_i = 0, R_i = R_{i+1}.
  • Otherwise, R_i = \min(R_{i+1}, j), where j \gt i is the first position such that A_i = A_j.
    This nearest next equal element can also be precomputed for each element, for \mathcal{O}(1) lookup.

The logic should be simple enough: either the repeat element is A_i itself, in which case we look for its nearest occurrence; or it’s something else in which case it’ll also be a repeat for the subarray starting at index i+1.

With this in hand, checking whether some subarray [i, j] contains repeat elements is trivial: just check if R_i \leq j or not!

TIME COMPLEXITY:

\mathcal{O}(N\log N + M) per testcase.

CODE:

Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;

struct input_checker {
    string buffer;
    int pos;
 
    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";
 
    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }
 
    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }
 
    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }
 
    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }
 
    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }
 
    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }
 
    void readEof() {
        assert((int) buffer.size() == pos);
    }
} inp;

namespace mint_ns {
template<auto P>
struct Modular {
    using value_type = decltype(P);
    value_type value;
 
    Modular(long long k = 0) : value(norm(k)) {}
 
    friend Modular<P>& operator += (      Modular<P>& n, const Modular<P>& m) { n.value += m.value; if (n.value >= P) n.value -= P; return n; }
    friend Modular<P>  operator +  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r += m; }
 
    friend Modular<P>& operator -= (      Modular<P>& n, const Modular<P>& m) { n.value -= m.value; if (n.value < 0)  n.value += P; return n; }
    friend Modular<P>  operator -  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r -= m; }
    friend Modular<P>  operator -  (const Modular<P>& n)                      { return Modular<P>(-n.value); }
 
    friend Modular<P>& operator *= (      Modular<P>& n, const Modular<P>& m) { n.value = n.value * 1ll * m.value % P; return n; }
    friend Modular<P>  operator *  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r *= m; }
 
    friend Modular<P>& operator /= (      Modular<P>& n, const Modular<P>& m) { return n *= m.inv(); }
    friend Modular<P>  operator /  (const Modular<P>& n, const Modular<P>& m) { Modular<P> r = n; return r /= m; }
 
    Modular<P>& operator ++ (   ) { return *this += 1; }
    Modular<P>& operator -- (   ) { return *this -= 1; }
    Modular<P>  operator ++ (int) { Modular<P> r = *this; *this += 1; return r; }
    Modular<P>  operator -- (int) { Modular<P> r = *this; *this -= 1; return r; }
 
    friend bool operator == (const Modular<P>& n, const Modular<P>& m) { return n.value == m.value; }
    friend bool operator != (const Modular<P>& n, const Modular<P>& m) { return n.value != m.value; }
 
    explicit    operator       int() const { return value; }
    explicit    operator      bool() const { return value; }
    explicit    operator long long() const { return value; }
 
    constexpr static value_type mod()      { return     P; }
 
    value_type norm(long long k) {
        if (!(-P <= k && k < P)) k %= P;
        if (k < 0) k += P;
        return k;
    }
 
    Modular<P> inv() const {
        value_type a = value, b = P, x = 0, y = 1;
        while (a != 0) { value_type k = b / a; b -= k * a; x -= k * y; swap(a, b); swap(x, y); }
        return Modular<P>(x);
    }
    friend void __print(Modular<P> x) {
        cerr << x;
    }
};
template<auto P> Modular<P> pow(Modular<P> m, long long p) {
    Modular<P> r(1);
    while (p) {
        if (p & 1) r *= m;
        m *= m;
        p >>= 1;
    }
    return r;
}
 
template<auto P> ostream& operator << (ostream& o, const Modular<P>& m) { return o << m.value; }
template<auto P> istream& operator >> (istream& i,       Modular<P>& m) { long long k; i >> k; m.value = m.norm(k); return i; }
template<auto P> string   to_string(const Modular<P>& m) { return to_string(m.value); }
 
}
constexpr int mod = 998244353;
using mod_int = mint_ns::Modular<mod>;
using mi = mod_int;
constexpr int maxn = 1e6 + 3;
vector<mi> fct(maxn, 1), invf(maxn, 1);

void calc_fact() {
    for(int i = 1 ; i < maxn ; i++) {
        fct[i] = fct[i - 1] * i;
    }
    invf.back() = mi(1) / fct.back();
    for(int i = maxn - 1 ; i ; i--)
        invf[i - 1] = i * invf[i];
}
 
mi choose(int n, int r) { // choose r elements out of n elements
    if(r > n)   return mi(0);
    assert(r <= n);
    return fct[n] * invf[r] * invf[n - r];
}
 
mi place(int n, int r) { // x1 + x2 ---- xr = n and limit value of xi >= n
    assert(r > 0);
    return choose(n + r - 1, r - 1);
}

template<typename T>
struct BIT {
	vector<T> tree;	int N;
	BIT(int N_ = 0) {
		N = N_;
		tree.resize(N + 1);
	}
	void update(int ind, T val) {
		for(++ind ; ind <= N ; ind += ind & -ind)
			tree[ind] += val;
	}
	T query(int ind) {
		T sum = 0;
		for(++ind ; ind >= 0 ; ind -= ind & -ind)
			sum += tree[ind];
		return sum;
	}
	T query(int L, int R) {
		return query(R) - query(L - 1);
	}
};

int32_t main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);

	const int NN = 5e5 + 1;
	vector<vector<int>> F(NN);
	for (int i = 1 ; i < NN ; ++i) {
		for(int j = 0 ; j <= NN ; j += i)
			F[j].push_back(i);
	}
	calc_fact();

	int sumN = 0;
	auto __solve_testcase = [&](int test) {
		int N = inp.readInt(1, (int)5e5);	inp.readSpace();	sumN += N;
		int M = inp.readInt(1, (int)5e5);	inp.readEoln();
		vector<int> A = inp.readInts(N, 0, M);	inp.readEoln();
		int s = set<int>(A.begin(), A.end()).size();
		if(count(A.begin(), A.end(), 0))
			--s;
		vector<mod_int> B(N + 1);
		for(int x = max(1, s) ; x <= min(N, M) ; ++x)
			B[x] = choose(M - s, x - s);
		vector<int> P(N + 1);
		for(int i = 0 ; i < N ; ++i) {
			P[i + 1] = P[i] + (A[i] == 0);
		}

		map<int, int> ind;
		int bad = 2 * N;
		for(int i = N - 1 ; i >= 0 ; --i) {
			if(A[i] && ind.find(A[i]) != ind.end())
				bad = min(bad, ind[A[i]]);
			ind[A[i]] = i;

			for(auto &x: F[i]) {
			    if(x > N)   break;
				if(x + i > bad) {
					B[x] = 0;
					continue;
				}
				int numz = P[min(i + x, N)] - P[i];
				int tot = min(i + x, N) - i;
				
				B[x] *= choose(x - (tot - numz), numz) * fct[numz];
			}
		}
		cout << accumulate(B.begin(), B.end(), mod_int(0)) << '\n';
	};
	
	int NumTest = 1;
	NumTest = inp.readInt(1, (int)1e5);	inp.readEoln();
	for(int testno = 1; testno <= NumTest ; ++testno) {
		__solve_testcase(testno);
	}

	inp.readEof();
	
	return 0;
}

Editorialist's code (Python)
mod = 998244353
N = 500005
fac = [1]*N
for i in range(1, N): fac[i] = fac[i-1] * i % mod
inv = fac[:]
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(N-1)): inv[i] = inv[i+1] * (i+1) % mod
def C(n, r):
    if n < 0 or n < r: return 0
    return fac[n] * inv[r] % mod * inv[n-r] % mod

for _ in range(int(input())):
    n, m = map(int, input().split())
    a = list(map(int, input().split()))
    mark = [0]*(m+1)
    for x in a: mark[x] = 1
    already_have = mark[1:].count(1)
    pref = [0]*(n+1)
    for i in range(n):
        if a[i] > 0: pref[i+1] = 1
        pref[i+1] += pref[i]
    
    next_dup = [n+1]*(n+2)
    last_seen = [n+1]*(m+1)
    for i in reversed(range(n)):
        if a[i] == 0: next_dup[i+1] = next_dup[i+2]
        else:
            next_dup[i+1] = min(next_dup[i+2], last_seen[a[i]])
            last_seen[a[i]] = i+1
    
    ans = 0
    for i in range(max(1, already_have), n+1):
        # C(m - already_have, i - already have) choices for the other elements
        choices = C(m - already_have, i - already_have)
        arrangements = 1
        for L in range(1, n+1, i):
            R = min(n, L+i-1)
            # check if everything is distinct in [L, R]
            if next_dup[L] > R:
                # if it is, find count of things in [L, R]
                in_range = pref[R] - pref[L-1]
                zeros = R-L+1 - in_range
                arrangements = arrangements * C(i - in_range, zeros) % mod * fac[zeros] % mod
            else:
                arrangements = 0
        ans = (ans + choices * arrangements) % mod
    print(ans)

CodeChef: Practical coding for everyone. , i done almost same as you but i am getting TLE

@iceknight1093 can you please see it

Your ncr function has a complexity of \mathcal{O}(\log {MOD}) because of inverse computation, but you can make it \mathcal{O}(1) by precomputing inverse factorials - just doing that brings it down to about 0.3s (submission).

okay Sir Thanks