CROSSPATH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: gunpoint_88
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

2930

PREREQUISITES:

Dynamic programming, prefix sums

PROBLEM:

You’re given N and M.
First, a random path from (1, 1) to (N, M) is chosen, moving only down or right each time.
Then, Bob moves from (1, M) to (N, 1), via down or left steps.
He gets one gold coin each time he visits a cell that was on the original path.

What’s the expected maximum number of coins he’ll receive?

EXPLANATION:

First, an edge case: if N = M = 1, the answer is obviously 1.

Now, let’s look at the general case.
Suppose the path from (1, 1) to (N, M) is fixed.
Then, Bob receives one coin for each cell from this path he’s able to visit - in other words, he receives coins equal to the size of the intersection between the two paths.

Observe that this intersection can only be a vertical or horizontal segment.

Proof

If the intersection is of size 1, the condition is trivially true; so now suppose there are at least two cells.
Among all the cells in the intersection, let (x_1, y_1) denote the lexicographically smallest among them, and (x_2, y_2) denote the largest.

We have x_1 \leq x_2.

If x_1 \lt x_2,

  • The cells belong to the (1, 1) \to (N, M) path, so y_1 \leq y_2 must hold.
  • The cells belong to the (1, M)\to (N, 1) path, so y_1\geq y_2 must hold.

So, y_1 = y_2 must hold, i.e, both cells belong to the same column.
Clearly then, everything between them will also belong to the intersection; and since we chose the lexicographically smallest and largest elements, nothing else can possibly belong to the intersection.
So, the intersection is a vertical segment, as claimed.

If x_1 = x_2 instead, then we must have y_1 \lt y_2.
In this case, similarly we see that the intersection is exactly the horizontal segment between (x_1, y_1) and (x_2, y_2).

These are the only possibilities, and hence our claim is proved.

In particular, if the (1, 1) \to (N, M) path is fixed, Bob will clearly choose the longest horizontal/vertical segment in it.
So, if there are c_x paths from (1, 1) \to (N, M) whose longest segment has length exactly x, they contribute a total of x\cdot \frac{c_x}{T} to the expectation (where T is the total number of paths).

Clearly, x \leq \max(N, M) so if we’re able to compute all the c_x values, we’d be done.

Let’s fix x and try to compute c_x.
This can be done with the help of dynamic programming.
Let f_x(i, j, 0) denote the number of paths from (1, 1) to (i, j) such that the longest segment has length at most x, and the last segment of the path was downwards.
Similarly, let f_x(i, j, 1) be the same thing, but where the last segment of the path was rightwards.
Then, fixing the length of the last segment, we have

f_x(i, j, 0) = \sum_{k = 1}^x f_x(i-k, j, 1) \\ f_x(i, j, 1) = \sum_{k = 1}^x f_x(i, j-k, 0)

This gives us a dynamic programming solution in \mathcal{O}(NMx), which is easily sped up to \mathcal{O}(NM) by maintaining prefix sums for each row/column (since each transition is the sum of dp values of some range of a row/column).

Finally, the number of paths reaching (N, M) whose longest segment is of length at most x, is f_x(N, M, 0) + f_x(N, M, 1).
To obtain the number of paths whose longest segment has length exactly x, subtract from it the number of paths whose longest segment has length \leq x-1 (which is just f_{x-1}(N, M, 0) + f_{x-1}(N, M, 1)).

TIME COMPLEXITY:

\mathcal{O}(NM\cdot\max(N, M)) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll mod=1e9+7;

ll fastexp(ll base,ll exp) {
	ll res=1;
	base%=mod;
	while(exp) {
		if(exp&1) res=(res*base)%mod;
		base=(base*base)%mod;
		exp>>=1;
	}
	return res;
}
 
ll modinv(ll num) {
	return fastexp(num, mod-2);
}

ll solution(ll n,ll m) {
	if(n==1 && m==1) return 1;
	ll ans=0,prev=0,tot=0;
	for(ll k=2;k<max(n,m)+1;k++) {
		vector<vector<ll>> dp(n,vector<ll>(m,0)),dpr=dp,dpd=dp;
		for(ll i=0;i<m;i++) {
			dp[0][i]=(i<k);
			dpr[0][i]=(i+1<k);
			dpd[0][i]=(i<k);
		}
		for(ll i=0;i<n;i++) {
			dp[i][0]=(i<k);
			dpd[i][0]=(i+1<k);
			dpr[i][0]=(i<k);
		}
		for(ll i=1;i<n;i++) {
			for(ll j=1;j<m;j++) {
				dpd[i][j]=dpr[i][j]=dp[i][j]=(dpd[i-1][j]+dpr[i][j-1])%mod;
				if(i-(k-1)>=0) {
					dpd[i][j]=(dpd[i][j]-dpr[i-(k-1)][j-1]+mod)%mod;
				}
				if(j-(k-1)>=0) {
					dpr[i][j]=(dpr[i][j]-dpd[i-1][j-(k-1)]+mod)%mod;
				}
			}
		}
		ll ways=(dp[n-1][m-1]-prev+mod)%mod;
		prev=dp[n-1][m-1];
		ans=(ans+ways*k)%mod;
		tot=(tot+ways)%mod;
	}
	ll res=(ans*modinv(tot)%mod);
	return res;
}

int main() {
    ll t; cin>>t;
    while(t--) {
        ll n,m; cin>>n>>m;
        cout<<solution(n,m)<<"\n";
    }
	return 0;
}
Tester's code (Python)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

