XM - EDITORIAL

PROBLEM LINK

Practice

Contest: Division 1

Contest: Division 2

Setter: Vivek Chauhan

Tester: Michael Nematollahi

Editorialist: Taranpreet Singh

DIFFICULTY

Medium.

PREREQUISITES

Precomputation, prefix arrays, and Dynamic Programming.

PROBLEM

Given an array A of length N, we have to answer Q queries of the form - Given range [L, R], output \sum_{i=L}^{R} A_i \oplus (i-L) where \oplus is the xor operation.

EXPLANATION

Since we are performing xor operation, we can see that xor operation affects each bit independently of other bits, and hence, we can calculate the sum for each bit separately and then add to get the final answer.

For a query [L, R], A_L is xored with 0, A_{L+1} is xored with 1, A_{L+2} is xored with 2 and so on.

We can see, that we are basically doing xor of the interval with sequence 0 1 2 3 4 5 6 7 8 9

Considering only 0-th bit of this sequence, it becomes 0 1 0 1 0 1 0 1 0 1
Considering only 1-st bit of this sequence, it becomes 0 0 1 1 0 0 1 1 0 0
Considering only 2-nd bit of this sequence, it becomes 0 0 0 0 1 1 1 1 0 0
\ldots
Considering only k-th bit of this sequence, it becomes 2^k 0s, 2^k 1s, 2^k 0s, 2^k 1s and so on.

Basically, for k-th bit, First 2^k positions of the interval are xor-ed with 0, next 2^k positions are xor-ed with 1, next 2^k positions are xor-ed with 0 and so on.

Now, define two functions set(b, L, R) which gives the number of positions in range [L, R] having b-th bit set and unset(b, L, R) which returns the number of positions in range [L, R] which have b-th bit not set.

Now, to answer queries, suppose we have to answer for k-th bit. Let us divide this range [L, R] into blocks of size 2^k. We can see, the first block shall remain as it is, all bits in the second block are flipped, the third block is not flipped, the fourth block is flipped and so on.

Now, we know, we need to perform xor on blocks, it brings us toward precomputation.

Let us use Dynamic Programming with state (b, p, flip) denoting that considering range [p, N] and b-th bit, this state represents the number of positions having b-th bit set in range [p, N] after dividing [p, N] into blocks of size 2^b. If st = 0, first block is not flipped, second is flipped, third is not and so on. If st = 1, first block is flipped, second is not, the third block is flipped and so on.

For example, consider bit 2 and position p and st = 0.
State (2, p, 0) represent set(b, p, p+2^b-1) + unset(b, p+2^b, p+2*2^b-1) + set(b, p+2*2^b, p+3*2^p -1) and so on. Basically, divided range [p, N] into blocks of size 2^B and xoring every alternative block, starting from st block.

