ORXOR2 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: shanu_singroha
Tester: satyam_343
Editorialist: iceknight1093

DIFFICULTY:

2808

PREREQUISITES:

DSU, familiarity with bitwise operations

PROBLEM:

Given N non-negative integers, find the maximum possible value of f(S_1)\oplus f(S_2), where f(S) denotes the bitwise OR of elements of S, and [S_1, S_2] forms a partition of the given integers.

EXPLANATION:

Let’s try to build the answer greedily, from the largest bit down to the smallest.

Let h be the highest bit such that there exists an element with it set, and an element without it set.
Put all elements without it into S_1, and all elements with it into S_2.
Now, we’ve guaranteed that:

  • S_1 and S_2 are non-empty
  • The answer is at least 2^h

No matter what we do with the lower bits, it’s important to keep bit h set in the answer.
In particular, this means we can only move elements from S_1 into S_2, and not vice versa; because all the elements with bit h set must be in the same set of the partition.

We can now try and set lower bits greedily too.
For each bit b from h-1 down to 0, there are four cases.


Case 1: b is not present in S_1, and not present in S_2.
This means b is not present in the array at all, so we just ignore it.


Case 2: b is not present in S_1, but is present in S_2.
Notice that this means b is already set in the answer (and that won’t change in the future, because we aren’t moving elements from S_2 to S_1).
Once again, we have to do nothing here.


Case 3: b is present in S_1, but is not present in S_2.
Let x_1, x_2, \ldots, x_k be the elements from S_1 that have b set.
Then, if we want b to be present in the answer, either all the x_i should lie in S_1 (which is already the case); or all the x_i should lie in S_2.
In particular, if we move one x_i to S_2, we should then move them all.
However, note that if this results in S_1 becoming empty, we can’t perform the move.

Let’s set this aside for now; we’ll come back to it.


Case 4: b is present in both S_1 and S_2.
Once again, let x_1, \ldots, x_k be the elements from S_1 that have b set.
Now, for b to be set in the answer, we must move all the x_i from S_1 to S_2.

However, now that we’re forced to move things, notice that some elements might have constraints from case 3 via higher bits; so more elements will have to move.

In particular, we’ll need to move anything that’s linked to any of the x_i.
If this movement results in S_1 becoming empty, we can’t do it.


All of this can be nicely represented by maintaining a DSU of the elements.
In particular, keep a DSU of size N:

  • Cases 1 and 2 do nothing.
  • For case 3, merge the components of all the x_i.
    This can be done quickly by uniting x_i and x_{i+1} for all i.
  • For case 4, we need to check if all merging all the x_i causes their union to equal the entirety of S_1.
    This can be done in a variety of ways, for example:
    • Compute the set of representatives of the components of the x_i, and check if the sums of the sizes of these components equals the size of S_i; or
    • Directly perform the merges, and if they equal S_1 in the end, rollback the changes

Either way, it’s possible to check for this quickly, which is all we need.
In the end, we know sets S_1 and S_2; so actually computing the answer is trivial from there.

In total, our worst case is performing \mathcal{O}(30\cdot N) merge operations on the DSU, which should run pretty fast (especially since there are only N vertices, so most of the merge operations won’t actually do anything).

TIME COMPLEXITY

\mathcal{O}(30\cdot N\alpha(N)) per testcase.

CODE:

Author's code (C++)
#include<bits/stdc++.h>
using namespace std;

#define fo(i,n) for( i=0;i<n;i++)
#define foA(i,a,b) for(i=a;i<=b;i++)
#define foD(i,a,b) for( i=a;i>=b;i--)
#define int long long
#define deb(x) cout << #x << "=" << x << endl
#define deb2(x, y) cout << #x << "=" << x << "," << #y << "=" << y << endl
#define deb3(x, y, z) cout << #x << "=" << x << "," << #y << "=" << y << "," << #z << "=" << z << endl
#define pb push_back
#define mp make_pair
#define all(x) x.begin(), x.end()
#define clr(x) memset(x, 0, sizeof(x))
#define sortall(x) sort(all(x))
#define el cout<<"\n"
#define max3(a,b,c) max(max((a),(b)),(c))
#define max4(a,b,c,d) max(max((a),(b)),max((c),(d)))
#define min3(a,b,c) min(min((a),(b)),(c))
#define min4(a,b,c,d) min(min((a),(b)),min((c),(d)))
/////////////////////
int dx[] = {0, 0, -1, 1, 1, 1, -1, -1};
int dy[] = {1, -1, 0, 0, -1, 1, 1, -1};

