MAGPER - Editorial

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Rami

Tester: Teja Vardhan Reddy

Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Observation, Segment Tree with Lazy Propagation.

PROBLEM:

Given two permutations A and B of length N each. Let’s denote a(x) position p such that A_p = x and b(x) return p such that B_p = x let’s denote the distance between them as \sum_{i = 1}^{N} |a(i)-b(i)|.

There are Q updates on permutations, each one being one of the following.

  • Shift: Cyclically shift permutation A to the left by z i.e. move z elements from the beginning at the end.
  • Swap: Swap values at position x and position b in permutation B.

After each update, print the distance between A and B.

Important: This editorial assumes 0-based indexing.

NOT SO QUICK EXPLANATION

  • Each value in permutation contributes independently to answer, so we can add/remove contribution for each value independently of other values.

  • Build a segment tree where ith leaf stores the difference between current permutation B and ith shift of permutation A.

  • For value v, if x and y denote its position in permutation A and B respectively, we can consider cases as

    1. x < y
      Contribution of v to each shift is affected as follows:
      • Cyclic shifts numbered in range [0, x] are affected by values y-x, y-x+1, y-x+2 \ldots
      • Cyclic shifts numbered in range [x+1, x+n-y-1] are affected by values n-y-1, n-y-2, n-y-3 \ldots
      • Cyclic shifts numbered in range [x+n-y, n-1] are affected by values 0, 1, 2 \ldots
    2. x = y
      Contribution of v to each shift is affected as follows:
      • Cyclic shifts numbered in range [0, x] are affected by values 0, 1, 2 \ldots
      • Cyclic shifts numbered in range [x+1, n-1] are affected by values n-x-1, n-x-2, n-x-3 \ldots
    3. x > y
      Contribution of v to each shift is affected as follows:
      • Cyclic shifts numbered in range [0, x-y-1] are affected by values x-y, x-y-1, x-y-2 \ldots
      • Cyclic shifts numbered in range [x-y, x] are affected by values 0, 1, 2 \ldots
      • Cyclic shifts numbered in range [x+1, n-1] are affected by values n-y-1, n-y-2, n-y-3 \ldots.
  • For each Swap, we can first remove contribution of values involved in swap, make the swap, and then add their contribution back according to new positions. For each cyclic shift, we just need to keep track of the shift of A currently under consideration.

DETAILED EXPLANATION

This explanation would be mostly explaining the points of Not so quick explanation and why they work. So don’t skip directly here. Trying to understand the above explanation is an exercise before we begin.

Firstly, by the definition of distance as given in the problem, we can see, that contribution from single value is not dependent upon other elements, it only depends upon N and position of this value in both permutations. So, we can handle contribution for each value separately and take its sum at the end.

Secondly, Let us see how a cyclic shift affects distance. Consider the following permutations.

_ _ _ 1 _
_ 1 _ _ _

For 1, a(1) = 3 and b(1) = 1. It contributes 2 to first shift, 1 to second shift and 0 to third shift, since till now, 1 in A is coming closer to 1 in B due to cyclic shifts. This is the case where x = 3 > y = 1, covered in first sub-point. This happends till v moves from x to y. This happens in shift numbered x-y.

But now, in the next shift, the distance actually increases as 1 moves from position$1$ to position 0. This process continues till 1 reaches the start of the array, which happens in shift numbered x. This is covered in second sub-point under x > y.

Finally, after x shifts, the value v gets moved to the end of the array, so its distance from position y is n-1-y. In the next shift, it again moves closer to y, reducing the distance to n-y-2 and so on. This is covered in sub-point 3 under x > y case.

Hence, we have seen how a value v at positions x and y in permutation A and B respectively affects the difference for each shift of A for x > y. We can similarly work out different cases for x = y and x < y too.

Hint for working out cases

Take care of the following shifts.

  • The shift where the value is moved from the first position to the last position of A
  • The v in this shift and permutation B have the same position. Till this shift, the value was moving closer, but from the next shift, the value shall move farther from current position y in B.

These shifts act as split points, otherwise, all the consecutive shifts are affected in series.

Now, we know how each value contributes to the difference for each shift of A. So, for swapping, we can just remove the contribution of values to be swapped, perform the swap, and then add back the contribution of values, on the basis of their new positions. This is all for swapping.

