# MAXXMIN - Editorial

Setter: Evgeny Karpovich
Tester: Istvan Nagy
Editorialist: Taranpreet Singh

Medium

# PREREQUISITES

Combinatorics, Pointers, and Patience.

# PROBLEM

Given a matrix A with N rows and M columns and an integer X. Let f(X) denotes the number of submatrices B inside A such that min(B) \oplus max(B) = X.

Find the sum of f(X) over all permutations of rows of A, and compute their sum.

# EXPLANATION

### From permutations to subsets

Assume rows and columns are 1-indexed. Considering a submatrix B, it shall consist of a range [L, R], 1 \leq L \leq R \leq M of columns of rows U to B in some permutation of rows. We can see that it shall correspond to some subset of rows (not necessarily continuous) in the original matrix A. Let S denote the set of rows included in the subset. We shall count the number of permutations in which rows present in this subset appear together in any order (only then they can be chosen as rows of a submatrix).

Letâ€™s assume C denotes the number of rows in subset. We donâ€™t care about the order of rows in subset, so there are C! orderings of rows within subset. Now we can consider all this set of rows as a single row. So there are N-C+1 rows which can be ordered freely. Hence, the number of permutations of rows, in which subset S appears as continuous set of rows is C! * (N-C+1)!

Letâ€™s assume, for a fixed set of rows S, the number of submatrices consisting of elements of these rows is W_S and min(B) \oplus max(B) = X, then we need to find \displaystyle \sum_{S \subseteq P(R)} W_S * (|S|!)*(N-|S|+1)!, where P(R) denotes the power set of rows. Also, W_S = 0 if S = \emptyset

### Computing W_S

Letâ€™s consider a naive way. For a subset S, each submatrix correspond to a continuous set of columns, say from L to R for 1 \leq L \leq R \leq M. Trying all pairs, we can count the number of pairs (L, R) such that the submatrix formed by columns L to R of subset of rows S has min(B) \oplus max(B) = X. This solution works in O(2^N * M^2) and shall time out.

Code
import java.util.*;
import java.io.*;
class MAXXMIN{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), M = ni(), X = ni();
long[] fact = new long[1+N];
fact[0] = 1;
for(int i = 1; i<= N; i++)fact[i] = fact[i-1]*i;
int[][] A = new int[N][M];
for(int i = 0; i< N; i++)
for(int j = 0; j< M; j++)
A[i][j] = ni();

long ans = 0;
int[] imin = new int[M], imax = new int[M];
Arrays.fill(imin, Integer.MAX_VALUE);
Arrays.fill(imax, Integer.MIN_VALUE);
for(int r = 0; r< N; r++){
for(int c = 0; c< M; c++){
imin[c] = Math.min(imin[c], A[r][c]);
imax[c] = Math.max(imax[c], A[r][c]);
}
}
}
long subarrays = 0;
for(int L = 0; L < M; L++){
int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE;
for(int R = L; R< M; R++){
min = Math.min(min, imin[R]);
max = Math.max(max, imax[R]);
if((min^max) == X)subarrays++;
}
}
ans += (N-cnt+1)*fact[cnt]*fact[N-cnt]*subarrays;
}
pn(ans);
}
int bit(int x){return x == 0?0:(1+bit(x&(x-1)));}

//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = false;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
//        new MAXXMIN().run();
new Thread(null, new Runnable() {public void run(){try{new MAXXMIN().run();}catch(Exception e){e.printStackTrace();System.exit(1);}}}, "1", 1 << 28).start();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}


### Computing W_S in O(M*logM)

Let \displaystyle min_c = \min_{r \in S} A_{r, c} and \displaystyle max_c = \max_{r \in S} A_{r, c}. Now we need to compute pairs (L, R) such that \displaystyle \min_{c = L}^R min_c \oplus \max_{c = L}^R max_c = X.

Letâ€™s try divide and conquer here. For range [L, R], if p denotes any position of \displaystyle\min_{c = L}^R min_c, then all intervals (l, r) such that L \leq l \leq p \leq r \leq R shall have \displaystyle\min_{c = L}^R min_c = min_p. With fixed minimum, we need to count the pairs (l, r) such that L \leq l \leq p \leq r \leq R such that \displaystyle \max_{c = l}^r max_c = min_p \oplus X.

