# FOURARR - Editorial

Author: # @#@polarity@#@
Testers: Satyam, Abhinav Sharma
Editorialist: Nishank Suresh

To be calculated

# PREREQUISITES:

Binary search/2-pointers, fast convolution using FFT, prefix sums

# PROBLEM:

You have 4 arrays A, B, C, D and an integer K. Find the K-th smallest value of (A_x + B_y) \cdot (C_z + D_w) across all valid indices x, y, z, w.

# EXPLANATION:

In many tasks asking for the K-th largest or smallest object of some kind, binary search should immediately come to mind as a possible solution.
Indeed, binary search does work in this problem — suppose we fix a value of X, then compute f(X): the number of values (A_x + B_y) \cdot (C_z + D_w) that are at most X. We are looking for the smallest X such that f(X) \geq K. The remainder of this editorial will detail how to compute f(X) given X.

The given expression factors nicely into two parts, (A_x + B_y) and (C_z + D_w). Note that each of these parts individually do not exceed 2 \cdot 10^5.
Let’s fix a value of A_x + B_y, say r. Then, C_z + D_w can take any value s such that r\cdot s \leq X, i.e, s \leq \lfloor \frac{X}{r} \rfloor.

Now, say we magically had two arrays P and Q, where P_r denotes the number of pairs (x, y) such that A_x + B_y = r, and Q_s denotes the number of pairs (z, w) such that C_z + D_w = s.

Then, note that \displaystyle f(X) = \sum_{r = 0}^{2 \cdot 10^5} \sum_{s = 0}^{\lfloor \frac{X}{r} \rfloor}P_r Q_s

(\lfloor \frac{X}{0} \rfloor isn’t defined but just pretend it’s 2\cdot 10^5 and the sum works out)
Now notice that the second summation is really just the prefix sum of Q upto index \lfloor \frac{X}{r} \rfloor, which means f(X) as a whole can be computed in linear time if we knew P and Q.

Computing P and Q quickly, as it turns out, is a classical application of fast polynomial multiplication using FFT. For example, here is how one would compute P:

• Consider the polynomial a(x) of degree 10^5, where the coefficient of x^i is the number of times the value i appears in A.
• Similarly, consider the polynomial b(x) that encodes the frequency of elements of B.
• P is then simply the product a * b, which can be computed in \mathcal{O}(N\log N) using FFT — a tutorial on this is linked above.

This brings us to the final solution:

• Use FFT to compute the arrays P and Q as defined above.
• Binary search over the value of the answer, X.
• Use P and Q to compute f(X) in linear time, and then update the bounds of the binary search appropriately

Note that the binary search isn’t strictly necessary — since \left\lfloor \frac{X}{r+1} \right\rfloor \leq \left\lfloor \frac{X}{r} \right\rfloor, a 2-pointer technique can be used, iterating across P in increasing order and Q in decreasing order. The complexity doesn’t change much though, since it’s still dominated by the \mathcal{O}(N\log N) from FFT.

# TIME COMPLEXITY:

\mathcal{O}(N\log M) or \mathcal{O}(N\log N), where N = 2\cdot 10^5 and M = N^2.

# CODE:

Preparer
#include <bits/stdc++.h>
using namespace std;

/*
------------------------Input Checker----------------------------------
*/

long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}

if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}

return x;
} else {
assert(false);
}
}
}
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}

/*
------------------------Main code starts here----------------------------------
*/

const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;

#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back

int sum_n = 0, sum_m = 0, sum = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 1000000007;
int sz = 100001;

using ii = pair<ll,ll>;

using cd = complex<long double>;
const long double PI = acos(-1.0);

