GRANDPAPA2 - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Cozma Tiberiu-Stefan
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

3155

PREREQUISITES:

None

PROBLEM:

You are given two numbers N and M. A pair of arrays (A, B) is called beautiful if and only if:

  • |A| = |B| = N (i.e. both their lengths are equal to N);
  • For all 1 \le i \le N, 1 \le A_i \le M and 1 \le B_i \le M;
  • median(A) \leq median(B);
  • A is lexicographically smaller than B.

Find the number of beautiful pairs of arrays. Since this number can be huge, print it modulo MOD.

Note:

  • median(X) denotes the \lceil \frac{|X|}{2} \rceil ^{th} element after sorting the array X.
  • An array X is lexicographically smaller than Y if and only if there exists some index i such that X_i \lt Y_i and X_j = Y_j for 1 \le j \lt i.

QUICK EXPLANATION:

Median is hard to manage, so let’s fix median of A and B beforehand. A straightforward DP solution is to calculate f(x, y) to be the number of pairs (A, B) such that the median of A is x, the median of B is y, and A is lexicographically smaller than B. We can calculate each of such value in O(N^3): dp[i][less_than_x][less_than_y][has_x][has_y][different] being the number of pairs of i-length prefixes A and B such that there are this many values less than x in A, this many values less than y in B, whether x has appeared in A, whether y has appeared in B, and whether A and B are different now. Unfortunately this leads to a O(N^3 M^2) solution.

To speed up, let’s see how we can calculate g(x, y) being the number of pairs (A, B) such that the median of A less than or equal to x, the median of B less than or equal to y, and A is lexicographically smaller than B. This is very similar to how we calculate f(x, y) before, so this part is still straightforward O(N^3). However, notice that g(x, y) is a “prefix sum” of f(x, y). The answer we are looking for is \sum_{x = 1}^m \sum_{y = x}^m f(x, y). Instead of calculating \sum_{y = x}^m f(x, y) by looping through and calculate for each f(x, y), we can instead calculate 4 different values of g(x, y) to get the answer. Therefore, the problem can be solved in O(N^3 M).

TIME COMPLEXITY:

Time complexity is O(N^3 M) per test case.

