MATMOD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: fremder
Tester: wasd2401
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Elementary combinatorics

PROBLEM:

An N\times N matrix is called K-strange if all its elements are distinct numbers from 1 to N, and (A_{i, j} - A_{j,i}) is a multiple of K for all (i, j).
Let M_N be the largest number such that there exists a M_N-strange N\times N matrix.

You’re given a matrix A with some elements missing.
Find the number of ways to fill in the missing elements such that the resulting matrix is M_N-strange.

EXPLANATION:

There’s a lot of stuff going on in the statement, so a good first step is to take things one at a time.
Our solution process can be broken up into three steps:

  • First, find M_N for a given N.
  • Second, characterize all M_N-strange matrices.
  • Finally, find out how many of them can be obtained from A.
Finding M

We want to find the largest integer for which there exists a strange matrix.
Suppose we fix K. When can there be a K-strange matrix?

Notice that:

  • A_{1, 2} and A_{2, 1} should be the same, modulo K. This means A_{1, 2} - A_{2, 1} must be a multiple of K.
    Since they must be distinct integers, their values must differ by at least K - that is, \min(A_{1, 2}, A_{2,1}) + K \leq \max(A_{1, 2}, A_{2, 1}).
  • In fact, the exact same reasoning applies to any pair of elements (i, j) with i \lt j: the pair (A_{i, j}, A_{j, i}) should have a difference of at least K between them.

There are \frac{N\cdot (N-1)}{2} such pairs of opposite elements.
Since everything must be distinct, this means there will definitely be some (i, j) such that \min(A_{i, j}, A_{j, i}) \geq \frac{N\cdot (N-1)}{2}.

Since the larger element of this pair should be at least K away from the smaller one, the larger element is at least K + \frac{N\cdot (N-1)}{2}.
However, this larger element should also be within N^2, giving us the inequality K + \frac{N\cdot (N-1)}{2} \leq N^2.

This immediately gives us an upper bound on K: K \leq N^2 - \frac{N\cdot (N-1)}{2}.

With a bit of trial and error, it’s not hard to see that you can also construct a matrix that satisfies this K.
For example, with N = 4, we get K = 4\cdot 5 / 2 = 10, and one valid matrix is:

\begin{bmatrix} 7 & 11 & 12 & 14 \\ 1 & 8 & 13 & 15 \\ 2 & 3 & 9 & 16 \\ 4 & 5 & 6 & 10 \\ \end{bmatrix}

More generally: you can fill the numbers from 1 to N\cdot (N-1)/2 in the lower triangle of the matrix; then add K = N\cdot (N+1)/2 to each of them and place that value in the symmetric part; and finally place all remaining elements in the middle.

We thus have an upper bound for K, and also a way to achieve this upper bound.
So, M_N = N\cdot (N+1)/2

Which matrices?

Next, let’s figure out which matrices can possibly be M_N-strange.
We know M_N = N\cdot (N+1)/2.
It can be observed that:

  • If some integer x is such that x + M_N \gt N^2 and x - M_N \lt 1; this integer can only be on the diagonal of the matrix (since it can’t be paired to anything else).
    There are exactly N such integers: M_N, M_N-1, M_N-2, \ldots, M_N-N+1.
    Coincidentally, there are N diagonal spots too.
    So, the diagonal should consist of exactly these elements, in some order!
  • Now, let’s look at the values other than these N.
    Since M_N is so large, we can see that all these elements will “pair up”; forming the pairs
    (1, M_N+1), (2, M_N+2), (3, M_N+3), \ldots, (N\cdot (N-1)/2, N^2).
  • For each of these pairs, one element should be in the lower triangle, and the other in the upper triangle in the symmetric position.
    However, it doesn’t matter where which pair is placed, since they’re independent of each other.
Filling in A

Now, we have a good idea of what a M_N-strange matrix should look like. Let’s see how many ways A can be filled in to satisfy this.

First, let’s look at the elements that must be on the diagonal - M_N, M_N-1, M_N-2, \ldots, M_N-N+1.

  • If any of them exist outside the diagonal, the answer is 0.
  • If any other elements are on the diagonal, the answer is 0.
  • Otherwise, suppose k of them are on the diagonal, and N-k don’t appear anywhere else in the matrix.
    These N-k elements can be arranged in any order along the diagonal, for (N-k)! ways in total.

