BLAZE - Editorial

Topics Required : Suffix Automaton, DSU on trees

First we build suffix automaton for the given string and get the tree of links. For each node (state) in the suffix automaton, we can find out the frequency of these strings in the given string (i.e. s) and also which suffixes belong to this node using dp on the tree which we built earlier.

So to find the closest distance for each node (set of strings) we need to find the minimum distance between starting points / ending points of all occurrences of these strings in given string (i.e. s). This can be done by using small to large merge / DSU on trees technique on the same tree.

Finally, we have the closest distance for each node of suffix automaton. So now, we just have to use mathematics to find out the expected power.

Time complexity : O(N * log^2N)

Solution :

#include "bits/stdc++.h"
#include "ext/pb_ds/assoc_container.hpp"
#include "ext/pb_ds/tree_policy.hpp"
using namespace std;
 
#pragma GCC optimize("O3","unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma GCC target("avx2")
 
#define sync ios_base::sync_with_stdio(0); cin.tie(0);
#define all(x) x.begin(),x.end()
#define unq(a) sort(all(a));a.resize(unique(all(a)) - a.begin())
#define ll long long
#define ld long double
#define pb push_back
#define fi first
#define se second
#define endl '\n'
//mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
mt19937 rng(1);
using pii = pair<int , int>;
 
int inline max(const int x, const int y){return (x > y ? x : y);}
int inline min(const int x, const int y){return (x < y ? x : y);}
int inline abs(const int x){return (x < 0 ? -x : x);}
 
const int inf = 1e9, mod = 998244353;
struct SAM{
    struct node{
        int len,link,endp;
        int next[26];
        node():len(0),link(0),endp(-1),next{}{}
    };
 
    int last;
    vector<node> t;
    vector<vector<int>> e;
    vector<int> sz, end_diff, endpos;
    vector<set<int>> rht;
    vector<set<int, greater<int>>> lft; 
 
    SAM():last(1){
        t.emplace_back();
        endpos.push_back(0); end_diff.push_back(inf);
        t[0].len = -1;
        t[0].link = -1;
        for(int i = 0;i < 26; ++i){
            t[0].next[i] = 1;
        }
        t.emplace_back();
        endpos.push_back(0); end_diff.push_back(inf);
    }
    ~SAM(){}
 
    void dfs(int u){
      sz[u] = 1;
      int w = end_diff[u];
      end_diff[u] = inf;
 
      if (!e[u].size()){
        lft[u].insert(w); rht[u].insert(w);
        return;
      }
      for(int v : e[u]){
          dfs(v);
          endpos[u] += endpos[v];
          sz[u] += sz[v];
          end_diff[u] = min(end_diff[u], end_diff[v]);
      }
      for (int i = 0; i < e[u].size(); i++){
        if (sz[e[u][i]] > sz[e[u][0]]){
          swap(sz[0] , sz[e[u][i]]);
          swap(e[u][i] , e[u][0]);
        }
      }
      lft[u].swap(lft[e[u][0]]);
      rht[u].swap(rht[e[u][0]]);
      for (int i = 1; i < e[u].size(); i++){
        for (const int& x : lft[e[u][i]]){
          auto l = lft[u].upper_bound(x);
          auto r = rht[u].upper_bound(x);
          if (l != lft[u].end()) end_diff[u] = min(end_diff[u] , x - *l);
          if (r != rht[u].end()) end_diff[u] = min(end_diff[u] , *r - x);
          lft[u].insert(x);
          rht[u].insert(x);
        }
      }
      auto l = lft[u].upper_bound(w);
      auto r = rht[u].upper_bound(w);
      if (l != lft[u].end()) end_diff[u] = min(end_diff[u] , w - *l);
      if (r != rht[u].end()) end_diff[u] = min(end_diff[u] , *r - w);
      lft[u].insert(w);
      rht[u].insert(w);
    }
 
    void extend(int c){
        int p = last;
        int cur = t.size();
        t.emplace_back();
        endpos.push_back(1);
        end_diff.push_back(t[p].len + 1);
 
        t[cur].len = t[p].len+1;
        t[cur].endp = t[p].len;
        last = cur;
        while(t[p].next[c] == 0){
            t[p].next[c] = cur;
            p = t[p].link;
        }
        int q = t[p].next[c];
        if(t[q].len == t[p].len+1){
            t[cur].link = q;
        }else{
            int r = t.size();
            t.push_back(t[q]); // cloned
            endpos.push_back(0);
            end_diff.push_back(inf);
 
            t[r].len = t[p].len+1;
            t[r].endp = -1;
            t[q].link = r;
            t[cur].link = r;
            while(t[p].next[c] == q){
                t[p].next[c] = r;
                p = t[p].link;
            }
        }
    }
 
    void endpos_size(){
        e.resize(t.size());
        sz.resize(t.size());
        lft.clear(); rht.clear();
        lft.resize(t.size());
        rht.resize(t.size());
        for(int i = 1;i < t.size(); ++i){
            e[t[i].link].push_back(i);
        }
        dfs(1);
    }
};
 
ll power(ll x, ll y){
  x %= mod;
  ll r = 1;
  while(y > 0){
    if (y & 1){
      r *= x; r %= mod;
    }
    y >>= 1;
    x *= x; x %= mod;
  }
  return r;
}
 
int main(){
 
    #ifndef ONLINE_JUDGE 
        freopen("input.txt", "r", stdin);
        freopen("output.txt", "w", stdout);
    #endif
 
    sync
 
    int t = 1;
    cin >> t;
    while(t--){
      string s;
      cin >> s;
      int n = s.size();
      SAM aum;
      for (char c : s){
        aum.extend(c - 'a');
      }
      aum.endpos_size();
      int expex = 0, nw, inv = power(2, mod - 2);
      for (int szmn, szmx, l, r, i = 2; i < aum.t.size(); i++){
        if (aum.endpos[i] < 2) continue;
        szmn = aum.t[aum.t[i].link].len + 1;
        szmx = aum.t[i].len;
        int mn_df = aum.end_diff[i];
        //cout << " => " << aum.endpos[i] << " " << szmn << " " << szmx << " " << mn_df << endl;
        assert(mn_df >= 1 && mn_df <= n - 1);
        nw = (((szmx * 1ll * (szmx + 1)) % mod - ((szmn - 1) * 1ll * szmn) % mod) * 1ll * inv) % mod;
        nw += (mn_df * 1ll * (szmx - szmn + 1)) % mod;
        nw = (nw * 1ll * aum.endpos[i]) % mod;
        expex += nw; expex %= mod; 
      }
      if (expex < 0) expex += mod;
      expex = (expex * power((n * 1ll * (n + 1)) / 2, mod - 2)) % mod;
      cout << expex << endl;
    }
    cerr << "processor time: " << clock() / (double) CLOCKS_PER_SEC << "s    ";
    return 0;
}