# PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Contest: Division 3

Contest: Division 4

* Author:* sushil2006

*sushil2006*

**Tester:***iceknight1093*

**Editorialist:**# DIFFICULTY:

Easy-Medium

# PREREQUISITES:

Combinatorics, specifically stars-and-bars

# PROBLEM:

From an array A of length N-1 containing only ones and twos, and A_1 = 1, we construct a tree on N vertices as follows:

- For each 1 \leq i \leq N-1,
- If A_i = 1, add an edge between i and i+1.
- Otherwise, add an edge between i-1 and i+1.

The height of a tree is the maximum distance of some vertex from 1.

You’re given N and K. Across all arrays of length N-1 containing exactly K ones (the other elements being twos), and first element 1, compute the sum of heights of the constructed trees.

# EXPLANATION:

Rather than adding edges between vertices, we can think of the tree construction process as A_i telling us which existing vertex to attach i+1 to.

We start with a single vertex, 1.

Then,

- If A_i = 1, attach vertex i+1 to vertex i.
- Otherwise, attach i+1 to i-1.

If d_i denotes the distance of vertex i from 1, we have:

- d_{i+1} = 1 + d_i in the first case.
- d_{i+1} = 1 + d_{i-1} in the second.

Observe that the array d is *non-decreasing*, i.e, d_i \leq d_{i+1} for every i.

If d_{i+1} = 1 + d_i this is obvious; otherwise we have d_{i+1} = 1 + d_{i-1}, and d_i cannot exceed d_{i-1} by more than 1 so it can’t be larger than d_{i+1} either.

In particular, this means that the furthest vertex from 1 is always going to be N, meaning we only care about d_N.

That is, we want to find the sum of d_N across all arrays A.

Now, let’s analyze the structure of the tree that is formed.

There’ll be a prefix of ones, forming a chain of vertices.

Then, we’ll have several twos - which will give us two chains of approximately equal length (at any point, their lengths will differ by at most 1).

Then, we’ll have some ones - this will end one of the chains and continue the other one.

Next, we have some twos again - as before, the single chain will split into two of approximately equal length, and so on and so forth.

Since we care about the distance between 1 and N, from that perspective the tree will look like a single long chain connecting 1 to N, and then there will be several smaller chains hanging off some parts of this long chain; but nothing more complicated than that.

There are N-1 edges in total, so the distance from 1 to N can be obtained by subtracting the total length of these smaller chains from N-1, since each edge in a smaller chain doesn’t contribute to the path between 1 and N.

From the construction process, recall that the smaller chains are formed by the contiguous occurrences of 2 in A.

In particular, if there’s a block of m twos, it will create \left\lceil \frac{m}{2} \right\rceil edges not on the main path.

Now, we know that A must contain exactly K ones - which means it will also contain exactly K blocks of twos, one after each 1 (some of these blocks may be empty, that’s ok).

If the K blocks have sizes x_1, x_2, \ldots, x_K, the distance from 1 to N will then be

Our task is thus to compute this sum across all possible sequences of x_i, i.e, across all possible sequences x_i such that \sum_{i=1}^K x_i = N-1-K.

Dealing with the \left\lceil \frac{x_i}{2} \right\rceil term is annoying since it depends on whether x_i is odd or even.

Let’s fix the number of x_i that are odd - suppose r of them are odd.

There are \binom{K}{r} ways to choose which r are odd.

Let x_i = 2y_i + 1 for the odd x_i, and x_i = 2y_i for the even ones.

Plugging this into

gives us

Counting the number of sequences of x_i with exactly r odd among them, is equivalent to counting the number of valid y_i satisfying this equation.

If the right side is odd, no solution exists; otherwise by stars-and-bars the number of solutions is

Further, note that \left\lceil \frac{x_i}{2} \right\rceil is y_i when x_i is even, and y_i + 1 otherwise.

So,

That means, for this fixed r, with N-1-K-r being even, the contribution to the answer is

which can be computed in constant time with precomputed factorials and inverse factorials.

Sum this up across all r from 0 to K to obtain the overall answer.

# TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

# CODE:

## Author's code (C++)

```
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif
/*
*/
const int MOD = 998244353;
const int N = 5e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
ll fact[N], ifact[N];
ll bexp(ll a, ll b) {
a %= MOD;
if (a == 0) return 0;
ll res = 1;
while (b) {
if (b & 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
ll invmod(ll a) {
return bexp(a, MOD - 2);
}
ll ncr(ll n, ll r) {
if (n < 0 or r < 0 or n < r) return 0;
return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}
ll npr(ll n, ll r) {
if (n < 0 or r < 0 or n < r) return 0;
return fact[n] * ifact[n - r] % MOD;
}
void precalc(ll n) {
fact[0] = 1;
rep1(i, n) fact[i] = fact[i - 1] * i % MOD;
ifact[n] = invmod(fact[n]);
rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
}
void solve(int test_case){
ll n,k; cin >> n >> k;
ll ans = 0;
rep(c,n){
ll ways = ncr(c+k-1,c)*ncr(k,n-1-k-2*c)%MOD;
ll val = c+k;
ans += ways*val;
ans %= MOD;
}
cout << ans << endl;
}
int main()
{
precalc(N-1);
fastio;
int t = 1;
cin >> t;
rep1(i, t) {
solve(i);
}
return 0;
}
```

## Editorialist's code (PyPy3)

```
mod = 998244353
maxN = 500005
fac = [1]
for n in range(1, maxN): fac.append(fac[-1] * n % mod)
inv = fac[:]
inv[-1] = pow(inv[-1], mod-2, mod)
for n in reversed(range(maxN-1)): inv[n] = inv[n+1] * (n+1) % mod
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * inv[r] * inv[n-r] % mod
for _ in range(int(input())):
n, k = map(int, input().split())
ans = 0
for odd in range(k+1):
if (n-1-k-odd)%2: continue
ans += (n-1-odd - (n-1-k-odd)//2) * C(k, odd) * C((n-1-k-odd)//2 + k - 1, k - 1) % mod
print(ans % mod)
```