# KBEAUTIFUL - Editorial

Testers: tabr, iceknight1093
Editorialist: iceknight1093

TBD

Combinatorics

# PROBLEM:

You have an array A of N integers. It’s called K-beautiful if every subarray of size K has the same sum.

You can perform at most M moves, each one increasing some A_i by 1.
How many possible K-beautiful final arrays are possible?

# EXPLANATION:

Suppose A was a K-beautiful array.
Then, the sum (A_i + A_{i+1} + \ldots + A_{i+K-1}) should be constant across all i.
In particular, this means (A_i + A_{i+1} + \ldots + A_{i+K-1}) = (A_{i+1} + A_{i+2} + \ldots + A_{i+K}), which tells us that A_i = A_{i+K}.

That is, if we fix the first K elements of A, the remaining values are fixed.

First however, we need to make everything equal to a baseline.
For each 1 \leq i \leq K, let M_i = \max(A_i, A_{i+K}, A_{i+2K}, \ldots).
We need to first bring all these values up to M_i, which takes (A_i - M_i) + (A_{i+K} - M_i) + \ldots operations.
Compute this for each i and sum up the number of operations needed.

If this baseline number of operations is larger than M, then the answer is immediately 0 since no K-beautiful array can be constructed. Otherwise, remove these operations from M; so now we have upto M ‘free’ operations.

We are now ready to actually solve the problem.
Note that the final array is entirely determined by the (final) values of A_1, A_2, \ldots, A_K.
We know their current values, so the only thing that matters is how much we add to each index.

So, suppose we add x_i to index i.
Each time we add 1 to index i, we also need to add 1 to indices i+K, i+2K, i+3K, \ldots in order to preserve the K-beauty condition.
In particular, suppose there are c_i such indices (including i).

The total number of operations we use is then c_1x_1 + c_2x_2 + \ldots + c_Kx_K.

Our objective is thus to count the number of solutions to

c_1x_1 + c_2x_2 + \ldots + c_Kx_K \leq M

such that each x_i is \geq 0.

Let’s fix m (0 \leq m \leq M) and count the number of solutions to c_1x_1 + c_2x_2 + \ldots + c_Kx_K = m.
This should remind you of the classical stars-and-bars problem, but it isn’t quite in that form.

First, note that the c_i aren’t arbitrary integers: in fact, each one of them is either \left\lceil \frac{N}{K} \right\rceil or \left\lfloor \frac{N}{K} \right\rfloor, depending on i.

For now, let’s assume K doesn’t divide N so those two numbers are different.
Let y = \left\lfloor \frac{N}{K} \right\rfloor, and suppose r of the c_i are equal to y.
y and r are constants independent of m.

Rewriting the equation in terms of y and reducing it a bit, we see that

c_1x_1 + c_2x_2 + \ldots + c_Kx_K = m \\ y\cdot (x_1 + x_2 + \ldots + x_r) + (y+1)\cdot (x_{r+1} + x_{r+2} + \ldots + x_K) = m

Now, notice that (x_1 + x_2 + \ldots + x_r) and (x_{r+1} + x_{r+2} + \ldots + x_K) are two essentially independent summations.

Recall that we fixed m. Let’s also fix m_1 = x_1 + x_2 + \ldots + x_r.
Notice that this uniquely fixes the value of x_{r+1} + x_{r+2} + \ldots + x_K, say to m_2.

The equations for m_1 and m_2 are in the stars-and-bars form, and so the number of solutions to each one can be computed in \mathcal{O}(1) time.
The respective coefficients are \displaystyle\binom{m_1+r-1}{m_1} and \displaystyle\binom{m_2+K-r-1}{m_2}.

Simply multiply these two numbers together to obtain the number of solutions for these fixed values of m and m_1.

Iterating over each (m, m_1) pair gives us a solution in \mathcal{O}(M^2), which is too slow.

Note that if K divides N then fixing m also fixes m_1 = \frac{m}{y} (and m_2 doesn’t even come into the picture), so this case can be solved in \mathcal{O}(M) time already.

### Speeding up

Notice that fixing any two of m, m_1, m_2 fixes the third one uniquely.
So, let’s fix just m_1 and see what we get.

