Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093
Combinatorics, specifically stars-and-bars
From an array A of length N-1 containing only ones and twos, and A_1 = 1, we construct a tree on N vertices as follows:
- For each 1 \leq i \leq N-1,
- If A_i = 1, add an edge between i and i+1.
- Otherwise, add an edge between i-1 and i+1.
The height of a tree is the maximum distance of some vertex from 1.
You’re given N and K. Across all arrays of length N-1 containing exactly K ones (the other elements being twos), and first element 1, compute the sum of heights of the constructed trees.
Rather than adding edges between vertices, we can think of the tree construction process as A_i telling us which existing vertex to attach i+1 to.
We start with a single vertex, 1.
- If A_i = 1, attach vertex i+1 to vertex i.
- Otherwise, attach i+1 to i-1.
If d_i denotes the distance of vertex i from 1, we have:
- d_{i+1} = 1 + d_i in the first case.
- d_{i+1} = 1 + d_{i-1} in the second.
Observe that the array d is non-decreasing, i.e, d_i \leq d_{i+1} for every i.
If d_{i+1} = 1 + d_i this is obvious; otherwise we have d_{i+1} = 1 + d_{i-1}, and d_i cannot exceed d_{i-1} by more than 1 so it can’t be larger than d_{i+1} either.
In particular, this means that the furthest vertex from 1 is always going to be N, meaning we only care about d_N.
That is, we want to find the sum of d_N across all arrays A.
Now, let’s analyze the structure of the tree that is formed.
There’ll be a prefix of ones, forming a chain of vertices.
Then, we’ll have several twos - which will give us two chains of approximately equal length (at any point, their lengths will differ by at most 1).
Then, we’ll have some ones - this will end one of the chains and continue the other one.
Next, we have some twos again - as before, the single chain will split into two of approximately equal length, and so on and so forth.
Since we care about the distance between 1 and N, from that perspective the tree will look like a single long chain connecting 1 to N, and then there will be several smaller chains hanging off some parts of this long chain; but nothing more complicated than that.
There are N-1 edges in total, so the distance from 1 to N can be obtained by subtracting the total length of these smaller chains from N-1, since each edge in a smaller chain doesn’t contribute to the path between 1 and N.
From the construction process, recall that the smaller chains are formed by the contiguous occurrences of 2 in A.
In particular, if there’s a block of m twos, it will create \left\lceil \frac{m}{2} \right\rceil edges not on the main path.
Now, we know that A must contain exactly K ones - which means it will also contain exactly K blocks of twos, one after each 1 (some of these blocks may be empty, that’s ok).
If the K blocks have sizes x_1, x_2, \ldots, x_K, the distance from 1 to N will then be
Our task is thus to compute this sum across all possible sequences of x_i, i.e, across all possible sequences x_i such that \sum_{i=1}^K x_i = N-1-K.
Dealing with the \left\lceil \frac{x_i}{2} \right\rceil term is annoying since it depends on whether x_i is odd or even.
Let’s fix the number of x_i that are odd - suppose r of them are odd.
There are \binom{K}{r} ways to choose which r are odd.
Let x_i = 2y_i + 1 for the odd x_i, and x_i = 2y_i for the even ones.
Plugging this into
gives us
Counting the number of sequences of x_i with exactly r odd among them, is equivalent to counting the number of valid y_i satisfying this equation.
If the right side is odd, no solution exists; otherwise by stars-and-bars the number of solutions is
Further, note that \left\lceil \frac{x_i}{2} \right\rceil is y_i when x_i is even, and y_i + 1 otherwise.
That means, for this fixed r, with N-1-K-r being even, the contribution to the answer is
which can be computed in constant time with precomputed factorials and inverse factorials.
Sum this up across all r from 0 to K to obtain the overall answer.
\mathcal{O}(N) per testcase.
Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
#ifdef LOCAL
#include "debug.h"
#define debug(...) 42
const int MOD = 998244353;
const int N = 5e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
ll fact[N], ifact[N];
ll bexp(ll a, ll b) {
a %= MOD;
if (a == 0) return 0;
ll res = 1;
while (b) {
if (b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
return res;
ll invmod(ll a) {
return bexp(a, MOD - 2);
ll ncr(ll n, ll r) {
if (n < 0 or r < 0 or n < r) return 0;
return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
ll npr(ll n, ll r) {
if (n < 0 or r < 0 or n < r) return 0;
return fact[n] * ifact[n - r] % MOD;
void precalc(ll n) {
fact[0] = 1;
rep1(i, n) fact[i] = fact[i - 1] * i % MOD;
ifact[n] = invmod(fact[n]);
rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
void solve(int test_case){
ll n,k; cin >> n >> k;
ll ans = 0;
ll ways = ncr(c+k-1,c)*ncr(k,n-1-k-2*c)%MOD;
ll val = c+k;
ans += ways*val;
ans %= MOD;
cout << ans << endl;
int main()
int t = 1;
cin >> t;
rep1(i, t) {
return 0;
Editorialist's code (PyPy3)
mod = 998244353
maxN = 500005
fac = [1]
for n in range(1, maxN): fac.append(fac[-1] * n % mod)
inv = fac[:]
inv[-1] = pow(inv[-1], mod-2, mod)
for n in reversed(range(maxN-1)): inv[n] = inv[n+1] * (n+1) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * inv[r] * inv[n-r] % mod
for _ in range(int(input())):
n, k = map(int, input().split())
ans = 0
for odd in range(k+1):
if (n-1-k-odd)%2: continue
ans += (n-1-odd - (n-1-k-odd)//2) * C(k, odd) * C((n-1-k-odd)//2 + k - 1, k - 1) % mod
print(ans % mod)