BINARYGA - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Le Duc Minh, Nguyen Anh Quân
Testers: Shubham Anand Jain, Aryan
Editorialist: Nishank Suresh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

Backtracking

PROBLEM:

You are given 16n positive integers a_1, a_2, \dotsc, a_{16n}. You can permute this sequence as you please.
Let
x = (a_1\oplus a_2)\otimes (a_3\oplus a_4)\otimes \cdots \otimes(a_{8n-1}\oplus a_{8n}) and
y = (a_{8n+1}\oplus a_{8n+2})\otimes (a_{8n+3}\oplus a_{8n+4})\otimes \cdots \otimes(a_{16n-1}\oplus a_{16n})
where \oplus denotes bitwise XOR and \otimes denotes bitwise AND.
Find the maximum value of x-y.

EXPLANATION:

All numbers are treated as binary strings of length 30, possibly having leasing zeroes.

Lemma: If n\geq 2, given any permutation of the 16n numbers, it is possible to permute the last 8n of them so that y = 0.

Proof

Let f(n) be the number of ways to permute a_{8n+1}, a_{8n+2}, \cdots, a_{16n}, ignoring order of \otimes operations within a pair and order of \oplus operations across all pairs. So, for example, the sequences (1, 2, 3, 4), (2, 1, 3, 4), and (4, 3, 1, 2) are not treated differently here.
Let g(n) be defined similarly, except only considering sequences whose y-value is > 0.

There are (8n)! permutations of the 8n numbers. Under the equivalence defined above, there are (4n)! ways to rearrange the pairs, and 2^{4n} ways to choose the order of elements within each pair.
Thus, \displaystyle f(n) = \frac{(8n)!}{2^{4n}(4n)!}

Let g_i(n) be the number of ways to arrange the numbers so that the i-th digit of y is 1. This can happen only when each of the 4n pairs has a number with i-th digit 1, and the other with i-th digit 0. So, among a_{8n+1}, a_{8n+2}, \dotsc, a_{16n} exactly 4n of the numbers should have the i-th bit set. If this is not the case, g_i(n) = 0.
If this does happen to be the case, fix the position of numbers whose i-th digit is 1, and we have exactly (4n)! ways to place the rest - distributing one to each pair. Some of these ways might be the same, based on the equivalence above, but either way, we have g_i(n) \leq (4n)!.

Now, the set of arrangements which are counted by g(n) is the union of the sets of arrangements counted by each g_i(n), because if y > 0 then at least one of its bits must be set.
Thus, basic set theory tells us that g(n) \leq g_1(n) + g_2(n) + \dotsc + g_{30}(n) \leq 30\cdot (4n)! (Generalize |A\cup B| \leq |A| + |B| to n sets)

We claim that f(n) > g(n) when n\geq 2.
This is true for n = 2, and treating it as the base case, can be proved for everything else with induction.

And of course, f(n) > g(n) means that there is some way to permute the numbers to obtain y = 0.

Thus, if n\geq 2, the problem reduces to maximizing x by choosing some 8n numbers and permuting them appropriately.
We maximize x greedily - first try to set the highest bit, then the second highest, and so on.

Initially, consider all 16n numbers to be a single group.
To set the highest bit in x, we need at least 4n numbers with the highest bit set and 4n with it unset, to pair with each other. If this is possible, divide the numbers into groups - one where the highest bit is set, and one where it isn’t. If it is not possible, do nothing and continue to the next bit.

For the other digits, we can follow a similar strategy.
Suppose we want to check whether fixing the digit i is possible, given some prefix has already been maximized.
Then, excluding skipped digits (i.e digits which it wasn’t possible for us to set), we want to consider groups of numbers whose prefixes are complementary to each other (like we looked at 0 and 1 of the highest bit, earlier). For example, 10010 and 01101 are complementary to each other.
The minimum of the sizes of a complementary pair is the most numbers we can make which preserves previous results, and also sets the i-th bit.
If the sum of these minimums over all complementary pairs is at least 4n, it’s possible to set the i-th bit; otherwise, it isn’t.
If the i-th bit is set, divide the groups used to create it into pairs; otherwise keep the set of groups as it is.
Make sure that every group is non-empty, so that the number of groups doesn’t become too large.

