MFSS - Editorial

PROBLEM LINK:

Contest Division 1
Contest Division 2
Contest Division 3
Practice

Setter: Sabbir Rahman (Abir)
Tester: Aryan
Editorialist: Taranpreet Singh

DIFFICULTY

Medium

PREREQUISITES

String Suffix Structures, Segment Tree

PROBLEM

Given an array A of N integers, compute the maximum possible score among all non-empty subarrays of given array. The score of subarray A_{l, r} is given as
score(l, r) = (\sum_{i = l}^r A_i) * (\text{occurrences}), where \text{occurrences} is the number of occurrences of subarray A_{l, r} in array A.

QUICK EXPLANATION

  • Build suffix tree on given array A as the text string. Each path corresponds to a distinct subarray, where the number of leaves in subtree denotes the number of occurrences (Assuming the last integer is distinct).
  • Then do a DFS on this tree, maintaining the sum of values on the path from the root to the current node, and the number of leaves in the subtree of each node.
  • Since there can be multiple values written on a single edge, we need to support query q(l, r) as the maximum sum of the non-empty prefix of subarray A_{l, r} to process all of them simultaneously, which can be done using segment tree.

EXPLANATION

Intuition

Counting the occurrences of a substring in the original string is a well-explored domain in string problems, so let’s consider our array as a string of integers, each index representing a character in the string.

We need to consider all subarrays of the given array A, so let’s create a trie, and add all suffices of A in this trie.

This structure is called Suffix Trie and is slightly different from the suffix tree.

Throughout the editorial, I’ll consider A = [-1,2,2,-1,2,4,-\infin] as the given array. I have added -\infin at the end which I’ll explain later.

The Trie of all suffices of this array is as follows. Ignore Node labels, only edge labels matter, and node R is root.

Claim: Each node in suffix trie represents a subarray of the given array A.
Proof: Each subarray A_{l, r} can be written as the prefix of l-th suffix of the original array.

So, for each node, we need to compute the sum of subarray (which is the sum of edge labels on the path from the root to this node) and the number of times it appears in the original array.

Claim: The number of times some subarray appears in A is exactly the number of leaves in the subtree of the node corresponding to that subarray, in suffix trie.
Proof: The number of leaves in the subtree of node u in suffix trie means the number of suffices of A which have subarray represented by node u as their prefix. Each suffix corresponds uniquely to a start point, and each suffix contributes to exactly one occurrence.

Hence, in the above example, subarray [-1, 2] appears twice, and subarray [2] appears thrice, but subarray [2, -1] appears once.

Hence, we can build suffix trie, and for each node, we can calculate the number of leaves in its subtree. The sum of subarray can be calculated along with DFS, or by prefix sums, solving the problem in time proportional to the number of nodes in the tree.

The number of nodes in suffix trie may be up to N^2, so this solution doesn’t work for the final subtask.

Faster intuition

Assume that cnt_u denotes the number of leaves in the subtree of node u in suffix trie.

Let’s merge the nodes which have only one child with their child and join their edge labels. We get the following tree, which is the suffix tree of A

The number of nodes in the above tree is O(N) since there are at most N+1 leaves (including -\infin), and each node has at least two children, so there cannot be more than N internal nodes.

But, there is a new problem now. Earlier, we were able to check all subarrays by establishing the correspondence between each node and subarray. But now, some subarray may end inside an edge. For example, subarrays [2,1], [-1,2,2,-1] end within an edge, so these would be missed if we apply the same idea again.

Let’s say a subarray A_{l, r} is considered to end on edge e if traversing through the suffix tree from the root by following the path with labels from l to r, the last traversed edge is edge e.

For example, for subarray [-1], [-1,2], the last traversed edge is (0, 5), while last traversed edge for subarrays [2,-1], [2, -1, 2],[2,-1,2,4], [2,-1,2,4,-\infin]

Claim: All the subarrays ending at same edge appear the same number of times in array A.
Proof: WLOG assume there are two characters in the edge label, say a and b. So suffix trie must have some nodes u, v and w such that there’s an edge from u to v with label a and from v to w with label b, and node v doesn’t have any other child.

Since node w is the only child of node v, cnt_v = cnt_w, which implies subarray ending at a appears same number of times as subarray ending at b.

