PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: munch_01
Preparation: iceknight1093
Tester: mridulahi
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Basic math
PROBLEM:
You have an N\times M grid. One cell is colored black, then the following process repeats:
- All cells adjacent to a black cell themselves turn black.
For starting cell (x, y), f(x, y) denotes the time required for the entire grid to have black cells.
Find the sum of x\cdot y across all (x, y) with minimum f(x, y).
EXPLANATION:
Let’s first try to figure out which cells have minimum f(x, y).
Without loss of generality, let N \leq M; if not, you can swap them and the answer won’t change.
Consider starting cell (x, y).
Then, to reach another cell (x', y') requires exactly \max(|x-x'|, |y-y'|) seconds.
This means that the horizontal and vertical components of movement are, in a sense, independent of each other: if we’ve reached some cell in row i (meaning at least |x-i| seconds have passed), and some cell in column j (at least |y-j| seconds have passed), then we’ve definitely reached cell (i, j), since at least \max(|x-i|, |y-j|) seconds have passed.
That is, f(x, y) is essentially the maximum of the times needed for us to reach the four “borders” of the grid: rows 1 and N, and columns 1 and M.
So, f(x, y) = \max(x-1, N-x, y-1, M-y). Our objective is to find cells that minimize this value.
Intuitively, these cells should be near the middle of the grid.
A bit more formally, notice that \max(x-1, N-x) \geq \left\lfloor \frac{N}{2} \right\rfloor, and similarly
\max(y-1, M-y) \geq \left\lfloor \frac{M}{2} \right\rfloor.
Attaining these bounds is quite easy too: choose x = \left\lceil \frac{N}{2} \right\rceil and y = \left\lceil \frac{M}{2} \right\rceil, for instance - visually, this corresponds to choosing cells from the middle of the grid, which matches our intuition.
Now, recall that we’re considering only N \leq M.
That means the minimum possible f(x, y) is \left\lfloor \frac{M}{2} \right\rfloor, we just need to figure out exactly which cells attain it.
Let’s look at a couple of cases.
Odd M
Suppose M = 2k+1.
Then, the minimum time needed is \left\lfloor \frac{M}{2} \right\rfloor = k.
Further, the only cells that can attain this are ones in the “middle” column, i.e, column k+1.
Anything to the left or right will require extra time.
This leaves us with a choice of rows.
If we choose row x, it must satisfy the following conditions:
- x-k \leq 1 (i.e, the first row must be reached in time).
This translates to x \leq k+1. - x+k \geq N (the last row must be reached in time).
This translates to x \geq N - k. - Of course, 1 \leq x \leq N must also hold.
So, we’re free to choose any x such that \max(1, N-k) \leq x \leq \min(N, 1+k).
y = k+1 is fixed for all of them.
The answer is thus k+1 multiplied by the sum of x in that range.
The sum of integers from L to R is given by
\frac{R\cdot (R+1)}{2} - \frac{L\cdot (L-1)}{2}
so the answer can be found in \mathcal{O}(1) time.
Even M
Suppose M = 2k.
Then, the minimum time required is k, and this time, either of the two middle columns can be chosen, i.e, columns k or k+1.
The analysis for the choice of x remains exactly the same, and we see that
\max(1, N-k) \leq x \leq \min(N, 1+k)
should hold.
Once again, the sum of x in this range can be found in \mathcal{O}(1) time.
Multiply this value by k + (k+1), to account for the choice of column.
TIME COMPLEXITY:
\mathcal{O}(1) per testcase.
CODE:
Tester's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define all(x) begin(x), end(x)
#define sz(x) static_cast<int>((x).size())
#define int long long
const int mod = 998244353;
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while (t--) {
int n, m;
cin >> n >> m;
if (n < m) swap(n, m);
int d = n / 2;
int r = min(d + 1, m);
int l = max(1ll, m - d);
int sm = (r - l + 1) * (l + r) / 2 % mod;
int ans = (n + 1) / 2 * sm % mod;
if (n % 2 == 0) ans += (n + 2) / 2 * sm % mod;
ans %= mod;
cout << ans << "\n";
}
}
Editorialist's code (Python)
mod = 998244353
for _ in range(int(input())):
n, m = map(int, input().split())
if n > m: n, m = m, n
req = m//2
# x+req >= n and x-req <= 1
lo, hi = max(1, n - req), min(n, 1 + req)
ans = (hi*(hi+1)//2 - lo*(lo-1)//2) % mod
if m%2 == 0: ans *= m + 1
else: ans *= (m+1)//2
print(ans % mod)