SPLITADD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: triggered_code and iceknight1093
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

2622

PREREQUISITES:

Recursion

PROBLEM:

You’re given N, L, and R. N is a power of 2.
Consider the following process f(A) for an array A:

  • If |A| = 2, return A.
  • Otherwise, split the array into two halves: the elements at odd indices and the elements at even indices.
    Recursively process these arrays, then return their concatenation.

Find the sum of the L-th through R-th elements of the resulting array, modulo 10^9 + 7.

EXPLANATION:

The given process is recursive, so let’s also try to use recursion to solve the problem.

Subtask 1

We’re looking for a single value here. Let K = L = R be that index.

Let ans(N, K) denote the value we’re looking for.
f(A) partitions the array into two halves, and then recursively applies the same process to each half.
So, we have two cases for which half K is in: either K \leq \frac{N}{2}, or K \gt \frac{N}{2}. Let’s look at them separately.

  • If K \gt \frac{N}{2}, then we want to find the (K - \frac{N}{2})-th element of the array f([2, 4, 6, 8, \ldots, N]).
    Notice that this array we operate on is just \left[1 ,2, 3 ,\ldots, \frac{N}{2}\right], but with all its elements multiplied by 2.
    So, the value we want is just 2\cdot ans\left(\frac{N}{2}, K - \frac{N}{2}\right).
  • If K \leq \frac{N}{2}, we want to find the K-th element of f([1, 3, 5, \ldots, N-1]).
    Everything here is odd - in particular, the i-th element is 2i - 1.
    So, we could instead find the K-th element of f([1, 2, 3, \ldots, \frac{N}{2}]), and then multiply it by 2 and subtract 1.
    That is, we want 2\cdot ans(\frac{N}{2}, K) - 1.

Putting both cases together, we have:

ans(N, K) = \begin{cases} 2\cdot ans(\frac{N}{2}, K) - 1, & \text{if } K \leq \frac{N}{2} \\ 2\cdot ans(\frac{N}{2}, K - \frac{N}{2}), & \text{if } K \gt \frac{N}{2} \end{cases}

With the base case being ans(N, K) = K if N = 1, since the process doesn’t change arrays of length 1.
At each step of the recursion, we halve N, so we find the answer in \mathcal{O}(\log N) time.

There are other ways to solve this subtask: for example, if you try hard enough and stare at the outputs for small N and K, you can probably observe some sort of pattern based on N and K.

Subtask 2

Let’s generalize our recursion from subtask 1 to ranges.
Let ans(N, L, R) denotet the answer we’re looking for.
Let H = \frac{N}{2} be half of N.
There are three cases:

  • If R \leq H, then the entire range lies in the left half.
    Here, we get ans(N, L, R) = 2\cdot ans(H, L, R) - (R-L+1).
    This is because, as noted earlier, the left half is elements of the form
    [(2\cdot 1 - 1), (2\cdot 2 - 1), (2\cdot 3 - 1), \ldots]
    So, we need to find the answer for [1, 2, \ldots, H]; then multiply everything by 2 and subtract 1 for each element.
  • If L \gt H, the entire range lies in the right half.
    Here, we get ans(N, L, R) = 2\cdot ans(H, L-H, R-H); just as we had in subtask 1.
  • Finally, we have the case when L \leq H \lt R.
    We can split this into two ranges [L, H] and [H+1, R], and then apply both cases above.

Putting it all together, we obtain:

ans(N, L, R) = \begin{cases} 2\cdot ans(H, L, R) - (R-L+1), & \text{if } R \leq H \\ 2\cdot ans(H, L-H, R-H), & \text{if } L \gt H \\ ans(N, L, H) + ans(N, L+1, R), &\text{otherwise } \end{cases}

For now, we can also set ans(1, 1, 1) = 1 as our base case.

While this recursion is correct, it is unfortunately too slow by itself: since we “split” into two branches whenever we encounter the third case, the number of branches can get quite large.
In fact, we’ll actually just visit every single integer in the range [L, R] via this recursion, so its complexity is \mathcal{O}(N) which is certainly too slow.

However, optimizing it is in fact quite easy!
Instead of setting ans(1, 1, 1) = 1 as the base case, we set ans(N, 1, N) = \frac{N\cdot (N+1)}{2} as the base case - that is, if the query is for the whole range, just return its sum without recursing any further.
Though this optimization might seem simple, it brings our time complexity down to \mathcal{O}(\log N).