Since we have differences for all shifts calculated, we can just keep track of shift operations to find out the current shift of A and print the difference for current shift.

Now, we need a data structure which can support following operations over array T of size N.

  • For every i in range [l, r], increase T_i by a+(l-i)*d for specified l, r, a and d.
  • Return T_p for specified p.

We can use Segment Tree with Lazy propagation, as explained here and here and at many places I didn’t link here.

The reason lazy updates work here is, that consider two updates covering current node, such that first update adds a, a+d, a+2*d \ldots and second update adds b, b+e, b+2*e \ldots to each position in current range.

After both updates, the values are increased by a+d, a+d+b+e, a+d+2*(b+e) \ldots which is same, if we had one update with a+d as first term and d+e as common difference. Hence, for each node, we can maintain two lazy values, the first term and the common difference.

Rest is implementation, which you may refer from solution (added comments in my solution)$

RELATED PROBLEMS

ADDMUL
SEGSQRSS
CF 446-C
CF 145-E

TIME COMPLEXITY

The time complexity is O((N+Q)*log(N)) per test case.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
#define ll  long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sc second
#define fr first
using namespace std;

const int N = 1e6+100;

int n,q,a[N],b[N],inA[N],inB[N];
ll dp[N*4],lazDw[N*4],nmDw[N*4],nmUp[N*4],lazUp[N*4];

void up(int p,int l,int h){
	int m = (l+h)/2;
	dp[2*p] += lazUp[p];
	lazUp[2*p] += lazUp[p];
	nmUp[2*p] += nmUp[p];

	dp[2*p+1] += lazUp[p] + nmUp[p]*(m-l+1);
	lazUp[2*p+1] += lazUp[p]+ nmUp[p]*(m-l+1);
	nmUp[2*p+1] += nmUp[p];

	lazUp[p] = nmUp[p] = 0;
}

void dw(int p,int l,int h){
	int m = (l+h)/2;
	dp[2*p] += lazDw[p];
	lazDw[2*p] += lazDw[p];
	nmDw[2*p] += nmDw[p];

	dp[2*p+1] += lazDw[p] - nmDw[p]*(m-l+1);
	lazDw[2*p+1] += lazDw[p]- nmDw[p]*(m-l+1);
	nmDw[2*p+1] += nmDw[p];

	lazDw[p] = nmDw[p] = 0;
}

void addUp(int a,int b,int x,int adOrDl,int l=0,int h=n-1,int p = 1){
	if(a == l && b == h){
	    dp[p] += x*adOrDl;
	    lazUp[p] += x*adOrDl;
	    nmUp[p] += adOrDl;
	    return ;
	}
	up(p,l,h);
	dw(p,l,h);
	int m = (l+h)/2;
	if(b <= m)
	    addUp(a,b,x,adOrDl,l,m,2*p);
	else if(a > m)
	    addUp(a,b,x,adOrDl,m+1,h,2*p+1);
	else{
	    addUp(a,m,x,adOrDl,l,m,2*p);
	    addUp(m+1,b,x+(m-a+1),adOrDl,m+1,h,2*p+1);
	}
}

void addDw(int a,int b,int x,int adOrDl,int l=0,int h=n-1,int p = 1){
	if(a == l && b == h){
	    dp[p] += x*adOrDl;
	    lazDw[p] += x*adOrDl;
	    nmDw[p] += adOrDl;
	    return ;
	}
	up(p,l,h);
	dw(p,l,h);
	int m = (l+h)/2;
	if(b <= m)
	    addDw(a,b,x,adOrDl,l,m,2*p);
	else if(a > m)
	    addDw(a,b,x,adOrDl,m+1,h,2*p+1);
	else{
	    addDw(a,m,x,adOrDl,l,m,2*p);
	    addDw(m+1,b,x-(m-a+1),adOrDl,m+1,h,2*p+1);
	}
}

ll cal(int in,int l=0,int h=n-1,int p = 1){
	if(h == l ){
	    return dp[p];
	}
	up(p,l,h);
	dw(p,l,h);
	int m = (l+h)/2;
	if(in <= m)
	    return cal(in,l,m,2*p);
	return cal(in,m+1,h,2*p+1);
}