Since min_p \oplus X is a fixed value, we can run binary searches now. Letâ€™s find largest position L_g such that max_{L_g} > min_p \oplus X, position L_{ge} such that max_{L_{ge}} \geq min_p \oplus X and smallest positions R_g such that max_{R_g} > min_p \oplus X and position R_{ge} such that max_{R_g} \geq min_p \oplus X. Note that if L_g < L or L_{ge} < L, we can use L-1. Similarly R+1 for right ends.

We need subarrays with left endpoint in range [L, p] and right end in range [p, R] such that maximum of range [l, r] is min_p \oplus X. We can see that the number of such subarrays is given by (p-L_g)*(R_g-p) - (p-L_{ge})*(R_{ge}-p). (p-L_g)*(R_g-p) denote the subarrays with maximum \leq min_p \oplus X and (p-L_{ge})*(R_{ge}-p) denotes the number of subarrays with maximum < min_p \oplus X.

This way, by buildning RMQ on both min and max arrays, we can solve the problem in O(M*log(M)) for each set S, leading to time complexity O(2^N*M*log(M)) per test, which is yet too slow to get AC. Some optimizations might be able to AC.

Code
import java.util.*;
import java.io.*;
class MAXXMIN{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), M = ni(), X = ni();
long[] fact = new long[1+N];
fact[0] = 1;
for(int i = 1; i<= N; i++)fact[i] = fact[i-1]*i;
int[][] A = new int[N][M];
int[] all = new int[N*M];
for(int i = 0; i< N; i++)
for(int j = 0; j< M; j++)
all[i*M+j] = A[i][j] = ni();

Arrays.sort(all);
int C = 1;
for(int i = 1; i< all.length; i++)if(all[i] != all[C-1])all[C++] = all[i];
all = Arrays.copyOf(all, C);
for(int i = 0; i< N; i++)
for(int j = 0; j< M; j++)
A[i][j] = Arrays.binarySearch(all, A[i][j]);

int[] nxt = new int[C];
for(int i = 0; i< C; i++){
int pos = Arrays.binarySearch(all, all[i]^X);
if(pos >= 0)nxt[i] = pos;
else nxt[i] = -1;
}
int[] stack = new int[M];//Stack

int[] leftMinPos = new int[M], rightMinPos = new int[M], leftMaxPos = new int[M], rightMaxPos = new int[M];

int[] lst = new int[C], leftNext = new int[M], rightNext = new int[M];
long ans = 0;
int[] imin = new int[M], imax = new int[M];
Arrays.fill(imin, Integer.MAX_VALUE);
Arrays.fill(imax, Integer.MIN_VALUE);
for(int r = 0; r< N; r++){
for(int c = 0; c< M; c++){
imin[c] = Math.min(imin[c], A[r][c]);
imax[c] = Math.max(imax[c], A[r][c]);
}
}
}

for(int c = 0, ptr = 0; c< M; c++){
while(ptr > 0 && imin[stack[ptr-1]] > imin[c])ptr--;
leftMinPos[c] = ptr == 0?-1:stack[ptr-1];
stack[ptr++] = c;
}
for(int c = 0, ptr = 0; c< M; c++){
while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
leftMaxPos[c] = ptr == 0?-1:stack[ptr-1];
stack[ptr++] = c;
}

for(int c = M-1, ptr = 0; c>= 0; c--){
while(ptr > 0 && imin[stack[ptr-1]] >= imin[c])ptr--;
rightMinPos[c] = ptr == 0?M:stack[ptr-1];
stack[ptr++] = c;
}
for(int c = M-1, ptr = 0; c>= 0; c--){
while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
rightMaxPos[c] = ptr == 0?M:stack[ptr-1];
stack[ptr++] = c;
}

Arrays.fill(leftNext, -1);Arrays.fill(rightNext, M);
Arrays.fill(lst, -1);
for(int c = 0; c< M; c++){
lst[imax[c]] = c;
if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)leftNext[c] = lst[nxt[imin[c]]];
}
Arrays.fill(lst, -1);
for(int c = M-1; c >= 0; c--){
lst[imax[c]] = c;
if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)rightNext[c] = lst[nxt[imin[c]]];
}

