PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: satyam_343
Tester: yash_daga
Editorialist: iceknight1093
DIFFICULTY:
3033
PREREQUISITES:
Combinatorics
PROBLEM:
The beauty of an array B equals \sum_{i=1}^M |B_i - B_{M+1-i}|.
The score of array B equals the maximum beauty across all its permutations.
Given N and X, find the sum of scores of all arrays of length N whose elements lie between 1 and X.
EXPLANATION:
First, let’s figure out what the score of an array is.
Notice that index i and index M+1-i are symmetric about the center of the array, so indices are essentially paired up. The difference of each such pair is counted twice.
Consider some index i, with B_i = x and B_{M+1-i} = y.
Then, these indices contribute 2 \cdot |x-y| = 2\max(x, y) - 2\min(x, y) to the beauty of the array.
Since indices are paired up, from a global perspective this means that exactly \left\lfloor \frac{N}{2} \right\rfloor indices contribute a positive value to the beauty, and exactly \left\lfloor \frac{N}{2} \right\rfloor indices contribute a negative value.
To maximize this, clearly the best we can do is to ensure that the largest \left\lfloor \frac{N}{2} \right\rfloor values contribute positively, and the smallest \left\lfloor \frac{N}{2} \right\rfloor values contribute negatively.
This is of course easily achievable: simply sort the array!
The above discussion tells us exactly what the score of a fixed array B is — it’s twice the sum of the largest \left\lfloor \frac{N}{2} \right\rfloor elements, minus twice the sum of the smallest \left\lfloor \frac{N}{2} \right\rfloor elements.
Let’s use this information to solve the problem.
Suppose we were able to calculate \text{ct}[i][y], the number of arrays whose i-th smallest element is y.
This would allow us to solve the problem quite easily: after all, the answer would either increase or decrease by 2\cdot y\cdot \text{ct}[i][y], depending on the value of i.
However, directly computing \text{ct}[i][y] is not easy.
Instead, let’s try to relax the problem a little.
Let g[i][y] be the number of arrays whose i-th smallest element is \leq y.
That is, g[i][y] is the number of arrays containing at least i elements that are \leq y.
Then, \text{ct}[i][y] = g[i][y] - g[i][y-1], so if we’re able to compute g we can easily compute \text{ct}.
However, once again it’s not clear how we would compute g. This time, there’s too much freedom: sure, we can place i elements \leq y in the array, but there might be other elements \leq y as well. How to account for those?
Well, we can try adding that as a constraint.
Let h[i][y] denote the number of arrays whose i-th smallest element is \leq y, and whose (i+1)-th element is \gt y.
In other words, h[i][y] is the number of arrays such that exactly i elements are \leq y.
h[i][y] is in fact quite easy to compute:
- Fix the positions of the i elements that are \leq y. This can be done in \binom{N}{i} ways.
- Each of these i positions has y choices, since we can place any of 1, 2, 3, \ldots, y there.
This is y^i ways. - Each of the other positions has X-y choices, since we can place any of y+1, y+2, \ldots, X there.
This is (X-y)^{N-i} ways.
So, h[i][y] = \binom{N}{i} \cdot y^i\cdot (X-y)^{N-i}, which can be computed in \mathcal{O}(\log N) or even \mathcal{O}(1) with some precalculation.
Since h[i][y] is the number of arrays with exactly i elements \leq y, and g[i][y] is the number of arrays with at least i elements \leq y, we have
This is basically a (column-wise) suffix sum of the h array, so is easy to compute once h is known.
So,
- Compute all the h[i][y] values, which can be done in \mathcal{O}(N\cdot X).
- Take column-wise suffix sums of h to obtain the g[i][y] values, which once again takes \mathcal{O}(N\cdot X) time.
- From the g[i][y] values, compute \text{ct}[i][y] and hence solve the problem.
TIME COMPLEXITY
\mathcal{O}(N\cdot X) per test case.
CODE:
Setter's code (C++)
#pragma GCC optimization("O3")
#pragma GCC optimization("Ofast,unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(ll x){cerr<<x;}
void _print(int x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;
const ll MAX=5005;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
a%=MOD;
while(b){
if(b&1)
ans=(ans*a)%MOD;
b/=2;
a=(a*a)%MOD;
}
return ans;
}
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
}
void precompute(ll MOD){
for(ll i=2;i<MAX;i++){
fact[i]=(fact[i-1]*i)%MOD;
}
inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
for(ll i=MAX-2;i>=0;i--){
inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
}
}
ll nCr(ll a,ll b,ll MOD){
if((a<0)||(a<b)||(b<0))
return 0;
ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
return (denom*fact[a])%MOD;
}
ll power_val[MAX][MAX];
void solve(){
ll n,x; cin>>n>>x;
ll ans=0;
for(ll i=1;i<x;i++){
for(ll j=1;j<=n;j++){
ll ways=(power_val[i][j]*power_val[x-i][n-j])%MOD;
ways=(ways*nCr(n,j,MOD))%MOD;
ans=(ans+2ll*min(j,n-j)*ways)%MOD;
}
}
cout<<ans<<nline;
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
precompute(MOD);
for(ll i=0;i<MAX;i++){
power_val[i][0]=1;
for(ll j=1;j<MAX;j++){
power_val[i][j]=(power_val[i][j-1]*i)%MOD;
}
}
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow
//Incase of close mle change language to c++17 or c++14
//Check ans for n=1
#pragma GCC target ("avx2")
#pragma GCC optimize ("O3")
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#define int long long
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back
#define lld long double
#define mii map<int, int>
#define pii pair<int, int>
#define ll long long
#define ff first
#define ss second
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
#define prev prev2
const long long N=5005, INF=2000000000000000000;
const int inf=2e9 + 5, mod=998244353;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
int g=__gcd(a, b);
return a/g*b;
}
int power(int a, int b, int p)
{
if(a==0)
return 0;
int res=1;
a%=p;
while(b>0)
{
if(b&1)
res=(1ll*res*a)%p;
b>>=1;
a=(1ll*a*a)%p;
}
return res;
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int getRand(int l, int r)
{
uniform_int_distribution<int> uid(l, r);
return uid(rng);
}
int po[N][N];
int fact[N], inv[N];
void pre()
{
fact[0]=inv[0]=1;
rep(i,1,N)
fact[i]=(fact[i-1]*i)%mod;
rep(i,1,N)
inv[i]=power(fact[i], mod-2, mod);
}
int nCr(int n, int r)
{
if(min(n, r)<0 || r>n)
return 0;
if(n==r)
return 1;
return (((fact[n]*inv[r])%mod)*inv[n-r])%mod;
}
int32_t main()
{
IOS;
pre();
rep(i,0,N)
{
po[i][0]=1;
rep(j,1,N)
po[i][j]=(po[i][j-1]*i)%mod;
}
int t;
cin>>t;
while(t--)
{
int n, x, ans=0;
cin>>n>>x;
rep(i,1,x)
{
rep(j,1,n)
{
int mul=(po[i][j]*po[x-i][n-j]%mod*nCr(n, j))%mod;
ans=(ans + (2ll*mul*min(j, n-j)))%mod;
}
}
cout<<ans<<"\n";
}
}
Editorialist's code (Python)
mod = 998244353
lim = 5005
pows = [ [0 for i in range(lim)] for j in range(lim) ]
pows[0][0] = 1
for j in range(1, lim):
pows[0][j] = 1
for i in range(1, lim):
pows[i][j] = (pows[i-1][j] * j) % mod
dp = [ [0 for i in range(lim)] for _ in range(2) ]
inv = [1]*lim
for i in range(1, lim): inv[i] = pow(i, mod-2, mod)
for _ in range(int(input())):
n, x = map(int, input().split())
C = 1
ans = 0
for i in reversed(range(1, n+1)):
cur, prv = dp[i%2], dp[(i+1)%2]
p1, p2 = pows[i], pows[n-i]
for y in range(1, x+1):
cur[y] = C * p1[y] % mod * p2[x-y] % mod
if i < n: cur[y] = (cur[y] + prv[y]) % mod
if 2*i != n+1:
mul = 2 if 2*i > n else -2
ways = cur[y] - cur[y-1]
ans += mul * ways % mod * y % mod
ans %= mod
C = C * i % mod * inv[n-i+1] % mod
print(ans)