void addVal(int x,int adOrDl){
	int ina = inA[x];
	int inb = inB[x];
	if(ina > inb){
	    int len = (ina - inb);
	    if(len)
	        addDw(0,len,len,adOrDl);
	    int st = len+1;
	    len = inb;
	    if(len)
	        addUp(st,st+len-1,1,adOrDl);
	    st += len;
	    len = n - inb-1;
	    if(st <= n-1)
	        addDw(st,n-1,len,adOrDl);
	}
	else{
	    int len = ina+1;
	    if(len)
	        addUp(0,len-1,(inb - ina),adOrDl);
	    int st = len;
	    len = n-inb-1;
	    if(len)
	        addDw(st,st+len-1,len,adOrDl);
	    st += len;
	    if(st <= n-1)
	        addUp(st,n-1,0,adOrDl);
	}
}

int main()  {
	int t;
	cin>>t;
	while(t--){
	    scanf("%d%d",&n,&q);
	    for(int i=0 ;i <=n*4 ; i++){
	        dp[i] = lazUp[i] = lazDw[i] = nmUp[i] = nmDw[i] = 0;
	    }
	    for(int i=0 ;i <n ;i ++){
	        scanf("%d",&a[i]);
	        inA[a[i]] = i;
	    }
	    for(int i=0 ;i <n ;i ++){
	        scanf("%d",&b[i]);
	        inB[b[i]] = i;
	    }
	    for(int i=1 ;i <=n ;i ++)
	        addVal(i,1);
	    int ty,x,y,sh = 0;
	    while(q--){
	       //printf("%lld\n",cal(sh));
	        scanf("%d%d",&ty,&x);
	        if(ty == 1){
	            sh += x;
	            sh %= n;
	        }
	        else{
	            scanf("%d",&y);
	            x--;
	            y--;
	            addVal(b[x],-1);
	            addVal(b[y],-1);
	            swap(b[x],b[y]);
	            swap(inB[b[x]],inB[b[y]]);
	            addVal(b[x],1);
	            addVal(b[y],1);
	        }
	        printf("%lld\n",cal(sh));
	    }
	}
	return 0;
}
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp> 
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std;
using namespace __gnu_pbds;
 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define sz(a) a.size()
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (1000*1000*1000+7)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
// find_by_order()  // order_of_key
typedef tree<
int,
null_type,
less<int>,
rb_tree_tag,
tree_order_statistics_node_update>
ordered_set;
 
#define int ll
int a[1234567],b[1234567],pos1[1234567],pos2[1234567];
int seg[4234567],lazy1[4234567],lazy2[4234567];
int n;
 
 
int build(int node,int s,int e){
	seg[node]=0;
	lazy1[node]=0;
	lazy2[node]=0;
	if(s==e)
		return 0;
	int mid=(s+e)/2;
	build(2*node,s,mid);
	build(2*node+1,mid+1,e);
	return 0;
}
int update(int node,int s,int e,int l,int r,int a,int d){
	int mid=(s+e)/2;
	if(lazy1[node] || lazy2[node]){
	
		if(s==e){
			seg[node]+=lazy1[node];
		}
		else{
			int mid=(s+e)/2;
			lazy1[2*node]+=lazy1[node];
			lazy2[2*node]+=lazy2[node];
			lazy1[2*node+1]+=lazy1[node]+(mid+1-s)*lazy2[node];
			lazy2[2*node+1]+=lazy2[node];
		}
		lazy1[node]=0;
		lazy2[node]=0;
	}
	if(l<=s && e<=r){
		lazy1[node]+=a+(s-l)*d;
		lazy2[node]+=d;
		return 0;
	}
	if(e<l || r<s){
		return 0;
	}
	update(2*node,s,mid,l,r,a,d);
	update(2*node+1,mid+1,e,l,r,a,d);
	return 0;
}
 
