PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Authors: triggered_code and iceknight1093
Tester: apoorv_me
Editorialist: iceknight1093
DIFFICULTY:
2622
PREREQUISITES:
Recursion
PROBLEM:
You’re given N, L, and R. N is a power of 2.
Consider the following process f(A) for an array A:
- If |A| = 2, return A.
- Otherwise, split the array into two halves: the elements at odd indices and the elements at even indices.
Recursively process these arrays, then return their concatenation.
Find the sum of the L-th through R-th elements of the resulting array, modulo 10^9 + 7.
EXPLANATION:
The given process is recursive, so let’s also try to use recursion to solve the problem.
Subtask 1
We’re looking for a single value here. Let K = L = R be that index.
Let ans(N, K) denote the value we’re looking for.
f(A) partitions the array into two halves, and then recursively applies the same process to each half.
So, we have two cases for which half K is in: either K \leq \frac{N}{2}, or K \gt \frac{N}{2}. Let’s look at them separately.
- If K \gt \frac{N}{2}, then we want to find the (K - \frac{N}{2})-th element of the array f([2, 4, 6, 8, \ldots, N]).
Notice that this array we operate on is just \left[1 ,2, 3 ,\ldots, \frac{N}{2}\right], but with all its elements multiplied by 2.
So, the value we want is just 2\cdot ans\left(\frac{N}{2}, K - \frac{N}{2}\right). - If K \leq \frac{N}{2}, we want to find the K-th element of f([1, 3, 5, \ldots, N-1]).
Everything here is odd - in particular, the i-th element is 2i - 1.
So, we could instead find the K-th element of f([1, 2, 3, \ldots, \frac{N}{2}]), and then multiply it by 2 and subtract 1.
That is, we want 2\cdot ans(\frac{N}{2}, K) - 1.
Putting both cases together, we have:
With the base case being ans(N, K) = K if N = 1, since the process doesn’t change arrays of length 1.
At each step of the recursion, we halve N, so we find the answer in \mathcal{O}(\log N) time.
There are other ways to solve this subtask: for example, if you try hard enough and stare at the outputs for small N and K, you can probably observe some sort of pattern based on N and K.
Subtask 2
Let’s generalize our recursion from subtask 1 to ranges.
Let ans(N, L, R) denotet the answer we’re looking for.
Let H = \frac{N}{2} be half of N.
There are three cases:
- If R \leq H, then the entire range lies in the left half.
Here, we get ans(N, L, R) = 2\cdot ans(H, L, R) - (R-L+1).
This is because, as noted earlier, the left half is elements of the form
[(2\cdot 1 - 1), (2\cdot 2 - 1), (2\cdot 3 - 1), \ldots]
So, we need to find the answer for [1, 2, \ldots, H]; then multiply everything by 2 and subtract 1 for each element. - If L \gt H, the entire range lies in the right half.
Here, we get ans(N, L, R) = 2\cdot ans(H, L-H, R-H); just as we had in subtask 1. - Finally, we have the case when L \leq H \lt R.
We can split this into two ranges [L, H] and [H+1, R], and then apply both cases above.
Putting it all together, we obtain:
For now, we can also set ans(1, 1, 1) = 1 as our base case.
While this recursion is correct, it is unfortunately too slow by itself: since we “split” into two branches whenever we encounter the third case, the number of branches can get quite large.
In fact, we’ll actually just visit every single integer in the range [L, R] via this recursion, so its complexity is \mathcal{O}(N) which is certainly too slow.
However, optimizing it is in fact quite easy!
Instead of setting ans(1, 1, 1) = 1 as the base case, we set ans(N, 1, N) = \frac{N\cdot (N+1)}{2} as the base case - that is, if the query is for the whole range, just return its sum without recursing any further.
Though this optimization might seem simple, it brings our time complexity down to \mathcal{O}(\log N).
Proof
If you’re familiar with segment trees, you might notice that this is exactly how segment tree queries work (at least in most recursive implementations), so the proof of complexity being \mathcal{O}(\log N) carries over from there.
A proof can be found in the linked page (the “sum queries” section).
If you don’t want to read the proof, here’s some intuition.
Consider the first time a query ‘splits’ into two, and look at the left branch of the split.
If the left branch splits any further, the right branch of this second split has to be of the form ans(N, 1, N); because it’ll be one entire half.
Since we set this to be our base case, this branch won’t need to proceed any further.
So, after the first split, no more non-trivial splits are possible; which is where the \mathcal{O}(\log N) bound comes from.
Make sure to print the answer modulo 10^9 + 7.
Also watch out for overflow errors, especially when computing \frac{N\cdot (N+1)}{2} since N can be quite large.
To get around it, you might want to take N modulo 10^9 + 7 before performing the multiplication.
TIME COMPLEXITY
\mathcal{O}(\log N) per testcase.
CODE:
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif
struct input_checker {
string buffer;
int pos;
const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";
input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}
int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
now++;
}
return now;
}
string readOne() {
assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}
string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
string res = readOne();
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}
int readInt(int minv, int maxv) {
assert(minv <= maxv);
int res = stoi(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
long long res = stoll(readOne());
assert(minv <= res);
assert(res <= maxv);
return res;
}
auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readInt(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
v[i] = readLong(minv, maxv);
if (i+1 < n) readSpace();
}
return v;
}
void readSpace() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}
void readEoln() {
assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}
void readEof() {
assert((int) buffer.size() == pos);
}
};
constexpr int mod = (int)1e9 + 7;
struct mi {
int64_t v; explicit operator int64_t() const { return v % mod; }
mi() { v = 0; }
mi(int64_t _v) {
v = (-mod < _v && _v < mod) ? _v : _v % mod;
if (v < 0) v += mod;
}
friend bool operator==(const mi& a, const mi& b) {
return a.v == b.v; }
friend bool operator!=(const mi& a, const mi& b) {
return !(a == b); }
friend bool operator<(const mi& a, const mi& b) {
return a.v < b.v; }
mi& operator+=(const mi& m) {
if ((v += m.v) >= mod) v -= mod;
return *this; }
mi& operator-=(const mi& m) {
if ((v -= m.v) < 0) v += mod;
return *this; }
mi& operator*=(const mi& m) {
v = v*m.v%mod; return *this; }
mi& operator/=(const mi& m) { return (*this) *= inv(m); }
friend mi pow(mi a, int64_t p) {
mi ans = 1; assert(p >= 0);
for (; p; p /= 2, a *= a) if (p&1) ans *= a;
return ans;
}
friend mi inv(const mi& a) { assert(a.v != 0);
return pow(a,mod-2); }
mi operator-() const { return mi(-v); }
mi& operator++() { return *this += 1; }
mi& operator--() { return *this -= 1; }
mi operator++(int32_t) { mi temp; temp.v = v++; return temp; }
mi operator--(int32_t) { mi temp; temp.v = v--; return temp; }
friend mi operator+(mi a, const mi& b) { return a += b; }
friend mi operator-(mi a, const mi& b) { return a -= b; }
friend mi operator*(mi a, const mi& b) { return a *= b; }
friend mi operator/(mi a, const mi& b) { return a /= b; }
friend ostream& operator<<(ostream& os, const mi& m) {
os << m.v; return os;
}
friend istream& operator>>(istream& is, mi& m) {
int64_t x; is >> x;
m.v = x;
return is;
}
friend void __print(const mi &x) {
cerr << x.v;
}
};
int32_t main() {
ios_base::sync_with_stdio(0); cin.tie(0);
input_checker input;
int T = input.readInt(0, (int)1e5); input.readEoln();
while(T-- > 0) {
long long N = input.readLong(2, (1ll << 60)); input.readSpace();
long long L = input.readLong(1, N); input.readSpace();
long long R = input.readLong(L, N); input.readEoln();
assert((N & (N - 1)) == 0);
function<pair<mi, mi>(long long, long long, long long)> solve = [&](long long A, long long L, long long R) -> pair<mi, mi> {
dbg(L, R, A);
if(R < 1 || L > A) return make_pair(mi(0), mi(0));
if(R >= A && L <= 1) {
return make_pair(mi(A / 2) * mi(A - 1) + mi(A), mi(A));
}
assert(A != 1);
auto pl = solve(A / 2, L, R);
auto pr = solve(A / 2, L - A / 2, R - A / 2);
return make_pair(2 * (pl.first + pr.first) - pl.second, pl.second + pr.second);
};
cout << solve(N, L, R).first << '\n';
}
input.readEof();
return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
def calc(n, l, r):
if l == 1 and r == n: return (n*(n+1)//2 ) % mod
mid = n//2
if r <= mid: return (2*calc(n//2, l, r) - (r-l+1)) % mod
if l > mid: return (2*calc(n//2, l - mid, r - mid))%mod
ret = (2*calc(n//2, l, n//2) - (n//2 - l + 1)) % mod + (2*calc(n//2, 1, r - mid) % mod)
return ret % mod
import sys
input = sys.stdin.readline
for _ in range(int(input())):
n, l, r = map(int, input().split())
print(calc(n, l, r))