MATPER - EDITORIAL

PROBLEM LINK:

Practice

Contest: Division 1

Contest: Division 2

Setter: Mohamed Anany

Tester: Encho Mishinev

Editorialist: Taranpreet Singh

DIFFICULTY:

Medium-Hard

PREREQUISITES:

KMP String Matching, Meet in the middle and Partial Sums.

PROBLEM:

Given String S of length N and M patterns in the String array P, Find number of permutations of Strings in array P which match with String S. A permutation p of strings match with string S if we can choose M ranges [l_i,r_i] for each i (1 \leq i \leq M) such that 1 \leq l_i \leq r_i < l_2 \leq r_2 \ldots < l_M \leq r_M \leq N such that for each valid i, String S_l,S_{l+1},\ldots S_r equals P_{p[i]}.

QUICK EXPLANATION

  • Using KMP String Matching, find all positions of all patterns in String S. Now reverse both String S and all patterns and once again, using KMP String matching, find positions of all reversed patterns on the reversed string.

  • Now iterate over all bitmasks of M bits which have M/2 bits on, and for all permutation of these set bits, find the minimum position p such that all these M/2 patterns are completely covered in prefix up to position p. Using partial sums, we can now count the number of permutations of set M/2 bits which matches with string S before position p.

  • For the M-M/2 off bits, we try all their permutations and for each permutation, find the position p from the end to start such that all the patterns given by M-M/2 off bits are found in suffix from p to end of the string in the given order, without overlapping. So, We know that for this permutation of M-M/2 patterns, all the permutations of first M/2 patterns ending before position p are valid. So, We can just add val[p] to answer where val[p] denotes the number of permutations of first M/2 patterns ending before or at position p.

EXPLANATION

First things first, we are going to need the positions where each pattern is present in the string. So, we can just find all match positions using KMP algorithm (or any other string matching algorithm) and store the positions of matches for each pattern. For our purpose, we can preprocess and build an NXT[i][j] table telling the position of the first match of ith pattern at or after jth position, since it makes sense to use the first match due to optimality.

The Naive solution for this problem shall be to iterate over all permutations of patterns and check for every permutation whether that permutation is matchable or not. This solution has complexity O(M!*N) and will definitely time out for M = 14.

Let us use Meet in the middle trick here.

Let us split each permutation into two equal (or nearly equal) parts. We can see that first M/2 patterns may be any subset of all patterns, so we can try all subsets of M patterns which is of size M/2. Now, These M/2 patterns may appear in any order and each shall be corresponding to a different permutation. So, Let us iterate over all permutation of all elements of current subsets. We can do the same for remaining M-M/2 patterns.

Now comes the interesting part.

Suppose we have a permutation of M/2 patterns (Calling it left permutation) (the ones with on bit in bitmask) and a permutation of M-M/2 patterns (calling it right permutation) (the ones with off bit in bitmask). How can we check if any pair of left and right permutation form a matchable permutation?

Let us suppose that p is the first position in the string S such that All the first M/2 patterns are present in String S[0, p] in order of left permutation such that no two patterns overlap. Also suppose that q is the last position in the string S such that all M-M/2 patterns are present in S[q, N-1] in order of right permutation such that no two patterns overlap.

The combination of these two permutations is valid if and only if p < q, since only that way all M patterns would be present in S without overlap for current permutation.

But trying every pair of permutation for every bitmask is essentially the same as iterating over all permutation since we check each permutation individually.

But we can use an observation here. If a right permutation is valid with a left permutation with end position p, it shall also be valid with any position j \leq p.

Hence, for every bitmask with M/2 bits set, we can iterate over all permutations of M/2 patterns and using prefix sum arrays, count the number of left permutations which end before or at a given position p for all positions p. Now iterating over all right permutations, we can easily count the number of left permutations which can be paired with current right permutation. We can increase our answer by the number of such left permutations for each right permutations and print the answer.

For ease of implementation, we can Build NXT table in the same manner, and then reverse both S and all patterns and find the position of matches on these reversed strings. That way, we can easily find the rightmost position such that all patterns of right permutation are present at or after that position without overlap, working in the same manner as we work with NXT table.

Time Complexity

Time complexity is O(C^M_{M/2}*(N+(M/2)!*M) + M*N+sum(|P_i|)).

AUTHOR’S AND TESTER’S SOLUTIONS:

Setter’s solution

Click to view
	#include <bits/stdc++.h>
using namespace std;
 
#define pb push_back
 
using ll = long long;
using ii = pair<int, int>;
 
const int N = 1e5 + 5, K = 14;
string S, P[K];
int n, k;
int nxt[K][N], rnxt[K][N];
int dp[1005][1 << K];
 
