MAXXMIN - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Evgeny Karpovich
Tester: Istvan Nagy
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Combinatorics, Pointers, and Patience.

PROBLEM

Given a matrix A with N rows and M columns and an integer X. Let f(X) denotes the number of submatrices B inside A such that min(B) \oplus max(B) = X.

Find the sum of f(X) over all permutations of rows of A, and compute their sum.

QUICK EXPLANATION

EXPLANATION

From permutations to subsets

Assume rows and columns are 1-indexed. Considering a submatrix B, it shall consist of a range [L, R], 1 \leq L \leq R \leq M of columns of rows U to B in some permutation of rows. We can see that it shall correspond to some subset of rows (not necessarily continuous) in the original matrix A. Let S denote the set of rows included in the subset. We shall count the number of permutations in which rows present in this subset appear together in any order (only then they can be chosen as rows of a submatrix).

Let’s assume C denotes the number of rows in subset. We don’t care about the order of rows in subset, so there are C! orderings of rows within subset. Now we can consider all this set of rows as a single row. So there are N-C+1 rows which can be ordered freely. Hence, the number of permutations of rows, in which subset S appears as continuous set of rows is C! * (N-C+1)!

Let’s assume, for a fixed set of rows S, the number of submatrices consisting of elements of these rows is W_S and min(B) \oplus max(B) = X, then we need to find \displaystyle \sum_{S \subseteq P(R)} W_S * (|S|!)*(N-|S|+1)!, where P(R) denotes the power set of rows. Also, W_S = 0 if S = \emptyset

Computing W_S

Let’s consider a naive way. For a subset S, each submatrix correspond to a continuous set of columns, say from L to R for 1 \leq L \leq R \leq M. Trying all pairs, we can count the number of pairs (L, R) such that the submatrix formed by columns L to R of subset of rows S has min(B) \oplus max(B) = X. This solution works in O(2^N * M^2) and shall time out.

Code
import java.util.*;
import java.io.*;
class MAXXMIN{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), M = ni(), X = ni();
        long[] fact = new long[1+N];
        fact[0] = 1;
        for(int i = 1; i<= N; i++)fact[i] = fact[i-1]*i;
        int[][] A = new int[N][M];
        for(int i = 0; i< N; i++)
            for(int j = 0; j< M; j++)
                A[i][j] = ni();
                
        long ans = 0;
        for(int mask = 1; mask < 1<<N; mask++){
            int[] imin = new int[M], imax = new int[M];
            Arrays.fill(imin, Integer.MAX_VALUE);
            Arrays.fill(imax, Integer.MIN_VALUE);
            for(int r = 0; r< N; r++){
                if(((mask>>r)&1) == 1){
                    for(int c = 0; c< M; c++){
                        imin[c] = Math.min(imin[c], A[r][c]);
                        imax[c] = Math.max(imax[c], A[r][c]);
                    }
                }
            }
            long subarrays = 0;
            for(int L = 0; L < M; L++){
                int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE;
                for(int R = L; R< M; R++){
                    min = Math.min(min, imin[R]);
                    max = Math.max(max, imax[R]);
                    if((min^max) == X)subarrays++;
                }
            }
            int cnt = bit(mask);
            ans += (N-cnt+1)*fact[cnt]*fact[N-cnt]*subarrays;
        }
        pn(ans);
    }
    int bit(int x){return x == 0?0:(1+bit(x&(x-1)));}
    
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
//        new MAXXMIN().run();
        new Thread(null, new Runnable() {public void run(){try{new MAXXMIN().run();}catch(Exception e){e.printStackTrace();System.exit(1);}}}, "1", 1 << 28).start();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Computing W_S in O(M*logM)

Let \displaystyle min_c = \min_{r \in S} A_{r, c} and \displaystyle max_c = \max_{r \in S} A_{r, c}. Now we need to compute pairs (L, R) such that \displaystyle \min_{c = L}^R min_c \oplus \max_{c = L}^R max_c = X.

