SPREADCT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: munch_01
Preparation: iceknight1093
Tester: mridulahi
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Basic math

PROBLEM:

You have an N\times M grid. One cell is colored black, then the following process repeats:

  • All cells adjacent to a black cell themselves turn black.

For starting cell (x, y), f(x, y) denotes the time required for the entire grid to have black cells.
Find the sum of x\cdot y across all (x, y) with minimum f(x, y).

EXPLANATION:

Let’s first try to figure out which cells have minimum f(x, y).

Without loss of generality, let N \leq M; if not, you can swap them and the answer won’t change.

Consider starting cell (x, y).
Then, to reach another cell (x', y') requires exactly \max(|x-x'|, |y-y'|) seconds.
This means that the horizontal and vertical components of movement are, in a sense, independent of each other: if we’ve reached some cell in row i (meaning at least |x-i| seconds have passed), and some cell in column j (at least |y-j| seconds have passed), then we’ve definitely reached cell (i, j), since at least \max(|x-i|, |y-j|) seconds have passed.

That is, f(x, y) is essentially the maximum of the times needed for us to reach the four “borders” of the grid: rows 1 and N, and columns 1 and M.
So, f(x, y) = \max(x-1, N-x, y-1, M-y). Our objective is to find cells that minimize this value.

Intuitively, these cells should be near the middle of the grid.
A bit more formally, notice that \max(x-1, N-x) \geq \left\lfloor \frac{N}{2} \right\rfloor, and similarly
\max(y-1, M-y) \geq \left\lfloor \frac{M}{2} \right\rfloor.

Attaining these bounds is quite easy too: choose x = \left\lceil \frac{N}{2} \right\rceil and y = \left\lceil \frac{M}{2} \right\rceil, for instance - visually, this corresponds to choosing cells from the middle of the grid, which matches our intuition.

Now, recall that we’re considering only N \leq M.
That means the minimum possible f(x, y) is \left\lfloor \frac{M}{2} \right\rfloor, we just need to figure out exactly which cells attain it.

Let’s look at a couple of cases.

Odd M

Suppose M = 2k+1.
Then, the minimum time needed is \left\lfloor \frac{M}{2} \right\rfloor = k.
Further, the only cells that can attain this are ones in the “middle” column, i.e, column k+1.
Anything to the left or right will require extra time.

This leaves us with a choice of rows.
If we choose row x, it must satisfy the following conditions:

  • x-k \leq 1 (i.e, the first row must be reached in time).
    This translates to x \leq k+1.
  • x+k \geq N (the last row must be reached in time).
    This translates to x \geq N - k.
  • Of course, 1 \leq x \leq N must also hold.

So, we’re free to choose any x such that \max(1, N-k) \leq x \leq \min(N, 1+k).
y = k+1 is fixed for all of them.

The answer is thus k+1 multiplied by the sum of x in that range.
The sum of integers from L to R is given by
\frac{R\cdot (R+1)}{2} - \frac{L\cdot (L-1)}{2}
so the answer can be found in \mathcal{O}(1) time.

Even M

Suppose M = 2k.
Then, the minimum time required is k, and this time, either of the two middle columns can be chosen, i.e, columns k or k+1.

The analysis for the choice of x remains exactly the same, and we see that
\max(1, N-k) \leq x \leq \min(N, 1+k)
should hold.

Once again, the sum of x in this range can be found in \mathcal{O}(1) time.
Multiply this value by k + (k+1), to account for the choice of column.

TIME COMPLEXITY:

\mathcal{O}(1) per testcase.

CODE:

Tester's code (C++)
#include<bits/stdc++.h>

using namespace std;

#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long

const int mod = 998244353;

signed main() {

        ios::sync_with_stdio(0);
        cin.tie(0);
        cout.tie(0);

        int t;
        cin >> t;

        while (t--) {

                int n, m;
                cin >> n >> m;
                if (n < m) swap(n, m);

                int d = n / 2;
                int r = min(d + 1, m);
                int l = max(1ll, m - d);
                int sm = (r - l + 1) * (l + r) / 2 % mod;
                
                int ans = (n + 1) / 2 * sm % mod;
                if (n % 2 == 0) ans += (n + 2) / 2 * sm % mod;
                ans %= mod;

                cout << ans << "\n";

        }
        
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
    n, m = map(int, input().split())
    if n > m: n, m = m, n
    req = m//2
    # x+req >= n and x-req <= 1
    lo, hi = max(1, n - req), min(n, 1 + req)
    ans = (hi*(hi+1)//2 - lo*(lo-1)//2) % mod
    if m%2 == 0: ans *= m + 1
    else: ans *= (m+1)//2
    print(ans % mod)