NSTROT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: Hasin Rayhan Dewan Dhruboo
Tester: Teja Vardhan Reddy
Editorialist: Taranpreet Singh

DIFFICULTY:

Medium

PREREQUISITES:

Observation, Implementation, Contribution trick would be helpful

PROBLEM:

Given two permutations A and B of length N each and an integer V, find the value of result in the following pseudocode.

Pseudo Code
function Log2(x):
    if x is at most 1:
        return 0
    else:
        return Log2(floor(x / 2)) + 1

function F(A[1...N], B[1...N]):
    posA[i] = index j such that A[j] == i
    posB[i] = index j such that B[j] == i
    ret = 0
    for i = 1 to N:
        ret += Log2(absolute_value(posA[i] - posB[i]))
    return ret

result = 1
For i = 0 to v - 1:
    R = cyclic right shift of Q by i
    result *= F(P, R)
result %= 998244353 

QUICK EXPLANATION

  • Let’s try to find value returned by S_i = F(A, B^i) where B^i denote ith cyclic shift of B for each 0 \leq i < V.
  • For each value i, we try to add the contribution of i to each cyclic shift. This gives us O(N*V) solution where we naively add contribution of each value to each shift.
  • In order to speed up, we can notice that the value of log(N) changes only log(N) times. This allows us to update value of shifts by range addition updates.

EXPLANATION

First of all, see the function F and notice that the contribution of value i is independent of the contribution of other values to the returned value.

Hence, this allows us to calculate the contribution of each value in each shift separately and compute the answer from there.

Suppose S_i = F(A, B^i) where B^i is the i-th cyclic shift of B.

Now, we need to find the contribution of value v. Assuming P_v denote the position of v in A and Q_v denote the position of v in given B.

Three cases arise

  • P_v < Q_v
    In this case, value v first moves from position Q_v to N, then from position 1 to P_v and then from P_v to Q_v-1 (Till v shifts are exhausted.)
  • P_v = Q_v
    In this case, value v first moves from position Q_v to N and then from position 1 to Q_v-1
  • P_v > Q_v
    In this case, value v first moves from Q_v to N and then from 1 to P_v and then from P_v+1 to Q_v-1

Consider example permutations

1 4 2 3 5 6
5 1 2 4 3 6

Let’s assume we need to find the contribution of 4.

  • It moves from position 4 to 6. It contributes log(|4-2|) to S_0, log(|5-2|) to S_0 and log(|6-2|) to S_2.
  • It moves from 1 to 2. It contributes log(|1-2|) to S_3 and log(|2-2|) to S_4
  • It moves from 3 to 3. It contributes log(|3-2|) to S_5.

If we perform the above process naively, we can get O(N*V) solution. (Refer my solution, commented part for reference)

You might be wondering why I chose the above weird ranges.
In all these three ranges, the absolute value of |P_v-Q_v| either increases by 1 or decreases by 1 throughout the range specified.

We need to observe the values of log(N) functions as N increases. It is easy to see that the value of log(N) changes only when N reaches power of 2.

This allows us to update multiple S_i in a range, which would be increased by the same value.

Considering our example again, in the first range, S_0 increases by log(2), S_1 increases by log(3) and S_2 increases by log(4). Now, log(2) = log(3) = 1. So we can further split this raneg [0, 2] into [0, 1] and [2, 2] such that each range is increases by same value.

Since log(N) changes values at most log(N) times, it gives us at most 3*log(N) ranges for range addition updates, which we can process in O(1) using difference array.

In the end, we just need to recover the original array from the difference array and find the product.

For doing the splitting process, it is helpful to precompute powers of 2 smaller than and larger than a given value.

If still in doubt, refer to solutions below.

TIME COMPLEXITY

The time complexity is O(N*log(N)) per test case.

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
const int mod = 998244353;
const int maxn = 100009;
inline int add(int a, int b) {return (a + b) >= mod ? a + b - mod : a + b;}
inline int sub(int a, int b) {return add(a, mod - b);}
inline int mul(int a, int b) {return (a * 1LL * b) % mod;}

int n, m;
int ara1[maxn], ara2[maxn];
int posP[maxn], posQ[maxn];
int perPosMax[maxn];

int perRot[maxn];
int logval[maxn];

int whichRot(int L, int R, int x)
{
	if(R >= L){
	    if(x < L || x > R) return -1;
	    return x - L;
	}
	else{
	    if(x > R && x < L) return -1;
	    if(x >= L) return x - L;
	    else return x + n - L;
	}
}

int main()
{
	int t, cs = 1;
	cin >> t;
	logval[0] = logval[1] = 0;
	for(int i = 2; i < maxn; i++) logval[i] = logval[i / 2] + 1;

	while(t--){
	    scanf("%d %d", &n, &m);
	    for(int i = 1; i <= n; i++) scanf("%d", &ara1[i]), posP[ara1[i]] = i;
	    for(int i = 1; i <= n; i++) scanf("%d", &ara2[i]), posQ[ara2[i]] = i;
	    memset(perRot, 0, sizeof(perRot));

	    for(int i = 1; i <= n; i++){
	        perPosMax[i] = (posQ[ara1[i]] + m - 1) % n;
	        if(perPosMax[i] == 0) perPosMax[i] = n;

	        int L = posQ[ara1[i]], R = perPosMax[i];

	        for(int j = 2, sc = 1; j <= n; j = j * 2, sc++){
	            int k = i + j;
	            if(k > n) break;
	            int rot = whichRot(L, R, k);
//                if(rot == 1) cout << i << ' ' << j << endl;
	            perRot[rot]++;
	        }

	        for(int j = 2, sc = 1; j <= n; j = j * 2, sc++){
	            int k = i - j + 1;
	            if(k <= 1) break;
	            int rot = whichRot(L, R, k);
//                if(rot == 1) cout << "x : " << i << ' ' << j << endl;
	            perRot[rot]--;
	        }

	    }
	    int curSc = 0;
	    int ans = 1;
	    for(int i = 1; i <= n; i++) curSc = add(curSc, logval[abs(posP[i] - posQ[i])]);
	    ans = curSc;
	    for(int i = 1; i < m; i++){
	        curSc += perRot[i];
	        int lst = n - i + 1;
	        int lstval = ara2[lst];
	        curSc = sub(curSc, logval[n - posP[lstval]]);
	        curSc = add(curSc, logval[posP[lstval] - 1]);
	        ans = mul(ans, curSc);
	    }
	    printf("%d\n", ans);
	}

	return 0;
}
/*

3
8
4 1 7 8 2 5 6 3
8 5 2 6 7 3 1 4

8
4 1 7 8 2 5 6 3
8 5 2 4 7 3 1 6

8
4 1 2 6 3 8 7 5
3 5 2 8 6 7 4 1

17 8 13 12 19 20 23 28
1
4 2
4 1 2 3
3 2 4 1

*/
Tester's Solution
//teja349
#include <bits/stdc++.h>
#include <vector>
#include <set>
#include <map>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <climits>
#include <utility>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <iomanip>
//setbase - cout << setbase (16); cout << 100 << endl; Prints 64
//setfill -   cout << setfill ('x') << setw (5); cout << 77 << endl; prints xxx77
//setprecision - cout << setprecision (14) << f << endl; Prints x.xxxx
//cout.precision(x)  cout<<fixed<<val;  // prints x digits after decimal in val
 
using namespace std; 
#define f(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) f(i,0,n)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define pb push_back
#define mp make_pair
#define vi vector< int >
#define vl vector< ll >
#define ss second
#define ff first
#define ll long long
#define pii pair< int,int >
#define pll pair< ll,ll >
#define inf (1000*1000*1000+5)
#define all(a) a.begin(),a.end()
#define tri pair<int,pii>
#define vii vector<pii>
#define vll vector<pll>
#define viii vector<tri>
#define mod (998244353)
#define pqueue priority_queue< int >
#define pdqueue priority_queue< int,vi ,greater< int > >
#define flush fflush(stdout) 
#define primeDEN 727999983
 

int arr[123456];
int p[123456],q[123456],invq[123456],invp[123456];
int gg[123456];
int applypos(int l,int r,int st){
	r++;
	int val,nexti,en;
	while(l<r){
		val=gg[st];
		nexti=(1<<(val+1));
		en=min(l+nexti-st,r);
		arr[l]+=val;
		arr[en]-=val;
		l=en;
		st=nexti;
	}
	return 0;
} 

int applyneg(int l,int r,int st){
	//cout<<l<<" "<<r<<" "<<st<<endl;
	int val,nexti,en;
	st=st-(r-l);
	r++;
	while(l<r){
		val=gg[st];
		nexti=(1<<(val+1));
		en=max(l,r-(nexti-st));
		arr[en]+=val;
		arr[r]-=val;
		r=en;
		st=nexti;
	}
	return 0;
}
int main(){
	std::ios::sync_with_stdio(false); cin.tie(NULL);
	int i;
	gg[0]=0;
	gg[1]=0;
	int nexti=2;
	f(i,2,123456){
		gg[i]=gg[i-1];
		if(i==nexti){
			gg[i]++;
			nexti*=2;
		}
	}

	int t;
	cin>>t;
	while(t--){
		int n,v;
		cin>>n>>v;
		int i;
		rep(i,n+10){
			arr[i]=0;
		}
		rep(i,n){
			cin>>p[i];
			p[i]--;
			invp[p[i]]=i;
		}
		rep(i,n){
			cin>>q[i];
			q[i]--;
			invq[q[i]]=i;
		}
		rep(i,n){
			if(invq[i]<invp[i]){
				applyneg(0,invp[i]-invq[i]-1,invp[i]-invq[i]);
				applypos(invp[i]-invq[i]+1,n-1-invq[i],1);
				applyneg(n-invq[i],n-1,invp[i]);
			}
			else if(invq[i]==invp[i]){
				applypos(1,n-1-invq[i],1);
				applyneg(n-invq[i],n-1,invp[i]);
			}
			else if(invq[i]>invp[i]){
				applypos(0,n-1-invq[i],invq[i]-invp[i]);
				applyneg(n-invq[i],n-invq[i]+invp[i]-1,invp[i]);
				applypos(n-invq[i]+invp[i]+1,n-1,1);
			}
		}

		f(i,1,n){
			arr[i]+=arr[i-1];
			arr[i]%=mod;
		}
		ll ans=1;
		rep(i,v){
			//cout<<arr[i]<<endl;
			ans*=arr[i];
			ans%=mod;
		}
		cout<<ans<<endl;

	}
}
Editorialist's Solution
import java.util.*;
import java.io.*;
import java.text.*;
class NSTROT{
	//SOLUTION BEGIN
	int mx = (int)1e5;
	long MOD = 998244353;
	int[] prev = new int[1+mx], nxt = new int[1+mx];
	void pre() throws Exception{
	    nxt[1] = 1;prev[1] = 1;
	    int p = 2;
	    for(int i = 2; i<= mx; i++){
	        if(i == p){
	            nxt[i] = p;
	            prev[i] = p;
	            p *= 2;
	        }else{
	            nxt[i] = p;
	            prev[i] = p/2;
	        }
	    }
	}
	void solve(int TC) throws Exception{
	    int n = ni(), v = ni();
	    int[] a = new int[1+n], b = new int[1+n];
	    for(int i = 1; i<= n; i++)a[i] = ni();
	    for(int i = 1; i<= n; i++)b[i] = ni();
	    long[] ans = new long[v];
	    int[] pA = new int[1+n], pB = new int[1+n];
	    for(int i = 1; i<= n; i++)pA[a[i]] = i;
	    for(int i = 1; i<= n; i++)pB[b[i]] = i;
	    
	    for(int i = 1; i<= n; i++){
	        int st = 0;
	        //Cases as mentioned in editorial
	        if(pA[i] < pB[i]){
	            int count = Math.min(v-st, n-pB[i]+1);
	            updatePos(ans, st, st+count-1, pB[i]-pA[i]);
	            st += count;
	            
	            count = Math.min(v-st, pA[i]-1);
	            updateNeg(ans, st, st+count-1, pA[i]-1);
	            st += count;
	            
	            count = Math.min(v-st, pB[i]-pA[i]);
	            updatePos(ans, st, st+count-1, 0);
	            st += count;
	        }else if(pA[i] == pB[i]){
	            int count = Math.min(v-st, n-pA[i]+1);
	            updatePos(ans, st, st+count-1, 0);
	            st += count;
	            
	            count = Math.min(v-st, pA[i]-1);
	            updateNeg(ans, st, st+count-1, pA[i]-1);
	            st += count;
	        }else if(pA[i] > pB[i]){
	            int count = Math.min(v-st, pA[i]-pB[i]);
	            updateNeg(ans, st, st+count-1, pA[i]-pB[i]);
	            st += count;

	            count = Math.min(v-st, n-pA[i]+1);
	            updatePos(ans, st, st+count-1, 0);
	            st += count;
	            
	            count = Math.min(v-st, pB[i]-1);
	            updateNeg(ans, st, st+count-1, pA[i]-1);
	            st += count;
	        }
	        hold(st == v);
	    }
	    for(int i = 1; i< v; i++)ans[i] += ans[i-1];
	    
	    long product = 1;
	    for(long l:ans)product = (product*l)%MOD;
	    pn(product);
	}
	//Handling updates where value increases starting from L
	void updatePos(long[] ans, int le, int ri, int L) throws Exception{
	    if(ri < le)return;
//        Slow, works for first subtask
//        for(int i = le; i <= ri; i++)ans[i] += log2(L+(i-le));
	    int R = L+(ri-le);
	    for(int i = le; i<= ri; ){
	        int cur = (L+(i-le));
	        int end = Math.min(R, Math.max(cur, nxt[cur]-1))-L+le;
	        //Computing [cur, end] such that S_cur to S_end would be updated by val
	        int val = log2(cur);
	        ans[i] += val;
	        if(end+1 < ans.length)ans[end+1] -= val;
	        i = end+1;
	    }
	}
	//Handles updates where value decreases starting from R
	void updateNeg(long[] ans, int le, int ri, int R) throws Exception{
	    if(ri < le)return;
//        Slow, works for first subtask
//        for(int i = le; i <= ri; i++)ans[i] += log2(R-(i-le));
	    int L = R-ri+le;
	    for(int i = le; i<= ri; ){
	        int cur = R-(i-le);
	        int end = i+cur-Math.max(L, prev[cur]);
	        //Computing [cur, end] such that S_cur to S_end would be updated by val
	        int val = log2(cur);
	        ans[i] += val;
	        if(end+1 < ans.length)ans[end+1] -= val;
	        i = end+1;
	    }
	}
	int log2(int x){
	    if(x <= 1)return 0;
	    return 1+log2(x/2);
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	DecimalFormat df = new DecimalFormat("0.00000000000");
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new NSTROT().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

1 Like

I couldn’t get the last part of it.Will you please explain
how you updated curSc for every rotation!
What is lst here and why are we substracting it!

Hi
Can you provide with a pseudo code or a commented code or some flow chart of your approach?
I am unable to figure out what you are trying to say after line
Since log (N) changes value … so on in editorial and
how to figure out what is written above it into code.
Can you provide with some links that can help?
Any help will be appreciated.