XORPALIN - Editorial


Contest: Division 1
Contest: Division 2

Setter: Nishant Shah
Tester: Rahul Dugar
Editorialist: Taranpreet Singh




Trie, Bitmasking, combinatorics, Segment Tree.


All binary strings in this problem would be considered of length K, padded with zeros on left if needed. We start with an empty sequence of binary strings A, add binary representation of values in range [L_i, R_i] at i-th query to A, and after each query, count the number of pairs (i, j) such that 1 \leq i < j \leq |A| such that the binary string A_i \oplus A_j is a palindrome.


  • By noticing that for two binary strings P and Q, P \oplus Q being palindrome implies P_i \oplus Q_i = P_{K-i-1} \oplus Q_{K-i-1} \implies P_i \oplus P_{K-i-1} = Q_i \oplus Q_{K-i-1} should hold for 0 \leq i < K/2 where P_i denote i-th character.
  • So each binary string maps to a unique bitmask with K/2 bits, where i-th bit is generated by P_i \oplus P_{K-i-1}. All pairs of strings in the same equivalence class form a good pair, and no pair across equivalence groups form a good pair.
  • Hence, If we can maintain the number of strings in each group after each update, and the number of good pairs before the update, we can easily compute the number of good pairs after the update.
  • In order to avoid explicitly storing the frequency of each element individually, we can find groups of positions that shall always remain equal, compress them to speed up time, and save memory.



  • P_i denote i-th least significant character from string P, starting from 0
  • \oplus denote bitwise XOR operation.

Basic Intuition

Let’s just consider two binary strings P and Q of length K. Assuming their XOR is a palindrome, we have for each i such that 0 \leq i < K/2, P_i \oplus Q_i = P_{K-i-1} \oplus Q_{K-i+1} which implies P_{i} \oplus P_{K-i-1} = Q_{i} \oplus Q_{K-i-1}

Consider binary string P and build a mask M with exactly K/2 bits (floored down), where M_i = P_i \oplus P_{K-i-1}. Let’s denote string M as the mask of string P. We can prove that each binary string of length K has a unique mask.

The XOR of two binary strings P and Q with K bits is a palindrome if and only if the mask formed by both strings is the same.

For some mask m, if X strings appear in A, it contributes \binom{X}{2} pairs. As we can see that no pair is ever repeated across masks, the answer is simply \displaystyle\sum_{m = 0}^{2^{K/2}-1} \binom{f_m}{2} where f_m denotes the number of strings with mask m.

Naive Solution

This gives us an easy way to solve the problem. Let’s keep a frequency array of masks and whenever we get an update, loop from L_i to R_i and update frequency and the total number of pairs. This solution can solve the second subtask, if used with a map, as storing a frequency array of size 2^{K/2} is not feasible.

Improving a bit

Let’s assume for now that we can maintain the array in memory, but we need to perform updates quickly. An update of appending string [L_i, R_i] can be seen as appending [0, R_i] and then removing [0, L_i-1]. So let’s figure out appending [0, R_i] first, the other case would be handled similarly.

Let’s notice the string P and its generated mask carefully. Let’s consider significant K/2 bits of P and reverse them, let’s call it U (for upper). The bitwise XOR of U and lower K/2 bits of P is the mask of P. Consider interval of strings [0, 2^{K/2}-1], it contains all possible lower bits for fixed U. So the set of masks generated by XORing U with all strings is [0, 2^{K/2}-1]. Read this line carefully.

The implication of last statement is that when appending all strings in an interval [0, X], the frequency of all masks is increased by \displaystyle\frac{X}{2^{K/2}}, and for remaining \displaystyle X - \bigg \lfloor \frac{X}{2^{K/2}} \bigg \rfloor * 2^{K/2} = X \bmod 2^{K/2} strings, the significant (K+1)/2 bits are fixed, while lower bits lie in range [0, X \bmod 2^{K/2}].

