FREQXOR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: tibinyte
Tester: tabr
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Familiarity with bitwise operations, prefix sums

PROBLEM:

You’re given X, L, R, M. Consider the sequence of numbers

(X\oplus L, X\oplus (L+1), \ldots, X\oplus R)

all taken modulo M.

Find the maximum frequency of an element in this sequence.

EXPLANATION:

Let’s solve a simplified version of the task first: with X = 0.

The sequence is simply A = [L\bmod M, (L+1)\bmod M, \ldots, R\bmod M].

Let \text{freq} be an array where \text{freq}[y] denotes the frequency of y in A.
Note that the above sequence modulo M will look like:

  1. L\bmod M, (L\bmod M) + 1, \ldots, (M-1) to start.
  2. Several blocks of length M which are just 0, 1, 2, \ldots, M-1 repeated.
  3. 0, 1, 2, \ldots, R\bmod M to end.

Of course, there are minor edge cases - for example when R-L is small enough that we start at L\bmod M and end at R\bmod M without ever wrapping around.
A more generic characterization of this which captures everything, is to say that we have several segments of [0, 1, \ldots, M-1] (which will add 1 to all the \text{freq} values), and then we also add 1 to the cyclic subarray from L\bmod M to R\bmod M (which is either one subarray, or a prefix and a suffix).


Now, let’s consider a different simplification.
Let’s fix L = 0, so we’re looking at the sequence X, X\oplus 1, X\oplus 2, \ldots, X\oplus R.
Ignore the modulo M part for now.

The important observation here is that the range [0, R] breaks up into O(\log R) mutually disjoint ranges [x_1, y_1], [x_2, y_2], \ldots, [x_k, y_k], such that all the numbers within a range [x_i, y_i] will map to a contiguous range of numbers after taking their XOR with X.

More detail

Let’s write the binary expansion of R, as R = 2^{b_1} + 2^{b_2} + \ldots + 2^{b_k}, where b_1 \gt b_2 \gt\ldots \gt b_k.

Now, note that for all the values of y from 0 to 2^{b_1} - 1, y\oplus X will have the same set of bits \geq b_1, and across all y, will also cover all possible sets of the first b_1 bits.
That is, if M_1 is the mask of bits of X that are \geq b_1, every value in the range [M_1, M_1 + 2^{b_1}) will appear exactly once.

Similarly, for 2^{b_1} \leq y \lt 2^{b_1} + 2^{b_2}, the bits \geq b_2 will remain a constant, while lower bits will take all 2^{b_2} possible values; once again we’ll get some range of contiguous values.

This will repeat for each prefix of bits of R, and since there are O(\log_2 R) set bits in R, we obtain O(\log R) disjoint ranges which all give contiguous elements.

Now, let’s find each of these ranges (which as noted in the spoiler above are determined by the binary representation of R).
For range [x_i, y_i], suppose after XOR-ing with X the resulting range is [l_i, r_i].
Then, we want to add 1 to \text{freq}[y\bmod M] for each l_i \leq y \leq r_i.

Note that this has now reduced to exactly the initial simplification we considered, where X = 0, and we already know how to process that - in particular, it will correspond to increasing the entirety of \text{freq} by some constant (for the [0, 1, 2, \ldots, M-1] blocks), and then adding 1 to some cyclic subarray.

The overall additions to \text{freq} can be stored separately, so we only need to worry about the additions of 1 to cyclic subarrays (or equivalently, we want to add 1 to at most two subarrays of \text{freq}).
This problem of processing several range additions to an array is a rather well-known one, and can be handled using a combination of a difference array and prefix sums.

Details

Consider an array D, initially filled with 0's.

To add 1 to the range [l, r], we simply add 1 to D_l and subtract 1 from D_{r+1}.
After all updates, the prefix sums of D will give the actual values at each index.

For Q updates on an array of size N, this takes \mathcal{O}(N + Q) time.

In this case, M is quite large so we can’t use an array directly - but we can use a map instead.
Taking prefix sums along the map will give us segments of equal values, with the number of segments being linear in the number of updates.

