PERMSHOP - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: sushil2006
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy-Med

PREREQUISITES:

Combinatorics

PROBLEM:

For a permutation P, define f(P) to be the number of integers K satisfying the following conditions:

  • 2 \leq K \leq N
  • Suppose you have K coins, and item i costs P_i coins.
  • Let S_1 be the set of items bought by greedily buying an item whenever possible from left to right.
  • Let S_1 be the set of items bought by greedily buying an item whenever possible from right to left.
  • Then S_1 = S_2.

You’re given a partially filled permutation P. Compute the sum of f(P) across all ways of filling it in.

EXPLANATION:

First, let’s understand how to compute f(P) at all.

if you try a few small values of K by hand, you might notice the following:

  • K = 2 will never work, because in one direction we’ll buy 1 but not 2, and in the other direction we’ll buy 2 but not 1.
  • K = 3 works when 3 is present between 1 and 2, i.e. either [1, 3, 2] or [2, 3, 1] appear as subsequences.
    In these two cases, we’ll buy both 1 and 2 from both directions.
    In all other cases, we’ll buy 3 from one direction and not from the other so K = 3 won’t work.
  • K = 4 and K = 5 will never work.
  • K = 6 works sometimes, for example with [1, 2, 4, 5, 6, 3].
    Further, when it does work, the set of items we buy will be \{1, 2, 3\} from both directions; and in fact, you might also notice that K = 6 works only in cases where K = 3 doesn’t.

Interestingly, it turns out that these are pretty much all the possible cases - that is, K \gt 6 will never work, no matter what N or the permutation are!

Proof

Let’s fix some K \gt 6, and a permutation P.

Suppose K is good with respect to P, i.e. Alice and Bob pick the same set of items.
Then, it can be shown that the items picked must be \{1, 2, 3, \ldots, x\} for some integer x, i.e. a consecutive set of values starting from 1.
This is not hard to see: suppose x+1 is picked but x is not, and w.l.o.g let x appear before x+1 in the permutation.
Then, Alice picking x+1 means that she had at least x+1 coins remaining when reaching x+1, which in turn means she had at least x+1 coins when crossing x, contradicting her not buying it.

Note that this means K = \frac{x\cdot (x+1)}{2}, which in turn means x is much smaller than N (because K \leq N) - it’s about 2\sqrt N at best.
This is particularly important because it means that x \lt N, so the value x+1 does exist in P.

Let the sum of the chosen elements that appear before x+1 be S_1, and the sum of the elements that appear after it be S_2.
Then, it can be seen that:

  1. K - S_1 \leq x must hold.
    If not, Alice will end up picking value x+1 when passing over it.
  2. Similarly, K - S_2 \leq x must hold, if not Bob will pick x+1.

This tells us that 2K \leq S_1 + S_2 + 2x = 2x + \frac{x\cdot (x+1)}{2} = 2x + K.
This then reduces to K \leq 2x, or \frac{x\cdot (x+1)}{2} \leq 2x.
Solving this inequality gives us x \leq 3.

This shows that the only potential values of K are with x = 1, 2, 3, giving K = 1, 3, 6 respectively.
K = 1 is disallowed by the problem, so only K = 3, 6 need to be considered.


This makes our task quite simple: compute the number of ways to fill in the permutation to make either K = 3 or K = 6 work.
In particular, this means that only the relative order of elements \leq 6 matters: once this is fixed, everything else can be freely rearranged.

Further simplifying our work is the fact that the K = 3 and K = 6 cases are disjoint, as noted above - which means that f(P) = 0 or f(P) = 1 for every permutation.
So, summing up f(P) across all P is equivalent to just finding the number of ways to fill in P to obtain f(P) = 1.

Let’s call a permutation Q of length 6 a “valid pattern”, if f(Q) = 1.
All valid patterns can be precomputed with brute force, since there are only 6! = 720 permutations of length 6, and each of them needs to be checked for only K = 3 or K = 6.

After precomputation, let’s fix a valid pattern Q, and count the number of ways to fill in the blanks of P such that Q appears of a subsequence of P.
This can be done using some simple combinatorics: consider i \lt j such that Q_i and Q_j both appear in P (with Q_i being before Q_j), while Q_k does not for every i \lt k \lt j.
Let there be M blanks between the positions of Q_i and Q_j in P.
Then, the (j-i-1) elements of Q between indices i and j must be placed in these blanks, but their order is fixed since we want Q to be a subsequence of the final permutation.
So, the number of ways of choosing their positions is just \binom{M}{j-i-1}.

The product of these binomial coefficients across all valid (i, j) pairs will give us the number of ways of fixing just the positions of the elements 1, 2, \ldots, 6 such that Q appears as a subsequence of P.
Once this is computed, all elements \gt 6 can be freely permuted into the remaining positions, so if there are F missing elements \gt 6, that’s an additional multiplier of F!.

Summing up this count across all valid patterns Q will give us the final answer.


Note that if N \lt 6, iterating through patterns of length 6 isn’t really valid.
However, in these cases, N is small enough that you can just iterate through all N! permutations and compute their results with brute force.

TIME COMPLEXITY:

\mathcal{O}(N + 6!\cdot 6) per testcase.

CODE:

Editorialist's code (PyPy3)
import itertools
def check(p, k):
    a, b = [], []
    ok = k
    for x in p:
        if x <= k:
            k -= x
            a.append(x)
    k = ok
    for x in reversed(p):
        if x <= k:
            k -= x
            b.append(x)
    b = b[::-1]
    return a == b

valid = []
valid3 = [[1, 3, 2], [2, 3, 1]]
for p in itertools.permutations(list(range(1, 7))):
    if check(p, 6) or check(p, 3): valid.append(list(p))


mod = 998244353
fac = list(range(200005))
fac[0] = 1
for i in range(1, 200005): fac[i] = fac[i-1] * i % mod
inv = fac[:]
for i in range(len(inv)): inv[i] = pow(inv[i], mod-2, mod)
def C(n, r):
    return fac[n] * inv[r] * inv[n-r] % mod

for _ in range(int(input())):
    n = int(input())
    p = [0] + list(map(int, input().split()))
    # n = 15
    # p = [0] + [-1] * n

    if n < 3:
        print(0)
        continue

    if n < 6: valid, valid3 = valid3, valid

    pos = [0]*(n+2)
    zer = [0]*(n+2)
    ct = 0
    for i in range(1, n+1):
        if p[i] > 0:
            pos[p[i]] = i
            zer[p[i]] = ct
        else:
            ct += 1
    zer[n+1] = ct
    pos[n+1] = n+1
    
    free = 0
    for i in range(len(valid[0]) + 1, n+1):
        if pos[i] == 0: free += 1
    
    ans = 0
    for pat in valid:
        cur = [0] + pat + [n+1]
        prv, ways = 0, 1
        for i in range(1, len(cur)):
            if pos[cur[i]] == 0: continue
            if pos[cur[i]] < pos[cur[prv]]:
                ways = 0
                break

            between = i - prv - 1
            have = zer[cur[i]] - zer[cur[prv]]
            if between > have:
                ways = 0
                break

            ways = ways * C(have, between) % mod
            prv = i
        ans += ways
    
    if n < 6: valid, valid3 = valid3, valid
    print(ans * fac[free] % mod)
Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 998244353;
const int N = 2e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

ll fact[N], ifact[N];

ll bexp(ll a, ll b) {
    a %= MOD;
    if (a == 0) return 0;

    ll res = 1;

    while (b) {
        if (b & 1) res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }

    return res;
}

ll invmod(ll a) {
    return bexp(a, MOD - 2);
}

ll ncr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}

ll npr(ll n, ll r) {
    if (n < 0 or r < 0 or n < r) return 0;
    return fact[n] * ifact[n - r] % MOD;
}

void precalc(ll n) {
    fact[0] = 1;
    rep1(i, n) fact[i] = fact[i - 1] * i % MOD;

    ifact[n] = invmod(fact[n]);
    rev(i, n - 1, 0) ifact[i] = ifact[i + 1] * (i + 1) % MOD;
}

vector<vector<ll>> good_perms[2];

void precalc(){
    for(auto n : {3,6}){
        vector<ll> p(n);
        iota(all(p),1);

        ll k = n;
        ll t = (n == 6);
        
        do{

            vector<bool> take1(n), take2(n);
            ll currk = k;

            rep(i,n){
                ll x = p[i];
                if(currk-x >= 0){
                    currk -= x;
                    take1[i] = 1;
                }
            }

            currk = k;

            rev(i,n-1,0){
                ll x = p[i];
                if(currk-x >= 0){
                    currk -= x;
                    take2[i] = 1;
                }
            }

            if(take1 == take2){
                good_perms[t].pb(p);
            }

        } while(next_permutation(all(p)));
    }
}

void solve(int test_case){
    ll n; cin >> n;
    vector<ll> a(n+5);
    rep1(i,n) cin >> a[i];

    vector<ll> pos(n+5,-1);
    rep1(i,n) if(a[i] != -1) pos[a[i]] = i;

    vector<ll> pref(n+5);
    rep1(i,n) pref[i] = pref[i-1]+(a[i] == -1);

    // fix ordering of [1..6], add contrib to ans
    ll ans = 0;

    rep(t,2){
        ll currk = 3*(t+1);
        if(currk > n) break;
        ll curr_ways = 0;

        trav(p,good_perms[t]){
            ll prev_pos = 0;
            ll pending_cnt = 0;
            ll ways = 1;

            rep(i,sz(p)){
                ll x = p[i];
                ll ind = pos[x];
                if(ind == -1){
                    pending_cnt++;
                }
                else{
                    if(ind < prev_pos){
                        ways = 0;
                    }
                    else{
                        ll avail_spots = pref[ind-1]-pref[prev_pos];
                        ll mul = ncr(avail_spots,pending_cnt);
                        ways = ways*mul%MOD;
                        prev_pos = ind;
                        pending_cnt = 0;
                    }
                }
            }

            {
                ll ind = n+1;
                ll avail_spots = pref[ind-1]-pref[prev_pos];
                ll mul = ncr(avail_spots,pending_cnt);
                ways = ways*mul%MOD;
                prev_pos = ind;
                pending_cnt = 0;
            }

            curr_ways = (curr_ways+ways)%MOD;
        }

        ll remain_vals = 0;
        for(int i = currk+1; i <= n; ++i){
            if(pos[i] == -1){
                remain_vals++;
            }
        }

        ans += curr_ways*fact[remain_vals];
        ans %= MOD;
    }

    cout << ans << endl;
}

int main()
{
    precalc();
    precalc(N-1);

    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}