SUMOVERALL - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: yash_daga
Editorialist: iceknight1093

DIFFICULTY:

3033

PREREQUISITES:

Combinatorics

PROBLEM:

The beauty of an array B equals \sum_{i=1}^M |B_i - B_{M+1-i}|.
The score of array B equals the maximum beauty across all its permutations.

Given N and X, find the sum of scores of all arrays of length N whose elements lie between 1 and X.

EXPLANATION:

First, let’s figure out what the score of an array is.

Notice that index i and index M+1-i are symmetric about the center of the array, so indices are essentially paired up. The difference of each such pair is counted twice.
Consider some index i, with B_i = x and B_{M+1-i} = y.
Then, these indices contribute 2 \cdot |x-y| = 2\max(x, y) - 2\min(x, y) to the beauty of the array.

Since indices are paired up, from a global perspective this means that exactly \left\lfloor \frac{N}{2} \right\rfloor indices contribute a positive value to the beauty, and exactly \left\lfloor \frac{N}{2} \right\rfloor indices contribute a negative value.

To maximize this, clearly the best we can do is to ensure that the largest \left\lfloor \frac{N}{2} \right\rfloor values contribute positively, and the smallest \left\lfloor \frac{N}{2} \right\rfloor values contribute negatively.
This is of course easily achievable: simply sort the array!

The above discussion tells us exactly what the score of a fixed array B is — it’s twice the sum of the largest \left\lfloor \frac{N}{2} \right\rfloor elements, minus twice the sum of the smallest \left\lfloor \frac{N}{2} \right\rfloor elements.
Let’s use this information to solve the problem.

Suppose we were able to calculate \text{ct}[i][y], the number of arrays whose i-th smallest element is y.
This would allow us to solve the problem quite easily: after all, the answer would either increase or decrease by 2\cdot y\cdot \text{ct}[i][y], depending on the value of i.

However, directly computing \text{ct}[i][y] is not easy.

Instead, let’s try to relax the problem a little.
Let g[i][y] be the number of arrays whose i-th smallest element is \leq y.
That is, g[i][y] is the number of arrays containing at least i elements that are \leq y.
Then, \text{ct}[i][y] = g[i][y] - g[i][y-1], so if we’re able to compute g we can easily compute \text{ct}.

However, once again it’s not clear how we would compute g. This time, there’s too much freedom: sure, we can place i elements \leq y in the array, but there might be other elements \leq y as well. How to account for those?

Well, we can try adding that as a constraint.
Let h[i][y] denote the number of arrays whose i-th smallest element is \leq y, and whose (i+1)-th element is \gt y.
In other words, h[i][y] is the number of arrays such that exactly i elements are \leq y.

h[i][y] is in fact quite easy to compute:

  • Fix the positions of the i elements that are \leq y. This can be done in \binom{N}{i} ways.
  • Each of these i positions has y choices, since we can place any of 1, 2, 3, \ldots, y there.
    This is y^i ways.
  • Each of the other positions has X-y choices, since we can place any of y+1, y+2, \ldots, X there.
    This is (X-y)^{N-i} ways.

So, h[i][y] = \binom{N}{i} \cdot y^i\cdot (X-y)^{N-i}, which can be computed in \mathcal{O}(\log N) or even \mathcal{O}(1) with some precalculation.

Since h[i][y] is the number of arrays with exactly i elements \leq y, and g[i][y] is the number of arrays with at least i elements \leq y, we have

g[i][y] = h[i][y] + h[i+1][y] + \ldots + h[N][y]

This is basically a (column-wise) suffix sum of the h array, so is easy to compute once h is known.

So,

  • Compute all the h[i][y] values, which can be done in \mathcal{O}(N\cdot X).
  • Take column-wise suffix sums of h to obtain the g[i][y] values, which once again takes \mathcal{O}(N\cdot X) time.
  • From the g[i][y] values, compute \text{ct}[i][y] and hence solve the problem.

TIME COMPLEXITY

\mathcal{O}(N\cdot X) per test case.

CODE:

Setter's code (C++)
#pragma GCC optimization("O3")
#pragma GCC optimization("Ofast,unroll-loops")

