PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: satyam_343
Tester: udhav2003
Editorialist: iceknight1093
DIFFICULTY:
3012
PREREQUISITES:
Combinatorics, in particular stars and bars
PROBLEM:
Given N and M, find the sum of \displaystyle \left( \sum_{i=1}^N \left| A_i - B_i \right|\right)^2 across all pairs of arrays A and B of length M such that:
- 0 \leq A_i, B_i \leq M
- sum(A_i) = sum(B_i) = M
EXPLANATION:
First, let’s analyze the value contributed by a single array.
|A_i - B_i| is somewhat hard to deal with, so let’s rewrite it as \max(A_i, B_i) - \min(A_i, B_i); after all, that’s what their difference is.
Then, notice that
So, the sum of differences is fixed if we fix the sum of \min(A_i, B_i) across al indices.
Also, notice that \sum_{i=1}^N \min(A_i, B_i) must lie between 0 and M, so there aren’t too many possible values for it.
Let’s fix S = \sum_{i=1}^N \min(A_i, B_i), and try to count how many pairs of arrays A and B attain this value of S.
Each such array will contribute (2M - 2S)^2 to the final answer, so if we can quickly count them for a fixed S, we’ll be done.
This counting can be done in several steps, as follows:
- First, let’s fix the distribution of minimum values across all N indices.
This is equivalent to saying we have N non-negative integers that sum up to S; and we know there are \binom{N+S-1}{S} such sequences of integers. - Once this is fixed, we need to distribute a total of M-S more to each of A and B so that their respective sums reach M.
However, we’ve already fixed the values of the minimums, so we need to distribute these values only to the maximums; in particular, we can’t increase both A_i and B_i now, but only at most one of them. - To account for this, we have the following:
- Fix the number of positions such that A_i is strictly greater than B_i; suppose there are x such positions.
Note that x can vary from 0 to N. - Once x is fixed, also fix which x positions these are: \binom{N}{x} choices in total.
- After this, we need to distribute M-S to the A_i values of these x positions such that all of them receive at least 1, and we need to distribute M-S to the B_i values of the other N-x positions such all of them receive \geq 0.
Both of these can be found as pretty much direct applications of stars-and-bars, giving us the product of two binomial coefficients.
- Fix the number of positions such that A_i is strictly greater than B_i; suppose there are x such positions.
Note that after S and x are fixed, we only compute a few binomial coefficients and multiply them out.
There are M+1 possible values of S and N+1 of x, so this solution is \mathcal{O}(N\cdot M) in total.
tl;dr the answer is
TIME COMPLEXITY
\mathcal{O}(N\cdot M) per test case.
CODE:
Author's code (C++)
#pragma GCC optimization("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#define ll long long
const ll INF_MUL=1e13;
const ll INF_ADD=1e18;
#define pb push_back
#define mp make_pair
#define nline "\n"
#define f first
#define s second
#define pll pair<ll,ll>
#define all(x) x.begin(),x.end()
#define vl vector<ll>
#define vvl vector<vector<ll>>
#define vvvl vector<vector<vector<ll>>>
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
void _print(int x){cerr<<x;}
void _print(ll x){cerr<<x;}
void _print(char x){cerr<<x;}
void _print(string x){cerr<<x;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
const ll MOD=998244353;
const ll MAX=500500;
vector<ll> fact(MAX+2,1),inv_fact(MAX+2,1);
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
a%=MOD;
while(b){
if(b&1)
ans=(ans*a)%MOD;
b/=2;
a=(a*a)%MOD;
}
return ans;
}
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
}
void precompute(ll MOD){
for(ll i=2;i<MAX;i++){
fact[i]=(fact[i-1]*i)%MOD;
}
inv_fact[MAX-1]=inverse(fact[MAX-1],MOD);
for(ll i=MAX-2;i>=0;i--){
inv_fact[i]=(inv_fact[i+1]*(i+1))%MOD;
}
}
ll nCr(ll a,ll b,ll MOD){
if(a==b){
return 1;
}
if((a<0)||(a<b)||(b<0))
return 0;
ll denom=(inv_fact[b]*inv_fact[a-b])%MOD;
return (denom*fact[a])%MOD;
}
void solve(){
ll n,m; cin>>n>>m;
ll ans=0;
for(ll i=0;i<=m;i++){
ll mul=(nCr(n+i-1,n-1,MOD)*(2ll*(m-i)))%MOD;
mul=(mul*(2ll*(m-i)))%MOD;
for(ll j=0;j<=n;j++){
ll now=(nCr(n,j,MOD)*nCr(m-i-1,j-1,MOD))%MOD;
now=(now*nCr(m-i+n-j-1,n-j-1,MOD))%MOD;
ans=(ans+now*mul)%MOD;
}
}
cout<<ans<<nline;
return;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=1;
cin>>test_cases;
precompute(MOD);
while(test_cases--){
solve();
}
cout<<fixed<<setprecision(10);
cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n";
}
Editorialist's code (Python)
mod = 998244353
MX = 200005
fac = [1]
invfac = [1]
for i in range(1, MX):
fac.append(i * fac[i-1] % mod)
invfac.append(pow(fac[-1], mod-2, mod))
def C(n, r):
if n < r or r < 0: return 0
return fac[n] * invfac[r] % mod * invfac[n-r] % mod
for _ in range(int(input())):
n, m = map(int, input().split())
ans = 0
for s in range(m+1):
mul = (2*m - 2*s) * (2*m - 2*s) % mod * C(n+s-1, s) % mod
for x in range(n+1):
indices = C(n, x)
ai_distr = C(m-s-1, x-1) # At least one to each
bi_distr = C(n-x+m-s-1, m-s) # >= 0
ans += mul * indices % mod * ai_distr % mod * bi_distr % mod
print(ans % mod)