DISTSUM - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

Given an array A of length N, define f(A) as follows.

  • Create a directed graph with edges i \to j for each (i, j) such that i \lt j \le i + A_i.
  • Then, f(A) equals the length of the shortest path from 1 to N.

Given N, compute the sum of f(A) across all N^N arrays of length N with elements in [1, N].

EXPLANATION

First, let’s understand how to compute f(A) for a fixed array A.

We start at position 1, and our first move can take us to any position in [2, 1+A_1].
Among these positions, it’s clearly optimal to jump to whichever one allows for the best possible second move - that is, we should choose whichever j in this range has the largest possible value of j + A_j.
This can be proved with a simple exchange argument.

Once the first move is made, the same argument applies to the second move, then the third, and so on - we always choose to jump to whichever index has the largest value of j + A_j across all j in the appropriate range.

This means we’ll have a sequence of ‘breakpoints’, say i_0 \lt i_1 \lt i_2 \lt \ldots such that:

  • i_0 = 1, and
  • For any x \ge 1, we have i_x = \max(j + A_j), where j \in [i_{x-2}+1, i_{x-1}] (we treat i_{-1} as 0 to make this work.)
    That is, the next breakpoint is determined purely by the values between the previous two breakpoints.

The above observation allows us to use dynamic programming.

Specifically, let’s define dp[i][x][y] to be the sum of answers across all ways in which:

  • The elements A_1, A_2, \ldots, A_i have been fixed.
  • The upcoming breakpoint is x (x \ge i).
  • The next breakpoint (which is determined by values till x) is y.

We now need to figure out transitions.
There are two possible cases: either index i-1 is a breakpoint, or it is not.


Case 1: Index i-1 is not a breakpoint.
Now, indices i and i-1 belong to the same ‘segment’ of indices between breakpoints.
So, if the state of breakpoints at index i-1 was (x_0, y_0), we must have:

  1. x_0 = x, because the values in this segment don’t affect their own breakpoint - only the next one.
  2. y_0 \le y, since the next breakpoint cannot decrease as we go along the segment.
    More precisely, y = \max(y_0, i + A_i) must hold.

So, we only need to look at states of the form dp[i-1][x][y_0] for y_0 \le y.
Now,

  • If y_0 \lt y, we’re forced to have y = i + A_i, which fixes the value of A_i uniquely.
    For this choice of A_i, any smaller y_0 will do; so we need to take the sum across all y_0 \lt y, which corresponds to a prefix sum of the dp[i-1][x] array.
  • If y_0 = y, then the value of A_i must just satisfy i+A_i \le y.
    This gives us a range of possible values for A_i, and any of them can be chosen.
    Thus, we add a constant times dp[i-1][x][y].

Note that there is actually one exceptional case, which is when y = N.
Here, it’s actually allowed to choose A_i such that i + A_i \gt N as well, so make sure to handle that properly (it’s just an extra constant multiple).

As long as prefix sums of the dp[i-1][x] array are known, this case can thus be processed in constant time.
Note that if values of y are processed in ascending order, the prefix sum can be maintained as you go, and they don’t need to be stored separately.


Case 2: Index i-1 is a breakpoint.
Here, note that the previous state must be exactly of the form dp[i-1][i-1][x], since our current breakpoint is x (and was determined by the previous segment), while the breakpoint for the previous segment was i-1.

Further, y = i+A_i must hold, since index i is currently the only one on the current segment, and so must itself determine the next breakpoint.
Once again, y = N is an edge case where i + A_i \gt N is allowed so handle that properly.

For the transition, we add dp[i-1][i-1][x], but this time we also need to add 1 to account for using an edge.
This 1 will be added exactly once for each possible array reaching this state.
Thus, we also need to add ct[i-1][i-1][x] to the value, where ct[i][x][y] is the number of ways to choose the first i elements such that the appropriate breakpoints are x and y.

ct[i][x][y] can itself be similarly calculated using dynamic programming, applying the two cases above.


This gives us a solution with a complexity of \mathcal{O}(N^3).
Note that while it also uses \mathcal{O}(N^3) memory, it’s trivial to optimize the usage to \mathcal{O}(N^2) since dp[i] depends only on dp[i-1], which depending on your implementation might actually be necessary for a speedup (since working with a 600\times 600\times 600 array might result in slowdowns.)

TIME COMPLEXITY:

\mathcal{O}(N^3) per testcase.

CODE:

Editorialist's code (PyPy3)
for _ in range(int(input())):
    n, p = map(int, input().split())
    
    ct = [[0]*(n+1) for _ in range(n+1)]
    dp = [[0]*(n+1) for _ in range(n+1)]
    ct[1][1] = 1
    
    for i in range(1, n):
        # not a breakpoint
        for x in range(i+1, n+1):
            pref_ct = pref_dp = 0
            for y in range(x, n+1):
                cur_ct, cur_dp = ct[x][y], dp[x][y]
                
                mul = y-i if y < n else y
                pre_mul = 1 if y < n else i+1
                ct[x][y] = (ct[x][y] * mul + pre_mul * pref_ct) % p
                dp[x][y] = (dp[x][y] * mul + pre_mul * pref_dp) % p
                
                pref_ct = (pref_ct + cur_ct) % p
                pref_dp = (pref_dp + cur_dp) % p
        
        # is a breakpoint
        for y in range(i, n+1):
            if ct[i][y] == 0: continue

            cur_ct, cur_dp = ct[i][y], dp[i][y]
            
            # small a[i]
            if i+1 < y < n:
                ct[y][y] = (ct[y][y] + cur_ct * (y-i)) % p
                dp[y][y] = (dp[y][y] + (cur_ct + cur_dp) * (y-i)) % p
            
            # larger a[i], but can't reach n
            start = y + 1 if y > i + 1 else i + 1
            for k in range(start, n):
                ct[k][k] = (ct[k][k] + cur_ct) % p
                dp[k][k] = (dp[k][k] + (cur_ct + cur_dp)) % p
            
            # reach n
            mul = i+1 if y < n else n
            ct[n][n] = (ct[n][n] + cur_ct * mul) % p
            dp[n][n] = (dp[n][n] + (cur_ct + cur_dp) * mul) % p
    ans = (dp[n][n] * n) % p
    print(ans)
1 Like