SUBCOUNT - Editorial

Author: notsoloud
Tester: raysh_07
Editorialist: iceknight1093

TBD

PREREQUISITES:

String matching (for example with KMP or hashing)

PROBLEM:

You’re given a string S_0.
Create K new strings as follows:

• S_i = S_{i-1} + rev(S_{i-1}) for each 1 \leq i \leq K

Find the number of times S_0 appears as a substring of S_K.

EXPLANATION:

Note that S_1 = S_0 + rev(S_0) is a palindrome.
This means that rev(S_1) = S_1, so S_2 = S_1 + S_1 (and S_2 is also a palindrome).
Similarly, S_3 = S_2 + S_2 = S_1 + S_1 + S_1 + S_1.
More generally, it can be seen that for any i \geq 1, S_i will equal S_1 repeated 2^{i-1} times.

Since S_1 has length 2N, this means that if S_0 occurs as a substring starting at index j, it’ll also appear at all valid indices of the form j+2Nx for integer x.

This means it’s enough to consider instances of S_0 starting at the first 2N indices of the string!
That is, we can do the following:

• For each i = 1, 2, 3, \ldots, 2N, check if the length-N substring starting at i equals S_0.
• If it does, add to the answer the number of non-negative integers x such that
i + 2Nx + N-1 \leq |S_K| (that is, the number of starting indices of the form i+2Nx such that there exists a length-N substring starting at it).

The first part is a rather standard string problem: we have a string (the first 3N characters of S_2) and a pattern (S_0), and we’d like to find all positions where the pattern appears.
This can be done in linear time in many ways: for example using hashing or the KMP algorithm.

The second part can be done with some simple math.
Recall that S_K equals S_1 repeated 2^{K-1} times.
So, for a starting index i,

• If i \leq N+1, this starting index will be valid in every copy of S_1, for 2^{K-1} in total.
• If i \gt N+1, this starting index will be valid in every copy of S_1, except for the last (since there aren’t enough characters to form a length-N substring).
This gives 2^{K-1} - 1.

So, find all valid starting indices 1 \leq i \leq 2N and add either 2^{K-1} or 2^{K-1}-1 to the answer for each of them, depending on their values.
Finding 2^{K-1} quickly can be done using binary exponentiation.

TIME COMPLEXITY:

\mathcal{O}(N + \log K) per testcase.

CODE:

Author's code (C++)
#include <iostream>
#include <string>
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <vector>
#include <utility>
#include <iomanip>
#include <sstream>
#include <bitset>
#include <cstdlib>
#include <iterator>
#include <algorithm>
#include <cstdio>
#include <cctype>
#include <cmath>
#include <math.h>
#include <ctime>
#include <cstring>
#include <unordered_set>
#include <unordered_map>
#include <cassert>
#define int long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;

const int N=500023;
bool vis[N];
long long readInt(long long l,long long r,char endd){
long long x=0;
int cnt=0;
int fi=-1;
bool is_neg=false;
while(true){
char g=getchar();
if(g=='-'){
assert(fi==-1);
is_neg=true;
continue;
}
if('0'<=g && g<='9'){
x*=10;
x+=g-'0';
if(cnt==0){
fi=g-'0';
}
cnt++;
assert(fi!=0 || cnt==1);
assert(fi!=0 || is_neg==false);

assert(!(cnt>19 || ( cnt==19 && fi>1) ));
} else if(g==endd){
if(is_neg){
x= -x;
}

if(!(l <= x && x <= r))
{
cerr << l << ' ' << r << ' ' << x << '\n';
assert(1 == 0);
}

return x;
} else {
assert(false);
}
}
}
string ret="";
int cnt=0;
while(true){
char g=getchar();
assert(g!=-1);
if(g==endd){
break;
}
cnt++;
ret+=g;
}
assert(l<=cnt && cnt<=r);
return ret;
}
long long readIntSp(long long l,long long r){
}
long long readIntLn(long long l,long long r){
}
}
}

int sumN = 0;
int testLimit = 100000;
int nLimit = 1000000;
int kLimit = 1000000000;
int sumLimit = 2000000;

