COOKPERM - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Tia Shi Wei
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

Medium

PREREQUISITES:

GCD, GCD Convolution

PROBLEM:

Chef is bored cooking his regular dishes, so he decided to cook permutations instead. Since this is his first time cooking permutations, he takes some help.

Chef has two permutations A and B, each of size N. Chef cooks a permutation C of size N^2 by doing the following:

for i = 1 to N:
    for j = 1 to N:
        C[i + (j-1) * N] = A[i] + (B[j] - 1) * N

Chef is allowed to make some (maybe 0) moves on permutation C. In one move, Chef can:

  • Pick two indices 1 \leq i \lt j \leq N^2 and swap the elements C_i and C_j.

Find the minimum number of moves required to sort permutation C in ascending order.

EXPLANATION:

For any permutation P, we can build a graph G_P by adding a directed edge i \to P_i for all 1 \le i \le N. Then the number of swaps needed to sort the permutation is N - C, where C is the number of directed cycles G_P. However, it is clear that we cannot just build the graph for the permutation C defined above, so we need to count the number of cycles cleverly.

Because of the formula in the given code, we can intepret the permutation C as an N \times N square, where C_{i, j} = A_i + (B_j - 1) * N. This means the cell (i, j) has a directed edge to the cell (A_i, B_j). Therefore, if node i is in a cycle with length l_A in G_A, and node j is in a cycle with length l_B in G_B, then the cycle length of cell (i, j) in the graph G_C is lcm(l_A, l_B).

Therefore, for any cycle C_A with length l_A in G_A and any cycle C_B with length l_B in G_B, there are l_A \cdot l_B cells involved in the final permutation C, where each cycle has length lcm(l_A, l_B), which means the number of cycles in G_C that involves both C_A and C_B is \frac{l_A \cdot l_B}{lcm(l_A, l_B)} = gcd(l_A, l_B).

This leads to the following formula to count the number of cycles in G_C:

LA = array of lengths of all cycles in A
LB = array of lengths of all cycles in B
ans = 0
for u in LA:
    for v in LB:
        ans += gcd(u, v)
return ans

Obviously we cannot loop naively like this. However there are many ways to solve the issues:

  • We can do direct GCD convolution. There are many optimizations to do this, but I find the following works:
    • Calculate FA, where FA_i is the number of elements in LA divisible by i.
    • Similarly, calculate FB.
    • Calculate FC where FC_i = FA_i \cdot FB_i. Intuitively, FC_i is the number of pairs (u, v) in (LA, LB) such that gcd(u, v) is divisible by i.
    • Calculate C such that C_i is the number of pairs (u, v) in (LA, LB) such that gcd(u, v) = i. We can calculate C from FC using inclusion-exclusion.

This takes O(N \log N).

  • We notice that there are O(\sqrt{N}) unique values in LA (and similarly with LB), so we can loop over pair of unique values in (LA, LB) instead of over all values. This takes O(\sqrt{N} \cdot \sqrt{N} \cdot \log{N}).

TIME COMPLEXITY:

Time complexity is O(N \log N) per test case.

SOLUTION:

Setter's Solution

// Hallelujah, praise the one who set me free
// Hallelujah, death has lost its grip on me
// You have broken every chain, There's salvation in your name
// Jesus Christ, my living hope
#include <bits/stdc++.h> 
using namespace std;

template <class T>
inline bool mnto(T& a, T b) {return a > b ? a = b, 1 : 0;}
template <class T>
inline bool mxto(T& a, T b) {return a < b ? a = b, 1: 0;}
#define REP(i, s, e) for (int i = s; i < e; i++)
#define RREP(i, s, e) for (int i = s; i >= e; i--)
typedef long long ll;
typedef long double ld;
#define MP make_pair
#define FI first
#define SE second
typedef pair<int, int> ii;
typedef pair<ll, ll> pll;
#define MT make_tuple
typedef tuple<int, int, int> iii;
#define ALL(_a) _a.begin(), _a.end()
#define pb push_back
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<ii> vii;

#ifndef DEBUG
#define cerr if (0) cerr
#endif

#define INF 1000000005
#define LINF 1000000000000000005ll
#define MAXN 200005

int t;
int n;
int a[MAXN], b[MAXN];
int cnta[MAXN], cntb[MAXN];
bool vis[MAXN];
vii va, vb;