int query(int node,int s,int e,int pos){
	int mid=(s+e)/2;
	if(lazy1[node] || lazy2[node]){
	
		if(s==e){
			seg[node]+=lazy1[node];
		}
		else{
			int mid=(s+e)/2;
			lazy1[2*node]+=lazy1[node];
			lazy2[2*node]+=lazy2[node];
			lazy1[2*node+1]+=lazy1[node]+(mid+1-s)*lazy2[node];
			lazy2[2*node+1]+=lazy2[node];
		}
		lazy1[node]=0;
		lazy2[node]=0;
	}
	if(s==e){
		return seg[node];
	}
	if(pos<=mid){
		return query(2*node,s,mid,pos);
	}
	else{
		return query(2*node+1,mid+1,e,pos);
	}
}
int updateall(int val,int sig){
	if(pos1[val]<pos2[val]){
		update(1,0,n-1,0,pos1[val]-1,sig*(pos2[val]-pos1[val]),sig);
		if(pos2[val]!=n)
			update(1,0,n-1,pos1[val],n+pos1[val]-pos2[val]-1,sig*(n-pos2[val]),-1*sig);
		update(1,0,n-1,n+pos1[val]-pos2[val],n-1,0,sig);
	}
	else if(pos1[val]==pos2[val]){
		update(1,0,n-1,0,pos1[val]-1,sig*(pos2[val]-pos1[val]),sig);
		update(1,0,n-1,pos1[val],n-1,sig*(n-pos2[val]),-1*sig);		
	}
	else if(pos1[val]>pos2[val]){
		update(1,0,n-1,0,pos1[val]-pos2[val]-1,sig*(pos1[val]-pos2[val]),-1*sig);
		update(1,0,n-1,pos1[val]-pos2[val],pos1[val]-1,0,sig);
		if(pos1[val]!=n){
			update(1,0,n-1,pos1[val],n-1,sig*(n-pos2[val]),-1*sig);
		}
	}
	return 0;
 
}
main(){
	//std::ios::sync_with_stdio(false); cin.tie(NULL);
	int t;
	// cin>>t;
	scanf("%lld",&t);
	while(t--){
		int q;
		// cin>>n>>q;
		scanf("%lld",&n);
		scanf("%lld",&q);
		int i;
		f(i,1,n+1){
			// cin>>a[i];
			scanf("%lld",&a[i]);
			pos1[a[i]]=i;
		}
		f(i,1,n+1){
			// cin>>b[i];
			scanf("%lld",&b[i]);
			pos2[b[i]]=i;
		}
		build(1,0,n-1);
		f(i,1,n+1){
			updateall(i,1);
		}
		int typ,move=0;
		int x,y;
		rep(i,q){
			// cin>>typ;
			scanf("%lld",&typ);
			if(typ==1){
				// cin>>x;
				scanf("%lld",&x);
				move+=x;
				move%=n;
			}
			else{
				// cin>>x>>y;
				scanf("%lld",&x);
				scanf("%lld",&y);
				updateall(b[x],-1);
				updateall(b[y],-1);
				swap(b[x],b[y]);
				swap(pos2[b[x]],pos2[b[y]]);
				updateall(b[x],1);
				updateall(b[y],1);
			}
			//cout<<query(1,0,n-1,move)<<endl;
			printf("%lld\n",query(1,0,n-1,move));
 
		}
	}
	return 0;   
}
Editorialist's Solution (Commented)
import java.util.*;
import java.io.*;
import java.text.*;
class MAGPER{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), q = ni();
	    int[] a = new int[n], b = new int[n], indA = new int[n], indB = new int[n];
	    for(int i = 0; i< n; i++){
	        a[i] = ni()-1;
	        indA[a[i]] = i;
	    }
	    for(int i = 0; i< n; i++){
	        b[i] = ni()-1;
	        indB[b[i]] = i;
	    }
	    LazySegTree t = new LazySegTree(n);
	    //Calculation for initial permutatons
	    for(int i = 0; i< n; i++)
	        updateEffect(t, indA[i], indB[i], n, true);
	    int curShift = 0;
	    while(q-->0){
	        int ty = ni();
	        if(ty == 1){
	            curShift = (curShift+ni())%n;
	        }else{
	            int p1 = ni()-1, p2 = ni()-1;
	            int v1 = b[p1], v2 = b[p2];
	            //Removing effect before swap
	            updateEffect(t, indA[v1], indB[v1], n, false);
	            updateEffect(t, indA[v2], indB[v2], n, false);
	            //Swap
	            b[p1] = v2;b[p2] = v1;
	            indB[v1] = p2;indB[v2] = p1;
	            //Adding effect after the swap
	            updateEffect(t, indA[v1], indB[v1], n, true);
	            updateEffect(t, indA[v2], indB[v2], n, true);
	        }
	        //Printing answer of current shift
	        pn(t.query(curShift));
	    }
	}
	//x -> ind in A
	//y -> ind in B
	void updateEffect(LazySegTree t, int x, int y, int n, boolean add){
	    if(x<y){
	        t.update(0, x, y-x, 1, add);
	        t.update(x+1, x+n-y-1, n-y-1, -1, add);
	        t.update(x+n-y, n-1, 0, 1, add);
	    }else if(x == y){
	        t.update(0, x, 0, 1, add);
	        t.update(x+1, n-1, n-x-1, -1, add);
	    }else{
	        t.update(0, x-y-1, x-y, -1, add);
	        t.update(x-y, x, 0, 1, add);
	        t.update(x+1, n-1, n-y-1, -1, add);
	    }
	}
	class LazySegTree{
	    int m = 1;
	    long[] t, lazy[];
	    public LazySegTree(int n){
	        while(m<=n)m<<=1;
	        t = new long[m<<1];
	        lazy = new long[2][m<<1];
	        //lazy[0][i] denotes the constant term for lazy update for ith node
	        //lazy[1][i] denotes the common difference for lazy update for ith node
	    }
	    //Pushing lazy to it's children
	    void push(int i, int ll, int rr){
	        if(lazy[0][i] != 0 || lazy[1][i] != 0){
	            long sz = rr-ll+1;
	            t[i] += sz*lazy[0][i]+((sz*sz-sz)/2)*lazy[1][i];
	            if(ll != rr){
	                int mid = (ll+rr)/2;
	                lazy[0][i<<1] += lazy[0][i];
	                lazy[1][i<<1] += lazy[1][i];
	                lazy[0][i<<1|1] += lazy[0][i]+lazy[1][i]*(mid-ll+1); 
	    		//Since first term of right child is (mid-ll+2)th term of parent, so calculated first term of right child
	                lazy[1][i<<1|1] += lazy[1][i];
	            }
	            lazy[0][i] = lazy[1][i] = 0;
	        }
	    }
	    void update(int l, int r, long a, long d, boolean add){
	        if(l>r)return;//degenerate cases may arise, for that
	        // add tells whether we are adding or removing contribution.
	        if(!add){a*=-1;d*=-1;}
	        u(l, r, 0, m-1, 1, a, d);
	    }
	    void u(int l, int r, int ll, int rr, int i, long a, long d){
	        push(i, ll, rr);
	        if(l == ll && r == rr){
	            lazy[0][i] += a;
	            lazy[1][i] += d;
	            push(i, ll, rr);
	            return;
	        }
	        int mid = (ll+rr)/2;
	        if(r <= mid){
	            u(l, r, ll, mid, i<<1, a, d);
	            push(i<<1|1, mid+1, rr);
	        }
	        else if(l>mid){
	            push(i<<1, ll, mid);
	            u(l, r, mid+1, rr, i<<1|1, a, d);
	        }
	        else{
	            u(l, mid, ll, mid, i<<1, a, d);
	            u(mid+1, r, mid+1, rr, i<<1|1, a+(mid-l+1)*d, d);
	        }
	        lazy[0][i] = lazy[1][i] = 0;
	        t[i] = t[i<<1]+t[i<<1|1];
	    }
	    long query(int p){
	        return q(p, 0, m-1, 1);
	    }
	    long q(int p, int ll, int rr, int i){
	        push(i, ll, rr);
	        if(ll == p && p == rr)return t[i];
	        int mid = (ll+rr)/2;
	        if(p<=mid)return q(p, ll, mid, i<<1);
	        else return q(p, mid+1, rr, i<<1|1);
	    }
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new MAGPER().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

It can also be solved in O(n\sqrt{n}). First calculate the answer for each shift, then when printing the answer after a query, if the number of swapped elements is less than \sqrt{n} then the answer is the precomputed one plus all changes from swapped elements (go through all swaps and change the answer), else recalculate the answer for each shift.