int power(int x, int y, int p){
int res = 1;
x = x%p;
while(y>0){
if(y&1){
res = (res*x)%p;
}
y = y>>1;
x = (x*x)%p;
}
return res;
}

void lps_func(string txt, vector<int>&Lps){
Lps[0] = 0;
int len = 0;
int i=1;
while (i<txt.length()){
if(txt[i]==txt[len]){
len++;
Lps[i] = len;
i++;
continue;
}
else{
if(len==0){
Lps[i] = 0;
i++;
continue;
}
else{
len = Lps[len-1];
continue;
}
}
}
}

int countSubstrings(string text,string pattern){
if(text == "")
return 0;

int n = text.length();
int m = pattern.length();
vector<int>Lps(m);

lps_func(pattern,Lps);

int i=0,j=0;
int ans = 0;
while(i<n){
if(pattern[j]==text[i]){i++;j++;} // If there is a match continue.
if (j == m) {
ans++;
j = Lps[j - 1];
}
else if (i < n && pattern[j] != text[i]) {  // If there is a mismatch
if (j == 0)          // if j becomes 0 then simply increment the index i
i++;
else
j = Lps[j - 1];  //Update j as Lps of last matched character
}
}
return ans;
}

string reverseString(string s){
string ans = "";
for(int i = s.size()-1; i>=0; i--){
ans += s[i];
}
return ans;
}

int maxAns = 0;

void solve()
{
sumN += n;
for(int i = 0; i<n; i++){
assert(s[i] >= 'a' || s[i] <= 'z');
}

if(k == 0){
cout << 1;
}
else{
//calculate ans for k = 1
int ans = 0;
string afterOneOp = s + reverseString(s);
int ans1 = countSubstrings(afterOneOp, s);
ans = (ans + ans1)%mod;

if(k > 1){
//calculate substring in s's
string midTwoOp = reverseString(s) + s;
int ansMid = countSubstrings(midTwoOp.substr(1, 2*n-2), s);
ans = (ans + ((power(2, k-1, mod)-1)*ans1)%mod)%mod;
ans = (ans + ((power(2, k-1, mod)-1)*ansMid)%mod)%mod;
}

cout << ans;
}
}

/*
1-aa'
2-aa'
3-aa'a'aaa'a'a
4-aa'a'aaa'a'aaa'a'aaa'a'a
5-aa'a'aaa'a'aaa'a'aaa'a'a

*/
int32_t main()
{
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
ios_base::sync_with_stdio(false);
cin.tie(NULL),cout.tie(NULL);
while(T--){
solve();
cout<<'\n';
}
cerr << sumN << '\n';
assert(getchar()==-1);
assert(sumN<=sumLimit);
cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}

/*
a a'
a a' a a' - 1 1
a a' a a' a a' a a' - 3 3

*/

Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

struct input_checker {
string buffer;
int pos;

const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
const string number = "0123456789";
const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
const string lower = "abcdefghijklmnopqrstuvwxyz";

input_checker() {
pos = 0;
while (true) {
int c = cin.get();
if (c == -1) {
break;
}
buffer.push_back((char) c);
}
}

int nextDelimiter() {
int now = pos;
while (now < (int) buffer.size() && !isspace(buffer[now])) {
now++;
}
return now;
}

assert(pos < (int) buffer.size());
int nxt = nextDelimiter();
string res;
while (pos < nxt) {
res += buffer[pos];
pos++;
}
return res;
}

string readString(int minl, int maxl, const string &pattern = "") {
assert(minl <= maxl);
assert(minl <= (int) res.size());
assert((int) res.size() <= maxl);
for (int i = 0; i < (int) res.size(); i++) {
assert(pattern.empty() || pattern.find(res[i]) != string::npos);
}
return res;
}

int readInt(int minv, int maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

long long readLong(long long minv, long long maxv) {
assert(minv <= maxv);
assert(minv <= res);
assert(res <= maxv);
return res;
}

auto readInts(int n, int minv, int maxv) {
assert(n >= 0);
vector<int> v(n);
for (int i = 0; i < n; ++i) {
}
return v;
}

auto readLongs(int n, long long minv, long long maxv) {
assert(n >= 0);
vector<long long> v(n);
for (int i = 0; i < n; ++i) {
}
return v;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == ' ');
pos++;
}

assert((int) buffer.size() > pos);
assert(buffer[pos] == '\n');
pos++;
}

assert((int) buffer.size() == pos);
}
};

input_checker inp;

const int T = 1e5;
const int N = 1e6;
const int K = 1e9;
const int mod = 1e9 + 7;
const int B = 100;
int sumn = 0;
int pb[3 * N], pib[3 * N];

int power(int x, int y){
if (y == 0) return 1;

int v = power(x, y / 2); v *= v; v %= mod;
if (y & 1) return (v * x) % mod;
else return v;
}

vector <int> generate(string s){
int n = s.length();
vector <int> pref(n + 1, 0);
for (int i = 1; i <= n; i++){
pref[i] = pref[i - 1] + (s[i - 1] - 'a' + 1) * pb[i];
pref[i] %= mod;
}

return pref;
}

void Solve()
{
for (auto x : s) assert(x >= 'a' && x <= 'z');

string t = s;
reverse(t.begin(), t.end());
string a1 = s + t;
string a2 = t + s;

int ans = 0;
auto v1 = generate(s);
auto v2 = generate(a1);
auto v3 = generate(a2);

int ok = power(2, k - 1);

for (int i = 2; i <= n; i++){
int val = v2[i + n - 1] - v2[i - 1];
if (val < 0) val += mod;
val *= pib[i - 1]; val %= mod;

if (val == v1[n]){
ans += ok;
}
}

for (int i = 2; i <= n; i++){
int val = v3[i + n - 1] - v3[i - 1];
if (val < 0) val += mod;
val *= pib[i - 1]; val %= mod;

if (val == v1[n]){
ans += ok - 1;
}
}

if (t == s){
ans += ok;
}
ans += ok;
ans %= mod;

cout << ans << "\n";
}

int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// freopen("in",  "r", stdin);
// freopen("out", "w", stdout);

pb[0] = pib[0] = 1;
for (int i = 1; i < 3 * N; i++){
pb[i] = pb[i - 1] * B % mod;
// pib[i] = power(pb[i], mod - 2);
}
pib[3 * N - 1] = power(pb[3 * N - 1], mod - 2);
for (int i = 3 * N - 2; i >= 0; i--){
pib[i] = pib[i + 1] * B % mod;
}

for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}

auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}

Editorialist's code (Python)
mod = 10**9 + 7
def partial(s):
g, pi = 0, [0] * len(s)
for i in range(1, len(s)):
while g and (s[g] != s[i]):
g = pi[g - 1]
pi[i] = g = g + (s[g] == s[i])

return pi

def match(s, pat):
pi = partial(pat)

g, idx = 0, []
for i in range(len(s)):
while g and pat[g] != s[i]:
g = pi[g - 1]
g += pat[g] == s[i]
if g == len(pi):
idx.append(i + 1 - g)
g = pi[g - 1]

return idx

for _ in range(int(input())):
n, k = map(int, input().split())
s = input()
big = s + s[::-1] + s
positions = match(big, s)
ans = 0
for i in positions:
if i <= n: ans += add
elif i < 2*n: ans += add - 1
print(ans % mod)

2 Likes

Can anyone please tell me why this solution is giving WA

I have used matrix exponentiation to solve the second part.

1 Like

I am unable to understand what is wrong with my solution.
I have counted the occurrences of S_0 in the string S_0 {\cdot} rev(S_0) {\cdot} S_0(0, N-1), where S_0(0, N-1) is the substring of S_0 having its first N-1 characters. These occurrences are repeated 2^{(K-1)} - 1 times. For the final instance of S_0 {\cdot} rev(S_0), I have separately counted the occurrences and added that to get the answer.
Any flaws pointed out in the approach/implementation are appreciated.