long subarrays = 0;
for(int c = 0; c< M; c++){
if(nxt[imin[c]] == -1)continue;//min^X doesn't appear in matrix
int Lmin = leftMinPos[c]+1, Rmin = rightMinPos[c]-1;
int Lmax = leftNext[c], Rmax = rightNext[c];
if(Lmax != -1 && rightMaxPos[Lmax] >= c){
Lmin = Math.max(Lmin, leftMaxPos[Lmax]+1);
Rmin = Math.min(Rmin, rightMaxPos[Lmax]-1);
}else if(Lmax != -1){
Lmin = Math.max(Lmin, rightMaxPos[Lmax]+1);
}

if(Rmax != M && leftMaxPos[Rmax] <= c){
Lmin = Math.max(Lmin, leftMaxPos[Rmax]+1);
Rmin = Math.min(Rmin, rightMaxPos[Rmax]-1);
}else if(Rmax != M){
Rmin = Math.min(Rmin, leftMaxPos[Rmax]-1);
}
Lmax = Math.max(Lmax, Lmin-1);
Rmax = Math.min(Rmax, Rmin+1);
long count = (c-Lmin+1)*(long)(Rmin-c+1) - (c-Lmax)*(long)(Rmax-c);
subarrays += count;
}
ans += (N-cnt+1)*fact[cnt]*fact[N-cnt]*subarrays;
}
pn(ans);
}
int bit(int x){return x == 0?0:(1+bit(x&(x-1)));}

//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = false;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
//        new MAXXMIN().run();
new Thread(null, new Runnable() {public void run(){try{new MAXXMIN().run();}catch(Exception e){e.printStackTrace();System.exit(1);}}}, "1", 1 << 28).start();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}


### Computing W_S in O(M)

We can no longer build RMQ, so letâ€™s compute previous and next smaller, and previous and next greater elements. Also, for position c, letâ€™s compute smallest position p \geq c such that max_p = min_c \oplus X. and letâ€™s compute largest p \leq c such that max_p = \ min_c \oplus X

Our aim is still the same, to consider all positions p one by one, find the interval in which position p is minimum(say [L, R]), and among subarrays (l, r) such that L \leq l \leq p \leq r \leq R, find the number of subarrays with maximum min_p \oplus X by computing L_g, L_{ge}, R_g, R_{ge} using these arrays. I have added comments in my code for better understanding.

Implementation note: While having multiple ocurrences of min_p, be sure not to doublecount.

# TIME COMPLEXITY

The time complexity is O(2^N * M) per test case.

# SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>

using namespace std;

typedef long long ll;
int const maxn = 9, maxm = 1e5 + 5;
int a[maxn][maxm], f[maxn], inf = 1e9 + 7;
int all_element[maxn * maxm];
int nxt[maxn * maxm];
int imin[(1 << (maxn - 1))][maxm];
int imax[(1 << (maxn - 1))][maxm];
int lmin[maxm], rmin[maxm], lmax[maxm], rmax[maxm], Q[maxm];
int lnxt[maxm], rnxt[maxm];
int lst[maxn * maxm];

