FINDIAMETER - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: satyam_343
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

2770

PREREQUISITES:

Dijkstra’s algorithm

PROBLEM:

There’s a complete graph on N vertices, with the weight of the edge between i and j being \min(|i-j|, |A_i-A_j|).
Find the diameter of this graph.

EXPLANATION:

Computing the diameter of an arbitrary weighted graph is not easy - generally, one needs to run Floyd-Warshall to compute all-pair shortest paths, and take the maximum of them all.
That’s too slow here, so we need to use the structure of the weights to do better.

Let w(i, j) denote the weight of the edge between i and j.
Then, since w(i, j) = \min(|i-j|, |A_i-A_j|), there are two possibilities: either w(i, j) = |i-j|, or w(i, j) = |A_i - A_j|.

If w(i, j) = |i-j|, this is essentially saying that we start at i, then take one step at a time till we reach j - that is, moving i \to (i+1) \to (i+2) \to\ldots\to j (assuming i \lt j, of course - repeatedly subtract 1 rather than add otherwise).

If w(i, j) = |A_i - A_j|, we can say something similar.
Let A_i \leq A_{x_1} \leq A_{x_2} \leq \ldots\leq A_{x_k} \leq A_j.
Then, instead of directly moving from i to j, moving via i \to x_1 \to x_2\to\ldots\to x_k \to j still incurs the same cost, since the costs telescope as they sum up.

Together, these observations tell us that it’s enough to only consider edges between ‘close’ vertices of either type — the length of the shortest path will be preserved in such a graph.
That is, consider a graph on N vertices as follows:

  • For each i, there’s an edge between (i, i-1) and (i, i+1) with weight 1.
  • Let A_{x_1} \leq A_{x_2} \leq\ldots\leq A_{x_N} be the sorted order of vertex values.
    For each i, create the edge (x_{i-1}, x_i) with weight A_{x_i} - A_{x_{i-1}}, and the edge (x_{i+1}, x_i) with weight A_{x_{i+1}} - A_{x_{i}}.

Then, for any pair of vertices (u, v), the length of the shortest path from u to v in this new graph equals the length of the shortest path from u to v in the original graph.
However, the new graph only has at most 2N - 2 edges!

So, for a fixed source vertex u, finding the shortest path from u to all other vertices can be done using Dijkstra’s algorithm in \mathcal{O}(N\log N) time (it’s \mathcal{O}(E\log V), but E = \mathcal{O}(N) here).
So, all pairs of shortest paths can be found in \mathcal{O}(N^2\log N) time, after which the maximum of them gives the answer.

It’s also possible to implement Dijkstra’s algorithm to run in \mathcal{O}(N) time per vertex, using the fact that the maximum distance to any other node is N-1.
The author’s and editorialist’s code implement this, though it wasn’t required to get AC.

TIME COMPLEXITY

\mathcal{O}(N^2) per testcase.

CODE:

Author's code (C++)
#pragma GCC optimod_intze("O3,unroll-loops")
#include <bits/stdc++.h>   
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
using namespace __gnu_pbds;   
using namespace std;
#define ll long long
const ll INF_MUL=1e15;
const ll INF_ADD=1e18;
#define pb push_back               
#define mp make_pair          
#define nline "\n"                           
#define f first                                          
#define s second                                             
#define pll pair<ll,ll> 
#define all(x) x.begin(),x.end()     
#define vl vector<ll>           
#define vvl vector<vector<ll>>    
#define vvvl vector<vector<vector<ll>>>          
#ifndef ONLINE_JUDGE    
#define debug(x) cerr<<#x<<" "; _print(x); cerr<<nline;
#else
#define debug(x);  
#endif     
void _print(ll x){cerr<<x;}  
void _print(char x){cerr<<x;}   
void _print(string x){cerr<<x;}    
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());   
template<class T,class V> void _print(pair<T,V> p) {cerr<<"{"; _print(p.first);cerr<<","; _print(p.second);cerr<<"}";}
template<class T>void _print(vector<T> v) {cerr<<" [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T>void _print(set<T> v) {cerr<<" [ "; for (T i:v){_print(i); cerr<<" ";}cerr<<"]";}
template<class T>void _print(multiset<T> v) {cerr<< " [ "; for (T i:v){_print(i);cerr<<" ";}cerr<<"]";}
template<class T,class V>void _print(map<T, V> v) {cerr<<" [ "; for(auto i:v) {_print(i);cerr<<" ";} cerr<<"]";} 
typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;
typedef tree<ll, null_type, less_equal<ll>, rb_tree_tag, tree_order_statistics_node_update> ordered_multiset;
typedef tree<pair<ll,ll>, null_type, less<pair<ll,ll>>, rb_tree_tag, tree_order_statistics_node_update> ordered_pset;
//--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------  
const ll MOD=998244353;
const ll MAX=500500;
void solve(){  
    ll n; cin>>n;
    vector<ll> a(n+5);
    vector<pair<ll,ll>> trav;
    for(ll i=1;i<=n;i++){
        cin>>a[i];
        trav.push_back({a[i],i});
    }
    vector<pair<ll,ll>> adj[n+5];
    sort(all(trav));
    for(ll i=1;i<=n-1;i++){
        adj[i].push_back({i+1,1});
        adj[i+1].push_back({i,1});
    }
    for(ll i=0;i<=n-2;i++){
        auto l=trav[i],r=trav[i+1];
        adj[l.s].push_back({r.s,r.f-l.f});
        adj[r.s].push_back({l.s,r.f-l.f});
    }
    ll ans=0;
    for(ll i=1;i<=n;i++){
        vector<ll> dist(n+5,INF_ADD);  
        vector<ll> track[n+5];
        vector<ll> len(n+5,0);
        for(ll j=1;j<=n;j++){  
            dist[j]=abs(i-j);
            track[dist[j]].push_back(j);  
            len[dist[j]]++;          
        }      
        vector<ll> visited(n+5,0);       
        for(ll j=0;j<=n-1;j++){  
            for(ll k=0;k<len[j];k++){    
                if(visited[track[j][k]]){    
                    continue;  
                } 
                ans=max(ans,j);
                ll node=track[j][k];  
                visited[node]=1;
                for(auto it:adj[node]){
                    if(dist[it.f]>dist[node]+it.s){
                        dist[it.f]=dist[node]+it.s;
                        track[dist[it.f]].push_back(it.f);
                        len[dist[it.f]]++;
                    }   
                }
            }  
        }   
    }  
    cout<<ans<<nline;      
    return;    
}                                          
int main()                                                                               
{     
    ios_base::sync_with_stdio(false);                         
    cin.tie(NULL);                              
    #ifndef ONLINE_JUDGE                 
    freopen("input.txt", "r", stdin);                                           
    freopen("output.txt", "w", stdout);      
    freopen("error.txt", "w", stderr);                        
    #endif     
    ll test_cases=1;               
    cin>>test_cases;
    while(test_cases--){
        solve();
    }
    cout<<fixed<<setprecision(10);
    cerr<<"Time:"<<1000*((double)clock())/(double)CLOCKS_PER_SEC<<"ms\n"; 
}  
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
	string buffer;
	int pos;

	const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
	const string number = "0123456789";
	const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
	const string lower = "abcdefghijklmnopqrstuvwxyz";

	input_checker() {
		pos = 0;
		while (true) {
			int c = cin.get();
			if (c == -1) {
				break;
			}
			buffer.push_back((char) c);
		}
	}

	int nextDelimiter() {
		int now = pos;
		while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
			now++;
		}
		return now;
	}

	string readOne() {
		assert(pos < (int) buffer.size());
		int nxt = nextDelimiter();
		string res;
		while (pos < nxt) {
			res += buffer[pos];
			pos++;
		}
		return res;
	}

	string readString(int minl, int maxl, const string &pattern = "") {
		assert(minl <= maxl);
		string res = readOne();
		assert(minl <= (int) res.size());
		assert((int) res.size() <= maxl);
		for (int i = 0; i < (int) res.size(); i++) {
			assert(pattern.empty() || pattern.find(res[i]) != string::npos);
		}
		return res;
	}

	int readInt(int minv, int maxv) {
		assert(minv <= maxv);
		int res = stoi(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	long long readLong(long long minv, long long maxv) {
		assert(minv <= maxv);
		long long res = stoll(readOne());
		assert(minv <= res);
		assert(res <= maxv);
		return res;
	}

	auto readInts(int n, int minv, int maxv) {
		assert(n >= 0);
		vector<int> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readInt(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	auto readLongs(int n, long long minv, long long maxv) {
		assert(n >= 0);
		vector<long long> v(n);
		for (int i = 0; i < n; ++i) {
			v[i] = readLong(minv, maxv);
			if (i+1 < n) readSpace();
		}
		return v;
	}

	void readSpace() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == ' ');
		pos++;
	}

	void readEoln() {
		assert((int) buffer.size() > pos);
		assert(buffer[pos] == '\n');
		pos++;
	}

	void readEof() {
		assert((int) buffer.size() == pos);
	}
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;
    int T = input.readInt(1, 1000); input.readEoln();
    int sum_N = 0;

    while(T-- > 0) {
        int n = input.readInt(1, 5000); input.readEoln();
        sum_N += n;

        vector<int> a(n), ord(n);
        a = input.readInts(n, 1, (int)1e9);  input.readEoln();

        vector<vector<pair<int, int>>> adj(n);

        iota(ord.begin(), ord.end(), 0);
        sort(ord.begin(), ord.end(), [&](auto &i, auto &j) {
            return a[i] < a[j];
        });

        for(int i = 1 ; i < n ; i++) {
            adj[i].emplace_back(i - 1, 1);
            adj[i - 1].emplace_back(i, 1);
            adj[ord[i]].emplace_back(ord[i - 1], a[ord[i]] - a[ord[i - 1]]);
            adj[ord[i - 1]].emplace_back(ord[i], a[ord[i]] - a[ord[i - 1]]);
        }

        vector<int> dis(n, n), vis(n);
        priority_queue<pair<int, int>, vector<pair<int, int>>, greater<pair<int, int>>> pq;

        int res = 0;
        for(int src = 0 ; src < n ; ++src) {
            dis[src] = 0;
            pq.emplace(0, src);
            while(!pq.empty()) {
                auto nd = pq.top().second;  pq.pop();
                if(vis[nd])     continue;
                vis[nd] = 1;
                for(auto &[u, w]: adj[nd]) if(dis[u] > w + dis[nd]) {
                    dis[u] = w + dis[nd];
                    pq.emplace(dis[u], u);
                }
            }
            res = max(res, *max_element(dis.begin(), dis.end()));

            fill(vis.begin(), vis.end(), 0);
            fill(dis.begin(), dis.end(), n);
        }
        assert(res < n);
        cout << res << '\n';
    }

    input.readEof();
    assert(sum_N <= 5000);

    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<ll> a(n);
        for (ll &x : a) cin >> x;

        vector<int> ord(n), pos(n);
        iota(begin(ord), end(ord), 0);
        sort(begin(ord), end(ord), [&] (int i, int j) {
            return a[i] < a[j];
        });
        for (int i = 0; i < n; ++i) pos[ord[i]] = i;

        vector<int> buffer(10*n), link(10*n), head(n), tail(n);
        vector<int> mark(n), dist(n);

        int ans = 0;
        for (int u = 0; u < n; ++u) {
            for (int i = 0; i < n; ++i) {
                head[i] = tail[i] = -1;
                mark[i] = 0;
                dist[i] = n;
            }
            head[0] = tail[0] = 0;
            dist[u] = 0;
            buffer[0] = u;
            link[0] = -1;
            int ptr = 1;

            for (int d = 0; d < n; ++d) {
                int cur = head[d];
                while (cur != -1) {
                    int v = buffer[cur];
                    if (mark[v]) {
                        cur = link[cur];
                        continue;
                    }
                    mark[v] = 1;
                    ans = max(ans, d);

                    for (int dv : {-1, 1}) {
                        if (v + dv >= 0 and v + dv < n and dist[v + dv] > 1 + d) {
                            dist[v + dv] = 1 + d;

                            buffer[ptr] = v + dv;
                            if (head[1+d] == -1) head[1+d] = tail[1+d] = ptr;
                            else {
                                link[tail[1+d]] = ptr;
                                tail[1+d] = ptr;
                            }
                            link[ptr] = -1;
                            ++ptr;
                        }
                    }
                    int where = pos[v];
                    for (int dw : {-1, 1}) {
                        if (where + dw < 0 or where + dw >= n) continue;
                        int w = ord[where + dw];
                        int curd = d + abs(a[v] - a[w]);
                        if (dist[w] > curd) {
                            dist[w] = curd;

                            buffer[ptr] = w;
                            if (head[curd] == -1) head[curd] = tail[curd] = ptr;
                            else {
                                link[tail[curd]] = ptr;
                                tail[curd] = ptr;
                            }
                            link[ptr] = -1;
                            ++ptr;
                        }
                    }
                    cur = link[cur];
                }
            }
        }
        cout << ans << '\n';
    }
}

Why number of edges = O(N)?

Dijkstra’s algorithm in O(Nlog⁡N) time (it’s O(Elog⁡V), but E=O(N) here).

Because we only created 2\cdot (N-1) edges in total for our new graph, which is obviously \mathcal{O}(N).

  • N-1 edges of the form i \leftrightarrow (i+1)
  • N-1 edges of the form x_i \leftrightarrow x_{i+1} by considering the sorted array values to be
    A_{x_1} \leq A_{x_2} \leq\ldots\leq A_{x_N}

that’s the part i don’t understand
Ax1 <= Ax2 <= … <Axn ?
How we get it and how does it help to improve complexity?

The problem’s objective is to find the shortest path between each pair of vertices, then output the maximum among them - that’s what diameter is.
However, the input graph has too many edges for us to be able to do that fast enough.

See this section of the editorial:

We construct a new graph on N vertices with 2N-2 edges, such that the distance between u and v in this new graph equals the distance between u and v in the original graph.
Finding shortest paths in this new graph is much faster because it has a small number of edges.

The motivation for why the graph is constructed this way, and in particular why we sort the values, is detailed in the paragraphs above it.

1 Like