Let’s try divide and conquer here. For range [L, R], if p denotes any position of \displaystyle\min_{c = L}^R min_c, then all intervals (l, r) such that L \leq l \leq p \leq r \leq R shall have \displaystyle\min_{c = L}^R min_c = min_p. With fixed minimum, we need to count the pairs (l, r) such that L \leq l \leq p \leq r \leq R such that \displaystyle \max_{c = l}^r max_c = min_p \oplus X.

Since min_p \oplus X is a fixed value, we can run binary searches now. Let’s find largest position L_g such that max_{L_g} > min_p \oplus X, position L_{ge} such that max_{L_{ge}} \geq min_p \oplus X and smallest positions R_g such that max_{R_g} > min_p \oplus X and position R_{ge} such that max_{R_g} \geq min_p \oplus X. Note that if L_g < L or L_{ge} < L, we can use L-1. Similarly R+1 for right ends.

We need subarrays with left endpoint in range [L, p] and right end in range [p, R] such that maximum of range [l, r] is min_p \oplus X. We can see that the number of such subarrays is given by (p-L_g)*(R_g-p) - (p-L_{ge})*(R_{ge}-p). (p-L_g)*(R_g-p) denote the subarrays with maximum \leq min_p \oplus X and (p-L_{ge})*(R_{ge}-p) denotes the number of subarrays with maximum < min_p \oplus X.

This way, by buildning RMQ on both min and max arrays, we can solve the problem in O(M*log(M)) for each set S, leading to time complexity O(2^N*M*log(M)) per test, which is yet too slow to get AC. Some optimizations might be able to AC.

Code
import java.util.*;
import java.io.*;
class MAXXMIN{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), M = ni(), X = ni();
        long[] fact = new long[1+N];
        fact[0] = 1;
        for(int i = 1; i<= N; i++)fact[i] = fact[i-1]*i;
        int[][] A = new int[N][M];
        int[] all = new int[N*M];
        for(int i = 0; i< N; i++)
            for(int j = 0; j< M; j++)
                all[i*M+j] = A[i][j] = ni();
        
        Arrays.sort(all);
        int C = 1;
        for(int i = 1; i< all.length; i++)if(all[i] != all[C-1])all[C++] = all[i];
        all = Arrays.copyOf(all, C);
        for(int i = 0; i< N; i++)
            for(int j = 0; j< M; j++)
                A[i][j] = Arrays.binarySearch(all, A[i][j]);
        
        int[] nxt = new int[C];
        for(int i = 0; i< C; i++){
            int pos = Arrays.binarySearch(all, all[i]^X);
            if(pos >= 0)nxt[i] = pos;
            else nxt[i] = -1;
        }
        int[] stack = new int[M];//Stack
        
        int[] leftMinPos = new int[M], rightMinPos = new int[M], leftMaxPos = new int[M], rightMaxPos = new int[M];

        int[] lst = new int[C], leftNext = new int[M], rightNext = new int[M];
        long ans = 0;
        for(int mask = 1; mask < 1<<N; mask++){
            int[] imin = new int[M], imax = new int[M];
            Arrays.fill(imin, Integer.MAX_VALUE);
            Arrays.fill(imax, Integer.MIN_VALUE);
            for(int r = 0; r< N; r++){
                if(((mask>>r)&1) == 1){
                    for(int c = 0; c< M; c++){
                        imin[c] = Math.min(imin[c], A[r][c]);
                        imax[c] = Math.max(imax[c], A[r][c]);
                    }
                }
            }
            
            for(int c = 0, ptr = 0; c< M; c++){
                while(ptr > 0 && imin[stack[ptr-1]] > imin[c])ptr--;
                leftMinPos[c] = ptr == 0?-1:stack[ptr-1];
                stack[ptr++] = c;
            }
            for(int c = 0, ptr = 0; c< M; c++){
                while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
                leftMaxPos[c] = ptr == 0?-1:stack[ptr-1];
                stack[ptr++] = c;
            }
            
            for(int c = M-1, ptr = 0; c>= 0; c--){
                while(ptr > 0 && imin[stack[ptr-1]] >= imin[c])ptr--;
                rightMinPos[c] = ptr == 0?M:stack[ptr-1];
                stack[ptr++] = c;
            }
            for(int c = M-1, ptr = 0; c>= 0; c--){
                while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
                rightMaxPos[c] = ptr == 0?M:stack[ptr-1];
                stack[ptr++] = c;
            }
            
            Arrays.fill(leftNext, -1);Arrays.fill(rightNext, M);
            Arrays.fill(lst, -1);
            for(int c = 0; c< M; c++){
                lst[imax[c]] = c;
                if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)leftNext[c] = lst[nxt[imin[c]]];
            }
            Arrays.fill(lst, -1);
            for(int c = M-1; c >= 0; c--){
                lst[imax[c]] = c;
                if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)rightNext[c] = lst[nxt[imin[c]]];
            }
            
            long subarrays = 0;
            for(int c = 0; c< M; c++){
                if(nxt[imin[c]] == -1)continue;//min^X doesn't appear in matrix
                int Lmin = leftMinPos[c]+1, Rmin = rightMinPos[c]-1;
                int Lmax = leftNext[c], Rmax = rightNext[c];
                if(Lmax != -1 && rightMaxPos[Lmax] >= c){
                    Lmin = Math.max(Lmin, leftMaxPos[Lmax]+1);
                    Rmin = Math.min(Rmin, rightMaxPos[Lmax]-1);
                }else if(Lmax != -1){
                    Lmin = Math.max(Lmin, rightMaxPos[Lmax]+1);
                }
                
                if(Rmax != M && leftMaxPos[Rmax] <= c){
                    Lmin = Math.max(Lmin, leftMaxPos[Rmax]+1);
                    Rmin = Math.min(Rmin, rightMaxPos[Rmax]-1);
                }else if(Rmax != M){
                    Rmin = Math.min(Rmin, leftMaxPos[Rmax]-1);
                }
                Lmax = Math.max(Lmax, Lmin-1);
                Rmax = Math.min(Rmax, Rmin+1);
                long count = (c-Lmin+1)*(long)(Rmin-c+1) - (c-Lmax)*(long)(Rmax-c);
                subarrays += count;
            }
            int cnt = bit(mask);
            ans += (N-cnt+1)*fact[cnt]*fact[N-cnt]*subarrays;
        }
        pn(ans);
    }
    int bit(int x){return x == 0?0:(1+bit(x&(x-1)));}
    
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
//        new MAXXMIN().run();
        new Thread(null, new Runnable() {public void run(){try{new MAXXMIN().run();}catch(Exception e){e.printStackTrace();System.exit(1);}}}, "1", 1 << 28).start();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Computing W_S in O(M)

