TREHUNT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Aryan Agarwala
Tester: Istvan Nagy
Editorialist: Aman Dwivedi

DIFFICULTY:

Easy - Medium

PREREQUISITES:

Maths, Observations

PROBLEM:

You are given an N \times M grid and an integer K, you have to find different pairs of cells such that the Manhattan distance between them is exactly K.

Let A_k be the number of desired pairs when the value of the Manhattan distance between the two cells is equal to K. Let C = \sum_{i=1}^{N+M-2} (A_i \times 31^{i-1}). You have to find the value of CC.

The answer may be large, so you need to find it modulo 998244353.

Note: The Manhattan distance between two points (x_1,y_1) and (x_2,y_2) is defined as |x_1−x_2|+|y_2−y_1|.

EXPLANATION:

Let’s try to think of a brute force approach using some examples.

Let’s take N=2 and M=3.
Now following are the points between whom the distance is 1.

  • (1,1) and (1,2)
  • (1,2) and (1, 3)
  • (2,2) and (2,3)
  • (2 1) and (2,2)
  • (1 1) and (2, 1)
  • (1,2) and (2,2)
  • (1,3) and (2,3)

The order in which the pair of points listed above have some reasons. Can you find it?

Hint

Look at the difference between their X and Y coordinates.

Answer

Let i be the distance in the X component then the difference in the Y component will be j where j= K - i.

If we take points:

  • (1,1) and (1,2)
  • (1,2) and (1, 3)
  • (2,2) and (2,3)
  • (2 1) and (2,2)

Then we can see that the difference between X and Y coordinates is 0 and 1 respectively.

And similary if we take points:

  • (1 1) and (2, 1)
  • (1,2) and (2,2)
  • (1,3) and (2,3)

Then we can see that the difference between X and Y coordinates is 1 and 0 respectively.

Now we know that if we iterate i from 0 to K_l inclusive and sum the results then we can get A_{K_l} where 1 \leq K_l \leq N+M-2.

But what to add while iterating i from 0 to K_l to get A_{K_l}?

If we can observe that for the above point when the difference between X coordinates was 0 and the difference between Y coordinates was 1. Then we have a total of 4 pairs of points.

Therefore, we can say that there are N rows so (N - i) possibilities for the row of this cell. Now, the other cell has to be i units to the right, so the X coordinate for the other cell is fixed. And there are (M-j) possibilities for Y coordinate when it is above our first cell and if i \neq0 and j \neq 0, then we also have (M-j) possibilities if it is below our first cell.

Therefore we have (N-i) \times (M-j) \times 2 when i \neq0 and j \neq 0, else (N-i) \times (M-j) .

Pseudo Code - Brute Force
int mult(int a, int b) {
  return (a * b) % MOD;
}
int add(int a, int b) {
  return (a + b) % MOD;
}

int32_t main() {

  int t;
  cin >> t;
  while (t--) {
    int n, m;
    cin >> n >> m;
    int tr = 0;
    int tm = 1;
    for (int k = 1; k <= (n + m - 2); k++) {
      int ans = 0;
      for (int x = 0; x <= k; x++) {
        int y = k - x;
        if (x >= n || x < 0 || y >= m || y < 0) continue;
        int ta = mult((n - x), (m - y));
        if (x != 0 && y != 0) ta = mult(ta, 2);
        ans = add(ans, ta);
      }
      tr += ((ans * tm) % MOD);
      tr %= MOD;
      tm *= 31;
      tm %= MOD;
    }
    cout << tr << endl;
  }
}

Now let’s try to optimize the Brute force method. In order to optimize we will replace the nested loop with some formula.
Currently inorder to calculate A_{K_l} for each i from 0 to K_l, with formula inside (N-i) \times (M-j). If we expand this formula we will get N \times (M-K) + (N-M-K) \times i + i^2.

  • Now we can easily calculate (N-M-K) \times i = (N-M-K) \times \sum_{i=0}^{K}i = (N-M-K) \times \frac{ (i)*(i+1)}{2}.
  • And \sum_{i=0}^{K}i^2 = \frac{i*(i+1)*(2*i + 1)}{6}

Now we remove the nested loop and get the sum in O(1).

TIME COMPLEXITY:

O(N+M-2) per test case.

SOLUTIONS:

