XORDETECTIVE - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Contest Division 4

Setter: Anton Trygub
Tester: Harris Leung
Editorialist: Trung Dang

DIFFICULTY:

2839

PREREQUISITES:

XOR

PROBLEM:

There are some hidden non-negative numbers A and B, with 0 \leq A \lt B \lt 2^{29}. You were not able to determine them and were gonna cry. Luckily, XOR The Detective decided to save your day!

XOR The Detective is a great interrogator. In one query, he can ask any integer X such that 0 \le X \lt 2^{30}, and learn the value of (A + X) \oplus (B + X). Here \oplus denotes the bitwise XOR operation.

Help XOR The Detective to determine both A and B in at most 30 queries.

EXPLANATION:

Let’s set up a goal on X: We want to find X such that B + X = 2^{29}. The reason being that since A < B, we know that A + X < 2^{29}, and therefore A + X and B + X don’t have any overlap bits, which makes it easier for us to recover A and B.

Let’s solve the easier case first: Suppose we know that the 28-th bit of B is 1 while the 28-th bit of A is 0 (we can easily check if this is the case using one question on X = 0). How can we proceed to find the suitable X (which I will now denote as \hat{X})?

We first see that \hat{X} \le 2^{28} (since B \ge 2^{28} due to our assumption). Observe that for any 0 \le X \le 2^{28}, A + X < 2^{29}. Therefore:

  • If 0 \le X < \hat{X}, then (A + X) \oplus (B + X) < 2^{29} (because B + X < B + \hat{X} = 2^{29}).
  • If \hat{X} \le X \le 2^{28}, then (A + X) \oplus (B + X) \ge 2^{29} (because B + X \ge B + \hat{X} = 2^{29}).

Would you look at that, we now have a condition to use in our binary search to find out \hat{X}! In this case though, we can implement this in an even easier way: simply loop through the 28-th to 0-th bit, turning it on one by one to see if \hat{X} exceeds this value or not. This takes us 30 queries exactly (one more from asking X = 0 at the beginning).

How do we generalize from our assumption? Suppose we query X = 0, and the first 1 bit in the returned result is at position P (where P = 28 would correspond to our assumption). We know that this means the P-th bit of B is 1, while the P-th bit of A is 0. Here’s an idea: we solve for these last P bits only. \hat{X} then has a slightly different semantic meaning: it is the smallest value such that B + \hat{X} has the last P bits being all 0. However, since the last P bits of B + \hat{X} are all 0's, we know that it will not affect the upper bits regardless of A, so we simply ignore these last P bits and solve the problem recursively with the upper bits, with our new A and B being A + \hat{X} and B + \hat{X}. We can prove that this recursive process of solving the last P bits and then ignoring them will ask one question for every bit (from binary lifting), which still takes us exactly 30 queries (with an initial query of X = 0 of course).

TIME COMPLEXITY:

Time complexity is O(1).

SOLUTION:

Setter's Solution
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <bits/stdc++.h>
#pragma GCC target ("avx2")
#pragma GCC optimization ("O3")
#pragma GCC optimization ("unroll-loops")

using namespace __gnu_pbds;
using namespace std;

using ll = long long;
using ld = double;

typedef tree<
        pair<int, int>,
        null_type,
        less<pair<int, int>>,
        rb_tree_tag,
        tree_order_statistics_node_update>
        ordered_set;

#define mp make_pair

int MOD = 998244353;

int mul(int a, int b) {
    return (1LL * a * b) % MOD;
}

int add(int a, int b) {
    int s = (a+b);
    if (s>=MOD) s-=MOD;
    return s;
}

int sub(int a, int b) {
    int s = (a+MOD-b);
    if (s>=MOD) s-=MOD;
    return s;
}

int po(int a, ll deg)
{
    if (deg==0) return 1;
    if (deg%2==1) return mul(a, po(a, deg-1));
    int t = po(a, deg/2);
    return mul(t, t);
}

int inv(int n)
{
    return po(n, MOD-2);
}


mt19937 rnd(time(0));


const int LIM = 1000005;

vector<int> facs(LIM), invfacs(LIM), invs(LIM);

void init()
{
    facs[0] = 1;
    for (int i = 1; i<LIM; i++) facs[i] = mul(facs[i-1], i);
    invfacs[LIM-1] = inv(facs[LIM-1]);
    for (int i = LIM-2; i>=0; i--) invfacs[i] = mul(invfacs[i+1], i+1);

    for (int i = 1; i<LIM; i++) invs[i] = mul(invfacs[i], facs[i-1]);
}

