FOURARR - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: # @#@polarity@#@
Testers: Satyam, Abhinav Sharma
Editorialist: Nishank Suresh

DIFFICULTY:

To be calculated

PREREQUISITES:

Binary search/2-pointers, fast convolution using FFT, prefix sums

PROBLEM:

You have 4 arrays A, B, C, D and an integer K. Find the K-th smallest value of (A_x + B_y) \cdot (C_z + D_w) across all valid indices x, y, z, w.

EXPLANATION:

In many tasks asking for the K-th largest or smallest object of some kind, binary search should immediately come to mind as a possible solution.
Indeed, binary search does work in this problem — suppose we fix a value of X, then compute f(X): the number of values (A_x + B_y) \cdot (C_z + D_w) that are at most X. We are looking for the smallest X such that f(X) \geq K. The remainder of this editorial will detail how to compute f(X) given X.

The given expression factors nicely into two parts, (A_x + B_y) and (C_z + D_w). Note that each of these parts individually do not exceed 2 \cdot 10^5.
Let’s fix a value of A_x + B_y, say r. Then, C_z + D_w can take any value s such that r\cdot s \leq X, i.e, s \leq \lfloor \frac{X}{r} \rfloor.

Now, say we magically had two arrays P and Q, where P_r denotes the number of pairs (x, y) such that A_x + B_y = r, and Q_s denotes the number of pairs (z, w) such that C_z + D_w = s.

Then, note that \displaystyle f(X) = \sum_{r = 0}^{2 \cdot 10^5} \sum_{s = 0}^{\lfloor \frac{X}{r} \rfloor}P_r Q_s

(\lfloor \frac{X}{0} \rfloor isn’t defined but just pretend it’s 2\cdot 10^5 and the sum works out)
Now notice that the second summation is really just the prefix sum of Q upto index \lfloor \frac{X}{r} \rfloor, which means f(X) as a whole can be computed in linear time if we knew P and Q.

Computing P and Q quickly, as it turns out, is a classical application of fast polynomial multiplication using FFT. For example, here is how one would compute P:

  • Consider the polynomial a(x) of degree 10^5, where the coefficient of x^i is the number of times the value i appears in A.
  • Similarly, consider the polynomial b(x) that encodes the frequency of elements of B.
  • P is then simply the product a * b, which can be computed in \mathcal{O}(N\log N) using FFT — a tutorial on this is linked above.

This brings us to the final solution:

  • Use FFT to compute the arrays P and Q as defined above.
  • Binary search over the value of the answer, X.
  • Use P and Q to compute f(X) in linear time, and then update the bounds of the binary search appropriately

Note that the binary search isn’t strictly necessary — since \left\lfloor \frac{X}{r+1} \right\rfloor \leq \left\lfloor \frac{X}{r} \right\rfloor, a 2-pointer technique can be used, iterating across P in increasing order and Q in decreasing order. The complexity doesn’t change much though, since it’s still dominated by the \mathcal{O}(N\log N) from FFT.

TIME COMPLEXITY:

\mathcal{O}(N\log M) or \mathcal{O}(N\log N), where N = 2\cdot 10^5 and M = N^2.

CODE:

Preparer
#include <bits/stdc++.h>
using namespace std;
 
 
/*
------------------------Input Checker----------------------------------
*/
 
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);
 
            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }
 
            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }
 
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}
 
 
/*
------------------------Main code starts here----------------------------------
*/
 
const int MAX_T = 1e5;
const int MAX_N = 1e5;
const int MAX_SUM_LEN = 1e5;
 
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define ff first
#define ss second
#define mp make_pair
#define ll long long
#define rep(i,n) for(int i=0;i<n;i++)
#define rev(i,n) for(int i=n;i>=0;i--)
#define rep_a(i,a,n) for(int i=a;i<n;i++)
#define pb push_back
 
int sum_n = 0, sum_m = 0, sum = 0;
int max_n = 0, max_m = 0;
int yess = 0;
int nos = 0;
int total_ops = 0;
ll mod = 1000000007;
int sz = 100001;

using ii = pair<ll,ll>;


using cd = complex<long double>;
const long double PI = acos(-1.0);