vector<int> FAIL(string pat) {
 
	int m = pat.size();
	vector<int> F(m + 1);
	int i = 0, j = -1;
	F[0] = -1;
 
	while (i < m) {
		while (j >= 0 && pat[i] != pat[j])
			j = F[j];
		i++, j++;
		F[i] = j;
	}
 
	return F;
}
 
vector<int> KMP_Search(string txt, string pat) {
 
	vector<int> F = FAIL(pat);
	int i = 0, j = 0;
	int n = txt.size(), m = pat.size();
	vector<int> ret;
	while (i < n) {
		while (j >= 0 && txt[i] != pat[j])
			j = F[j];
		i++, j++;
		if (j == m) {
			ret.pb(i - j);
			j = F[j];
		}
	}
 
	return ret;
}
 
vector<vector<int>> Match, rMatch;
 
void buildNext(int * arr, vector<int> matches) {
	//first matching >= i
	for (int i = 0; i <= n; i++)
		arr[i] = n;
	for (auto x : matches)
		arr[x] = x;
	for (int i = n - 2; i >= 0; --i)
		arr[i] = min(arr[i], arr[i + 1]);
}
int calc[N];
int solve(int idx, int mask) {
	if (idx > n)
		return 0;
	if (mask == (1 << k) - 1)
		return 1;
	int &ret = dp[idx][mask];
	if (~ret)
		return ret;
	ret = 0;
	for (int j = 0; j < k; j++)
		if (mask >> j & 1 ^ 1) {
			ret += solve(nxt[j][idx] + P[j].size(), mask | (1 << j));
		}
	return ret;
}
void Stress1() {
	///solution for subtask 2
	///first brute force dp[index][mask];
	///use next[index]
	///can solve subtasks 1-2
	memset(dp, -1, sizeof dp);
	cerr << solve(0, 0) << '\n';
}
int getFirst(vector<int> & x) {
	if (x.empty())
		return n;
	return x[0];
}
int getLast(vector<int> & x) {
	if (x.empty())
		return -1;
	return x.back();
}
 
void Stress2() {
	///Solution for subtask 1
	if (k == 1) {
		cerr << !KMP_Search(S, P[0]).empty() << '\n';
	} else if (k == 2) {
		vector<int> x1 = KMP_Search(S, P[0]);
		vector<int> x2 = KMP_Search(S, P[1]);
		int ans = 0;
		if (getFirst(x1) + (int) P[0].size() <= getLast(x2))
			ans++;
		if (getFirst(x2) + (int) P[1].size() <= getLast(x1))
			ans++;
		cerr << ans << '\n';
	} else {
		vector<int> x[3];
		x[0] = KMP_Search(S, P[0]);
		x[1] = KMP_Search(S, P[1]);
		x[2] = KMP_Search(S, P[2]);
		int ans = 0;
		for (int i = 0; i < 3; i++) {
			for (int j = 0; j < 3; j++) {
				if (i == j)
					continue;
				for (auto v : x[3 - i - j]) {
//					cout << i << ' ' << getFirst(x[i]) << ' ' << j << ' ' << getLast(x[j]) << ' ' << v << '\n';
					if (getFirst(x[i]) + (int) P[i].size() <= v
							&& v + (int) P[3 - i - j].size() <= getLast(x[j])) {
						ans++;
						break;
					}
				}
			}
		}
		cerr << ans << '\n';
	}
}
int main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);
 
	cin >> n >> k;
	cin >> S;
 
	for (int i = 0; i < k; i++)
		cin >> P[i];
 
	///get matchings and build next array for them
	///where nxt[j][i] = next matching position for the j-th pattern which is more than or equal i
	for (int i = 0; i < k; i++) {
		Match.pb(KMP_Search(S, P[i]));
		buildNext(nxt[i], Match.back());
	}
 
	reverse(S.begin(), S.end());
 
	///do the same as above for reverse strings and patterns
	for (int i = 0; i < k; i++) {
		reverse(P[i].begin(), P[i].end());
		rMatch.pb(KMP_Search(S, P[i]));
		buildNext(rnxt[i], rMatch.back());
		reverse(P[i].begin(), P[i].end());
	}
 
	reverse(S.begin(), S.end());
 
	ll ans = 0;
	for (int mask = 0; mask < (1 << k); mask++)
		if (__builtin_popcount(mask) == k / 2) {
			///process the normal
			///get the indexes
			///brute force on all permutations and find the minimum suffix
			///needed to get all of these matched for each permutation
			///the use partial sum to pre-process the results
			vector<int> v;
			for (int i = 0; i < k; i++)
				if (mask >> i & 1) {
					v.pb(i);
				}
			memset(calc, 0, sizeof calc);
			do {
				int cur = 0;
				for (auto x : v) {
					if (rnxt[x][cur] == n)
						goto fin1;
					cur = rnxt[x][cur] + P[x].size();
				}
 
				///i have the last cur digits covered
				calc[cur]++;
 
				fin1: ;
			} while (next_permutation(v.begin(), v.end()));
 
			///partial sum
			for (int i = 1; i <= n; i++)
				calc[i] += calc[i - 1];
 
			///solve the flip
			v.clear();
			int flip = ((1 << k) - 1) ^ mask;
			for (int i = 0; i < k; i++)
				if (flip >> i & 1) {
					v.pb(i);
				}
			///get the indexes brute force on them find the minimum
			///prefix to cover these patterns find all suffixes in the previous calculation
			do {
				int cur = 0;
				for (auto x : v) {
					if (nxt[x][cur] == n)
						goto fin2;
					cur = nxt[x][cur] + P[x].size();
				}
				ans += calc[n - cur];
				fin2: ;
			} while (next_permutation(v.begin(), v.end()));
 
		}
 
	cout << ans << '\n';
 
	return 0;
}

