PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: shanu_singroha
Tester: satyam_343
Editorialist: iceknight1093
DIFFICULTY:
2808
PREREQUISITES:
DSU, familiarity with bitwise operations
PROBLEM:
Given N non-negative integers, find the maximum possible value of f(S_1)\oplus f(S_2), where f(S) denotes the bitwise OR of elements of S, and [S_1, S_2] forms a partition of the given integers.
EXPLANATION:
Let’s try to build the answer greedily, from the largest bit down to the smallest.
Let h be the highest bit such that there exists an element with it set, and an element without it set.
Put all elements without it into S_1, and all elements with it into S_2.
Now, we’ve guaranteed that:
- S_1 and S_2 are non-empty
- The answer is at least 2^h
No matter what we do with the lower bits, it’s important to keep bit h set in the answer.
In particular, this means we can only move elements from S_1 into S_2, and not vice versa; because all the elements with bit h set must be in the same set of the partition.
We can now try and set lower bits greedily too.
For each bit b from h-1 down to 0, there are four cases.
Case 1: b is not present in S_1, and not present in S_2.
This means b is not present in the array at all, so we just ignore it.
Case 2: b is not present in S_1, but is present in S_2.
Notice that this means b is already set in the answer (and that won’t change in the future, because we aren’t moving elements from S_2 to S_1).
Once again, we have to do nothing here.
Case 3: b is present in S_1, but is not present in S_2.
Let x_1, x_2, \ldots, x_k be the elements from S_1 that have b set.
Then, if we want b to be present in the answer, either all the x_i should lie in S_1 (which is already the case); or all the x_i should lie in S_2.
In particular, if we move one x_i to S_2, we should then move them all.
However, note that if this results in S_1 becoming empty, we can’t perform the move.
Let’s set this aside for now; we’ll come back to it.
Case 4: b is present in both S_1 and S_2.
Once again, let x_1, \ldots, x_k be the elements from S_1 that have b set.
Now, for b to be set in the answer, we must move all the x_i from S_1 to S_2.
However, now that we’re forced to move things, notice that some elements might have constraints from case 3 via higher bits; so more elements will have to move.
In particular, we’ll need to move anything that’s linked to any of the x_i.
If this movement results in S_1 becoming empty, we can’t do it.
All of this can be nicely represented by maintaining a DSU of the elements.
In particular, keep a DSU of size N:
- Cases 1 and 2 do nothing.
- For case 3, merge the components of all the x_i.
This can be done quickly by uniting x_i and x_{i+1} for all i. - For case 4, we need to check if all merging all the x_i causes their union to equal the entirety of S_1.
This can be done in a variety of ways, for example:- Compute the set of representatives of the components of the x_i, and check if the sums of the sizes of these components equals the size of S_i; or
- Directly perform the merges, and if they equal S_1 in the end, rollback the changes
Either way, it’s possible to check for this quickly, which is all we need.
In the end, we know sets S_1 and S_2; so actually computing the answer is trivial from there.
In total, our worst case is performing \mathcal{O}(30\cdot N) merge operations on the DSU, which should run pretty fast (especially since there are only N vertices, so most of the merge operations won’t actually do anything).
TIME COMPLEXITY
\mathcal{O}(30\cdot N\alpha(N)) per testcase.
CODE:
Author's code (C++)
#include<bits/stdc++.h>
using namespace std;
#define fo(i,n) for( i=0;i<n;i++)
#define foA(i,a,b) for(i=a;i<=b;i++)
#define foD(i,a,b) for( i=a;i>=b;i--)
#define int long long
#define deb(x) cout << #x << "=" << x << endl
#define deb2(x, y) cout << #x << "=" << x << "," << #y << "=" << y << endl
#define deb3(x, y, z) cout << #x << "=" << x << "," << #y << "=" << y << "," << #z << "=" << z << endl
#define pb push_back
#define mp make_pair
#define all(x) x.begin(), x.end()
#define clr(x) memset(x, 0, sizeof(x))
#define sortall(x) sort(all(x))
#define el cout<<"\n"
#define max3(a,b,c) max(max((a),(b)),(c))
#define max4(a,b,c,d) max(max((a),(b)),max((c),(d)))
#define min3(a,b,c) min(min((a),(b)),(c))
#define min4(a,b,c,d) min(min((a),(b)),min((c),(d)))
/////////////////////
int dx[] = {0, 0, -1, 1, 1, 1, -1, -1};
int dy[] = {1, -1, 0, 0, -1, 1, 1, -1};
//////////////////for vectors
# define maxv(a) (*max_element(a.begin(),a.end()))
# define minv(a) (*min_element(a.begin(),a.end()))
# define sumvi(a) (accumulate(a.begin(),a.end(),0LL))
# define sumvd(a) (accumulate(a.begin(),a.end(),double(0)))
# define printv(v) {auto i = v;for(auto j : i) cout<< j << ' ';cout << "\n";}
# define printvv(v) {auto i = v;for(auto j : i) {for(auto k : j) cout<< k << ' ';cout << "\n";}}
# define prints(s) {auto i = s;for(auto j : i) cout<< j << ' ';cout << "\n";}
# define printm(m) {auto i = m;for(auto j : i) cout<< j.first << ':' << j.second << ' ';cout << "\n";}
/////////////////////////
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<pii> vpii;
typedef vector<vi> vvi;
/////////////////////////
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
uniform_int_distribution<int> uid(0, lim - 1);
return uid(rang);
}
/////////////////////
const int inf = 1e9;
const int INF = 1e18;
const int mod = 1000000007;
// const int mod = 998244353;
const int N = 3e5 + 5, M = N;
////////////////
int parent[N];
int sizeo[N];
int tempparent[N];
int temppsizeo[N];
int findop(int v) {
if (v == parent[v])
return v;
return parent[v] = findop(parent[v]);
}
void setunionop(int a , int b) {
a = findop(a);
b = findop(b);
if (a == b)
return;
else {
if (sizeo[a] < sizeo[b])
swap(a, b);
parent[b] = a;
sizeo[a] += sizeo[b];
}
}
void initialize(int n) {
for (int i = 1; i < n + 1 ; i++) {
sizeo[i] = 1;
parent[i] = i;
}
}
void solve() {
int i, j, n, m;
cin >> n;
vector<int> arr(n + 1);
fo(i, n) cin >> arr[i + 1];
sort(arr.begin() , arr.end());
if (arr[1] == arr[n]) {
cout << 0 << "\n";
return;
}
else if (n == 2) {
cout << (arr[1] ^ arr[2]) << "\n";
return;
}
// printv(arr);
int highestbit = 0;
for (int j = 30; j >= 0 ; j--) {
int count = 0 ;
for (int i = 0 ; i < n ; i++) {
if ( (1ll << j) & arr[i + 1]) {
count++;
}
}
if (count > 0 && count < n) {
highestbit = j;
break;
}
}
vector<int> whichset(n + 1);
initialize(n);
for (int i = 0 ; i < n ; i++) {
if ( (1ll << highestbit) & arr[i + 1])
whichset[i + 1] = 2;
else whichset[i + 1] = 1;
}
for (int j = highestbit - 1 ; j >= 0 ; j--) {
int activeintwo = 0;
int activeinone = 0;
for (int i = 0 ; i < n ; i++) {
if ( (1ll << j) & arr[i + 1]) {
if (whichset[i + 1] == 1) {
activeinone = 1;
}
else activeintwo = 1;
}
}
if (activeintwo == 0 && activeinone == 0) {
continue;
}
else if ( activeintwo == 1 && activeinone == 0) {
continue;
}
else if (activeintwo == 0 && activeinone == 1) {
vector<int> members;
for (int i = 0 ; i < n ; i++) {
if ( (1ll << j) & arr[i + 1]) {
members.pb(i + 1);
}
}
for (int i = 1 ; i < members.size() ; i++) {
setunionop( members[i - 1] , members[i]);
}
}
else {
vector<int> members;
for (int i = 0 ; i < n ; i++) {
if ( (1ll << j) & arr[i + 1]) {
if ( whichset[i + 1] == 1)
members.pb(i + 1);
}
}
for (int i = 1 ; i <= n ; i++) {
int a = findop(i);
tempparent[i] = a;
temppsizeo[i] = sizeo[a];
}
for (int i = 1 ; i < members.size() ; i++) {
setunionop( members[i - 1] , members[i]);
}
int countinsetone = 0;
fo(i, n) {
if (whichset[i + 1] == 1) countinsetone++;
}
int sizeofmembers = sizeo[findop(members[0])];
// deb2(sizeofmembers , countinsetone);
if (countinsetone == sizeofmembers) {
for (int i = 1 ; i <= n ; i++) {
parent[i] = tempparent[i];
sizeo[i] = temppsizeo[i];
}
}
else {
fo(i, n) {
if ( findop(i + 1) == findop( members[0])) {
whichset[i + 1] = 2;
}
}
}
// printv(members);
}
// printv(whichset);
// fo(i, n + 1) cout << findop(i) << " ";
// cout << "\n";
}
int or1 = 0 ;
int or2 = 0 ;
int ans = 0;
fo(i, n) {
if (whichset[i + 1] == 1) {
or1 = or1 | arr[i + 1];
}
else
or2 = or2 | arr[i + 1];
}
ans = or1 ^ or2;
cout << ans << "\n";
}
int32_t main() {
ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0);
srand(chrono::high_resolution_clock::now().time_since_epoch().count());
int t = 1;
cin >> t;
while (t--) {
solve();
}
return 0;
}
Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#ifndef ONLINE_JUDGE
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);
#endif
#define ll long long
/*
------------------------Input Checker----------------------------------
*/
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);
assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}
if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}
return x;
} else {
assert(false);
}
}
}
string readString(int l,int r,char endd){
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
return readString(l,r,'\n');
}
string readStringSp(int l,int r){
return readString(l,r,' ');
}
/*
------------------------Main code starts here----------------------------------
*/
const ll MOD=1e9+7;
vector<ll> readv(ll n,ll l,ll r){
vector<ll> a;
ll x;
for(ll i=1;i<n;i++){
x=readIntSp(l,r);
a.push_back(x);
}
x=readIntLn(l,r);
a.push_back(x);
return a;
}
const ll MAX=3000300;
ll sum_n=0;
void dbug(vector<ll> a){
for(auto t:a){
cout<<t<<" ";
}
cout<<endl;
}
ll binpow(ll a,ll b,ll MOD){
ll ans=1;
a%=MOD;
while(b){
if(b&1)
ans=(ans*a)%MOD;
b/=2;
a=(a*a)%MOD;
}
return ans;
}
ll inverse(ll a,ll MOD){
return binpow(a,MOD-2,MOD);
}
ll gt(ll n,ll freq,ll k){
ll pw=(binpow(2,k,MOD-1)*freq)%(MOD-1);
ll now=(binpow(n,pw+1,MOD)-binpow(n,freq,MOD)+MOD)*inverse(n-1,MOD);
now%=MOD;
return now;
}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
bool check_distinct(vector<ll> a){
sort(a.begin(),a.end());
ll n=a.size();
for(ll i=1;i<n;i++){
assert(a[i]!=a[i-1]);
}
return true;
}
ll g(ll x){
return x;
}
struct dsu{
vector<ll> parent,height;
ll n,len;
dsu(ll n){
this->n=n;
parent.resize(n);
height.resize(n);
len=n;
for(ll i=0;i<n;i++){
parent[i]=i;
height[i]=1;
}
}
ll find_set(ll x){
return find_set(x,x);
}
ll find_set(ll x,ll orig){
if(parent[x]==x){
return x;
}
parent[orig]=find_set(parent[x]);
return parent[orig];
}
void union_set(ll u,ll v){
u=find_set(u),v=find_set(v);
if(u==v){
return;
}
len--;
if(height[u]<height[v]){
swap(u,v);
}
parent[v]=u;
height[u]+=height[v];
}
ll getv(ll l){
l=find_set(l);
return height[l];
}
};
void solve(){
ll n=readIntLn(2,g(2e5));
sum_n+=n;
vector<ll> a=readv(n,0,g(1<<30)-1);
ll ans=0,node=-1;
dsu global(n);
for(ll b=29;b>=0;b--){
vector<ll> on;
for(ll i=0;i<n;i++){
if(a[i]&(1<<b)){
on.push_back(i);
if(node==-1){
node=i;
}
}
}
dsu cur=global;
for(auto it:on){
cur.union_set(on[0],it);
}
if(cur.getv(max(0ll,node))!=n){
global=cur;
}
}
ll l=0,r=0;
for(ll i=0;i<n;i++){
if(global.find_set(i)==global.find_set(max(0ll,node))){
l|=a[i];
}
else{
r|=a[i];
}
}
ans=l^r;
cout<<ans<<"\n";
return;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
freopen("error.txt", "w", stderr);
#endif
ll test_cases=readIntLn(1,g(2e4));
while(test_cases--){
solve();
}
assert(sum_n<=g(2e5));
assert(getchar()==-1);
return 0;
}