PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: iiii63027
Tester: jay_1048576
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Familiarity with bitwise XOR
PROBLEM:
You’re given an N\times M grid A.
You can choose a cell (i, j), delete the row i and column j, and compute the sum of A_{i, j} \oplus x across x remaining in the grid.
Find the maximum possible value of this sum if (i, j) is chosen optimally.
EXPLANATION:
Let’s fix a cell (i, j) and try to quickly compute the resulting sum.
Iterating over all remaining elements would be too slow, we need to do a bit better.
As always, when dealing with bitwise operations, it helps to look at each bit individually.
So, let’s fix a bit b, and see how many of the A_{i, j} \oplus x values have this bit set.
If this count is k_b, then bit b contributes k_b \cdot 2^b to the overall sum.
We only care about values of b upto 30, so if we can find k_b for each one quickly we’ll be done.
There are two cases:
-
Case 1: A_{i, j} itself has bit b set.
Here, A_{i, j} \oplus x will have b set if and only if x doesn’t have b set.
So, we’d like to know the total number of x that don’t have bit b set. -
Case 2: A_{i, j} doesn’t have bit b set.
Here, A_{i, j}\oplus x will have b set if and only if x has b set.
So, we’d like to know the total number of x that do have bit b set.
Notice that both cases require us to know similar information: for fixed (i, j) and bit b, how many elements outside the i-th row and j-column have (or don’t have) bit b set.
This can be found in \mathcal{O}(1) with a bit of precomputation.
In particular,
- Let \text{ct}_b be the total number of values in the grid that have bit b set.
- Let \text{row}_{i, b} the the number of values in the i-th row that have bit b set.
- Let \text{col}_{j, b} the the number of values in the j-th column that have bit b set.
- Then, the number of elements outside row i and column j that have bit b set equals
ct_b - \text{row}_{i, b} - \text{col}_{j, b} + (1 \text{ if } A_{i, j} \text{ has bit } b \text{ set, and } 0 \text{ otherwise}). - The number of elements that don’t have bit b set equals (N-1)\cdot (M-1) minus the above quantity.
The values of \text{ct}_b, \text{row}_{i, b} and \text{col}_{j, b} can be precomputed for all i, j, b by going through the grid once.
After this, the answer for each (i, j) can be found by iterating across each bit b and computing k_b.
We check 30 bits for each cell, so the overall complexity is \mathcal{O}(30\cdot NM), which is fast enough.
TIME COMPLEXITY
\mathcal{O}(30\cdot NM) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
#define int long long
#define ll __int128
#define mod (int)(1e9 + 7)
using namespace std;
int n, m;
int solve(vector<int> &rows, vector<int> &cols, vector<int> total, vector<int> &indi)
{
for (int i = 0; i < 32; i++)
{
total[i] -= rows[i] + cols[i];
total[i] += indi[i];
}
int ans = 0;
for (int i = 0; i < 32; i++)
{
if (indi[i])
{
ans += (1ll << i) * (n * m - n - m + 1 - total[i]);
}
else
{
ans += (1ll << i) * (total[i]);
}
}
return ans;
}
void solve()
{
cin >> n >> m;
vector<vector<int>> v(n, vector<int>(m));
vector<vector<int>> rows(n, vector<int>(32)), cols(m, vector<int>(32));
vector<int> total(32);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < m; j++)
{
cin >> v[i][j];
int it = v[i][j];
int tep = 0;
while (it != 0)
{
if (it & 1)
{
rows[i][tep] += 1;
cols[j][tep] += 1;
total[tep] += 1;
}
tep++;
it >>= 1;
}
}
}
int ans = 0;
vector<int> indi(32);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < m; j++)
{
fill(indi.begin(), indi.end(), 0);
int it = v[i][j], tep = 0;
while (it != 0)
{
if (it & 1)
{
indi[tep] = 1;
}
else
indi[tep] = 0;
tep++;
it >>= 1;
}
ans = max(ans, solve(rows[i], cols[j], total, indi));
}
}
cout << ans << '\n';
}
signed main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
for (int i = 0; i < n; i++)
solve();
}
Tester's code (C++)
/*...................................................................*
*............___..................___.....____...______......___....*
*.../|....../...\........./|...../...\...|.............|..../...\...*
*../.|...../.....\......./.|....|.....|..|.............|.../........*
*....|....|.......|...../..|....|.....|..|............/...|.........*
*....|....|.......|..../...|.....\___/...|___......../....|..___....*
*....|....|.......|.../....|...../...\.......\....../.....|./...\...*
*....|....|.......|../_____|__..|.....|.......|..../......|/.....\..*
*....|.....\...../.........|....|.....|.......|.../........\...../..*
*..__|__....\___/..........|.....\___/...\___/.../..........\___/...*
*...................................................................*
*/
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF 1000000000000000000
#define MOD 1000000007
void solve(int tc)
{
int n,m;
cin >> n >> m;
int a[n][m];
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
cin >> a[i][j];
int ans[n][m];
memset(ans,0,sizeof(ans));
for(int b=0;b<30;b++)
{
int pre[n][m];
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
pre[i][j]=(a[i][j]&(1<<b))?1:0;
for(int i=1;i<n;i++)
for(int j=0;j<m;j++)
pre[i][j]+=pre[i-1][j];
for(int i=0;i<n;i++)
for(int j=1;j<m;j++)
pre[i][j]+=pre[i][j-1];
for(int i=0;i<n;i++)
{
for(int j=0;j<m;j++)
{
int s = pre[n-1][m-1]-pre[i][m-1]-pre[n-1][j];
if(i>0)
s += pre[i-1][m-1];
if(j>0)
s += pre[n-1][j-1];
if(a[i][j]&(1<<b))
s = n*m-n-m-s;
ans[i][j] += s*(1<<b);
}
}
}
int mx = 0;
for(int i=0;i<n;i++)
for(int j=0;j<m;j++)
mx = max(mx,ans[i][j]);
cout << mx << '\n';
}
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int tc=1;
cin >> tc;
for(int ttc=1;ttc<=tc;ttc++)
solve(ttc);
return 0;
}
Editorialist's code (Python)
for _ in range(int(input())):
n, m = map(int, input().split())
bitct_full = [0]*31
bitct_row = [ [0]*31 for _ in range(n)]
bitct_col = [ [0]*31 for _ in range(m)]
a = []
for row in range(n):
a.append(list(map(int, input().split())))
for col in range(m):
for bit in range(30):
if a[row][col] & (1 << bit):
bitct_full[bit] += 1
bitct_row[row][bit] += 1
bitct_col[col][bit] += 1
ans = 0
for row in range(n):
for col in range(m):
x = a[row][col]
cur = 0
for bit in range(30):
ct = bitct_full[bit] - bitct_row[row][bit] - bitct_col[col][bit]
if x & (2 ** bit):
ct += 1
ct = (n-1)*(m-1) - ct
cur += ct << bit
ans = max(ans, cur)
print(ans)