CTC - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: shubham_grg
Tester: mexomerf
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

(Multi)sets or prefix maximums

PROBLEM:

You are given N chains, each of length M.
The j-th vertex of the i-th chain has value A_{i, j}.
Find the minimum possible diameter attainable when joining these chains together to form a tree.

EXPLANATION:

First, we introduce some notation.

  • (i, j) denotes the j-th vertex of the i-th chain.
  • L_{i, j} denotes the sum A_{i, 1} + A_{i, 2} + \ldots + A_{i, j}, that is, the sum of elements to the left of (and including) (i, j).
  • R_{i, j} denotes the sum A_{i, j} + A_{i, j+1} + \ldots + A_{i, M}, that is, the sum of elements to the right of (and including) (i, j).
  • M_{i, j} = \max(L_{I, j}, R_{i, j}).
    This is the longest path within this chain that starts at (i, j).

To start off, there are way too many possibilities of joining the chains into a tree to even attempt to try them all.
Instead, we need to make some observation about the structure of the optimal solution.

Intuitively, it seems like a good idea to not have too many chains joined together in, well, a chain.
That is, it doesn’t look great if we have chain 1 connected to chain 2, chain 2 connected to chain 3, 3 connected to 4, and so on.

This intuition can be formalized as follows:
Claim: There exists an optimal solution that’s “star-like”.
Formally, there will exist some (i, j) such that for every chain k \neq i, k will be connected via an edge to (i, j).

Proof

If N = 1 or N = 2 the statement is trivially true, so we work with N\geq 3.

Consider an arbitrary way of adding edges to form a tree. We’ll convert this to a star without making the answer worse.

Recall the quantity M_{i, j} defined at the start.
Let (i_1, j_1) and (i_2, j_2) be two vertices with the maximum M_{i, j} values, across all vertices that have edges connecting them to another chain.
Note that we ensure i_1 \neq i_2 here: meaning (i_1, j_1) can be treated as the maximum, and (i_2, j_2) can be treated as the maximum among the other chains.
So, M_{i_1, j_1} \geq M_{i_2, j_2} henceforth.

Case 1: (i_1, j_1) and (i_2, j_2) are not directly connected with an edge.
Since we have a tree, there exists a path from (i_1, j_1) to (i_2, j_2). This path also includes at least one other vertex.
Choose any such vertex on this path, say (i_3, j_3).

Consider the following setup.
“Root” the tree at chain i_3.
Then, for every existing edge, disconnect it from its parent and connect it to (i_3, j_3) instead.

It can be seen that this doesn’t make the maximum path length worse:

  • Entire chains still exist as they were.
  • For paths between two chains that aren’t i_3, the path must pass through A_{i_3, j_3}.
    The maximum possible length is M_{i_1, j_1} + M_{i_2, j_2} + A_{i_3, j_3} which is not worse than what we had earlier.
  • For paths that start within chain i_3 and end in a different chain, the maximum possible cost is M_{i_1, j_1} + M_{i_3, j_3}.
    Such a path (or worse) also existed in the original tree, so once again we aren’t worse.

So, connecting everything to (i_3, j_3) doesn’t make things worse.

Case 2: (i_1, j_1) and (i_2, j_2) are directly connected with an edge.
In this case, it can be shown that either connecting every edge to (i_1, j_1) or (i_2, j_2) will not make things worse (just an in the previous case, you can analyze all three types of paths separately).

This proves that it’s always possible to convert any configuration to a ‘star’ with not worse answer, so it’s enough to look at stars to find the optimal answer.


With this in hand, let’s try to find the answer.
Let’s fix the vertex (i, j) that’s the center of the star.