So, by repeating the process for each of the O(\log R) ranges, we can compute the \text{freq} array (as segments of equal values) in \mathcal{O}(\log R \log\log R) time (the \log\log R part comes from accessing a map with O(\log R) values in it).


Now, from the above we can obtain the frequency array for [0, R].

Simply repeat this process again for L-1, but this time, subtract values instead of adding them.
The number of segments in the resulting frequency array is still \mathcal{O}(\log R), so we can simply iterate through them all and compute the maximum.

TIME COMPLEXITY:

\mathcal{O}(\log R \log\log R) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int random(int st, int dr)
{
    uniform_int_distribution<int> dist(st, dr);
    return dist(rng);
}

template <typename t>
using ordered_set = tree<t, null_type, less<t>, rb_tree_tag, tree_order_statistics_node_update>;

const int mod = 1e9 + 7;
struct Mint
{
    int val;
    Mint(int x = 0)
    {
        val = x % mod;
    }
    Mint(long long x)
    {
        val = x % mod;
    }
    Mint operator+(Mint oth)
    {
        return val + oth.val;
    }
    Mint operator-(Mint oth)
    {
        return val - oth.val + mod;
    }
    Mint operator*(Mint oth)
    {
        return 1ll * val * oth.val;
    }
    void operator+=(Mint oth)
    {
        val = (*this + oth).val;
    }
    void operator-=(Mint oth)
    {
        val = (*this - oth).val;
    }
    void operator*=(Mint oth)
    {
        val = (*this * oth).val;
    }
};

Mint powmod(int a, int b)
{
    if (b == 0)
    {
        return 1;
    }
    if (b % 2 == 1)
    {
        return powmod(a, b - 1) * a;
    }
    Mint p = powmod(a, b / 2);
    return p * p;
}

/*
                 .___                 __                 __           .__
  ____  ____   __| _/____     _______/  |______ ________/  |_  ______ |  |__   ___________   ____
_/ ___\/  _ \ / __ |/ __ \   /  ___/\   __\__  \\_  __ \   __\/  ___/ |  |  \_/ __ \_  __ \_/ __ \
\  \__(  <_> ) /_/ \  ___/   \___ \  |  |  / __ \|  | \/|  |  \___ \  |   Y  \  ___/|  | \/\  ___/
 \___  >____/\____ |\___  > /____  > |__| (____  /__|   |__| /____  > |___|  /\___  >__|    \___  >
     \/           \/    \/       \/            \/                 \/       \/     \/            \/
*/

#define int long long

const int lgmax = 30;
const int inf = 1e16;

vector<pair<int, int>> find_ranges(int n, int x) // 0 ^ x, 1 ^ x, ... , n^x
{
    vector<pair<int, int>> ans;
    int frog = 0;
    int frog2 = 0;
    for (int bit = lgmax; bit >= 0; --bit)
    {
        frog2 += x & (1 << bit);
        if (n & (1 << bit))
        {
            ans.push_back({(frog ^ frog2), (frog ^ frog2) + (1 << bit) - 1});
        }
        frog += n & (1 << bit);
    }
    ans.push_back({frog ^ frog2, frog ^ frog2});
    int cnt = 0;
    for (auto i : ans)
    {
        cnt += i.second - i.first + 1;
    }
    assert(cnt == n + 1);
    return ans;
}

int ceil(int a, int b)
{
    return (a + b - 1) / b;
}

int brute(int n, int st, int dr, int x)
{
    map<int, int> fr;
    for (int i = st; i <= dr; ++i)
    {
        int val = n ^ i;
        fr[val % x]++;
    }
    int ans = 0;
    for (auto it : fr)
    {
        ans = max(ans, it.second);
    }
    return ans;
}

