PERMCYCQUERY - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: harsh_h
Tester: kaori1
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Prefix sums, binary search, math

PROBLEM:

You are given an array A of length N.
For any permutation P of \{1, 2, \ldots, N\}, define f(P, A) to be the sum of minimum A_i values of every cycle.
Define g(A) to be the sum of f(P, A) across all permutations P.

Answer Q independent queries: each time, given i and x, set A_i := x and then compute g(A).
Updates are not persistent.
The answer is to be found modulo M, which may not be prime.

EXPLANATION:

First, we need to figure out how to compute g(A) quickly.

For convenience, let B denote the sorted array A, so that B_1 \leq B_2 \leq\ldots\leq B_N.
Let’s try to count the number of permutations in which B_i is the minimum of its cycle.

First, let’s look at just the elements 1 through i.
For B_i to be the minimum in its cycle, it certainly can’t be with any of the smaller elements.
So, the number of permutations of only the first i elements such that B_i is the minimum of its cycle, equals (i-1)! (choose any permutation of the other elements, keep index i alone).
Now, let’s try to extend this permutation by inserting in larger values.

  • Index i+1 has i+1 choices: it can either be placed somewhere in an existing cycle (there are i choices, since it can be placed in the ‘middle’ of any of the i existing edges), or it can start a new cycle on its own.
  • Similarly, index i+2 has i+2 choices, i+3 has i+3 choices, and so on till we reach N, which has N choices.

Note that each of these choices is independent of previous choices, so the total number of permutations we can build is just their product,

(i-1)!\cdot (i+1)\cdot (i+2)\cdot\ldots\cdot N = \frac{N!}{i}

This means g(A) is, quite simply,

\sum_{i=1}^N \frac{N!}{i} \cdot B_i

Now, we need to process updates.

Suppose A_i = y initially, and is now updated to x. Then, observe that:

  • Any value which is \leq \min(x, y) doesn’t change its multiplier.
  • Similarly, any value which is \geq \max(x, y) doesn’t change its multiplier either.
  • For elements between x and y, their multipliers will go from \frac{N!}{i} to either \frac{N!}{i-1} or \frac{N!}{i+1}, depending on whether x or y is larger.

Since updates are independent, this gives us a pretty simple way of processing them, using prefix sums.
Let’s store, for each i, the values \frac{B_i}{i}, \frac{B_i}{i-1}, \frac{B_i}{i+1}.
Also build prefix sums on these three arrays.
Then, when processing an update,

  • Some prefix and some suffix will use the \frac{B_i}{i} values.
  • The remaining subarray will use one of the other two.
    Either way, all three parts can be obtained in \mathcal{O}(1) time using prefix sums.

Finding the appropriate prefix and suffix is a simple exercise in binary search.

We’re only left with one problem now: the answer must be found modulo M, but division might not be possible if M is composite!
Getting around this is quite simple, however.
Observe that we only really need division to find \frac{N!}{i} for some 1\leq i \leq N.
Looking back at how we obtained this, it equals (i-1)!\cdot (i+1)\cdot (i+2)\cdot \ldots\cdot N.

Let’s break this into two parts: the (i-1)! part, and the (i+1)\cdot (i+2)\cdot\ldots\cdot N part.
Each of these can be precomputed for every i, requiring only multiplication. Then, combining appropriate prefix and suffix products for each i requires only multiplication again, so we’ve avoided division entirely!

TIME COMPLEXITY:

\mathcal{O}(N + Q\log N) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>

using namespace std;