Then, since every other chain k will be connected to (i, j), any path in the resulting tree will be one of these types:

  1. Entirely within some chain.
    The maximum here is clearly just whichever chain has the maximum sum. This doesn’t depend on (i, j).
  2. Starting in chain i and ending in chain k for some k \neq i.
    • Suppose (i, j) is connected to k at (k, x).
      The longest path of this form is then M_{i, j} + M_{k, x}.
    • The first term is a constant since i and j are fixed, only the second term can vary.
      Clearly, it’s best to choose whichever x minimizes M_{k, x}.
    • The M_{k,x} values can all be precomputed, and we can store \text{mn}_k to be the minimum within chain k.
      Then, we’re looking for the maximum of \text{mn}_k across all k \neq i.
    • To find such a maximum of all elements except one, either maintain prefix/suffix maximums or use a multiset, so that can you delete \text{mn}_i, find the maximum of the remaining elements, and then insert \text{mn}_i back in.
  3. Starting at chain k_1 and ending at chain k_2 for some k_1, k_2 \neq i.
    • Suppose k_1 and k_2 are connected at vertices x_1, x_2 respectively.
      The maximum cost of such a path is then M_{k_1, x_1} + M_{k_2, x_2} + A_{i, j}.
    • Once again, it’s clearly optimal to choose whichever x_1 and x_2 independently minimize these two maximums; so we just want the two maximum values of the \text{mn} array, excluding \text{mn}_i.
      This is essentially the same as the previous case, just that we need to find two maximums now rather than just one.

There are N\cdot M choices of (i, j), and each of them is processed in \mathcal{O}(\log N) time (if using a multiset), which is fast enough.
It’s possible to process each one in constant time using prefix/suffix maximums too, though you’ll need to store both the maximum and the second maximum, and do a bit of casework to combine prefixes and suffixes.

TIME COMPLEXITY:

\mathcal{O}(NM\log N) or \mathcal{O}(NM) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;

#define int                 long long int
#define fast                ios::sync_with_stdio(0),cin.tie(0), cout.tie(0);

void solve()
{
    int n, m; cin>>n>>m;
    vector<vector<int>>a(n, vector<int>(m));
    vector<int>cost(n), tot(n);

    int ans=0;
    for(int i=0; i<n; i++){
        for(int j=0; j<m; j++){
            cin>>a[i][j];
            tot[i]+=a[i][j];
        }
        ans=max(ans, tot[i]);
    }

    for(int i=0; i<n; i++)
    {
        int curr=0;
        cost[i]=tot[i];
        for(int j=0; j<m; j++)
        {
            curr+=a[i][j];
            cost[i]=min(cost[i], max(curr, tot[i]-curr+a[i][j]));
        }
    }

    if(n==2)
    {
        ans=max(ans, cost[0]+cost[1]);
        cout<<ans<<"\n";
        return;
    }

    multiset<int>s;

    for(int i=0; i<n; i++)
    {
        s.insert(cost[i]);
    }

    int ans2=1e18;
    for(int i=0; i<n; i++)
    {
        int curr=0;
        s.erase(s.find(cost[i]));
        int mx1=*s.rbegin();
        int mx2=*(++s.rbegin());
        for(int j=0; j<m; j++)
        {
            curr+=a[i][j];
            int mx=max(curr, tot[i]-curr+a[i][j]);
            int op1=mx1+mx2+a[i][j];
            int op2=mx1+mx;
            ans2=min(ans2, max(op1, op2));
        }
        s.insert(cost[i]);
    }

    cout<<max(ans, ans2)<<"\n";
}


signed main()
{
    fast
    int t=1;
    cin>>t;
    while(t--) solve();
}   
Tester's code (C++)
#include <bits/stdc++.h>
#define int long long
#define pii pair <int, int>
#define ff first
#define ss second
const long long INF = 2e17;
using namespace std;
long long readInt(long long l, long long r, char endd) {
    long long x = 0;
    int cnt = 0;
    int fi = -1;
    bool is_neg = false;
    while (true) {
        char g = getchar();
        if (g == '-') {
            assert(fi == -1);
            is_neg = true;
            continue;
        }
        if ('0' <= g && g <= '9') {
            x *= 10;
            x += g - '0';
            if (cnt == 0) {
                fi = g - '0';
            }
            cnt++;
            assert(fi != 0 || cnt == 1);
            assert(fi != 0 || is_neg == false);
 
            assert(!(cnt > 19 || (cnt == 19 && fi > 1)));
        }
        else if (g == endd) {
            assert(cnt > 0);
            if (is_neg) {
                x = -x;
            }
            assert(l <= x && x <= r);
            return x;
        }
        else {
            assert(false);
        }
    }
}
 