int32_t main()
{
    cin.tie(nullptr)->sync_with_stdio(false);
    int q;
    cin >> q;
    while (q--)
    {
        int n, st, dr, x;
        cin >> n >> st >> dr >> x;

        vector<pair<int, int>> a = find_ranges(dr, n);
        vector<pair<int, int>> b = find_ranges(st - 1, n);

        for (auto it : b)
        {
            vector<pair<int, int>> new_a;
            for (auto i : a)
            {
                pair<int, int> I = {max(i.first, it.first), min(i.second, it.second)};
                if (I.first <= I.second)
                {
                    if (i.first <= I.first - 1)
                    {
                        new_a.push_back({i.first, I.first - 1});
                    }
                    if (I.second + 1 <= i.second)
                    {
                        new_a.push_back({I.second + 1, i.second});
                    }
                }
                else
                {
                    new_a.push_back(i);
                }
            }
            a = new_a;
        }

        int cnt = 0;
        for (auto it : a)
        {
            cnt += it.second - it.first + 1;
        }

        assert(cnt == (dr - st + 1));

        int ans = 0;

        map<int, int> mars;

        function<void(int, int)> add = [&](int st, int dr)
        {
            if (st > dr)
            {
                return;
            }
            st %= x;
            dr %= x;
            if (st > dr)
            {
                swap(st, dr);
            }
            mars[st]++;
            mars[dr + 1]--;
        };

        for (auto it : a)
        {
            int A = ceil(it.first, x) * x;
            int B = (it.second / x) * x;
            if (A <= B)
            {
                ans += (B - A) / x;
                add(it.first, A - 1);
                add(B, it.second);
            }
            else
            {
                add(it.first, it.second);
            }
        }
        int sum = 0;
        int maxi = 0;
        for (auto it : mars)
        {
            sum += it.second;
            maxi = max(maxi, sum);
        }
        ans += maxi;

        cout << ans << '\n';

        //cout << ans << ' ' << brute(n, st, dr, x) << endl;
        //assert(ans == brute(n, st, dr, x));
    }
}

/*
I cannot take this anymore
Saying everything I've said before
All these words, they make no sense
I find bliss in ignorance
Less I hear, the less you say
You'll find that out anyway
Just like before

Everything you say to me
(Takes me one step closer to the edge)
(And I'm about to break)
I need a little room to breathe
(Cause I'm one step closer to the edge)
(I'm about to break)
*/
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;

void solve(istringstream cin) {
    long long n, l, r, x;
    cin >> n >> l >> r >> x;
    map<long long, long long> pref;
    auto add = [&](long long low, long long high, int coeff) {
        pref[low % x] += coeff;
        pref[high % x] -= coeff;
        pref[0] += coeff * (high / x - low / x);
    };
    auto go = [&](long long t, int coeff) {
        for (int i = 30; i >= 0; i--) {
            if (t & (1LL << i)) {
                long long u = ((n ^ t ^ (1LL << i)) >> i) << i;
                add(u, u + (1LL << i), coeff);
            }
        }
    };
    go(r + 1, +1);
    go(l, -1);
    long long ans = 0;
    long long sum = 0;
    for (auto p : pref) {
        sum += p.second;
        ans = max(ans, sum);
    }
    cout << ans << '\n';
}

////////////////////////////////////////

// #define IGNORE_CR

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
#ifdef IGNORE_CR
            if (c == '\r') {
                continue;
            }
#endif
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            assert(!isspace(buffer[pos]));
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    input_checker in;
    int tt = in.readInt(1, 1e4);
    in.readEoln();
    while (tt--) {
        long long n = in.readInt(1, 1e9);
        in.readSpace();
        long long l = in.readInt(1, 1e9);
        in.readSpace();
        long long r = in.readInt(1, 1e9);
        in.readSpace();
        long long x = in.readInt(2, 1e9);
        in.readEoln();
        assert(l <= r);
        ostringstream sout;
        sout << n << " " << l << " " << r << " " << x << '\n';
        solve(istringstream(sout.str()));
    }
    in.readEof();
    return 0;
}

Where can I read more about it? This breaking of 0 to R into logR segments seems quite a useful and not obvious but well-known property of XOR. And this playing with xor of a range of elements as well seem quite a useful trick.

If u can suggest some more problems that u r aware of or some tricks like this, articles or anything, I would be grateful.