void fft(vector<cd> &a, bool inv){
int n = a.size();
int logn = 0;

while((1ll<<logn) < n) logn++;

for(int i=0; i<n; i++){
int tmp=0;
for(int j=0; j<logn; j++){
if((i>>j)&1) tmp |= (1<<(logn-j-1));
}
if(i < tmp) swap(a[i], a[tmp]);
}

int k=2;
long double ang;
cd u,v;
while(k<=n){
ang = 2*PI/k*(inv?-1:1);
cd wn(cos(ang), sin(ang));
for(int i=0; i<n; i+=k){
cd w(1.0);
for(int j=i; j<i+k/2; j++){
u = a[j];
v = a[j+k/2]*w;
a[j] = u+v;
a[j+k/2] = u-v;
w*=wn;
}
}
k<<=1;
}

if(inv){
for(int i=0; i<n; i++) a[i]/=n;
}
}

vector<long long> poly_mul(vector<int> &p1, vector<int> &p2){
int n=1;
while(n < p1.size()+p2.size()) n*=2;

vector<cd> pa(n), pb(n);

for(int i=0; i<p1.size(); i++){
pa[i] = p1[i];
}
for(int i=p1.size(); i<n; i++){
pa[i] = 0;
}
for(int i=0; i<p2.size(); i++){
pb[i] = p2[i];
}
for(int i=p2.size(); i<n; i++){
pb[i] = 0;
}

fft(pa, 0);
fft(pb, 0);

for(int i=0; i<n; i++){
pa[i] *= pb[i];
}

fft(pa, 1);
vector<long long> ret(n);
for(int i=0; i<n; i++){
ret[i] = round(real(pa[i]));
}
return ret;
}

void solve()
{

vector<int> a(sz,0), b(sz,0), c(sz,0), d(sz,0);

int x;
rep(i,sa){
a[x]++;
}
rep(i,sb){
b[x]++;
}
rep(i,sc){
c[x]++;
}
rep(i,sd){
d[x]++;
}

vector<long long> v1 = poly_mul(a,b), v2 = poly_mul(c,d);

int n = v1.size();
rep_a(i,1,n) v1[i]+=v1[i-1];

ll lo = 0, hi = 5e10;

while(lo<hi){
ll mid = (lo+hi)>>1;
ll cnt = 0;
int p1 = 0, p2 = n-1;
while(p2>=0){
while(p1<n && p1*p2<=mid) p1++;
cnt += v2[p2]*(p1>0?v1[p1-1]:0);
p2--;
}

if(cnt<k) lo = mid+1;
else hi = mid;
}

cout<<hi<<'\n';
}

signed main()
{

#ifndef ONLINE_JUDGE
freopen("input.txt", "r" , stdin);
freopen("output.txt", "w" , stdout);
#endif
fast;

int t = 1;

for(int i=1;i<=t;i++)
{
solve();
}

assert(getchar() == -1);

cerr<<"SUCCESS\n";
cerr<<"Tests : " << t << '\n';
// cerr<<"Sum of lengths : " << sum_m <<'\n';
// cerr<<"Maximum length : " << max_n <<'\n';
// // cerr<<"Total operations : " << total_ops << '\n';
// cerr<<"Answered yes : " << yess << '\n';
// cerr<<"Answered no : " << nos << '\n';

cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}

Editorialist

#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