If we also fix m_2, the resulting value of m must be \leq M.
In particular, m_2 must satisfy the inequality y\cdot m_1 + (y+1)\cdot m_2 \leq M.

This means that the valid values of m_2 form a consecutive set of integers starting from 0.

Note that the first quantity is fixed since m_1 is fixed, so we just need the sum of \binom{m_2+K-r-1}{m_2} across all valid m_2.
This isn’t too hard: since the m_2 form a contiguous range starting from 0, we can create an array V where V_i = \binom{i+K-r-1}{i} and take its prefix sums.

This allows us to process a single m_1 in \mathcal{O}(1) time, so iterating it from 0 to M solves the problem in \mathcal{O}(N+M) time.

# TIME COMPLEXITY

\mathcal{O}(N + M) per testcase.

# CODE:

Setter's code (C++)
#include <bits/stdc++.h>
using namespace std;
const int MXN=2000010;
const long long MOD=998244353,INF=1000000000;
long long f[MXN],inv[MXN],finv[MXN];
void Initialize(){
f[0]=f[1]=inv[0]=inv[1]=finv[0]=finv[1]=1;
for (int i=2; i<MXN; i++){
f[i]=f[i-1]*i%MOD;
inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
finv[i]=finv[i-1]*inv[i]%MOD;
}
}
long long nCr(int n,int r){
if (n<r) return 0LL;
return f[n]*finv[r]%MOD*finv[n-r]%MOD;
}
long long nHr(int n,int r){
return nCr(n+r-1,r);
}
void solve(){
int n,m,k;
cin>>n>>m>>k;
long long a[n+1];
for (int i=1; i<=n; i++) cin>>a[i];
long long mx[k+1];
for (int i=1; i<=k; i++){
mx[i]=0;
for (int j=i; j<=n; j+=k){
mx[i]=max(mx[i],a[j]);
}
for (int j=i; j<=n; j+=k){
m-=min(INF,mx[i]-a[j]);
if (m<0){
cout<<"0\n";
return;
}
}
}
long long sum[m/(n/k)+1];
for (int i=0; i<=m/(n/k); i++){
if (i>0) sum[i]=sum[i-1];
else sum[i]=0;
sum[i]+=nHr(k-n%k,i);
sum[i]%=MOD;
}
long long ans=0;
if (n%k==0){
int mxCnt=m/(n/k);
cout<<sum[mxCnt]<<'\n';
return;
}
for (int cntBig=0; cntBig<=m/((n+k-1)/k); cntBig++){
int mxSmall=(m-(n+k-1)/k*cntBig)/(n/k);
ans+=nHr(n%k,cntBig)*sum[mxSmall];
ans%=MOD;
}
cout<<ans<<'\n';
}
int main(){
ios_base::sync_with_stdio(0); cin.tie(0);
int T=1;
cin>>T;
Initialize();
while (T--) solve();
}

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}

assert(pos < (int) buffer.size());
string res;
while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
res += buffer[pos];
pos++;
}
return res;
}

string readString(int min_len, int max_len, const string& pattern = "") {
assert(min_len <= max_len);
assert(min_len <= (int) res.size());
assert((int) res.size() <= max_len);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int min_val, int max_val) {
assert(min_val <= max_val);
assert(min_val <= res);
assert(res <= max_val);
return res;
}

long long readLong(long long min_val, long long max_val) {
assert(min_val <= max_val);
assert(min_val <= res);
assert(res <= max_val);
return res;
}

vector<int> readInts(int size, int min_val, int max_val) {
assert(min_val <= max_val);
vector<int> res(size);
for (int i = 0; i < size; i++) {
if (i != size - 1) {
}
}
return res;
}

vector<long long> readLongs(int size, long long min_val, long long max_val) {
assert(min_val <= max_val);
vector<long long> res(size);
for (int i = 0; i < size; i++) {
if (i != size - 1) {
}
}
return res;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
};

