SUBCOUNT - Editorial

PROBLEM LINK:

Practice
Contest: Division 1
Contest: Division 2
Contest: Division 3
Contest: Division 4

Author: notsoloud
Tester: raysh_07
Editorialist: iceknight1093

DIFFICULTY:

TBD

PREREQUISITES:

String matching (for example with KMP or hashing)

PROBLEM:

You’re given a string S_0.
Create K new strings as follows:

  • S_i = S_{i-1} + rev(S_{i-1}) for each 1 \leq i \leq K

Find the number of times S_0 appears as a substring of S_K.

EXPLANATION:

Note that S_1 = S_0 + rev(S_0) is a palindrome.
This means that rev(S_1) = S_1, so S_2 = S_1 + S_1 (and S_2 is also a palindrome).
Similarly, S_3 = S_2 + S_2 = S_1 + S_1 + S_1 + S_1.
More generally, it can be seen that for any i \geq 1, S_i will equal S_1 repeated 2^{i-1} times.

Since S_1 has length 2N, this means that if S_0 occurs as a substring starting at index j, it’ll also appear at all valid indices of the form j+2Nx for integer x.

This means it’s enough to consider instances of S_0 starting at the first 2N indices of the string!
That is, we can do the following:

  • For each i = 1, 2, 3, \ldots, 2N, check if the length-N substring starting at i equals S_0.
  • If it does, add to the answer the number of non-negative integers x such that
    i + 2Nx + N-1 \leq |S_K| (that is, the number of starting indices of the form i+2Nx such that there exists a length-N substring starting at it).

The first part is a rather standard string problem: we have a string (the first 3N characters of S_2) and a pattern (S_0), and we’d like to find all positions where the pattern appears.
This can be done in linear time in many ways: for example using hashing or the KMP algorithm.

The second part can be done with some simple math.
Recall that S_K equals S_1 repeated 2^{K-1} times.
So, for a starting index i,

  • If i \leq N+1, this starting index will be valid in every copy of S_1, for 2^{K-1} in total.
  • If i \gt N+1, this starting index will be valid in every copy of S_1, except for the last (since there aren’t enough characters to form a length-N substring).
    This gives 2^{K-1} - 1.

So, find all valid starting indices 1 \leq i \leq 2N and add either 2^{K-1} or 2^{K-1}-1 to the answer for each of them, depending on their values.
Finding 2^{K-1} quickly can be done using binary exponentiation.

TIME COMPLEXITY:

\mathcal{O}(N + \log K) per testcase.

CODE:

Author's code (C++)
#include <iostream> 
#include <string> 
#include <set> 
#include <map> 
#include <stack> 
#include <queue> 
#include <vector> 
#include <utility> 
#include <iomanip> 
#include <sstream> 
#include <bitset> 
#include <cstdlib> 
#include <iterator> 
#include <algorithm> 
#include <cstdio> 
#include <cctype> 
#include <cmath> 
#include <math.h> 
#include <ctime> 
#include <cstring> 
#include <unordered_set> 
#include <unordered_map> 
#include <cassert>
#define int long long int
#define pb push_back
#define mp make_pair
#define mod 1000000007
#define vl vector <ll>
#define all(c) (c).begin(),(c).end()
using namespace std;

const int N=500023;
bool vis[N];
vector <int> adj[N];
long long readInt(long long l,long long r,char endd){
    long long x=0;
    int cnt=0;
    int fi=-1;
    bool is_neg=false;
    while(true){
        char g=getchar();
        if(g=='-'){
            assert(fi==-1);
            is_neg=true;
            continue;
        }
        if('0'<=g && g<='9'){
            x*=10;
            x+=g-'0';
            if(cnt==0){
                fi=g-'0';
            }
            cnt++;
            assert(fi!=0 || cnt==1);
            assert(fi!=0 || is_neg==false);

            assert(!(cnt>19 || ( cnt==19 && fi>1) ));
        } else if(g==endd){
            if(is_neg){
                x= -x;
            }

            if(!(l <= x && x <= r))
            {
                cerr << l << ' ' << r << ' ' << x << '\n';
                assert(1 == 0);
            }

            return x;
        } else {
            assert(false);
        }
    }
}
string readString(int l,int r,char endd){
    string ret="";
    int cnt=0;
    while(true){
        char g=getchar();
        assert(g!=-1);
        if(g==endd){
            break;
        }
        cnt++;
        ret+=g;
    }
    assert(l<=cnt && cnt<=r);
    return ret;
}
long long readIntSp(long long l,long long r){
    return readInt(l,r,' ');
}
long long readIntLn(long long l,long long r){
    return readInt(l,r,'\n');
}
string readStringLn(int l,int r){
    return readString(l,r,'\n');
}
string readStringSp(int l,int r){
    return readString(l,r,' ');
}

int sumN = 0;
int testLimit = 100000;
int nLimit = 1000000;
int kLimit = 1000000000;
int sumLimit = 2000000;

