BOWSAFE - Editorial

Problem statement
Contest source

Author : Sumit Kumar Sahu
Editorialist : Sumit Kumar Sahu
Tester : Miten Shah

DIFFICULTY:

Medium

PREREQUISITES:

DP, Matrix Exponentiation, Graphs

PROBLEM:

Given three integers L, R and K. Consider all the integers between L and R (both endpoints included) and write them out in binary format with K bits. Then you have to tell the number of ways in which you can arrange the bits of these numbers such that the bitwise XOR of the resulting numbers is maximum.

EXPLANATION:

Sub Problem 1: Find number of integers between 1 and R (both inclusive) having exactly i (1 \leq i \leq K) set bits .

Consider R = 1011011 (In binary form) and K=7.

First, consider all the numbers of the form 0_ _ _ _ _ _ (Note that all these numbers will be less than R). Then the number of numbers of this form having exactly i set bits = ^6C_i.

Now consider all the numbers of the form 100_ _ _ _ (basically we dropped the next set bit from left). Then the number of numbers of this form having exactly i set bits = ^4C_{i-1}.

Now consider all the numbers of the form 1010_ _ _ (again we dropped the next set bit from left). Then the number of numbers of this form having exactly i set bits = ^3C_{i-2}.

And so on we extend this to calculate the number of numbers having exactly i set bits in the forms 101100_, and 1011010.

Sub Problem 2: Given two integers K and M. You have to form a 2-D matrix T_M of size (K+1) * (K+1) where T_M[i][j] = number of ways to transform a number with i set bits into a number with j set bits using a number having M set bits using XOR operation.

Consider a number with a set bits, and a number with b set bits. Let a be less than equal to b. As given in the question we can rearrange the bits of these numbers. Then the maximum number of set bits we can get :

  • a+b, if (a+b) \leq K
  • 2*K-(a+b), if (a+b)>K

Minimum number of set bits we can get = b-a

Also, note that the number of set bits increases from minimum to maximum with an increment of 2. Hence we can loop b from 0 to K and if b satisfies all the conditions then we can increase the number of ways accordingly.

Main Logic of the problem :

Transition matrix T_M represents the number of ways to transform a number with i set bits into a number with j set bits using a number having M set bits (sub-problem 2).

Let’s say we have X numbers with M set bits.

Then the number of ways to transform a number with i set bits into a number with j set bits using these X numbers = T_M^{X}[i][j]. (T_M^X represents matrix exponentiation)[this is analogous to finding the number of paths from node i to node j using exactly X edges].

Now from sub-problem 1 we can calculate the number of numbers from L to R having exactly 1 \leq i \leq K set bits. Let this be stored as a vector NUM of size K+1.

Then Number of ways to form transform a number with i set bits into a number with j set bits using all numbers from L to R, T[i][j] = T_1 ^{ NUM[1] }* T_2 ^{ NUM[2] }T_{k-1} ^{ NUM[k-1] } * T_k ^{ NUM[k] }

Final answer will be the T[0][x] such that x=max(1,2,3,.,i,.,K) such that T[0][i] is non-zero.

SOLUTION :

