GRDXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: iiii63027
Tester: jay_1048576
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Familiarity with bitwise XOR

PROBLEM:

You’re given an N\times M grid A.
You can choose a cell (i, j), delete the row i and column j, and compute the sum of A_{i, j} \oplus x across x remaining in the grid.

Find the maximum possible value of this sum if (i, j) is chosen optimally.

EXPLANATION:

Let’s fix a cell (i, j) and try to quickly compute the resulting sum.

Iterating over all remaining elements would be too slow, we need to do a bit better.
As always, when dealing with bitwise operations, it helps to look at each bit individually.

So, let’s fix a bit b, and see how many of the A_{i, j} \oplus x values have this bit set.
If this count is k_b, then bit b contributes k_b \cdot 2^b to the overall sum.
We only care about values of b upto 30, so if we can find k_b for each one quickly we’ll be done.

There are two cases:

  • Case 1: A_{i, j} itself has bit b set.
    Here, A_{i, j} \oplus x will have b set if and only if x doesn’t have b set.
    So, we’d like to know the total number of x that don’t have bit b set.
  • Case 2: A_{i, j} doesn’t have bit b set.
    Here, A_{i, j}\oplus x will have b set if and only if x has b set.
    So, we’d like to know the total number of x that do have bit b set.

Notice that both cases require us to know similar information: for fixed (i, j) and bit b, how many elements outside the i-th row and j-column have (or don’t have) bit b set.

This can be found in \mathcal{O}(1) with a bit of precomputation.
In particular,

  • Let \text{ct}_b be the total number of values in the grid that have bit b set.
  • Let \text{row}_{i, b} the the number of values in the i-th row that have bit b set.
  • Let \text{col}_{j, b} the the number of values in the j-th column that have bit b set.
  • Then, the number of elements outside row i and column j that have bit b set equals
    ct_b - \text{row}_{i, b} - \text{col}_{j, b} + (1 \text{ if } A_{i, j} \text{ has bit } b \text{ set, and } 0 \text{ otherwise}).
  • The number of elements that don’t have bit b set equals (N-1)\cdot (M-1) minus the above quantity.

The values of \text{ct}_b, \text{row}_{i, b} and \text{col}_{j, b} can be precomputed for all i, j, b by going through the grid once.

After this, the answer for each (i, j) can be found by iterating across each bit b and computing k_b.
We check 30 bits for each cell, so the overall complexity is \mathcal{O}(30\cdot NM), which is fast enough.

TIME COMPLEXITY

\mathcal{O}(30\cdot NM) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#define int long long
#define ll __int128

#define mod (int)(1e9 + 7)
using namespace std;

int n, m;

int solve(vector<int> &rows, vector<int> &cols, vector<int> total, vector<int> &indi)
{
	for (int i = 0; i < 32; i++)
	{
		total[i] -= rows[i] + cols[i];
		total[i] += indi[i];
	}
	int ans = 0;
	for (int i = 0; i < 32; i++)
	{
		if (indi[i])
		{
			ans += (1ll << i) * (n * m - n - m + 1 - total[i]);
		}
		else
		{
			ans += (1ll << i) * (total[i]);
		}
	}
	return ans;
}

void solve()
{
	cin >> n >> m;
	vector<vector<int>> v(n, vector<int>(m));
	vector<vector<int>> rows(n, vector<int>(32)), cols(m, vector<int>(32));
	vector<int> total(32);
	for (int i = 0; i < n; i++)
	{
		for (int j = 0; j < m; j++)
		{
			cin >> v[i][j];
			int it = v[i][j];
			int tep = 0;
			while (it != 0)
			{
				if (it & 1)
				{
					rows[i][tep] += 1;
					cols[j][tep] += 1;
					total[tep] += 1;
				}
				tep++;
				it >>= 1;
			}
		}
	}
	int ans = 0;
	vector<int> indi(32);
	for (int i = 0; i < n; i++)
	{
		for (int j = 0; j < m; j++)
		{
			fill(indi.begin(), indi.end(), 0);
			int it = v[i][j], tep = 0;
			while (it != 0)
			{
				if (it & 1)
				{
					indi[tep] = 1;
				}
				else
					indi[tep] = 0;
				tep++;
				it >>= 1;
			}
			ans = max(ans, solve(rows[i], cols[j], total, indi));
		}
	}
	cout << ans << '\n';
}
signed main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	int n;
	cin >> n;
	for (int i = 0; i < n; i++)
		solve();
}
Tester's code (C++)
/*...................................................................*
 *............___..................___.....____...______......___....*
 *.../|....../...\........./|...../...\...|.............|..../...\...*
 *../.|...../.....\......./.|....|.....|..|.............|.../........*
 *....|....|.......|...../..|....|.....|..|............/...|.........*
 *....|....|.......|..../...|.....\___/...|___......../....|..___....*
 *....|....|.......|.../....|...../...\.......\....../.....|./...\...*
 *....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
 *....|.....\...../.........|....|.....|.......|.../........\...../..*
 *..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
 *...................................................................*
 */
 
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1000000000000000000
#define MOD 1000000007

void solve(int tc)
{
    int n,m;
    cin >> n >> m;
    int a[n][m];
    for(int i=0;i<n;i++)
        for(int j=0;j<m;j++)
            cin >> a[i][j];
    int ans[n][m];
    memset(ans,0,sizeof(ans));
    for(int b=0;b<30;b++)
    {
        int pre[n][m];
        for(int i=0;i<n;i++)
            for(int j=0;j<m;j++)
                pre[i][j]=(a[i][j]&(1<<b))?1:0;
        for(int i=1;i<n;i++)
            for(int j=0;j<m;j++)
                pre[i][j]+=pre[i-1][j];
        for(int i=0;i<n;i++)
            for(int j=1;j<m;j++)
                pre[i][j]+=pre[i][j-1];
        for(int i=0;i<n;i++)
        {
            for(int j=0;j<m;j++)
            {
                int s = pre[n-1][m-1]-pre[i][m-1]-pre[n-1][j];
                if(i>0)
                    s += pre[i-1][m-1];
                if(j>0)
                    s += pre[n-1][j-1];
                if(a[i][j]&(1<<b))
                    s = n*m-n-m-s;
                ans[i][j] += s*(1<<b);
            }
        }
    }
    int mx = 0;
    for(int i=0;i<n;i++)
        for(int j=0;j<m;j++)
            mx = max(mx,ans[i][j]);
    cout << mx << '\n';
}

int32_t main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int tc=1;
    cin >> tc;
    for(int ttc=1;ttc<=tc;ttc++)
        solve(ttc);
    return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
    n, m = map(int, input().split())
    bitct_full = [0]*31
    bitct_row = [ [0]*31 for _ in range(n)]
    bitct_col = [ [0]*31 for _ in range(m)]
    
    a = []
    for row in range(n):
        a.append(list(map(int, input().split())))
        for col in range(m):
            for bit in range(30):
                if a[row][col] & (1 << bit):
                    bitct_full[bit] += 1
                    bitct_row[row][bit] += 1
                    bitct_col[col][bit] += 1
    ans = 0
    for row in range(n):
        for col in range(m):
            x = a[row][col]
            cur = 0
            for bit in range(30):
                ct = bitct_full[bit] - bitct_row[row][bit] - bitct_col[col][bit]
                if x & (2 ** bit):
                    ct += 1
                    ct = (n-1)*(m-1) - ct
                cur += ct << bit
            ans = max(ans, cur)
    print(ans)