AC-Search实现

写了一遍多字符串匹配的ac-search算法,输出所有的原串匹配区间。

build_trie()构造一个trie,build_longest_trans_target()函数构造自动机无法前向匹配时的转移路径,即TrieNode::prefix字段。当自动机到达一个最终状态后(terminal state,TrieNode::pattern_end)构造原串当前位置之前的所有匹配区间,通过把patterns串reverse之后构造出一个prefix树然后遍历实现。如果有更好的匹配串构造方法请留言。

原理请参考: <Flexible Pattern Matching in Strings>, 3.2.2 Basic Aho-Corasick Algorithm.

Trie & reverse(patterns)之后的trie树和状态机构造代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
struct TrieNode {
struct TrieNode* sibling;
struct TrieNode* children;
struct TrieNode* prefix;
int height; // 1, 2, 3 ...
bool pattern_end;
char this_char;
TrieNode() : sibling(nullptr), children(nullptr), prefix(nullptr),
height(0), pattern_end(false), this_char(0) {}
};

TrieNode* get_node() {
return new TrieNode();
}

void release_node(TrieNode* node) {
delete node;
}

void free_trie(TrieNode* root) {
vector<TrieNode*> stack;
stack.push_back(root);
while(!stack.empty()){
root = stack.back();
if(root==nullptr) {
stack.pop_back();
continue;
}
stack.back() = root->sibling;
stack.push_back(root->children);
release_node(root);
}
}

// assumption: there are no duplicate item and no empty string in strs.
TrieNode* build_trie(const vector<string>& strs) {
// build root;
TrieNode* root = nullptr;
for(auto &str : strs) {
int len = 0;
TrieNode** insert_pos = &root;
TrieNode* parent_node = nullptr;
for(auto curr_char : str) {
// find insert pos
len += 1;
while(*insert_pos!=nullptr && (*insert_pos)->this_char<curr_char)
insert_pos = &(*insert_pos)->sibling;

parent_node = *insert_pos;
// new char, insert
if(*insert_pos==nullptr || (*insert_pos)->this_char>curr_char) {
parent_node = get_node();
parent_node->sibling = *insert_pos;
parent_node->this_char = curr_char;
parent_node->height = len;
*insert_pos = parent_node;
}
insert_pos = &parent_node->children;
}
// no duplicate
assert(parent_node && parent_node->pattern_end==false);
parent_node->pattern_end = true;
}
return root;
}

// 忽略这个名字吧。。
TrieNode* build_suffix(const vector<string>& strs) {
TrieNode* root = nullptr;
for(auto &str : strs) {
TrieNode** insert_pos = &root;
TrieNode* parent_node = nullptr;
for(int i=0; i<str.size(); ++i) {
char curr_char = str[str.size()-1-i];
while(*insert_pos!=nullptr && (*insert_pos)->this_char<curr_char)
insert_pos = &(*insert_pos)->sibling;

parent_node = *insert_pos;
if(*insert_pos==nullptr || (*insert_pos)->this_char>curr_char) {
parent_node = get_node();
parent_node->sibling = *insert_pos;
parent_node->this_char = curr_char;
parent_node->height = i+1;
*insert_pos = parent_node;
}
insert_pos = &parent_node->children;
}
assert(parent_node && parent_node->pattern_end==false);
parent_node->pattern_end = true;
}
return root;
}

TrieNode* list_search(TrieNode* list_head, char target) {
while(list_head && list_head->this_char!=target)
list_head = list_head->sibling;
return list_head;
}

void build_longest_trans_target(TrieNode* root) {
assert(root!=nullptr);
vector<TrieNode*> stack;
TrieNode* current = root;
while(current) {
stack.push_back(current);
current = current->sibling;
}
// bfs可以正确传递pattern_end标志, 但是需要使用queue实现。
// 此处dfs。
while(!stack.empty()){
current = stack.back();
stack.pop_back();
TrieNode* children = current->children;
while(children) {
TrieNode* candidate = current->prefix;
while(candidate && children->prefix==nullptr) {
children->prefix = list_search(candidate->children, children->this_char);
candidate = candidate->prefix;
}
if(children->prefix==nullptr)
children->prefix = list_search(root, children->this_char);
assert(children->prefix!=children);
stack.push_back(children);
TrieNode* prefix = children->prefix;
bool pattern_end = children->pattern_end;
while(prefix && !pattern_end) {
pattern_end = pattern_end || prefix->pattern_end;
prefix = prefix->prefix;
}
children->pattern_end = pattern_end;
children = children->sibling;
}
}
}

void print_tree(TrieNode* root, bool newline=true) {
while(root) {
cout << "(" << root->this_char << "," << root->height << ","
<< root->pattern_end << ") ";
print_tree(root->children, false);
root = root->sibling;
}
if(newline) cout << endl;
}

ac-search & test

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
void fill_pattern_occurrence(TrieNode* suffix, const string& haystack, int pos, 
vector<pair<int, int>>& result) {

for(int i=pos; i>=0 && suffix; --i) {
char curr_char = haystack[i];
suffix = list_search(suffix, curr_char);
if(suffix->pattern_end) {
// [i, pos+1) ~ [i, pos]
result.push_back(make_pair(i, pos+1));
}
suffix = suffix->children;
}
}

// [begin_pos, end_pos)
vector<pair<int, int>> ac_search(TrieNode* root, TrieNode* suffix, const string& haystack) {
assert(root!=nullptr);
assert(suffix!=nullptr);

vector<pair<int, int>> result;
TrieNode* prev_matched = nullptr;
TrieNode* curr_matched = nullptr;
for(int i=0; i<haystack.size(); ++i) {
char curr_char = haystack[i];
while(prev_matched && !curr_matched) {
curr_matched = list_search(prev_matched->children, curr_char);
prev_matched = prev_matched->prefix;
}
if(curr_matched==nullptr)
curr_matched = list_search(root, curr_char);
if(curr_matched && curr_matched->pattern_end) {
fill_pattern_occurrence(suffix, haystack, i, result);
}
prev_matched = curr_matched;
curr_matched = nullptr;
}
return result;
}

int main(int argc, char** argv) {
vector<string> patterns = {
"announce",
"annual",
"annually",
};
string haystack = "annual_announce";

TrieNode* prefix = build_trie(patterns);
build_longest_trans_target(prefix);
TrieNode* suffix = build_suffix(patterns);

print_tree(prefix);
print_tree(suffix);

auto result = ac_search(prefix, suffix, haystack);
for(auto& item : result) {
cout << item.first << ", " << item.second << endl;
}

free_trie(prefix);
free_trie(suffix);
return 0;
}

完整代码:github