CENS20C - Editorial

PROBLEM LINK:

Practice
Contest

Author & Editorialist: Aman Dwivedi
Tester: Jatin Nagpal

DIFFICULTY:

EASY

PREREQUISITES:

KMP Algorithm/Z Algorithm/String Hashing and Prefix Array

Problem

Given a string S that forms a pyramid of infinite length. You need to answer Q queries, the answer to each query is the number of occurrences of string T in particular row of that pyramid.

EXPLANATION:

Initial thoughts

Let us find the number of occurrences of string T in string S, name it c1. Now concatenate the string S with itself, and now find the number of occurrences of string T in this concatenated string, name it c2.

Is c2=2*c1 ?
No as due to concatenation, the number of occurrences can increase. Also concatenating only once is enough as the size of T is always smaller than or equal to size of S.

Finding number of occurrences

First we find the number of occurrences of string T in string S by using any pattern matching algorithm, name it c1.

Now concatenate string S with itself and find the the number of occurrences of string T in this concatenated string, name it c2

Extra occurrences, if any can be found by c3 = c2-2*c1.

Answering Queries

Find the number of times S appears in given row. This can be found by dividing the given row number by the size of S, say it x.

This means (x-1) times the string is concatenated. So the number of occurrences will be (x * c1 + (x-1) * c3).

Still some characters will be left i.e. (x % |S|). Check whether these characters results in some extra occurrences, it can be found by maintaining a prefix array at start.

Well take care when x comes out to be 0. Try to handle this case seperately.

SOLUTIONS:

Using KMP Algorithm
#include<bits/stdc++.h>
using namespace std;
 
#define rep(i,n) for(int i=0;i<n;i++)
#define repa(i,a,n) for(int i=a;i<=n;i++)
#define repb(i,a,n) for(int i=a;i>=n;i--)
#define trav(a,x) for(auto a=x.begin();a!=x.end();a++)
#define all(x) x.begin(),x.end()
#define fst first
#define snd second
#define pb push_back
#define mp make_pair
typedef long double ld;
typedef pair <int,int> pii;
typedef vector <int> vi;
typedef long long ll;
 
void pre(){
 
}
 
vi prefix_function(string const &s){
  int n=s.size();
  vi pi(n);
  pi[0]=0;
 
  for(int i=1;i<n;i++){
    int j=pi[i-1];
 
    while(j>0 && s[i]!=s[j])
      j=pi[j-1];
 
    pi[i]=j+(s[i]==s[j]);
  }
 
  return pi;
}
 
void solve(){
  string s; cin>>s;
  string t; cin>>t;
  int c1=0,c2=0;
 
  string temp=t+'#'+s+s;
 
  vi pi=prefix_function(temp);
 
  int count[2*s.size()];
  int j=0;
 
  for(int i=t.size()+1;i<(temp.size());i++){
    if(i==t.size()+1){
      if(pi[i]==(t.size())){
        c1++;
        c2++;
        count[j]=1;
      }
      else count[j]=0;
    }
    else{
      if(pi[i]==(t.size())){
        if(i>(s.size()+t.size())){
          c2++;
        }
        else{
          c1++;
          c2++;
        }
        count[j]=count[j-1]+1;
      }
      else{
        count[j]=count[j-1];
      }
    }
    j++;
  }
 
  int con=(c2-(2*c1));
 
  int q; cin>>q;
 
  while(q--){
    ll n; cin>>n;
    ll ans=0;
    ll quot=n/(s.size());
 
    if(quot!=0){
      ll rem=n%(s.size());
      ans+=((quot*c1)+((quot-1)*con));
 
      if(rem!=0){
        ans+=(count[s.size()+rem-1]-count[s.size()-1]);
      }
    }
    else{
      ll rem=n%(s.size());
      ans+=(count[rem-1]);
    }
 
    cout<<ans<<endl;
  }
}
 
int main(){
  ios_base::sync_with_stdio(0); cin.tie(0);
  pre();
  int t; t=1;
  rep(i,t) solve();
  return 0;
}
Using Z Algorithm
#include <bits/stdc++.h>
using namespace std;
#define ff first
#define ss second
#define MP make_pair
#define PB push_back
#define ll long long
#define int long long
#define f(i,x,n) for(int i=x;i<n;i++)
#define ld long double
const int mod=1000000007;
const int INF=1e18;
 