We can no longer build RMQ, so let’s compute previous and next smaller, and previous and next greater elements. Also, for position c, let’s compute smallest position p \geq c such that max_p = min_c \oplus X. and let’s compute largest p \leq c such that max_p = \ min_c \oplus X

Our aim is still the same, to consider all positions p one by one, find the interval in which position p is minimum(say [L, R]), and among subarrays (l, r) such that L \leq l \leq p \leq r \leq R, find the number of subarrays with maximum min_p \oplus X by computing L_g, L_{ge}, R_g, R_{ge} using these arrays. I have added comments in my code for better understanding.

Implementation note: While having multiple ocurrences of min_p, be sure not to doublecount.

TIME COMPLEXITY

The time complexity is O(2^N * M) per test case.

SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>
 
using namespace std;
 
typedef long long ll;
int const maxn = 9, maxm = 1e5 + 5;
int a[maxn][maxm], f[maxn], inf = 1e9 + 7;
int all_element[maxn * maxm];
int nxt[maxn * maxm];
int imin[(1 << (maxn - 1))][maxm];
int imax[(1 << (maxn - 1))][maxm];
int lmin[maxm], rmin[maxm], lmax[maxm], rmax[maxm], Q[maxm];
int lnxt[maxm], rnxt[maxm];
int lst[maxn * maxm];
 