Tester’s solution

Click to view
	#include <iostream>
#include <stdio.h>
#include <string.h>
#include <map>
using namespace std;
typedef long long llong;
 
int n,m;
char s[100111];
char pattern[100111];
int L = 0;
 
int pLens[15];
 
int isMatch[100111];
int mKey = 1;
 
int nextMatch[15][100111];
 
///Knuth-Morris-Pratt
int F[100111];
void findMatches()
{
    int i;
    int k;
 
    //Failure function
    F[1] = 0;
    for (i=2;i<=L;i++)
    {
	k = F[i-1];
 
	while(k != 0 && pattern[k+1] != pattern[i])
	    k = F[k];
 
	if (k == 0)
	{
	    if (pattern[1] == pattern[i])
	        F[i] = 1;
	    else
	        F[i] = 0;
	}
	else
	    F[i] = k+1;
    }
 
    //Matching
    mKey++;
 
    k = 0;
    for (i=1;i<=n;i++)
    {
	while(k != 0 && pattern[k+1] != s[i])
	    k = F[k];
 
	if (k == 0)
	{
	    if (pattern[1] == s[i])
	        k = 1;
	    else
	        k = 0;
	}
	else
	    k++;
 
	if (k == L)
	{
	    isMatch[i - L + 1] = mKey;
 
	    k = F[k];
	}
    }
}
 
map< pair<int,int>, llong > mem;
map< pair<int,int>, llong >::iterator myit;
 
llong solve(int ind, int mask)
{
    if (mask == ((1<<m)-1))
	return 1LL;
 
    myit = mem.find(make_pair(ind,mask));
    if (myit != mem.end())
	return (*myit).second;
 
    int i;
    llong ans = 0;
 
    for (i=1;i<=m;i++)
    {
	if (nextMatch[i][ind] > n)
	    continue;
 
	if ( (mask&(1<<(i-1))) == 0 )
	{
	    ans += solve(nextMatch[i][ind] + pLens[i], mask | (1<<(i-1)));
	}
    }
 
    mem.insert(make_pair(make_pair(ind,mask),ans));
 
    return ans;
}
 
int main()
{
    int i,j;
 
    scanf("%d %d",&n,&m);
 
    scanf("%s",s+1);
 
    for (i=1;i<=m;i++)
    {
	scanf("%s",pattern+1);
	L = strlen(pattern+1);
 
	findMatches();
 
	nextMatch[i][n+1] = n+1;
	for (j=n;j>=1;j--)
	{
	    if (isMatch[j] == mKey)
	        nextMatch[i][j] = j;
	    else
	        nextMatch[i][j] = nextMatch[i][j+1];
	}
 
	pLens[i] = L;
    }
 
    printf("%lld\n",solve(1, 0));
 
    return 0;
}

Editorialist’s solution

Click to view
    import java.util.*;
