PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: harsh_h
Tester: kaori1
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Prefix sums, binary search, math
PROBLEM:
You are given an array A of length N.
For any permutation P of \{1, 2, \ldots, N\}, define f(P, A) to be the sum of minimum A_i values of every cycle.
Define g(A) to be the sum of f(P, A) across all permutations P.
Answer Q independent queries: each time, given i and x, set A_i := x and then compute g(A).
Updates are not persistent.
The answer is to be found modulo M, which may not be prime.
EXPLANATION:
First, we need to figure out how to compute g(A) quickly.
For convenience, let B denote the sorted array A, so that B_1 \leq B_2 \leq\ldots\leq B_N.
Let’s try to count the number of permutations in which B_i is the minimum of its cycle.
First, let’s look at just the elements 1 through i.
For B_i to be the minimum in its cycle, it certainly can’t be with any of the smaller elements.
So, the number of permutations of only the first i elements such that B_i is the minimum of its cycle, equals (i-1)! (choose any permutation of the other elements, keep index i alone).
Now, let’s try to extend this permutation by inserting in larger values.
- Index i+1 has i+1 choices: it can either be placed somewhere in an existing cycle (there are i choices, since it can be placed in the ‘middle’ of any of the i existing edges), or it can start a new cycle on its own.
- Similarly, index i+2 has i+2 choices, i+3 has i+3 choices, and so on till we reach N, which has N choices.
Note that each of these choices is independent of previous choices, so the total number of permutations we can build is just their product,
This means g(A) is, quite simply,
Now, we need to process updates.
Suppose A_i = y initially, and is now updated to x. Then, observe that:
- Any value which is \leq \min(x, y) doesn’t change its multiplier.
- Similarly, any value which is \geq \max(x, y) doesn’t change its multiplier either.
- For elements between x and y, their multipliers will go from \frac{N!}{i} to either \frac{N!}{i-1} or \frac{N!}{i+1}, depending on whether x or y is larger.
Since updates are independent, this gives us a pretty simple way of processing them, using prefix sums.
Let’s store, for each i, the values \frac{B_i}{i}, \frac{B_i}{i-1}, \frac{B_i}{i+1}.
Also build prefix sums on these three arrays.
Then, when processing an update,
- Some prefix and some suffix will use the \frac{B_i}{i} values.
- The remaining subarray will use one of the other two.
Either way, all three parts can be obtained in \mathcal{O}(1) time using prefix sums.
Finding the appropriate prefix and suffix is a simple exercise in binary search.
We’re only left with one problem now: the answer must be found modulo M, but division might not be possible if M is composite!
Getting around this is quite simple, however.
Observe that we only really need division to find \frac{N!}{i} for some 1\leq i \leq N.
Looking back at how we obtained this, it equals (i-1)!\cdot (i+1)\cdot (i+2)\cdot \ldots\cdot N.
Let’s break this into two parts: the (i-1)! part, and the (i+1)\cdot (i+2)\cdot\ldots\cdot N part.
Each of these can be precomputed for every i, requiring only multiplication. Then, combining appropriate prefix and suffix products for each i requires only multiplication again, so we’ve avoided division entirely!
TIME COMPLEXITY:
\mathcal{O}(N + Q\log N) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
ll t;cin>>t;
while(t--){
ll n,q,m;cin>>n>>q>>m;
ll Mod = m;
vector<ll> v(n+1);
vector<ll> b(1);
for(ll i = 1; i <= n; i++){
cin>>v[i];
b.push_back(v[i]);
}
vector<ll> coeff(n+1);
vector<ll> pref(n+1,1);
vector<ll> suf(n+2,1);
for(ll i = 1; i <= n; i++ ){
pref[i] = pref[i-1]*i;
pref[i] %= Mod;
}
for(ll i = n; i >= 1; i-- ){
suf[i] = suf[i+1]*i;
suf[i] %= Mod;
}
coeff[1] = suf[2];
coeff[n] = pref[n-1];
for(ll i = 2; i <= n-1 ; i++){
coeff[i] = (pref[i-1] * suf[i+1])%Mod;
}
sort(b.begin()+1,b.end());
vector<ll> bi(n+1); // stores prefix sum bi/i
vector<ll> bip(n+1); // stores prefix sum bi/(i+1)
vector<ll> bim(n+1); // stores prefix sum bi/(i-1);
for(ll i = 1; i <= n; i++){
bi[i] = (bi[i-1] + b[i]*coeff[i])%Mod;
bip[i] = (bip[i-1] + b[i]*coeff[i+1])%Mod;
if(i > 1) bim[i] = (bim[i-1] + b[i]*coeff[i-1])%Mod;
}
for(ll i = 1; i<= q;i++){
ll pos, x;cin>>pos>>x;
ll curr = v[pos];
ll ans = 0;
if(curr == x){
ans=(bi[n])%Mod;
}else if(curr > x){
ll index = lower_bound(b.begin(),b.end(),curr)-b.begin();
ll newPos = lower_bound(b.begin(),b.end(),x)-b.begin();
ans = (((bi[newPos-1] + bi.back() - bi[index] + bip[index-1] - bip[newPos-1] + (x*coeff[newPos])%Mod)%Mod)%Mod + Mod)%Mod;
}else if(curr < x){
ll index = lower_bound(b.begin(),b.end(),curr)-b.begin();
ll newPos = lower_bound(b.begin(),b.end(),x)-b.begin()-1;
ans = (((bi[index-1] + bi.back() - bi[newPos] + bim[newPos] - bim[index] + (x*coeff[newPos])%Mod)%Mod)%Mod + Mod)%Mod;
}
cout<<ans<<"\n";
}
// i x
// x x y x x x x x x x - 3
// x x x x x x x y x x - 8
}
}
Tester's code (C++)
//#pragma GCC optimize("O3")
//#pragma GCC optimize("Ofast")
//#pragma GCC optimize("unroll-loops")
//#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
// http://xorshift.di.unimi.it/splitmix64.c
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<typename T>
T rand(T a, T b){
return uniform_int_distribution<T>(a, b)(rng);
}
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<pair<int,int>, null_type, less<pair<int,int>>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef long long ll;
typedef long double ld;
typedef vector<ll> vl;
typedef vector<int> vi;
#define rep(i, a, b) for(int i = a; i < b; i++)
#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
//#define int long long
ll mod;
const ll INF = 1e18;
int norm (int x) {
if (x < 0) x += mod;
if (x >= mod) x -= mod;
return x;
}
template<class T>
T power(T a, int b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b & 1) res *= a;
}
return res;
}
struct Z {
int x;
Z(int x = 0) : x(norm(x)) {}
int val() const {
return x;
}
Z operator-() const {
return Z(norm(mod - x));
}
Z inv() const {
return power(*this, mod - 2);
}
Z &operator*=(const Z &rhs) {
x = 1ll * x * rhs.x % mod;
return *this;
}
Z &operator+=(const Z &rhs) {
x = norm(x + rhs.x);
return *this;
}
Z &operator-=(const Z &rhs) {
x = norm(x - rhs.x);
return *this;
}
Z &operator/=(const Z &rhs) {
return *this *= rhs.inv();
}
friend Z operator*(const Z &lhs, const Z &rhs) {
Z res = lhs;
res *= rhs;
return res;
}
friend Z operator+(const Z &lhs, const Z &rhs) {
Z res = lhs;
res += rhs;
return res;
}
friend Z operator-(const Z &lhs, const Z &rhs) {
Z res = lhs;
res -= rhs;
return res;
}
friend Z operator/(const Z &lhs, const Z &rhs) {
Z res = lhs;
res /= rhs;
return res;
}
friend std::istream &operator>>(std::istream &is, Z &a) {
int v;
is >> v;
a = Z(v);
return is;
}
friend std::ostream &operator<<(std::ostream &os, const Z &a) {
return os << a.val();
}
};
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
int t;
cin >> t;
while (t--) {
int n, q;
cin >> n >> q >> mod;
int a[n];
for (auto &x : a) cin >> x;
Z pref[n + 1], suf[n + 2];
pref[0] = 1;
for (int i = 1; i <= n; i++) pref[i] = pref[i - 1] * i;
suf[n + 1] = 1;
for (int i = n; i >= 0; i--) suf[i] = suf[i + 1] * i;
array<int, 2> b[n];
int pos[n];
for (int i = 0; i < n; i++) b[i] = {a[i], i};
sort(b, b + n);
for (int i = 0; i < n; i++) pos[b[i][1]] = i;
Z ans1[n], ans2[n], ans3[n];
for (int i = 0; i < n; i++) {
ans1[i] = pref[i] * suf[i + 2] * b[i][0];
if (i) ans2[i] = pref[i - 1] * suf[i + 1] * b[i][0];
if (i < n - 1) ans3[i] = pref[i + 1] * suf[i + 3] * b[i][0];
}
Z p1[n], p2[n], p3[n];
for (int i = 0; i < n; i++) p1[i] = ans1[i], p2[i] = ans2[i], p3[i] = ans3[i];
for (int i = 1; i < n; i++) {
p1[i] += p1[i - 1];
p2[i] += p2[i - 1];
p3[i] += p3[i - 1];
}
while (q--) {
int i, x;
cin >> i >> x;
i--;
i = pos[i];
array<int, 2> f = {x, -1};
int j = lower_bound(b, b + n, f) - b;
if (j == i) j++;
if (j > i) {
Z ans = (i > 0 ? p1[i - 1] : Z(0)) + p1[n - 1] - p1[j - 1] + p2[j - 1] - p2[i];
j--;
ans += pref[j] * suf[j + 2] * x;
cout << ans << "\n";
}
else {
Z ans = (j > 0 ? p1[j - 1] : Z(0)) + p1[n - 1] - p1[i] + p3[i - 1] - (j > 0 ? p3[j - 1] : Z(0));
ans += pref[j] * suf[j + 2] * x;
cout << ans << "\n";
}
}
}
}
Editorialist's code (Python)
import bisect
for _ in range(int(input())):
n, q, mod = map(int, input().split())
a = list(map(int, input().split()))
b = sorted(a)
pprod, sprod = [1]*(n+1), [1]*(n+2)
for i in range(1, n+1):
pprod[i] = pprod[i-1] * i % mod
for i in reversed(range(n+1)):
sprod[i] = sprod[i+1] * i % mod
pre1 = [0]*n
pre2 = [0]*n
pre3 = [0]*n
for i in range(n):
pre1[i] = b[i] * pprod[i] % mod * sprod[i+2] % mod
if i+1 < n: pre3[i] = b[i] * pprod[i+1] % mod * sprod[i+3] % mod
if i > 0:
pre1[i] = (pre1[i] + pre1[i-1]) % mod
pre3[i] = (pre3[i] + pre3[i-1]) % mod
pre2[i] = (pre2[i-1] + b[i] * pprod[i-1] % mod * sprod[i+1]) % mod
for ii in range(q):
i, x = map(int, input().split())
i -= 1
if x > a[i]:
pref = bisect.bisect_right(b, a[i])
suf = bisect.bisect_left(b, x)
ans = pre1[-1] - pre1[suf-1] + pre2[suf-1] + pre1[pref-1] - pre2[pref-1]
ans -= a[i] * pprod[pref-1] % mod * sprod[pref+1] % mod
ans += x * pprod[suf-1] % mod * sprod[suf+1] % mod
print(ans % mod)
elif x < a[i]:
pref = bisect.bisect_right(b, x)
suf = bisect.bisect_left(b, a[i])
ans = pre1[-1]
if suf > 0:
ans += pre3[suf-1] - pre1[suf-1]
if pref > 0:
ans += pre1[pref-1] - pre3[pref-1]
ans -= a[i] * pprod[suf] % mod * sprod[suf+2] % mod
ans += x * pprod[pref] % mod * sprod[pref+2] % mod
print(ans % mod)
else: print(pre1[-1])