#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;  
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;  
#define pb push_back                 
#define mp make_pair            
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>             
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE      
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);    
#endif 
void _print(ll x){cerr<<x;}         
void _print(int x){cerr<<x;}  
void _print(char x){cerr<<x;}     
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;    
const ll MAX=5005;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;  
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){  
    return binpow(a,MOD-2,MOD); 
} 
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){
        fact[i]=(fact[i-1]*i)%MOD;
    }
    inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
    for(ll i=MAX-2;i>=0;i--){  
        inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;  
    }          
}  
ll nCr(ll a,ll b,ll MOD){    
    if((a<0)||(a<b)||(b<0))    
        return 0;       
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
    return (denom*fact[a])%MOD;    
}  
ll power_val[MAX][MAX]; 
void solve(){  
    ll n,x; cin>>n>>x;
    ll ans=0;  
    for(ll i=1;i<x;i++){
        for(ll j=1;j<=n;j++){
            ll ways=(power_val[i][j]*power_val[x-i][n-j])%MOD; 
            ways=(ways*nCr(n,j,MOD))%MOD; 
            ans=(ans+2ll*min(j,n-j)*ways)%MOD;
        }
    }
    cout<<ans<<nline; 
    return;     
}                                                     
int main()                                                                                            
{                
    ios_base::sync_with_stdio(false);                             
    cin.tie(NULL);          
    #ifndef ONLINE_JUDGE                        
    freopen("input.txt", "r", stdin);                                                  
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif                            
    ll test_cases=1;                   
    cin>>test_cases;
    precompute(MOD);
    for(ll i=0;i<MAX;i++){
        power_val[i][0]=1;
        for(ll j=1;j<MAX;j++){
            power_val[i][j]=(power_val[i][j-1]*i)%MOD; 
        }
    }
    while(test_cases--){  
        solve();
    } 
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}   
Tester's code (C++)
//clear adj and visited vector declared globally after each test case
//check for long long overflow   
//Incase of close mle change language to c++17 or c++14  
//Check ans for n=1 
#pragma GCC target ("avx2")    
#pragma GCC optimize ("O3")  
#pragma GCC optimize ("unroll-loops")
#include <bits/stdc++.h>                   
#include <ext/pb_ds/assoc_container.hpp>  
#define int long long      
#define IOS std::ios::sync_with_stdio(false); cin.tie(NULL);cout.tie(NULL);cout.precision(dbl::max_digits10);
#define pb push_back 
#define lld long double
#define mii map<int, int> 
#define pii pair<int, int>
#define ll long long 
#define ff first
#define ss second 
#define all(x) (x).begin(), (x).end()
#define rep(i,x,y) for(int i=x; i<y; i++)    
#define fill(a,b) memset(a, b, sizeof(a))
#define vi vector<int>
#define setbits(x) __builtin_popcountll(x)
#define print2d(dp,n,m) for(int i=0;i<=n;i++){for(int j=0;j<=m;j++)cout<<dp[i][j]<<" ";cout<<"\n";}
typedef std::numeric_limits< double > dbl;
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> indexed_set;
//member functions :
//1. order_of_key(k) : number of elements strictly lesser than k
//2. find_by_order(k) : k-th element in the set
#define prev prev2
const long long N=5005, INF=2000000000000000000;
const int inf=2e9 + 5, mod=998244353;
lld pi=3.1415926535897932;
int lcm(int a, int b)
{
    int g=__gcd(a, b);
    return a/g*b;
}
int power(int a, int b, int p)
    {
        if(a==0)
        return 0;
        int res=1;
        a%=p;
        while(b>0)
        {
            if(b&1)
            res=(1ll*res*a)%p;
            b>>=1;
            a=(1ll*a*a)%p;
        }
        return res;
    }
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int getRand(int l, int r)
{
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

int po[N][N];
int fact[N], inv[N];
void pre()
{
    fact[0]=inv[0]=1;
    rep(i,1,N)
    fact[i]=(fact[i-1]*i)%mod;
    rep(i,1,N)
    inv[i]=power(fact[i], mod-2, mod);
}
int nCr(int n, int r)
{
    if(min(n, r)<0 || r>n)
    return 0;
    if(n==r)
    return 1;
    return (((fact[n]*inv[r])%mod)*inv[n-r])%mod;
}

int32_t main()
{
    IOS;
    pre();
    rep(i,0,N)
    {
        po[i][0]=1;
        rep(j,1,N)
        po[i][j]=(po[i][j-1]*i)%mod;
    }
    int t;
    cin>>t;
    while(t--)
    {
        int n, x, ans=0;
        cin>>n>>x;
        rep(i,1,x)
        {
            rep(j,1,n)
            {
                int mul=(po[i][j]*po[x-i][n-j]%mod*nCr(n, j))%mod;
                ans=(ans + (2ll*mul*min(j, n-j)))%mod;
            }
        }
        cout<<ans<<"\n";
    }
}
Editorialist's code (Python)
mod = 998244353
lim = 5005

pows = [ [0 for i in range(lim)] for j in range(lim) ]
pows[0][0] = 1
for j in range(1, lim):
	pows[0][j] = 1
	for i in range(1, lim):
		pows[i][j] = (pows[i-1][j] * j) % mod

dp = [ [0 for i in range(lim)] for _ in range(2) ]
inv = [1]*lim
for i in range(1, lim): inv[i] = pow(i, mod-2, mod)
for _ in range(int(input())):
	n, x = map(int, input().split())
	
	C = 1
	ans = 0
	for i in reversed(range(1, n+1)):
		cur, prv = dp[i%2], dp[(i+1)%2]
		p1, p2 = pows[i], pows[n-i]
		for y in range(1, x+1):
			
			cur[y] = C * p1[y] % mod * p2[x-y] % mod
			if i < n: cur[y] = (cur[y] + prv[y]) % mod
			
			if 2*i != n+1:
				mul = 2 if 2*i > n else -2
				ways = cur[y] - cur[y-1]
				ans += mul * ways % mod * y % mod
				ans %= mod
		C = C * i % mod * inv[n-i+1] % mod
	print(ans)