GLOBALWARMIN - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Combinatorics, specifically stars-and-bars

PROBLEM:

From an array A of length N-1 containing only ones and twos, and A_1 = 1, we construct a tree on N vertices as follows:

  • For each 1 \leq i \leq N-1,
    • If A_i = 1, add an edge between i and i+1.
    • Otherwise, add an edge between i-1 and i+1.

The height of a tree is the maximum distance of some vertex from 1.

You’re given N and K. Across all arrays of length N-1 containing exactly K ones (the other elements being twos), and first element 1, compute the sum of heights of the constructed trees.

EXPLANATION:

Rather than adding edges between vertices, we can think of the tree construction process as A_i telling us which existing vertex to attach i+1 to.
We start with a single vertex, 1.
Then,

  • If A_i = 1, attach vertex i+1 to vertex i.
  • Otherwise, attach i+1 to i-1.

If d_i denotes the distance of vertex i from 1, we have:

  • d_{i+1} = 1 + d_i in the first case.
  • d_{i+1} = 1 + d_{i-1} in the second.

Observe that the array d is non-decreasing, i.e, d_i \leq d_{i+1} for every i.
If d_{i+1} = 1 + d_i this is obvious; otherwise we have d_{i+1} = 1 + d_{i-1}, and d_i cannot exceed d_{i-1} by more than 1 so it can’t be larger than d_{i+1} either.

In particular, this means that the furthest vertex from 1 is always going to be N, meaning we only care about d_N.
That is, we want to find the sum of d_N across all arrays A.


Now, let’s analyze the structure of the tree that is formed.

There’ll be a prefix of ones, forming a chain of vertices.
Then, we’ll have several twos - which will give us two chains of approximately equal length (at any point, their lengths will differ by at most 1).
Then, we’ll have some ones - this will end one of the chains and continue the other one.
Next, we have some twos again - as before, the single chain will split into two of approximately equal length, and so on and so forth.

Since we care about the distance between 1 and N, from that perspective the tree will look like a single long chain connecting 1 to N, and then there will be several smaller chains hanging off some parts of this long chain; but nothing more complicated than that.

There are N-1 edges in total, so the distance from 1 to N can be obtained by subtracting the total length of these smaller chains from N-1, since each edge in a smaller chain doesn’t contribute to the path between 1 and N.

From the construction process, recall that the smaller chains are formed by the contiguous occurrences of 2 in A.
In particular, if there’s a block of m twos, it will create \left\lceil \frac{m}{2} \right\rceil edges not on the main path.

Now, we know that A must contain exactly K ones - which means it will also contain exactly K blocks of twos, one after each 1 (some of these blocks may be empty, that’s ok).

If the K blocks have sizes x_1, x_2, \ldots, x_K, the distance from 1 to N will then be

N-1- \left( \sum_{i=1}^K \left\lceil \frac{x_i}{2} \right\rceil \right)

Our task is thus to compute this sum across all possible sequences of x_i, i.e, across all possible sequences x_i such that \sum_{i=1}^K x_i = N-1-K.


Dealing with the \left\lceil \frac{x_i}{2} \right\rceil term is annoying since it depends on whether x_i is odd or even.

Let’s fix the number of x_i that are odd - suppose r of them are odd.
There are \binom{K}{r} ways to choose which r are odd.

Let x_i = 2y_i + 1 for the odd x_i, and x_i = 2y_i for the even ones.
Plugging this into

x_1 + x_2 + \ldots + x_K = N-1-K

gives us

2\cdot (y_1 + y_2 + \ldots + y_K) = N-1-K-r

Counting the number of sequences of x_i with exactly r odd among them, is equivalent to counting the number of valid y_i satisfying this equation.
If the right side is odd, no solution exists; otherwise by stars-and-bars the number of solutions is

\binom{\frac{N-1-K-r}{2} + K - 1}{K-1}

Further, note that \left\lceil \frac{x_i}{2} \right\rceil is y_i when x_i is even, and y_i + 1 otherwise.
So,

\begin{aligned} N-1- \left( \sum_{i=1}^K \left\lceil \frac{x_i}{2} \right\rceil \right) &= N-1- \left( r + \sum_{i=1}^K y_i \right) \\ &= N-1- \left( r + \frac{N-1-K-r}{2} \right) \end{aligned}

That means, for this fixed r, with N-1-K-r being even, the contribution to the answer is

\binom{K}{r} \cdot \left( N-1- \left( r + \frac{N-1-K-r}{2} \right) \right) \cdot \binom{\frac{N-1-K-r}{2} + K - 1}{K-1}

which can be computed in constant time with precomputed factorials and inverse factorials.

Sum this up across all r from 0 to K to obtain the overall answer.

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "YES" << endl
#define no cout << "NO" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 5e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

ll fact[N], ifact[N];

ll bexp(ll a, ll b) {
    a %= MOD;
    if (a == 0) return 0;

    ll res = 1;

    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }

    return res;
}

ll invmod(ll a) {
    return bexp(a, MOD - 2);
}

ll ncr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}

ll npr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[n - r] % MOD;
}

void precalc(ll n) {
    fact[0] = 1;
    rep1(i, n) fact[i] = fact[i - 1] * i % MOD;

    ifact[n] = invmod(fact[n]);
    rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
}

void solve(int test_case){    
    ll n,k; cin >> n >> k;

    ll ans = 0;
    rep(c,n){
        ll ways = ncr(c+k-1,c)*ncr(k,n-1-k-2*c)%MOD;
        ll val = c+k;
        ans += ways*val;
        ans %= MOD;
    }
    
    cout << ans << endl;
}

int main()
{
    precalc(N-1);

    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}

Editorialist's code (PyPy3)
mod = 998244353
maxN = 500005
fac = [1]
for n in range(1, maxN): fac.append(fac[-1] * n % mod)
inv = fac[:]
inv[-1] = pow(inv[-1], mod-2, mod)
for n in reversed(range(maxN-1)): inv[n] = inv[n+1] * (n+1) % mod

def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * inv[r] * inv[n-r] % mod

for _ in range(int(input())):
    n, k = map(int, input().split())

    ans = 0
    for odd in range(k+1):
        if (n-1-k-odd)%2: continue
        ans += (n-1-odd - (n-1-k-odd)//2) * C(k, odd) * C((n-1-k-odd)//2 + k - 1, k - 1) % mod
    print(ans % mod)
1 Like