MNXR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: pols_agyi_pols
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Familiarity with bitwise XOR

PROBLEM:

Given N, M, X, Y, compute the xor-sum of the manhattan distances of all the cells in an N\times M grid from (X, Y).

EXPLANATION:

There are likely several ways to approach this task, below I’ll detail one of them.

Observe that the grid can be split into four regions about (X, Y), much like the four quadrants in 2D coordinate space.
The ‘axes’ in this case will be row X and column Y.

For example, on a 6\times 7 grid with (X, Y) = (3, 3), we have the following grid.

Click to open image

The colored regions above represent the different quadrants, while the axes remain white (and the source is green).
Let’s try to solve for each axis and each quadrant independently, and xor all their answers together afterwards.

Solving for an axis

Let’s look at the Manhattan distance of cells in the same row as X.
It’s easy to see that this distance starts out at 0 at (X, Y), and then increases by 1 for each step we take away from column Y.

So, looking at just the cells to the left, we’ll have the values 1, 2, 3, \ldots, Y-1.
What’s the xor-sum of these numbers?

Answer

Computing 1\oplus 2\oplus\ldots\oplus K for a fixed integer K is a rather well-known problem.
The idea is to look at K modulo 4, because every 4 numbers (starting from 0) will have a xor-sum of 0 and cancel out.
So if f(K) is the value we want,

  • If K is a multiple of 4, f(K) = K.
  • If K-1 is a multiple of 4, f(K) = 1.
  • If K-2 is a multiple of 4, f(K) = K+1.
  • If K-3 is a multiple of 4, f(K) = 0.

A detailed explanation can be found here.
There’s also a nice visualization at the bottom of that page as for why things cancel out every four steps.

So, 1\oplus 2\oplus\ldots\oplus (Y-1) can be computed in constant time.
The right side can be similarly computed: it’ll be the xor-sum till M-Y.
The values of the vertical axes can also be computed in exactly the same fashion; bring xor-sums upto X-1 and N-X respectively.

Solving for a quadrant

Let’s look at the bottom-right quadrant. For example, taking the image above, we’re looking at the part shaded red here.

This will be a rectangle of dimensions (N-X)\times (M-Y), and you can see that the Manhattan distances of its cells will follow a fairly simple pattern.

Each row has a contiguous set of values - in particular, the i-th row contains the values i+1 to M-Y+i.
We already know how to compute f(X), the xor-sum of integers 0 to X.
This of course allows us to compute g(L, R), the xor-sum of integers from L to R, as
g(L, R) = f(R)\oplus f(L-1).

Going row-by-row, it can be seen that the value we want to compute is exactly

\bigoplus_{i=1}^{N-X} g(i+1, M-Y+i) \\ \\ = \bigoplus_{i=1}^{N-X} (f(M-Y+i) \oplus f(i))

This can be broken up into two parts: f(1)\oplus f(2)\oplus\ldots\oplus f(N-X), and
f(M-Y+1)\oplus\ldots \oplus f(M-Y+N-X).

Notice that for both of them, we just need to be able to find the prefix xor of f(i) quickly.
That is, if h(K) = f(1)\oplus f(2)\oplus \ldots\oplus f(K), we need to be able to quickly compute the values of h.
Recall that f(K) was computed quickly by relying on the fact that things cancelled out every four steps; so intuitively it seems to be a good idea to check if something similar happens to h.
As it turns out, this is indeed the case!

Details

Rather than every 4 steps, now things cancel out every 8 steps.

Specifically, if r = K\bmod 8,

  • If r \in \{0, 1\}, we have h(K) = K.
  • If r \in \{2, 3\}, we have h(K) = 2.
  • If r \in \{4, 5\}, we have h(K) = K+2.
  • If r \in \{6, 7\}, we have h(K) = 0.

This allows for h(K) to be computed in constant time.

With this, the value we’re looking for is exactly h(N-X+M-Y)\oplus h(M-Y)\oplus h(N-X).
The other three quadrants can be solved similarly, only the dimensions change.

Each axis and quadrant are solved in constant time; and there’s only four of each so the overall solution is constant time too.

TIME COMPLEXITY:

\mathcal{O}(1) per testcase.

CODE:

Author's code (C++)
 #include <bits/stdc++.h>
using namespace std;
#define ll long long
ll xorr(ll n)
{
    if(n<0) return 0;
    if(n%4==0)  return n;
    if(n%4==1)  return 1;
    if(n%4==2)  return n+1;
    if(n%4==3)  return 0;
}
ll lxorr(ll l, ll r)
{
    if(l>r) return 0;
    ll t= (2*(xorr((l-1)/2))) ^ (2*xorr(r/2));
    if(r%2)
        t^= (xorr(r)^xorr(l-1));
    return t;
}