If for a query, r-l = k*2^B, we can divide the query range into k blocks and can observe that the first block is not flipped, the second block is flipped and so on. If k is even, k-th block is flipped. So, for even k, we can see that answer is given by (b, l, 0) - (b, (r+1, 0). However, if k is odd, the next block is flipped, which gives answer (b, l, 0) - (b, r+1, 1) as answer. Idea is that we just remove the contribution of range [r+1, N] from [l, N].

But what if r-l \neq k*2^B. In this case, there shall be a partial block at the end with size \leq 2^b. Suppose we find largest rr \leq r and calculate the answer for the range [l, rr-1] using above idea. Now, the good news is that positions [rr, r] belong to the same group, and are either flipped or not flipped, depending upon the number of blocks in the range [l, rr-1]. we can use the set and unset function to answer for this range, hence answering the query.

For calculating DP, we can write state dp(b, p, 0) as set(b, p, p+2^b-1) + dp(b, p+2^b, 1) because we are adding contribution of ranges [p, p+2^b-1] which is not flipped and range [p+2^b, N] if we start flipping from first block itself.

Similarly, dp(b, p, 1) can be written as the sum of unset(b, p, p+2^b-1) + dp(b, p+2^b, 0). Using these transitions, we can easily compute the whole table from end to start for each bit separately.

The final thing, if there are c bits set in a query range for b-th bit, it contributes c*2^b to the final answer for the query.

Still Confused? We all are, so refer to the implementation for more details.

TIME COMPLEXITY

Time complexity is O((N+Q)*B) per test case where B = log(max(A_i)) the number of bits.

SOLUTIONS:

Setter's Solution
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
typedef long double ld;
const int N = 100005;
ll inf = 1e16;
ll mod = 1e9 + 7;

char en = '\n';
ll power(ll x, ll n, ll mod) {
  ll res = 1;
  x %= mod;
  while (n) {
	if (n & 1)
	  res = (res * x) % mod;
	x = (x * x) % mod;
	n >>= 1;
  }
  return res;
}
int setBit[N][33];
int dp[N][33][2];
ll n;
inline int countSet(ll l, ll r, ll bit) {
  if (l > r)
	return 0;
  r = min(r, n);
  return setBit[r][bit] - setBit[l - 1][bit];
}
inline int countUnset(ll l, ll r, ll bit) {
  if (l > r)
	return 0;
  r = min(r, n);
  return r - l + 1 - (setBit[r][bit] - setBit[l - 1][bit]);
}

inline int getValue(ll pos, ll bit, ll type) {
  if (pos > n)
	return 0;
  return dp[pos][bit][type];
}
int32_t main() {
  ios_base::sync_with_stdio(false);
  cin.tie(NULL);

  ll t;
  cin >> t;
  while (t--) {
	ll q;
	cin >> n >> q;

	memset(dp, 0, sizeof(dp));
	memset(setBit, 0, sizeof(setBit));
	ll arr[n + 5];
	for (ll i = 1; i <= n; i++)
	  cin >> arr[i];

	memset(dp, 0, sizeof(dp));

	for (ll i = 1; i <= n; i++) {
	  for (ll j = 0; j <= 30; j++) {
	    setBit[i][j] = setBit[i - 1][j] + ((arr[i] >> j) & 1);
	  }
	}

	for (ll i = n; i >= 1; i--) {
	  for (ll j = 0; j <= 30; j++) {
	    dp[i][j][0] =
	        countSet(i, i + (1 << j) - 1, j) + getValue(i + (1 << j), j, 1);
	    dp[i][j][1] =
	        countUnset(i, i + (1 << j) - 1, j) + getValue(i + (1 << j), j, 0);
	  }
	}

	while (q--) {
	  ll l, r;
	  cin >> l >> r;
	  ll res = 0;
	  for (ll j = 0; j <= 30; j++) {
	    ll block = (r - l) / (1 << j);
	    ll rightEnd = l + (ll)(block + 1) * (1 << j);

	    ll currRes = 0;
	    if (block % 2 == 0) {
	      currRes = getValue(l, j, 0) - getValue(rightEnd, j, 1) -
	                countSet(r + 1, rightEnd - 1, j);
	    } else {
	      currRes = getValue(l, j, 0) - getValue(rightEnd, j, 0) -
	                countUnset(r + 1, rightEnd - 1, j);
	    }

	    res += (ll)(1 << j) * currRes;
	  }
	  cout << res << en;
	}
  }

  return 0;
}
Tester's Solution
#include<bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

#define F first
#define S second

const int MAXN = 1e5 + 10;
const int LOG = 17;

int n, q, cnt[LOG][MAXN];
ll a[LOG][MAXN];

int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	int te; cin >> te;
	while (te--){
		cin >> n >> q;
		for (int i = 0; i < n; i++){
			cin >> a[0][i];
			for (int w = 0; w < LOG; w++)
				cnt[w][i+1] = cnt[w][i] + (a[0][i]>>w&1);
		}
		for (int w = 1; w < LOG; w++)
			for (int i = 0; i < n; i++)
				if (i + (1<<w) <= n){
					a[w][i] = a[w-1][i] + a[w-1][i+(1<<w-1)];
					int t = cnt[w-1][i+(1<<w)] - cnt[w-1][i+(1<<w-1)];
					a[w][i] -= 1ll*t*(1<<w-1);
					a[w][i] += 1ll*((1<<w-1)-t)*(1<<w-1);
				}

		while (q--){
			int l, r; cin >> l >> r, l--;
			ll ans = 0;
			int cur = 0;
			for (int w = LOG-1; ~w; w--)
				if (l + (1<<w) <= r){
					ans += a[w][l];
					for (int j = LOG-1; j > w; j--)
						if (cur>>j&1){
							int t = cnt[j][l + (1<<w)] - cnt[j][l];
							ans -= 1ll*t*(1<<j);
							ans += 1ll*((1<<w)-t)*(1<<j);
						}

					l += 1<<w;
					cur ^= 1<<w;
				}
			cout << ans << "\n";
		}
	}
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class XM{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int n = ni(), q = ni(), B = 30;
	    //sum[b][i] -> Number of positions in range [1, i] which have bth bit set
	    int[][] sum = new int[B][2+n];
	    for(int i = 1; i<= n; i++){
	        for(int b = 0; b< B; b++)sum[b][i] = sum[b][i-1];
	        if(i>n)break;
	        int x = ni();
	        for(int b = 0; b< B; b++, x>>=1)sum[b][i]+=(x&1);
	    }
	    //If considering from position i and bit b, divide range [i, n] into blocks of size pow(2, b)
	    //We can see that these blocks shall be flipped alternatively for current bit.
	    //dp[b][1][i] -> Number of set bits in range [i, n] after flipping, if first block is not filpped, second is flipped, third is not and so on
	    //dp[b][0][i] -> Number of set bits in range [i, n] after flipping, if first block is filpped, second is not flipped, third is flipped and so on
	    int[][][] dp = new int[B][2][2+n];
	    //b- bit, sz - size of block
	    for(int b = 0, sz = 1; b< B; b++, sz <<=1)
	        for(int i = n+1; i> 0; i--){
	            dp[b][0][i] = sum[b][Math.min(n, i+sz-1)]-sum[b][i-1];
	            dp[b][1][i] = (Math.min(n, i+sz-1)-i+1)-(sum[b][Math.min(n, i+sz-1)]-sum[b][i-1]);
	            if(i+sz <= n+1){
	                dp[b][0][i] += dp[b][1][i+sz];
	                dp[b][1][i] += dp[b][0][i+sz];
	            }
	        }
	    while(q-->0){
	        int l = ni(), r = ni();
	        long ans = 0;
	        for(int b = 0, sz = 1; b< B; b++, sz <<= 1){
	            int bl = (r-l)/sz;
	            int rr = l+(bl*sz), cnt = 0;
	            if(bl%2==0) cnt = dp[b][0][l] - dp[b][0][rr] + sum[b][r]-sum[b][rr-1];
	            else cnt = dp[b][0][l] - dp[b][1][rr] + (r-rr+1)-(sum[b][r]-sum[b][rr-1]);
	            ans += cnt*(long)sz;
	        }
	        pn(ans);
	    }
	}
	int countSet(int[][] sum, int b, int l, int r){
	    if(l>r)return 0;
	    return sum[b][r] - sum[b][l-1];
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new XM().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to Share your approach, if you want to. (even if its same :stuck_out_tongue: ) . Suggestions are welcomed as always had been. :slight_smile:

3 Likes

Very nice problem! Thanks for the editorial
But dp(b,p,0) = set(b,p,p + pow(2,b) - 1) + dp(d,p + pow(2,b),1) is still unclear.
Please help

See, dp(b, p, 0) means considering b-th bit and position [p, N], the number of set bits, if we divide this range into blocks of size 2^b and first block is not flipped.

We can see, dp(b, p, 0) means number of set bits in current block, if current block is not flipped plus dp(b, p+2^b, 1) since the next block is flipped, and next block starts at (p+2^b).
Hope that make sense.

That was a nice problem.
I found the editorial a little confusing, for reference i looked up on this solution which clean and readable

For kth bit, all blocks of zeros, the numbers with kth bit set would be added and for all blocks of ones, the numbers with kth bit unset will get added, because for 1 \ xor \ x = 1 => x has to be 0.

I implement a slower solution with two phase of partial sum.

For the low bit with block size below 512 ( 2^9), calculate the respective query result in total with 512 partial sum with different start position.

For the high bit with block size above 512, the number of block is at most 100000 / 512.
The query result of different bit can be accumulate directly block by block.

The time complexity is larger then the editorial dp version, and luckily the time consumption is proximity to the time limit but not overtime.

my solution

1 Like