Proof

If you’re familiar with segment trees, you might notice that this is exactly how segment tree queries work (at least in most recursive implementations), so the proof of complexity being \mathcal{O}(\log N) carries over from there.
A proof can be found in the linked page (the “sum queries” section).

If you don’t want to read the proof, here’s some intuition.
Consider the first time a query ‘splits’ into two, and look at the left branch of the split.
If the left branch splits any further, the right branch of this second split has to be of the form ans(N, 1, N); because it’ll be one entire half.
Since we set this to be our base case, this branch won’t need to proceed any further.

So, after the first split, no more non-trivial splits are possible; which is where the \mathcal{O}(\log N) bound comes from.

Make sure to print the answer modulo 10^9 + 7.
Also watch out for overflow errors, especially when computing \frac{N\cdot (N+1)}{2} since N can be quite large.
To get around it, you might want to take N modulo 10^9 + 7 before performing the multiplication.

TIME COMPLEXITY

\mathcal{O}(\log N) per testcase.

CODE:

Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};
constexpr int mod = (int)1e9 + 7;
struct mi {
    int64_t v; explicit operator int64_t() const { return v % mod; }
    mi() { v = 0; }
    mi(int64_t _v) {
        v = (-mod < _v && _v < mod) ? _v : _v % mod;
        if (v < 0) v += mod;
    }
    friend bool operator==(const mi& a, const mi& b) {
        return a.v == b.v; }
    friend bool operator!=(const mi& a, const mi& b) {
        return !(a == b); }
    friend bool operator<(const mi& a, const mi& b) {
        return a.v < b.v; }

    mi& operator+=(const mi& m) {
        if ((v += m.v) >= mod) v -= mod;
        return *this; }
    mi& operator-=(const mi& m) {
        if ((v -= m.v) < 0) v += mod;
        return *this; }
    mi& operator*=(const mi& m) {
        v = v*m.v%mod; return *this; }
    mi& operator/=(const mi& m) { return (*this) *= inv(m); }
    friend mi pow(mi a, int64_t p) {
        mi ans = 1; assert(p >= 0);
        for (; p; p /= 2, a *= a) if (p&1) ans *= a;
        return ans;
    }
    friend mi inv(const mi& a) { assert(a.v != 0);
        return pow(a,mod-2); }