int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n, m, x, N = 0;
    cin >> n >> m >> x;
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            cin >> a[i][j];
            all_element[++N] = a[i][j];
        }
    }
    f[0] = 1;
    for (int i = 1; i <= n; ++i) f[i] = f[i - 1] * i;
    sort(all_element + 1, all_element + N + 1);
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            a[i][j] = lower_bound(all_element + 1, all_element + N + 1, a[i][j]) - all_element;
        }
    }
    for (int i = 1; i <= N; ++i) {
        int pos = lower_bound(all_element + 1, all_element + N + 1, (all_element[i]^x)) - all_element;
        if (pos <= N && all_element[pos] == (all_element[i]^x)) {
            nxt[i] = pos;
        }
    }
    for (int i = 1; i <= m; ++i) {
        imin[0][i] = inf;
        imax[0][i] = -inf;
    }
    ll ans = 0;
    for (int mask = 1; mask < (1 << n); ++mask) {
        int cnt = __builtin_popcount(mask);
        int b = __builtin_ffs(mask);
        int where = (mask^(1 << (b - 1)));
        for (int j = 1; j <= m; ++j) {
            imin[mask][j] = min(imin[where][j], a[b][j]);
            imax[mask][j] = max(imax[where][j], a[b][j]);
        }
        ll add = 0;
        int ptr = 0;
        for (int j = 1; j <= m; ++j) {
            while (ptr && imin[mask][Q[ptr]] > imin[mask][j]) ptr--;
            lmin[j] = Q[ptr];
            Q[++ptr] = j;
        }
        ptr = 0;
        for (int j = 1; j <= m; ++j) {
            while (ptr && imax[mask][Q[ptr]] <= imax[mask][j]) ptr--;
            lmax[j] = Q[ptr];
            Q[++ptr] = j;
        }
        ptr = 0;
        for (int j = m; j >= 1; --j) {
            while (ptr && imin[mask][Q[ptr]] >= imin[mask][j]) ptr--;
            if (ptr == 0) rmin[j] = m + 1;
            else rmin[j] = Q[ptr];
            Q[++ptr] = j;
        }
        ptr = 0;
        for (int j = m; j >= 1; --j) {
            while (ptr && imax[mask][Q[ptr]] <= imax[mask][j]) ptr--;
            if (ptr == 0) rmax[j] = m + 1;
            else rmax[j] = Q[ptr];
            Q[++ptr] = j;
        }
        for (int j = 1; j <= m; ++j) {
            lst[imax[mask][j]] = j;
            lnxt[j] = lst[nxt[imin[mask][j]]];
        }
        for (int j = m; j >= 1; --j) {
            lst[imax[mask][j]] = j;
            if (lst[nxt[imin[mask][j]]] >= j) {
                rnxt[j] = lst[nxt[imin[mask][j]]];
            }
            else rnxt[j] = m + 1;
        }
        for (int j = 1; j <= m; ++j) lst[imax[mask][j]] = 0;
        for (int j = 1; j <= m; ++j) {
            if (nxt[imin[mask][j]] == 0) continue;
            int L = lmin[j] + 1, R = rmin[j] - 1;
            int lx = lnxt[j], rx = rnxt[j];
            if (lx != 0 && rmax[lx] >= j) {
                R = min(R, rmax[lx] - 1);
                L = max(L, lmax[lx] + 1);
            }
            else if (lx != 0) {
                L = max(L, rmax[lx] + 1);
            }
            if (rx != m + 1 && lmax[rx] <= j) {
                L = max(L, lmax[rx] + 1);
                R = min(R, rmax[rx] - 1);
            }
            else if (rx != m + 1) {
                R = min(R, lmax[rx] - 1);
            }
            lx = max(lx, L - 1), rx = min(rx, R + 1);
            ll val = (ll)(j - L + 1) * (ll)(R - j + 1) - (ll)(j - lx) * (ll)(rx - j);
            add += val;
        }
        ans += add * (ll)(f[cnt] * f[n - cnt] * (n - cnt + 1));
    }
    cout << ans << '\n';
    return 0;
}
Tester's Solution
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#include <vector>
#include <set>
#include <map>
#include <unordered_set>
#include <unordered_map>
#include <queue>
#include <ctime>
#include <cassert>
#include <complex>
#include <string>
#include <cstring>
#include <chrono>
#include <random>
#include <bitset>
using namespace std;

#ifdef LOCAL
    #define eprintf(...) fprintf(stderr, __VA_ARGS__);fflush(stderr);