int main() {
#ifndef DEBUG
    ios::sync_with_stdio(0), cin.tie(0);
#endif
    cin >> t;
    while (t--) {
        cin >> n;
        va.clear(); vb.clear();
        REP (i, 0, n + 1) {
            cnta[i] = cntb[i] = 0;
            vis[i] = 0;
        }
        REP (i, 1, n + 1) {
            cin >> a[i];
        }
        REP (i, 1, n + 1) {
            cin >> b[i];
        }
        REP (i, 1, n + 1) {
            if (vis[i]) continue;
            vis[i] = 1;
            int u = i, l = 1;
            while (a[u] != i) {
                u = a[u];
                vis[u] = 1;
                l++;
            }
            cnta[l]++;
        }
        REP (i, 0, n + 1) {
            vis[i] = 0;
        }
        REP (i, 1, n + 1) {
            if (vis[i]) continue;
            vis[i] = 1;
            int u = i, l = 1;
            while (b[u] != i) {
                u = b[u];
                vis[u] = 1;
                l++;
            }
            cntb[l]++;
        }
        REP (i, 1, n + 1) {
            if (cnta[i] == 0) continue;
            va.pb(MP(i, cnta[i]));
        }
        REP (i, 1, n + 1) {
            if (cntb[i] == 0) continue;
            vb.pb(MP(i, cntb[i]));
        }
        ll sm = 0;
        for (auto [x, ox] : va) {
            for (auto [y, oy] : vb) {
                int g = __gcd(x, y);
                sm += (ll) g * ox * oy;
            }
        }
        ll ans = (ll) n * n - sm;
        cout << ans << '\n';
    }
    return 0;
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}

void readEOF(){
    assert(getchar()==EOF);
}
const ll mod=998244353;
const int N=2e5+1;
int n;
const int iu=1e5;
int a[N],b[N];
ll ca[N],cb[N];
bool vis[N];
int tot=1e6;
int c[N],d[N];
void solve(){
	n=readInt(1,min(200000,tot),'\n');
	for(int i=1; i<=n ;i++) ca[i]=cb[i]=0;
	for(int i=1; i<=n ;i++){
		if(i!=n) a[i]=readInt(1,n,' ');
		else a[i]=readInt(1,n,'\n');
	}
	for(int i=1; i<=n ;i++){
		if(i!=n) b[i]=readInt(1,n,' ');
		else b[i]=readInt(1,n,'\n');
	}
	for(int i=1; i<=n ;i++){
		c[i]=a[i],d[i]=b[i];
	}
	sort(c+1,c+n+1);sort(d+1,d+n+1);
	for(int i=1; i<=n ;i++){
		assert(c[i]==i);assert(d[i]==i);
	}
	for(int i=1; i<=n ;i++) vis[i]=false;
	for(int i=1; i<=n ;i++){
		if(vis[i]) continue;
		int x=i;
		int c=0;
		while(!vis[x]){
			vis[x]=true;
			c++;
			x=a[x];
		}
		ca[c]++;
	}
	for(int i=1; i<=n ;i++) vis[i]=false;
	for(int i=1; i<=n ;i++){
		if(vis[i]) continue;
		int x=i;
		int c=0;
		while(!vis[x]){
			vis[x]=true;
			c++;
			x=b[x];
		}
		cb[c]++;
	}
	for(int i=1; i<=n ;i++){
		for(int j=2*i; j<=n ;j+=i){
			ca[i]+=ca[j];cb[i]+=cb[j];
		}
	}
	for(int i=1; i<=n ;i++) ca[i]*=cb[i];
	for(int i=n; i>=1 ;i--){
		for(int j=2*i; j<=n ;j+=i){
			ca[i]-=ca[j];
		}
	}
	ll ans=1LL*n*n;
	for(int i=1; i<=n ;i++) ans-=ca[i]*i;
	cout << ans << '\n';
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int t;t=readInt(1,100000,'\n');while(t--) solve();
	readEOF();
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

vector<int> decompose(vector<int> &vec) {
    vector<int> ans(vec.size() + 1);
    for (int i = 0; i < vec.size(); i++) {
        if (vec[i] != -1) {
            int cnt = -1;
            for (int x = i; x != -1; cnt++) {
                int &xd = vec[x];
                x = vec[x]; xd = -1;
            }
            ans[cnt]++;
        }
    }
    return ans;
}

template<typename T>
void convolute(vector<T> &vec, bool inv) {
    int n = vec.size() - 1;
    if (!inv) {
        for (int i = 1; i <= n; i++) {
            for (int j = 2 * i; j <= n; j += i) {
                vec[i] += vec[j];
            }
        }
    } else {
        for (int i = n; i >= 1; i--) {
            for (int j = 2 * i; j <= n; j += i) {
                vec[i] -= vec[j];
            }
        }
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &v : a) {
            cin >> v; v--;
        }
        a = decompose(a);
        vector<int> b(n);
        for (int &v : b) {
            cin >> v; v--;
        }
        b = decompose(b);
        convolute(a, false); convolute(b, false);
        vector<long long> res(n + 1);
        for (int i = 1; i <= n; i++) {
            res[i] = 1LL * a[i] * b[i];
        }
        convolute(res, true);
        long long ans = 1LL * n * n;
        for (int i = 1; i <= n; i++) {
            ans -= res[i] * i;
        }
        cout << ans << '\n';
    }
}
1 Like

Why is length of a cycle cell (i, j) is equal to lcm(La, Lb)?