All that’s left is to solve for n = 1.
There are 16 numbers - choose a subset of them to be the first 8. There are \binom{16}{8} = 12870 subsets of size 8.
Once a subset is fixed, x can be maximized using the same strategy as above in O(n) - that didn’t depend on n \geq 2 at all.
Meanwhile, of the remaining 8, there are f(1) = 105 ways to arrange them (f was defined in the proof of the first lemma), so use backtracking to go over all possibilities and find the minimum value of y.

TIME COMPLEXITY

\mathcal{O}(30\cdot N)

SOLUTIONS:

Setter's Solution
#include<bits/stdc++.h>
using namespace std;

bool hasbit(int mask, int i) {
    return (mask >> i) & 1;
}

const int INF = 2e9;
const int maxN = 1e4 + 5;
const int logN = 30;

struct vector_pair {
    vector<int> x, y;
};

int n;
vector<int> a;

int track(int i, int type, int mask, vector<int>& vec) {
    if (i == vec.size()) return mask;

    if (vec[i] == -INF) return track(i + 1, type, mask, vec);

    int res = (type == 0 ? -INF : INF);

    for (int j = i + 1; j < vec.size(); j++) {
        if (vec[j] != -INF) {
            int old = vec[j];

            int nmask = mask & (vec[i] ^ vec[j]);

            vec[j] = -INF;

            int v = track(i + 1, type, nmask, vec);
            if (type == 0) res = max(res, v);
            else res = min(res, v);

            vec[j] = old;
        }
    }

    assert(abs(res) != INF);
    return res;
}

int getMaxX() {
    int res = 0;

    vector<vector_pair> cur, nex;
    cur.push_back({a, a});

    int first_time = 1;

    for (int i = logN; i >= 0; i--) {
        nex.clear();

        int cnt = 0;
        for (vector_pair& p : cur) {
            vector<int> x0, x1;
            for (int& u : p.x) {
                if (!hasbit(u, i)) x0.push_back(u);
                else x1.push_back(u);
            }

            vector<int> y0, y1;
            for (int& u : p.y) {
                if (!hasbit(u, i)) y0.push_back(u);
                else y1.push_back(u);
            }

            cnt += min(x0.size(), y1.size());
            if (min(x0.size(), y1.size()) != 0) 
                nex.push_back({x0, y1});

            if (first_time) continue;

            cnt += min(x1.size(), y0.size());
            if (min(x1.size(), y0.size()) != 0) 
                nex.push_back({x1, y0});
        }

        if (cnt >= n * 4) {
            swap(cur, nex);

            first_time = 0;

            res |= (1 << i);
        }
    }

    return res;
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    cin >> n;
    for (int i = 1; i <= n * 16; i++) {
        int x;
        cin >> x;
        a.push_back(x);
    }

    if (n == 1) {
        int ans = -INF;

        for (int mask = 1; mask < (1 << 16); mask++) {
            if (__builtin_popcount(mask) == 8) {
                vector<int> x, y;
                for (int i = 0; i < 16; i++) {
                    if (hasbit(mask, i)) x.push_back(a[i]);
                    else y.push_back(a[i]);
                }

                int full = (1ll << 31) - 1;

                int maxX = track(0, 0, full, x);
                int minY = track(0, 1, full, y);
                ans = max(ans, maxX - minY);
            }
        }

        cout << ans;
    }
    else cout << getMaxX();
}
Tester's Solution
//By TheOneYouWant
#pragma GCC optimize ("-O2")
#include <bits/stdc++.h>
using namespace std;
#define fastio ios_base::sync_with_stdio(0);cin.tie(0)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define all(x) x.begin(),x.end()
#define forstl(i,v) for(auto &i: v)
#define forn(i,e) for(int i=0;i<e;++i)
#define forsn(i,s,e) for(int i=s;i<e;++i)
#define rforn(i,s) for(int i=s;i>=0;--i)
#define rforsn(i,s,e) for(int i=s;i>=e;--i)
#define bitcount(a) __builtin_popcount(a) // set bits (add ll)
#define ln '\n'
typedef long long ll;
typedef pair<int,int> p32;
typedef vector<int> v32; 
 