int power(int x, int y, int p){
    int res = 1;
    x = x%p;
    while(y>0){
        if(y&1){
            res = (res*x)%p;
        }
        y = y>>1;
        x = (x*x)%p;
    }
    return res;
}

void lps_func(string txt, vector<int>&Lps){
    Lps[0] = 0;                   
    int len = 0;
    int i=1;
    while (i<txt.length()){
        if(txt[i]==txt[len]){   
            len++;
            Lps[i] = len;
            i++;
            continue;
        }
        else{                   
            if(len==0){         
                Lps[i] = 0;
                i++;
                continue;
            }
            else{              
                len = Lps[len-1];
                continue;
            }
        }
    }
}

 
int countSubstrings(string text,string pattern){
    if(text == "")
        return 0;

    int n = text.length();
    int m = pattern.length();
    vector<int>Lps(m);
    
    lps_func(pattern,Lps); 
    
    int i=0,j=0;
    int ans = 0;
    while(i<n){
        if(pattern[j]==text[i]){i++;j++;} // If there is a match continue.
        if (j == m) { 
            ans++;
            j = Lps[j - 1]; 
        } 
        else if (i < n && pattern[j] != text[i]) {  // If there is a mismatch
            if (j == 0)          // if j becomes 0 then simply increment the index i
                i++;
            else
                j = Lps[j - 1];  //Update j as Lps of last matched character
        }
    }
    return ans;
}

string reverseString(string s){
    string ans = "";
    for(int i = s.size()-1; i>=0; i--){
        ans += s[i];
    }
    return ans;
}

int maxAns = 0;

void solve()
{
    int n = readIntSp(1, nLimit);
    int k = readIntLn(0, kLimit);
    sumN += n;
    string s = readStringLn(n, n);
    for(int i = 0; i<n; i++){
        assert(s[i] >= 'a' || s[i] <= 'z');
    }

    if(k == 0){
        cout << 1;
    }
    else{
        //calculate ans for k = 1
        int ans = 0;
        string afterOneOp = s + reverseString(s);
        int ans1 = countSubstrings(afterOneOp, s);
        ans = (ans + ans1)%mod;

        if(k > 1){
            //calculate substring in s's
            string midTwoOp = reverseString(s) + s;
            int ansMid = countSubstrings(midTwoOp.substr(1, 2*n-2), s);
            ans = (ans + ((power(2, k-1, mod)-1)*ans1)%mod)%mod;
            ans = (ans + ((power(2, k-1, mod)-1)*ansMid)%mod)%mod;
        }

        cout << ans;
    }
}

/*
1-aa'
2-aa'
3-aa'a'aaa'a'a
4-aa'a'aaa'a'aaa'a'aaa'a'a
5-aa'a'aaa'a'aaa'a'aaa'a'a



*/
int32_t main()
{
    #ifndef ONLINE_JUDGE
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
    #endif
    ios_base::sync_with_stdio(false);
    cin.tie(NULL),cout.tie(NULL);
    int T=readIntLn(1, testLimit);
    while(T--){
        solve();
        cout<<'\n';
    }
    cerr << sumN << '\n';
    assert(getchar()==-1);
    assert(sumN<=sumLimit);
    cerr << "Time : " << 1000 * ((double)clock()) / (double)CLOCKS_PER_SEC << "ms\n";
}

/*
a a'
a a' a a' - 1 1
a a' a a' a a' a a' - 3 3




*/
Tester's code (C++)
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

struct input_checker {
    string buffer;
    int pos;

    const string all = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
    const string number = "0123456789";
    const string upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
    const string lower = "abcdefghijklmnopqrstuvwxyz";

    input_checker() {
        pos = 0;
        while (true) {
            int c = cin.get();
            if (c == -1) {
                break;
            }
            buffer.push_back((char) c);
        }
    }

    int nextDelimiter() {
        int now = pos;
        while (now < (int) buffer.size() && !isspace(buffer[now])) {
            now++;
        }
        return now;
    }

    string readOne() {
        assert(pos < (int) buffer.size());
        int nxt = nextDelimiter();
        string res;
        while (pos < nxt) {
            res += buffer[pos];
            pos++;
        }
        return res;
    }

    string readString(int minl, int maxl, const string &pattern = "") {
        assert(minl <= maxl);
        string res = readOne();
        assert(minl <= (int) res.size());
        assert((int) res.size() <= maxl);
        for (int i = 0; i < (int) res.size(); i++) {
            assert(pattern.empty() || pattern.find(res[i]) != string::npos);
        }
        return res;
    }

