PROBLEM LINK:
Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4
Author: grayhathacker
Tester: airths
Editorialist: iceknight1093
DIFFICULTY:
TBD
PREREQUISITES:
Hashing or Z-function
PROBLEM:
You’re given a string S of length N. Count the number of ways to choose three strings P, Q, R such that:
- P + Q + R = S, and
- P + R = Q
+ denotes string concatenation.
EXPLANATION:
Let |S| denote the length of string S.
Looking at the lengths of the relevant strings, we see that |P| + |Q| + |R| = N and
|P| + |R| = |Q|.
Together, these imply that |Q| = \frac{N}{2}.
In particular, if N is odd, no such Q can exist at all, and the answer is 0.
When N is even, the length of Q is fixed to be half the length of S.
Let’s fix index i, and check whether the substring of S of length \frac{N}{2} starting at index i can possibly be a valid Q.
Note that once this index i is fixed, P and R are also uniquely fixed since P+Q+R = S:
- P must be the prefix of S of length i-1, i.e, ending just before index i.
- R must be the suffix of S starting at index i+\frac{N}{2}, i.e, starting just after Q ends.
All that’s left is to check the second condition: whether P + R = Q.
For this, observe that if it’s true, P will be a prefix of Q and R will be a suffix of Q.
So, we only need to check if the first |P| characters of Q equal P, and if the last |R| characters of Q equal R.
This is a standard exercise in string algorithms, with a variety of solutions: the simplest two of them are to use either hashing or the z-function.
- Prefix hashing offers a way to check for the equality of two substrings in constant time - the hash of a substring can be computed in constant time, after which the hashes can simply be compared.
- The z-algorithm when run on a string S returns an array Z, where Z_i is the length of the longest common prefix of S and S[i:].
Once this is known, checking whether P is a prefix of Q is simple to check: just see whether Z_i \geq i-1 (in 1-based indexing).
Checking whether R is a suffix of Q can be done similarly, just run the z-algorithm on the reverse of S instead to get information about suffixes rather than prefixes.
TIME COMPLEXITY:
\mathcal{O}(N) per testcase.
CODE:
Author's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
struct Hash{
int b, n; // b = number of hashes
const int mod = 1e9 + 7;
vector<vector<int>> fw, bc, pb, ib;
vector<int> bases;
inline int power(int x, int y){
if (y == 0){
return 1;
}
int v = power(x, y / 2);
v = 1LL * v * v % mod;
if (y & 1) return 1LL * v * x % mod;
else return v;
}
inline void init(int nn, int bb, string str){
n = nn;
b = bb;
fw = vector<vector<int>>(b, vector<int>(n + 2, 0));
bc = vector<vector<int>>(b, vector<int>(n + 2, 0));
pb = vector<vector<int>>(b, vector<int>(n + 2, 1));
ib = vector<vector<int>>(b, vector<int>(n + 2, 1));
bases = vector<int>(b);
str = "0" + str;
for (auto &x : bases) x = RNG() % (mod / 10);
for (int i = 0; i < b; i++){
for (int j = 1; j <= n + 1; j++){
pb[i][j] = 1LL * pb[i][j - 1] * bases[i] % mod;
}
ib[i][n + 1] = power(pb[i][n + 1], mod - 2);
for (int j = n; j >= 1; j--){
ib[i][j] = 1LL * ib[i][j + 1] * bases[i] % mod;
}
for (int j = 1; j <= n; j++){
fw[i][j] = (fw[i][j - 1] + 1LL * (str[j] - 'a' + 1) * pb[i][j]) % mod;
}
for (int j = n; j >= 1; j--){
bc[i][j] = (bc[i][j + 1] + 1LL * (str[j] - 'a' + 1) * pb[i][n + 1 - j]) % mod;
}
}
}
inline int getfwhash(int l, int r, int i){
int ans = fw[i][r] - fw[i][l - 1];
ans = 1LL * ans * ib[i][l - 1] % mod;
if (ans < 0) ans += mod;
return ans;
}
inline int getbchash(int l, int r, int i){
int ans = bc[i][l] - bc[i][r + 1];
ans = 1LL * ans * ib[i][n - r] % mod;
if (ans < 0) ans += mod;
return ans;
}
inline bool equal(int l1, int r1, int l2, int r2){
for (int i = 0; i < b; i++){
int v1, v2;
if (l1 <= r1) v1 = getfwhash(l1, r1, i);
else v1 = getbchash(r1, l1, i);
if (l2 <= r2) v2 = getfwhash(l2, r2, i);
else v2 = getbchash(r2, l2, i);
if (v1 != v2) return false;
}
return true;
}
inline bool pal(int l, int r){
return equal(l, r, r, l);
}
};
void Solve()
{
string s; cin >> s;
int n = s.length();
Hash h;
h.init(n, 2, s);
int ans = 0;
if (n & 1){
cout << 0 << "\n";
return;
}
for (int i = 1; i <= n; i++){
int e = i + (n / 2) - 1;
if (e > n) break;
if (h.equal(1, i - 1, i, i + i - 2) && h.equal(i + i - 1, e, e + 1, n)){
ans += 1;
}
}
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in", "r", stdin);
// freopen("out", "w", stdout);
cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Tester's code (C++)
/*
*
* ^v^
*
*/
#include <iostream>
#include <numeric>
#include <set>
#include <cctype>
#include <iomanip>
#include <chrono>
#include <queue>
#include <string>
#include <vector>
#include <functional>
#include <tuple>
#include <map>
#include <bitset>
#include <algorithm>
#include <array>
#include <random>
#include <cassert>
using namespace std;
using ll = long long int;
using ld = long double;
#define iamtefu ios_base::sync_with_stdio(false); cin.tie(0);
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
void scn(){
// not necessarily distinct
// right down ytdm
string s; cin>>s;
int n = s.size();
if (s.size()&1){
cout<<0<<'\n';
return;
}
auto z_func=[n](const string &t)->vector <int>{
vector <int> ans(n, 0);
ans[0]=0;
int l = 0, r = 0;
for (int i=1; i<n; i++){
if (i<r){
ans[i] = min(ans[i-l], r-i);
}
while (i+ans[i]<n && t[i+ans[i]]==t[ans[i]]){
ans[i]++;
}
if (i+ans[i]>r){
r = i+ans[i];
l = i;
}
}
return ans;
};
// cout<<s<<'\n';
auto front = z_func(s);
reverse(s.begin(), s.end());
auto back = z_func(s);
// reverse(back.begin(), back.end());
int ans = 0;
// cout<<s<<'\n';
for (int i=0; i<n; i++){
if (n-n/2-i>=0){
ll matched = min(front[i], i) + min(back[n-n/2-i], n-n/2-i);
if (matched>=n/2){
ans++;
}
// cout<<i<<' ';
}
}
// cout<<'\n';
cout<<ans<<'\n';
}
int main(){
iamtefu;
#if defined(airths)
auto t1=chrono::high_resolution_clock::now();
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
int _; for(cin>>_; _--;)
{
scn();
}
#if defined(airths)
auto t2=chrono::high_resolution_clock::now();
ld ti=chrono::duration_cast<chrono::nanoseconds>(t2-t1).count();
ti*=1e-6;
cerr<<"Time: "<<setprecision(12)<<ti;
cerr<<"ms\n";
#endif
return 0;
}
Editorialist's code (Python)
# https://github.com/cheran-senthil/PyRival/blob/master/pyrival/strings/z_algorithm.py
def z_function(S):
"""
Z Algorithm in O(n)
:param S: text string to process
:return: the Z array, where Z[i] = length of the longest common prefix of S[i:] and S
"""
n = len(S)
Z = [0] * n
l = r = 0
for i in range(1, n):
z = Z[i - l]
if i + z >= r:
z = max(r - i, 0)
while i + z < n and S[z] == S[i + z]:
z += 1
l, r = i, i + z
Z[i] = z
Z[0] = n
return Z
for _ in range(int(input())):
s = input()
n = len(s)
if n%2 == 1:
print(0)
continue
Zf = z_function(s)
Zb = z_function(s[::-1])[::-1]
ans = 0
for i in range(n//2 + 1):
if Zf[i] >= i and Zb[i+n//2-1] >= n//2 - i: ans += 1
print(ans)