//////////////////for vectors
# define maxv(a) (*max_element(a.begin(),a.end()))
# define minv(a) (*min_element(a.begin(),a.end()))
# define sumvi(a) (accumulate(a.begin(),a.end(),0LL))
# define sumvd(a) (accumulate(a.begin(),a.end(),double(0)))

# define printv(v) {auto i = v;for(auto j : i) cout<< j << ' ';cout << "\n";}
# define printvv(v) {auto i = v;for(auto j : i) {for(auto k : j) cout<< k << ' ';cout << "\n";}}
# define prints(s) {auto i = s;for(auto j : i) cout<< j << ' ';cout << "\n";}
# define printm(m) {auto i = m;for(auto j : i) cout<< j.first << ':' << j.second << ' ';cout << "\n";}
/////////////////////////
typedef pair<int, int>  pii;
typedef vector<int>   vi;
typedef vector<pii>   vpii;
typedef vector<vi>    vvi;
/////////////////////////
mt19937_64 rang(chrono::high_resolution_clock::now().time_since_epoch().count());
int rng(int lim) {
  uniform_int_distribution<int> uid(0, lim - 1);
  return uid(rang);
}
/////////////////////
const int inf = 1e9;
const int INF = 1e18;
const int mod = 1000000007;
// const int mod = 998244353;
const int N = 3e5 + 5, M = N;
////////////////

int parent[N];
int sizeo[N];
int tempparent[N];
int temppsizeo[N];

int findop(int v) {
  if (v == parent[v])
    return v;
  return parent[v] = findop(parent[v]);
}
void setunionop(int a , int b) {
  a = findop(a);
  b = findop(b);
  if (a == b)
    return;
  else {
    if (sizeo[a] < sizeo[b])
      swap(a, b);
    parent[b] = a;
    sizeo[a] += sizeo[b];
  }
}
void initialize(int n) {
  for (int i = 1; i < n + 1 ; i++) {
    sizeo[i] = 1;
    parent[i] = i;
  }
}
void solve() {
  int i, j, n, m;
  cin >> n;
  vector<int> arr(n + 1);

  fo(i, n) cin >> arr[i + 1];

  sort(arr.begin() , arr.end());

  if (arr[1] == arr[n]) {
    cout << 0 << "\n";
    return;
  }
  else if (n == 2) {
    cout << (arr[1] ^ arr[2]) << "\n";
    return;
  }


  // printv(arr);
  int highestbit = 0;
  for (int j = 30; j >= 0  ; j--) {
    int count = 0 ;
    for (int i = 0 ; i < n ; i++) {
      if ( (1ll << j) & arr[i + 1]) {
        count++;
      }
    }
    if (count > 0 && count < n) {
      highestbit = j;
      break;
    }
  }


  vector<int> whichset(n + 1);
  initialize(n);

  for (int i = 0 ; i < n ; i++) {
    if ( (1ll << highestbit) & arr[i + 1])
      whichset[i + 1] = 2;
    else whichset[i + 1] = 1;
  }

  for (int j = highestbit - 1 ; j >= 0 ; j--) {
    int activeintwo = 0;
    int activeinone = 0;
    for (int i = 0 ; i < n ; i++) {
      if ( (1ll << j) & arr[i + 1]) {
        if (whichset[i + 1] == 1) {
          activeinone =  1;
        }
        else activeintwo = 1;
      }
    }
    if (activeintwo == 0 && activeinone == 0) {
      continue;
    }
    else if ( activeintwo == 1 && activeinone == 0) {
      continue;
    }
    else if (activeintwo == 0 && activeinone == 1) {
      vector<int> members;
      for (int i = 0 ; i < n ; i++) {
        if ( (1ll << j) & arr[i + 1]) {
          members.pb(i + 1);
        }
      }
      for (int i = 1 ; i < members.size() ; i++) {
        setunionop( members[i - 1] , members[i]);
      }
    }
    else {
      vector<int> members;
      for (int i = 0 ; i < n ; i++) {
        if ( (1ll << j) & arr[i + 1]) {
          if ( whichset[i + 1] == 1)
            members.pb(i + 1);

        }
      }
      for (int i = 1 ; i <= n ; i++) {
        int a = findop(i);
        tempparent[i] = a;
        temppsizeo[i] = sizeo[a];
      }
      for (int i = 1 ; i < members.size() ; i++) {
        setunionop( members[i - 1] , members[i]);
      }
      int countinsetone = 0;
      fo(i, n) {
        if (whichset[i + 1] == 1) countinsetone++;
      }
      int sizeofmembers = sizeo[findop(members[0])];
      // deb2(sizeofmembers , countinsetone);
      if (countinsetone == sizeofmembers) {
        for (int i = 1 ; i <= n ; i++) {
          parent[i] = tempparent[i];
          sizeo[i] = temppsizeo[i];
        }
      }
      else {
        fo(i, n) {
          if ( findop(i + 1) == findop( members[0])) {
            whichset[i + 1] = 2;
          }
        }
      }
      // printv(members);

    }
    // printv(whichset);
    // fo(i, n + 1) cout << findop(i) << " ";
    // cout << "\n";
  }

  int or1 = 0 ;
  int or2 = 0 ;

  int ans = 0;

  fo(i, n) {
    if (whichset[i + 1] == 1) {
      or1 = or1 | arr[i + 1];
    }
    else
      or2 = or2 | arr[i + 1];
  }

  ans = or1 ^ or2;
  cout << ans << "\n";
}
int32_t main() {
  ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  srand(chrono::high_resolution_clock::now().time_since_epoch().count());
  int t = 1;
  cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}