template <long long mod>
struct modular {
long long value;
modular(long long x = 0) {
value = x % mod;
if (value < 0) value += mod;
}
modular& operator+=(const modular& other) {
if ((value += other.value) >= mod) value -= mod;
return *this;
}
modular& operator-=(const modular& other) {
if ((value -= other.value) < 0) value += mod;
return *this;
}
modular& operator*=(const modular& other) {
value = value * other.value % mod;
return *this;
}
modular& operator/=(const modular& other) {
long long a = 0, b = 1, c = other.value, m = mod;
while (c != 0) {
long long t = m / c;
m -= t * c;
swap(c, m);
a -= t * b;
swap(a, b);
}
a %= mod;
if (a < 0) a += mod;
value = value * a % mod;
return *this;
}
friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
modular& operator++() { return *this += 1; }
modular& operator--() { return *this -= 1; }
modular operator++(int) {
modular res(*this);
*this += 1;
return res;
}
modular operator--(int) {
modular res(*this);
*this -= 1;
return res;
}
modular operator-() const { return modular(-value); }
bool operator==(const modular& rhs) const { return value == rhs.value; }
bool operator!=(const modular& rhs) const { return value != rhs.value; }
bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
stream >> x.value;
x.value %= mod;
if (x.value < 0) x.value += mod;
return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
mint res = 1;
while (n > 0) {
if (n & 1) {
res *= a;
}
a *= a;
n >>= 1;
}
return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
if (n < k || k < 0) {
return mint(0);
}
while ((int) fact.size() < n + 1) {
fact.emplace_back(fact.back() * (int) fact.size());
finv.emplace_back(mint(1) / fact.back());
}
return fact[n] * finv[k] * finv[n - k];
}

int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
input_checker in;
int sn = 0, sm = 0;
while (tt--) {
long long m = in.readInt(1, 1e6);
sn += n;
sm += m;
vector<long long> a = in.readLongs(n, 1, 1e18);
vector<vector<long long>> b(k);
for (int i = 0; i < n; i++) {
b[i % k].emplace_back(a[i]);
}
map<int, int> cnt;
for (int i = 0; i < k; i++) {
sort(b[i].begin(), b[i].end());
cnt[(int) b[i].size()]++;
for (int j = 0; j < (int) b[i].size(); j++) {
m -= b[i].back() - b[i][j];

// don't forget!
if (m < 0) {
break;
}

}
}
if (m < 0) {
cout << 0 << '\n';
continue;
}
mint ans = 0;
if (cnt.size() == 1) {
ans = C(m / cnt.begin()->first + cnt.begin()->second, cnt.begin()->second);
} else {
for (int i = 0; i * 1LL * cnt.begin()->first <= m; i++) {
long long t = m - i * 1LL * cnt.begin()->first;
ans += C(i + cnt.begin()->second - 1, cnt.begin()->second - 1) * C(t / cnt.rbegin()->first + cnt.rbegin()->second, cnt.rbegin()->second);
}
}
cout << ans << '\n';
}
assert(max(sn, sm) <= 1e6);
return 0;
}

Editorialist's code (Python)
mod, maxn = 998244353, 2 * 10**6 + 20
fac = [1]
for i in range(1,  maxn):
fac.append(i * fac[i-1] % mod)

ifac = fac[:]
ifac[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(maxn-1)):
ifac[i] = ifac[i+1] * (i+1) % mod

def C(n, r):
if n < r or r < 0: return 0
return fac[n] * ifac[r] * ifac[n-r] % mod

def f(n, k): # number of integer solutions to a1 + a2 + ... + an = k, each ai >= 0
if n == 0: return 1 if k == 0 else 0
return C(n+k-1, k)

for _ in range(int(input())):
n, m, k = map(int, input().split())
a = list(map(int, input().split()))
base_ops = 0
for i in range(k):
cur = a[i::k]
base_ops += len(cur) * max(cur) - sum(cur)
if base_ops > m:
print(0)
continue
m -= base_ops

lo, hi, ct = n//k, (n+k-1)//k, (-n)%k
ans = 0

dp = [f(k - ct, 0)]
for i in range(1, m+1):
dp.append(dp[-1] + f(k - ct, i))

for i in range(m // lo + 1):
y = m - lo*i
ans += f(ct, i) * dp[y // hi] % mod
print(ans % mod)

2 Likes

I did same as Editorial still got WA, please help me figure out, where i am getting wrong.
You forgot to reduce the pref[j] values modulo.