FINDINGSUM - Editorial


Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: udhav2003
Editorialist: iceknight1093




Combinatorics, in particular stars and bars


Given N and M, find the sum of \displaystyle \left( \sum_{i=1}^N \left| A_i - B_i \right|\right)^2 across all pairs of arrays A and B of length M such that:

  • 0 \leq A_i, B_i \leq M
  • sum(A_i) = sum(B_i) = M


First, let’s analyze the value contributed by a single array.
|A_i - B_i| is somewhat hard to deal with, so let’s rewrite it as \max(A_i, B_i) - \min(A_i, B_i); after all, that’s what their difference is.

Then, notice that

\begin{align*} \sum_{i=1}^N \left|A_i - B_i\right| &= \sum_{i=1}^N \left(\max(A_i, B_i) - \min(A_i, B_i) \right) \\ &= \sum_{i=1}^N \left(\max(A_i, B_i) + \min(A_i, B_i) \right) - 2\cdot\sum_{i=1}^N \min(A_i, B_i) \\ &= 2M - 2\cdot \sum_{i=1}^N \min(A_i, B_i) \end{align*}

So, the sum of differences is fixed if we fix the sum of \min(A_i, B_i) across al indices.
Also, notice that \sum_{i=1}^N \min(A_i, B_i) must lie between 0 and M, so there aren’t too many possible values for it.

Let’s fix S = \sum_{i=1}^N \min(A_i, B_i), and try to count how many pairs of arrays A and B attain this value of S.
Each such array will contribute (2M - 2S)^2 to the final answer, so if we can quickly count them for a fixed S, we’ll be done.

This counting can be done in several steps, as follows:

  • First, let’s fix the distribution of minimum values across all N indices.
    This is equivalent to saying we have N non-negative integers that sum up to S; and we know there are \binom{N+S-1}{S} such sequences of integers.
  • Once this is fixed, we need to distribute a total of M-S more to each of A and B so that their respective sums reach M.
    However, we’ve already fixed the values of the minimums, so we need to distribute these values only to the maximums; in particular, we can’t increase both A_i and B_i now, but only at most one of them.
  • To account for this, we have the following:
    • Fix the number of positions such that A_i is strictly greater than B_i; suppose there are x such positions.
      Note that x can vary from 0 to N.
    • Once x is fixed, also fix which x positions these are: \binom{N}{x} choices in total.
    • After this, we need to distribute M-S to the A_i values of these x positions such that all of them receive at least 1, and we need to distribute M-S to the B_i values of the other N-x positions such all of them receive \geq 0.
      Both of these can be found as pretty much direct applications of stars-and-bars, giving us the product of two binomial coefficients.

Note that after S and x are fixed, we only compute a few binomial coefficients and multiply them out.
There are M+1 possible values of S and N+1 of x, so this solution is \mathcal{O}(N\cdot M) in total.

tl;dr the answer is

\sum_{S=0}^M \sum_{x=0}^N \left((2M-2S)^2 \cdot \binom{N+S-1}{S} \binom{N}{x}\binom{M-S-1}{x-1}\binom{N-x+M-S-1}{M-S} \right)


\mathcal{O}(N\cdot M) per test case.


Author's code (C++)
#pragma GCC optimization("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long  
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;  
#define pb push_back               
#define mp make_pair        
#define nline "\n"                         
#define f first                                          
#define s second                                               
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()   
#define vl vector<ll>         
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#define debug(x);  
void _print(int x){cerr<<x;}    
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;} 
void _print(string x){cerr<<x;}     
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); 
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
const ll MOD=998244353;   
const ll MAX=500500;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    return ans;
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
void precompute(ll MOD){
    for(ll i=2;i<MAX;i++){
    for(ll i=MAX-2;i>=0;i--){
ll nCr(ll a,ll b,ll MOD){
        return 1;  
        return 0;     
    ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;  
    return (denom*fact[a])%MOD;     
void solve(){   
    ll n,m; cin>>n>>m;
    ll ans=0;
    for(ll i=0;i<=m;i++){
        ll mul=(nCr(n+i-1,n-1,MOD)*(2ll*(m-i)))%MOD;
        for(ll j=0;j<=n;j++){
            ll now=(nCr(n,j,MOD)*nCr(m-i-1,j-1,MOD))%MOD;
int main()                                                                             
    #ifndef ONLINE_JUDGE               
    freopen("input.txt", "r", stdin);                                             
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                        
    ll test_cases=1;                   
Editorialist's code (Python)
mod = 998244353
MX = 200005
fac = [1]
invfac = [1]
for i in range(1, MX):
	fac.append(i * fac[i-1] % mod)
	invfac.append(pow(fac[-1], mod-2, mod))

def C(n, r):
	if n < r or r < 0: return 0
	return fac[n] * invfac[r] % mod * invfac[n-r] % mod

for _ in range(int(input())):
	n, m = map(int, input().split())
	ans = 0
	for s in range(m+1):
		mul = (2*m - 2*s) * (2*m - 2*s) % mod * C(n+s-1, s) % mod
		for x in range(n+1):
			indices = C(n, x)
			ai_distr = C(m-s-1, x-1) # At least one to each
			bi_distr = C(n-x+m-s-1, m-s) # >= 0
			ans += mul * indices % mod * ai_distr % mod * bi_distr % mod
	print(ans % mod)
1 Like

Will there be any over counting issue here

Eg: if ai=bi for some i, then we can choose either ai or bi as minimum, but it is counted twice i guess,but it needs to be counted exactly once

A = 1 2 0
B = 0 2 1

Here for S = 2
x = 2 i choose {1,2} as my indices
x = 1 i choose only {1} as my indice where ai>bi

Never mind, i got the approach