Tester's code (C++)
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;
using namespace std;
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif 
#define ll long long 
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){   
        char g=getchar();
        if(g=='-'){  
            assert(fi==-1);     
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
const ll MOD=1e9+7;
vector<ll> readv(ll n,ll l,ll r){
    vector<ll> a;
    ll x;
    for(ll i=1;i<n;i++){  
        x=readIntSp(l,r);  
        a.push_back(x);   
    }
    x=readIntLn(l,r);
    a.push_back(x);
    return a;  
}
const ll MAX=3000300;   
ll sum_n=0;     
void dbug(vector<ll> a){
    for(auto t:a){
        cout<<t<<" ";
    }   
    cout<<endl; 
}
ll binpow(ll a,ll b,ll MOD){
    ll ans=1;
    a%=MOD;
    while(b){
        if(b&1)
            ans=(ans*a)%MOD;
        b/=2;  
        a=(a*a)%MOD;
    }
    return ans;
}
ll inverse(ll a,ll MOD){
    return binpow(a,MOD-2,MOD);
}
ll gt(ll n,ll freq,ll k){
    ll pw=(binpow(2,k,MOD-1)*freq)%(MOD-1);
    ll now=(binpow(n,pw+1,MOD)-binpow(n,freq,MOD)+MOD)*inverse(n-1,MOD);
    now%=MOD;
    return now;
}
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
bool check_distinct(vector<ll> a){
    sort(a.begin(),a.end());
    ll n=a.size();
    for(ll i=1;i<n;i++){
        assert(a[i]!=a[i-1]);
    }
    return true;
}
ll g(ll x){
    return x;  
}
struct dsu{
    vector<ll> parent,height;
    ll n,len;
    dsu(ll n){
        this->n=n;
        parent.resize(n);
        height.resize(n);
        len=n;
        for(ll i=0;i<n;i++){
            parent[i]=i;
            height[i]=1;
        }
    }
    ll find_set(ll x){
        return find_set(x,x); 
    }
    ll find_set(ll x,ll orig){
        if(parent[x]==x){
            return x;
        }
        parent[orig]=find_set(parent[x]);
        return parent[orig]; 
    }
    void union_set(ll u,ll v){
        u=find_set(u),v=find_set(v);
        if(u==v){
            return;
        }
        len--; 
        if(height[u]<height[v]){
            swap(u,v); 
        }
        parent[v]=u;
        height[u]+=height[v]; 
    }
    ll getv(ll l){
        l=find_set(l);
        return height[l]; 
    }
};
void solve(){  
    ll n=readIntLn(2,g(2e5));
    sum_n+=n;
    vector<ll> a=readv(n,0,g(1<<30)-1);
    ll ans=0,node=-1;
    dsu global(n);
    for(ll b=29;b>=0;b--){
        vector<ll> on;
        for(ll i=0;i<n;i++){
            if(a[i]&(1<<b)){
                on.push_back(i);
                if(node==-1){
                    node=i;
                }
            }
        }
        dsu cur=global;
        for(auto it:on){
            cur.union_set(on[0],it);
        }
        if(cur.getv(max(0ll,node))!=n){
            global=cur;
        }
    }
    ll l=0,r=0; 
    for(ll i=0;i<n;i++){
        if(global.find_set(i)==global.find_set(max(0ll,node))){
            l|=a[i];
        }
        else{
            r|=a[i];
        }
    }
    ans=l^r;
    cout<<ans<<"\n";
    return;  
}  
int main(){
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                                
    freopen("output.txt", "w", stdout);  
    freopen("error.txt", "w", stderr);                          
    #endif         
    ll test_cases=readIntLn(1,g(2e4)); 
    while(test_cases--){
        solve();
    }
    assert(sum_n<=g(2e5)); 
    assert(getchar()==-1);
    return 0;
}
1 Like