ANDPREF - Editorial

PROBLEM LINK:

Practice
Contest

Author: Ashish Gupta
Tester: Rahul Dugar
Editorialist: Aman Dwivedi

DIFFICULTY:

Medium

PREREQUISITES:

SOS DP, Bitmasking

PROBLEM:

You are given a sequence of N integers. Your want to reorder the sequence in such a way that in the resulting sequence, the sum of the values of bitwise AND for all prefixes of the sequence is maximum possible.

QUICK EXPLANATION:

Deducing the type of the problem is important. If we are able to deduce that it is SOS-DP, we will be able to get our answer. Having prior knowledge of SOS-DP can be beneficial here.

  • For each integer say x, we will calculate the number of supermasks of x.

  • We will use SOS-DP to pre-compute the count of supermasks.

Finally, for each mask we will iterate over all bits in order from left to right. For each bit, if the bit is set we will add its contribution to the child and move to the next bit.

EXPLANATION:

The first basic observation that we can make is that the Maximum AND we can obtain, if we do AND of two numbers is the minimum of those numbers.

Why

Whenever we do AND, operation no new bit is set. A bit which was unset remains unset and the bit was set may become unset remains unset. In the best case the number can remain as it is. As, the maximum number will have at least one bit which is going to be unset, so the best case answer is minimum of both.

Now, for each integer, we will calculate the number of supermasks of x, for this we will use SOS-DP to pre-compute it.

Code Snippet with comments
// Initially freq table is initialized with count of numbers in a given array

void SOS(){
  // Iterating from bit o to maximum bit
  for (int bit = 0; bit < max_bit; bit++) {
    for (int mask = 0; mask < (1 << max_bit); mask++) {
      // If mask has '0' at this position
      if (~mask & (1 << bit) ) {
        //Number of ways to get mask += Number of ways to get supermask
        freq [mask] += freq[mask ^ (1 << bit)];
      }
    }
  }
}

Initially the SOS-DP table is initialized with the count of numbers in the given array. Using the SOS-DP, we iterate from bit 0 to max_bit-1.

Now let’s see what are the ways we can obtain the given mask. Since, no new bit is set in AND operation, so we can say that the bit with value 0 will remain 0, and the bit with value 1 can get reduced to 0 or may remain 1. Hence, the given mask can only be obtained by super-masks i.e the numbers which have 1 at all places where mask has, and have 1's at other places where our mask has 0.

Hence, if mask has 0 at current position, we add number of ways to obtain super-mask i.e (freq[super-mask] to freq[mask]).

How can be obtain our final answer now ?

Well take a look at the below code snippet and then we will see.

Code Snippet with comments
int ans() {
  // Iterating over all possible values of mask
  for (int mask = (1ll << bits) - 1; mask >= 0; mask--) {
    // Iterating over bits from left to right
    for(int bit = 0; bit < bits; bit++) {
      // If the current bit is set
      if(mask & (1 << bit)) {
        // We will add its contribution to its child and take the maximum
        int child = mask ^ (1<<bit);
        dp[child] = max(dp[child], dp[mask] + (freq[child] - freq[mask]) * child);
      }
    }
  }
}

Here, dp[mask] represents the maximum sum of all the prefix we can obtain till a given bit, such that the current value of the prefix of AND is mask. Whenever we find a bit that is set in the mask, we can add its contribution to its child and take maximum of those. That’s the basic DP, we can do to find out our final answer.

You can also check out this video for SOS-DP.

TIME COMPLEXITY:

O(M*2^M) per testcase, where M is number of bits

SOLUTIONS:

Setter
#include <bits/stdc++.h>
using namespace std;
 
#define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define endl "\n"
#define int long long
 
const int N = (1LL << 21);
 
int n, m;
int a[N], cnt[N], cache[N];
 
int get(int x)
{
	return (x ^ ((1 << m) - 1));
}
 
int dp(int mask)
{
	int &ans = cache[mask];
	if(ans != -1)
		return ans;
	ans = 0;
	for(int i = 0; i < m; i++)
	{
		if(mask >> i & 1)
			continue;
		int cur = cnt[get(mask | (1 << i))] * (1 << i) + dp(mask | (1 << i));
		ans = max(ans, cur);
	}
	return ans;
}
 
int32_t main()
{
	IOS;
	int t;
	cin >> t;
	while(t--)
	{
		m = 0;
		cin >> n;
		while((1 << m) < n)
			m++;
		assert(m < N);
		for(int i = 0; i < (1 << m); i++)
			cache[i] = -1, cnt[i] = 0;
		for(int i = 1; i <= n; i++)
		{
			cin >> a[i];
			cnt[get(a[i])]++;
			assert(a[i] < n);
		}
		for(int i = 0; i < m; i++)
		{
			for(int mask = 0; mask < (1 << m); mask++)
			{
				if(mask >> i & 1)
					cnt[mask] += cnt[mask ^ (1 << i)];
			}
		}
		int ans = dp(0);
		cout << ans << endl;
	}
	return 0;
}
Tester
#include <bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace __gnu_cxx;
#ifndef rd
#define trace(...)
#define endl '\n'
#endif
#define pb push_back
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double f80;
#define double long double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define sz(x) ((long long)x.size())
#define fr(a,b,c) for(int a=b; a<=c; a++)
#define rep(a,b,c) for(int a=b; a<c; a++)
#define trav(a,x) for(auto &a:x)
#define all(con) con.begin(),con.end()
const ll infl=0x3f3f3f3f3f3f3f3fLL;
const int infi=0x3f3f3f3f;
//const int mod=998244353;
const int mod=1000000007;
typedef vector<int> vi;
typedef vector<ll> vl;
 
typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> oset;
auto clk=clock();
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
	uniform_int_distribution<int> uid(0,lim-1);
	return uid(rang);
}
int powm(int a, int b) {
	int res=1;
	while(b) {
		if(b&1)
			res=(res*a)%mod;
		a=(a*a)%mod;
		b>>=1;
	}
	return res;
}
 
long long readInt(long long l, long long r, char endd) {
	long long x=0;
	int cnt=0;
	int fi=-1;
	bool is_neg=false;
	while(true) {
		char g=getchar();
		if(g=='-') {
			assert(fi==-1);
			is_neg=true;
			continue;
		}
		if('0'<=g&&g<='9') {
			x*=10;
			x+=g-'0';
			if(cnt==0) {
				fi=g-'0';
			}
			cnt++;
			assert(fi!=0 || cnt==1);
			assert(fi!=0 || is_neg==false);
 
			assert(!(cnt>19 || ( cnt==19 && fi>1) ));
		} else if(g==endd) {
			if(is_neg) {
				x=-x;
			}
			assert(l<=x&&x<=r);
			return x;
		} else {
			assert(false);
		}
	}
}
string readString(int l, int r, char endd) {
	string ret="";
	int cnt=0;
	while(true) {
		char g=getchar();
		assert(g!=-1);
		if(g==endd) {
			break;
		}
		cnt++;
		ret+=g;
	}
	assert(l<=cnt&&cnt<=r);
	return ret;
}
long long readIntSp(long long l, long long r) {
	return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
	return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
	return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
	return readString(l,r,' ');
}
 
 
 
 
int sum_n=0;
int dp[1<<20],cnt[1<<20];
void solve() {
	int n=readIntLn(1,(1LL<<20));
	sum_n+=n;
	assert(sum_n<=(1LL<<20));
	int m=0;
	while(n>(1<<m))
		m++;
	memset(cnt,0,sizeof(int)*(1<<m));
	memset(dp,0,sizeof(int)*(1<<m));
	fr(i,1,n) {
		int te;
		if(i!=n)
			te=readIntSp(0,n-1);
		else
			te=readIntLn(0,n-1);
		cnt[te]++;
	}
	for(int j=0; j<m; j++)
		for(int i=0; i<(1<<m); i++)
			if((i>>j)&1)
				cnt[i^(1<<j)]+=cnt[i];
	dp[(1<<m)-1]=cnt[(1<<m)-1]*((1<<m)-1);
	for(int i=(1<<m)-1; i>=0; i--) {
		for(int j=0; j<m; j++) {
			if((i>>j)&1) {
				dp[i^(1<<j)]=max(dp[i^(1<<j)],dp[i]+(cnt[i^(1<<j)]-cnt[i])*(i^(1<<j)));
			}
		}
	}
	cout<<dp[0]<<endl;
}
 
signed main() {
	ios_base::sync_with_stdio(0),cin.tie(0);
	srand(chrono::high_resolution_clock::now().time_since_epoch().count());
	cout<<fixed<<setprecision(10);
	int t=readIntLn(1,1000);
//	int t;
//	cin>>t;
	fr(i,1,t)
		solve();
#ifdef rd
	cerr<<endl<<endl<<endl<<"Time Elapsed: "<<((double)(clock()-clk))/CLOCKS_PER_SEC<<endl;
#endif
}
 
Editorialist
#include<bits/stdc++.h>
using namespace std;
 
#define int long long
 
void solve(){
  int n; cin>>n;
  int bits=0;
 
  while((1ll<<bits)<n) bits++;
 
  vector <int> freq(1<<bits),dp(1<<bits);
 
  int last_count=0;
 
  for(int i=0;i<(1<<bits);i++){
    freq[i]=0;
    dp[i]=0;
  }
 
  vector <int> a(n);
 
  for(int i=0;i<n;i++){
    cin>>a[i];
    freq[a[i]]++;
    if(a[i]==(1<<bits)-1) last_count++;
  }
 
  for(int bit=0;bit<bits;bit++){
    for(int mask=0;mask<(1<<bits);mask++){
      if(~mask & (1<<bit)){
        freq[mask]+=freq[mask^(1<<bit)];
      }
    }
  }
 
  dp[(1ll<<bits)-1]=last_count*((1ll<<bits)-1);
 
  for(int mask=(1ll<<bits)-1;mask>=0;mask--){
    for(int bit=0;bit<bits;bit++){
      if(mask & (1<<bit)){
        dp[mask ^ (1<<bit)] = max(dp[mask ^ (1<<bit)],dp[mask]+(freq[mask ^ (1<<bit)]-freq[mask])*(mask ^ (1<<bit)));
      }
    }
  }
 
  cout<<dp[0]<<"\n";
}
 
int32_t main(){
  ios_base::sync_with_stdio(0);
  cin.tie(0);
 
  int t; cin>>t;
  while(t--){
    solve();
  }
 
return 0;
}

VIDEO EDITORIAL:

4 Likes