COUNTISFUN7 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Stars and bars

PROBLEM:

For a permutation P of length M, let F(P, K) denote the number of sorted arrays of length K with elements from 1 to K, such that

\min_{j=1}^{A_1} P_j = \min_{j=1}^{A_i} P_j

for every 2 \leq i \leq K.

Given N, K, V find the sum of F(P, K) across all permutations P of \{1, 2, \ldots, N\} such that P_1 = V.

EXPLANATION:

The condition \min_{j=1}^{A_1} P_j = \min_{j=1}^{A_i} P_j for a sorted array A means that the prefix minimum of P till index A_1 should equal the prefix minimum of P till index A_K; the rest will be satisfied automatically.
In particular, this means that if M is the prefix minimum till A_1, the elements 1, 2, 3, \ldots, M-1 can appear only after A_K.
Since P_1 = V is fixed, the prefix minimum till A_1 also cannot exceed V, no matter what A_1 is.

Let’s fix M, the prefix minimum till A_1, and then try to count all sorted arrays A (and permutations they belong to) that are satisfied by M.

This is rather hard if approached directly — the trick is to look at differences.
Observe that any sorted array can be represented in terms of the differences of its adjacent elements (as long as at least one of the elements is known, of course).

So, let’s try to represent everything in terms of differences; or rather distances.
We know P_1 = V, which can’t be changed.
Let d_1 denote the distance between index 1 and the position of M, the minimum.
d_1 can be any positive integer.

After M's position is fixed, we have the array A. Note that all the elements of A must lie to the right of the position of M.
So, we can have distances d_2, d_3, \ldots, d_K, where:

  • d_2 is the distance between the position of M and A_1.
  • For 3\leq i \leq K+1, d_i is the difference between A_{i-2} and A_{i-1}.

With this, everything upto A_K is determined by choosing values for the d_i.
Note that each of d_2, d_3, d_{K+1} is a non-negative integer.
Next, recall that the elements 1, 2, 3, \ldots, M-1 must be placed after A_K.

So, let’s first fix their order (in (M-1)! ways, of course), then try to once again use distances to define their positions.
if you’ve been following so far, this is straightforward: we need to choose M-1 distinct indices after A_K, which can just be modelled as having M-1 positive distances.
We can denote these by d_{K+2}, d_{K+3}, \ldots, d_{K+M}; all of which are positive.

This pretty much completes our model - notice that once we choose values for the d_i, the positions of everything and the array A are determined.
We just need to take care of a couple more things:

  • We impose the constraint d_1+d_2+d_3+\ldots + d_{K+M} \leq N-1, because the final index we choose should be within the array (we start from index 1 and add differences; which is why the sum is bounded by N-1 and not N).
  • The order of elements other than 1, 2, \ldots, M and V hasn’t been decided; but also really doesn’t matter since it doesn’t affect A or the prefix minimums of A.
    So, we can choose any of the (N-M-1)! orders for them.

We’re almost done.
All we need to do is count the number of solutions to d_1+d_2+d_3+\ldots + d_{K+M} \leq N-1 and multiply by (M-1)! \cdot (N-M-1)! to get the total contribution of M.

If we had equality instead of an inequality, this wouldn’t be so hard: the number of solutions can be found using stars-and-bars.
As-is, we could enumerate all \mathcal{O}(N) choices for what the sum should be and find the answer, but there’s a quicker way.
Let’s add a slack variable, d_{M+K+1}, which denotes how far away from N-1 the sum is.

Then, all we need to do is count the number of solutions to

d_1 + d_2 + \ldots + d_{M+K+1} = N-1

which is a single application of stars-and-bars!

Be careful though: some of the d_i must be positive while others can be non-negative.
Stars-and-bars applies only when they’re all of the same type, so add/subtract 1 from both sides appropriately so get an equation where all variables are of the same type.
The required binomial coefficient comes out to be \binom{N+K-1}{M+K}.


The above solved a single M \lt V.
For M = V, almost the same analysis holds: the only difference is that we no longer need the variable d_1 (since V is always fixed to the first position), and the number of “unfixed” elements is (N-M) instead of (N-M-1) (since the previous case had V and 1, 2, \ldots, M fixed but we have V = M).
The process is still the same though, and you’ll get a similar-looking binomial coefficient.

Now that a single M can be solved in \mathcal{O}(1) time, just iterate over them all and add up the answers to get a solution in \mathcal{O}(N) (or rather, \mathcal{O}(V) if necessary stuff is precomputed).

TIME COMPLEXITY:

