MASTRIAN - Editorial (Kodeathon 15.2)

PROBLEM LINK:

Practice

Contest

Author: Satyam Gupta

Tester: Pulkit Sharma

Editorialists: Satyam Gupta , Pulkit Sharma

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Math

PROBLEM:

You have a n x m grid of cells. Where each cell is square shaped and has a line drawn on it from the bottom left to the

top right corner of the cell. The task is to count the number of triangles (of all possible sizes) that will be formed inside

the grid.

EXPLANATION:

Let TR_i be the triangle with base or height equal to i cells.

alt text

Here, the blue triangle is a TR_1 triangle and the red triangle is a TR_2 triangle.

For now, we will only count the triangles below the diagonal and later we will multiply the answer with 2, to account for the triangles above the diagonal

Let’s take an example of a 2 x 3 grid.

TR_1 triangles are shown in red:

enter image description here

No. of TR_1 triangles : 2*3 = 6 (excluding triangles above the diagonal, later we will multiply with 2 to account for

triangles above the diagonal)

TR_2 triangles:

enter image description here

No. of $TR_2$triangles : 1*2 = 2

So, No. of TR_i triangles turns out to be : (n-(i-1)) * (m-(i-1))

Since, a TR_i triangle needs size of the grid to be at least i x i , the largest triangle that can fit into a grid will

have base or height equal to minimum of n and m.
Hence, the largest triangle would be TR_{min(n,m)}

Let, x=min(n,m)
So, we need to calculate no. of triangles for T_i, where i goes from 1 to x.

Let, ans be our final answer.

ans = TR_1 + TR_2 + \ldots + TR_x

ans = (n*m) + ( (n-1)*(m-1)) + \ldots + ( (n-(x-1)) * (m-(x-1)) )

Here is a code for this:

ans=0;
x=min(n,m);
for(i=1;i<=x;i++)
{
long long int y = (n-(i-1)) * (m-(i-1));
ans= ans + y;
}
ans= 2*ans; //For taking in the count of triangles above the diagonal

But since this code runs in O(min(n,m)) , it won’t pass the time limit as n and m can go up to 10^{12}.

Therefore, me must find some other way to calculate the solution.

Deriving the formula

We can substitute m with (n+b), where b=m-n. Now,

ans = (n*(n+b)) + ( (n-1)*( n+b-1)) + \ldots + ( (n-(x-1)) * (n+b-(x-1)) )

ans = \sum_{i = n-(x-1)}^n i*(i+b)

ans = \sum_{i = n-(x-1)}^n i^2+i*b

ans = \sum_{i = n-(x-1)}^n i^2 \ + \sum_{i = n-(x-1)}^n i*b \

Breaking first summation,

ans =\sum_{i = 0}^n i^2 \ - \sum_{i = 0}^{n-x} i^2 \ + \sum_{i = n-(x-1)}^n i*b \

Applying formula for ‘sum of squares’ for first two terms and ‘sum of A.P. for third term’.

ans = (n*(n+1)*(2*n +1))/6 -((n-x)*(n-x+1)*(2*(n-x) +1) )/6 + (x/2)*( (n-(x-1))*b + n*b)

Substituting (m-n) for b and then expanding the equation,

ans = 2*m*n*x - m*x^2 + m*x - n*x^2 + n*x + (2*x^3 + x)/3 - x^2

Since the answer is modulo 10^9+7, after each operation mod operation should be applied.

And since 3 is a division term, modular inverse of 3 should be multiplied instead.

Modular Inverse 3 mod (10^9+7) = 333333336

(You can read more about Modular Arithmetic here: here)
So,

ans = 2*m*n*x - m*x^2 + m*x - n*x^2 + n*x + (2*x^3 + x)* 333333336 - x^2

(Note: Apply the modulus operation on each operation and each term)

And lastly multiply the answer with 2, to count the triangles above the diagonal as well.

Therefore, the final solution to the problem is calculated by multiplying the current answer by 2.

ans = 2*ans

Now, the answer is calculated in O(1).

AUTHOR’S SOLUTION:

#include<bits/stdc++.h>

using namespace std;

#define MOD 1000000007
#define ll long long int

int main()
{
	ll t,x,i,n,m,ans,z,y,ans2;
	//freopen("triangles.in","r",stdin);
	//freopen("triangles.out","w",stdout);

	cin>>t;
	
	while(t--)
	{
		cin>>n>>m;
		
		ans=0;
		x=min(n,m);
		
		//Equation without MOD = 2*m*n*x - m*x*x + m*x - n*x*x + n*x + (2*x*x*x + x)/3 - x*x ;
		
		ans = ( ((((2*m)%MOD * (n%MOD))%MOD * (x%MOD) )%MOD)  -  (((((m%MOD)*(x%MOD))%MOD)*(x%MOD))%MOD)  +  ((m%MOD)*(x%MOD))%MOD )%MOD;
		ans = ( ans - (((((n%MOD)*(x%MOD))%MOD)*(x%MOD))%MOD)  +  ((n%MOD)*(x%MOD))%MOD )%MOD;
		ans = ( ans + (((((((((2*x%MOD)*(x%MOD))%MOD)*(x%MOD))%MOD) + x)%MOD)* 333333336 )%MOD) - ((x%MOD)*(x%MOD))%MOD )%MOD;
		
		ans=(ans%MOD +MOD)%MOD;
		
		cout<<ans<<endl;
	}
	
	return 0;
}

http://www.codechef.com/viewsolution/7007451

link of my solution, no idean why i m getting timeout.
help me