GRIDMEET - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Authors: shubham_grg
Testers: iceknight1093, tabr
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Sweep line, ternary search and fenwick/segment tree or two pointers

PROBLEM:

N people are on a grid, the i-th at position (x_i, y_i).

Find the smallest total cost required to make K of them meet at a single point; where the distance is Manhattan distance.

EXPLANATION:

First off, it should be obvious that the optimal meeting point, say (a, b), will be such that a is the x-coordinate of one of the people and b is the y-coordinate of one of the people.

This gives us N^2 potential meeting points, so we need to be able to compute the best answer for them all.

For a fixed meeting point (a, b), computing the answer in \mathcal{O}(N) or \mathcal{O}(N\log N) isn’t too hard: compute the distance of all N points from this one, and take the K smallest of them.
That’s obviously too slow though.

Let’s fix the value of b and sweep the value of a from small to large.
This has the following effect as a varies:

  • The distance of (x_i, y_i) from (a, b) is |x_i - a| + |y_i - b|.
  • In particular, this is x_i - a + |y_i - b| if x_i \geq a and a-x_i + |y_i-b| otherwise.
    Note that the second part of the sum (|y_i - b|) remains constant as a changes.

So, once b is fixed, each point has a ‘base’ value |y_i - b|, and a ‘dynamic’ value, which is either x_i - a or a - x_i depending on the current value of a.

Let’s say a point is on the left if x_i \lt a, and on the right otherwise.
If we sweep a from low to high, every point initially starts out on the right, and then one by one points move to the left.

Let L denote the multiset of values (i.e, base value + dynamic value) of the left points, and R denote the same for the set of right points.
Once L and R are fixed (which is unique when a and b are fixed), our aim is then to find k_1 and k_2 such that:

  • k_1 + k_2 = K
  • S_1 + S_2 is minimized, where
    • S_1 is the sum of the k_1 smallest values in L
    • S_2 is the sum of the k_2 smallest values in R

We have two objectives: maintain L and R appropriately, and find the optimal k_1 and k_2.
There are a couple of ways to do this.

Setter's method (two pointers)

Suppose we know the k_1 and k_2 values for a fixed a and b.
Let’s move the sweep line by one step, say to a'.

This moves a few points from R into L, so let’s do that.
Suppose c points were moved.

The main observation here is that after this movement, the new optimal k_1 value won’t change too much: in fact, it’ll change by \mathcal{O}(c).

This allows us to ‘adjust’ the k_1 and k_2 values by maintaining pointers to their current values, and moving them by \mathcal{O}(c) to compute the new optimum.
This way, L and R can be maintained as multisets while k_1 and k_2 can be kept as iterators to these multisets; so that they remain sorted after insertion/deletion.

The implementation of this can be found in the setter’s code below.

Editorialist's method (ternary search + fenwick)

Let f(k_1) denote the cost of taking the smallest k_1 elements from L and smallest k_2 elements from R.

Intuitively, as we move k_1 from 0 to K, f(k_1) should decrease, hit a minimum, then increase again.
This allows us to ternary search on f to find the minimum f(k_1).

In particular, evaluating f(k_1) requires us to quickly compute the sum of the smallest k_1 elements in L and smallest k_2 elements in R.

So, we need to maintain L and R in a data structure that supports:

  • Elements can be quickly added/removed from the set
  • Query for the sum of the smallest k_1 elements of the set

This can be accomplished with a Fenwick tree, built on values.

First, precompute all the ‘left’ values for all the points, and put them in sorted order.
Build a Fenwick tree on this sorted order, with each element initially zero.
Do the same thing for the ‘right’ values, but this time populate the fenwick tree with the appropriate values as well.

Now,

  • Removing an element from the ‘right’ Fenwick tree just means setting the corresponding position to 0
  • Adding an element to the ‘left’ Fenwick tree means setting the value to the corresponding position
  • Computing the sum of the smallest k_1 non-zero values in a Fenwick tree is also doable in \mathcal{O}(\log N) with binary search.

This gives us a solution in \mathcal{O}(N^2 \log^2 N), since each position requires \mathcal{O}(\log^2 N) work and there are N^2 positions to consider.
Moving from one position to the next during the sweep is \mathcal{O}(\log N) work which is dwarfed by the ternary search.

TIME COMPLEXITY

\mathcal{O}(N^2 \log N) or \mathcal{O}(N^2\log^2 N) per test case.

CODE:

Setter's code (C++)
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll inf=1e16;

#ifdef ANI
#include "D:/DUSTBIN/local_inc.h"
#else
#define dbg(...) 0
#endif

void solve(ll &tot) {
	ll n,k;
	cin>>n>>k;
	ll nax=1e9;
	assert(n<=1000 && k<=n);
	tot+=n;
	map<ll,vector<ll>> xp,yp;
	vector<vector<ll>> a(n,vector<ll>(2));
	for(int i=0;i<n;i++) {
		ll x,y;
		cin>>x>>y;
		xp[x].push_back(i);
		yp[y].push_back(i);
		a[i]={x,y};
		assert(abs(x)<=nax && abs(y)<=nax);
	}
	vector<ll> xc;
	for(auto el:xp) xc.push_back(el.first);
	ll ans=inf;
	for(int i=0;i<xc.size();i++) {
		set<pair<ll,ll>> uu,dd,uh; vector<bool> dh(n,0);
		ll xx=xc[i],yy=yp.begin()->first;
		for(auto el:yp) {
			auto pts=el.second;
			ll yi=el.first;
			for(ll ii:pts) {
				ll xi=a[ii][0];
				dd.insert({abs(xx-xi)+abs(yy-yi),ii});
			}
		}
		ll cur=0,mv=0,dct=0,uct=0; // mv: how much we have moved down
		for(auto it=yp.begin();it!=yp.end();it++) {
			while(!uh.empty()) {
				auto it=uh.end(),jt=dd.begin();
				it--;
				if(jt==dd.end() or jt->first-mv>it->first+mv) break;
				cur-=it->first+mv;
				uu.insert(*it);
				uh.erase(it);
			}
			while(dct+uh.size() < k) {
				auto it=dd.begin(),jt=uu.begin();
				if(it==dd.end()) {
					cur+=jt->first+mv;
					uh.insert(*jt);
					uu.erase(jt);
				} else if(jt==uu.end()) {
					cur+=it->first-mv;
					dh[it->second]=1; dct++;
					dd.erase(it);
				} else {
					ll gd=it->first-mv,gu=jt->first+mv;
					if(gd<gu) {
						cur+=it->first-mv;
						dh[it->second]=1; dct++;
						dd.erase(it);
					} else {
						cur+=jt->first+mv;
						uh.insert(*jt);
						uu.erase(jt);
					}
				}
			}
			ans=min(ans,cur);
			for(ll ii:it->second) {
				ll xi=a[ii][0],cost=abs(xx-xi)+abs(yy-it->first);
				uu.insert({cost-2*mv,ii});
				if(dh[ii]) {
					dh[ii]=0;
					dct--;
					cur-=cost-mv;
				} else dd.erase({cost,ii});
			}
			auto jt=it; jt++;
			ll delta=0;
			if(jt!=yp.end()) {
				delta=jt->first-it->first;
				mv+=delta;
			}
			cur+=(uh.size()-dct)*delta;
		}
	}
	cout<<ans<<"\n";
}

