MAXAND18 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2

Setter: SHUBHAM KHANDELWAL
Tester: Trung Nguyen
Editorialist: Taranpreet Singh

DIFFICULTY

Simple

PREREQUISITES

Greedy

PROBLEM

Given a sequence A of N integers, you must choose an integer X such that X has exactly K bits set and \displaystyle\sum_{i = 1}^N A_i \wedge X is maximized where \wedge denotes bitwise AND. Among all such X, you must find the minimum possible value of X.

QUICK EXPLANATION

  • We can write the above expression in terms of the contribution of each bit individually. If there are C_b integers with b-th bit on, b-th bit can contribute C_b*2^b to the answer.
  • We find contribution for each bit and sort, first in decreasing order of value and then by increasing order of the bits (to minimize X) and select first K bits.

EXPLANATION

Let’s notice how AND function behaves. Suppose X has x-th and y-th bit ON and rest all bits OFF. We can see that we can write the sum as \displaystyle\sum_{i = 1}^N A_i \wedge 2^x + \sum_{i = 1}^N A_i \wedge 2^y since AND operation works independently for each bit.

This allows us to separate the contribution of each bit into the final sum for some X. Formally, if some bit b is set in X, we add 2^b exactly C_b times, where C_b is the number of integers in A having b-th bit ON.

Hence, for each bit, we can calculate C_b beforehand. Now b-th bit, if set in X, contributes 2^b*C_b to final sum. Hence, we need to select K bits among these, to maximize sum and then minimize X for maximum sum.

Hence, we can sort these (bit, gain) pairs, first in non-increasing order of gain and then in increasing order of bit (to minimize X) and choose first K bits in this sorting. This allows us to reconstruct optimum value of X.

Bonus:
Maximize \displaystyle\sum_{i = 1}^N A_i \vee X where vee denotes bitwise OR, by choosing some X with K bits set.

TIME COMPLEXITY

The time complexity is O(N*B) where B = 30 denotes the number of bits.

SOLUTIONS

Setter's Solution
#include<bits/stdc++.h>
using namespace std;
#define FIO ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0)
#define mod 1000000007
#define test ll tx; cin>>tx; while(tx--)
typedef long long int ll;
int main() {
	FIO;
	ll power[31];
	ll i;
	power[0]=1;
	for(i=1;i<=30;i++){
	    power[i]=2*power[i-1];
	}
	test{
	    ll n,k,x,i,j;
	    cin>>n>>k;
	    ll a[n];
	    for(i=0;i<n;i++){
	        cin>>a[i];
	    }
	    ll count[31];
	    for(i=0;i<31;i++){
	        count[i]=0;
	    }
	    for(i=0;i<n;i++){
	        x=a[i];
	        j=0;
	        while(x!=0){
	            if(x%2==1){
	                count[j]++;
	            }
	            x/=2;
	            j++;
	        }
	    }
	    vector<pair<ll,ll>>res;
	    x=1;
	    for(i=0;i<=30;i++){
	        res.push_back(make_pair(count[i]*x,50-i));
	        x*=2;
	    }
	    sort(res.begin(),res.end());
	    ll ans=0;
	    for(i=30;i>30-k;i--){
	        ans+=power[50-res[i].second];
	    }
	    cout<<ans<<endl;
	}
	return 0;
}
Tester's Solution
#include <bits/stdc++.h>
using namespace std;
//#pragma GCC optimize("Ofast")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
 