// https://judge.yosupo.jp/submission/69895
namespace ntt {

template <class T, class F = multiplies<T>>
T power(T a, long long n, F op = multiplies<T>(), T e = {1}) {
// assert(n >= 0);
T res = e;
while (n) {
if (n & 1) res = op(res, a);
if (n >>= 1) a = op(a, a);
}
return res;
}

constexpr int mod = int(1e9) + 7;
constexpr int nttmod = 998'244'353;

template <std::uint32_t P>
struct ModInt32 {
public:
using i32 = std::int32_t;
using u32 = std::uint32_t;
using i64 = std::int64_t;
using u64 = std::uint64_t;
using m32 = ModInt32;
using internal_value_type = u32;

private:
u32 v;
static constexpr u32 get_r() {
u32 iv = P;
for (u32 i = 0; i != 4; ++i) iv *= 2U - P * iv;
return -iv;
}
static constexpr u32 r = get_r(), r2 = -u64(P) % P;
static_assert((P & 1) == 1);
static_assert(-r * P == 1);
static_assert(P < (1 << 30));
static constexpr u32 pow_mod(u32 x, u64 y) {
u32 res = 1;
for (; y != 0; y >>= 1, x = u64(x) * x % P)
if (y & 1) res = u64(res) * x % P;
return res;
}
static constexpr u32 reduce(u64 x) {
return (x + u64(u32(x) * r) * P) >> 32;
}
static constexpr u32 norm(u32 x) { return x - (P & -(x >= P)); }

public:
static constexpr u32 get_pr() {
u32 tmp[32] = {}, cnt = 0;
const u64 phi = P - 1;
u64 m = phi;
for (u64 i = 2; i * i <= m; ++i)
if (m % i == 0) {
tmp[cnt++] = i;
while (m % i == 0) m /= i;
}
if (m != 1) tmp[cnt++] = m;
for (u64 res = 2; res != P; ++res) {
bool flag = true;
for (u32 i = 0; i != cnt && flag; ++i)
flag &= pow_mod(res, phi / tmp[i]) != 1;
if (flag) return res;
}
return 0;
}
constexpr ModInt32() : v(0){};
~ModInt32() = default;
constexpr ModInt32(u32 _v) : v(reduce(u64(_v) * r2)) {}
constexpr ModInt32(i32 _v) : v(reduce(u64(_v % P + P) * r2)) {}
constexpr ModInt32(u64 _v) : v(reduce((_v % P) * r2)) {}
constexpr ModInt32(i64 _v) : v(reduce(u64(_v % P + P) * r2)) {}
constexpr ModInt32(const m32& rhs) : v(rhs.v) {}
constexpr u32 get() const { return norm(reduce(v)); }
explicit constexpr operator u32() const { return get(); }
explicit constexpr operator i32() const { return i32(get()); }
constexpr m32& operator=(const m32& rhs) { return v = rhs.v, *this; }
constexpr m32 operator-() const {
m32 res;
return res.v = (P << 1 & -(v != 0)) - v, res;
}
constexpr m32 inv() const { return pow(P - 2); }
constexpr m32& operator+=(const m32& rhs) {
return v += rhs.v - (P << 1), v += P << 1 & -(v >> 31), *this;
}
constexpr m32& operator-=(const m32& rhs) {
return v -= rhs.v, v += P << 1 & -(v >> 31), *this;
}
constexpr m32& operator*=(const m32& rhs) {
return v = reduce(u64(v) * rhs.v), *this;
}
constexpr m32& operator/=(const m32& rhs) {
return this->operator*=(rhs.inv());
}
friend m32 operator+(const m32& lhs, const m32& rhs) {
return m32(lhs) += rhs;
}
friend m32 operator-(const m32& lhs, const m32& rhs) {
return m32(lhs) -= rhs;
}
friend m32 operator*(const m32& lhs, const m32& rhs) {
return m32(lhs) *= rhs;
}
friend m32 operator/(const m32& lhs, const m32& rhs) {
return m32(lhs) /= rhs;
}
friend bool operator==(const m32& lhs, const m32& rhs) {
return norm(lhs.v) == norm(rhs.v);
}
friend bool operator!=(const m32& lhs, const m32& rhs) {
return norm(lhs.v) != norm(rhs.v);
}
friend std::istream& operator>>(std::istream& is, m32& rhs) {
return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
}
friend std::ostream& operator<<(std::ostream& os, const m32& rhs) {
return os << rhs.get();
}
constexpr m32 pow(i64 y) const {
// assumes P is a prime
i64 rem = y % (P - 1);
if (y > 0 && rem == 0)
y = P - 1;
else
y = rem;
m32 res(1), x(*this);
for (; y != 0; y >>= 1, x *= x)
if (y & 1) res *= x;
return res;
}
};

using mint = ModInt32<nttmod>;

void ntt(vector<mint>& a, bool inverse) {
static array<mint, 30> dw{}, idw{};
if (dw[0] == 0) {
mint root = 2;
while (power(root, (nttmod - 1) / 2) == 1) root += 1;
for (int i = 0; i < 30; ++i)
dw[i] = -power(root, (nttmod - 1) >> (i + 2)),
idw[i] = 1 / dw[i];
}
int n = (int)a.size();
assert((n & (n - 1)) == 0);
if (not inverse) {
for (int m = n; m >>= 1;) {
mint w = 1;
for (int s = 0, k = 0; s < n; s += 2 * m) {
for (int i = s, j = s + m; i < s + m; ++i, ++j) {
auto x = a[i], y = a[j] * w;
a[i] = x + y;
a[j] = x - y;
}
w *= dw[__builtin_ctz(++k)];
}
}
} else {
for (int m = 1; m < n; m *= 2) {
mint w = 1;
for (int s = 0, k = 0; s < n; s += 2 * m) {
for (int i = s, j = s + m; i < s + m; ++i, ++j) {
auto x = a[i], y = a[j];
a[i] = x + y;
a[j] = (x - y) * w;
}
w *= idw[__builtin_ctz(++k)];
}
}
auto inv = 1 / mint(n);
for (auto&& e : a) e *= inv;
}
}
vector<mint> operator*(vector<mint> l, vector<mint> r) {
if (l.empty() or r.empty()) return {};
int n = (int)l.size(), m = (int)r.size(),
sz = 1 << __lg(2 * (n + m - 1) - 1);
if (min(n, m) < 30) {
vector<mint> res(n + m - 1);
for (int i = 0; i < n; ++i)
for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
return {begin(res), end(res)};
}
bool eq = l == r;
l.resize(sz), ntt(l, false);
if (eq)
r = l;
else
r.resize(sz), ntt(r, false);
for (int i = 0; i < sz; ++i) l[i] *= r[i];
ntt(l, true), l.resize(n + m - 1);
return l;
}
vector<mint>& operator*=(vector<mint>& l, vector<mint> r) {
if (l.empty() or r.empty()) {
l.clear();
return l;
}
int n = (int)l.size(), m = (int)r.size(),
sz = 1 << __lg(2 * (n + m - 1) - 1);
if (min(n, m) < 30) {
vector<mint> res(n + m - 1);
for (int i = 0; i < n; ++i)
for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
l = {begin(res), end(res)};
return l;
}
bool eq = l == r;
l.resize(sz), ntt(l, false);
if (eq)
r = l;
else
r.resize(sz), ntt(r, false);
for (int i = 0; i < sz; ++i) l[i] *= r[i];
ntt(l, true), l.resize(n + m - 1);
return l;
}
}  // namespace ntt

