GHOUDIES - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: khaab_2004
Tester: raysh07
Editorialist: iceknight1093

DIFFICULTY:

2850

PREREQUISITES:

DFS, Dynamic programming

PROBLEM:

You’re a given a tree on N vertices, each edge of which has a character written on it.
You also have a string S.
Define \text{str}(u, v) to be the string obtained by following the edges from u to v in the tree.
Find the maximum possible value of \text{LCS}(\text{str}(u, v), S) across all pairs (u, v) of vertices.

EXPLANATION:

This task becomes much more approachable if you’re familiar with the longest common subsequence problem for two strings. If you’re unfamiliar with it or its solution, please read about it first, for example from here.

Let M = |S| be the length of S.

Recall that the classical longest common subsequence problem, dealing with two strings, is solved with dynamic programming in \mathcal{O}(N\cdot M) time.
If the input tree were a straight line, we would need to solve exactly this, so our task is clearly a harder version - and requires dynamic programming to solve, at that.
All we need to do is figure out the states and transitions.

Consider some path u\to v in the tree.
Root the tree at vertex 1, and let L denote the lowest common ancestor of u and v.
Note that the string \text{str}(u, v) is obtained by moving up from u till we reach L, then down from L till we reach v.
More generally, any path is the combination of one upward path and one downward path.

Observe that:

  • The upward string will have some common subsequence with a prefix of S.
  • The downward string will have some common subsequence with a suffix of S.

This should remind you of how we deal with prefixes (or suffixes, depending on your implementation) of the two strings in the classical LCS problem.
Indeed, we can use this to define our dynamic programming states.

Let \text{up}[u][i] denote the length of the longest common subsequence such that:

  • The path we consider starts inside the subtree of u and moves upward till it reaches u; and
  • We attempt to match with only the first i characters of S.

It’s not too hard to come up with transitions for this.
Let v be a child of u, and c be the character on the edge joining them.
Then, we get:

  • If the edge between u and v is not matched to any character, we just get a length of \text{up}[v][i], since all the matching must come from the subtree of v itself.
  • Otherwise, the edge is matched to some character. There are two choices:
    • S_i = c. In this case, the best we can do is 1 + \text{up}[v][i-1], by matching the first i-1 characters in the subtree of v.
    • S_i \neq c. In this case, the best we can do is \text{up}[u][i-1], since the i-th character of S isn’t being matched anyway.
  • \text{up}[u][i] is then the maximum of these, across all children v of u.

You may notice that this is, once again, very similar to the transitions of the classical LCS problem.
The complexity is \mathcal{O}(N\cdot M), since we have N\cdot (M+1) states and \mathcal{O}(1) transitions from each one.

Similarly, one can compute \text{down}[u][i] to be the LCS length of a path starting at u and going down into its subtree; matching only with the last i characters of S.

Finally, to get the actual answer we need to combine upward paths with downward paths.
To do that, let’s fix u and look at \text{up}[u][i].
This is an upward path matching with the first i characters of S, so our best bet is to combine it with some downward path matching with the last M-i characters of S.

This, by definition, is \text{down}[u][M-i] — but there’s one catch: \text{up} and \text{down} were computed by processing children of u, but we need to ensure that the upward path and downward paths we pick are from different children; otherwise their LCA wouldn’t be u.

There are several ways to account for this. The simplest way is to update the answer as and when transitions are performed.
That is, suppose you’re processing a new child v of u.
Note that at this point, \text{up}[u][i] and \text{down}[u][i] represent only the correct values for all children considered so far; in particular, not including v.
So, for each i, look at the up/down values of v and how they’d extend to u, and use \text{up}[u][M-i] and \text{down}[u][M-i] to update the overall answer with.
After this is done entirely, use v to update the dp values for u.

TIME COMPLEXITY

\mathcal{O}(N\cdot |S|) per testcase.

CODE:

Author's code (C++)
#include <bits/stdc++.h>
// #include 
using namespace std;
// #ifdef tabr
// #include "library/debug.cpp"
// #else
// #define debug(...)
// #endif

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }

    string readOne() {
        // assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        // cerr << res << endl;
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

int main() {
    #ifndef ONLINE_JUDGE
        freopen("input5.txt", "r", stdin);
        freopen("output5.txt", "w", stdout);
    #endif
    input_checker in;
    int tt = in.readInt(1, 500);
    in.readEoln();
    while (tt--) {
        // int a = in.readInt(1, 6);
        // in.readSpace();
        // int b = in.readInt(1, 6);
        // in.readEoln();
        int n = in.readInt(2 , 1e4);
        in.readEoln();
        // cout << n << endl;
        vector< vector <pair<int, char> >> adj(n);      
        for (int i = 0 ; i < n - 1 ; i++ ) {
            // cin >> a >> b ;
            // char c ; cin >> c ;
            int a = in.readInt(1 , n) ;
            in.readSpace() ;
            int b = in.readInt(1 , n) ;
            in.readSpace(); 

            // cout << a << " " << b << endl;
            string s = in.readString(1 , 2) ;
            in.readEoln();
            char c = s[0];
            // cout << c << endl;
            --a ; --b;
            adj[a].push_back({b , c}) ; adj[b].push_back({a , c});
        }
        string s = in.readString(1 , 1001);
        // // string s ; cin >> s ; 
        in.readEoln();
        // cout << 1 << endl;
        int m = (int)s.size() ;
        vector <vector<int>> prefix(n , vector <int> (m , 0)) ;
        vector <vector<int>> suffix = prefix ;
        vector <int> ans(n , 0);

        function <void (int , int)> dfs = [&] (int v , int p) {
            for (auto &node : adj[v]) {
                int u = node.first ; char c = node.second ;
                if (u == p) continue ;
                dfs (u , v) ;
                for (int i = 0; i <= m; i++) {
                    int pf1 = (i == 0 ? 0 : prefix[v][i - 1]) , sf1 = (i <= m - 1 ? suffix[u][i] : 0);
                    int pf2 = (i == 0 ? 0 : prefix[u][i - 1]) , sf2 = (i <= m - 1 ? suffix[v][i] : 0);

                    ans[v] = max ({ans[v] , pf1 + sf1 , pf2 + sf2 }) ;

                    if (i <= m - 1 && s[i] == c) {  
                        int pf1 = (i == 0 ? 0 : prefix[v][i - 1]) , sf1 = (i < m - 1 ? suffix[u][i + 1] : 0);
                        int pf2 = (i == 0 ? 0 : prefix[u][i - 1]) , sf2 = (i < m - 1 ? suffix[v][i + 1] : 0);
                        ans[v] = max({ans[v] , 1 + pf1 + sf1 , 1 + pf2 + sf2}) ;
                    }
                  }
                  for (int i = 0; i < m; i++) {
                    prefix[v][i] = max(prefix[v][i], prefix[u][i]);
                    if (s[i] == c) prefix[v][i] = max(prefix[v][i], 1 + (i == 0 ? 0 : prefix[u][i - 1]));
                    if (i > 0) {
                      prefix[v][i] = max(prefix[v][i], prefix[v][i - 1]);
                    }
                  }

                  for (int i = m - 1 ; i >= 0; i--) {
                    suffix[v][i] = max(suffix[v][i], suffix[u][i]);

                    if (s[i] == c) suffix[v][i] = max(suffix[v][i], 1 + (i + 1 < m ? suffix[u][i + 1] : 0));
                    if (i + 1 < m ) {
                      suffix[v][i] = max(suffix[v][i], suffix[v][i + 1]);
                    }
                  }
                }

        };
        dfs (0 , -1 ) ;

        cout << *max_element(ans.begin() , ans.end()) << endl;
        
    }
    in.readEof();
    return 0;
}
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct input_checker {
    string buffer;
    int pos;
 
    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";
 
    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }
 
    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && buffer[now] != ' ' && buffer[now] != '\n') {
            now++;
        }
        return now;
    }
 
    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }
 
    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }
 
    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }
 
    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }
 
    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }
 
    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }
 
    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

//input_checker inp;
const int N = 5e4 + 5;
const int M = 1e3 + 5;
int n, m, dp1[N][M], dp2[N][M], ans;
string s;
int sum_n = 0, sum_m = 0;
vector <pair<int, char>> adj[N];

void dfs(int u, int par){
    for (int i = 0; i <= m + 1; i++){
        dp1[u][i] = dp2[u][i] = 0;
    }

    for (auto [v, ch] : adj[u]){
        if (v != par){
            dfs(v, u);

            for (int i = 0; i <= m; i++){
                ans = max(ans, dp1[u][i] + dp2[v][i + 1]);
                ans = max(ans, dp1[v][i] + dp2[u][i + 1]);

                if (ch == s[i]){
                    ans = max(ans, dp1[u][i - 1] + dp2[v][i + 1] + 1);
                    ans = max(ans, dp1[v][i - 1] + dp2[u][i + 1] + 1);
                }
            }

            for (int i = 1; i <= m; i++){
                dp1[u][i] = max(dp1[u][i], dp1[v][i]);
                if (s[i] == ch){
                    dp1[u][i] = max(dp1[u][i], dp1[v][i - 1] + 1);
                }
                dp1[u][i] = max(dp1[u][i], dp1[u][i - 1]);
            }

            for (int i = m; i >= 1; i--){
                dp2[u][i] = max(dp2[u][i], dp2[v][i]);
                if (s[i] == ch){
                    dp2[u][i] = max(dp2[u][i], dp2[v][i + 1] + 1);
                }
                dp2[u][i] = max(dp2[u][i], dp2[u][i + 1]);
            }
        }
    }
}