#define ms(s, n) memset(s, n, sizeof(s))
#define FOR(i, a, b) for (int i = (a); i < (b); ++i)
#define FORd(i, a, b) for (int i = (a) - 1; i >= (b); --i)
#define FORall(it, a) for (__typeof((a).begin()) it = (a).begin(); it != (a).end(); it++)
#define sz(a) int((a).size())
#define present(t, x) (t.find(x) != t.end())
#define all(a) (a).begin(), (a).end()
#define uni(a) (a).erase(unique(all(a)), (a).end())
#define pb push_back
#define pf push_front
#define mp make_pair
#define fi first
#define se second
#define prec(n) fixed<<setprecision(n)
#define bit(n, i) (((n) >> (i)) & 1)
#define bitcount(n) __builtin_popcountll(n)
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pi;
typedef vector<int> vi;
typedef vector<pi> vii;
const int MOD = (int) 1e9 + 7;
const int FFTMOD = 119 << 23 | 1;
const int INF = (int) 1e9 + 23111992;
const ll LINF = (ll) 1e18 + 23111992;
const ld PI = acos((ld) -1);
const ld EPS = 1e-9;
inline ll gcd(ll a, ll b) {ll r; while (b) {r = a % b; a = b; b = r;} return a;}
inline ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
inline ll fpow(ll n, ll k, int p = MOD) {ll r = 1; for (; k; k >>= 1) {if (k & 1) r = r * n % p; n = n * n % p;} return r;}
template<class T> inline int chkmin(T& a, const T& val) {return val < a ? a = val, 1 : 0;}
template<class T> inline int chkmax(T& a, const T& val) {return a < val ? a = val, 1 : 0;}
inline ull isqrt(ull k) {ull r = sqrt(k) + 1; while (r * r > k) r--; return r;}
inline ll icbrt(ll k) {ll r = cbrt(k) + 1; while (r * r * r > k) r--; return r;}
inline void addmod(int& a, int val, int p = MOD) {if ((a = (a + val)) >= p) a -= p;}
inline void submod(int& a, int val, int p = MOD) {if ((a = (a - val)) < 0) a += p;}
inline int mult(int a, int b, int p = MOD) {return (ll) a * b % p;}
inline int inv(int a, int p = MOD) {return fpow(a, p - 2, p);}
inline int sign(ld x) {return x < -EPS ? -1 : x > +EPS;}
inline int sign(ld x, ld y) {return sign(x - y);}
mt19937 mt(chrono::high_resolution_clock::now().time_since_epoch().count());
inline int mrand() {return abs((int) mt());}
inline int mrand(int k) {return abs((int) mt()) % k;}
#define db(x) cerr << "[" << #x << ": " << (x) << "] ";
#define endln cerr << "\n";

void chemthan() {
	int test; cin >> test;
	assert(1 <= test && test <= 1e3);
	while (test--) {
	    int n, k; cin >> n >> k;
	    assert(1 <= n && n <= 1e5);
	    assert(1 <= k && k <= 30);
	    vi a(n);
	    vi f(30);
	    FOR(i, 0, n) {
	        cin >> a[i];
	        assert(1 <= a[i] && a[i] <= 1e9);
	        FOR(j, 0, 30) {
	            f[j] += bit(a[i], j);
	        }
	    }
	    vector<pair<long long, int>> vals;
	    FOR(j, 0, 30) {
	        vals.pb({f[j] * (1LL << j), -j});
	    }
	    sort(all(vals)), reverse(all(vals));
	    vals.resize(k);
	    int res = 0;
	    FOR(i, 0, k) res |= 1 << -vals[i].se;
	    cout << res << "\n";
	}
}

int main(int argc, char* argv[]) {
	ios_base::sync_with_stdio(0), cin.tie(0);
	if (argc > 1) {
	    assert(freopen(argv[1], "r", stdin));
	}
	if (argc > 2) {
	    assert(freopen(argv[2], "wb", stdout));
	}
	chemthan();
	cerr << "\nTime elapsed: " << 1000 * clock() / CLOCKS_PER_SEC << "ms\n";
	return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class MAXAND18{
	//SOLUTION BEGIN
	void pre() throws Exception{}
	void solve(int TC) throws Exception{
	    int N = ni(), K = ni(), B = 30;
	    int[] A = new int[N];
	    for(int i = 0; i< N; i++)A[i] = ni();
	    //Computing gain for each bit
	    long[][] gain = new long[B][2];
	    for(int b = 0; b< B; b++)gain[b] = new long[]{b, 0L};
	    for(int i = 0; i< N; i++)
	        for(int b = 0; b< B; b++)
	            gain[b][1] += A[i]&(1L<<b);
	    Arrays.sort(gain, (long[] l1, long[] l2) -> {
	        if(l1[1] != l2[1])return Long.compare(l2[1], l1[1]);//Sorting in non-increasing order of gain
	        return Long.compare(l1[0], l2[0]);//Choosing smallest bit among all bits having same gain
	    });
	    long X = 0;
	    for(int b = 0; b< K; b++)X |= 1<<gain[b][0];
	    pn(X);
	}
	//SOLUTION END
	void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
	static boolean multipleTC = true;
	FastReader in;PrintWriter out;
	void run() throws Exception{
	    in = new FastReader();
	    out = new PrintWriter(System.out);
	    //Solution Credits: Taranpreet Singh
	    int T = (multipleTC)?ni():1;
	    pre();for(int t = 1; t<= T; t++)solve(t);
	    out.flush();
	    out.close();
	}
	public static void main(String[] args) throws Exception{
	    new MAXAND18().run();
	}
	int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
	void p(Object o){out.print(o);}
	void pn(Object o){out.println(o);}
	void pni(Object o){out.println(o);out.flush();}
	String n()throws Exception{return in.next();}
	String nln()throws Exception{return in.nextLine();}
	int ni()throws Exception{return Integer.parseInt(in.next());}
	long nl()throws Exception{return Long.parseLong(in.next());}
	double nd()throws Exception{return Double.parseDouble(in.next());}

	class FastReader{
	    BufferedReader br;
	    StringTokenizer st;
	    public FastReader(){
	        br = new BufferedReader(new InputStreamReader(System.in));
	    }

	    public FastReader(String s) throws Exception{
	        br = new BufferedReader(new FileReader(s));
	    }

	    String next() throws Exception{
	        while (st == null || !st.hasMoreElements()){
	            try{
	                st = new StringTokenizer(br.readLine());
	            }catch (IOException  e){
	                throw new Exception(e.toString());
	            }
	        }
	        return st.nextToken();
	    }

	    String nextLine() throws Exception{
	        String str = "";
	        try{   
	            str = br.readLine();
	        }catch (IOException e){
	            throw new Exception(e.toString());
	        }  
	        return str;
	    }
	}
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

16 Likes
#include <iostream>
using namespace std;
unsigned int countSetBits(int n)
    {
        unsigned int count = 0;
        while (n) {
            n &= (n - 1);
            count++;
        }
        return count;
    }
    int fun(unsigned int n)
{
    return n & (n - 1);
}
int main()
{
    long int t;
    cin>>t;
    while(t--)
    {
        long long int n,i;
        cin>>n;
        long long int a[n];
        long long int k;
        cin>>k;

        for(i=0;i<n;i++)
        {
            cin>>a[i];

        }
        long long int p=a[0];
        for(i=1;i<n;i++)
        {
            p=(p|a[i]);
        }
        long long int bits = countSetBits(p);
        long long int val=bits-k ;
        if(val<0)
        {
            val=val*(-1);
            while(val--)
            {
                p=p<<1;
                p=p+1;
            }
        }
        else
        while(val--)
        {
           p=fun(p);
        }
        long long int sum=0;
        for(i=0;i<n;i++)
        {

            sum=sum+(p&a[i]);
        }
        cout<<p<<endl;

    }



    return 0; }

why is this wrong?

1 Like

@taran_1407 @dean_student

#include <bits/stdc++.h>
#include <fstream>
#define vi vector<int>
#define pi pair<int, int>
#define vs vector<string>
#define mp make_pair
#define mi map<int, int>
#define ull unsigned long long int
#define fo(i, k, n) for(int i = k; i < n; i++)
#define pqi priority_queue<int>
#define ll long long int
#define pb push_back
#define mod 1000000007
#define INF 1e18
#define f first
#define s second

using namespace std;
bool sortcmp(pi a, pi b){
    if(a.f == b.f){
        return a.s < b.s;
    }
    return a.f > b.f;
}
int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	cout.tie(NULL);
	int t;
	cin >> t;
	while(t--){
        int n, k;
        cin >> n >> k;
        int a[n];
        for(int i = 0; i < n; i++) cin >> a[i];
        vi bcount(31, 0);
        for(int i = 0; i < n; i++){
            for(int j = 0; j <= 30; j++){
                if(a[i] & (1 << j)) bcount[j]++;
            }
        }
        vector<pi> cz(31);
        for(int i = 0; i <= 30; i++){
            cz[i] = mp(pow(2, i)*bcount[i], i);
        }
        sort(cz.begin(), cz.end(), sortcmp);    
        int ans = 0;
        for(int i = 0; i < k; i++){
            ans = ans|(1 << (cz[i].s));
        }
        cout << ans << "\n";
	}
	return 0;
}

What is the error in the above solution which makes it fail the first sub-task?

Video explanation

5 Likes

Can anyone tell me where i am going wrong ?

t = int(input())
for _ in range(t):
    sums = [0] * 35
    n,l = map(int,input().split())
    a = list(map(int,input().split()))
    for i in range(n):
        ele = a[i]
        binrep = bin(ele)[2:]
        binrep = binrep[::-1]
        length = len(binrep)
        mul = 1
        for j in range(length):
            ch = binrep[j]
            if ch == '1':
                sums[j]+=mul
            mul*=2
    unique = set()
    freq = {}
    for i in range(31):
        if sums[i]!=0:
            unique.add(sums[i])
            if freq.get(sums[i],[]) == []:
                freq[sums[i]] = [i]
            else:
                freq[sums[i]].append(i)
                
#     print(freq)

    unique = sorted(list(unique), reverse=True)
    answer = []
    for i in unique:
        count = freq[i]
        if len(count) >= l:
            count.sort()
            for i in range(l):
                answer.append(count[i])
                
        else :
            for i in count:
                answer.append(i)
        l-=len(count)
        if(l < 1):
            break
    binary_answer = ['0'] * 31
    for i in answer:
        binary_answer[i] = '1'
    result = "".join(binary_answer)
    result = result[::-1]
    result = int(result,2)
    print(result)

I passed the first subtask, but other subtasks show wrong answer.

Hey,
Tried your solution using long long int it worked fine

Is there a specific case where x exceeds 2^31 - 1?

Can anyone tell, what’s wrong with my code. Unable to pass subtask #1 but passed subtask #2.

#include <bits/stdc++.h>
using namespace std;

bool compare(pair<int,int>p1, pair<int,int>p2){
if(p1.first == p2.first){
return p1.second>p2.second;
}
else return p1.first<p2.first;
}

int main(){
int t;
cin>>t;
while(t–){
int n,k;
cin>>n>>k;
int arr[n];
for(int i=0;i<n;i++){
cin>>arr[i];
}
int bits[32]={0};
for(int i=0;i<n;i++){
int temp = arr[i];
int count=0;
while(temp){
if(temp&1) bits[count]++;
temp = temp>>1;
count++;
}
}
vector<pair<int,int>>vec;
for(int i=0;i<=31;i++){
vec.push_back(make_pair(pow(2,i)*bits[i],i));
}
sort(vec.begin(),vec.end(),compare);
int tot = vec.size();
int ans = 0;
for(int i=0;i<k;i++){
ans += pow(2,vec[tot-i-1].second);
}
cout<<ans<<’\n’;
}
return 0;
}

I thought of using long long int but then didn’t use it as I could not think of when it would exceed the limit of int.

P.S. Thanks a lot

That was completely random test case where your solution failed always use long long

https://www.codechef.com/viewsolution/34813628
please help in figuring out why its giving wrong ans

#include <bits/stdc++.h>
#include <fstream>
#define vi vector<int>
#define pi pair<int, int>
#define vs vector<string>
#define mp make_pair
#define mi map<int, int>
#define ull unsigned long long int
#define fo(i, k, n) for(int i = k; i < n; i++)
#define pqi priority_queue<int>
#define ll long long int
#define pb push_back
#define mod 1000000007
#define INF 1e18
#define f first
#define s second

using namespace std;
bool sortcmp(pi a, pi b){
    if(a.f == b.f){
        return a.s < b.s;
    }
    return a.f > b.f;
}
int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	cout.tie(NULL);
	int t;
	cin >> t;
	while(t--){
        int n, k;
        cin >> n >> k;
        int a[n];
        for(int i = 0; i < n; i++) cin >> a[i];
        vi bcount(31, 0);
        for(int i = 0; i < n; i++){
            for(int j = 0; j <= 30; j++){
                if(a[i] & (1 << j)) bcount[j]++;
            }
        }
        vector<pair<ll, ll>> cz(31);
        for(int i = 0; i <= 30; i++){
            cz[i] = mp(pow(2, i)*(1LL)*bcount[i], i);
        }
        sort(cz.begin(), cz.end(), sortcmp);
        ll ans = 0;
        for(int i = 0; i < k; i++){
            ans = ans|(1LL << (cz[i].s));
        }
        cout << ans << "\n";
	}
	return 0;
}