Hence, this allows for a new approach. For each edge from u to v, we know that all subarrays ending at this edge appear cnt_v times in the original array. We need to compute maximum sum subarray for each subarray ending at edge e

Segment tree Time

Let’s assume we want to compute this for edge e = (u, v). All the subarrays ending at e must have all values on labels on the path from the root to u as the prefix, only then we can reach this edge.

Considering all subarrays A_{l, r} ending at edge e, we can see that the right ends from a consecutive segment.

For example, for edge (3, 4), subarrays appearing on that edge are $$A_{2,3}, A_{2,4},A_{2,5},A_{2,6}$, and each of these subarrays appear cnt_v = 1 times in original array.

Claim: The right ends of subarrays ending at the same edge are consecutive values in some range.

For edge (3, 4), the left end is a node 2, and the right ends of subarrays form range $[3, 6].

Let’s say, for edge e, this interval is [l', r]. This means, that subarrays ending at edge u have prefix A_{l, l'-1} followed by some non-empty prefix of subarray A_{l', r}

For edge (3, 4), l' = 3, so each subarray consist of A_{2,2} followed by some non-empty prefix of A_{3,6}

Hence, we need to compute for edge e, \displaystyle \max_{i = l'}^r S_{l, r} = S_{l, l'-1} + \max_{i = l'}^r {S_{l', r}}. S_{l, l'} is computed during DFS, or by prefix sums.

Now we have query, for some interval (l, r), find \max_{i = l}^r S_{l, r}, the largest valued non-empty prefix.

This is where the segment tree comes in. Let’s suppose node i in tree stores information of interval [L_i, R_i]. max_i is the maximum value of \displaystyle\max_{i = L_i}^{R_i} S_{l, i}, the largest valued non-empty prefix in range [L_i, R_i]. and \displaystyle sum_i = \sum_{i = L_i}^{R_i} A_{i}.

It is easy to see that if node x is parent of node 2*x and node 2*x+1 in segment tree, then we can write sum_x = sum_{x*2} + sum_{x*2+1} and max_x = max(max_{2*x}, sum_{2*x}+max_{2*x+1}).

Prove why above works, if for leaf nodes, we have sum_x = max_x = A_{L_i}

Concluding

Hence, the segment tree allows computing the largest subarray sum ending at each edge. We also know the number of occurrences of all subarrays ending at each edge. We can multiply and take maximum.

Why -\infin was added at the end

Since -\infin = -10^13 doesn’t appear in the array, adding it to the end makes sure that all the leaves are explicit nodes. Consider suffix tree for string abcab, the two leaves are implicit, which would not allow us to procced normally.

If this value is included in the subarray, the subarray has only one occurrence, and a very small value, so this doesn’t affect the maximum answer.

Resources

Suffix Tries and Suffix Trees
Suffix Tree Implementation

TIME COMPLEXITY

The time complexity is O(N*log(N)) per test case.

SOLUTIONS

Setter's Solution
#include <bits/stdc++.h>
using namespace std;

using ll = long long int;
mt19937 rng((unsigned)chrono::system_clock::now().time_since_epoch().count());

const int nmax = 2e5+10;

namespace SA {

    int rnk[nmax], suf[nmax];
    int key[nmax], grp[nmax];

    void quick_split(int L, int R)		//3-way quicksorts L ... R of suf
    {
        if((grp[L] = R) == L) return;		//base case, single element
        int l = L, r = R, p = key[suf[L + rng()%(R-L+1)]];	//pivot key

        for(int i = l; i<= r; i++)
            if(key[suf[i]] < p) swap(suf[i], suf[l++]);
            else if(key[suf[i]] > p) swap(suf[i--], suf[r--]);

        for(int i = l; i<=r; i++) grp[i] = r;		//updating group for pivots
        if(l > L) quick_split(L, l-1);
        if(r < R) quick_split(r+1, R);
    }

    void Build_Suf(ll str[], int n)	//str is string, n is length
    {
        for(int i = 0; i<n; i++) suf[i] = i, key[i] = str[i], grp[i] = n-1;

        for(int len = 1; len<n+n; len<<=1)				//prefix doubling loop
        {
            for(int i = 0, g = 0; i<n; i = g + 1)	//quick sort grouped subarrays
                if((g = grp[i]) != i) quick_split(i, g);

            for(int i = 0; i<n; i++) rnk[suf[i]] = i;				//recreate ranks
            for(int i = 0; i<n-len; i++) key[i] = grp[rnk[i+len]];	//find keys
            if(n >= len)
                key[n-len] = -INT_MAX;			//key of suffix n is lower than all other
                //for(int i = n-len; i<n; i++) key[i] = grp[rnk[i+len-n]];	//cyclic shift ver.
        }
    }

    //sparse table arrays here
    int lcp[nmax];
    void Kasai(ll str[], int n){		//call with string and strlen
        int prv = 0;
        for(int i = 0; i<n; i++){
            if(rnk[i] == 0) {prv = lcp[0] = 0; continue;}

            int j = suf[rnk[i]-1];
            while(i+prv<n && j+prv<n && str[i+prv] == str[j+prv]) prv++;
            lcp[rnk[i]] = prv;
            if(prv > 0) prv--;
        }
    }


    vector<int> adj[2*nmax];		//adjacency list
    int lb[2*nmax], rb[2*nmax], plen[2*nmax], nodecnt = 0;	//info about nodes
    //call with length n
    void SuffixTree(int n){
        vector<int> stk{++nodecnt};
        lb[nodecnt] = 0, rb[nodecnt] = n-1, plen[nodecnt] = 0, lcp[n] = 0;
        int last = -1;
        for(int i = 0, sf = 1; i+sf<=n; i+=sf, sf^=1){	//sf = suf len or lcp is being used
            int left = i-(sf^1), curlcp = (sf)? n-suf[i]: lcp[i];

            while(curlcp < plen[stk.back()]){
                rb[stk.back()] = i-(sf^1), left = lb[stk.back()];
                last = stk.back(), stk.pop_back();
                if(curlcp <= plen[stk.back()])
                    adj[stk.back()].push_back(last), last = -1;
            }
            if(curlcp > plen[stk.back()]){
                stk.push_back(++nodecnt);
                if(last != -1)
                    adj[nodecnt].push_back(last), last = -1;
                plen[nodecnt] = curlcp, lb[nodecnt] = left;
            }
        }
    }
    //check suffix node -> (adj[node].empty() || lb[node] != lb[adj[node][0]])
}

ll input[nmax];

const ll inf = 1e18;

namespace ST{
    const int lgn = 32 - __builtin_clz(nmax);
    ll sparse[nmax][lgn];

    void BuildSparse(int n)
    {
        for(int i = 0; i<n; i++) sparse[i][0] = input[i];		//changed if 1-indexed
        for(int k = 1, p = 1; k<lgn; k++, p <<= 1)
            for(int i = 0; i+p+p <= n; i++)					//changed if 1-indexed
                sparse[i][k] = max(sparse[i][k-1], sparse[i+p][k-1]);
    }

    ll query(int l, int r)
    {
        int len = r - l + 1;
        int k = 32 - __builtin_clz(len) - 1;
        return max(sparse[l][k], sparse[r-(1<<k)+1][k]);
    }
}

int n;

ll ans;

void dfs(int u){


    for(int v : SA::adj[u]){

        int st = SA::suf[SA::lb[v]];
        ll mss = ST::query(st+SA::plen[u], st+SA::plen[v]-1);

        mss -= (st == 0? 0 : input[st-1]);

        ll cand = mss * (SA::rb[v] - SA::lb[v]+1);
        if(cand > ans){
            ans = cand;
        }

        dfs(v);
    }

}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);


    int tc;
    cin>>tc;

    for(int cs = 1; cs <= tc; cs++){
        cin>>n;

        for(int i = 0; i<n; i++){
            cin>>input[i];
        }

        SA::Build_Suf(input, n);
        SA::Kasai(input, n);
        SA::SuffixTree(n);

        partial_sum(input, input+n, input);
        ST::BuildSparse(n);

        ans = -inf;
        dfs(1);

        cout<<ans<<"\n";


        for(int i = 1; i<=SA::nodecnt; i++){
            SA::adj[i].clear();
        }
        SA::nodecnt = 0;
    }

    return 0;
}
Tester's Solution
/* in the name of Anton */

/*
  Compete against Yourself.
  Author - Aryan (@aryanc403)
  Atcoder library - https://atcoder.github.io/ac-library/production/document_en/
*/

#ifdef ARYANC403
    #include <header.h>
#else
    #pragma GCC optimize ("Ofast")
    #pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx")
    //#pragma GCC optimize ("-ffloat-store")
    #include<bits/stdc++.h>
    #define dbg(args...) 42;
#endif

using namespace std;
#define fo(i,n)   for(i=0;i<(n);++i)
#define repA(i,j,n)   for(i=(j);i<=(n);++i)
#define repD(i,j,n)   for(i=(j);i>=(n);--i)
#define all(x) begin(x), end(x)
#define sz(x) ((lli)(x).size())
#define pb push_back
#define mp make_pair
#define X first
#define Y second
#define endl "\n"

typedef long long int lli;
typedef long double mytype;
typedef pair<lli,lli> ii;
typedef vector<ii> vii;
typedef vector<lli> vi;

const auto start_time = std::chrono::high_resolution_clock::now();
void aryanc403()
{
#ifdef ARYANC403
auto end_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end_time-start_time;
    cerr<<"Time Taken : "<<diff.count()<<"\n";
#endif
}

long long readInt(long long l, long long r, char endd) {
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true) {
        char g=getchar();
        if(g=='-') {
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g&&g<='9') {
            x*=10;
            x+=g-'0';
            if(cnt==0) {
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd) {
            if(is_neg) {
                x=-x;
            }
            assert(l<=x&&x<=r);
            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l, int r, char endd) {
    string ret="";
    int cnt=0;
    while(true) {
        char g=getchar();
        assert(g!=-1);
        if(g==endd) {
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt&&cnt<=r);
    return ret;
}
long long readIntSp(long long l, long long r) {
    return readInt(l,r,' ');
}
long long readIntLn(long long l, long long r) {
    return readInt(l,r,'\n');
}
string readStringLn(int l, int r) {
    return readString(l,r,'\n');
}
string readStringSp(int l, int r) {
    return readString(l,r,' ');
}

void readEOF(){
    assert(getchar()==EOF);
}

vi readVectorInt(int n,lli l,lli r){
    vi a(n);
    for(int i=0;i<n-1;++i)
        a[i]=readIntSp(l,r);
    a[n-1]=readIntLn(l,r);
    return a;
}

const lli INF = 0xFFFFFFFFFFFFFL;

lli seed;
mt19937 rng(seed=chrono::steady_clock::now().time_since_epoch().count());
inline lli rnd(lli l=0,lli r=INF)
{return uniform_int_distribution<lli>(l,r)(rng);}

class CMP
{public:
bool operator()(ii a , ii b) //For min priority_queue .
{    return ! ( a.X < b.X || ( a.X==b.X && a.Y <= b.Y ));   }};

void add( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt==m.end())         m.insert({x,cnt});
    else                    jt->Y+=cnt;
}

void del( map<lli,lli> &m, lli x,lli cnt=1)
{
    auto jt=m.find(x);
    if(jt->Y<=cnt)            m.erase(jt);
    else                      jt->Y-=cnt;
}

bool cmp(const ii &a,const ii &b)
{
    return a.X<b.X||(a.X==b.X&&a.Y<b.Y);
}

const lli mod = 1000000007L;
// const lli maxN = 1000000007L;

#include <algorithm>
#include <cassert>
#include <vector>


#ifdef _MSC_VER
#include <intrin.h>
#endif

namespace atcoder {

namespace internal {

int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}

int bsf(unsigned int n) {
#ifdef _MSC_VER
    unsigned long index;
    _BitScanForward(&index, n);
    return index;
#else
    return __builtin_ctz(n);
#endif
}

}  // namespace internal

}  // namespace atcoder


namespace atcoder {

template <class S, S (*op)(S, S), S (*e)()> struct segtree {
  public:
    segtree() : segtree(0) {}
    segtree(int n) : segtree(std::vector<S>(n, e())) {}
    segtree(const std::vector<S>& v) : _n(int(v.size())) {
        log = internal::ceil_pow2(_n);
        size = 1 << log;
        d = std::vector<S>(2 * size, e());
        for (int i = 0; i < _n; i++) d[size + i] = v[i];
        for (int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }

    void set(int p, S x) {
        assert(0 <= p && p < _n);
        p += size;
        d[p] = x;
        for (int i = 1; i <= log; i++) update(p >> i);
    }

    S get(int p) {
        assert(0 <= p && p < _n);
        return d[p + size];
    }

    S prod(int l, int r) {
        assert(0 <= l && l <= r && r <= _n);
        S sml = e(), smr = e();
        l += size;
        r += size;

        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }
        return op(sml, smr);
    }

    S all_prod() { return d[1]; }

    template <bool (*f)(S)> int max_right(int l) {
        return max_right(l, [](S x) { return f(x); });
    }
    template <class F> int max_right(int l, F f) {
        assert(0 <= l && l <= _n);
        assert(f(e()));
        if (l == _n) return _n;
        l += size;
        S sm = e();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(op(sm, d[l]))) {
                while (l < size) {
                    l = (2 * l);
                    if (f(op(sm, d[l]))) {
                        sm = op(sm, d[l]);
                        l++;
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    template <bool (*f)(S)> int min_left(int r) {
        return min_left(r, [](S x) { return f(x); });
    }
    template <class F> int min_left(int r, F f) {
        assert(0 <= r && r <= _n);
        assert(f(e()));
        if (r == 0) return 0;
        r += size;
        S sm = e();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(op(d[r], sm))) {
                while (r < size) {
                    r = (2 * r + 1);
                    if (f(op(d[r], sm))) {
                        sm = op(d[r], sm);
                        r--;
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while ((r & -r) != r);
        return 0;
    }

  private:
    int _n, size, log;
    std::vector<S> d;

    void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
};

}  // namespace atcoder

using namespace atcoder;

ii op(ii a, ii b) { return {a.X+b.X,max(a.Y, a.X+b.Y)}; }
ii e() { return {0,-INF}; }
segtree<ii, op, e> seg;

const int inf = -mod;
vi s;
int n;

struct node {
    int l, r, par, link;
    map<int,int> next;

    node (int l=0, int r=0, int par=-1)
        : l(l), r(r), par(par), link(-1) {}
    int len()  {  return r - l;  }
    int &get (int c) {
        if (!next.count(c))  next[c] = -1;
        return next[c];
    }
};

vector<node> t;
int SZ;

struct state {
    int v, pos;
    state (int v, int pos) : v(v), pos(pos)  {}
};

state ptr (0, 0);

state go (state st, int l, int r) {
    while (l < r)
        if (st.pos == t[st.v].len()) {
            st = state (t[st.v].get( s[l] ), 0);
            if (st.v == -1)  return st;
        }
        else {
            if (s[ t[st.v].l + st.pos ] != s[l])
                return state (-1, -1);
            if (r-l < t[st.v].len() - st.pos)
                return state (st.v, st.pos + r-l);
            l += t[st.v].len() - st.pos;
            st.pos = t[st.v].len();
        }
    return st;
}

int split (state st) {
    if (st.pos == t[st.v].len())
        return st.v;
    if (st.pos == 0)
        return t[st.v].par;
    node v = t[st.v];
    int id = SZ++;
    t[id] = node (v.l, v.l+st.pos, v.par);
    t[v.par].get( s[v.l] ) = id;
    t[id].get( s[v.l+st.pos] ) = st.v;
    t[st.v].par = id;
    t[st.v].l += st.pos;
    return id;
}

int get_link (int v) {
    if (t[v].link != -1)  return t[v].link;
    if (t[v].par == -1)  return 0;
    int to = get_link (t[v].par);
    return t[v].link = split (go (state(to,t[to].len()), t[v].l + (t[v].par==0), t[v].r));
}

void tree_extend (int pos) {
    for(;;) {
        state nptr = go (ptr, pos, pos+1);
        if (nptr.v != -1) {
            ptr = nptr;
            return;
        }

        int mid = split (ptr);
        int leaf = SZ++;
        t[leaf] = node (pos, n, mid);
        t[mid].get( s[pos] ) = leaf;

        ptr.v = get_link (mid);
        ptr.pos = t[ptr.v].len();
        if (!mid)  break;
    }
}

void build_tree() {
    SZ = 1;
    for (int i=0; i<n; ++i)
        tree_extend (i);
}

    //priority_queue < ii , vector < ii > , CMP > pq;// min priority_queue .

lli ans;

lli dfs(lli u,lli sum){
    if(u==-1)
        return 0;
    lli cnt=0;
    if(t[u].r==n)
        cnt++;
    auto pd=seg.prod(t[u].l,t[u].r);
    pd.X+=sum;
    pd.Y+=sum;
    for(auto &cld:t[u].next){
        cnt+=dfs(cld.Y,pd.X);
    }

    ans=max(ans,cnt*pd.Y);
    return cnt;
}

int main(void) {
    ios_base::sync_with_stdio(false);cin.tie(NULL);
    // freopen("txt.in", "r", stdin);
    // freopen("txt.out", "w", stdout);
// cout<<std::fixed<<std::setprecision(35);
lli T=readIntLn(1,2e5);
lli sumn=0;
while(T--)
{

    n=readIntLn(1,1e5);
    sumn+=n;
    assert(sumn<=3e5);
    s=readVectorInt(n,-1e7,1e7);
    s.pb(inf);
    n++;
    t.clear();t.resize(2*n+5);
    SZ=0;
    ptr=state(0,0);
    build_tree();
    ans=-INF;

    seg=segtree<ii, op, e>(n);

    for(int i=0;i<n;++i)
        seg.set(i,{s[i],s[i]});

    dfs(0,0);
    cout<<ans<<endl;
}   aryanc403();
    readEOF();
    return 0;
}
Editorialist's Solution
import java.util.*;
import java.io.*;
class MFSS{
    //SOLUTION BEGIN
    long INFIN = (long)1e13;
    int m;
    long[] sum, best;
    void pre() throws Exception{}
    void solve(int TC) throws Exception{
        int N = in.nextInt();
        long[] A = new long[N+1]; 
        for(int i = 0; i< N; i++)A[i] = in.nextLong();
        A[N] = -INFIN;
        SuffixTree st = new SuffixTree(1+N);
        for(int i = 0; i<= N; i++)st.addChar(A[i]);
        m = 1;
        while(m < A.length)m<<=1;
        sum = new long[m<<1];
        best = new long[m<<1];
        for(int i = 0; i< A.length; i++){
            sum[i+m] = A[i];
            best[i+m] = A[i];
        }
        for(int i = A.length; i< m; i++){
            sum[i+m] = -INFIN;
            best[i+m] = -INFIN;
        }
        for(int i = m-1; i>= 0; i--){
            sum[i] = sum[i<<1]+sum[i<<1|1];
            best[i] = Math.max(best[i<<1], sum[i<<1]+best[i<<1|1]);
        }
        st.dfs();
        pn(st.ans);
    }
    long[] merge(long[] a, long[] b){
        return new long[]{a[0]+b[0], Math.max(a[1], a[0]+b[1])};
    }
    long[] query(int l, int r){
        long[] le = new long[]{0, -INFIN}, ri = new long[]{0, -INFIN};
        for(l += m, r += m+1; l< r; l>>=1, r>>=1){
            if((l&1)==1){
                le = merge(le, new long[]{sum[l], best[l]});
                l++;
            }
            if((r&1) == 1){
                r--;
                ri = merge(new long[]{sum[r], best[r]}, ri);
            }
        }
        return merge(le, ri);
    }
    class SuffixTree {
        int INF = Integer.MAX_VALUE/4;
        long[] cur;
        int curLength;
        State active;
        Node root, sentinal;
        long ans = Long.MIN_VALUE;
        class Node{
            int le, ri;
            Node link;
            TreeMap<Long, Node> nxt;
            public Node(int l, int r){
                le = l;ri = r;
                link = null;
                nxt = new TreeMap<>();
            }
            int edgeLength(){return Math.min(ri, curLength-1)-le+1;}
            Node get(long ch){if(this == sentinal)return root;return nxt.getOrDefault(ch, null);}
            void put(long ch, Node node){nxt.put(ch, node);}
            public String toString(){
                StringBuilder a = new StringBuilder("");
                a.append("[le="+le+",ri="+ri+"]");
                return a.toString();
            }
        }
        class State{
            Node node;
            int K;
            public State(Node node, int le){
                this.node = node;
                this.K = le;
            }
            @Override
            public String toString(){
                return "[node="+node.toString()+",k="+K+"]";
            }
        }
        Object[] testAndSplit(State st, int P, long ch){
            if(st.K <= P){
                Node nxt = st.node.get(cur[st.K]);
                if(ch == cur[nxt.le+P-st.K+1])return new Object[]{true, st.node};
                else {
                    Node newNode = new Node(nxt.le, nxt.le+P-st.K);
                    st.node.put(cur[st.K], newNode);
                    newNode.put(cur[nxt.le+P-st.K+1], nxt);
                    nxt.le = nxt.le+P-st.K+1;
                    return new Object[]{false, newNode};
                }
            }else{
                if(st.node.get(ch) == null)return new Object[]{false, st.node};
                else return new Object[]{true, st.node};
            }
        }
        State canonize(State st, int p){
            if(st.node == sentinal){
                st.node = root;
                st.K++;
            }
            Node s = st.node;
            int k = st.K;
            if(p < k)return new State(s, k);
            Node s1 = s.get(cur[st.K]);
            int k1 = s1.le, p1 = s1.ri;
            while(p1-k1 <= p-k){
                k += p1-k1+1;
                s = s1;
                if(k <= p){
                    s1 = s.get(cur[k]);
                    k1 = s1.le;
                    p1 = s1.ri;
                }
            }
            return new State(s, k);
        }
        State update(State active, int i, long ch){
            Node oldr = root;
            Object[] end = testAndSplit(active, i-1, ch);
            boolean endpoint = (boolean)end[0];
            Node r = (Node)end[1];
            while(!endpoint){
                r.put(ch, new Node(i, INF));
                if(oldr != root)oldr.link = r;
                oldr = r;
                active = canonize(new State(active.node.link, active.K), i-1);
                end = testAndSplit(active, i-1, ch);
                endpoint = (boolean)end[0];
                r = (Node)end[1];
            }
            if(oldr != root)oldr.link = active.node;
//            if(oldr != (Node)end[1])oldr.link = active.node;
            return active;
        }
        public SuffixTree(int N){
            sentinal = new Node(0, -1);
            root = new Node(0, -1);
            root.link = sentinal;
            cur = new long[N];
            curLength = 0;
            active = new State(root, 0);
        }
        
        void addChar(long val){
            cur[curLength++] = val;
            active = update(active, curLength-1, val);
            active = canonize(active, curLength-1);
        }
        public void dfs(){
            dfs(root, 0);
        }
        private int dfs(Node node, long prefSum){
            int le = node.le, ri = Math.min(curLength-1, node.ri);
            long[] pair = query(le, ri);
            long sum = pair[0], best = pair[1];
            int nodeCount = node.ri > curLength?1:0;
            for(Node nxt:node.nxt.values())
                nodeCount += dfs(nxt, prefSum+sum);
            ans = Math.max(ans, (prefSum + best)*nodeCount);
            return nodeCount;
        }
        public void printTree(){
            printTree(root, "");
        }
        private void printTree(Node node, String dash){
            StringBuilder tmp = new StringBuilder();
            for(int i = node.le; i <= Math.min(node.ri, curLength-1); i++)tmp.append(cur[i]+" ");
            System.out.println(dash+tmp.toString()+"\t\t"+node.link);
//            System.out.println(dash+cur.substring(node.le,Math.min(cur.length()-1, node.ri)+1)+"\t\t"+node.link);
            for(Node nxt:node.nxt.values())printTree(nxt, dash+"\t");
        }
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    static boolean multipleTC = true;
    FastScanner in;PrintWriter out;
    void run() throws Exception{
        in = new FastScanner();//"in.txt");
        out = new PrintWriter(System.out);
        //Solution Credits: Taranpreet Singh
        int T = (multipleTC)?in.nextInt():1;
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        // new MFSS().run();
        new Thread(null, new Runnable() {public void run(){try{new MFSS().run();}catch(Exception e){e.printStackTrace();System.exit(1);}}}, "1", 1 << 26).start();
    }
    void pn(Object o){out.println(o);}
    class FastScanner implements AutoCloseable {
        private final java.io.InputStream in;
        private final byte[] buf = new byte[2048];
        private int ptr = 0;
        private int buflen = 0;

        public FastScanner(java.io.InputStream in) {
            this.in = in;
        }

        public FastScanner() {
            this(System.in);
        }

        private boolean hasNextByte() {
            if (ptr < buflen) return true;
            ptr = 0;
            try {
                buflen = in.read(buf);
            } catch (java.io.IOException e) {
                throw new RuntimeException(e);
            }
            return buflen > 0;
        }

        private int readByte() {
            return hasNextByte() ? buf[ptr++] : -1;
        }

        public boolean hasNext() {
            while (hasNextByte() && !(32 < buf[ptr] && buf[ptr] < 127)) ptr++;
            return hasNextByte();
        }

        private StringBuilder nextSequence() {
            if (!hasNext()) throw new java.util.NoSuchElementException();
            StringBuilder sb = new StringBuilder();
            for (int b = readByte(); 32 < b && b < 127; b = readByte()) {
                sb.appendCodePoint(b);
            }
            return sb;
        }

        public String next() {
            return nextSequence().toString();
        }

        public String next(int len) {
            return new String(nextChars(len));
        }

        public char nextChar() {
            if (!hasNextByte()) throw new java.util.NoSuchElementException();
            return (char) readByte();
        }

        public char[] nextChars() {
            StringBuilder sb = nextSequence();
            int l = sb.length();
            char[] dst = new char[l];
            sb.getChars(0, l, dst, 0);
            return dst;
        }
        public char[] nextChars(int len) {
            if (!hasNext()) throw new java.util.NoSuchElementException();
            char[] s = new char[len];
            int i = 0;
            int b = readByte();
            while (32 < b && b < 127 && i < len) {
                s[i++] = (char) b; b = readByte();
            }
            if (i != len) {
                throw new java.util.NoSuchElementException(
                    String.format("Next token has smaller length than expected.", len)
                );
            }
            return s;
        }
        public long nextLong() {
            if (!hasNext()) throw new java.util.NoSuchElementException();
            long n = 0;
            boolean minus = false;
            int b = readByte();
            if (b == '-') {
                minus = true;
                b = readByte();
            }
            if (b < '0' || '9' < b) throw new NumberFormatException();
            while (true) {
                if ('0' <= b && b <= '9') {
                    n = n * 10 + b - '0';
                } else if (b == -1 || !(32 < b && b < 127)) {
                    return minus ? -n : n;
                } else throw new NumberFormatException();
                b = readByte();
            }
        }
        public int nextInt() {
            return Math.toIntExact(nextLong());
        }
        public double nextDouble() {
            return Double.parseDouble(next());
        }
        public void close() {
            try {
                in.close();
            } catch (java.io.IOException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

Feel free to share your approach. Suggestions are welcomed as always. :slight_smile:

2 Likes

@taran_1407 is there any way to do this with polynomial roll hash for 100 points?

1 Like

I doubt it, as hashing is just a way to compare strings, we’d still need to consider each subarray once, unless something else is combined along with hashing

2 Likes

Is -10^9 really low enough? Consider the subarray -10^7 repeated 10^5 times. If I understand correctly the maximum sum would be -10^7 * 10^5 = -10^12, lower than -10^9 and causing you trouble.

2 Likes

Thanks for noticing, I have notified the team and updated editorial as well. Sorry for the mistake.

This was a nice problem! Here’s another solution that bears similarity to the maximum histogram problem. After we build the suffix array and the LCP array on the string, we note that our desired answer is the area of a rectangle in the histogram formed by the LCP array, with the caveat that we can have negative height rectangles since our “height” is \sum A_i. We use a stack to maintain heights just like in the canonical maximum histogram algorithm, and when we process the base of a rectangle, the optimal height can be any amount such that it remains taller than the two bars on either side. So in the diagram below, the red rectangle is permitted to be any height in the range indicated by green, but if it gets any shorter it’ll start including the bars on the left and right and become wider.


The optimal height can be determined with a segment tree that queries for the maximum prefix of a range. So this solution also works in \mathcal O(N \log N).
Submission for reference

6 Likes

Can’t we just maintain max sum possible from root to node and multiply it with occurrences while doing dfs on suffix tree?