void fft(vector<cd> &a, bool inv){
    int n = a.size();
    int logn = 0;

    while((1ll<<logn) < n) logn++;

    for(int i=0; i<n; i++){
        int tmp=0;
        for(int j=0; j<logn; j++){
            if((i>>j)&1) tmp |= (1<<(logn-j-1));
        }
        if(i < tmp) swap(a[i], a[tmp]);
    }

    int k=2;
    long double ang;
    cd u,v;
    while(k<=n){
        ang = 2*PI/k*(inv?-1:1);
        cd wn(cos(ang), sin(ang));
        for(int i=0; i<n; i+=k){
            cd w(1.0);
            for(int j=i; j<i+k/2; j++){
                u = a[j];
                v = a[j+k/2]*w;
                a[j] = u+v;
                a[j+k/2] = u-v;
                w*=wn;
            }
        }
        k<<=1;
    }

    if(inv){
        for(int i=0; i<n; i++) a[i]/=n;
    }
}

vector<long long> poly_mul(vector<int> &p1, vector<int> &p2){
    int n=1;
    while(n < p1.size()+p2.size()) n*=2;

    vector<cd> pa(n), pb(n);

    for(int i=0; i<p1.size(); i++){
        pa[i] = p1[i];
    }
    for(int i=p1.size(); i<n; i++){
        pa[i] = 0;
    }
    for(int i=0; i<p2.size(); i++){
        pb[i] = p2[i];
    }
    for(int i=p2.size(); i<n; i++){
        pb[i] = 0;
    }

    fft(pa, 0);
    fft(pb, 0);

    for(int i=0; i<n; i++){
        pa[i] *= pb[i];
    }

    fft(pa, 1);
    vector<long long> ret(n);
    for(int i=0; i<n; i++){
        ret[i] = round(real(pa[i]));
    }
    return ret;
}


void solve()
{ 

    int sa = readIntSp(1,3e4);
    int sb = readIntSp(1,3e4);
    int sc = readIntSp(1,3e4);
    int sd = readIntSp(1,3e4);
    ll k = readIntLn(1, (ll)sa*sb*sc*sd);

    vector<int> a(sz,0), b(sz,0), c(sz,0), d(sz,0);

    int x;
    rep(i,sa){
        if(i<sa-1) x = readIntSp(0,1e6);
        else x = readIntLn(0,1e6);
        a[x]++;
    }
    rep(i,sb){
        if(i<sb-1) x = readIntSp(0,1e6);
        else x = readIntLn(0,1e6);
        b[x]++;
    }
    rep(i,sc){
        if(i<sc-1) x = readIntSp(0,1e6);
        else x = readIntLn(0,1e6);
        c[x]++;
    }
    rep(i,sd){
        if(i<sd-1) x = readIntSp(0,1e6);
        else x = readIntLn(0,1e6);
        d[x]++;
    }

    vector<long long> v1 = poly_mul(a,b), v2 = poly_mul(c,d);

    int n = v1.size();
    rep_a(i,1,n) v1[i]+=v1[i-1];

    ll lo = 0, hi = 5e10;

    while(lo<hi){
        ll mid = (lo+hi)>>1;
        ll cnt = 0;
        int p1 = 0, p2 = n-1;
        while(p2>=0){
            while(p1<n && p1*p2<=mid) p1++;
            cnt += v2[p2]*(p1>0?v1[p1-1]:0);
            p2--;
        }

        if(cnt<k) lo = mid+1;
        else hi = mid;
    }

    cout<<hi<<'\n';
}
 
signed main()
{

    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r" , stdin);
    freopen("output.txt", "w" , stdout);
    #endif
    fast;
    
    int t = 1;
    
   //t = readIntLn(1,1e5);
    
    for(int i=1;i<=t;i++)
    {    
       solve();
    }
   
    assert(getchar() == -1);
    
    cerr<<"SUCCESS\n";
    cerr<<"Tests : " << t << '\n';
    // cerr<<"Sum of lengths : " << sum_m <<'\n';
    // cerr<<"Maximum length : " << max_n <<'\n';
    // // cerr<<"Total operations : " << total_ops << '\n';
    // cerr<<"Answered yes : " << yess << '\n';
    // cerr<<"Answered no : " << nos << '\n';

    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}
Editorialist

#include "bits/stdc++.h"
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

// https://judge.yosupo.jp/submission/69895
namespace ntt {

	template <class T, class F = multiplies<T>>
	T power(T a, long long n, F op = multiplies<T>(), T e = {1}) {
		// assert(n >= 0);
		T res = e;
		while (n) {
			if (n & 1) res = op(res, a);
			if (n >>= 1) a = op(a, a);
		}
		return res;
	}

	constexpr int mod = int(1e9) + 7;
	constexpr int nttmod = 998'244'353;

	template <std::uint32_t P>
	struct ModInt32 {
	   public:
		using i32 = std::int32_t;
		using u32 = std::uint32_t;
		using i64 = std::int64_t;
		using u64 = std::uint64_t;
		using m32 = ModInt32;
		using internal_value_type = u32;