#else
    #define eprintf(...) 42
#endif

using ll = long long;
using ld = long double;
using uint = unsigned int;
using ull = unsigned long long;
template<typename T>
using pair2 = pair<T, T>;
using pii = pair<int, int>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
ll myRand(ll B) {
    return (ull)rng() % B;
}

#define pb push_back
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define fi first
#define se second

clock_t startTime;
double getCurrentTime() {
    return (double)(clock() - startTime) / CLOCKS_PER_SEC;
}

const int N = 100100;
const int M = 8;
const int K = N * M;
int a[M][N];
int b[M + 1][N], c[M + 1][N];
int n, m, k;
int xs[K];
int pr[K];
ll ans[M + 1];
int smL[N], smR[N], bgL[N], bgR[N], wL[N], wR[N];
int st[N];
int stSz;
int lst[K];

ll solve() {
    stSz = 0;
    st[0] = -1;
    for (int i = 0; i < n; i++) {
	    while(stSz > 0 && b[m][i] < b[m][st[stSz]]) stSz--;
	    smL[i] = st[stSz];
	    st[++stSz] = i;
    }
    stSz = 0;
    st[0] = n;
    for (int i = n - 1; i >= 0; i--) {
	    while(stSz > 0 && b[m][i] <= b[m][st[stSz]]) stSz--;
	    smR[i] = st[stSz];
	    st[++stSz] = i;
    }
    stSz = 0;
    st[0] = -1;
    for (int i = 0; i < n; i++) {
	    while(stSz > 0 && c[m][i] >= c[m][st[stSz]]) stSz--;
	    bgL[i] = st[stSz];
	    st[++stSz] = i;
    }
    stSz = 0;
    st[0] = n;
    for (int i = n - 1; i >= 0; i--) {
	    while(stSz > 0 && c[m][i] >= c[m][st[stSz]]) stSz--;
	    bgR[i] = st[stSz];
	    st[++stSz] = i;
    }
    for (int i = 0; i < n; i++) {
	    lst[c[m][i]] = i;
	    if (pr[b[m][i]] != -1 && lst[pr[b[m][i]]] != -1)
		    wL[i] = lst[pr[b[m][i]]];
	    else
		    wL[i] = -1;
    }
    for (int i = 0; i < n; i++)
	    lst[c[m][i]] = -1;
    for (int i = n - 1; i >= 0; i--) {
	    lst[c[m][i]] = i;
	    if (pr[b[m][i]] != -1 && lst[pr[b[m][i]]] != -1)
		    wR[i] = lst[pr[b[m][i]]];
	    else
		    wR[i] = n;
    }
    for (int i = 0; i < n; i++)
	    lst[c[m][i]] = -1;
    ll res = 0;
    for (int i = 0; i < n; i++) {
	    int x = pr[b[m][i]];
	    if (x == -1) continue;
	    if (c[m][i] == x) {
		    res += (ll)(i - max(smL[i], bgL[i])) * (min(smR[i], bgR[i]) - i);
	    } else {
		    int l = smL[i], r = smR[i];
		    int p = -1, q = -1;
		    if (wL[i] > l) {
			    p = wL[i];
			    if (bgR[p] <= i) p = -1;
		    }
		    if (wR[i] < r) {
			    q = wR[i];
			    if (bgL[q] >= i) q = -1;
		    }
		    if (p != -1) {
			    if (q != -1) {
				    assert(bgL[p] == bgL[q]);
				    assert(bgR[p] == bgR[q]);
				    l = max(l, bgL[p]);
				    r = min(r, bgR[p]);
				    res += (ll)(i - l) * (r - i) - (ll)(i - p) * (q - i);
			    } else {
				    l = max(l, bgL[p]);
				    r = min(r, bgR[p]);
				    res += (ll)(p - l) * (r - i);
			    }
		    } else if (q != -1) {
			    l = max(l, bgL[q]);
			    r = min(r, bgR[q]);
			    res += (ll)(i - l) * (r - q);
		    }
	    }
    }
    return res;
}