    mi operator-() const { return mi(-v); }
    mi& operator++() { return *this += 1; }
    mi& operator--() { return *this -= 1; }
    mi operator++(int32_t) { mi temp; temp.v = v++; return temp; }
    mi operator--(int32_t) { mi temp; temp.v = v--; return temp; }
    friend mi operator+(mi a, const mi& b) { return a += b; }
    friend mi operator-(mi a, const mi& b) { return a -= b; }
    friend mi operator*(mi a, const mi& b) { return a *= b; }
    friend mi operator/(mi a, const mi& b) { return a /= b; }
    friend ostream& operator<<(ostream& os, const mi& m) {
        os << m.v; return os;
    }
    friend istream& operator>>(istream& is, mi& m) {
        int64_t x; is >> x;
        m.v = x;
        return is;
    }
    friend void __print(const mi &x) {
        cerr << x.v;
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;
    int T = input.readInt(0, (int)1e5); input.readEoln();
    while(T-- > 0) {
        long long N = input.readLong(2, (1ll << 60));   input.readSpace();
        long long L = input.readLong(1, N); input.readSpace();
        long long R = input.readLong(L, N); input.readEoln();

        assert((N & (N - 1)) == 0);
        function<pair<mi, mi>(long long, long long, long long)> solve = [&](long long A, long long L, long long R) -> pair<mi, mi> {
            dbg(L, R, A);
            if(R < 1 || L > A)   return make_pair(mi(0), mi(0));
            if(R >= A && L <= 1) {
                return make_pair(mi(A / 2) * mi(A - 1) + mi(A), mi(A));
            }
            assert(A != 1);
            auto pl = solve(A / 2, L, R);
            auto pr = solve(A / 2, L - A / 2, R - A / 2);
            return make_pair(2 * (pl.first + pr.first) - pl.second, pl.second + pr.second);
        };

        cout << solve(N, L, R).first << '\n';
    }
    input.readEof();

    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
def calc(n, l, r):
    if l == 1 and r == n: return (n*(n+1)//2 ) % mod
    mid = n//2
    if r <= mid: return (2*calc(n//2, l, r) - (r-l+1)) % mod
    if l > mid: return (2*calc(n//2, l - mid, r - mid))%mod
    ret = (2*calc(n//2, l, n//2) - (n//2 - l + 1)) % mod + (2*calc(n//2, 1, r - mid) % mod)
    return ret % mod

import sys
input = sys.stdin.readline
for _ in range(int(input())):
    n, l, r = map(int, input().split())
    print(calc(n, l, r))
1 Like

A slight simplification is to write a function F(N, R) that computes the prefix sum from 1 to R and compute the answer as F(N, R) - F(N, L-1). This also has the advantage of making it obvious that the complexity is O(\log N) because there is only one recursive call at each step.

1 Like

@iceknight1093 Why does my submission which does not implement the base-case cutoff optimization mentioned in the editorial pass the TL?

This problem can be solved by another way also.
First lets find the K-th element in the array. Array elements will range from indices 0 to N-1 and to represent any index between this range we require log(N) bits.
K-th elemnet will be present at (K-1)-th index. The element at any index can be found by 1 + reverseing the first log(N) bits of index.

example- lets say we have N=8 and K=4
so, number of bits required= log(N) = 3.
and K-th element(4-th) is present at index 3. And 3 can be represented as 011 in binary. Now reverse of this is 110 which is 6 in decimal. So 4-th element is 1+6 = 7.
array for N=8 is [1,5,3,7,2,6,4,8].

Now we know how to calculate the K-th element. Lets define f(K) as sum of first K elements of the array. To calculate sum of first K elements we just need to observe the indices from 0 to k-1 and then we have to find the number of set bits at each position and multiply with different powers of 2 , from 0 to (log(N)-1) to each number according to above logic and then add K to it (as 1 will be added to each number). So our solution would be f(R)-f(L-1).

example- N=8, L=2, R=6.
To calculate f(R) we do as;
indices of first 6 element will be from 0 to 5 i.e.
000
001
010
011
100
101
Now we can see number of set bits at each position i.e [2,2,3].
Now sum of first K element will be =( 2.2^0 + 2.2^1 + 3.2^2) + K = (18) + 6 = 24.
similarly f(L-1)=f(2-1)=f(1) =1.
so ans = f(R)-f(L-1) = 24-1=23.

Here is submission CodeChef: Practical coding for everyone

2 Likes

Your implementation differs a bit from the one I described the editorial - in fact, the way you implemented the recursion is how queries for iterative segment trees are done which is why it works fast.
Here’s a blogpost about them.

I solved try to solve this Question using Segment tree (Query) like approach , and I think I am doing this right way , simple dividing and dividing , and computing the answer recursively.
But I am afraid that I am getting so much confused in taking MOD.
Please have a look to my code and please try to guide me on MODULO , where I am doing error !

@iceknight1093 I am still finding it difficult to understand the left recursive part of subtask 1. There are odd numbers at the left half but it’s really difficult to understand how recursively it is getting correct odd number at the position.
2*i + 1 is odd but why again sequence get’s divided into (1,2,3,4 … N/2)

Cleaner code, using same approach…
https://www.codechef.com/viewsolution/1021867544

1 Like

If you understand how the right half recursion works, the left half is pretty much the same way.
For example, if the right half is [2, 4, 6, 8] then clearly you can find the answer for [1, 2, 3, 4] and multiply it by two right?

Similarly, if the left half is [1, 3, 5, 7], you can instead find the answer for [1, 2, 3, 4]; then multiply it by 2 and subtract 1.
Essentially, you have

[1, 3, 5, 7, \ldots, N-1] = [(2\cdot 1) - 1, (2\cdot 2) - 1, (2\cdot 3) - 1, \ldots, (2\cdot \frac{N}{2}) - 1]

so you go in reverse, by finding the answer for [1, 2, 3, \ldots, \frac{N}{2}], then multiply it by 2 to get numbers [2, 4, 6, 8, \ldots, N], then subtract 1 to get [1, 3, 5, 7, \ldots, N-1].

1 Like