Author: rivalq
Tester: rivalq

MEDIUM

# PREREQUISITES:

String suffix structure, disjoint set union.

# PROBLEM:

Given a string S of length N. We have to count the number of unordered pairs of disjoint substrings P, Q of S such that:

• Both P and Q have the same length.

• P + Q is a palindrome, where + denotes the concatenation operator.

# QUICK EXPLANATION:

For each prefix p_i and suffix s_j of S such that j \gt i we have to calculate the length of the longest common prefix of rev(p_i) and s_j. We can ignore the inequality j \gt i and calculate the summation using suffix array. To handle non-disjoint cases we have to subtract the contribution of each palindrome.

# EXPLANATION:

Let T = rev(S). S_i and T_i denotes the i^{th} suffix of string S and T respectively.

let Count denotes the answer if we ignore the disjoint cases and unordered pair condition.
ans denotes the answer to the original problem.

Count = 2 \cdot ans + Non - disjoint cases

Count = \sum_{i = 1}^{i = n} \sum_{j = 1}^{j = n} LCP(S_i,T_j)

where LCP(S_i,T_j) denotes the length of longest common prefix of S_i and T_j.

The above summation can be easily calculated using a suffix array.

Calculation of Count

Create another string Z = S + separator +T . Now build suffix array and LCP table for Z.

Let C_i denotes the position of suffix i in the suffix array. Now LCP of two suffixes i and j (j \geq i) is min(LCP(C_i,C_i+1),LCP(C_i+1,C_i+2).......LCP(C_j-1,C_j)).

There is a common trick to handle summation involving LCP.
Consider suffix array as a chain where two consecutive suffixes in suffix array are nodes of chain connected with the edge of the weight of there LCP. All we need is just the sum of minimums over all paths of the chain. For that sort the edges according to their weight in non-increasing order and add them in DSU one be one. Due to sorting current edge will have minimum weight in the chain.

Let’s say we are connecting two components u and v, S_u and S_v are the number of suffixes of S in u and v respectively. Similarly T_u and T_v are the number of suffixes of T in u and v respectively. So we have to add S_u \cdot T_v + T_u \cdot S_v.

In O(N \log N) we can calculate the Count.

Now a non-disjoint case will occur only on the palindromes.
For every center using manachar or binary search, we can calculate the length of the palindrome and subtract its contribution in the answer.

Let L_c be the length of some palindrome for some center c.

For some suffix i \leq \left \lceil{\frac{L_c}{2}}\right \rceil of the palindrome centered at c, \left \lceil{\frac{L_c}{2}}\right \rceil cases are non-disjoint. So for each c we have to subtract \left \lceil{\frac{L_c}{2}}\right \rceil^2 from the count. In the end divide the Count by 2 to get ans.

I have used manachar to find palindromes, but you can use any algorithm to find palindromes in a string.

# SOLUTIONS:

Setter's Solution
// Jai shree ram

#include<bits/stdc++.h>
using namespace std;

#define rep(i,a,n)     for(int i=a;i<n;i++)
#define ll             long long
#define int            long long
#define all(v)         v.begin(),v.end()
#define endl           "\n"
#define x              first
#define y              second
#define pii            pair<int,int>

const int maxn = 3e5+5;
int p[maxn];
int sz[maxn];
int sz1[maxn];
int sz2[maxn];

void clear(int n=maxn){
rep(i,0,maxn){
p[i]=i,sz[i]=1;
sz1[i]=0;sz2[i]=0;
}
}

int root(int x){
while(x!=p[x]){
p[x]=p[p[x]];
x=p[x];
}
return x;
}

void merge(int x,int y){
int p1 = root(x);
int p2 = root(y);
if(p1 == p2)return;
if(sz[p1] < sz[p2])swap(p1,p2);
p[p2] = p1;
sz[p1] += sz[p2];
sz1[p1] += sz1[p2];
sz2[p1] += sz2[p2];

}

