summaryrefslogtreecommitdiffstats
path: root/modules/policy
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 15:26:00 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 15:26:00 +0000
commit830407e88f9d40d954356c3754f2647f91d5c06a (patch)
treed6a0ece6feea91f3c656166dbaa884ef8a29740e /modules/policy
parentInitial commit. (diff)
downloadknot-resolver-upstream.tar.xz
knot-resolver-upstream.zip
Adding upstream version 5.6.0.upstream/5.6.0upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--modules/policy/.packaging/test.config4
-rw-r--r--modules/policy/README.rst774
-rw-r--r--modules/policy/lua-aho-corasick/LICENSE28
-rw-r--r--modules/policy/lua-aho-corasick/Makefile134
-rw-r--r--modules/policy/lua-aho-corasick/README.md40
-rw-r--r--modules/policy/lua-aho-corasick/ac.cxx101
-rw-r--r--modules/policy/lua-aho-corasick/ac.h49
-rw-r--r--modules/policy/lua-aho-corasick/ac_fast.cxx468
-rw-r--r--modules/policy/lua-aho-corasick/ac_fast.hpp124
-rw-r--r--modules/policy/lua-aho-corasick/ac_lua.cxx173
-rw-r--r--modules/policy/lua-aho-corasick/ac_slow.cxx318
-rw-r--r--modules/policy/lua-aho-corasick/ac_slow.hpp158
-rw-r--r--modules/policy/lua-aho-corasick/ac_util.hpp69
-rw-r--r--modules/policy/lua-aho-corasick/load_ac.lua90
-rw-r--r--modules/policy/lua-aho-corasick/mytest.cxx200
-rw-r--r--modules/policy/lua-aho-corasick/tests/Makefile65
-rw-r--r--modules/policy/lua-aho-corasick/tests/ac_bench.cxx519
-rw-r--r--modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx135
-rw-r--r--modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx275
-rw-r--r--modules/policy/lua-aho-corasick/tests/dict/README.txt1
-rw-r--r--modules/policy/lua-aho-corasick/tests/dict/dict1.txt11
-rw-r--r--modules/policy/lua-aho-corasick/tests/load_ac_test.lua82
-rw-r--r--modules/policy/lua-aho-corasick/tests/lua_test.lua67
-rw-r--r--modules/policy/lua-aho-corasick/tests/test_base.hpp60
-rw-r--r--modules/policy/lua-aho-corasick/tests/test_bigfile.cxx167
-rw-r--r--modules/policy/lua-aho-corasick/tests/test_main.cxx33
-rw-r--r--modules/policy/meson.build50
-rw-r--r--modules/policy/noipv6.test.integr/broken-ipv6.rpl47
-rw-r--r--modules/policy/noipv6.test.integr/deckard.yaml12
-rw-r--r--modules/policy/noipv6.test.integr/kresd_config.j259
-rw-r--r--modules/policy/noipvx.test.integr/broken-ipvx.rpl35
-rw-r--r--modules/policy/noipvx.test.integr/deckard.yaml12
-rw-r--r--modules/policy/noipvx.test.integr/kresd_config.j260
-rw-r--r--modules/policy/policy.lua1109
-rw-r--r--modules/policy/policy.rpz.test.lua65
-rw-r--r--modules/policy/policy.slice.test.lua109
-rw-r--r--modules/policy/policy.test.lua145
-rw-r--r--modules/policy/policy.test.rpz18
-rw-r--r--modules/policy/policy.test.rpz.soa5
-rw-r--r--modules/policy/test.integr/deckard.yaml12
-rw-r--r--modules/policy/test.integr/kresd_config.j259
-rw-r--r--modules/policy/test.integr/refuse.rpl44
42 files changed, 5986 insertions, 0 deletions
diff --git a/modules/policy/.packaging/test.config b/modules/policy/.packaging/test.config
new file mode 100644
index 0000000..60c9ddc
--- /dev/null
+++ b/modules/policy/.packaging/test.config
@@ -0,0 +1,4 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+modules.load('policy')
+assert(policy)
+quit()
diff --git a/modules/policy/README.rst b/modules/policy/README.rst
new file mode 100644
index 0000000..202aaba
--- /dev/null
+++ b/modules/policy/README.rst
@@ -0,0 +1,774 @@
+.. SPDX-License-Identifier: GPL-3.0-or-later
+
+.. default-domain:: py
+.. module:: policy
+
+.. _mod-policy:
+
+
+Query policies
+==============
+
+This module can block, rewrite, or alter inbound queries based on user-defined policies. It does not affect queries generated by the resolver itself, e.g. when following CNAME chains etc.
+
+Each policy *rule* has two parts: a *filter* and an *action*. A *filter* selects which queries will be affected by the policy, and *action* which modifies queries matching the associated filter.
+
+Typically a rule is defined as follows: ``filter(action(action parameters), filter parameters)``. For example, a filter can be ``suffix`` which matches queries whose suffix part is in specified set, and one of possible actions is :any:`policy.DENY`, which denies resolution. These are combined together into ``policy.suffix(policy.DENY, {todname('badguy.example.')})``. The rule is effective when it is added into rule table using ``policy.add()``, please see examples below.
+
+This module is enabled by default because it implements mandatory :rfc:`6761` logic.
+When no rule applies to a query, built-in rules for `special-use <https://www.iana.org/assignments/special-use-domain-names/special-use-domain-names.xhtml>`_ and `locally-served <http://www.iana.org/assignments/locally-served-dns-zones>`_ domain names are applied.
+These rules can be overridden by action :any:`policy.PASS`. For debugging purposes you can also add ``modules.unload('policy')`` to your config to unload the module.
+
+
+Filters
+-------
+A *filter* selects which queries will be affected by specified Actions_. There are several policy filters available in the ``policy.`` table:
+
+.. function:: all(action)
+
+ Always applies the action.
+
+.. function:: pattern(action, pattern)
+
+ Applies the action if query name matches a `Lua regular expression <http://lua-users.org/wiki/PatternsTutorial>`_.
+
+.. function:: suffix(action, suffix_table)
+
+ Applies the action if query name suffix matches one of suffixes in the table (useful for "is domain in zone" rules).
+
+ .. code-block:: lua
+
+ policy.add(policy.suffix(policy.DENY, policy.todnames({'example.com', 'example.net'})))
+
+.. note:: For speed this filter requires domain names in DNS wire format, not textual representation, so each label in the name must be prefixed with its length. Always use convenience function :func:`policy.todnames` for automatic conversion from strings! For example:
+
+.. _IDN:
+
+.. note:: Non-ASCII is not supported.
+
+ Knot Resolver does not provide any convenience support for IDN.
+ Therefore everywhere (all configuration, logs, RPZ files) you need to deal with the
+ `xn\-\- forms <https://en.wikipedia.org/wiki/Internationalized_domain_name#Example_of_IDNA_encoding>`_
+ of domain name labels, instead of directly using unicode characters.
+
+.. function:: domains(action, domain_table)
+
+ Like :func:`policy.suffix` match, but the queried name must match exactly, not just its suffix.
+
+.. function:: suffix_common(action, suffix_table[, common_suffix])
+
+ :param action: action if the pattern matches query name
+ :param suffix_table: table of valid suffixes
+ :param common_suffix: common suffix of entries in suffix_table
+
+ Like :func:`policy.suffix` match, but you can also provide a common suffix of all matches for faster processing (nil otherwise).
+ This function is faster for small suffix tables (in the order of "hundreds").
+
+.. :noindex: function:: rpz(default_action, path, [watch])
+
+ Implements a subset of `Response Policy Zone` (RPZ_) stored in zonefile format. See below for details: :func:`policy.rpz`.
+
+It is also possible to define custom filter function with any name.
+
+.. function:: custom_filter(state, query)
+
+ :param state: Request processing state :c:type:`kr_layer_state`, typically not used by filter function.
+ :param query: Incoming DNS query as :c:type:`kr_query` structure.
+ :return: An `action <#actions>`_ function or ``nil`` if filter did not match.
+
+ Typically filter function is generated by another function, which allows easy parametrization - this technique is called `closure <https://www.lua.org/pil/6.1.html>`_. An practical example of such filter generator is:
+
+.. code-block:: lua
+
+ function match_query_type(action, target_qtype)
+ return function (state, query)
+ if query.stype == target_qtype then
+ -- filter matched the query, return action function
+ return action
+ else
+ -- filter did not match, continue with next filter
+ return nil
+ end
+ end
+ end
+
+This custom filter can be used as any other built-in filter.
+For example this applies our custom filter and executes action :any:`policy.DENY` on all queries of type `HINFO`:
+
+.. code-block:: lua
+
+ -- custom filter which matches HINFO queries, action is policy.DENY
+ policy.add(match_query_type(policy.DENY, kres.type.HINFO))
+
+
+.. _mod-policy-actions:
+
+Actions
+-------
+An *action* is a function which modifies DNS request, and is either of type *chain* or *non-chain*:
+
+ * `Non-chain actions`_ modify state of the request and stop rule processing. An example of such action is :ref:`forwarding`.
+ * `Chain actions`_ modify state of the request and allow other rules to evaluate and act on the same request. One such example is :func:`policy.MIRROR`.
+
+Non-chain actions
+^^^^^^^^^^^^^^^^^
+
+Following actions stop the policy matching on the query, i.e. other rules are not evaluated once rule with following actions matches:
+
+.. py:attribute:: PASS
+
+ Let the query pass through; it's useful to make exceptions before wider rules. For example:
+
+ More specific whitelist rule must precede generic blacklist rule:
+
+ .. code-block:: lua
+
+ -- Whitelist 'good.example.com'
+ policy.add(policy.pattern(policy.PASS, todname('good.example.com.')))
+ -- Block all names below example.com
+ policy.add(policy.suffix(policy.DENY, {todname('example.com.')}))
+
+.. py:attribute:: DENY
+
+ Deny existence of names matching filter, i.e. reply NXDOMAIN authoritatively.
+
+.. function:: DENY_MSG(message, [extended_error=kres.extended_error.BLOCKED])
+
+ Deny existence of a given domain and add explanatory message. NXDOMAIN reply
+ contains an additional explanatory message as TXT record in the additional
+ section.
+
+ You may override the extended DNS error to provide the user with more
+ information. By default, ``BLOCKED`` is returned to indicate the domain is
+ blocked due to the internal policy of the operator. Other suitable error
+ codes are ``CENSORED`` (for externally imposed policy reasons) or
+ ``FILTERED`` (for blocking requested by the client). For more information,
+ please refer to :rfc:`8914`.
+
+.. py:attribute:: DROP
+
+ Terminate query resolution and return SERVFAIL to the requestor.
+
+.. py:attribute:: REFUSE
+
+ Terminate query resolution and return REFUSED to the requestor.
+
+.. py:attribute:: NO_ANSWER
+
+ Terminate query resolution and do not return any answer to the requestor.
+
+ .. warning:: During normal operation, an answer should always be returned.
+ Deliberate query drops are indistinguishable from packet loss and may
+ cause problems as described in :rfc:`8906`. Only use :any:`NO_ANSWER`
+ on very specific occasions, e.g. as a defense mechanism during an attack,
+ and prefer other actions (e.g. :any:`DROP` or :any:`REFUSE`) for normal
+ operation.
+
+.. py:attribute:: TC
+
+ Force requestor to use TCP. It sets truncated bit (*TC*) in response to true if the request came through UDP, which will force standard-compliant clients to retry the request over TCP.
+
+.. function:: REROUTE({subnet = target, ...})
+
+ Reroute IP addresses in response matching given subnet to given target, e.g. ``{['192.0.2.0/24'] = '127.0.0.0'}`` will rewrite '192.0.2.55' to '127.0.0.55', see :ref:`renumber module <mod-renumber>` for more information. See :func:`policy.add` and do not forget to specify that this is *postrule*. Quick example:
+
+ .. code-block:: lua
+
+ -- this policy is enforced on answers
+ -- therefore we have to use 'postrule'
+ -- (the "true" at the end of policy.add)
+ policy.add(policy.all(policy.REROUTE({['192.0.2.0/24'] = '127.0.0.0'})), true)
+
+.. function:: ANSWER({ type = { rdata=data, [ttl=1] } }, [nodata=false])
+
+ Overwrite Resource Records in responses with specified values.
+
+ * type
+ - RR type to be replaced, e.g. ``[kres.type.A]`` or `numeric value <https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4>`_.
+ * rdata
+ - RR data in DNS wire format, i.e. binary form specific for given RR type. Set of multiple RRs can be specified as table ``{ rdata1, rdata2, ... }``. Use helper function :func:`kres.str2ip` to generate wire format for A and AAAA records. Wire format for other record types can be generated with :func:`kres.parse_rdata`.
+ * ttl
+ - TTL in seconds. Default: 1 second.
+ * nodata
+ - If type requested by client is not configured in this policy:
+
+ - ``true``: Return empty answer (`NODATA`).
+ - ``false``: Ignore this policy and continue processing other rules.
+
+ Default: ``false``.
+
+ .. code-block:: lua
+
+ -- policy to change IPv4 address and TTL for example.com
+ policy.add(
+ policy.domains(
+ policy.ANSWER(
+ { [kres.type.A] = { rdata=kres.str2ip('192.0.2.7'), ttl=300 } }
+ ), { todname('example.com') }))
+ -- policy to generate two TXT records (specified in binary format) for example.net
+ policy.add(
+ policy.domains(
+ policy.ANSWER(
+ { [kres.type.TXT] = { rdata={'\005first', '\006second'}, ttl=5 } }
+ ), { todname('example.net') }))
+
+
+ .. function:: kres.parse_rdata({str, ...})
+
+ Parse string representation of RTYPE and RDATA into RDATA wire format. Expects
+ a table of string(s) and returns a table of wire data.
+
+ .. code-block:: lua
+
+ -- create wire format RDATA that can be passed to policy.ANSWER
+ kres.parse_rdata({'SVCB 1 resolver.example. alpn=dot'})
+ kres.parse_rdata({
+ 'SVCB 1 resolver.example. alpn=dot ipv4hint=192.0.2.1 ipv6hint=2001:db8::1',
+ 'SVCB 2 resolver.example. mandatory=key65380 alpn=h2 key65380=/dns-query{?dns}',
+ })
+
+More complex non-chain actions are described in their own chapters, namely:
+
+ * :ref:`forwarding`
+ * `Response Policy Zones`_
+
+Chain actions
+^^^^^^^^^^^^^
+
+Following actions act on request and then processing continue until first non-chain action (specified in the previous section) is triggered:
+
+.. function:: MIRROR(ip_address)
+
+ Send copy of incoming DNS queries to a given IP address using DNS-over-UDP and continue resolving them as usual. This is useful for sanity testing new versions of DNS resolvers.
+
+ .. code-block:: lua
+
+ policy.add(policy.all(policy.MIRROR('127.0.0.2')))
+
+.. function:: FLAGS(set, clear)
+
+ Set and/or clear some flags for the query. There can be multiple flags to set/clear. You can just pass a single flag name (string) or a set of names. Flag names correspond to :c:type:`kr_qflags` structure. Use only if you know what you are doing.
+
+
+.. _mod-policy-logging:
+
+Actions for extra logging
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+These are also "chain" actions, i.e. they don't stop processing the policy rule list.
+Similarly to other actions, they apply during whole processing of the client's request,
+i.e. including any sub-queries.
+
+The log lines from these policy actions are tagged by extra ``[reqdbg]`` prefix,
+and they are produced regardless of your :func:`log_level()` setting.
+They are marked as ``debug`` level, so e.g. with journalctl command you can use ``-p info`` to skip them.
+
+.. warning:: Beware of producing too much logs.
+
+ These actions are not suitable for use on a large fraction of resolver's requests.
+ The extra logs have significant performance impact and might also overload your logging system
+ (or get rate-limited by it).
+ You can use `Filters`_ to further limit on which requests this happens.
+
+.. py:attribute:: DEBUG_ALWAYS
+
+ Print debug-level logging for this request.
+ That also includes messages from client (:any:`REQTRACE`), upstream servers (:any:`QTRACE`), and stats about interesting records at the end.
+
+ .. code-block:: lua
+
+ -- debug requests that ask for flaky.example.net or below
+ policy.add(policy.suffix(policy.DEBUG_ALWAYS,
+ policy.todnames({'flaky.example.net'})))
+
+.. py:attribute:: DEBUG_CACHE_MISS
+
+ Same as :any:`DEBUG_ALWAYS` but only if the request required information which was not available locally, i.e. requests which forced resolver to ask upstream server(s).
+ Intended usage is for debugging problems with remote servers.
+
+.. py:function:: DEBUG_IF(test_function)
+
+ :param test_function: Function with single argument of type :c:type:`kr_request` which returns ``true`` if debug logs for that request should be generated and ``false`` otherwise.
+
+ Same as :any:`DEBUG_ALWAYS` but only logs if the test_function says so.
+
+ .. note:: ``test_function`` is evaluated only when request is finished.
+ As a result all debug logs this request must be collected,
+ and at the end they get either printed or thrown away.
+
+ Example usage which gathers verbose logs for all requests in subtree ``dnssec-failed.org.`` and prints debug logs for those finishing in a different state than ``kres.DONE`` (most importantly ``kres.FAIL``, see :c:type:`kr_layer_state`).
+
+ .. code-block:: lua
+
+ policy.add(policy.suffix(
+ policy.DEBUG_IF(function(req)
+ return (req.state ~= kres.DONE)
+ end),
+ policy.todnames({'dnssec-failed.org.'})))
+
+.. py:attribute:: QTRACE
+
+ Pretty-print DNS responses from upstream servers (or cache) into logs.
+ It's useful for debugging weird DNS servers.
+
+ If you do not use ``QTRACE`` in combination with ``DEBUG*``,
+ you additionally need either ``log_groups({'iterat'})`` (possibly with other groups)
+ or ``log_level('debug')`` to see the output in logs.
+
+.. py:attribute:: REQTRACE
+
+ Pretty-print DNS requests from clients into the verbose log. It's useful for debugging weird DNS clients.
+ It makes most sense together with :ref:`mod-view` (enabling per-client)
+ and probably with verbose logging those request (e.g. use :any:`DEBUG_ALWAYS` instead).
+
+.. py:attribute:: IPTRACE
+
+ Log how the request arrived.
+ Most notably, this includes the client's IP address, so beware of privacy implications.
+
+ .. code-block:: lua
+
+ -- example usage in configuration
+ policy.add(policy.all(policy.IPTRACE))
+ -- you might want to combine it with some other logs, e.g.
+ policy.add(policy.all(policy.DEBUG_ALWAYS))
+
+ .. code-block:: text
+
+ -- example log lines from IPTRACE:
+ [reqdbg][policy][57517.00] request packet arrived from ::1#37931 to ::1#00853 (TCP + TLS)
+ [reqdbg][policy][65538.00] request packet arrived internally
+
+
+Custom actions
+^^^^^^^^^^^^^^
+
+.. function:: custom_action(state, request)
+
+ :param state: Request processing state :c:type:`kr_layer_state`.
+ :param request: Current DNS request as :c:type:`kr_request` structure.
+ :return: Returning a new :c:type:`kr_layer_state` prevents evaluating other policy rules. Returning ``nil`` creates a `chain action <#actions>`_ and allows to continue evaluating other rules.
+
+ This is real example of an action function:
+
+.. code-block:: lua
+
+ -- Custom action which generates fake A record
+ local ffi = require('ffi')
+ local function fake_A_record(state, req)
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ local qry = req:current()
+ if qry.stype ~= kres.type.A then
+ return state
+ end
+ ffi.C.kr_pkt_make_auth_header(answer)
+ answer:rcode(kres.rcode.NOERROR)
+ answer:begin(kres.section.ANSWER)
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\192\168\1\3')
+ return kres.DONE
+ end
+
+This custom action can be used as any other built-in action.
+For example this applies our *fake A record action* and executes it on all queries in subtree ``example.net``:
+
+.. code-block:: lua
+
+ policy.add(policy.suffix(fake_A_record, policy.todnames({'example.net'})))
+
+The action function can implement arbitrary logic so it is possible to implement complex heuristics, e.g. to deflect `Slow drip DNS attacks <https://secure64.com/water-torture-slow-drip-dns-ddos-attack>`_ or gray-list resolution of misbehaving zones.
+
+.. warning:: The policy module currently only looks at whole DNS requests. The rules won't be re-applied e.g. when following CNAMEs.
+
+.. _forwarding:
+
+Forwarding
+----------
+
+Forwarding action alters behavior for cache-miss events. If an information is missing in the local cache the resolver will *forward* the query to *another DNS resolver* for resolution (instead of contacting authoritative servers directly). DNS answers from the remote resolver are then processed locally and sent back to the original client.
+
+Actions :func:`policy.FORWARD`, :func:`policy.TLS_FORWARD` and :func:`policy.STUB` accept up to four IP addresses at once and the resolver will automatically select IP address which statistically responds the fastest.
+
+.. function:: FORWARD(ip_address)
+ FORWARD({ ip_address, [ip_address, ...] })
+
+ Forward cache-miss queries to specified IP addresses (without encryption), DNSSEC validate received answers and cache them. Target IP addresses are expected to be DNS resolvers.
+
+ .. code-block:: lua
+
+ -- Forward all queries to public resolvers https://www.nic.cz/odvr
+ policy.add(policy.all(
+ policy.FORWARD(
+ {'2001:148f:fffe::1', '2001:148f:ffff::1',
+ '185.43.135.1', '193.14.47.1'})))
+
+ A variant which uses encrypted DNS-over-TLS transport is called :func:`policy.TLS_FORWARD`, please see section :ref:`tls-forwarding`.
+
+.. function:: STUB(ip_address)
+ STUB({ ip_address, [ip_address, ...] })
+
+ Similar to :func:`policy.FORWARD` but *without* attempting DNSSEC validation.
+ Each request may be either answered from cache or simply sent to one of the IPs with proxying back the answer.
+
+ This mode does not support encryption and should be used only for `Replacing part of the DNS tree`_.
+ Use :func:`policy.FORWARD` mode if possible.
+
+ .. code-block:: lua
+
+ -- Answers for reverse queries about the 192.168.1.0/24 subnet
+ -- are to be obtained from IP address 192.0.2.1 port 5353
+ -- This disables DNSSEC validation!
+ policy.add(policy.suffix(
+ policy.STUB('192.0.2.1@5353'),
+ {todname('1.168.192.in-addr.arpa')}))
+
+.. note:: By default, forwarding targets must support
+ `EDNS <https://en.wikipedia.org/wiki/Extension_mechanisms_for_DNS>`_ and
+ `0x20 randomization <https://tools.ietf.org/html/draft-vixie-dnsext-dns0x20-00>`_.
+ See example in `Replacing part of the DNS tree`_.
+
+.. warning::
+ Limiting forwarding actions by filters (e.g. :func:`policy.suffix`) may have unexpected consequences.
+ Notably, forwarders can inject *any* records into your cache
+ even if you "restrict" them to an insignificant DNS subtree --
+ except in cases where DNSSEC validation applies, of course.
+
+ The behavior is probably best understood through the fact
+ that filters and actions are completely decoupled.
+ The forwarding actions have no clue about why they were executed,
+ e.g. that the user wanted to restrict the forwarder only to some subtree.
+ The action just selects some set of forwarders to process this whole request from the client,
+ and during that processing it might need some other "sub-queries" (e.g. for validation).
+ Some of those might not've passed the intended filter,
+ but policy rule-set only applies once per client's request.
+
+.. _tls-forwarding:
+
+Forwarding over TLS protocol (DNS-over-TLS)
+-------------------------------------------
+.. function:: TLS_FORWARD( { {ip_address, authentication}, [...] } )
+
+ Same as :func:`policy.FORWARD` but send query over DNS-over-TLS protocol (encrypted).
+ Each target IP address needs explicit configuration how to validate
+ TLS certificate so each IP address is configured by pair:
+ ``{ip_address, authentication}``. See sections below for more details.
+
+
+Policy :func:`policy.TLS_FORWARD` allows you to forward queries using `Transport Layer Security`_ protocol, which hides the content of your queries from an attacker observing the network traffic. Further details about this protocol can be found in :rfc:`7858` and `IETF draft dprive-dtls-and-tls-profiles`_.
+
+Queries affected by :func:`policy.TLS_FORWARD` will always be resolved over TLS connection. Knot Resolver does not implement fallback to non-TLS connection, so if TLS connection cannot be established or authenticated according to the configuration, the resolution will fail.
+
+To test this feature you need to either :ref:`configure Knot Resolver as DNS-over-TLS server <tls-server-config>`, or pick some public DNS-over-TLS server. Please see `DNS Privacy Project`_ homepage for list of public servers.
+
+.. note:: Some public DNS-over-TLS providers may apply rate-limiting which
+ makes their service incompatible with Knot Resolver's TLS forwarding.
+ Notably, `Google Public DNS
+ <https://developers.google.com/speed/public-dns/docs/dns-over-tls>`_ doesn't
+ work as of 2019-07-10.
+
+When multiple servers are specified, the one with the lowest round-trip time is used.
+
+CA+hostname authentication
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+Traditional PKI authentication requires server to present certificate with specified hostname, which is issued by one of trusted CAs. Example policy is:
+
+.. code-block:: lua
+
+ policy.TLS_FORWARD({
+ {'2001:DB8::d0c', hostname='res.example.com'}})
+
+- ``hostname`` must be a valid domain name matching server's certificate. It will also be sent to the server as SNI_.
+- ``ca_file`` optionally contains a path to a CA certificate (or certificate bundle) in `PEM format`_.
+ If you omit that, the system CA certificate store will be used instead (usually sufficient).
+ A list of paths is also accepted, but all of them must be valid PEMs.
+
+Key-pinned authentication
+^^^^^^^^^^^^^^^^^^^^^^^^^
+Instead of CAs, you can specify hashes of accepted certificates in ``pin_sha256``.
+They are in the usual format -- base64 from sha256.
+You may still specify ``hostname`` if you want SNI_ to be sent.
+
+.. _tls-examples:
+
+TLS Examples
+^^^^^^^^^^^^
+
+.. code-block:: lua
+
+ modules = { 'policy' }
+ -- forward all queries over TLS to the specified server
+ policy.add(policy.all(policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}})))
+ -- for brevity, other TLS examples omit policy.add(policy.all())
+ -- single server authenticated using its certificate pin_sha256
+ policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}}) -- pin_sha256 is base64-encoded
+ -- single server authenticated using hostname and system-wide CA certificates
+ policy.TLS_FORWARD({{'192.0.2.1', hostname='res.example.com'}})
+ -- single server using non-standard port
+ policy.TLS_FORWARD({{'192.0.2.1@443', pin_sha256='YQ=='}}) -- use @ or # to specify port
+ -- single server with multiple valid pins (e.g. anycast)
+ policy.TLS_FORWARD({{'192.0.2.1', pin_sha256={'YQ==', 'Wg=='}})
+ -- multiple servers, each with own authenticator
+ policy.TLS_FORWARD({ -- please note that { here starts list of servers
+ {'192.0.2.1', pin_sha256='Wg=='},
+ -- server must present certificate issued by specified CA and hostname must match
+ {'2001:DB8::d0c', hostname='res.example.com', ca_file='/etc/knot-resolver/tlsca.crt'}
+ })
+
+Forwarding to multiple targets
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+With the use of :func:`policy.slice` function, it is possible to split the
+entire DNS namespace into distinct slices. When used in conjunction with
+:func:`policy.TLS_FORWARD`, it's possible to forward different queries to
+different targets.
+
+.. function:: slice(slice_func, action[, action[, ...])
+
+ :param slice_func: slicing function that returns index based on query
+ :param action: action to be performed for the slice
+
+ This function splits the entire domain space into multiple slices (determined
+ by the number of provided ``actions``). A ``slice_func`` is called to determine
+ which slice a query belongs to. The corresponding ``action`` is then executed.
+
+
+.. function:: slice_randomize_psl(seed = os.time() / (3600 * 24 * 7))
+
+ :param seed: seed for random assignment
+
+ The function initializes and returns a slicing function, which
+ deterministically assigns ``query`` to a slice based on the query name.
+
+ It utilizes the `Public Suffix List`_ to ensure domains under the same
+ registrable domain end up in a single slice. (see example below)
+
+ ``seed`` can be used to re-shuffle the slicing algorithm when the slicing
+ function is initialized. By default, the assignment is re-shuffled after one
+ week (when resolver restart / reloads config). To force a stable
+ distribution, pass a fixed value. To re-shuffle on every resolver restart,
+ use ``os.time()``.
+
+ The following example demonstrates a distribution among 3 slices::
+
+ slice 1/3:
+ example.com
+ a.example.com
+ b.example.com
+ x.b.example.com
+ example3.com
+
+ slice 2/3:
+ example2.co.uk
+
+ slice 3/3:
+ example.co.uk
+ a.example.co.uk
+
+These two functions can be used together to forward queries for names
+in different parts of DNS name space to different target servers:
+
+.. code-block:: lua
+
+ policy.add(policy.slice(
+ policy.slice_randomize_psl(),
+ policy.TLS_FORWARD({{'192.0.2.1', hostname='res.example.com'}}),
+ policy.TLS_FORWARD({
+ -- multiple servers can be specified for a single slice
+ -- the one with lowest round-trip time will be used
+ {'193.17.47.1', hostname='odvr.nic.cz'},
+ {'185.43.135.1', hostname='odvr.nic.cz'},
+ })
+ ))
+
+.. note:: The privacy implications of using this feature aren't clear. Since
+ websites often make requests to multiple domains, these might be forwarded
+ to different targets. This could result in decreased privacy (e.g. when the
+ remote targets are both logging or otherwise processing your DNS traffic).
+ The intended use-case is to use this feature with semi-trusted resolvers
+ which claim to do no logging (such as those listed on `dnsprivacy.org
+ <https://dnsprivacy.org/wiki/display/DP/DNS+Privacy+Test+Servers>`_), to
+ decrease the potential exposure of your DNS data to a malicious resolver
+ operator.
+
+.. _dns-graft:
+
+Replacing part of the DNS tree
+------------------------------
+
+Following procedure applies only to domains which have different content
+publicly and internally. For example this applies to "your own" top-level domain
+``example.`` which does not exist in the public (global) DNS namespace.
+
+Dealing with these internal-only domains requires extra configuration because
+DNS was designed as "single namespace" and local modifications like adding
+your own TLD break this assumption.
+
+.. warning:: Use of internal names which are not delegated from the public DNS
+ *is causing technical problems* with caching and DNSSEC validation
+ and generally makes DNS operation more costly.
+ We recommend **against** using these non-delegated names.
+
+To make such internal domain available in your resolver it is necessary to
+*graft* your domain onto the public DNS namespace,
+but *grafting* creates new issues:
+
+These *grafted* domains will be rejected by DNSSEC validation
+because such domains are technically indistinguishable from an spoofing attack
+against the public DNS.
+Therefore, if you trust the remote resolver which hosts the internal-only domain,
+and you trust your link to it, you need to use the :func:`policy.STUB` policy
+instead of :func:`policy.FORWARD` to disable DNSSEC validation for those
+*grafted* domains.
+
+.. code-block:: lua
+ :caption: Example configuration grafting domains onto public DNS namespace
+
+ extraTrees = policy.todnames(
+ {'faketldtest.',
+ 'sld.example.',
+ 'internal.example.com.',
+ '2.0.192.in-addr.arpa.' -- this applies to reverse DNS tree as well
+ })
+ -- Beware: the rule order is important, as policy.STUB is not a chain action.
+ -- Flags: for "dumb" targets disabling EDNS can help (below) as DNSSEC isn't
+ -- validated anyway; in some of those cases adding 'NO_0X20' can also help,
+ -- though it also lowers defenses against off-path attacks on communication
+ -- between the two servers.
+ -- With kresd <= 5.5.3 you also needed 'NO_CACHE' flag to avoid unintentional
+ -- NXDOMAINs that could sometimes happen due to aggressive DNSSEC caching.
+ policy.add(policy.suffix(policy.FLAGS({'NO_EDNS'}), extraTrees))
+ policy.add(policy.suffix(policy.STUB({'2001:db8::1'}), extraTrees))
+
+Response policy zones
+---------------------
+ .. warning::
+
+ There is no published Internet Standard for RPZ_ and implementations vary.
+ At the moment Knot Resolver supports limited subset of RPZ format and deviates
+ from implementation in BIND. Nevertheless it is good enough
+ for blocking large lists of spam or advertising domains.
+
+
+
+ The RPZ file format is basically a DNS zone file with *very special* semantics.
+ For example:
+
+ .. code-block:: none
+
+ ; left hand side ; TTL and class ; right hand side
+ ; encodes RPZ trigger ; ignored ; encodes action
+ ; (i.e. filter)
+ blocked.domain.example 600 IN CNAME . ; block main domain
+ *.blocked.domain.example 600 IN CNAME . ; block subdomains
+
+ The only "trigger" supported in Knot Resolver is query name,
+ i.e. left hand side must be a domain name which triggers the action specified
+ on the right hand side.
+
+ Subset of possible RPZ actions is supported, namely:
+
+ .. csv-table::
+ :header: "RPZ Right Hand Side", "Knot Resolver Action", "BIND Compatibility"
+
+ "``.``", "``action`` is used", "compatible if ``action`` is :any:`policy.DENY`"
+ "``*.``", ":func:`policy.ANSWER`", "yes"
+ "``rpz-passthru.``", ":any:`policy.PASS`", "yes"
+ "``rpz-tcp-only.``", ":any:`policy.TC`", "yes"
+ "``rpz-drop.``", ":any:`policy.DROP`", "no [#]_"
+ "fake A/AAAA", ":func:`policy.ANSWER`", "yes"
+ "fake CNAME", "not supported", "no"
+
+ .. [#] Our :any:`policy.DROP` returns *SERVFAIL* answer (for historical reasons).
+
+
+ .. note::
+
+ To debug which domains are affected by RPZ (or other policy actions), you can enable the ``policy`` log group:
+
+ .. code-block:: lua
+
+ log_groups({'policy'})
+
+ See also :ref:`non-ASCII support note <IDN>`.
+
+
+.. function:: rpz(action, path, [watch = true])
+
+ :param action: the default action for match in the zone; typically you want :any:`policy.DENY`
+ :param path: path to zone file
+ :param watch: boolean, if true, the file will be reloaded on file change
+
+ Enforce RPZ_ rules. This can be used in conjunction with published blocklist feeds.
+ The RPZ_ operation is well described in this `Jan-Piet Mens's post`_,
+ or the `Pro DNS and BIND`_ book.
+
+ For example, we can store the example snippet with domain ``blocked.domain.example``
+ (above) into file ``/etc/knot-resolver/blocklist.rpz`` and configure resolver to
+ answer with *NXDOMAIN* plus the specified additional text to queries for this domain:
+
+ .. code-block:: lua
+
+ policy.add(
+ policy.rpz(policy.DENY_MSG('domain blocked by your resolver operator'),
+ '/etc/knot-resolver/blocklist.rpz',
+ true))
+
+ Resolver will reload RPZ file at run-time if the RPZ file changes.
+ Recommended RPZ update procedure is to store new blocklist in a new file
+ (*newblocklist.rpz*) and then rename the new file to the original file name
+ (*blocklist.rpz*). This avoids problems where resolver might attempt
+ to re-read an incomplete file.
+
+
+
+Additional properties
+---------------------
+
+Most properties (actions, filters) are described above.
+
+.. function:: add(rule, postrule)
+
+ :param rule: added rule, i.e. ``policy.pattern(policy.DENY, '[0-9]+\2cz')``
+ :param postrule: boolean, if true the rule will be evaluated on answer instead of query
+ :return: rule description
+
+ Add a new policy rule that is executed either or queries or answers, depending on the ``postrule`` parameter. You can then use the returned rule description to get information and unique identifier for the rule, as well as match count.
+
+ .. code-block:: lua
+
+ -- mirror all queries, keep handle so we can retrieve information later
+ local rule = policy.add(policy.all(policy.MIRROR('127.0.0.2')))
+ -- we can print statistics about this rule any time later
+ print(string.format('id: %d, matched queries: %d', rule.id, rule.count)
+
+.. function:: del(id)
+
+ :param id: identifier of a given rule returned by :func:`policy.add`
+ :return: boolean ``true`` if rule was deleted, ``false`` otherwise
+
+ Remove a rule from policy list.
+
+.. function:: todnames({name, ...})
+
+ :param: names table of domain names in textual format
+
+ Returns table of domain names in wire format converted from strings.
+
+ .. code-block:: lua
+
+ -- Convert single name
+ assert(todname('example.com') == '\7example\3com\0')
+ -- Convert table of names
+ policy.todnames({'example.com', 'me.cz'})
+ { '\7example\3com\0', '\2me\2cz\0' }
+
+
+.. _RPZ: https://dnsrpz.info/
+.. _`PEM format`: https://en.wikipedia.org/wiki/Privacy-enhanced_Electronic_Mail
+.. _`Pro DNS and BIND`: http://www.zytrax.com/books/dns/ch7/rpz.html
+.. _`Jan-Piet Mens's post`: http://jpmens.net/2011/04/26/how-to-configure-your-bind-resolvers-to-lie-using-response-policy-zones-rpz/
+.. _`Transport Layer Security`: https://en.wikipedia.org/wiki/Transport_Layer_Security
+.. _`DNS Privacy Project`: https://dnsprivacy.org/
+.. _`IETF draft dprive-dtls-and-tls-profiles`: https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles
+.. _SNI: https://en.wikipedia.org/wiki/Server_Name_Indication
+.. _`Public Suffix List`: https://publicsuffix.org
diff --git a/modules/policy/lua-aho-corasick/LICENSE b/modules/policy/lua-aho-corasick/LICENSE
new file mode 100644
index 0000000..dd65f72
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/LICENSE
@@ -0,0 +1,28 @@
+ Copyright (c) 2014 CloudFlare, Inc. All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of CloudFlare, Inc. nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
diff --git a/modules/policy/lua-aho-corasick/Makefile b/modules/policy/lua-aho-corasick/Makefile
new file mode 100644
index 0000000..6471664
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/Makefile
@@ -0,0 +1,134 @@
+OS := $(shell uname)
+
+ifeq ($(OS), Darwin)
+ SO_EXT := dylib
+else
+ SO_EXT := so
+endif
+
+#############################################################################
+#
+# Binaries we are going to build
+#
+#############################################################################
+#
+C_SO_NAME = libac.$(SO_EXT)
+LUA_SO_NAME = ahocorasick.$(SO_EXT)
+AR_NAME = libac.a
+
+#############################################################################
+#
+# Compile and link flags
+#
+#############################################################################
+PREFIX ?= /usr/local
+LUA_VERSION := 5.1
+LUA_INCLUDE_DIR := $(PREFIX)/include/lua$(LUA_VERSION)
+SO_TARGET_DIR := $(PREFIX)/lib/lua/$(LUA_VERSION)
+LUA_TARGET_DIR := $(PREFIX)/share/lua/$(LUA_VERSION)
+
+# Available directives:
+# -DDEBUG : Turn on debugging support
+# -DVERIFY : To verify if the slow-version and fast-version implementations
+# get exactly the same result. Note -DVERIFY implies -DDEBUG.
+#
+COMMON_FLAGS = -O3 #-g -DVERIFY -msse2 -msse3 -msse4.1
+COMMON_FLAGS += -fvisibility=hidden -Wall $(CXXFLAGS) $(MY_CXXFLAGS) $(CPPFLAGS)
+
+SO_CXXFLAGS = $(COMMON_FLAGS) -fPIC
+SO_LFLAGS = $(COMMON_FLAGS) $(LDFLAGS)
+AR_CXXFLAGS = $(COMMON_FLAGS)
+
+# -DVERIFY implies -DDEBUG
+ifneq ($(findstring -DVERIFY, $(COMMON_FLAGS)), )
+ifeq ($(findstring -DDEBUG, $(COMMON_FLAGS)), )
+ COMMON_FLAGS += -DDEBUG
+endif
+endif
+
+AR = ar
+AR_FLAGS = cru
+
+#############################################################################
+#
+# Divide source codes and objects into several categories
+#
+#############################################################################
+#
+SRC_COMMON := ac_fast.cxx ac_slow.cxx
+LIBAC_SO_SRC := $(SRC_COMMON) ac.cxx # source for libac.so
+LUA_SO_SRC := $(SRC_COMMON) ac_lua.cxx # source for ahocorasick.so
+LIBAC_A_SRC := $(LIBAC_SO_SRC) # source for libac.a
+
+#############################################################################
+#
+# Make rules
+#
+#############################################################################
+#
+.PHONY = all clean test benchmark prepare
+all : $(C_SO_NAME) $(LUA_SO_NAME) $(AR_NAME)
+
+-include c_so_dep.txt
+-include lua_so_dep.txt
+-include ar_dep.txt
+
+BUILD_SO_DIR := build_so
+BUILD_AR_DIR := build_ar
+
+$(BUILD_SO_DIR) :; mkdir $@
+$(BUILD_AR_DIR) :; mkdir $@
+
+$(BUILD_SO_DIR)/%.o : %.cxx | $(BUILD_SO_DIR)
+ $(CXX) $< -c $(SO_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@
+
+$(BUILD_AR_DIR)/%.o : %.cxx | $(BUILD_AR_DIR)
+ $(CXX) $< -c $(AR_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@
+
+ifneq ($(OS), Darwin)
+$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared -Wl,-soname=$(C_SO_NAME) $(SO_LFLAGS) -o $@
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt
+
+$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared -Wl,-soname=$(LUA_SO_NAME) $(SO_LFLAGS) -o $@
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt
+
+else
+$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared $(SO_LFLAGS) -o $@
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt
+
+$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared $(SO_LFLAGS) -o $@ -Wl,-undefined,dynamic_lookup
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt
+endif
+
+$(AR_NAME) : $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.o})
+ $(AR) $(AR_FLAGS) $@ $+
+ cat $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.d}) > lua_so_dep.txt
+
+#############################################################################
+#
+# Misc
+#
+#############################################################################
+#
+test : $(C_SO_NAME)
+ $(MAKE) -C tests && \
+ luajit tests/lua_test.lua && \
+ luajit tests/load_ac_test.lua
+
+benchmark: $(C_SO_NAME)
+ $(MAKE) benchmark -C tests
+
+clean :
+ -rm -rf *.o *.d c_so_dep.txt lua_so_dep.txt ar_dep.txt $(TEST) \
+ $(C_SO_NAME) $(LUA_SO_NAME) $(TEST) $(BUILD_SO_DIR) $(BUILD_AR_DIR) \
+ $(AR_NAME)
+ make clean -C tests
+
+install:
+ install -D -m 755 $(C_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(C_SO_NAME)
+ install -D -m 755 $(LUA_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(LUA_SO_NAME)
+ install -D -m 664 load_ac.lua $(DESTDIR)/$(LUA_TARGET_DIR)/load_ac.lua
diff --git a/modules/policy/lua-aho-corasick/README.md b/modules/policy/lua-aho-corasick/README.md
new file mode 100644
index 0000000..b5cc406
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/README.md
@@ -0,0 +1,40 @@
+aho-corasick-lua
+================
+
+C++ and Lua Implementation of the Aho-Corasick (AC) string matching algorithm
+(http://dl.acm.org/citation.cfm?id=360855).
+
+We began with pure Lua implementation and realize the performance is not
+satisfactory. So we switch to C/C++ implementation.
+
+There are two shared objects provied by this package: libac.so and ahocorasick.so
+The former is a regular shared object which can be directly used by C/C++
+application, or by Lua via FFI; and the later is a Lua module. An example usage
+is shown bellow:
+
+```lua
+local ac = require "ahocorasick"
+local dict = {"string1", "string", "etc"}
+local acinst = ac.create(dict)
+local r = ac.match(acinst, "mystring")
+```
+
+For efficiency reasons, the implementation is slightly different from the
+standard AC algorithm in that it doesn't return a set of strings in the dictionary
+that match the given string, instead it only returns one of them in case the string
+matches. The functionality of our implementation can be (precisely) described by
+following pseudo-c snippet.
+
+```C
+string foo(input-string, dictionary) {
+ string ret = the-end-of-input-string;
+ for each string s in dictionary {
+ // find the first occurrence match sub-string.
+ ret = min(ret, strstr(input-string, s);
+ }
+ return ret;
+}
+```
+
+It's pretty easy to get rid of this limitation, just to associate each state with
+a spare bit-vector dipicting the set of strings recognized by that state.
diff --git a/modules/policy/lua-aho-corasick/ac.cxx b/modules/policy/lua-aho-corasick/ac.cxx
new file mode 100644
index 0000000..23fb3b5
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac.cxx
@@ -0,0 +1,101 @@
+// Interface functions for libac.so
+//
+#include "ac_slow.hpp"
+#include "ac_fast.hpp"
+#include "ac.h"
+
+static inline ac_result_t
+_match(buf_header_t* ac, const char* str, unsigned int len) {
+ AC_Buffer* buf = (AC_Buffer*)(void*)ac;
+ ASSERT(ac->magic_num == AC_MAGIC_NUM);
+
+ ac_result_t r = Match(buf, str, len);
+
+ #ifdef VERIFY
+ {
+ Match_Result r2 = buf->slow_impl->Match(str, len);
+ if (r.match_begin != r2.begin) {
+ ASSERT(0);
+ } else {
+ ASSERT((r.match_begin < 0) ||
+ (r.match_end == r2.end &&
+ r.pattern_idx == r2.pattern_idx));
+ }
+ }
+ #endif
+ return r;
+}
+
+extern "C" int
+ac_match2(ac_t* ac, const char* str, unsigned int len) {
+ ac_result_t r = _match((buf_header_t*)(void*)ac, str, len);
+ return r.match_begin;
+}
+
+extern "C" ac_result_t
+ac_match(ac_t* ac, const char* str, unsigned int len) {
+ return _match((buf_header_t*)(void*)ac, str, len);
+}
+
+extern "C" ac_result_t
+ac_match_longest_l(ac_t* ac, const char* str, unsigned int len) {
+ AC_Buffer* buf = (AC_Buffer*)(void*)ac;
+ ASSERT(((buf_header_t*)ac)->magic_num == AC_MAGIC_NUM);
+
+ ac_result_t r = Match_Longest_L(buf, str, len);
+ return r;
+}
+
+class BufAlloc : public Buf_Allocator {
+public:
+ virtual AC_Buffer* alloc(int sz) {
+ return (AC_Buffer*)(new unsigned char[sz]);
+ }
+
+ // Do not de-allocate the buffer when the BufAlloc die.
+ virtual void free() {}
+
+ static void myfree(AC_Buffer* buf) {
+ ASSERT(buf->hdr.magic_num == AC_MAGIC_NUM);
+ const char* b = (const char*)buf;
+ delete[] b;
+ }
+};
+
+extern "C" ac_t*
+ac_create(const char** strv, unsigned int* strlenv, unsigned int v_len) {
+ if (v_len >= 65535) {
+ // TODO: Currently we use 16-bit to encode pattern-index (see the
+ // comment to AC_State::is_term), therefore we are not able to
+ // handle pattern set with more than 65535 entries.
+ return 0;
+ }
+
+ ACS_Constructor *acc;
+#ifdef VERIFY
+ acc = new ACS_Constructor;
+#else
+ ACS_Constructor tmp;
+ acc = &tmp;
+#endif
+ acc->Construct(strv, strlenv, v_len);
+
+ BufAlloc ba;
+ AC_Converter cvt(*acc, ba);
+ AC_Buffer* buf = cvt.Convert();
+
+#ifdef VERIFY
+ buf->slow_impl = acc;
+#endif
+ return (ac_t*)(void*)buf;
+}
+
+extern "C" void
+ac_free(void* ac) {
+ AC_Buffer* buf = (AC_Buffer*)ac;
+#ifdef VERIFY
+ delete buf->slow_impl;
+#endif
+
+ BufAlloc::myfree(buf);
+}
diff --git a/modules/policy/lua-aho-corasick/ac.h b/modules/policy/lua-aho-corasick/ac.h
new file mode 100644
index 0000000..30bf447
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac.h
@@ -0,0 +1,49 @@
+#ifndef AC_H
+#define AC_H
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define AC_EXPORT __attribute__ ((visibility ("default")))
+
+/* If the subject-string dosen't match any of the given patterns, "match_begin"
+ * should be a negative; otherwise the substring of the subject-string,
+ * starting from offset "match_begin" to "match_end" incusively,
+ * should exactly match the pattern specified by the 'pattern_idx' (i.e.
+ * the pattern is "pattern_v[pattern_idx]" where the "pattern_v" is the
+ * first acutal argument passing to ac_create())
+ */
+typedef struct {
+ int match_begin;
+ int match_end;
+ int pattern_idx;
+} ac_result_t;
+
+struct ac_t;
+
+/* Create an AC instance. "pattern_v" is a vector of patterns, the length of
+ * i-th pattern is specified by "pattern_len_v[i]"; the number of patterns
+ * is specified by "vect_len".
+ *
+ * Return the instance on success, or NUL otherwise.
+ */
+ac_t* ac_create(const char** pattern_v, unsigned int* pattern_len_v,
+ unsigned int vect_len) AC_EXPORT;
+
+ac_result_t ac_match(ac_t*, const char *str, unsigned int len) AC_EXPORT;
+
+ac_result_t ac_match_longest_l(ac_t*, const char *str, unsigned int len) AC_EXPORT;
+
+/* Similar to ac_match() except that it only returns match-begin. The rationale
+ * for this interface is that luajit has hard time in dealing with strcture-
+ * return-value.
+ */
+int ac_match2(ac_t*, const char *str, unsigned int len) AC_EXPORT;
+
+void ac_free(void*) AC_EXPORT;
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* AC_H */
diff --git a/modules/policy/lua-aho-corasick/ac_fast.cxx b/modules/policy/lua-aho-corasick/ac_fast.cxx
new file mode 100644
index 0000000..9dbc2e6
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_fast.cxx
@@ -0,0 +1,468 @@
+#include <algorithm> // for std::sort
+#include "ac_slow.hpp"
+#include "ac_fast.hpp"
+
+uint32
+AC_Converter::Calc_State_Sz(const ACS_State* s) const {
+ AC_State dummy;
+ uint32 sz = offsetof(AC_State, input_vect);
+ sz += s->Get_GotoNum() * sizeof(dummy.input_vect[0]);
+
+ if (sz < sizeof(AC_State))
+ sz = sizeof(AC_State);
+
+ uint32 align = __alignof__(dummy);
+ sz = (sz + align - 1) & ~(align - 1);
+ return sz;
+}
+
+AC_Buffer*
+AC_Converter::Alloc_Buffer() {
+ const vector<ACS_State*>& all_states = _acs.Get_All_States();
+ const ACS_State* root_state = _acs.Get_Root_State();
+ uint32 root_fanout = root_state->Get_GotoNum();
+
+ // Step 1: Calculate the buffer size
+ AC_Ofst root_goto_ofst, states_ofst_ofst, first_state_ofst;
+
+ // part 1 : buffer header
+ uint32 sz = root_goto_ofst = sizeof(AC_Buffer);
+
+ // part 2: Root-node's goto function
+ if (likely(root_fanout != 255))
+ sz += 256;
+ else
+ root_goto_ofst = 0;
+
+ // part 3: mapping of state's relative position.
+ unsigned align = __alignof__(AC_Ofst);
+ sz = (sz + align - 1) & ~(align - 1);
+ states_ofst_ofst = sz;
+
+ sz += sizeof(AC_Ofst) * all_states.size();
+
+ // part 4: state's contents
+ align = __alignof__(AC_State);
+ sz = (sz + align - 1) & ~(align - 1);
+ first_state_ofst = sz;
+
+ uint32 state_sz = 0;
+ for (vector<ACS_State*>::const_iterator i = all_states.begin(),
+ e = all_states.end(); i != e; i++) {
+ state_sz += Calc_State_Sz(*i);
+ }
+ state_sz -= Calc_State_Sz(root_state);
+
+ sz += state_sz;
+
+ // Step 2: Allocate buffer, and populate header.
+ AC_Buffer* buf = _buf_alloc.alloc(sz);
+
+ buf->hdr.magic_num = AC_MAGIC_NUM;
+ buf->hdr.impl_variant = IMPL_FAST_VARIANT;
+ buf->buf_len = sz;
+ buf->root_goto_ofst = root_goto_ofst;
+ buf->states_ofst_ofst = states_ofst_ofst;
+ buf->first_state_ofst = first_state_ofst;
+ buf->root_goto_num = root_fanout;
+ buf->state_num = _acs.Get_State_Num();
+ return buf;
+}
+
+void
+AC_Converter::Populate_Root_Goto_Func(AC_Buffer* buf,
+ GotoVect& goto_vect) {
+ unsigned char *buf_base = (unsigned char*)(buf);
+ InputTy* root_gotos = (InputTy*)(buf_base + buf->root_goto_ofst);
+ const ACS_State* root_state = _acs.Get_Root_State();
+
+ root_state->Get_Sorted_Gotos(goto_vect);
+
+ // Renumber the ID of root-node's immediate kids.
+ uint32 new_id = 1;
+ bool full_fantout = (goto_vect.size() == 255);
+ if (likely(!full_fantout))
+ bzero(root_gotos, 256*sizeof(InputTy));
+
+ for (GotoVect::iterator i = goto_vect.begin(), e = goto_vect.end();
+ i != e; i++, new_id++) {
+ InputTy c = i->first;
+ ACS_State* s = i->second;
+ _id_map[s->Get_ID()] = new_id;
+
+ if (likely(!full_fantout))
+ root_gotos[c] = new_id;
+ }
+}
+
+AC_Buffer*
+AC_Converter::Convert() {
+ // Step 1: Some preparation stuff.
+ GotoVect gotovect;
+
+ _id_map.clear();
+ _ofst_map.clear();
+ _id_map.resize(_acs.Get_Next_Node_Id());
+ _ofst_map.resize(_acs.Get_Next_Node_Id());
+
+ // Step 2: allocate buffer to accommodate the entire AC graph.
+ AC_Buffer* buf = Alloc_Buffer();
+ unsigned char* buf_base = (unsigned char*)buf;
+
+ // Step 3: Root node need special care.
+ Populate_Root_Goto_Func(buf, gotovect);
+ buf->root_goto_num = gotovect.size();
+ _id_map[_acs.Get_Root_State()->Get_ID()] = 0;
+
+ // Step 4: Converting the remaining states by BFSing the graph.
+ // First of all, enter root's immediate kids to the working list.
+ vector<const ACS_State*> wl;
+ State_ID id = 1;
+ for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end();
+ i != e; i++, id++) {
+ ACS_State* s = i->second;
+ wl.push_back(s);
+ _id_map[s->Get_ID()] = id;
+ }
+
+ AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst);
+ AC_Ofst ofst = buf->first_state_ofst;
+ for (uint32 idx = 0; idx < wl.size(); idx++) {
+ const ACS_State* old_s = wl[idx];
+ AC_State* new_s = (AC_State*)(buf_base + ofst);
+
+ // This property should hold as we:
+ // - States are appended to worklist in the BFS order.
+ // - sibiling states are appended to worklist in the order of their
+ // corresponding input.
+ //
+ State_ID state_id = idx + 1;
+ ASSERT(_id_map[old_s->Get_ID()] == state_id);
+
+ state_ofst_vect[state_id] = ofst;
+
+ new_s->first_kid = wl.size() + 1;
+ new_s->depth = old_s->Get_Depth();
+ new_s->is_term = old_s->is_Terminal() ?
+ old_s->get_Pattern_Idx() + 1 : 0;
+
+ uint32 gotonum = old_s->Get_GotoNum();
+ new_s->goto_num = gotonum;
+
+ // Populate the "input" field
+ old_s->Get_Sorted_Gotos(gotovect);
+ uint32 input_idx = 0;
+ uint32 id = wl.size() + 1;
+ InputTy* input_vect = new_s->input_vect;
+ for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end();
+ i != e; i++, id++, input_idx++) {
+ input_vect[input_idx] = i->first;
+
+ ACS_State* kid = i->second;
+ _id_map[kid->Get_ID()] = id;
+ wl.push_back(kid);
+ }
+
+ _ofst_map[old_s->Get_ID()] = ofst;
+ ofst += Calc_State_Sz(old_s);
+ }
+
+ // This assertion might be useful to catch buffer overflow
+ ASSERT(ofst == buf->buf_len);
+
+ // Populate the fail-link field.
+ for (vector<const ACS_State*>::iterator i = wl.begin(), e = wl.end();
+ i != e; i++) {
+ const ACS_State* slow_s = *i;
+ State_ID fast_s_id = _id_map[slow_s->Get_ID()];
+ AC_State* fast_s = (AC_State*)(buf_base + state_ofst_vect[fast_s_id]);
+ if (const ACS_State* fl = slow_s->Get_FailLink()) {
+ State_ID id = _id_map[fl->Get_ID()];
+ fast_s->fail_link = id;
+ } else
+ fast_s->fail_link = 0;
+ }
+#ifdef DEBUG
+ //dump_buffer(buf, stderr);
+#endif
+ return buf;
+}
+
+static inline AC_State*
+Get_State_Addr(unsigned char* buf_base, AC_Ofst* StateOfstVect, uint32 state_id) {
+ ASSERT(state_id != 0 && "root node is handled in speical way");
+ ASSERT(state_id < ((AC_Buffer*)buf_base)->state_num);
+ return (AC_State*)(buf_base + StateOfstVect[state_id]);
+}
+
+// The performance of the binary search is critical to this work.
+//
+// Here we provide two versions of binary-search functions.
+// The non-pristine version seems to consistently out-perform "pristine" one on
+// bunch of benchmarks we tested. With the benchmark under tests/testinput/
+//
+// The speedup is following on my laptop (core i7, ubuntu):
+//
+// benchmark was is
+// ----------------------------------------
+// image.bin 2.3s 2.0s
+// test.tar 6.7s 5.7s
+//
+// NOTE: As of I write this comment, we only measure the performance on about
+// 10+ benchmarks. It's still too early to say which one works better.
+//
+#if !defined(BS_MULTI_VER)
+static bool __attribute__((always_inline)) inline
+Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) {
+ if (vect_len <= 8) {
+ for (int i = 0; i < vect_len; i++) {
+ if (input_vect[i] == input) {
+ idx = i;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // The "low" and "high" must be signed integers, as they could become -1.
+ // Also since they are signed integer, "(low + high)/2" is sightly more
+ // expensive than (low+high)>>1 or ((unsigned)(low + high))/2.
+ //
+ int low = 0, high = vect_len - 1;
+ while (low <= high) {
+ int mid = (low + high) >> 1;
+ InputTy mid_c = input_vect[mid];
+
+ if (input < mid_c)
+ high = mid - 1;
+ else if (input > mid_c)
+ low = mid + 1;
+ else {
+ idx = mid;
+ return true;
+ }
+ }
+ return false;
+}
+
+#else
+
+/* Let us call this version "pristine" version. */
+static inline bool
+Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) {
+ int low = 0, high = vect_len - 1;
+ while (low <= high) {
+ int mid = (low + high) >> 1;
+ InputTy mid_c = input_vect[mid];
+
+ if (input < mid_c)
+ high = mid - 1;
+ else if (input > mid_c)
+ low = mid + 1;
+ else {
+ idx = mid;
+ return true;
+ }
+ }
+ return false;
+}
+#endif
+
+typedef enum {
+ // Look for the first match. e.g. pattern set = {"ab", "abc", "def"},
+ // subject string "ababcdef". The first match would be "ab" at the
+ // beginning of the subject string.
+ MV_FIRST_MATCH,
+
+ // Look for the left-most longest match. Follow above example; there are
+ // two longest matches, "abc" and "def", and the left-most longest match
+ // is "abc".
+ MV_LEFT_LONGEST,
+
+ // Similar to the left-most longest match, except that it returns the
+ // *right* most longest match. Follow above example, the match would
+ // be "def". NYI.
+ MV_RIGHT_LONGEST,
+
+ // Return all patterns that match that given subject string. NYI.
+ MV_ALL_MATCHES,
+} MATCH_VARIANT;
+
+/* The Match_Tmpl is the template for vairants MV_FIRST_MATCH, MV_LEFT_LONGEST,
+ * MV_RIGHT_LONGEST (If we really really need MV_RIGHT_LONGEST variant, we are
+ * better off implementing it in a seprate function).
+ *
+ * The Match_Tmpl supports three variants at once "symbolically", once it's
+ * instanced to a particular variants, all the code irrelevant to the variants
+ * will be statically removed. So don't worry about the code like
+ * "if (variant == MV_XXXX)"; they will not incur any penalty.
+ *
+ * The drawback of using template is increased code size. Unfortunately, there
+ * is no silver bullet.
+ */
+template<MATCH_VARIANT variant> static ac_result_t
+Match_Tmpl(AC_Buffer* buf, const char* str, uint32 len) {
+ unsigned char* buf_base = (unsigned char*)(buf);
+ unsigned char* root_goto = buf_base + buf->root_goto_ofst;
+ AC_Ofst* states_ofst_vect = (AC_Ofst* )(buf_base + buf->states_ofst_ofst);
+
+ AC_State* state = 0;
+ uint32 idx = 0;
+
+ // Skip leading chars that are not valid input of root-nodes.
+ if (likely(buf->root_goto_num != 255)) {
+ while(idx < len) {
+ unsigned char c = str[idx++];
+ if (unsigned char kid_id = root_goto[c]) {
+ state = Get_State_Addr(buf_base, states_ofst_vect, kid_id);
+ break;
+ }
+ }
+ } else {
+ idx = 1;
+ state = Get_State_Addr(buf_base, states_ofst_vect, *str);
+ }
+
+ ac_result_t r = {-1, -1};
+ if (likely(state != 0)) {
+ if (unlikely(state->is_term)) {
+ /* Dictionary may have string of length 1 */
+ r.match_begin = idx - state->depth;
+ r.match_end = idx - 1;
+ r.pattern_idx = state->is_term - 1;
+
+ if (variant == MV_FIRST_MATCH) {
+ return r;
+ }
+ }
+ }
+
+ while (idx < len) {
+ unsigned char c = str[idx];
+ int res;
+ bool found;
+ found = Binary_Search_Input(state->input_vect, state->goto_num, c, res);
+ if (found) {
+ // The "t = goto(c, current_state)" is valid, advance to state "t".
+ uint32 kid = state->first_kid + res;
+ state = Get_State_Addr(buf_base, states_ofst_vect, kid);
+ idx++;
+ } else {
+ // Follow the fail-link.
+ State_ID fl = state->fail_link;
+ if (fl == 0) {
+ // fail-link is root-node, which implies the root-node dosen't
+ // have 255 valid transitions (otherwise, the fail-link should
+ // points to "goto(root, c)"), so we don't need speical handling
+ // as we did before this while-loop is entered.
+ //
+ while(idx < len) {
+ InputTy c = str[idx++];
+ if (unsigned char kid_id = root_goto[c]) {
+ state =
+ Get_State_Addr(buf_base, states_ofst_vect, kid_id);
+ break;
+ }
+ }
+ } else {
+ state = Get_State_Addr(buf_base, states_ofst_vect, fl);
+ }
+ }
+
+ // Check to see if the state is terminal state?
+ if (state->is_term) {
+ if (variant == MV_FIRST_MATCH) {
+ ac_result_t r;
+ r.match_begin = idx - state->depth;
+ r.match_end = idx - 1;
+ r.pattern_idx = state->is_term - 1;
+ return r;
+ }
+
+ if (variant == MV_LEFT_LONGEST) {
+ int match_begin = idx - state->depth;
+ int match_end = idx - 1;
+
+ if (r.match_begin == -1 ||
+ match_end - match_begin > r.match_end - r.match_begin) {
+ r.match_begin = match_begin;
+ r.match_end = match_end;
+ r.pattern_idx = state->is_term - 1;
+ }
+ continue;
+ }
+
+ ASSERT(false && "NYI");
+ }
+ }
+
+ return r;
+}
+
+ac_result_t
+Match(AC_Buffer* buf, const char* str, uint32 len) {
+ return Match_Tmpl<MV_FIRST_MATCH>(buf, str, len);
+}
+
+ac_result_t
+Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len) {
+ return Match_Tmpl<MV_LEFT_LONGEST>(buf, str, len);
+}
+
+#ifdef DEBUG
+void
+AC_Converter::dump_buffer(AC_Buffer* buf, FILE* f) {
+ vector<AC_Ofst> state_ofst;
+ state_ofst.resize(_id_map.size());
+
+ fprintf(f, "Id maps between old/slow and new/fast graphs\n");
+ int old_id = 0;
+ for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end();
+ i != e; i++, old_id++) {
+ State_ID new_id = *i;
+ if (new_id != 0) {
+ fprintf(f, "%d -> %d, ", old_id, new_id);
+ }
+ }
+ fprintf(f, "\n");
+
+ int idx = 0;
+ for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end();
+ i != e; i++, idx++) {
+ uint32 id = *i;
+ if (id == 0) continue;
+ state_ofst[id] = _ofst_map[idx];
+ }
+
+ unsigned char* buf_base = (unsigned char*)buf;
+
+ // dump root goto-function.
+ fprintf(f, "root, fanout:%d goto {", buf->root_goto_num);
+ if (buf->root_goto_num != 255) {
+ unsigned char* root_goto = buf_base + buf->root_goto_ofst;
+ for (uint32 i = 0; i < 255; i++) {
+ if (root_goto[i] != 0)
+ fprintf(f, "%c->S:%d, ", (unsigned char)i, root_goto[i]);
+ }
+ } else {
+ fprintf(f, "full fanout\n");
+ }
+ fprintf(f, "}\n");
+
+ // dump remaining states.
+ AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst);
+ for (uint32 i = 1, e = buf->state_num; i < e; i++) {
+ AC_Ofst ofst = state_ofst_vect[i];
+ ASSERT(ofst == state_ofst[i]);
+ fprintf(f, "S:%d, ofst:%d, goto={", i, ofst);
+
+ AC_State* s = (AC_State*)(buf_base + ofst);
+ State_ID kid = s->first_kid;
+ for (uint32 k = 0, ke = s->goto_num; k < ke; k++, kid++)
+ fprintf(f, "%c->S:%d, ", s->input_vect[k], kid);
+
+ fprintf(f, "}, fail-link = S:%d, %s\n", s->fail_link,
+ s->is_term ? "terminal" : "");
+ }
+}
+#endif
diff --git a/modules/policy/lua-aho-corasick/ac_fast.hpp b/modules/policy/lua-aho-corasick/ac_fast.hpp
new file mode 100644
index 0000000..9ac557c
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_fast.hpp
@@ -0,0 +1,124 @@
+#ifndef AC_FAST_H
+#define AC_FAST_H
+
+#include <vector>
+#include "ac.h"
+#include "ac_slow.hpp"
+
+using namespace std;
+
+class ACS_Constructor;
+
+typedef uint32 AC_Ofst;
+typedef uint32 State_ID;
+
+// The entire "fast" AC graph is converted from its "slow" version, and store
+// in an consecutive trunk of memory or "buffer". Since the pointers in the
+// fast AC graph are represented as offset relative to the base address of
+// the buffer, this fast AC graph is position-independent, meaning cloning
+// the fast graph is just to memcpy the entire buffer.
+//
+// The buffer is laid-out as following:
+//
+// 1. The buffer header. (i.e. the AC_Buffer content)
+// 2. root-node's goto functions. It is represented as an array indiced by
+// root-node's valid inputs, and the element is the ID of the corresponding
+// transition state (aka kid). To save space, we used 8-bit to represent
+// the IDs. ID of root's kids starts with 1.
+//
+// Root may have 255 valid inputs. In this speical case, i-th element
+// stores value i -- i.e the i-th state. So, we don't need such array
+// at all. On the other hand, 8-bit is insufficient to encode kids' ID.
+//
+// 3. An array indiced by state's id, and the element is the offset
+// of correspoding state wrt the base address of the buffer.
+//
+// 4. the contents of states.
+//
+typedef struct {
+ buf_header_t hdr; // The header exposed to the user using this lib.
+#ifdef VERIFY
+ ACS_Constructor* slow_impl;
+#endif
+ uint32 buf_len;
+ AC_Ofst root_goto_ofst; // addr of root node's goto() function.
+ AC_Ofst states_ofst_ofst; // addr of state pointer vector (indiced by id)
+ AC_Ofst first_state_ofst; // addr of the first state in the buffer.
+ uint16 root_goto_num; // fan-out of root-node.
+ uint16 state_num; // number of states
+
+ // Followed by the gut of the buffer:
+ // 1. map: root's-valid-input -> kid's id
+ // 2. map: state's ID -> offset of the state
+ // 3. states' content.
+} AC_Buffer;
+
+// Depict the state of "fast" AC graph.
+typedef struct {
+ // transition are sorted. For instance, state s1, has two transitions :
+ // goto(b) -> S_b, goto(a)->S_a. The inputs are sorted in the ascending
+ // order, and the target states are permuted accordingly. In this case,
+ // the inputs are sorted as : a, b, and the target states are permuted
+ // into S_a, S_b. So, S_a is the 1st kid, the ID of kids are consecutive,
+ // so we don't need to save all the target kids.
+ //
+ State_ID first_kid;
+ AC_Ofst fail_link;
+ short depth; // How far away from root.
+ unsigned short is_term; // Is terminal node. if is_term != 0, it encodes
+ // the value of "1 + pattern-index".
+ unsigned char goto_num; // The number of valid transition.
+ InputTy input_vect[1]; // Vector of valid input. Must be last field!
+} AC_State;
+
+class Buf_Allocator {
+public:
+ Buf_Allocator() : _buf(0) {}
+ virtual ~Buf_Allocator() { free(); }
+
+ virtual AC_Buffer* alloc(int sz) = 0;
+ virtual void free() {};
+protected:
+ AC_Buffer* _buf;
+};
+
+// Convert slow-AC-graph into fast one.
+class AC_Converter {
+public:
+ AC_Converter(ACS_Constructor& acs, Buf_Allocator& ba) :
+ _acs(acs), _buf_alloc(ba) {}
+ AC_Buffer* Convert();
+
+private:
+ // Return the size in byte needed to to save the specified state.
+ uint32 Calc_State_Sz(const ACS_State *) const;
+
+ // In fast-AC-graph, the ID is bit trikcy. Given a state of slow-graph,
+ // this function is to return the ID of its counterpart in the fast-graph.
+ State_ID Get_Renumbered_Id(const ACS_State *s) const {
+ const vector<uint32> &m = _id_map;
+ return m[s->Get_ID()];
+ }
+
+ AC_Buffer* Alloc_Buffer();
+ void Populate_Root_Goto_Func(AC_Buffer *, GotoVect&);
+
+#ifdef DEBUG
+ void dump_buffer(AC_Buffer*, FILE*);
+#endif
+
+private:
+ ACS_Constructor& _acs;
+ Buf_Allocator& _buf_alloc;
+
+ // map: ID of state in slow-graph -> ID of counterpart in fast-graph.
+ vector<uint32> _id_map;
+
+ // map: ID of state in slow-graph -> offset of counterpart in fast-graph.
+ vector<AC_Ofst> _ofst_map;
+};
+
+ac_result_t Match(AC_Buffer* buf, const char* str, uint32 len);
+ac_result_t Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len);
+
+#endif // AC_FAST_H
diff --git a/modules/policy/lua-aho-corasick/ac_lua.cxx b/modules/policy/lua-aho-corasick/ac_lua.cxx
new file mode 100644
index 0000000..ad7307e
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_lua.cxx
@@ -0,0 +1,173 @@
+// Interface functions for libac.so
+//
+#include <vector>
+#include <string>
+#include "ac_slow.hpp"
+#include "ac_fast.hpp"
+#include "ac.h" // for the definition of ac_result_t
+#include "ac_util.hpp"
+
+extern "C" {
+ #include <lua.h>
+ #include <lauxlib.h>
+}
+
+#if defined(USE_SLOW_VER)
+#error "Not going to implement it"
+#endif
+
+using namespace std;
+static const char* tname = "aho-corasick";
+
+class BufAlloc : public Buf_Allocator {
+public:
+ BufAlloc(lua_State* L) : _L(L) {}
+ virtual AC_Buffer* alloc(int sz) {
+ return (AC_Buffer*)lua_newuserdata (_L, sz);
+ }
+
+ // Let GC to take care.
+ virtual void free() {}
+
+private:
+ lua_State* _L;
+};
+
+static bool
+_create_helper(lua_State* L, const vector<const char*>& str_v,
+ const vector<unsigned int>& strlen_v) {
+ ASSERT(str_v.size() == strlen_v.size());
+
+ ACS_Constructor acc;
+ BufAlloc ba(L);
+
+ // Step 1: construt the slow version.
+ unsigned int strnum = str_v.size();
+ const char** str_vect = new const char*[strnum];
+ unsigned int* strlen_vect = new unsigned int[strnum];
+
+ int idx = 0;
+ for (vector<const char*>::const_iterator i = str_v.begin(), e = str_v.end();
+ i != e; i++) {
+ str_vect[idx++] = *i;
+ }
+
+ idx = 0;
+ for (vector<unsigned int>::const_iterator i = strlen_v.begin(),
+ e = strlen_v.end(); i != e; i++) {
+ strlen_vect[idx++] = *i;
+ }
+
+ acc.Construct(str_vect, strlen_vect, idx);
+ delete[] str_vect;
+ delete[] strlen_vect;
+
+ // Step 2: convert to fast version
+ AC_Converter cvt(acc, ba);
+ return cvt.Convert() != 0;
+}
+
+static ac_result_t
+_match_helper(buf_header_t* ac, const char *str, unsigned int len) {
+ AC_Buffer* buf = (AC_Buffer*)(void*)ac;
+ ASSERT(ac->magic_num == AC_MAGIC_NUM);
+
+ ac_result_t r = Match(buf, str, len);
+ return r;
+}
+
+// LUA sematic:
+// input: array of strings
+// output: userdata containing the AC-graph (i.e. the AC_Buffer).
+//
+static int
+lac_create(lua_State* L) {
+ // The table of the array must be the 1st argument.
+ int input_tab = 1;
+
+ luaL_checktype(L, input_tab, LUA_TTABLE);
+
+ // Init the "iteartor".
+ lua_pushnil(L);
+
+ vector<const char*> str_v;
+ vector<unsigned int> strlen_v;
+
+ // Loop over the elements
+ while (lua_next(L, input_tab)) {
+ size_t str_len;
+ const char* s = luaL_checklstring(L, -1, &str_len);
+ str_v.push_back(s);
+ strlen_v.push_back(str_len);
+
+ // remove the value, but keep the key as the iterator.
+ lua_pop(L, 1);
+ }
+
+ // pop the nil value
+ lua_pop(L, 1);
+
+ if (_create_helper(L, str_v, strlen_v)) {
+ // The AC graph, as a userdata is already pushed to the stack, hence 1.
+ return 1;
+ }
+
+ return 0;
+}
+
+// LUA input:
+// arg1: the userdata, representing the AC graph, returned from l_create().
+// arg2: the string to be matched.
+//
+// LUA return:
+// if match, return index range of the match; otherwise nil is returned.
+//
+static int
+lac_match(lua_State* L) {
+ buf_header_t* ac = (buf_header_t*)lua_touserdata(L, 1);
+ if (!ac) {
+ luaL_checkudata(L, 1, tname);
+ return 0;
+ }
+
+ size_t len;
+ const char* str;
+ #if LUA_VERSION_NUM >= 502
+ str = luaL_tolstring(L, 2, &len);
+ #else
+ str = lua_tolstring(L, 2, &len);
+ #endif
+ if (!str) {
+ luaL_checkstring(L, 2);
+ return 0;
+ }
+
+ ac_result_t r = _match_helper(ac, str, len);
+ if (r.match_begin != -1) {
+ lua_pushinteger(L, r.match_begin);
+ lua_pushinteger(L, r.match_end);
+ return 2;
+ }
+
+ return 0;
+}
+
+static const struct luaL_Reg lib_funcs[] = {
+ { "create", lac_create },
+ { "match", lac_match },
+ {0, 0}
+};
+
+extern "C" int AC_EXPORT
+luaopen_ahocorasick(lua_State* L) {
+ luaL_newmetatable(L, tname);
+
+#if LUA_VERSION_NUM == 501
+ luaL_register(L, tname, lib_funcs);
+#elif LUA_VERSION_NUM >= 502
+ luaL_newlib(L, lib_funcs);
+#else
+ #error "Don't know how to do it right"
+#endif
+ return 1;
+}
diff --git a/modules/policy/lua-aho-corasick/ac_slow.cxx b/modules/policy/lua-aho-corasick/ac_slow.cxx
new file mode 100644
index 0000000..cb3957a
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_slow.cxx
@@ -0,0 +1,318 @@
+#include <ctype.h>
+#include <strings.h> // for bzero
+#include <algorithm>
+#include "ac_slow.hpp"
+#include "ac.h"
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Implementation of AhoCorasick_Slow
+//
+//////////////////////////////////////////////////////////////////////////
+//
+ACS_Constructor::ACS_Constructor() : _next_node_id(1) {
+ _root = new_state();
+ _root_char = new InputTy[256];
+ bzero((void*)_root_char, 256);
+
+#ifdef VERIFY
+ _pattern_buf = 0;
+#endif
+}
+
+ACS_Constructor::~ACS_Constructor() {
+ for (std::vector<ACS_State* >::iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ delete *i;
+ }
+ _all_states.clear();
+ delete[] _root_char;
+
+#ifdef VERIFY
+ delete[] _pattern_buf;
+#endif
+}
+
+ACS_State*
+ACS_Constructor::new_state() {
+ ACS_State* t = new ACS_State(_next_node_id++);
+ _all_states.push_back(t);
+ return t;
+}
+
+void
+ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len,
+ int pattern_idx) {
+ ACS_State* state = _root;
+ for (unsigned int i = 0; i < str_len; i++) {
+ const char c = str[i];
+ ACS_State* new_s = state->Get_Goto(c);
+ if (!new_s) {
+ new_s = new_state();
+ new_s->_depth = state->_depth + 1;
+ state->Set_Goto(c, new_s);
+ }
+ state = new_s;
+ }
+ state->_is_terminal = true;
+ state->set_Pattern_Idx(pattern_idx);
+}
+
+void
+ACS_Constructor::Propagate_faillink() {
+ ACS_State* r = _root;
+ std::vector<ACS_State*> wl;
+
+ const ACS_Goto_Map& m = r->Get_Goto_Map();
+ for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) {
+ ACS_State* s = i->second;
+ s->_fail_link = r;
+ wl.push_back(s);
+ }
+
+ // For any input c, make sure "goto(root, c)" is valid, which make the
+ // fail-link propagation lot easier.
+ ACS_Goto_Map goto_save = r->_goto_map;
+ for (uint32 i = 0; i <= 255; i++) {
+ ACS_State* s = r->Get_Goto(i);
+ if (!s) r->Set_Goto(i, r);
+ }
+
+ for (uint32 i = 0; i < wl.size(); i++) {
+ ACS_State* s = wl[i];
+ ACS_State* fl = s->_fail_link;
+
+ const ACS_Goto_Map& tran_map = s->Get_Goto_Map();
+
+ for (ACS_Goto_Map::const_iterator ii = tran_map.begin(),
+ ee = tran_map.end(); ii != ee; ii++) {
+ InputTy c = ii->first;
+ ACS_State *tran = ii->second;
+
+ ACS_State* tran_fl = 0;
+ for (ACS_State* fl_walk = fl; ;) {
+ if (ACS_State* t = fl_walk->Get_Goto(c)) {
+ tran_fl = t;
+ break;
+ } else {
+ fl_walk = fl_walk->Get_FailLink();
+ }
+ }
+
+ tran->_fail_link = tran_fl;
+ wl.push_back(tran);
+ }
+ }
+
+ // Remove "goto(root, c) == root" transitions
+ r->_goto_map = goto_save;
+}
+
+void
+ACS_Constructor::Construct(const char** strv, unsigned int* strlenv,
+ uint32 strnum) {
+ Save_Patterns(strv, strlenv, strnum);
+
+ for (uint32 i = 0; i < strnum; i++) {
+ Add_Pattern(strv[i], strlenv[i], i);
+ }
+
+ Propagate_faillink();
+ unsigned char* p = _root_char;
+
+ const ACS_Goto_Map& m = _root->Get_Goto_Map();
+ for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end();
+ i != e; i++) {
+ p[i->first] = 1;
+ }
+}
+
+Match_Result
+ACS_Constructor::MatchHelper(const char *str, uint32 len) const {
+ const ACS_State* root = _root;
+ const ACS_State* state = root;
+
+ uint32 idx = 0;
+ while (idx < len) {
+ InputTy c = str[idx];
+ idx++;
+ if (_root_char[c]) {
+ state = root->Get_Goto(c);
+ break;
+ }
+ }
+
+ if (unlikely(state->is_Terminal())) {
+ // This could happen if the one of the pattern has only one char!
+ uint32 pos = idx - 1;
+ Match_Result r(pos - state->Get_Depth() + 1, pos,
+ state->get_Pattern_Idx());
+ return r;
+ }
+
+ while (idx < len) {
+ InputTy c = str[idx];
+ ACS_State* gs = state->Get_Goto(c);
+
+ if (!gs) {
+ ACS_State* fl = state->Get_FailLink();
+ if (fl == root) {
+ while (idx < len) {
+ InputTy c = str[idx];
+ idx++;
+ if (_root_char[c]) {
+ state = root->Get_Goto(c);
+ break;
+ }
+ }
+ } else {
+ state = fl;
+ }
+ } else {
+ idx ++;
+ state = gs;
+ }
+
+ if (state->is_Terminal()) {
+ uint32 pos = idx - 1;
+ Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos,
+ state->get_Pattern_Idx());
+ return r;
+ }
+ }
+
+ return Match_Result(-1, -1, -1);
+}
+
+#ifdef DEBUG
+void
+ACS_Constructor::dump_text(const char* txtfile) const {
+ FILE* f = fopen(txtfile, "w+");
+ for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ ACS_State* s = *i;
+
+ fprintf(f, "S%d goto:{", s->Get_ID());
+ const ACS_Goto_Map& goto_func = s->Get_Goto_Map();
+
+ for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end();
+ i != e; i++) {
+ InputTy input = i->first;
+ ACS_State* tran = i->second;
+ if (isprint(input))
+ fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID());
+ else
+ fprintf(f, "%#x -> S:%d,", input, tran->Get_ID());
+ }
+ fprintf(f, "} ");
+
+ if (s->_fail_link) {
+ fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID());
+ }
+
+ if (s->_is_terminal) {
+ fprintf(f, ", terminal");
+ }
+
+ fprintf(f, "\n");
+ }
+ fclose(f);
+}
+
+void
+ACS_Constructor::dump_dot(const char *dotfile) const {
+ FILE* f = fopen(dotfile, "w+");
+ const char* indent = " ";
+
+ fprintf(f, "digraph G {\n");
+
+ // Emit node information
+ fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID());
+ for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ ACS_State *s = *i;
+ if (s->_is_terminal) {
+ fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID());
+ }
+ }
+ fprintf(f, "\n");
+
+ // Emit edge information
+ for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ ACS_State* s = *i;
+ uint32 id = s->Get_ID();
+
+ const ACS_Goto_Map& m = s->Get_Goto_Map();
+ for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end();
+ ii != ee; ii++) {
+ InputTy input = ii->first;
+ ACS_State* tran = ii->second;
+ if (isalnum(input))
+ fprintf(f, "%s%d -> %d [label=%c];\n",
+ indent, id, tran->Get_ID(), input);
+ else
+ fprintf(f, "%s%d -> %d [label=\"%#x\"];\n",
+ indent, id, tran->Get_ID(), input);
+
+ }
+
+ // Emit fail-link
+ ACS_State* fl = s->Get_FailLink();
+ if (fl && fl != _root) {
+ fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n",
+ indent, id, fl->Get_ID());
+ }
+ }
+ fprintf(f, "}\n");
+ fclose(f);
+}
+#endif
+
+#ifdef VERIFY
+void
+ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r)
+ const {
+ if (r->begin >= 0) {
+ unsigned len = r->end - r->begin + 1;
+ int ptn_idx = r->pattern_idx;
+
+ ASSERT(ptn_idx >= 0 &&
+ len == get_ith_Pattern_Len(ptn_idx) &&
+ memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0);
+ }
+}
+
+void
+ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv,
+ int pattern_num) {
+ // calculate the total size needed to save all patterns.
+ //
+ int buf_size = 0;
+ for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; }
+
+ // HINT: patterns are delimited by '\0' in order to ease debugging.
+ buf_size += pattern_num;
+ ASSERT(_pattern_buf == 0);
+ _pattern_buf = new char[buf_size + 1];
+ #define MAGIC_NUM 0x5a
+ _pattern_buf[buf_size] = MAGIC_NUM;
+
+ int ofst = 0;
+ _pattern_lens.resize(pattern_num);
+ _pattern_vect.resize(pattern_num);
+ for (int i = 0; i < pattern_num; i++) {
+ int l = strlenv[i];
+ _pattern_lens[i] = l;
+ _pattern_vect[i] = _pattern_buf + ofst;
+
+ memcpy(_pattern_buf + ofst, strv[i], l);
+ ofst += l;
+ _pattern_buf[ofst++] = '\0';
+ }
+
+ ASSERT(_pattern_buf[buf_size] == MAGIC_NUM);
+ #undef MAGIC_NUM
+}
+
+#endif
diff --git a/modules/policy/lua-aho-corasick/ac_slow.hpp b/modules/policy/lua-aho-corasick/ac_slow.hpp
new file mode 100644
index 0000000..030b95d
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_slow.hpp
@@ -0,0 +1,158 @@
+#ifndef MY_AC_H
+#define MY_AC_H
+
+#include <string.h>
+#include <stdio.h>
+#include <map>
+#include <vector>
+#include <algorithm> // for std::sort
+#include "ac_util.hpp"
+
+// Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation"
+class ACS_State;
+class ACS_Constructor;
+class AhoCorasick;
+
+using namespace std;
+
+typedef std::map<InputTy, ACS_State*> ACS_Goto_Map;
+
+class Match_Result {
+public:
+ int begin;
+ int end;
+ int pattern_idx;
+ Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {}
+};
+
+typedef pair<InputTy, ACS_State *> GotoPair;
+typedef vector<GotoPair> GotoVect;
+
+// Sorting functor
+class GotoSort {
+public:
+ bool operator() (const GotoPair& g1, const GotoPair& g2) {
+ return g1.first < g2.first;
+ }
+};
+
+class ACS_State {
+friend class ACS_Constructor;
+
+public:
+ ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0),
+ _is_terminal(false), _fail_link(0){}
+ ~ACS_State() {};
+
+ void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; }
+ ACS_State *Get_Goto(InputTy c) const {
+ ACS_Goto_Map::const_iterator iter = _goto_map.find(c);
+ return iter != _goto_map.end() ? (*iter).second : 0;
+ }
+
+ // Return all transitions sorted in the ascending order of their input.
+ void Get_Sorted_Gotos(GotoVect& Gotos) const {
+ const ACS_Goto_Map& m = _goto_map;
+ Gotos.clear();
+ for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end();
+ i != e; i++) {
+ Gotos.push_back(GotoPair(i->first, i->second));
+ }
+ sort(Gotos.begin(), Gotos.end(), GotoSort());
+ }
+
+ ACS_State* Get_FailLink() const { return _fail_link; }
+ uint32 Get_GotoNum() const { return _goto_map.size(); }
+ uint32 Get_ID() const { return _id; }
+ uint32 Get_Depth() const { return _depth; }
+ const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; }
+ bool is_Terminal() const { return _is_terminal; }
+ int get_Pattern_Idx() const {
+ ASSERT(is_Terminal() && _pattern_idx >= 0);
+ return _pattern_idx;
+ }
+
+private:
+ void set_Pattern_Idx(int idx) {
+ ASSERT(is_Terminal());
+ _pattern_idx = idx;
+ }
+
+private:
+ uint32 _id;
+ int _pattern_idx;
+ short _depth;
+ bool _is_terminal;
+ ACS_Goto_Map _goto_map;
+ ACS_State* _fail_link;
+};
+
+class ACS_Constructor {
+public:
+ ACS_Constructor();
+ ~ACS_Constructor();
+
+ void Construct(const char** strv, unsigned int* strlenv,
+ unsigned int strnum);
+
+ Match_Result Match(const char* s, uint32 len) const {
+ Match_Result r = MatchHelper(s, len);
+ Verify_Result(s, &r);
+ return r;
+ }
+
+ Match_Result Match(const char* s) const { return Match(s, strlen(s)); }
+
+#ifdef DEBUG
+ void dump_text(const char* = "ac.txt") const;
+ void dump_dot(const char* = "ac.dot") const;
+#endif
+ const ACS_State *Get_Root_State() const { return _root; }
+ const vector<ACS_State*>& Get_All_States() const {
+ return _all_states;
+ }
+
+ uint32 Get_Next_Node_Id() const { return _next_node_id; }
+ uint32 Get_State_Num() const { return _next_node_id - 1; }
+
+private:
+ void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx);
+ ACS_State* new_state();
+ void Propagate_faillink();
+
+ Match_Result MatchHelper(const char*, uint32 len) const;
+
+#ifdef VERIFY
+ void Verify_Result(const char* subject, const Match_Result* r) const;
+ void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len);
+ const char* get_ith_Pattern(unsigned i) const {
+ ASSERT(i < _pattern_vect.size());
+ return _pattern_vect.at(i);
+ }
+ unsigned get_ith_Pattern_Len(unsigned i) const {
+ ASSERT(i < _pattern_lens.size());
+ return _pattern_lens.at(i);
+ }
+#else
+ void Verify_Result(const char* subject, const Match_Result* r) const {
+ (void)subject; (void)r;
+ }
+ void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) {
+ (void)strv; (void)strlenv;
+ }
+#endif
+
+private:
+ ACS_State* _root;
+ vector<ACS_State*> _all_states;
+ unsigned char* _root_char;
+ uint32 _next_node_id;
+
+#ifdef VERIFY
+ char* _pattern_buf;
+ vector<int> _pattern_lens;
+ vector<char*> _pattern_vect;
+#endif
+};
+
+#endif
diff --git a/modules/policy/lua-aho-corasick/ac_util.hpp b/modules/policy/lua-aho-corasick/ac_util.hpp
new file mode 100644
index 0000000..56fd46c
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_util.hpp
@@ -0,0 +1,69 @@
+/*
+ Copyright (c) 2014 CloudFlare, Inc. All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of CloudFlare, Inc. nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+#ifndef AC_UTIL_H
+#define AC_UTIL_H
+
+#ifdef DEBUG
+#include <stdio.h> // for fprintf
+#include <stdlib.h> // for abort
+#endif
+
+typedef unsigned short uint16;
+typedef unsigned int uint32;
+typedef unsigned long uint64;
+typedef unsigned char InputTy;
+
+#ifdef DEBUG
+ // Usage examples: ASSERT(a > b), ASSERT(foo() && "Opps, foo() reutrn 0");
+ #define ASSERT(c) if (!(c))\
+ { fprintf(stderr, "%s:%d Assert: %s\n", __FILE__, __LINE__, #c); abort(); }
+#else
+ #define ASSERT(c) ((void)0)
+#endif
+
+#define likely(x) __builtin_expect((x),1)
+#define unlikely(x) __builtin_expect((x),0)
+
+#ifndef offsetof
+#define offsetof(st, m) ((size_t)(&((st *)0)->m))
+#endif
+
+typedef enum {
+ IMPL_SLOW_VARIANT = 1,
+ IMPL_FAST_VARIANT = 2,
+} impl_var_t;
+
+#define AC_MAGIC_NUM 0x5a
+typedef struct {
+ unsigned char magic_num;
+ unsigned char impl_variant;
+} buf_header_t;
+
+#endif //AC_UTIL_H
diff --git a/modules/policy/lua-aho-corasick/load_ac.lua b/modules/policy/lua-aho-corasick/load_ac.lua
new file mode 100644
index 0000000..eb70446
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/load_ac.lua
@@ -0,0 +1,90 @@
+-- Helper wrappring script for loading shared object libac.so (FFI interface)
+-- from package.cpath instead of LD_LIBRARTY_PATH.
+--
+
+local ffi = require 'ffi'
+ffi.cdef[[
+ void* ac_create(const char** str_v, unsigned int* strlen_v,
+ unsigned int v_len);
+ int ac_match2(void*, const char *str, int len);
+ void ac_free(void*);
+]]
+
+local _M = {}
+
+local string_gmatch = string.gmatch
+local string_match = string.match
+
+local ac_lib = nil
+local ac_create = nil
+local ac_match = nil
+local ac_free = nil
+
+--[[ Find shared object file package.cpath, obviating the need of setting
+ LD_LIBRARY_PATH
+]]
+local function find_shared_obj(cpath, so_name)
+ for k, v in string_gmatch(cpath, "[^;]+") do
+ local so_path = string_match(k, "(.*/)")
+ if so_path then
+ -- "so_path" could be nil. e.g, the dir path component is "."
+ so_path = so_path .. so_name
+
+ -- Don't get me wrong, the only way to know if a file exist is
+ -- trying to open it.
+ local f = io.open(so_path)
+ if f ~= nil then
+ io.close(f)
+ return so_path
+ end
+ end
+ end
+end
+
+function _M.load_ac_lib()
+ if ac_lib ~= nil then
+ return ac_lib
+ else
+ local so_path = find_shared_obj(package.cpath, "libac.so")
+ if so_path ~= nil then
+ ac_lib = ffi.load(so_path)
+ ac_create = ac_lib.ac_create
+ ac_match = ac_lib.ac_match2
+ ac_free = ac_lib.ac_free
+ return ac_lib
+ end
+ end
+end
+
+-- Create an Aho-Corasick instance, and return the instance if it was
+-- successful.
+function _M.create_ac(dict)
+ local strnum = #dict
+ if ac_lib == nil then
+ _M.load_ac_lib()
+ end
+
+ local str_v = ffi.new("const char *[?]", strnum)
+ local strlen_v = ffi.new("unsigned int [?]", strnum)
+
+ for i = 1, strnum do
+ local s = dict[i]
+ str_v[i - 1] = s
+ strlen_v[i - 1] = #s
+ end
+
+ local ac = ac_create(str_v, strlen_v, strnum);
+ if ac ~= nil then
+ return ffi.gc(ac, ac_free)
+ end
+end
+
+-- Return nil if str doesn't match the dictionary, else return non-nil.
+function _M.match(ac, str)
+ local r = ac_match(ac, str, #str);
+ if r >= 0 then
+ return r
+ end
+end
+
+return _M
diff --git a/modules/policy/lua-aho-corasick/mytest.cxx b/modules/policy/lua-aho-corasick/mytest.cxx
new file mode 100644
index 0000000..ef3dc87
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/mytest.cxx
@@ -0,0 +1,200 @@
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include "ac.h"
+
+using namespace std;
+
+/////////////////////////////////////////////////////////////////////////
+//
+// Test using strings from input files
+//
+/////////////////////////////////////////////////////////////////////////
+//
+class BigFileTester {
+public:
+ BigFileTester(const char* filepath);
+
+private:
+ void Genector
+privaete:
+ const char* _msg;
+ int _msg_len;
+ int _key_num; // number of strings in dictionary
+ int _key_len_idx;
+};
+
+/////////////////////////////////////////////////////////////////////////
+//
+// Simple (yet maybe tricky) testings
+//
+/////////////////////////////////////////////////////////////////////////
+//
+typedef struct {
+ const char* str;
+ const char* match;
+} StrPair;
+
+typedef struct {
+ const char* name;
+ const char** dict;
+ StrPair* strpairs;
+ int dict_len;
+ int strpair_num;
+} TestingCase;
+
+class Tests {
+public:
+ Tests(const char* name,
+ const char* dict[], int dict_len,
+ StrPair strpairs[], int strpair_num) {
+ if (!_tests)
+ _tests = new vector<TestingCase>;
+
+ TestingCase tc;
+ tc.name = name;
+ tc.dict = dict;
+ tc.strpairs = strpairs;
+ tc.dict_len = dict_len;
+ tc.strpair_num = strpair_num;
+ _tests->push_back(tc);
+ }
+
+ static vector<TestingCase>* Get_Tests() { return _tests; }
+ static void Erase_Tests() { delete _tests; _tests = 0; }
+
+private:
+ static vector<TestingCase> *_tests;
+};
+
+vector<TestingCase>* Tests::_tests = 0;
+
+static void
+simple_test(void) {
+ int total = 0;
+ int fail = 0;
+
+ vector<TestingCase> *tests = Tests::Get_Tests();
+ if (!tests)
+ return 0;
+
+ for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end();
+ i != e; i++) {
+ TestingCase& t = *i;
+ fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name);
+ for (int i = 0, e = t.dict_len, need_break=0; i < e; i++) {
+ fprintf(stdout, "%s, ", t.dict[i]);
+ if (need_break++ == 16) {
+ fputs("\n ", stdout);
+ need_break = 0;
+ }
+ }
+ fputs("]\n", stdout);
+
+ /* Create the dictionary */
+ int dict_len = t.dict_len;
+ ac_t* ac = ac_create(t.dict, dict_len);
+
+ for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) {
+ const StrPair& sp = t.strpairs[ii];
+ const char *str = sp.str; // the string to be matched
+ const char *match = sp.match;
+
+ fprintf(stdout, "[%3d] Testing '%s' : ", total, str);
+
+ int len = strlen(str);
+ ac_result_t r = ac_match(ac, str, len);
+ int m_b = r.match_begin;
+ int m_e = r.match_end;
+
+ // The return value per se is insane.
+ if (m_b > m_e ||
+ ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) {
+ fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e);
+ fail ++;
+ continue;
+ }
+
+ // If the string is not supposed to match the dictionary.
+ if (!match) {
+ if (m_b != -1 || m_e != -1) {
+ fail ++;
+ fprintf(stdout, "Not Supposed to match (%d, %d) \n",
+ m_b, m_e);
+ } else
+ fputs("Pass\n", stdout);
+ continue;
+ }
+
+ // The string or its substring is match the dict.
+ if (m_b >= len || m_b >= len) {
+ fail ++;
+ fprintf(stdout,
+ "Return value >= the length of the string (%d, %d)\n",
+ m_b, m_e);
+ continue;
+ } else {
+ int mlen = strlen(match);
+ if ((mlen != m_e - m_b + 1) ||
+ strncmp(str + m_b, match, mlen)) {
+ fail ++;
+ fprintf(stdout, "Fail\n");
+ } else
+ fprintf(stdout, "Pass\n");
+ }
+ }
+ fputs("\n", stdout);
+ ac_free(ac);
+ }
+
+ fprintf(stdout, "Total : %d, Fail %d\n", total, fail);
+
+ return fail ? -1 : 0;
+}
+
+int
+main (int argc, char** argv) {
+ int res = simple_test();
+ return res;
+};
+
+/* test 1*/
+const char *dict1[] = {"he", "she", "his", "her"};
+StrPair strpair1[] = {
+ {"he", "he"}, {"she", "she"}, {"his", "his"},
+ {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"},
+ {"shis2", "his"}, {"ahhe", "he"}
+};
+Tests test1("test 1",
+ dict1, sizeof(dict1)/sizeof(dict1[0]),
+ strpair1, sizeof(strpair1)/sizeof(strpair1[0]));
+
+/* test 2*/
+const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/
+StrPair strpair2[] = {{"The pot had a handle", 0}};
+Tests test2("test 2", dict2, 2, strpair2, 1);
+
+/* test 3*/
+const char *dict3[] = {"The"};
+StrPair strpair3[] = {{"The pot had a handle", "The"}};
+Tests test3("test 3", dict3, 1, strpair3, 1);
+
+/* test 4*/
+const char *dict4[] = {"pot"};
+StrPair strpair4[] = {{"The pot had a handle", "pot"}};
+Tests test4("test 4", dict4, 1, strpair4, 1);
+
+/* test 5*/
+const char *dict5[] = {"pot "};
+StrPair strpair5[] = {{"The pot had a handle", "pot "}};
+Tests test5("test 5", dict5, 1, strpair5, 1);
+
+/* test 6*/
+const char *dict6[] = {"ot h"};
+StrPair strpair6[] = {{"The pot had a handle", "ot h"}};
+Tests test6("test 6", dict6, 1, strpair6, 1);
+
+/* test 7*/
+const char *dict7[] = {"andle"};
+StrPair strpair7[] = {{"The pot had a handle", "andle"}};
+Tests test7("test 7", dict7, 1, strpair7, 1);
diff --git a/modules/policy/lua-aho-corasick/tests/Makefile b/modules/policy/lua-aho-corasick/tests/Makefile
new file mode 100644
index 0000000..54fd90f
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/Makefile
@@ -0,0 +1,65 @@
+OS := $(shell uname)
+ifeq ($(OS), Darwin)
+ SO_EXT := dylib
+else
+ SO_EXT := so
+endif
+
+.PHONY = all clean test runtest benchmark
+
+PROGRAM = ac_test
+BENCHMARK = ac_bench
+all: runtest
+
+CXXFLAGS = -O3 -g -march=native -Wall -DDEBUG
+MYCXXFLAGS = -MMD -I.. $(CXXFLAGS)
+%.o : %.cxx
+ $(CXX) $< -c $(MYCXXFLAGS)
+
+-include dep.cxx
+SRC = test_main.cxx ac_test_simple.cxx ac_test_aggr.cxx test_bigfile.cxx
+
+OBJ = ${SRC:.cxx=.o}
+
+-include test_dep.txt
+-include bench_dep.txt
+
+$(PROGRAM) $(BENCHMARK) : testinput/text.tar testinput/image.bin
+$(PROGRAM) : $(OBJ) ../libac.$(SO_EXT)
+ $(CXX) $(OBJ) -L.. -lac -o $@
+ -cat *.d > test_dep.txt
+
+$(BENCHMARK) : ac_bench.o ../libac.$(SO_EXT)
+ $(CXX) ac_bench.o -L.. -lac -o $@
+ -cat *.d > bench_dep.txt
+
+ifneq ($(OS), Darwin)
+runtest:$(PROGRAM)
+ LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/*
+
+benchmark:$(BENCHMARK)
+ LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./ac_bench
+
+else
+runtest:$(PROGRAM)
+ DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/*
+
+benchmark:$(BENCHMARK)
+ DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./ac_bench
+
+endif
+
+testinput/text.tar:
+ echo "download testing files (gcc tarball)..."
+ if [ ! -d testinput ] ; then mkdir testinput; fi
+ cd testinput && \
+ curl ftp://ftp.gnu.org/gnu/gcc/gcc-1.42.tar.gz -o text.tar.gz 2>/dev/null \
+ && gzip -d text.tar.gz
+
+testinput/image.bin:
+ echo "download testing files.."
+ if [ ! -d testinput ] ; then mkdir testinput; fi
+ curl http://www.3dvisionlive.com/sites/default/files/Curiosity_render_hiresb.jpg -o $@ 2>/dev/null
+
+clean:
+ -rm -f *.o *.d dep.txt $(PROGRAM) $(BENCHMARK)
diff --git a/modules/policy/lua-aho-corasick/tests/ac_bench.cxx b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx
new file mode 100644
index 0000000..421322c
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx
@@ -0,0 +1,519 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <sys/time.h>
+#include <time.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <dirent.h>
+#include <libgen.h>
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <getopt.h>
+
+#include <string>
+#include <vector>
+#include "ac.h"
+#include "ac_util.hpp"
+
+using namespace std;
+
+static bool SomethingWrong = false;
+
+static int iteration = 300;
+static string dict_dir;
+static string obj_file_dir;
+static bool print_help = false;
+static int piece_size = 1024;
+
+class PatternSet {
+public:
+ PatternSet(const char* filepath);
+ ~PatternSet() { Cleanup(); }
+
+ int getPatternNum() const { return _pat_num; }
+ const char** getPatternVector() const { return _patterns; }
+ unsigned int* getPatternLenVector() const { return _pat_len; }
+
+ const char* getErrMessage() const { return _errmsg; }
+ static bool isDictFile(const char* filepath) {
+ if (strncmp(basename(const_cast<char*>(filepath)), "dict", 4))
+ return false;
+ return true;
+ }
+
+private:
+ bool ExtractPattern(const char* filepath);
+ void Cleanup();
+
+ const char** _patterns;
+ unsigned int* _pat_len;
+ char* _mmap;
+ int _fd;
+ size_t _mmap_size;
+ int _pat_num;
+
+ const char* _errmsg;
+};
+
+bool
+PatternSet::ExtractPattern(const char* filepath) {
+ if (!isDictFile(filepath))
+ return false;
+
+ struct stat filestat;
+ if (stat(filepath, &filestat)) {
+ _errmsg = "fail to call stat()";
+ return false;
+ }
+
+ if (filestat.st_size > 4096 * 1024) {
+ /* It dosen't seem to be a dictionary file*/
+ _errmsg = "file too big?";
+ return false;
+ }
+
+ _fd = open(filepath, 0);
+ if (_fd == -1) {
+ _errmsg = "fail to open dictionary file";
+ return false;
+ }
+
+ _mmap_size = filestat.st_size;
+ _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE,
+ MAP_PRIVATE, _fd, 0);
+ if (_mmap == MAP_FAILED) {
+ _errmsg = "fail to call mmap";
+ return false;
+ }
+
+ const char* pat = _mmap;
+ vector<const char*> pat_vect;
+ vector<unsigned> pat_len_vect;
+
+ for (size_t i = 0, e = filestat.st_size; i < e; i++) {
+ if (_mmap[i] == '\r' || _mmap[i] == '\n') {
+ _mmap[i] = '\0';
+ int len = _mmap + i - pat;
+ if (len > 0) {
+ pat_vect.push_back(pat);
+ pat_len_vect.push_back(len);
+ }
+ pat = _mmap + i + 1;
+ }
+ }
+
+ ASSERT(pat_vect.size() == pat_len_vect.size());
+
+ int pat_num = pat_vect.size();
+ if (pat_num > 0) {
+ const char** p = _patterns = new const char*[pat_num];
+ int i = 0;
+ for (vector<const char*>::iterator iter = pat_vect.begin(),
+ iter_e = pat_vect.end(); iter != iter_e; ++iter) {
+ p[i++] = *iter;
+ }
+
+ i = 0;
+ unsigned int* q = _pat_len = new unsigned int[pat_num];
+ for (vector<unsigned>::iterator iter = pat_len_vect.begin(),
+ iter_e = pat_len_vect.end(); iter != iter_e; ++iter) {
+ q[i++] = *iter;
+ }
+ }
+
+ _pat_num = pat_num;
+ if (pat_num <= 0) {
+ _errmsg = "no pattern at all";
+ return false;
+ }
+
+ return true;
+}
+
+void
+PatternSet::Cleanup() {
+ if (_mmap != MAP_FAILED) {
+ munmap(_mmap, _mmap_size);
+ _mmap = (char*)MAP_FAILED;
+ _mmap_size = 0;
+ }
+
+ delete[] _patterns;
+ delete[] _pat_len;
+ if (_fd != -1)
+ close(_fd);
+ _pat_num = -1;
+}
+
+PatternSet::PatternSet(const char* filepath) {
+ _patterns = 0;
+ _pat_len = 0;
+ _mmap = (char*)MAP_FAILED;
+ _mmap_size = 0;
+ _pat_num = -1;
+ _errmsg = "";
+
+ if (!ExtractPattern(filepath))
+ Cleanup();
+}
+
+bool
+getFilesUnderDir(vector<string>& files, const char* path) {
+ files.clear();
+
+ DIR* dir = opendir(path);
+ if (!dir)
+ return false;
+
+ string path_dir = path;
+ path_dir += "/";
+
+ for (;;) {
+ struct dirent* entry = readdir(dir);
+ if (entry) {
+ string filepath = path_dir + entry->d_name;
+ struct stat file_stat;
+ if (stat(filepath.c_str(), &file_stat)) {
+ closedir(dir);
+ return false;
+ }
+
+ if (S_ISREG(file_stat.st_mode))
+ files.push_back(filepath);
+
+ continue;
+ }
+
+ if (errno) {
+ return false;
+ }
+ break;
+ }
+ closedir(dir);
+ return true;
+}
+
+class Timer {
+public:
+ Timer() {
+ my_clock_gettime(&_start);
+ _stop = _start;
+ _acc.tv_sec = 0;
+ _acc.tv_nsec = 0;
+ }
+
+ const Timer& operator += (const Timer& that) {
+ time_t sec = _acc.tv_sec + that._acc.tv_sec;
+ long nsec = _acc.tv_nsec + that._acc.tv_nsec;
+ if (nsec > 1000000000) {
+ nsec -= 1000000000;
+ sec += 1;
+ }
+ _acc.tv_sec = sec;
+ _acc.tv_nsec = nsec;
+ return *this;
+ }
+
+ // return duration in us
+ size_t getDuration() const {
+ return _acc.tv_sec * (size_t)1000000 + _acc.tv_nsec/1000;
+ }
+
+ void Start(bool acc=true) {
+ my_clock_gettime(&_start);
+ }
+
+ void Stop() {
+ my_clock_gettime(&_stop);
+ struct timespec t = CalcDuration();
+ _acc = add_duration(_acc, t);
+ }
+
+private:
+ int my_clock_gettime(struct timespec* t) {
+#ifdef __linux
+ return clock_gettime(CLOCK_PROCESS_CPUTIME_ID, t);
+#else
+ struct timeval tv;
+ int rc = gettimeofday(&tv, 0);
+ t->tv_sec = tv.tv_sec;
+ t->tv_nsec = tv.tv_usec * 1000;
+ return rc;
+#endif
+ }
+
+ struct timespec add_duration(const struct timespec& dur1,
+ const struct timespec& dur2) {
+ time_t sec = dur1.tv_sec + dur2.tv_sec;
+ long nsec = dur1.tv_nsec + dur2.tv_nsec;
+ if (nsec > 1000000000) {
+ nsec -= 1000000000;
+ sec += 1;
+ }
+ timespec t;
+ t.tv_sec = sec;
+ t.tv_nsec = nsec;
+
+ return t;
+ }
+
+ struct timespec CalcDuration() const {
+ timespec diff;
+ if ((_stop.tv_nsec - _start.tv_nsec)<0) {
+ diff.tv_sec = _stop.tv_sec - _start.tv_sec - 1;
+ diff.tv_nsec = 1000000000 + _stop.tv_nsec - _start.tv_nsec;
+ } else {
+ diff.tv_sec = _stop.tv_sec - _start.tv_sec;
+ diff.tv_nsec = _stop.tv_nsec - _start.tv_nsec;
+ }
+ return diff;
+ }
+
+ struct timespec _start;
+ struct timespec _stop;
+ struct timespec _acc;
+};
+
+class Benchmark {
+public:
+ Benchmark(const PatternSet& pat_set, const char* infile):
+ _pat_set(pat_set), _infile(infile) {
+ _mmap = (char*)MAP_FAILED;
+ _file_sz = 0;
+ _fd = -1;
+ }
+
+ ~Benchmark() {
+ if (_mmap != MAP_FAILED)
+ munmap(_mmap, _file_sz);
+ if (_fd != -1)
+ close(_fd);
+ }
+
+ bool Run(int iteration);
+ const Timer& getTimer() const { return _timer; }
+
+private:
+ const PatternSet& _pat_set;
+ const char* _infile;
+ char* _mmap;
+ int _fd;
+ size_t _file_sz; // input file size
+ Timer _timer;
+};
+
+bool
+Benchmark::Run(int iteration) {
+ if (_pat_set.getPatternNum() <= 0) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ if (_mmap == MAP_FAILED) {
+ struct stat filestat;
+ if (stat(_infile, &filestat)) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ if (!S_ISREG(filestat.st_mode)) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ _fd = open(_infile, 0);
+ if (_fd == -1)
+ return false;
+
+ _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE,
+ MAP_PRIVATE, _fd, 0);
+
+ if (_mmap == MAP_FAILED) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ _file_sz = filestat.st_size;
+ }
+
+ ac_t* ac = ac_create(_pat_set.getPatternVector(),
+ _pat_set.getPatternLenVector(),
+ _pat_set.getPatternNum());
+ if (!ac) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ int piece_num = _file_sz/piece_size;
+
+ _timer.Start(false);
+
+ /* Stupid compiler may not be able to promote piece_size into register.
+ * Do it manually.
+ */
+ int piece_sz = piece_size;
+ for (int i = 0; i < iteration; i++) {
+ size_t match_ofst = 0;
+ for (int piece_idx = 0; piece_idx < piece_num; piece_idx ++) {
+ ac_match2(ac, _mmap + match_ofst, piece_sz);
+ match_ofst += piece_sz;
+ }
+ if (match_ofst != _file_sz)
+ ac_match2(ac, _mmap + match_ofst, _file_sz - match_ofst);
+ }
+ _timer.Stop();
+ return true;
+}
+
+const char* short_opt = "hd:f:i:p:";
+const struct option long_opts[] = {
+ {"help", no_argument, 0, 'h'},
+ {"iteration", required_argument, 0, 'i'},
+ {"dictionary-dir", required_argument, 0, 'd'},
+ {"obj-file-dir", required_argument, 0, 'f'},
+ {"piece-size", required_argument, 0, 'p'},
+};
+
+static void
+PrintHelp(const char* prog_name) {
+ const char* msg =
+"Usage %s [OPTIONS]\n"
+" -d, --dictionary-dir : specify the dictionary directory (./dict by default)\n"
+" -f, --obj-file-dir : specify the object file directory\n"
+" (./testinput by default)\n"
+" -i, --iteration : Run this many iteration for each pattern match\n"
+" -p, --piece-size : The size of 'piece' in byte. The input file is\n"
+" divided into pieces, and match function is working\n"
+" on one piece at a time. The default size of piece\n"
+" is 1k byte.\n";
+
+ fprintf(stdout, msg, prog_name);
+}
+
+static bool
+getOptions(int argc, char** argv) {
+ bool dict_dir_set = false;
+ bool objfile_dir_set = false;
+ int opt_index;
+
+ while (1) {
+ if (print_help) break;
+
+ int c = getopt_long(argc, argv, short_opt, long_opts, &opt_index);
+
+ if (c == -1) break;
+ if (c == 0) { c = long_opts[opt_index].val; }
+
+ switch(c) {
+ case 'h':
+ print_help = true;
+ break;
+
+ case 'i':
+ iteration = atol(optarg);
+ break;
+
+ case 'd':
+ dict_dir = optarg;
+ dict_dir_set = true;
+ break;
+
+ case 'f':
+ obj_file_dir = optarg;
+ objfile_dir_set = true;
+ break;
+
+ case 'p':
+ piece_size = atol(optarg);
+ break;
+
+ case '?':
+ default:
+ return false;
+ }
+ }
+
+ if (print_help)
+ return true;
+
+ string basedir(dirname(argv[0]));
+ if (!dict_dir_set)
+ dict_dir = basedir + "/dict";
+
+ if (!objfile_dir_set)
+ obj_file_dir = basedir + "/testinput";
+
+ return true;
+}
+
+int
+main(int argc, char** argv) {
+ if (!getOptions(argc, argv))
+ return -1;
+
+ if (print_help) {
+ PrintHelp(argv[0]);
+ return 0;
+ }
+
+#ifndef __linux
+ fprintf(stdout, "\n!!!WARNING: On this OS, the execution time is measured"
+ " by gettimeofday(2) which is imprecise!!!\n\n");
+#endif
+
+ fprintf(stdout, "Test with iteration = %d, piece size = %d, and",
+ iteration, piece_size);
+ fprintf(stdout, "\n dictionary dir = %s\n object file dir = %s\n\n",
+ dict_dir.c_str(), obj_file_dir.c_str());
+
+ vector<string> dict_files;
+ vector<string> input_files;
+
+ if (!getFilesUnderDir(dict_files, dict_dir.c_str())) {
+ fprintf(stdout, "fail to find dictionary files\n");
+ return -1;
+ }
+
+ if (!getFilesUnderDir(input_files, obj_file_dir.c_str())) {
+ fprintf(stdout, "fail to find test input files\n");
+ return -1;
+ }
+
+ for (vector<string>::iterator diter = dict_files.begin(),
+ diter_e = dict_files.end(); diter != diter_e; ++diter) {
+
+ const char* dict_name = diter->c_str();
+ if (!PatternSet::isDictFile(dict_name))
+ continue;
+
+ PatternSet ps(dict_name);
+ if (ps.getPatternNum() <= 0) {
+ fprintf(stdout, "fail to open dictionary file %s : %s\n",
+ dict_name, ps.getErrMessage());
+ SomethingWrong = true;
+ continue;
+ }
+
+ fprintf(stdout, "Using dictionary %s\n", dict_name);
+ Timer timer;
+ for (vector<string>::iterator iter = input_files.begin(),
+ iter_e = input_files.end(); iter != iter_e; ++iter) {
+ fprintf(stdout, " testing %s ... ", iter->c_str());
+ fflush(stdout);
+ Benchmark bm(ps, iter->c_str());
+ bm.Run(iteration);
+ const Timer& t = bm.getTimer();
+ timer += bm.getTimer();
+ fprintf(stdout, "elapsed %.3f\n", t.getDuration() / 1000000.0);
+ }
+
+ fprintf(stdout,
+ "\n==========================================================\n"
+ " Total Elapse %.3f\n\n", timer.getDuration() / 1000000.0);
+ }
+
+ return SomethingWrong ? -1 : 0;
+}
diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx
new file mode 100644
index 0000000..4ea02bc
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx
@@ -0,0 +1,135 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+using namespace std;
+
+namespace {
+class ACBigFileTester : public BigFileTester {
+public:
+ ACBigFileTester(const char* filepath) : BigFileTester(filepath){};
+
+private:
+ virtual buf_header_t* PM_Create(const char** strv, uint32* strlenv,
+ uint32 vect_len) {
+ return (buf_header_t*)ac_create(strv, strlenv, vect_len);
+ }
+
+ virtual void PM_Free(buf_header_t* PM) { ac_free(PM); }
+ virtual bool Run_Helper(buf_header_t* PM);
+};
+
+class ACTestAggressive: public ACTestBase {
+public:
+ ACTestAggressive(const vector<const char*>& files, const char* banner)
+ : ACTestBase(banner), _files(files) {}
+ virtual bool Run();
+
+private:
+ void PrintSummary(int total, int fail) {
+ fprintf(stdout, "Test count : %d, fail: %d\n", total, fail);
+ fflush(stdout);
+ }
+ vector<const char*> _files;
+};
+
+} // end of anonymous namespace
+
+bool
+ACBigFileTester::Run_Helper(buf_header_t* PM) {
+ int fail = 0;
+ // advance one chunk at a time.
+ int len = _msg_len;
+ int chunk_sz = _chunk_sz;
+
+ vector<const char*> c_style_keys;
+ for (int i = 0, e = _keys.size(); i != e; i++) {
+ const char* key = _keys[i].first;
+ int len = _keys[i].second;
+ char *t = new char[len+1];
+ memcpy(t, key, len);
+ t[len] = '\0';
+ c_style_keys.push_back(t);
+ }
+
+ for (int ofst = 0, chunk_idx = 0, chunk_num = _chunk_num;
+ chunk_idx < chunk_num; ofst += chunk_sz, chunk_idx++) {
+ const char* substring = _msg + ofst;
+ ac_result_t r = ac_match((ac_t*)(void*)PM, substring , len - ofst);
+ int m_b = r.match_begin;
+ int m_e = r.match_end;
+
+ if (m_b < 0 || m_e < 0 || m_e <= m_b || m_e >= len) {
+ fprintf(stdout, "fail to find match substring[%d:%d])\n",
+ ofst, len - 1);
+ fail ++;
+ continue;
+ }
+
+ const char* match_str = _msg + len;
+ int strstr_len = 0;
+ int key_idx = -1;
+
+ for (int i = 0, e = c_style_keys.size(); i != e; i++) {
+ const char* key = c_style_keys[i];
+ if (const char *m = strstr(substring, key)) {
+ if (m < match_str) {
+ match_str = m;
+ strstr_len = _keys[i].second;
+ key_idx = i;
+ }
+ }
+ }
+ ASSERT(key_idx != -1);
+ if ((match_str - substring != m_b)) {
+ fprintf(stdout,
+ "Fail to find match substring[%d:%d]),"
+ " expected to find match at offset %d instead of %d\n",
+ ofst, len - 1,
+ (int)(match_str - _msg), ofst + m_b);
+ fprintf(stdout, "%d vs %d (key idx %d)\n", strstr_len, m_e - m_b + 1, key_idx);
+ PrintStr(stdout, match_str, strstr_len);
+ fprintf(stdout, "\n");
+ PrintStr(stdout, _msg + ofst + m_b,
+ m_e - m_b + 1);
+ fprintf(stdout, "\n");
+ fail ++;
+ }
+ }
+ for (vector<const char*>::iterator i = c_style_keys.begin(),
+ e = c_style_keys.end(); i != e; i++) {
+ delete[] *i;
+ }
+
+ return fail == 0;
+}
+
+bool
+ACTestAggressive::Run() {
+ int fail = 0;
+ for (vector<const char*>::iterator i = _files.begin(), e = _files.end();
+ i != e; i++) {
+ ACBigFileTester bft(*i);
+ if (!bft.Run())
+ fail ++;
+ }
+ return fail == 0;
+}
+
+bool
+Run_AC_Aggressive_Test(const vector<const char*>& files) {
+ ACTestAggressive t(files, "AC Aggressive test");
+ t.PrintBanner();
+ return t.Run();
+}
diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx
new file mode 100644
index 0000000..fa2d7fd
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx
@@ -0,0 +1,275 @@
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+using namespace std;
+
+namespace {
+typedef struct {
+ const char* str;
+ const char* match;
+} StrPair;
+
+typedef enum {
+ MV_FIRST_MATCH = 0,
+ MV_LEFT_LONGEST = 1,
+} MatchVariant;
+
+typedef struct {
+ const char* name;
+ const char** dict;
+ StrPair* strpairs;
+ int dict_len;
+ int strpair_num;
+ MatchVariant match_variant;
+} TestingCase;
+
+class Tests {
+public:
+ Tests(const char* name,
+ const char* dict[], int dict_len,
+ StrPair strpairs[], int strpair_num,
+ MatchVariant mv = MV_FIRST_MATCH) {
+ if (!_tests)
+ _tests = new vector<TestingCase>;
+
+ TestingCase tc;
+ tc.name = name;
+ tc.dict = dict;
+ tc.strpairs = strpairs;
+ tc.dict_len = dict_len;
+ tc.strpair_num = strpair_num;
+ tc.match_variant = mv;
+ _tests->push_back(tc);
+ }
+
+ static vector<TestingCase>* Get_Tests() { return _tests; }
+ static void Erase_Tests() { delete _tests; _tests = 0; }
+
+private:
+ static vector<TestingCase> *_tests;
+};
+
+class LeftLongestTests : public Tests {
+public:
+ LeftLongestTests (const char* name, const char* dict[], int dict_len,
+ StrPair strpairs[], int strpair_num):
+ Tests(name, dict, dict_len, strpairs, strpair_num, MV_LEFT_LONGEST) {
+ }
+};
+
+vector<TestingCase>* Tests::_tests = 0;
+
+class ACTestSimple: public ACTestBase {
+public:
+ ACTestSimple(const char* banner) : ACTestBase(banner) {}
+ virtual bool Run();
+
+private:
+ void PrintSummary(int total, int fail) {
+ fprintf(stdout, "Test count : %d, fail: %d\n", total, fail);
+ fflush(stdout);
+ }
+};
+}
+
+bool
+ACTestSimple::Run() {
+ int total = 0;
+ int fail = 0;
+
+ vector<TestingCase> *tests = Tests::Get_Tests();
+ if (!tests) {
+ PrintSummary(0, 0);
+ return true;
+ }
+
+ for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end();
+ i != e; i++) {
+ TestingCase& t = *i;
+ int dict_len = t.dict_len;
+ unsigned int* strlen_v = new unsigned int[dict_len];
+
+ fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name);
+ for (int i = 0, need_break=0; i < dict_len; i++) {
+ const char* s = t.dict[i];
+ fprintf(stdout, "%s, ", s);
+ strlen_v[i] = strlen(s);
+ if (need_break++ == 16) {
+ fputs("\n ", stdout);
+ need_break = 0;
+ }
+ }
+ fputs("]\n", stdout);
+
+ /* Create the dictionary */
+ ac_t* ac = ac_create(t.dict, strlen_v, dict_len);
+ delete[] strlen_v;
+
+ for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) {
+ const StrPair& sp = t.strpairs[ii];
+ const char *str = sp.str; // the string to be matched
+ const char *match = sp.match;
+
+ fprintf(stdout, "[%3d] Testing '%s' : ", total, str);
+
+ int len = strlen(str);
+ ac_result_t r;
+ if (t.match_variant == MV_FIRST_MATCH)
+ r = ac_match(ac, str, len);
+ else if (t.match_variant == MV_LEFT_LONGEST)
+ r = ac_match_longest_l(ac, str, len);
+ else {
+ ASSERT(false && "Unknown variant");
+ }
+
+ int m_b = r.match_begin;
+ int m_e = r.match_end;
+
+ // The return value per se is insane.
+ if (m_b > m_e ||
+ ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) {
+ fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e);
+ fail ++;
+ continue;
+ }
+
+ // If the string is not supposed to match the dictionary.
+ if (!match) {
+ if (m_b != -1 || m_e != -1) {
+ fail ++;
+ fprintf(stdout, "Not Supposed to match (%d, %d) \n",
+ m_b, m_e);
+ } else
+ fputs("Pass\n", stdout);
+ continue;
+ }
+
+ // The string or its substring is match the dict.
+ if (m_b >= len || m_b >= len) {
+ fail ++;
+ fprintf(stdout,
+ "Return value >= the length of the string (%d, %d)\n",
+ m_b, m_e);
+ continue;
+ } else {
+ int mlen = strlen(match);
+ if ((mlen != m_e - m_b + 1) ||
+ strncmp(str + m_b, match, mlen)) {
+ fail ++;
+ fprintf(stdout, "Fail\n");
+ } else
+ fprintf(stdout, "Pass\n");
+ }
+ }
+ fputs("\n", stdout);
+ ac_free(ac);
+ }
+
+ PrintSummary(total, fail);
+ return fail == 0;
+}
+
+bool
+Run_AC_Simple_Test() {
+ ACTestSimple t("AC Simple test");
+ t.PrintBanner();
+ return t.Run();
+}
+
+//////////////////////////////////////////////////////////////////////////////
+//
+// Testing cases for first-match variant (i.e. test ac_match())
+//
+//////////////////////////////////////////////////////////////////////////////
+//
+
+/* test 1*/
+const char *dict1[] = {"he", "she", "his", "her"};
+StrPair strpair1[] = {
+ {"he", "he"}, {"she", "she"}, {"his", "his"},
+ {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"},
+ {"shis2", "his"}, {"ahhe", "he"}
+};
+Tests test1("test 1",
+ dict1, sizeof(dict1)/sizeof(dict1[0]),
+ strpair1, sizeof(strpair1)/sizeof(strpair1[0]));
+
+/* test 2*/
+const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/
+StrPair strpair2[] = {{"The pot had a handle", 0}};
+Tests test2("test 2", dict2, 2, strpair2, 1);
+
+/* test 3*/
+const char *dict3[] = {"The"};
+StrPair strpair3[] = {{"The pot had a handle", "The"}};
+Tests test3("test 3", dict3, 1, strpair3, 1);
+
+/* test 4*/
+const char *dict4[] = {"pot"};
+StrPair strpair4[] = {{"The pot had a handle", "pot"}};
+Tests test4("test 4", dict4, 1, strpair4, 1);
+
+/* test 5*/
+const char *dict5[] = {"pot "};
+StrPair strpair5[] = {{"The pot had a handle", "pot "}};
+Tests test5("test 5", dict5, 1, strpair5, 1);
+
+/* test 6*/
+const char *dict6[] = {"ot h"};
+StrPair strpair6[] = {{"The pot had a handle", "ot h"}};
+Tests test6("test 6", dict6, 1, strpair6, 1);
+
+/* test 7*/
+const char *dict7[] = {"andle"};
+StrPair strpair7[] = {{"The pot had a handle", "andle"}};
+Tests test7("test 7", dict7, 1, strpair7, 1);
+
+const char *dict8[] = {"aaab"};
+StrPair strpair8[] = {{"aaaaaaab", "aaab"}};
+Tests test8("test 8", dict8, 1, strpair8, 1);
+
+const char *dict9[] = {"haha", "z"};
+StrPair strpair9[] = {{"aaaaz", "z"}, {"z", "z"}};
+Tests test9("test 9", dict9, 2, strpair9, 2);
+
+/* test the case when input string dosen't contain even a single char
+ * of the pattern in dictionary.
+ */
+const char *dict10[] = {"abc"};
+StrPair strpair10[] = {{"cde", 0}};
+Tests test10("test 10", dict10, 1, strpair10, 1);
+
+
+//////////////////////////////////////////////////////////////////////////////
+//
+// Testing cases for first longest match variant (i.e.
+// test ac_match_longest_l())
+//
+//////////////////////////////////////////////////////////////////////////////
+//
+
+// This was actually first motivation for left-longest-match
+const char *dict100[] = {"Mozilla", "Mozilla Mobile"};
+StrPair strpair100[] = {{"User Agent containing string Mozilla Mobile", "Mozilla Mobile"}};
+LeftLongestTests test100("l_test 100", dict100, 2, strpair100, 1);
+
+// Dict with single char is tricky
+const char *dict101[] = {"a", "abc"};
+StrPair strpair101[] = {{"abcdef", "abc"}};
+LeftLongestTests test101("l_test 101", dict101, 2, strpair101, 1);
+
+// Testing case with partially overlapping patterns. The purpose is to
+// check if the fail-link leading from terminal state is correct.
+//
+// The fail-link leading from terminal-state does not matter in
+// match-first-occurrence variant, as it stop when a terminal is hit.
+//
+const char *dict102[] = {"abc", "bcdef"};
+StrPair strpair102[] = {{"abcdef", "bcdef"}};
+LeftLongestTests test102("l_test 102", dict102, 2, strpair102, 1);
diff --git a/modules/policy/lua-aho-corasick/tests/dict/README.txt b/modules/policy/lua-aho-corasick/tests/dict/README.txt
new file mode 100644
index 0000000..cd50b41
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/dict/README.txt
@@ -0,0 +1 @@
+This directory contains pattern set of benchmark purpose.
diff --git a/modules/policy/lua-aho-corasick/tests/dict/dict1.txt b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt
new file mode 100644
index 0000000..94085a9
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt
@@ -0,0 +1,11 @@
+false_return@
+forloop#haha
+wtfprogram
+mmaporunmap
+ThIs?Module!IsEssential
+struct rtlwtf
+gettIMEOfdayWrong
+edistribution_and_use_in_@source
+Copyright~#@
+while {!
+!%SQLinje
diff --git a/modules/policy/lua-aho-corasick/tests/load_ac_test.lua b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua
new file mode 100644
index 0000000..7fb7db9
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua
@@ -0,0 +1,82 @@
+-- This script is to test load_ac.lua
+--
+-- Some notes:
+-- 1. The purpose of this script is not to check if the libac.so work
+-- properly, it is to check if there are something stupid in load_ac.lua
+--
+-- 2. There are bunch of collectgarbage() calls, the purpose is to make
+-- sure the shared lib is not unloaded after GC.
+
+-- load_ac.lua looks up libac.so via package.cpath rather than LD_LIBRARY_PATH,
+-- prepend (instead of appending) some insane paths here to see if it quit
+-- prematurely.
+--
+package.cpath = ".;./?.so;" .. package.cpath
+
+local ac = require "load_ac"
+
+local ac_create = ac.create_ac
+local ac_match = ac.match
+local string_fmt = string.format
+local string_sub = string.sub
+
+local err_cnt = 0
+local function mytest(testname, dict, match, notmatch)
+ print(">Testing ", testname)
+
+ io.write(string_fmt("Dictionary: "));
+ for i=1, #dict do
+ io.write(string_fmt("%s, ", dict[i]))
+ end
+ print ""
+
+ local ac_inst = ac_create(dict);
+ collectgarbage()
+ for i=1, #match do
+ local str = match[i]
+ io.write(string_fmt("Matching %s, ", str))
+ local b = ac_match(ac_inst, str)
+ if b then
+ print "pass"
+ else
+ err_cnt = err_cnt + 1
+ print "fail"
+ end
+ collectgarbage()
+ end
+
+ if notmatch == nil then
+ return
+ end
+
+ collectgarbage()
+
+ for i = 1, #notmatch do
+ local str = notmatch[i]
+ io.write(string_fmt("*Matching %s, ", str))
+ local r = ac_match(ac_inst, str)
+ if r then
+ err_cnt = err_cnt + 1
+ print("fail")
+ else
+ print("succ")
+ end
+ collectgarbage()
+ end
+ ac_inst = nil
+ collectgarbage()
+end
+
+print("")
+print("====== Test to see if load_ac.lua works properly ========")
+
+mytest("test1",
+ {"he", "she", "his", "her", "str\0ing"},
+ -- matching cases
+ { "he", "she", "his", "hers", "ahe", "shhe", "shis2", "ahhe", "str\0ing" },
+
+ -- not matching case
+ {"str\0", "str"}
+ )
+
+os.exit((err_cnt == 0) and 0 or 1)
diff --git a/modules/policy/lua-aho-corasick/tests/lua_test.lua b/modules/policy/lua-aho-corasick/tests/lua_test.lua
new file mode 100644
index 0000000..cfe178f
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/lua_test.lua
@@ -0,0 +1,67 @@
+-- This script is to test ahocorasick.so not libac.so
+--
+local ac = require "ahocorasick"
+
+local ac_create = ac.create
+local ac_match = ac.match
+local string_fmt = string.format
+local string_sub = string.sub
+
+local err_cnt = 0
+local function mytest(testname, dict, match, notmatch)
+ print(">Testing ", testname)
+
+ io.write(string_fmt("Dictionary: "));
+ for i=1, #dict do
+ io.write(string_fmt("%s, ", dict[i]))
+ end
+ print ""
+
+ local ac_inst = ac_create(dict);
+ for i=1, #match do
+ local str = match[i][1]
+ local substr = match[i][2]
+ io.write(string_fmt("Matching %s, ", str))
+ local b, e = ac_match(ac_inst, str)
+ if b and e and (string_sub(str, b+1, e+1) == substr) then
+ print "pass"
+ else
+ err_cnt = err_cnt + 1
+ print "fail"
+ end
+ --print("gc is called")
+ collectgarbage()
+ end
+
+ if notmatch == nil then
+ return
+ end
+
+ for i = 1, #notmatch do
+ local str = notmatch[i]
+ io.write(string_fmt("*Matching %s, ", str))
+ local r = ac_match(ac_inst, str)
+ if r then
+ err_cnt = err_cnt + 1
+ print("fail")
+ else
+ print("succ")
+ end
+ collectgarbage()
+ end
+end
+
+mytest("test1",
+ {"he", "she", "his", "her", "str\0ing"},
+ -- matching cases
+ { {"he", "he"}, {"she", "she"}, {"his", "his"}, {"hers", "he"},
+ {"ahe", "he"}, {"shhe", "he"}, {"shis2", "his"}, {"ahhe", "he"},
+ {"str\0ing", "str\0ing"}
+ },
+
+ -- not matching case
+ {"str\0", "str"}
+
+ )
+
+os.exit((err_cnt == 0) and 0 or 1)
diff --git a/modules/policy/lua-aho-corasick/tests/test_base.hpp b/modules/policy/lua-aho-corasick/tests/test_base.hpp
new file mode 100644
index 0000000..7758371
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/test_base.hpp
@@ -0,0 +1,60 @@
+#ifndef TEST_BASE_H
+#define TEST_BASE_H
+
+#include <stdio.h>
+#include <string>
+#include <stdint.h>
+
+using namespace std;
+class ACTestBase {
+public:
+ ACTestBase(const char* name) :_banner(name) {}
+ virtual void PrintBanner() {
+ fprintf(stdout, "\n===== %s ====\n", _banner.c_str());
+ }
+
+ virtual bool Run() = 0;
+private:
+ string _banner;
+};
+
+typedef std::pair<const char*, int> StrInfo;
+class BigFileTester {
+public:
+ BigFileTester(const char* filepath);
+ virtual ~BigFileTester() { Cleanup(); }
+
+ bool Run();
+
+protected:
+ virtual buf_header_t* PM_Create(const char** strv, uint32_t* strlenv,
+ uint32_t vect_len) = 0;
+ virtual void PM_Free(buf_header_t*) = 0;
+ virtual bool Run_Helper(buf_header_t* PM) = 0;
+
+ // Return true if the '\0' is valid char of a string.
+ virtual bool Str_C_Style() { return true; }
+
+ bool GenerateKeys();
+ void Cleanup();
+ void PrintStr(FILE*, const char* str, int len);
+
+protected:
+ const char* _filepath;
+ int _fd;
+ vector<StrInfo> _keys;
+ char* _msg;
+ int _msg_len;
+ int _key_num; // number of strings in dictionary
+ int _chunk_sz;
+ int _chunk_num;
+
+ int _max_key_num;
+ int _key_min_len;
+ int _key_max_len;
+};
+
+extern bool Run_AC_Simple_Test();
+extern bool Run_AC_Aggressive_Test(const vector<const char*>& files);
+
+#endif
diff --git a/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx
new file mode 100644
index 0000000..f189d8d
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx
@@ -0,0 +1,167 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+///////////////////////////////////////////////////////////////////////////
+//
+// Implementation of BigFileTester
+//
+///////////////////////////////////////////////////////////////////////////
+//
+BigFileTester::BigFileTester(const char* filepath) {
+ _filepath = filepath;
+ _fd = -1;
+ _msg = (char*)MAP_FAILED;
+ _msg_len = 0;
+ _key_num = 0;
+ _chunk_sz = 0;
+ _chunk_num = 0;
+
+ _max_key_num = 100;
+ _key_min_len = 20;
+ _key_max_len = 80;
+}
+
+void
+BigFileTester::Cleanup() {
+ if (_msg != MAP_FAILED) {
+ munmap((void*)_msg, _msg_len);
+ _msg = (char*)MAP_FAILED;
+ _msg_len = 0;
+ }
+
+ if (_fd != -1) {
+ close(_fd);
+ _fd = -1;
+ }
+}
+
+bool
+BigFileTester::GenerateKeys() {
+ int chunk_sz = 4096;
+ int max_key_num = _max_key_num;
+ int key_min_len = _key_min_len;
+ int key_max_len = _key_max_len;
+
+ int t = _msg_len / chunk_sz;
+ int keynum = t > max_key_num ? max_key_num : t;
+
+ if (keynum <= 4) {
+ // file is too small
+ return false;
+ }
+ chunk_sz = _msg_len / keynum;
+ _chunk_sz = chunk_sz;
+
+ // For each chunck, "randomly" grab a sub-string searving
+ // as key.
+ int random_ofst[] = { 12, 30, 23, 15 };
+ int rofstsz = sizeof(random_ofst)/sizeof(random_ofst[0]);
+ int ofst = 0;
+ const char* msg = _msg;
+ _chunk_num = keynum - 1;
+ for (int idx = 0, e = _chunk_num; idx < e; idx++) {
+ const char* key = msg + ofst + idx % rofstsz;
+ int key_len = key_min_len + idx % (key_max_len - key_min_len);
+ _keys.push_back(StrInfo(key, key_len));
+ ofst += chunk_sz;
+ }
+ return true;
+}
+
+bool
+BigFileTester::Run() {
+ // Step 1: Bring the file into memory
+ fprintf(stdout, "Testing using file '%s'...\n", _filepath);
+
+ int fd = _fd = ::open(_filepath, O_RDONLY);
+ if (fd == -1) {
+ perror("open");
+ return false;
+ }
+
+ struct stat sb;
+ if (fstat(fd, &sb) == -1) {
+ perror("fstat");
+ return false;
+ }
+
+ if (!S_ISREG (sb.st_mode)) {
+ fprintf(stderr, "%s is not regular file\n", _filepath);
+ return false;
+ }
+
+ int ten_M = 1024 * 1024 * 10;
+ int map_sz = _msg_len = sb.st_size > ten_M ? ten_M : sb.st_size;
+ char* p = _msg =
+ (char*)mmap (0, map_sz, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0);
+ if (p == MAP_FAILED) {
+ perror("mmap");
+ return false;
+ }
+
+ // Get rid of '\0' if we are picky at it.
+ if (Str_C_Style()) {
+ for (int i = 0; i < map_sz; i++) { if (!p[i]) p[i] = 'a'; }
+ p[map_sz - 1] = 0;
+ }
+
+ // Step 2: "Fabricate" some keys from the file.
+ if (!GenerateKeys()) {
+ close(fd);
+ return false;
+ }
+
+ // Step 3: Create PM instance
+ const char** keys = new const char*[_keys.size()];
+ unsigned int* keylens = new unsigned int[_keys.size()];
+
+ int i = 0;
+ for (vector<StrInfo>::iterator si = _keys.begin(), se = _keys.end();
+ si != se; si++, i++) {
+ const StrInfo& strinfo = *si;
+ keys[i] = strinfo.first;
+ keylens[i] = strinfo.second;
+ }
+
+ buf_header_t* PM = PM_Create(keys, keylens, i);
+ delete[] keys;
+ delete[] keylens;
+
+ // Step 4: Run testing
+ bool res = Run_Helper(PM);
+ PM_Free(PM);
+
+ // Step 5: Clanup
+ munmap(p, map_sz);
+ _msg = (char*)MAP_FAILED;
+ close(fd);
+ _fd = -1;
+
+ fprintf(stdout, "%s\n", res ? "succ" : "fail");
+ return res;
+}
+
+void
+BigFileTester::PrintStr(FILE* f, const char* str, int len) {
+ fprintf(f, "{");
+ for (int i = 0; i < len; i++) {
+ unsigned char c = str[i];
+ if (isprint(c))
+ fprintf(f, "'%c', ", c);
+ else
+ fprintf(f, "%#x, ", c);
+ }
+ fprintf(f, "}");
+};
diff --git a/modules/policy/lua-aho-corasick/tests/test_main.cxx b/modules/policy/lua-aho-corasick/tests/test_main.cxx
new file mode 100644
index 0000000..b4f5225
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/test_main.cxx
@@ -0,0 +1,33 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+using namespace std;
+
+
+/////////////////////////////////////////////////////////////////////////
+//
+// Simple (yet maybe tricky) testings
+//
+/////////////////////////////////////////////////////////////////////////
+//
+int
+main (int argc, char** argv) {
+ bool succ = Run_AC_Simple_Test();
+
+ vector<const char*> files;
+ for (int i = 1; i < argc; i++) { files.push_back(argv[i]); }
+ succ = Run_AC_Aggressive_Test(files) && succ;
+
+ return succ ? 0 : -1;
+};
diff --git a/modules/policy/meson.build b/modules/policy/meson.build
new file mode 100644
index 0000000..37f1683
--- /dev/null
+++ b/modules/policy/meson.build
@@ -0,0 +1,50 @@
+# LUA module: policy
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+lua_mod_src += [
+ files('policy.lua'),
+]
+
+config_tests += [
+ ['policy', files('policy.test.lua')],
+ ['policy.slice', files('policy.slice.test.lua')],
+ ['policy.rpz', files('policy.rpz.test.lua')],
+]
+
+integr_tests += [
+ ['policy', meson.current_source_dir() / 'test.integr'],
+ ['policy.noipv6', meson.current_source_dir() / 'noipv6.test.integr'],
+ ['policy.noipvx', meson.current_source_dir() / 'noipvx.test.integr'],
+]
+
+# check git submodules were initialized
+lua_ac_submodule = run_command(['test', '-r',
+ '@0@/lua-aho-corasick/ac_fast.cxx'.format(meson.current_source_dir())],
+ check: false)
+if lua_ac_submodule.returncode() != 0
+ error('run "git submodule update --init --recursive" to initialize git submodules')
+endif
+
+# compile bundled lua-aho-corasick as shared module
+lua_ac_src = files([
+ 'lua-aho-corasick/ac_fast.cxx',
+ 'lua-aho-corasick/ac_lua.cxx',
+ 'lua-aho-corasick/ac_slow.cxx',
+])
+
+lua_ac_lib = shared_module(
+ 'ahocorasick',
+ lua_ac_src,
+ cpp_args: [
+ '-fvisibility=hidden',
+ '-Wall',
+ '-fPIC',
+ ],
+ dependencies: [
+ luajit_inc,
+ ],
+ include_directories: mod_inc_dir,
+ name_prefix: '',
+ install: true,
+ install_dir: lib_dir,
+)
diff --git a/modules/policy/noipv6.test.integr/broken-ipv6.rpl b/modules/policy/noipv6.test.integr/broken-ipv6.rpl
new file mode 100644
index 0000000..cb0738d
--- /dev/null
+++ b/modules/policy/noipv6.test.integr/broken-ipv6.rpl
@@ -0,0 +1,47 @@
+; config options
+; SPDX-License-Identifier: GPL-3.0-or-later
+ stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET.
+CONFIG_END
+
+SCENARIO_BEGIN Test that IPv6 is not used by kresd.
+
+RANGE_BEGIN 0 100
+ ADDRESS ::1:2:3:4
+RANGE_END
+
+RANGE_BEGIN 0 100
+ ADDRESS 1.2.3.4
+
+ENTRY_BEGIN
+MATCH opcode qtype qname
+ADJUST copy_id
+REPLY QR NOERROR
+SECTION QUESTION
+www.test.org A
+SECTION ANSWER
+www.test.org 3600 A 4.3.2.1
+ENTRY_END
+
+RANGE_END
+
+
+STEP 10 QUERY
+ENTRY_BEGIN
+REPLY RD AD
+SECTION QUESTION
+www.test.org A
+ENTRY_END
+
+STEP 20 CHECK_ANSWER
+ENTRY_BEGIN
+MATCH all answer
+REPLY QR RD RA NOERROR
+SECTION QUESTION
+www.test.org A
+SECTION ANSWER
+www.test.org 3600 A 4.3.2.1
+SECTION AUTHORITY
+SECTION ADDITIONAL
+ENTRY_END
+
+SCENARIO_END
diff --git a/modules/policy/noipv6.test.integr/deckard.yaml b/modules/policy/noipv6.test.integr/deckard.yaml
new file mode 100644
index 0000000..4c1b6f8
--- /dev/null
+++ b/modules/policy/noipv6.test.integr/deckard.yaml
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+programs:
+- name: kresd
+ binary: kresd
+ additional:
+ - --noninteractive
+ templates:
+ - modules/policy/noipv6.test.integr/kresd_config.j2
+ - tests/integration/hints_zone.j2
+ configs:
+ - config
+ - hints
diff --git a/modules/policy/noipv6.test.integr/kresd_config.j2 b/modules/policy/noipv6.test.integr/kresd_config.j2
new file mode 100644
index 0000000..a897d17
--- /dev/null
+++ b/modules/policy/noipv6.test.integr/kresd_config.j2
@@ -0,0 +1,59 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+{% raw %}
+net.ipv6 = false
+policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' })))
+
+-- make sure DNSSEC is turned off for tests
+trust_anchors.remove('.')
+
+-- Disable RFC5011 TA update
+if ta_update then
+ modules.unload('ta_update')
+end
+
+-- Disable RFC8145 signaling, scenario doesn't provide expected answers
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+
+-- Disable RFC8109 priming, scenario doesn't provide expected answers
+if priming then
+ modules.unload('priming')
+end
+
+-- Disable this module because it make one priming query
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+_hint_root_file('hints')
+cache.size = 2*MB
+log_level('debug')
+{% endraw %}
+
+net = { '{{SELF_ADDR}}' }
+
+
+{% if QMIN == "false" %}
+option('NO_MINIMIZE', true)
+{% else %}
+option('NO_MINIMIZE', false)
+{% endif %}
+
+
+-- Self-checks on globals
+assert(help() ~= nil)
+assert(worker.id ~= nil)
+-- Self-checks on facilities
+assert(cache.count() == 0)
+assert(cache.stats() ~= nil)
+assert(cache.backends() ~= nil)
+assert(worker.stats() ~= nil)
+assert(net.interfaces() ~= nil)
+-- Self-checks on loaded stuff
+assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
+assert(#modules.list() > 0)
+-- Self-check timers
+ev = event.recurrent(1 * sec, function (ev) return 1 end)
+event.cancel(ev)
+ev = event.after(0, function (ev) return 1 end)
diff --git a/modules/policy/noipvx.test.integr/broken-ipvx.rpl b/modules/policy/noipvx.test.integr/broken-ipvx.rpl
new file mode 100644
index 0000000..60ed618
--- /dev/null
+++ b/modules/policy/noipvx.test.integr/broken-ipvx.rpl
@@ -0,0 +1,35 @@
+; config options
+; SPDX-License-Identifier: GPL-3.0-or-later
+ stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET.
+CONFIG_END
+
+SCENARIO_BEGIN Test that neither IPv6 nor IPv4 is used by kresd :-)
+
+RANGE_BEGIN 0 100
+ ADDRESS ::1:2:3:4
+RANGE_END
+
+RANGE_BEGIN 0 100
+ ADDRESS 1.2.3.4
+RANGE_END
+
+
+STEP 10 QUERY
+ENTRY_BEGIN
+REPLY RD AD
+SECTION QUESTION
+www.test.org A
+ENTRY_END
+
+STEP 20 CHECK_ANSWER
+ENTRY_BEGIN
+MATCH all answer
+REPLY QR RD RA SERVFAIL
+SECTION QUESTION
+www.test.org A
+SECTION ANSWER
+SECTION AUTHORITY
+SECTION ADDITIONAL
+ENTRY_END
+
+SCENARIO_END
diff --git a/modules/policy/noipvx.test.integr/deckard.yaml b/modules/policy/noipvx.test.integr/deckard.yaml
new file mode 100644
index 0000000..8178759
--- /dev/null
+++ b/modules/policy/noipvx.test.integr/deckard.yaml
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+programs:
+- name: kresd
+ binary: kresd
+ additional:
+ - --noninteractive
+ templates:
+ - modules/policy/noipvx.test.integr/kresd_config.j2
+ - tests/integration/hints_zone.j2
+ configs:
+ - config
+ - hints
diff --git a/modules/policy/noipvx.test.integr/kresd_config.j2 b/modules/policy/noipvx.test.integr/kresd_config.j2
new file mode 100644
index 0000000..87873e8
--- /dev/null
+++ b/modules/policy/noipvx.test.integr/kresd_config.j2
@@ -0,0 +1,60 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+{% raw %}
+net.ipv4 = false
+net.ipv6 = false
+policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' })))
+
+-- make sure DNSSEC is turned off for tests
+trust_anchors.remove('.')
+
+-- Disable RFC5011 TA update
+if ta_update then
+ modules.unload('ta_update')
+end
+
+-- Disable RFC8145 signaling, scenario doesn't provide expected answers
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+
+-- Disable RFC8109 priming, scenario doesn't provide expected answers
+if priming then
+ modules.unload('priming')
+end
+
+-- Disable this module because it make one priming query
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+_hint_root_file('hints')
+cache.size = 2*MB
+log_level('debug')
+{% endraw %}
+
+net = { '{{SELF_ADDR}}' }
+
+
+{% if QMIN == "false" %}
+option('NO_MINIMIZE', true)
+{% else %}
+option('NO_MINIMIZE', false)
+{% endif %}
+
+
+-- Self-checks on globals
+assert(help() ~= nil)
+assert(worker.id ~= nil)
+-- Self-checks on facilities
+assert(cache.count() == 0)
+assert(cache.stats() ~= nil)
+assert(cache.backends() ~= nil)
+assert(worker.stats() ~= nil)
+assert(net.interfaces() ~= nil)
+-- Self-checks on loaded stuff
+assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
+assert(#modules.list() > 0)
+-- Self-check timers
+ev = event.recurrent(1 * sec, function (ev) return 1 end)
+event.cancel(ev)
+ev = event.after(0, function (ev) return 1 end)
diff --git a/modules/policy/policy.lua b/modules/policy/policy.lua
new file mode 100644
index 0000000..47e436f
--- /dev/null
+++ b/modules/policy/policy.lua
@@ -0,0 +1,1109 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+local kres = require('kres')
+local ffi = require('ffi')
+
+local LOG_GRP_POLICY_TAG = ffi.string(ffi.C.kr_log_grp2name(ffi.C.LOG_GRP_POLICY))
+local LOG_GRP_REQDBG_TAG = ffi.string(ffi.C.kr_log_grp2name(ffi.C.LOG_GRP_REQDBG))
+
+local todname = kres.str2dname -- not available during module load otherwise
+
+-- Counter of unique rules
+local nextid = 0
+local function getruleid()
+ local newid = nextid
+ nextid = nextid + 1
+ return newid
+end
+
+-- Support for client sockets from inside policy actions
+local socket_client = function ()
+ return error("missing lua-cqueues library, can't create socket client")
+end
+local has_socket, socket = pcall(require, 'cqueues.socket')
+if has_socket then
+ socket_client = function (host, port)
+ local s, err, status
+
+ s = socket.connect({ host = host, port = port, type = socket.SOCK_DGRAM })
+ s:setmode('bn', 'bn')
+ status, err = pcall(s.connect, s)
+
+ if not status then
+ return status, err
+ end
+ return s
+ end
+end
+
+-- Split address and port from a combined string.
+local function addr_split_port(target, default_port)
+ assert(default_port and type(default_port) == 'number')
+ local port = ffi.new('uint16_t[1]', default_port)
+ local addr = ffi.new('char[47]') -- INET6_ADDRSTRLEN + 1
+ local ret = ffi.C.kr_straddr_split(target, addr, port)
+ if ret ~= 0 then
+ error('failed to parse address ' .. target)
+ end
+ return addr, tonumber(port[0])
+end
+
+-- String address@port -> sockaddr.
+local function addr2sock(target, default_port)
+ local addr, port = addr_split_port(target, default_port)
+ local sock = ffi.gc(ffi.C.kr_straddr_socket(addr, port, nil), ffi.C.free);
+ if sock == nil then
+ error("target '"..target..'" is not a valid IP address')
+ end
+ return sock
+end
+
+-- Debug logging for taken policy actions
+local function log_policy_action(req, name)
+ if ffi.C.kr_log_is_debug_fun(ffi.C.LOG_GRP_POLICY, req) then
+ local qry = req:current()
+ ffi.C.kr_log_req1(
+ req, qry.uid, 2, ffi.C.LOG_GRP_POLICY, LOG_GRP_POLICY_TAG,
+ "%s applied for %s %s\n",
+ name, kres.dname2str(qry.sname), kres.tostring.type[qry.stype])
+ end
+end
+
+-- policy functions are defined below
+local policy = {}
+
+function policy.PASS(state, _)
+ return state
+end
+
+-- Mirror request elsewhere, and continue solving
+function policy.MIRROR(target)
+ local addr, port = addr_split_port(target, 53)
+ local sink, err = socket_client(ffi.string(addr), port)
+ if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end
+ return function(state, req)
+ if state == kres.FAIL then return state end
+ local query = req.qsource.packet
+ if query ~= nil then
+ sink:send(ffi.string(query.wire, query.size), 1, tonumber(query.size))
+ end
+ return -- Chain action to next
+ end
+end
+
+-- Override the list of nameservers (forwarders)
+local function set_nslist(req, list)
+ local ns_i = 0
+ for _, ns in ipairs(list) do
+ if ffi.C.kr_forward_add_target(req, ns) == 0 then
+ ns_i = ns_i + 1
+ end
+ end
+ if ns_i == 0 then
+ -- would use assert() but don't want to compose the message if not triggered
+ error('no usable address in NS set (check net.ipv4 and '
+ .. 'net.ipv6 config):\n' .. table_print(list, 2))
+ end
+end
+
+-- Forward request, and solve as stub query
+function policy.STUB(target)
+ local list = {}
+ if type(target) == 'table' then
+ for _, v in pairs(target) do
+ table.insert(list, addr2sock(v, 53))
+ end
+ else
+ table.insert(list, addr2sock(target, 53))
+ end
+ return function(state, req)
+ local qry = req:current()
+ -- Switch mode to stub resolver, do not track origin zone cut since it's not real authority NS
+ qry.flags.STUB = true
+ qry.flags.ALWAYS_CUT = false
+ set_nslist(req, list)
+ return state
+ end
+end
+
+-- Forward request and all subrequests to upstream; validate answers
+function policy.FORWARD(target)
+ local list = {}
+ if type(target) == 'table' then
+ for _, v in pairs(target) do
+ table.insert(list, addr2sock(v, 53))
+ end
+ else
+ table.insert(list, addr2sock(target, 53))
+ end
+ return function(state, req)
+ local qry = req:current()
+ req.options.FORWARD = true
+ req.options.NO_MINIMIZE = true
+ qry.flags.FORWARD = true
+ qry.flags.ALWAYS_CUT = false
+ qry.flags.NO_MINIMIZE = true
+ qry.flags.AWAIT_CUT = true
+ set_nslist(req, list)
+ return state
+ end
+end
+
+-- Forward request and all subrequests to upstream over TLS; validate answers
+function policy.TLS_FORWARD(targets)
+ if type(targets) ~= 'table' or #targets < 1 then
+ error('TLS_FORWARD argument must be a non-empty table')
+ end
+
+ local sockaddr_c_set = {}
+ local nslist = {} -- to persist in closure of the returned function
+ for idx, target in pairs(targets) do
+ if type(target) ~= 'table' or type(target[1]) ~= 'string' then
+ error(string.format('TLS_FORWARD configuration at position ' ..
+ '%d must be a table starting with an IP address', idx))
+ end
+ -- Note: some functions have checks with error() calls inside.
+ local sockaddr_c = addr2sock(target[1], 853)
+
+ -- Refuse repeated addresses in the same set.
+ local sockaddr_lua = ffi.string(sockaddr_c, ffi.C.kr_sockaddr_len(sockaddr_c))
+ if sockaddr_c_set[sockaddr_lua] then
+ error('TLS_FORWARD configuration cannot declare two configs for IP address '
+ .. target[1])
+ else
+ sockaddr_c_set[sockaddr_lua] = true;
+ end
+
+ table.insert(nslist, sockaddr_c)
+ net.tls_client(target)
+ end
+
+ return function(state, req)
+ local qry = req:current()
+ req.options.FORWARD = true
+ req.options.NO_MINIMIZE = true
+ qry.flags.FORWARD = true
+ qry.flags.ALWAYS_CUT = false
+ qry.flags.NO_MINIMIZE = true
+ qry.flags.AWAIT_CUT = true
+ req.options.TCP = true
+ qry.flags.TCP = true
+ set_nslist(req, nslist)
+ return state
+ end
+end
+
+-- Rewrite records in packet
+function policy.REROUTE(tbl, names)
+ -- Import renumbering rules
+ local ren = require('kres_modules.renumber')
+ local prefixes = {}
+ for from, to in pairs(tbl) do
+ local prefix = names and ren.name(from, to) or ren.prefix(from, to)
+ table.insert(prefixes, prefix)
+ end
+ -- Return rule closure
+ return ren.rule(prefixes)
+end
+
+-- Set and clear some query flags
+function policy.FLAGS(opts_set, opts_clear)
+ return function(_, req)
+ -- We assume to be running in the begin phase, so to truly apply
+ -- to the whole request we need to change both kr_request and kr_query.
+ local qry = req:current()
+ for _, flags in pairs({qry.flags, req.options}) do
+ ffi.C.kr_qflags_set (flags, kres.mk_qflags(opts_set or {}))
+ ffi.C.kr_qflags_clear(flags, kres.mk_qflags(opts_clear or {}))
+ end
+ return nil -- chain rule
+ end
+end
+
+local function mkauth_soa(answer, dname, mname, ttl)
+ if mname == nil then
+ mname = dname
+ end
+ return answer:put(dname, ttl or 10800, answer:qclass(), kres.type.SOA,
+ mname .. '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48')
+end
+
+-- Create answer with passed arguments
+function policy.ANSWER(rtable, nodata)
+ return function(_, req)
+ local qry = req:current()
+ local data = rtable[qry.stype]
+ if data == nil and nodata ~= true then
+ return nil
+ end
+ -- now we're certain we want to generate an answer
+
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ ffi.C.kr_pkt_make_auth_header(answer)
+ local ttl = (data or {}).ttl or 1
+ answer:rcode(kres.rcode.NOERROR)
+ req:set_extended_error(kres.extended_error.FORGED, "5DO5")
+
+ if data == nil then -- want NODATA, i.e. just a SOA
+ answer:begin(kres.section.AUTHORITY)
+ local soa = rtable[kres.type.SOA]
+ if soa ~= nil then
+ answer:put(qry.sname, soa.ttl or ttl, qry.sclass, kres.type.SOA,
+ soa.rdata[1] or soa.rdata)
+ else
+ mkauth_soa(answer, kres.dname2wire(qry.sname), nil, ttl)
+ end
+ log_policy_action(req, 'ANSWER (nodata)')
+ else
+ answer:begin(kres.section.ANSWER)
+ if type(data.rdata) == 'table' then
+ for _, entry in ipairs(data.rdata) do
+ answer:put(qry.sname, ttl, qry.sclass, qry.stype, entry)
+ end
+ else
+ answer:put(qry.sname, ttl, qry.sclass, qry.stype, data.rdata)
+ end
+ log_policy_action(req, 'ANSWER (forged)')
+ end
+ return kres.DONE
+ end
+end
+
+local dname_localhost = todname('localhost.')
+
+-- Rule for localhost. zone; see RFC6303, sec. 3
+local function localhost(_, req)
+ local qry = req:current()
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ ffi.C.kr_pkt_make_auth_header(answer)
+
+ local is_exact = ffi.C.knot_dname_is_equal(qry.sname, dname_localhost)
+
+ answer:rcode(kres.rcode.NOERROR)
+ answer:begin(kres.section.ANSWER)
+ if qry.stype == kres.type.AAAA then
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.AAAA,
+ '\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1')
+ elseif qry.stype == kres.type.A then
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\127\0\0\1')
+ elseif is_exact and qry.stype == kres.type.SOA then
+ mkauth_soa(answer, dname_localhost)
+ elseif is_exact and qry.stype == kres.type.NS then
+ answer:put(dname_localhost, 900, answer:qclass(), kres.type.NS, dname_localhost)
+ else
+ answer:begin(kres.section.AUTHORITY)
+ mkauth_soa(answer, dname_localhost)
+ end
+ return kres.DONE
+end
+
+local dname_rev4_localhost = todname('1.0.0.127.in-addr.arpa');
+local dname_rev4_localhost_apex = todname('127.in-addr.arpa');
+
+-- Rule for reverse localhost.
+-- Answer with locally served minimal 127.in-addr.arpa domain, only having
+-- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone.
+-- TODO: much of this would better be left to the hints module (or coordinated).
+local function localhost_reversed(_, req)
+ local qry = req:current()
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+
+ -- classify qry.sname:
+ local is_exact -- exact dname for localhost
+ local is_apex -- apex of a locally-served localhost zone
+ local is_nonterm -- empty non-terminal name
+ if ffi.C.knot_dname_in_bailiwick(qry.sname, todname('ip6.arpa.')) > 0 then
+ -- exact ::1 query (relying on the calling rule)
+ is_exact = true
+ is_apex = true
+ else
+ -- within 127.in-addr.arpa.
+ local labels = ffi.C.knot_dname_labels(qry.sname, nil)
+ if labels == 3 then
+ is_exact = false
+ is_apex = true
+ elseif labels == 4+2 and ffi.C.knot_dname_is_equal(
+ qry.sname, dname_rev4_localhost) then
+ is_exact = true
+ else
+ is_exact = false
+ is_apex = false
+ is_nonterm = ffi.C.knot_dname_in_bailiwick(dname_rev4_localhost, qry.sname) > 0
+ end
+ end
+
+ ffi.C.kr_pkt_make_auth_header(answer)
+ answer:rcode(kres.rcode.NOERROR)
+ answer:begin(kres.section.ANSWER)
+ if is_exact and qry.stype == kres.type.PTR then
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.PTR, dname_localhost)
+ elseif is_apex and qry.stype == kres.type.SOA then
+ mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost)
+ elseif is_apex and qry.stype == kres.type.NS then
+ answer:put(dname_rev4_localhost_apex, 900, answer:qclass(), kres.type.NS,
+ dname_localhost)
+ else
+ if not is_nonterm then
+ answer:rcode(kres.rcode.NXDOMAIN)
+ end
+ answer:begin(kres.section.AUTHORITY)
+ mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost)
+ end
+ return kres.DONE
+end
+
+-- All requests
+function policy.all(action)
+ return function(_, _) return action end
+end
+
+-- Requests whose QNAME is exactly the provided domain
+function policy.domains(action, dname_list)
+ return function(_, query)
+ local qname = query:name()
+ for _, dname in ipairs(dname_list) do
+ if ffi.C.knot_dname_is_equal(qname, dname) then
+ return action
+ end
+ end
+ return nil
+ end
+end
+
+-- Requests whose QNAME matches given zone list (i.e. suffix match)
+function policy.suffix(action, zone_list)
+ local AC = require('ahocorasick')
+ local tree = AC.create(zone_list)
+ return function(_, query)
+ local match = AC.match(tree, query:name(), false)
+ if match ~= nil then
+ return action
+ end
+ return nil
+ end
+end
+
+-- Check for common suffix first, then suffix match (specialized version of suffix match)
+function policy.suffix_common(action, suffix_list, common_suffix)
+ local common_len = string.len(common_suffix)
+ local suffix_count = #suffix_list
+ return function(_, query)
+ -- Preliminary check
+ local qname = query:name()
+ if not string.find(qname, common_suffix, -common_len, true) then
+ return nil
+ end
+ -- String match
+ for i = 1, suffix_count do
+ local zone = suffix_list[i]
+ if string.find(qname, zone, -string.len(zone), true) then
+ return action
+ end
+ end
+ return nil
+ end
+end
+
+-- Filter QNAME pattern
+function policy.pattern(action, pattern)
+ return function(_, query)
+ if string.find(query:name(), pattern) then
+ return action
+ end
+ return nil
+ end
+end
+
+local function rpz_parse(action, path)
+ local rules = {}
+ local new_actions = {}
+ local action_map = {
+ -- RPZ Policy Actions
+ ['\0'] = action,
+ ['\1*\0'] = policy.ANSWER({}, true),
+ ['\012rpz-passthru\0'] = policy.PASS, -- the grammar...
+ ['\008rpz-drop\0'] = policy.DROP,
+ ['\012rpz-tcp-only\0'] = policy.TC,
+ -- Policy triggers @NYI@
+ }
+ -- RR types to be skipped; boolean denoting whether to throw a warning even for RPZ apex.
+ local rrtype_bad = {
+ [kres.type.DNAME] = true,
+ [kres.type.NS] = false,
+ [kres.type.DNSKEY] = true,
+ [kres.type.DS] = true,
+ [kres.type.RRSIG] = true,
+ [kres.type.NSEC] = true,
+ [kres.type.NSEC3] = true,
+ }
+
+ -- We generally don't know what zone should be in the file; we try to detect it.
+ -- Fortunately, it's typical that SOA is the first record, even required for AXFR.
+ local origin_soa = nil
+ local warned_soa, warned_bailiwick
+
+ local parser = require('zonefile').new()
+ local ok, errstr = parser:open(path)
+ if not ok then
+ error(string.format('failed to parse "%s": %s', path, errstr or "unknown error"))
+ end
+ while true do
+ ok, errstr = parser:parse()
+ if errstr then
+ log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d: %s',
+ path, tonumber(parser.line_counter), errstr)
+ end
+ if not ok then break end
+
+ local full_name = ffi.gc(ffi.C.knot_dname_copy(parser.r_owner, nil), ffi.C.free)
+ local rdata = ffi.string(parser.r_data, parser.r_data_length)
+ ffi.C.knot_dname_to_lower(full_name)
+
+ local origin = origin_soa or parser.zone_origin
+ local prefix_labels = ffi.C.knot_dname_in_bailiwick(full_name, origin)
+ if prefix_labels < 0 then
+ if not warned_bailiwick then
+ warned_bailiwick = true
+ log_warn(ffi.C.LOG_GRP_POLICY,
+ 'RPZ %s:%d: RR owner "%s" outside the zone (ignored; reported once per file)',
+ path, tonumber(parser.line_counter), kres.dname2str(full_name))
+ end
+ goto continue
+ end
+
+ local bytes = ffi.C.knot_dname_size(full_name) - ffi.C.knot_dname_size(origin)
+ local name = ffi.string(full_name, bytes) .. '\0'
+
+ if parser.r_type == kres.type.CNAME then
+ if action_map[rdata] then
+ rules[name] = action_map[rdata]
+ else
+ log_warn(ffi.C.LOG_GRP_POLICY,
+ 'RPZ %s:%d: CNAME with custom target in RPZ is not supported yet (ignored)',
+ path, tonumber(parser.line_counter))
+ end
+ else
+ if #name then
+ local is_bad = rrtype_bad[parser.r_type]
+
+ if parser.r_type == kres.type.SOA then
+ if origin_soa == nil then
+ origin_soa = ffi.gc(ffi.C.knot_dname_copy(parser.r_owner, nil), ffi.C.free)
+ goto continue -- we don't want to modify `new_actions`
+ else
+ is_bad = true -- maybe provide more info, but it seems rare
+ end
+ elseif origin_soa == nil and not warned_soa then
+ warned_soa = true
+ log_warn(ffi.C.LOG_GRP_POLICY,
+ 'RPZ %s:%d warning: SOA missing as the first record',
+ path, tonumber(parser.line_counter))
+ end
+
+ if is_bad == true or (is_bad == false and prefix_labels ~= 0) then
+ log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d warning: RR type %s is not allowed in RPZ (ignored)',
+ path, tonumber(parser.line_counter), kres.tostring.type[parser.r_type])
+ elseif is_bad == nil then
+ if new_actions[name] == nil then new_actions[name] = {} end
+ local act = new_actions[name][parser.r_type]
+ if act == nil then
+ new_actions[name][parser.r_type] = { ttl=parser.r_ttl, rdata=rdata }
+ else -- multiple RRs: no reordering or deduplication
+ if type(act.rdata) ~= 'table' then
+ act.rdata = { act.rdata }
+ end
+ table.insert(act.rdata, rdata)
+ if parser.r_ttl ~= act.ttl then -- be conservative
+ log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d warning: different TTLs in a set (minimum taken)',
+ path, tonumber(parser.line_counter))
+ act.ttl = math.min(act.ttl, parser.r_ttl)
+ end
+ end
+ else
+ assert(is_bad == false and prefix_labels == 0)
+ end
+ end
+ end
+
+ ::continue::
+ end
+ collectgarbage()
+ for qname, rrsets in pairs(new_actions) do
+ rules[qname] = policy.ANSWER(rrsets, true)
+ end
+ return rules
+end
+
+-- Split path into dirname and basename (like the shell utilities)
+local function get_dir_and_file(path)
+ local dir, file = string.match(path, "(.*)/([^/]+)")
+
+ -- If regex doesn't match then path must be the file directly (i.e. doesn't contain '/')
+ -- This assumes that the file exists (rpz_parse() would fail if it doesn't)
+ if not dir and not file then
+ dir = '.'
+ file = path
+ end
+
+ return dir, file
+end
+
+-- RPZ policy set
+-- Create RPZ from zone file and optionally watch the file for changes
+function policy.rpz(action, path, watch)
+ local rules = rpz_parse(action, path)
+
+ if watch ~= false then
+ local has_notify, notify = pcall(require, 'cqueues.notify')
+ if has_notify then
+ local bit = require('bit')
+
+ local dir, file = get_dir_and_file(path)
+ local watcher = notify.opendir(dir)
+ watcher:add(file, bit.bxor(notify.CREATE, notify.MODIFY))
+
+ worker.coroutine(function ()
+ for _, name in watcher:changes() do
+ -- Limit to changes on file we're interested in
+ -- Watcher will also fire for changes to the directory itself
+ if name == file then
+ -- If the file changes then reparse and replace the existing ruleset
+ log_info(ffi.C.LOG_GRP_POLICY, 'RPZ reloading: ' .. name)
+ rules = rpz_parse(action, path)
+ end
+ end
+ end)
+ elseif watch then -- explicitly requested and failed
+ error('[poli] lua-cqueues required to watch and reload RPZ file')
+ else
+ log_info(ffi.C.LOG_GRP_POLICY, 'lua-cqueues required to watch and reload RPZ file, continuing without watching')
+ end
+ end
+
+ return function(_, query)
+ local label = query:name()
+ local rule = rules[label]
+ while rule == nil and string.len(label) > 0 do
+ label = string.sub(label, string.byte(label) + 2)
+ rule = rules['\1*'..label]
+ end
+ return rule
+ end
+end
+
+-- Apply an action when query belongs to a slice (determined by slice_func())
+function policy.slice(slice_func, ...)
+ local actions = {...}
+ if #actions <= 0 then
+ error('[poli] at least one action must be provided to policy.slice()')
+ end
+
+ return function(_, query)
+ local index = slice_func(query, #actions)
+ return actions[index]
+ end
+end
+
+-- Initializes slicing function that randomly assigns queries to a slice based on their registrable domain
+function policy.slice_randomize_psl(seed)
+ local has_psl, psl_lib = pcall(require, 'psl')
+ if not has_psl then
+ error('[poli] lua-psl is required for policy.slice_randomize_psl()')
+ end
+ -- load psl
+ local has_latest, psl = pcall(psl_lib.latest)
+ if not has_latest then -- compatibility with lua-psl < 0.15
+ psl = psl_lib.builtin()
+ end
+
+ if seed == nil then
+ seed = os.time() / (3600 * 24 * 7)
+ end
+ seed = math.floor(seed) -- convert to int
+
+ return function(query, length)
+ assert(length > 0)
+
+ local domain = kres.dname2str(query:name())
+ if domain == nil then -- invalid data: auto-select first action
+ return 1
+ end
+ if domain:len() > 1 then --remove trailing dot
+ domain = domain:sub(0, -2)
+ end
+
+ -- do psl lookup for registrable domain
+ local reg_domain = psl:registrable_domain(domain)
+ if reg_domain == nil then -- fallback to unreg. domain
+ reg_domain = psl:unregistrable_domain(domain)
+ if reg_domain == nil then -- shouldn't happen: safe fallback
+ return 1
+ end
+ end
+
+ local rand_seed = seed
+ -- create deterministic seed for pseudo-random slice assignment
+ for i = 1, #reg_domain do
+ rand_seed = rand_seed + reg_domain:byte(i)
+ end
+
+ -- use linear congruential generator with values from ANSI C
+ rand_seed = rand_seed % 0x80000000 -- ensure seed is positive 32b int
+ local rand = (1103515245 * rand_seed + 12345) % 0x10000
+ return 1 + rand % length
+ end
+end
+
+-- Prepare for making an answer from scratch. (Return the packet for convenience.)
+local function answer_clear(req)
+ -- If we're in postrules, previous resolving might have chosen some RRs
+ -- for inclusion in the answer, so we need to avoid those.
+ -- *_selected arrays are in mempool, so explicit deallocation is not necessary.
+ req.answ_selected.len = 0
+ req.auth_selected.len = 0
+ req.add_selected.len = 0
+
+ -- Let's be defensive and clear the answer, too.
+ local pkt = req:ensure_answer()
+ if pkt == nil then return nil end
+ pkt:clear_payload()
+ req:ensure_edns()
+ return pkt
+end
+
+function policy.DENY_MSG(msg, extended_error)
+ if msg and (type(msg) ~= 'string' or #msg >= 255) then
+ error('DENY_MSG: optional msg must be string shorter than 256 characters')
+ end
+ if extended_error == nil then
+ extended_error = kres.extended_error.BLOCKED
+ end
+ local action_name = msg and 'DENY_MSG' or 'DENY'
+
+ return function (_, req)
+ -- Write authority information
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ ffi.C.kr_pkt_make_auth_header(answer)
+ answer:rcode(kres.rcode.NXDOMAIN)
+ answer:begin(kres.section.AUTHORITY)
+ mkauth_soa(answer, answer:qname())
+ if msg then
+ answer:begin(kres.section.ADDITIONAL)
+ answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT,
+ string.char(#msg) .. msg)
+
+ end
+ req:set_extended_error(extended_error, "CR36")
+ log_policy_action(req, action_name)
+ return kres.DONE
+ end
+end
+
+local function free_cb(func)
+ func:free()
+end
+
+local debug_logline_cb = ffi.cast('trace_log_f', function (_, msg)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+ ffi.C.kr_log_fmt(
+ ffi.C.LOG_GRP_REQDBG, -- but the original [group] tag also remains in the string
+ LOG_DEBUG,
+ 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=policy.DEBUG_ALWAYS', -- no meaningful locations
+ '[%-6s]%s', LOG_GRP_REQDBG_TAG, msg) -- msg should end with newline already
+end)
+ffi.gc(debug_logline_cb, free_cb)
+
+-- LOG_DEBUG without log_trace and without code locations
+local function log_notrace(req, fmt, ...)
+ ffi.C.kr_log_fmt(ffi.C.LOG_GRP_REQDBG, LOG_DEBUG,
+ 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=', -- no meaningful locations
+ '%s', string.format( -- convert in lua, as integers in C varargs would pass as double
+ '[%-6s][%-6s][%05u.00] ' .. fmt,
+ LOG_GRP_REQDBG_TAG, LOG_GRP_POLICY_TAG, req.uid, ...)
+ )
+end
+
+local debug_logfinish_cb = ffi.cast('trace_callback_f', function (req)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+ log_notrace(req, 'following rrsets were marked as interesting:\n%s\n',
+ req:selected_tostring())
+ if req.answer ~= nil then
+ log_notrace(req, 'answer packet:\n%s\n', req.answer)
+ else
+ log_notrace(req, 'answer packet DROPPED\n')
+ end
+end)
+ffi.gc(debug_logfinish_cb, free_cb)
+
+-- log request packet
+function policy.REQTRACE(_, req)
+ log_notrace(req, 'request packet:\n%s', req.qsource.packet)
+end
+
+-- log how the request arrived, notably the client's IP
+function policy.IPTRACE(_, req)
+ if req.qsource.addr == nil then
+ log_notrace(req, 'request packet arrived internally\n')
+ else
+ -- stringify transport flags: struct kr_request_qsource_flags
+ local qf = req.qsource.flags
+ local qf_str = qf.tcp and 'TCP' or 'UDP'
+ if qf.tls then qf_str = qf_str .. ' + TLS' end
+ if qf.http then qf_str = qf_str .. ' + HTTP' end
+ if qf.xdp then qf_str = qf_str .. ' + XDP' end
+
+ log_notrace(req, 'request packet arrived from %s to %s (%s)\n',
+ req.qsource.addr, req.qsource.dst_addr, qf_str)
+ end
+ return nil -- chain rule
+end
+
+function policy.DEBUG_ALWAYS(state, req)
+ policy.QTRACE(state, req)
+ req:trace_chain_callbacks(debug_logline_cb, debug_logfinish_cb)
+ policy.REQTRACE(state, req)
+end
+
+local debug_stashlog_cb = ffi.cast('trace_log_f', function (req, msg)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+
+ -- stash messages for conditional logging in trace_finish
+ local stash = req:vars()['policy_debug_stash']
+ table.insert(stash, ffi.string(msg))
+end)
+ffi.gc(debug_stashlog_cb, free_cb)
+
+-- buffer debug logs and print then only if test() returns a truthy value
+function policy.DEBUG_IF(test)
+ local debug_finish_cb = ffi.cast('trace_callback_f', function (cbreq)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+ if test(cbreq) then
+ policy.REQTRACE(nil, cbreq)
+ debug_logfinish_cb(cbreq) -- unconditional version
+
+ local stash = cbreq:vars()['policy_debug_stash']
+ for _, line in ipairs(stash) do -- don't want one huge entry
+ ffi.C.kr_log_fmt(ffi.C.LOG_GRP_REQDBG, LOG_DEBUG,
+ 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=', -- no meaningful locations
+ '[%-6s]%s', LOG_GRP_REQDBG_TAG, line)
+ end
+ end
+ end)
+ ffi.gc(debug_finish_cb, function (func) func:free() end)
+
+ return function (state, req)
+ req:vars()['policy_debug_stash'] = {}
+ policy.QTRACE(state, req)
+ req:trace_chain_callbacks(debug_stashlog_cb, debug_finish_cb)
+ return
+ end
+end
+
+policy.DEBUG_CACHE_MISS = policy.DEBUG_IF(
+ function(req)
+ return not req:all_from_cache()
+ end
+)
+
+policy.DENY = policy.DENY_MSG() -- compatibility with < 2.0
+
+function policy.DROP(_, req)
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ req:set_extended_error(kres.extended_error.PROHIBITED, "U5KL")
+ log_policy_action(req, 'DROP')
+ return kres.FAIL
+end
+
+function policy.NO_ANSWER(_, req)
+ req.options.NO_ANSWER = true
+ log_policy_action(req, 'NO_ANSWER')
+ return kres.FAIL
+end
+
+function policy.REFUSE(_, req)
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ answer:rcode(kres.rcode.REFUSED)
+ answer:ad(false)
+ req:set_extended_error(kres.extended_error.PROHIBITED, "EIM4")
+ log_policy_action(req, 'REFUSE')
+ return kres.DONE
+end
+
+function policy.TC(state, req)
+ -- Avoid non-UDP queries
+ if req.qsource.addr == nil or req.qsource.flags.tcp then
+ return state
+ end
+
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ answer:tc(1)
+ answer:ad(false)
+ log_policy_action(req, 'TC')
+ return kres.DONE
+end
+
+function policy.QTRACE(_, req)
+ local qry = req:current()
+ req.options.TRACE = true
+ qry.flags.TRACE = true
+ return -- this allows to continue iterating over policy list
+end
+
+-- Evaluate packet in given rules to determine policy action
+function policy.evaluate(rules, req, query, state)
+ for i = 1, #rules do
+ local rule = rules[i]
+ if not rule.suspended then
+ local action = rule.cb(req, query)
+ if action ~= nil then
+ rule.count = rule.count + 1
+ local next_state = action(state, req)
+ if next_state then -- Not a chain rule,
+ return next_state -- stop on first match
+ end
+ end
+ end
+ end
+ return
+end
+
+-- Add rule to policy list
+function policy.add(rule, postrule)
+ -- Compatibility with 1.0.0 API
+ -- it will be dropped in 1.2.0
+ if rule == policy then
+ rule = postrule
+ postrule = nil
+ end
+ -- End of compatibility shim
+ local desc = {id=getruleid(), cb=rule, count=0}
+ table.insert(postrule and policy.postrules or policy.rules, desc)
+ return desc
+end
+
+-- Remove rule from a list
+local function delrule(rules, id)
+ for i, r in ipairs(rules) do
+ if r.id == id then
+ table.remove(rules, i)
+ return true
+ end
+ end
+ return false
+end
+
+-- Delete rule from policy list
+function policy.del(id)
+ if not delrule(policy.rules, id) then
+ if not delrule(policy.postrules, id) then
+ return false
+ end
+ end
+ return true
+end
+
+-- Convert list of string names to domain names
+function policy.todnames(names)
+ for i, v in ipairs(names) do
+ names[i] = kres.str2dname(v)
+ end
+ return names
+end
+
+-- RFC1918 Private, local, broadcast, test and special zones
+-- Considerations: RFC6761, sec 6.1.
+-- https://www.iana.org/assignments/locally-served-dns-zones
+local private_zones = {
+ -- RFC6303
+ '10.in-addr.arpa.',
+ '16.172.in-addr.arpa.',
+ '17.172.in-addr.arpa.',
+ '18.172.in-addr.arpa.',
+ '19.172.in-addr.arpa.',
+ '20.172.in-addr.arpa.',
+ '21.172.in-addr.arpa.',
+ '22.172.in-addr.arpa.',
+ '23.172.in-addr.arpa.',
+ '24.172.in-addr.arpa.',
+ '25.172.in-addr.arpa.',
+ '26.172.in-addr.arpa.',
+ '27.172.in-addr.arpa.',
+ '28.172.in-addr.arpa.',
+ '29.172.in-addr.arpa.',
+ '30.172.in-addr.arpa.',
+ '31.172.in-addr.arpa.',
+ '168.192.in-addr.arpa.',
+ '0.in-addr.arpa.',
+ '254.169.in-addr.arpa.',
+ '2.0.192.in-addr.arpa.',
+ '100.51.198.in-addr.arpa.',
+ '113.0.203.in-addr.arpa.',
+ '255.255.255.255.in-addr.arpa.',
+ -- RFC7793
+ '64.100.in-addr.arpa.',
+ '65.100.in-addr.arpa.',
+ '66.100.in-addr.arpa.',
+ '67.100.in-addr.arpa.',
+ '68.100.in-addr.arpa.',
+ '69.100.in-addr.arpa.',
+ '70.100.in-addr.arpa.',
+ '71.100.in-addr.arpa.',
+ '72.100.in-addr.arpa.',
+ '73.100.in-addr.arpa.',
+ '74.100.in-addr.arpa.',
+ '75.100.in-addr.arpa.',
+ '76.100.in-addr.arpa.',
+ '77.100.in-addr.arpa.',
+ '78.100.in-addr.arpa.',
+ '79.100.in-addr.arpa.',
+ '80.100.in-addr.arpa.',
+ '81.100.in-addr.arpa.',
+ '82.100.in-addr.arpa.',
+ '83.100.in-addr.arpa.',
+ '84.100.in-addr.arpa.',
+ '85.100.in-addr.arpa.',
+ '86.100.in-addr.arpa.',
+ '87.100.in-addr.arpa.',
+ '88.100.in-addr.arpa.',
+ '89.100.in-addr.arpa.',
+ '90.100.in-addr.arpa.',
+ '91.100.in-addr.arpa.',
+ '92.100.in-addr.arpa.',
+ '93.100.in-addr.arpa.',
+ '94.100.in-addr.arpa.',
+ '95.100.in-addr.arpa.',
+ '96.100.in-addr.arpa.',
+ '97.100.in-addr.arpa.',
+ '98.100.in-addr.arpa.',
+ '99.100.in-addr.arpa.',
+ '100.100.in-addr.arpa.',
+ '101.100.in-addr.arpa.',
+ '102.100.in-addr.arpa.',
+ '103.100.in-addr.arpa.',
+ '104.100.in-addr.arpa.',
+ '105.100.in-addr.arpa.',
+ '106.100.in-addr.arpa.',
+ '107.100.in-addr.arpa.',
+ '108.100.in-addr.arpa.',
+ '109.100.in-addr.arpa.',
+ '110.100.in-addr.arpa.',
+ '111.100.in-addr.arpa.',
+ '112.100.in-addr.arpa.',
+ '113.100.in-addr.arpa.',
+ '114.100.in-addr.arpa.',
+ '115.100.in-addr.arpa.',
+ '116.100.in-addr.arpa.',
+ '117.100.in-addr.arpa.',
+ '118.100.in-addr.arpa.',
+ '119.100.in-addr.arpa.',
+ '120.100.in-addr.arpa.',
+ '121.100.in-addr.arpa.',
+ '122.100.in-addr.arpa.',
+ '123.100.in-addr.arpa.',
+ '124.100.in-addr.arpa.',
+ '125.100.in-addr.arpa.',
+ '126.100.in-addr.arpa.',
+ '127.100.in-addr.arpa.',
+
+ -- RFC6303
+ -- localhost_reversed handles ::1
+ '0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.',
+ 'd.f.ip6.arpa.',
+ '8.e.f.ip6.arpa.',
+ '9.e.f.ip6.arpa.',
+ 'a.e.f.ip6.arpa.',
+ 'b.e.f.ip6.arpa.',
+ '8.b.d.0.1.0.0.2.ip6.arpa.',
+ -- RFC8375
+ 'home.arpa.',
+}
+policy.todnames(private_zones)
+
+-- @var Default rules
+policy.rules = {}
+policy.postrules = {}
+policy.special_names = {
+ -- XXX: beware of special_names_optim() when modifying these filters
+ {
+ cb=policy.suffix_common(policy.DENY_MSG(
+ 'Blocking is mandated by standards, see references on '
+ .. 'https://www.iana.org/assignments/'
+ .. 'locally-served-dns-zones/locally-served-dns-zones.xhtml',
+ kres.extended_error.NOTSUP),
+ private_zones, todname('arpa.')),
+ count=0
+ },
+ {
+ cb=policy.suffix(policy.DENY_MSG(
+ 'Blocking is mandated by standards, see references on '
+ .. 'https://www.iana.org/assignments/'
+ .. 'special-use-domain-names/special-use-domain-names.xhtml',
+ kres.extended_error.NOTSUP),
+ {
+ todname('test.'),
+ todname('onion.'),
+ todname('invalid.'),
+ todname('local.'), -- RFC 8375.4
+ }),
+ count=0
+ },
+ {
+ cb=policy.suffix(localhost, {dname_localhost}),
+ count=0
+ },
+ {
+ cb=policy.suffix_common(localhost_reversed, {
+ todname('127.in-addr.arpa.'),
+ todname('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.')},
+ todname('arpa.')),
+ count=0
+ },
+}
+
+-- Return boolean; false = no special name may apply, true = some might apply.
+-- The point is to *efficiently* filter almost all QNAMEs that do not apply.
+local function special_names_optim(req, sname)
+ local qname_size = req.qsource.packet.qname_size
+ if qname_size < 9 then return true end -- don't want to special-case bad array access
+ local root = sname + qname_size - 1
+ return
+ -- .a???. or .t???.
+ (root[-5] == 4 and (root[-4] == 97 or root[-4] == 116))
+ -- .on???. or .in?????. or lo???. or *ost.
+ or (root[-6] == 5 and root[-5] == 111 and root[-4] == 110)
+ or (root[-8] == 7 and root[-7] == 105 and root[-6] == 110)
+ or (root[-6] == 5 and root[-5] == 108 and root[-4] == 111)
+ or (root[-3] == 111 and root[-2] == 115 and root[-1] == 116)
+end
+
+-- Top-down policy list walk until we hit a match
+-- the caller is responsible for reordering policy list
+-- from most specific to least specific.
+-- Some rules may be chained, in this case they are evaluated
+-- as a dependency chain, e.g. r1,r2,r3 -> r3(r2(r1(state)))
+policy.layer = {
+ begin = function(state, req)
+ -- Don't act on "finished" cases.
+ if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end
+ local qry = req:initial() -- same as :current() but more descriptive
+ return policy.evaluate(policy.rules, req, qry, state)
+ or (special_names_optim(req, qry.sname)
+ and policy.evaluate(policy.special_names, req, qry, state))
+ or state
+ end,
+ finish = function(state, req)
+ -- Optimization for the typical case
+ if #policy.postrules == 0 then return state end
+ -- Don't act on failed cases.
+ if bit.band(state, kres.FAIL) ~= 0 then return state end
+ return policy.evaluate(policy.postrules, req, req:initial(), state) or state
+ end
+}
+
+return policy
diff --git a/modules/policy/policy.rpz.test.lua b/modules/policy/policy.rpz.test.lua
new file mode 100644
index 0000000..94fb9ce
--- /dev/null
+++ b/modules/policy/policy.rpz.test.lua
@@ -0,0 +1,65 @@
+
+local function prepare_cache()
+ cache.open(100*MB)
+ cache.clear()
+
+ local ffi = require('ffi')
+ local c = kres.context().cache
+
+ local passthru_addr = '\127\0\0\9'
+ rr_passthru = kres.rrset(todname('rpzpassthru.'), kres.type.A, kres.class.IN, 2147483647)
+ assert(rr_passthru:add_rdata(passthru_addr, #passthru_addr))
+ assert(c:insert(rr_passthru, nil, ffi.C.KR_RANK_SECURE + ffi.C.KR_RANK_AUTH))
+
+ c:commit()
+end
+
+local check_answer = require('test_utils').check_answer
+
+local function test_rpz()
+ check_answer('"CNAME ." return NXDOMAIN',
+ 'nxdomain.', kres.type.A, kres.rcode.NXDOMAIN)
+ check_answer('"CNAME *." return NODATA',
+ 'nodata.', kres.type.A, kres.rcode.NOERROR, {})
+ check_answer('"CNAME *. on wildcard" return NODATA',
+ 'nodata.nxdomain.', kres.type.A, kres.rcode.NOERROR, {})
+ check_answer('"CNAME rpz-drop." be dropped',
+ 'rpzdrop.', kres.type.A, kres.rcode.SERVFAIL)
+ check_answer('"CNAME rpz-passthru" return A rrset',
+ 'rpzpassthru.', kres.type.A, kres.rcode.NOERROR, '127.0.0.9')
+ check_answer('"A 192.168.5.5" return local A rrset',
+ 'rra.', kres.type.A, kres.rcode.NOERROR, '192.168.5.5')
+ check_answer('"A 192.168.6.6" with suffixed zone name in owner return local A rrset',
+ 'rra-zonename-suffix.', kres.type.A, kres.rcode.NOERROR, '192.168.6.6')
+ check_answer('"A 192.168.7.7" with suffixed zone name in owner return local A rrset',
+ 'testdomain.rra.', kres.type.A, kres.rcode.NOERROR, '192.168.7.7')
+ check_answer('non existing AAAA on rra domain return NODATA',
+ 'rra.', kres.type.AAAA, kres.rcode.NOERROR, {})
+ check_answer('"A 192.168.8.8" and domain with uppercase and lowercase letters',
+ 'case.sensitive.', kres.type.A, kres.rcode.NOERROR, '192.168.8.8')
+ check_answer('"A 192.168.8.8" and domain with uppercase and lowercase letters',
+ 'CASe.SENSItivE.', kres.type.A, kres.rcode.NOERROR, '192.168.8.8')
+ check_answer('two AAAA records',
+ 'two.records.', kres.type.AAAA, kres.rcode.NOERROR,
+ {'2001:db8::2', '2001:db8::1'})
+end
+
+local function test_rpz_soa()
+ check_answer('"CNAME ." return NXDOMAIN (SOA origin)',
+ 'nxdomain-fqdn.', kres.type.A, kres.rcode.NXDOMAIN)
+ check_answer('"CNAME *." return NODATA (SOA origin)',
+ 'nodata-fqdn.', kres.type.A, kres.rcode.NOERROR, {})
+end
+
+net.ipv4 = false
+net.ipv6 = false
+
+prepare_cache()
+
+policy.add(policy.rpz(policy.DENY, 'policy.test.rpz'))
+policy.add(policy.rpz(policy.DENY, 'policy.test.rpz.soa'))
+
+return {
+ test_rpz,
+ test_rpz_soa,
+}
diff --git a/modules/policy/policy.slice.test.lua b/modules/policy/policy.slice.test.lua
new file mode 100644
index 0000000..89c1b05
--- /dev/null
+++ b/modules/policy/policy.slice.test.lua
@@ -0,0 +1,109 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+-- check lua-psl is available
+local has_psl = pcall(require, 'psl')
+if not has_psl then
+ os.exit(77) -- SKIP policy.slice
+end
+
+-- unload modules which are not related to this test
+if ta_update then
+ modules.unload('ta_update')
+end
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+if priming then
+ modules.unload('priming')
+end
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+local kres = require('kres')
+
+local slice_queries = {
+ {},
+ {},
+ {},
+}
+
+local function sliceaction(index)
+ return function(_, req)
+ -- log query
+ local qry = req:current()
+ local name = kres.dname2str(qry:name())
+ local count = slice_queries[index][name]
+ if not count then
+ count = 0
+ end
+ slice_queries[index][name] = count + 1
+
+ -- refuse query
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ answer:rcode(kres.rcode.REFUSED)
+ answer:ad(false)
+ return kres.DONE
+ end
+end
+
+-- configure slicing
+policy.add(policy.slice(
+ policy.slice_randomize_psl(0),
+ sliceaction(1),
+ sliceaction(2),
+ sliceaction(3)
+))
+
+local function check_slice(desc, qname, qtype, expected_slice, expected_count)
+ callback = function()
+ count = slice_queries[expected_slice][qname]
+ qtype_str = kres.tostring.type[qtype]
+ same(count, expected_count, desc .. qname .. ' ' .. qtype_str)
+ end
+ resolve(qname, qtype, kres.class.IN, {}, callback)
+end
+
+local function test_randomize_psl()
+ local desc = 'randomize_psl() same qname, different qtype (same slice): '
+ check_slice(desc, 'example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'example.com.', kres.type.AAAA, 2, 2)
+ check_slice(desc, 'example.com.', kres.type.MX, 2, 3)
+ check_slice(desc, 'example.com.', kres.type.NS, 2, 4)
+
+ desc = 'randomize_psl() subdomain in same slice: '
+ check_slice(desc, 'a.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'b.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'c.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'a.a.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'a.a.a.example.com.', kres.type.A, 2, 1)
+
+ desc = 'randomize_psl() different qnames in different slices: '
+ check_slice(desc, 'example2.com.', kres.type.A, 1, 1)
+ check_slice(desc, 'example5.com.', kres.type.A, 3, 1)
+
+ desc = 'randomize_psl() check unregistrable domains: '
+ check_slice(desc, '.', kres.type.A, 3, 1)
+ check_slice(desc, 'com.', kres.type.A, 1, 1)
+ check_slice(desc, 'cz.', kres.type.A, 2, 1)
+ check_slice(desc, 'co.uk.', kres.type.A, 1, 1)
+
+ desc = 'randomize_psl() check multi-level reg. domains: '
+ check_slice(desc, 'example.co.uk.', kres.type.A, 3, 1)
+ check_slice(desc, 'a.example.co.uk.', kres.type.A, 3, 1)
+ check_slice(desc, 'b.example.co.uk.', kres.type.MX, 3, 1)
+ check_slice(desc, 'example2.co.uk.', kres.type.A, 2, 1)
+
+ desc = 'randomize_psl() reg. domain - always ends up in slice: '
+ check_slice(desc, 'fdsnnsdfvkdn.com.', kres.type.A, 3, 1)
+ check_slice(desc, 'bdfbd.cz.', kres.type.A, 1, 1)
+ check_slice(desc, 'nrojgvn.net.', kres.type.A, 1, 1)
+ check_slice(desc, 'jnojtnbv.engineer.', kres.type.A, 2, 1)
+ check_slice(desc, 'dfnjonfdsjg.gov.', kres.type.A, 1, 1)
+ check_slice(desc, 'okfjnosdfgjn.mil.', kres.type.A, 1, 1)
+ check_slice(desc, 'josdhnojn.test.', kres.type.A, 2, 1)
+end
+
+return {
+ test_randomize_psl,
+}
diff --git a/modules/policy/policy.test.lua b/modules/policy/policy.test.lua
new file mode 100644
index 0000000..69dda1f
--- /dev/null
+++ b/modules/policy/policy.test.lua
@@ -0,0 +1,145 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+-- setup resolver
+-- policy module should be loaded by default, do not load it explicitly
+
+-- do not attempt to contact outside world, operate only on cache
+net.ipv4 = false
+net.ipv6 = false
+-- do not listen, test is driven by config code
+env.KRESD_NO_LISTEN = true
+
+-- test for default configuration
+local function test_tls_forward()
+ boom(policy.TLS_FORWARD, {}, 'TLS_FORWARD without arguments')
+ boom(policy.TLS_FORWARD, {'1'}, 'TLS_FORWARD with non-table argument')
+ boom(policy.TLS_FORWARD, {{}}, 'TLS_FORWARD with empty table')
+ boom(policy.TLS_FORWARD, {{{}}}, 'TLS_FORWARD with empty target table')
+ boom(policy.TLS_FORWARD, {{{bleble=''}}}, 'TLS_FORWARD with invalid parameters in table')
+
+ boom(policy.TLS_FORWARD, {{'1'}}, 'TLS_FORWARD with invalid IP address')
+ boom(policy.TLS_FORWARD, {{{'::1', bleble=''}}}, 'TLS_FORWARD with valid IP and invalid parameters')
+ boom(policy.TLS_FORWARD, {{{'127.0.0.1'}}}, 'TLS_FORWARD with missing auth parameters')
+
+ ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true}}), 'TLS_FORWARD with no authentication')
+ boom(policy.TLS_FORWARD, {{{'100:dead::', insecure=true},
+ {'100:DEAD:0::', insecure=true}
+ }}, 'TLS_FORWARD with duplicate IP addresses is not allowed')
+ ok(policy.TLS_FORWARD({{'100:dead::2', insecure=true},
+ {'100:dead::2@443', insecure=true}
+ }), 'TLS_FORWARD with duplicate IP addresses but different ports is allowed')
+ ok(policy.TLS_FORWARD({{'100:dead::3', insecure=true},
+ {'100:beef::3', insecure=true}
+ }), 'TLS_FORWARD with different IPv6 addresses is allowed')
+ ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true},
+ {'127.0.0.2', insecure=true}
+ }), 'TLS_FORWARD with different IPv4 addresses is allowed')
+
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256=''}}}, 'TLS_FORWARD with empty pin_sha256')
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='č'}}}, 'TLS_FORWARD with bad pin_sha256')
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='d161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}},
+ 'TLS_FORWARD with bad pin_sha256 (short base64)')
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='bbd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}},
+ 'TLS_FORWARD with bad pin_sha256 (long base64)')
+ ok(policy.TLS_FORWARD({
+ {'::1', pin_sha256='g1PpXsxqPchz2tH6w9kcvVXqzQ0QclhInFP2+VWOqic='}
+ }), 'TLS_FORWARD with base64 pin_sha256')
+ ok(policy.TLS_FORWARD({
+ {'::1', pin_sha256={
+ 'ev1xcdU++dY9BlcX0QoKeaUftvXQvNIz/PCss1Z/3ek=',
+ 'SgnqTFcvYduWX7+VUnlNFT1gwSNvQdZakH7blChIRbM=',
+ 'bd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg=',
+ }}}), 'TLS_FORWARD with a table of pins')
+
+ -- ok(policy.TLS_FORWARD({{'::1', hostname='test.', ca_file='/tmp/ca.crt'}}), 'TLS_FORWARD with hostname + CA cert')
+ ok(policy.TLS_FORWARD({{'::1', hostname='test.'}}),
+ 'TLS_FORWARD with just hostname (use system CA store)')
+ boom(policy.TLS_FORWARD, {{{'::1', ca_file='/tmp/ca.crt'}}},
+ 'TLS_FORWARD with just CA cert')
+ boom(policy.TLS_FORWARD, {{{'::1', hostname='', ca_file='/tmp/ca.crt'}}},
+ 'TLS_FORWARD with empty hostname + CA cert')
+ boom(policy.TLS_FORWARD, {
+ {{'::1', hostname='test.', ca_file='/dev/a_file_which_surely_does_NOT_exist!'}}
+ }, 'TLS_FORWARD with hostname + unreadable CA cert')
+
+end
+
+local function test_slice()
+ boom(policy.slice, {function() end}, 'policy.slice() without any action')
+ ok(policy.slice, {function() end, policy.FORWARD, policy.FORWARD})
+end
+
+local function mirror_parser(srv, cv, nqueries)
+ local ffi = require('ffi')
+ local test_end = 0
+ local TIMEOUT = 5 -- seconds
+
+ while true do
+ local input = srv:xread('*a', 'bn', TIMEOUT)
+ if not input then
+ cv:signal()
+ return false, 'mirror: timeout'
+ end
+ --print(#input, input)
+ -- convert query to knot_pkt_t
+ local wire = ffi.cast("void *", input)
+ local pkt = ffi.gc(ffi.C.knot_pkt_new(wire, #input, nil), ffi.C.knot_pkt_free)
+ if not pkt then
+ cv:signal()
+ return false, 'mirror: packet allocation error'
+ end
+
+ local result = ffi.C.knot_pkt_parse(pkt, 0)
+ if result ~= 0 then
+ cv:signal()
+ return false, 'mirror: packet parse error'
+ end
+ --print(pkt)
+ test_end = test_end + 1
+
+ if test_end == nqueries then
+ cv:signal()
+ return true, 'packet mirror pass'
+ end
+
+ end
+end
+
+local function test_mirror()
+ local kluautil = require('kluautil')
+ local socket = require('cqueues.socket')
+ local cond = require('cqueues.condition')
+ local cv = cond.new()
+ local queries = {}
+ local srv = socket.listen({
+ host = "127.0.0.1",
+ port = 36659,
+ type = socket.SOCK_DGRAM,
+ })
+ -- binary mode, no buffering
+ srv:setmode('bn', 'bn')
+
+ queries["bla.mujtest.cz."] = kres.type.AAAA
+ queries["bla.mujtest2.cz."] = kres.type.AAAA
+
+ -- UDP server for test
+ worker.bg_worker.cq:wrap(function()
+ local err, msg = mirror_parser(srv, cv, kluautil.kr_table_len(queries))
+
+ ok(err, msg)
+ end)
+
+ policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest.cz.'})))
+ policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest2.cz.'})))
+
+ for name, rtype in pairs(queries) do
+ resolve(name, rtype)
+ end
+
+ cv:wait()
+end
+
+return {
+ test_tls_forward,
+ test_mirror,
+ test_slice,
+}
diff --git a/modules/policy/policy.test.rpz b/modules/policy/policy.test.rpz
new file mode 100644
index 0000000..d962e9f
--- /dev/null
+++ b/modules/policy/policy.test.rpz
@@ -0,0 +1,18 @@
+$ORIGIN testdomain.
+$TTL 30
+testdomain. SOA nonexistent.testdomain. testdomain. 1 12h 15m 3w 2h
+ NS nonexistent.testdomain.
+
+nxdomain CNAME .
+nodata CNAME *.
+*.nxdomain CNAME *.
+rpzdrop CNAME rpz-drop.
+rpzpassthru CNAME rpz-passthru.
+rra A 192.168.5.5
+rra-zonename-suffix A 192.168.6.6
+testdomain.rra.testdomain. A 192.168.7.7
+CaSe.SeNSiTiVe A 192.168.8.8
+
+two.records AAAA 2001:db8::2
+two.records AAAA 2001:db8::1
+
diff --git a/modules/policy/policy.test.rpz.soa b/modules/policy/policy.test.rpz.soa
new file mode 100644
index 0000000..ad18aa4
--- /dev/null
+++ b/modules/policy/policy.test.rpz.soa
@@ -0,0 +1,5 @@
+test2domain. SOA nonexistent.test2domain. test2domain. 1 12h 15m 3w 2h
+ NS nonexistent.test2domain.
+
+nxdomain-fqdn.test2domain. CNAME .
+nodata-fqdn.test2domain. CNAME *.
diff --git a/modules/policy/test.integr/deckard.yaml b/modules/policy/test.integr/deckard.yaml
new file mode 100644
index 0000000..9c6cb70
--- /dev/null
+++ b/modules/policy/test.integr/deckard.yaml
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+programs:
+- name: kresd
+ binary: kresd
+ additional:
+ - --noninteractive
+ templates:
+ - modules/policy/test.integr/kresd_config.j2
+ - tests/integration/hints_zone.j2
+ configs:
+ - config
+ - hints
diff --git a/modules/policy/test.integr/kresd_config.j2 b/modules/policy/test.integr/kresd_config.j2
new file mode 100644
index 0000000..668c792
--- /dev/null
+++ b/modules/policy/test.integr/kresd_config.j2
@@ -0,0 +1,59 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+{% raw %}
+policy.add(policy.domains(policy.DENY, {todname('example.com')}))
+policy.add(policy.suffix(policy.REFUSE, {todname('refuse.example.com')}))
+
+-- make sure DNSSEC is turned off for tests
+trust_anchors.remove('.')
+
+-- Disable RFC5011 TA update
+if ta_update then
+ modules.unload('ta_update')
+end
+
+-- Disable RFC8145 signaling, scenario doesn't provide expected answers
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+
+-- Disable RFC8109 priming, scenario doesn't provide expected answers
+if priming then
+ modules.unload('priming')
+end
+
+-- Disable this module because it make one priming query
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+_hint_root_file('hints')
+cache.size = 2*MB
+log_level('debug')
+{% endraw %}
+
+net = { '{{SELF_ADDR}}' }
+
+
+{% if QMIN == "false" %}
+option('NO_MINIMIZE', true)
+{% else %}
+option('NO_MINIMIZE', false)
+{% endif %}
+
+
+-- Self-checks on globals
+assert(help() ~= nil)
+assert(worker.id ~= nil)
+-- Self-checks on facilities
+assert(cache.count() == 0)
+assert(cache.stats() ~= nil)
+assert(cache.backends() ~= nil)
+assert(worker.stats() ~= nil)
+assert(net.interfaces() ~= nil)
+-- Self-checks on loaded stuff
+assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
+assert(#modules.list() > 0)
+-- Self-check timers
+ev = event.recurrent(1 * sec, function (ev) return 1 end)
+event.cancel(ev)
+ev = event.after(0, function (ev) return 1 end)
diff --git a/modules/policy/test.integr/refuse.rpl b/modules/policy/test.integr/refuse.rpl
new file mode 100644
index 0000000..08f9942
--- /dev/null
+++ b/modules/policy/test.integr/refuse.rpl
@@ -0,0 +1,44 @@
+; SPDX-License-Identifier: GPL-3.0-or-later
+; config options
+ stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET.
+CONFIG_END
+
+SCENARIO_BEGIN Test refuse policy
+
+STEP 10 QUERY
+ENTRY_BEGIN
+REPLY RD AD
+SECTION QUESTION
+www.refuse.example.com. IN A
+ENTRY_END
+
+STEP 20 CHECK_ANSWER
+ENTRY_BEGIN
+MATCH all answer
+; AD must not be set in the answer
+REPLY QR RD RA REFUSED
+SECTION QUESTION
+www.refuse.example.com. IN A
+SECTION ANSWER
+ENTRY_END
+
+STEP 30 QUERY
+ENTRY_BEGIN
+REPLY RD AD
+SECTION QUESTION
+example.com. IN A
+ENTRY_END
+
+STEP 40 CHECK_ANSWER
+ENTRY_BEGIN
+MATCH all answer
+REPLY QR RD AA RA NXDOMAIN
+SECTION QUESTION
+example.com. IN A
+SECTION ANSWER
+SECTION AUTHORITY
+example.com. 10800 IN SOA example.com. nobody.invalid. 1 3600 1200 604800 10800
+ENTRY_END
+
+
+SCENARIO_END