import java.io.*;
import java.text.*;
//Solution Credits: Taranpreet Singh
public class Main{
    //SOLUTION BEGIN
    void pre() throws Exception{}
    long[] calc;
    void solve(int TC) throws Exception{
        int n = ni(), m = ni();
        String sf = n(), sr = new StringBuilder(sf).reverse().toString();
        int[][] nxtf = new int[m][n], nxtr = new int[m][n];int[] len = new int[m];
        calc = new long[n+1];
        for(int i = 0; i< m; i++){
            String x = n();len[i] = x.length();
            KMP(nxtf[i], sf, x);
            KMP(nxtr[i], sr, new StringBuilder(x).reverse().toString());
        }
        for(int mask = 0; mask< 1<<m; mask++){
            if(bit(mask)!=m/2)continue;
            int[] a = new int[bit(mask)];
            for(int i = 0, j = 0; i< m; i++)if(((mask>>i)&1)==1)a[j++] = i;
            Arrays.fill(calc, 0);
            permute(a, 0, nxtf,len, false);
            for(int i = 1; i< calc.length; i++)calc[i]+=calc[i-1];
            int[] b = new int[m-a.length];
            for(int i = 0, j = 0; i< m; i++)if(((mask>>i)&1)==0)b[j++] = i;
            permute(b, 0, nxtr,len, true);
        }
        pn(ans);
    }
    long ans = 0;
    void permute(int[] a, int pos,int[][] nxt,int[] len, boolean flag){
        if(pos==a.length){
            int cur = 0, n = nxt[0].length;
            for(int i = 0; i< a.length; i++){
                if(cur==n || nxt[a[i]][cur]==n)return;
                cur = nxt[a[i]][cur]+len[a[i]];
            }
            if(flag)ans+=calc[n-cur];
            else calc[cur]++;
        }else{
            for(int i = pos; i< a.length; i++){
                int tmp = a[i];
                a[i] = a[pos];
                a[pos] = tmp;
                permute(a,pos+1,nxt,len, flag);
                tmp = a[i];
                a[i] = a[pos];
                a[pos] = tmp;
            }
        }
    }
    void KMP(int[] nxt, String txt, String pat){ 
        int M = pat.length(); 
        int N = txt.length(); 
        Arrays.fill(nxt,N);
        int lps[] = new int[M]; 
        int i = 0,j = 0;
        computeLPSArray(pat, M, lps); 
        while (i < N) { 
            if (pat.charAt(j) == txt.charAt(i)) { 
                j++;i++; 
            } 
            if (j == M) { 
                nxt[i-j] = i-j;
                j = lps[j - 1]; 
            }else if (i < N && pat.charAt(j) != txt.charAt(i)) { 
                if (j != 0)j = lps[j - 1]; 
                else i++;
            } 
        } 
        for(i = N-2; i>= 0; i--)nxt[i] = Math.min(nxt[i], nxt[i+1]);
    } 
  
    void computeLPSArray(String pat, int M, int lps[]){ 
        int len = 0; 
        int i = 1; 
        lps[0] = 0;
        while (i < M) { 
            if (pat.charAt(i) == pat.charAt(len)) { 
                len++; 
                lps[i] = len; 
                i++; 
            }else{ 
                if(len != 0){len = lps[len - 1]; 
                }else{ 
                    lps[i] = len; 
                    i++; 
                } 
            } 
        } 
    }
    //SOLUTION END
    void hold(boolean b)throws Exception{if(!b)throw new Exception("Hold right there, Sparky!");}
    long mod = (long)1e9+7, IINF = (long)1e18;
    final int INF = (int)1e9, MX = (int)2e3+1;
    DecimalFormat df = new DecimalFormat("0.00000000000");
    double PI = 3.1415926535897932384626433832792884197169399375105820974944, eps = 1e-8;
    static boolean multipleTC = false, memory = false;
    FastReader in;PrintWriter out;
    void run() throws Exception{
        in = new FastReader();
        out = new PrintWriter(System.out);
        int T = (multipleTC)?ni():1;
        //Solution Credits: Taranpreet Singh
        pre();for(int t = 1; t<= T; t++)solve(t);
        out.flush();
        out.close();
    }
    public static void main(String[] args) throws Exception{
        if(memory)new Thread(null, new Runnable() {public void run(){try{new Main().run();}catch(Exception e){e.printStackTrace();}}}, "1", 1 << 28).start();
        else new Main().run();
    }
    long gcd(long a, long b){return (b==0)?a:gcd(b,a%b);}
    int gcd(int a, int b){return (b==0)?a:gcd(b,a%b);}
    int bit(long n){return (n==0)?0:(1+bit(n&(n-1)));}
    void p(Object o){out.print(o);}
    void pn(Object o){out.println(o);}
    void pni(Object o){out.println(o);out.flush();}
    String n()throws Exception{return in.next();}
    String nln()throws Exception{return in.nextLine();}
    int ni()throws Exception{return Integer.parseInt(in.next());}
    long nl()throws Exception{return Long.parseLong(in.next());}
    double nd()throws Exception{return Double.parseDouble(in.next());}
 
