GRANDPAPA2 - Editorial

Setter: Cozma Tiberiu-Stefan
Tester: Harris Leung
Editorialist: Trung Dang

3155

None

PROBLEM:

You are given two numbers N and M. A pair of arrays (A, B) is called beautiful if and only if:

• |A| = |B| = N (i.e. both their lengths are equal to N);
• For all 1 \le i \le N, 1 \le A_i \le M and 1 \le B_i \le M;
• median(A) \leq median(B);
• A is lexicographically smaller than B.

Find the number of beautiful pairs of arrays. Since this number can be huge, print it modulo MOD.

Note:

• median(X) denotes the \lceil \frac{|X|}{2} \rceil ^{th} element after sorting the array X.
• An array X is lexicographically smaller than Y if and only if there exists some index i such that X_i \lt Y_i and X_j = Y_j for 1 \le j \lt i.

QUICK EXPLANATION:

Median is hard to manage, so letâ€™s fix median of A and B beforehand. A straightforward DP solution is to calculate f(x, y) to be the number of pairs (A, B) such that the median of A is x, the median of B is y, and A is lexicographically smaller than B. We can calculate each of such value in O(N^3): dp[i][less_than_x][less_than_y][has_x][has_y][different] being the number of pairs of i-length prefixes A and B such that there are this many values less than x in A, this many values less than y in B, whether x has appeared in A, whether y has appeared in B, and whether A and B are different now. Unfortunately this leads to a O(N^3 M^2) solution.

To speed up, letâ€™s see how we can calculate g(x, y) being the number of pairs (A, B) such that the median of A less than or equal to x, the median of B less than or equal to y, and A is lexicographically smaller than B. This is very similar to how we calculate f(x, y) before, so this part is still straightforward O(N^3). However, notice that g(x, y) is a â€śprefix sumâ€ť of f(x, y). The answer we are looking for is \sum_{x = 1}^m \sum_{y = x}^m f(x, y). Instead of calculating \sum_{y = x}^m f(x, y) by looping through and calculate for each f(x, y), we can instead calculate 4 different values of g(x, y) to get the answer. Therefore, the problem can be solved in O(N^3 M).

TIME COMPLEXITY:

Time complexity is O(N^3 M) per test case.

