GCD, GCD Convolution


Chef is bored cooking his regular dishes, so he decided to cook permutations instead. Since this is his first time cooking permutations, he takes some help.

Chef has two permutations A and B, each of size N. Chef cooks a permutation C of size N^2 by doing the following:

for i = 1 to N:
    for j = 1 to N:
        C[i + (j-1) * N] = A[i] + (B[j] - 1) * N

Chef is allowed to make some (maybe 0) moves on permutation C. In one move, Chef can:

  • Pick two indices 1 \leq i \lt j \leq N^2 and swap the elements C_i and C_j.

Find the minimum number of moves required to sort permutation C in ascending order.


For any permutation P, we can build a graph G_P by adding a directed edge i \to P_i for all 1 \le i \le N. Then the number of swaps needed to sort the permutation is N - C, where C is the number of directed cycles G_P. However, it is clear that we cannot just build the graph for the permutation C defined above, so we need to count the number of cycles cleverly.

Because of the formula in the given code, we can intepret the permutation C as an N \times N square, where C_{i, j} = A_i + (B_j - 1) * N. This means the cell (i, j) has a directed edge to the cell (A_i, B_j). Therefore, if node i is in a cycle with length l_A in G_A, and node j is in a cycle with length l_B in G_B, then the cycle length of cell (i, j) in the graph G_C is lcm(l_A, l_B).

Therefore, for any cycle C_A with length l_A in G_A and any cycle C_B with length l_B in G_B, there are l_A \cdot l_B cells involved in the final permutation C, where each cycle has length lcm(l_A, l_B), which means the number of cycles in G_C that involves both C_A and C_B is \frac{l_A \cdot l_B}{lcm(l_A, l_B)} = gcd(l_A, l_B).

This leads to the following formula to count the number of cycles in G_C:

LA = array of lengths of all cycles in A
LB = array of lengths of all cycles in B
ans = 0
for u in LA:
    for v in LB:
        ans += gcd(u, v)
return ans

Obviously we cannot loop naively like this. However there are many ways to solve the issues:

  • We can do direct GCD convolution. There are many optimizations to do this, but I find the following works:
    • Calculate FA, where FA_i is the number of elements in LA divisible by i.
    • Similarly, calculate FB.
    • Calculate FC where FC_i = FA_i \cdot FB_i. Intuitively, FC_i is the number of pairs (u, v) in (LA, LB) such that gcd(u, v) is divisible by i.
    • Calculate C such that C_i is the number of pairs (u, v) in (LA, LB) such that gcd(u, v) = i. We can calculate C from FC using inclusion-exclusion.

This takes O(N \log N).

  • We notice that there are O(\sqrt{N}) unique values in LA (and similarly with LB), so we can loop over pair of unique values in (LA, LB) instead of over all values. This takes O(\sqrt{N} \cdot \sqrt{N} \cdot \log{N}).


Time complexity is O(N \log N) per test case.


Setter's Solution

// Hallelujah, praise the one who set me free
// Hallelujah, death has lost its grip on me
// You have broken every chain, There's salvation in your name
// Jesus Christ, my living hope
#include <bits/stdc++.h> 
using namespace std;

template <class T>
inline bool mnto(T& a, T b) {return a > b ? a = b, 1 : 0;}
template <class T>
inline bool mxto(T& a, T b) {return a < b ? a = b, 1: 0;}
#define REP(i, s, e) for (int i = s; i < e; i++)
#define RREP(i, s, e) for (int i = s; i >= e; i--)
typedef long long ll;
typedef long double ld;
#define MP make_pair
#define FI first
#define SE second
typedef pair<int, int> ii;
typedef pair<ll, ll> pll;
#define MT make_tuple
typedef tuple<int, int, int> iii;
#define ALL(_a) _a.begin(), _a.end()
#define pb push_back
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<ii> vii;

#endif
#define cerr if (0) cerr

#define INF 1000000005
#define LINF 1000000000000000005ll
#define MAXN 200005

int t;
int n;
int a[MAXN], b[MAXN];
int cnta[MAXN], cntb[MAXN];
bool vis[MAXN];
vii va, vb;