int main()
{
ios::sync_with_stdio(false); cin.tie(0);

using mint = ntt::mint;
const int MAXN = 1e5 + 10;
ll A, B, C, D, k; cin >> A >> B >> C >> D >> k;
vector<mint> a(MAXN), b(MAXN), c(MAXN), d(MAXN);
for (int i = 0; i < A; ++i) {
int x; cin >> x;
a[x] += 1;
}
for (int i = 0; i < B; ++i) {
int x; cin >> x;
b[x] += 1;
}
for (int i = 0; i < C; ++i) {
int x; cin >> x;
c[x] += 1;
}
for (int i = 0; i < D; ++i) {
int x; cin >> x;
d[x] += 1;
}
auto res1 = a*b, res2 = c*d;
for (int i = 1; i < res2.size(); ++i) res2[i] += res2[i-1];
ll lo = 0, hi = 1e11;
while (lo < hi) {
ll mid = (lo + hi)/2;
ll leq = 0;
for (int i = 0; i < res1.size(); ++i) {
ll id = res2.size()-1;
if (i) id = min(id, mid/i);
// leq += (res1[i] * res2[id]).get();
leq += 1LL * res1[i].get() * res2[id].get();
}
if (leq < k) lo = mid+1;
else hi = mid;
}
cout << lo;
}


Small edit to render the sum properly:

f(X) = \sum_{r = 0}^{2 \cdot 10^5} \sum_{s = 0}^{\lfloor \frac{X}{r} \rfloor}P_r Q_s

Good catch, I’ve fixed it.