CNTR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: hellolad
Tester: kingmessi
Editorialist: iceknight1093

DIFFICULTY:

Easy

PREREQUISITES:

Dynamic programming, combinatorics

PROBLEM:

You’re given an array A with some elements missing.
Define f(A) to be the number of increasing subsequences in A.
Compute the sum of f(A) across all ways of replacing the missing elements with integers in [1, K].

EXPLANATION:

First, let’s solve the problem when there are no missing elements, i.e. just “count the number of increasing subsequences of the given array”.
This is a fairly standard dynamic programming task, with one of its solutions being as follows:

  • Let f(i, x) denote the number of increasing subsequences among the first i elements, such that the last element of the subsequence is x.
  • Then, we have:
    f(i, x) = f(i-1, x) for x \neq A_i, of course.
    f(i, A_i) = 1 + f(i-1, 1) + f(i-1, 2) + \ldots + f(i-1, A_i).
    • This is because, to obtain a subsequence ending with the value A_i, we have three options.
    • First, we can take the singleton element [A_i] at index i.
    • Second, we can take a subsequence that ends with the value A_i but from among the first i-1 elements, which by definition can be done in f(i-1, A_i) ways.
    • Third, we can take a subsequence that ends with some value \lt A_i, and extend it by appending this copy of A_i to it.
      There are f(i-1, 1) + f(i-1, 2) + \ldots + f(i-1, A_i - 1) such subsequences ending with a smaller value.

There are \mathcal{O}(NK) states here, but at each index most of them require constant work, and only one (f(i, A_i)) requires \mathcal{O}(K) work, so the overall complexity remains \mathcal{O}(NK).


Now, we extend this solution to when there are missing elements.
For this, we need to change the definition of what we’re counting a bit, to account for some elements being undefined.

So, we redefine f(i, x) to be the number of increasing subsequences among the first i indices, ending with the value x, summed up across all possible ways of replacing elements among the first i elements.
For example, suppose A = [1, -1, 2]. Then, the subsequence [1, 2] with indices 1 and 3 will always be an increasing subseqeunce, no matter what the -1 is replaced by. So, it should be counted K times, which is why the f(i, x) has been redefined to mean this.

Let’s now go back to computing f(i, x).
There are now two possibilities: A_i = -1, and A_i \neq -1.


First, let’s look at A_i \neq -1.
In this case, the transition is pretty much the exact same as the initial version.
Specifically, we definitely have f(i, x) = f(i-1, x) for x \neq A_i,
However, f(i, A_i) is slightly different: it will equal

f(i-1, 1) + f(i-1, 2) + \ldots + f(i-1, A_i) + K^m

where m is the number of missing elements before index i.
Essentially, the 1 in the original transition has been replaced by K^m, with everything else remaining the same.

To see why: recall that there were three different ways of obtaining a subsequence ending with the value A_i.
Two of them remain the same: namely, taking an existing subsequence from before index i, and also extending an existing subsequence from before index i.
However, the third option was to take just the element [A_i] alone. This is independent of what the replacements before it were, which is why it must be counted separately for each of those replacement options — of which there are precisely K^m.


Next, we look at A_i = -1, which means there are K options for what value this element can take.

With this, we obtain

f(i, x) = K^m + f(i-1, 1) + f(i-1, 2) + \ldots + f(i-1, x-1) + K\cdot f(i-1, x)

This is because:

  • If we choose A_i = x, the singleton subsequence is available to us in K^m ways, just as in the previous case.
    Note here that m is the number of missing elements strictly before index i (meaning this one isn’t being counted).
  • If we choose A_i = x, then any previous subsequence ending with a smaller value is uniquely extended.
    This gives us f(i-1, 1) + f(i-1, 2) + \ldots + f(i-1, x-1) options.
  • Finally, there are subsequences that don’t include index i at all.
    There are f(i-1, x) such subsequences, and they don’t care what value is placed here so all K replacements are equally valid for them.
    This gives the K\cdot f(i-1, x) term.

Now, while it looks like we have to do \mathcal{O}(K) work for each x, observe that the only part that’s slow to compute is f(i-1, 1) + f(i-1, 2) + \ldots + f(i-1, x-1).
This, however, is just a prefix sum of the f(i-1, \cdot) array, and so can be precomputed and looked up in constant time.

This way, we end up doing only \mathcal{O}(K) work at this index across all x, so the complexity remains \mathcal{O}(NK).
The final answer is, of course, the sum of all f(N, x) values for 1 \leq x \leq K.

TIME COMPLEXITY:

\mathcal{O}(NK) per testcase.

CODE:

Editorialist's code (PyPy3)
mod = 998244353
for _ in range(int(input())):
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    
    dp = [0]*k
    pw = 1
    for x in a:
        if x == -1:
            s = 0
            for i in range(k):
                sv = dp[i]
                dp[i] = (k*dp[i] + s + pw) % mod
                s += sv
            pw = (pw * k) % mod
        else:
            x -= 1
            dp[x] = (pw + sum(dp[:x+1])) % mod
    print(sum(dp))
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long

ll mod=998244353;
ll dp[5001][5001];
ll dp2[5001][5001];

int main() {
    ll tt;
    cin>>tt;
    while(tt--){
        ll n,k;
    	cin>>n>>k;
    	ll a[n+1];    
        for(int i=1;i<=n;i++){
            cin>>a[i];
        }
        for(int i=0;i<=n;i++){
            for(int j=0;j<=5000;j++){
                dp[i][j]=dp2[i][j]=0;
            }
        }
        dp[0][0]=1;
        for(int i=1;i<=n;i++){
            if(a[i]==-1){
                for(int j=1;j<=k;j++){
                    dp2[i][j]=dp2[i][j-1]+dp[i-1][j-1];
                    dp2[i][j]%=mod;
                }
                for(int j=0;j<=5e3;j++){
                    dp[i][j]=dp[i-1][j]*k+dp2[i][j];
                    dp[i][j]%=mod;
                }
            }else{
                for(int j=0;j<a[i];j++){
                    dp2[i][a[i]]+=dp[i-1][j];
                    dp2[i][a[i]]%=mod;
                }
                for(int j=0;j<=5e3;j++){
                    dp[i][j]=dp[i-1][j]+dp2[i][j];
                    dp[i][j]%=mod;
                }
            }
        }
        ll ans=0;
        for(int i=1;i<=5e3;i++){
            ans+=dp[n][i];
            ans%=mod;
        }
        cout<<ans<<"\n";    
    }
}
2 Likes