It still does not work

For those guys who are not passing subtask-1 only Use long long int

1 Like

https://www.codechef.com/viewsolution/34822788
can someone tell me whats wrong?

Hi guys try this video solution
Hope it helps :raised_hands::raised_hands:

1 Like

i think there’s no garuantee that map will give smallest x ,for eg let’s say two values {1,2} and {1,3}
it isnt necessary that {1,2} will always come before {1,3} in map trasversal

The sortcmp function takes care of that

I just figured out why this version failed.
I didn’t change my sorting function.

Thanks for the help.

I have just pasted the main part of my code
Please Tell me where i am wrong.

  • First gone thro all elements and inreased the count of set bits at each position
  • Then Sorted (position, value_of_set_bit(count_of_set_bit*(1<<position))) in descending order keeping in mind that if same value_of_set_bit then smaller position comes first (for minimum value of x)
  • then took the k bits with the most value.
bool f(pair<int, int> p1, pair<int, int> p2) {
    if (p1.second == p2.second) {
        return p1.first < p2.first;
    }
    return p1.second > p2.second;
}

void solve() {
    int i, j, n, k;
    cin >> n >> k;
    int a[n];
    map<int, int> hash;
    for (i = 0; i < n; ++i) {
        cin >> a[i];
        int temp = a[i];
        j = 0;
        while (temp) {
            if (temp & 1) {
                ++hash[j];
            }
            temp >>= 1;
            ++j;
        }
    }
    vector<pair<int, int>> v;
    for (auto pr : hash) {
        v.pb(make_pair(pr.first, (pr.second * (1 << pr.first))));
    }
    sort(v.begin(), v.end(), f);
    int res = 0;
    for (i = 0; i < k; ++i) {
        res += (1 << (v[i].first));
    }
    cout << res << endl;
}

Hey see my code almost same please help me

bool f(pair<int, int> p1, pair<int, int> p2) {
    if (p1.second == p2.second) {
        return p1.first < p2.first;
    }
    return p1.second > p2.second;
}

void solve() {
    int i, j, n, k;
    cin >> n >> k;
    int a[n];
    map<int, int> hash;
    for (i = 0; i < n; ++i) {
        cin >> a[i];
        int temp = a[i];
        j = 0;
        while (temp) {
            if (temp & 1) {
                ++hash[j];
            }
            temp >>= 1;
            ++j;
        }
    }
    vector<pair<int, int>> v;
    for (auto pr : hash) {
        v.pb(make_pair(pr.first, (pr.second * (1 << pr.first))));
    }
    sort(all(v), f);
    int res = 0;
    for (i = 0; i < k; ++i) {
        res += (1 << (v[i].first));
    }
    cout << res << endl;
}