vector<int> z_function(string s) {
    int n = (int) s.length();
    vector<int> z(n);
    for (int i = 1, l = 0, r = 0; i < n; ++i) {
        if (i <= r)
            z[i] = min (r - i + 1, z[i - l]);
        while (i + z[i] < n && s[z[i]] == s[i + z[i]])
            ++z[i];
        if (i + z[i] - 1 > r)
            l = i, r = i + z[i] - 1;
    }
    return z;
}
int pre[300005];
int32_t main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	string s,t;
	cin>>s>>t;
	int sl=s.length();
	int tl=t.length();
	string a=t+"#"+s+s;
	vector <int> b=z_function(a);
	for(int i=0;i<sl;i++) {
		if(b[i+tl+1]==tl) {
			pre[i+tl]++;
		}
	}
	for(int i=1;i<=2*sl;i++) {
		pre[i]+=pre[i-1];
	}
	int q;
	cin>>q;
	while(q--) {
		int in;
		cin>>in;
		if(in<=sl) {
			cout<<pre[in]<<'\n';
		}
		else {
			int ans=0;
			in-=sl+1;
			ans=in/sl;
			in-=ans*sl;
			ans*=pre[2*sl];
			in+=sl+1;
			ans+=pre[in]+pre[in-sl];
			cout<<ans<<'\n';
		}
	}
	return 0;
} 
Using String Hashing

#include <bits/stdc++.h>
using namespace std;
 
#define int long long
 
int A = 911382323, B = 972663749;
 
long long binpow(long long a, long long b, long long m) {
    a %= m;
    long long res = 1;
    while (b > 0) {
        if (b & 1)
            res = res * a % m;
        a = a * a % m;
        b >>= 1;
    }
    return res;
}
 
 
int32_t main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	string s, t;
	cin >> s >> t;
	int shash[2*s.size()], pref[2*s.size()];
	int i, j, q, a=1;
 
	int Apow[2*s.size()], modinv[2*s.size()+1];
	Apow[0] = 1;
	for(i=1; i<2*s.size(); i++){
		Apow[i] = Apow[i-1]*A;
		Apow[i] %= B;
	}
 
	for(i=0; i<=2*s.size(); i++){
		modinv[i] = binpow(Apow[i], B-2, B);
		modinv[i] %= B;
	}
 
	for(i=0; i<2*s.size(); i++){
		shash[i] = s[i%s.size()]*a % B;
		a *= A;
		a %= B;
	}
 
	pref[0] = shash[0];
	for(i=1; i<2*s.size(); i++){
		pref[i] = pref[i-1]+shash[i];
		pref[i] %= B;
	}
 
	int thash = 0;
	a = 1;
	for(i=0; i<t.size(); i++){
		thash += t[i]*a % B;
		thash %= B;
		a *= A;
		a %= B;
	}
 
	int inside = 0, across = 0;
	int reminside[s.size()]={0}, remacross[s.size()]={0};
 
	if(pref[t.size()-1] == thash){
		inside++;
		reminside[t.size()-1] = inside;
	}
 
	for(i=t.size(); i<s.size(); i++){
		int sum = (pref[i] - pref[i-t.size()]+B)%B;
		sum *= modinv[i-t.size()+1];
		sum %= B;
 
		if(thash == sum){
			inside++;
		}
		reminside[i] = inside;
	}
 
	for(i=s.size()-t.size()+1; i<s.size(); i++){
		int sum = (pref[i+t.size()-1] - pref[i-1]+B)%B;
		sum *= modinv[i];
		sum %= B;
 
		if(thash == sum){
			across++;
		}
		remacross[(i+t.size())%s.size() - 1] = across;
	}
 
	
	if(t.size() != 1){
		for(i=t.size()-1; i<s.size(); i++){
			remacross[i] = remacross[i-1];
		}
	}
 
	cin >> q;
	while(q--){
		int n;
		cin >> n;
		int whole = n/s.size();
 
		int ans = inside*whole + across*(max(0ll, whole-1));
		if(n%s.size() == 0){
			cout << ans << "\n";
			continue;
		}
		ans += reminside[n%s.size()-1];
		if(whole == 0){
			cout << ans << "\n";
			continue;
		}
 
		ans += remacross[n%s.size()-1];
 
		cout << ans << "\n";
 
	}
}
12 Likes