int C(int n, int k)
{
    if (n<k) return 0;
    if (n<0 || k<0) return 0;
    return mul(facs[n], mul(invfacs[k], invfacs[n-k]));
}



struct DSU
{
    vector<int> sz;
    vector<int> parent;
    void make_set(int v) {
        parent[v] = v;
        sz[v] = 1;
    }

    int find_set(int v) {
        if (v == parent[v])
            return v;
        return find_set(parent[v]);
    }

    void union_sets(int a, int b) {
        a = find_set(a);
        b = find_set(b);

        if (a != b) {
            if (sz[a] < sz[b])
                swap(a, b);
            parent[b] = a;
            sz[a] += sz[b];
        }
    }

    DSU (int n)
    {
        parent.resize(n);
        sz.resize(n);
        for (int i = 0; i<n; i++) make_set(i);
    }
};

void print(vector<int> a)
{
    for (auto it: a) cout<<it<<' ';
    cout<<endl;
}

void print(vector<bool> a)
{
    for (auto it: a) cout<<it<<' ';
    cout<<endl;
}
/*
void print(vector<pair<ll, ll>> a)
{
    for (auto it: a) cout<<it.first<<' '<<it.second<<"| ";
    cout<<endl;
}*/

void print(vector<pair<int, int>> a)
{
    for (auto it: a) cout<<it.first<<' '<<it.second<<"| ";
    cout<<endl;
}

/*const int mod = 998244353;

template<int mod>
struct NTT {
    static constexpr int max_lev = __builtin_ctz(mod - 1);

    int prod[2][max_lev - 1];

    NTT() {
        int root = find_root();//(mod == 998244353) ? 31 : find_root();
        int rroot = power(root, mod - 2);
        vector<vector<int>> roots(2, vector<int>(max_lev - 1));
        roots[0][max_lev - 2] = root;
        roots[1][max_lev - 2] = rroot;
        for (int tp = 0; tp < 2; ++tp) {
            for (int i = max_lev - 3; i >= 0; --i) {
                roots[tp][i] = mul(roots[tp][i + 1], roots[tp][i + 1]);
            }
        }
        for (int tp = 0; tp < 2; ++tp) {
            int cur = 1;
            for (int i = 0; i < max_lev - 1; ++i) {
                prod[tp][i] = mul(cur, roots[tp][i]);
                cur = mul(cur, roots[tp ^ 1][i]);
            }
        }
    }

    template<bool inv>
    void fft(int *a, int lg) const {
        const int n = 1 << lg;
        int pos = max_lev - 1;
        for (int it = 0; it < lg; ++it) {
            const int h = inv ? lg - 1 - it : it;
            const int shift = (1 << (lg - h - 1));
            int coef = 1;
            for (int start = 0; start < (1 << h); ++start) {
                for (int i = start << (lg - h); i < (start << (lg - h)) + shift; ++i) {
                    if (!inv) {
                        const int y = mul(a[i + shift], coef);
                        a[i + shift] = a[i];
                        inc(a[i], y);
                        dec(a[i + shift], y);
                    } else {
                        const int y = mul(a[i] + mod - a[i + shift], coef);
                        inc(a[i], a[i + shift]);
                        a[i + shift] = y;
                    }
                }
                coef = mul(coef, prod[inv][__builtin_ctz(~start)]);
            }
        }
    }

    vector<int> product(vector<int> a, vector<int> b) const {
        if (a.empty() || b.empty()) {
            return {};
        }
        const int sz = a.size() + b.size() - 1;
        const int lg = 32 - __builtin_clz(sz - 1), n = 1 << lg;
        a.resize(n);
        b.resize(n);
        fft<false>(a.data(), lg);
        fft<false>(b.data(), lg);
        for (int i = 0; i < n; ++i) {
            a[i] = mul(a[i], b[i]);
        }
        fft<true>(a.data(), lg);
        a.resize(sz);
        const int rn = power(n, mod - 2);
        for (int &x : a) {
            x = mul(x, rn);
        }
        return a;
    }

private:
    static inline void inc(int &x, int y) {
        x += y;
        if (x >= mod) {
            x -= mod;
        }
    }

    static inline void dec(int &x, int y) {
        x -= y;
        if (x < 0) {
            x += mod;
        }
    }

    static inline int mul(int x, int y) {
        return (1LL * x * y) % mod;
    }

    static int power(int x, int y) {
        if (y == 0) {
            return 1;
        }
        if (y % 2 == 0) {
            return power(mul(x, x), y / 2);
        }
        return mul(x, power(x, y - 1));
    }

    static int find_root() {
        for (int root = 2; ; ++root) {
            if (power(root, (1 << max_lev)) == 1 && power(root, (1 << (max_lev - 1))) != 1) {
                return root;
            }
        }
    }
};

NTT<mod> ntt;
*/

