diff options
Diffstat (limited to 'modules/policy/lua-aho-corasick')
24 files changed, 3367 insertions, 0 deletions
diff --git a/modules/policy/lua-aho-corasick/LICENSE b/modules/policy/lua-aho-corasick/LICENSE new file mode 100644 index 0000000..dd65f72 --- /dev/null +++ b/modules/policy/lua-aho-corasick/LICENSE @@ -0,0 +1,28 @@ + Copyright (c) 2014 CloudFlare, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of CloudFlare, Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/modules/policy/lua-aho-corasick/Makefile b/modules/policy/lua-aho-corasick/Makefile new file mode 100644 index 0000000..6471664 --- /dev/null +++ b/modules/policy/lua-aho-corasick/Makefile @@ -0,0 +1,134 @@ +OS := $(shell uname) + +ifeq ($(OS), Darwin) + SO_EXT := dylib +else + SO_EXT := so +endif + +############################################################################# +# +# Binaries we are going to build +# +############################################################################# +# +C_SO_NAME = libac.$(SO_EXT) +LUA_SO_NAME = ahocorasick.$(SO_EXT) +AR_NAME = libac.a + +############################################################################# +# +# Compile and link flags +# +############################################################################# +PREFIX ?= /usr/local +LUA_VERSION := 5.1 +LUA_INCLUDE_DIR := $(PREFIX)/include/lua$(LUA_VERSION) +SO_TARGET_DIR := $(PREFIX)/lib/lua/$(LUA_VERSION) +LUA_TARGET_DIR := $(PREFIX)/share/lua/$(LUA_VERSION) + +# Available directives: +# -DDEBUG : Turn on debugging support +# -DVERIFY : To verify if the slow-version and fast-version implementations +# get exactly the same result. Note -DVERIFY implies -DDEBUG. +# +COMMON_FLAGS = -O3 #-g -DVERIFY -msse2 -msse3 -msse4.1 +COMMON_FLAGS += -fvisibility=hidden -Wall $(CXXFLAGS) $(MY_CXXFLAGS) $(CPPFLAGS) + +SO_CXXFLAGS = $(COMMON_FLAGS) -fPIC +SO_LFLAGS = $(COMMON_FLAGS) $(LDFLAGS) +AR_CXXFLAGS = $(COMMON_FLAGS) + +# -DVERIFY implies -DDEBUG +ifneq ($(findstring -DVERIFY, $(COMMON_FLAGS)), ) +ifeq ($(findstring -DDEBUG, $(COMMON_FLAGS)), ) + COMMON_FLAGS += -DDEBUG +endif +endif + +AR = ar +AR_FLAGS = cru + +############################################################################# +# +# Divide source codes and objects into several categories +# +############################################################################# +# +SRC_COMMON := ac_fast.cxx ac_slow.cxx +LIBAC_SO_SRC := $(SRC_COMMON) ac.cxx # source for libac.so +LUA_SO_SRC := $(SRC_COMMON) ac_lua.cxx # source for ahocorasick.so +LIBAC_A_SRC := $(LIBAC_SO_SRC) # source for libac.a + +############################################################################# +# +# Make rules +# +############################################################################# +# +.PHONY = all clean test benchmark prepare +all : $(C_SO_NAME) $(LUA_SO_NAME) $(AR_NAME) + +-include c_so_dep.txt +-include lua_so_dep.txt +-include ar_dep.txt + +BUILD_SO_DIR := build_so +BUILD_AR_DIR := build_ar + +$(BUILD_SO_DIR) :; mkdir $@ +$(BUILD_AR_DIR) :; mkdir $@ + +$(BUILD_SO_DIR)/%.o : %.cxx | $(BUILD_SO_DIR) + $(CXX) $< -c $(SO_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ + +$(BUILD_AR_DIR)/%.o : %.cxx | $(BUILD_AR_DIR) + $(CXX) $< -c $(AR_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ + +ifneq ($(OS), Darwin) +$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared -Wl,-soname=$(C_SO_NAME) $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt + +$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared -Wl,-soname=$(LUA_SO_NAME) $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt + +else +$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt + +$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared $(SO_LFLAGS) -o $@ -Wl,-undefined,dynamic_lookup + cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt +endif + +$(AR_NAME) : $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.o}) + $(AR) $(AR_FLAGS) $@ $+ + cat $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.d}) > lua_so_dep.txt + +############################################################################# +# +# Misc +# +############################################################################# +# +test : $(C_SO_NAME) + $(MAKE) -C tests && \ + luajit tests/lua_test.lua && \ + luajit tests/load_ac_test.lua + +benchmark: $(C_SO_NAME) + $(MAKE) benchmark -C tests + +clean : + -rm -rf *.o *.d c_so_dep.txt lua_so_dep.txt ar_dep.txt $(TEST) \ + $(C_SO_NAME) $(LUA_SO_NAME) $(TEST) $(BUILD_SO_DIR) $(BUILD_AR_DIR) \ + $(AR_NAME) + make clean -C tests + +install: + install -D -m 755 $(C_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(C_SO_NAME) + install -D -m 755 $(LUA_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(LUA_SO_NAME) + install -D -m 664 load_ac.lua $(DESTDIR)/$(LUA_TARGET_DIR)/load_ac.lua diff --git a/modules/policy/lua-aho-corasick/README.md b/modules/policy/lua-aho-corasick/README.md new file mode 100644 index 0000000..b5cc406 --- /dev/null +++ b/modules/policy/lua-aho-corasick/README.md @@ -0,0 +1,40 @@ +aho-corasick-lua +================ + +C++ and Lua Implementation of the Aho-Corasick (AC) string matching algorithm +(http://dl.acm.org/citation.cfm?id=360855). + +We began with pure Lua implementation and realize the performance is not +satisfactory. So we switch to C/C++ implementation. + +There are two shared objects provied by this package: libac.so and ahocorasick.so +The former is a regular shared object which can be directly used by C/C++ +application, or by Lua via FFI; and the later is a Lua module. An example usage +is shown bellow: + +```lua +local ac = require "ahocorasick" +local dict = {"string1", "string", "etc"} +local acinst = ac.create(dict) +local r = ac.match(acinst, "mystring") +``` + +For efficiency reasons, the implementation is slightly different from the +standard AC algorithm in that it doesn't return a set of strings in the dictionary +that match the given string, instead it only returns one of them in case the string +matches. The functionality of our implementation can be (precisely) described by +following pseudo-c snippet. + +```C +string foo(input-string, dictionary) { + string ret = the-end-of-input-string; + for each string s in dictionary { + // find the first occurrence match sub-string. + ret = min(ret, strstr(input-string, s); + } + return ret; +} +``` + +It's pretty easy to get rid of this limitation, just to associate each state with +a spare bit-vector dipicting the set of strings recognized by that state. diff --git a/modules/policy/lua-aho-corasick/ac.cxx b/modules/policy/lua-aho-corasick/ac.cxx new file mode 100644 index 0000000..23fb3b5 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac.cxx @@ -0,0 +1,101 @@ +// Interface functions for libac.so +// +#include "ac_slow.hpp" +#include "ac_fast.hpp" +#include "ac.h" + +static inline ac_result_t +_match(buf_header_t* ac, const char* str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(ac->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match(buf, str, len); + + #ifdef VERIFY + { + Match_Result r2 = buf->slow_impl->Match(str, len); + if (r.match_begin != r2.begin) { + ASSERT(0); + } else { + ASSERT((r.match_begin < 0) || + (r.match_end == r2.end && + r.pattern_idx == r2.pattern_idx)); + } + } + #endif + return r; +} + +extern "C" int +ac_match2(ac_t* ac, const char* str, unsigned int len) { + ac_result_t r = _match((buf_header_t*)(void*)ac, str, len); + return r.match_begin; +} + +extern "C" ac_result_t +ac_match(ac_t* ac, const char* str, unsigned int len) { + return _match((buf_header_t*)(void*)ac, str, len); +} + +extern "C" ac_result_t +ac_match_longest_l(ac_t* ac, const char* str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(((buf_header_t*)ac)->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match_Longest_L(buf, str, len); + return r; +} + +class BufAlloc : public Buf_Allocator { +public: + virtual AC_Buffer* alloc(int sz) { + return (AC_Buffer*)(new unsigned char[sz]); + } + + // Do not de-allocate the buffer when the BufAlloc die. + virtual void free() {} + + static void myfree(AC_Buffer* buf) { + ASSERT(buf->hdr.magic_num == AC_MAGIC_NUM); + const char* b = (const char*)buf; + delete[] b; + } +}; + +extern "C" ac_t* +ac_create(const char** strv, unsigned int* strlenv, unsigned int v_len) { + if (v_len >= 65535) { + // TODO: Currently we use 16-bit to encode pattern-index (see the + // comment to AC_State::is_term), therefore we are not able to + // handle pattern set with more than 65535 entries. + return 0; + } + + ACS_Constructor *acc; +#ifdef VERIFY + acc = new ACS_Constructor; +#else + ACS_Constructor tmp; + acc = &tmp; +#endif + acc->Construct(strv, strlenv, v_len); + + BufAlloc ba; + AC_Converter cvt(*acc, ba); + AC_Buffer* buf = cvt.Convert(); + +#ifdef VERIFY + buf->slow_impl = acc; +#endif + return (ac_t*)(void*)buf; +} + +extern "C" void +ac_free(void* ac) { + AC_Buffer* buf = (AC_Buffer*)ac; +#ifdef VERIFY + delete buf->slow_impl; +#endif + + BufAlloc::myfree(buf); +} diff --git a/modules/policy/lua-aho-corasick/ac.h b/modules/policy/lua-aho-corasick/ac.h new file mode 100644 index 0000000..30bf447 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac.h @@ -0,0 +1,49 @@ +#ifndef AC_H +#define AC_H +#ifdef __cplusplus +extern "C" { +#endif + +#define AC_EXPORT __attribute__ ((visibility ("default"))) + +/* If the subject-string dosen't match any of the given patterns, "match_begin" + * should be a negative; otherwise the substring of the subject-string, + * starting from offset "match_begin" to "match_end" incusively, + * should exactly match the pattern specified by the 'pattern_idx' (i.e. + * the pattern is "pattern_v[pattern_idx]" where the "pattern_v" is the + * first acutal argument passing to ac_create()) + */ +typedef struct { + int match_begin; + int match_end; + int pattern_idx; +} ac_result_t; + +struct ac_t; + +/* Create an AC instance. "pattern_v" is a vector of patterns, the length of + * i-th pattern is specified by "pattern_len_v[i]"; the number of patterns + * is specified by "vect_len". + * + * Return the instance on success, or NUL otherwise. + */ +ac_t* ac_create(const char** pattern_v, unsigned int* pattern_len_v, + unsigned int vect_len) AC_EXPORT; + +ac_result_t ac_match(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +ac_result_t ac_match_longest_l(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +/* Similar to ac_match() except that it only returns match-begin. The rationale + * for this interface is that luajit has hard time in dealing with strcture- + * return-value. + */ +int ac_match2(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +void ac_free(void*) AC_EXPORT; + +#ifdef __cplusplus +} +#endif + +#endif /* AC_H */ diff --git a/modules/policy/lua-aho-corasick/ac_fast.cxx b/modules/policy/lua-aho-corasick/ac_fast.cxx new file mode 100644 index 0000000..9dbc2e6 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_fast.cxx @@ -0,0 +1,468 @@ +#include <algorithm> // for std::sort +#include "ac_slow.hpp" +#include "ac_fast.hpp" + +uint32 +AC_Converter::Calc_State_Sz(const ACS_State* s) const { + AC_State dummy; + uint32 sz = offsetof(AC_State, input_vect); + sz += s->Get_GotoNum() * sizeof(dummy.input_vect[0]); + + if (sz < sizeof(AC_State)) + sz = sizeof(AC_State); + + uint32 align = __alignof__(dummy); + sz = (sz + align - 1) & ~(align - 1); + return sz; +} + +AC_Buffer* +AC_Converter::Alloc_Buffer() { + const vector<ACS_State*>& all_states = _acs.Get_All_States(); + const ACS_State* root_state = _acs.Get_Root_State(); + uint32 root_fanout = root_state->Get_GotoNum(); + + // Step 1: Calculate the buffer size + AC_Ofst root_goto_ofst, states_ofst_ofst, first_state_ofst; + + // part 1 : buffer header + uint32 sz = root_goto_ofst = sizeof(AC_Buffer); + + // part 2: Root-node's goto function + if (likely(root_fanout != 255)) + sz += 256; + else + root_goto_ofst = 0; + + // part 3: mapping of state's relative position. + unsigned align = __alignof__(AC_Ofst); + sz = (sz + align - 1) & ~(align - 1); + states_ofst_ofst = sz; + + sz += sizeof(AC_Ofst) * all_states.size(); + + // part 4: state's contents + align = __alignof__(AC_State); + sz = (sz + align - 1) & ~(align - 1); + first_state_ofst = sz; + + uint32 state_sz = 0; + for (vector<ACS_State*>::const_iterator i = all_states.begin(), + e = all_states.end(); i != e; i++) { + state_sz += Calc_State_Sz(*i); + } + state_sz -= Calc_State_Sz(root_state); + + sz += state_sz; + + // Step 2: Allocate buffer, and populate header. + AC_Buffer* buf = _buf_alloc.alloc(sz); + + buf->hdr.magic_num = AC_MAGIC_NUM; + buf->hdr.impl_variant = IMPL_FAST_VARIANT; + buf->buf_len = sz; + buf->root_goto_ofst = root_goto_ofst; + buf->states_ofst_ofst = states_ofst_ofst; + buf->first_state_ofst = first_state_ofst; + buf->root_goto_num = root_fanout; + buf->state_num = _acs.Get_State_Num(); + return buf; +} + +void +AC_Converter::Populate_Root_Goto_Func(AC_Buffer* buf, + GotoVect& goto_vect) { + unsigned char *buf_base = (unsigned char*)(buf); + InputTy* root_gotos = (InputTy*)(buf_base + buf->root_goto_ofst); + const ACS_State* root_state = _acs.Get_Root_State(); + + root_state->Get_Sorted_Gotos(goto_vect); + + // Renumber the ID of root-node's immediate kids. + uint32 new_id = 1; + bool full_fantout = (goto_vect.size() == 255); + if (likely(!full_fantout)) + bzero(root_gotos, 256*sizeof(InputTy)); + + for (GotoVect::iterator i = goto_vect.begin(), e = goto_vect.end(); + i != e; i++, new_id++) { + InputTy c = i->first; + ACS_State* s = i->second; + _id_map[s->Get_ID()] = new_id; + + if (likely(!full_fantout)) + root_gotos[c] = new_id; + } +} + +AC_Buffer* +AC_Converter::Convert() { + // Step 1: Some preparation stuff. + GotoVect gotovect; + + _id_map.clear(); + _ofst_map.clear(); + _id_map.resize(_acs.Get_Next_Node_Id()); + _ofst_map.resize(_acs.Get_Next_Node_Id()); + + // Step 2: allocate buffer to accommodate the entire AC graph. + AC_Buffer* buf = Alloc_Buffer(); + unsigned char* buf_base = (unsigned char*)buf; + + // Step 3: Root node need special care. + Populate_Root_Goto_Func(buf, gotovect); + buf->root_goto_num = gotovect.size(); + _id_map[_acs.Get_Root_State()->Get_ID()] = 0; + + // Step 4: Converting the remaining states by BFSing the graph. + // First of all, enter root's immediate kids to the working list. + vector<const ACS_State*> wl; + State_ID id = 1; + for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); + i != e; i++, id++) { + ACS_State* s = i->second; + wl.push_back(s); + _id_map[s->Get_ID()] = id; + } + + AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); + AC_Ofst ofst = buf->first_state_ofst; + for (uint32 idx = 0; idx < wl.size(); idx++) { + const ACS_State* old_s = wl[idx]; + AC_State* new_s = (AC_State*)(buf_base + ofst); + + // This property should hold as we: + // - States are appended to worklist in the BFS order. + // - sibiling states are appended to worklist in the order of their + // corresponding input. + // + State_ID state_id = idx + 1; + ASSERT(_id_map[old_s->Get_ID()] == state_id); + + state_ofst_vect[state_id] = ofst; + + new_s->first_kid = wl.size() + 1; + new_s->depth = old_s->Get_Depth(); + new_s->is_term = old_s->is_Terminal() ? + old_s->get_Pattern_Idx() + 1 : 0; + + uint32 gotonum = old_s->Get_GotoNum(); + new_s->goto_num = gotonum; + + // Populate the "input" field + old_s->Get_Sorted_Gotos(gotovect); + uint32 input_idx = 0; + uint32 id = wl.size() + 1; + InputTy* input_vect = new_s->input_vect; + for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); + i != e; i++, id++, input_idx++) { + input_vect[input_idx] = i->first; + + ACS_State* kid = i->second; + _id_map[kid->Get_ID()] = id; + wl.push_back(kid); + } + + _ofst_map[old_s->Get_ID()] = ofst; + ofst += Calc_State_Sz(old_s); + } + + // This assertion might be useful to catch buffer overflow + ASSERT(ofst == buf->buf_len); + + // Populate the fail-link field. + for (vector<const ACS_State*>::iterator i = wl.begin(), e = wl.end(); + i != e; i++) { + const ACS_State* slow_s = *i; + State_ID fast_s_id = _id_map[slow_s->Get_ID()]; + AC_State* fast_s = (AC_State*)(buf_base + state_ofst_vect[fast_s_id]); + if (const ACS_State* fl = slow_s->Get_FailLink()) { + State_ID id = _id_map[fl->Get_ID()]; + fast_s->fail_link = id; + } else + fast_s->fail_link = 0; + } +#ifdef DEBUG + //dump_buffer(buf, stderr); +#endif + return buf; +} + +static inline AC_State* +Get_State_Addr(unsigned char* buf_base, AC_Ofst* StateOfstVect, uint32 state_id) { + ASSERT(state_id != 0 && "root node is handled in speical way"); + ASSERT(state_id < ((AC_Buffer*)buf_base)->state_num); + return (AC_State*)(buf_base + StateOfstVect[state_id]); +} + +// The performance of the binary search is critical to this work. +// +// Here we provide two versions of binary-search functions. +// The non-pristine version seems to consistently out-perform "pristine" one on +// bunch of benchmarks we tested. With the benchmark under tests/testinput/ +// +// The speedup is following on my laptop (core i7, ubuntu): +// +// benchmark was is +// ---------------------------------------- +// image.bin 2.3s 2.0s +// test.tar 6.7s 5.7s +// +// NOTE: As of I write this comment, we only measure the performance on about +// 10+ benchmarks. It's still too early to say which one works better. +// +#if !defined(BS_MULTI_VER) +static bool __attribute__((always_inline)) inline +Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { + if (vect_len <= 8) { + for (int i = 0; i < vect_len; i++) { + if (input_vect[i] == input) { + idx = i; + return true; + } + } + return false; + } + + // The "low" and "high" must be signed integers, as they could become -1. + // Also since they are signed integer, "(low + high)/2" is sightly more + // expensive than (low+high)>>1 or ((unsigned)(low + high))/2. + // + int low = 0, high = vect_len - 1; + while (low <= high) { + int mid = (low + high) >> 1; + InputTy mid_c = input_vect[mid]; + + if (input < mid_c) + high = mid - 1; + else if (input > mid_c) + low = mid + 1; + else { + idx = mid; + return true; + } + } + return false; +} + +#else + +/* Let us call this version "pristine" version. */ +static inline bool +Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { + int low = 0, high = vect_len - 1; + while (low <= high) { + int mid = (low + high) >> 1; + InputTy mid_c = input_vect[mid]; + + if (input < mid_c) + high = mid - 1; + else if (input > mid_c) + low = mid + 1; + else { + idx = mid; + return true; + } + } + return false; +} +#endif + +typedef enum { + // Look for the first match. e.g. pattern set = {"ab", "abc", "def"}, + // subject string "ababcdef". The first match would be "ab" at the + // beginning of the subject string. + MV_FIRST_MATCH, + + // Look for the left-most longest match. Follow above example; there are + // two longest matches, "abc" and "def", and the left-most longest match + // is "abc". + MV_LEFT_LONGEST, + + // Similar to the left-most longest match, except that it returns the + // *right* most longest match. Follow above example, the match would + // be "def". NYI. + MV_RIGHT_LONGEST, + + // Return all patterns that match that given subject string. NYI. + MV_ALL_MATCHES, +} MATCH_VARIANT; + +/* The Match_Tmpl is the template for vairants MV_FIRST_MATCH, MV_LEFT_LONGEST, + * MV_RIGHT_LONGEST (If we really really need MV_RIGHT_LONGEST variant, we are + * better off implementing it in a seprate function). + * + * The Match_Tmpl supports three variants at once "symbolically", once it's + * instanced to a particular variants, all the code irrelevant to the variants + * will be statically removed. So don't worry about the code like + * "if (variant == MV_XXXX)"; they will not incur any penalty. + * + * The drawback of using template is increased code size. Unfortunately, there + * is no silver bullet. + */ +template<MATCH_VARIANT variant> static ac_result_t +Match_Tmpl(AC_Buffer* buf, const char* str, uint32 len) { + unsigned char* buf_base = (unsigned char*)(buf); + unsigned char* root_goto = buf_base + buf->root_goto_ofst; + AC_Ofst* states_ofst_vect = (AC_Ofst* )(buf_base + buf->states_ofst_ofst); + + AC_State* state = 0; + uint32 idx = 0; + + // Skip leading chars that are not valid input of root-nodes. + if (likely(buf->root_goto_num != 255)) { + while(idx < len) { + unsigned char c = str[idx++]; + if (unsigned char kid_id = root_goto[c]) { + state = Get_State_Addr(buf_base, states_ofst_vect, kid_id); + break; + } + } + } else { + idx = 1; + state = Get_State_Addr(buf_base, states_ofst_vect, *str); + } + + ac_result_t r = {-1, -1}; + if (likely(state != 0)) { + if (unlikely(state->is_term)) { + /* Dictionary may have string of length 1 */ + r.match_begin = idx - state->depth; + r.match_end = idx - 1; + r.pattern_idx = state->is_term - 1; + + if (variant == MV_FIRST_MATCH) { + return r; + } + } + } + + while (idx < len) { + unsigned char c = str[idx]; + int res; + bool found; + found = Binary_Search_Input(state->input_vect, state->goto_num, c, res); + if (found) { + // The "t = goto(c, current_state)" is valid, advance to state "t". + uint32 kid = state->first_kid + res; + state = Get_State_Addr(buf_base, states_ofst_vect, kid); + idx++; + } else { + // Follow the fail-link. + State_ID fl = state->fail_link; + if (fl == 0) { + // fail-link is root-node, which implies the root-node dosen't + // have 255 valid transitions (otherwise, the fail-link should + // points to "goto(root, c)"), so we don't need speical handling + // as we did before this while-loop is entered. + // + while(idx < len) { + InputTy c = str[idx++]; + if (unsigned char kid_id = root_goto[c]) { + state = + Get_State_Addr(buf_base, states_ofst_vect, kid_id); + break; + } + } + } else { + state = Get_State_Addr(buf_base, states_ofst_vect, fl); + } + } + + // Check to see if the state is terminal state? + if (state->is_term) { + if (variant == MV_FIRST_MATCH) { + ac_result_t r; + r.match_begin = idx - state->depth; + r.match_end = idx - 1; + r.pattern_idx = state->is_term - 1; + return r; + } + + if (variant == MV_LEFT_LONGEST) { + int match_begin = idx - state->depth; + int match_end = idx - 1; + + if (r.match_begin == -1 || + match_end - match_begin > r.match_end - r.match_begin) { + r.match_begin = match_begin; + r.match_end = match_end; + r.pattern_idx = state->is_term - 1; + } + continue; + } + + ASSERT(false && "NYI"); + } + } + + return r; +} + +ac_result_t +Match(AC_Buffer* buf, const char* str, uint32 len) { + return Match_Tmpl<MV_FIRST_MATCH>(buf, str, len); +} + +ac_result_t +Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len) { + return Match_Tmpl<MV_LEFT_LONGEST>(buf, str, len); +} + +#ifdef DEBUG +void +AC_Converter::dump_buffer(AC_Buffer* buf, FILE* f) { + vector<AC_Ofst> state_ofst; + state_ofst.resize(_id_map.size()); + + fprintf(f, "Id maps between old/slow and new/fast graphs\n"); + int old_id = 0; + for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end(); + i != e; i++, old_id++) { + State_ID new_id = *i; + if (new_id != 0) { + fprintf(f, "%d -> %d, ", old_id, new_id); + } + } + fprintf(f, "\n"); + + int idx = 0; + for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end(); + i != e; i++, idx++) { + uint32 id = *i; + if (id == 0) continue; + state_ofst[id] = _ofst_map[idx]; + } + + unsigned char* buf_base = (unsigned char*)buf; + + // dump root goto-function. + fprintf(f, "root, fanout:%d goto {", buf->root_goto_num); + if (buf->root_goto_num != 255) { + unsigned char* root_goto = buf_base + buf->root_goto_ofst; + for (uint32 i = 0; i < 255; i++) { + if (root_goto[i] != 0) + fprintf(f, "%c->S:%d, ", (unsigned char)i, root_goto[i]); + } + } else { + fprintf(f, "full fanout\n"); + } + fprintf(f, "}\n"); + + // dump remaining states. + AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); + for (uint32 i = 1, e = buf->state_num; i < e; i++) { + AC_Ofst ofst = state_ofst_vect[i]; + ASSERT(ofst == state_ofst[i]); + fprintf(f, "S:%d, ofst:%d, goto={", i, ofst); + + AC_State* s = (AC_State*)(buf_base + ofst); + State_ID kid = s->first_kid; + for (uint32 k = 0, ke = s->goto_num; k < ke; k++, kid++) + fprintf(f, "%c->S:%d, ", s->input_vect[k], kid); + + fprintf(f, "}, fail-link = S:%d, %s\n", s->fail_link, + s->is_term ? "terminal" : ""); + } +} +#endif diff --git a/modules/policy/lua-aho-corasick/ac_fast.hpp b/modules/policy/lua-aho-corasick/ac_fast.hpp new file mode 100644 index 0000000..9ac557c --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_fast.hpp @@ -0,0 +1,124 @@ +#ifndef AC_FAST_H +#define AC_FAST_H + +#include <vector> +#include "ac.h" +#include "ac_slow.hpp" + +using namespace std; + +class ACS_Constructor; + +typedef uint32 AC_Ofst; +typedef uint32 State_ID; + +// The entire "fast" AC graph is converted from its "slow" version, and store +// in an consecutive trunk of memory or "buffer". Since the pointers in the +// fast AC graph are represented as offset relative to the base address of +// the buffer, this fast AC graph is position-independent, meaning cloning +// the fast graph is just to memcpy the entire buffer. +// +// The buffer is laid-out as following: +// +// 1. The buffer header. (i.e. the AC_Buffer content) +// 2. root-node's goto functions. It is represented as an array indiced by +// root-node's valid inputs, and the element is the ID of the corresponding +// transition state (aka kid). To save space, we used 8-bit to represent +// the IDs. ID of root's kids starts with 1. +// +// Root may have 255 valid inputs. In this speical case, i-th element +// stores value i -- i.e the i-th state. So, we don't need such array +// at all. On the other hand, 8-bit is insufficient to encode kids' ID. +// +// 3. An array indiced by state's id, and the element is the offset +// of correspoding state wrt the base address of the buffer. +// +// 4. the contents of states. +// +typedef struct { + buf_header_t hdr; // The header exposed to the user using this lib. +#ifdef VERIFY + ACS_Constructor* slow_impl; +#endif + uint32 buf_len; + AC_Ofst root_goto_ofst; // addr of root node's goto() function. + AC_Ofst states_ofst_ofst; // addr of state pointer vector (indiced by id) + AC_Ofst first_state_ofst; // addr of the first state in the buffer. + uint16 root_goto_num; // fan-out of root-node. + uint16 state_num; // number of states + + // Followed by the gut of the buffer: + // 1. map: root's-valid-input -> kid's id + // 2. map: state's ID -> offset of the state + // 3. states' content. +} AC_Buffer; + +// Depict the state of "fast" AC graph. +typedef struct { + // transition are sorted. For instance, state s1, has two transitions : + // goto(b) -> S_b, goto(a)->S_a. The inputs are sorted in the ascending + // order, and the target states are permuted accordingly. In this case, + // the inputs are sorted as : a, b, and the target states are permuted + // into S_a, S_b. So, S_a is the 1st kid, the ID of kids are consecutive, + // so we don't need to save all the target kids. + // + State_ID first_kid; + AC_Ofst fail_link; + short depth; // How far away from root. + unsigned short is_term; // Is terminal node. if is_term != 0, it encodes + // the value of "1 + pattern-index". + unsigned char goto_num; // The number of valid transition. + InputTy input_vect[1]; // Vector of valid input. Must be last field! +} AC_State; + +class Buf_Allocator { +public: + Buf_Allocator() : _buf(0) {} + virtual ~Buf_Allocator() { free(); } + + virtual AC_Buffer* alloc(int sz) = 0; + virtual void free() {}; +protected: + AC_Buffer* _buf; +}; + +// Convert slow-AC-graph into fast one. +class AC_Converter { +public: + AC_Converter(ACS_Constructor& acs, Buf_Allocator& ba) : + _acs(acs), _buf_alloc(ba) {} + AC_Buffer* Convert(); + +private: + // Return the size in byte needed to to save the specified state. + uint32 Calc_State_Sz(const ACS_State *) const; + + // In fast-AC-graph, the ID is bit trikcy. Given a state of slow-graph, + // this function is to return the ID of its counterpart in the fast-graph. + State_ID Get_Renumbered_Id(const ACS_State *s) const { + const vector<uint32> &m = _id_map; + return m[s->Get_ID()]; + } + + AC_Buffer* Alloc_Buffer(); + void Populate_Root_Goto_Func(AC_Buffer *, GotoVect&); + +#ifdef DEBUG + void dump_buffer(AC_Buffer*, FILE*); +#endif + +private: + ACS_Constructor& _acs; + Buf_Allocator& _buf_alloc; + + // map: ID of state in slow-graph -> ID of counterpart in fast-graph. + vector<uint32> _id_map; + + // map: ID of state in slow-graph -> offset of counterpart in fast-graph. + vector<AC_Ofst> _ofst_map; +}; + +ac_result_t Match(AC_Buffer* buf, const char* str, uint32 len); +ac_result_t Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len); + +#endif // AC_FAST_H diff --git a/modules/policy/lua-aho-corasick/ac_lua.cxx b/modules/policy/lua-aho-corasick/ac_lua.cxx new file mode 100644 index 0000000..ad7307e --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_lua.cxx @@ -0,0 +1,173 @@ +// Interface functions for libac.so +// +#include <vector> +#include <string> +#include "ac_slow.hpp" +#include "ac_fast.hpp" +#include "ac.h" // for the definition of ac_result_t +#include "ac_util.hpp" + +extern "C" { + #include <lua.h> + #include <lauxlib.h> +} + +#if defined(USE_SLOW_VER) +#error "Not going to implement it" +#endif + +using namespace std; +static const char* tname = "aho-corasick"; + +class BufAlloc : public Buf_Allocator { +public: + BufAlloc(lua_State* L) : _L(L) {} + virtual AC_Buffer* alloc(int sz) { + return (AC_Buffer*)lua_newuserdata (_L, sz); + } + + // Let GC to take care. + virtual void free() {} + +private: + lua_State* _L; +}; + +static bool +_create_helper(lua_State* L, const vector<const char*>& str_v, + const vector<unsigned int>& strlen_v) { + ASSERT(str_v.size() == strlen_v.size()); + + ACS_Constructor acc; + BufAlloc ba(L); + + // Step 1: construt the slow version. + unsigned int strnum = str_v.size(); + const char** str_vect = new const char*[strnum]; + unsigned int* strlen_vect = new unsigned int[strnum]; + + int idx = 0; + for (vector<const char*>::const_iterator i = str_v.begin(), e = str_v.end(); + i != e; i++) { + str_vect[idx++] = *i; + } + + idx = 0; + for (vector<unsigned int>::const_iterator i = strlen_v.begin(), + e = strlen_v.end(); i != e; i++) { + strlen_vect[idx++] = *i; + } + + acc.Construct(str_vect, strlen_vect, idx); + delete[] str_vect; + delete[] strlen_vect; + + // Step 2: convert to fast version + AC_Converter cvt(acc, ba); + return cvt.Convert() != 0; +} + +static ac_result_t +_match_helper(buf_header_t* ac, const char *str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(ac->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match(buf, str, len); + return r; +} + +// LUA sematic: +// input: array of strings +// output: userdata containing the AC-graph (i.e. the AC_Buffer). +// +static int +lac_create(lua_State* L) { + // The table of the array must be the 1st argument. + int input_tab = 1; + + luaL_checktype(L, input_tab, LUA_TTABLE); + + // Init the "iteartor". + lua_pushnil(L); + + vector<const char*> str_v; + vector<unsigned int> strlen_v; + + // Loop over the elements + while (lua_next(L, input_tab)) { + size_t str_len; + const char* s = luaL_checklstring(L, -1, &str_len); + str_v.push_back(s); + strlen_v.push_back(str_len); + + // remove the value, but keep the key as the iterator. + lua_pop(L, 1); + } + + // pop the nil value + lua_pop(L, 1); + + if (_create_helper(L, str_v, strlen_v)) { + // The AC graph, as a userdata is already pushed to the stack, hence 1. + return 1; + } + + return 0; +} + +// LUA input: +// arg1: the userdata, representing the AC graph, returned from l_create(). +// arg2: the string to be matched. +// +// LUA return: +// if match, return index range of the match; otherwise nil is returned. +// +static int +lac_match(lua_State* L) { + buf_header_t* ac = (buf_header_t*)lua_touserdata(L, 1); + if (!ac) { + luaL_checkudata(L, 1, tname); + return 0; + } + + size_t len; + const char* str; + #if LUA_VERSION_NUM >= 502 + str = luaL_tolstring(L, 2, &len); + #else + str = lua_tolstring(L, 2, &len); + #endif + if (!str) { + luaL_checkstring(L, 2); + return 0; + } + + ac_result_t r = _match_helper(ac, str, len); + if (r.match_begin != -1) { + lua_pushinteger(L, r.match_begin); + lua_pushinteger(L, r.match_end); + return 2; + } + + return 0; +} + +static const struct luaL_Reg lib_funcs[] = { + { "create", lac_create }, + { "match", lac_match }, + {0, 0} +}; + +extern "C" int AC_EXPORT +luaopen_ahocorasick(lua_State* L) { + luaL_newmetatable(L, tname); + +#if LUA_VERSION_NUM == 501 + luaL_register(L, tname, lib_funcs); +#elif LUA_VERSION_NUM >= 502 + luaL_newlib(L, lib_funcs); +#else + #error "Don't know how to do it right" +#endif + return 1; +} diff --git a/modules/policy/lua-aho-corasick/ac_slow.cxx b/modules/policy/lua-aho-corasick/ac_slow.cxx new file mode 100644 index 0000000..cb3957a --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_slow.cxx @@ -0,0 +1,318 @@ +#include <ctype.h> +#include <strings.h> // for bzero +#include <algorithm> +#include "ac_slow.hpp" +#include "ac.h" + +////////////////////////////////////////////////////////////////////////// +// +// Implementation of AhoCorasick_Slow +// +////////////////////////////////////////////////////////////////////////// +// +ACS_Constructor::ACS_Constructor() : _next_node_id(1) { + _root = new_state(); + _root_char = new InputTy[256]; + bzero((void*)_root_char, 256); + +#ifdef VERIFY + _pattern_buf = 0; +#endif +} + +ACS_Constructor::~ACS_Constructor() { + for (std::vector<ACS_State* >::iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + delete *i; + } + _all_states.clear(); + delete[] _root_char; + +#ifdef VERIFY + delete[] _pattern_buf; +#endif +} + +ACS_State* +ACS_Constructor::new_state() { + ACS_State* t = new ACS_State(_next_node_id++); + _all_states.push_back(t); + return t; +} + +void +ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len, + int pattern_idx) { + ACS_State* state = _root; + for (unsigned int i = 0; i < str_len; i++) { + const char c = str[i]; + ACS_State* new_s = state->Get_Goto(c); + if (!new_s) { + new_s = new_state(); + new_s->_depth = state->_depth + 1; + state->Set_Goto(c, new_s); + } + state = new_s; + } + state->_is_terminal = true; + state->set_Pattern_Idx(pattern_idx); +} + +void +ACS_Constructor::Propagate_faillink() { + ACS_State* r = _root; + std::vector<ACS_State*> wl; + + const ACS_Goto_Map& m = r->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) { + ACS_State* s = i->second; + s->_fail_link = r; + wl.push_back(s); + } + + // For any input c, make sure "goto(root, c)" is valid, which make the + // fail-link propagation lot easier. + ACS_Goto_Map goto_save = r->_goto_map; + for (uint32 i = 0; i <= 255; i++) { + ACS_State* s = r->Get_Goto(i); + if (!s) r->Set_Goto(i, r); + } + + for (uint32 i = 0; i < wl.size(); i++) { + ACS_State* s = wl[i]; + ACS_State* fl = s->_fail_link; + + const ACS_Goto_Map& tran_map = s->Get_Goto_Map(); + + for (ACS_Goto_Map::const_iterator ii = tran_map.begin(), + ee = tran_map.end(); ii != ee; ii++) { + InputTy c = ii->first; + ACS_State *tran = ii->second; + + ACS_State* tran_fl = 0; + for (ACS_State* fl_walk = fl; ;) { + if (ACS_State* t = fl_walk->Get_Goto(c)) { + tran_fl = t; + break; + } else { + fl_walk = fl_walk->Get_FailLink(); + } + } + + tran->_fail_link = tran_fl; + wl.push_back(tran); + } + } + + // Remove "goto(root, c) == root" transitions + r->_goto_map = goto_save; +} + +void +ACS_Constructor::Construct(const char** strv, unsigned int* strlenv, + uint32 strnum) { + Save_Patterns(strv, strlenv, strnum); + + for (uint32 i = 0; i < strnum; i++) { + Add_Pattern(strv[i], strlenv[i], i); + } + + Propagate_faillink(); + unsigned char* p = _root_char; + + const ACS_Goto_Map& m = _root->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); + i != e; i++) { + p[i->first] = 1; + } +} + +Match_Result +ACS_Constructor::MatchHelper(const char *str, uint32 len) const { + const ACS_State* root = _root; + const ACS_State* state = root; + + uint32 idx = 0; + while (idx < len) { + InputTy c = str[idx]; + idx++; + if (_root_char[c]) { + state = root->Get_Goto(c); + break; + } + } + + if (unlikely(state->is_Terminal())) { + // This could happen if the one of the pattern has only one char! + uint32 pos = idx - 1; + Match_Result r(pos - state->Get_Depth() + 1, pos, + state->get_Pattern_Idx()); + return r; + } + + while (idx < len) { + InputTy c = str[idx]; + ACS_State* gs = state->Get_Goto(c); + + if (!gs) { + ACS_State* fl = state->Get_FailLink(); + if (fl == root) { + while (idx < len) { + InputTy c = str[idx]; + idx++; + if (_root_char[c]) { + state = root->Get_Goto(c); + break; + } + } + } else { + state = fl; + } + } else { + idx ++; + state = gs; + } + + if (state->is_Terminal()) { + uint32 pos = idx - 1; + Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos, + state->get_Pattern_Idx()); + return r; + } + } + + return Match_Result(-1, -1, -1); +} + +#ifdef DEBUG +void +ACS_Constructor::dump_text(const char* txtfile) const { + FILE* f = fopen(txtfile, "w+"); + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State* s = *i; + + fprintf(f, "S%d goto:{", s->Get_ID()); + const ACS_Goto_Map& goto_func = s->Get_Goto_Map(); + + for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end(); + i != e; i++) { + InputTy input = i->first; + ACS_State* tran = i->second; + if (isprint(input)) + fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID()); + else + fprintf(f, "%#x -> S:%d,", input, tran->Get_ID()); + } + fprintf(f, "} "); + + if (s->_fail_link) { + fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID()); + } + + if (s->_is_terminal) { + fprintf(f, ", terminal"); + } + + fprintf(f, "\n"); + } + fclose(f); +} + +void +ACS_Constructor::dump_dot(const char *dotfile) const { + FILE* f = fopen(dotfile, "w+"); + const char* indent = " "; + + fprintf(f, "digraph G {\n"); + + // Emit node information + fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID()); + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State *s = *i; + if (s->_is_terminal) { + fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID()); + } + } + fprintf(f, "\n"); + + // Emit edge information + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State* s = *i; + uint32 id = s->Get_ID(); + + const ACS_Goto_Map& m = s->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end(); + ii != ee; ii++) { + InputTy input = ii->first; + ACS_State* tran = ii->second; + if (isalnum(input)) + fprintf(f, "%s%d -> %d [label=%c];\n", + indent, id, tran->Get_ID(), input); + else + fprintf(f, "%s%d -> %d [label=\"%#x\"];\n", + indent, id, tran->Get_ID(), input); + + } + + // Emit fail-link + ACS_State* fl = s->Get_FailLink(); + if (fl && fl != _root) { + fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n", + indent, id, fl->Get_ID()); + } + } + fprintf(f, "}\n"); + fclose(f); +} +#endif + +#ifdef VERIFY +void +ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r) + const { + if (r->begin >= 0) { + unsigned len = r->end - r->begin + 1; + int ptn_idx = r->pattern_idx; + + ASSERT(ptn_idx >= 0 && + len == get_ith_Pattern_Len(ptn_idx) && + memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0); + } +} + +void +ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv, + int pattern_num) { + // calculate the total size needed to save all patterns. + // + int buf_size = 0; + for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; } + + // HINT: patterns are delimited by '\0' in order to ease debugging. + buf_size += pattern_num; + ASSERT(_pattern_buf == 0); + _pattern_buf = new char[buf_size + 1]; + #define MAGIC_NUM 0x5a + _pattern_buf[buf_size] = MAGIC_NUM; + + int ofst = 0; + _pattern_lens.resize(pattern_num); + _pattern_vect.resize(pattern_num); + for (int i = 0; i < pattern_num; i++) { + int l = strlenv[i]; + _pattern_lens[i] = l; + _pattern_vect[i] = _pattern_buf + ofst; + + memcpy(_pattern_buf + ofst, strv[i], l); + ofst += l; + _pattern_buf[ofst++] = '\0'; + } + + ASSERT(_pattern_buf[buf_size] == MAGIC_NUM); + #undef MAGIC_NUM +} + +#endif diff --git a/modules/policy/lua-aho-corasick/ac_slow.hpp b/modules/policy/lua-aho-corasick/ac_slow.hpp new file mode 100644 index 0000000..030b95d --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_slow.hpp @@ -0,0 +1,158 @@ +#ifndef MY_AC_H +#define MY_AC_H + +#include <string.h> +#include <stdio.h> +#include <map> +#include <vector> +#include <algorithm> // for std::sort +#include "ac_util.hpp" + +// Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation" +class ACS_State; +class ACS_Constructor; +class AhoCorasick; + +using namespace std; + +typedef std::map<InputTy, ACS_State*> ACS_Goto_Map; + +class Match_Result { +public: + int begin; + int end; + int pattern_idx; + Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {} +}; + +typedef pair<InputTy, ACS_State *> GotoPair; +typedef vector<GotoPair> GotoVect; + +// Sorting functor +class GotoSort { +public: + bool operator() (const GotoPair& g1, const GotoPair& g2) { + return g1.first < g2.first; + } +}; + +class ACS_State { +friend class ACS_Constructor; + +public: + ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0), + _is_terminal(false), _fail_link(0){} + ~ACS_State() {}; + + void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; } + ACS_State *Get_Goto(InputTy c) const { + ACS_Goto_Map::const_iterator iter = _goto_map.find(c); + return iter != _goto_map.end() ? (*iter).second : 0; + } + + // Return all transitions sorted in the ascending order of their input. + void Get_Sorted_Gotos(GotoVect& Gotos) const { + const ACS_Goto_Map& m = _goto_map; + Gotos.clear(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); + i != e; i++) { + Gotos.push_back(GotoPair(i->first, i->second)); + } + sort(Gotos.begin(), Gotos.end(), GotoSort()); + } + + ACS_State* Get_FailLink() const { return _fail_link; } + uint32 Get_GotoNum() const { return _goto_map.size(); } + uint32 Get_ID() const { return _id; } + uint32 Get_Depth() const { return _depth; } + const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; } + bool is_Terminal() const { return _is_terminal; } + int get_Pattern_Idx() const { + ASSERT(is_Terminal() && _pattern_idx >= 0); + return _pattern_idx; + } + +private: + void set_Pattern_Idx(int idx) { + ASSERT(is_Terminal()); + _pattern_idx = idx; + } + +private: + uint32 _id; + int _pattern_idx; + short _depth; + bool _is_terminal; + ACS_Goto_Map _goto_map; + ACS_State* _fail_link; +}; + +class ACS_Constructor { +public: + ACS_Constructor(); + ~ACS_Constructor(); + + void Construct(const char** strv, unsigned int* strlenv, + unsigned int strnum); + + Match_Result Match(const char* s, uint32 len) const { + Match_Result r = MatchHelper(s, len); + Verify_Result(s, &r); + return r; + } + + Match_Result Match(const char* s) const { return Match(s, strlen(s)); } + +#ifdef DEBUG + void dump_text(const char* = "ac.txt") const; + void dump_dot(const char* = "ac.dot") const; +#endif + const ACS_State *Get_Root_State() const { return _root; } + const vector<ACS_State*>& Get_All_States() const { + return _all_states; + } + + uint32 Get_Next_Node_Id() const { return _next_node_id; } + uint32 Get_State_Num() const { return _next_node_id - 1; } + +private: + void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx); + ACS_State* new_state(); + void Propagate_faillink(); + + Match_Result MatchHelper(const char*, uint32 len) const; + +#ifdef VERIFY + void Verify_Result(const char* subject, const Match_Result* r) const; + void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len); + const char* get_ith_Pattern(unsigned i) const { + ASSERT(i < _pattern_vect.size()); + return _pattern_vect.at(i); + } + unsigned get_ith_Pattern_Len(unsigned i) const { + ASSERT(i < _pattern_lens.size()); + return _pattern_lens.at(i); + } +#else + void Verify_Result(const char* subject, const Match_Result* r) const { + (void)subject; (void)r; + } + void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) { + (void)strv; (void)strlenv; + } +#endif + +private: + ACS_State* _root; + vector<ACS_State*> _all_states; + unsigned char* _root_char; + uint32 _next_node_id; + +#ifdef VERIFY + char* _pattern_buf; + vector<int> _pattern_lens; + vector<char*> _pattern_vect; +#endif +}; + +#endif diff --git a/modules/policy/lua-aho-corasick/ac_util.hpp b/modules/policy/lua-aho-corasick/ac_util.hpp new file mode 100644 index 0000000..56fd46c --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_util.hpp @@ -0,0 +1,69 @@ +/* + Copyright (c) 2014 CloudFlare, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of CloudFlare, Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#ifndef AC_UTIL_H +#define AC_UTIL_H + +#ifdef DEBUG +#include <stdio.h> // for fprintf +#include <stdlib.h> // for abort +#endif + +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long uint64; +typedef unsigned char InputTy; + +#ifdef DEBUG + // Usage examples: ASSERT(a > b), ASSERT(foo() && "Opps, foo() reutrn 0"); + #define ASSERT(c) if (!(c))\ + { fprintf(stderr, "%s:%d Assert: %s\n", __FILE__, __LINE__, #c); abort(); } +#else + #define ASSERT(c) ((void)0) +#endif + +#define likely(x) __builtin_expect((x),1) +#define unlikely(x) __builtin_expect((x),0) + +#ifndef offsetof +#define offsetof(st, m) ((size_t)(&((st *)0)->m)) +#endif + +typedef enum { + IMPL_SLOW_VARIANT = 1, + IMPL_FAST_VARIANT = 2, +} impl_var_t; + +#define AC_MAGIC_NUM 0x5a +typedef struct { + unsigned char magic_num; + unsigned char impl_variant; +} buf_header_t; + +#endif //AC_UTIL_H diff --git a/modules/policy/lua-aho-corasick/load_ac.lua b/modules/policy/lua-aho-corasick/load_ac.lua new file mode 100644 index 0000000..eb70446 --- /dev/null +++ b/modules/policy/lua-aho-corasick/load_ac.lua @@ -0,0 +1,90 @@ +-- Helper wrappring script for loading shared object libac.so (FFI interface) +-- from package.cpath instead of LD_LIBRARTY_PATH. +-- + +local ffi = require 'ffi' +ffi.cdef[[ + void* ac_create(const char** str_v, unsigned int* strlen_v, + unsigned int v_len); + int ac_match2(void*, const char *str, int len); + void ac_free(void*); +]] + +local _M = {} + +local string_gmatch = string.gmatch +local string_match = string.match + +local ac_lib = nil +local ac_create = nil +local ac_match = nil +local ac_free = nil + +--[[ Find shared object file package.cpath, obviating the need of setting + LD_LIBRARY_PATH +]] +local function find_shared_obj(cpath, so_name) + for k, v in string_gmatch(cpath, "[^;]+") do + local so_path = string_match(k, "(.*/)") + if so_path then + -- "so_path" could be nil. e.g, the dir path component is "." + so_path = so_path .. so_name + + -- Don't get me wrong, the only way to know if a file exist is + -- trying to open it. + local f = io.open(so_path) + if f ~= nil then + io.close(f) + return so_path + end + end + end +end + +function _M.load_ac_lib() + if ac_lib ~= nil then + return ac_lib + else + local so_path = find_shared_obj(package.cpath, "libac.so") + if so_path ~= nil then + ac_lib = ffi.load(so_path) + ac_create = ac_lib.ac_create + ac_match = ac_lib.ac_match2 + ac_free = ac_lib.ac_free + return ac_lib + end + end +end + +-- Create an Aho-Corasick instance, and return the instance if it was +-- successful. +function _M.create_ac(dict) + local strnum = #dict + if ac_lib == nil then + _M.load_ac_lib() + end + + local str_v = ffi.new("const char *[?]", strnum) + local strlen_v = ffi.new("unsigned int [?]", strnum) + + for i = 1, strnum do + local s = dict[i] + str_v[i - 1] = s + strlen_v[i - 1] = #s + end + + local ac = ac_create(str_v, strlen_v, strnum); + if ac ~= nil then + return ffi.gc(ac, ac_free) + end +end + +-- Return nil if str doesn't match the dictionary, else return non-nil. +function _M.match(ac, str) + local r = ac_match(ac, str, #str); + if r >= 0 then + return r + end +end + +return _M diff --git a/modules/policy/lua-aho-corasick/mytest.cxx b/modules/policy/lua-aho-corasick/mytest.cxx new file mode 100644 index 0000000..ef3dc87 --- /dev/null +++ b/modules/policy/lua-aho-corasick/mytest.cxx @@ -0,0 +1,200 @@ +#include <stdio.h> +#include <string.h> +#include <vector> +#include "ac.h" + +using namespace std; + +///////////////////////////////////////////////////////////////////////// +// +// Test using strings from input files +// +///////////////////////////////////////////////////////////////////////// +// +class BigFileTester { +public: + BigFileTester(const char* filepath); + +private: + void Genector +privaete: + const char* _msg; + int _msg_len; + int _key_num; // number of strings in dictionary + int _key_len_idx; +}; + +///////////////////////////////////////////////////////////////////////// +// +// Simple (yet maybe tricky) testings +// +///////////////////////////////////////////////////////////////////////// +// +typedef struct { + const char* str; + const char* match; +} StrPair; + +typedef struct { + const char* name; + const char** dict; + StrPair* strpairs; + int dict_len; + int strpair_num; +} TestingCase; + +class Tests { +public: + Tests(const char* name, + const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num) { + if (!_tests) + _tests = new vector<TestingCase>; + + TestingCase tc; + tc.name = name; + tc.dict = dict; + tc.strpairs = strpairs; + tc.dict_len = dict_len; + tc.strpair_num = strpair_num; + _tests->push_back(tc); + } + + static vector<TestingCase>* Get_Tests() { return _tests; } + static void Erase_Tests() { delete _tests; _tests = 0; } + +private: + static vector<TestingCase> *_tests; +}; + +vector<TestingCase>* Tests::_tests = 0; + +static void +simple_test(void) { + int total = 0; + int fail = 0; + + vector<TestingCase> *tests = Tests::Get_Tests(); + if (!tests) + return 0; + + for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end(); + i != e; i++) { + TestingCase& t = *i; + fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); + for (int i = 0, e = t.dict_len, need_break=0; i < e; i++) { + fprintf(stdout, "%s, ", t.dict[i]); + if (need_break++ == 16) { + fputs("\n ", stdout); + need_break = 0; + } + } + fputs("]\n", stdout); + + /* Create the dictionary */ + int dict_len = t.dict_len; + ac_t* ac = ac_create(t.dict, dict_len); + + for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { + const StrPair& sp = t.strpairs[ii]; + const char *str = sp.str; // the string to be matched + const char *match = sp.match; + + fprintf(stdout, "[%3d] Testing '%s' : ", total, str); + + int len = strlen(str); + ac_result_t r = ac_match(ac, str, len); + int m_b = r.match_begin; + int m_e = r.match_end; + + // The return value per se is insane. + if (m_b > m_e || + ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { + fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); + fail ++; + continue; + } + + // If the string is not supposed to match the dictionary. + if (!match) { + if (m_b != -1 || m_e != -1) { + fail ++; + fprintf(stdout, "Not Supposed to match (%d, %d) \n", + m_b, m_e); + } else + fputs("Pass\n", stdout); + continue; + } + + // The string or its substring is match the dict. + if (m_b >= len || m_b >= len) { + fail ++; + fprintf(stdout, + "Return value >= the length of the string (%d, %d)\n", + m_b, m_e); + continue; + } else { + int mlen = strlen(match); + if ((mlen != m_e - m_b + 1) || + strncmp(str + m_b, match, mlen)) { + fail ++; + fprintf(stdout, "Fail\n"); + } else + fprintf(stdout, "Pass\n"); + } + } + fputs("\n", stdout); + ac_free(ac); + } + + fprintf(stdout, "Total : %d, Fail %d\n", total, fail); + + return fail ? -1 : 0; +} + +int +main (int argc, char** argv) { + int res = simple_test(); + return res; +}; + +/* test 1*/ +const char *dict1[] = {"he", "she", "his", "her"}; +StrPair strpair1[] = { + {"he", "he"}, {"she", "she"}, {"his", "his"}, + {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, + {"shis2", "his"}, {"ahhe", "he"} +}; +Tests test1("test 1", + dict1, sizeof(dict1)/sizeof(dict1[0]), + strpair1, sizeof(strpair1)/sizeof(strpair1[0])); + +/* test 2*/ +const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ +StrPair strpair2[] = {{"The pot had a handle", 0}}; +Tests test2("test 2", dict2, 2, strpair2, 1); + +/* test 3*/ +const char *dict3[] = {"The"}; +StrPair strpair3[] = {{"The pot had a handle", "The"}}; +Tests test3("test 3", dict3, 1, strpair3, 1); + +/* test 4*/ +const char *dict4[] = {"pot"}; +StrPair strpair4[] = {{"The pot had a handle", "pot"}}; +Tests test4("test 4", dict4, 1, strpair4, 1); + +/* test 5*/ +const char *dict5[] = {"pot "}; +StrPair strpair5[] = {{"The pot had a handle", "pot "}}; +Tests test5("test 5", dict5, 1, strpair5, 1); + +/* test 6*/ +const char *dict6[] = {"ot h"}; +StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; +Tests test6("test 6", dict6, 1, strpair6, 1); + +/* test 7*/ +const char *dict7[] = {"andle"}; +StrPair strpair7[] = {{"The pot had a handle", "andle"}}; +Tests test7("test 7", dict7, 1, strpair7, 1); diff --git a/modules/policy/lua-aho-corasick/tests/Makefile b/modules/policy/lua-aho-corasick/tests/Makefile new file mode 100644 index 0000000..54fd90f --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/Makefile @@ -0,0 +1,65 @@ +OS := $(shell uname) +ifeq ($(OS), Darwin) + SO_EXT := dylib +else + SO_EXT := so +endif + +.PHONY = all clean test runtest benchmark + +PROGRAM = ac_test +BENCHMARK = ac_bench +all: runtest + +CXXFLAGS = -O3 -g -march=native -Wall -DDEBUG +MYCXXFLAGS = -MMD -I.. $(CXXFLAGS) +%.o : %.cxx + $(CXX) $< -c $(MYCXXFLAGS) + +-include dep.cxx +SRC = test_main.cxx ac_test_simple.cxx ac_test_aggr.cxx test_bigfile.cxx + +OBJ = ${SRC:.cxx=.o} + +-include test_dep.txt +-include bench_dep.txt + +$(PROGRAM) $(BENCHMARK) : testinput/text.tar testinput/image.bin +$(PROGRAM) : $(OBJ) ../libac.$(SO_EXT) + $(CXX) $(OBJ) -L.. -lac -o $@ + -cat *.d > test_dep.txt + +$(BENCHMARK) : ac_bench.o ../libac.$(SO_EXT) + $(CXX) ac_bench.o -L.. -lac -o $@ + -cat *.d > bench_dep.txt + +ifneq ($(OS), Darwin) +runtest:$(PROGRAM) + LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* + +benchmark:$(BENCHMARK) + LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./ac_bench + +else +runtest:$(PROGRAM) + DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* + +benchmark:$(BENCHMARK) + DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./ac_bench + +endif + +testinput/text.tar: + echo "download testing files (gcc tarball)..." + if [ ! -d testinput ] ; then mkdir testinput; fi + cd testinput && \ + curl ftp://ftp.gnu.org/gnu/gcc/gcc-1.42.tar.gz -o text.tar.gz 2>/dev/null \ + && gzip -d text.tar.gz + +testinput/image.bin: + echo "download testing files.." + if [ ! -d testinput ] ; then mkdir testinput; fi + curl http://www.3dvisionlive.com/sites/default/files/Curiosity_render_hiresb.jpg -o $@ 2>/dev/null + +clean: + -rm -f *.o *.d dep.txt $(PROGRAM) $(BENCHMARK) diff --git a/modules/policy/lua-aho-corasick/tests/ac_bench.cxx b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx new file mode 100644 index 0000000..421322c --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx @@ -0,0 +1,519 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <sys/time.h> +#include <time.h> +#include <fcntl.h> +#include <unistd.h> +#include <dirent.h> +#include <libgen.h> +#include <errno.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <getopt.h> + +#include <string> +#include <vector> +#include "ac.h" +#include "ac_util.hpp" + +using namespace std; + +static bool SomethingWrong = false; + +static int iteration = 300; +static string dict_dir; +static string obj_file_dir; +static bool print_help = false; +static int piece_size = 1024; + +class PatternSet { +public: + PatternSet(const char* filepath); + ~PatternSet() { Cleanup(); } + + int getPatternNum() const { return _pat_num; } + const char** getPatternVector() const { return _patterns; } + unsigned int* getPatternLenVector() const { return _pat_len; } + + const char* getErrMessage() const { return _errmsg; } + static bool isDictFile(const char* filepath) { + if (strncmp(basename(const_cast<char*>(filepath)), "dict", 4)) + return false; + return true; + } + +private: + bool ExtractPattern(const char* filepath); + void Cleanup(); + + const char** _patterns; + unsigned int* _pat_len; + char* _mmap; + int _fd; + size_t _mmap_size; + int _pat_num; + + const char* _errmsg; +}; + +bool +PatternSet::ExtractPattern(const char* filepath) { + if (!isDictFile(filepath)) + return false; + + struct stat filestat; + if (stat(filepath, &filestat)) { + _errmsg = "fail to call stat()"; + return false; + } + + if (filestat.st_size > 4096 * 1024) { + /* It dosen't seem to be a dictionary file*/ + _errmsg = "file too big?"; + return false; + } + + _fd = open(filepath, 0); + if (_fd == -1) { + _errmsg = "fail to open dictionary file"; + return false; + } + + _mmap_size = filestat.st_size; + _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, + MAP_PRIVATE, _fd, 0); + if (_mmap == MAP_FAILED) { + _errmsg = "fail to call mmap"; + return false; + } + + const char* pat = _mmap; + vector<const char*> pat_vect; + vector<unsigned> pat_len_vect; + + for (size_t i = 0, e = filestat.st_size; i < e; i++) { + if (_mmap[i] == '\r' || _mmap[i] == '\n') { + _mmap[i] = '\0'; + int len = _mmap + i - pat; + if (len > 0) { + pat_vect.push_back(pat); + pat_len_vect.push_back(len); + } + pat = _mmap + i + 1; + } + } + + ASSERT(pat_vect.size() == pat_len_vect.size()); + + int pat_num = pat_vect.size(); + if (pat_num > 0) { + const char** p = _patterns = new const char*[pat_num]; + int i = 0; + for (vector<const char*>::iterator iter = pat_vect.begin(), + iter_e = pat_vect.end(); iter != iter_e; ++iter) { + p[i++] = *iter; + } + + i = 0; + unsigned int* q = _pat_len = new unsigned int[pat_num]; + for (vector<unsigned>::iterator iter = pat_len_vect.begin(), + iter_e = pat_len_vect.end(); iter != iter_e; ++iter) { + q[i++] = *iter; + } + } + + _pat_num = pat_num; + if (pat_num <= 0) { + _errmsg = "no pattern at all"; + return false; + } + + return true; +} + +void +PatternSet::Cleanup() { + if (_mmap != MAP_FAILED) { + munmap(_mmap, _mmap_size); + _mmap = (char*)MAP_FAILED; + _mmap_size = 0; + } + + delete[] _patterns; + delete[] _pat_len; + if (_fd != -1) + close(_fd); + _pat_num = -1; +} + +PatternSet::PatternSet(const char* filepath) { + _patterns = 0; + _pat_len = 0; + _mmap = (char*)MAP_FAILED; + _mmap_size = 0; + _pat_num = -1; + _errmsg = ""; + + if (!ExtractPattern(filepath)) + Cleanup(); +} + +bool +getFilesUnderDir(vector<string>& files, const char* path) { + files.clear(); + + DIR* dir = opendir(path); + if (!dir) + return false; + + string path_dir = path; + path_dir += "/"; + + for (;;) { + struct dirent* entry = readdir(dir); + if (entry) { + string filepath = path_dir + entry->d_name; + struct stat file_stat; + if (stat(filepath.c_str(), &file_stat)) { + closedir(dir); + return false; + } + + if (S_ISREG(file_stat.st_mode)) + files.push_back(filepath); + + continue; + } + + if (errno) { + return false; + } + break; + } + closedir(dir); + return true; +} + +class Timer { +public: + Timer() { + my_clock_gettime(&_start); + _stop = _start; + _acc.tv_sec = 0; + _acc.tv_nsec = 0; + } + + const Timer& operator += (const Timer& that) { + time_t sec = _acc.tv_sec + that._acc.tv_sec; + long nsec = _acc.tv_nsec + that._acc.tv_nsec; + if (nsec > 1000000000) { + nsec -= 1000000000; + sec += 1; + } + _acc.tv_sec = sec; + _acc.tv_nsec = nsec; + return *this; + } + + // return duration in us + size_t getDuration() const { + return _acc.tv_sec * (size_t)1000000 + _acc.tv_nsec/1000; + } + + void Start(bool acc=true) { + my_clock_gettime(&_start); + } + + void Stop() { + my_clock_gettime(&_stop); + struct timespec t = CalcDuration(); + _acc = add_duration(_acc, t); + } + +private: + int my_clock_gettime(struct timespec* t) { +#ifdef __linux + return clock_gettime(CLOCK_PROCESS_CPUTIME_ID, t); +#else + struct timeval tv; + int rc = gettimeofday(&tv, 0); + t->tv_sec = tv.tv_sec; + t->tv_nsec = tv.tv_usec * 1000; + return rc; +#endif + } + + struct timespec add_duration(const struct timespec& dur1, + const struct timespec& dur2) { + time_t sec = dur1.tv_sec + dur2.tv_sec; + long nsec = dur1.tv_nsec + dur2.tv_nsec; + if (nsec > 1000000000) { + nsec -= 1000000000; + sec += 1; + } + timespec t; + t.tv_sec = sec; + t.tv_nsec = nsec; + + return t; + } + + struct timespec CalcDuration() const { + timespec diff; + if ((_stop.tv_nsec - _start.tv_nsec)<0) { + diff.tv_sec = _stop.tv_sec - _start.tv_sec - 1; + diff.tv_nsec = 1000000000 + _stop.tv_nsec - _start.tv_nsec; + } else { + diff.tv_sec = _stop.tv_sec - _start.tv_sec; + diff.tv_nsec = _stop.tv_nsec - _start.tv_nsec; + } + return diff; + } + + struct timespec _start; + struct timespec _stop; + struct timespec _acc; +}; + +class Benchmark { +public: + Benchmark(const PatternSet& pat_set, const char* infile): + _pat_set(pat_set), _infile(infile) { + _mmap = (char*)MAP_FAILED; + _file_sz = 0; + _fd = -1; + } + + ~Benchmark() { + if (_mmap != MAP_FAILED) + munmap(_mmap, _file_sz); + if (_fd != -1) + close(_fd); + } + + bool Run(int iteration); + const Timer& getTimer() const { return _timer; } + +private: + const PatternSet& _pat_set; + const char* _infile; + char* _mmap; + int _fd; + size_t _file_sz; // input file size + Timer _timer; +}; + +bool +Benchmark::Run(int iteration) { + if (_pat_set.getPatternNum() <= 0) { + SomethingWrong = true; + return false; + } + + if (_mmap == MAP_FAILED) { + struct stat filestat; + if (stat(_infile, &filestat)) { + SomethingWrong = true; + return false; + } + + if (!S_ISREG(filestat.st_mode)) { + SomethingWrong = true; + return false; + } + + _fd = open(_infile, 0); + if (_fd == -1) + return false; + + _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, + MAP_PRIVATE, _fd, 0); + + if (_mmap == MAP_FAILED) { + SomethingWrong = true; + return false; + } + + _file_sz = filestat.st_size; + } + + ac_t* ac = ac_create(_pat_set.getPatternVector(), + _pat_set.getPatternLenVector(), + _pat_set.getPatternNum()); + if (!ac) { + SomethingWrong = true; + return false; + } + + int piece_num = _file_sz/piece_size; + + _timer.Start(false); + + /* Stupid compiler may not be able to promote piece_size into register. + * Do it manually. + */ + int piece_sz = piece_size; + for (int i = 0; i < iteration; i++) { + size_t match_ofst = 0; + for (int piece_idx = 0; piece_idx < piece_num; piece_idx ++) { + ac_match2(ac, _mmap + match_ofst, piece_sz); + match_ofst += piece_sz; + } + if (match_ofst != _file_sz) + ac_match2(ac, _mmap + match_ofst, _file_sz - match_ofst); + } + _timer.Stop(); + return true; +} + +const char* short_opt = "hd:f:i:p:"; +const struct option long_opts[] = { + {"help", no_argument, 0, 'h'}, + {"iteration", required_argument, 0, 'i'}, + {"dictionary-dir", required_argument, 0, 'd'}, + {"obj-file-dir", required_argument, 0, 'f'}, + {"piece-size", required_argument, 0, 'p'}, +}; + +static void +PrintHelp(const char* prog_name) { + const char* msg = +"Usage %s [OPTIONS]\n" +" -d, --dictionary-dir : specify the dictionary directory (./dict by default)\n" +" -f, --obj-file-dir : specify the object file directory\n" +" (./testinput by default)\n" +" -i, --iteration : Run this many iteration for each pattern match\n" +" -p, --piece-size : The size of 'piece' in byte. The input file is\n" +" divided into pieces, and match function is working\n" +" on one piece at a time. The default size of piece\n" +" is 1k byte.\n"; + + fprintf(stdout, msg, prog_name); +} + +static bool +getOptions(int argc, char** argv) { + bool dict_dir_set = false; + bool objfile_dir_set = false; + int opt_index; + + while (1) { + if (print_help) break; + + int c = getopt_long(argc, argv, short_opt, long_opts, &opt_index); + + if (c == -1) break; + if (c == 0) { c = long_opts[opt_index].val; } + + switch(c) { + case 'h': + print_help = true; + break; + + case 'i': + iteration = atol(optarg); + break; + + case 'd': + dict_dir = optarg; + dict_dir_set = true; + break; + + case 'f': + obj_file_dir = optarg; + objfile_dir_set = true; + break; + + case 'p': + piece_size = atol(optarg); + break; + + case '?': + default: + return false; + } + } + + if (print_help) + return true; + + string basedir(dirname(argv[0])); + if (!dict_dir_set) + dict_dir = basedir + "/dict"; + + if (!objfile_dir_set) + obj_file_dir = basedir + "/testinput"; + + return true; +} + +int +main(int argc, char** argv) { + if (!getOptions(argc, argv)) + return -1; + + if (print_help) { + PrintHelp(argv[0]); + return 0; + } + +#ifndef __linux + fprintf(stdout, "\n!!!WARNING: On this OS, the execution time is measured" + " by gettimeofday(2) which is imprecise!!!\n\n"); +#endif + + fprintf(stdout, "Test with iteration = %d, piece size = %d, and", + iteration, piece_size); + fprintf(stdout, "\n dictionary dir = %s\n object file dir = %s\n\n", + dict_dir.c_str(), obj_file_dir.c_str()); + + vector<string> dict_files; + vector<string> input_files; + + if (!getFilesUnderDir(dict_files, dict_dir.c_str())) { + fprintf(stdout, "fail to find dictionary files\n"); + return -1; + } + + if (!getFilesUnderDir(input_files, obj_file_dir.c_str())) { + fprintf(stdout, "fail to find test input files\n"); + return -1; + } + + for (vector<string>::iterator diter = dict_files.begin(), + diter_e = dict_files.end(); diter != diter_e; ++diter) { + + const char* dict_name = diter->c_str(); + if (!PatternSet::isDictFile(dict_name)) + continue; + + PatternSet ps(dict_name); + if (ps.getPatternNum() <= 0) { + fprintf(stdout, "fail to open dictionary file %s : %s\n", + dict_name, ps.getErrMessage()); + SomethingWrong = true; + continue; + } + + fprintf(stdout, "Using dictionary %s\n", dict_name); + Timer timer; + for (vector<string>::iterator iter = input_files.begin(), + iter_e = input_files.end(); iter != iter_e; ++iter) { + fprintf(stdout, " testing %s ... ", iter->c_str()); + fflush(stdout); + Benchmark bm(ps, iter->c_str()); + bm.Run(iteration); + const Timer& t = bm.getTimer(); + timer += bm.getTimer(); + fprintf(stdout, "elapsed %.3f\n", t.getDuration() / 1000000.0); + } + + fprintf(stdout, + "\n==========================================================\n" + " Total Elapse %.3f\n\n", timer.getDuration() / 1000000.0); + } + + return SomethingWrong ? -1 : 0; +} diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx new file mode 100644 index 0000000..4ea02bc --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx @@ -0,0 +1,135 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + +namespace { +class ACBigFileTester : public BigFileTester { +public: + ACBigFileTester(const char* filepath) : BigFileTester(filepath){}; + +private: + virtual buf_header_t* PM_Create(const char** strv, uint32* strlenv, + uint32 vect_len) { + return (buf_header_t*)ac_create(strv, strlenv, vect_len); + } + + virtual void PM_Free(buf_header_t* PM) { ac_free(PM); } + virtual bool Run_Helper(buf_header_t* PM); +}; + +class ACTestAggressive: public ACTestBase { +public: + ACTestAggressive(const vector<const char*>& files, const char* banner) + : ACTestBase(banner), _files(files) {} + virtual bool Run(); + +private: + void PrintSummary(int total, int fail) { + fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); + fflush(stdout); + } + vector<const char*> _files; +}; + +} // end of anonymous namespace + +bool +ACBigFileTester::Run_Helper(buf_header_t* PM) { + int fail = 0; + // advance one chunk at a time. + int len = _msg_len; + int chunk_sz = _chunk_sz; + + vector<const char*> c_style_keys; + for (int i = 0, e = _keys.size(); i != e; i++) { + const char* key = _keys[i].first; + int len = _keys[i].second; + char *t = new char[len+1]; + memcpy(t, key, len); + t[len] = '\0'; + c_style_keys.push_back(t); + } + + for (int ofst = 0, chunk_idx = 0, chunk_num = _chunk_num; + chunk_idx < chunk_num; ofst += chunk_sz, chunk_idx++) { + const char* substring = _msg + ofst; + ac_result_t r = ac_match((ac_t*)(void*)PM, substring , len - ofst); + int m_b = r.match_begin; + int m_e = r.match_end; + + if (m_b < 0 || m_e < 0 || m_e <= m_b || m_e >= len) { + fprintf(stdout, "fail to find match substring[%d:%d])\n", + ofst, len - 1); + fail ++; + continue; + } + + const char* match_str = _msg + len; + int strstr_len = 0; + int key_idx = -1; + + for (int i = 0, e = c_style_keys.size(); i != e; i++) { + const char* key = c_style_keys[i]; + if (const char *m = strstr(substring, key)) { + if (m < match_str) { + match_str = m; + strstr_len = _keys[i].second; + key_idx = i; + } + } + } + ASSERT(key_idx != -1); + if ((match_str - substring != m_b)) { + fprintf(stdout, + "Fail to find match substring[%d:%d])," + " expected to find match at offset %d instead of %d\n", + ofst, len - 1, + (int)(match_str - _msg), ofst + m_b); + fprintf(stdout, "%d vs %d (key idx %d)\n", strstr_len, m_e - m_b + 1, key_idx); + PrintStr(stdout, match_str, strstr_len); + fprintf(stdout, "\n"); + PrintStr(stdout, _msg + ofst + m_b, + m_e - m_b + 1); + fprintf(stdout, "\n"); + fail ++; + } + } + for (vector<const char*>::iterator i = c_style_keys.begin(), + e = c_style_keys.end(); i != e; i++) { + delete[] *i; + } + + return fail == 0; +} + +bool +ACTestAggressive::Run() { + int fail = 0; + for (vector<const char*>::iterator i = _files.begin(), e = _files.end(); + i != e; i++) { + ACBigFileTester bft(*i); + if (!bft.Run()) + fail ++; + } + return fail == 0; +} + +bool +Run_AC_Aggressive_Test(const vector<const char*>& files) { + ACTestAggressive t(files, "AC Aggressive test"); + t.PrintBanner(); + return t.Run(); +} diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx new file mode 100644 index 0000000..fa2d7fd --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx @@ -0,0 +1,275 @@ +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + +namespace { +typedef struct { + const char* str; + const char* match; +} StrPair; + +typedef enum { + MV_FIRST_MATCH = 0, + MV_LEFT_LONGEST = 1, +} MatchVariant; + +typedef struct { + const char* name; + const char** dict; + StrPair* strpairs; + int dict_len; + int strpair_num; + MatchVariant match_variant; +} TestingCase; + +class Tests { +public: + Tests(const char* name, + const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num, + MatchVariant mv = MV_FIRST_MATCH) { + if (!_tests) + _tests = new vector<TestingCase>; + + TestingCase tc; + tc.name = name; + tc.dict = dict; + tc.strpairs = strpairs; + tc.dict_len = dict_len; + tc.strpair_num = strpair_num; + tc.match_variant = mv; + _tests->push_back(tc); + } + + static vector<TestingCase>* Get_Tests() { return _tests; } + static void Erase_Tests() { delete _tests; _tests = 0; } + +private: + static vector<TestingCase> *_tests; +}; + +class LeftLongestTests : public Tests { +public: + LeftLongestTests (const char* name, const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num): + Tests(name, dict, dict_len, strpairs, strpair_num, MV_LEFT_LONGEST) { + } +}; + +vector<TestingCase>* Tests::_tests = 0; + +class ACTestSimple: public ACTestBase { +public: + ACTestSimple(const char* banner) : ACTestBase(banner) {} + virtual bool Run(); + +private: + void PrintSummary(int total, int fail) { + fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); + fflush(stdout); + } +}; +} + +bool +ACTestSimple::Run() { + int total = 0; + int fail = 0; + + vector<TestingCase> *tests = Tests::Get_Tests(); + if (!tests) { + PrintSummary(0, 0); + return true; + } + + for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end(); + i != e; i++) { + TestingCase& t = *i; + int dict_len = t.dict_len; + unsigned int* strlen_v = new unsigned int[dict_len]; + + fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); + for (int i = 0, need_break=0; i < dict_len; i++) { + const char* s = t.dict[i]; + fprintf(stdout, "%s, ", s); + strlen_v[i] = strlen(s); + if (need_break++ == 16) { + fputs("\n ", stdout); + need_break = 0; + } + } + fputs("]\n", stdout); + + /* Create the dictionary */ + ac_t* ac = ac_create(t.dict, strlen_v, dict_len); + delete[] strlen_v; + + for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { + const StrPair& sp = t.strpairs[ii]; + const char *str = sp.str; // the string to be matched + const char *match = sp.match; + + fprintf(stdout, "[%3d] Testing '%s' : ", total, str); + + int len = strlen(str); + ac_result_t r; + if (t.match_variant == MV_FIRST_MATCH) + r = ac_match(ac, str, len); + else if (t.match_variant == MV_LEFT_LONGEST) + r = ac_match_longest_l(ac, str, len); + else { + ASSERT(false && "Unknown variant"); + } + + int m_b = r.match_begin; + int m_e = r.match_end; + + // The return value per se is insane. + if (m_b > m_e || + ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { + fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); + fail ++; + continue; + } + + // If the string is not supposed to match the dictionary. + if (!match) { + if (m_b != -1 || m_e != -1) { + fail ++; + fprintf(stdout, "Not Supposed to match (%d, %d) \n", + m_b, m_e); + } else + fputs("Pass\n", stdout); + continue; + } + + // The string or its substring is match the dict. + if (m_b >= len || m_b >= len) { + fail ++; + fprintf(stdout, + "Return value >= the length of the string (%d, %d)\n", + m_b, m_e); + continue; + } else { + int mlen = strlen(match); + if ((mlen != m_e - m_b + 1) || + strncmp(str + m_b, match, mlen)) { + fail ++; + fprintf(stdout, "Fail\n"); + } else + fprintf(stdout, "Pass\n"); + } + } + fputs("\n", stdout); + ac_free(ac); + } + + PrintSummary(total, fail); + return fail == 0; +} + +bool +Run_AC_Simple_Test() { + ACTestSimple t("AC Simple test"); + t.PrintBanner(); + return t.Run(); +} + +////////////////////////////////////////////////////////////////////////////// +// +// Testing cases for first-match variant (i.e. test ac_match()) +// +////////////////////////////////////////////////////////////////////////////// +// + +/* test 1*/ +const char *dict1[] = {"he", "she", "his", "her"}; +StrPair strpair1[] = { + {"he", "he"}, {"she", "she"}, {"his", "his"}, + {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, + {"shis2", "his"}, {"ahhe", "he"} +}; +Tests test1("test 1", + dict1, sizeof(dict1)/sizeof(dict1[0]), + strpair1, sizeof(strpair1)/sizeof(strpair1[0])); + +/* test 2*/ +const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ +StrPair strpair2[] = {{"The pot had a handle", 0}}; +Tests test2("test 2", dict2, 2, strpair2, 1); + +/* test 3*/ +const char *dict3[] = {"The"}; +StrPair strpair3[] = {{"The pot had a handle", "The"}}; +Tests test3("test 3", dict3, 1, strpair3, 1); + +/* test 4*/ +const char *dict4[] = {"pot"}; +StrPair strpair4[] = {{"The pot had a handle", "pot"}}; +Tests test4("test 4", dict4, 1, strpair4, 1); + +/* test 5*/ +const char *dict5[] = {"pot "}; +StrPair strpair5[] = {{"The pot had a handle", "pot "}}; +Tests test5("test 5", dict5, 1, strpair5, 1); + +/* test 6*/ +const char *dict6[] = {"ot h"}; +StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; +Tests test6("test 6", dict6, 1, strpair6, 1); + +/* test 7*/ +const char *dict7[] = {"andle"}; +StrPair strpair7[] = {{"The pot had a handle", "andle"}}; +Tests test7("test 7", dict7, 1, strpair7, 1); + +const char *dict8[] = {"aaab"}; +StrPair strpair8[] = {{"aaaaaaab", "aaab"}}; +Tests test8("test 8", dict8, 1, strpair8, 1); + +const char *dict9[] = {"haha", "z"}; +StrPair strpair9[] = {{"aaaaz", "z"}, {"z", "z"}}; +Tests test9("test 9", dict9, 2, strpair9, 2); + +/* test the case when input string dosen't contain even a single char + * of the pattern in dictionary. + */ +const char *dict10[] = {"abc"}; +StrPair strpair10[] = {{"cde", 0}}; +Tests test10("test 10", dict10, 1, strpair10, 1); + + +////////////////////////////////////////////////////////////////////////////// +// +// Testing cases for first longest match variant (i.e. +// test ac_match_longest_l()) +// +////////////////////////////////////////////////////////////////////////////// +// + +// This was actually first motivation for left-longest-match +const char *dict100[] = {"Mozilla", "Mozilla Mobile"}; +StrPair strpair100[] = {{"User Agent containing string Mozilla Mobile", "Mozilla Mobile"}}; +LeftLongestTests test100("l_test 100", dict100, 2, strpair100, 1); + +// Dict with single char is tricky +const char *dict101[] = {"a", "abc"}; +StrPair strpair101[] = {{"abcdef", "abc"}}; +LeftLongestTests test101("l_test 101", dict101, 2, strpair101, 1); + +// Testing case with partially overlapping patterns. The purpose is to +// check if the fail-link leading from terminal state is correct. +// +// The fail-link leading from terminal-state does not matter in +// match-first-occurrence variant, as it stop when a terminal is hit. +// +const char *dict102[] = {"abc", "bcdef"}; +StrPair strpair102[] = {{"abcdef", "bcdef"}}; +LeftLongestTests test102("l_test 102", dict102, 2, strpair102, 1); diff --git a/modules/policy/lua-aho-corasick/tests/dict/README.txt b/modules/policy/lua-aho-corasick/tests/dict/README.txt new file mode 100644 index 0000000..cd50b41 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/dict/README.txt @@ -0,0 +1 @@ +This directory contains pattern set of benchmark purpose. diff --git a/modules/policy/lua-aho-corasick/tests/dict/dict1.txt b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt new file mode 100644 index 0000000..94085a9 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt @@ -0,0 +1,11 @@ +false_return@ +forloop#haha +wtfprogram +mmaporunmap +ThIs?Module!IsEssential +struct rtlwtf +gettIMEOfdayWrong +edistribution_and_use_in_@source +Copyright~#@ +while {! +!%SQLinje diff --git a/modules/policy/lua-aho-corasick/tests/load_ac_test.lua b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua new file mode 100644 index 0000000..7fb7db9 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua @@ -0,0 +1,82 @@ +-- This script is to test load_ac.lua +-- +-- Some notes: +-- 1. The purpose of this script is not to check if the libac.so work +-- properly, it is to check if there are something stupid in load_ac.lua +-- +-- 2. There are bunch of collectgarbage() calls, the purpose is to make +-- sure the shared lib is not unloaded after GC. + +-- load_ac.lua looks up libac.so via package.cpath rather than LD_LIBRARY_PATH, +-- prepend (instead of appending) some insane paths here to see if it quit +-- prematurely. +-- +package.cpath = ".;./?.so;" .. package.cpath + +local ac = require "load_ac" + +local ac_create = ac.create_ac +local ac_match = ac.match +local string_fmt = string.format +local string_sub = string.sub + +local err_cnt = 0 +local function mytest(testname, dict, match, notmatch) + print(">Testing ", testname) + + io.write(string_fmt("Dictionary: ")); + for i=1, #dict do + io.write(string_fmt("%s, ", dict[i])) + end + print "" + + local ac_inst = ac_create(dict); + collectgarbage() + for i=1, #match do + local str = match[i] + io.write(string_fmt("Matching %s, ", str)) + local b = ac_match(ac_inst, str) + if b then + print "pass" + else + err_cnt = err_cnt + 1 + print "fail" + end + collectgarbage() + end + + if notmatch == nil then + return + end + + collectgarbage() + + for i = 1, #notmatch do + local str = notmatch[i] + io.write(string_fmt("*Matching %s, ", str)) + local r = ac_match(ac_inst, str) + if r then + err_cnt = err_cnt + 1 + print("fail") + else + print("succ") + end + collectgarbage() + end + ac_inst = nil + collectgarbage() +end + +print("") +print("====== Test to see if load_ac.lua works properly ========") + +mytest("test1", + {"he", "she", "his", "her", "str\0ing"}, + -- matching cases + { "he", "she", "his", "hers", "ahe", "shhe", "shis2", "ahhe", "str\0ing" }, + + -- not matching case + {"str\0", "str"} + ) + +os.exit((err_cnt == 0) and 0 or 1) diff --git a/modules/policy/lua-aho-corasick/tests/lua_test.lua b/modules/policy/lua-aho-corasick/tests/lua_test.lua new file mode 100644 index 0000000..cfe178f --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/lua_test.lua @@ -0,0 +1,67 @@ +-- This script is to test ahocorasick.so not libac.so +-- +local ac = require "ahocorasick" + +local ac_create = ac.create +local ac_match = ac.match +local string_fmt = string.format +local string_sub = string.sub + +local err_cnt = 0 +local function mytest(testname, dict, match, notmatch) + print(">Testing ", testname) + + io.write(string_fmt("Dictionary: ")); + for i=1, #dict do + io.write(string_fmt("%s, ", dict[i])) + end + print "" + + local ac_inst = ac_create(dict); + for i=1, #match do + local str = match[i][1] + local substr = match[i][2] + io.write(string_fmt("Matching %s, ", str)) + local b, e = ac_match(ac_inst, str) + if b and e and (string_sub(str, b+1, e+1) == substr) then + print "pass" + else + err_cnt = err_cnt + 1 + print "fail" + end + --print("gc is called") + collectgarbage() + end + + if notmatch == nil then + return + end + + for i = 1, #notmatch do + local str = notmatch[i] + io.write(string_fmt("*Matching %s, ", str)) + local r = ac_match(ac_inst, str) + if r then + err_cnt = err_cnt + 1 + print("fail") + else + print("succ") + end + collectgarbage() + end +end + +mytest("test1", + {"he", "she", "his", "her", "str\0ing"}, + -- matching cases + { {"he", "he"}, {"she", "she"}, {"his", "his"}, {"hers", "he"}, + {"ahe", "he"}, {"shhe", "he"}, {"shis2", "his"}, {"ahhe", "he"}, + {"str\0ing", "str\0ing"} + }, + + -- not matching case + {"str\0", "str"} + + ) + +os.exit((err_cnt == 0) and 0 or 1) diff --git a/modules/policy/lua-aho-corasick/tests/test_base.hpp b/modules/policy/lua-aho-corasick/tests/test_base.hpp new file mode 100644 index 0000000..7758371 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_base.hpp @@ -0,0 +1,60 @@ +#ifndef TEST_BASE_H +#define TEST_BASE_H + +#include <stdio.h> +#include <string> +#include <stdint.h> + +using namespace std; +class ACTestBase { +public: + ACTestBase(const char* name) :_banner(name) {} + virtual void PrintBanner() { + fprintf(stdout, "\n===== %s ====\n", _banner.c_str()); + } + + virtual bool Run() = 0; +private: + string _banner; +}; + +typedef std::pair<const char*, int> StrInfo; +class BigFileTester { +public: + BigFileTester(const char* filepath); + virtual ~BigFileTester() { Cleanup(); } + + bool Run(); + +protected: + virtual buf_header_t* PM_Create(const char** strv, uint32_t* strlenv, + uint32_t vect_len) = 0; + virtual void PM_Free(buf_header_t*) = 0; + virtual bool Run_Helper(buf_header_t* PM) = 0; + + // Return true if the '\0' is valid char of a string. + virtual bool Str_C_Style() { return true; } + + bool GenerateKeys(); + void Cleanup(); + void PrintStr(FILE*, const char* str, int len); + +protected: + const char* _filepath; + int _fd; + vector<StrInfo> _keys; + char* _msg; + int _msg_len; + int _key_num; // number of strings in dictionary + int _chunk_sz; + int _chunk_num; + + int _max_key_num; + int _key_min_len; + int _key_max_len; +}; + +extern bool Run_AC_Simple_Test(); +extern bool Run_AC_Aggressive_Test(const vector<const char*>& files); + +#endif diff --git a/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx new file mode 100644 index 0000000..f189d8d --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx @@ -0,0 +1,167 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +/////////////////////////////////////////////////////////////////////////// +// +// Implementation of BigFileTester +// +/////////////////////////////////////////////////////////////////////////// +// +BigFileTester::BigFileTester(const char* filepath) { + _filepath = filepath; + _fd = -1; + _msg = (char*)MAP_FAILED; + _msg_len = 0; + _key_num = 0; + _chunk_sz = 0; + _chunk_num = 0; + + _max_key_num = 100; + _key_min_len = 20; + _key_max_len = 80; +} + +void +BigFileTester::Cleanup() { + if (_msg != MAP_FAILED) { + munmap((void*)_msg, _msg_len); + _msg = (char*)MAP_FAILED; + _msg_len = 0; + } + + if (_fd != -1) { + close(_fd); + _fd = -1; + } +} + +bool +BigFileTester::GenerateKeys() { + int chunk_sz = 4096; + int max_key_num = _max_key_num; + int key_min_len = _key_min_len; + int key_max_len = _key_max_len; + + int t = _msg_len / chunk_sz; + int keynum = t > max_key_num ? max_key_num : t; + + if (keynum <= 4) { + // file is too small + return false; + } + chunk_sz = _msg_len / keynum; + _chunk_sz = chunk_sz; + + // For each chunck, "randomly" grab a sub-string searving + // as key. + int random_ofst[] = { 12, 30, 23, 15 }; + int rofstsz = sizeof(random_ofst)/sizeof(random_ofst[0]); + int ofst = 0; + const char* msg = _msg; + _chunk_num = keynum - 1; + for (int idx = 0, e = _chunk_num; idx < e; idx++) { + const char* key = msg + ofst + idx % rofstsz; + int key_len = key_min_len + idx % (key_max_len - key_min_len); + _keys.push_back(StrInfo(key, key_len)); + ofst += chunk_sz; + } + return true; +} + +bool +BigFileTester::Run() { + // Step 1: Bring the file into memory + fprintf(stdout, "Testing using file '%s'...\n", _filepath); + + int fd = _fd = ::open(_filepath, O_RDONLY); + if (fd == -1) { + perror("open"); + return false; + } + + struct stat sb; + if (fstat(fd, &sb) == -1) { + perror("fstat"); + return false; + } + + if (!S_ISREG (sb.st_mode)) { + fprintf(stderr, "%s is not regular file\n", _filepath); + return false; + } + + int ten_M = 1024 * 1024 * 10; + int map_sz = _msg_len = sb.st_size > ten_M ? ten_M : sb.st_size; + char* p = _msg = + (char*)mmap (0, map_sz, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); + if (p == MAP_FAILED) { + perror("mmap"); + return false; + } + + // Get rid of '\0' if we are picky at it. + if (Str_C_Style()) { + for (int i = 0; i < map_sz; i++) { if (!p[i]) p[i] = 'a'; } + p[map_sz - 1] = 0; + } + + // Step 2: "Fabricate" some keys from the file. + if (!GenerateKeys()) { + close(fd); + return false; + } + + // Step 3: Create PM instance + const char** keys = new const char*[_keys.size()]; + unsigned int* keylens = new unsigned int[_keys.size()]; + + int i = 0; + for (vector<StrInfo>::iterator si = _keys.begin(), se = _keys.end(); + si != se; si++, i++) { + const StrInfo& strinfo = *si; + keys[i] = strinfo.first; + keylens[i] = strinfo.second; + } + + buf_header_t* PM = PM_Create(keys, keylens, i); + delete[] keys; + delete[] keylens; + + // Step 4: Run testing + bool res = Run_Helper(PM); + PM_Free(PM); + + // Step 5: Clanup + munmap(p, map_sz); + _msg = (char*)MAP_FAILED; + close(fd); + _fd = -1; + + fprintf(stdout, "%s\n", res ? "succ" : "fail"); + return res; +} + +void +BigFileTester::PrintStr(FILE* f, const char* str, int len) { + fprintf(f, "{"); + for (int i = 0; i < len; i++) { + unsigned char c = str[i]; + if (isprint(c)) + fprintf(f, "'%c', ", c); + else + fprintf(f, "%#x, ", c); + } + fprintf(f, "}"); +}; diff --git a/modules/policy/lua-aho-corasick/tests/test_main.cxx b/modules/policy/lua-aho-corasick/tests/test_main.cxx new file mode 100644 index 0000000..b4f5225 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_main.cxx @@ -0,0 +1,33 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + + +///////////////////////////////////////////////////////////////////////// +// +// Simple (yet maybe tricky) testings +// +///////////////////////////////////////////////////////////////////////// +// +int +main (int argc, char** argv) { + bool succ = Run_AC_Simple_Test(); + + vector<const char*> files; + for (int i = 1; i < argc; i++) { files.push_back(argv[i]); } + succ = Run_AC_Aggressive_Test(files) && succ; + + return succ ? 0 : -1; +}; |