template<typename T,typename U>
struct suffix_array{
T s;
vector<int>p,c;
vector<int>lcp,len;
int n;
suffix_array(T str){
s=str;
n=s.size();p.resize(n);
c.resize(n);lcp.resize(n);
len.resize(n);
}
void count_sort(vector<int>&p,vector<int>&c){
int n=p.size();
vector<int>cnt(n);
for(auto x:c)cnt[x]++;
vector<int>p_new(n),pos(n);pos=0;
for(int i=1;i<n;i++)pos[i]=pos[i-1]+cnt[i-1];
for(auto x:p){
int i=c[x];
p_new[pos[i]]=x;pos[i]++;
}
p=p_new;
}
void build(){
vector<pair<int,int>>a(n);
for(int i=0;i<n;i++)a[i]={s[i],i};
sort(all(a));
for(int i=0;i<n;i++)p[i]=a[i].y;
c[p]=0;
for(int i=1;i<n;i++)c[p[i]]=(a[i].x==a[i-1].x)?c[p[i-1]]:c[p[i-1]]+1;
int k=0;
while((1<<k)<n){
for(int i=0;i<n;i++)p[i]=(p[i]-(1<<k)+n)%n;
count_sort(p,c);
vector<int>c_new(n);
c_new[p]=0;
for(int i=1;i<n;i++){
pii prev={c[p[i-1]],c[(p[i-1]+(1<<k))%n]};
pii now={c[p[i]],c[(p[i]+(1<<k))%n]};
if(now==prev){
c_new[p[i]]=c_new[p[i-1]];
}
else{
c_new[p[i]]=c_new[p[i-1]]+1;
}
}
c=c_new;
k++;
}
for(int i=0;i<n;i++){
len[i]=n-p[i];
}
k=0;
for(int i=0;i<n-1;i++){
int pi=c[i];int j=p[pi-1];
while(s[i+k]==s[j+k])k++;
lcp[pi]=k;k=max(0LL,k-1);
}

}

};

void manachar(vector<int>&d1,vector<int>&d2,string &s){
int n = s.length();
for (int i = 0, l = 0, r = -1; i < n; i++) {
int k = (i > r) ? 1 : min(d1[l + r - i], r - i + 1);
while (0 <= i - k && i + k < n && s[i - k] == s[i + k]) {
k++;
}
d1[i] = k--;
if (i + k > r) {
l = i - k;
r = i + k;
}
}
for (int i = 0, l = 0, r = -1; i < n; i++) {
int k = (i > r) ? 0 : min(d2[l + r - i + 1], r - i + 1);
while (0 <= i - k - 1 && i + k < n && s[i - k - 1] == s[i + k]) {
k++;
}
d2[i] = k--;
if (i + k > r) {
l = i - k - 1;
r = i + k ;
}
}
}

// nlogn

int solve(string &s){
string t = s;
string z = t;

int n = s.length();
reverse(z.begin(),z.end());
t = t +"#" + z + "?";

int m = t.length();

suffix_array<string,int> sa(t);
sa.build();

vector<pair<int,pii>>edges;

vector<int>d1(n),d2(n);
manachar(d1,d2,s);

for(int i = 2; i < m; i++){
edges.push_back({sa.lcp[i],{sa.p[i],sa.p[i-1]}});
}

clear(m);

for(int i = 0; i < n; i++){
sz1[i]=1;
sz2[i+n+1]=1;
}

sort(edges.begin(),edges.end());
reverse(edges.begin(),edges.end());

int ans = 0;

for(auto i: edges){
int u = root(i.y.x);
int v = root(i.y.y);
int w = i.x;
ans += w*(sz1[u]*sz2[v]+sz2[u]*sz1[v]);
merge(u,v);

}

for(int i = 0; i < n; i++){
int k = d1[i];
ans -= k*(k);
k = d2[i];
ans -= k*(k);
}
return ans/2;
}

signed main(){

ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
int t=1;cin>>t;

while(t--){
string s; cin >> s;
cout << solve(s) << endl;
}
return 0;
}


O(N \log N)

1 Like