DISTMULT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: raysh07 and iceknight1093
Tester: sushil2006
Editorialist: iceknight1093

DIFFICULTY:

Medium

PREREQUISITES:

Divide and conquer, BFS

PROBLEM:

You have a complete undirected graph on N vertices.
Vertex i has the label A_i, and the weight between vertices i and j is as follows:

  • |i-j| if A_i = A_j
  • 2\cdot |i-j| otherwise

Answer Q queries of this form: given i and j, find the minimum weight of a path from i to j.

EXPLANATION:

To begin, observe that while we have \mathcal{O}(N^2) edges, most of them are fairly useless for our purposes.

In particular, suppose the move x \to y is optimal in some path.
Then,

  • Suppose A_x = A_y.
    If there exists a z between x and y such that A_x = A_z, we could replace x \to y by x \to z \to y to still end up at y with the same cost.
  • Next, suppose A_x \neq A_y, so the cost is 2\cdot |x-y|.
    Then, assuming x \lt y, we can replace the x \to y move with the sequence
    x \to x+1 \to x+2 \to\ldots\to y, which has a cost of 2 repeated (y-x) times, which is the same as what we started with.
    In fact, it might even have a lower cost, if some adjacent elements along the path happen to be equal.

This tells us that very few edges of the graph really matter: specifically,

  • Edges of the form (i, i+1), which can all be assumed to have a cost of 2.
  • For each i, the edges i \to L_i and i \to R_i with a cost of |i-L_i| and |R_i - i|, respectively.
    Here, L_i is the closest index to the left of i that has the label A_i, and R_i is the same but to the right.

This brings us down to only \mathcal{O}(N) edges.
With this, we’re already able to answer a single query in \mathcal{O}(N\log N) time by just directly applying Dijkstra’s algorithm on the reduced graph.
This can be further optimized to linear time by noting that the answer is trivially bounded by 2N, so we can group vertices by their distance (by keeping a list for each distance), allowing for the heap to be thrown out.

This is of course nowhere near fast enough yet.


Let’s try to bound the cost a bit.
Suppose we want to move from x to y, and without loss of generality x \lt y.
Let’s look at what happens if we only move rightward.

First, there will certainly be a base cost of (y-x), no matter what, since each i \to i+1 must be crossed someway or another.
Beyond this, every time we move to a different value we’ll incur an additional cost of 1 (recall that as noted at the start, we’ll only move to a different value when moving to an adjacent index).
So, we now need to figure out some optimal sequence of change of values.

Let L_i denote the first occurrence of label i at an index \geq x.
Without loss of generality, we can assume L_1 \lt L_2 \lt L_3\lt \ldots (if not, just relabel them, the answer doesn’t change).
Now, observe that, starting at x,

  • All indices in [L_1, L_2) don’t require any change of values (they’re all equal anyway).
  • All indices in [L_2, L_3) require at most one change - they’re all either 1 or 2, so we either don’t need to change at all, or change just once from L_2-1 to L_2.
  • All indices in [L_3, L_4) require at most two changes, by similar reasoning.
  • More generally, all indices in [L_k, L_{k+1}) require at most k-1 changes.

Given that all the labels are in [1, 20], this means we’re definitely able to reach y from x using no more than 19 changes.

So, by just moving rightwards, we know that x \to y has an answer of \leq (y-x) + 19.

Now, what if we don’t move only rightwards?
Well, each time we move i \to i-1, we’ll need to move i-1 \to i in the future anyway since the goal is to reach y.
So, each leftward move has an inherent cost of (at least) 2.
This means we really can’t have too many leftward moves: in particular, if we make even 10 leftward moves, we end up with a cost of (y-x) + 20 at minimum, which is definitely worse than optimal.

In particular, this means an optimal path can’t really get too far away from either x or y - it will be contained in [x-10, y+10] for sure.


We can now use the above ideas to obtain a complete solution.

We solve the queries offline, using divide-and-conquer.
Let \text{solve}(L, R) be the function that solves all queries whose endpoints lie in [L, R].

