KBEAUTIFUL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: nicholask
Testers: tabr, iceknight1093
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Combinatorics

PROBLEM:

You have an array A of N integers. It’s called K-beautiful if every subarray of size K has the same sum.

You can perform at most M moves, each one increasing some A_i by 1.
How many possible K-beautiful final arrays are possible?

EXPLANATION:

Suppose A was a K-beautiful array.
Then, the sum (A_i + A_{i+1} + \ldots + A_{i+K-1}) should be constant across all i.
In particular, this means (A_i + A_{i+1} + \ldots + A_{i+K-1}) = (A_{i+1} + A_{i+2} + \ldots + A_{i+K}), which tells us that A_i = A_{i+K}.

That is, if we fix the first K elements of A, the remaining values are fixed.

First however, we need to make everything equal to a baseline.
For each 1 \leq i \leq K, let M_i = \max(A_i, A_{i+K}, A_{i+2K}, \ldots).
We need to first bring all these values up to M_i, which takes (A_i - M_i) + (A_{i+K} - M_i) + \ldots operations.
Compute this for each i and sum up the number of operations needed.

If this baseline number of operations is larger than M, then the answer is immediately 0 since no K-beautiful array can be constructed. Otherwise, remove these operations from M; so now we have upto M ‘free’ operations.

We are now ready to actually solve the problem.
Note that the final array is entirely determined by the (final) values of A_1, A_2, \ldots, A_K.
We know their current values, so the only thing that matters is how much we add to each index.

So, suppose we add x_i to index i.
Each time we add 1 to index i, we also need to add 1 to indices i+K, i+2K, i+3K, \ldots in order to preserve the K-beauty condition.
In particular, suppose there are c_i such indices (including i).

The total number of operations we use is then c_1x_1 + c_2x_2 + \ldots + c_Kx_K.

Our objective is thus to count the number of solutions to

c_1x_1 + c_2x_2 + \ldots + c_Kx_K \leq M

such that each x_i is \geq 0.

A quadratic solution

Let’s fix m (0 \leq m \leq M) and count the number of solutions to c_1x_1 + c_2x_2 + \ldots + c_Kx_K = m.
This should remind you of the classical stars-and-bars problem, but it isn’t quite in that form.

First, note that the c_i aren’t arbitrary integers: in fact, each one of them is either \left\lceil \frac{N}{K} \right\rceil or \left\lfloor \frac{N}{K} \right\rfloor, depending on i.

For now, let’s assume K doesn’t divide N so those two numbers are different.
Let y = \left\lfloor \frac{N}{K} \right\rfloor, and suppose r of the c_i are equal to y.
y and r are constants independent of m.

Rewriting the equation in terms of y and reducing it a bit, we see that

c_1x_1 + c_2x_2 + \ldots + c_Kx_K = m \\ y\cdot (x_1 + x_2 + \ldots + x_r) + (y+1)\cdot (x_{r+1} + x_{r+2} + \ldots + x_K) = m

Now, notice that (x_1 + x_2 + \ldots + x_r) and (x_{r+1} + x_{r+2} + \ldots + x_K) are two essentially independent summations.

Recall that we fixed m. Let’s also fix m_1 = x_1 + x_2 + \ldots + x_r.
Notice that this uniquely fixes the value of x_{r+1} + x_{r+2} + \ldots + x_K, say to m_2.

The equations for m_1 and m_2 are in the stars-and-bars form, and so the number of solutions to each one can be computed in \mathcal{O}(1) time.
The respective coefficients are \displaystyle\binom{m_1+r-1}{m_1} and \displaystyle\binom{m_2+K-r-1}{m_2}.

Simply multiply these two numbers together to obtain the number of solutions for these fixed values of m and m_1.

Iterating over each (m, m_1) pair gives us a solution in \mathcal{O}(M^2), which is too slow.

Note that if K divides N then fixing m also fixes m_1 = \frac{m}{y} (and m_2 doesn’t even come into the picture), so this case can be solved in \mathcal{O}(M) time already.

Speeding up

Notice that fixing any two of m, m_1, m_2 fixes the third one uniquely.
So, let’s fix just m_1 and see what we get.

If we also fix m_2, the resulting value of m must be \leq M.
In particular, m_2 must satisfy the inequality y\cdot m_1 + (y+1)\cdot m_2 \leq M.

This means that the valid values of m_2 form a consecutive set of integers starting from 0.
For each one, we need to add \binom{m_1+r-1}{m_1}\binom{m_2+K-r-1}{m_2} to the answer.

Note that the first quantity is fixed since m_1 is fixed, so we just need the sum of \binom{m_2+K-r-1}{m_2} across all valid m_2.
This isn’t too hard: since the m_2 form a contiguous range starting from 0, we can create an array V where V_i = \binom{i+K-r-1}{i} and take its prefix sums.

This allows us to process a single m_1 in \mathcal{O}(1) time, so iterating it from 0 to M solves the problem in \mathcal{O}(N+M) time.

TIME COMPLEXITY

\mathcal{O}(N + M) per testcase.

CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
const int MXN=2000010;
const long long MOD=998244353,INF=1000000000;
long long f[MXN],inv[MXN],finv[MXN];
void Initialize(){
	f[0]=f[1]=inv[0]=inv[1]=finv[0]=finv[1]=1;
	for (int i=2; i<MXN; i++){
		f[i]=f[i-1]*i%MOD;
		inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
		finv[i]=finv[i-1]*inv[i]%MOD;
	}
}
long long nCr(int n,int r){
	if (n<r) return 0LL;
	return f[n]*finv[r]%MOD*finv[n-r]%MOD;
}
long long nHr(int n,int r){
	return nCr(n+r-1,r);
}
void solve(){
	int n,m,k;
	cin>>n>>m>>k;
	long long a[n+1];
	for (int i=1; i<=n; i++) cin>>a[i];
	long long mx[k+1];
	for (int i=1; i<=k; i++){
		mx[i]=0;
		for (int j=i; j<=n; j+=k){
			mx[i]=max(mx[i],a[j]);
		}
		for (int j=i; j<=n; j+=k){
			m-=min(INF,mx[i]-a[j]);
			if (m<0){
				cout<<"0\n";
				return;
			}
		}
	}
	long long sum[m/(n/k)+1];
	for (int i=0; i<=m/(n/k); i++){
		if (i>0) sum[i]=sum[i-1];
		else sum[i]=0;
		sum[i]+=nHr(k-n%k,i);
		sum[i]%=MOD; 
	}
	long long ans=0;
	if (n%k==0){
		int mxCnt=m/(n/k);
		cout<<sum[mxCnt]<<'\n';
		return;
	}
	for (int cntBig=0; cntBig<=m/((n+k-1)/k); cntBig++){
		int mxSmall=(m-(n+k-1)/k*cntBig)/(n/k);
		ans+=nHr(n%k,cntBig)*sum[mxSmall];
		ans%=MOD;
	}
	cout<<ans<<'\n';
}
int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	int T=1;
	cin>>T;
	Initialize();
	while (T--) solve();
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 20000);
    in.readEoln();
    int sn = 0, sm = 0;
    while (tt--) {
        int n = in.readInt(1, 1e6);
        in.readSpace();
        long long m = in.readInt(1, 1e6);
        in.readSpace();
        int k = in.readInt(1, n);
        in.readEoln();
        sn += n;
        sm += m;
        vector<long long> a = in.readLongs(n, 1, 1e18);
        in.readEoln();
        vector<vector<long long>> b(k);
        for (int i = 0; i < n; i++) {
            b[i % k].emplace_back(a[i]);
        }
        map<int, int> cnt;
        for (int i = 0; i < k; i++) {
            sort(b[i].begin(), b[i].end());
            cnt[(int) b[i].size()]++;
            for (int j = 0; j < (int) b[i].size(); j++) {
                m -= b[i].back() - b[i][j];

                // don't forget!
                if (m < 0) {
                    break;
                }

            }
        }
        if (m < 0) {
            cout << 0 << '\n';
            continue;
        }
        mint ans = 0;
        if (cnt.size() == 1) {
            ans = C(m / cnt.begin()->first + cnt.begin()->second, cnt.begin()->second);
        } else {
            for (int i = 0; i * 1LL * cnt.begin()->first <= m; i++) {
                long long t = m - i * 1LL * cnt.begin()->first;
                ans += C(i + cnt.begin()->second - 1, cnt.begin()->second - 1) * C(t / cnt.rbegin()->first + cnt.rbegin()->second, cnt.rbegin()->second);
            }
        }
        cout << ans << '\n';
    }
    in.readEof();
    assert(max(sn, sm) <= 1e6);
    return 0;
}
Editorialist's code (Python)
mod, maxn = 998244353, 2 * 10**6 + 20
fac = [1]
for i in range(1,  maxn):
	fac.append(i * fac[i-1] % mod)

ifac = fac[:]
ifac[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(maxn-1)):
    ifac[i] = ifac[i+1] * (i+1) % mod

def C(n, r):
	if n < r or r < 0: return 0
	return fac[n] * ifac[r] * ifac[n-r] % mod

def f(n, k): # number of integer solutions to a1 + a2 + ... + an = k, each ai >= 0
	if n == 0: return 1 if k == 0 else 0
	return C(n+k-1, k)

for _ in range(int(input())):
	n, m, k = map(int, input().split())
	a = list(map(int, input().split()))
	base_ops = 0
	for i in range(k):
		cur = a[i::k]
		base_ops += len(cur) * max(cur) - sum(cur)
	if base_ops > m:
		print(0)
		continue
	m -= base_ops

	lo, hi, ct = n//k, (n+k-1)//k, (-n)%k
	ans = 0

	dp = [f(k - ct, 0)]
	for i in range(1, m+1):
		dp.append(dp[-1] + f(k - ct, i))
	
	for i in range(m // lo + 1):
		y = m - lo*i
		ans += f(ct, i) * dp[y // hi] % mod
	print(ans % mod)
2 Likes

I did same as Editorial still got WA, please help me figure out, where i am getting wrong.
My solution link : Solution Link

1 Like

You forgot to reduce the pref[j] values modulo.

1 Like

Thank you, I don’t know how I missed it.
I think it’s correct time to use mint template.

BTW, thanks for such a beautiful editorial. Learned a lot !!

1 Like