ll fx(ll x,ll y,ll a,ll b)
{
    ll ans=0;
    if(x==a)
    {
        ans=(xorr(x)^xorr(y-1));return ans;
    }
    if(x==y)
    {
        ans=(xorr(x)^(xorr(a-1)));return ans;
    }
    if(a%2==x%2)
    {
        ans=(xorr(a)^(xorr(b-1)));a++;b++;
    }
    ans^=lxorr(a,x);
    ans^=lxorr(b,y-1);
    return ans;
}
int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
    ll kitne_cases_hain;
    //freopen("input.txt","r",stdin);freopen("output.txt","w",stdout);
    cin>>kitne_cases_hain;
    assert(1<=kitne_cases_hain && kitne_cases_hain<=1e5);
    while(kitne_cases_hain--){     
        ll n,m;cin>>n>>m;
        assert(1<=n && n<=1e9 && 1<=m && m<=1e9);
        ll x,y;cin>>x>>y;
        assert(1<=x && x<=n && 1<=y && y<=m);
        ll ans=0;
        ans^=fx(x+y-2,x-1,y-1,0);
        if(y!=m)
            ans^=fx(m-y+x-1,x,m-y,1);
        if(x!=n)
            ans^=fx(n-x+y-1,n-x,y,1);
        if(x!=n && y!=m)
            ans^=fx(n+m-x-y,n-x+1,m-y+1,2);
        cout<<ans<<"\n";
    }
	return 0;
} 
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

#define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

// 0 ^ 1 ^ ... ^ t
int xorsum(int t) {
    int res = 0;
    for (int i = t / 4 * 4; i <= t; i++) {
        res ^= i;
    }
    return res;
}

// (t % 2) ^ ... ^ (t - 2) ^ t
int xorsum2(int t) {
    if (t % 2 == 1) {
        return xorsum(t) ^ xorsum2(t - 1);
    } else {
        int res = 0;
        for (int i = t / 8 * 8; i <= t; i += 2) {
            res ^= i;
        }
        return res;
    }
}

int main() {
    {
        int t = 0;
        for (int i = 0; i <= 10000; i++) {
            t ^= i;
            assert(t == xorsum(i));
        }
        t = 0;
        for (int i = 1; i <= 10000; i += 2) {
            t ^= i;
            assert(t == xorsum2(i));
        }
        t = 0;
        for (int i = 2; i <= 10000; i += 2) {
            t ^= i;
            assert(t == xorsum2(i));
        }
    }
    input_checker in;
    int tt = in.readInt(1, 1e5);
    in.readEoln();
    while (tt--) {
        int n = in.readInt(1, 1e9);
        in.readSpace();
        int m = in.readInt(1, 1e9);
        in.readSpace();
        int x = in.readInt(1, n);
        in.readSpace();
        int y = in.readInt(1, m);
        in.readEoln();
        x--;
        y--;
        x = min(x, n - 1 - x);
        y = min(y, m - 1 - y);
        int ans = 0;
        int a = n - 1 - x - x;
        int b = m - 1 - y - y;
        ans ^= xorsum(x) ^ xorsum(x + a);
        ans ^= xorsum(y) ^ xorsum(y + b);
        if (a > b) {
            swap(a, b);
        }
        int c = x + 1 + y + 1;
        if (a % 2 == 1) {
            ans ^= xorsum(c + a - 2) ^ xorsum(c + a - 1 + b - 1);
            a--;
        }
        ans ^= xorsum2(c - 2) ^ xorsum2(c + a - 2);
        ans ^= xorsum2(c + b - 2) ^ xorsum2(c + b + a - 2);
        cout << ans << '\n';
    }
    in.readEof();
    return 0;
}
Editorialist's code (Python)
def pre_xor(n):
    if n%4 == 0: return n
    if n%4 == 1: return 1
    if n%4 == 2: return n+1
    return 0
def prepre_xor(n):
    if n%8 <= 1: return n
    if n%8 <= 3: return 2
    if n%8 <= 5: return n+2
    return 0
def range_prexor(l, r):
    if l == 0: return prepre_xor(r)
    return prepre_xor(r) ^ prepre_xor(l-1)
def calc(n, m):
    if min(n, m) <= 0: return 0
    res = range_prexor(2+m-1, n+m)
    res ^= range_prexor(1, n)
    return res

for _ in range(int(input())):
    n, m, x, y = map(int, input().split())
    ans = pre_xor(x-1) ^ pre_xor(y-1) ^ pre_xor(n-x) ^ pre_xor(m-y)
    for N in [x-1, n-x]:
        for M in [y-1, m-y]:
            ans ^= calc(N, M)
    print(ans)
4 Likes

This is brilliant!!

1 Like