int main() {
#ifndef DEBUG
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> t;
    while (t--) {
        cin >> n;
        va.clear(); vb.clear();
        REP (i, 0, n + 1) {
            cnta[i] = cntb[i] = 0;
            vis[i] = 0;
        REP (i, 1, n + 1) {
            cin >> a[i];
        REP (i, 1, n + 1) {
            cin >> b[i];
        REP (i, 1, n + 1) {
            if (vis[i]) continue;
            vis[i] = 1;
            int u = i, l = 1;
            while (a[u] != i) {
                u = a[u];
                vis[u] = 1;
        REP (i, 0, n + 1) {
            vis[i] = 0;
        REP (i, 1, n + 1) {
            if (vis[i]) continue;
            vis[i] = 1;
            int u = i, l = 1;
            while (b[u] != i) {
                u = b[u];
                vis[u] = 1;
        REP (i, 1, n + 1) {
            if (cnta[i] == 0) continue;
            va.pb(MP(i, cnta[i]));
        REP (i, 1, n + 1) {
            if (cntb[i] == 0) continue;
            vb.pb(MP(i, cntb[i]));
        ll sm = 0;
        for (auto [x, ox] : va) {
            for (auto [y, oy] : vb) {
                int g = __gcd(x, y);
                sm += (ll) g * ox * oy;
        ll ans = (ll) n * n - sm;
        cout << ans << '\n';
    return 0;
Tester's Solution
using namespace std;
typedef long long ll;
#define fi first
#define se second
long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
        if('0'<=g&&g<='9') {
            if(cnt==0) {
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
            return x;
        } else {
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        if(g==endd) {
    return ret;
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
string readStringSp(int l, int r) {
    return readString(l,r,' ');

void readEOF(){
const ll mod=998244353;
const int N=2e5+1;
int n;
const int iu=1e5;
int a[N],b[N];
ll ca[N],cb[N];
bool vis[N];
int tot=1e6;
int c[N],d[N];
void solve(){
	for(int i=1; i<=n ;i++) ca[i]=cb[i]=0;
	for(int i=1; i<=n ;i++){
		if(i!=n) a[i]=readInt(1,n,' ');
		else a[i]=readInt(1,n,'\n');
	for(int i=1; i<=n ;i++){
		if(i!=n) b[i]=readInt(1,n,' ');
		else b[i]=readInt(1,n,'\n');
	for(int i=1; i<=n ;i++){
	for(int i=1; i<=n ;i++){
	for(int i=1; i<=n ;i++) vis[i]=false;
	for(int i=1; i<=n ;i++){
		if(vis[i]) continue;
		int x=i;
		int c=0;
	for(int i=1; i<=n ;i++) vis[i]=false;
	for(int i=1; i<=n ;i++){
		if(vis[i]) continue;
		int x=i;
		int c=0;
	for(int i=1; i<=n ;i++){
		for(int j=2*i; j<=n ;j+=i){
	for(int i=1; i<=n ;i++) ca[i]*=cb[i];
	for(int i=n; i>=1 ;i--){
		for(int j=2*i; j<=n ;j+=i){
	ll ans=1LL*n*n;
	for(int i=1; i<=n ;i++) ans-=ca[i]*i;
	cout << ans << '\n';
int main(){
	int t;t=readInt(1,100000,'\n');while(t--) solve();
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

vector<int> decompose(vector<int> &vec) {
    vector<int> ans(vec.size() + 1);
    for (int i = 0; i < vec.size(); i++) {
        if (vec[i] != -1) {
            int cnt = -1;
            for (int x = i; x != -1; cnt++) {
                int &xd = vec[x];
                x = vec[x]; xd = -1;
    return ans;

template<typename T>
void convolute(vector<T> &vec, bool inv) {
    int n = vec.size() - 1;
    if (!inv) {
        for (int i = 1; i <= n; i++) {
            for (int j = 2 * i; j <= n; j += i) {
                vec[i] += vec[j];
    } else {
        for (int i = n; i >= 1; i--) {
            for (int j = 2 * i; j <= n; j += i) {
                vec[i] -= vec[j];

int main() {
    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &v : a) {
            cin >> v; v--;
        a = decompose(a);
        vector<int> b(n);
        for (int &v : b) {
            cin >> v; v--;
        b = decompose(b);
        convolute(a, false); convolute(b, false);
        vector<long long> res(n + 1);
        for (int i = 1; i <= n; i++) {
            res[i] = 1LL * a[i] * b[i];
        convolute(res, true);
        long long ans = 1LL * n * n;
        for (int i = 1; i <= n; i++) {
            ans -= res[i] * i;
        cout << ans << '\n';
Why is length of a cycle cell (i, j) is equal to lcm(La, Lb)?