int ask(int x)
{
    cout<<"? "<<x<<endl;
    int val; cin>>val; return val;
}

const int M = (1<<29);

void solve()
{
    int q; cin>>q;
    int X = ask(0);

    int a = 0; int b = 0;
    int bit = 0;
    for (int i = 0; i<29; i++) if (X&(1<<i)) bit = i;

    b+=(1<<bit);

    for (int i = bit+1; i<29; i++)
    {
        int x = (1<<i) - b;
        int res = ask(x);
        if ((res^X)&(1<<(i+1))) b+=(1<<i);
    }

    for (int i = bit-1; i>=0; i--)
    {
        int x = M - b - (1<<i);
        int res = ask(x);
        if (res&M) b+=(1<<i);
    }
    a = b^X;
    cout<<"! "<<a<<' '<<b<<endl;
}


int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(nullptr);

    int t; cin>>t;
    while (t--) solve();


}
Tester's Solution
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define fi first
#define se second
const int N=2e5+1;
const int iu=29;
ll ask(ll x){
	cout << "? " << x << endl;
	ll res;cin >> res;
	return res;
}
void solve(){
	int antonbaby;cin >> antonbaby;
	ll king=ask(0);
	int z=0;
	while(1<<(z+1)<=king) z++;
	ll a=0,b=(1<<z);
	ll love=(1<<z);
	for(int i=z+1; i<iu ;i++){
		ll res=ask(love);
		int flag=((res^king)>>(i+1))&1;
		if(flag) a|=(1<<i),b|=(1<<i);
		else love|=(1<<i);
	}
	love=0;
	int last=z;
	for(int i=z-1; i>=0 ;i--){
		if((king>>i)&1){
			ll res=ask(love+(1<<i));
			int flag=(res^king)>>(last+1);
			if(!flag) swap(a,b);
			b|=(1<<i);
			last=i;
		}
		else{
			ll res=ask(love+(1<<i));
			int flag=(res!=king);
			if(flag) a|=(1<<i),b|=(1<<i);
			else love|=(1<<i);
		}
	}
	if(a>b) swap(a,b);
	cout << "! " << a << ' ' << b << endl;
}
int main(){
	ios::sync_with_stdio(false);cin.tie(0);
	int t;cin >> t;while(t--) solve();
}
Editorialist's Solution
#include <bits/stdc++.h>
using namespace std;

int ask(int u) {
    cout << "? " << u << endl;
    int ans; cin >> ans; return ans;
}

int main() {
    int t; cin >> t;
    while (t--) {
        int q; cin >> q;
        int add, sum, lst;
        for (add = 0, sum = ask(add), lst = -1; __lg(sum) < 29; ) {
            int top = __lg(sum);
            for (int i = top; i > lst; i--) {
                int ret = ask(add + (1 << i));
                if (ret < (1 << (top + 1))) {
                    add += (1 << i);
                } else {
                    sum = ret;
                }
            }
            add += (1 << (lst + 1)); lst = top;
        }
        int b = 1 << (__lg(sum)), a = sum - b;
        b -= add; a -= add;
        cout << "! " << a << " " << b << endl;
    }
}

Is there a typo here, if X <=2^{29} how can A+X <2^{29} ?

Yea it’s X \le 2^{28} … thank you for spotting

My code gives WA for subtask 1. Am I missing something?

#include<bits/stdc++.h>
using namespace std;

int main() {
  cin.tie(0)->sync_with_stdio(0);
  
  const int MX = 128;
  map<vector<int>, pair<int, int>> mp;
  vector<int> q(MX);
  for (int a = 0; a < MX; a++) {
    for (int b = a + 1; b < MX; b++) {
      for (int x = 0; x < MX; x++) {
        q[x] = (a + x) ^ (b + x);
      }
      mp[q] = {a, b};
    }
  }
  int t;
  for (cin >> t; t; t--) {
    for (int i = 0; i < MX; i++) {
      cout << "? " << i << endl << flush;
      cin >> q[i];
    }
    assert(mp.count(q));
    auto [a, b] = mp[q];
    cout << "! " << a << ' ' << b << endl << flush;
  } 
  return 0;
}

can anyone explain the approach in little bit easy way or other method