summaryrefslogtreecommitdiffstats
path: root/modules/policy/lua-aho-corasick/ac_fast.cxx
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-06 00:55:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-06 00:55:53 +0000
commit3d0386f27ca66379acf50199e1d1298386eeeeb8 (patch)
treef87bd4a126b3a843858eb447e8fd5893c3ee3882 /modules/policy/lua-aho-corasick/ac_fast.cxx
parentInitial commit. (diff)
downloadknot-resolver-3d0386f27ca66379acf50199e1d1298386eeeeb8.tar.xz
knot-resolver-3d0386f27ca66379acf50199e1d1298386eeeeb8.zip
Adding upstream version 3.2.1.upstream/3.2.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'modules/policy/lua-aho-corasick/ac_fast.cxx')
-rw-r--r--modules/policy/lua-aho-corasick/ac_fast.cxx468
1 files changed, 468 insertions, 0 deletions
diff --git a/modules/policy/lua-aho-corasick/ac_fast.cxx b/modules/policy/lua-aho-corasick/ac_fast.cxx
new file mode 100644
index 0000000..9dbc2e6
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_fast.cxx
@@ -0,0 +1,468 @@
+#include <algorithm> // for std::sort
+#include "ac_slow.hpp"
+#include "ac_fast.hpp"
+
+uint32
+AC_Converter::Calc_State_Sz(const ACS_State* s) const {
+ AC_State dummy;
+ uint32 sz = offsetof(AC_State, input_vect);
+ sz += s->Get_GotoNum() * sizeof(dummy.input_vect[0]);
+
+ if (sz < sizeof(AC_State))
+ sz = sizeof(AC_State);
+
+ uint32 align = __alignof__(dummy);
+ sz = (sz + align - 1) & ~(align - 1);
+ return sz;
+}
+
+AC_Buffer*
+AC_Converter::Alloc_Buffer() {
+ const vector<ACS_State*>& all_states = _acs.Get_All_States();
+ const ACS_State* root_state = _acs.Get_Root_State();
+ uint32 root_fanout = root_state->Get_GotoNum();
+
+ // Step 1: Calculate the buffer size
+ AC_Ofst root_goto_ofst, states_ofst_ofst, first_state_ofst;
+
+ // part 1 : buffer header
+ uint32 sz = root_goto_ofst = sizeof(AC_Buffer);
+
+ // part 2: Root-node's goto function
+ if (likely(root_fanout != 255))
+ sz += 256;
+ else
+ root_goto_ofst = 0;
+
+ // part 3: mapping of state's relative position.
+ unsigned align = __alignof__(AC_Ofst);
+ sz = (sz + align - 1) & ~(align - 1);
+ states_ofst_ofst = sz;
+
+ sz += sizeof(AC_Ofst) * all_states.size();
+
+ // part 4: state's contents
+ align = __alignof__(AC_State);
+ sz = (sz + align - 1) & ~(align - 1);
+ first_state_ofst = sz;
+
+ uint32 state_sz = 0;
+ for (vector<ACS_State*>::const_iterator i = all_states.begin(),
+ e = all_states.end(); i != e; i++) {
+ state_sz += Calc_State_Sz(*i);
+ }
+ state_sz -= Calc_State_Sz(root_state);
+
+ sz += state_sz;
+
+ // Step 2: Allocate buffer, and populate header.
+ AC_Buffer* buf = _buf_alloc.alloc(sz);
+
+ buf->hdr.magic_num = AC_MAGIC_NUM;
+ buf->hdr.impl_variant = IMPL_FAST_VARIANT;
+ buf->buf_len = sz;
+ buf->root_goto_ofst = root_goto_ofst;
+ buf->states_ofst_ofst = states_ofst_ofst;
+ buf->first_state_ofst = first_state_ofst;
+ buf->root_goto_num = root_fanout;
+ buf->state_num = _acs.Get_State_Num();
+ return buf;
+}
+
+void
+AC_Converter::Populate_Root_Goto_Func(AC_Buffer* buf,
+ GotoVect& goto_vect) {
+ unsigned char *buf_base = (unsigned char*)(buf);
+ InputTy* root_gotos = (InputTy*)(buf_base + buf->root_goto_ofst);
+ const ACS_State* root_state = _acs.Get_Root_State();
+
+ root_state->Get_Sorted_Gotos(goto_vect);
+
+ // Renumber the ID of root-node's immediate kids.
+ uint32 new_id = 1;
+ bool full_fantout = (goto_vect.size() == 255);
+ if (likely(!full_fantout))
+ bzero(root_gotos, 256*sizeof(InputTy));
+
+ for (GotoVect::iterator i = goto_vect.begin(), e = goto_vect.end();
+ i != e; i++, new_id++) {
+ InputTy c = i->first;
+ ACS_State* s = i->second;
+ _id_map[s->Get_ID()] = new_id;
+
+ if (likely(!full_fantout))
+ root_gotos[c] = new_id;
+ }
+}
+
+AC_Buffer*
+AC_Converter::Convert() {
+ // Step 1: Some preparation stuff.
+ GotoVect gotovect;
+
+ _id_map.clear();
+ _ofst_map.clear();
+ _id_map.resize(_acs.Get_Next_Node_Id());
+ _ofst_map.resize(_acs.Get_Next_Node_Id());
+
+ // Step 2: allocate buffer to accommodate the entire AC graph.
+ AC_Buffer* buf = Alloc_Buffer();
+ unsigned char* buf_base = (unsigned char*)buf;
+
+ // Step 3: Root node need special care.
+ Populate_Root_Goto_Func(buf, gotovect);
+ buf->root_goto_num = gotovect.size();
+ _id_map[_acs.Get_Root_State()->Get_ID()] = 0;
+
+ // Step 4: Converting the remaining states by BFSing the graph.
+ // First of all, enter root's immediate kids to the working list.
+ vector<const ACS_State*> wl;
+ State_ID id = 1;
+ for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end();
+ i != e; i++, id++) {
+ ACS_State* s = i->second;
+ wl.push_back(s);
+ _id_map[s->Get_ID()] = id;
+ }
+
+ AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst);
+ AC_Ofst ofst = buf->first_state_ofst;
+ for (uint32 idx = 0; idx < wl.size(); idx++) {
+ const ACS_State* old_s = wl[idx];
+ AC_State* new_s = (AC_State*)(buf_base + ofst);
+
+ // This property should hold as we:
+ // - States are appended to worklist in the BFS order.
+ // - sibiling states are appended to worklist in the order of their
+ // corresponding input.
+ //
+ State_ID state_id = idx + 1;
+ ASSERT(_id_map[old_s->Get_ID()] == state_id);
+
+ state_ofst_vect[state_id] = ofst;
+
+ new_s->first_kid = wl.size() + 1;
+ new_s->depth = old_s->Get_Depth();
+ new_s->is_term = old_s->is_Terminal() ?
+ old_s->get_Pattern_Idx() + 1 : 0;
+
+ uint32 gotonum = old_s->Get_GotoNum();
+ new_s->goto_num = gotonum;
+
+ // Populate the "input" field
+ old_s->Get_Sorted_Gotos(gotovect);
+ uint32 input_idx = 0;
+ uint32 id = wl.size() + 1;
+ InputTy* input_vect = new_s->input_vect;
+ for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end();
+ i != e; i++, id++, input_idx++) {
+ input_vect[input_idx] = i->first;
+
+ ACS_State* kid = i->second;
+ _id_map[kid->Get_ID()] = id;
+ wl.push_back(kid);
+ }
+
+ _ofst_map[old_s->Get_ID()] = ofst;
+ ofst += Calc_State_Sz(old_s);
+ }
+
+ // This assertion might be useful to catch buffer overflow
+ ASSERT(ofst == buf->buf_len);
+
+ // Populate the fail-link field.
+ for (vector<const ACS_State*>::iterator i = wl.begin(), e = wl.end();
+ i != e; i++) {
+ const ACS_State* slow_s = *i;
+ State_ID fast_s_id = _id_map[slow_s->Get_ID()];
+ AC_State* fast_s = (AC_State*)(buf_base + state_ofst_vect[fast_s_id]);
+ if (const ACS_State* fl = slow_s->Get_FailLink()) {
+ State_ID id = _id_map[fl->Get_ID()];
+ fast_s->fail_link = id;
+ } else
+ fast_s->fail_link = 0;
+ }
+#ifdef DEBUG
+ //dump_buffer(buf, stderr);
+#endif
+ return buf;
+}
+
+static inline AC_State*
+Get_State_Addr(unsigned char* buf_base, AC_Ofst* StateOfstVect, uint32 state_id) {
+ ASSERT(state_id != 0 && "root node is handled in speical way");
+ ASSERT(state_id < ((AC_Buffer*)buf_base)->state_num);
+ return (AC_State*)(buf_base + StateOfstVect[state_id]);
+}
+
+// The performance of the binary search is critical to this work.
+//
+// Here we provide two versions of binary-search functions.
+// The non-pristine version seems to consistently out-perform "pristine" one on
+// bunch of benchmarks we tested. With the benchmark under tests/testinput/
+//
+// The speedup is following on my laptop (core i7, ubuntu):
+//
+// benchmark was is
+// ----------------------------------------
+// image.bin 2.3s 2.0s
+// test.tar 6.7s 5.7s
+//
+// NOTE: As of I write this comment, we only measure the performance on about
+// 10+ benchmarks. It's still too early to say which one works better.
+//
+#if !defined(BS_MULTI_VER)
+static bool __attribute__((always_inline)) inline
+Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) {
+ if (vect_len <= 8) {
+ for (int i = 0; i < vect_len; i++) {
+ if (input_vect[i] == input) {
+ idx = i;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // The "low" and "high" must be signed integers, as they could become -1.
+ // Also since they are signed integer, "(low + high)/2" is sightly more
+ // expensive than (low+high)>>1 or ((unsigned)(low + high))/2.
+ //
+ int low = 0, high = vect_len - 1;
+ while (low <= high) {
+ int mid = (low + high) >> 1;
+ InputTy mid_c = input_vect[mid];
+
+ if (input < mid_c)
+ high = mid - 1;
+ else if (input > mid_c)
+ low = mid + 1;
+ else {
+ idx = mid;
+ return true;
+ }
+ }
+ return false;
+}
+
+#else
+
+/* Let us call this version "pristine" version. */
+static inline bool
+Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) {
+ int low = 0, high = vect_len - 1;
+ while (low <= high) {
+ int mid = (low + high) >> 1;
+ InputTy mid_c = input_vect[mid];
+
+ if (input < mid_c)
+ high = mid - 1;
+ else if (input > mid_c)
+ low = mid + 1;
+ else {
+ idx = mid;
+ return true;
+ }
+ }
+ return false;
+}
+#endif
+
+typedef enum {
+ // Look for the first match. e.g. pattern set = {"ab", "abc", "def"},
+ // subject string "ababcdef". The first match would be "ab" at the
+ // beginning of the subject string.
+ MV_FIRST_MATCH,
+
+ // Look for the left-most longest match. Follow above example; there are
+ // two longest matches, "abc" and "def", and the left-most longest match
+ // is "abc".
+ MV_LEFT_LONGEST,
+
+ // Similar to the left-most longest match, except that it returns the
+ // *right* most longest match. Follow above example, the match would
+ // be "def". NYI.
+ MV_RIGHT_LONGEST,
+
+ // Return all patterns that match that given subject string. NYI.
+ MV_ALL_MATCHES,
+} MATCH_VARIANT;
+
+/* The Match_Tmpl is the template for vairants MV_FIRST_MATCH, MV_LEFT_LONGEST,
+ * MV_RIGHT_LONGEST (If we really really need MV_RIGHT_LONGEST variant, we are
+ * better off implementing it in a seprate function).
+ *
+ * The Match_Tmpl supports three variants at once "symbolically", once it's
+ * instanced to a particular variants, all the code irrelevant to the variants
+ * will be statically removed. So don't worry about the code like
+ * "if (variant == MV_XXXX)"; they will not incur any penalty.
+ *
+ * The drawback of using template is increased code size. Unfortunately, there
+ * is no silver bullet.
+ */
+template<MATCH_VARIANT variant> static ac_result_t
+Match_Tmpl(AC_Buffer* buf, const char* str, uint32 len) {
+ unsigned char* buf_base = (unsigned char*)(buf);
+ unsigned char* root_goto = buf_base + buf->root_goto_ofst;
+ AC_Ofst* states_ofst_vect = (AC_Ofst* )(buf_base + buf->states_ofst_ofst);
+
+ AC_State* state = 0;
+ uint32 idx = 0;
+
+ // Skip leading chars that are not valid input of root-nodes.
+ if (likely(buf->root_goto_num != 255)) {
+ while(idx < len) {
+ unsigned char c = str[idx++];
+ if (unsigned char kid_id = root_goto[c]) {
+ state = Get_State_Addr(buf_base, states_ofst_vect, kid_id);
+ break;
+ }
+ }
+ } else {
+ idx = 1;
+ state = Get_State_Addr(buf_base, states_ofst_vect, *str);
+ }
+
+ ac_result_t r = {-1, -1};
+ if (likely(state != 0)) {
+ if (unlikely(state->is_term)) {
+ /* Dictionary may have string of length 1 */
+ r.match_begin = idx - state->depth;
+ r.match_end = idx - 1;
+ r.pattern_idx = state->is_term - 1;
+
+ if (variant == MV_FIRST_MATCH) {
+ return r;
+ }
+ }
+ }
+
+ while (idx < len) {
+ unsigned char c = str[idx];
+ int res;
+ bool found;
+ found = Binary_Search_Input(state->input_vect, state->goto_num, c, res);
+ if (found) {
+ // The "t = goto(c, current_state)" is valid, advance to state "t".
+ uint32 kid = state->first_kid + res;
+ state = Get_State_Addr(buf_base, states_ofst_vect, kid);
+ idx++;
+ } else {
+ // Follow the fail-link.
+ State_ID fl = state->fail_link;
+ if (fl == 0) {
+ // fail-link is root-node, which implies the root-node dosen't
+ // have 255 valid transitions (otherwise, the fail-link should
+ // points to "goto(root, c)"), so we don't need speical handling
+ // as we did before this while-loop is entered.
+ //
+ while(idx < len) {
+ InputTy c = str[idx++];
+ if (unsigned char kid_id = root_goto[c]) {
+ state =
+ Get_State_Addr(buf_base, states_ofst_vect, kid_id);
+ break;
+ }
+ }
+ } else {
+ state = Get_State_Addr(buf_base, states_ofst_vect, fl);
+ }
+ }
+
+ // Check to see if the state is terminal state?
+ if (state->is_term) {
+ if (variant == MV_FIRST_MATCH) {
+ ac_result_t r;
+ r.match_begin = idx - state->depth;
+ r.match_end = idx - 1;
+ r.pattern_idx = state->is_term - 1;
+ return r;
+ }
+
+ if (variant == MV_LEFT_LONGEST) {
+ int match_begin = idx - state->depth;
+ int match_end = idx - 1;
+
+ if (r.match_begin == -1 ||
+ match_end - match_begin > r.match_end - r.match_begin) {
+ r.match_begin = match_begin;
+ r.match_end = match_end;
+ r.pattern_idx = state->is_term - 1;
+ }
+ continue;
+ }
+
+ ASSERT(false && "NYI");
+ }
+ }
+
+ return r;
+}
+
+ac_result_t
+Match(AC_Buffer* buf, const char* str, uint32 len) {
+ return Match_Tmpl<MV_FIRST_MATCH>(buf, str, len);
+}
+
+ac_result_t
+Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len) {
+ return Match_Tmpl<MV_LEFT_LONGEST>(buf, str, len);
+}
+
+#ifdef DEBUG
+void
+AC_Converter::dump_buffer(AC_Buffer* buf, FILE* f) {
+ vector<AC_Ofst> state_ofst;
+ state_ofst.resize(_id_map.size());
+
+ fprintf(f, "Id maps between old/slow and new/fast graphs\n");
+ int old_id = 0;
+ for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end();
+ i != e; i++, old_id++) {
+ State_ID new_id = *i;
+ if (new_id != 0) {
+ fprintf(f, "%d -> %d, ", old_id, new_id);
+ }
+ }
+ fprintf(f, "\n");
+
+ int idx = 0;
+ for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end();
+ i != e; i++, idx++) {
+ uint32 id = *i;
+ if (id == 0) continue;
+ state_ofst[id] = _ofst_map[idx];
+ }
+
+ unsigned char* buf_base = (unsigned char*)buf;
+
+ // dump root goto-function.
+ fprintf(f, "root, fanout:%d goto {", buf->root_goto_num);
+ if (buf->root_goto_num != 255) {
+ unsigned char* root_goto = buf_base + buf->root_goto_ofst;
+ for (uint32 i = 0; i < 255; i++) {
+ if (root_goto[i] != 0)
+ fprintf(f, "%c->S:%d, ", (unsigned char)i, root_goto[i]);
+ }
+ } else {
+ fprintf(f, "full fanout\n");
+ }
+ fprintf(f, "}\n");
+
+ // dump remaining states.
+ AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst);
+ for (uint32 i = 1, e = buf->state_num; i < e; i++) {
+ AC_Ofst ofst = state_ofst_vect[i];
+ ASSERT(ofst == state_ofst[i]);
+ fprintf(f, "S:%d, ofst:%d, goto={", i, ofst);
+
+ AC_State* s = (AC_State*)(buf_base + ofst);
+ State_ID kid = s->first_kid;
+ for (uint32 k = 0, ke = s->goto_num; k < ke; k++, kid++)
+ fprintf(f, "%c->S:%d, ", s->input_vect[k], kid);
+
+ fprintf(f, "}, fail-link = S:%d, %s\n", s->fail_link,
+ s->is_term ? "terminal" : "");
+ }
+}
+#endif