Please format the codes.

Really good questions this contest! Loved it!

5 Likes

Please format your code .[ hashing approach ]

Thanks for pointing out. Done!

2 Likes

very balanced contest i have just shifted to codechef from hackerearth :sweat_smile: and my first contest gone really well. Good work guys.

3 Likes

Don’t post useless comments ! search that in google .

why useless , I solved many hashing problems and still unable to solve , so I just ask for more questions , is it wrong ?

1 Like

Easier implementation:
a_i = number of j \leq i such that a matched substring starts at current S and ends at j'th index of current S
b_i = number of j \leq i such that a matched substring starts at previous S and ends at j'th index of current S
Rest of the logic is same.

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

const int MX=100005;
char s[MX],t[MX*3];
int pi[MX*3],a[MX],b[MX];

void compute(int len,char s[],int pi[])
{
    pi[0]=0;
    for(int i=1;i<len;i++){
        int j=pi[i-1];
        while(j){
            if(s[j]==s[i]){
                pi[i]=j+1;
                break;
            }
            j=pi[j-1];
        }
        if(j==0) pi[i]=(s[i]==s[0])?1:0;
    }
}

int main()
{
    int n,m,k,q;
    scanf("%s %s %d",s,t,&q);
    n=strlen(s),m=strlen(t);
    strcat(t,"#");
    strcat(t,s);
    strcat(t,s);
    compute(n+n+m+1,t,pi);
    for(int i=0,j=m+1;i<n;i++,j++) a[i]=(pi[j]==m);
    for(int i=0,j=n+m+1;i<m-1;i++,j++) b[i]=(pi[j]==m);
    for(int i=1;i<n;i++){
        a[i]+=a[i-1];
        b[i]+=b[i-1];
    }
    while(q--){
        scanf("%d",&k);
        ll x=k/n,r=k%n;
        if(x){
            ll ans=x*(ll)a[n-1]+(x-1ll)*(ll)b[n-1];
            if(r) ans+=(ll)(a[r-1]+b[r-1]);
            printf("%lld\n",ans);
        }
        else printf("%d\n",a[r-1]);
    }
    return 0;
}

Can anybody explain the prefix array implementation part more clearly in KMP Code.

I am curious about how to approach this problem if length of T can be greater than length of S

Let say s has 2 occurence of t and when combined with another s, there is another. So, total for s+s they will have 5 t.
So, why are you multiplying directly (x-1) with b[n-1] without subtracting the cases where only second s alone leads to contribution in the count of t.

Also, I tried your approach using single array. Can you tell me where it is wrong?

Code
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

const int MX=100005;
char s[MX],t[MX*3];
int pi[MX*3],a[MX],b[MX],c[MX*3];

void compute(int len,char s[],int pi[])
{
    pi[0]=0;
    for(int i=1;i<len;i++){
        int j=pi[i-1];
        while(j){
            if(s[j]==s[i]){
                pi[i]=j+1;
                break;
            }
            j=pi[j-1];
        }
        if(j==0) pi[i]=(s[i]==s[0])?1:0;
    }
}

int main()
{
    int n,m,k,q;
    scanf("%s %s %d",s,t,&q);
    n=strlen(s),m=strlen(t);
    strcat(t,"#");
    strcat(t,s);
    strcat(t,s);
    compute(n+n+m+1,t,pi);
    for(int i=0,j=m+1;i<n;i++,j++) a[i]=(pi[j]==m);
    for(int i=0,j=n+m+1;i<m-1;i++,j++) b[i]=(pi[j]==m);
    for(int i=1;i<n;i++){
        a[i]+=a[i-1];
        b[i]+=b[i-1];
    }
    //cout<<t<<"\n";
    for(int i=0;i<n+n+m+1;i++)
    {
        if(i!=0)
        c[i]+=c[i-1];
        
        if(pi[i]==m)
        c[i]++;
        
        //cout<<pi[i]<<" ";
    }
    while(q--){
        scanf("%d",&k);
        ll x=k/n,r=k%n;
        if(x){
            ll ans=x*(ll)c[m+1+n-1]+(x-1ll)*(ll)(c[n+m+1+n-1]-c[m+1+n-1]);
            if(r) ans+=(ll)(c[m+1+r-1]+c[n+m+1+r-1]-c[n+m]);
            printf("%lld\n",ans);
        }
        else printf("%d\n",a[r-1]);
    }
    return 0;
}