int main() {
	ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
	int t;
	cin>>t;
	assert(t<=1000);
	ll tot=0;
	while(t--) {
		solve(tot);
	}	
	assert(tot<=1000);
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#ifdef tabr
#include "library/debug.cpp"
#else
#define debug(...)
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        string res;
        while (pos < (int) buffer.size() && buffer[pos] != ' ' && buffer[pos] != '\n') {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int min_len, int max_len, const string& pattern = "") {
        assert(min_len <= max_len);
        string res = readOne();
        assert(min_len <= (int) res.size());
        assert((int) res.size() <= max_len);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int min_val, int max_val) {
        assert(min_val <= max_val);
        int res = stoi(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    long long readLong(long long min_val, long long max_val) {
        assert(min_val <= max_val);
        long long res = stoll(readOne());
        assert(min_val <= res);
        assert(res <= max_val);
        return res;
    }

    vector<int> readInts(int size, int min_val, int max_val) {
        assert(min_val <= max_val);
        vector<int> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readInt(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    vector<long long> readLongs(int size, long long min_val, long long max_val) {
        assert(min_val <= max_val);
        vector<long long> res(size);
        for (int i = 0; i < size; i++) {
            res[i] = readLong(min_val, max_val);
            if (i != size - 1) {
                readSpace();
            }
        }
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

template <typename T>
struct fenwick {
    int n;
    vector<T> node;

    fenwick(int _n) : n(_n) {
        node.resize(n);
    }

    void add(int x, T v) {
        while (x < n) {
            node[x] += v;
            x |= (x + 1);
        }
    }

    T get(int x) {  // [0, x]
        T v = 0;
        while (x >= 0) {
            v += node[x];
            x = (x & (x + 1)) - 1;
        }
        return v;
    }

    T get(int x, int y) {  // [x, y]
        return (get(y) - (x ? get(x - 1) : 0));
    }

    int lower_bound(T v) {
        int x = 0;
        int h = 1;
        while (n >= (h << 1)) {
            h <<= 1;
        }
        for (int k = h; k > 0; k >>= 1) {
            if (x + k <= n && node[x + k - 1] < v) {
                v -= node[x + k - 1];
                x += k;
            }
        }
        return x;
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    input_checker in;
    int tt = in.readInt(1, 1000);
    in.readEoln();
    int sn = 0;
    while (tt--) {
        int n = in.readInt(1, 1000);
        in.readSpace();
        int k = in.readInt(1, n);
        in.readEoln();
        sn += n;
        vector<long long> x(n), y(n);
        for (int i = 0; i < n; i++) {
            x[i] = in.readInt(-1e9, 1e9);
            in.readSpace();
            y[i] = in.readInt(-1e9, 1e9);
            in.readEoln();
        }
        long long ans = 9e18;
        for (auto xx : x) {
            vector<pair<long long, long long>> p(n);
            for (int i = 0; i < n; i++) {
                p[i].first = y[i];
                p[i].second = abs(x[i] - xx);
            }
            sort(p.begin(), p.end());
            vector<long long> a(n), b(n);
            for (int i = 0; i < n; i++) {
                a[i] = p[i].second - p[i].first;
                b[i] = p[i].second + p[i].first;
            }
            vector<int> oa(n);
            iota(oa.begin(), oa.end(), 0);
            sort(oa.begin(), oa.end(), [&](int i, int j) {
                return a[i] < a[j];
            });
            vector<int> pa(n);
            for (int i = 0; i < n; i++) {
                pa[oa[i]] = i;
            }
            vector<int> ob(n);
            iota(ob.begin(), ob.end(), 0);
            sort(ob.begin(), ob.end(), [&](int i, int j) {
                return b[i] < b[j];
            });
            vector<int> pb(n);
            for (int i = 0; i < n; i++) {
                pb[ob[i]] = i;
            }
            fenwick<int> ca(n), cb(n);
            fenwick<long long> sa(n), sb(n);
            for (int i = 0; i < n; i++) {
                cb.add(pb[i], 1);
                sb.add(pb[i], b[i]);
            }
            auto ta = a, tb = b;
            sort(ta.begin(), ta.end());
            sort(tb.begin(), tb.end());
            for (int i = 0; i < n; i++) {
                auto GetC = [&](long long t) {
                    int res = 0;
                    {
                        int j = (int) (upper_bound(ta.begin(), ta.end(), t - p[i].first) - ta.begin());
                        if (j > 0) {
                            res += ca.get(j - 1);
                        }
                    }
                    {
                        int j = (int) (upper_bound(tb.begin(), tb.end(), t + p[i].first) - tb.begin());
                        if (j > 0) {
                            res += cb.get(j - 1);
                        }
                    }
                    return res;
                };
                auto GetS = [&](long long t) {
                    long long res = 0;
                    {
                        int j = (int) (upper_bound(ta.begin(), ta.end(), t - p[i].first) - ta.begin());
                        if (j > 0) {
                            res += sa.get(j - 1);
                            res += ca.get(j - 1) * p[i].first;
                        }
                    }
                    {
                        int j = (int) (upper_bound(tb.begin(), tb.end(), t + p[i].first) - tb.begin());
                        if (j > 0) {
                            res += sb.get(j - 1);
                            res -= cb.get(j - 1) * p[i].first;
                        }
                    }
                    return res;
                };
                cb.add(pb[i], -1);
                sb.add(pb[i], -b[i]);
                ca.add(pa[i], 1);
                sa.add(pa[i], a[i]);
                long long low = -1, high = 3e9;
                while (high - low > 1) {
                    long long mid = (high + low) >> 1;
                    if (GetC(mid) >= k) {
                        high = mid;
                    } else {
                        low = mid;
                    }
                }
                ans = min(ans, GetS(high) - (GetC(high) - k) * high);
            }
        }
        cout << ans << '\n';
    }
    assert(sn <= 1000);
    in.readEof();
    return 0;
}
Editorialist's code (C++)
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

struct FT {
	vector<ll> s;
	FT(int n) : s(n) {}
	void update(int pos, ll dif) { // a[pos] += dif
		for (; pos < (int)size(s); pos |= pos + 1) s[pos] += dif;
	}
	ll query(int pos) { // sum of values in [0, pos)
		ll res = 0;
		for (; pos > 0; pos &= pos - 1) res += s[pos-1];
		return res;
	}
	int lower_bound(ll sum) {// min pos st sum of [0, pos] >= sum
		// Returns n if no sum is >= sum, or -1 if empty sum is.
		if (sum <= 0) return -1;
		int pos = 0;
		for (int pw = 1 << 12; pw; pw >>= 1) {
			if (pos + pw <= (int)size(s) && s[pos + pw-1] < sum)
				pos += pw, sum -= s[pos-1];
		}
		return pos;
	}
	void reset() {
		s.assign(s.size(), 0);
	}
};

int main()
{
	ios::sync_with_stdio(false); cin.tie(0);

	int t; cin >> t;
	while (t--) {
		int n, k; cin >> n >> k;
		vector<array<ll, 2>> pts(n);
		set<int> ycoords;
		for (auto &[x, y] : pts) {
			cin >> x >> y;
			ycoords.insert(y);
		}
		sort(begin(pts), end(pts));

		ll ans = 1e18;
		FT lct(n), rct(n), lsum(n), rsum(n);
		for (auto Y : ycoords) {
			lct.reset(), rct.reset(), lsum.reset(), rsum.reset();

			vector<array<ll, 2>> leftvals, rightvals;
			for (int i = 0; i < n; ++i) {
				auto &[x, y] = pts[i];
				leftvals.push_back({abs(Y - y) - x, i});
				rightvals.push_back({abs(Y - y) + x, i});
			}
			sort(begin(leftvals), end(leftvals));
			sort(begin(rightvals), end(rightvals));

			vector<int> leftpos(n), rightpos(n);
			for (int i = 0; i < n; ++i) {
				leftpos[leftvals[i][1]] = i;
				rightpos[rightvals[i][1]] = i;
			}

			for (int i = 0; i < n; ++i) {
				rct.update(i, 1);
				rsum.update(i, rightvals[i][0]);
			}

			for (int i = 0; i < n; ++i) {
				int u = rightpos[i], v = leftpos[i];
				auto &[x, y] = pts[i];
				ll rval = abs(Y - y) + x, lval = abs(Y - y) - x;

				rct.update(u, -1); rsum.update(u, -rval);
				lct.update(v, 1); lsum.update(v, lval);

				if (i > 0 and x == pts[i-1][0]) continue;

				{
					auto f = [&] (int lt) {
						int rt = k - lt;
						int L = lct.lower_bound(lt), R = rct.lower_bound(rt);
						if (L == n or R == n) return (ll) 1e18;
						return lsum.query(L+1) + rsum.query(R+1) + 1LL*x*(lt - rt);
					};
					int lo = 0, hi = k;
					while (hi - lo > 2) {
						int mid = (lo + hi)/2;
						if (f(mid) > f(mid + 1)) lo = mid;
						else hi = mid+1;
					}
					for (int i = lo; i <= hi; ++i) ans = min(ans, f(i));
				}
			}
		}
		cout << ans << '\n';
	}
}