How to check if a string is a concatenation from collection of words? (exact match)

Given a set of words (or patterns), how can we find out if a string is a concatenation of these words (repetition is allowed). For example:

W = {"12", "122"}    
s = "12122" -> match  
s = "1122" -> miss  
s = "1221212" -> match 

I came up with a crappy solution O(n^2) :frowning:

  bool slow_match(const std::string &s, const std::unordered_set<std::string> &dict) {
	std::vector<int> exists(s.size(), 0);
	for (int i = 0; i < s.size(); ++i) {
		if (dict.find(s.substr(0, i + 1)) != dict.end()) {
			exists[i] = i + 1;
		}
		for (int j = 0; j < i && exists[i] == 0; ++j) {
			if (exists[j] != 0 && dict.find(s.substr(j + 1, i - j)) != dict.end()) {
				exists[i] = i - j;
			}
		}
	}
	return exists[s.size() - 1];
}

So is there a faster way to check if a match occurs? Iā€™m trying to use Aho Corasick (from Codechef August Long Contest 2013 - Music & Lyrics), but as long as there is a match for any word, it matches the whole text! Any idea?
This is what I got so far:

#include <iostream>
#include <fstream>
#include <sstream>
#include <iomanip>
#include <cstdio>
#include <cstdlib>

#include <algorithm>
#include <functional>
#include <utility>
#include <memory>
#include <cassert>
#include <cctype>

#include <exception>
#include <stdexcept>
#include <string>
#include <cstring>
#include <limits>
#include <climits>
#include <numeric>
#include <cmath>

#include <vector>
#include <list>
#include <stack>
#include <deque>
#include <queue>
#include <map>
#include <set>
#include <unordered_map>
#include <unordered_set>

#include <cstring>

const int MAX_STATES = 100;
const int CHARSET = 10;

class aho_corasick {
	// deterministics finite automaton
	int dfa[MAX_STATES][CHARSET];
	int fail[MAX_STATES];
	bool accepting[MAX_STATES];
	// current state
	int ct;

public:
	aho_corasick() {
		ct = 0;
		for (int s = 0; s < MAX_STATES; ++s) {
			accepting[s] = false;
			for (int c = 0; c < CHARSET; ++c) {
				dfa[s][c] = -1;
			}
		}
	}

	void build_fail_function() {
		// fail to root is root
		fail[0] = 0;
		for (int a = 0; a < CHARSET; ++a) {
			if (dfa[0][a] != -1) {
				// make all valid transition from root fall
				// back to root
				fail[dfa[0][a]] = 0;
			} else {
				dfa[0][a] = 0;
			}
		}
	}

	void build_dfa() {
		build_fail_function();
		std::queue<int> q;
		for (int a = 0; a < CHARSET; ++a) {
			if (dfa[0][a] > 0) {
				q.push(dfa[0][a]);
			}
		}
		while (!q.empty()) {
			int next_state = q.front();
			q.pop();
			// for all transitions starting from next_state
			for (int a = 0; a < CHARSET; ++a) {
				if (dfa[next_state][a] > 0) {
					fail[dfa[next_state][a]] = dfa[fail[next_state]][a];
					q.push(dfa[next_state][a]);
				} else {
					dfa[next_state][a] = dfa[fail[next_state]][a];
				}
			}
		}
	}

	int map(char c) const {
		return c - '0';
	}

	int new_state() {
		return ct++;
	}

	bool match(const std::string &text) const {
		int s = 0;
		int a;
		for (int i = 0, sz = text.size(); i < sz; ++i) {
			a = map(text[i]);
			s = dfa[s][a];
			// std::cout << "s = " << s << "\n";
		}
		return accepting[s];
	}

	void add_pattern(const std::string &pattern) {
		int s = 0;
		int a;
		for (int i = 0, sz = pattern.size(); i < sz; ++i) {
			a = map(pattern[i]);
			if (dfa[s][a] == -1) {
				dfa[s][a] = ct++;
				for (int b = 0; b < CHARSET; ++b) {
					dfa[ct][b] = -1;
				}
			}
			// nove to next state
			s = dfa[s][a];
			std::cout << "s: " << s << "\n";
		}
		accepting[s] = true;
	}
};

void test() {
	aho_corasick ac;
	ac.add_pattern("12");
	ac.add_pattern("122");
	ac.build_dfa();
	std::cout << ac.match("5100") << "\n";
	std::cout << ac.match("1266") << "\n";
}

int main() {
	test();
	return 0;
}
1 Like