AVOIDWALK - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iceknight1093
Tester: mridulahi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dynamic programming

PROBLEM:

There’s an N\times M grid, on which you want to move from (1, 1) to (N, M) by moving right or down each step.
It costs A_{i, j} to move onto cell (i, j).
Further, some cells have birds on them, which will attack you if you’re on an adjacent cell. Each one incurs a time loss of K.

Find the shortest time required to reach (N, M).

EXPLANATION:

The setup (a grid with right/down moves and a minimization objective) screams dynamic programming, which is exactly what we’ll use for the solution.

Let f(i, j) denote the minimum time required to reach cell (i, j).
When we reach this cell, there are a few things that happen:

  • A cost of A_{i, j} is always incurred.
  • Birds at (i+1, j) and (i, j+1), if they exist, will attack us.
    Note that it is impossible for these birds to have attacked us earlier on the path.
  • If there’s a bird on (i-1, j) that hasn’t yet attacked us, it will attack us now.
    In particular, this can only happen when moving right from (i, j-1), since if moving down from above this bird would definitely have attacked already.
  • Similarly, if there’s a bird on (i, j-1) that hasn’t yet attacked us, it will attack us now.
    This time, this case only applies when moving down, and not when moving right.

So, to accurately compute transitions, we need the following information:

  • When moving right from (i, j-1) to (i, j), has the bird at (i-1, j) attacked us already (if it exists)?
  • When moving down from (i-1, j), has the bird at (i, j-1) attacked us already (if it exists)?

With this in mind, let’s augment our states a bit.
When moving away from (i, j), we need to know the states of birds at (i-1, j+1) (for the right movement), and (i+1, j-1) (for the down movement).
So, let’s define f(i, j, k) to be the minimum cost to end up at (i, j), where:

  • k = 0 means that both the birds at (i-1, j+1) and (i+1, j-1) don’t exist (or have attacked already).
  • k = 1 means that the bird at (i-1, j+1) still exists.
  • k = 2 means that the bird at (i+1, j-1) still exists.

Computing transitions now is relatively straightforward:

  • When moving rightwards from (i, j-1) to (i, j):
    • States (i, j-1, 0) and (i, j-1, 2) result in no extra cost, and contribute to either (i, j, 0) or (i, j, 1) (depending on whether there’s a bird at (i-1, j+1).
    • State (i, j-1, 1) adds an extra K to the cost, and again contributes to either (i, j, 0) or (i, j, 1).
  • Similarly, when moving downwards,
    • States (i-1, j, 0) and (i-1, j, 1) have no extra cost, and contribute to either (i, j, 0) or (i, j, 2).
    • State (i-1, j, 2) adds an extra K and contributes to either (i, j, 0) or (i, j, 2).

This way, we have \mathcal{O}(N\cdot M) states and \mathcal{O}(1) transitions from each, leading to a solution in \mathcal{O}(N\cdot M) overall.

TIME COMPLEXITY:

\mathcal{O}(N \cdot M) per testcase.

CODE:

Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long

const int inf = 1e18;

vector<array<int, 2>> adj = {{-1, 0}, {0, -1}, {1, 0}, {0, 1}};

signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);

        int t;
        cin >> t;

        while (t--) {

                int n, m, k;
                cin >> n >> m >> k;
                int a[n][m];
                for (int i = 0; i < n; i++) {
                        for (int j = 0; j < m; j++) cin >> a[i][j];
                }

                string s[n];
                for (int i = 0; i < n; i++) cin >> s[i];

                vector<vector<array<int, 2>>> dp(n, vector<array<int, 2>>(m, {inf, inf}));     
                
                dp[0][0][0] = 0;
                dp[0][0][1] = 0;
                
                for (int i = 0; i < n; i++) {
                        for (int j = 0; j < m; j++) {
                                
                                if (i + 1 < n) {
                                        int f = 0;
                                        if (j + 1 < m && s[i + 1][j + 1] == '1') f += k;
                                        if (i + 2 < n && s[i + 2][j] == '1') f += k;
                                        dp[i + 1][j][1] = min(dp[i + 1][j][1], dp[i][j][0] + a[i + 1][j] + f);
                                        if (j - 1 >= 0 && s[i + 1][j - 1] == '1') f += k;
                                        dp[i + 1][j][1] = min(dp[i + 1][j][1], dp[i][j][1] + a[i + 1][j] + f);
                                }
                                if (j + 1 < m) {
                                        int f = 0;
                                        if (i + 1 < n && s[i + 1][j + 1] == '1') f += k;
                                        if (j + 2 < m && s[i][j + 2] == '1') f += k;
                                        dp[i][j + 1][0] = min(dp[i][j + 1][0], dp[i][j][1] + a[i][j + 1] + f);
                                        if (i - 1 >= 0 && s[i - 1][j + 1] == '1') f += k;
                                        dp[i][j + 1][0] = min(dp[i][j + 1][0], dp[i][j][0] + a[i][j + 1] + f);
                                }
                        }
                        
                }

                cout << min(dp[n - 1][m - 1][0], dp[n - 1][m - 1][1]) << "\n";


        }
        
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

const int N = 1005;
const ll inf = 1e18;
ll dp[N][N][4];
int a[N][N], mark[N][N];

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    for (int i = 0; i < N; ++i) for (int x = 0; x < 4; ++x)
        dp[0][i][x] = dp[i][0][x] = inf;

    int t; cin >> t;
    while (t--) {
        ll n, m, k; cin >> n >> m >> k;
        for (int i = 1; i <= n; ++i)
            for (int j = 1; j <= m; ++j)
                cin >> a[i][j];
        for (int i = 1; i <= n+1; ++i) {
            string s;
            if (i <= n) cin >> s;
            else s = string(m, '0');
            for (int j = 1; j <= m; ++j) mark[i][j] = s[j-1] == '1';
            mark[i][m+1] = 0;
        }
        for (int i = 1; i <= n; ++i)
            for (int j = 1; j <= m; ++j)
                for (int x = 0; x < 4; ++x)
                    dp[i][j][x] = inf;

        dp[1][1][0] = 0;
        for (int i = 1; i <= n; ++i) for (int j = 1; j <= m; ++j) {
            if (i == 1 and j == 1) continue;

            ll add = k*mark[i+1][j] + k*mark[i][j+1] + a[i][j]; // Always taken

            // From left -> either 0 or 2
            if (mark[i-1][j+1]) dp[i][j][2] = min(dp[i][j-1][0] + add, dp[i][j-1][2] + add + k);
            else dp[i][j][0] = min(dp[i][j-1][0] + add, dp[i][j-1][2] + add + k);

            // From up -> either 0 or 1
            if (mark[i+1][j-1]) dp[i][j][1] = min(dp[i-1][j][0] + add, dp[i-1][j][1] + add + k);
            else dp[i][j][0] = min(dp[i][j][0], min(dp[i-1][j][0] + add, dp[i-1][j][1] + add + k));
        }
        cout << dp[n][m][0] << '\n';
    }
}
1 Like