	   private:
		u32 v;
		static constexpr u32 get_r() {
			u32 iv = P;
			for (u32 i = 0; i != 4; ++i) iv *= 2U - P * iv;
			return -iv;
		}
		static constexpr u32 r = get_r(), r2 = -u64(P) % P;
		static_assert((P & 1) == 1);
		static_assert(-r * P == 1);
		static_assert(P < (1 << 30));
		static constexpr u32 pow_mod(u32 x, u64 y) {
			u32 res = 1;
			for (; y != 0; y >>= 1, x = u64(x) * x % P)
				if (y & 1) res = u64(res) * x % P;
			return res;
		}
		static constexpr u32 reduce(u64 x) {
			return (x + u64(u32(x) * r) * P) >> 32;
		}
		static constexpr u32 norm(u32 x) { return x - (P & -(x >= P)); }

	   public:
		static constexpr u32 get_pr() {
			u32 tmp[32] = {}, cnt = 0;
			const u64 phi = P - 1;
			u64 m = phi;
			for (u64 i = 2; i * i <= m; ++i)
				if (m % i == 0) {
					tmp[cnt++] = i;
					while (m % i == 0) m /= i;
				}
			if (m != 1) tmp[cnt++] = m;
			for (u64 res = 2; res != P; ++res) {
				bool flag = true;
				for (u32 i = 0; i != cnt && flag; ++i)
					flag &= pow_mod(res, phi / tmp[i]) != 1;
				if (flag) return res;
			}
			return 0;
		}
		constexpr ModInt32() : v(0){};
		~ModInt32() = default;
		constexpr ModInt32(u32 _v) : v(reduce(u64(_v) * r2)) {}
		constexpr ModInt32(i32 _v) : v(reduce(u64(_v % P + P) * r2)) {}
		constexpr ModInt32(u64 _v) : v(reduce((_v % P) * r2)) {}
		constexpr ModInt32(i64 _v) : v(reduce(u64(_v % P + P) * r2)) {}
		constexpr ModInt32(const m32& rhs) : v(rhs.v) {}
		constexpr u32 get() const { return norm(reduce(v)); }
		explicit constexpr operator u32() const { return get(); }
		explicit constexpr operator i32() const { return i32(get()); }
		constexpr m32& operator=(const m32& rhs) { return v = rhs.v, *this; }
		constexpr m32 operator-() const {
			m32 res;
			return res.v = (P << 1 & -(v != 0)) - v, res;
		}
		constexpr m32 inv() const { return pow(P - 2); }
		constexpr m32& operator+=(const m32& rhs) {
			return v += rhs.v - (P << 1), v += P << 1 & -(v >> 31), *this;
		}
		constexpr m32& operator-=(const m32& rhs) {
			return v -= rhs.v, v += P << 1 & -(v >> 31), *this;
		}
		constexpr m32& operator*=(const m32& rhs) {
			return v = reduce(u64(v) * rhs.v), *this;
		}
		constexpr m32& operator/=(const m32& rhs) {
			return this->operator*=(rhs.inv());
		}
		friend m32 operator+(const m32& lhs, const m32& rhs) {
			return m32(lhs) += rhs;
		}
		friend m32 operator-(const m32& lhs, const m32& rhs) {
			return m32(lhs) -= rhs;
		}
		friend m32 operator*(const m32& lhs, const m32& rhs) {
			return m32(lhs) *= rhs;
		}
		friend m32 operator/(const m32& lhs, const m32& rhs) {
			return m32(lhs) /= rhs;
		}
		friend bool operator==(const m32& lhs, const m32& rhs) {
			return norm(lhs.v) == norm(rhs.v);
		}
		friend bool operator!=(const m32& lhs, const m32& rhs) {
			return norm(lhs.v) != norm(rhs.v);
		}
		friend std::istream& operator>>(std::istream& is, m32& rhs) {
			return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
		}
		friend std::ostream& operator<<(std::ostream& os, const m32& rhs) {
			return os << rhs.get();
		}
		constexpr m32 pow(i64 y) const {
			// assumes P is a prime
			i64 rem = y % (P - 1);
			if (y > 0 && rem == 0)
				y = P - 1;
			else
				y = rem;
			m32 res(1), x(*this);
			for (; y != 0; y >>= 1, x *= x)
				if (y & 1) res *= x;
			return res;
		}
	};

	using mint = ModInt32<nttmod>;