SOLUTION:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
int n, m, mod;
int dp[51][51][51][2];
int solve(int lima, int limb)
{
if (lima == 0 || limb == 0)
{
return 0;
}
function<int(int)> gauss = [&](int n)
{
return ((1ll * n * (n - 1)) / 2ll) % mod;
};
for (int i = 0; i <= n; ++i)
{
for (int j = 0; j <= i; ++j)
{
for (int k = 0; k <= i; ++k)
{
dp[i][j][k][true] = dp[i][j][k][false] = 0;
}
}
}
dp[0][0][0][0] = 1;
if (lima <= limb)
{
int coef1 = gauss(limb - lima);
int coef2 = ((1ll * (limb - lima) * (m - limb)) % mod + gauss(m - limb)) % mod;
int coef3 = ((1ll * lima * (limb - lima)) % mod + gauss(lima)) % mod;
int coef4 = (1ll * lima * (m - limb)) % mod;
int coef5 = (1ll * (m - lima) * (m - limb)) % mod;
int coef6 = (1ll * lima * (m - limb)) % mod;
int coef7 = (1ll * (m - lima) * limb) % mod;
int coef8 = (1ll * lima * limb) % mod;
for (int i = 0; i < n; ++i)
{
for (int j = 0; j <= i; ++j)
{
for (int k = 0; k <= i; ++k)
{
if (dp[i][j][k][false])
{
dp[i + 1][j][k + 1][true] += (1ll * dp[i][j][k][false] * coef1) % mod;
if (dp[i + 1][j][k + 1][true] >= mod)
dp[i + 1][j][k + 1][true] -= mod;
dp[i + 1][j][k][true] += (1ll * dp[i][j][k][false] * coef2) % mod;
if (dp[i + 1][j][k][true] >= mod)
dp[i + 1][j][k][true] -= mod;
dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][false] * coef3) % mod;
if (dp[i + 1][j + 1][k + 1][true] >= mod)
dp[i + 1][j + 1][k + 1][true] -= mod;
dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][false] * coef4) % mod;
if (dp[i + 1][j + 1][k][true] >= mod)
dp[i + 1][j + 1][k][true] -= mod;
dp[i + 1][j][k][false] += (1ll * dp[i][j][k][false] * (m - limb)) % mod;
if (dp[i + 1][j][k][false] >= mod)
dp[i + 1][j][k][false] -= mod;
dp[i + 1][j][k + 1][false] += (1ll * dp[i][j][k][false] * (limb - lima)) % mod;
if (dp[i + 1][j][k + 1][false] >= mod)
dp[i + 1][j][k + 1][false] -= mod;
dp[i + 1][j + 1][k + 1][false] += (1ll * dp[i][j][k][false] * lima) % mod;
if (dp[i + 1][j + 1][k + 1][false] >= mod)
dp[i + 1][j + 1][k + 1][false] -= mod;
}
if (dp[i][j][k][true])
{
dp[i + 1][j][k][true] += (1ll * dp[i][j][k][true] * coef5) % mod;
if (dp[i + 1][j][k][true] >= mod)
dp[i + 1][j][k][true] -= mod;
dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][true] * coef6) % mod;
if (dp[i + 1][j + 1][k][true] >= mod)
dp[i + 1][j + 1][k][true] -= mod;
dp[i + 1][j][k + 1][true] += (1ll * dp[i][j][k][true] * coef7) % mod;
if (dp[i + 1][j][k + 1][true] >= mod)
dp[i + 1][j][k + 1][true] -= mod;
dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][true] * coef8) % mod;
if (dp[i + 1][j + 1][k + 1][true] >= mod)
dp[i + 1][j + 1][k + 1][true] -= mod;
}
}
}
}
}
else
{
int coef1 = (((1ll * limb * (m - limb)) % mod + gauss(lima - limb)) % mod + (1ll * (lima - limb) * (m - lima)) % mod) % mod;
int coef2 = gauss(m - lima);
int coef3 = gauss(limb);
int coef4 = (1ll * (m - lima) * (m - limb)) % mod;
int coef5 = (1ll * lima * (m - limb)) % mod;
int coef6 = (1ll * (m - lima) * limb) % mod;
int coef7 = (1ll * lima * limb) % mod;
for (int i = 0; i < n; ++i)
{
for (int j = 0; j <= i; ++j)
{
for (int k = 0; k <= i; ++k)
{
if (dp[i][j][k][false])
{
dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][false] * coef1) % mod;
if (dp[i + 1][j + 1][k][true] >= mod)
dp[i + 1][j + 1][k][true] -= mod;
dp[i + 1][j][k][true] += (1ll * dp[i][j][k][false] * coef2) % mod;
if (dp[i + 1][j][k][true] >= mod)
dp[i + 1][j][k][true] -= mod;
dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][false] * coef3) % mod;
if (dp[i + 1][j + 1][k + 1][true] >= mod)
dp[i + 1][j + 1][k + 1][true] -= mod;
dp[i + 1][j + 1][k + 1][false] += (1ll * dp[i][j][k][false] * limb) % mod;
if (dp[i + 1][j + 1][k + 1][false] >= mod)
dp[i + 1][j + 1][k + 1][false] -= mod;
dp[i + 1][j + 1][k][false] += (1ll * dp[i][j][k][false] * (lima - limb)) % mod;
if (dp[i + 1][j + 1][k][false] >= mod)
dp[i + 1][j + 1][k][false] -= mod;
dp[i + 1][j][k][false] += (1ll * dp[i][j][k][false] * (m - lima)) % mod;
if (dp[i + 1][j][k][false] >= mod)
dp[i + 1][j][k][false] -= mod;
}
if (dp[i][j][k][true])
{
dp[i + 1][j][k][true] += (1ll * dp[i][j][k][true] * coef4) % mod;
if (dp[i + 1][j][k][true] >= mod)
dp[i + 1][j][k][true] -= mod;
dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][true] * coef5) % mod;
if (dp[i + 1][j + 1][k][true] >= mod)
dp[i + 1][j + 1][k][true] -= mod;
dp[i + 1][j][k + 1][true] += (1ll * dp[i][j][k][true] * coef6) % mod;
if (dp[i + 1][j][k + 1][true] >= mod)
dp[i + 1][j][k + 1][true] -= mod;
dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][true] * coef7) % mod;
if (dp[i + 1][j + 1][k + 1][true] >= mod)
dp[i + 1][j + 1][k + 1][true] -= mod;
}
}
}
}
}
int ans = 0;
for (int i = (n + 1) / 2; i <= n; ++i)
{
for (int j = (n + 1) / 2; j <= n; ++j)
{
ans += dp[n][i][j][true];
if (ans >= mod)
ans -= mod;
}
}
return ans;
}
int main()
{
cin.tie(nullptr)->sync_with_stdio(false);
cin >> n >> m >> mod;
int ans = 0;
function<int(int, int, int, int)> query = [&](int i1, int j1, int i2, int j2)
{
return (((solve(i2, j2) - solve(i1 - 1, j2) + mod) % mod - solve(i2, j1 - 1) + mod) % mod + solve(i1 - 1, j1 - 1)) % mod;
};
for (int i = 1; i <= m; ++i)
{
ans += query(i, i, i, m);
if (ans >= mod)
ans -= mod;
}
cout << ans;
}

Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
ll mod;
int n,m;
ll dp[51][51][51][2];
int main(){
ios::sync_with_stdio(false);cin.tie(0);
cin >> n >> m >> mod;
ll ans=0;
int mid=(n+1)/2;
for(int z=1; z<=m ;z++){
{
for(int i=0; i<=n ;i++){
for(int j=0; j<=n ;j++){
for(int k=0; k<=n ;k++){
dp[i][j][k][0]=dp[i][j][k][1]=0;
}
}
}
dp[0][0][0][0]=1;
for(int i=1; i<=n ;i++){
for(int j=0; j<i ;j++){
for(int k=0; k<i ;k++){
{//0->0
ll w11=z-1;
ll w10=1;
ll w01=0;
ll w00=m-z;
dp[i][j+1][k+1][0]=(dp[i][j+1][k+1][0]+w11*dp[i-1][j][k][0])%mod;
dp[i][j+1][k+0][0]=(dp[i][j+1][k+0][0]+w10*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+1][0]=(dp[i][j+0][k+1][0]+w01*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+0][0]=(dp[i][j+0][k+0][0]+w00*dp[i-1][j][k][0])%mod;
}
{//0->1
ll w11=(z-1)*(z-2)/2;
ll w10=z*(m-z+1)-1;
ll w01=0;
ll w00=(m-z)*(m-z-1)/2;
dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][0])%mod;
dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][0])%mod;

}
{//1->1
ll w11=z*(z-1);
ll w10=z*(m-z+1);
ll w01=(m-z)*(z-1);
ll w00=(m-z)*(m-z+1);
dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][1])%mod;
dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][1])%mod;
dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][1])%mod;
dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][1])%mod;
}
}
}
}
for(int i=mid; i<=n ;i++){
for(int j=0; j<mid ;j++){
ans=(ans+dp[n][i][j][1])%mod;
}
}
//cout << ans << '\n';
}
{
for(int i=0; i<=n ;i++){
for(int j=0; j<=n ;j++){
for(int k=0; k<=n ;k++){
dp[i][j][k][0]=dp[i][j][k][1]=0;
}
}
}
dp[0][0][0][0]=1;
for(int i=1; i<=n ;i++){
for(int j=0; j<i ;j++){
for(int k=0; k<i ;k++){
{//0->0
ll w11=z;
ll w10=0;
ll w01=0;
ll w00=m-z;
dp[i][j+1][k+1][0]=(dp[i][j+1][k+1][0]+w11*dp[i-1][j][k][0])%mod;
dp[i][j+1][k+0][0]=(dp[i][j+1][k+0][0]+w10*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+1][0]=(dp[i][j+0][k+1][0]+w01*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+0][0]=(dp[i][j+0][k+0][0]+w00*dp[i-1][j][k][0])%mod;
}
{//0->1
ll w11=(z-1)*z/2;
ll w10=z*(m-z);
ll w01=0;
ll w00=(m-z)*(m-z-1)/2;
dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][0])%mod;
dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][0])%mod;
dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][0])%mod;

}
{//1->1
ll w11=z*z;
ll w10=z*(m-z);
ll w01=(m-z)*z;
ll w00=(m-z)*(m-z);
dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][1])%mod;
dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][1])%mod;
dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][1])%mod;
dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][1])%mod;
}
}
}
}
for(int i=mid; i<=n ;i++){
for(int j=0; j<mid ;j++){
ans=(ans+mod-dp[n][i][j][1])%mod;
}
}
//cout << ans << '\n';
}
}
cout << ans << '\n';
}

1 Like