ABCD - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Danny Boy
Tester: Anay Karnik
Editorialist: Mohan Abhyas

DIFFICULTY:

MEDIUM

PREREQUISITES:

None

PROBLEM:

There are four hidden positive integers A,B,C,D.
Your goal is to find the relationship between A\times B and C\times D.
i.e. Report A\times B (>, = or <) C\times D.

To do that, you are allowed to ask the following query no more than q times:
query(a, b, c, d), where a, b, c, d are each an integer between -10^9 and 10^9.
Let X = a\times A + b\times B + c\times C + d\times D.
For each query we will tell you whether X > 0, X = 0, or X < 0.

It is guaranteed that the hidden integers are positive and don’t exceed 10^9 (1≤A,B,C,D≤ 10^9).

EXPLANATION:

Let’s call a rational number nice if both numerator, denominator <= 1e9
Use Brocot-Euclid algorithm to find nearest nice rational numbers both above and below of a given arbitrary rational number.
References:
Page 47-50 of https://www.math.ru.nl/~bosma/Students/CF.pdf
https://cp-algorithms.com/others/stern_brocot_tree_farey_sequences.html

D/A = x, B/C = y
st = 0/1, en = 1e9/1
Given a nice fraction p/q we can check if x < p/q or x = p/q or x > p/q using the query (-p, 0,0,q)
Initially st < x,y < en
Binary search on the interval (st,en) until both x,y lie in between them
Binary search iteration:
mid = (st+en)/2
lo = nice rational number just below mid
hi = nice rational number just above mid
x, y are nice rational numbers => x,y does not belong to interval (lo, hi)
if(x,y > hi) st = hi
else if(x,y < lo) en = lo
else output the answer based on x < y, y < x, x=y

TIME COMPLEXITY:

Binary search with initial difference 1e9 and accuracy = minimum difference between two nice rational numbers
\mathcal{O}(log(1e27)*2) per testcase.

SOLUTIONS:

Tester's Solution
#include <iostream>
#include <iomanip>
#include <cmath>

#define int long long

const int MAX = 1'000'000'000;
const int SUBMIT = true;

void print(struct frac &a);

int gcd(int a, int b) {
  if(a == 0)
    return b;
  return gcd(b%a, a);
}

struct frac
{
  int a, b;

  void reduce() {
    int d = gcd(a, b);
    a /= d;
    b /= d;
  }
};

int solved;

int eval(char c) {
  if(c == '>')
    return 1;
  if(c == '=')
    return 0;
  return -1;
}

int A, B, C, D;
int nq, max_q;

int query(int a, int b, int c, int d) {
  nq++;
  if(SUBMIT) {
    std::cout << "? " << a << " " << b << " " << c << " " << d << std::endl;
    char c;
    std::cin >> c;
    return eval(c);
  }
  else {
    int val = a*A + b*B + c*C + d*D;
    if(val > 0)
      return 1;
    if(val < 0)
      return -1;
    return 0;
  }
}

int cor;
void res(int out) {
  if(SUBMIT) {
    std::cout << "! ";
    if(out > 0)
      std::cout << ">" << std::endl;
    else if(out < 0)
      std::cout << "<" << std::endl;
    else
      std::cout << "=" << std::endl;
  }
  else {
    int act = A*B-C*D;
    if(act > 0)
      act = 1;
    if(act < 0)
      act = -1;
    if(act != out)
      cor = false;
  }
}

int guess(frac &x) {
  if(solved)
    return 0;

  int r1, r2;
  r1 = query(-x.b, 0, 0, x.a); // sgn(x cross (A, D))
  r2 = query(0, x.a, -x.b, 0); // sgn(x cross (C, B))

  if(r1 == 0 && r2 == 0) {
    solved = 1;
    res(0);
    return 0;
  }
  else if(r1 == 0) {
    res(r2);
    solved = 1;
    return 0;
  }
  else if(r2 == 0) {
    res(-r1);
    solved = 1;
    return 0;
  }
  else if(r1 != r2) {
    res(r2);
    solved = 1;
    return 0;
  }
  else {
    return r1;
  }
}

std::pair<frac, frac> bound(frac &x) {
  frac lo = {x.a/x.b, 1}, hi = {x.a/x.b + (x.a%x.b != 0), 1};
  frac lo_f = lo, hi_f = hi;
  int lErr = x.a - lo.a*x.b, hErr = hi.a*x.b-x.a;

  while(lErr && hErr) {
    if(lErr < hErr) {
      int take = hErr/lErr;
      if(lo.a) take = std::min(take, (MAX-hi.a)/lo.a);
      take = std::min(take, (MAX-hi.b)/lo.b);

      hErr %= lErr;
      hi.a += take*lo.a;
      hi.b += take*lo.b;

      hi_f = hi;

      lo.a += hi.a;
      lo.b += hi.b;
      lErr -= hErr;

      if(lo.a > MAX || lo.b > MAX)
        break;

      lo_f = lo;
    }
    else {
      int take = lErr/hErr;
      if(hi.a) take = std::min(take, (MAX-lo.a)/hi.a);
      take = std::min(take, (MAX-lo.b)/hi.b);

      lErr %= hErr;
      lo.a += take*hi.a;
      lo.b += take*hi.b;

      lo_f = lo;

      hi.a += lo.a;
      hi.b += lo.b;
      hErr -= lErr;
      
      if(hi.a > MAX || hi.b > MAX)
        break;

      hi_f = hi;
    }
  }

  return {lo_f, hi_f};
}

void print(frac &a) {
  std::cout << a.a << " " << a.b << std::endl;
}

signed main() {
  cor = 1;
  std::ios::sync_with_stdio(false);
  std::cin.tie(0);

  int t, q;
  std::cin >> t >> q;

  while(t--) {
    solved = 0;
    nq = 0;
    if(!SUBMIT) {
      std::cin >> A >> B >> C >> D;
    }

    frac lo = {0, 1}, hi = {MAX, 1};
    guess(lo);

    while(!solved) {
      frac next;
      next.b = hi.b*lo.b*2;
      next.a = hi.b*lo.a + hi.a*lo.b;
      next.reduce();
      std::pair<frac, frac> mid = bound(next);

      int ret = guess(mid.second);

      if(ret == -1)
        lo = mid.second;
      else
        hi = mid.first;
    }

    max_q = std::max(max_q, nq);
  }

  if(!SUBMIT) {
    std::cout << max_q << std::endl;
    std::cout << cor << std::endl;
  }

  return 0;
}