Next, let’s look at the other elements.
Consider some pair (x, x+K) where x \leq N\cdot (N-1)/2.

  • If x and x+K both appear in the matrix, they must be opposite each other; otherwise the answer is 0.
  • If exactly one of x and x+K appears, the value of the cell opposite to it is uniquely fixed (to the other element of the pair).

Now, suppose there are c pairs such that neither x nor x+K appear in A.
Then,

  • There will also be exactly c pairs of opposite cells that are empty; everything else will be filled (since if one cell of a pair was originally filled, the other is uniquely determined).
  • We can assign there c pairs of values to the c empty cells in any order; for c! ways.
  • Further, for each pair we can choose whether the smaller or the larger value appears in the bottom triangle.
    With 2 choices for each pair, we get 2^c choices.

So, the final answer is

(N-k)!\times c! \times 2^c

where:

  • k diagonal elements are already present; and
  • c pairs of values don’t appear in A at all.

Of course, this is only after checking if the answer is 0.

Note that you need to calculate some factorials and powers of two, but that can be done in a bruteforce fashion since N-k and c are both \leq N^2; and the sum of N^2 is bounded.

TIME COMPLEXITY:

\mathcal{O}(N^2) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;

#define MOD 1000000007
#define INF 1e18
#define endl "\n"
#define pb push_back
#define ppb pop_back
#define ff first
#define ss second
#define sz(x) ((int)(x).size())
#define all(x) (x).begin(), (x).end()
#define int long long

const int N = 1e6;
vector<int> pow2(N), factorial(N);

int sum = 0;

void solve() {
    int n;
    cin >> n;

    assert(n >= 1 and n <= 500);

    sum += n * n;

    int k = n * (n + 1) / 2;
    set<int> diagonal, nonDiagonal;
    for(int i = 1; i <= n * n; i++)
        nonDiagonal.insert(i);

    for(int i = (n * (n - 1) / 2) + 1; i <= k; i++) {
        diagonal.insert(i);
        nonDiagonal.erase(i);
    }

    std::vector<vector<int>> mat(n, vector<int>(n));
    for(int i = 0; i < n; i++) {
        for(int j = 0; j < n; j++) {
            cin >> mat[i][j];

            assert((mat[i][j] >= 0) and (mat[i][j] <= n * n));
        }
    }

    map<int, int> cnt;
    for(int i = 0; i < n; i++) {
        for(int j = 0; j < n; j++) {
            if(mat[i][j] == 0) continue;

            if(i == j) {
                if(diagonal.count(mat[i][j]) == 0) {
                    cout << 0 << endl;
                    return;
                }
                else
                    diagonal.erase(mat[i][j]);
            }

            cnt[mat[i][j]]++;
        }
    }

    for(auto &pr: cnt) assert(pr.ss == 1);

    bool flag = 1;
    int pairs = 0;
    for(int i = 0; i < n; i++) {
        for(int j = 0; j < i; j++) {
            int a = mat[i][j], b = mat[j][i];

            if(a and (nonDiagonal.count(a) == 0)) 
                flag = 0;

            if(b and (nonDiagonal.count(b) == 0)) 
                flag = 0;

            if(a and b) {
                if((a % k) != (b % k))
                    flag = 0;
            }
            else if(a) {
                if(a <= n * (n - 1) / 2) {
                    if(cnt[a + k])
                        flag = 0;
                }
                else {
                    if(cnt[a - k])
                        flag = 0;
                }
            }
            else if(b) {
                if(b <= n * (n - 1) / 2) {
                    if(cnt[b + k])
                        flag = 0;
                }
                else {
                    if(cnt[b - k])
                        flag = 0;
                }
            }
            else {
                pairs++;
            }
        }
    }   

    if(!flag) {
        cout << 0 << endl;
        return;
    }

    int ans = (factorial[pairs] * pow2[pairs]) % MOD;
    ans *= factorial[sz(diagonal)]; ans %= MOD;

    cout << ans << endl;
}

signed main() {
    pow2[0] = factorial[0] = 1;
    for(int i = 1; i < N; i++) {
        pow2[i] = (pow2[i - 1] * 2) % MOD;
        factorial[i] = (factorial[i - 1] * i) % MOD;
    }
    
    int tt = 1;
    cin >> tt; assert(tt >= 1 and tt <= 100000);
    while(tt--) {
        solve();       
    }

    assert(sum <= 500000);
    return 0;
}
Tester's code (C++)
/*

*       *  *  ***       *       *****
 *     *   *  *  *     * *        *
  *   *    *  ***     *****       *
   * *     *  * *    *     *   *  *
    *      *  *  *  *       *   **

                                 *
                                * *
                               *****
                              *     *
        *****                *       *
      _*     *_
     | * * * * |                ***
     |_*  _  *_|               *   *
       *     *                 *  
        *****                  *  **
       *     *                  ***
  {===*       *===}
      *  IS   *                 ***
      *  IT   *                *   *
      * RATED?*                *  
      *       *                *  **
      *       *                 ***
       *     *
        *****                  *   *
                               *   *
                               *   *
                               *   *
                                ***   

*/

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace __gnu_pbds;
using namespace std;

#define osl tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update>
#define ll long long
#define ld long double
#define forl(i, a, b) for(ll i = a; i < b; i++)
#define rofl(i, a, b) for(ll i = a; i > b; i--)
#define fors(i, a, b, c) for(ll i = a; i < b; i += c)
#define fora(x, v) for(auto x : v)
#define vl vector<ll>
#define vb vector<bool>
#define pub push_back
#define pob pop_back
#define fbo find_by_order
#define ook order_of_key
#define yesno(x) cout << ((x) ? "YES" : "NO")
#define all(v) v.begin(), v.end()

const ll N = 2e5 + 4;
const ll mod = 1e9 + 7;
// const ll mod = 998244353;

ll modinverse(ll a) {
	ll m = mod, y = 0, x = 1;
	while (a > 1) {
		ll q = a / m;
		ll t = m;
		m = a % m;
		a = t;
		t = y;
		y = x - q * y;
		x = t;
	}
	if (x < 0) x += mod;
	return x;
}
ll gcd(ll a, ll b) {
	if (b == 0)
		return a;
	return gcd(b, a % b);
}
ll lcm(ll a, ll b) {
	return (a / gcd(a, b)) * b;
}
bool poweroftwo(ll n) {
	return !(n & (n - 1));
}
ll power(ll a, ll b, ll md = mod) {
	ll product = 1;
	a %= md;
	while (b) {
		if (b & 1) product = (product * a) % md;
		a = (a * a) % md;
		b /= 2;
	}
	return product % md;
}
struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && !isspace(buffer[now])) {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};
ll summ=0;
void panipuri(input_checker &jaggu) {
	ll n, m = 0, k = -1, c = 0, sum = 0, q = 0, ans = 0, p = 1;
	string s;
	bool ch = true;
	set<ll> t;
	n=jaggu.readInt(1,500);
	jaggu.readEoln();
	summ+=n*n;
	assert(summ<=5e5);
	vector<vl> a(n,vl(n));
	forl(i, 0, n) {
		vector<int> v=jaggu.readInts(n,0,n*n);
		jaggu.readEoln();
		forl(j,0,n) {
			a[i][j]=v[j];
			if(v[j]){
				assert(t.count(v[j])==0);
				t.insert(v[j]);
			}
		}
	}
	t.clear();
	k=(n*(n+1))/2;
	ans=1;
	forl(i,0,n){
		forl(j,i+1,n){
			if(a[i][j] && a[j][i]){
				if(a[i][j]%k!=a[j][i]%k){
					cout<<0;
					return;
				}
			}
			else if(a[i][j]){
				if(t.count(a[i][j]%k) || a[i][j]%k==0 || a[i][j]%k>(n*(n-1))/2){
					cout<<0;
					return;
				}
				t.insert(a[i][j]%k);
			}
			else if(a[j][i]){
				if(t.count(a[j][i]%k) || a[j][i]%k==0 || a[j][i]%k>(n*(n-1))/2){
					cout<<0;
					return;
				}
				t.insert(a[j][i]%k);
			}
			else{
				ans*=2*p;
				ans%=mod;
				p++;
			}
		}
	}
	p=1;
	forl(i,0,n){
		if(a[i][i]==0){
			ans*=p;
			ans%=mod;
			p++;
		}
		else{
			if(a[i][i]<=(n*(n-1))/2 || a[i][i]>(n*(n+1))/2){
				cout<<0;
				return;
			}
		}
	}
	cout<<ans;
	return;
}
int main() {
// 	ios::sync_with_stdio(false);
// 	cin.tie(NULL);
	#ifndef ONLINE_JUDGE
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
	#endif
	int laddu = 1;
	// cin >> laddu;
	input_checker jaggu;
	laddu=jaggu.readInt(1,1e5);
	jaggu.readEoln();
	forl(i, 1, laddu + 1) {
		// cout << "Case #" << i << ": ";
		panipuri(jaggu);
		cout << '\n';
	}
	jaggu.readEof();
}
Editorialist's code (Python)
mod = 10**9 + 7
fac = [1]*10**6
for i in range(1, 10**6): fac[i] = fac[i-1] * i % mod