    int readInt(int minv, int maxv) {
        assert(minv <= maxv);
        int res = stoi(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    long long readLong(long long minv, long long maxv) {
        assert(minv <= maxv);
        long long res = stoll(readOne());
        assert(minv <= res);
        assert(res <= maxv);
        return res;
    }

    auto readInts(int n, int minv, int maxv) {
        assert(n >= 0);
        vector<int> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readInt(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    auto readLongs(int n, long long minv, long long maxv) {
        assert(n >= 0);
        vector<long long> v(n);
        for (int i = 0; i < n; ++i) {
            v[i] = readLong(minv, maxv);
            if (i+1 < n) readSpace();
        }
        return v;
    }

    void readSpace() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == ' ');
        pos++;
    }

    void readEoln() {
        assert((int) buffer.size() > pos);
        assert(buffer[pos] == '\n');
        pos++;
    }

    void readEof() {
        assert((int) buffer.size() == pos);
    }
};

input_checker inp;

const int T = 1e5;
const int N = 1e6;
const int K = 1e9;
const int mod = 1e9 + 7;
const int B = 100;
int sumn = 0;
int pb[3 * N], pib[3 * N];

int power(int x, int y){
    if (y == 0) return 1;
    
    int v = power(x, y / 2); v *= v; v %= mod;
    if (y & 1) return (v * x) % mod;
    else return v;
}

vector <int> generate(string s){
    int n = s.length();
    vector <int> pref(n + 1, 0);
    for (int i = 1; i <= n; i++){
        pref[i] = pref[i - 1] + (s[i - 1] - 'a' + 1) * pb[i];
        pref[i] %= mod;
    }
    
    return pref;
}

void Solve() 
{
    int n = inp.readInt(1, N); sumn += n; inp.readSpace();
    int k = inp.readInt(0, K); inp.readEoln();
    string s = inp.readString(n, n); inp.readEoln();
    for (auto x : s) assert(x >= 'a' && x <= 'z');
    
    string t = s;
    reverse(t.begin(), t.end());
    string a1 = s + t;
    string a2 = t + s;
    
    int ans = 0;
    auto v1 = generate(s);
    auto v2 = generate(a1);
    auto v3 = generate(a2);
    
    int ok = power(2, k - 1);
    
    for (int i = 2; i <= n; i++){
        int val = v2[i + n - 1] - v2[i - 1];
        if (val < 0) val += mod;
        val *= pib[i - 1]; val %= mod;
        
        if (val == v1[n]){
            ans += ok;
        }
    }
    
    for (int i = 2; i <= n; i++){
        int val = v3[i + n - 1] - v3[i - 1];
        if (val < 0) val += mod;
        val *= pib[i - 1]; val %= mod;
        
        if (val == v1[n]){
            ans += ok - 1;
        }
    }
    
    if (t == s){
        ans += ok;
    }
    ans += ok;
    ans %= mod;
    
    cout << ans << "\n";
}

int32_t main() 
{
    auto begin = std::chrono::high_resolution_clock::now();
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int t = 1;
    // freopen("in",  "r", stdin);
    // freopen("out", "w", stdout);
    
    t = inp.readInt(1, T);
    inp.readEoln();
    
    pb[0] = pib[0] = 1;
    for (int i = 1; i < 3 * N; i++){
        pb[i] = pb[i - 1] * B % mod;
        // pib[i] = power(pb[i], mod - 2);
    }
    pib[3 * N - 1] = power(pb[3 * N - 1], mod - 2);
    for (int i = 3 * N - 2; i >= 0; i--){
        pib[i] = pib[i + 1] * B % mod;
    }

    for(int i = 1; i <= t; i++) 
    {
        //cout << "Case #" << i << ": ";
        Solve();
    }

    inp.readEof();

    auto end = std::chrono::high_resolution_clock::now();
    auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
    cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n"; 
    return 0;
}
Editorialist's code (Python)
mod = 10**9 + 7
def partial(s):
    g, pi = 0, [0] * len(s)
    for i in range(1, len(s)):
        while g and (s[g] != s[i]):
            g = pi[g - 1]
        pi[i] = g = g + (s[g] == s[i])

    return pi


def match(s, pat):
    pi = partial(pat)

    g, idx = 0, []
    for i in range(len(s)):
        while g and pat[g] != s[i]:
            g = pi[g - 1]
        g += pat[g] == s[i]
        if g == len(pi):
            idx.append(i + 1 - g)
            g = pi[g - 1]

    return idx

for _ in range(int(input())):
    n, k = map(int, input().split())
    s = input()
    big = s + s[::-1] + s
    positions = match(big, s)
    ans = 0
    add = pow(2, k-1, mod)
    for i in positions:
        if i <= n: ans += add
        elif i < 2*n: ans += add - 1
    print(ans % mod)
2 Likes

Can anyone please tell me why this solution is giving WA

I have used matrix exponentiation to solve the second part.

1 Like

I am unable to understand what is wrong with my solution.
I have counted the occurrences of S_0 in the string S_0 {\cdot} rev(S_0) {\cdot} S_0(0, N-1), where S_0(0, N-1) is the substring of S_0 having its first N-1 characters. These occurrences are repeated 2^{(K-1)} - 1 times. For the final instance of S_0 {\cdot} rev(S_0), I have separately counted the occurrences and added that to get the answer.
Any flaws pointed out in the approach/implementation are appreciated.