void brute(int p, int cnt) {
    if (p == m) {
	    if (cnt > 0) ans[cnt] += solve();
	    return;
    }
    for (int i = 0; i < n; i++) {
	    b[p + 1][i] = b[p][i];
	    c[p + 1][i] = c[p][i];
    }
    brute(p + 1, cnt);
    for (int i = 0; i < n; i++) {
	    b[p + 1][i] = min(b[p + 1][i], a[p][i]);
	    c[p + 1][i] = max(c[p + 1][i], a[p][i]);
    }
    brute(p + 1, cnt + 1);
}

int main()
{
    startTime = clock();
//	freopen("input.txt", "r", stdin);
//	freopen("output.txt", "w", stdout);

    int X;
    scanf("%d%d%d", &m, &n, &X);
    for (int i = 0; i < m; i++)
	    for (int j = 0; j < n; j++) {
		    scanf("%d", &a[i][j]);
		    xs[k++] = a[i][j];
	    }
    sort(xs, xs + k);
    k = unique(xs, xs + k) - xs;
    for (int i = 0; i < k; i++) {
	    int x = xs[i] ^ X;
	    int p = lower_bound(xs, xs + k, x) - xs;
	    if (p < k && xs[p] == x)
		    pr[i] = p;
	    else
		    pr[i] = -1;
    }
    for (int i = 0; i < k; i++)
	    lst[i] = -1;
    for (int i = 0; i < m; i++)
	    for (int j = 0; j < n; j++)
		    a[i][j] = lower_bound(xs, xs + k, a[i][j]) - xs;
    for (int i = 0; i < n; i++) {
	    b[0][i] = k;
	    c[0][i] = 0;
    }
    brute(0, 0);
    ll res = 0;
    for (int k = 1; k <= m; k++) {
	    ll w = ans[k];
	    for (int i = 1; i <= k; i++)
		    w *= i;
	    for (int i = 1; i <= m + 1 - k; i++)
		    w *= i;
	    res += w;
    }
    printf("%lld\n", res);

    return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class MAXXMIN{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = ni(), M = ni(), X = ni();
        long[] fact = new long[1+N];
        fact[0] = 1;
        for(int i = 1; i<= N; i++)fact[i] = fact[i-1]*i;
        
        int[][] A = new int[N][M];
        int[] all = new int[N*M];//Contains all distinct elements
        for(int i = 0; i< N; i++)
            for(int j = 0; j< M; j++)
                all[i*M+j] = A[i][j] = ni();
        
        Arrays.sort(all);
        int C = 1;
        for(int i = 1; i< all.length; i++)if(all[i] != all[C-1])all[C++] = all[i];
        all = Arrays.copyOf(all, C);
        
        for(int i = 0; i< N; i++)
            for(int j = 0; j< M; j++)
                A[i][j] = Arrays.binarySearch(all, A[i][j]);//Value compression, all[A[r][c]] now gets the original value of A[r][c]
        
        int[] nxt = new int[C];//all[nxt[i]] = all[i]^X, or nxt[i] = -1 if no such position exists
        for(int i = 0; i< C; i++){
            int pos = Arrays.binarySearch(all, all[i]^X);
            if(pos >= 0)nxt[i] = pos;
            else nxt[i] = -1;
        }
        
        int[] stack = new int[M];//Temporary stack
        
        int[] leftMinPos = new int[M], rightMinPos = new int[M], leftMaxPos = new int[M], rightMaxPos = new int[M];
        //leftMinPos[i] -> largest p < i such that imin[p] <= imin[i]
        //rightMinPos[i] -> smallest p > i such that imin[p] > imin[i]
        
        //leftMaxPos[i] -> largest p < i such that imax[p] <= imax[i]
        //rightMaxPos[i] -> smallest p > i such that imax[p] > imax[i]
        
        int[] lst = new int[C], leftNext = new int[M], rightNext = new int[M];
        //leftNext[i] = largest p <= i such that max[p] = min[i]^X
        //rightNext[i] = smallest p >= i such that max[p] = min[i]^X
        long ans = 0;
        for(int mask = 1; mask < 1<<N; mask++){
            int[] imin = new int[M], imax = new int[M];
            Arrays.fill(imin, Integer.MAX_VALUE);//imin[c] = min_{r \in mask} A[r][c]
            Arrays.fill(imax, Integer.MIN_VALUE);//imax[c] = max_{r \in mask} A[r][c]
            for(int r = 0; r< N; r++){
                if(((mask>>r)&1) == 1){
                    for(int c = 0; c< M; c++){
                        imin[c] = Math.min(imin[c], A[r][c]);
                        imax[c] = Math.max(imax[c], A[r][c]);
                    }
                }
            }
            //Computing leftMinPos
            for(int c = 0, ptr = 0; c< M; c++){
                while(ptr > 0 && imin[stack[ptr-1]] > imin[c])ptr--;
                leftMinPos[c] = ptr == 0?-1:stack[ptr-1];
                stack[ptr++] = c;
            }
            //Computing leftMaxPos
            for(int c = 0, ptr = 0; c< M; c++){
                while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
                leftMaxPos[c] = ptr == 0?-1:stack[ptr-1];
                stack[ptr++] = c;
            }
            //Computing rightMinPos
            for(int c = M-1, ptr = 0; c>= 0; c--){
                while(ptr > 0 && imin[stack[ptr-1]] >= imin[c])ptr--;
                rightMinPos[c] = ptr == 0?M:stack[ptr-1];
                stack[ptr++] = c;
            }
            //Computing rightMaxPos
            for(int c = M-1, ptr = 0; c>= 0; c--){
                while(ptr > 0 && imax[stack[ptr-1]] <= imax[c])ptr--;
                rightMaxPos[c] = ptr == 0?M:stack[ptr-1];
                stack[ptr++] = c;
            }
            
            Arrays.fill(leftNext, -1);Arrays.fill(rightNext, M);
            //leftNext[i] = largest p <= i such that max[p] = min[i]^X
            //rightNext[i] = smallest p >= i such that max[p] = min[i]^X
            Arrays.fill(lst, -1);//lst[i] = last updated position of occurrence of i.
            for(int c = 0; c< M; c++){
                lst[imax[c]] = c;//updating position of imax[c] to c
                if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)leftNext[c] = lst[nxt[imin[c]]];
            }
            Arrays.fill(lst, -1);
            for(int c = M-1; c >= 0; c--){
                lst[imax[c]] = c;//updating position of imax[c] to c
                if(nxt[imin[c]] != -1 && lst[nxt[imin[c]]] != -1)rightNext[c] = lst[nxt[imin[c]]];
            }
            
            long subarrays = 0;
            for(int c = 0; c< M; c++){
                if(nxt[imin[c]] == -1)continue;//min^X doesn't appear in matrix
                int Lmin = leftMinPos[c]+1, Rmin = rightMinPos[c]-1;
                int Lmax = leftNext[c], Rmax = rightNext[c];
                if(Lmax != -1 && rightMaxPos[Lmax] >= c){
                    Lmin = Math.max(Lmin, leftMaxPos[Lmax]+1);
                    Rmin = Math.min(Rmin, rightMaxPos[Lmax]-1);
                }else if(Lmax != -1){
                    Lmin = Math.max(Lmin, rightMaxPos[Lmax]+1);
                }
                
                if(Rmax != M && leftMaxPos[Rmax] <= c){
                    Lmin = Math.max(Lmin, leftMaxPos[Rmax]+1);
                    Rmin = Math.min(Rmin, rightMaxPos[Rmax]-1);
                }else if(Rmax != M){
                    Rmin = Math.min(Rmin, leftMaxPos[Rmax]-1);
                }
                Lmax = Math.max(Lmax, Lmin-1);
                Rmax = Math.min(Rmax, Rmin+1);
                long count = (c-Lmin+1)*(long)(Rmin-c+1) - (c-Lmax)*(long)(Rmax-c);
                subarrays += count;
            }
            int cnt = bit(mask);
            ans += (N-cnt+1)*fact[cnt]*fact[N-cnt]*subarrays;
        }
        pn(ans);
    }
    int bit(int x){return x == 0?0:(1+bit(x&(x-1)));}
    
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new MAXXMIN().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

1 Like