SORTSET7 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07
Tester: satyam_343
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Combinatorics

PROBLEM:

For a permutation P of length N and a parameter K, define f(P, K) as follows.

  • Let S be a set of pairs (u, v) with 1 \le u \lt v \le N.
  • Suppose it’s possible to sort P by utilizing only swaps of the form (u, v) for (u, v) \in S.
  • Define \text{score}(S) = \sum_{(u, v) \in S} (u + v + K).
  • f(P, K) is then defined to be the minimum possible score across all valid sets S.

You’re given N and K.
Compute the sum of f(P, K) across all permutations P.

EXPLANATION:

First, we need to understand when exactly it’s possible to sort a permutation using some set of swaps.

Perhaps the simplest way to see this, is to look at things in terms of graphs.
Let S denote a set of swaps.
Consider a graph on N vertices, with the undirected edge (u, v) existing if and only if (u, v) \in S.

Observe that in this graph, if x and y lie in the same connected component, then we can swap the values at indices x and y via a sequence of swaps (take any path x \to y and swap along this path, then swap along the reverse of the same path).
On the other and, if x and y lie in different components, then the value initially at position x can never be brought to position y no matter how many swaps are made.

To sort the permutation, we want the value at index i to end up at index P_i, for every 1 \le i \le N.
This means that i and P_i must be in the same connected component of the above graph.
It’s not hard to see that this condition is both necessary and sufficient (necessary because otherwise reachability is broken; sufficient because using the path + reverse path algorithm from earlier gives an explicit sequence of swaps that sorts P.)


Now, let’s look at some fixed permutation P, and try to decide which swaps we want to use.

We have N pairs of the form (i, P_i), and we need to ensure that both elements of each pair are in the same component.
This naturally lends itself to looking at the cycle decomposition of the permutation (if you’re unaware of what this is, read this.)

So, let’s look at a single cycle of the permutation.
Let the cycle have length L, and its elements be x_1, x_2, \ldots, x_L.

Due to the connectivity constraint, the swaps we choose must ensure that all the x_i values lie in the same connected component (of the swap-graph we considered initially.)
There are now two options available to us: either make these L elements form their own component; or have them connected via some “outside” vertices.


First, let’s look at the case where these elements form their own isolated component.
Since the cost of edge (u, v) is u+v+K, it makes sense to prioritize using lower-indexed values.
In particular, if m = \min(x_1, \ldots, x_L) is the minimum element in the cycle, then the optimal solution is to just have all edges of the form (m, x_i) for x_i \ne m.
It’s easy to prove that this is optimal, given that we’re not using any outside vertices: for example, you can run any MST algorithm on this set of elements and see that this is exactly the set of edges that will be chosen.

The cost of this set of swaps is then equal to

(m + x_1 + K) + (m + x_2 + K) + \ldots + (m + x_L + K) - (m + m + K)

since we can connect m to every x_i other than itself. This becomes

(x_1 + \ldots + x_L) + (L-1)\cdot (m+K) - m

Next, we consider the case where some “outside” vertex is used.

Here, it can be proved that the optimal solution is to just connect every x_i to 1 to obtain minimal cost.
(Proof outline: if some x_i is connected to a not-1 outside vertex, connect it to 1 instead and that lowers cost; then for all other vertices connecting to 1 lowers cost while maintaining connectivity.)

Connecting every vertex to 1 has a cost of

(x_1 + 1 + K) + (x_2 + 1 + K) + \ldots + (x_L + 1 + K)

which reduces to

(x_1 + \ldots + x_L) + (1+K)\cdot L

Observe that either way we obtain a cost of (x_1 + \ldots + x_L).
The only choice between schemes is then either (1+K)\cdot L or (L-1)\cdot (m+K) - m, whichever is smaller.


Since the (x_1 + \ldots + x_L) part always appears, when summed up across all cycles in a permutation this simply becomes
(1 + 2 + \ldots + N) = \frac{N\cdot (N+1)}{2}

This cost appears for every permutation, so we can simply add N! \times \frac{N\cdot (N+1)}{2} to the answer to begin with and forget about this part entirely.

Now, for a cycle with length L and minimum element m, the associated cost is simply

\min(L\cdot (1+K), (L-1)\cdot (m+K) - m)

This depends purely on the cycle parameters itself.
So, we can compute the contribution of a single cycle and multiply it by the number of permutations it appears in; summing this up across all cycles will give us the overall answer.