void Solve() 
{
    // n = inp.readInt(1, (int)1e4); inp.readEoln();
    cin >> n;
    sum_n += n; ans = 0;
    assert(sum_n <= (int)1e4);
    
    for (int i = 1; i <= n; i++) adj[i].clear();

    for (int i = 1; i < n; i++){
        // int u = inp.readInt(1, n); inp.readSpace();
        // int v = inp.readInt(1, n); inp.readSpace();
        // string str = inp.readString(1, 1); inp.readEoln();
        // assert(str[0] >= 'a' && str[0] <= 'z');
        
        int u, v; string str; cin >> u >> v >> str;

        adj[u].push_back({v, str[0]});
        adj[v].push_back({u, str[0]});
    }

    cin >> s;
   // inp.readEoln();
    for (auto x : s){
        assert(x >= 'a' && x <= 'z');
    }
    
    m = s.length(); sum_m += m; assert(sum_m <= 1000);
    s = "0" + s + "0";

    dfs(1, -1);
   // return;

    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);

    // t = inp.readInt(1, 500);
    // inp.readEoln();
    
    cin >> t;
    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }
    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (C++)
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include "bits/stdc++.h"
using namespace std;
using ll = long long int;
mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

int main()
{
    ios::sync_with_stdio(false); cin.tie(0);

    int t; cin >> t;
    while (t--) {
        int n; cin >> n;
        vector adj(n, vector<pair<int, char>>());
        for (int i = 0; i < n-1; ++i) {
            int u, v; cin >> u >> v; --u ,--v;
            char c; cin >> c;
            adj[u].push_back({v, c});
            adj[v].push_back({u, c});
        }
        string s; cin >> s;
        int m = s.size();
        vector up(n, vector(m+1, 0));
        vector down(n, vector(m+1, 0));

        // up[i][j] -> longest subsequence from the subtree of u going up, considering the first i chars of s
        // down[i][j] -> longest subseq in subtree of u going down, considering last i chars of s

        int ans = 0;
        auto dfs = [&] (const auto &self, int u, int p) -> void {
            for (auto [v, c] : adj[u]) {
                if (v == p) continue;
                self(self, v, u);
                for (int i = 1; i <= m; ++i) {
                    int cur = up[v][i];
                    if (c == s[i-1]) cur = max(cur, up[v][i-1] + 1);
                    ans = max(ans, cur + down[u][m-i]);

                    cur = down[v][i];
                    if (c == s[m-i]) cur = max(cur, down[v][i-1] + 1);
                    ans = max(ans, cur + up[u][m-i]);
                }
                for (int i = 1; i <= m; ++i) {
                    int cur = up[v][i];
                    if (c == s[i-1]) cur = max(cur, up[v][i-1] + 1);
                    up[u][i] = max(up[u][i], cur);

                    cur = down[v][i];
                    if (c == s[m-i]) cur = max(cur, down[v][i-1] + 1);
                    down[u][i] = max(down[u][i], cur);
                }

                for (int i = 1; i <= m; ++i) up[u][i] = max(up[u][i], up[u][i-1]), down[u][i] = max(down[u][i], down[u][i-1]);
            }
        };
        dfs(dfs, 0, 0);
        cout << ans << '\n';
    }
}

Why is 1 added here? I think it should be just up[u][i-1].

I am doing it the this way →

  1. fix u and i and look for up and down value for each children figure out what they contribute, if the max (up and down) is just contributed by same children look for the second maximum children, corresponding to up and down.
    Now use (max_up and second_max_down) and (max_down && second_max_up).

Is it ok to do this way?

Ah you’re right, looks like I copied over the expression and forgot to erase the 1 - should be fixed now.

If implemented properly, I believe it should work — but it also seems somewhat messy to implement, so you’ll need to be extra careful.
The method I suggested has the pro of being quite simple to implement (my implementation is attached to the post).

int cur = up[v][i];
if (c == s[i-1]) cur = max(cur, up[v][i-1] + 1);
ans = max(ans, cur + down[u][m-i]);

Here while updating cur you haven’t considered the choice when edge (u,v) is matched to some character and (S[i] != c) i.e when we take up[u][i-1]. Why is that ?

To update the dp table, I do that later using the line

for (int i = 1; i <= m; ++i) up[u][i] = max(up[u][i], up[u][i-1]), down[u][i] = max(down[u][i], down[u][i-1]);

When updating ans, you don’t need to consider that case separately - if it was optimal, then it would’ve been considered when we paired up[u][i-1] and down[v][m-i+1] anyway.

1 Like