c++ Solution (Setter's)
// created by phantom654
#include <bits/stdc++.h>
using namespace std;
#define fastio                                                                 \
        ios_base::sync_with_stdio(false);                                      \
        cin.tie(NULL);                                                         \
        cout.tie(NULL)
#define ll long long
#define matrix vector<vector<ll>>
const ll mod = 1e9 + 7;

ll bin(ll a, ll b) {
        if (b == 0)
                return 1;
        if (b & 1)
                return (a * bin((a * a) % mod, b / 2)) % mod;
        return (1 * bin((a * a) % mod, b / 2)) % mod;
}

ll fact[40];

ll NCR(ll n, ll r) {
        ll ans = fact[n];
        ans = (ans * bin(fact[r], mod - 2)) % mod;
        ans = (ans * bin(fact[n - r], mod - 2)) % mod;
        return ans;
}

// finds number of numbers having i setbis
vector<ll> findNumberofNumsOfBits(ll r, ll k) {
        vector<ll> ans(k + 1);
        int cnt = 0;
        while (r > 0) {
                ll num = log2(r);
                ll ncr[k + 2];
                ncr[0] = 1;
                for (int i = 1; i <= k; i++) {
                        if (i <= num)
                                ncr[i] = (ncr[i - 1] * (num - i + 1)) / i;
                        else
                                ncr[i] = 0;
                }
                for (int i = 1; i <= num; i++) {
                        ans[i + cnt] += ncr[i];
                }
                ans[1 + cnt] += 1;
                cnt++;
                r -= (1 << num);
        }
        return ans;
}

// transition matrix Tm[i][j]
matrix iTojUsingM(int k, int m) {
        matrix dp(k + 1, vector<ll>(k + 1, 0));
        for (int i = 0; i <= k; i++) {
                int a = i, b = m;
                if (a < b)
                        swap(a, b);
                int num = a - b;
                int maxSetBits = a + b;
                if (maxSetBits > k) {
                        maxSetBits = 2 * k - maxSetBits;
                }
                maxSetBits = max(maxSetBits, num);

                for (int j = 0; j <= k; j++) {
                        int oneToOnePair = b;
                        int oneToZeroPair = a - b;
                        int zeroToZeroPair = k - a;
                        int zeroToOnePair = 0;
                        if (j >= num && (j - num) % 2 == 0 && j <= maxSetBits) {
                                int diff = (j - num) / 2;
                                oneToOnePair -= diff;
                                zeroToZeroPair -= diff;
                                oneToZeroPair += diff;
                                zeroToOnePair += diff;
                                ll ways = NCR(oneToOnePair + zeroToZeroPair,
                                              oneToOnePair);
                                ways =
                                    (ways * NCR(oneToZeroPair + zeroToOnePair,
                                                oneToZeroPair)) %
                                    mod;
                                dp[i][j] = ways;
                        } else
                                dp[i][j] = 0;
                }
        }
        return dp;
}

matrix mul(matrix &A, matrix &B) {
        int n = A.size();
        matrix res(n, vector<ll>(n, 0));
        for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                        for (int k = 0; k < n; k++) {
                                res[i][j] += A[i][k] * B[k][j];
                                res[i][j] %= mod;
                        }
                }
        }
        return res;
}

matrix matExpo(matrix &A, long long n) {

        int dimension = A.size();
        matrix res(dimension, vector<ll>(dimension, 0));
        for (int i = 0; i < dimension; i++) {
                res[i][i] = 1;
        }
        while (n > 0) {
                if (n & 1) {
                        res = mul(A, res);
                }
                n /= 2;
                A = mul(A, A);
        }
        return res;
}

void I_m_Beast() {
        ll l, r, k;
        cin >> l >> r >> k;

        assert(2<= l);
        assert(l< r);
        assert(r< (1<<k));
        
        // get the number of numbers having i setbits
        vector<ll> ans = findNumberofNumsOfBits(r, k);
        vector<ll> ans2 = findNumberofNumsOfBits(l - 1, k);
        for (int i = 0; i < ans.size(); i++) {
                ans[i] = (ans[i] - ans2[i] + mod) % mod;
        }

        // identity matirx
        matrix res(k + 1, vector<ll>(k + 1, 0));
        for (int i = 0; i <= k; i++) {
                res[i][i] = 1;
        }

        for (auto i = 0; i <= k; i++) {
                matrix temp = iTojUsingM(k, i);
                temp = matExpo(temp, ans[i]);
                res = mul(res, temp);
        }

        for (int i = k; i >= 0; i--) {
                if (res[0][i] > 0) {
                        cout << res[0][i] << endl;
                        break;
                }
        }
}

int main() {
        fastio;
        srand(time(NULL));
        int t = 1;
        cin >> t;
        fact[0] = 1;
        for (int i = 1; i < 40; i++) {
                fact[i] = (fact[i - 1] * i) % mod;
        }
        while (t--) {
                I_m_Beast();
        }
        return 0;
}