constexpr int mod = (int)1e9 + 7;
struct mi {
    int64_t v; explicit operator int64_t() const { return v % mod; }
    mi() { v = 0; }
    mi(int64_t _v) {
        v = (-mod < _v && _v < mod) ? _v : _v % mod;
        if (v < 0) v += mod;
    }
    friend bool operator==(const mi& a, const mi& b) {
        return a.v == b.v; }
    friend bool operator!=(const mi& a, const mi& b) {
        return !(a == b); }
    friend bool operator<(const mi& a, const mi& b) {
        return a.v < b.v; }

    mi& operator+=(const mi& m) {
        if ((v += m.v) >= mod) v -= mod;
        return *this; }
    mi& operator-=(const mi& m) {
        if ((v -= m.v) < 0) v += mod;
        return *this; }
    mi& operator*=(const mi& m) {
        v = v*m.v%mod; return *this; }
    mi& operator/=(const mi& m) { return (*this) *= inv(m); }
    friend mi pow(mi a, int64_t p) {
        mi ans = 1; assert(p >= 0);
        for (; p; p /= 2, a *= a) if (p&1) ans *= a;
        return ans;
    }
    friend mi inv(const mi& a) { assert(a.v != 0);
        return pow(a,mod-2); }

    mi operator-() const { return mi(-v); }
    mi& operator++() { return *this += 1; }
    mi& operator--() { return *this -= 1; }
    mi operator++(int32_t) { mi temp; temp.v = v++; return temp; }
    mi operator--(int32_t) { mi temp; temp.v = v--; return temp; }
    friend mi operator+(mi a, const mi& b) { return a += b; }
    friend mi operator-(mi a, const mi& b) { return a -= b; }
    friend mi operator*(mi a, const mi& b) { return a *= b; }
    friend mi operator/(mi a, const mi& b) { return a /= b; }
    friend ostream& operator<<(ostream& os, const mi& m) {
        os << m.v; return os;
    }
    friend istream& operator>>(istream& is, mi& m) {
        int64_t x; is >> x;
        m.v = x;
        return is;
    }
    friend void __print(const mi &x) {
        cerr << x.v;
    }
};


bool prime(int s) {
    for(int i = 2 ; i * i <= s ; i++) {
        if(s % i == 0)  return false;
    }
    return true;
}

constexpr int N = 251;
mi dp[N][N][2];
mi pf[N][N][2];

void init(int n, int m) {
    int K = max(n, m);
    for(int i = 0 ; i <= n ; i++)
        for(int j = 0 ; j <= m ; j++)
            for(int last = 0 ; last < 2 ; last++)
                dp[i][j][last] = pf[i][j][last] = 0;
}

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;
    int sum_n = 0, sum_m = 0;
    int T = input.readInt(1, 100);  input.readEoln();
    // int T;  cin >> T;
    while(T--) {
        int n = input.readInt(1, 250);  input.readSpace();
        int m = input.readInt(1, 250);  input.readEoln();

        // int n, m;   cin >> n >> m;
        sum_n += n, sum_m += m;
        int K = max(n, m);
        if(n == 1 || m == 1) {
            cout << max(n, m) << '\n';
            continue;
        }

        vector<mi> val(K + 1);
        for(int k = 2 ; k <= K ; k++) {
            init(n, m);
            pf[1][1][1] = pf[1][1][0] = 1;
            for(int row = 1 ; row <= n ; row++) {
                for(int col = 1 ; col <= m ; col++) {
                    if(row == 1 && col == 1)    continue;
                    dp[row][col][1] = pf[row][col - 1][0] - pf[row][max(0, col - k)][0];
                    pf[row][col][1] = pf[row - 1][col][1] + dp[row][col][1];

                    dp[row][col][0] = pf[row - 1][col][1] - pf[max(0, row - k)][col][1];
                    pf[row][col][0] += pf[row][col - 1][0] + dp[row][col][0];
                }
            }
            val[k] = dp[n][m][1] + dp[n][m][0];
        }

        mi res = 0, div = val[K];
        for(int i = K ; i >= 1 ; i--) {
            val[i] -= val[i - 1];
            res += i * val[i];
        }
        cout << res * inv(div) << '\n';
    }
    assert(sum_n <= 250 && sum_m <= 250);

    input.readEof();

    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7

dp = [ [[0, 0] for _ in range(300)] for _ in range(300)]
pref = [ [[0, 0] for _ in range(300)] for _ in range(300)]
for _ in range(int(input())):
    n, m = map(int, input().split())
    if n == 1 and m == 1:
        print(1)
        continue
    
    ans, prv = 0, 0
    for k in range(1, max(n, m)+1):
        dp[1][1] = pref[1][1] = [1, 1]
        for i in range(1, n+1):
            for j in range(1, m+1):
                if i == 1 and j == 1: continue
                dp[i][j][0] = (pref[i-1][j][1] - pref[max(0, i-k)][j][1]) % mod
                pref[i][j][0] = (dp[i][j][0] + pref[i][j-1][0]) % mod
                
                dp[i][j][1] = (pref[i][j-1][0] - pref[i][max(0, j-k)][0]) % mod
                pref[i][j][1] = (dp[i][j][1] + pref[i-1][j][1]) % mod
        ans += k * (dp[n][m][0] + dp[n][m][1] - prv) % mod
        prv = dp[n][m][0] + dp[n][m][1]
    print(ans * pow(prv, mod-2, mod) % mod)
1 Like

Nice problem and nice approach of solution i really liked it