MEXPERMDIF - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: Omkar Tripathi
Tester: Harris Leung
Editorialist: Nishank Suresh

DIFFICULTY:

2945

PREREQUISITES:

None

PROBLEM:

Given N and K, construct a permutation of the integers \{0, 1, 2, \ldots, N\} whose difference between sum of prefix mexes and sum of suffix mexes is K, or report that this is not possible.

EXPLANATION:

First, to leave out a couple of edge cases, brute force over all N! permutations to find the answer when N \leq 5.

We claim that a suitable permutation always exists if and only if K \leq \frac{N\cdot (N+1)}{2}

This constraint is clearly necessary, since the minimum possible sum of prefix mexes is N+1 and the maximum possible sum is 1 + 2 + \ldots + N+1, and so the difference between the prefix mexes and suffix mexes is bounded above by their difference, which is \frac{N\cdot(N+1)}{2}.

To show sufficiency, it is enough to provide a construction for every case (note that we assume N\gt 5 below).

The construction will be broken up into a couple of cases: K \geq N and K \lt N.

Large K

For K \geq N, it is always possible to obtain a required permutation with P_0 = 0. This can be done as follows:

  • First, note that P_0 = 0 forces the sum of suffix mexes to be exactly N+1, regardless of how the remaining elements are arranged.
  • For the sum of prefix mexes, note that positions 0, 1, 2, \ldots, N-1 currently contribute 1 to it, and position N contributes N+1, so our difference is initially N without having placed anything at all.
  • Let us try to place the elements 1, 2, 3, \ldots in order at positions 1, 2, 3, \ldots.
  • Placing i at position i increases the contribution of positions i, i+1, \ldots, N-1 by 1 each. In other words, the difference increases by exactly N-i.
    • Let d be the current difference. If d+N-i \lt K, set P_i = i and continue, increasing d by N-i.
    • Otherwise, d+N-i \geq K. Find the position j \geq i such that d+N-j = K, and set P_j = i. Our difference is now exactly K, so we would like to place the remaining elements so as to not change the difference.
    • This is simple: place the elements N, N-1, N-2, \ldots, i+1 from left to right, each time choosing the first empty position available.

This finishes the construction for K \geq N.

Small K

For K \lt N, we make use of a different pattern.

First, consider a permutation such that P_0 = x and P_1 = 0, and positions 2 to N consist of the remaining elements sorted in descending order.

  • The suffix mexes of this permutation sum to x + N+1.
  • The prefix mexes of this permutation sum to 0 + \underbrace{1 + 1 + \ldots + 1}_{N-1 \text{ times}} + N+1 = 2N. Note that this only holds when x \gt 1.

The difference between these values is |N-x-1|.
For x \lt N, this is simply N-x-1. Along with our earlier condition that x\gt 1, we can obtain every difference from 0 to N-3 this way.

Only K = N-2 and K = N-1 are remaining. Once again, we use a similar pattern: P_0 = x, P_1 = y, P_2 = 0 with the rest in descending order, for appropriately chosen x and y.

Computing the prefix and suffix mexes for this case as we did above gives the following:

  • For K = N-2, choose x = N-2 and y = N-1.
  • For K = N-1, choose x = N-1 and y = N-2.

This completes the proof for every case.

The constructions given above are easily implemented in \mathcal{O}(N).

TIME COMPLEXITY

\mathcal{O}(N) per test case.

CODE:

Setter's code (C++)
// code by triggered_code
#include<bits/stdc++.h>
#define int long long
#define pb push_back
const char nl = '\n';
const int INF = LONG_MAX;
using namespace std;
using namespace __gnu_cxx;

/* Function to check whether the vector<int> a has k diff for prefix sum and suffix sum */
int check_the_permutation(vector<int> a){

 int n = a.size();
 vector<bool> arr(n+2, 0);

   for(int i: a){
   //invalid permutations
   if((i > n) || (i < 0) || (arr[i] == 1)){
      return -1;
   }
   else{
      arr[i] = 1;
   }

   }

   fill(arr.begin(), arr.end(), 0);

   int prefix_mex_sum = 0;
   int index1 = 0;
   for(int i = 0; i < n; i++){
   arr[a[i]] = 1;
   if(a[i] == index1){
       while(arr[index1]) index1++;
   }

   prefix_mex_sum += (index1);
   }

   fill(arr.begin(), arr.end(), 0);

   int suffix_mex_sum = 0;
   int index2 = 0;
   for(int i = n-1; i >= 0; i--){
      arr[a[i]] = 1;
      if(a[i] == index2){
          while(arr[index2]) index2++;
      }

      suffix_mex_sum += (index2);
   }


  int diff = abs(prefix_mex_sum - suffix_mex_sum);

  return diff;
}


void solve(int n, int k){

   vector<int> ans, sum_elements, used(n+1,0);

   if(k < n){

      used[0] = 1;
      sum_elements.pb(0);

      if(k == 0){
      used[n-1] = 1;
      sum_elements.pb(n-1);
      }

      else if(k == 1){
      used[n] = 1;
      sum_elements.pb(n);
      }

      else{

      if(n-2 == k){
      used[n-3] = 1;
      used[k+1] = 1;
      sum_elements.pb(min(n-3, k+1)); 
      sum_elements.pb(max(n-3, k+1));   

      }
      else{
      used[n-2] = 1;
      used[k] = 1;
      sum_elements.pb(min(k,n-2)); 
      sum_elements.pb(max(k,n-2)); 
      }
      }

      for(int i = 0; i <= n; i++){
      if(!used[i]) ans.pb(i);
      }

      for(int i: sum_elements) ans.pb(i);

    }

    else if(k < (n*(n+1))/2){

      int x = n;
      while(k){

      if(k >= x){
       k -= x;
       used[x] = 1;
       sum_elements.pb(x);
      }
      x--;
      }

      int n_ = sum_elements.size();

      if(used[n_]){

      if(!used[1]){
       sum_elements.pb(1);
       used[1] = 1;
      }

      for(int i = 0; i <= n; i++){
       if(!used[i]) ans.pb(i);
      }

      sort(sum_elements.begin(), sum_elements.end());
      sum_elements.pop_back();
      for(int i: sum_elements){

       if(i == n_){
           ans.pb(n);
       }

       ans.pb(i);
      }

      }
      else{

      sum_elements[0] = n_;
      used[n_] = 1;
      used[n] = 0;

      if(!used[1]){
       sum_elements.pb(1);
       used[1] = 1;
      }

      for(int i = 0; i <= n; i++){
       if(!used[i]) ans.pb(i);
      }

      sort(sum_elements.begin(), sum_elements.end());

      for(int i: sum_elements){
       ans.pb(i);
      }

      }

    }
    else if( k == (n*(n+1))/2 ){

      for(int i = 0; i <= n; i++) ans.pb(i);
    }
    else{

        cout<<-1<<nl;
        return;
    }

   for(int i: ans) cout<<i<<' ';
   cout<<nl;

}


signed main() {

   ios_base::sync_with_stdio(false);
   cin.tie(NULL);

   #ifndef ONLINE_JUDGE
   freopen("input_t4.txt" , "r" , stdin) ;
   freopen("output.txt" , "w" , stdout) ;
   freopen("error.txt" , "w" , stderr) ;
   #endif

   map<pair<int,int>,vector<int>> pre_result;

   for(int i = 1; i <= 4; i++){

   vector<int> v ;

   for(int j = 0; j <= i; j++){
      v.pb(j);
   }

   do{
      int k = check_the_permutation(v);
      if(k != -1) pre_result[{i,k}] = v;    

   }while(next_permutation(v.begin(), v.end()));

   }

   int T;
   cin>>T;

   while(T--){

   int n, k;
   cin>>n>>k;

   if(n <= 4){

      if(pre_result.find({n,k}) == pre_result.end()){
      cout<<-1<<nl;
   }

   else{

      for(int i: pre_result[{n,k}]){
      cout<<i<<' ';
      }
      cout<<nl;
   }

   }
   else{
      solve(n,k);
      }
   }


   return 0; 

}
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const ll mod=998244353;
const int N=2e5+1;
ll n,k;
ll a[N];
void solve(){
	cin >> n >> k;n++;
	if(k>n*(n-1)/2){
		cout << "-1\n";
		return;
	}
	if(k<n-1){
		if(k<=n-4){
			int x=n-2-k;
			for(int i=1; i<=n-2 ;i++){
				cout << i+(i>=x) << ' ';
			}
			cout << 0 << ' ' << x << '\n';
		}
		if(k==n-2){
			if(n==2) cout << "-1\n";
			else if(n==3) cout << "1 0 2\n";
			else{
				cout << "3 1 ";
				for(int i=4; i<n ;i++) cout << i << ' ';
				cout << "0 2\n";
			}
		}
		if(k==n-3){
			if(n==3) cout << "-1\n";
			else if(n==4) cout << "1 2 0 3\n";
			else{
				cout << "3 2 1 ";
				for(int i=5; i<n ;i++) cout << i << ' ';
				cout << "0 4\n";
			}
		}
		return;
	}
	for(int i=1; i<=n ;i++) a[i]=0;
	k-=n-1;
	for(int i=1; i<=n ;i++){
		ll plus=min(n-i-1,k);k-=plus;
		a[plus+1]=i;
		if(plus==0){
			for(int j=1; j<n ;j++){
				if(a[j]==0) a[j]=++i;
			}
				
			for(int j=1; j<=n ;j++) cout << a[j] << ' ';
			cout << '\n';
			return;
		}
	}
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int t;cin >> t;while(t--) solve();
}
Editorialist's code (C++)
#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);
	
	auto brute = [] (int n, int k) {
		vector<int> v(n+1); iota(begin(v), end(v), 0);
		do {
			int a = 0, b = 0;
			vector<int> pref(n+2), suf(n+2);
			int p = 0, q = 0;
			for (int i = 0; i <= n; ++i) {
				pref[v[i]] = 1;
				suf[v[n-i]] = 1;
				while (pref[p]) ++p;
				while (suf[q]) ++q;
				a += p; b += q;
			}
			if (abs(a - b) == k) {
				for (int x : v) cout << x << ' ';
				cout << '\n';
				return;
			}
		} while(next_permutation(begin(v), end(v)));
		cout << -1 << '\n';
	};

	auto solve = [] (int n, ll k) {
		if (k > 1LL * n * (n+1)/2) {
			cout << -1 << '\n';
			return;
		}

		if (k >= n) {
			// start with 0
			// suf = n + 1
			// try to place 1, 2, ... in order while possible
			ll have = n;
			vector<int> ans(n+1);
			for (int i = 1; i <= n; ++i) {
				// place i at position i -> profit of n-i
				if (have + n-i < k) {
					ans[i] = i;
					have += n - i;
					continue;
				}

				have += n-i;
				int pos = i;
				while (have > k) {
					++pos;
					--have;
				}
				ans[pos] = i;
				int place = n;
				for (int j = 1; j <= n; ++j) {
					if (ans[j] == 0) ans[j] = place--;
				}
				break;
			}

			for (int x : ans) cout << x << ' ';
			cout << '\n';
			return;
		}

		// x 0 ...
		// say x > 1 and x < n
		// suf = x + n+1
		// pref = 0 + n-1 + n+1 = 2n
		// dif = n-x-1 -> range = [0, n-3]
		// special case n-2, n-1

		if (k <= n-3) {
			cout << n-k-1 << ' ' << 0 << ' ';
			for (int i = n; i >= 1; --i) {
				if (i == n-k-1) continue;
				cout << i << ' ';
			}
			cout << '\n';
			return;
		}

		// x y 0 ...
		// suf = 2*min(x, y) + n+1 + (x > y)
		// pref = n+1 + n-2 = 2n-1
		// k = n-2 -> n-2 = min(x, y)

		if (k == n-1) cout << n-1 << ' ' << n-2 << ' ';
		else cout << n-2 << ' ' << n-1 << ' ';
		cout << 0 << ' ';
		for (int i = n; i >= 1; --i) {
			if (i == n-1 or i == n-2) continue;
			cout << i << ' ';
		}
		cout << '\n';
	};

	int t; cin >> t;
	while (t--) {
		ll n, k; cin >> n >> k;
		if (n <= 4) {
			brute(n, k);
			continue;
		}
		solve(n, k);
	}
}
1 Like