int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n, m, x, N = 0;
cin >> n >> m >> x;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
cin >> a[i][j];
all_element[++N] = a[i][j];
}
}
f[0] = 1;
for (int i = 1; i <= n; ++i) f[i] = f[i - 1] * i;
sort(all_element + 1, all_element + N + 1);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
a[i][j] = lower_bound(all_element + 1, all_element + N + 1, a[i][j]) - all_element;
}
}
for (int i = 1; i <= N; ++i) {
int pos = lower_bound(all_element + 1, all_element + N + 1, (all_element[i]^x)) - all_element;
if (pos <= N && all_element[pos] == (all_element[i]^x)) {
nxt[i] = pos;
}
}
for (int i = 1; i <= m; ++i) {
imin[0][i] = inf;
imax[0][i] = -inf;
}
ll ans = 0;
int where = (mask^(1 << (b - 1)));
for (int j = 1; j <= m; ++j) {
}
int ptr = 0;
for (int j = 1; j <= m; ++j) {
lmin[j] = Q[ptr];
Q[++ptr] = j;
}
ptr = 0;
for (int j = 1; j <= m; ++j) {
lmax[j] = Q[ptr];
Q[++ptr] = j;
}
ptr = 0;
for (int j = m; j >= 1; --j) {
if (ptr == 0) rmin[j] = m + 1;
else rmin[j] = Q[ptr];
Q[++ptr] = j;
}
ptr = 0;
for (int j = m; j >= 1; --j) {
if (ptr == 0) rmax[j] = m + 1;
else rmax[j] = Q[ptr];
Q[++ptr] = j;
}
for (int j = 1; j <= m; ++j) {
}
for (int j = m; j >= 1; --j) {
}
else rnxt[j] = m + 1;
}
for (int j = 1; j <= m; ++j) lst[imax[mask][j]] = 0;
for (int j = 1; j <= m; ++j) {
int L = lmin[j] + 1, R = rmin[j] - 1;
int lx = lnxt[j], rx = rnxt[j];
if (lx != 0 && rmax[lx] >= j) {
R = min(R, rmax[lx] - 1);
L = max(L, lmax[lx] + 1);
}
else if (lx != 0) {
L = max(L, rmax[lx] + 1);
}
if (rx != m + 1 && lmax[rx] <= j) {
L = max(L, lmax[rx] + 1);
R = min(R, rmax[rx] - 1);
}
else if (rx != m + 1) {
R = min(R, lmax[rx] - 1);
}
lx = max(lx, L - 1), rx = min(rx, R + 1);
ll val = (ll)(j - L + 1) * (ll)(R - j + 1) - (ll)(j - lx) * (ll)(rx - j);
}
ans += add * (ll)(f[cnt] * f[n - cnt] * (n - cnt + 1));
}
cout << ans << '\n';
return 0;
}

Tester's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
using namespace std;

#ifdef LOCAL
#define eprintf(...) fprintf(stderr, __VA_ARGS__);fflush(stderr);
#else
#define eprintf(...) 42
#endif

using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
ll myRand(ll B) {
return (ull)rng() % B;
}

#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second

clock_t startTime;
double getCurrentTime() {
return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}

const int N = 100100;
const int M = 8;
const int K = N * M;
int a[M][N];
int b[M + 1][N], c[M + 1][N];
int n, m, k;
int xs[K];
int pr[K];
ll ans[M + 1];
int smL[N], smR[N], bgL[N], bgR[N], wL[N], wR[N];
int st[N];
int stSz;
int lst[K];

ll solve() {
stSz = 0;
st[0] = -1;
for (int i = 0; i < n; i++) {
while(stSz > 0 && b[m][i] < b[m][st[stSz]]) stSz--;
smL[i] = st[stSz];
st[++stSz] = i;
}
stSz = 0;
st[0] = n;
for (int i = n - 1; i >= 0; i--) {
while(stSz > 0 && b[m][i] <= b[m][st[stSz]]) stSz--;
smR[i] = st[stSz];
st[++stSz] = i;
}
stSz = 0;
st[0] = -1;
for (int i = 0; i < n; i++) {
while(stSz > 0 && c[m][i] >= c[m][st[stSz]]) stSz--;
bgL[i] = st[stSz];
st[++stSz] = i;
}
stSz = 0;
st[0] = n;
for (int i = n - 1; i >= 0; i--) {
while(stSz > 0 && c[m][i] >= c[m][st[stSz]]) stSz--;
bgR[i] = st[stSz];
st[++stSz] = i;
}
for (int i = 0; i < n; i++) {
lst[c[m][i]] = i;
if (pr[b[m][i]] != -1 && lst[pr[b[m][i]]] != -1)
wL[i] = lst[pr[b[m][i]]];
else
wL[i] = -1;
}
for (int i = 0; i < n; i++)
lst[c[m][i]] = -1;
for (int i = n - 1; i >= 0; i--) {
lst[c[m][i]] = i;
if (pr[b[m][i]] != -1 && lst[pr[b[m][i]]] != -1)
wR[i] = lst[pr[b[m][i]]];
else
wR[i] = n;
}
for (int i = 0; i < n; i++)
lst[c[m][i]] = -1;
ll res = 0;
for (int i = 0; i < n; i++) {
int x = pr[b[m][i]];
if (x == -1) continue;
if (c[m][i] == x) {
res += (ll)(i - max(smL[i], bgL[i])) * (min(smR[i], bgR[i]) - i);
} else {
int l = smL[i], r = smR[i];
int p = -1, q = -1;
if (wL[i] > l) {
p = wL[i];
if (bgR[p] <= i) p = -1;
}
if (wR[i] < r) {
q = wR[i];
if (bgL[q] >= i) q = -1;
}
if (p != -1) {
if (q != -1) {
assert(bgL[p] == bgL[q]);
assert(bgR[p] == bgR[q]);
l = max(l, bgL[p]);
r = min(r, bgR[p]);
res += (ll)(i - l) * (r - i) - (ll)(i - p) * (q - i);
} else {
l = max(l, bgL[p]);
r = min(r, bgR[p]);
res += (ll)(p - l) * (r - i);
}
} else if (q != -1) {
l = max(l, bgL[q]);
r = min(r, bgR[q]);
res += (ll)(i - l) * (r - q);
}
}
}
return res;
}