bool taken[8];
v32 ord;
ll mi, mx;
v32 num1, num2;
int cnt;
 
void backtrack(int pos){
    if(pos == 8){
        ll g = (1LL<<32)-1;
        ll g2 = (1LL<<32)-1;
        forn(i,4){
            g &= (num1[ord[i*2]]^num1[ord[i*2+1]]);
            g2 &= (num2[ord[i*2]]^num2[ord[i*2+1]]);
        }
        mi = min(mi, g);
        mx = max(mx, g2);
        // cnt++;
        return; 
    }
    if(taken[pos]) return backtrack(pos+1);
    forsn(i,pos+1,8){
        if(!taken[i]){
            taken[pos] = 1;
            taken[i] = 1;
            ord.pb(pos);
            ord.pb(i);
            backtrack(pos+1);
            taken[pos] = 0;
            taken[i] = 0;
            ord.pop_back();
            ord.pop_back();
        }
    }
}
 
int main(){
    fastio;
 
    int n;
    cin>>n;
    int a[16*n];
    forn(i,16*n){
        cin>>a[i];
    }
 
    if(n == 1){
 
        ll ans = 0;
        vector<int> v;
        forn(i,1<<16){
            if(bitcount(i) == 8) v.pb(i);
        }
        forstl(mask, v){
            mi = 1e9;
            mx = -1e9;
            num1.clear();
            num2.clear();
            forn(i,8) taken[i] = 0;
            forn(i,16){
                if((mask & (1<<i)) > 0){
                    num1.pb(a[i]);
                }
                else{
                    num2.pb(a[i]);
                }
            }
            cnt = 0;
            backtrack(0);   
            ans = max(ans, mx - mi);
        }
        cout<<ans<<ln;
 
        return 0;
    }
 
    map<int, v32> groups;
    map<int, int> comp;
    map<int, p32> num;
    forn(i,16*n){
        groups[1].pb(a[i]);
    } 
    comp[1] = 1;
    vector<int> curr_str;
    curr_str.pb(1);
 
    int ans = 0;
 
    rforn(i,29){
        sort(all(curr_str));
        int mask = ((1<<30) - 1) - ((1<<i) - 1);
        int pairs = 0;
        forstl(r, curr_str){
            int val = r * (1<<(i+1));
            int z = 0, o = 0;
            forstl(k, groups[r]){
                if((mask & k) == (val & k)){
                    z++;
                }
                else{
                    o++;
                }
            }
            num[r] = mp(o, z);
            if(comp[r] < r){
                pairs += min((int)z, num[comp[r]].fi);
                pairs += min((int)o, num[comp[r]].se);
            }
            if(comp[r] == r){
                pairs += min(z, o);
            }
        }
        bool dig = 0;
        if(pairs >= 4*n){
            dig = 1;
        }
        if(dig){
            ans += (1<<i);
        }
        // if dig, just add 1 to all the suffixes
        // else, add the 1's and 0's separately
        vector<int> new_str;
        forstl(r, curr_str){
            int n1 = 2*r + 1;
            int n2 = 2*r;
            int val = r * (1<<(i+1));
            int o = num[r].fi, z = num[r].se;
            if(!dig){
                comp[n1] = 2*comp[r] + 1;
                new_str.pb(n1);
                forstl(tt, groups[r]) groups[n1].pb(tt);
            }
            else{
                bool pres1 = 0, pres2 = 0;
                if(o > 0 && (num[comp[r]].se > 0)){
                    pres1 = 1;
                    comp[n1] = comp[r]*2;
                    new_str.pb(n1);
                } 
                if(z > 0 && (num[comp[r]].fi > 0)){
                    pres2 = 1;
                    comp[n2] = 2*comp[r] + 1;
                    new_str.pb(n2);
                }
                forstl(tt, groups[r]){
                    if((mask & tt) == (val & tt)){
                        if(pres2) groups[n2].pb(tt);
                    }
                    else{
                        if(pres1) groups[n1].pb(tt);
                    }
                }
            }
        }
        swap(curr_str, new_str);
    }
    cout<<ans<<ln;
 
    return 0;
}