@everule1 @akshitm16 @ssjgz

I have only recently learned about the KMP algorithm and I did not understand how to do this part of the editorial:
“Still some characters will be left i.e. (x). Check whether these characters results in some extra occurrences, it can be found by maintaining a prefix array at start.”
This is my implementation of just running KMP again for finding extra sequence from the left out part. It gives correct answer on given test case, but TLE when I submit. Can someone please explain?

#include <bits/stdc++.h> 
using namespace std;

typedef vector <int> vi;

vi computeLPSArray(string pat, int M); 

int KMPSearch(string pat, string txt, int start) 
{ 
	int M = pat.size(); 
	int N = txt.size(); 

	vi lps = computeLPSArray(pat, M); 

	int i = 0;
	int j = 0;
	int count = 0;
	while (i < N) { 
		if (pat[j] == txt[i]) { 
			j++; 
			i++; 
		} 

		if (j == M && i>start) { 
			count+=1;
			j = lps[j - 1]; 
		} 

		else if (i < N && pat[j] != txt[i]) { 
			if (j != 0) 
				j = lps[j - 1]; 
			else
				i = i + 1; 
		} 
	}
	return count;
} 

vi computeLPSArray(string pat, int M) 
{ 
	int len = 0; 
	vi lps;
	lps.push_back(0); 

	int i = 1; 
	while (i < M) { 
		if (pat[i] == pat[len]) { 
			len++; 
			lps.push_back(len); 
			i++; 
		} 
		else // (pat[i] != pat[len]) 
		{ 
			if (len != 0) { 
				len = lps.at(len - 1); 
			} 
			else // if (len == 0) 
			{ 
				lps.push_back(0); 
				i++; 
			} 
		} 
	} 
	return lps;
} 
int main() 
{ 
    string pat, txt;
    cin >> txt >> pat;
    int q;
    cin >> q;
    while(q--){
        int line;
        cin >> line;
        if(line < pat.size()){
            cout << 0<<"\n";
            continue;
        }
    	int count1 = KMPSearch(pat, txt, 0);
    	int count2 = KMPSearch(pat, txt + txt, 0);
    	int count3 = count2 - 2*count1;
    	int x = line / txt.size();
    	int rem = line % txt.size();
    	int count4 = KMPSearch(pat, txt + txt.substr(0,rem), txt.size());
    	int ans = count3*(x-1) + count1*x + count4;
    	cout << ans << "\n";
    }
	return 0; 
}

@everule1 @akshitm16 @ssjgz

Didn’t look at the code but it fails this simple case-

code
c
1
5

That TLE is due to complexity of your code being O(Q \cdot (\mid T \mid+N) ) . Try to find an optimized solution and why are you calculating count1 and count2 for each query while these don’t depend on the query in any way. Fixing this will still give you TLE as you’re still calculating count4 in for each query. The worst case complexity will still remain the same.

There are two type of occurances of T:

  • the occurance that starts and ends at the same S.
  • the occurance that starts at the previous S and ends at the current S (like you said when two S are combined we get this type of occurances).

The array a counts only first type of occurances and the array b counts only second type of occurances. So in your example, the occurances which are entirely contained in the second S, aren’t counted in the term (x-1)*b[n-1] , they are counted in the term x*a[n-1] . So I don’t need to subtract anything.

In your implementation, I think you tried to count second type of occurances using the term (x-1ll)*(ll)(c[n+m+1+n-1]-c[m+1+n-1]) , but that’s wrong. Bacause this term counts some first type occurances too.

1 Like
(x-1ll)*(ll)(c[n+m+1+m-1]-c[m+1+n-1])

What if I use this? I understood my first doubt.

But count4 will have to be calculated for each query separately right as it depends on the query?

Why to use another kmp search to calculate count4 ?
count4 denotes the number of occurrences of T in S+ (some prefix of S). Right? You can do some pre-calculations.

I am using count4 to count any occurence that starts at the previous S and ends at the current S. Could you pleases suggest a method as I have just learnt this KMP algorithm and unable to figure which way I can do those pre-calculations :sweat_smile: .