MAXREV - Editorial

Maximize the Revenue - EDITORIAL:



Author: karthikeya619

Editorialist: karthikeya619




Graphs, MST, Greedy, Binary Search, Dynamic Programming, Alien Trick


Given a Graph G with N nodes and M undirected edges and each of the edge having 2 weights C_i and R_i (1 \leq i \leq M) where all C_i's are unique. Find set of edges using which all the nodes
are reachable and \Sigma C_i for those edges is minimized. Now find a K-size matching in this tree
maximizing \Sigma R_i on the edges selected in matching.


Find the MST of the given graph as mentioned all C_i's are unique, so the graph has only one MST. If there was no constraint on size of matching the maximum weight matching can be
solved by using simple DFS. Now let us add an extra weight \lambda to all the edges. If \lambda is very high(0)
the answer would be the size of maximum matching and if it is very low(-\inf) then the answer
would be 0. Binary search on \lambda to find the maximum weight matching of size K.


The first part of the problem is straight forward so we just have to find MST of the given graph
using C_i's as the weights of edges. As mentioned that C_i's are unique there exists a unique MST
for given graph. Now let us solve the second part of the question to find the heaviest K size
matching in this tree.

A matching is nothing but a subset of edges in a graph such that no two edges are adjacent(i.e. no two edges in this set share a vertex). Let us assume that there was no constraint on size and our task is just to find the heaviest matching in the given tree.

This problem can be solved by a DFS. Root the tree at any node, for each node u find the heaviest matching in its subtree in two case where u is a part of matching and u is not a part of matching. Now just update the values of all the nodes by using the results of respective childs.

Now let us suppose we add a constant \lambda to the weight of each edge and solve the subproblem
discussed earlier which doesn’t have any constraint on the size of matching.

If \lambda is high(0) the answer would be the size of maximum matching and if it is very low(-\inf) then the answer would be 0. So the size of the matching increases as we increase the value of \lambda hence we can binary search on \lambda to find the heaviest matching on size K.
This is inspired by the trick from aliens of IOI 2016. One main thing to check if we can apply this
trick is to check if the total weight of matching is concave w.r.t the size of matching. The proof of concavity is left as an exercise to the reader :blush:.
Hint: Model it as min-cost max-flow problem by negating all the costs.

Please refer this link for more details on Alien Trick.

Alien Trick


Complexity of MST is O(M log^*N) for Kruskal and complexity of DFS inside a binary search is N log(max(R_i)). So the overall complexity of this solution is
O(M log^*N + N log(max(R_i)))


Setter's Solution
#include <bits/stdc++.h>
using namespace std;
using lint = long long;
using pi = pair<lint, int>;
using vi = vector<int>;
using Wedges = vector<vector<int>>;
#define FOR(i,k,m) for(lint i=k; i<m ;i++)

const int MAXN = 500005;
int n, m, k;
map<pi, int> new_weights;
vector<pi> gph[MAXN];
pi dp[MAXN][2];

class dsu{
    vi p;
    int n;
    int cmp;
    dsu(int _n): n(_n){
    inline int find(int x){
        return (x==p[x]? x : (p[x]=find(p[x])));
    inline int unite(int x, int y){
        x = find(x);
        y = find(y);
        if (x!=y){
            return 1;
        return 0;

vector<pi> kruskal_mst(Wedges &edg, int n, int m){
lint ans=0;
dsu d(n);
int u,v;
lint w;
vector<pi> mst;
    u = edg[i][1]; v = edg[i][2]; w = edg[i][0];
    if (d.unite(u,v)){
    return mst;
void dfs(int x, int p, lint m){
    pi one_up(0, 0), sum(0, 0);
    for(auto &i : gph[x]){
        int v, c; tie(c, v) = i;
        if(v == p) continue;
        dfs(v, x, m);
        sum.first += dp[v][0].first;
        sum.second += dp[v][0].second;
        one_up = max(one_up, pi(dp[v][1].first - dp[v][0].first + c - m, dp[v][1].second - dp[v][0].second + 1));
    dp[x][1] = sum;
    dp[x][0] = pi(sum.first + one_up.first, sum.second + one_up.second);

int main(){
    scanf("%d %d %d",&n, &m ,&k);
    Wedges graph; 
    for(int i=0; i<m; i++){
        int u, v, w1, w2;
        scanf("%d %d %d %d",&u,&v,&w1,&w2);

    vector<pi> mst = kruskal_mst(graph,n,m);
    for(auto i: mst)
        int u = i.first; int v = i.second;

    dfs(1, 0, 0);

    if(dp[1][0].second < k){
        return 0;

    lint s = 0, e = 1e6;

    while(s != e){
        lint m = s + (e - s + 1) / 2;
        dfs(1, 0, m);
        if(dp[1][0].second >= k) s = m;
        else e = m - 1;

    dfs(1, 0, s);

    cout << dp[1][0].first + s * k << endl;
Tester's Solution
import sys
input = sys.stdin.readline
from sys import stdin, stdout
from collections import defaultdict, Counter
M = 10**9+7

class dsu:

    def __init__(self,n):
        self.n = n
        self.p = [i for i in range(n)]

    def find(self,x):
        if x==self.p[x]:
            return x
            self.p[x] = self.find(self.p[x])
            return self.p[x]
    def unite(self,x,y):
        x = self.find(x)
        y = self.find(y)
            self.p[y] = x
            return 1
        return 0

def dfs(adj,u,par,dp,lambd):
    su = [0,0]
    edg = (0,0)
    for v,w in adj[u]:
        if v!=par:
            su[0] += dp[v][0][0]
            su[1] += dp[v][0][1]
            edg = max(edg, (dp[v][1][0]-dp[v][0][0] + w - lambd, dp[v][1][1]-dp[v][0][1]+1))

    dp[u][1] = su
    dp[u][0] = [edg[0]+su[0],su[1]+edg[1]]

def main():

    n,m,k = [int(s) for s in input().split()]
    assert 1<=n and n<=5*10**5
    assert 1<=m and m<=5*10**5
    assert 1<=k and k<=n-1
    edges = []

    for i in range(m):
        u,v,w1,w2 = [int(s) for s in input().split()]
        assert 1<=w1 and w1<=10**6
        assert 1<=w2 and w2<=10**6

    adj = [[] for i in range(n)]
    d = dsu(n)

    for edg in edges:
        if d.unite(edg[1],edg[2]):
    lo = 0
    hi = 10**6
    dp = [[[0,0],[0,0]] for i in range(n)]
    if(dp[0][0][1] <k):
    while lo < hi:
        mid = lo + (hi-lo+1)//2
        if dp[0][0][1] >= k:
            lo = mid
            hi = mid-1

if __name__== '__main__':