MEX_PATH - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: wuhudsm
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

Dijkstra’s algorithm

PROBLEM:

You’re given an array A. Consider the complete undirected weighted graph on N vertices, with the weight of the edge (i, j) being \text{MEX}(\{A_i, A_{i+1}, \ldots, A_j\}).

Find the shortest path from 1 to N.

EXPLANATION:

There are around N^2 edges, but it turns out that most of them are quite useless!

Specifically, consider an edge (i, j) that’s used in a solution.
Let w = \text{MEX}(A[i\ldots j]) be the weight of this edge.
Then,

  • If i\gt 1 and A_{i-1} \neq w, we also have \text{MEX}(A[i-1\ldots j]) = w.
    So, we could use the edge (i-1, j) instead (and if we previously had (k, i), use (k, i-1) which has not-higher cost).
  • Similarly, if j\lt N and A_{j+1} \neq w, we can instead use edge (i, j+1).

This means we only need to consider edges (i, j) such that:

  • i = 1 or j = N; or
  • Let w = \text{MEX}(A[i\ldots j]). Then, A_{i-1} = w or A_{j+1} = w.

The first case gives us 2N-3 edges in total.
As for the second case: let’s fix the left endpoint i of such an edge, and let w = A_{i-1}.
Let k \gt i be the index of the next occurrence of w.
We then need to consider some edges (i, j) such that i \lt j \lt k (since any j \geq k cannot have the mex of the range [i, j] be w, given that the range itself contains w).

In fact, it’s enough to just consider the single edge i \to (k-1) with weight w, along with “backward” edges x \to (x-1) with weight 0 for every x.
This allows us to move from i to anywhere before k with weight w.
Note that the backward edges don’t really mess anything up: they pretty much just mean “if you can reach i with a cost of x, you can also reach any index \lt i with a cost not exceeding x” which is obviously true.

You might notice that it’s possible for the mex of range [i, k-1] to not actually be w, but that’s ok - it’ll be \leq w for sure (since w isn’t in the range), and if it’s \lt w then none of these edges will be optimal to use anyway; so considering them with a higher cost certainly won’t affect the answer).

This second case also gives us \leq 2N-3 more edges: N-1 backward edges with weight 0, and (at most) N-2 forward edges, one for each 1 \lt i \lt N.

We now have \mathcal{O}(N) edges to deal with on a graph with N vertices, so directly finding the shortest path using Dijkstra’s algorithm works quickly.

TIME COMPLEXITY:

\mathcal{O}(N\log N) per testcase.

CODE:

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n; cin >> n;
    
    vector <int> a(n);
    for (auto &x : a) cin >> x;
    
    vector<vector<pair<int, int>>> b(n);
    vector<int> last(n + 1, -1);
    for (int i = 0; i < n; i++){
        if (last[a[i]] != -1){
            b[last[a[i]] + 1].push_back({i - 1, a[i]});
        }
        last[a[i]] = i;
    }
    
    set <int> st;
    int mex = 0;
    for (int i = 0; i < n; i++){
        st.insert(a[i]);
        while (st.count(mex)) mex++;
        
        b[0].push_back({i, mex});
    }
    
    for (int i = 1; i < n; i++){
        b[i].push_back({i - 1, 0});
    }
    
    st.clear();
    mex = 0;
    for (int i = n - 1; i >= 0; i--){
        st.insert(a[i]);
        while (st.count(mex)) mex++;
        b[i].push_back({n - 1, mex});
    }
    
    vector <int> dp(n, INF);
    dp[0] = 0;
    priority_queue <pair<int, int>> pq;
    
    pq.push({0, 0});
    
    while (!pq.empty()){
        auto pi = pq.top(); pq.pop();
        
        int u = pi.second;
        int ds = -pi.first;
        
        if (dp[u] != ds) continue;
        
        for (auto [v, w] : b[u]){
            if (dp[v] > dp[u] + w){
                dp[v] = dp[u] + w;
                pq.push({-dp[v], v});
            }
        }
    }
    
    cout << dp[n - 1] << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 RNG(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector<int> a(n);
        for (int &x : a) cin >> x;

        vector adj(n, vector<array<int, 2>>());
        vector<int> mark(n+2);
        int mex = 0;
        for (int i = 0; i < n; ++i) {
            mark[a[i]] = 1;
            while (mark[mex]) ++mex;
            adj[0].push_back({i, mex});
        }
        mex = 0;
        mark.assign(n+2, 0);
        for (int i = n-1; i >= 0; --i) {
            mark[a[i]] = 1;
            while (mark[mex]) ++mex;
            adj[i].push_back({n-1, mex});
        }
        
        vector<int> prv(n+1, -1);
        prv[a[0]] = 0;
        for (int i = 1; i < n; ++i) {
            if (prv[a[i]] != -1) {
                int L = prv[a[i]] + 1;
                adj[L].push_back({i-1, a[i]});
            }
            prv[a[i]] = i;
        }
        for (int i = 1; i < n; ++i)
            adj[i].push_back({i-1, 0});

        vector<int> dist(n, n+1);
        dist[0] = 0;
        set<array<int, 2>> st = {{0, 0}};
        while (!st.empty()) {
            auto [d, u] = *st.begin();
            st.erase({d, u});
            for (auto [v, w] : adj[u]) {
                if (dist[v] > dist[u] + w) {
                    st.erase({dist[v], v});
                    dist[v] = dist[u] + w;
                    st.insert({dist[v], v});
                }
            }
        }
        cout << dist[n-1] << '\n';

    }
}

really cool!!! its pure beauty! i love these type of graph building problems.