Suppose we fix L, the length of the cycle.
Observe that the quantity \min(L\cdot (1+K), (L-1)\cdot (m+K) - m) will equal L\cdot (1+K) for “large enough m”, and (L-1)\cdot (m+K) - m for “small” m.

Let m_0 be this threshold value enforcing the split.
We solve for m \le m_0 and m \gt m_0 separately.
(Finding m_0 can be done algebraically, or just use binary search and let the computer do the work for you.)

First, let’s look at m \le m_0.

  • The value under consideration is (L-1)\cdot (m+K) - m.
  • There are \binom{N-m}{L-1} \cdot (L-1)! possible cycles of length N with m as the minimum element (choose the other L-1 elements, then permute them cyclically).
  • Each such cycle appears in (N-L)! permutations overall.

Thus, we want to compute the sum

\sum_{m=1}^{m_0} \binom{N-m}{L-1} \cdot \left((L-1)\cdot (m+K) - m \right)

(multiplied by (L-1)! (N-L)! which is a constant).

There are two ways to do this quickly.

Method 1

Simply brute force iterate through all m \le m_0. This is fast enough!

This is because if you compute the value of m_0, it will look like \frac{L+K}{L-2} which is \le \frac{2N}{L-2} because K \le N.
So, summing up across all L is bounded by \mathcal{O}(N\log N) via the harmonic summation, which is fast enough for us.

Method 2

It’s also possible to compute the required value in constant time.
Let’s break the sum into two parts.
One will look like some constant multiplied by \sum_{m \le m_0} \binom{N-m}{L-1}, while the other will be some constant multiplied by \sum_{m \le m_0} \binom{N-m}{L-1}\cdot m

Both of these can be computed in constant time.
Range sums of binomial coefficients (with varying “numerator” and fixed “denominator”) can be computed using the hockey stick identity, while range sums of the form \sum x\binom{x}{L} can also be computed via an identity derived similarly: it might be helpful to note that

\sum_{x=0}^n x\binom{x}{L} = (L+1)\cdot \binom{n+2}{L+2} - \binom{n+1}{L+1}

In any case, the value for all m \le m_0 can be computed quickly enough.


Next, let’s look at m \gt m_0.
Here, the cost of every such m is just L\cdot (K+1), which is a constant.
The computation for number of cycles and their contribution remains the same, so we want to compute

\sum_{m\gt m_0} \binom{N-m}{L-1}

multiplied by (L-1)! (N-L)! \cdot L \cdot (K+1), and the above sum can be computed in constant time using the hockey stick identity.

This allows us to process a single value of L with a constant number of binomial coefficient computations (or \mathcal{O}(N\log N) overall, depending on what you choose to do for m \le m_0); so simply running this for all L and adding up the answers is fast enough.

TIME COMPLEXITY:

\mathcal{O}(N) or \mathcal{O}(N\log N) per testcase.

CODE:

Editorialist's code (PyPy3)
mod = 998244353
N = 10**6
fac = [1]*(N+1)
for i in range(1, N+1): fac[i] = i * fac[i-1] % mod
ifac = fac[:]
ifac[-1] = pow(ifac[-1], mod-2, mod)
for i in reversed(range(N)): ifac[i] = ifac[i+1] * (i+1) % mod
def C(n, r):
    if n < r or r < 0: return 0
    return fac[n] * ifac[r] * ifac[n-r] % mod

def rsum(l, r, k): # C(l, k) + C(l+1, k) + ... + C(r, k)
    return C(r+1, k+1) - C(l, k+1)

for _ in range(int(input())):
    n, k = map(int, input().split())

    ans = (n*(n+1)//2) * fac[n] % mod
    for L in range(1, n+1):
        lo, hi = 1, n+1
        while lo < hi:
            mid = (lo + hi)//2
            if min(L*(1+k), (L-1)*(mid+k)-mid) == L*(1+k): hi = mid
            else: lo = mid+1
        
        m0 = lo
        # m < m0
        for m in range(1, m0):
            cost = ((L-1)*(m+k)-m) % mod
            ans += cost * C(n-m, L-1) % mod * fac[L-1] % mod * fac[n-L] % mod

        # m >= m0
        ans += L*(1+k) % mod * fac[L-1] % mod * fac[n-L] % mod * rsum(0, n-m0, L-1)
        ans %= mod
    print(ans % mod)
1 Like