	void ntt(vector<mint>& a, bool inverse) {
		static array<mint, 30> dw{}, idw{};
		if (dw[0] == 0) {
			mint root = 2;
			while (power(root, (nttmod - 1) / 2) == 1) root += 1;
			for (int i = 0; i < 30; ++i)
				dw[i] = -power(root, (nttmod - 1) >> (i + 2)),
				idw[i] = 1 / dw[i];
		}
		int n = (int)a.size();
		assert((n & (n - 1)) == 0);
		if (not inverse) {
			for (int m = n; m >>= 1;) {
				mint w = 1;
				for (int s = 0, k = 0; s < n; s += 2 * m) {
					for (int i = s, j = s + m; i < s + m; ++i, ++j) {
						auto x = a[i], y = a[j] * w;
						a[i] = x + y;
						a[j] = x - y;
					}
					w *= dw[__builtin_ctz(++k)];
				}
			}
		} else {
			for (int m = 1; m < n; m *= 2) {
				mint w = 1;
				for (int s = 0, k = 0; s < n; s += 2 * m) {
					for (int i = s, j = s + m; i < s + m; ++i, ++j) {
						auto x = a[i], y = a[j];
						a[i] = x + y;
						a[j] = (x - y) * w;
					}
					w *= idw[__builtin_ctz(++k)];
				}
			}
			auto inv = 1 / mint(n);
			for (auto&& e : a) e *= inv;
		}
	}
	vector<mint> operator*(vector<mint> l, vector<mint> r) {
		if (l.empty() or r.empty()) return {};
		int n = (int)l.size(), m = (int)r.size(),
			sz = 1 << __lg(2 * (n + m - 1) - 1);
		if (min(n, m) < 30) {
			vector<mint> res(n + m - 1);
			for (int i = 0; i < n; ++i)
				for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
			return {begin(res), end(res)};
		}
		bool eq = l == r;
		l.resize(sz), ntt(l, false);
		if (eq)
			r = l;
		else
			r.resize(sz), ntt(r, false);
		for (int i = 0; i < sz; ++i) l[i] *= r[i];
		ntt(l, true), l.resize(n + m - 1);
		return l;
	}
	vector<mint>& operator*=(vector<mint>& l, vector<mint> r) {
		if (l.empty() or r.empty()) {
			l.clear();
			return l;
		}
		int n = (int)l.size(), m = (int)r.size(),
			sz = 1 << __lg(2 * (n + m - 1) - 1);
		if (min(n, m) < 30) {
			vector<mint> res(n + m - 1);
			for (int i = 0; i < n; ++i)
				for (int j = 0; j < m; ++j) res[i + j] += (l[i] * r[j]);
			l = {begin(res), end(res)};
			return l;
		}
		bool eq = l == r;
		l.resize(sz), ntt(l, false);
		if (eq)
			r = l;
		else
			r.resize(sz), ntt(r, false);
		for (int i = 0; i < sz; ++i) l[i] *= r[i];
		ntt(l, true), l.resize(n + m - 1);
		return l;
	}
}  // namespace ntt

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	using mint = ntt::mint;
	const int MAXN = 1e5 + 10;
	ll A, B, C, D, k; cin >> A >> B >> C >> D >> k;
	vector<mint> a(MAXN), b(MAXN), c(MAXN), d(MAXN);
	for (int i = 0; i < A; ++i) {
		int x; cin >> x;
		a[x] += 1;
	}
	for (int i = 0; i < B; ++i) {
		int x; cin >> x;
		b[x] += 1;
	}
	for (int i = 0; i < C; ++i) {
		int x; cin >> x;
		c[x] += 1;
	}
	for (int i = 0; i < D; ++i) {
		int x; cin >> x;
		d[x] += 1;
	}
	auto res1 = a*b, res2 = c*d;
	for (int i = 1; i < res2.size(); ++i) res2[i] += res2[i-1];
	ll lo = 0, hi = 1e11;
	while (lo < hi) {
		ll mid = (lo + hi)/2;
		ll leq = 0;
		for (int i = 0; i < res1.size(); ++i) {
			ll id = res2.size()-1;
			if (i) id = min(id, mid/i);
			// leq += (res1[i] * res2[id]).get();
			leq += 1LL * res1[i].get() * res2[id].get();
		}
		if (leq < k) lo = mid+1;
		else hi = mid;
	}
	cout << lo;
}

Small edit to render the sum properly:

f(X) = \sum_{r = 0}^{2 \cdot 10^5} \sum_{s = 0}^{\lfloor \frac{X}{r} \rfloor}P_r Q_s

Good catch, I’ve fixed it.