Author
#include <bits/stdc++.h>
#define int long long
//#include <sys/resource.h>
#define initrand mt19937 mt_rand(time(0));
#define rand mt_rand()
#define MOD 1000000007
#define INF 1000000000
#define mid(l, u) ((l+u)/2)
#define rchild(i) (i*2 + 2)
#define lchild(i) (i*2 + 1)
#define mp(a, b) make_pair(a, b)
#define lz lazup(l, u, i);
using namespace std;
int mult(int a, int b){
    return (a*b)%MOD;
}
int add(int a, int b){
    return (a+b)%MOD;
}
int sub(int a, int b){
    return (((a-b)%MOD)+MOD)%MOD;
}
int calcAns(int n, int m, int k, int x){
    int tr = 0;
    tr = add(tr, mult(x+1, mult(n, m)));
    tr = sub(tr, mult(x+1, mult(n, k)));
    tr = add(mult(mult(mult(x, x+1), 500000004), n), tr);
    tr = add(mult(mult(mult(x, x+1), 500000004), k), tr);
    tr = sub(tr, mult(mult(mult(x, x+1), 500000004), m));
    tr = sub(tr, mult(mult(mult(x, x+1), 2*x + 1), 166666668));
    tr = mult(tr, 2);
    return tr;
}
signed main(){
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int t;
    cin>>t;
    while(t--) {
        int n, m;
        cin>>n>>m;
        int tr = 0;
        int tm = 1;
        for (int k = 1; k <= (n + m - 2); k++) { //we need k-i to be less than or equal to (m-1), implies
            int minv = max(0ll, k - m + 1);
            int maxv = min(k, n - 1);
            int ans = calcAns(n, m, k, maxv);
            if (minv > 0) ans = sub(ans, calcAns(n, m, k, minv - 1));
            if (minv == 0) ans = sub(ans, mult(n, m-k));
            if (minv <= k && maxv >= k) {
                ans = sub(ans, mult(n - k, m));
            }
            tr += ((ans * tm) % MOD);
            tr %= MOD;
            tm *= 31;
            tm %= MOD;
        }
        cout << tr << endl;
    }
}

Tester
// created by mtnshh

#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define pb push_back
#define rb pop_back
#define ti tuple<int, int, int>
#define pii pair<int, int>
#define pli pair<ll, int>
#define pll pair<ll, ll>
#define mp make_pair
#define mt make_tuple
 
#define rep(i,a,b) for(ll i=a;i<b;i++)
#define repb(i,a,b) for(ll i=a;i>=b;i--)
 
#define err() cout<<"--------------------------"<<endl; 
#define errA(A) for(auto i:A)   cout<<i<<" ";cout<<endl;
#define err1(a) cout<<#a<<" "<<a<<endl
#define err2(a,b) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<endl
#define err3(a,b,c) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<endl
#define err4(a,b,c,d) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<" "<<#d<<" "<<d<<endl

#define all(A)  A.begin(),A.end()
#define allr(A)    A.rbegin(),A.rend()
#define ft first
#define sd second

#define V vector<ll>
#define S set<ll>
#define VV vector<V>
#define Vpll vector<pll>
 
#define endl "\n"

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        // char g = getc(fp);
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
            
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
            // cerr << x << " " << l << " " << r << endl;
            assert(l<=x && x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        // char g=getc(fp);
        assert(g != -1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}

const int max_q = 5;
const int max_n = 1e7;

const ll N = 200005;
const ll INF = 1e12;
const ll M = 998244353;

ll power(ll a,ll n,ll m=M){
    ll ans=1;
    while(n){
        if(n&1) ans=ans*a;
        a=a*a;
        n=n>>1;
        ans=ans%m;
        a=a%m;
    }
    return ans;
}

void solve(ll n, ll m){
    V ans(n+m);
    ll inv_six = power(6, M-2, M);
    rep(i,0,n+m-1){
        ll cnt = i;
        ll r_n = cnt, r_m = cnt, l_n = 0, l_m = 0;
        if(r_n >= n){
            r_n = n - 1;
            l_m = cnt - r_n;
        }
        if(r_m >= m){
            r_m = m - 1;
            l_n = cnt - r_m;
        }
        if(l_m == 0){
            l_m = 1;
            r_n -= 1;
        }
        if(l_n == 0){
            l_n = 1;
            r_m -= 1;
        }
        if(l_n > r_n or l_m > r_m)
            continue;
        ll x = n - l_n;
        ll y = m - r_m;
        ll cnt_elem = (r_n - l_n);
        ll res_1 = (x * y) % M;
        ll res_2 = ((cnt_elem * (cnt_elem + 1) / 2)) % M;
        ll res_3 = ((((cnt_elem * (cnt_elem + 1)) % M * (2 * cnt_elem + 1)) % M * inv_six)) % M;
        ll z = ((cnt_elem + 1) * res_1 - (y - x) * res_2 - res_3) % M;
        z = (z + M) % M;
        cnt += 1;
        ans[i] = 2*z;
    }
    rep(i,1,n){
        ans[i] = (ans[i] + (n-i)*m) % M;
    }
    rep(i,1,m){
        ans[i] = (ans[i] + n*(m-i)) % M;
    }
    ll final = 0, p = 1;
    rep(i,1,n+m-1){
        final = (final + (ans[i] * p)) % M;
        p = (p * 31) % M;
    }
    cout << final << endl;
}

int main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    ll q = readIntLn(1, max_q);
    ll sum_n = 0, sum_m = 0;
    while(q--){
        ll n = readIntSp(1, max_n), m = readIntLn(1, max_n);
    	solve(n, m);
        sum_n += n;
        sum_m += m;
    }
    assert(sum_n <= max_n and sum_m <= max_n);
    assert(getchar()==-1);   
}

1 Like

Nice explanation :raised_hands: