COUNTONES - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4
Practice

Setter: Lavish Gupta
Testers: Tejas Pandey and Abhinav sharma
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

Recursion

PROBLEM

Alice recently converted all the positive integers from 1 to 2^N - 1 (both inclusive) into binary strings and stored them in an array S. Note that the binary strings do not have leading zeroes.
While she was out, Bob sorted all the elements of S in lexicographically increasing order.

Let S_i denotes the i^{th} string in the sorted array.
Alice defined a function F such that F(S_i) is equal to the count of 1 in the string S_i.
For example, F(101) = 2 and F(1000) = 1.

Given a positive integer K, find the value of \displaystyle\sum_{i = 1}^K F(S_i).

String P is lexicographically smaller than string Q if one of the following satisfies:

  • P is a prefix of Q and P \neq Q.
  • There exists an index i such that P_i \lt Q_i and for all j \lt i, P_j=Q_j.

EXPLANATION

Observation 1: All binary strings start with and we are counting the number of ones in the first K strings, we can increase the answer by K and remove the first character from all strings.

Hence, we shall remove the first bit from all numbers and add K to the answer. (This is followed throughout the editorial).

For example, for N = 3, the set of strings in sorted order is \{"1", "10", "100", "101", "11", "110", "111"\}. After removing first bit, the set of strings become \{"", "0", "00", "01", "1", "10", "11"\}. (Note that string "0" and "00" are considered different).

Simpler problem

Let’s assume K = 2^N-1 (The size of set). Consider all strings present in the set and find the sum of the number of ON bits among all strings.

For N = 1, the set of strings S_1 is \{""\}
For N = 2, the set of strings S_2 is \{"", "0", "1"\}
For N = 3, the set of strings S_3 is \{"", "0", "00", "01","1", "10", "11"\}
and so on.

Let f(N) denote the number of ones among all strings present in the set, and sz(N) denote the size of the set of strings.

For example, we have cnt(1) = 1, f(1) = 0 as base case. Let’s see how S_i is constructed.

We can see that S_n = \{""\} \bigcup (\{"0", "1"\} \times S_{n-1}).

First, we add an empty string to S_n. Then we prepend 0 to each string in S_{n-1} and add it to S_n in the same order, and then prepend 1 to each string in S_{n-1} and add it to S_n in the same order.

Conversely, if we have set S_4, the first string would be empty.
Let’s group the remaining strings by their first character. One group will contain all strings of S_3 prepended with 0, and the other group will contain all strings of S_3 prepended with 1.

Based on all above, we can make few observations. sz(N) = 1+2*sz(N-1). Let’s derive some formula for f(N), if we already know f(N-1). Since set S_{N-1} is appended twice, all ones in S_{N-1} would be present two times in S_N. So we add 2*f(N-1) to f(N).

Secondly, when we added S_{N-1} second time, we prepended all strings in S_{N-1} with 1. This adds sz(N-1) ones.

Hence, the recurrence relation for number of ones is f(N) = 2*f(N-1) + sz(N-1)

Using this relation, we can solve the simpler version.

Original problem

Now, we know how to solve this for K = 2^N-1. Let’s try solving it for general K.

For a given N and K, let’s denote solve(N, K) solving the required problem.

Based on K, there can be three cases.

  • K = 1 This only counts the empty string case. No ones are to be added.
  • 1 \lt K \leq 1+sz(N-1): This counts the empty string and the strings of S_{N-1} prepended with 0. This is caculated as solve(N-1, K-1). (K is reduced due to empty string).
  • 1+sz(N-1) \lt K: In this case, the empty string is considered, the set S_{N-1} prepended with 0 is considered, and part of set S_{N-1} prepended with 1 is considered.

For third case, empty string contributes 0 ones, S_{N-1} prepended with 0 contributes f(N-1) ones. Now we are left with counting the number of ones in first K-(1+sz(N-1)) elements of S_{N-1} prepended with ones.

The leading one contribute K-(1+sz(N-1)) ones, and the remaining is an identical subproblem solve(N-1, K-1-cnt(N-1))

Hence, we can write solve(N, K) as
solve(N, K) = \begin{cases} 0 & \quad \text{K = 1}\\ solve(N-1, K-1) & \quad \text{if } 1 < K \leq 1+cnt(N-1) \\ f(N-1) + K' + solve(N-1, K') & \quad \text{else } K' = (K-(1+cnt(N-1))) \end{cases}
Since 1 \leq K \leq cnt(N) and cnt(1) = 1, the recursive calls don’t go beyond N = 1.

We can precompute f and cnt for all 50 values of N.

General Note
This kind of technique is widely usable, whenever we have a large sequence (recursive sequence), and we need to compute some arithmetic function over the first K of those. It tries to trace the path from the root (empty string) to a specific leaf, such that the nodes on the path divide first K nodes into some kind of complete groups (For this problem, the complete group was S_{N-1} completely being used).

TIME COMPLEXITY

Since solve(N, K) calls solve(N-1, K'), N reduces by one at each iteration. So the time complexity is written is T(N) = T(N-1) + O(1) which is O(N).

Hence, the time complexity is O(N) per test case.

SOLUTIONS

Setter's Solution
#define ll long long
#define dd long double
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define mp make_pair
#define mt make_tuple
#define fo(i , n) for(ll i = 0 ; i < n ; i++)
#define tll tuple<ll ,ll , ll> 
#define pll pair<ll ,ll> 
#include<bits/stdc++.h>
/*#include<iomanip>   
#include<cmath>
#include<cstdio>
#include<utility>
#include<iostream>
#include<vector>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<bitset>*/
dd pi = acos(-1) ;
ll z =  1e18 ;
ll inf = 100000000000000000 ;
ll p1 = 37 ;
ll p2 = 53 ;
ll mod1 =  202976689 ;
ll mod2 =  203034253 ;
ll fact[100] ;
ll gdp(ll a , ll b){return (a - (a%b)) ;}
ll ld(ll a , ll b){if(a < 0) return -1*gdp(abs(a) , b) ; if(a%b == 0) return a ; return (a + (b - a%b)) ;} // least number >=a divisible by b
ll gd(ll a , ll b){if(a < 0) return(-1 * ld(abs(a) , b)) ;    return (a - (a%b)) ;} // greatest number <= a divisible by b
ll gcd(ll a , ll b){ if(b > a) return gcd(b , a) ; if(b == 0) return a ; return gcd(b , a%b) ;}
ll e_gcd(ll a , ll b , ll &x , ll &y){ if(b > a) return e_gcd(b , a , y , x) ; if(b == 0){x = 1 ; y = 0 ; return a ;}
ll x1 , y1 , g; g = e_gcd(b , a%b , x1 , y1) ; x = y1 ; y = (x1 - ((a/b) * y1)) ; return g ;}
ll power(ll a ,ll b , ll p){if(b == 0) return 1 ; ll c = power(a , b/2 , p) ; if(b%2 == 0) return ((c*c)%p) ; else return ((((c*c)%p)*a)%p) ;}
ll inverse(ll a ,ll n){return power(a , n-2 , n) ;}
ll max(ll a , ll b){if(a > b) return a ; return b ;}
ll min(ll a , ll b){if(a < b) return a ; return b ;}
ll left(ll i){return ((2*i)+1) ;}
ll right(ll i){return ((2*i) + 2) ;}
ll ncr(ll n , ll r){if(n < r|| (n < 0) || (r < 0)) return 0 ; return ((((fact[n] * inverse(fact[r] , z)) % z) * inverse(fact[n-r] , z)) % z);}
void swap(ll&a , ll&b){ll c = a ; a = b ; b = c ; return ;}
//ios_base::sync_with_stdio(0);
//cin.tie(0); cout.tie(0);
using namespace std ;
#include <ext/pb_ds/assoc_container.hpp> 
#include <ext/pb_ds/tree_policy.hpp> 
using namespace __gnu_pbds; 
#define ordered_set tree<int, null_type,less<int>, rb_tree_tag,tree_order_statistics_node_update>
// ordered_set s ; s.order_of_key(val)  no. of elements strictly less than val
// s.find_by_order(i)  itertor to ith element (0 indexed)
//__builtin_popcount(n) -> returns number of set bits in n
ll seed;
mt19937 rnd(seed=chrono::steady_clock::now().time_since_epoch().count()); // include bits

ll p[62], f[62];


void get_ans(ll k, ll n, ll &ans)
{
    if(n == 0)
        return ;

    ll half_size = p[n-1]-1 ;
    k-- ;

    // cout << endl ;
    // cout << "k = " << k << " n = " << n << endl ;
    // cout << "half_size = " << half_size << endl ;

    if(k == 0)
        return ;

    if(k <= half_size)
    {
        get_ans(k , n-1 , ans) ;
    }
    else
    {
        ans += f[n-2] ;
        ans += (k - half_size) ;
        ans %= z ;

    //cout << "add1 = " << f[n-2] << " add2 = " << (k - half_size) << " ans = " << ans << endl ;
        get_ans(k - half_size , n-1 , ans) ;
    }

    return ;
}
 
void solve()
{
    ll n, k ;
    cin >> n >> k ;
    ll ans = 0 ;
    ans += k ;
    ans %= z ;

    get_ans(k, n, ans) ;
    cout << ans << '\n' ;

}


int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    #ifndef ONLINE_JUDGE
    freopen("inputf.txt" , "r" , stdin) ;
    freopen("outputf.txt" , "w" , stdout) ;
    freopen("error.txt" , "w" , stderr) ;
    #endif
 
    ll t = 1;
    cin >> t ;

    p[0] = 1 ;
    for(ll i = 1 ; i < 62 ; i++)
        p[i] = (p[i-1]*2)%z ;

    f[0] = 0 ;
    for(ll i = 1 ; i < 62 ; i++)
    {
        f[i] = (p[i-1]*i)%z ;
        f[i] = (f[i-1] + f[i])%z ;
    }

    while(t--)
    {
        solve() ;
    }
    
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
 
    return 0;
}
Tester's Solution 1
#include <bits/stdc++.h>
using namespace std;


/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}


/*
------------------------Main code starts here----------------------------------
*/

const int MAX_T = 3000;
const int MAX_N = 50;
const int MAX_SUM_N = 100000;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ll long long int
long long int sum_len=0;

ll mpow(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b&1) res *= a;
        a *= a;
        b >>= 1;
    }
    return res;
}

void solve()
{
    int n = readIntSp(1, MAX_N);
    sum_len += n;
    assert(sum_len <= MAX_SUM_N);
    long long k = readIntLn(1, (1LL<<n) - 1);
    long long ans = k;
    k -= n;
    for(int i = 1; k > 0; i++) {
        long long int nums = (1LL<<i) - 1;
        if(k < nums) {
            ans += k;
            k -= i;
            i = 0;
            continue;
        }
        ans += (i*(1LL<<(i - 1)));
        k -= nums;
    }
    cout << ans << "\n";
}

signed main()
{
    //fast;
    #ifndef ONLINE_JUDGE
    //freopen("input.txt", "r", stdin);
    //freopen("output.txt", "w", stdout);
    #endif


    int t = readIntLn(1, MAX_T);

    for(int i=1;i<=t;i++)
    {
        solve();
    }

    assert(getchar() == -1);
}
Tester's Solution 2
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
 
int sum_n = 0, sum_m = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 998244353;

ll po(ll x, ll n){ 
    ll ans=1;
    while(n>0){ if(n&1) ans=(ans*x)%mod; x=(x*x)%mod; n/=2;}
    return ans;
}


void solve()
{   
    ll n = readIntSp(1, 50);
    sum_n+=n;
    ll k = readIntLn(0, (1ll<<n)-1);

    ll ans = k;

    k-=n;
    ll sz = n-1;
    while(k>0){

        for(ll i=1; i<=sz; i++){
            if(k<((1ll<<i)-1)){
                ans += k;
                sz = i-1;
                k-=i;
                break;
            }

            ans += (i*(1ll<<(i-1)));
            k -= ((1ll<<i)-1);

            //cout<<ans<<" "<<k<<'\n';
        }
    }

    cout<<ans<<'\n';

    
}
 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
    t = readIntLn(1,3000);
    
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
   
    assert(getchar() == -1);
    assert(sum_n<=1e5);
 
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    cerr<<"Sum of lengths : " << sum_n <<'\n';
    cerr<<"Maximum length : " << max_n <<'\n';
    // cerr<<"Total operations : " << total_ops << '\n';
    //cerr<<"Answered yes : " << yess << '\n';
    //cerr<<"Answered no : " << nos << '\n';
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class COUNTONES{
    //SOLUTION BEGIN
    long[] f, cnt;
    void pre() throws Exception{
        f = new long[51];
        cnt = new long[51];
        cnt[0] = 1;
        f[0] = 0;
        for(int i = 1; i<= 50; i++){
            cnt[i] = cnt[i-1]*2+1;
            f[i] = f[i-1]*2+cnt[i-1];
        }
    }
    void solve(int TC) throws Exception{
        int N = ni();
        long K = nl();
        pn(K + sum(N-1, K));
    }
    long sum(int B, long K) throws Exception{
        if(K == 1)return 0;
        if(B == 0)hold(false);
        K--;
        if(K <= cnt[B-1])return sum(B-1, K);
        return f[B-1]+(K-cnt[B-1])+sum(B-1, K-cnt[B-1]);
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?ni():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        new COUNTONES().run();
    }
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}

    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }

        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }

        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }

        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

4 Likes