Let’s break down this [0, X \bmod 2^{K/2}] interval into intervals of length power of two from left to right, greedily from largest power of two to smallest power of two. Now we need to find intervals in a trie manner, the continuous intervals of positions affected.

We can compute the actual intervals just like we use Trie to find the smallest XOR pair or finding a value in a list with k-th smallest XOR with the given value.

A walkthrough example of the above line

Once broken, let’s find the actual intervals. Let’s assume K = 7, consider X = 62, it’s binary representation being 0111110, we can extract upper K/2 bits and reverse them to get U = 110. We need to increase frequency of all masks in \{U \oplus x: x \in [0, 62] \}. But considering all x in range \displaystyle [0, \lfloor \frac{X}{2^{K/2}} \rfloor * 2^{K/2} -1] = [0, 55], all masks appear \lfloor \frac{X}{2^{K/2}} \rfloor = \frac{62}{8} = 7 times.

Handling interval [56, 62] now, all bits except last K/2 bits are same. Let’s find U = \text{reverse}(011) = 110. Now since we only care for last K/2 bits of this range, it is equivalent to [0, 6]. We need to increase frequency of masks 0 \oplus U, 0 \oplus U \ldots 6 \oplus U by one. The masks are [110, 111, 100, 101, 010, 011, 000] = [4,5,6,7,2,3,0] in order.

The interval [0, 6] shall be broken into [0, 3] + [4, 6] where [4, 6] is further broken into [0, 3], [4, 5], [6, 6]. Each of these intervals would represent a continuous interval in our frequency array, whose start point depends upon U.

Now, considering bit b from K/2-1 to 0-th bit, if there’s an interval with length 2^b, it shall have first b bits same as U and all subsequent intervals must have b-th bit opposite as U. Otherwise, all subsequent intervals must have b-th bit same as U.
In our example,

  • There’s an interval with length 2^2 = 4, so It has 2-th bit, same as U,i.e. interval [4, 7]
  • There’s an interval with length 2^1 = 2, so It has 1-th bit, same as U,i.e. interval [2, 3]
  • There’s an interval with length 2^0 = 1, so It has 0-th bit, same as U,i.e. interval [0, 0]

As we can see, the above intervals represent the actual update needed.


Now, We can update the frequencies of each mask efficiently. But Computing Number of pairs every time from scratch is still a pain. So let’s notice that \binom{x+y}{2} = \binom{x}{2} + \binom{y}{2} + x*y. So if frequency of a mask m is increased by x, then number of pairs is increased by \binom{x}{2} + x*f_{m}. Also, if frequency of a range [L, R] of masks is increased by x, the number of pairs is increased by \displaystyle \binom{x}{2}*(R-L+1) + x * \sum_{m = L}^R f_m.

Similar relations can be found for calculating the impact of removing one or X strings of same mask on pair count.

From our example, we had two kinds of updates to process on array [0, 2^{K/2}-1]

  • Increase all masks in a range [L, R] by x
  • In the above update, keep updating the number of pairs due to that update.

Segment Tree with Lazy propagation is a decent choice storing frequencies and a pair count variable.

Reducing memory Consumption

In all above, we assume we are able to store an array of length 2^{K/2}, but Life ain’t a bed of roses.


Credits: Imgflip

But We can notice that Each query can result in O(K) updates, leading to a total of O(N*K) updates. Hence, we can divide the frequency array into O(N*K) disjoint intervals such that all frequencies in each interval remain the same, thus handling them simultaneously. A Little modification in Segment Tree to use the sum of weights of subtree instead of range length allows us to handle these weighted leaves. Refer to my implementation for details.

Important Note

This is one way to approach it. Users experienced with Tries may try dynamic Tries to avoid segment tree and coordinate compression, as all they need to maintain is the sum of frequencies.


