diff options
Diffstat (limited to '')
38 files changed, 4770 insertions, 0 deletions
diff --git a/modules/policy/README.rst b/modules/policy/README.rst new file mode 100644 index 0000000..a68ad4d --- /dev/null +++ b/modules/policy/README.rst @@ -0,0 +1,284 @@ +.. _mod-policy: + +Query policies +-------------- + +This module can block, rewrite, or alter inbound queries based on user-defined policies. + +Each policy *rule* has two parts: a *filter* and an *action*. A *filter* selects which queries will be affected by the policy, and *action* which modifies queries matching the associated filter. + +Typically a rule is defined as follows: ``filter(action(action parameters), filter parameters)``. For example, a filter can be ``suffix`` which matches queries whose suffix part is in specified set, and one of possible actions is ``DENY``, which denies resolution. These are combined together into ``policy.suffix(policy.DENY, {todname('badguy.example.')})``. The rule is effective when it is added into rule table using ``policy.add()``, please see `Policy examples`_. + +This module is enabled by default because it implements mandatory :rfc:`6761` logic. +When no rule applies to a query, built-in rules for `special-use <https://www.iana.org/assignments/special-use-domain-names/special-use-domain-names.xhtml>`_ and `locally-served <http://www.iana.org/assignments/locally-served-dns-zones>`_ domain names are applied. +These rules can be overriden by action ``PASS``, see `Policy examples`_ below. For debugging purposes you can also add ``modules.unload('policy')`` to your config to unload the module. + + +Filters +^^^^^^^ +A *filter* selects which queries will be affected by specified *action*. There are several policy filters available in the ``policy.`` table: + +* ``all(action)`` + - always applies the action +* ``pattern(action, pattern)`` + - applies the action if QNAME matches a `regular expression <http://lua-users.org/wiki/PatternsTutorial>`_ +* ``suffix(action, table)`` + - applies the action if QNAME suffix matches one of suffixes in the table (useful for "is domain in zone" rules), + uses `Aho-Corasick`_ string matching algorithm `from CloudFlare <https://github.com/cloudflare/lua-aho-corasick>`_ (BSD 3-clause) +* :any:`policy.suffix_common` +* ``rpz(default_action, path)`` + - implements a subset of RPZ_ in zonefile format. See below for details: :any:`policy.rpz`. +* custom filter function + +.. _mod-policy-actions: + +Actions +^^^^^^^ +An *action* is function which modifies DNS query, and is either of type *chain* or *non-chain*. So-called *chain* actions modify the query and allow other rules to evaluate and modify the same query. *Non-chain* actions have opposite behavior, i.e. modify the query and stop rule processing. + +Resolver comes with several actions available in the ``policy.`` table: + +**Non-chain actions** + +Following actions stop the policy matching on the query, i.e. other rules are not evaluated once rule with following actions matches: + +* ``PASS`` - let the query pass through; it's useful to make exceptions before wider rules +* ``DENY`` - reply NXDOMAIN authoritatively +* ``DENY_MSG(msg)`` - reply NXDOMAIN authoritatively and add explanatory message to additional section +* ``DROP`` - terminate query resolution and return SERVFAIL to the requestor +* ``REFUSE`` - terminate query resolution and return REFUSED to the requestor +* ``TC`` - set TC=1 if the request came through UDP, forcing client to retry with TCP +* ``FORWARD(ip)`` - resolve a query via forwarding to an IP while validating and caching locally; +* ``TLS_FORWARD({{ip, authentication}})`` - resolve a query via TLS connection forwarding to an IP while validating and caching locally; + the parameter can be a single IP (string) or a lua list of up to four IPs. +* ``STUB(ip)`` - similar to ``FORWARD(ip)`` but *without* attempting DNSSEC validation. + Each request may be either answered from cache or simply sent to one of the IPs with proxying back the answer. +* ``REROUTE({{subnet,target}, ...})`` - reroute addresses in response matching given subnet to given target, e.g. ``{'192.0.2.0/24', '127.0.0.0'}`` will rewrite '192.0.2.55' to '127.0.0.55', see :ref:`renumber module <mod-renumber>` for more information. + + +**Chain actions** + +Following actions allow to keep trying to match other rules, until a non-chain action is triggered: + +* ``MIRROR(ip)`` - mirror query to given IP and continue solving it (useful for partial snooping). +* ``QTRACE`` - pretty-print DNS response packets into the log for the query and its sub-queries. It's useful for debugging weird DNS servers. +* ``FLAGS(set, clear)`` - set and/or clear some flags for the query. There can be multiple flags to set/clear. You can just pass a single flag name (string) or a set of names. + + +Also, it is possible to write your own action (i.e. Lua function). It is possible to implement complex heuristics, e.g. to deflect `Slow drip DNS attacks <https://secure64.com/water-torture-slow-drip-dns-ddos-attack>`_ or gray-list resolution of misbehaving zones. + +.. warning:: The policy module currently only looks at whole DNS requests. The rules won't be re-applied e.g. when following CNAMEs. + +.. note:: The module (and ``kres``) expects domain names in wire format, not textual representation. So each label in name is prefixed with its length, e.g. "example.com" equals to ``"\7example\3com"``. You can use convenience function ``todname('example.com')`` for automatic conversion. + +Forwarding over TLS protocol (DNS-over-TLS) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Policy `TLS_FORWARD` allows you to forward queries using `Transport Layer Security`_ protocol, which hides the content of your queries from an attacker observing the network traffic. Further details about this protocol can be found in :rfc:`7858` and `IETF draft dprive-dtls-and-tls-profiles`_. + +Queries affected by `TLS_FORWARD` policy will always be resolved over TLS connection. Knot Resolver does not implement fallback to non-TLS connection, so if TLS connection cannot be established or authenticated according to the configuration, the resolution will fail. + +To test this feature you need to either :ref:`configure Knot Resolver as DNS-over-TLS server <tls-server-config>`, or pick some public DNS-over-TLS server. Please see `DNS Privacy Project`_ homepage for list of public servers. + +When multiple servers are specified, the one with the lowest round-trip time is used. + +CA+hostname authentication +~~~~~~~~~~~~~~~~~~~~~~~~~~ +Traditional PKI authentication requires server to present certificate with specified hostname, which is issued by one of trusted CAs. Example policy is: + +.. code-block:: lua + + policy.TLS_FORWARD({ + {'2001:DB8::d0c', hostname='res.example.com'}}) + +- `hostname` must exactly match hostname in server's certificate, i.e. in most cases it must not contain trailing dot (`res.example.com`). +- System CA certificate store will be used if no `ca_file` option is specified. +- Optional `ca_file` option can specify path to CA certificate (or certificate bundle) in `PEM format`_. + +TLS Examples +~~~~~~~~~~~~ + +.. code-block:: lua + + modules = { 'policy' } + -- forward all queries over TLS to the specified server + policy.add(policy.all(policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}}))) + -- for brevity, other TLS examples omit policy.add(policy.all()) + -- single server authenticated using its certificate pin_sha256 + policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}}) -- pin_sha256 is base64-encoded + -- single server authenticated using hostname and system-wide CA certificates + policy.TLS_FORWARD({{'192.0.2.1', hostname='res.example.com'}}) + -- single server using non-standard port + policy.TLS_FORWARD({{'192.0.2.1@443', pin_sha256='YQ=='}}) -- use @ or # to specify port + -- single server with multiple valid pins (e.g. anycast) + policy.TLS_FORWARD({{'192.0.2.1', pin_sha256={'YQ==', 'Wg=='}}) + -- multiple servers, each with own authenticator + policy.TLS_FORWARD({ -- please note that { here starts list of servers + {'192.0.2.1', pin_sha256='Wg=='}, + -- server must present certificate issued by specified CA and hostname must match + {'2001:DB8::d0c', hostname='res.example.com', ca_file='/etc/knot-resolver/tlsca.crt'} + }) + +.. _policy_examples: + +Policy examples +^^^^^^^^^^^^^^^ + +.. code-block:: lua + + -- Whitelist 'www[0-9].badboy.cz' + policy.add(policy.pattern(policy.PASS, '\4www[0-9]\6badboy\2cz')) + -- Block all names below badboy.cz + policy.add(policy.suffix(policy.DENY, {todname('badboy.cz.')})) + + -- Custom rule + local ffi = require('ffi') + local function genRR (state, req) + local answer = req.answer + local qry = req:current() + if qry.stype ~= kres.type.A then + return state + end + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\192\168\1\3') + return kres.DONE + end + policy.add(policy.suffix(genRR, { todname('my.example.cz.') })) + + -- Disallow ANY queries + policy.add(function (req, query) + if query.stype == kres.type.ANY then + return policy.DROP + end + end) + -- Enforce local RPZ + policy.add(policy.rpz(policy.DENY, 'blacklist.rpz')) + -- Forward all queries below 'company.se' to given resolver; + -- beware: typically this won't work due to DNSSEC - see "Replacing part..." below + policy.add(policy.suffix(policy.FORWARD('192.168.1.1'), {todname('company.se')})) + -- Forward reverse queries about the 192.168.1.1/24 space to .1 port 5353 + -- and do it directly without attempts to validate DNSSEC etc. + policy.add(policy.suffix(policy.STUB('192.168.1.1@5353'), {todname('1.168.192.in-addr.arpa')})) + -- Forward all queries matching pattern + policy.add(policy.pattern(policy.FORWARD('2001:DB8::1'), '\4bad[0-9]\2cz')) + -- Forward all queries (to public resolvers https://www.nic.cz/odvr) + policy.add(policy.all(policy.FORWARD({'2001:678:1::206', '193.29.206.206'}))) + -- Print all responses with matching suffix + policy.add(policy.suffix(policy.QTRACE, {todname('rhybar.cz.')})) + -- Print all responses + policy.add(policy.all(policy.QTRACE)) + -- Mirror all queries and retrieve information + local rule = policy.add(policy.all(policy.MIRROR('127.0.0.2'))) + -- Print information about the rule + print(string.format('id: %d, matched queries: %d', rule.id, rule.count) + -- Reroute all addresses found in answer from 192.0.2.0/24 to 127.0.0.x + -- this policy is enforced on answers, therefore 'postrule' + local rule = policy.add(policy.REROUTE({'192.0.2.0/24', '127.0.0.0'}), true) + -- Delete rule that we just created + policy.del(rule.id) + + +Replacing part of the DNS tree +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You may want to resolve most of the DNS namespace by usual means while letting some other resolver solve specific subtrees. +Such data would typically be rejected by DNSSEC validation starting from the ICANN root keys. Therefore, if you trust the resolver and your link to it, you can simply use the ``STUB`` action instead of ``FORWARD`` to avoid validation only for those subtrees. + +Another issue is caused by caching, because Knot Resolver only keeps a single cache for everything. +For example, if you add an alternative top-level domain while using the ICANN root zone for the rest, at some point the cache may obtain records proving that your top-level domain does not exist, and those records could then be used when the positive records fall out of cache. The easiest work-around is to disable reading from cache for those subtrees; the other resolver is often very close anyway. + + +.. code-block:: lua + :caption: Example configuration: graft DNS sub-trees ``faketldtest``, ``sld.example``, and ``internal.example.com`` into existing namespace + + extraTrees = policy.todnames({'faketldtest', 'sld.example', 'internal.example.com'}) + -- Beware: the rule order is important, as STUB is not a chain action. + policy.add(policy.suffix(policy.FLAGS({'NO_CACHE'}), extraTrees)) + policy.add(policy.suffix(policy.STUB({'2001:db8::1'}), extraTrees)) + + +Additional properties +^^^^^^^^^^^^^^^^^^^^^ + +Most properties (actions, filters) are described above. + +.. function:: policy.add(rule, postrule) + + :param rule: added rule, i.e. ``policy.pattern(policy.DENY, '[0-9]+\2cz')`` + :param postrule: boolean, if true the rule will be evaluated on answer instead of query + :return: rule description + + Add a new policy rule that is executed either or queries or answers, depending on the ``postrule`` parameter. You can then use the returned rule description to get information and unique identifier for the rule, as well as match count. + +.. function:: policy.del(id) + + :param id: identifier of a given rule + :return: boolean + + Remove a rule from policy list. + +.. function:: policy.suffix_common(action, suffix_table[, common_suffix]) + + :param action: action if the pattern matches QNAME + :param suffix_table: table of valid suffixes + :param common_suffix: common suffix of entries in suffix_table + + Like suffix match, but you can also provide a common suffix of all matches for faster processing (nil otherwise). + This function is faster for small suffix tables (in the order of "hundreds"). + +.. function:: policy.rpz(action, path) + + :param action: the default action for match in the zone; typically you want ``policy.DENY`` + :param path: path to zone file | database + + Enforce RPZ_ rules. This can be used in conjunction with published blocklist feeds. + The RPZ_ operation is well described in this `Jan-Piet Mens's post`_, + or the `Pro DNS and BIND`_ book. Here's compatibility table: + + .. csv-table:: + :header: "Policy Action", "RH Value", "Support" + + "``action`` is used", "``.``", "**yes**, if ``action`` is ``DENY``" + "``action`` is used ", "``*.``", "*partial* [#]_" + "``policy.PASS``", "``rpz-passthru.``", "**yes**" + "``policy.DROP``", "``rpz-drop.``", "**yes**" + "``policy.TC``", "``rpz-tcp-only.``", "**yes**" + "Modified", "anything", "no" + + .. [#] The specification for ``*.`` wants a ``NODATA`` answer. + For now, ``policy.DENY`` action doing ``NXDOMAIN`` is typically used instead. + + .. csv-table:: + :header: "Policy Trigger", "Support" + + "QNAME", "**yes**" + "CLIENT-IP", "*partial*, may be done with :ref:`views <mod-view>`" + "IP", "no" + "NSDNAME", "no" + "NS-IP", "no" + +.. function:: policy.todnames({name, ...}) + + :param: names table of domain names in textual format + + Returns table of domain names in wire format converted from strings. + + .. code-block:: lua + + -- Convert single name + assert(todname('example.com') == '\7example\3com\0') + -- Convert table of names + policy.todnames({'example.com', 'me.cz'}) + { '\7example\3com\0', '\2me\2cz\0' } + + +.. _`Aho-Corasick`: https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_string_matching_algorithm +.. _`@jgrahamc`: https://github.com/jgrahamc/aho-corasick-lua +.. _RPZ: https://dnsrpz.info/ +.. _`PEM format`: https://en.wikipedia.org/wiki/Privacy-enhanced_Electronic_Mail +.. _`Pro DNS and BIND`: http://www.zytrax.com/books/dns/ch7/rpz.html +.. _`Jan-Piet Mens's post`: http://jpmens.net/2011/04/26/how-to-configure-your-bind-resolvers-to-lie-using-response-policy-zones-rpz/ +.. _`Transport Layer Security`: https://en.wikipedia.org/wiki/Transport_Layer_Security +.. _`DNS Privacy Project`: https://dnsprivacy.org/ +.. _`IETF draft dprive-dtls-and-tls-profiles`: https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles diff --git a/modules/policy/lua-aho-corasick/.gitignore b/modules/policy/lua-aho-corasick/.gitignore new file mode 100644 index 0000000..04fd221 --- /dev/null +++ b/modules/policy/lua-aho-corasick/.gitignore @@ -0,0 +1,6 @@ +*.d +*.o +*.a +*.so +*_dep.txt +tests/testinput diff --git a/modules/policy/lua-aho-corasick/LICENSE b/modules/policy/lua-aho-corasick/LICENSE new file mode 100644 index 0000000..dd65f72 --- /dev/null +++ b/modules/policy/lua-aho-corasick/LICENSE @@ -0,0 +1,28 @@ + Copyright (c) 2014 CloudFlare, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of CloudFlare, Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/modules/policy/lua-aho-corasick/Makefile b/modules/policy/lua-aho-corasick/Makefile new file mode 100644 index 0000000..6471664 --- /dev/null +++ b/modules/policy/lua-aho-corasick/Makefile @@ -0,0 +1,134 @@ +OS := $(shell uname) + +ifeq ($(OS), Darwin) + SO_EXT := dylib +else + SO_EXT := so +endif + +############################################################################# +# +# Binaries we are going to build +# +############################################################################# +# +C_SO_NAME = libac.$(SO_EXT) +LUA_SO_NAME = ahocorasick.$(SO_EXT) +AR_NAME = libac.a + +############################################################################# +# +# Compile and link flags +# +############################################################################# +PREFIX ?= /usr/local +LUA_VERSION := 5.1 +LUA_INCLUDE_DIR := $(PREFIX)/include/lua$(LUA_VERSION) +SO_TARGET_DIR := $(PREFIX)/lib/lua/$(LUA_VERSION) +LUA_TARGET_DIR := $(PREFIX)/share/lua/$(LUA_VERSION) + +# Available directives: +# -DDEBUG : Turn on debugging support +# -DVERIFY : To verify if the slow-version and fast-version implementations +# get exactly the same result. Note -DVERIFY implies -DDEBUG. +# +COMMON_FLAGS = -O3 #-g -DVERIFY -msse2 -msse3 -msse4.1 +COMMON_FLAGS += -fvisibility=hidden -Wall $(CXXFLAGS) $(MY_CXXFLAGS) $(CPPFLAGS) + +SO_CXXFLAGS = $(COMMON_FLAGS) -fPIC +SO_LFLAGS = $(COMMON_FLAGS) $(LDFLAGS) +AR_CXXFLAGS = $(COMMON_FLAGS) + +# -DVERIFY implies -DDEBUG +ifneq ($(findstring -DVERIFY, $(COMMON_FLAGS)), ) +ifeq ($(findstring -DDEBUG, $(COMMON_FLAGS)), ) + COMMON_FLAGS += -DDEBUG +endif +endif + +AR = ar +AR_FLAGS = cru + +############################################################################# +# +# Divide source codes and objects into several categories +# +############################################################################# +# +SRC_COMMON := ac_fast.cxx ac_slow.cxx +LIBAC_SO_SRC := $(SRC_COMMON) ac.cxx # source for libac.so +LUA_SO_SRC := $(SRC_COMMON) ac_lua.cxx # source for ahocorasick.so +LIBAC_A_SRC := $(LIBAC_SO_SRC) # source for libac.a + +############################################################################# +# +# Make rules +# +############################################################################# +# +.PHONY = all clean test benchmark prepare +all : $(C_SO_NAME) $(LUA_SO_NAME) $(AR_NAME) + +-include c_so_dep.txt +-include lua_so_dep.txt +-include ar_dep.txt + +BUILD_SO_DIR := build_so +BUILD_AR_DIR := build_ar + +$(BUILD_SO_DIR) :; mkdir $@ +$(BUILD_AR_DIR) :; mkdir $@ + +$(BUILD_SO_DIR)/%.o : %.cxx | $(BUILD_SO_DIR) + $(CXX) $< -c $(SO_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ + +$(BUILD_AR_DIR)/%.o : %.cxx | $(BUILD_AR_DIR) + $(CXX) $< -c $(AR_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ + +ifneq ($(OS), Darwin) +$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared -Wl,-soname=$(C_SO_NAME) $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt + +$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared -Wl,-soname=$(LUA_SO_NAME) $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt + +else +$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt + +$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared $(SO_LFLAGS) -o $@ -Wl,-undefined,dynamic_lookup + cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt +endif + +$(AR_NAME) : $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.o}) + $(AR) $(AR_FLAGS) $@ $+ + cat $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.d}) > lua_so_dep.txt + +############################################################################# +# +# Misc +# +############################################################################# +# +test : $(C_SO_NAME) + $(MAKE) -C tests && \ + luajit tests/lua_test.lua && \ + luajit tests/load_ac_test.lua + +benchmark: $(C_SO_NAME) + $(MAKE) benchmark -C tests + +clean : + -rm -rf *.o *.d c_so_dep.txt lua_so_dep.txt ar_dep.txt $(TEST) \ + $(C_SO_NAME) $(LUA_SO_NAME) $(TEST) $(BUILD_SO_DIR) $(BUILD_AR_DIR) \ + $(AR_NAME) + make clean -C tests + +install: + install -D -m 755 $(C_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(C_SO_NAME) + install -D -m 755 $(LUA_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(LUA_SO_NAME) + install -D -m 664 load_ac.lua $(DESTDIR)/$(LUA_TARGET_DIR)/load_ac.lua diff --git a/modules/policy/lua-aho-corasick/README.md b/modules/policy/lua-aho-corasick/README.md new file mode 100644 index 0000000..b5cc406 --- /dev/null +++ b/modules/policy/lua-aho-corasick/README.md @@ -0,0 +1,40 @@ +aho-corasick-lua +================ + +C++ and Lua Implementation of the Aho-Corasick (AC) string matching algorithm +(http://dl.acm.org/citation.cfm?id=360855). + +We began with pure Lua implementation and realize the performance is not +satisfactory. So we switch to C/C++ implementation. + +There are two shared objects provied by this package: libac.so and ahocorasick.so +The former is a regular shared object which can be directly used by C/C++ +application, or by Lua via FFI; and the later is a Lua module. An example usage +is shown bellow: + +```lua +local ac = require "ahocorasick" +local dict = {"string1", "string", "etc"} +local acinst = ac.create(dict) +local r = ac.match(acinst, "mystring") +``` + +For efficiency reasons, the implementation is slightly different from the +standard AC algorithm in that it doesn't return a set of strings in the dictionary +that match the given string, instead it only returns one of them in case the string +matches. The functionality of our implementation can be (precisely) described by +following pseudo-c snippet. + +```C +string foo(input-string, dictionary) { + string ret = the-end-of-input-string; + for each string s in dictionary { + // find the first occurrence match sub-string. + ret = min(ret, strstr(input-string, s); + } + return ret; +} +``` + +It's pretty easy to get rid of this limitation, just to associate each state with +a spare bit-vector dipicting the set of strings recognized by that state. diff --git a/modules/policy/lua-aho-corasick/ac.cxx b/modules/policy/lua-aho-corasick/ac.cxx new file mode 100644 index 0000000..23fb3b5 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac.cxx @@ -0,0 +1,101 @@ +// Interface functions for libac.so +// +#include "ac_slow.hpp" +#include "ac_fast.hpp" +#include "ac.h" + +static inline ac_result_t +_match(buf_header_t* ac, const char* str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(ac->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match(buf, str, len); + + #ifdef VERIFY + { + Match_Result r2 = buf->slow_impl->Match(str, len); + if (r.match_begin != r2.begin) { + ASSERT(0); + } else { + ASSERT((r.match_begin < 0) || + (r.match_end == r2.end && + r.pattern_idx == r2.pattern_idx)); + } + } + #endif + return r; +} + +extern "C" int +ac_match2(ac_t* ac, const char* str, unsigned int len) { + ac_result_t r = _match((buf_header_t*)(void*)ac, str, len); + return r.match_begin; +} + +extern "C" ac_result_t +ac_match(ac_t* ac, const char* str, unsigned int len) { + return _match((buf_header_t*)(void*)ac, str, len); +} + +extern "C" ac_result_t +ac_match_longest_l(ac_t* ac, const char* str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(((buf_header_t*)ac)->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match_Longest_L(buf, str, len); + return r; +} + +class BufAlloc : public Buf_Allocator { +public: + virtual AC_Buffer* alloc(int sz) { + return (AC_Buffer*)(new unsigned char[sz]); + } + + // Do not de-allocate the buffer when the BufAlloc die. + virtual void free() {} + + static void myfree(AC_Buffer* buf) { + ASSERT(buf->hdr.magic_num == AC_MAGIC_NUM); + const char* b = (const char*)buf; + delete[] b; + } +}; + +extern "C" ac_t* +ac_create(const char** strv, unsigned int* strlenv, unsigned int v_len) { + if (v_len >= 65535) { + // TODO: Currently we use 16-bit to encode pattern-index (see the + // comment to AC_State::is_term), therefore we are not able to + // handle pattern set with more than 65535 entries. + return 0; + } + + ACS_Constructor *acc; +#ifdef VERIFY + acc = new ACS_Constructor; +#else + ACS_Constructor tmp; + acc = &tmp; +#endif + acc->Construct(strv, strlenv, v_len); + + BufAlloc ba; + AC_Converter cvt(*acc, ba); + AC_Buffer* buf = cvt.Convert(); + +#ifdef VERIFY + buf->slow_impl = acc; +#endif + return (ac_t*)(void*)buf; +} + +extern "C" void +ac_free(void* ac) { + AC_Buffer* buf = (AC_Buffer*)ac; +#ifdef VERIFY + delete buf->slow_impl; +#endif + + BufAlloc::myfree(buf); +} diff --git a/modules/policy/lua-aho-corasick/ac.h b/modules/policy/lua-aho-corasick/ac.h new file mode 100644 index 0000000..30bf447 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac.h @@ -0,0 +1,49 @@ +#ifndef AC_H +#define AC_H +#ifdef __cplusplus +extern "C" { +#endif + +#define AC_EXPORT __attribute__ ((visibility ("default"))) + +/* If the subject-string dosen't match any of the given patterns, "match_begin" + * should be a negative; otherwise the substring of the subject-string, + * starting from offset "match_begin" to "match_end" incusively, + * should exactly match the pattern specified by the 'pattern_idx' (i.e. + * the pattern is "pattern_v[pattern_idx]" where the "pattern_v" is the + * first acutal argument passing to ac_create()) + */ +typedef struct { + int match_begin; + int match_end; + int pattern_idx; +} ac_result_t; + +struct ac_t; + +/* Create an AC instance. "pattern_v" is a vector of patterns, the length of + * i-th pattern is specified by "pattern_len_v[i]"; the number of patterns + * is specified by "vect_len". + * + * Return the instance on success, or NUL otherwise. + */ +ac_t* ac_create(const char** pattern_v, unsigned int* pattern_len_v, + unsigned int vect_len) AC_EXPORT; + +ac_result_t ac_match(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +ac_result_t ac_match_longest_l(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +/* Similar to ac_match() except that it only returns match-begin. The rationale + * for this interface is that luajit has hard time in dealing with strcture- + * return-value. + */ +int ac_match2(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +void ac_free(void*) AC_EXPORT; + +#ifdef __cplusplus +} +#endif + +#endif /* AC_H */ diff --git a/modules/policy/lua-aho-corasick/ac_fast.cxx b/modules/policy/lua-aho-corasick/ac_fast.cxx new file mode 100644 index 0000000..9dbc2e6 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_fast.cxx @@ -0,0 +1,468 @@ +#include <algorithm> // for std::sort +#include "ac_slow.hpp" +#include "ac_fast.hpp" + +uint32 +AC_Converter::Calc_State_Sz(const ACS_State* s) const { + AC_State dummy; + uint32 sz = offsetof(AC_State, input_vect); + sz += s->Get_GotoNum() * sizeof(dummy.input_vect[0]); + + if (sz < sizeof(AC_State)) + sz = sizeof(AC_State); + + uint32 align = __alignof__(dummy); + sz = (sz + align - 1) & ~(align - 1); + return sz; +} + +AC_Buffer* +AC_Converter::Alloc_Buffer() { + const vector<ACS_State*>& all_states = _acs.Get_All_States(); + const ACS_State* root_state = _acs.Get_Root_State(); + uint32 root_fanout = root_state->Get_GotoNum(); + + // Step 1: Calculate the buffer size + AC_Ofst root_goto_ofst, states_ofst_ofst, first_state_ofst; + + // part 1 : buffer header + uint32 sz = root_goto_ofst = sizeof(AC_Buffer); + + // part 2: Root-node's goto function + if (likely(root_fanout != 255)) + sz += 256; + else + root_goto_ofst = 0; + + // part 3: mapping of state's relative position. + unsigned align = __alignof__(AC_Ofst); + sz = (sz + align - 1) & ~(align - 1); + states_ofst_ofst = sz; + + sz += sizeof(AC_Ofst) * all_states.size(); + + // part 4: state's contents + align = __alignof__(AC_State); + sz = (sz + align - 1) & ~(align - 1); + first_state_ofst = sz; + + uint32 state_sz = 0; + for (vector<ACS_State*>::const_iterator i = all_states.begin(), + e = all_states.end(); i != e; i++) { + state_sz += Calc_State_Sz(*i); + } + state_sz -= Calc_State_Sz(root_state); + + sz += state_sz; + + // Step 2: Allocate buffer, and populate header. + AC_Buffer* buf = _buf_alloc.alloc(sz); + + buf->hdr.magic_num = AC_MAGIC_NUM; + buf->hdr.impl_variant = IMPL_FAST_VARIANT; + buf->buf_len = sz; + buf->root_goto_ofst = root_goto_ofst; + buf->states_ofst_ofst = states_ofst_ofst; + buf->first_state_ofst = first_state_ofst; + buf->root_goto_num = root_fanout; + buf->state_num = _acs.Get_State_Num(); + return buf; +} + +void +AC_Converter::Populate_Root_Goto_Func(AC_Buffer* buf, + GotoVect& goto_vect) { + unsigned char *buf_base = (unsigned char*)(buf); + InputTy* root_gotos = (InputTy*)(buf_base + buf->root_goto_ofst); + const ACS_State* root_state = _acs.Get_Root_State(); + + root_state->Get_Sorted_Gotos(goto_vect); + + // Renumber the ID of root-node's immediate kids. + uint32 new_id = 1; + bool full_fantout = (goto_vect.size() == 255); + if (likely(!full_fantout)) + bzero(root_gotos, 256*sizeof(InputTy)); + + for (GotoVect::iterator i = goto_vect.begin(), e = goto_vect.end(); + i != e; i++, new_id++) { + InputTy c = i->first; + ACS_State* s = i->second; + _id_map[s->Get_ID()] = new_id; + + if (likely(!full_fantout)) + root_gotos[c] = new_id; + } +} + +AC_Buffer* +AC_Converter::Convert() { + // Step 1: Some preparation stuff. + GotoVect gotovect; + + _id_map.clear(); + _ofst_map.clear(); + _id_map.resize(_acs.Get_Next_Node_Id()); + _ofst_map.resize(_acs.Get_Next_Node_Id()); + + // Step 2: allocate buffer to accommodate the entire AC graph. + AC_Buffer* buf = Alloc_Buffer(); + unsigned char* buf_base = (unsigned char*)buf; + + // Step 3: Root node need special care. + Populate_Root_Goto_Func(buf, gotovect); + buf->root_goto_num = gotovect.size(); + _id_map[_acs.Get_Root_State()->Get_ID()] = 0; + + // Step 4: Converting the remaining states by BFSing the graph. + // First of all, enter root's immediate kids to the working list. + vector<const ACS_State*> wl; + State_ID id = 1; + for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); + i != e; i++, id++) { + ACS_State* s = i->second; + wl.push_back(s); + _id_map[s->Get_ID()] = id; + } + + AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); + AC_Ofst ofst = buf->first_state_ofst; + for (uint32 idx = 0; idx < wl.size(); idx++) { + const ACS_State* old_s = wl[idx]; + AC_State* new_s = (AC_State*)(buf_base + ofst); + + // This property should hold as we: + // - States are appended to worklist in the BFS order. + // - sibiling states are appended to worklist in the order of their + // corresponding input. + // + State_ID state_id = idx + 1; + ASSERT(_id_map[old_s->Get_ID()] == state_id); + + state_ofst_vect[state_id] = ofst; + + new_s->first_kid = wl.size() + 1; + new_s->depth = old_s->Get_Depth(); + new_s->is_term = old_s->is_Terminal() ? + old_s->get_Pattern_Idx() + 1 : 0; + + uint32 gotonum = old_s->Get_GotoNum(); + new_s->goto_num = gotonum; + + // Populate the "input" field + old_s->Get_Sorted_Gotos(gotovect); + uint32 input_idx = 0; + uint32 id = wl.size() + 1; + InputTy* input_vect = new_s->input_vect; + for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); + i != e; i++, id++, input_idx++) { + input_vect[input_idx] = i->first; + + ACS_State* kid = i->second; + _id_map[kid->Get_ID()] = id; + wl.push_back(kid); + } + + _ofst_map[old_s->Get_ID()] = ofst; + ofst += Calc_State_Sz(old_s); + } + + // This assertion might be useful to catch buffer overflow + ASSERT(ofst == buf->buf_len); + + // Populate the fail-link field. + for (vector<const ACS_State*>::iterator i = wl.begin(), e = wl.end(); + i != e; i++) { + const ACS_State* slow_s = *i; + State_ID fast_s_id = _id_map[slow_s->Get_ID()]; + AC_State* fast_s = (AC_State*)(buf_base + state_ofst_vect[fast_s_id]); + if (const ACS_State* fl = slow_s->Get_FailLink()) { + State_ID id = _id_map[fl->Get_ID()]; + fast_s->fail_link = id; + } else + fast_s->fail_link = 0; + } +#ifdef DEBUG + //dump_buffer(buf, stderr); +#endif + return buf; +} + +static inline AC_State* +Get_State_Addr(unsigned char* buf_base, AC_Ofst* StateOfstVect, uint32 state_id) { + ASSERT(state_id != 0 && "root node is handled in speical way"); + ASSERT(state_id < ((AC_Buffer*)buf_base)->state_num); + return (AC_State*)(buf_base + StateOfstVect[state_id]); +} + +// The performance of the binary search is critical to this work. +// +// Here we provide two versions of binary-search functions. +// The non-pristine version seems to consistently out-perform "pristine" one on +// bunch of benchmarks we tested. With the benchmark under tests/testinput/ +// +// The speedup is following on my laptop (core i7, ubuntu): +// +// benchmark was is +// ---------------------------------------- +// image.bin 2.3s 2.0s +// test.tar 6.7s 5.7s +// +// NOTE: As of I write this comment, we only measure the performance on about +// 10+ benchmarks. It's still too early to say which one works better. +// +#if !defined(BS_MULTI_VER) +static bool __attribute__((always_inline)) inline +Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { + if (vect_len <= 8) { + for (int i = 0; i < vect_len; i++) { + if (input_vect[i] == input) { + idx = i; + return true; + } + } + return false; + } + + // The "low" and "high" must be signed integers, as they could become -1. + // Also since they are signed integer, "(low + high)/2" is sightly more + // expensive than (low+high)>>1 or ((unsigned)(low + high))/2. + // + int low = 0, high = vect_len - 1; + while (low <= high) { + int mid = (low + high) >> 1; + InputTy mid_c = input_vect[mid]; + + if (input < mid_c) + high = mid - 1; + else if (input > mid_c) + low = mid + 1; + else { + idx = mid; + return true; + } + } + return false; +} + +#else + +/* Let us call this version "pristine" version. */ +static inline bool +Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { + int low = 0, high = vect_len - 1; + while (low <= high) { + int mid = (low + high) >> 1; + InputTy mid_c = input_vect[mid]; + + if (input < mid_c) + high = mid - 1; + else if (input > mid_c) + low = mid + 1; + else { + idx = mid; + return true; + } + } + return false; +} +#endif + +typedef enum { + // Look for the first match. e.g. pattern set = {"ab", "abc", "def"}, + // subject string "ababcdef". The first match would be "ab" at the + // beginning of the subject string. + MV_FIRST_MATCH, + + // Look for the left-most longest match. Follow above example; there are + // two longest matches, "abc" and "def", and the left-most longest match + // is "abc". + MV_LEFT_LONGEST, + + // Similar to the left-most longest match, except that it returns the + // *right* most longest match. Follow above example, the match would + // be "def". NYI. + MV_RIGHT_LONGEST, + + // Return all patterns that match that given subject string. NYI. + MV_ALL_MATCHES, +} MATCH_VARIANT; + +/* The Match_Tmpl is the template for vairants MV_FIRST_MATCH, MV_LEFT_LONGEST, + * MV_RIGHT_LONGEST (If we really really need MV_RIGHT_LONGEST variant, we are + * better off implementing it in a seprate function). + * + * The Match_Tmpl supports three variants at once "symbolically", once it's + * instanced to a particular variants, all the code irrelevant to the variants + * will be statically removed. So don't worry about the code like + * "if (variant == MV_XXXX)"; they will not incur any penalty. + * + * The drawback of using template is increased code size. Unfortunately, there + * is no silver bullet. + */ +template<MATCH_VARIANT variant> static ac_result_t +Match_Tmpl(AC_Buffer* buf, const char* str, uint32 len) { + unsigned char* buf_base = (unsigned char*)(buf); + unsigned char* root_goto = buf_base + buf->root_goto_ofst; + AC_Ofst* states_ofst_vect = (AC_Ofst* )(buf_base + buf->states_ofst_ofst); + + AC_State* state = 0; + uint32 idx = 0; + + // Skip leading chars that are not valid input of root-nodes. + if (likely(buf->root_goto_num != 255)) { + while(idx < len) { + unsigned char c = str[idx++]; + if (unsigned char kid_id = root_goto[c]) { + state = Get_State_Addr(buf_base, states_ofst_vect, kid_id); + break; + } + } + } else { + idx = 1; + state = Get_State_Addr(buf_base, states_ofst_vect, *str); + } + + ac_result_t r = {-1, -1}; + if (likely(state != 0)) { + if (unlikely(state->is_term)) { + /* Dictionary may have string of length 1 */ + r.match_begin = idx - state->depth; + r.match_end = idx - 1; + r.pattern_idx = state->is_term - 1; + + if (variant == MV_FIRST_MATCH) { + return r; + } + } + } + + while (idx < len) { + unsigned char c = str[idx]; + int res; + bool found; + found = Binary_Search_Input(state->input_vect, state->goto_num, c, res); + if (found) { + // The "t = goto(c, current_state)" is valid, advance to state "t". + uint32 kid = state->first_kid + res; + state = Get_State_Addr(buf_base, states_ofst_vect, kid); + idx++; + } else { + // Follow the fail-link. + State_ID fl = state->fail_link; + if (fl == 0) { + // fail-link is root-node, which implies the root-node dosen't + // have 255 valid transitions (otherwise, the fail-link should + // points to "goto(root, c)"), so we don't need speical handling + // as we did before this while-loop is entered. + // + while(idx < len) { + InputTy c = str[idx++]; + if (unsigned char kid_id = root_goto[c]) { + state = + Get_State_Addr(buf_base, states_ofst_vect, kid_id); + break; + } + } + } else { + state = Get_State_Addr(buf_base, states_ofst_vect, fl); + } + } + + // Check to see if the state is terminal state? + if (state->is_term) { + if (variant == MV_FIRST_MATCH) { + ac_result_t r; + r.match_begin = idx - state->depth; + r.match_end = idx - 1; + r.pattern_idx = state->is_term - 1; + return r; + } + + if (variant == MV_LEFT_LONGEST) { + int match_begin = idx - state->depth; + int match_end = idx - 1; + + if (r.match_begin == -1 || + match_end - match_begin > r.match_end - r.match_begin) { + r.match_begin = match_begin; + r.match_end = match_end; + r.pattern_idx = state->is_term - 1; + } + continue; + } + + ASSERT(false && "NYI"); + } + } + + return r; +} + +ac_result_t +Match(AC_Buffer* buf, const char* str, uint32 len) { + return Match_Tmpl<MV_FIRST_MATCH>(buf, str, len); +} + +ac_result_t +Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len) { + return Match_Tmpl<MV_LEFT_LONGEST>(buf, str, len); +} + +#ifdef DEBUG +void +AC_Converter::dump_buffer(AC_Buffer* buf, FILE* f) { + vector<AC_Ofst> state_ofst; + state_ofst.resize(_id_map.size()); + + fprintf(f, "Id maps between old/slow and new/fast graphs\n"); + int old_id = 0; + for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end(); + i != e; i++, old_id++) { + State_ID new_id = *i; + if (new_id != 0) { + fprintf(f, "%d -> %d, ", old_id, new_id); + } + } + fprintf(f, "\n"); + + int idx = 0; + for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end(); + i != e; i++, idx++) { + uint32 id = *i; + if (id == 0) continue; + state_ofst[id] = _ofst_map[idx]; + } + + unsigned char* buf_base = (unsigned char*)buf; + + // dump root goto-function. + fprintf(f, "root, fanout:%d goto {", buf->root_goto_num); + if (buf->root_goto_num != 255) { + unsigned char* root_goto = buf_base + buf->root_goto_ofst; + for (uint32 i = 0; i < 255; i++) { + if (root_goto[i] != 0) + fprintf(f, "%c->S:%d, ", (unsigned char)i, root_goto[i]); + } + } else { + fprintf(f, "full fanout\n"); + } + fprintf(f, "}\n"); + + // dump remaining states. + AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); + for (uint32 i = 1, e = buf->state_num; i < e; i++) { + AC_Ofst ofst = state_ofst_vect[i]; + ASSERT(ofst == state_ofst[i]); + fprintf(f, "S:%d, ofst:%d, goto={", i, ofst); + + AC_State* s = (AC_State*)(buf_base + ofst); + State_ID kid = s->first_kid; + for (uint32 k = 0, ke = s->goto_num; k < ke; k++, kid++) + fprintf(f, "%c->S:%d, ", s->input_vect[k], kid); + + fprintf(f, "}, fail-link = S:%d, %s\n", s->fail_link, + s->is_term ? "terminal" : ""); + } +} +#endif diff --git a/modules/policy/lua-aho-corasick/ac_fast.hpp b/modules/policy/lua-aho-corasick/ac_fast.hpp new file mode 100644 index 0000000..9ac557c --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_fast.hpp @@ -0,0 +1,124 @@ +#ifndef AC_FAST_H +#define AC_FAST_H + +#include <vector> +#include "ac.h" +#include "ac_slow.hpp" + +using namespace std; + +class ACS_Constructor; + +typedef uint32 AC_Ofst; +typedef uint32 State_ID; + +// The entire "fast" AC graph is converted from its "slow" version, and store +// in an consecutive trunk of memory or "buffer". Since the pointers in the +// fast AC graph are represented as offset relative to the base address of +// the buffer, this fast AC graph is position-independent, meaning cloning +// the fast graph is just to memcpy the entire buffer. +// +// The buffer is laid-out as following: +// +// 1. The buffer header. (i.e. the AC_Buffer content) +// 2. root-node's goto functions. It is represented as an array indiced by +// root-node's valid inputs, and the element is the ID of the corresponding +// transition state (aka kid). To save space, we used 8-bit to represent +// the IDs. ID of root's kids starts with 1. +// +// Root may have 255 valid inputs. In this speical case, i-th element +// stores value i -- i.e the i-th state. So, we don't need such array +// at all. On the other hand, 8-bit is insufficient to encode kids' ID. +// +// 3. An array indiced by state's id, and the element is the offset +// of correspoding state wrt the base address of the buffer. +// +// 4. the contents of states. +// +typedef struct { + buf_header_t hdr; // The header exposed to the user using this lib. +#ifdef VERIFY + ACS_Constructor* slow_impl; +#endif + uint32 buf_len; + AC_Ofst root_goto_ofst; // addr of root node's goto() function. + AC_Ofst states_ofst_ofst; // addr of state pointer vector (indiced by id) + AC_Ofst first_state_ofst; // addr of the first state in the buffer. + uint16 root_goto_num; // fan-out of root-node. + uint16 state_num; // number of states + + // Followed by the gut of the buffer: + // 1. map: root's-valid-input -> kid's id + // 2. map: state's ID -> offset of the state + // 3. states' content. +} AC_Buffer; + +// Depict the state of "fast" AC graph. +typedef struct { + // transition are sorted. For instance, state s1, has two transitions : + // goto(b) -> S_b, goto(a)->S_a. The inputs are sorted in the ascending + // order, and the target states are permuted accordingly. In this case, + // the inputs are sorted as : a, b, and the target states are permuted + // into S_a, S_b. So, S_a is the 1st kid, the ID of kids are consecutive, + // so we don't need to save all the target kids. + // + State_ID first_kid; + AC_Ofst fail_link; + short depth; // How far away from root. + unsigned short is_term; // Is terminal node. if is_term != 0, it encodes + // the value of "1 + pattern-index". + unsigned char goto_num; // The number of valid transition. + InputTy input_vect[1]; // Vector of valid input. Must be last field! +} AC_State; + +class Buf_Allocator { +public: + Buf_Allocator() : _buf(0) {} + virtual ~Buf_Allocator() { free(); } + + virtual AC_Buffer* alloc(int sz) = 0; + virtual void free() {}; +protected: + AC_Buffer* _buf; +}; + +// Convert slow-AC-graph into fast one. +class AC_Converter { +public: + AC_Converter(ACS_Constructor& acs, Buf_Allocator& ba) : + _acs(acs), _buf_alloc(ba) {} + AC_Buffer* Convert(); + +private: + // Return the size in byte needed to to save the specified state. + uint32 Calc_State_Sz(const ACS_State *) const; + + // In fast-AC-graph, the ID is bit trikcy. Given a state of slow-graph, + // this function is to return the ID of its counterpart in the fast-graph. + State_ID Get_Renumbered_Id(const ACS_State *s) const { + const vector<uint32> &m = _id_map; + return m[s->Get_ID()]; + } + + AC_Buffer* Alloc_Buffer(); + void Populate_Root_Goto_Func(AC_Buffer *, GotoVect&); + +#ifdef DEBUG + void dump_buffer(AC_Buffer*, FILE*); +#endif + +private: + ACS_Constructor& _acs; + Buf_Allocator& _buf_alloc; + + // map: ID of state in slow-graph -> ID of counterpart in fast-graph. + vector<uint32> _id_map; + + // map: ID of state in slow-graph -> offset of counterpart in fast-graph. + vector<AC_Ofst> _ofst_map; +}; + +ac_result_t Match(AC_Buffer* buf, const char* str, uint32 len); +ac_result_t Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len); + +#endif // AC_FAST_H diff --git a/modules/policy/lua-aho-corasick/ac_lua.cxx b/modules/policy/lua-aho-corasick/ac_lua.cxx new file mode 100644 index 0000000..ad7307e --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_lua.cxx @@ -0,0 +1,173 @@ +// Interface functions for libac.so +// +#include <vector> +#include <string> +#include "ac_slow.hpp" +#include "ac_fast.hpp" +#include "ac.h" // for the definition of ac_result_t +#include "ac_util.hpp" + +extern "C" { + #include <lua.h> + #include <lauxlib.h> +} + +#if defined(USE_SLOW_VER) +#error "Not going to implement it" +#endif + +using namespace std; +static const char* tname = "aho-corasick"; + +class BufAlloc : public Buf_Allocator { +public: + BufAlloc(lua_State* L) : _L(L) {} + virtual AC_Buffer* alloc(int sz) { + return (AC_Buffer*)lua_newuserdata (_L, sz); + } + + // Let GC to take care. + virtual void free() {} + +private: + lua_State* _L; +}; + +static bool +_create_helper(lua_State* L, const vector<const char*>& str_v, + const vector<unsigned int>& strlen_v) { + ASSERT(str_v.size() == strlen_v.size()); + + ACS_Constructor acc; + BufAlloc ba(L); + + // Step 1: construt the slow version. + unsigned int strnum = str_v.size(); + const char** str_vect = new const char*[strnum]; + unsigned int* strlen_vect = new unsigned int[strnum]; + + int idx = 0; + for (vector<const char*>::const_iterator i = str_v.begin(), e = str_v.end(); + i != e; i++) { + str_vect[idx++] = *i; + } + + idx = 0; + for (vector<unsigned int>::const_iterator i = strlen_v.begin(), + e = strlen_v.end(); i != e; i++) { + strlen_vect[idx++] = *i; + } + + acc.Construct(str_vect, strlen_vect, idx); + delete[] str_vect; + delete[] strlen_vect; + + // Step 2: convert to fast version + AC_Converter cvt(acc, ba); + return cvt.Convert() != 0; +} + +static ac_result_t +_match_helper(buf_header_t* ac, const char *str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(ac->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match(buf, str, len); + return r; +} + +// LUA sematic: +// input: array of strings +// output: userdata containing the AC-graph (i.e. the AC_Buffer). +// +static int +lac_create(lua_State* L) { + // The table of the array must be the 1st argument. + int input_tab = 1; + + luaL_checktype(L, input_tab, LUA_TTABLE); + + // Init the "iteartor". + lua_pushnil(L); + + vector<const char*> str_v; + vector<unsigned int> strlen_v; + + // Loop over the elements + while (lua_next(L, input_tab)) { + size_t str_len; + const char* s = luaL_checklstring(L, -1, &str_len); + str_v.push_back(s); + strlen_v.push_back(str_len); + + // remove the value, but keep the key as the iterator. + lua_pop(L, 1); + } + + // pop the nil value + lua_pop(L, 1); + + if (_create_helper(L, str_v, strlen_v)) { + // The AC graph, as a userdata is already pushed to the stack, hence 1. + return 1; + } + + return 0; +} + +// LUA input: +// arg1: the userdata, representing the AC graph, returned from l_create(). +// arg2: the string to be matched. +// +// LUA return: +// if match, return index range of the match; otherwise nil is returned. +// +static int +lac_match(lua_State* L) { + buf_header_t* ac = (buf_header_t*)lua_touserdata(L, 1); + if (!ac) { + luaL_checkudata(L, 1, tname); + return 0; + } + + size_t len; + const char* str; + #if LUA_VERSION_NUM >= 502 + str = luaL_tolstring(L, 2, &len); + #else + str = lua_tolstring(L, 2, &len); + #endif + if (!str) { + luaL_checkstring(L, 2); + return 0; + } + + ac_result_t r = _match_helper(ac, str, len); + if (r.match_begin != -1) { + lua_pushinteger(L, r.match_begin); + lua_pushinteger(L, r.match_end); + return 2; + } + + return 0; +} + +static const struct luaL_Reg lib_funcs[] = { + { "create", lac_create }, + { "match", lac_match }, + {0, 0} +}; + +extern "C" int AC_EXPORT +luaopen_ahocorasick(lua_State* L) { + luaL_newmetatable(L, tname); + +#if LUA_VERSION_NUM == 501 + luaL_register(L, tname, lib_funcs); +#elif LUA_VERSION_NUM >= 502 + luaL_newlib(L, lib_funcs); +#else + #error "Don't know how to do it right" +#endif + return 1; +} diff --git a/modules/policy/lua-aho-corasick/ac_slow.cxx b/modules/policy/lua-aho-corasick/ac_slow.cxx new file mode 100644 index 0000000..cb3957a --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_slow.cxx @@ -0,0 +1,318 @@ +#include <ctype.h> +#include <strings.h> // for bzero +#include <algorithm> +#include "ac_slow.hpp" +#include "ac.h" + +////////////////////////////////////////////////////////////////////////// +// +// Implementation of AhoCorasick_Slow +// +////////////////////////////////////////////////////////////////////////// +// +ACS_Constructor::ACS_Constructor() : _next_node_id(1) { + _root = new_state(); + _root_char = new InputTy[256]; + bzero((void*)_root_char, 256); + +#ifdef VERIFY + _pattern_buf = 0; +#endif +} + +ACS_Constructor::~ACS_Constructor() { + for (std::vector<ACS_State* >::iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + delete *i; + } + _all_states.clear(); + delete[] _root_char; + +#ifdef VERIFY + delete[] _pattern_buf; +#endif +} + +ACS_State* +ACS_Constructor::new_state() { + ACS_State* t = new ACS_State(_next_node_id++); + _all_states.push_back(t); + return t; +} + +void +ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len, + int pattern_idx) { + ACS_State* state = _root; + for (unsigned int i = 0; i < str_len; i++) { + const char c = str[i]; + ACS_State* new_s = state->Get_Goto(c); + if (!new_s) { + new_s = new_state(); + new_s->_depth = state->_depth + 1; + state->Set_Goto(c, new_s); + } + state = new_s; + } + state->_is_terminal = true; + state->set_Pattern_Idx(pattern_idx); +} + +void +ACS_Constructor::Propagate_faillink() { + ACS_State* r = _root; + std::vector<ACS_State*> wl; + + const ACS_Goto_Map& m = r->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) { + ACS_State* s = i->second; + s->_fail_link = r; + wl.push_back(s); + } + + // For any input c, make sure "goto(root, c)" is valid, which make the + // fail-link propagation lot easier. + ACS_Goto_Map goto_save = r->_goto_map; + for (uint32 i = 0; i <= 255; i++) { + ACS_State* s = r->Get_Goto(i); + if (!s) r->Set_Goto(i, r); + } + + for (uint32 i = 0; i < wl.size(); i++) { + ACS_State* s = wl[i]; + ACS_State* fl = s->_fail_link; + + const ACS_Goto_Map& tran_map = s->Get_Goto_Map(); + + for (ACS_Goto_Map::const_iterator ii = tran_map.begin(), + ee = tran_map.end(); ii != ee; ii++) { + InputTy c = ii->first; + ACS_State *tran = ii->second; + + ACS_State* tran_fl = 0; + for (ACS_State* fl_walk = fl; ;) { + if (ACS_State* t = fl_walk->Get_Goto(c)) { + tran_fl = t; + break; + } else { + fl_walk = fl_walk->Get_FailLink(); + } + } + + tran->_fail_link = tran_fl; + wl.push_back(tran); + } + } + + // Remove "goto(root, c) == root" transitions + r->_goto_map = goto_save; +} + +void +ACS_Constructor::Construct(const char** strv, unsigned int* strlenv, + uint32 strnum) { + Save_Patterns(strv, strlenv, strnum); + + for (uint32 i = 0; i < strnum; i++) { + Add_Pattern(strv[i], strlenv[i], i); + } + + Propagate_faillink(); + unsigned char* p = _root_char; + + const ACS_Goto_Map& m = _root->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); + i != e; i++) { + p[i->first] = 1; + } +} + +Match_Result +ACS_Constructor::MatchHelper(const char *str, uint32 len) const { + const ACS_State* root = _root; + const ACS_State* state = root; + + uint32 idx = 0; + while (idx < len) { + InputTy c = str[idx]; + idx++; + if (_root_char[c]) { + state = root->Get_Goto(c); + break; + } + } + + if (unlikely(state->is_Terminal())) { + // This could happen if the one of the pattern has only one char! + uint32 pos = idx - 1; + Match_Result r(pos - state->Get_Depth() + 1, pos, + state->get_Pattern_Idx()); + return r; + } + + while (idx < len) { + InputTy c = str[idx]; + ACS_State* gs = state->Get_Goto(c); + + if (!gs) { + ACS_State* fl = state->Get_FailLink(); + if (fl == root) { + while (idx < len) { + InputTy c = str[idx]; + idx++; + if (_root_char[c]) { + state = root->Get_Goto(c); + break; + } + } + } else { + state = fl; + } + } else { + idx ++; + state = gs; + } + + if (state->is_Terminal()) { + uint32 pos = idx - 1; + Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos, + state->get_Pattern_Idx()); + return r; + } + } + + return Match_Result(-1, -1, -1); +} + +#ifdef DEBUG +void +ACS_Constructor::dump_text(const char* txtfile) const { + FILE* f = fopen(txtfile, "w+"); + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State* s = *i; + + fprintf(f, "S%d goto:{", s->Get_ID()); + const ACS_Goto_Map& goto_func = s->Get_Goto_Map(); + + for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end(); + i != e; i++) { + InputTy input = i->first; + ACS_State* tran = i->second; + if (isprint(input)) + fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID()); + else + fprintf(f, "%#x -> S:%d,", input, tran->Get_ID()); + } + fprintf(f, "} "); + + if (s->_fail_link) { + fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID()); + } + + if (s->_is_terminal) { + fprintf(f, ", terminal"); + } + + fprintf(f, "\n"); + } + fclose(f); +} + +void +ACS_Constructor::dump_dot(const char *dotfile) const { + FILE* f = fopen(dotfile, "w+"); + const char* indent = " "; + + fprintf(f, "digraph G {\n"); + + // Emit node information + fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID()); + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State *s = *i; + if (s->_is_terminal) { + fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID()); + } + } + fprintf(f, "\n"); + + // Emit edge information + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State* s = *i; + uint32 id = s->Get_ID(); + + const ACS_Goto_Map& m = s->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end(); + ii != ee; ii++) { + InputTy input = ii->first; + ACS_State* tran = ii->second; + if (isalnum(input)) + fprintf(f, "%s%d -> %d [label=%c];\n", + indent, id, tran->Get_ID(), input); + else + fprintf(f, "%s%d -> %d [label=\"%#x\"];\n", + indent, id, tran->Get_ID(), input); + + } + + // Emit fail-link + ACS_State* fl = s->Get_FailLink(); + if (fl && fl != _root) { + fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n", + indent, id, fl->Get_ID()); + } + } + fprintf(f, "}\n"); + fclose(f); +} +#endif + +#ifdef VERIFY +void +ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r) + const { + if (r->begin >= 0) { + unsigned len = r->end - r->begin + 1; + int ptn_idx = r->pattern_idx; + + ASSERT(ptn_idx >= 0 && + len == get_ith_Pattern_Len(ptn_idx) && + memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0); + } +} + +void +ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv, + int pattern_num) { + // calculate the total size needed to save all patterns. + // + int buf_size = 0; + for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; } + + // HINT: patterns are delimited by '\0' in order to ease debugging. + buf_size += pattern_num; + ASSERT(_pattern_buf == 0); + _pattern_buf = new char[buf_size + 1]; + #define MAGIC_NUM 0x5a + _pattern_buf[buf_size] = MAGIC_NUM; + + int ofst = 0; + _pattern_lens.resize(pattern_num); + _pattern_vect.resize(pattern_num); + for (int i = 0; i < pattern_num; i++) { + int l = strlenv[i]; + _pattern_lens[i] = l; + _pattern_vect[i] = _pattern_buf + ofst; + + memcpy(_pattern_buf + ofst, strv[i], l); + ofst += l; + _pattern_buf[ofst++] = '\0'; + } + + ASSERT(_pattern_buf[buf_size] == MAGIC_NUM); + #undef MAGIC_NUM +} + +#endif diff --git a/modules/policy/lua-aho-corasick/ac_slow.hpp b/modules/policy/lua-aho-corasick/ac_slow.hpp new file mode 100644 index 0000000..030b95d --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_slow.hpp @@ -0,0 +1,158 @@ +#ifndef MY_AC_H +#define MY_AC_H + +#include <string.h> +#include <stdio.h> +#include <map> +#include <vector> +#include <algorithm> // for std::sort +#include "ac_util.hpp" + +// Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation" +class ACS_State; +class ACS_Constructor; +class AhoCorasick; + +using namespace std; + +typedef std::map<InputTy, ACS_State*> ACS_Goto_Map; + +class Match_Result { +public: + int begin; + int end; + int pattern_idx; + Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {} +}; + +typedef pair<InputTy, ACS_State *> GotoPair; +typedef vector<GotoPair> GotoVect; + +// Sorting functor +class GotoSort { +public: + bool operator() (const GotoPair& g1, const GotoPair& g2) { + return g1.first < g2.first; + } +}; + +class ACS_State { +friend class ACS_Constructor; + +public: + ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0), + _is_terminal(false), _fail_link(0){} + ~ACS_State() {}; + + void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; } + ACS_State *Get_Goto(InputTy c) const { + ACS_Goto_Map::const_iterator iter = _goto_map.find(c); + return iter != _goto_map.end() ? (*iter).second : 0; + } + + // Return all transitions sorted in the ascending order of their input. + void Get_Sorted_Gotos(GotoVect& Gotos) const { + const ACS_Goto_Map& m = _goto_map; + Gotos.clear(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); + i != e; i++) { + Gotos.push_back(GotoPair(i->first, i->second)); + } + sort(Gotos.begin(), Gotos.end(), GotoSort()); + } + + ACS_State* Get_FailLink() const { return _fail_link; } + uint32 Get_GotoNum() const { return _goto_map.size(); } + uint32 Get_ID() const { return _id; } + uint32 Get_Depth() const { return _depth; } + const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; } + bool is_Terminal() const { return _is_terminal; } + int get_Pattern_Idx() const { + ASSERT(is_Terminal() && _pattern_idx >= 0); + return _pattern_idx; + } + +private: + void set_Pattern_Idx(int idx) { + ASSERT(is_Terminal()); + _pattern_idx = idx; + } + +private: + uint32 _id; + int _pattern_idx; + short _depth; + bool _is_terminal; + ACS_Goto_Map _goto_map; + ACS_State* _fail_link; +}; + +class ACS_Constructor { +public: + ACS_Constructor(); + ~ACS_Constructor(); + + void Construct(const char** strv, unsigned int* strlenv, + unsigned int strnum); + + Match_Result Match(const char* s, uint32 len) const { + Match_Result r = MatchHelper(s, len); + Verify_Result(s, &r); + return r; + } + + Match_Result Match(const char* s) const { return Match(s, strlen(s)); } + +#ifdef DEBUG + void dump_text(const char* = "ac.txt") const; + void dump_dot(const char* = "ac.dot") const; +#endif + const ACS_State *Get_Root_State() const { return _root; } + const vector<ACS_State*>& Get_All_States() const { + return _all_states; + } + + uint32 Get_Next_Node_Id() const { return _next_node_id; } + uint32 Get_State_Num() const { return _next_node_id - 1; } + +private: + void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx); + ACS_State* new_state(); + void Propagate_faillink(); + + Match_Result MatchHelper(const char*, uint32 len) const; + +#ifdef VERIFY + void Verify_Result(const char* subject, const Match_Result* r) const; + void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len); + const char* get_ith_Pattern(unsigned i) const { + ASSERT(i < _pattern_vect.size()); + return _pattern_vect.at(i); + } + unsigned get_ith_Pattern_Len(unsigned i) const { + ASSERT(i < _pattern_lens.size()); + return _pattern_lens.at(i); + } +#else + void Verify_Result(const char* subject, const Match_Result* r) const { + (void)subject; (void)r; + } + void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) { + (void)strv; (void)strlenv; + } +#endif + +private: + ACS_State* _root; + vector<ACS_State*> _all_states; + unsigned char* _root_char; + uint32 _next_node_id; + +#ifdef VERIFY + char* _pattern_buf; + vector<int> _pattern_lens; + vector<char*> _pattern_vect; +#endif +}; + +#endif diff --git a/modules/policy/lua-aho-corasick/ac_util.hpp b/modules/policy/lua-aho-corasick/ac_util.hpp new file mode 100644 index 0000000..56fd46c --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_util.hpp @@ -0,0 +1,69 @@ +/* + Copyright (c) 2014 CloudFlare, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of CloudFlare, Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#ifndef AC_UTIL_H +#define AC_UTIL_H + +#ifdef DEBUG +#include <stdio.h> // for fprintf +#include <stdlib.h> // for abort +#endif + +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long uint64; +typedef unsigned char InputTy; + +#ifdef DEBUG + // Usage examples: ASSERT(a > b), ASSERT(foo() && "Opps, foo() reutrn 0"); + #define ASSERT(c) if (!(c))\ + { fprintf(stderr, "%s:%d Assert: %s\n", __FILE__, __LINE__, #c); abort(); } +#else + #define ASSERT(c) ((void)0) +#endif + +#define likely(x) __builtin_expect((x),1) +#define unlikely(x) __builtin_expect((x),0) + +#ifndef offsetof +#define offsetof(st, m) ((size_t)(&((st *)0)->m)) +#endif + +typedef enum { + IMPL_SLOW_VARIANT = 1, + IMPL_FAST_VARIANT = 2, +} impl_var_t; + +#define AC_MAGIC_NUM 0x5a +typedef struct { + unsigned char magic_num; + unsigned char impl_variant; +} buf_header_t; + +#endif //AC_UTIL_H diff --git a/modules/policy/lua-aho-corasick/load_ac.lua b/modules/policy/lua-aho-corasick/load_ac.lua new file mode 100644 index 0000000..eb70446 --- /dev/null +++ b/modules/policy/lua-aho-corasick/load_ac.lua @@ -0,0 +1,90 @@ +-- Helper wrappring script for loading shared object libac.so (FFI interface) +-- from package.cpath instead of LD_LIBRARTY_PATH. +-- + +local ffi = require 'ffi' +ffi.cdef[[ + void* ac_create(const char** str_v, unsigned int* strlen_v, + unsigned int v_len); + int ac_match2(void*, const char *str, int len); + void ac_free(void*); +]] + +local _M = {} + +local string_gmatch = string.gmatch +local string_match = string.match + +local ac_lib = nil +local ac_create = nil +local ac_match = nil +local ac_free = nil + +--[[ Find shared object file package.cpath, obviating the need of setting + LD_LIBRARY_PATH +]] +local function find_shared_obj(cpath, so_name) + for k, v in string_gmatch(cpath, "[^;]+") do + local so_path = string_match(k, "(.*/)") + if so_path then + -- "so_path" could be nil. e.g, the dir path component is "." + so_path = so_path .. so_name + + -- Don't get me wrong, the only way to know if a file exist is + -- trying to open it. + local f = io.open(so_path) + if f ~= nil then + io.close(f) + return so_path + end + end + end +end + +function _M.load_ac_lib() + if ac_lib ~= nil then + return ac_lib + else + local so_path = find_shared_obj(package.cpath, "libac.so") + if so_path ~= nil then + ac_lib = ffi.load(so_path) + ac_create = ac_lib.ac_create + ac_match = ac_lib.ac_match2 + ac_free = ac_lib.ac_free + return ac_lib + end + end +end + +-- Create an Aho-Corasick instance, and return the instance if it was +-- successful. +function _M.create_ac(dict) + local strnum = #dict + if ac_lib == nil then + _M.load_ac_lib() + end + + local str_v = ffi.new("const char *[?]", strnum) + local strlen_v = ffi.new("unsigned int [?]", strnum) + + for i = 1, strnum do + local s = dict[i] + str_v[i - 1] = s + strlen_v[i - 1] = #s + end + + local ac = ac_create(str_v, strlen_v, strnum); + if ac ~= nil then + return ffi.gc(ac, ac_free) + end +end + +-- Return nil if str doesn't match the dictionary, else return non-nil. +function _M.match(ac, str) + local r = ac_match(ac, str, #str); + if r >= 0 then + return r + end +end + +return _M diff --git a/modules/policy/lua-aho-corasick/mytest.cxx b/modules/policy/lua-aho-corasick/mytest.cxx new file mode 100644 index 0000000..ef3dc87 --- /dev/null +++ b/modules/policy/lua-aho-corasick/mytest.cxx @@ -0,0 +1,200 @@ +#include <stdio.h> +#include <string.h> +#include <vector> +#include "ac.h" + +using namespace std; + +///////////////////////////////////////////////////////////////////////// +// +// Test using strings from input files +// +///////////////////////////////////////////////////////////////////////// +// +class BigFileTester { +public: + BigFileTester(const char* filepath); + +private: + void Genector +privaete: + const char* _msg; + int _msg_len; + int _key_num; // number of strings in dictionary + int _key_len_idx; +}; + +///////////////////////////////////////////////////////////////////////// +// +// Simple (yet maybe tricky) testings +// +///////////////////////////////////////////////////////////////////////// +// +typedef struct { + const char* str; + const char* match; +} StrPair; + +typedef struct { + const char* name; + const char** dict; + StrPair* strpairs; + int dict_len; + int strpair_num; +} TestingCase; + +class Tests { +public: + Tests(const char* name, + const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num) { + if (!_tests) + _tests = new vector<TestingCase>; + + TestingCase tc; + tc.name = name; + tc.dict = dict; + tc.strpairs = strpairs; + tc.dict_len = dict_len; + tc.strpair_num = strpair_num; + _tests->push_back(tc); + } + + static vector<TestingCase>* Get_Tests() { return _tests; } + static void Erase_Tests() { delete _tests; _tests = 0; } + +private: + static vector<TestingCase> *_tests; +}; + +vector<TestingCase>* Tests::_tests = 0; + +static void +simple_test(void) { + int total = 0; + int fail = 0; + + vector<TestingCase> *tests = Tests::Get_Tests(); + if (!tests) + return 0; + + for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end(); + i != e; i++) { + TestingCase& t = *i; + fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); + for (int i = 0, e = t.dict_len, need_break=0; i < e; i++) { + fprintf(stdout, "%s, ", t.dict[i]); + if (need_break++ == 16) { + fputs("\n ", stdout); + need_break = 0; + } + } + fputs("]\n", stdout); + + /* Create the dictionary */ + int dict_len = t.dict_len; + ac_t* ac = ac_create(t.dict, dict_len); + + for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { + const StrPair& sp = t.strpairs[ii]; + const char *str = sp.str; // the string to be matched + const char *match = sp.match; + + fprintf(stdout, "[%3d] Testing '%s' : ", total, str); + + int len = strlen(str); + ac_result_t r = ac_match(ac, str, len); + int m_b = r.match_begin; + int m_e = r.match_end; + + // The return value per se is insane. + if (m_b > m_e || + ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { + fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); + fail ++; + continue; + } + + // If the string is not supposed to match the dictionary. + if (!match) { + if (m_b != -1 || m_e != -1) { + fail ++; + fprintf(stdout, "Not Supposed to match (%d, %d) \n", + m_b, m_e); + } else + fputs("Pass\n", stdout); + continue; + } + + // The string or its substring is match the dict. + if (m_b >= len || m_b >= len) { + fail ++; + fprintf(stdout, + "Return value >= the length of the string (%d, %d)\n", + m_b, m_e); + continue; + } else { + int mlen = strlen(match); + if ((mlen != m_e - m_b + 1) || + strncmp(str + m_b, match, mlen)) { + fail ++; + fprintf(stdout, "Fail\n"); + } else + fprintf(stdout, "Pass\n"); + } + } + fputs("\n", stdout); + ac_free(ac); + } + + fprintf(stdout, "Total : %d, Fail %d\n", total, fail); + + return fail ? -1 : 0; +} + +int +main (int argc, char** argv) { + int res = simple_test(); + return res; +}; + +/* test 1*/ +const char *dict1[] = {"he", "she", "his", "her"}; +StrPair strpair1[] = { + {"he", "he"}, {"she", "she"}, {"his", "his"}, + {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, + {"shis2", "his"}, {"ahhe", "he"} +}; +Tests test1("test 1", + dict1, sizeof(dict1)/sizeof(dict1[0]), + strpair1, sizeof(strpair1)/sizeof(strpair1[0])); + +/* test 2*/ +const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ +StrPair strpair2[] = {{"The pot had a handle", 0}}; +Tests test2("test 2", dict2, 2, strpair2, 1); + +/* test 3*/ +const char *dict3[] = {"The"}; +StrPair strpair3[] = {{"The pot had a handle", "The"}}; +Tests test3("test 3", dict3, 1, strpair3, 1); + +/* test 4*/ +const char *dict4[] = {"pot"}; +StrPair strpair4[] = {{"The pot had a handle", "pot"}}; +Tests test4("test 4", dict4, 1, strpair4, 1); + +/* test 5*/ +const char *dict5[] = {"pot "}; +StrPair strpair5[] = {{"The pot had a handle", "pot "}}; +Tests test5("test 5", dict5, 1, strpair5, 1); + +/* test 6*/ +const char *dict6[] = {"ot h"}; +StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; +Tests test6("test 6", dict6, 1, strpair6, 1); + +/* test 7*/ +const char *dict7[] = {"andle"}; +StrPair strpair7[] = {{"The pot had a handle", "andle"}}; +Tests test7("test 7", dict7, 1, strpair7, 1); diff --git a/modules/policy/lua-aho-corasick/tests/Makefile b/modules/policy/lua-aho-corasick/tests/Makefile new file mode 100644 index 0000000..54fd90f --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/Makefile @@ -0,0 +1,65 @@ +OS := $(shell uname) +ifeq ($(OS), Darwin) + SO_EXT := dylib +else + SO_EXT := so +endif + +.PHONY = all clean test runtest benchmark + +PROGRAM = ac_test +BENCHMARK = ac_bench +all: runtest + +CXXFLAGS = -O3 -g -march=native -Wall -DDEBUG +MYCXXFLAGS = -MMD -I.. $(CXXFLAGS) +%.o : %.cxx + $(CXX) $< -c $(MYCXXFLAGS) + +-include dep.cxx +SRC = test_main.cxx ac_test_simple.cxx ac_test_aggr.cxx test_bigfile.cxx + +OBJ = ${SRC:.cxx=.o} + +-include test_dep.txt +-include bench_dep.txt + +$(PROGRAM) $(BENCHMARK) : testinput/text.tar testinput/image.bin +$(PROGRAM) : $(OBJ) ../libac.$(SO_EXT) + $(CXX) $(OBJ) -L.. -lac -o $@ + -cat *.d > test_dep.txt + +$(BENCHMARK) : ac_bench.o ../libac.$(SO_EXT) + $(CXX) ac_bench.o -L.. -lac -o $@ + -cat *.d > bench_dep.txt + +ifneq ($(OS), Darwin) +runtest:$(PROGRAM) + LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* + +benchmark:$(BENCHMARK) + LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./ac_bench + +else +runtest:$(PROGRAM) + DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* + +benchmark:$(BENCHMARK) + DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./ac_bench + +endif + +testinput/text.tar: + echo "download testing files (gcc tarball)..." + if [ ! -d testinput ] ; then mkdir testinput; fi + cd testinput && \ + curl ftp://ftp.gnu.org/gnu/gcc/gcc-1.42.tar.gz -o text.tar.gz 2>/dev/null \ + && gzip -d text.tar.gz + +testinput/image.bin: + echo "download testing files.." + if [ ! -d testinput ] ; then mkdir testinput; fi + curl http://www.3dvisionlive.com/sites/default/files/Curiosity_render_hiresb.jpg -o $@ 2>/dev/null + +clean: + -rm -f *.o *.d dep.txt $(PROGRAM) $(BENCHMARK) diff --git a/modules/policy/lua-aho-corasick/tests/ac_bench.cxx b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx new file mode 100644 index 0000000..421322c --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx @@ -0,0 +1,519 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <sys/time.h> +#include <time.h> +#include <fcntl.h> +#include <unistd.h> +#include <dirent.h> +#include <libgen.h> +#include <errno.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <getopt.h> + +#include <string> +#include <vector> +#include "ac.h" +#include "ac_util.hpp" + +using namespace std; + +static bool SomethingWrong = false; + +static int iteration = 300; +static string dict_dir; +static string obj_file_dir; +static bool print_help = false; +static int piece_size = 1024; + +class PatternSet { +public: + PatternSet(const char* filepath); + ~PatternSet() { Cleanup(); } + + int getPatternNum() const { return _pat_num; } + const char** getPatternVector() const { return _patterns; } + unsigned int* getPatternLenVector() const { return _pat_len; } + + const char* getErrMessage() const { return _errmsg; } + static bool isDictFile(const char* filepath) { + if (strncmp(basename(const_cast<char*>(filepath)), "dict", 4)) + return false; + return true; + } + +private: + bool ExtractPattern(const char* filepath); + void Cleanup(); + + const char** _patterns; + unsigned int* _pat_len; + char* _mmap; + int _fd; + size_t _mmap_size; + int _pat_num; + + const char* _errmsg; +}; + +bool +PatternSet::ExtractPattern(const char* filepath) { + if (!isDictFile(filepath)) + return false; + + struct stat filestat; + if (stat(filepath, &filestat)) { + _errmsg = "fail to call stat()"; + return false; + } + + if (filestat.st_size > 4096 * 1024) { + /* It dosen't seem to be a dictionary file*/ + _errmsg = "file too big?"; + return false; + } + + _fd = open(filepath, 0); + if (_fd == -1) { + _errmsg = "fail to open dictionary file"; + return false; + } + + _mmap_size = filestat.st_size; + _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, + MAP_PRIVATE, _fd, 0); + if (_mmap == MAP_FAILED) { + _errmsg = "fail to call mmap"; + return false; + } + + const char* pat = _mmap; + vector<const char*> pat_vect; + vector<unsigned> pat_len_vect; + + for (size_t i = 0, e = filestat.st_size; i < e; i++) { + if (_mmap[i] == '\r' || _mmap[i] == '\n') { + _mmap[i] = '\0'; + int len = _mmap + i - pat; + if (len > 0) { + pat_vect.push_back(pat); + pat_len_vect.push_back(len); + } + pat = _mmap + i + 1; + } + } + + ASSERT(pat_vect.size() == pat_len_vect.size()); + + int pat_num = pat_vect.size(); + if (pat_num > 0) { + const char** p = _patterns = new const char*[pat_num]; + int i = 0; + for (vector<const char*>::iterator iter = pat_vect.begin(), + iter_e = pat_vect.end(); iter != iter_e; ++iter) { + p[i++] = *iter; + } + + i = 0; + unsigned int* q = _pat_len = new unsigned int[pat_num]; + for (vector<unsigned>::iterator iter = pat_len_vect.begin(), + iter_e = pat_len_vect.end(); iter != iter_e; ++iter) { + q[i++] = *iter; + } + } + + _pat_num = pat_num; + if (pat_num <= 0) { + _errmsg = "no pattern at all"; + return false; + } + + return true; +} + +void +PatternSet::Cleanup() { + if (_mmap != MAP_FAILED) { + munmap(_mmap, _mmap_size); + _mmap = (char*)MAP_FAILED; + _mmap_size = 0; + } + + delete[] _patterns; + delete[] _pat_len; + if (_fd != -1) + close(_fd); + _pat_num = -1; +} + +PatternSet::PatternSet(const char* filepath) { + _patterns = 0; + _pat_len = 0; + _mmap = (char*)MAP_FAILED; + _mmap_size = 0; + _pat_num = -1; + _errmsg = ""; + + if (!ExtractPattern(filepath)) + Cleanup(); +} + +bool +getFilesUnderDir(vector<string>& files, const char* path) { + files.clear(); + + DIR* dir = opendir(path); + if (!dir) + return false; + + string path_dir = path; + path_dir += "/"; + + for (;;) { + struct dirent* entry = readdir(dir); + if (entry) { + string filepath = path_dir + entry->d_name; + struct stat file_stat; + if (stat(filepath.c_str(), &file_stat)) { + closedir(dir); + return false; + } + + if (S_ISREG(file_stat.st_mode)) + files.push_back(filepath); + + continue; + } + + if (errno) { + return false; + } + break; + } + closedir(dir); + return true; +} + +class Timer { +public: + Timer() { + my_clock_gettime(&_start); + _stop = _start; + _acc.tv_sec = 0; + _acc.tv_nsec = 0; + } + + const Timer& operator += (const Timer& that) { + time_t sec = _acc.tv_sec + that._acc.tv_sec; + long nsec = _acc.tv_nsec + that._acc.tv_nsec; + if (nsec > 1000000000) { + nsec -= 1000000000; + sec += 1; + } + _acc.tv_sec = sec; + _acc.tv_nsec = nsec; + return *this; + } + + // return duration in us + size_t getDuration() const { + return _acc.tv_sec * (size_t)1000000 + _acc.tv_nsec/1000; + } + + void Start(bool acc=true) { + my_clock_gettime(&_start); + } + + void Stop() { + my_clock_gettime(&_stop); + struct timespec t = CalcDuration(); + _acc = add_duration(_acc, t); + } + +private: + int my_clock_gettime(struct timespec* t) { +#ifdef __linux + return clock_gettime(CLOCK_PROCESS_CPUTIME_ID, t); +#else + struct timeval tv; + int rc = gettimeofday(&tv, 0); + t->tv_sec = tv.tv_sec; + t->tv_nsec = tv.tv_usec * 1000; + return rc; +#endif + } + + struct timespec add_duration(const struct timespec& dur1, + const struct timespec& dur2) { + time_t sec = dur1.tv_sec + dur2.tv_sec; + long nsec = dur1.tv_nsec + dur2.tv_nsec; + if (nsec > 1000000000) { + nsec -= 1000000000; + sec += 1; + } + timespec t; + t.tv_sec = sec; + t.tv_nsec = nsec; + + return t; + } + + struct timespec CalcDuration() const { + timespec diff; + if ((_stop.tv_nsec - _start.tv_nsec)<0) { + diff.tv_sec = _stop.tv_sec - _start.tv_sec - 1; + diff.tv_nsec = 1000000000 + _stop.tv_nsec - _start.tv_nsec; + } else { + diff.tv_sec = _stop.tv_sec - _start.tv_sec; + diff.tv_nsec = _stop.tv_nsec - _start.tv_nsec; + } + return diff; + } + + struct timespec _start; + struct timespec _stop; + struct timespec _acc; +}; + +class Benchmark { +public: + Benchmark(const PatternSet& pat_set, const char* infile): + _pat_set(pat_set), _infile(infile) { + _mmap = (char*)MAP_FAILED; + _file_sz = 0; + _fd = -1; + } + + ~Benchmark() { + if (_mmap != MAP_FAILED) + munmap(_mmap, _file_sz); + if (_fd != -1) + close(_fd); + } + + bool Run(int iteration); + const Timer& getTimer() const { return _timer; } + +private: + const PatternSet& _pat_set; + const char* _infile; + char* _mmap; + int _fd; + size_t _file_sz; // input file size + Timer _timer; +}; + +bool +Benchmark::Run(int iteration) { + if (_pat_set.getPatternNum() <= 0) { + SomethingWrong = true; + return false; + } + + if (_mmap == MAP_FAILED) { + struct stat filestat; + if (stat(_infile, &filestat)) { + SomethingWrong = true; + return false; + } + + if (!S_ISREG(filestat.st_mode)) { + SomethingWrong = true; + return false; + } + + _fd = open(_infile, 0); + if (_fd == -1) + return false; + + _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, + MAP_PRIVATE, _fd, 0); + + if (_mmap == MAP_FAILED) { + SomethingWrong = true; + return false; + } + + _file_sz = filestat.st_size; + } + + ac_t* ac = ac_create(_pat_set.getPatternVector(), + _pat_set.getPatternLenVector(), + _pat_set.getPatternNum()); + if (!ac) { + SomethingWrong = true; + return false; + } + + int piece_num = _file_sz/piece_size; + + _timer.Start(false); + + /* Stupid compiler may not be able to promote piece_size into register. + * Do it manually. + */ + int piece_sz = piece_size; + for (int i = 0; i < iteration; i++) { + size_t match_ofst = 0; + for (int piece_idx = 0; piece_idx < piece_num; piece_idx ++) { + ac_match2(ac, _mmap + match_ofst, piece_sz); + match_ofst += piece_sz; + } + if (match_ofst != _file_sz) + ac_match2(ac, _mmap + match_ofst, _file_sz - match_ofst); + } + _timer.Stop(); + return true; +} + +const char* short_opt = "hd:f:i:p:"; +const struct option long_opts[] = { + {"help", no_argument, 0, 'h'}, + {"iteration", required_argument, 0, 'i'}, + {"dictionary-dir", required_argument, 0, 'd'}, + {"obj-file-dir", required_argument, 0, 'f'}, + {"piece-size", required_argument, 0, 'p'}, +}; + +static void +PrintHelp(const char* prog_name) { + const char* msg = +"Usage %s [OPTIONS]\n" +" -d, --dictionary-dir : specify the dictionary directory (./dict by default)\n" +" -f, --obj-file-dir : specify the object file directory\n" +" (./testinput by default)\n" +" -i, --iteration : Run this many iteration for each pattern match\n" +" -p, --piece-size : The size of 'piece' in byte. The input file is\n" +" divided into pieces, and match function is working\n" +" on one piece at a time. The default size of piece\n" +" is 1k byte.\n"; + + fprintf(stdout, msg, prog_name); +} + +static bool +getOptions(int argc, char** argv) { + bool dict_dir_set = false; + bool objfile_dir_set = false; + int opt_index; + + while (1) { + if (print_help) break; + + int c = getopt_long(argc, argv, short_opt, long_opts, &opt_index); + + if (c == -1) break; + if (c == 0) { c = long_opts[opt_index].val; } + + switch(c) { + case 'h': + print_help = true; + break; + + case 'i': + iteration = atol(optarg); + break; + + case 'd': + dict_dir = optarg; + dict_dir_set = true; + break; + + case 'f': + obj_file_dir = optarg; + objfile_dir_set = true; + break; + + case 'p': + piece_size = atol(optarg); + break; + + case '?': + default: + return false; + } + } + + if (print_help) + return true; + + string basedir(dirname(argv[0])); + if (!dict_dir_set) + dict_dir = basedir + "/dict"; + + if (!objfile_dir_set) + obj_file_dir = basedir + "/testinput"; + + return true; +} + +int +main(int argc, char** argv) { + if (!getOptions(argc, argv)) + return -1; + + if (print_help) { + PrintHelp(argv[0]); + return 0; + } + +#ifndef __linux + fprintf(stdout, "\n!!!WARNING: On this OS, the execution time is measured" + " by gettimeofday(2) which is imprecise!!!\n\n"); +#endif + + fprintf(stdout, "Test with iteration = %d, piece size = %d, and", + iteration, piece_size); + fprintf(stdout, "\n dictionary dir = %s\n object file dir = %s\n\n", + dict_dir.c_str(), obj_file_dir.c_str()); + + vector<string> dict_files; + vector<string> input_files; + + if (!getFilesUnderDir(dict_files, dict_dir.c_str())) { + fprintf(stdout, "fail to find dictionary files\n"); + return -1; + } + + if (!getFilesUnderDir(input_files, obj_file_dir.c_str())) { + fprintf(stdout, "fail to find test input files\n"); + return -1; + } + + for (vector<string>::iterator diter = dict_files.begin(), + diter_e = dict_files.end(); diter != diter_e; ++diter) { + + const char* dict_name = diter->c_str(); + if (!PatternSet::isDictFile(dict_name)) + continue; + + PatternSet ps(dict_name); + if (ps.getPatternNum() <= 0) { + fprintf(stdout, "fail to open dictionary file %s : %s\n", + dict_name, ps.getErrMessage()); + SomethingWrong = true; + continue; + } + + fprintf(stdout, "Using dictionary %s\n", dict_name); + Timer timer; + for (vector<string>::iterator iter = input_files.begin(), + iter_e = input_files.end(); iter != iter_e; ++iter) { + fprintf(stdout, " testing %s ... ", iter->c_str()); + fflush(stdout); + Benchmark bm(ps, iter->c_str()); + bm.Run(iteration); + const Timer& t = bm.getTimer(); + timer += bm.getTimer(); + fprintf(stdout, "elapsed %.3f\n", t.getDuration() / 1000000.0); + } + + fprintf(stdout, + "\n==========================================================\n" + " Total Elapse %.3f\n\n", timer.getDuration() / 1000000.0); + } + + return SomethingWrong ? -1 : 0; +} diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx new file mode 100644 index 0000000..4ea02bc --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx @@ -0,0 +1,135 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + +namespace { +class ACBigFileTester : public BigFileTester { +public: + ACBigFileTester(const char* filepath) : BigFileTester(filepath){}; + +private: + virtual buf_header_t* PM_Create(const char** strv, uint32* strlenv, + uint32 vect_len) { + return (buf_header_t*)ac_create(strv, strlenv, vect_len); + } + + virtual void PM_Free(buf_header_t* PM) { ac_free(PM); } + virtual bool Run_Helper(buf_header_t* PM); +}; + +class ACTestAggressive: public ACTestBase { +public: + ACTestAggressive(const vector<const char*>& files, const char* banner) + : ACTestBase(banner), _files(files) {} + virtual bool Run(); + +private: + void PrintSummary(int total, int fail) { + fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); + fflush(stdout); + } + vector<const char*> _files; +}; + +} // end of anonymous namespace + +bool +ACBigFileTester::Run_Helper(buf_header_t* PM) { + int fail = 0; + // advance one chunk at a time. + int len = _msg_len; + int chunk_sz = _chunk_sz; + + vector<const char*> c_style_keys; + for (int i = 0, e = _keys.size(); i != e; i++) { + const char* key = _keys[i].first; + int len = _keys[i].second; + char *t = new char[len+1]; + memcpy(t, key, len); + t[len] = '\0'; + c_style_keys.push_back(t); + } + + for (int ofst = 0, chunk_idx = 0, chunk_num = _chunk_num; + chunk_idx < chunk_num; ofst += chunk_sz, chunk_idx++) { + const char* substring = _msg + ofst; + ac_result_t r = ac_match((ac_t*)(void*)PM, substring , len - ofst); + int m_b = r.match_begin; + int m_e = r.match_end; + + if (m_b < 0 || m_e < 0 || m_e <= m_b || m_e >= len) { + fprintf(stdout, "fail to find match substring[%d:%d])\n", + ofst, len - 1); + fail ++; + continue; + } + + const char* match_str = _msg + len; + int strstr_len = 0; + int key_idx = -1; + + for (int i = 0, e = c_style_keys.size(); i != e; i++) { + const char* key = c_style_keys[i]; + if (const char *m = strstr(substring, key)) { + if (m < match_str) { + match_str = m; + strstr_len = _keys[i].second; + key_idx = i; + } + } + } + ASSERT(key_idx != -1); + if ((match_str - substring != m_b)) { + fprintf(stdout, + "Fail to find match substring[%d:%d])," + " expected to find match at offset %d instead of %d\n", + ofst, len - 1, + (int)(match_str - _msg), ofst + m_b); + fprintf(stdout, "%d vs %d (key idx %d)\n", strstr_len, m_e - m_b + 1, key_idx); + PrintStr(stdout, match_str, strstr_len); + fprintf(stdout, "\n"); + PrintStr(stdout, _msg + ofst + m_b, + m_e - m_b + 1); + fprintf(stdout, "\n"); + fail ++; + } + } + for (vector<const char*>::iterator i = c_style_keys.begin(), + e = c_style_keys.end(); i != e; i++) { + delete[] *i; + } + + return fail == 0; +} + +bool +ACTestAggressive::Run() { + int fail = 0; + for (vector<const char*>::iterator i = _files.begin(), e = _files.end(); + i != e; i++) { + ACBigFileTester bft(*i); + if (!bft.Run()) + fail ++; + } + return fail == 0; +} + +bool +Run_AC_Aggressive_Test(const vector<const char*>& files) { + ACTestAggressive t(files, "AC Aggressive test"); + t.PrintBanner(); + return t.Run(); +} diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx new file mode 100644 index 0000000..fa2d7fd --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx @@ -0,0 +1,275 @@ +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + +namespace { +typedef struct { + const char* str; + const char* match; +} StrPair; + +typedef enum { + MV_FIRST_MATCH = 0, + MV_LEFT_LONGEST = 1, +} MatchVariant; + +typedef struct { + const char* name; + const char** dict; + StrPair* strpairs; + int dict_len; + int strpair_num; + MatchVariant match_variant; +} TestingCase; + +class Tests { +public: + Tests(const char* name, + const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num, + MatchVariant mv = MV_FIRST_MATCH) { + if (!_tests) + _tests = new vector<TestingCase>; + + TestingCase tc; + tc.name = name; + tc.dict = dict; + tc.strpairs = strpairs; + tc.dict_len = dict_len; + tc.strpair_num = strpair_num; + tc.match_variant = mv; + _tests->push_back(tc); + } + + static vector<TestingCase>* Get_Tests() { return _tests; } + static void Erase_Tests() { delete _tests; _tests = 0; } + +private: + static vector<TestingCase> *_tests; +}; + +class LeftLongestTests : public Tests { +public: + LeftLongestTests (const char* name, const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num): + Tests(name, dict, dict_len, strpairs, strpair_num, MV_LEFT_LONGEST) { + } +}; + +vector<TestingCase>* Tests::_tests = 0; + +class ACTestSimple: public ACTestBase { +public: + ACTestSimple(const char* banner) : ACTestBase(banner) {} + virtual bool Run(); + +private: + void PrintSummary(int total, int fail) { + fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); + fflush(stdout); + } +}; +} + +bool +ACTestSimple::Run() { + int total = 0; + int fail = 0; + + vector<TestingCase> *tests = Tests::Get_Tests(); + if (!tests) { + PrintSummary(0, 0); + return true; + } + + for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end(); + i != e; i++) { + TestingCase& t = *i; + int dict_len = t.dict_len; + unsigned int* strlen_v = new unsigned int[dict_len]; + + fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); + for (int i = 0, need_break=0; i < dict_len; i++) { + const char* s = t.dict[i]; + fprintf(stdout, "%s, ", s); + strlen_v[i] = strlen(s); + if (need_break++ == 16) { + fputs("\n ", stdout); + need_break = 0; + } + } + fputs("]\n", stdout); + + /* Create the dictionary */ + ac_t* ac = ac_create(t.dict, strlen_v, dict_len); + delete[] strlen_v; + + for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { + const StrPair& sp = t.strpairs[ii]; + const char *str = sp.str; // the string to be matched + const char *match = sp.match; + + fprintf(stdout, "[%3d] Testing '%s' : ", total, str); + + int len = strlen(str); + ac_result_t r; + if (t.match_variant == MV_FIRST_MATCH) + r = ac_match(ac, str, len); + else if (t.match_variant == MV_LEFT_LONGEST) + r = ac_match_longest_l(ac, str, len); + else { + ASSERT(false && "Unknown variant"); + } + + int m_b = r.match_begin; + int m_e = r.match_end; + + // The return value per se is insane. + if (m_b > m_e || + ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { + fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); + fail ++; + continue; + } + + // If the string is not supposed to match the dictionary. + if (!match) { + if (m_b != -1 || m_e != -1) { + fail ++; + fprintf(stdout, "Not Supposed to match (%d, %d) \n", + m_b, m_e); + } else + fputs("Pass\n", stdout); + continue; + } + + // The string or its substring is match the dict. + if (m_b >= len || m_b >= len) { + fail ++; + fprintf(stdout, + "Return value >= the length of the string (%d, %d)\n", + m_b, m_e); + continue; + } else { + int mlen = strlen(match); + if ((mlen != m_e - m_b + 1) || + strncmp(str + m_b, match, mlen)) { + fail ++; + fprintf(stdout, "Fail\n"); + } else + fprintf(stdout, "Pass\n"); + } + } + fputs("\n", stdout); + ac_free(ac); + } + + PrintSummary(total, fail); + return fail == 0; +} + +bool +Run_AC_Simple_Test() { + ACTestSimple t("AC Simple test"); + t.PrintBanner(); + return t.Run(); +} + +////////////////////////////////////////////////////////////////////////////// +// +// Testing cases for first-match variant (i.e. test ac_match()) +// +////////////////////////////////////////////////////////////////////////////// +// + +/* test 1*/ +const char *dict1[] = {"he", "she", "his", "her"}; +StrPair strpair1[] = { + {"he", "he"}, {"she", "she"}, {"his", "his"}, + {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, + {"shis2", "his"}, {"ahhe", "he"} +}; +Tests test1("test 1", + dict1, sizeof(dict1)/sizeof(dict1[0]), + strpair1, sizeof(strpair1)/sizeof(strpair1[0])); + +/* test 2*/ +const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ +StrPair strpair2[] = {{"The pot had a handle", 0}}; +Tests test2("test 2", dict2, 2, strpair2, 1); + +/* test 3*/ +const char *dict3[] = {"The"}; +StrPair strpair3[] = {{"The pot had a handle", "The"}}; +Tests test3("test 3", dict3, 1, strpair3, 1); + +/* test 4*/ +const char *dict4[] = {"pot"}; +StrPair strpair4[] = {{"The pot had a handle", "pot"}}; +Tests test4("test 4", dict4, 1, strpair4, 1); + +/* test 5*/ +const char *dict5[] = {"pot "}; +StrPair strpair5[] = {{"The pot had a handle", "pot "}}; +Tests test5("test 5", dict5, 1, strpair5, 1); + +/* test 6*/ +const char *dict6[] = {"ot h"}; +StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; +Tests test6("test 6", dict6, 1, strpair6, 1); + +/* test 7*/ +const char *dict7[] = {"andle"}; +StrPair strpair7[] = {{"The pot had a handle", "andle"}}; +Tests test7("test 7", dict7, 1, strpair7, 1); + +const char *dict8[] = {"aaab"}; +StrPair strpair8[] = {{"aaaaaaab", "aaab"}}; +Tests test8("test 8", dict8, 1, strpair8, 1); + +const char *dict9[] = {"haha", "z"}; +StrPair strpair9[] = {{"aaaaz", "z"}, {"z", "z"}}; +Tests test9("test 9", dict9, 2, strpair9, 2); + +/* test the case when input string dosen't contain even a single char + * of the pattern in dictionary. + */ +const char *dict10[] = {"abc"}; +StrPair strpair10[] = {{"cde", 0}}; +Tests test10("test 10", dict10, 1, strpair10, 1); + + +////////////////////////////////////////////////////////////////////////////// +// +// Testing cases for first longest match variant (i.e. +// test ac_match_longest_l()) +// +////////////////////////////////////////////////////////////////////////////// +// + +// This was actually first motivation for left-longest-match +const char *dict100[] = {"Mozilla", "Mozilla Mobile"}; +StrPair strpair100[] = {{"User Agent containing string Mozilla Mobile", "Mozilla Mobile"}}; +LeftLongestTests test100("l_test 100", dict100, 2, strpair100, 1); + +// Dict with single char is tricky +const char *dict101[] = {"a", "abc"}; +StrPair strpair101[] = {{"abcdef", "abc"}}; +LeftLongestTests test101("l_test 101", dict101, 2, strpair101, 1); + +// Testing case with partially overlapping patterns. The purpose is to +// check if the fail-link leading from terminal state is correct. +// +// The fail-link leading from terminal-state does not matter in +// match-first-occurrence variant, as it stop when a terminal is hit. +// +const char *dict102[] = {"abc", "bcdef"}; +StrPair strpair102[] = {{"abcdef", "bcdef"}}; +LeftLongestTests test102("l_test 102", dict102, 2, strpair102, 1); diff --git a/modules/policy/lua-aho-corasick/tests/dict/README.txt b/modules/policy/lua-aho-corasick/tests/dict/README.txt new file mode 100644 index 0000000..cd50b41 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/dict/README.txt @@ -0,0 +1 @@ +This directory contains pattern set of benchmark purpose. diff --git a/modules/policy/lua-aho-corasick/tests/dict/dict1.txt b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt new file mode 100644 index 0000000..94085a9 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt @@ -0,0 +1,11 @@ +false_return@ +forloop#haha +wtfprogram +mmaporunmap +ThIs?Module!IsEssential +struct rtlwtf +gettIMEOfdayWrong +edistribution_and_use_in_@source +Copyright~#@ +while {! +!%SQLinje diff --git a/modules/policy/lua-aho-corasick/tests/load_ac_test.lua b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua new file mode 100644 index 0000000..7fb7db9 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua @@ -0,0 +1,82 @@ +-- This script is to test load_ac.lua +-- +-- Some notes: +-- 1. The purpose of this script is not to check if the libac.so work +-- properly, it is to check if there are something stupid in load_ac.lua +-- +-- 2. There are bunch of collectgarbage() calls, the purpose is to make +-- sure the shared lib is not unloaded after GC. + +-- load_ac.lua looks up libac.so via package.cpath rather than LD_LIBRARY_PATH, +-- prepend (instead of appending) some insane paths here to see if it quit +-- prematurely. +-- +package.cpath = ".;./?.so;" .. package.cpath + +local ac = require "load_ac" + +local ac_create = ac.create_ac +local ac_match = ac.match +local string_fmt = string.format +local string_sub = string.sub + +local err_cnt = 0 +local function mytest(testname, dict, match, notmatch) + print(">Testing ", testname) + + io.write(string_fmt("Dictionary: ")); + for i=1, #dict do + io.write(string_fmt("%s, ", dict[i])) + end + print "" + + local ac_inst = ac_create(dict); + collectgarbage() + for i=1, #match do + local str = match[i] + io.write(string_fmt("Matching %s, ", str)) + local b = ac_match(ac_inst, str) + if b then + print "pass" + else + err_cnt = err_cnt + 1 + print "fail" + end + collectgarbage() + end + + if notmatch == nil then + return + end + + collectgarbage() + + for i = 1, #notmatch do + local str = notmatch[i] + io.write(string_fmt("*Matching %s, ", str)) + local r = ac_match(ac_inst, str) + if r then + err_cnt = err_cnt + 1 + print("fail") + else + print("succ") + end + collectgarbage() + end + ac_inst = nil + collectgarbage() +end + +print("") +print("====== Test to see if load_ac.lua works properly ========") + +mytest("test1", + {"he", "she", "his", "her", "str\0ing"}, + -- matching cases + { "he", "she", "his", "hers", "ahe", "shhe", "shis2", "ahhe", "str\0ing" }, + + -- not matching case + {"str\0", "str"} + ) + +os.exit((err_cnt == 0) and 0 or 1) diff --git a/modules/policy/lua-aho-corasick/tests/lua_test.lua b/modules/policy/lua-aho-corasick/tests/lua_test.lua new file mode 100644 index 0000000..cfe178f --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/lua_test.lua @@ -0,0 +1,67 @@ +-- This script is to test ahocorasick.so not libac.so +-- +local ac = require "ahocorasick" + +local ac_create = ac.create +local ac_match = ac.match +local string_fmt = string.format +local string_sub = string.sub + +local err_cnt = 0 +local function mytest(testname, dict, match, notmatch) + print(">Testing ", testname) + + io.write(string_fmt("Dictionary: ")); + for i=1, #dict do + io.write(string_fmt("%s, ", dict[i])) + end + print "" + + local ac_inst = ac_create(dict); + for i=1, #match do + local str = match[i][1] + local substr = match[i][2] + io.write(string_fmt("Matching %s, ", str)) + local b, e = ac_match(ac_inst, str) + if b and e and (string_sub(str, b+1, e+1) == substr) then + print "pass" + else + err_cnt = err_cnt + 1 + print "fail" + end + --print("gc is called") + collectgarbage() + end + + if notmatch == nil then + return + end + + for i = 1, #notmatch do + local str = notmatch[i] + io.write(string_fmt("*Matching %s, ", str)) + local r = ac_match(ac_inst, str) + if r then + err_cnt = err_cnt + 1 + print("fail") + else + print("succ") + end + collectgarbage() + end +end + +mytest("test1", + {"he", "she", "his", "her", "str\0ing"}, + -- matching cases + { {"he", "he"}, {"she", "she"}, {"his", "his"}, {"hers", "he"}, + {"ahe", "he"}, {"shhe", "he"}, {"shis2", "his"}, {"ahhe", "he"}, + {"str\0ing", "str\0ing"} + }, + + -- not matching case + {"str\0", "str"} + + ) + +os.exit((err_cnt == 0) and 0 or 1) diff --git a/modules/policy/lua-aho-corasick/tests/test_base.hpp b/modules/policy/lua-aho-corasick/tests/test_base.hpp new file mode 100644 index 0000000..7758371 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_base.hpp @@ -0,0 +1,60 @@ +#ifndef TEST_BASE_H +#define TEST_BASE_H + +#include <stdio.h> +#include <string> +#include <stdint.h> + +using namespace std; +class ACTestBase { +public: + ACTestBase(const char* name) :_banner(name) {} + virtual void PrintBanner() { + fprintf(stdout, "\n===== %s ====\n", _banner.c_str()); + } + + virtual bool Run() = 0; +private: + string _banner; +}; + +typedef std::pair<const char*, int> StrInfo; +class BigFileTester { +public: + BigFileTester(const char* filepath); + virtual ~BigFileTester() { Cleanup(); } + + bool Run(); + +protected: + virtual buf_header_t* PM_Create(const char** strv, uint32_t* strlenv, + uint32_t vect_len) = 0; + virtual void PM_Free(buf_header_t*) = 0; + virtual bool Run_Helper(buf_header_t* PM) = 0; + + // Return true if the '\0' is valid char of a string. + virtual bool Str_C_Style() { return true; } + + bool GenerateKeys(); + void Cleanup(); + void PrintStr(FILE*, const char* str, int len); + +protected: + const char* _filepath; + int _fd; + vector<StrInfo> _keys; + char* _msg; + int _msg_len; + int _key_num; // number of strings in dictionary + int _chunk_sz; + int _chunk_num; + + int _max_key_num; + int _key_min_len; + int _key_max_len; +}; + +extern bool Run_AC_Simple_Test(); +extern bool Run_AC_Aggressive_Test(const vector<const char*>& files); + +#endif diff --git a/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx new file mode 100644 index 0000000..f189d8d --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx @@ -0,0 +1,167 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +/////////////////////////////////////////////////////////////////////////// +// +// Implementation of BigFileTester +// +/////////////////////////////////////////////////////////////////////////// +// +BigFileTester::BigFileTester(const char* filepath) { + _filepath = filepath; + _fd = -1; + _msg = (char*)MAP_FAILED; + _msg_len = 0; + _key_num = 0; + _chunk_sz = 0; + _chunk_num = 0; + + _max_key_num = 100; + _key_min_len = 20; + _key_max_len = 80; +} + +void +BigFileTester::Cleanup() { + if (_msg != MAP_FAILED) { + munmap((void*)_msg, _msg_len); + _msg = (char*)MAP_FAILED; + _msg_len = 0; + } + + if (_fd != -1) { + close(_fd); + _fd = -1; + } +} + +bool +BigFileTester::GenerateKeys() { + int chunk_sz = 4096; + int max_key_num = _max_key_num; + int key_min_len = _key_min_len; + int key_max_len = _key_max_len; + + int t = _msg_len / chunk_sz; + int keynum = t > max_key_num ? max_key_num : t; + + if (keynum <= 4) { + // file is too small + return false; + } + chunk_sz = _msg_len / keynum; + _chunk_sz = chunk_sz; + + // For each chunck, "randomly" grab a sub-string searving + // as key. + int random_ofst[] = { 12, 30, 23, 15 }; + int rofstsz = sizeof(random_ofst)/sizeof(random_ofst[0]); + int ofst = 0; + const char* msg = _msg; + _chunk_num = keynum - 1; + for (int idx = 0, e = _chunk_num; idx < e; idx++) { + const char* key = msg + ofst + idx % rofstsz; + int key_len = key_min_len + idx % (key_max_len - key_min_len); + _keys.push_back(StrInfo(key, key_len)); + ofst += chunk_sz; + } + return true; +} + +bool +BigFileTester::Run() { + // Step 1: Bring the file into memory + fprintf(stdout, "Testing using file '%s'...\n", _filepath); + + int fd = _fd = ::open(_filepath, O_RDONLY); + if (fd == -1) { + perror("open"); + return false; + } + + struct stat sb; + if (fstat(fd, &sb) == -1) { + perror("fstat"); + return false; + } + + if (!S_ISREG (sb.st_mode)) { + fprintf(stderr, "%s is not regular file\n", _filepath); + return false; + } + + int ten_M = 1024 * 1024 * 10; + int map_sz = _msg_len = sb.st_size > ten_M ? ten_M : sb.st_size; + char* p = _msg = + (char*)mmap (0, map_sz, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); + if (p == MAP_FAILED) { + perror("mmap"); + return false; + } + + // Get rid of '\0' if we are picky at it. + if (Str_C_Style()) { + for (int i = 0; i < map_sz; i++) { if (!p[i]) p[i] = 'a'; } + p[map_sz - 1] = 0; + } + + // Step 2: "Fabricate" some keys from the file. + if (!GenerateKeys()) { + close(fd); + return false; + } + + // Step 3: Create PM instance + const char** keys = new const char*[_keys.size()]; + unsigned int* keylens = new unsigned int[_keys.size()]; + + int i = 0; + for (vector<StrInfo>::iterator si = _keys.begin(), se = _keys.end(); + si != se; si++, i++) { + const StrInfo& strinfo = *si; + keys[i] = strinfo.first; + keylens[i] = strinfo.second; + } + + buf_header_t* PM = PM_Create(keys, keylens, i); + delete[] keys; + delete[] keylens; + + // Step 4: Run testing + bool res = Run_Helper(PM); + PM_Free(PM); + + // Step 5: Clanup + munmap(p, map_sz); + _msg = (char*)MAP_FAILED; + close(fd); + _fd = -1; + + fprintf(stdout, "%s\n", res ? "succ" : "fail"); + return res; +} + +void +BigFileTester::PrintStr(FILE* f, const char* str, int len) { + fprintf(f, "{"); + for (int i = 0; i < len; i++) { + unsigned char c = str[i]; + if (isprint(c)) + fprintf(f, "'%c', ", c); + else + fprintf(f, "%#x, ", c); + } + fprintf(f, "}"); +}; diff --git a/modules/policy/lua-aho-corasick/tests/test_main.cxx b/modules/policy/lua-aho-corasick/tests/test_main.cxx new file mode 100644 index 0000000..b4f5225 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_main.cxx @@ -0,0 +1,33 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + + +///////////////////////////////////////////////////////////////////////// +// +// Simple (yet maybe tricky) testings +// +///////////////////////////////////////////////////////////////////////// +// +int +main (int argc, char** argv) { + bool succ = Run_AC_Simple_Test(); + + vector<const char*> files; + for (int i = 1; i < argc; i++) { files.push_back(argv[i]); } + succ = Run_AC_Aggressive_Test(files) && succ; + + return succ ? 0 : -1; +}; diff --git a/modules/policy/noipv6.test.integr/broken-ipv6.rpl b/modules/policy/noipv6.test.integr/broken-ipv6.rpl new file mode 100644 index 0000000..499974d --- /dev/null +++ b/modules/policy/noipv6.test.integr/broken-ipv6.rpl @@ -0,0 +1,46 @@ +; config options + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test that IPv6 is not used by kresd. + +RANGE_BEGIN 0 100 + ADDRESS ::1:2:3:4 +RANGE_END + +RANGE_BEGIN 0 100 + ADDRESS 1.2.3.4 + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +www.test.org A +SECTION ANSWER +www.test.org 3600 A 4.3.2.1 +ENTRY_END + +RANGE_END + + +STEP 10 QUERY +ENTRY_BEGIN +REPLY RD AD +SECTION QUESTION +www.test.org A +ENTRY_END + +STEP 20 CHECK_ANSWER +ENTRY_BEGIN +MATCH all answer +REPLY QR RD RA NOERROR +SECTION QUESTION +www.test.org A +SECTION ANSWER +www.test.org 3600 A 4.3.2.1 +SECTION AUTHORITY +SECTION ADDITIONAL +ENTRY_END + +SCENARIO_END diff --git a/modules/policy/noipv6.test.integr/deckard.yaml b/modules/policy/noipv6.test.integr/deckard.yaml new file mode 100644 index 0000000..c353b1c --- /dev/null +++ b/modules/policy/noipv6.test.integr/deckard.yaml @@ -0,0 +1,12 @@ +programs: +- name: kresd + binary: kresd + additional: + - -f + - "1" + templates: + - modules/policy/noipv6.test.integr/kresd_config.j2 + - tests/hints_zone.j2 + configs: + - config + - hints diff --git a/modules/policy/noipv6.test.integr/kresd_config.j2 b/modules/policy/noipv6.test.integr/kresd_config.j2 new file mode 100644 index 0000000..56cdb01 --- /dev/null +++ b/modules/policy/noipv6.test.integr/kresd_config.j2 @@ -0,0 +1,50 @@ +{% raw %} +net.ipv6 = false +policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' }))) + +-- Disable RFC8145 signaling, scenario doesn't provide expected answers +if ta_signal_query then + modules.unload('ta_signal_query') +end + +-- Disable RFC8109 priming, scenario doesn't provide expected answers +if priming then + modules.unload('priming') +end + +-- Disable this module because it make one priming query +if detect_time_skew then + modules.unload('detect_time_skew') +end + +_hint_root_file('hints') +cache.size = 2*MB +verbose(true) +{% endraw %} + +net = { '{{SELF_ADDR}}' } + + +{% if QMIN == "false" %} +option('NO_MINIMIZE', true) +{% else %} +option('NO_MINIMIZE', false) +{% endif %} + + +-- Self-checks on globals +assert(help() ~= nil) +assert(worker.id ~= nil) +-- Self-checks on facilities +assert(cache.count() == 0) +assert(cache.stats() ~= nil) +assert(cache.backends() ~= nil) +assert(worker.stats() ~= nil) +assert(net.interfaces() ~= nil) +-- Self-checks on loaded stuff +assert(net.list()['{{SELF_ADDR}}']) +assert(#modules.list() > 0) +-- Self-check timers +ev = event.recurrent(1 * sec, function (ev) return 1 end) +event.cancel(ev) +ev = event.after(0, function (ev) return 1 end) diff --git a/modules/policy/noipvx.test.integr/broken-ipvx.rpl b/modules/policy/noipvx.test.integr/broken-ipvx.rpl new file mode 100644 index 0000000..9bb9a8c --- /dev/null +++ b/modules/policy/noipvx.test.integr/broken-ipvx.rpl @@ -0,0 +1,34 @@ +; config options + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test that neither IPv6 nor IPv4 is used by kresd :-) + +RANGE_BEGIN 0 100 + ADDRESS ::1:2:3:4 +RANGE_END + +RANGE_BEGIN 0 100 + ADDRESS 1.2.3.4 +RANGE_END + + +STEP 10 QUERY +ENTRY_BEGIN +REPLY RD AD +SECTION QUESTION +www.test.org A +ENTRY_END + +STEP 20 CHECK_ANSWER +ENTRY_BEGIN +MATCH all answer +REPLY QR RD RA SERVFAIL +SECTION QUESTION +www.test.org A +SECTION ANSWER +SECTION AUTHORITY +SECTION ADDITIONAL +ENTRY_END + +SCENARIO_END diff --git a/modules/policy/noipvx.test.integr/deckard.yaml b/modules/policy/noipvx.test.integr/deckard.yaml new file mode 100644 index 0000000..a71a56e --- /dev/null +++ b/modules/policy/noipvx.test.integr/deckard.yaml @@ -0,0 +1,12 @@ +programs: +- name: kresd + binary: kresd + additional: + - -f + - "1" + templates: + - modules/policy/noipvx.test.integr/kresd_config.j2 + - tests/hints_zone.j2 + configs: + - config + - hints diff --git a/modules/policy/noipvx.test.integr/kresd_config.j2 b/modules/policy/noipvx.test.integr/kresd_config.j2 new file mode 100644 index 0000000..4b5b576 --- /dev/null +++ b/modules/policy/noipvx.test.integr/kresd_config.j2 @@ -0,0 +1,51 @@ +{% raw %} +net.ipv4 = false +net.ipv6 = false +policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' }))) + +-- Disable RFC8145 signaling, scenario doesn't provide expected answers +if ta_signal_query then + modules.unload('ta_signal_query') +end + +-- Disable RFC8109 priming, scenario doesn't provide expected answers +if priming then + modules.unload('priming') +end + +-- Disable this module because it make one priming query +if detect_time_skew then + modules.unload('detect_time_skew') +end + +_hint_root_file('hints') +cache.size = 2*MB +verbose(true) +{% endraw %} + +net = { '{{SELF_ADDR}}' } + + +{% if QMIN == "false" %} +option('NO_MINIMIZE', true) +{% else %} +option('NO_MINIMIZE', false) +{% endif %} + + +-- Self-checks on globals +assert(help() ~= nil) +assert(worker.id ~= nil) +-- Self-checks on facilities +assert(cache.count() == 0) +assert(cache.stats() ~= nil) +assert(cache.backends() ~= nil) +assert(worker.stats() ~= nil) +assert(net.interfaces() ~= nil) +-- Self-checks on loaded stuff +assert(net.list()['{{SELF_ADDR}}']) +assert(#modules.list() > 0) +-- Self-check timers +ev = event.recurrent(1 * sec, function (ev) return 1 end) +event.cancel(ev) +ev = event.after(0, function (ev) return 1 end) diff --git a/modules/policy/policy.lua b/modules/policy/policy.lua new file mode 100644 index 0000000..6e9492f --- /dev/null +++ b/modules/policy/policy.lua @@ -0,0 +1,756 @@ +local kres = require('kres') +local ffi = require('ffi') + +local todname = kres.str2dname -- not available during module load otherwise + +-- Counter of unique rules +local nextid = 0 +local function getruleid() + local newid = nextid + nextid = nextid + 1 + return newid +end + +-- Support for client sockets from inside policy actions +local socket_client = function () return error("missing luasocket, can't create socket client") end +local has_socket, socket = pcall(require, 'socket') +if has_socket then + socket_client = function (host, port) + local s, err, status + if host:find(':') then + s, err = socket.udp6() + else + s, err = socket.udp() + end + if not s then + return nil, err + end + status, err = s:setpeername(host, port) + if not status then + return nil, err + end + return s + end +end + +local function addr_split_port(target, default_port) + assert(default_port) + assert(type(default_port) == 'number') + local addr, port = target:match '([^@]*)@?(.*)' + local nport + if port ~= "" then + nport = tonumber(port) + if not nport or nport < 1 or nport > 65535 then + error('port "'.. port ..'" is not valid') + end + end + return addr, nport or default_port +end + +-- String address@port -> sockaddr. +local function addr2sock(target, default_port) + local addr, port = addr_split_port(target, default_port) + local sock = ffi.gc(ffi.C.kr_straddr_socket(addr, port), ffi.C.free); + if sock == nil then + error("target '"..target..'" is not a valid IP address') + end + return sock +end + +-- policy functions are defined below +local policy = {} + +function policy.PASS(state, _) + return state +end + +-- Mirror request elsewhere, and continue solving +function policy.MIRROR(target) + local addr, port = addr_split_port(target, 53) + local sink, err = socket_client(addr, port) + if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end + return function(state, req) + if state == kres.FAIL then return state end + local query = req.qsource.packet + if query ~= nil then + sink:send(ffi.string(query.wire, query.size)) + end + return -- Chain action to next + end +end + +-- Override the list of nameservers (forwarders) +local function set_nslist(qry, list) + local ns_i = 0 + for _, ns in ipairs(list) do + -- kr_nsrep_set() can return kr_error(ENOENT), it's OK + if ffi.C.kr_nsrep_set(qry, ns_i, ns) == 0 then + ns_i = ns_i + 1 + end + end + -- If less than maximum NSs, insert guard to terminate the list + if ns_i < 3 then + assert(ffi.C.kr_nsrep_set(qry, ns_i, nil) == 0); + end + if ns_i == 0 then + -- would use assert() but don't want to compose the message if not triggered + error('no usable address in NS set (check net.ipv4 and ' + .. 'net.ipv6 config):\n' .. table_print(list, 2)) + end +end + +-- Forward request, and solve as stub query +function policy.STUB(target) + local list = {} + if type(target) == 'table' then + for _, v in pairs(target) do + table.insert(list, addr2sock(v, 53)) + assert(#list <= 4, 'at most 4 STUB targets are supported') + end + else + table.insert(list, addr2sock(target, 53)) + end + return function(state, req) + local qry = req:current() + -- Switch mode to stub resolver, do not track origin zone cut since it's not real authority NS + qry.flags.STUB = true + qry.flags.ALWAYS_CUT = false + set_nslist(qry, list) + return state + end +end + +-- Forward request and all subrequests to upstream; validate answers +function policy.FORWARD(target) + local list = {} + if type(target) == 'table' then + for _, v in pairs(target) do + table.insert(list, addr2sock(v, 53)) + assert(#list <= 4, 'at most 4 FORWARD targets are supported') + end + else + table.insert(list, addr2sock(target, 53)) + end + return function(state, req) + local qry = req:current() + req.options.FORWARD = true + req.options.NO_MINIMIZE = true + qry.flags.FORWARD = true + qry.flags.ALWAYS_CUT = false + qry.flags.NO_MINIMIZE = true + qry.flags.AWAIT_CUT = true + set_nslist(qry, list) + return state + end +end + +-- object must be non-empty string or non-empty table of non-empty strings +local function is_nonempty_string_or_table(object) + if type(object) == 'string' then + return #object ~= 0 + elseif type(object) ~= 'table' or not next(object) then + return false + end + for _, val in pairs(object) do + if type(val) ~= 'string' or #val == 0 then + return false + end + end + return true +end + +local function insert_from_string_or_table(source, destination) + if type(source) == 'table' then + for _, v in pairs(source) do + table.insert(destination, v) + end + else + table.insert(destination, source) + end +end + +-- Check for allowed authentication types and return type for the current target +local function tls_forward_target_authtype(idx, target) + if (target.pin_sha256 and not (target.ca_file or target.hostname or target.insecure)) then + if not is_nonempty_string_or_table(target.pin_sha256) then + error('TLS_FORWARD target authentication is invalid at position ' + .. idx .. '; pin_sha256 must be string or list of strings') + end + return 'pin_sha256' + elseif (target.insecure and not (target.ca_file or target.hostname or target.pin_sha256)) then + return 'insecure' + elseif (target.hostname and not (target.insecure or target.pin_sha256)) then + if not (is_nonempty_string_or_table(target.hostname)) then + error('TLS_FORWARD target authentication is invalid at position ' + .. idx .. '; hostname must be string or list of strings') + end + -- if target.ca_file is empty, system CA will be used + return 'cert' + else + error('TLS_FORWARD authentication options at position ' .. idx + .. ' are invalid; specify one of: pin_sha256 / hostname [+ca_file] / insecure') + end +end + +local function tls_forward_target_check_syntax(idx, list_entry) + if type(list_entry) ~= 'table' then + error('TLS_FORWARD target must be a non-empty table (found ' + .. type(list_entry) .. ' at position ' .. idx .. ')') + end + if type(list_entry[1]) ~= 'string' then + error('TLS_FORWARD target must start with an IP address (found ' + .. type(list_entry[1]) .. ' at the beginning of target position ' .. idx .. ')') + end +end + +-- Forward request and all subrequests to upstream over TLS; validate answers +function policy.TLS_FORWARD(target) + local sockaddr_c_list = {} + local sockaddr_config = {} -- items: { string_addr=<addr string>, auth_type=<auth type> } + local ca_files = {} + local hostnames = {} + local pins = {} + if type(target) ~= 'table' or #target < 1 then + error('TLS_FORWARD argument must be a non-empty table') + end + for idx, upstream_list_entry in pairs(target) do + tls_forward_target_check_syntax(idx, upstream_list_entry) + local auth_type = tls_forward_target_authtype(idx, upstream_list_entry) + local string_addr = upstream_list_entry[1] + local sockaddr_c = addr2sock(string_addr, 853) + local sockaddr_lua = ffi.string(sockaddr_c, ffi.C.kr_sockaddr_len(sockaddr_c)) + if sockaddr_config[sockaddr_lua] then + error('TLS_FORWARD configuration cannot declare two configs for IP address ' .. string_addr) + end + table.insert(sockaddr_c_list, sockaddr_c) + sockaddr_config[sockaddr_lua] = {string_addr=string_addr, auth_type=auth_type} + if auth_type == 'cert' then + ca_files[sockaddr_lua] = {} + hostnames[sockaddr_lua] = {} + insert_from_string_or_table(upstream_list_entry.ca_file, ca_files[sockaddr_lua]) + insert_from_string_or_table(upstream_list_entry.hostname, hostnames[sockaddr_lua]) + elseif auth_type == 'pin_sha256' then + pins[sockaddr_lua] = {} + insert_from_string_or_table(upstream_list_entry.pin_sha256, pins[sockaddr_lua]) + elseif auth_type ~= 'insecure' then + -- insecure does nothing, user does not want authentication + assert(false, 'unsupported auth_type') + end + end + + -- Update the global table of authentication data only if all checks above passed + for sockaddr_lua, config in pairs(sockaddr_config) do + assert(#config.string_addr > 0) + if config.auth_type == 'insecure' then + net.tls_client(config.string_addr) + elseif config.auth_type == 'pin_sha256' then + assert(#pins[sockaddr_lua] > 0) + net.tls_client(config.string_addr, pins[sockaddr_lua]) + elseif config.auth_type == 'cert' then + assert(#hostnames[sockaddr_lua] > 0) + net.tls_client(config.string_addr, ca_files[sockaddr_lua], hostnames[sockaddr_lua]) + else + assert(false, 'unsupported auth_type') + end + end + + return function(state, req) + local qry = req:current() + req.options.FORWARD = true + req.options.NO_MINIMIZE = true + qry.flags.FORWARD = true + qry.flags.ALWAYS_CUT = false + qry.flags.NO_MINIMIZE = true + qry.flags.AWAIT_CUT = true + req.options.TCP = true + qry.flags.TCP = true + set_nslist(qry, sockaddr_c_list) + return state + end +end + +-- Rewrite records in packet +function policy.REROUTE(tbl, names) + -- Import renumbering rules + local ren = require('renumber') + local prefixes = {} + for from, to in pairs(tbl) do + table.insert(prefixes, names and ren.name(from, to) or ren.prefix(from, to)) + end + -- Return rule closure + return ren.rule(prefixes) +end + +-- Set and clear some query flags +function policy.FLAGS(opts_set, opts_clear) + return function(_, req) + local qry = req:current() + ffi.C.kr_qflags_set (qry.flags, kres.mk_qflags(opts_set or {})) + ffi.C.kr_qflags_clear(qry.flags, kres.mk_qflags(opts_clear or {})) + return nil -- chain rule + end +end + +local function mkauth_soa(answer, dname, mname) + if mname == nil then + mname = dname + end + return answer:put(dname, 10800, answer:qclass(), kres.type.SOA, + mname .. '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48') +end + +local dname_localhost = todname('localhost.') + +-- Rule for localhost. zone; see RFC6303, sec. 3 +local function localhost(_, req) + local qry = req:current() + local answer = req.answer + ffi.C.kr_pkt_make_auth_header(answer) + + local is_exact = ffi.C.knot_dname_is_equal(qry.sname, dname_localhost) + + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + if qry.stype == kres.type.AAAA then + answer:put(qry.sname, 900, answer:qclass(), kres.type.AAAA, + '\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1') + elseif qry.stype == kres.type.A then + answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\127\0\0\1') + elseif is_exact and qry.stype == kres.type.SOA then + mkauth_soa(answer, dname_localhost) + elseif is_exact and qry.stype == kres.type.NS then + answer:put(dname_localhost, 900, answer:qclass(), kres.type.NS, dname_localhost) + else + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, dname_localhost) + end + return kres.DONE +end + +local dname_rev4_localhost = todname('1.0.0.127.in-addr.arpa'); +local dname_rev4_localhost_apex = todname('127.in-addr.arpa'); + +-- Rule for reverse localhost. +-- Answer with locally served minimal 127.in-addr.arpa domain, only having +-- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone. +-- TODO: much of this would better be left to the hints module (or coordinated). +local function localhost_reversed(_, req) + local qry = req:current() + local answer = req.answer + + -- classify qry.sname: + local is_exact -- exact dname for localhost + local is_apex -- apex of a locally-served localhost zone + local is_nonterm -- empty non-terminal name + if ffi.C.knot_dname_in_bailiwick(qry.sname, todname('ip6.arpa.')) > 0 then + -- exact ::1 query (relying on the calling rule) + is_exact = true + is_apex = true + else + -- within 127.in-addr.arpa. + local labels = ffi.C.knot_dname_labels(qry.sname, nil) + if labels == 3 then + is_exact = false + is_apex = true + elseif labels == 4+2 and ffi.C.knot_dname_is_equal( + qry.sname, dname_rev4_localhost) then + is_exact = true + else + is_exact = false + is_apex = false + is_nonterm = ffi.C.knot_dname_in_bailiwick(dname_rev4_localhost, qry.sname) > 0 + end + end + + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + if is_exact and qry.stype == kres.type.PTR then + answer:put(qry.sname, 900, answer:qclass(), kres.type.PTR, dname_localhost) + elseif is_apex and qry.stype == kres.type.SOA then + mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost) + elseif is_apex and qry.stype == kres.type.NS then + answer:put(dname_rev4_localhost_apex, 900, answer:qclass(), kres.type.NS, + dname_localhost) + else + if not is_nonterm then + answer:rcode(kres.rcode.NXDOMAIN) + end + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost) + end + return kres.DONE +end + +-- All requests +function policy.all(action) + return function(_, _) return action end +end + +-- Requests which QNAME matches given zone list (i.e. suffix match) +function policy.suffix(action, zone_list) + local AC = require('ahocorasick') + local tree = AC.create(zone_list) + return function(_, query) + local match = AC.match(tree, query:name(), false) + if match ~= nil then + return action + end + return nil + end +end + +-- Check for common suffix first, then suffix match (specialized version of suffix match) +function policy.suffix_common(action, suffix_list, common_suffix) + local common_len = string.len(common_suffix) + local suffix_count = #suffix_list + return function(_, query) + -- Preliminary check + local qname = query:name() + if not string.find(qname, common_suffix, -common_len, true) then + return nil + end + -- String match + for i = 1, suffix_count do + local zone = suffix_list[i] + if string.find(qname, zone, -string.len(zone), true) then + return action + end + end + return nil + end +end + +-- Filter QNAME pattern +function policy.pattern(action, pattern) + return function(_, query) + if string.find(query:name(), pattern) then + return action + end + return nil + end +end + +local function rpz_parse(action, path) + local rules = {} + local action_map = { + -- RPZ Policy Actions + ['\0'] = action, + ['\1*\0'] = action, -- deviates from RPZ spec + ['\012rpz-passthru\0'] = policy.PASS, -- the grammar... + ['\008rpz-drop\0'] = policy.DROP, + ['\012rpz-tcp-only\0'] = policy.TC, + -- Policy triggers @NYI@ + } + local parser = require('zonefile').new() + if not parser:open(path) then error(string.format('failed to parse "%s"', path)) end + while parser:parse() do + local name = ffi.string(parser.r_owner, parser.r_owner_length) + local name_action = ffi.string(parser.r_data, parser.r_data_length) + rules[name] = action_map[name_action] + -- Warn when NYI + if #name > 1 and not action_map[name_action] then + print(string.format('[ rpz ] %s:%d: unsupported policy action', path, tonumber(parser.line_counter))) + end + end + return rules +end + +-- RPZ policy set +-- Create RPZ from zone file +function policy.rpz(action, path) + local rules = rpz_parse(action, path) + collectgarbage() + return function(_, query) + local label = query:name() + local rule = rules[label] + while rule == nil and string.len(label) > 0 do + label = string.sub(label, string.byte(label) + 2) + rule = rules['\1*'..label] + end + return rule + end +end + +function policy.DENY_MSG(msg) + if msg and (type(msg) ~= 'string' or #msg >= 255) then + error('DENY_MSG: optional msg must be string shorter than 256 characters') + end + + return function (_, req) + -- Write authority information + local answer = req.answer + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NXDOMAIN) + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, answer:qname()) + if msg then + answer:begin(kres.section.ADDITIONAL) + answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT, + string.char(#msg) .. msg) + + end + return kres.DONE + end +end +policy.DENY = policy.DENY_MSG() -- compatibility with < 2.0 + +function policy.DROP(_, _) + return kres.FAIL +end + +function policy.REFUSE(_, req) + local answer = req.answer + answer:rcode(kres.rcode.REFUSED) + answer:ad(false) + return kres.DONE +end + +function policy.TC(state, req) + local answer = req.answer + if answer.max_size ~= 65535 then + answer:tc(1) -- ^ Only UDP queries + answer:ad(false) + return kres.DONE + else + return state + end +end + +function policy.QTRACE(_, req) + local qry = req:current() + req.options.TRACE = true + qry.flags.TRACE = true + return -- this allows to continue iterating over policy list +end + +-- Evaluate packet in given rules to determine policy action +function policy.evaluate(rules, req, query, state) + for i = 1, #rules do + local rule = rules[i] + if not rule.suspended then + local action = rule.cb(req, query) + if action ~= nil then + rule.count = rule.count + 1 + local next_state = action(state, req) + if next_state then -- Not a chain rule, + return next_state -- stop on first match + end + end + end + end + return +end + +-- Top-down policy list walk until we hit a match +-- the caller is responsible for reordering policy list +-- from most specific to least specific. +-- Some rules may be chained, in this case they are evaluated +-- as a dependency chain, e.g. r1,r2,r3 -> r3(r2(r1(state))) +policy.layer = { + begin = function(state, req) + -- Don't act on "resolved" cases. + if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end + + req = kres.request_t(req) + return policy.evaluate(policy.rules, req, req:current(), state) or + policy.evaluate(policy.special_names, req, req:current(), state) or + state + end, + finish = function(state, req) + -- Don't act on "resolved" cases. + if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end + + req = kres.request_t(req) + return policy.evaluate(policy.postrules, req, req:current(), state) or state + end +} + +-- Add rule to policy list +function policy.add(rule, postrule) + -- Compatibility with 1.0.0 API + -- it will be dropped in 1.2.0 + if rule == policy then + rule = postrule + postrule = nil + end + -- End of compatibility shim + local desc = {id=getruleid(), cb=rule, count=0} + table.insert(postrule and policy.postrules or policy.rules, desc) + return desc +end + +-- Remove rule from a list +local function delrule(rules, id) + for i, r in ipairs(rules) do + if r.id == id then + table.remove(rules, i) + return true + end + end + return false +end + +-- Delete rule from policy list +function policy.del(id) + if not delrule(policy.rules, id) then + if not delrule(policy.postrules, id) then + return false + end + end + return true +end + +-- Convert list of string names to domain names +function policy.todnames(names) + for i, v in ipairs(names) do + names[i] = kres.str2dname(v) + end + return names +end + +-- RFC1918 Private, local, broadcast, test and special zones +-- Considerations: RFC6761, sec 6.1. +-- https://www.iana.org/assignments/locally-served-dns-zones +local private_zones = { + -- RFC6303 + '10.in-addr.arpa.', + '16.172.in-addr.arpa.', + '17.172.in-addr.arpa.', + '18.172.in-addr.arpa.', + '19.172.in-addr.arpa.', + '20.172.in-addr.arpa.', + '21.172.in-addr.arpa.', + '22.172.in-addr.arpa.', + '23.172.in-addr.arpa.', + '24.172.in-addr.arpa.', + '25.172.in-addr.arpa.', + '26.172.in-addr.arpa.', + '27.172.in-addr.arpa.', + '28.172.in-addr.arpa.', + '29.172.in-addr.arpa.', + '30.172.in-addr.arpa.', + '31.172.in-addr.arpa.', + '168.192.in-addr.arpa.', + '0.in-addr.arpa.', + '254.169.in-addr.arpa.', + '2.0.192.in-addr.arpa.', + '100.51.198.in-addr.arpa.', + '113.0.203.in-addr.arpa.', + '255.255.255.255.in-addr.arpa.', + -- RFC7793 + '64.100.in-addr.arpa.', + '65.100.in-addr.arpa.', + '66.100.in-addr.arpa.', + '67.100.in-addr.arpa.', + '68.100.in-addr.arpa.', + '69.100.in-addr.arpa.', + '70.100.in-addr.arpa.', + '71.100.in-addr.arpa.', + '72.100.in-addr.arpa.', + '73.100.in-addr.arpa.', + '74.100.in-addr.arpa.', + '75.100.in-addr.arpa.', + '76.100.in-addr.arpa.', + '77.100.in-addr.arpa.', + '78.100.in-addr.arpa.', + '79.100.in-addr.arpa.', + '80.100.in-addr.arpa.', + '81.100.in-addr.arpa.', + '82.100.in-addr.arpa.', + '83.100.in-addr.arpa.', + '84.100.in-addr.arpa.', + '85.100.in-addr.arpa.', + '86.100.in-addr.arpa.', + '87.100.in-addr.arpa.', + '88.100.in-addr.arpa.', + '89.100.in-addr.arpa.', + '90.100.in-addr.arpa.', + '91.100.in-addr.arpa.', + '92.100.in-addr.arpa.', + '93.100.in-addr.arpa.', + '94.100.in-addr.arpa.', + '95.100.in-addr.arpa.', + '96.100.in-addr.arpa.', + '97.100.in-addr.arpa.', + '98.100.in-addr.arpa.', + '99.100.in-addr.arpa.', + '100.100.in-addr.arpa.', + '101.100.in-addr.arpa.', + '102.100.in-addr.arpa.', + '103.100.in-addr.arpa.', + '104.100.in-addr.arpa.', + '105.100.in-addr.arpa.', + '106.100.in-addr.arpa.', + '107.100.in-addr.arpa.', + '108.100.in-addr.arpa.', + '109.100.in-addr.arpa.', + '110.100.in-addr.arpa.', + '111.100.in-addr.arpa.', + '112.100.in-addr.arpa.', + '113.100.in-addr.arpa.', + '114.100.in-addr.arpa.', + '115.100.in-addr.arpa.', + '116.100.in-addr.arpa.', + '117.100.in-addr.arpa.', + '118.100.in-addr.arpa.', + '119.100.in-addr.arpa.', + '120.100.in-addr.arpa.', + '121.100.in-addr.arpa.', + '122.100.in-addr.arpa.', + '123.100.in-addr.arpa.', + '124.100.in-addr.arpa.', + '125.100.in-addr.arpa.', + '126.100.in-addr.arpa.', + '127.100.in-addr.arpa.', + + -- RFC6303 + -- localhost_reversed handles ::1 + '0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.', + 'd.f.ip6.arpa.', + '8.e.f.ip6.arpa.', + '9.e.f.ip6.arpa.', + 'a.e.f.ip6.arpa.', + 'b.e.f.ip6.arpa.', + '8.b.d.0.1.0.0.2.ip6.arpa.', +} +policy.todnames(private_zones) + +-- @var Default rules +policy.rules = {} +policy.postrules = {} +policy.special_names = { + { + cb=policy.suffix_common(policy.DENY_MSG( + 'Blocking is mandated by standards, see references on ' + .. 'https://www.iana.org/assignments/' + .. 'locally-served-dns-zones/locally-served-dns-zones.xhtml'), + private_zones, todname('arpa.')), + count=0 + }, + { + cb=policy.suffix(policy.DENY_MSG( + 'Blocking is mandated by standards, see references on ' + .. 'https://www.iana.org/assignments/' + .. 'special-use-domain-names/special-use-domain-names.xhtml'), + { + todname('test.'), + todname('onion.'), + todname('invalid.'), + }), + count=0 + }, + { + cb=policy.suffix(localhost, {dname_localhost}), + count=0 + }, + { + cb=policy.suffix_common(localhost_reversed, { + todname('127.in-addr.arpa.'), + todname('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.')}, + todname('arpa.')), + count=0 + }, +} + +return policy diff --git a/modules/policy/policy.mk b/modules/policy/policy.mk new file mode 100644 index 0000000..98c9f88 --- /dev/null +++ b/modules/policy/policy.mk @@ -0,0 +1,15 @@ +AHOCORASICK_DIR = modules/policy/lua-aho-corasick/ + +policy_SOURCES := policy.lua +policy_DEPEND := $(AHOCORASICK_DIR)ahocorasick$(LIBEXT) +$(call make_lua_module,policy) + +policy-clean: + $(MAKE) -C $(AHOCORASICK_DIR) clean +$(AHOCORASICK_DIR)ahocorasick$(LIBEXT): $(AHOCORASICK_DIR)Makefile + $(MAKE) -C $(AHOCORASICK_DIR) ahocorasick$(LIBEXT) CXXFLAGS="$(lua_CFLAGS)" + +policy-install: ahocorasick-install +ahocorasick-install: $(AHOCORASICK_DIR)ahocorasick$(LIBEXT) $(DESTDIR)$(MODULEDIR) + $(INSTALL) -m 755 $(AHOCORASICK_DIR)ahocorasick$(LIBEXT) $(DESTDIR)$(MODULEDIR) + diff --git a/modules/policy/policy.test.lua b/modules/policy/policy.test.lua new file mode 100644 index 0000000..8780caa --- /dev/null +++ b/modules/policy/policy.test.lua @@ -0,0 +1,52 @@ +-- setup resolver +-- policy module should be loaded by default, do not load it explicitly + +-- test for default configuration +local function test_tls_forward() + boom(policy.TLS_FORWARD, {}, 'TLS_FORWARD without arguments') + boom(policy.TLS_FORWARD, {'1'}, 'TLS_FORWARD with non-table argument') + boom(policy.TLS_FORWARD, {{}}, 'TLS_FORWARD with empty table') + boom(policy.TLS_FORWARD, {{{}}}, 'TLS_FORWARD with empty target table') + boom(policy.TLS_FORWARD, {{{bleble=''}}}, 'TLS_FORWARD with invalid parameters in table') + + boom(policy.TLS_FORWARD, {{'1'}}, 'TLS_FORWARD with invalid IP address') + -- boom(policy.TLS_FORWARD, {{{'::1', bleble=''}}}, 'TLS_FORWARD with valid IP and invalid parameters') + boom(policy.TLS_FORWARD, {{{'127.0.0.1'}}}, 'TLS_FORWARD with missing auth parameters') + + ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true}}), 'TLS_FORWARD with no authentication') + boom(policy.TLS_FORWARD, {{{'100:dead::', insecure=true}, + {'100:DEAD:0::', insecure=true} + }}, 'TLS_FORWARD with duplicate IP addresses is not allowed') + ok(policy.TLS_FORWARD({{'100:dead::', insecure=true}, + {'100:dead::@443', insecure=true} + }), 'TLS_FORWARD with duplicate IP addresses but different ports is allowed') + ok(policy.TLS_FORWARD({{'100:dead::', insecure=true}, + {'100:beef::', insecure=true} + }), 'TLS_FORWARD with different IPv6 addresses is allowed') + ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true}, + {'127.0.0.2', insecure=true} + }), 'TLS_FORWARD with different IPv4 addresses is allowed') + + boom(policy.TLS_FORWARD, {{{'::1', pin_sha256=''}}}, 'TLS_FORWARD with empty pin_sha256') + -- boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='č'}}}, 'TLS_FORWARD with bad pin_sha256') + ok(policy.TLS_FORWARD({ + {'::1', pin_sha256='ZTNiMGM0NDI5OGZjMWMxNDlhZmJmNGM4OTk2ZmI5MjQyN2FlNDFlNDY0OWI5MzRjYTQ5NTk5MWI3ODUyYjg1NQ=='} + }), 'TLS_FORWARD with base64 pin_sha256') + ok(policy.TLS_FORWARD({ + {'::1', pin_sha256={ + 'ZTNiMGM0NDI5OGZjMWMxNDlhZmJmNGM4OTk2ZmI5MjQyN2FlNDFlNDY0OWI5MzRjYTQ5NTk5MWI3ODUyYjg1NQ==', + 'MTcwYWUzMGNjZDlmYmE2MzBhZjhjZGE2ODQxZTAwYzZiNjU3OWNlYzc3NmQ0MTllNzAyZTIwYzY5YzQ4OGZmOA==' + }}}), 'TLS_FORWARD with table of pins') + + -- ok(policy.TLS_FORWARD({{'::1', hostname='test.', ca_file='/tmp/ca.crt'}}), 'TLS_FORWARD with hostname + CA cert') + ok(policy.TLS_FORWARD({{'::1', hostname='test.'}}), 'TLS_FORWARD with just hostname (use system CA store)') + boom(policy.TLS_FORWARD, {{{'::1', ca_file='/tmp/ca.crt'}}}, 'TLS_FORWARD with just CA cert') + boom(policy.TLS_FORWARD, {{{'::1', hostname='', ca_file='/tmp/ca.crt'}}}, 'TLS_FORWARD with empty hostname + CA cert') + boom(policy.TLS_FORWARD, {{{'::1', hostname='test.', ca_file='/dev/a_file_which_surely_does_NOT_exist!'}}}, + 'TLS_FORWARD with hostname + unreadable CA cert') + +end + +return { + test_tls_forward +} diff --git a/modules/policy/test.integr/deckard.yaml b/modules/policy/test.integr/deckard.yaml new file mode 100644 index 0000000..be60e97 --- /dev/null +++ b/modules/policy/test.integr/deckard.yaml @@ -0,0 +1,12 @@ +programs: +- name: kresd + binary: kresd + additional: + - -f + - "1" + templates: + - modules/policy/test.integr/kresd_config.j2 + - tests/hints_zone.j2 + configs: + - config + - hints diff --git a/modules/policy/test.integr/kresd_config.j2 b/modules/policy/test.integr/kresd_config.j2 new file mode 100644 index 0000000..9274363 --- /dev/null +++ b/modules/policy/test.integr/kresd_config.j2 @@ -0,0 +1,49 @@ +{% raw %} +policy.add(policy.suffix(policy.REFUSE, {todname('refuse.example.com')})) + +-- Disable RFC8145 signaling, scenario doesn't provide expected answers +if ta_signal_query then + modules.unload('ta_signal_query') +end + +-- Disable RFC8109 priming, scenario doesn't provide expected answers +if priming then + modules.unload('priming') +end + +-- Disable this module because it make one priming query +if detect_time_skew then + modules.unload('detect_time_skew') +end + +_hint_root_file('hints') +cache.size = 2*MB +verbose(true) +{% endraw %} + +net = { '{{SELF_ADDR}}' } + + +{% if QMIN == "false" %} +option('NO_MINIMIZE', true) +{% else %} +option('NO_MINIMIZE', false) +{% endif %} + + +-- Self-checks on globals +assert(help() ~= nil) +assert(worker.id ~= nil) +-- Self-checks on facilities +assert(cache.count() == 0) +assert(cache.stats() ~= nil) +assert(cache.backends() ~= nil) +assert(worker.stats() ~= nil) +assert(net.interfaces() ~= nil) +-- Self-checks on loaded stuff +assert(net.list()['{{SELF_ADDR}}']) +assert(#modules.list() > 0) +-- Self-check timers +ev = event.recurrent(1 * sec, function (ev) return 1 end) +event.cancel(ev) +ev = event.after(0, function (ev) return 1 end) diff --git a/modules/policy/test.integr/refuse.rpl b/modules/policy/test.integr/refuse.rpl new file mode 100644 index 0000000..1d1372b --- /dev/null +++ b/modules/policy/test.integr/refuse.rpl @@ -0,0 +1,24 @@ +; config options + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test refuse policy + +STEP 10 QUERY +ENTRY_BEGIN +REPLY RD AD +SECTION QUESTION +www.refuse.example.com. IN A +ENTRY_END + +STEP 20 CHECK_ANSWER +ENTRY_BEGIN +MATCH all answer +; AD must not be set in the answer +REPLY QR RD RA REFUSED +SECTION QUESTION +www.refuse.example.com. IN A +SECTION ANSWER +ENTRY_END + +SCENARIO_END |