using ll = long long;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    ll t;cin>>t;

    while(t--){
        ll n,q,m;cin>>n>>q>>m;
        ll Mod = m;
        vector<ll> v(n+1);
        vector<ll> b(1);
        for(ll i = 1; i <= n; i++){
            cin>>v[i];
            b.push_back(v[i]);
        }
        vector<ll> coeff(n+1);
        vector<ll> pref(n+1,1);
        vector<ll> suf(n+2,1);
        for(ll i = 1; i <= n; i++ ){
            pref[i] = pref[i-1]*i;
            pref[i] %= Mod;
        }
        for(ll i = n; i >= 1; i-- ){
            suf[i] = suf[i+1]*i;
            suf[i] %= Mod;
        }
        coeff[1] = suf[2];
        coeff[n] = pref[n-1];
        for(ll i = 2; i <= n-1 ; i++){
            coeff[i] = (pref[i-1] * suf[i+1])%Mod;
        }

        
        sort(b.begin()+1,b.end());
        vector<ll> bi(n+1); // stores prefix sum bi/i
        vector<ll> bip(n+1); // stores prefix sum bi/(i+1)
        vector<ll> bim(n+1); // stores prefix sum bi/(i-1);
        for(ll i = 1; i <= n; i++){
            bi[i] = (bi[i-1] + b[i]*coeff[i])%Mod;
            bip[i] = (bip[i-1] + b[i]*coeff[i+1])%Mod;
            if(i > 1) bim[i] = (bim[i-1] + b[i]*coeff[i-1])%Mod;
        }
        for(ll i = 1; i<= q;i++){
            ll pos, x;cin>>pos>>x;
            ll curr = v[pos];
            ll ans = 0;
            if(curr == x){
                ans=(bi[n])%Mod;
            }else if(curr > x){
                ll index = lower_bound(b.begin(),b.end(),curr)-b.begin();
                ll newPos = lower_bound(b.begin(),b.end(),x)-b.begin();
                ans = (((bi[newPos-1] + bi.back() - bi[index] + bip[index-1] - bip[newPos-1] + (x*coeff[newPos])%Mod)%Mod)%Mod + Mod)%Mod;
            }else if(curr < x){
                ll index = lower_bound(b.begin(),b.end(),curr)-b.begin();
                ll newPos = lower_bound(b.begin(),b.end(),x)-b.begin()-1;
                ans = (((bi[index-1] + bi.back() - bi[newPos] + bim[newPos] - bim[index] + (x*coeff[newPos])%Mod)%Mod)%Mod + Mod)%Mod;
            }
            cout<<ans<<"\n";
        }
        //  i x
        // x x y x x x x x x x - 3
        // x x x x x x x y x x - 8 
    }
}
Tester's code (C++)
//#pragma GCC optimize("O3")
//#pragma GCC optimize("Ofast")
//#pragma GCC optimize("unroll-loops")
//#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")


#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;


struct custom_hash {
        static uint64_t splitmix64(uint64_t x) {
                // http://xorshift.di.unimi.it/splitmix64.c
                x += 0x9e3779b97f4a7c15;
                x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
                x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
                return x ^ (x >> 31);
        }

        size_t operator()(uint64_t x) const {
                static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
                return splitmix64(x + FIXED_RANDOM);
        }
};
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<typename T>
T rand(T a, T b){
    return uniform_int_distribution<T>(a, b)(rng);
}

typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int,int>, null_type, less<pair<int,int>>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef long long ll;
typedef long double ld;
typedef vector<ll> vl;
typedef vector<int> vi;


#define rep(i, a, b) for(int i = a; i < b; i++)
#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
//#define int long long

ll mod;
const ll INF = 1e18;

int norm (int x) {
        if (x < 0) x += mod;
        if (x >= mod) x -= mod;
        return x;
}
template<class T>
T power(T a, int b) {
        T res = 1;
        for (; b; b /= 2, a *= a) {
                if (b & 1) res *= a;
        }
        return res;
}
struct Z {
        int x;
        Z(int x = 0) : x(norm(x)) {}
        int val() const {
                return x;
        }
        Z operator-() const {
                return Z(norm(mod - x));
        }
        Z inv() const {
                return power(*this, mod - 2);
        }
        Z &operator*=(const Z &rhs) {
                x = 1ll * x * rhs.x % mod;
                return *this;
        }
        Z &operator+=(const Z &rhs) {
                x = norm(x + rhs.x);
                return *this;
        }
        Z &operator-=(const Z &rhs) {
                x = norm(x - rhs.x);
                return *this;
        }
        Z &operator/=(const Z &rhs) {
                return *this *= rhs.inv();
        }
        friend Z operator*(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res *= rhs;
                return res;
        }
        friend Z operator+(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res += rhs;
                return res;
        }
        friend Z operator-(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res -= rhs;
                return res;
        }
        friend Z operator/(const Z &lhs, const Z &rhs) {
                Z res = lhs;
                res /= rhs;
                return res;
        }
        friend std::istream &operator>>(std::istream &is, Z &a) {
                int v;
                is >> v;
                a = Z(v);
                return is;
        }
        friend std::ostream &operator<<(std::ostream &os, const Z &a) {
                return os << a.val();
        }
};



signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);



        int t;
        cin >> t;

        while (t--) {

                int n, q;
                cin >> n >> q >> mod;
                int a[n];
                for (auto &x : a) cin >> x;

                Z pref[n + 1], suf[n + 2];
                pref[0] = 1;
                for (int i = 1; i <= n; i++) pref[i] = pref[i - 1] * i;
                suf[n + 1] = 1;
                for (int i = n; i >= 0; i--) suf[i] = suf[i + 1] * i;

                array<int, 2> b[n];
                int pos[n];
                for (int i = 0; i < n; i++) b[i] = {a[i], i};
                sort(b, b + n);
                for (int i = 0; i < n; i++) pos[b[i][1]] = i;

                Z ans1[n], ans2[n], ans3[n];
                for (int i = 0; i < n; i++) {
                        ans1[i] = pref[i] * suf[i + 2] * b[i][0];
                        if (i) ans2[i] = pref[i - 1] * suf[i + 1] * b[i][0];
                        if (i < n - 1) ans3[i] = pref[i + 1] * suf[i + 3] * b[i][0];
                }

                Z p1[n], p2[n], p3[n];
                for (int i = 0; i < n; i++) p1[i] = ans1[i], p2[i] = ans2[i], p3[i] = ans3[i];
                for (int i = 1; i < n; i++) {
                        p1[i] += p1[i - 1];
                        p2[i] += p2[i - 1];
                        p3[i] += p3[i - 1];
                }

                while (q--) {

                        int i, x;
                        cin >> i >> x;
                        i--;
                        i = pos[i];
                        array<int, 2> f = {x, -1};
                        int j = lower_bound(b, b + n, f) - b;
                        if (j == i) j++;
                        
                        if (j > i) {
                                Z ans = (i > 0 ? p1[i - 1] : Z(0)) + p1[n - 1] - p1[j - 1] + p2[j - 1] - p2[i];
                                j--;
                                ans += pref[j] * suf[j + 2] * x;
                                cout << ans << "\n";
                        }
                        else {
                                Z ans = (j > 0 ? p1[j - 1] : Z(0)) + p1[n - 1] - p1[i] + p3[i - 1] - (j > 0 ? p3[j - 1] : Z(0));
                                ans += pref[j] * suf[j + 2] * x;
                                cout << ans << "\n";
                        }

                }


        }
        
        
}
Editorialist's code (Python)
import bisect
for _ in range(int(input())):
    n, q, mod = map(int, input().split())
    a = list(map(int, input().split()))
    b = sorted(a)

    pprod, sprod = [1]*(n+1), [1]*(n+2)
    for i in range(1, n+1):
        pprod[i] = pprod[i-1] * i % mod
    for i in reversed(range(n+1)):
        sprod[i] = sprod[i+1] * i % mod

    pre1 = [0]*n
    pre2 = [0]*n
    pre3 = [0]*n
    for i in range(n):
        pre1[i] = b[i] * pprod[i] % mod * sprod[i+2] % mod
        if i+1 < n: pre3[i] = b[i] * pprod[i+1] % mod * sprod[i+3] % mod
        if i > 0:
            pre1[i] = (pre1[i] + pre1[i-1]) % mod
            pre3[i] = (pre3[i] + pre3[i-1]) % mod

            pre2[i] = (pre2[i-1] + b[i] * pprod[i-1] % mod * sprod[i+1]) % mod
    for ii in range(q):
        i, x = map(int, input().split())
        i -= 1

        if x > a[i]:
            pref = bisect.bisect_right(b, a[i])
            suf = bisect.bisect_left(b, x)
            ans = pre1[-1] - pre1[suf-1] + pre2[suf-1] + pre1[pref-1] - pre2[pref-1]
            
            ans -= a[i] * pprod[pref-1] % mod * sprod[pref+1] % mod
            ans += x * pprod[suf-1] % mod * sprod[suf+1] % mod
            print(ans % mod)
        elif x < a[i]:
            pref = bisect.bisect_right(b, x)
            suf = bisect.bisect_left(b, a[i])
            ans = pre1[-1]
            if suf > 0:
                ans += pre3[suf-1] - pre1[suf-1]
            if pref > 0:
                ans += pre1[pref-1] - pre3[pref-1]
            ans -= a[i] * pprod[suf] % mod * sprod[suf+2] % mod
            ans += x * pprod[pref] % mod * sprod[pref+2] % mod
            print(ans % mod)
        else: print(pre1[-1])