In standard fashion, if M = \text{midpoint}(L, R), we can call \text{solve}(L, M) and \text{solve}(M+1, R) to recursively solve for queries whose endpoints lie entirely in the left/right side.
This means we only need to worry about queries such that one endpoint is in the left half and the other is in the right half.

To solve for these, observe that at some point, we must cross over from the left side to the right side.
In the process of doing so, we’ll definitely pass through the leftmost occurrence of some label in the right half (if you recall the edges of the reduced graph).

Let the indices of the leftmost occurrences of labels in the right half be i_1, i_2, i_3, \ldots
If the optimal path x\to y passes through i_K, then observe that by symmetry this is just the combination of the i_K \to x and the i_K \to y paths, which are independent of each other.

Now, if we fix the “pivot” i_K, we can in linear time compute the distances from it to all of [L, R], with the optimization of Dijkstra mentioned at the very beginning.
Once these distances to i_K are known, for each query that crosses the midpoint we can update the answer.

To ensure that the distances from i_K are correctly computed, we need to extend the interval [L, R] to [L-10, R+10]; since as noted previously no optimal path that starts or ends in [L, R] needs to cross L-10 or R+10.

There are at most 20 distinct pivots because A_i \leq 20, so we end up doing \mathcal{O}(20\cdot (R-L+Q_0)) work here, with Q_0 being the number of queries that cross the midpoint.


Across the entirety of the divide-and-conquer, we thus end up doing \mathcal{O}(20\cdot Q + 20\cdot N\log N) work.
There’s also an additional factor of N\cdot 20^2, caused by us expanding each interval by an additional distance of 10 on either side - each pivot vertex thus visits at least 20 more vertices, and so there’s an extra additive factor of 20^2 in there for each node of the divide-and-conquer tree (of which there are \leq 2N).

Depending on how you figure out which segment each query gets solved in, there might be a further additive Q\log Q factor, though doing it in \mathcal{O}(Q) is possible too.

In the worst case we end up with \mathcal{O}(20\cdot Q + 20\cdot N\log N + N\cdot 20^2 + Q\log Q), which is still fast enough for the constraints we have.

TIME COMPLEXITY:

\mathcal{O}((N\log N + Q)\cdot M + N\cdot M^2 + Q\log Q) per testcase, where M = \max(A).

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

