PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Setter: Tia Shi Wei
Tester: Harris Leung
Editorialist: Trung Dang
DIFFICULTY:
Medium
PREREQUISITES:
GCD, GCD Convolution
PROBLEM:
Chef is bored cooking his regular dishes, so he decided to cook permutations instead. Since this is his first time cooking permutations, he takes some help.
Chef has two permutations A and B, each of size N. Chef cooks a permutation C of size N^2 by doing the following:
for i = 1 to N:
for j = 1 to N:
C[i + (j-1) * N] = A[i] + (B[j] - 1) * N
Chef is allowed to make some (maybe 0) moves on permutation C. In one move, Chef can:
- Pick two indices 1 \leq i \lt j \leq N^2 and swap the elements C_i and C_j.
Find the minimum number of moves required to sort permutation C in ascending order.
EXPLANATION:
For any permutation P, we can build a graph G_P by adding a directed edge i \to P_i for all 1 \le i \le N. Then the number of swaps needed to sort the permutation is N - C, where C is the number of directed cycles G_P. However, it is clear that we cannot just build the graph for the permutation C defined above, so we need to count the number of cycles cleverly.
Because of the formula in the given code, we can intepret the permutation C as an N \times N square, where C_{i, j} = A_i + (B_j - 1) * N. This means the cell (i, j) has a directed edge to the cell (A_i, B_j). Therefore, if node i is in a cycle with length l_A in G_A, and node j is in a cycle with length l_B in G_B, then the cycle length of cell (i, j) in the graph G_C is lcm(l_A, l_B).
Therefore, for any cycle C_A with length l_A in G_A and any cycle C_B with length l_B in G_B, there are l_A \cdot l_B cells involved in the final permutation C, where each cycle has length lcm(l_A, l_B), which means the number of cycles in G_C that involves both C_A and C_B is \frac{l_A \cdot l_B}{lcm(l_A, l_B)} = gcd(l_A, l_B).
This leads to the following formula to count the number of cycles in G_C:
LA = array of lengths of all cycles in A
LB = array of lengths of all cycles in B
ans = 0
for u in LA:
for v in LB:
ans += gcd(u, v)
return ans
Obviously we cannot loop naively like this. However there are many ways to solve the issues:
- We can do direct GCD convolution. There are many optimizations to do this, but I find the following works:
- Calculate FA, where FA_i is the number of elements in LA divisible by i.
- Similarly, calculate FB.
- Calculate FC where FC_i = FA_i \cdot FB_i. Intuitively, FC_i is the number of pairs (u, v) in (LA, LB) such that gcd(u, v) is divisible by i.
- Calculate C such that C_i is the number of pairs (u, v) in (LA, LB) such that gcd(u, v) = i. We can calculate C from FC using inclusion-exclusion.
This takes O(N \log N).
- We notice that there are O(\sqrt{N}) unique values in LA (and similarly with LB), so we can loop over pair of unique values in (LA, LB) instead of over all values. This takes O(\sqrt{N} \cdot \sqrt{N} \cdot \log{N}).
TIME COMPLEXITY:
Time complexity is O(N \log N) per test case.
SOLUTION:
Setter's Solution
// Hallelujah, praise the one who set me free
// Hallelujah, death has lost its grip on me
// You have broken every chain, There's salvation in your name
// Jesus Christ, my living hope
#include <bits/stdc++.h>
using namespace std;
template <class T>
inline bool mnto(T& a, T b) {return a > b ? a = b, 1 : 0;}
template <class T>
inline bool mxto(T& a, T b) {return a < b ? a = b, 1: 0;}
#define REP(i, s, e) for (int i = s; i < e; i++)
#define RREP(i, s, e) for (int i = s; i >= e; i--)
typedef long long ll;
typedef long double ld;
#define MP make_pair
#define FI first
#define SE second
typedef pair<int, int> ii;
typedef pair<ll, ll> pll;
#define MT make_tuple
typedef tuple<int, int, int> iii;
#define ALL(_a) _a.begin(), _a.end()
#define pb push_back
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<ii> vii;
#ifndef DEBUG
#define cerr if (0) cerr
#endif
#define INF 1000000005
#define LINF 1000000000000000005ll
#define MAXN 200005
int t;
int n;
int a[MAXN], b[MAXN];
int cnta[MAXN], cntb[MAXN];
bool vis[MAXN];
vii va, vb;
int main() {
#ifndef DEBUG
ios::sync_with_stdio(0), cin.tie(0);
#endif
cin >> t;
while (t--) {
cin >> n;
va.clear(); vb.clear();
REP (i, 0, n + 1) {
cnta[i] = cntb[i] = 0;
vis[i] = 0;
}
REP (i, 1, n + 1) {
cin >> a[i];
}
REP (i, 1, n + 1) {
cin >> b[i];
}
REP (i, 1, n + 1) {
if (vis[i]) continue;
vis[i] = 1;
int u = i, l = 1;
while (a[u] != i) {
u = a[u];
vis[u] = 1;
l++;
}
cnta[l]++;
}
REP (i, 0, n + 1) {
vis[i] = 0;
}
REP (i, 1, n + 1) {
if (vis[i]) continue;
vis[i] = 1;
int u = i, l = 1;
while (b[u] != i) {
u = b[u];
vis[u] = 1;
l++;
}
cntb[l]++;
}
REP (i, 1, n + 1) {
if (cnta[i] == 0) continue;
va.pb(MP(i, cnta[i]));
}
REP (i, 1, n + 1) {
if (cntb[i] == 0) continue;
vb.pb(MP(i, cntb[i]));
}
ll sm = 0;
for (auto [x, ox] : va) {
for (auto [y, oy] : vb) {
int g = __gcd(x, y);
sm += (ll) g * ox * oy;
}
}
ll ans = (ll) n * n - sm;
cout << ans << '\n';
}
return 0;
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
long long readInt(long long l, long long r, char endd) {
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true) {
char g=getchar();
if(g=='-') {
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g&&g<='9') {
x*=10;
x+=g-'0';
if(cnt==0) {
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd) {
if(is_neg) {
x=-x;
}
assert(l<=x&&x<=r);
return x;
} else {
assert(false);
}
}
}
string readString(int l, int r, char endd) {
string ret="";
int cnt=0;
while(true) {
char g=getchar();
assert(g!=-1);
if(g==endd) {
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt&&cnt<=r);
return ret;
}
long long readIntSp(long long l, long long r) {
return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
return readString(l,r,' ');
}
void readEOF(){
assert(getchar()==EOF);
}
const ll mod=998244353;
const int N=2e5+1;
int n;
const int iu=1e5;
int a[N],b[N];
ll ca[N],cb[N];
bool vis[N];
int tot=1e6;
int c[N],d[N];
void solve(){
n=readInt(1,min(200000,tot),'\n');
for(int i=1; i<=n ;i++) ca[i]=cb[i]=0;
for(int i=1; i<=n ;i++){
if(i!=n) a[i]=readInt(1,n,' ');
else a[i]=readInt(1,n,'\n');
}
for(int i=1; i<=n ;i++){
if(i!=n) b[i]=readInt(1,n,' ');
else b[i]=readInt(1,n,'\n');
}
for(int i=1; i<=n ;i++){
c[i]=a[i],d[i]=b[i];
}
sort(c+1,c+n+1);sort(d+1,d+n+1);
for(int i=1; i<=n ;i++){
assert(c[i]==i);assert(d[i]==i);
}
for(int i=1; i<=n ;i++) vis[i]=false;
for(int i=1; i<=n ;i++){
if(vis[i]) continue;
int x=i;
int c=0;
while(!vis[x]){
vis[x]=true;
c++;
x=a[x];
}
ca[c]++;
}
for(int i=1; i<=n ;i++) vis[i]=false;
for(int i=1; i<=n ;i++){
if(vis[i]) continue;
int x=i;
int c=0;
while(!vis[x]){
vis[x]=true;
c++;
x=b[x];
}
cb[c]++;
}
for(int i=1; i<=n ;i++){
for(int j=2*i; j<=n ;j+=i){
ca[i]+=ca[j];cb[i]+=cb[j];
}
}
for(int i=1; i<=n ;i++) ca[i]*=cb[i];
for(int i=n; i>=1 ;i--){
for(int j=2*i; j<=n ;j+=i){
ca[i]-=ca[j];
}
}
ll ans=1LL*n*n;
for(int i=1; i<=n ;i++) ans-=ca[i]*i;
cout << ans << '\n';
}
int main(){
ios::sync_with_stdio(false);cin.tie(0);
int t;t=readInt(1,100000,'\n');while(t--) solve();
readEOF();
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;
vector<int> decompose(vector<int> &vec) {
vector<int> ans(vec.size() + 1);
for (int i = 0; i < vec.size(); i++) {
if (vec[i] != -1) {
int cnt = -1;
for (int x = i; x != -1; cnt++) {
int &xd = vec[x];
x = vec[x]; xd = -1;
}
ans[cnt]++;
}
}
return ans;
}
template<typename T>
void convolute(vector<T> &vec, bool inv) {
int n = vec.size() - 1;
if (!inv) {
for (int i = 1; i <= n; i++) {
for (int j = 2 * i; j <= n; j += i) {
vec[i] += vec[j];
}
}
} else {
for (int i = n; i >= 1; i--) {
for (int j = 2 * i; j <= n; j += i) {
vec[i] -= vec[j];
}
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
int t; cin >> t;
while (t--) {
int n; cin >> n;
vector<int> a(n);
for (int &v : a) {
cin >> v; v--;
}
a = decompose(a);
vector<int> b(n);
for (int &v : b) {
cin >> v; v--;
}
b = decompose(b);
convolute(a, false); convolute(b, false);
vector<long long> res(n + 1);
for (int i = 1; i <= n; i++) {
res[i] = 1LL * a[i] * b[i];
}
convolute(res, true);
long long ans = 1LL * n * n;
for (int i = 1; i <= n; i++) {
ans -= res[i] * i;
}
cout << ans << '\n';
}
}