ADVITIYA27 - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: adj_alt
Tester: apoorv_me
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

DFS, dynamic programming

PROBLEM:

There’s a tree with N vertices. M deliveries need to be made on it: the i-th starts at S_i and ends at D_i.
There are two delivery boys; each delivery can be taken up by either one.
On a given day, the time taken to traverse an edge of the tree is 2 the first time, and 1 for all subsequent traversals.
Find the minimum time needed for all M deliveries.

EXPLANATION:

To solve this task, it appears that we need to maintain information about where both drivers are at all times; which is what will allow us to quickly process transitions.
To that end, a rather natural solution to start out with is as follows:

  • Let T(i, x, y) denote the minimum time necessary such that the first i deliveries have been made, the first driver is at node x, and the second is at node y.
  • Then, one of the drivers must make the (i+1)-th delivery, so T(i, x, y) updates either T(i+1, D_{i+1}, y) or T(i+1, x, D_{i+1}) depending on who moves.
  • For the first transition, you’ll need to compute the path length from x to S_{i+1}, then from S_{i+1} to D_{i+1}, and finally the overlap between these two paths (which will have 1 cost the second time, rather than 2).
    The second transition is similar.

Difficulty of computing path lengths and intersections aside, it should be immediately obvious that this solution is too slow: after all, it has M N^2 states, which is already too much for N=M=5000.

However, observe that we don’t really need all of them at all.
Indeed, when i deliveries have been made, one of the drivers must be at D_i after it.
So, we only care about T(i, x, y) for all those states such that x = D_i or y = D_i.
This immediately brings us down to \mathcal{O}(N\cdot M) states, since we essentially eliminate one dimension!
As noted earlier, we already had only two transitions from each state, so this is fast enough already; provided we can compute whatever is needed for the transitions quickly enough.

Let’s look at a specific transition.
Suppose T(i, x, y) is being used to update T(i+1, x, D_{i+1}).
We need to know three things:

  • The distance from y to S_{i+1}.
  • The distance from S_{i+1} to D_{i+1}.
  • The common length between these two paths.

The overall cost then becomes 2\cdot d(y, S_{i+1}) + 2\cdot d(S_{i+1}, D_{i+1}) - C_y (with C_y being the common length).

While these can be “bashed” using LCA and some casework, the low constraints allow for a much more elegant solution.
Notice that if we fix i+1 (the delivery under consideration), the quantity d(S_{i+1}, D_{i+1}) is always a constant, independent of y.
Further, 2\cdot d(y, S_{i+1}) - C_y can be thought of as follows:

  • Mark every edge on the path from S_{i+1} to D_{i+1} with 1, and every other tree edge with 2.
  • The above quantity is then the distance from S_{i+1} to y in this weighted tree.
    This distance can be found simultaneously for every y using a DFS from S_{i+1}, for instance.
  • In fact, under this model, d(S_{i+1}, D_{i+1}) is simply the distance from D_{i+1} to S_{i+1} in this weighted tree, since all the edges along this path have weight 1 anyway!

In other words, all we need to do is perform a single DFS starting from S_{i+1} on an appropriately weighted tree, after which the required values for each transition to i+1 are all obtained!

So, we have \mathcal{O}(N\cdot M) states, each with \mathcal{O}(1) transitions.
Further, to compute those transitions, we perform a DFS on the tree M times; which is also \mathcal{O}(N\cdot M) time.

The final answer is, of course, the minimum value of T(M, x, y) across all states such that x = D_M or y = D_M.
The number of states can be further reduced by noting that they’re symmetric (T(i, x, y) = T(i, y, x)) but this (probably) isn’t needed to get AC, and doesn’t affect runtime.

TIME COMPLEXITY:

\mathcal{O}(N\cdot M) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define INF (int)1e8

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

int n, m;
const int N = 5005;
vector <int> adj[N];
vector <pair<int, int>> g[N];
int dist[N];

int dfs(int u, int d, int par){
    int check = INF;
    if (u == d) check = 0;
    for (int v : adj[u]){
        if (v != par){
            int val = dfs(v, d, u);
            if (val != INF){
                g[u].push_back({v, 1});
            } else {
                g[u].push_back({v, 2});
            }
            
            check = min(check, val + 1);
        }
    }
    
    if (par != -1){
        if (check != INF){
            g[u].push_back({par, 1});
        } else {
            g[u].push_back({par, 2});
        }
    }
    return check;
}

void Solve() 
{
    cin >> n >> m;
    
    for (int i = 1; i <= n; i++){
        adj[i].clear();
    }
    
    for (int i = 1; i < n; i++){
        int u, v; cin >> u >> v;
        
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    
    vector <int> dp(n + 1, 0);
    int last = 1;
    
    for (int i = 1; i <= m; i++){
        int s, d; cin >> s >> d;
        
        vector <int> ndp(n + 1, INF);
        
        for (int j = 1; j <= n; j++){
            g[j].clear();
        }
        
        int val = dfs(s, d, -1);
        
        for (int j = 1; j <= n; j++){
            dist[j] = INF;
        }
        
        queue <int> q;
        q.push(s);
        dist[s] = 0;
        
        while (!q.empty()){
            int u = q.front(); q.pop();
            
            for (auto pi : g[u]){
                int v = pi.first;
                int w = pi.second;
                if (dist[v] == INF){
                    dist[v] = dist[u] + w;
                    q.push(v);
                }
            }
        }
        
        for (int j = 1; j <= n; j++){
            ndp[j] = min(ndp[j], dp[j] + dist[last] + 2 * val);
            ndp[last] = min(ndp[last], dp[j] + dist[j] + 2 * val);

            if (i == 1) ndp[j] = 2 * val;
        }
        last = d;
        swap(dp, ndp);
    }
    
    int ans = INF;
    for (int i = 1; i <= n; i++){
        ans = min(ans, dp[i]);
    }
    
    cout << ans << "\n";
}

void main_() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return;
}

int32_t main() {
    main_();
    return 0;
}
Tester's code (C++)
#ifndef LOCAL
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,sse,sse2,sse3,sse4,popcnt,fma")
#endif

#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "../debug.h"
#else
#define dbg(...) "11-111"
#endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int32_t main() {
    ios_base::sync_with_stdio(0);   cin.tie(0);

    input_checker input;

    int T = input.readInt(1, 100);   input.readEoln();
    int NN = 0, NM = 0;
    while(T-- > 0) {
        int N = input.readInt(1, 5000); input.readSpace();
        int M = input.readInt(1, 5000); input.readEoln();

        NN += N, NM += M;

        vector<vector<int>> adj(N);
        for(int i = 1, u, v ; i < N ; ++i) {
            u = input.readInt(1, N); input.readSpace();
            v = input.readInt(1, N); input.readEoln();
            adj[u - 1].push_back(v - 1);
            adj[v - 1].push_back(u - 1);
        }

        vector<vector<int>> dis(N, vector<int>(N, N));

        auto bfs = [&](vector<int> &d, int src) {
            vector<int> que(1, src);    d[src] = 0;
            for(int i = 0 ; i < (int)que.size() ; ++i) {
                int nd = que[i];
                for(auto &u: adj[nd]) if(d[u] > d[nd] + 1) {
                    d[u] = d[nd] + 1;
                    que.push_back(u);
                }
            }
        };

        for(int i = 0 ; i < N ; ++i)
            bfs(dis[i], i);

        assert(*max_element(dis[0].begin(), dis[0].end()) < N);

        vector<vector<int64_t>> dp(M + 1, vector<int64_t>(M + 1, (int64_t)1e15));
        vector<int> S(M), D(M);
        for(int i = 0 ; i < M ; ++i) {
            S[i] = input.readInt(1, N); input.readSpace();
            D[i] = input.readInt(1, N); input.readEoln();
            --S[i], --D[i];
        }

        dp[1][0] = 2 * dis[S[0]][D[0]];
        for(int i = 1 ; i < M ; ++i) {
            dp[i + 1][i] = dp[i][0] + 2 * dis[S[i]][D[i]];
            for(int j = i - 1 ; j >= 0 ; --j) {
                if(j) {
                    dp[i + 1][i] = min(dp[i + 1][i],
                    dp[i][j] + (3 * dis[D[j - 1]][S[i]] + 3 * dis[S[i]][D[i]] + dis[D[j - 1]][D[i]]) / 2);
                }
                dp[i + 1][j] = min(dp[i + 1][j], dp[i][j] + (3 * dis[D[i - 1]][S[i]] + 3 * dis[S[i]][D[i]] + dis[D[i - 1]][D[i]]) / 2);
            }
        }
        cout << *min_element(dp[M].begin(), dp[M].end()) << '\n';
    }
    assert(NN <= 5000 && NM <= 5000);

    input.readEof();

    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n, q; cin >> n >> q;
        vector adj(n, vector<int>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v;
            adj[--u].push_back(--v);
            adj[v].push_back(u);
        }
        vector<int> dp(n), dp2(n);
        vector<int> dist(n), par(n), ord;
        int pdes = -1;
        for (int i = 0; i < q; ++i) {
            int src, des; cin >> src >> des;
            --src, --des;
            ord.clear();
            auto dfs = [&] (const auto &self, int u, int p) -> void {
                ord.push_back(u);
                for (int v : adj[u]) if (v != p) {
                    self(self, v, u);
                    par[v] = u;
                }
            };
            dfs(dfs, src, -1);
            // Populate distances
            {
                dist.assign(n, 0);
                int u = des;
                while (true) {
                    dist[u] = 1;
                    if (u == src) break;
                    u = par[u];
                }
                for (int v : ord) {
                    if (v == src) dist[v] = 0;
                    else {
                        if (dist[v]) dist[v] = 1 + dist[par[v]];
                        else dist[v] = 2 + dist[par[v]];
                    }
                }
            }
            dp2.assign(n, INT_MAX);
            for (int x = 0; x < n; ++x) {
                // Min cost such that other person is at x
                // other person was already at x, and remains there: pdes -> src -> des
                // for pdes, can go x -> src -> des too
                int add = pdes == -1 ? 0 : dist[pdes];
                dp2[x] = min(dp2[x], dp[x] + add + 2*dist[des]);
                if (pdes != -1) {
                    dp2[pdes] = min(dp2[pdes], dp[x] + dist[x] + 2*dist[des]);
                }
            }
            swap(dp, dp2);
            pdes = des;
        }
        cout << *min_element(begin(dp), end(dp)) << '\n';
    }
}
2 Likes

Elegant, my foot. I am not sure if you actually wanted to TLE O(log) LCA (why on earth would you do that?) or just didn’t test the constant factor well enough, but either way having to squeeze out a log factor is a needlessly frustrating competitive experience.

1 Like

Sorry, it was primarily my decision to not allow the log factor as i thought that log factor was bullshit on this problem (you can precompute all n^2 lca values anyways if you want to use that solution, or even better all n^2 distance values which is the entire reason you calculate lca)

2 Likes

I ended up precomputing all LCA, but I don’t see how asking participants to do that is better than letting O(log) pass. It introduces zero new ideas and just makes the implementation more annoying (or library-heavy, neither of which is good).

1 Like