for _ in range(int(input())):
    n = int(input())
    a = []
    for i in range(n): a.append(list(map(int, input().split())))
    k = n*(n+1)//2
    ans, diag_free, ndiag_free = 1, 0, 0
    dL, dR = n*(n-1)//2+1, n*(n+1)//2
    
    mark = [[-1, -1]]*(n*n + 1)
    for i in range(n):
        for j in range(i+1):
            if i == j:
                if a[i][j] == 0: diag_free += 1
                else:
                    if a[i][j] < dL or a[i][j] > dR: ans = 0
            else:
                if a[i][j] == 0 and a[j][i] == 0: ndiag_free += 1
                elif a[i][j] == 0 or a[j][i] == 0:
                    x = max(a[i][j], a[j][i])
                    mark[x] = (i, j)
                else:
                    if abs(a[i][j] - a[j][i]) != k: ans = 0
                if dL <= a[i][j] <= dR: ans = 0
                if dL <= a[j][i] <= dR: ans = 0
    for i in range(dL):
        if mark[i][0] == -1 or mark[i+k][0] == -1: continue
        x1, y1 = mark[i]
        x2, y2 = mark[i+k]
        if x1 != y2 or x2 != y1: ans = 0
    
    # ans = diag_free! * ndiag_free! * 2^(ndiag_free)
    ans = (ans * pow(2, ndiag_free, mod)) % mod
    ans = (ans * fac[diag_free] * fac[ndiag_free]) % mod
    print(ans)

My approach is also similar to the editorial. First, I’m precomputing all k values for all possible n.
Then I’m checking for possible answers. if the diagonal is free and contains only suitable values, the number of non-diagonal free positions.
if the grid is valid or not,

Here is my code
https://www.codechef.com/viewsolution/1046484287

1 Like

Tester code’s panipuri and laddu are tasty

2 Likes

Is this matrix considered K Strange?
[3 1
3 4]

I don’t quite understand how to construct a K Strange matrix for a N=2 and 2*2 matrix where K = 3, since there is the max range of K for n=2?

The editorial is saying to make pair like (a,a+k), where k = n*(n+1)/2. But what will happen if both a and a+k are present in the matrix in non symmetric position?

you just print 0 then

it’s not k strange because 3 is repeated twice.
for n = 2 and k = 3
[ 3 4
1 2 ] would be k strange as
[ (3 % 3 = 0) (4 % 3 = 1)
(1 % 3 = 1) (2 % 3 = 2) ] is a symmetric matrix.
it was given in the second test case

I have used a bit different approach (completely explained in my code with comments) , and I have tested it using user accepted codes, with all possible tests I could think of, still getting wrong answer. Please tell me why
My solution : https://www.codechef.com/viewsolution/1046692247

https://www.codechef.com/viewsolution/1046694564

my solution is also similar in my code finalpair is c and ans is n-k

i think i am making some mistake in code implementation not in logic please help me to find the wrong code in my solution

What about 1 1 1 test case its answer should be 1(saw on pro version) but if a matrix is already filled how can the answer be 1?