void brute(int p, int cnt) {
if (p == m) {
if (cnt > 0) ans[cnt] += solve();
return;
}
for (int i = 0; i < n; i++) {
b[p + 1][i] = b[p][i];
c[p + 1][i] = c[p][i];
}
brute(p + 1, cnt);
for (int i = 0; i < n; i++) {
b[p + 1][i] = min(b[p + 1][i], a[p][i]);
c[p + 1][i] = max(c[p + 1][i], a[p][i]);
}
brute(p + 1, cnt + 1);
}

int main()
{
startTime = clock();
//	freopen("input.txt", "r", stdin);
//	freopen("output.txt", "w", stdout);

int X;
scanf("%d%d%d", &m, &n, &X);
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++) {
scanf("%d", &a[i][j]);
xs[k++] = a[i][j];
}
sort(xs, xs + k);
k = unique(xs, xs + k) - xs;
for (int i = 0; i < k; i++) {
int x = xs[i] ^ X;
int p = lower_bound(xs, xs + k, x) - xs;
if (p < k && xs[p] == x)
pr[i] = p;
else
pr[i] = -1;
}
for (int i = 0; i < k; i++)
lst[i] = -1;
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
a[i][j] = lower_bound(xs, xs + k, a[i][j]) - xs;
for (int i = 0; i < n; i++) {
b[0][i] = k;
c[0][i] = 0;
}
brute(0, 0);
ll res = 0;
for (int k = 1; k <= m; k++) {
ll w = ans[k];
for (int i = 1; i <= k; i++)
w *= i;
for (int i = 1; i <= m + 1 - k; i++)
w *= i;
res += w;
}
printf("%lld\n", res);

return 0;
}