string readString(int l, int r, char endd) {
    string ret = "";
    int cnt = 0;
    while (true) {
        char g = getchar();
        assert(g != -1);
        if (g == endd) {
            break;
        }
        cnt++;
        ret += g;
    }
    assert(l <= cnt && cnt <= r);
    return ret;
}
int32_t main() {
	int t;
	//cin>>t;
	t = readInt(1, 100000, '\n');
	assert(t<=1e5 && t>=1);
	int mulsum=0;
	while(t--)
	{
	    int n, m;
	    //cin>>n>>m;
	    n = readInt(2, 100000, ' ');
	    m = readInt(1, 100000, '\n');
	    assert(n>=2 && n<=1e5);
	    assert(m>=1 && m<=1e5);
	    mulsum+=(n*m);
	    assert(mulsum<=5e5);
	    int a[n][m], sum[n];
	    memset(sum, 0, sizeof(sum));
	    pii sum1={-INF, 0}, sum2={-INF, 0};
	    pii half1={-INF, 0}, half2={-INF, 0}, half3={-INF, 0};
	    for(int i=0;i<n;i++)
	    {
	        for(int j=0;j<m;j++)
	        {
	            //cin>>a[i][j];
	            if(j != m - 1){
	                a[i][j] = readInt(1, 1000000000, ' ');
	            }else{
	                a[i][j] = readInt(1, 1000000000 ,'\n');
	            }
	            sum[i]+=a[i][j];
	        }
	        if(sum[i]>sum1.ff)
	        {
	            sum2=sum1;
	            sum1={sum[i], i};
	        }
	        else if(sum[i]>sum2.ff)
	        {
	            sum2={sum[i], i};
	        }
	        int half=sum[i], lef=0;
	        for(int j=0;j<m;j++)
	        {
	            lef+=a[i][j];
	            half=min(half, max(lef, sum[i]-lef+a[i][j]));
	        }
	        if(half>half1.ff)
	        {
	            half3=half2;
	            half2=half1;
	            half1={half, i};
	        }
	        else if(half>half2.ff)
	        {
	            half3=half2;
	            half2={half, i};
	        }
	        else if(half>half3.ff)
	        {
	            half3={half, i};
	        }
	        
	    }
	    int diam=INF;
	    for(int i=0;i<n;i++)
	    {
	        int lef=0;
	        for(int j=0;j<m;j++)
	        {
	            int tot=(sum1.ss!=i ? sum1.ff : sum2.ff);
	            int half=(half1.ss!=i ? half1.ff : half2.ff);
	            int twohalf=half1.ff+half2.ff;
	            if(i==half1.ss)
	            twohalf=half2.ff + half3.ff;
	            if(i==half2.ss)
	            twohalf=half1.ff + half3.ff;
	            diam=min(diam, max({sum[i], tot, max(lef+a[i][j], sum[i]-lef) + half, a[i][j] + twohalf}));
	            lef+=a[i][j];
	        }
	    }
	    cout<<diam<<"\n";
	}
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, m; cin >> n >> m;
        vector a(n, vector(m, 0));
        for (auto &v : a) for (auto &x : v)
            cin >> x;
        
        vector<ll> best(n, LLONG_MAX);
        ll ans1 = 0, ans2 = LLONG_MAX;
        for (int i = 0; i < n; ++i) {
            ll left = 0, right = accumulate(begin(a[i]), end(a[i]), 0ll);
            ans1 = max(ans1, right);
            for (int j = 0; j < m; ++j) {
                left += a[i][j];
                best[i] = min(best[i], max(left, right));
                right -= a[i][j];
            }
        }

        multiset<ll> bests(begin(best), end(best));
        for (int i = 0; i < n; ++i) {
            ll left = 0, right = accumulate(begin(a[i]), end(a[i]), 0ll);
            bests.erase(bests.find(best[i]));
            for (int j = 0; j < m; ++j) {
                left += a[i][j];

                ll cur = 0;
                // this branch and something else
                cur = max(left, right) + *bests.rbegin();
                // two branches meeting here
                if (n > 2) {
                    auto it = rbegin(bests);
                    cur = max(cur, a[i][j] + *it + *next(it));
                }
                ans2 = min(ans2, cur);
                right -= a[i][j];
            }
            bests.insert(best[i]);
        }
        cout << max(ans1, ans2) << '\n';
    }
}
1 Like

Hi! Are there really N \cdot M choices of (i, j)? I mean, it seems like it’s worth checking only (i, argmin_j M_{i,j})

If it isn’t a case, can you please show some example, where this logic fails