SOLUTION:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
int n, m, mod;
int dp[51][51][51][2];
int solve(int lima, int limb)
{
    if (lima == 0 || limb == 0)
    {
        return 0;
    }
    function<int(int)> gauss = [&](int n)
    {
        return ((1ll * n * (n - 1)) / 2ll) % mod;
    };
    for (int i = 0; i <= n; ++i)
    {
        for (int j = 0; j <= i; ++j)
        {
            for (int k = 0; k <= i; ++k)
            {
                dp[i][j][k][true] = dp[i][j][k][false] = 0;
            }
        }
    }
    dp[0][0][0][0] = 1;
    if (lima <= limb)
    {
        int coef1 = gauss(limb - lima);
        int coef2 = ((1ll * (limb - lima) * (m - limb)) % mod + gauss(m - limb)) % mod;
        int coef3 = ((1ll * lima * (limb - lima)) % mod + gauss(lima)) % mod;
        int coef4 = (1ll * lima * (m - limb)) % mod;
        int coef5 = (1ll * (m - lima) * (m - limb)) % mod;
        int coef6 = (1ll * lima * (m - limb)) % mod;
        int coef7 = (1ll * (m - lima) * limb) % mod;
        int coef8 = (1ll * lima * limb) % mod;
        for (int i = 0; i < n; ++i)
        {
            for (int j = 0; j <= i; ++j)
            {
                for (int k = 0; k <= i; ++k)
                {
                    if (dp[i][j][k][false])
                    {
                        dp[i + 1][j][k + 1][true] += (1ll * dp[i][j][k][false] * coef1) % mod;
                        if (dp[i + 1][j][k + 1][true] >= mod)
                            dp[i + 1][j][k + 1][true] -= mod;
                        dp[i + 1][j][k][true] += (1ll * dp[i][j][k][false] * coef2) % mod;
                        if (dp[i + 1][j][k][true] >= mod)
                            dp[i + 1][j][k][true] -= mod;
                        dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][false] * coef3) % mod;
                        if (dp[i + 1][j + 1][k + 1][true] >= mod)
                            dp[i + 1][j + 1][k + 1][true] -= mod;
                        dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][false] * coef4) % mod;
                        if (dp[i + 1][j + 1][k][true] >= mod)
                            dp[i + 1][j + 1][k][true] -= mod;
                        dp[i + 1][j][k][false] += (1ll * dp[i][j][k][false] * (m - limb)) % mod;
                        if (dp[i + 1][j][k][false] >= mod)
                            dp[i + 1][j][k][false] -= mod;
                        dp[i + 1][j][k + 1][false] += (1ll * dp[i][j][k][false] * (limb - lima)) % mod;
                        if (dp[i + 1][j][k + 1][false] >= mod)
                            dp[i + 1][j][k + 1][false] -= mod;
                        dp[i + 1][j + 1][k + 1][false] += (1ll * dp[i][j][k][false] * lima) % mod;
                        if (dp[i + 1][j + 1][k + 1][false] >= mod)
                            dp[i + 1][j + 1][k + 1][false] -= mod;
                    }
                    if (dp[i][j][k][true])
                    {
                        dp[i + 1][j][k][true] += (1ll * dp[i][j][k][true] * coef5) % mod;
                        if (dp[i + 1][j][k][true] >= mod)
                            dp[i + 1][j][k][true] -= mod;
                        dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][true] * coef6) % mod;
                        if (dp[i + 1][j + 1][k][true] >= mod)
                            dp[i + 1][j + 1][k][true] -= mod;
                        dp[i + 1][j][k + 1][true] += (1ll * dp[i][j][k][true] * coef7) % mod;
                        if (dp[i + 1][j][k + 1][true] >= mod)
                            dp[i + 1][j][k + 1][true] -= mod;
                        dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][true] * coef8) % mod;
                        if (dp[i + 1][j + 1][k + 1][true] >= mod)
                            dp[i + 1][j + 1][k + 1][true] -= mod;
                    }
                }
            }
        }
    }
    else
    {
        int coef1 = (((1ll * limb * (m - limb)) % mod + gauss(lima - limb)) % mod + (1ll * (lima - limb) * (m - lima)) % mod) % mod;
        int coef2 = gauss(m - lima);
        int coef3 = gauss(limb);
        int coef4 = (1ll * (m - lima) * (m - limb)) % mod;
        int coef5 = (1ll * lima * (m - limb)) % mod;
        int coef6 = (1ll * (m - lima) * limb) % mod;
        int coef7 = (1ll * lima * limb) % mod;
        for (int i = 0; i < n; ++i)
        {
            for (int j = 0; j <= i; ++j)
            {
                for (int k = 0; k <= i; ++k)
                {
                    if (dp[i][j][k][false])
                    {
                        dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][false] * coef1) % mod;
                        if (dp[i + 1][j + 1][k][true] >= mod)
                            dp[i + 1][j + 1][k][true] -= mod;
                        dp[i + 1][j][k][true] += (1ll * dp[i][j][k][false] * coef2) % mod;
                        if (dp[i + 1][j][k][true] >= mod)
                            dp[i + 1][j][k][true] -= mod;
                        dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][false] * coef3) % mod;
                        if (dp[i + 1][j + 1][k + 1][true] >= mod)
                            dp[i + 1][j + 1][k + 1][true] -= mod;
                        dp[i + 1][j + 1][k + 1][false] += (1ll * dp[i][j][k][false] * limb) % mod;
                        if (dp[i + 1][j + 1][k + 1][false] >= mod)
                            dp[i + 1][j + 1][k + 1][false] -= mod;
                        dp[i + 1][j + 1][k][false] += (1ll * dp[i][j][k][false] * (lima - limb)) % mod;
                        if (dp[i + 1][j + 1][k][false] >= mod)
                            dp[i + 1][j + 1][k][false] -= mod;
                        dp[i + 1][j][k][false] += (1ll * dp[i][j][k][false] * (m - lima)) % mod;
                        if (dp[i + 1][j][k][false] >= mod)
                            dp[i + 1][j][k][false] -= mod;
                    }
                    if (dp[i][j][k][true])
                    {
                        dp[i + 1][j][k][true] += (1ll * dp[i][j][k][true] * coef4) % mod;
                        if (dp[i + 1][j][k][true] >= mod)
                            dp[i + 1][j][k][true] -= mod;
                        dp[i + 1][j + 1][k][true] += (1ll * dp[i][j][k][true] * coef5) % mod;
                        if (dp[i + 1][j + 1][k][true] >= mod)
                            dp[i + 1][j + 1][k][true] -= mod;
                        dp[i + 1][j][k + 1][true] += (1ll * dp[i][j][k][true] * coef6) % mod;
                        if (dp[i + 1][j][k + 1][true] >= mod)
                            dp[i + 1][j][k + 1][true] -= mod;
                        dp[i + 1][j + 1][k + 1][true] += (1ll * dp[i][j][k][true] * coef7) % mod;
                        if (dp[i + 1][j + 1][k + 1][true] >= mod)
                            dp[i + 1][j + 1][k + 1][true] -= mod;
                    }
                }
            }
        }
    }
    int ans = 0;
    for (int i = (n + 1) / 2; i <= n; ++i)
    {
        for (int j = (n + 1) / 2; j <= n; ++j)
        {
            ans += dp[n][i][j][true];
            if (ans >= mod)
                ans -= mod;
        }
    }
    return ans;
}
int main()
{
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> n >> m >> mod;
    int ans = 0;
    function<int(int, int, int, int)> query = [&](int i1, int j1, int i2, int j2)
    {
        return (((solve(i2, j2) - solve(i1 - 1, j2) + mod) % mod - solve(i2, j1 - 1) + mod) % mod + solve(i1 - 1, j1 - 1)) % mod;
    };
    for (int i = 1; i <= m; ++i)
    {
        ans += query(i, i, i, m);
        if (ans >= mod)
            ans -= mod;
    }
    cout << ans;
}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
ll mod;
int n,m;
ll dp[51][51][51][2];
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	cin >> n >> m >> mod;
	ll ans=0;
	int mid=(n+1)/2;
	for(int z=1; z<=m ;z++){
		{
			for(int i=0; i<=n ;i++){
				for(int j=0; j<=n ;j++){
					for(int k=0; k<=n ;k++){
						dp[i][j][k][0]=dp[i][j][k][1]=0;
					}
				}
			}
			dp[0][0][0][0]=1;
			for(int i=1; i<=n ;i++){
				for(int j=0; j<i ;j++){
					for(int k=0; k<i ;k++){
						{//0->0
							ll w11=z-1;
							ll w10=1;
							ll w01=0;
							ll w00=m-z;
							dp[i][j+1][k+1][0]=(dp[i][j+1][k+1][0]+w11*dp[i-1][j][k][0])%mod;
							dp[i][j+1][k+0][0]=(dp[i][j+1][k+0][0]+w10*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+1][0]=(dp[i][j+0][k+1][0]+w01*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+0][0]=(dp[i][j+0][k+0][0]+w00*dp[i-1][j][k][0])%mod;
						}
						{//0->1
							ll w11=(z-1)*(z-2)/2;
							ll w10=z*(m-z+1)-1;
							ll w01=0;
							ll w00=(m-z)*(m-z-1)/2;
							dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][0])%mod;
							dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][0])%mod;
							
						}
						{//1->1
							ll w11=z*(z-1);
							ll w10=z*(m-z+1);
							ll w01=(m-z)*(z-1);
							ll w00=(m-z)*(m-z+1);
							dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][1])%mod;
							dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][1])%mod;
							dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][1])%mod;
							dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][1])%mod;
						}
					}
				}
			}
			for(int i=mid; i<=n ;i++){
				for(int j=0; j<mid ;j++){
					ans=(ans+dp[n][i][j][1])%mod;
				}
			}
			//cout << ans << '\n';
		}
		{
			for(int i=0; i<=n ;i++){
				for(int j=0; j<=n ;j++){
					for(int k=0; k<=n ;k++){
						dp[i][j][k][0]=dp[i][j][k][1]=0;
					}
				}
			}
			dp[0][0][0][0]=1;
			for(int i=1; i<=n ;i++){
				for(int j=0; j<i ;j++){
					for(int k=0; k<i ;k++){
						{//0->0
							ll w11=z;
							ll w10=0;
							ll w01=0;
							ll w00=m-z;
							dp[i][j+1][k+1][0]=(dp[i][j+1][k+1][0]+w11*dp[i-1][j][k][0])%mod;
							dp[i][j+1][k+0][0]=(dp[i][j+1][k+0][0]+w10*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+1][0]=(dp[i][j+0][k+1][0]+w01*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+0][0]=(dp[i][j+0][k+0][0]+w00*dp[i-1][j][k][0])%mod;
						}
						{//0->1
							ll w11=(z-1)*z/2;
							ll w10=z*(m-z);
							ll w01=0;
							ll w00=(m-z)*(m-z-1)/2;
							dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][0])%mod;
							dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][0])%mod;
							dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][0])%mod;
							
						}
						{//1->1
							ll w11=z*z;
							ll w10=z*(m-z);
							ll w01=(m-z)*z;
							ll w00=(m-z)*(m-z);
							dp[i][j+1][k+1][1]=(dp[i][j+1][k+1][1]+w11*dp[i-1][j][k][1])%mod;
							dp[i][j+1][k+0][1]=(dp[i][j+1][k+0][1]+w10*dp[i-1][j][k][1])%mod;
							dp[i][j+0][k+1][1]=(dp[i][j+0][k+1][1]+w01*dp[i-1][j][k][1])%mod;
							dp[i][j+0][k+0][1]=(dp[i][j+0][k+0][1]+w00*dp[i-1][j][k][1])%mod;
						}
					}
				}
			}
			for(int i=mid; i<=n ;i++){
				for(int j=0; j<mid ;j++){
					ans=(ans+mod-dp[n][i][j][1])%mod;
				}
			}
			//cout << ans << '\n';
		}
	}
	cout << ans << '\n';
}
1 Like