\mathcal{O}(N) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
#define pb push_back                  
#define mp make_pair          
#define nline "\n"                            
#define f first                                            
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>      
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=2000200;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
} 
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){
        fact[i]=(fact[i-1]*i)%MOD;
    }
    inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
    for(ll i=MAX-2;i>=0;i--){
        inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
    }
}
ll nCr(ll a,ll b,ll MOD){
    if(a==b){
        return 1;
    }
    if((a<0)||(a<b)||(b<0))
        return 0;   
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
    return (denom*fact[a])%MOD;  
}
void solve(){ 
    ll n,k,v; cin>>n>>k>>v;
    ll ans=0;  
    for(ll i=1;i<=v;i++){
        ll gaps=i+k+(i!=v);
        ll balls=n-i-(i!=v);
        ll now=(fact[i-1]*fact[balls])%MOD;
        ans=(ans+nCr(gaps+balls-1,gaps-1,MOD)*now)%MOD;
    }
    cout<<ans<<nline;
    return;    
}                                       
int main()                                                                               
{       
    ios_base::sync_with_stdio(false);                          
    cin.tie(NULL);                               
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;                 
    cin>>test_cases;
    precompute(MOD);
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
#ifdef IGNORE_CR
            if (buffer[pos] == '\r') {
                pos++;
                continue;
            }
#endif
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <long long mod>
struct modular {
    long long value;
    modular(long long x = 0) {
        value = x % mod;
        if (value < 0) value += mod;
    }
    modular& operator+=(const modular& other) {
        if ((value += other.value) >= mod) value -= mod;
        return *this;
    }
    modular& operator-=(const modular& other) {
        if ((value -= other.value) < 0) value += mod;
        return *this;
    }
    modular& operator*=(const modular& other) {
        value = value * other.value % mod;
        return *this;
    }
    modular& operator/=(const modular& other) {
        long long a = 0, b = 1, c = other.value, m = mod;
        while (c != 0) {
            long long t = m / c;
            m -= t * c;
            swap(c, m);
            a -= t * b;
            swap(a, b);
        }
        a %= mod;
        if (a < 0) a += mod;
        value = value * a % mod;
        return *this;
    }
    friend modular operator+(const modular& lhs, const modular& rhs) { return modular(lhs) += rhs; }
    friend modular operator-(const modular& lhs, const modular& rhs) { return modular(lhs) -= rhs; }
    friend modular operator*(const modular& lhs, const modular& rhs) { return modular(lhs) *= rhs; }
    friend modular operator/(const modular& lhs, const modular& rhs) { return modular(lhs) /= rhs; }
    modular& operator++() { return *this += 1; }
    modular& operator--() { return *this -= 1; }
    modular operator++(int) {
        modular res(*this);
        *this += 1;
        return res;
    }
    modular operator--(int) {
        modular res(*this);
        *this -= 1;
        return res;
    }
    modular operator-() const { return modular(-value); }
    bool operator==(const modular& rhs) const { return value == rhs.value; }
    bool operator!=(const modular& rhs) const { return value != rhs.value; }
    bool operator<(const modular& rhs) const { return value < rhs.value; }
};
template <long long mod>
string to_string(const modular<mod>& x) {
    return to_string(x.value);
}
template <long long mod>
ostream& operator<<(ostream& stream, const modular<mod>& x) {
    return stream << x.value;
}
template <long long mod>
istream& operator>>(istream& stream, modular<mod>& x) {
    stream >> x.value;
    x.value %= mod;
    if (x.value < 0) x.value += mod;
    return stream;
}

constexpr long long mod = 998244353;
using mint = modular<mod>;

mint power(mint a, long long n) {
    mint res = 1;
    while (n > 0) {
        if (n & 1) {
            res *= a;
        }
        a *= a;
        n >>= 1;
    }
    return res;
}

vector<mint> fact(1, 1);
vector<mint> finv(1, 1);

mint C(int n, int k) {
    if (n < k || k < 0) {
        return mint(0);
    }
    while ((int) fact.size() < n + 1) {
        fact.emplace_back(fact.back() * (int) fact.size());
        finv.emplace_back(mint(1) / fact.back());
    }
    return fact[n] * finv[k] * finv[n - k];
}

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e6);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1e6);
        in.readSpace();
        int k = in.readInt(1, n);
        in.readSpace();
        int v = in.readInt(1, n);
        in.readEoln();
        mint ans = 0;
        C(n + 10, 0);
        for (int x = 1; x < v; x++) {
            mint add = C(n + k - 1, k + x);
            add *= fact[x - 1];
            add *= fact[n - 1 - x];
            ans += add;
        }
        {
            mint add = C(n + k - 1, k + v - 1);
            add *= fact[v - 1];
            add *= fact[n - v];
            ans += add;
        }
        cout << ans << '\n';
    }
    assert(sn <= 1e6);
    in.readEof();
    return 0;
}
Editorialist's code (Python)
mod = 998244353
maxn = 3*10**6 + 100
fac = [1]*maxn
for i in range(1, maxn): fac[i] = i*fac[i-1] % mod
inv = [1]*maxn
inv[-1] = pow(fac[-1], mod-2, mod)
for i in reversed(range(maxn-1)): inv[i] = (i+1)*inv[i+1] % mod

def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * inv[r] % mod * inv[n-r] % mod

for _ in range(int(input())):
    n, k, v = map(int, input().split())
    ans = 0
    for m in range(1, v+1):
        ways = 0
        if m < v:
            ways = fac[m-1] * fac[n-m-1] % mod
            # m+k+1 integers summing to N-1
            # k+1 are non-negative -> m+k+1 positive integers summing to N+k
            ways = ways * C(n+k-1, m+k) % mod
        else:
            ways = fac[m-1] * fac[n-m] % mod
            # m+k integers summing to N-1
            # k+1 are non-negative -> m+k positive integers summing to N+k
            ways = ways * C(n+k-1, m+k-1) % mod
        ans = (ans + ways) % mod
    print(ans)

I think you meant to write prefix minimum instead of prefix maximum in this editorial.

Ah you’re right, I’ll fix it, thanks!
Edit: fixed.