    class FastReader{
        BufferedReader br;
        StringTokenizer st;
        public FastReader(){
            br = new BufferedReader(new InputStreamReader(System.in));
        }
 
        public FastReader(String s) throws Exception{
            br = new BufferedReader(new FileReader(s));
        }
 
        String next() throws Exception{
            while (st == null || !st.hasMoreElements()){
                try{
                    st = new StringTokenizer(br.readLine());
                }catch (IOException  e){
                    throw new Exception(e.toString());
                }
            }
            return st.nextToken();
        }
 
        String nextLine() throws Exception{
            String str = "";
            try{   
                str = br.readLine();
            }catch (IOException e){
                throw new Exception(e.toString());
            }  
            return str;
        }
    }
} 

Feel free to Share your approach, If it differs. Suggestions are always welcomed. :slight_smile:

1 Like

I was able to solve it using kind of brute force recursion but with memoization.
Not sure if the solution just didn’t encounter strong enough test case or it is really valid. Anyway, it ran test cases in 0.07s time.

  1. First using KMP find all instances of each pattern in the string. Store those matches.

  2. The main function F(mask, L, total) finds the number of permutations of the subset of patterns in the main string starting at position L. Here: mask - specifies which patterns to include in the count, L - the starting position in the main string from which to start the search of permutations, total - the total length of the patterns specified by mask.

  3. The function iterates over set bits of the mask. For each set bit i, it finds the first occurrence (starting L) of the corresponding pattern (let’s call it P_i). Then the function calls itself to count the number permutations of the remaining patterns starting the position right after the identified occurrence of P_i. When the mask is zero, the recursion stops with the permutation count of 1.

  4. Just by itself this solution would be O(M!), which is too slow. However the following memoization helps to make it AC:

  • Each call to the function produces a triplet (L, L_{act}, Count), where L is the starting position for permutations search, L_{act}\geq L - actual starting positions that produced the permutations count Count.
  • For a given mask any starting position in the interval [L, L_{act}] will produce the same count Count of permutations of patterns specified by the mask.
  • For each mask, we can store these ranges in the map dp[mask] that consists of entries like: \{L: (L_{act}, Count)\}
  • With each new request for the same mask, we can check if L is included in one of the ranges. If not, then permutations are counted and the map is updated with one of the three outcomes: (a) The new entry \{L: (L_{act}, Count)\} is added to dp[mask], (b) The existing entry is updated with the new value of L or (c) The existing entry is updated with the new value of L_{act}.
  1. The final solution is F(mask = (1 << M)-1, L = 0, total = \sum_i |P_i|)

Her is the solution: CodeChef: Practical coding for everyone

3 Likes

I tried in the same way as @shoom did, but could only pass the solution for 40 points. Can someone please tell, where was I lagging?
Solution : https://www.codechef.com/viewsolution/23525369

My solution was super simple. It just computes dp(L, mask) = how many permutations of the patterns contained in mask can be formed up to position L (inclusive). What we care about is dp(N, 2^M - 1). By itself, there would be way too many states. But we don’t care about most of them. We can implement the dp function recursively with memoization. We call dp(N, 2^M - 1). Within a recursive call dp(L, mask) we iterate over all patterns j in the mask and we assume j is the last pattern in the permutation. For a fixed j we need to add dp(L’, mask xor 2^j) to the result, where L’ is the position just before the rightmost match of pattern j up to position L (L’ can be obtained in O(1) time by using something similar to the NXT table from the editorial).

I don’t have a formal proof why there won’t be too many visited states, but I tried many types of test cases locally and this solution was always quite fast. OK, it took 1.7 seconds on the Codechef servers, but that’s well within the time limit. Moreover, I’m pretty sure an important chunk of this duration is caused by the unordered maps I used (I didn’t use the well known tricks to set some parameters which usually make unordered maps perform significantly faster than using them with their default params).

4 Likes

Setter’s solution is giving compilation error in C++14(GCC 6.3).

1 Like

This C++ Code gives AC for 100.

While implementing same thing in Java Code gives TLE.

The logic implemented in both languages is as per editorial.

You are memoizing positions. It is very unlikely to hit exactly same position. My solution is to memoize ranges. Any starting position that falls into a memoized range would have the same result. In addition, if the requested starting position is outside the range, but the resulting count is the same, the range can be extended.

Yeah, I also did the same thing. But I also made unordered map a bit faster by reserving memory. It ran in 1.18s. Here’s the submission : CodeChef: Practical coding for everyone.

Yes, there were quite a solutions which passed by that.

Extra characters are appearing due to latex (First time posting code snippets in editorial). I’m trying to correct it now.

I hope it seems okay now. :slight_smile: