PROBLEM LINK:
Contest Division 1
Contest Division 2
Contest Division 3
Practice
Setter: Evgeny
Tester & Editorialist: Taranpreet Singh
DIFFICULTY
Medium-Hard
PREREQUISITES
Dijkstra Algorithm and Segment Tree
PROBLEM
Given a directed graph with N nodes and N^2 edges, the cost of directed edge from u to v is given by A_u, where array A is given.
In this graph, exactly M modifications are made, each modification updating the cost of edge from u to v to cost c.
Find the sum of lengths of the shortest paths over all ordered pairs of distinct vertices.
QUICK EXPLANATION
- Treating edges as undirected, find the connected components of nodes whose edge weights are modified.
- For each node in component, we are going to run dijkstra to find minimum distance of two nodes within component.
- In order to execute dijkstra faster, we’d use lazy segment tree to simulate the dijkstra algorithm operations to process multiple edges with same weight faster.
- For each component, we also need to add an external node v with minimum value of A_v to handle case where path within component will have higher distance.
EXPLANATION
Let’s start with brute force solution and then improve it.
The brute force approach would be to run floyd warshall all pair shortest path algorithm on given graph, and then taking the sum of distances of ordered pairs. This solution is sufficient for subtask 1.
Observation 1
Consider a node u not having any of its outgoing edge modified. Then we can see that all paths starting at that node have cost A_u, and since there’s a direct edge from u to all other nodes with cost A_u, This node contributes (N-1)*A_u to the total cost.
Hence, we can add (N-1)*A_u to the cost and handle nodes which are endpoints of atleast one modified edge.
Subtask 3
By above observation, we can see that we only need to compute sum of distance of pairs starting at atmost 2*M candidate nodes.
Let’s divide these nodes into connected components (assuming the modified edges were undirected). This way, any shortest path from inside one component to node outside that component would have to use at least one original edge.
Observation 2
Let’s write the shortest path from u to v as u → x_1 → \ldots → x_k → v where u is inside a component and v is outside that component.
Claim: All edges in this path till x_k are modified edges, and edge x_k → v is an original edge.
Proof: Let us first prove that all x_i lie in the same component as u. Let’s assume the first x_i lies outside component. Then edge x_{i-1} → x_i (u → x_1 if i = 1) must be an original edge. But we can simply use original edge x_{i-1} → v having same cost A_{x_{i-1}}, reducing the cost. Hence, all x_i lies within same component as u.
The edge x_k → v cannot be a modified edge, since x_k is inside the component and v is outside. Hence, edge x_k → v is an original edge
Using above observation, if we have distance computed for all pairs of nodes (u, w) within same component, then distance from node u to all outside nodes shall have cost min(dist(u, w)+A_w) over all w in component.
Hence, all we need to do is to compute the distance between all pairs of nodes within component, which can be done using floyd warshall algorithm
Corner Case
Due to modification of edge weights, there might be a case where the shortest path from u to v where u and v are within same component, may include a node w outside the component.
This happens, because due to updates, the direct path from u to v may actually have higher cost now.
Consider example
3
1 1 1
1
1 2 5
dist(1, 2) = dist(1, 3)+dist(3, 2) = 2
An easy way to handle that is by observing that for all nodes w outside component, it is sufficient to consider the one with minimum A_w, since the cost is dist(u, x) + dist(x, w) + A_w where x is some internal node. This distance is written as dist(u, x) + A_x + A_w. The only term depending upon w is A_w, hence it is sufficient to add one node with minimum value of A_w to this component before computing all pair shortest paths.
Subtask 4
All this time, we relied on floyd warshall to compute all pair shortest paths within components of size upto M, which have at most M^2 edges.
Let’s try dijkstra here. Since there are total 2*M nodes as endpoint of at least one edge, we need to run dijkstra 2*M times. Each run of dijkstra, if implemented normally, would take O(M^2) time (proportional to the number of edges), which is too slow.
Optimizing dijkstra
Among these M^2 edges within component, only upto M edges are modified ones. Rest all outgoing edges from a node have same weight.
Revisiting dijkstra algorithm, the operations we usually need to support are (assuming implemented with min-heap)
- Finding the node with minimum distance
- Removing a node from our DS
- Update the distance to a node v with min of distance to node u plus edge weight of edge from u to v. (Called the relaxation operation)
Here, the first two operations are performed O(V) time during dijkstra and third operation is performed O(E) time during dijkstra, if V and E denote the number of nodes and edges in graph respectively. With E = M^2, we only need to speed up third operation.
Here, for a fixed node u, most of the outgoing edges have weight A_u. Can we process multiple relaxations simultaneously?
Segment Tree
Yes, Let’s fix an order of nodes in the component and label them from 0 to C-1 where C is the size of component. Let’s say we are processing node u, having already computed dist(source, u), having outgoing edges to nodes c_1, c_2 \ldots c_k with weights w_1, w_2 \ldots w_k in sorted order of c_i.
The relaxations we need to perform are on node c_i with value dist(source, u) + w_i (there are C of them over whole execution of dijkstra, so we perform these naively)
AND
relaxations on all nodes lying in any of the range [0, c_1-1], [c_1+1, c_2-1] \ldots [c_{k-1}+1, c_k-1], [c_k+1, C-1] with value dist(source, u)+A_u.
The second type of relaxations have same value X = dist(source, u)+A_u. It is equivalent to replacing dist(source, w) with X for all w lying in any of the range.
This operation can be supported by segment tree, where i-th leaf stores the distance to i-th node from source node.
The operations we need to support by segment tree are
- Find position with minimum value (equivalent to finding topmost element of min-heap)
- Removing position (equivalent to popping out topmost element). This can be done by maintaining set of active positions.
- Updating all dist(source, w) with min(X, dist(source, w)) for w \in [l, r]
These operations can be supported by lazy segment tree, maintaining state of active leaves (initially all are active, and each pop operation marks one of them inactive).
Refer my code for implementation details.
TIME COMPLEXITY
The time complexity is O(N + M^2*log(M)) per test case.
SOLUTIONS
Setter's Solution
#include<bits/stdc++.h>
using namespace std;
int const maxn = 1e6 + 5, maxm = 2005;
int a[maxn], used[maxn], cur, numb[maxn];
vector < pair < int, int > > g[maxn];
vector < int > G[maxn];
vector < int > now;
vector < pair < int, int > > go[maxn];
int b[maxn], dist[maxn], inf = 1e9 + 7;
pair < int, int > imin[(1 << 13)];
int psh[(1 << 14)];
int good[(1 << 13)];
int free_good[(1 << 13)];
void dfs(int v) {
used[v] = cur;
now.push_back(v);
for (auto u : G[v]) {
if (used[u] == 0) dfs(u);
}
}
void build(int i, int l, int r) {
good[i] = r - l;
free_good[i] = l;
psh[i] = inf;
imin[i] = {inf, l};
if (r - l == 1) return;
int m = (r + l) / 2;
build(2 * i + 1, l, m);
build(2 * i + 2, m, r);
}
inline void push(int i, int l, int r) {
if (psh[i] == inf) return;
if (good[i]) {
if (psh[i] < imin[i].first) {
imin[i] = {psh[i], free_good[i]};
}
}
psh[2 * i + 1] = min(psh[2 * i + 1], psh[i]);
psh[2 * i + 2] = min(psh[2 * i + 2], psh[i]);
psh[i] = inf;
}
void update(int i, int l, int r, int lq, int rq, int x) {
push(i, l, r);
if (lq >= r || l >= rq) return;
if (lq <= l && r <= rq) {
psh[i] = x;
push(i, l, r);
return;
}
int m = (r + l) / 2;
update(2 * i + 1, l, m, lq, rq, x);
update(2 * i + 2, m, r, lq, rq, x);
imin[i] = min(imin[2 * i + 1], imin[2 * i + 2]);
}
void del(int i, int l, int r, int lq) {
push(i, l, r);
if (!(l <= lq && lq < r)) return;
if (r - l == 1) {
good[i]--;
imin[i] = {inf, l};
return;
}
int m = (r + l) / 2;
del(2 * i + 1, l, m, lq);
del(2 * i + 2, m, r, lq);
good[i]--;
if (good[2 * i + 1]) free_good[i] = free_good[2 * i + 1];
else free_good[i] = free_good[2 * i + 2];
if (good[i]) imin[i] = min(imin[2 * i + 1], imin[2 * i + 2]);
else imin[i] = {inf, l};
}
void solve(int v, int n) {
build(0, 1, n + 1);
update(0, 1, n + 1, v, v + 1, 0);
for (int i = 1; i <= n; ++i) {
pair < int, int > best = imin[0];
dist[best.second] = best.first;
del(0, 1, n + 1, best.second);
int was = 0;
for (auto key : go[best.second]) {
update(0, 1, n + 1, key.first, key.first + 1, key.second + best.first);
if (was + 1 != key.first) {
update(0, 1, n + 1, was + 1, key.first, best.first + b[best.second]);
}
was = key.first;
}
if (was != n) update(0, 1, n + 1, was + 1, n + 1, best.first + b[best.second]);
}
}
main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t, n, m, u, v, c;
cin >> t;
while (t--) {
cin >> n;
cur = 0;
long long ans = 0;
vector < pair < int, int > > all_elem;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
all_elem.push_back({a[i], i});
g[i] = {};
G[i] = {};
used[i] = 0;
}
sort(all_elem.begin(), all_elem.end());
cin >> m;
for (int i = 1; i <= m; ++i) {
cin >> u >> v >> c;
G[u].push_back(v);
G[v].push_back(u);
g[u].push_back({v, c});
}
for (int i = 1; i <= n; ++i) {
sort(g[i].begin(), g[i].end());
}
for (int i = 1; i <= n; ++i) {
if (used[i] == 0) {
cur++, now = {};
dfs(i);
if ((int)now.size() == 1) {
ans += (long long)a[i] * (n - 1);
continue;
}
int bad = -1;
for (auto key : all_elem) {
if (used[key.second] != cur) {
bad = key.second;
now.push_back(bad);
break;
}
}
sort(now.begin(), now.end());
for (int j = 0; j < (int)now.size(); ++j) numb[now[j]] = j + 1;
for (int j = 1; j <= (int)now.size(); ++j) {
go[j] = {};
b[j] = a[now[j - 1]];
for (auto elem : g[now[j - 1]]) {
go[j].push_back({numb[elem.first], elem.second});
}
}
for (int j = 1; j <= (int)now.size(); ++j) {
if (now[j - 1] == bad) continue;
solve(j, (int)now.size());
int out = a[now[j - 1]];
for (int pos = 1; pos <= (int)now.size(); ++pos) {
ans += (long long)dist[pos];
out = min(out, dist[pos] + a[now[pos - 1]]);
}
ans += (long long)(n - (int)now.size()) * (long long)out;
}
}
}
cout << ans << '\n';
}
return 0;
}
Tester's Solution
import java.util.*;
import java.io.*;
class ALLGRAPH{
//SOLUTION BEGIN
void pre() throws Exception{}
void solve(int TC) throws Exception{
int N = ni();
int[] A = new int[N];
int[][] order = new int[N][];
for(int i = 0; i< N; i++){
A[i] = ni();
order[i] = new int[]{i, (int)A[i]};
}
Arrays.sort(order, (int[] i1, int[] i2) -> Integer.compare(i1[1], i2[1]));
int M = ni();
int[] from = new int[M], to = new int[M], w = new int[M];
int[] set = java.util.stream.IntStream.range(0, N).toArray();
for(int i = 0; i< M; i++){
from[i] = ni()-1;
to[i] = ni()-1;
w[i] = ni();
set[find(set, to[i])] = find(set, from[i]);
}
int[][][] g = makeS(N, M, from, to, false);
List<Integer>[] list = new ArrayList[N];
for(int i = 0; i< N; i++)if(find(set, i) == i)list[i] = new ArrayList<>();
for(int i = 0; i< N; i++)list[find(set, i)].add(i);
int[] index = new int[N];
Arrays.fill(index, -1);
long ans = 0;
for(int i = 0; i< N; i++){
if(list[i] == null)continue;
if(list[i].size() == 1){
ans += A[i]*(long)(N-1);
continue;
}
int smallestOutside = -1;
for(int u = 0; u< N; u++)
if(find(set, order[u][0]) != find(set, i)){
smallestOutside = order[u][0];
break;
}
if(smallestOutside != -1)list[i].add(smallestOutside);
Collections.sort(list[i]);
for(int j = 0; j< list[i].size(); j++)index[list[i].get(j)] = j;
int sz = list[i].size();
// List<int[]>[] graph = new ArrayList[sz];
int[][][] graph = new int[sz][][];
int[] reindexedA = new int[sz];
for(int j = 0; j< list[i].size(); j++){
int u = list[i].get(j);
reindexedA[j] = A[u];
if(u != smallestOutside){
graph[j] = new int[g[u].length][];
int p = 0;
for(int[] e:g[u])
graph[j][p++] = new int[]{index[e[0]], w[e[1]]};
Arrays.sort(graph[j], (int[] i1, int[] i2) -> Integer.compare(i1[0], i2[0]));
}else graph[j] = new int[0][];
}
for(int j = 0; j< sz; j++){
if(list[i].get(j) == smallestOutside)continue;
int[] d = dijkstra(sz, graph, reindexedA, j);
long others = Long.MAX_VALUE;
for(int k = 0; k< sz; k++){
ans += d[k];
others = Math.min(others, d[k]+A[list[i].get(k)]);
}
ans += (N-sz)*others;
}
}
pn(ans);
}
int[] dijkstra(int sz, int[][][] graph, int[] A, int src){
LazySegmentTree st = new LazySegmentTree(sz);
st.update(src, src, 0);
int[] dist = new int[sz];
for(int iteration = 1; iteration <= sz; iteration++){
long[] min = st.argmin();
int u = (int)min[0];
dist[u] = (int)min[1];
int from = 0;
for(int[] e:graph[u]){
if(from < e[0])st.update(from, e[0]-1, dist[u]+A[u]);
st.update(e[0], e[0], dist[u]+e[1]);
from = e[0]+1;
}
if(from < sz)st.update(from, sz-1, dist[u]+A[u]);
st.delete(u);
}
return dist;
}
class LazySegmentTree{
int m = 1;
long IINF = (long)1e17;
long[] min, lazy;
int[] argmin;
boolean[] alive;
public LazySegmentTree(int n){
while(m<n)m<<=1;
min = new long[m<<1];
lazy = new long[m<<1];
alive = new boolean[m<<1];
argmin = new int[m<<1];
Arrays.fill(min, IINF);
Arrays.fill(lazy, IINF);
Arrays.fill(argmin, -1);
for(int i = 0; i< n; i++){
alive[m+i] = true;
argmin[m+i] = i;
}
for(int i = m-1; i> 0; i--){
alive[i] = alive[i<<1]|| alive[i<<1|1];
if(alive[i<<1])argmin[i] = argmin[i<<1];
if(alive[i<<1|1])argmin[i] = argmin[i<<1|1];
}
}
private void push(int i){
if(!alive[i])return;
min[i] = Math.min(min[i], lazy[i]);
if(i < m){
lazy[i<<1] = Math.min(lazy[i<<1], lazy[i]);
lazy[i<<1|1] = Math.min(lazy[i<<1|1], lazy[i]);
}
}
private void updateNode(int i){
argmin[i] = -1;
min[i] = IINF;
alive[i] = false;
if(alive[i<<1] && min[i<<1] < min[i]){
min[i] = min[i<<1];
argmin[i] = argmin[i<<1];
alive[i] = true;
}
if(alive[i<<1|1] && min[i<<1|1] < min[i]){
min[i] = min[i<<1|1];
argmin[i] = argmin[i<<1|1];
alive[i] = true;
}
}
public void update(int l, int r, long x){u(l, r, 0, m-1, 1, x);}
public void delete(int p){d(p, 0, m-1, 1);}
public long[] argmin(){
return new long[]{argmin[1], min[1]};
}
private void u(int l, int r, int ll, int rr, int i, long x){
if(!alive[i])return;
push(i);
if(l == ll && r == rr){
lazy[i] = Math.min(lazy[i], x);
push(i);
return;
}
int mid = (ll+rr)>>1;
push(i<<1);
push(i<<1|1);
if(r <= mid)u(l, r, ll, mid, i<<1, x);
else if(l > mid)u(l, r, mid+1, rr, i<<1|1, x);
else{
u(l, mid, ll, mid, i<<1, x);
u(mid+1, r, mid+1, rr, i<<1|1, x);
}
updateNode(i);
}
private void d(int p, int ll, int rr, int i){
push(i);
if(ll == rr){
alive[i] = false;
argmin[i] = -1;
min[i] = IINF;
return;
}
int mid = (ll+rr)>>1;
push(i<<1);
push(i<<1|1);
if(p <= mid)d(p, ll, mid, i<<1);
else d(p, mid+1, rr, i<<1|1);
updateNode(i);
}
}
int find(int[] set, int u){return set[u] = set[u] == u?u:find(set, set[u]);}
int[][][] makeS(int n, int e, int[] from, int[] to, boolean f){
int[][][] g = new int[n][][];int[]cnt = new int[n];
for(int i = 0; i< e; i++){
cnt[from[i]]++;
if(f)cnt[to[i]]++;
}
for(int i = 0; i< n; i++)g[i] = new int[cnt[i]][];
for(int i = 0; i< e; i++){
g[from[i]][--cnt[from[i]]] = new int[]{to[i], i, 1};
if(f)g[to[i]][--cnt[to[i]]] = new int[]{from[i], i, -1};
}
return g;
}
static void dbg(Object... o){System.err.println(Arrays.deepToString(o));}
//SOLUTION END
void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
static boolean multipleTC = true;
FastReader in;PrintWriter out;
void run() throws Exception{
in = new FastReader();
out = new PrintWriter(System.out);
//Solution Credits: Taranpreet Singh
int T = (multipleTC)?ni():1;
pre();for(int t = 1; t<= T; t++)solve(t);
out.flush();
out.close();
}
public static void main(String[] args) throws Exception{
new ALLGRAPH().run();
}
int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
void p(Object o){out.print(o);}
void pn(Object o){out.println(o);}
void pni(Object o){out.println(o);out.flush();}
String n()throws Exception{return in.next();}
String nln()throws Exception{return in.nextLine();}
int ni()throws Exception{return Integer.parseInt(in.next());}
long nl()throws Exception{return Long.parseLong(in.next());}
double nd()throws Exception{return Double.parseDouble(in.next());}
class FastReader{
BufferedReader br;
StringTokenizer st;
public FastReader(){
br = new BufferedReader(new InputStreamReader(System.in));
}
public FastReader(String s) throws Exception{
br = new BufferedReader(new FileReader(s));
}
String next() throws Exception{
while (st == null || !st.hasMoreElements()){
try{
st = new StringTokenizer(br.readLine());
}catch (IOException e){
throw new Exception(e.toString());
}
}
return st.nextToken();
}
String nextLine() throws Exception{
String str = "";
try{
str = br.readLine();
}catch (IOException e){
throw new Exception(e.toString());
}
return str;
}
}
}
Feel free to share your approach. Suggestions are welcomed as always.