LNGSUB - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3

Author: Shubham Anand Jain
Tester: Anay Karnik
Editorialist: Mohan Abhyas

DIFFICULTY:

EASY-MEDIUM

PREREQUISITES:

Suffix-array, LCP construction,Minimum Spanning Tree

PROBLEM:

You are given N strings S_1, S_2, \ldots, S_N. Consider a complete undirected graph with N vertices (numbered 1 through N), in which the weight of an edge between vertices u and v is equal to the length of the longest common substring of S_u and S_v.

Find the maximum possible weight of a spanning tree of this graph.

EXPLANATION:

Let S be concatenated string of S_i with special character in between them i.e.,
S = S_1 + \{ + \dots + \{ + S_N
sa = suffix array of string S
lcp = LCP array of string S meaning lcp[i] = length of the longest common prefix of the suffixes starting at S[i] and S[i+1].
The above constructions of sa, lcp arrays can be done in \mathcal{O}(|S|log(|S|)).

Consider string S_i,S_j. If suffixes corresponding to longest common substring of S_i,S_j are not consecutive in suffix array sa => edge between i,j will not be a part of maximum spanning tree.
Proof:
sa_{ij}, sa_{ji} denote positions of suffixes corresponding longest common substring of S_i,S_j respectively.
Let k_1 = i, k_2,\dots,k_m = j be strings which contain suffixes in between sa_{ij}, sa_{ji}.
Weights of all edges between k_i, k_{i+1} > i,j => there exists a cycle with edge between i,j as the minimum weight => edge between i,j will not be a part of it.

Iterate through suffix array
If suffixes at position i,i+1 correspond to different valid strings then add a edge between those strings with value min(lcp[i], length of suffixes)

Use some minimum spanning tree algorithm to find the maximum weight tree of the above graph(traverse the edges in descending order of weights in MST algorithm or normal MST algorithm with negative weights of actual edge weights)

TIME COMPLEXITY:

\mathcal{O}(\sum|S_i|log(\sum|S_i|)) per testcase.

SOLUTIONS:

Setter's Solution
// By TheOneYouWant
#include<bits/stdc++.h>
using namespace std;
#define fastio ios_base::sync_with_stdio(0);cin.tie(0)

void induced_sort(const vector<int> &vec, int val_range, vector<int> &SA, const vector<bool> &sl, const vector<int> &lms_idx) {
    vector<int> l(val_range, 0), r(val_range, 0);
    for (int c : vec) {
        if (c + 1 < val_range) ++l[c + 1];
        ++r[c];
    }
    partial_sum(l.begin(), l.end(), l.begin());
    partial_sum(r.begin(), r.end(), r.begin());
    fill(SA.begin(), SA.end(), -1);
    for (int i = lms_idx.size() - 1; i >= 0; --i)
        SA[--r[vec[lms_idx[i]]]] = lms_idx[i];
    for (int i : SA)
        if (i >= 1 && sl[i - 1]) {
            SA[l[vec[i - 1]]++] = i - 1;
        }
    fill(r.begin(), r.end(), 0);
    for (int c : vec)
        ++r[c];
    partial_sum(r.begin(), r.end(), r.begin());
    for (int k = SA.size() - 1, i = SA[k]; k >= 1; --k, i = SA[k])
        if (i >= 1 && !sl[i - 1]) {
            SA[--r[vec[i - 1]]] = i - 1;
        }
}
 
vector<int> SA_IS(const vector<int> &vec, int val_range) {
    const int n = vec.size();
    vector<int> SA(n), lms_idx;
    vector<bool> sl(n);
    sl[n - 1] = false;
    for (int i = n - 2; i >= 0; --i) {
        sl[i] = (vec[i] > vec[i + 1] || (vec[i] == vec[i + 1] && sl[i + 1]));
        if (sl[i] && !sl[i + 1]) lms_idx.push_back(i + 1);
    }
    reverse(lms_idx.begin(), lms_idx.end());
    induced_sort(vec, val_range, SA, sl, lms_idx);
    vector<int> new_lms_idx(lms_idx.size()), lms_vec(lms_idx.size());
    for (int i = 0, k = 0; i < n; ++i)
        if (!sl[SA[i]] && SA[i] >= 1 && sl[SA[i] - 1]) {
            new_lms_idx[k++] = SA[i];
        }
    int cur = 0;
    SA[n - 1] = cur;
    for (size_t k = 1; k < new_lms_idx.size(); ++k) {
        int i = new_lms_idx[k - 1], j = new_lms_idx[k];
        if (vec[i] != vec[j]) {
            SA[j] = ++cur;
            continue;
        }
        bool flag = false;
        for (int a = i + 1, b = j + 1;; ++a, ++b) {
            if (vec[a] != vec[b]) {
                flag = true;
                break;
            }
            if ((!sl[a] && sl[a - 1]) || (!sl[b] && sl[b - 1])) {
                flag = !((!sl[a] && sl[a - 1]) && (!sl[b] && sl[b - 1]));
                break;
            }
        }
        SA[j] = (flag ? ++cur : cur);
    }
    for (size_t i = 0; i < lms_idx.size(); ++i)
        lms_vec[i] = SA[lms_idx[i]];
    if (cur + 1 < (int)lms_idx.size()) {
        auto lms_SA = SA_IS(lms_vec, cur + 1);
        for (size_t i = 0; i < lms_idx.size(); ++i) {
            new_lms_idx[i] = lms_idx[lms_SA[i]];
        }
    }
    induced_sort(vec, val_range, SA, sl, new_lms_idx);
    return SA;
}
vector<int> suffix_array(const string &s, const int LIM = 128) {
    vector<int> vec(s.size() + 1);
    copy(begin(s), end(s), begin(vec));
    vec.back() = '$';
    auto ret = SA_IS(vec, LIM);
    ret.erase(ret.begin());
    return ret;
}

vector<int> LCP(const string &s, const vector<int> &sa) {
    int n = s.size(), k = 0;
    vector<int> lcp(n), rank(n);
    for (int i = 0; i < n; i++)
        rank[sa[i]] = i;
    for (int i = 0; i < n; i++, k ? k-- : 0) {
        if (rank[i] == n - 1) {
            k = 0;
            continue;
        }
        int j = sa[rank[i] + 1];
        while (i + k < n && j + k < n && s[i + k] == s[j + k])
            k++;
        lcp[rank[i]] = k;
    }
    lcp[n - 1] = 0;
    return lcp;
}

const int LIM = 1e6+5;

int link[LIM];
int sz[LIM];

int find(int x){
    if(x == link[x]) return x;
    return link[x] = find(link[x]);
}

void unite(int a, int b){
    a = link[a];
    b = link[b];
    if(a == b) return;
    if(sz[a] < sz[b]) swap(a,b);
    sz[a] += sz[b];
    link[b] = a;
}

int main(){
    fastio;

    int tests;
    cin >> tests;

    while(tests--){
        int n;
        cin >> n;

        string s[n];
        int pref[n] = {0};
        string combined;
        vector<int> which_str;

        for(int i = 0; i < n; i++){
            cin >> s[i];
            if(i>0) pref[i] += pref[i-1];
            pref[i]++;
            pref[i] += s[i].length();
            for(int j = combined.size(); j < combined.size() + s[i].length(); j++){
                which_str.push_back(i);
            }
            combined += s[i];
            if(i!=n-1){
                combined += '{';
                which_str.push_back(-1);
            }
            link[i] = i;
            sz[i] = 1;
        }
        vector<pair<int,pair<int,int>>> edges;

        vector<int> sa = suffix_array(combined);
        vector<int> lcp = LCP(combined, sa);

        for(int i = 0; i < (int)lcp.size() - 1; i++){
            int str1 = which_str[sa[i]];
            int str2 = which_str[sa[i+1]];
            if(str1==-1) continue;
            if(str2==-1) continue;
            if(str1==str2) continue;
            int wt = min(lcp[i], pref[str1] - 1 - sa[i]);
            edges.push_back({wt, {str1, str2}});
        }

        sort(edges.begin(), edges.end(), greater<pair<int,pair<int,int>>>());

        long long int span_wt = 0;

        for(int i = 0; i < edges.size(); i++){
            int wt = edges[i].first;
            int u, v;
            tie(u, v) = edges[i].second;
            if(find(u)==find(v)) continue;
            unite(u, v);
            span_wt += wt;
        }
        cout << span_wt << endl;
    }
    return 0;
}
Tester's Solution
#include <iostream>
#include <algorithm>
#include <set>
#include <vector>

#define int long long

const int MAXN = 1000006;

int par[MAXN];

int root(int u) {
  if(par[u] < 0)
    return u;
  return par[u] = root(par[u]);
}

void merge(int u, int v) {
  u = root(u);
  v = root(v);

  if(u == v)
    return;

  if(par[u] > par[v])
    u ^= v ^= u ^= v;

  par[u] += par[v];
  par[v] = u;
}

std::vector<int> sort_cyclic_shifts(std::string const& s) {
  int n = s.size();
  const int alphabet = 28;

  std::vector<int> p(n), c(n), cnt(std::max(alphabet, n), 0);
  for (int i = 0; i < n; i++)
    cnt[s[i]]++;
  for (int i = 1; i < alphabet; i++)
    cnt[i] += cnt[i-1];
  for (int i = 0; i < n; i++)
    p[--cnt[s[i]]] = i;
  c[p[0]] = 0;
  int classes = 1;
  for (int i = 1; i < n; i++) {
    if (s[p[i]] != s[p[i-1]])
      classes++;
    c[p[i]] = classes - 1;
  }

  std::vector<int> pn(n), cn(n);
  for (int h = 0; (1 << h) < n; ++h) {
    for (int i = 0; i < n; i++) {
      pn[i] = p[i] - (1 << h);
      if (pn[i] < 0)
        pn[i] += n;
    }
    std::fill(cnt.begin(), cnt.begin() + classes, 0);
    for (int i = 0; i < n; i++)
      cnt[c[pn[i]]]++;
    for (int i = 1; i < classes; i++)
      cnt[i] += cnt[i-1];
    for (int i = n-1; i >= 0; i--)
      p[--cnt[c[pn[i]]]] = pn[i];
    cn[p[0]] = 0;
    classes = 1;
    for (int i = 1; i < n; i++) {
      std::pair<int, int> cur = {c[p[i]], c[(p[i] + (1 << h)) % n]};
      std::pair<int, int> prev = {c[p[i-1]], c[(p[i-1] + (1 << h)) % n]};
      if (cur != prev)
        ++classes;
      cn[p[i]] = classes - 1;
    }
    c.swap(cn);
  }
  return p;
}

std::vector<int> suffix_array_construction(std::string s) {
  s.push_back(0);
  std::vector<int> sorted_shifts = sort_cyclic_shifts(s);
  s.pop_back();
  return sorted_shifts;
}

std::vector<int> lcp_construction(std::string const& s, std::vector<int> const& p) {
  int n = s.size();
  std::vector<int> rank(n, 0);
  for (int i = 0; i < n; i++)
    rank[p[i+1]] = i;

  int k = 0;
  std::vector<int> lcp(n-1, 0);
  for (int i = 0; i < n; i++) {
    if (rank[i] == n - 1) {
      k = 0;
      continue;
    }
    int j = p[rank[i] + 2];
    while (i + k < n && j + k < n && (s[i+k] == s[j+k] && s[i+k] != 1))
      k++;
    lcp[rank[i]] = k;
    if (k)
      k--;
  }
  return lcp;
}

signed main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(0);

  int t;
  std::cin >> t;

  while(t--) {
    int n;
    std::cin >> n;

    std::string s[n];
    int len[n];
    int tlen = 0;
    for(int i = 0; i < n; i++) {
      std::cin >> s[i];
      par[i] = -1;
      len[i] = s[i].length();
      tlen += len[i] + 1;
    }

    tlen--;

    std::string cat(tlen, 'a');
    int pos[tlen];
    int it = 0;

    for(int i = 0; i < n; i++) {
      for(int j = 0; j < len[i]; j++) {
        pos[it] = i;
        cat[it++] = s[i][j]-'a'+2;
      }
      if(i+1 < n) {
        pos[it] = -1;
        cat[it++] = 1;
      }
    }

    std::vector<int> sa = suffix_array_construction(cat);
    std::vector<int> lcp = lcp_construction(cat, sa);

    int last[tlen], max[tlen];
    std::vector<std::pair<int, std::pair<int, int> > > krus;

    last[n-1] = -1;
    max[n-1] = 0;

    for(int i = n; i < tlen; i++) {
      if(pos[sa[i+1]] == pos[sa[i]]) {
        last[i] = last[i-1];
        max[i] = std::min(max[i-1], lcp[i-1]);
      }
      else {
        last[i] = i-1;
        max[i] = lcp[i-1];
      }

      if(last[i] != -1) {
        krus.push_back({max[i], {pos[sa[i+1]], pos[sa[last[i]+1]]}});
      }
    }

    int ans = 0;
    std::sort(krus.begin(), krus.end());
    std::reverse(krus.begin(), krus.end());
    for(int i = 0; i < krus.size(); i++) {
      auto u = krus[i];

      int x = root(u.second.first), y = root(u.second.second);
      if(x != y) {
        ans += u.first;
        merge(x, y);
      }
    }

    std::cout << ans << std::endl;
  }

  return 0;
}

We can find LCS (longest common substring ) in N time -

  1. CodeChef: Practical coding for everyone
  2. SSTORY - Editorial

and then simply build graph do kruskal Via DSU , and you are done :slight_smile:

2 Likes

Hey, I am trying to understand this O(n) way and I need to spend some time on it. But I had a doubt about whether it would be useful for this problem. As we use all pairs of strings, won’t this lead to a O(n^2) solution, where n is the sum of length of all strings. Or am I missing something here.

1 Like