Editorialist's Solution
import java.util.*;
import java.io.*;
class MAXXMIN{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni(), M = ni(), X = ni();
long[] fact = new long[1+N];
fact[0] = 1;
for(int i = 1; i<= N; i++)fact[i] = fact[i-1]*i;

int[][] A = new int[N][M];
int[] all = new int[N*M];//Contains all distinct elements
for(int i = 0; i< N; i++)
for(int j = 0; j< M; j++)
all[i*M+j] = A[i][j] = ni();

Arrays.sort(all);
int C = 1;
for(int i = 1; i< all.length; i++)if(all[i] != all[C-1])all[C++] = all[i];
all = Arrays.copyOf(all, C);

for(int i = 0; i< N; i++)
for(int j = 0; j< M; j++)
A[i][j] = Arrays.binarySearch(all, A[i][j]);//Value compression, all[A[r][c]] now gets the original value of A[r][c]

int[] nxt = new int[C];//all[nxt[i]] = all[i]^X, or nxt[i] = -1 if no such position exists
for(int i = 0; i< C; i++){
int pos = Arrays.binarySearch(all, all[i]^X);
if(pos >= 0)nxt[i] = pos;
else nxt[i] = -1;
}

int[] stack = new int[M];//Temporary stack

int[] leftMinPos = new int[M], rightMinPos = new int[M], leftMaxPos = new int[M], rightMaxPos = new int[M];
//leftMinPos[i] -> largest p < i such that imin[p] <= imin[i]
//rightMinPos[i] -> smallest p > i such that imin[p] > imin[i]

//leftMaxPos[i] -> largest p < i such that imax[p] <= imax[i]
//rightMaxPos[i] -> smallest p > i such that imax[p] > imax[i]

int[] lst = new int[C], leftNext = new int[M], rightNext = new int[M];
//leftNext[i] = largest p <= i such that max[p] = min[i]^X
//rightNext[i] = smallest p >= i such that max[p] = min[i]^X
long ans = 0;
int[] imin = new int[M], imax = new int[M];
Arrays.fill(imin, Integer.MAX_VALUE);//imin[c] = min_{r \in mask} A[r][c]
Arrays.fill(imax, Integer.MIN_VALUE);//imax[c] = max_{r \in mask} A[r][c]
for(int r = 0; r< N; r++){
for(int c = 0; c< M; c++){
imin[c] = Math.min(imin[c], A[r][c]);
imax[c] = Math.max(imax[c], A[r][c]);
}
}
}
//Computing leftMinPos
for(int c = 0, ptr = 0; c< M; c++){
while(ptr > 0 && imin[stack[ptr-1]] > imin[c])ptr--;
leftMinPos[c] = ptr == 0?-1:stack[ptr-1];
stack[ptr++] = c;
}
//Computing leftMaxPos
for(int c = 0, ptr = 0; c< M; c++){
while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
leftMaxPos[c] = ptr == 0?-1:stack[ptr-1];
stack[ptr++] = c;
}
//Computing rightMinPos
for(int c = M-1, ptr = 0; c>= 0; c--){
while(ptr > 0 && imin[stack[ptr-1]] >= imin[c])ptr--;
rightMinPos[c] = ptr == 0?M:stack[ptr-1];
stack[ptr++] = c;
}
//Computing rightMaxPos
for(int c = M-1, ptr = 0; c>= 0; c--){
while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
rightMaxPos[c] = ptr == 0?M:stack[ptr-1];
stack[ptr++] = c;
}

Arrays.fill(leftNext, -1);Arrays.fill(rightNext, M);
//leftNext[i] = largest p <= i such that max[p] = min[i]^X
//rightNext[i] = smallest p >= i such that max[p] = min[i]^X
Arrays.fill(lst, -1);//lst[i] = last updated position of occurrence of i.
for(int c = 0; c< M; c++){
lst[imax[c]] = c;//updating position of imax[c] to c
if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)leftNext[c] = lst[nxt[imin[c]]];
}
Arrays.fill(lst, -1);
for(int c = M-1; c >= 0; c--){
lst[imax[c]] = c;//updating position of imax[c] to c
if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)rightNext[c] = lst[nxt[imin[c]]];
}

long subarrays = 0;
for(int c = 0; c< M; c++){
if(nxt[imin[c]] == -1)continue;//min^X doesn't appear in matrix
int Lmin = leftMinPos[c]+1, Rmin = rightMinPos[c]-1;
int Lmax = leftNext[c], Rmax = rightNext[c];
if(Lmax != -1 && rightMaxPos[Lmax] >= c){
Lmin = Math.max(Lmin, leftMaxPos[Lmax]+1);
Rmin = Math.min(Rmin, rightMaxPos[Lmax]-1);
}else if(Lmax != -1){
Lmin = Math.max(Lmin, rightMaxPos[Lmax]+1);
}

if(Rmax != M && leftMaxPos[Rmax] <= c){
Lmin = Math.max(Lmin, leftMaxPos[Rmax]+1);
Rmin = Math.min(Rmin, rightMaxPos[Rmax]-1);
}else if(Rmax != M){
Rmin = Math.min(Rmin, leftMaxPos[Rmax]-1);
}
Lmax = Math.max(Lmax, Lmin-1);
Rmax = Math.min(Rmax, Rmin+1);
long count = (c-Lmin+1)*(long)(Rmin-c+1) - (c-Lmax)*(long)(Rmax-c);
subarrays += count;
}
ans += (N-cnt+1)*fact[cnt]*fact[N-cnt]*subarrays;
}
pn(ans);
}
int bit(int x){return x == 0?0:(1+bit(x&(x-1)));}

//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = false;
void run() throws Exception{
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new MAXXMIN().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}

StringTokenizer st;
}

}

String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
}catch (IOException  e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}

String nextLine() throws Exception{
String str = "";
try{