The time complexity is O(N*K*log(K)) per test case with a high constant.
The memory complexity is O(N*K) per test case since we only need to store updates O(N*K) and segment tree O(N*K) in compressed form.


Setter's Solution
#include <bits/stdc++.h>
using namespace std;
#define int long long 
#define fast ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)

const int MOD = 998244353;
const int K = 62;

int pw2[K+K];

int add(int x, int y) {
	x += y;
	if (x >= MOD) return x - MOD;
	return x;

int sub(int x, int y) {
	x -= y;
	if (x < 0) return x + MOD;
	return x;

int mult(int x, int y) {
	return (x*y) % MOD;

int mod_pow(int x,int y)
	int res = 1;
	while(y > 0)
	    if(y & 1) res = mult(res,x);
	    x = mult(x,x);
	return res;

int inv(int x)
	return mod_pow(x,MOD-2);

typedef struct data
	data* bit[2];
	int cnt = 0,sum = 0;

trie* head;
int n,k;
int cnt_tot,single[K];
int res = 0;   

void insert(trie* cur,int x,int ad,int len,int sign)
	if(sign == 1)
	    cur->sum = add(cur->sum,pw2[len]);
	    cur->sum = sub(cur->sum,pw2[len]);
	for(int i=k/2-1;i>=0 && ad > 0;i--,ad--)
		int b = (x>>i) & 1;
		if(!cur->bit[b]) cur->bit[b] = new trie();
		cur = cur->bit[b];
	    if(sign == 1)
	        cur->sum = add(cur->sum,pw2[len]);
	        cur->sum = sub(cur->sum,pw2[len]);

pair<int,int> get(trie* cur,int x,int v)
	pair<int,int> ret = {0,0};
	for(int i=k/2-1;i>=0&&v>0;i--,v--)
	    cur = cur->bit[(x>>i) & 1];
	    if(!cur) break;
	if(cur) ret.first = cur->sum;
	return ret;

void update(int num,int len,int sign)
	if(len <= k/2)
	    if(sign == -1) 
	        cnt_tot = sub(cnt_tot,pw2[len-(k/2)]);
	        cnt_tot = add(cnt_tot,pw2[len-(k/2)]);

int calc(int num,int i)
	int temp = single[i];
	int bts = (k/2) - i;

	pair<int,int> dat = get(head,num,bts);
	if(i <= k/2)
	    temp = add(temp,dat.first);
	    temp = add(temp,mult(pw2[i],dat.second));
	    temp = add(temp,mult(dat.first,pw2[i-(k/2)]));
	temp = add(temp,mult(cnt_tot,pw2[i]));
	return temp;

void add_seg(int r,int sign)
	int num = 0;
	int num_rev = 0;
	for(int i=k;i>=0;i--)
	    if(num + (1LL<<i) > r) continue;
	    int num2 = (num^num_rev) & ((1LL<<(k/2)) - 1);
	    if(sign == 1) 
	        res = add(res,calc(num2,i)); 
	        res = sub(res,calc(num2,i));
	    if(i != k) num_rev+=(1LL<<(k-i-1));

void solve()
	cin >> n >> k;

	head = new trie(); 
	cnt_tot = 0;
	for(int i=0;i<=k;i++)
	 single[i] = sub(pw2[i+i-min(k/2,i)],pw2[i]);
	 if(single[i] & 1) single[i] += MOD;
	res = 0;
	for(int i=1;i<=n;i++)
	    int l,r;
	    cin >> l >> r;
	    if(l > 0) add_seg(l-1,-1);
		cout << res << '\n';

signed main()
	pw2[0] = 1;
	for(int i=1;i<K+K;i++) pw2[i] = add(pw2[i-1],pw2[i-1]);
	int t = 1;
	cin >> t;
	while(t--) solve();
Tester's Solution
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=c; a++)
#define rep(a,b,c) for(int a=b; a<c; a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
const int mod=998244353;
//const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;

typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
	uniform_int_distribution<int> uid(0,lim-1);
	return uid(rang);

int powm(int a, int b) {
	int res=1;
	while(b) {
	return res;

long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
		if('0'<=g&&g<='9') {
			if(cnt==0) {
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);

			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
			return x;
		} else {
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		if(g==endd) {
	return ret;
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
string readStringSp(int l, int r) {
	return readString(l,r,' ');

int ans=0;
int cntr=0;
int to[20000005][2];
int sum[20000005],sum2[20000005];
int pp;
// sum neechla
// sum2 athla
void add(int at, int num, int i, int cnt, int pots, int tots) {
	if(i==cnt) {
	} else
int sum_n=0;
int subtask_n=25000,subtask_k=60,subtask_s=25000;
void solve() {
	int n=readIntSp(1,subtask_n),k=readIntLn(1,subtask_k);
	while(n--) {
		int l=readIntSp(0,(1LL<<k)-1);
		int r=readIntLn(l,(1LL<<k)-1);
		int num=r-l;
		vector<pii> poo;
		while(l<r) {
			int tea=__builtin_ffsll(l)-1,teb=__builtin_ffsll(r)-1;
			if(l==0||tea>teb) {
			} else {
		for(auto i:poo) {
			int tol=0;
			for(int j=i.se; j<pp; j++) {
			if(i.se>pp) {
				int cnt=(1LL<<(i.se-pp));

signed main() {
	int t=readIntLn(1,10);
#ifdef rd
	cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
Editorialist's Solution
import java.util.*;
import java.io.*;
	final long MOD = 998244353, inv2 = (MOD+1)/2;
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	int N = ni(), K = ni();
	long[] delta = new long[N];//delta[i] denote the increase in frequency of all masks in ith update
	ArrayList<long[]>[] upd = new ArrayList[N];
	long DOWN = (1L<<(K/2))-1;
	for(int i = 0; i< N; i++){
		upd[i] = new ArrayList<>();
		long L = nl()-1, R = nl();
		delta[i] += R>>(K/2);
		if(L != -1)delta[i] -= L>>(K/2);
		addUpdate(upd[i], reverseBits(R>>((K+1)/2), K/2), 1+(R&DOWN), K/2, 1);
		if(L != -1)addUpdate(upd[i], reverseBits(L>>((K+1)/2), K/2), 1+(L&DOWN), K/2, -1);
	    long[] val = new long[N*K*4];int cnt = 0;
	    for(ArrayList<long[]> up:upd)for(long[] u:up){
	        val[cnt++] = u[0];
	        val[cnt++] = u[1]+1;
	    val = Arrays.copyOf(val, cnt);
	    cnt = 1;
	    for(int i = 1; i< val.length; i++)if(val[i] != val[cnt-1])val[cnt++] = val[i];
	    val = Arrays.copyOf(val, cnt);
	long[] weight = new long[val.length];
	for(int i = 0; i+1 < val.length; i++)weight[i] = val[i+1]-val[i];
	    //ith leaf in segment tree represent all masks in interval [val[i], val[i+1]-1], so weight[i] = val[i+1]-1 - val[i] +1 = val[i+1]-val[i]
	LazySegmentTree st = new LazySegmentTree(cnt, weight);
	long ans = 0;
	for(int i = 0; i< N; i++){
	        //Handling update to whole array
		ans += nC2(delta[i])*(1L<<(K/2))%MOD;
		if(ans >= MOD)ans -= MOD;
		ans += delta[i]*st.sum(0, cnt-1)%MOD;
		if(ans >= MOD)ans -= MOD;
		st.update(0, cnt-1, delta[i]);
	        //Handling intervals 
		for(long[] u:upd[i]){
	            int l = Arrays.binarySearch(val, u[0]), r = Arrays.binarySearch(val, u[1]+1)-1;
		if(u[2] == 1)ans += st.sum(l, r);
		st.update(l, r, u[2]);
		if(u[2] == -1)ans += MOD-st.sum(l, r);
		if(ans >= MOD)ans -= MOD;
	long nC2(long n){
	return n*(n-1)%MOD * inv2%MOD;
	void addUpdate(ArrayList<long[]> upd, long mask, long count, int B, int sign){
	    if(count == 1L<<B){
	        upd.add(new long[]{0, (1L<<B)-1, sign});
	    long prefix = 0;
	    for(int b = B-1; b >= 0; b--){
	        prefix |= 1L<<b;
	            upd.add(new long[]{mask&prefix, (mask&prefix)|((1L<<b)-1), sign});
	            mask ^= 1<<b;
	long reverseBits(long x, int B){
	long rev = 0;
	for(int b = 0; b< B; b++){
		rev |= (x&1);
	return rev;
	class LazySegmentTree{
	    //node i is parent of nodes 2*i and 2*i+1, rooted at 1
	    //w[i] denote sum of weights of leaves in subtree of node i
	int m = 1;
	long[] t, lazy, w;//w denote weights
	public LazySegmentTree(int n, long[] weight){
		t = new long[m<<1];
		lazy = new long[m<<1];
		w = new long[m<<1];
		for(int i = 0; i< n; i++)w[i+m] = weight[i]%MOD;
		for(int i = m-1; i> 0; i--)w[i] = (w[i<<1]+w[i<<1|1])%MOD;
	private void push(int i, int ll, int rr){
		if(lazy[i] != 0){
		t[i] += lazy[i]*w[i]%MOD;           //weighted update
		if(t[i] >= MOD)t[i] -= MOD;
		if(i < m){
			lazy[i<<1] = (lazy[i]+lazy[i<<1])%MOD;
			lazy[i<<1|1] = (lazy[i<<1|1]+lazy[i])%MOD;
		lazy[i] = 0;
	public void update(int l, int r, long x){u(l, r, 0, m-1, 1, x);}
	public long sum(int l, int r){return s(l, r, 0, m-1, 1);}

	private void u(int l, int r, int ll, int rr, int i, long x){
		push(i, ll, rr);
		if(l == ll && r == rr){
		lazy[i] += x;
		push(i, ll, rr);return;
		int mid = (ll+rr)/2;
		if(r <= mid){
		u(l, r, ll, mid, i<<1, x);
		push(i<<1|1, mid+1, rr);
		}else if(l > mid){
		push(i<<1, ll, mid);
		u(l, r, mid+1, rr, i<<1|1, x);
		u(l, mid, ll, mid, i<<1, x);
		u(mid+1, r, mid+1, rr, i<<1|1, x);
		t[i] = (t[i<<1]+t[i<<1|1])%MOD;
	private long s(int l, int r, int ll, int rr, int i){
		push(i, ll, rr);
		if(l == ll && r == rr)return t[i];
		int mid = (ll+rr)/2;
		if(r <= mid)return s(l ,r, ll, mid, i<<1);
		else if(l > mid)return s(l, r, mid+1, rr, i<<1|1);
		else return (s(l, mid, ll, mid, i<<1)+s(mid+1, r, mid+1, rr, i<<1|1))%MOD;
	void dbg(Object... o){System.err.println(Arrays.deepToString(o));}
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	public static void main(String[] args) throws Exception{
	    new XORPALIN().run();
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	        return st.nextToken();

	    String nextLine() throws Exception{
	        String str = "";
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        return str;

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:


We can prove that each binary string of length K has a unique mask.

I’m not sure why this statement is there and I think it’s wrong. For example, 00 and 11 have the same mask 0.

You misinterpreted the statement slightly.

Correct, Strings 00 and 11 have same mask 0, but no string has multiple masks.

A String correspond to exactly one mask, while a mask may correspond to multiple masks (Each mask in this problem correspond to 2^{(K+1)/2} strings)

1 Like

Thank you for the explanation.