c++ Solution (Tester's)
// created by mtnshh

#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define pb push_back
#define rb pop_back
#define ti tuple<int, int, int>
#define pii pair<int, int>
#define pli pair<ll, int>
#define pll pair<ll, ll>
#define mp make_pair
#define mt make_tuple
 
#define rep(i,a,b) for(ll i=a;i<b;i++)
#define repb(i,a,b) for(ll i=a;i>=b;i--)
 
#define err() cout<<"--------------------------"<<endl; 
#define errA(A) for(auto i:A)   cout<<i<<" ";cout<<endl;
#define err1(a) cout<<#a<<" "<<a<<endl
#define err2(a,b) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<endl
#define err3(a,b,c) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<endl
#define err4(a,b,c,d) cout<<#a<<" "<<a<<" "<<#b<<" "<<b<<" "<<#c<<" "<<c<<" "<<#d<<" "<<d<<endl

#define all(A)  A.begin(),A.end()
#define allr(A)    A.rbegin(),A.rend()
#define ft first
#define sd second

#define V vector<ll>
#define S set<ll>
#define VV vector<V>
#define Vpll vector<pll>
 
#define endl "\n"

long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        // char g = getc(fp);
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
            
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
            // cerr << x << " " << l << " " << r << endl;
            assert(l<=x && x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        // char g=getc(fp);
        assert(g != -1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}

const int max_q = 5;

const ll N = 100005;
const ll INF = 1e12;
const ll M = 1000000007;

// Matrices using vectors

ll add(ll a,ll b){
    return ((a+b)%M+M)%M;
}

ll mult(ll a,ll b){
    return ((a*b)%M+M)%M;
}

struct matrix{
    ll SZ;
    vector<vector<ll>> arr;
    
    matrix(ll n){
        SZ=n;
        arr.resize(n,V(n,0));
    }
    void null(){
        rep(i,0,SZ) rep(j,0,SZ) arr[i][j]=0;
    }
    void iden(){
        null();
        rep(i,0,SZ) arr[i][i]=1;
    }
    matrix operator + (const matrix &o) const{
        matrix res(SZ);
        rep(i,0,SZ) rep(j,0,SZ) res.arr[i][j]=add(arr[i][j],o.arr[i][j]);
        return res;
    }
    matrix operator * (const matrix &o) const{
        matrix res(SZ);
        rep(i,0,SZ) rep(j,0,SZ){
            ll temp=0;
            rep(k,0,SZ) temp = add(temp, mult(arr[i][k],o.arr[k][j]));
            res.arr[i][j]=temp;
        }
        return res;
    }
    void print(){
        rep(i,0,SZ) {rep(j,0,SZ)    cout<<arr[i][j]<<" ";cout<<endl;}cout<<"--"<<endl;
    }
};

matrix matexpo(matrix a,ll n){
    matrix res(a.SZ);
    res.iden();
    while(n){
        if(n&1) res = res*a;
        a = a*a;
        n/=2;
    }
    return res;
}

ll power(ll a,ll n,ll m=M){
    ll ans=1;
    while(n){
        if(n&1) ans=ans*a;
        a=a*a;
        n=n>>1;
        ans=ans%m;
        a=a%m;
    }
    return ans;
}

ll f[N], invf[N];

void pre(){
    f[0] = 1;
    rep(i,1,N)  f[i] = (f[i-1] * i) % M;
    rep(i,0,N)  invf[i] = power(f[i], M-2) % M; 
}

ll ncr(ll n, ll r){
    return (f[n] * ((invf[r] * invf[n-r]) % M)) % M;
}

vector<ll> get(ll x, ll k){
    vector<ll> res(k+1);
    ll a=0;
    repb(i,k-1,0){
        if(x&(1LL<<i)){
            rep(j,0,i+1){
                res[a+j] += ncr(i,j);
            }
            a++;
        }
    }
    res[a]++;
    return res;
}

matrix get_transition_matrix(ll a, ll k){
    matrix res(k+1);
    rep(i,0,k+1){
        ll n1=i,n2=a;
        if(n1<n2)   swap(n1,n2);
        ll mx=min(n1+n2,2*k-n1-n2);
        ll mn=max(n1,n2)-min(n1,n2);
        ll a_11 = n2, a_10 = n1 - n2, a_01 = 0, a_00 = k - n1;
        rep(j,mn,mx+1){
            ll ans = ncr(a_10+a_01,a_01);
            ans = (ans * ncr(a_00+a_11,a_00)) % M;
            res.arr[i][j] = ans;
            j++;a_11--;a_01++;a_00--;a_10++;
        }
    }
    return res;
}

void solve(){
    ll l = readIntSp(2, INT_MAX), r = readIntSp(2, INT_MAX), k = readIntLn(2, 30);
    assert(l <= (1LL<<k));
    assert(r <= (1LL<<k));
    assert(l<r);
    vector<ll> v1 = get(r,k), v2 = get(l-1,k);
    vector<ll> v(k+1);
    rep(i,0,k+1)  v[i]=v1[i]-v2[i];
    matrix res(k+1);
    res.iden();
    rep(i,0,k+1){
        matrix t = get_transition_matrix(i,k);
        res = (res * matexpo(t, v[i]));
    }
    repb(i,k,0){
        if(res.arr[0][i]>0){
            cout << res.arr[0][i] << endl;
            break;
        }
    }
}

int main(){
    ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
    ll q = readIntLn(1, max_q);
    pre();
    while(q--){
        solve();
    }
    assert(getchar()==-1);
}