PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: wuhudsm
Tester: sushil2006
Editorialist: iceknight1093
DIFFICULTY:
Easy
PREREQUISITES:
Cycle decomposition of permutations
PROBLEM:
For a permutation P, define f(P, A, B) to be the minimum cost of satisfying P_i = i for all i using the following operation:
- Choose an index i and modify P_i to any value, with a cost of A.
- Choose indices i and j and swap P_i with P_j, with a cost of B.
You’re given P. Compute
EXPLANATION:
First off, note that all “modify value” operations can be done after all swap operations have been done - if we modify a value and then swap it, we could just as well swap first and modify at the swapped position instead.
This means our sequence of operations will involve using some swaps to fix a few values, and then using the modify operation on everything else.
Using this idea, let’s analyze how to compute f(P, A, B) for fixed constants A and B.
The key question is: which elements should we fix using the swap operation?
To answer that, it’s helpful to understand how exactly a permutation can be sorted using swaps.
This is a very classical problem, with an equally well-known solution: the minimum number of swaps needed equals N minus the number of swaps in the cycle decomposition of the permutation.
The cycle decomposition of a permutation P is an undirected graph on N vertices with edges of the form (i, P_i) for all i.
If you’re unfamiliar with the cycle decomposition, this blog is a good read, though the important points are:
- The graph with edges (i, P_i) constructed above has every connected component be a cycle.
- A single swap either merges two cycles together or splits a cycle into two, depending on whether the elements being swapped belong to different cycles or not, respectively.
Since we’re working with the cycle decomposition perspective, let’s look at just a single cycle. How can we sort all its elements?
If the cycle has length x, we have two extremal options:
- Use the modify operation on each of the x elements, for an overall cost of A\cdot x.
- Use the swap operation x-1 times, for an overall cost of B\cdot (x-1).
Note that the optimal choice will definitely be one of these: it’s never going to be optimal to mix swaps and modifications.
This is because each swap can only fix one element, just like a modification operation - the only exception is when the cycle has length 2 and a single swap will fix both elements.
So, if we perform a mix of swaps and modifications, we’ll need x operations in total anyway, which is not going to be better than \min(A\cdot x, B\cdot (x-1)) seen above.
Now, given a cycle length x and the values A and B, we know the optimal cost is \min(A\cdot x, B\cdot (x-1)).
When there are multiple cycles, it’s easy to see that we never need to perform swaps to “merge” cycles - it’s better to just work on each cycle separately, the cost will be lower.
So, all cycles in the decomposition are independent of each other - meaning computing the sum of f(P, A, B) across all A, B is equivalent to computing the sum of f(C, A, B) across all A, B for each cycle C in the decomposition.
Now, consider a cycle of length x. Let’s compute the sum of costs of sorting only this cycle across all cost pairs (A, B).
For a fixed (A, B), we know the cost is \min(Ax, B\cdot (x-1)), so we want to find
Suppose we fix the value of B.
Then, \min(Ax, B\cdot (x-1)) = B\cdot (x-1) \iff B\cdot (x-1) \leq Ax \iff B\cdot \frac{x-1}{x} \le A
That is, for “large enough” values of A, the minimum cost is just always going to be B\cdot (x-1).
On the other hand, for all smaller values, the minimum cost is A\cdot x.
So, if the breakpoint we find is A_0, the overall cost is:
This can be computed in constant time: the first expression is just 1 + 2 + \ldots + (A_0 - 1) = \frac{A_0\cdot (A_0-1)}{2} multiplied by x, while the second is B\cdot (x-1) added a constant number of times.
From the above discussion, we know that if both the cycle length x and the swap cost B are fixed, the sum of costs across all A can be computed in constant time.
This lends itself to a solution in \mathcal{O}(N^2) time immediately: there can be \mathcal{O}(N) cycles, and for each of them there are \mathcal{O}(N) values of B to try.
We can now throw in a simple optimization: if there are k cycles with the same length x, it’s enough to perform the computation for just one of them, and then multiply the result by k.
After all, each of these cycles will give the same result.
This simple-looking optimization in fact improves even our complexity!
Note that the complexity of the algorithm is now N multiplied by the number of distinct cycle lengths.
Because the cycle sizes add up to N, and they’re all positive, there can be only \mathcal{O}(\sqrt N) distinct cycle sizes: after all 1 + 2 + \ldots + K = \frac{K\cdot (K+1)}{2}, and if K \gt \sqrt{2N} this quantity exceeds N.
Our final algorithm is thus as follows:
- Compute all the cycles of the given permutation.
- For each distinct cycle length x that exists, compute the answer in \mathcal{O}(N) by iterating through all values of B.
- Multiply the answer for x by the number of cycles of length x, and add that to the answer.
This leads to an \mathcal{O}(N\sqrt N) algorithm overall which is easily fast enough.
TIME COMPLEXITY:
\mathcal{O}(N \sqrt N) per testcase.
CODE:
Editorialist's code (C++)
// #include <bits/allocator.h>
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());
struct DSU {
private:
std::vector<int> parent_or_size;
public:
DSU(int n = 1): parent_or_size(n, -1) {}
int get_root(int u) {
if (parent_or_size[u] < 0) return u;
return parent_or_size[u] = get_root(parent_or_size[u]);
}
int size(int u) { return -parent_or_size[get_root(u)]; }
bool same_set(int u, int v) {return get_root(u) == get_root(v); }
bool merge(int u, int v) {
u = get_root(u), v = get_root(v);
if (u == v) return false;
if (parent_or_size[u] > parent_or_size[v]) std::swap(u, v);
parent_or_size[u] += parent_or_size[v];
parent_or_size[v] = u;
return true;
}
std::vector<std::vector<int>> group_up() {
int n = parent_or_size.size();
std::vector<std::vector<int>> groups(n);
for (int i = 0; i < n; ++i) {
groups[get_root(i)].push_back(i);
}
groups.erase(std::remove_if(groups.begin(), groups.end(), [&](auto &s) { return s.empty(); }), groups.end());
return groups;
}
};
int main()
{
ios::sync_with_stdio(false); cin.tie(0);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector p(n, 0);
for (int &x : p) cin >> x;
DSU D(n);
for (int i = 0; i < n; ++i)
D.merge(i, p[i]-1);
vector cyc(n+1, 0);
for (auto g : D.group_up()) cyc[g.size()] += 1;
const int mod = 998244353;
ll ans = 0;
for (int x = 1; x <= n; ++x) {
if (!cyc[x]) continue;
for (int a = 1; a <= n; ++a) {
// for which b is a*(x-1) <= bx?
// a*(x-1)/x <= b
int lim = (1ll*a*(x-1) + x - 1) / x;
// for lim <= b <= n, ans = a*(x-1)
// for 1 <= b < lim, ans = b*x
ll cost = 1ll*a*(x-1)*max(0, n - lim + 1);
cost += 1ll*x*lim*(lim-1)/2;
cost %= mod;
ans += cost * cyc[x] % mod;
}
}
cout << ans % mod << '\n';
}
}