COUNTGOOD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author:
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Easy-Medium

PREREQUISITES:

Familiarity with bitwise operations

PROBLEM:

Given N, K, Q, answer Q queries of the following form:

  • Given L and R, count the number of good numbers in [L, R]. x is a good number if it can be written as i_1\mid i_2\mid \cdots \mid i_M for some integers K \leq i_1 \lt i_2 \lt \cdots \lt i_M \lt N+K.

EXPLANATION:

First, let’s figure out which numbers are good.

In simple words, a good number is one that can be written as the bitwise OR of several values that are all in the range [K, N+K).

Let’s work with a simple version first: what if K = 0?
That is, we simply have all numbers in the range [0, N).
In this case, it’s easy to see that if 2^h is the largest power of 2 that’s less than N, we can make all h-bit integers using the combination of 2^0, 2^1, \ldots, 2^h, and can’t make anything beyond that.
So, in this case, the good numbers exactly form the range [0, 2^{h+1}).


Now, let’s see what happens with a more general K.
First, let P be the longest common (binary) prefix of all the integers with us.
P can in fact be found as simply the longest common prefix of K and N+K-1, the two extreme values: if they share a prefix, everything between them will also share the same prefix.
Since every number with us has P as a prefix, every number we can form with their bitwise OR will also have P as its prefix. This means we only need to care about smaller bits.

Now, let b be the highest bit where K and N+K-1 differ.
We now have two separate sets of values: [K, P + 2^b) which doesn’t have b set, and [P+2^b, N+K) which does have b set.

Looking at them separately,

  • The range [K, P + 2^b) can only form numbers within it using the bitwise OR operation.
    No smaller value is possible because the bitwise OR of a set of numbers can’t be smaller than a number in the set; and no larger value is possible because none of these values have the b-th bit set.
  • The range [P + 2^b, N+K) is very similar to the K = 0 case we worked on initially.
    Indeed, since everything in this range has a prefix of P + 2^b, we’re functionally working with only the lower bits, i.e. the range [0, N+K - P - 2^b).
    The range of good values here is something already known to us - in particular, if h is the largest integer such that 2^h \lt N+K - P - 2^b, everything in [0, 2^{h+1}) is good.
    Translating back to the original, everything in [P + 2^b, P + 2^b + 2^{h+1}) is good.

Combining the above, we see that simply everything in the range [K, N+K+2^b+2^{h+1}) is good.

However, the above discussion considered numbers with the bit b set and unset separately.
We also need to consider bitwise ORs obtained by choosing numbers of both types.
Here, observe that everything in the range [K + 2^b, P + 2^b + 2^b) = [K+2^b, P+2^{b+1}) is certainly obtainable, simply by choosing any number in [K, K+2^b) (which doesn’t have b set) and then choosing P+2^b which sets b and doesn’t change the lower bits.

It’s not hard to see that these are the only numbers obtainable this way.
After all, anything larger would result in the b+1-th bit of the prefix changing (which is impossible, since we know all bits larger than the b-th are fixed), while anything smaller is impossible the instant we choose a number from [K, P + 2^b) in the first place.


To summarize, the set of good numbers is in fact quite easy to characterize.
Let P be the largest common (binary) prefix of K and N+K-1, b be the first bit at which they differ, and h be the largest integer such that 2^h \lt N+K-P-2^b.
Then, the good numbers are simply the union of the intervals
[K, N+K+2^{b} + 2^{h+1}) and [K+2^b, P + 2^{b+1}).

Once these two intervals are known, answering queries [L, R] is quite easy: simply find the intersection of [L, R] with the above intervals and add up the lengths.
Note that the above intervals might not be disjoint, so take care to not overcount values.

TIME COMPLEXITY:

\mathcal{O}(\log(N+K) + Q) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define fast                       \
 ios_base::sync_with_stdio(0);      \
 cin.tie(0);                         \
 cout.tie(0);

int main(){
    fast;
    ll t;
    cin>>t;
    while(t--){
        ll n,k,q;
        cin>>n>>k>>q;
        ll n1=k,n2=n+k-1;
        ll xx=0;
        for(ll i=61;i>=0;i--){
            if((n2&(1LL<<i))==((n1&(1LL<<i)))){
                if((n1&(1LL<<i))>0){
                    xx+=(1LL<<i);
                    n1-=(1LL<<i);
                    n2-=(1LL<<i);
                }
            }
            else break;
        }
        vector<pair<ll,ll>> v1,v;
        v1.push_back({n1,n2});
        ll x=0,y=0;
        if(n2>0) x=1LL<<(__lg(n2));
        if(n2>x) y=1LL<<(__lg(n2-x));
        v1.push_back({x,x+2*y-1});
        v1.push_back({x+n1,2*x-1});
        sort(v1.begin(),v1.end());
        for(ll i=0;i<(ll)v1.size();i++){
            v1[i].first+=xx;
            v1[i].second+=xx;
        }
        for(ll i=0;i<(ll)v1.size();i++){
            if((ll)v.size()==0) v.push_back({v1[i].first,v1[i].second});
            else{
                if(v1[i].first<=v.back().second){
                    ll x=v.back().first,y=v.back().second;
                    v.pop_back();
                    x=min(x,v1[i].first);
                    y=max(y,v1[i].second);
                    v.push_back({x,y});
                }
                else v.push_back({v1[i].first,v1[i].second});
            }
        }
        while(q--){
            ll l,r;
            cin>>l>>r;
            ll ans=0;
            for(ll i=0;i<(ll)v.size();i++){
                ll x=v[i].first,y=v[i].second;
                if(l>=x && l<=y){
                    ans+=(min(r,y)-l+1);
                }
                else if(r>=x && r<=y){
                    ans+=(r-max(l,x)+1);
                }
                else if(l<=x && y<=r){
                    ans+=(y-x+1);
                }
            }
            cout<<ans<<"\n";
        }
    }
}
Editorialist's code (PyPy3)
import sys
input = sys.stdin.readline

for _ in range(int(input())):
    n, k, q = map(int, input().split())
    
    l1, r1 = k, k
    l2, r2 = -1, -1
    
    pref = 0
    for b in reversed(range(61)):
        if (k >> b) & 1 == ((n + k - 1) >> b) & 1:
            pref += k & (1 << b)
            continue
        
        r1 = pref + (1 << b)
        r2 = r1 + (1 << b) - 1
        for b2 in reversed(range(b)):
            if ((n + k - 1) >> b2) & 1:
                r1 += 2 << b2
                r1 -= 1
                break
        
        l2 = l1 + (1 << b)
        l2 = max(l2, r1 + 1)
        break
    
    def intersect(a, b, x, y):
        if x > y or b < x or y < a: return 0
        return min(b, y) - max(a, x) + 1

    while q > 0:
        l, r = map(int, input().split())
        print(intersect(l, r, l1, r1) + intersect(l, r, l2, r2))
        q -= 1
2 Likes