void Solve() 
{
    int n, q; cin >> n >> q;
    
    vector <int> a(n + 1);
    for (int i = 1; i <= n; i++){
        cin >> a[i];
    }
    
    vector <pair<int, int>> b(q);
    vector <int> ans(q, INF);
    
    for (int i = 0; i < q; i++){
        auto &[l, r] = b[i];
        cin >> l >> r;
        if (l > r){
            swap(l, r);
        }
        if (l == r){
            ans[i] = 0;
        }
    }
    
    vector<vector<pair<int, int>>> adj(n + 1);
    vector <int> last(21);
    for (int i = 1; i <= n; i++){
        if (last[a[i]]){
            int cost = (i - last[a[i]]);
            adj[i].push_back({last[a[i]], cost});
            adj[last[a[i]]].push_back({i, cost});
        }
        
        last[a[i]] = i;
        
        if (i >= 2){
            adj[i].push_back({i - 1, 2});
            adj[i - 1].push_back({i, 2});
        }
    }
    
    vector<vector<int>> c(2 * n + 1);
    vector <int> d(n + 1);
    
    auto solve = [&](auto self, int l, int r, vector <int> v){
        int m = (l + r) / 2;
        // solve for [l, r] for passing through m 
        
        if (l == r){
            return;
        }
        
        int L = max(1LL, l - 10);
        int R = min(n, r + 10);
        
        // solve for the range [L, R] 
        
        vector <int> found(21, 0);
        vector <int> imp;
        
        for (int i = m; i >= L; i--){
            if (!found[a[i]]){
                found[a[i]] = 1;
                imp.push_back(i);
            }
        }
        
        // optimal passes through one of imp 
        // run a linear dijkstra 
        
        int mx = 2 * (R - L);
        
        vector <int> to_solve, lf, rg;
        for (auto i : v){
            auto [l, r] = b[i];
            if (l <= m && m < r){
                to_solve.push_back(i);
            } else if (r <= m){
                lf.push_back(i);
            } else {
                rg.push_back(i);
            }
        }
        
        for (int s : imp){
            for (int i = 0; i <= mx; i++){
                c[i].clear();
            }
            
            for (int i = L; i <= R; i++){
                d[i] = INF;
            }
            
            d[s] = 0;
            c[0].push_back(s);
            
            for (int i = 0; i <= mx; i++){
                for (auto u : c[i]) if (d[u] == i){
                    for (auto [v, w] : adj[u]) if (L <= v && v <= R){
                        if (d[v] > d[u] + w){
                            d[v] = d[u] + w;
                            c[d[v]].push_back(v);
                        }
                    }
                }
            }
            
            for (auto i : to_solve){
                auto [l, r] = b[i];
                ans[i] = min(ans[i], d[l] + d[r]);
            }
        }
        
        self(self, l, m, lf);
        self(self, m + 1, r, rg);
    };  
    
    vector <int> v(q);
    iota(v.begin(), v.end(), 0);
    solve(solve, 1, n, v);
    
    for (auto x : ans){
        cout << x << " ";
    }
    cout << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 1e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

void solve(int test_case){
    ll n,q; cin >> n >> q;
    vector<ll> a(n+5);
    rep1(i,n) cin >> a[i];
    vector<pll> queries(q+5);
    rep1(i,q){
        ll l,r; cin >> l >> r;
        if(l > r) swap(l,r);
        queries[i] = {l,r};
    }

    vector<ll> lx(n+5), rx(n+5);
    vector<ll> prev_occ(25,-1);
    rep1(i,n){
        lx[i] = prev_occ[a[i]];
        prev_occ[a[i]] = i;
    }
    fill(all(prev_occ),-1);
    rev(i,n,1){
        rx[i] = prev_occ[a[i]];
        prev_occ[a[i]] = i;
    }
    
    vector<ll> ans(q+5,inf2);

    array<bool,25> came;
    vector<bool> vis(n+5);
    vector<ll> dis(n+5);
    vector<vector<ll>> here(2*n+5);

    auto go = [&](ll l, ll r, vector<array<ll,3>> &curr_queries, auto &&go) -> void{
        if(l == r){
            for(auto [u,v,id] : curr_queries){
                ans[id] = 0;
            }
            return;
        }

        ll mid = (l+r)>>1;
        vector<ll> guys;

        came.fill(0);
        for(int i = mid+1; i <= r; ++i){
            if(!came[a[i]]){
                came[a[i]] = 1;
                guys.pb(i);
            }
        }

        trav(s,guys){
            for(int i = l; i <= r; ++i){
                vis[i] = 0;
                dis[i] = inf2;
            }
            for(int i = 0; i <= 2*(r-l); ++i){
                here[i].clear();
            }

            here[0].pb(s);
            
            rep(d,2*(r-l)+1){
                trav(u,here[d]){
                    if(vis[u]) conts;
                    vis[u] = 1;
                    dis[u] = d;
                    
                    for(auto v : {lx[u],rx[u],u-1,u+1}){
                        if(v >= l and v <= r){
                            ll curr_dis = abs(u-v);
                            ll cost2 = d+curr_dis;
                            if(a[u] != a[v]) cost2 += curr_dis;

                            if(cost2 < dis[v]){
                                dis[v] = cost2;
                                here[cost2].pb(v);
                            }
                        }
                    }
                }
            }

            for(auto [u,v,id] : curr_queries){
                amin(ans[id],dis[u]+dis[v]);
            }
        }
        
        vector<array<ll,3>> left_queries, right_queries;
        for(auto [u,v,id] : curr_queries){
            if(v <= mid){
                left_queries.pb({u,v,id});
            }
            else if(u > mid){
                right_queries.pb({u,v,id});
            }
        }

        go(l,mid,left_queries,go);
        go(mid+1,r,right_queries,go);
    };

    vector<array<ll,3>> curr_queries;
    rep1(i,q) curr_queries.pb({queries[i].ff,queries[i].ss,i});

    go(1,n,curr_queries,go);
    
    rep1(i,q) cout << ans[i] << " ";
    cout << endl;
}

int main()
{
    fastio;

    int t = 1;
    cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}