summaryrefslogtreecommitdiffstats
path: root/modules/policy
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 10:41:58 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 10:41:58 +0000
commit1852910ef0fd7393da62b88aee66ee092208748e (patch)
treead3b659dbbe622b58a5bda4fe0b5e1d80eee9277 /modules/policy
parentInitial commit. (diff)
downloadknot-resolver-1852910ef0fd7393da62b88aee66ee092208748e.tar.xz
knot-resolver-1852910ef0fd7393da62b88aee66ee092208748e.zip
Adding upstream version 5.3.1.upstream/5.3.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'modules/policy')
-rw-r--r--modules/policy/.packaging/test.config4
-rw-r--r--modules/policy/README.rst665
-rw-r--r--modules/policy/lua-aho-corasick/LICENSE28
-rw-r--r--modules/policy/lua-aho-corasick/Makefile134
-rw-r--r--modules/policy/lua-aho-corasick/README.md40
-rw-r--r--modules/policy/lua-aho-corasick/ac.cxx101
-rw-r--r--modules/policy/lua-aho-corasick/ac.h49
-rw-r--r--modules/policy/lua-aho-corasick/ac_fast.cxx468
-rw-r--r--modules/policy/lua-aho-corasick/ac_fast.hpp124
-rw-r--r--modules/policy/lua-aho-corasick/ac_lua.cxx173
-rw-r--r--modules/policy/lua-aho-corasick/ac_slow.cxx318
-rw-r--r--modules/policy/lua-aho-corasick/ac_slow.hpp158
-rw-r--r--modules/policy/lua-aho-corasick/ac_util.hpp69
-rw-r--r--modules/policy/lua-aho-corasick/load_ac.lua90
-rw-r--r--modules/policy/lua-aho-corasick/mytest.cxx200
-rw-r--r--modules/policy/lua-aho-corasick/tests/Makefile65
-rw-r--r--modules/policy/lua-aho-corasick/tests/ac_bench.cxx519
-rw-r--r--modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx135
-rw-r--r--modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx275
-rw-r--r--modules/policy/lua-aho-corasick/tests/dict/README.txt1
-rw-r--r--modules/policy/lua-aho-corasick/tests/dict/dict1.txt11
-rw-r--r--modules/policy/lua-aho-corasick/tests/load_ac_test.lua82
-rw-r--r--modules/policy/lua-aho-corasick/tests/lua_test.lua67
-rw-r--r--modules/policy/lua-aho-corasick/tests/test_base.hpp60
-rw-r--r--modules/policy/lua-aho-corasick/tests/test_bigfile.cxx167
-rw-r--r--modules/policy/lua-aho-corasick/tests/test_main.cxx33
-rw-r--r--modules/policy/meson.build49
-rw-r--r--modules/policy/noipv6.test.integr/broken-ipv6.rpl47
-rw-r--r--modules/policy/noipv6.test.integr/deckard.yaml12
-rw-r--r--modules/policy/noipv6.test.integr/kresd_config.j259
-rw-r--r--modules/policy/noipvx.test.integr/broken-ipvx.rpl35
-rw-r--r--modules/policy/noipvx.test.integr/deckard.yaml12
-rw-r--r--modules/policy/noipvx.test.integr/kresd_config.j260
-rw-r--r--modules/policy/policy.lua998
-rw-r--r--modules/policy/policy.rpz.test.lua56
-rw-r--r--modules/policy/policy.slice.test.lua109
-rw-r--r--modules/policy/policy.test.lua145
-rw-r--r--modules/policy/policy.test.rpz18
-rw-r--r--modules/policy/test.integr/deckard.yaml12
-rw-r--r--modules/policy/test.integr/kresd_config.j258
-rw-r--r--modules/policy/test.integr/refuse.rpl25
41 files changed, 5731 insertions, 0 deletions
diff --git a/modules/policy/.packaging/test.config b/modules/policy/.packaging/test.config
new file mode 100644
index 0000000..60c9ddc
--- /dev/null
+++ b/modules/policy/.packaging/test.config
@@ -0,0 +1,4 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+modules.load('policy')
+assert(policy)
+quit()
diff --git a/modules/policy/README.rst b/modules/policy/README.rst
new file mode 100644
index 0000000..83ca772
--- /dev/null
+++ b/modules/policy/README.rst
@@ -0,0 +1,665 @@
+.. SPDX-License-Identifier: GPL-3.0-or-later
+
+.. default-domain:: py
+.. module:: policy
+
+.. _mod-policy:
+
+
+Query policies
+==============
+
+This module can block, rewrite, or alter inbound queries based on user-defined policies. It does not affect queries generated by the resolver itself, e.g. when following CNAME chains etc.
+
+Each policy *rule* has two parts: a *filter* and an *action*. A *filter* selects which queries will be affected by the policy, and *action* which modifies queries matching the associated filter.
+
+Typically a rule is defined as follows: ``filter(action(action parameters), filter parameters)``. For example, a filter can be ``suffix`` which matches queries whose suffix part is in specified set, and one of possible actions is :any:`policy.DENY`, which denies resolution. These are combined together into ``policy.suffix(policy.DENY, {todname('badguy.example.')})``. The rule is effective when it is added into rule table using ``policy.add()``, please see examples below.
+
+This module is enabled by default because it implements mandatory :rfc:`6761` logic.
+When no rule applies to a query, built-in rules for `special-use <https://www.iana.org/assignments/special-use-domain-names/special-use-domain-names.xhtml>`_ and `locally-served <http://www.iana.org/assignments/locally-served-dns-zones>`_ domain names are applied.
+These rules can be overriden by action :any:`policy.PASS`. For debugging purposes you can also add ``modules.unload('policy')`` to your config to unload the module.
+
+
+Filters
+-------
+A *filter* selects which queries will be affected by specified Actions_. There are several policy filters available in the ``policy.`` table:
+
+.. function:: all(action)
+
+ Always applies the action.
+
+.. function:: pattern(action, pattern)
+
+ Applies the action if query name matches a `Lua regular expression <http://lua-users.org/wiki/PatternsTutorial>`_.
+
+.. function:: suffix(action, suffix_table)
+
+ Applies the action if query name suffix matches one of suffixes in the table (useful for "is domain in zone" rules).
+
+.. note:: For speed this filter requires domain names in DNS wire format, not textual representation, so each label in the name must be prefixed with its length. Always use convenience function :func:`policy.todnames` for automatic conversion from strings! For example:
+
+ .. code-block:: lua
+
+ policy.suffix(policy.DENY, policy.todnames({'example.com', 'example.net'}))
+
+.. function:: suffix_common(action, suffix_table[, common_suffix])
+
+ :param action: action if the pattern matches query name
+ :param suffix_table: table of valid suffixes
+ :param common_suffix: common suffix of entries in suffix_table
+
+ Like :func:`policy.suffix` match, but you can also provide a common suffix of all matches for faster processing (nil otherwise).
+ This function is faster for small suffix tables (in the order of "hundreds").
+
+.. :noindex: function:: rpz(default_action, path, [watch])
+
+ Implements a subset of `Response Policy Zone` (RPZ_) stored in zonefile format. See below for details: :func:`policy.rpz`.
+
+It is also possible to define custom filter function with any name.
+
+.. function:: custom_filter(state, query)
+
+ :param state: Request processing state :c:type:`kr_layer_state`, typically not used by filter function.
+ :param query: Incoming DNS query as :c:type:`kr_query` structure.
+ :return: An `action <#actions>`_ function or ``nil`` if filter did not match.
+
+ Typically filter function is generated by another function, which allows easy parametrization - this technique is called `closure <https://www.lua.org/pil/6.1.html>`_. An practical example of such filter generator is:
+
+.. code-block:: lua
+
+ function match_query_type(action, target_qtype)
+ return function (state, query)
+ if query.stype == target_qtype then
+ -- filter matched the query, return action function
+ return action
+ else
+ -- filter did not match, continue with next filter
+ return nil
+ end
+ end
+ end
+
+This custom filter can be used as any other built-in filter.
+For example this applies our custom filter and executes action :any:`policy.DENY` on all queries of type `HINFO`:
+
+.. code-block:: lua
+
+ -- custom filter which matches HINFO queries, action is policy.DENY
+ policy.add(match_query_type(policy.DENY, kres.type.HINFO))
+
+
+.. _mod-policy-actions:
+
+Actions
+-------
+An *action* is a function which modifies DNS request, and is either of type *chain* or *non-chain*:
+
+ * `Non-chain actions`_ modify state of the request and stop rule processing. An example of such action is :ref:`forwarding`.
+ * `Chain actions`_ modify state of the request and allow other rules to evaluate and act on the same request. One such example is :func:`policy.MIRROR`.
+
+Non-chain actions
+^^^^^^^^^^^^^^^^^
+
+Following actions stop the policy matching on the query, i.e. other rules are not evaluated once rule with following actions matches:
+
+.. py:attribute:: PASS
+
+ Let the query pass through; it's useful to make exceptions before wider rules. For example:
+
+ More specific whitelist rule must preceede generic blacklist rule:
+
+ .. code-block:: lua
+
+ -- Whitelist 'good.example.com'
+ policy.add(policy.pattern(policy.PASS, todname('good.example.com.')))
+ -- Block all names below example.com
+ policy.add(policy.suffix(policy.DENY, {todname('example.com.')}))
+
+.. py:attribute:: DENY
+
+ Deny existence of names matching filter, i.e. reply NXDOMAIN authoritatively.
+
+.. function:: DENY_MSG(message)
+
+ Deny existence of a given domain and add explanatory message. NXDOMAIN reply contains an additional explanatory message as TXT record in the additional section.
+
+.. py:attribute:: DROP
+
+ Terminate query resolution and return SERVFAIL to the requestor.
+
+.. py:attribute:: REFUSE
+
+ Terminate query resolution and return REFUSED to the requestor.
+
+.. py:attribute:: TC
+
+ Force requestor to use TCP. It sets truncated bit (*TC*) in response to true if the request came through UDP, which will force standard-compliant clients to retry the request over TCP.
+
+.. function:: REROUTE({{subnet,target}, ...})
+
+ Reroute IP addresses in response matching given subnet to given target, e.g. ``{'192.0.2.0/24', '127.0.0.0'}`` will rewrite '192.0.2.55' to '127.0.0.55', see :ref:`renumber module <mod-renumber>` for more information. See :func:`policy.add` and do not forget to specify that this is *postrule*. Quick example:
+
+ .. code-block:: lua
+
+ -- this policy is enforced on answers
+ -- therefore we have to use 'postrule'
+ -- (the "true" at the end of policy.add)
+ policy.add(policy.REROUTE({'192.0.2.0/24', '127.0.0.0'}), true)
+
+.. function:: ANSWER({ type = { rdata=data, [ttl=1] } }, [nodata=false])
+
+ Overwrite Resource Records in responses with specified values.
+
+ * type
+ - RR type to be replaced, e.g. ``[kres.type.A]`` or `numeric value <https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4>`_.
+ * rdata
+ - RR data in DNS wire format, i.e. binary form specific for given RR type. Set of multiple RRs can be specified as table ``{ rdata1, rdata2, ... }``. Use helper function :func:`kres.str2ip` to generate wire format for A and AAAA records.
+ * ttl
+ - TTL in seconds. Default: 1 second.
+ * nodata
+ - If type requested by client is not configured in this policy:
+
+ - ``true``: Return empty answer (`NODATA`).
+ - ``false``: Ignore this policy and continue processing other rules.
+
+ Default: ``false``.
+
+ .. code-block:: lua
+
+ -- policy to change IPv4 address and TTL for example.com
+ policy.add(
+ policy.suffix(
+ policy.ANSWER(
+ { [kres.type.A] = { rdata=kres.str2ip('192.0.2.7'), ttl=300 } }
+ ), { todname('example.com') }))
+ -- policy to generate two TXT records (specified in binary format) for example.net
+ policy.add(
+ policy.suffix(
+ policy.ANSWER(
+ { [kres.type.TXT] = { rdata={'\005first', '\006second'}, ttl=5 } }
+ ), { todname('example.net') }))
+
+More complex non-chain actions are described in their own chapters, namely:
+
+ * :ref:`forwarding`
+ * `Response Policy Zones`_
+
+Chain actions
+^^^^^^^^^^^^^
+
+Following actions act on request and then processing continue until first non-chain action (specified in the previous section) is triggered:
+
+.. function:: MIRROR(ip_address)
+
+ Send copy of incoming DNS queries to a given IP address using DNS-over-UDP and continue resolving them as usual. This is useful for sanity testing new versions of DNS resolvers.
+
+ .. code-block:: lua
+
+ policy.add(policy.all(policy.MIRROR('127.0.0.2')))
+
+.. function:: FLAGS(set, clear)
+
+ Set and/or clear some flags for the query. There can be multiple flags to set/clear. You can just pass a single flag name (string) or a set of names. Flag names correspond to :c:type:`kr_qflags` structure. Use only if you know what you are doing.
+
+.. py:attribute:: QTRACE
+
+ Pretty-print DNS response packets from authoritative servers into the verbose log for the query and its sub-queries. It's useful for debugging weird DNS servers. Verbose logging must be enabled using :func:`verbose` for this policy to be effective.
+
+ .. code-block:: lua
+
+ -- log answers from all authoritative servers involved in resolving
+ -- requests for example.net. and its subdomains
+ policy.add(policy.suffix(policy.QTRACE, policy.todnames({'example.net'})))
+
+.. py:attribute:: REQTRACE
+
+ Pretty-print DNS requests from clients into the verbose log. It's useful for debugging weird DNS clients. Verbose logging must be enabled using :func:`verbose` for this policy to be effective. It makes most sense together with :ref:`mod-view`.
+
+.. py:attribute:: DEBUG_ALWAYS
+
+ Enable extra verbose logging for all requests, including cache hits. See caveats for :func:`policy.DEBUG_IF`.
+
+.. py:data:: DEBUG_CACHE_MISS
+
+ Enable extra verbose logging but print logs only for requests which required information which was not available locally (i.e. requests which forced resolver to communicate over network). Intended usage is for debugging problems with remote servers. This action typically produces less logs than :any:`policy.DEBUG_ALWAYS` but all caveats from :func:`policy.DEBUG_IF` apply as well.
+
+ .. code-block:: lua
+
+ policy.add(policy.suffix(
+ policy.DEBUG_CACHE_MISS,
+ policy.todnames({'example.com.'})))
+
+.. py:function:: DEBUG_IF(test_function)
+
+ :param test_function: Function with single argument of type :c:type:`kr_request` which returns ``true`` if verbose logs for a given request should be printed and ``false`` otherwise.
+
+ Enable extra verbose logging but print logs only for requests which match condition specified by ``test_function``. This allows to fine-tune which requests should be printed.
+
+ .. warning:: Verbose logging has significant performance impact on resolver and might also overload you logging system because one request can easily generate tens of kilobytes of logs. Always use appropriate `Filters`_ to limit number of requests triggering this action to a minimum!
+
+ .. note:: ``test_function`` is evaluated only when request is finished. As a result verbose logs for all requests must be collected until request is finished because it is not possible to know beforehand how ``test_function`` at the end evaluates given request. When a request is finalized logs are either printed or thrown away.
+
+ Example usage which gathers verbose logs for all requests in subtree ``dnssec-failed.org.`` and prints verbose logs for all requests finished with states different than ``kres.DONE`` (most importantly ``kres.FAIL``, see :c:type:`kr_layer_state`).
+
+ .. code-block:: lua
+
+ policy.add(policy.suffix(
+ policy.DEBUG_IF(function(req)
+ return (req.state ~= kres.DONE)
+ end),
+ policy.todnames({'dnssec-failed.org.'})))
+
+
+Custom actions
+^^^^^^^^^^^^^^
+
+.. function:: custom_action(state, request)
+
+ :param state: Request processing state :c:type:`kr_layer_state`.
+ :param request: Current DNS request as :c:type:`kr_request` structure.
+ :return: Returning a new :c:type:`kr_layer_state` prevents evaluating other policy rules. Returning ``nil`` creates a `chain action <#actions>`_ and allows to continue evaluating other rules.
+
+ This is real example of an action function:
+
+.. code-block:: lua
+
+ -- Custom action which generates fake A record
+ local ffi = require('ffi')
+ local function fake_A_record(state, req)
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ local qry = req:current()
+ if qry.stype ~= kres.type.A then
+ return state
+ end
+ ffi.C.kr_pkt_make_auth_header(answer)
+ answer:rcode(kres.rcode.NOERROR)
+ answer:begin(kres.section.ANSWER)
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\192\168\1\3')
+ return kres.DONE
+ end
+
+This custom action can be used as any other built-in action.
+For example this applies our *fake A record action* and executes it on all queries in subtree ``example.net``:
+
+.. code-block:: lua
+
+ policy.add(policy.suffix(fake_A_record, policy.todnames({'example.net'})))
+
+The action function can implement arbitrary logic so it is possible to implement complex heuristics, e.g. to deflect `Slow drip DNS attacks <https://secure64.com/water-torture-slow-drip-dns-ddos-attack>`_ or gray-list resolution of misbehaving zones.
+
+.. warning:: The policy module currently only looks at whole DNS requests. The rules won't be re-applied e.g. when following CNAMEs.
+
+.. _forwarding:
+
+Forwarding
+----------
+
+Forwarding action alters behavior for cache-miss events. If an information is missing in the local cache the resolver will *forward* the query to *another DNS resolver* for resolution (instead of contacting authoritative servers directly). DNS answers from the remote resolver are then processed locally and sent back to the original client.
+
+Actions :func:`policy.FORWARD`, :func:`policy.TLS_FORWARD` and :func:`policy.STUB` accept up to four IP addresses at once and the resolver will automatically select IP address which statistically responds the fastest.
+
+.. function:: FORWARD(ip_address)
+ FORWARD({ ip_address, [ip_address, ...] })
+
+ Forward cache-miss queries to specified IP addresses (without encryption), DNSSEC validate received answers and cache them. Target IP addresses are expected to be DNS resolvers.
+
+ .. code-block:: lua
+
+ -- Forward all queries to public resolvers https://www.nic.cz/odvr
+ policy.add(policy.all(
+ policy.FORWARD(
+ {'2001:148f:fffe::1', '2001:148f:ffff::1',
+ '185.43.135.1', '193.14.47.1'})))
+
+ A variant which uses encrypted DNS-over-TLS transport is called :func:`policy.TLS_FORWARD`, please see section :ref:`tls-forwarding`.
+
+.. function:: STUB(ip_address)
+ STUB({ ip_address, [ip_address, ...] })
+
+ Similar to :func:`policy.FORWARD` but *without* attempting DNSSEC validation.
+ Each request may be either answered from cache or simply sent to one of the IPs with proxying back the answer.
+
+ This mode does not support encryption and should be used only for `Replacing part of the DNS tree`_.
+ Use :func:`policy.FORWARD` mode if possible.
+
+ .. code-block:: lua
+
+ -- Answers for reverse queries about the 192.168.1.0/24 subnet
+ -- are to be obtained from IP address 192.0.2.1 port 5353
+ -- This disables DNSSEC validation!
+ policy.add(policy.suffix(
+ policy.STUB('192.0.2.1@5353'),
+ {todname('1.168.192.in-addr.arpa')}))
+
+.. note:: Forwarding targets must support
+ `EDNS <https://en.wikipedia.org/wiki/Extension_mechanisms_for_DNS>`_ and
+ `0x20 randomization <https://tools.ietf.org/html/draft-vixie-dnsext-dns0x20-00>`_.
+
+
+.. _tls-forwarding:
+
+Forwarding over TLS protocol (DNS-over-TLS)
+-------------------------------------------
+.. function:: TLS_FORWARD( { {ip_address, authentication}, [...] } )
+
+ Same as :func:`policy.FORWARD` but send query over DNS-over-TLS protocol (encrypted).
+ Each target IP address needs explicit configuration how to validate
+ TLS certificate so each IP address is configured by pair:
+ ``{ip_address, authentication}``. See sections below for more details.
+
+
+Policy :func:`policy.TLS_FORWARD` allows you to forward queries using `Transport Layer Security`_ protocol, which hides the content of your queries from an attacker observing the network traffic. Further details about this protocol can be found in :rfc:`7858` and `IETF draft dprive-dtls-and-tls-profiles`_.
+
+Queries affected by :func:`policy.TLS_FORWARD` will always be resolved over TLS connection. Knot Resolver does not implement fallback to non-TLS connection, so if TLS connection cannot be established or authenticated according to the configuration, the resolution will fail.
+
+To test this feature you need to either :ref:`configure Knot Resolver as DNS-over-TLS server <tls-server-config>`, or pick some public DNS-over-TLS server. Please see `DNS Privacy Project`_ homepage for list of public servers.
+
+.. note:: Some public DNS-over-TLS providers may apply rate-limiting which
+ makes their service incompatible with Knot Resolver's TLS forwarding.
+ Notably, `Google Public DNS
+ <https://developers.google.com/speed/public-dns/docs/dns-over-tls>`_ doesn't
+ work as of 2019-07-10.
+
+When multiple servers are specified, the one with the lowest round-trip time is used.
+
+CA+hostname authentication
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+Traditional PKI authentication requires server to present certificate with specified hostname, which is issued by one of trusted CAs. Example policy is:
+
+.. code-block:: lua
+
+ policy.TLS_FORWARD({
+ {'2001:DB8::d0c', hostname='res.example.com'}})
+
+- ``hostname`` must be a valid domain name matching server's certificate. It will also be sent to the server as SNI_.
+- ``ca_file`` optionally contains a path to a CA certificate (or certificate bundle) in `PEM format`_.
+ If you omit that, the system CA certificate store will be used instead (usually sufficient).
+ A list of paths is also accepted, but all of them must be valid PEMs.
+
+Key-pinned authentication
+^^^^^^^^^^^^^^^^^^^^^^^^^
+Instead of CAs, you can specify hashes of accepted certificates in ``pin_sha256``.
+They are in the usual format -- base64 from sha256.
+You may still specify ``hostname`` if you want SNI_ to be sent.
+
+.. _tls-examples:
+
+TLS Examples
+^^^^^^^^^^^^
+
+.. code-block:: lua
+
+ modules = { 'policy' }
+ -- forward all queries over TLS to the specified server
+ policy.add(policy.all(policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}})))
+ -- for brevity, other TLS examples omit policy.add(policy.all())
+ -- single server authenticated using its certificate pin_sha256
+ policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}}) -- pin_sha256 is base64-encoded
+ -- single server authenticated using hostname and system-wide CA certificates
+ policy.TLS_FORWARD({{'192.0.2.1', hostname='res.example.com'}})
+ -- single server using non-standard port
+ policy.TLS_FORWARD({{'192.0.2.1@443', pin_sha256='YQ=='}}) -- use @ or # to specify port
+ -- single server with multiple valid pins (e.g. anycast)
+ policy.TLS_FORWARD({{'192.0.2.1', pin_sha256={'YQ==', 'Wg=='}})
+ -- multiple servers, each with own authenticator
+ policy.TLS_FORWARD({ -- please note that { here starts list of servers
+ {'192.0.2.1', pin_sha256='Wg=='},
+ -- server must present certificate issued by specified CA and hostname must match
+ {'2001:DB8::d0c', hostname='res.example.com', ca_file='/etc/knot-resolver/tlsca.crt'}
+ })
+
+Forwarding to multiple targets
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+With the use of :func:`policy.slice` function, it is possible to split the
+entire DNS namespace into distinct slices. When used in conjuction with
+:func:`policy.TLS_FORWARD`, it's possible to forward different queries to
+different targets.
+
+.. function:: slice(slice_func, action[, action[, ...])
+
+ :param slice_func: slicing function that returns index based on query
+ :param action: action to be performed for the slice
+
+ This function splits the entire domain space into multiple slices (determined
+ by the number of provided ``actions``). A ``slice_func`` is called to determine
+ which slice a query belongs to. The corresponding ``action`` is then executed.
+
+
+.. function:: slice_randomize_psl(seed = os.time() / (3600 * 24 * 7))
+
+ :param seed: seed for random assignment
+
+ The function initializes and returns a slicing function, which
+ deterministically assigns ``query`` to a slice based on the query name.
+
+ It utilizes the `Public Suffix List`_ to ensure domains under the same
+ registrable domain end up in a single slice. (see example below)
+
+ ``seed`` can be used to re-shuffle the slicing algorhitm when the slicing
+ function is initialized. By default, the assigment is re-shuffled after one
+ week (when resolver restart / reloads config). To force a stable
+ distribution, pass a fixed value. To re-shuffle on every resolver restart,
+ use ``os.time()``.
+
+ The following example demonstrates a distribution among 3 slices::
+
+ slice 1/3:
+ example.com
+ a.example.com
+ b.example.com
+ x.b.example.com
+ example3.com
+
+ slice 2/3:
+ example2.co.uk
+
+ slice 3/3:
+ example.co.uk
+ a.example.co.uk
+
+These two functions can be used together to forward queries for names
+in different parts of DNS name space to different target servers:
+
+.. code-block:: lua
+
+ policy.add(policy.slice(
+ policy.slice_randomize_psl(),
+ policy.TLS_FORWARD({{'192.0.2.1', hostname='res.example.com'}}),
+ policy.TLS_FORWARD({
+ -- multiple servers can be specified for a single slice
+ -- the one with lowest round-trip time will be used
+ {'193.17.47.1', hostname='odvr.nic.cz'},
+ {'185.43.135.1', hostname='odvr.nic.cz'},
+ })
+ ))
+
+.. note:: The privacy implications of using this feature aren't clear. Since
+ websites often make requests to multiple domains, these might be forwarded
+ to different targets. This could result in decreased privacy (e.g. when the
+ remote targets are both logging or otherwise processing your DNS traffic).
+ The intended use-case is to use this feature with semi-trusted resolvers
+ which claim to do no logging (such as those listed on `dnsprivacy.org
+ <https://dnsprivacy.org/wiki/display/DP/DNS+Privacy+Test+Servers>`_), to
+ decrease the potential exposure of your DNS data to a malicious resolver
+ operator.
+
+.. _dns-graft:
+
+Replacing part of the DNS tree
+------------------------------
+
+Following procedure applies only to domains which have different content
+publicly and internally. For example this applies to "your own" top-level domain
+``example.`` which does not exist in the public (global) DNS namespace.
+
+Dealing with these internal-only domains requires extra configuration because
+DNS was designed as "single namespace" and local modifications like adding
+your own TLD break this assumption.
+
+.. warning:: Use of internal names which are not delegated from the public DNS
+ *is causing technical problems* with caching and DNSSEC validation
+ and generally makes DNS operation more costly.
+ We recommend **against** using these non-delegated names.
+
+To make such internal domain available in your resolver it is necessary to
+*graft* your domain onto the public DNS namespace,
+but *grafting* creates new issues:
+
+These *grafted* domains will be rejected by DNSSEC validation
+because such domains are technically indistinguishable from an spoofing attack
+against the public DNS.
+Therefore, if you trust the remote resolver which hosts the internal-only domain,
+and you trust your link to it, you need to use the :func:`policy.STUB` policy
+instead of :func:`policy.FORWARD` to disable DNSSEC validation for those
+*grafted* domains.
+
+Secondly, after disabling DNSSEC validation you have to solve another issue
+caused by grafting. For example, if you grafted your own top-level domain
+``example.`` onto the public DNS namespace, at some point the root server might
+send proof-of-nonexistence proving e.g. that there are no other top-level
+domain in between names ``events.`` and ``exchange.``, effectivelly proving
+non-existence of ``example.``.
+
+These proofs-of-nonexistence protect public DNS from spoofing but break
+*grafted* domains because proofs will be latter used by resolver
+(when the positive records for the grafted domain timeout from cache),
+effectivelly making grafted domain unavailable.
+The easiest work-around is to disable reading from cache for grafted domains.
+
+.. code-block:: lua
+ :caption: Example configuration grafting domains onto public DNS namespace
+
+ extraTrees = policy.todnames(
+ {'faketldtest.',
+ 'sld.example.',
+ 'internal.example.com.',
+ '2.0.192.in-addr.arpa.' -- this applies to reverse DNS tree as well
+ })
+ -- Beware: the rule order is important, as policy.STUB is not a chain action.
+ policy.add(policy.suffix(policy.FLAGS({'NO_CACHE'}), extraTrees))
+ policy.add(policy.suffix(policy.STUB({'2001:db8::1'}), extraTrees))
+
+Response policy zones
+---------------------
+ .. warning::
+
+ There is no published Internet Standard for RPZ_ and implementations vary.
+ At the moment Knot Resolver supports limited subset of RPZ format and deviates
+ from implementation in BIND. Nevertheless it is good enough
+ for blocking large lists of spam or advertising domains.
+
+
+
+ The RPZ file format is basically a DNS zone file with *very special* semantics.
+ For example:
+
+ .. code-block:: none
+
+ ; left hand side ; TTL and class ; right hand side
+ ; encodes RPZ trigger ; ignored ; encodes action
+ ; (i.e. filter)
+ blocked.domain.example 600 IN CNAME . ; block main domain
+ *.blocked.domain.example 600 IN CNAME . ; block subdomains
+
+ The only "trigger" supported in Knot Resolver is query name,
+ i.e. left hand side must be a domain name which triggers the action specified
+ on the right hand side.
+
+ Subset of possible RPZ actions is supported, namely:
+
+ .. csv-table::
+ :header: "RPZ Right Hand Side", "Knot Resolver Action", "BIND Compatibility"
+
+ "``.``", "``action`` is used", "compatible if ``action`` is :any:`policy.DENY`"
+ "``*.``", ":func:`policy.ANSWER`", "yes"
+ "``rpz-passthru.``", ":any:`policy.PASS`", "yes"
+ "``rpz-tcp-only.``", ":any:`policy.TC`", "yes"
+ "``rpz-drop.``", ":any:`policy.DROP`", "no [#]_"
+ "fake A/AAAA", ":func:`policy.ANSWER`", "yes"
+ "fake CNAME", "not supported", "no"
+
+ .. [#] Our :any:`policy.DROP` returns *SERVFAIL* answer (for historical reasons).
+
+
+.. function:: rpz(action, path, [watch = true])
+
+ :param action: the default action for match in the zone; typically you want :any:`policy.DENY`
+ :param path: path to zone file
+ :param watch: boolean, if true, the file will be reloaded on file change
+
+ Enforce RPZ_ rules. This can be used in conjunction with published blocklist feeds.
+ The RPZ_ operation is well described in this `Jan-Piet Mens's post`_,
+ or the `Pro DNS and BIND`_ book.
+
+ For example, we can store the example snippet with domain ``blocked.domain.example``
+ (above) into file ``/etc/knot-resolver/blocklist.rpz`` and configure resolver to
+ answer with *NXDOMAIN* plus the specified additional text to queries for this domain:
+
+ .. code-block:: lua
+
+ policy.add(
+ policy.rpz(policy.DENY_MSG('domain blocked by your resolver operator'),
+ '/etc/knot-resolver/blocklist.rpz',
+ true))
+
+ Resolver will reload RPZ file at run-time if the RPZ file changes.
+ Recommended RPZ update procedure is to store new blocklist in a new file
+ (*newblocklist.rpz*) and then rename the new file to the original file name
+ (*blocklist.rpz*). This avoids problems where resolver might attempt
+ to re-read an incomplete file.
+
+
+
+Additional properties
+---------------------
+
+Most properties (actions, filters) are described above.
+
+.. function:: add(rule, postrule)
+
+ :param rule: added rule, i.e. ``policy.pattern(policy.DENY, '[0-9]+\2cz')``
+ :param postrule: boolean, if true the rule will be evaluated on answer instead of query
+ :return: rule description
+
+ Add a new policy rule that is executed either or queries or answers, depending on the ``postrule`` parameter. You can then use the returned rule description to get information and unique identifier for the rule, as well as match count.
+
+ .. code-block:: lua
+
+ -- mirror all queriesm, keep handle so we can retrieve information later
+ local rule = policy.add(policy.all(policy.MIRROR('127.0.0.2')))
+ -- we can print statistics about this rule any time later
+ print(string.format('id: %d, matched queries: %d', rule.id, rule.count)
+
+.. function:: del(id)
+
+ :param id: identifier of a given rule returned by :func:`policy.add`
+ :return: boolean ``true`` if rule was deleted, ``false`` otherwise
+
+ Remove a rule from policy list.
+
+.. function:: todnames({name, ...})
+
+ :param: names table of domain names in textual format
+
+ Returns table of domain names in wire format converted from strings.
+
+ .. code-block:: lua
+
+ -- Convert single name
+ assert(todname('example.com') == '\7example\3com\0')
+ -- Convert table of names
+ policy.todnames({'example.com', 'me.cz'})
+ { '\7example\3com\0', '\2me\2cz\0' }
+
+
+.. _RPZ: https://dnsrpz.info/
+.. _`PEM format`: https://en.wikipedia.org/wiki/Privacy-enhanced_Electronic_Mail
+.. _`Pro DNS and BIND`: http://www.zytrax.com/books/dns/ch7/rpz.html
+.. _`Jan-Piet Mens's post`: http://jpmens.net/2011/04/26/how-to-configure-your-bind-resolvers-to-lie-using-response-policy-zones-rpz/
+.. _`Transport Layer Security`: https://en.wikipedia.org/wiki/Transport_Layer_Security
+.. _`DNS Privacy Project`: https://dnsprivacy.org/
+.. _`IETF draft dprive-dtls-and-tls-profiles`: https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles
+.. _SNI: https://en.wikipedia.org/wiki/Server_Name_Indication
+.. _`Public Suffix List`: https://publicsuffix.org
diff --git a/modules/policy/lua-aho-corasick/LICENSE b/modules/policy/lua-aho-corasick/LICENSE
new file mode 100644
index 0000000..dd65f72
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/LICENSE
@@ -0,0 +1,28 @@
+ Copyright (c) 2014 CloudFlare, Inc. All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of CloudFlare, Inc. nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
diff --git a/modules/policy/lua-aho-corasick/Makefile b/modules/policy/lua-aho-corasick/Makefile
new file mode 100644
index 0000000..6471664
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/Makefile
@@ -0,0 +1,134 @@
+OS := $(shell uname)
+
+ifeq ($(OS), Darwin)
+ SO_EXT := dylib
+else
+ SO_EXT := so
+endif
+
+#############################################################################
+#
+# Binaries we are going to build
+#
+#############################################################################
+#
+C_SO_NAME = libac.$(SO_EXT)
+LUA_SO_NAME = ahocorasick.$(SO_EXT)
+AR_NAME = libac.a
+
+#############################################################################
+#
+# Compile and link flags
+#
+#############################################################################
+PREFIX ?= /usr/local
+LUA_VERSION := 5.1
+LUA_INCLUDE_DIR := $(PREFIX)/include/lua$(LUA_VERSION)
+SO_TARGET_DIR := $(PREFIX)/lib/lua/$(LUA_VERSION)
+LUA_TARGET_DIR := $(PREFIX)/share/lua/$(LUA_VERSION)
+
+# Available directives:
+# -DDEBUG : Turn on debugging support
+# -DVERIFY : To verify if the slow-version and fast-version implementations
+# get exactly the same result. Note -DVERIFY implies -DDEBUG.
+#
+COMMON_FLAGS = -O3 #-g -DVERIFY -msse2 -msse3 -msse4.1
+COMMON_FLAGS += -fvisibility=hidden -Wall $(CXXFLAGS) $(MY_CXXFLAGS) $(CPPFLAGS)
+
+SO_CXXFLAGS = $(COMMON_FLAGS) -fPIC
+SO_LFLAGS = $(COMMON_FLAGS) $(LDFLAGS)
+AR_CXXFLAGS = $(COMMON_FLAGS)
+
+# -DVERIFY implies -DDEBUG
+ifneq ($(findstring -DVERIFY, $(COMMON_FLAGS)), )
+ifeq ($(findstring -DDEBUG, $(COMMON_FLAGS)), )
+ COMMON_FLAGS += -DDEBUG
+endif
+endif
+
+AR = ar
+AR_FLAGS = cru
+
+#############################################################################
+#
+# Divide source codes and objects into several categories
+#
+#############################################################################
+#
+SRC_COMMON := ac_fast.cxx ac_slow.cxx
+LIBAC_SO_SRC := $(SRC_COMMON) ac.cxx # source for libac.so
+LUA_SO_SRC := $(SRC_COMMON) ac_lua.cxx # source for ahocorasick.so
+LIBAC_A_SRC := $(LIBAC_SO_SRC) # source for libac.a
+
+#############################################################################
+#
+# Make rules
+#
+#############################################################################
+#
+.PHONY = all clean test benchmark prepare
+all : $(C_SO_NAME) $(LUA_SO_NAME) $(AR_NAME)
+
+-include c_so_dep.txt
+-include lua_so_dep.txt
+-include ar_dep.txt
+
+BUILD_SO_DIR := build_so
+BUILD_AR_DIR := build_ar
+
+$(BUILD_SO_DIR) :; mkdir $@
+$(BUILD_AR_DIR) :; mkdir $@
+
+$(BUILD_SO_DIR)/%.o : %.cxx | $(BUILD_SO_DIR)
+ $(CXX) $< -c $(SO_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@
+
+$(BUILD_AR_DIR)/%.o : %.cxx | $(BUILD_AR_DIR)
+ $(CXX) $< -c $(AR_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@
+
+ifneq ($(OS), Darwin)
+$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared -Wl,-soname=$(C_SO_NAME) $(SO_LFLAGS) -o $@
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt
+
+$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared -Wl,-soname=$(LUA_SO_NAME) $(SO_LFLAGS) -o $@
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt
+
+else
+$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared $(SO_LFLAGS) -o $@
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt
+
+$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o})
+ $(CXX) $+ -shared $(SO_LFLAGS) -o $@ -Wl,-undefined,dynamic_lookup
+ cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt
+endif
+
+$(AR_NAME) : $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.o})
+ $(AR) $(AR_FLAGS) $@ $+
+ cat $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.d}) > lua_so_dep.txt
+
+#############################################################################
+#
+# Misc
+#
+#############################################################################
+#
+test : $(C_SO_NAME)
+ $(MAKE) -C tests && \
+ luajit tests/lua_test.lua && \
+ luajit tests/load_ac_test.lua
+
+benchmark: $(C_SO_NAME)
+ $(MAKE) benchmark -C tests
+
+clean :
+ -rm -rf *.o *.d c_so_dep.txt lua_so_dep.txt ar_dep.txt $(TEST) \
+ $(C_SO_NAME) $(LUA_SO_NAME) $(TEST) $(BUILD_SO_DIR) $(BUILD_AR_DIR) \
+ $(AR_NAME)
+ make clean -C tests
+
+install:
+ install -D -m 755 $(C_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(C_SO_NAME)
+ install -D -m 755 $(LUA_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(LUA_SO_NAME)
+ install -D -m 664 load_ac.lua $(DESTDIR)/$(LUA_TARGET_DIR)/load_ac.lua
diff --git a/modules/policy/lua-aho-corasick/README.md b/modules/policy/lua-aho-corasick/README.md
new file mode 100644
index 0000000..b5cc406
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/README.md
@@ -0,0 +1,40 @@
+aho-corasick-lua
+================
+
+C++ and Lua Implementation of the Aho-Corasick (AC) string matching algorithm
+(http://dl.acm.org/citation.cfm?id=360855).
+
+We began with pure Lua implementation and realize the performance is not
+satisfactory. So we switch to C/C++ implementation.
+
+There are two shared objects provied by this package: libac.so and ahocorasick.so
+The former is a regular shared object which can be directly used by C/C++
+application, or by Lua via FFI; and the later is a Lua module. An example usage
+is shown bellow:
+
+```lua
+local ac = require "ahocorasick"
+local dict = {"string1", "string", "etc"}
+local acinst = ac.create(dict)
+local r = ac.match(acinst, "mystring")
+```
+
+For efficiency reasons, the implementation is slightly different from the
+standard AC algorithm in that it doesn't return a set of strings in the dictionary
+that match the given string, instead it only returns one of them in case the string
+matches. The functionality of our implementation can be (precisely) described by
+following pseudo-c snippet.
+
+```C
+string foo(input-string, dictionary) {
+ string ret = the-end-of-input-string;
+ for each string s in dictionary {
+ // find the first occurrence match sub-string.
+ ret = min(ret, strstr(input-string, s);
+ }
+ return ret;
+}
+```
+
+It's pretty easy to get rid of this limitation, just to associate each state with
+a spare bit-vector dipicting the set of strings recognized by that state.
diff --git a/modules/policy/lua-aho-corasick/ac.cxx b/modules/policy/lua-aho-corasick/ac.cxx
new file mode 100644
index 0000000..23fb3b5
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac.cxx
@@ -0,0 +1,101 @@
+// Interface functions for libac.so
+//
+#include "ac_slow.hpp"
+#include "ac_fast.hpp"
+#include "ac.h"
+
+static inline ac_result_t
+_match(buf_header_t* ac, const char* str, unsigned int len) {
+ AC_Buffer* buf = (AC_Buffer*)(void*)ac;
+ ASSERT(ac->magic_num == AC_MAGIC_NUM);
+
+ ac_result_t r = Match(buf, str, len);
+
+ #ifdef VERIFY
+ {
+ Match_Result r2 = buf->slow_impl->Match(str, len);
+ if (r.match_begin != r2.begin) {
+ ASSERT(0);
+ } else {
+ ASSERT((r.match_begin < 0) ||
+ (r.match_end == r2.end &&
+ r.pattern_idx == r2.pattern_idx));
+ }
+ }
+ #endif
+ return r;
+}
+
+extern "C" int
+ac_match2(ac_t* ac, const char* str, unsigned int len) {
+ ac_result_t r = _match((buf_header_t*)(void*)ac, str, len);
+ return r.match_begin;
+}
+
+extern "C" ac_result_t
+ac_match(ac_t* ac, const char* str, unsigned int len) {
+ return _match((buf_header_t*)(void*)ac, str, len);
+}
+
+extern "C" ac_result_t
+ac_match_longest_l(ac_t* ac, const char* str, unsigned int len) {
+ AC_Buffer* buf = (AC_Buffer*)(void*)ac;
+ ASSERT(((buf_header_t*)ac)->magic_num == AC_MAGIC_NUM);
+
+ ac_result_t r = Match_Longest_L(buf, str, len);
+ return r;
+}
+
+class BufAlloc : public Buf_Allocator {
+public:
+ virtual AC_Buffer* alloc(int sz) {
+ return (AC_Buffer*)(new unsigned char[sz]);
+ }
+
+ // Do not de-allocate the buffer when the BufAlloc die.
+ virtual void free() {}
+
+ static void myfree(AC_Buffer* buf) {
+ ASSERT(buf->hdr.magic_num == AC_MAGIC_NUM);
+ const char* b = (const char*)buf;
+ delete[] b;
+ }
+};
+
+extern "C" ac_t*
+ac_create(const char** strv, unsigned int* strlenv, unsigned int v_len) {
+ if (v_len >= 65535) {
+ // TODO: Currently we use 16-bit to encode pattern-index (see the
+ // comment to AC_State::is_term), therefore we are not able to
+ // handle pattern set with more than 65535 entries.
+ return 0;
+ }
+
+ ACS_Constructor *acc;
+#ifdef VERIFY
+ acc = new ACS_Constructor;
+#else
+ ACS_Constructor tmp;
+ acc = &tmp;
+#endif
+ acc->Construct(strv, strlenv, v_len);
+
+ BufAlloc ba;
+ AC_Converter cvt(*acc, ba);
+ AC_Buffer* buf = cvt.Convert();
+
+#ifdef VERIFY
+ buf->slow_impl = acc;
+#endif
+ return (ac_t*)(void*)buf;
+}
+
+extern "C" void
+ac_free(void* ac) {
+ AC_Buffer* buf = (AC_Buffer*)ac;
+#ifdef VERIFY
+ delete buf->slow_impl;
+#endif
+
+ BufAlloc::myfree(buf);
+}
diff --git a/modules/policy/lua-aho-corasick/ac.h b/modules/policy/lua-aho-corasick/ac.h
new file mode 100644
index 0000000..30bf447
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac.h
@@ -0,0 +1,49 @@
+#ifndef AC_H
+#define AC_H
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define AC_EXPORT __attribute__ ((visibility ("default")))
+
+/* If the subject-string dosen't match any of the given patterns, "match_begin"
+ * should be a negative; otherwise the substring of the subject-string,
+ * starting from offset "match_begin" to "match_end" incusively,
+ * should exactly match the pattern specified by the 'pattern_idx' (i.e.
+ * the pattern is "pattern_v[pattern_idx]" where the "pattern_v" is the
+ * first acutal argument passing to ac_create())
+ */
+typedef struct {
+ int match_begin;
+ int match_end;
+ int pattern_idx;
+} ac_result_t;
+
+struct ac_t;
+
+/* Create an AC instance. "pattern_v" is a vector of patterns, the length of
+ * i-th pattern is specified by "pattern_len_v[i]"; the number of patterns
+ * is specified by "vect_len".
+ *
+ * Return the instance on success, or NUL otherwise.
+ */
+ac_t* ac_create(const char** pattern_v, unsigned int* pattern_len_v,
+ unsigned int vect_len) AC_EXPORT;
+
+ac_result_t ac_match(ac_t*, const char *str, unsigned int len) AC_EXPORT;
+
+ac_result_t ac_match_longest_l(ac_t*, const char *str, unsigned int len) AC_EXPORT;
+
+/* Similar to ac_match() except that it only returns match-begin. The rationale
+ * for this interface is that luajit has hard time in dealing with strcture-
+ * return-value.
+ */
+int ac_match2(ac_t*, const char *str, unsigned int len) AC_EXPORT;
+
+void ac_free(void*) AC_EXPORT;
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* AC_H */
diff --git a/modules/policy/lua-aho-corasick/ac_fast.cxx b/modules/policy/lua-aho-corasick/ac_fast.cxx
new file mode 100644
index 0000000..9dbc2e6
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_fast.cxx
@@ -0,0 +1,468 @@
+#include <algorithm> // for std::sort
+#include "ac_slow.hpp"
+#include "ac_fast.hpp"
+
+uint32
+AC_Converter::Calc_State_Sz(const ACS_State* s) const {
+ AC_State dummy;
+ uint32 sz = offsetof(AC_State, input_vect);
+ sz += s->Get_GotoNum() * sizeof(dummy.input_vect[0]);
+
+ if (sz < sizeof(AC_State))
+ sz = sizeof(AC_State);
+
+ uint32 align = __alignof__(dummy);
+ sz = (sz + align - 1) & ~(align - 1);
+ return sz;
+}
+
+AC_Buffer*
+AC_Converter::Alloc_Buffer() {
+ const vector<ACS_State*>& all_states = _acs.Get_All_States();
+ const ACS_State* root_state = _acs.Get_Root_State();
+ uint32 root_fanout = root_state->Get_GotoNum();
+
+ // Step 1: Calculate the buffer size
+ AC_Ofst root_goto_ofst, states_ofst_ofst, first_state_ofst;
+
+ // part 1 : buffer header
+ uint32 sz = root_goto_ofst = sizeof(AC_Buffer);
+
+ // part 2: Root-node's goto function
+ if (likely(root_fanout != 255))
+ sz += 256;
+ else
+ root_goto_ofst = 0;
+
+ // part 3: mapping of state's relative position.
+ unsigned align = __alignof__(AC_Ofst);
+ sz = (sz + align - 1) & ~(align - 1);
+ states_ofst_ofst = sz;
+
+ sz += sizeof(AC_Ofst) * all_states.size();
+
+ // part 4: state's contents
+ align = __alignof__(AC_State);
+ sz = (sz + align - 1) & ~(align - 1);
+ first_state_ofst = sz;
+
+ uint32 state_sz = 0;
+ for (vector<ACS_State*>::const_iterator i = all_states.begin(),
+ e = all_states.end(); i != e; i++) {
+ state_sz += Calc_State_Sz(*i);
+ }
+ state_sz -= Calc_State_Sz(root_state);
+
+ sz += state_sz;
+
+ // Step 2: Allocate buffer, and populate header.
+ AC_Buffer* buf = _buf_alloc.alloc(sz);
+
+ buf->hdr.magic_num = AC_MAGIC_NUM;
+ buf->hdr.impl_variant = IMPL_FAST_VARIANT;
+ buf->buf_len = sz;
+ buf->root_goto_ofst = root_goto_ofst;
+ buf->states_ofst_ofst = states_ofst_ofst;
+ buf->first_state_ofst = first_state_ofst;
+ buf->root_goto_num = root_fanout;
+ buf->state_num = _acs.Get_State_Num();
+ return buf;
+}
+
+void
+AC_Converter::Populate_Root_Goto_Func(AC_Buffer* buf,
+ GotoVect& goto_vect) {
+ unsigned char *buf_base = (unsigned char*)(buf);
+ InputTy* root_gotos = (InputTy*)(buf_base + buf->root_goto_ofst);
+ const ACS_State* root_state = _acs.Get_Root_State();
+
+ root_state->Get_Sorted_Gotos(goto_vect);
+
+ // Renumber the ID of root-node's immediate kids.
+ uint32 new_id = 1;
+ bool full_fantout = (goto_vect.size() == 255);
+ if (likely(!full_fantout))
+ bzero(root_gotos, 256*sizeof(InputTy));
+
+ for (GotoVect::iterator i = goto_vect.begin(), e = goto_vect.end();
+ i != e; i++, new_id++) {
+ InputTy c = i->first;
+ ACS_State* s = i->second;
+ _id_map[s->Get_ID()] = new_id;
+
+ if (likely(!full_fantout))
+ root_gotos[c] = new_id;
+ }
+}
+
+AC_Buffer*
+AC_Converter::Convert() {
+ // Step 1: Some preparation stuff.
+ GotoVect gotovect;
+
+ _id_map.clear();
+ _ofst_map.clear();
+ _id_map.resize(_acs.Get_Next_Node_Id());
+ _ofst_map.resize(_acs.Get_Next_Node_Id());
+
+ // Step 2: allocate buffer to accommodate the entire AC graph.
+ AC_Buffer* buf = Alloc_Buffer();
+ unsigned char* buf_base = (unsigned char*)buf;
+
+ // Step 3: Root node need special care.
+ Populate_Root_Goto_Func(buf, gotovect);
+ buf->root_goto_num = gotovect.size();
+ _id_map[_acs.Get_Root_State()->Get_ID()] = 0;
+
+ // Step 4: Converting the remaining states by BFSing the graph.
+ // First of all, enter root's immediate kids to the working list.
+ vector<const ACS_State*> wl;
+ State_ID id = 1;
+ for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end();
+ i != e; i++, id++) {
+ ACS_State* s = i->second;
+ wl.push_back(s);
+ _id_map[s->Get_ID()] = id;
+ }
+
+ AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst);
+ AC_Ofst ofst = buf->first_state_ofst;
+ for (uint32 idx = 0; idx < wl.size(); idx++) {
+ const ACS_State* old_s = wl[idx];
+ AC_State* new_s = (AC_State*)(buf_base + ofst);
+
+ // This property should hold as we:
+ // - States are appended to worklist in the BFS order.
+ // - sibiling states are appended to worklist in the order of their
+ // corresponding input.
+ //
+ State_ID state_id = idx + 1;
+ ASSERT(_id_map[old_s->Get_ID()] == state_id);
+
+ state_ofst_vect[state_id] = ofst;
+
+ new_s->first_kid = wl.size() + 1;
+ new_s->depth = old_s->Get_Depth();
+ new_s->is_term = old_s->is_Terminal() ?
+ old_s->get_Pattern_Idx() + 1 : 0;
+
+ uint32 gotonum = old_s->Get_GotoNum();
+ new_s->goto_num = gotonum;
+
+ // Populate the "input" field
+ old_s->Get_Sorted_Gotos(gotovect);
+ uint32 input_idx = 0;
+ uint32 id = wl.size() + 1;
+ InputTy* input_vect = new_s->input_vect;
+ for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end();
+ i != e; i++, id++, input_idx++) {
+ input_vect[input_idx] = i->first;
+
+ ACS_State* kid = i->second;
+ _id_map[kid->Get_ID()] = id;
+ wl.push_back(kid);
+ }
+
+ _ofst_map[old_s->Get_ID()] = ofst;
+ ofst += Calc_State_Sz(old_s);
+ }
+
+ // This assertion might be useful to catch buffer overflow
+ ASSERT(ofst == buf->buf_len);
+
+ // Populate the fail-link field.
+ for (vector<const ACS_State*>::iterator i = wl.begin(), e = wl.end();
+ i != e; i++) {
+ const ACS_State* slow_s = *i;
+ State_ID fast_s_id = _id_map[slow_s->Get_ID()];
+ AC_State* fast_s = (AC_State*)(buf_base + state_ofst_vect[fast_s_id]);
+ if (const ACS_State* fl = slow_s->Get_FailLink()) {
+ State_ID id = _id_map[fl->Get_ID()];
+ fast_s->fail_link = id;
+ } else
+ fast_s->fail_link = 0;
+ }
+#ifdef DEBUG
+ //dump_buffer(buf, stderr);
+#endif
+ return buf;
+}
+
+static inline AC_State*
+Get_State_Addr(unsigned char* buf_base, AC_Ofst* StateOfstVect, uint32 state_id) {
+ ASSERT(state_id != 0 && "root node is handled in speical way");
+ ASSERT(state_id < ((AC_Buffer*)buf_base)->state_num);
+ return (AC_State*)(buf_base + StateOfstVect[state_id]);
+}
+
+// The performance of the binary search is critical to this work.
+//
+// Here we provide two versions of binary-search functions.
+// The non-pristine version seems to consistently out-perform "pristine" one on
+// bunch of benchmarks we tested. With the benchmark under tests/testinput/
+//
+// The speedup is following on my laptop (core i7, ubuntu):
+//
+// benchmark was is
+// ----------------------------------------
+// image.bin 2.3s 2.0s
+// test.tar 6.7s 5.7s
+//
+// NOTE: As of I write this comment, we only measure the performance on about
+// 10+ benchmarks. It's still too early to say which one works better.
+//
+#if !defined(BS_MULTI_VER)
+static bool __attribute__((always_inline)) inline
+Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) {
+ if (vect_len <= 8) {
+ for (int i = 0; i < vect_len; i++) {
+ if (input_vect[i] == input) {
+ idx = i;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // The "low" and "high" must be signed integers, as they could become -1.
+ // Also since they are signed integer, "(low + high)/2" is sightly more
+ // expensive than (low+high)>>1 or ((unsigned)(low + high))/2.
+ //
+ int low = 0, high = vect_len - 1;
+ while (low <= high) {
+ int mid = (low + high) >> 1;
+ InputTy mid_c = input_vect[mid];
+
+ if (input < mid_c)
+ high = mid - 1;
+ else if (input > mid_c)
+ low = mid + 1;
+ else {
+ idx = mid;
+ return true;
+ }
+ }
+ return false;
+}
+
+#else
+
+/* Let us call this version "pristine" version. */
+static inline bool
+Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) {
+ int low = 0, high = vect_len - 1;
+ while (low <= high) {
+ int mid = (low + high) >> 1;
+ InputTy mid_c = input_vect[mid];
+
+ if (input < mid_c)
+ high = mid - 1;
+ else if (input > mid_c)
+ low = mid + 1;
+ else {
+ idx = mid;
+ return true;
+ }
+ }
+ return false;
+}
+#endif
+
+typedef enum {
+ // Look for the first match. e.g. pattern set = {"ab", "abc", "def"},
+ // subject string "ababcdef". The first match would be "ab" at the
+ // beginning of the subject string.
+ MV_FIRST_MATCH,
+
+ // Look for the left-most longest match. Follow above example; there are
+ // two longest matches, "abc" and "def", and the left-most longest match
+ // is "abc".
+ MV_LEFT_LONGEST,
+
+ // Similar to the left-most longest match, except that it returns the
+ // *right* most longest match. Follow above example, the match would
+ // be "def". NYI.
+ MV_RIGHT_LONGEST,
+
+ // Return all patterns that match that given subject string. NYI.
+ MV_ALL_MATCHES,
+} MATCH_VARIANT;
+
+/* The Match_Tmpl is the template for vairants MV_FIRST_MATCH, MV_LEFT_LONGEST,
+ * MV_RIGHT_LONGEST (If we really really need MV_RIGHT_LONGEST variant, we are
+ * better off implementing it in a seprate function).
+ *
+ * The Match_Tmpl supports three variants at once "symbolically", once it's
+ * instanced to a particular variants, all the code irrelevant to the variants
+ * will be statically removed. So don't worry about the code like
+ * "if (variant == MV_XXXX)"; they will not incur any penalty.
+ *
+ * The drawback of using template is increased code size. Unfortunately, there
+ * is no silver bullet.
+ */
+template<MATCH_VARIANT variant> static ac_result_t
+Match_Tmpl(AC_Buffer* buf, const char* str, uint32 len) {
+ unsigned char* buf_base = (unsigned char*)(buf);
+ unsigned char* root_goto = buf_base + buf->root_goto_ofst;
+ AC_Ofst* states_ofst_vect = (AC_Ofst* )(buf_base + buf->states_ofst_ofst);
+
+ AC_State* state = 0;
+ uint32 idx = 0;
+
+ // Skip leading chars that are not valid input of root-nodes.
+ if (likely(buf->root_goto_num != 255)) {
+ while(idx < len) {
+ unsigned char c = str[idx++];
+ if (unsigned char kid_id = root_goto[c]) {
+ state = Get_State_Addr(buf_base, states_ofst_vect, kid_id);
+ break;
+ }
+ }
+ } else {
+ idx = 1;
+ state = Get_State_Addr(buf_base, states_ofst_vect, *str);
+ }
+
+ ac_result_t r = {-1, -1};
+ if (likely(state != 0)) {
+ if (unlikely(state->is_term)) {
+ /* Dictionary may have string of length 1 */
+ r.match_begin = idx - state->depth;
+ r.match_end = idx - 1;
+ r.pattern_idx = state->is_term - 1;
+
+ if (variant == MV_FIRST_MATCH) {
+ return r;
+ }
+ }
+ }
+
+ while (idx < len) {
+ unsigned char c = str[idx];
+ int res;
+ bool found;
+ found = Binary_Search_Input(state->input_vect, state->goto_num, c, res);
+ if (found) {
+ // The "t = goto(c, current_state)" is valid, advance to state "t".
+ uint32 kid = state->first_kid + res;
+ state = Get_State_Addr(buf_base, states_ofst_vect, kid);
+ idx++;
+ } else {
+ // Follow the fail-link.
+ State_ID fl = state->fail_link;
+ if (fl == 0) {
+ // fail-link is root-node, which implies the root-node dosen't
+ // have 255 valid transitions (otherwise, the fail-link should
+ // points to "goto(root, c)"), so we don't need speical handling
+ // as we did before this while-loop is entered.
+ //
+ while(idx < len) {
+ InputTy c = str[idx++];
+ if (unsigned char kid_id = root_goto[c]) {
+ state =
+ Get_State_Addr(buf_base, states_ofst_vect, kid_id);
+ break;
+ }
+ }
+ } else {
+ state = Get_State_Addr(buf_base, states_ofst_vect, fl);
+ }
+ }
+
+ // Check to see if the state is terminal state?
+ if (state->is_term) {
+ if (variant == MV_FIRST_MATCH) {
+ ac_result_t r;
+ r.match_begin = idx - state->depth;
+ r.match_end = idx - 1;
+ r.pattern_idx = state->is_term - 1;
+ return r;
+ }
+
+ if (variant == MV_LEFT_LONGEST) {
+ int match_begin = idx - state->depth;
+ int match_end = idx - 1;
+
+ if (r.match_begin == -1 ||
+ match_end - match_begin > r.match_end - r.match_begin) {
+ r.match_begin = match_begin;
+ r.match_end = match_end;
+ r.pattern_idx = state->is_term - 1;
+ }
+ continue;
+ }
+
+ ASSERT(false && "NYI");
+ }
+ }
+
+ return r;
+}
+
+ac_result_t
+Match(AC_Buffer* buf, const char* str, uint32 len) {
+ return Match_Tmpl<MV_FIRST_MATCH>(buf, str, len);
+}
+
+ac_result_t
+Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len) {
+ return Match_Tmpl<MV_LEFT_LONGEST>(buf, str, len);
+}
+
+#ifdef DEBUG
+void
+AC_Converter::dump_buffer(AC_Buffer* buf, FILE* f) {
+ vector<AC_Ofst> state_ofst;
+ state_ofst.resize(_id_map.size());
+
+ fprintf(f, "Id maps between old/slow and new/fast graphs\n");
+ int old_id = 0;
+ for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end();
+ i != e; i++, old_id++) {
+ State_ID new_id = *i;
+ if (new_id != 0) {
+ fprintf(f, "%d -> %d, ", old_id, new_id);
+ }
+ }
+ fprintf(f, "\n");
+
+ int idx = 0;
+ for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end();
+ i != e; i++, idx++) {
+ uint32 id = *i;
+ if (id == 0) continue;
+ state_ofst[id] = _ofst_map[idx];
+ }
+
+ unsigned char* buf_base = (unsigned char*)buf;
+
+ // dump root goto-function.
+ fprintf(f, "root, fanout:%d goto {", buf->root_goto_num);
+ if (buf->root_goto_num != 255) {
+ unsigned char* root_goto = buf_base + buf->root_goto_ofst;
+ for (uint32 i = 0; i < 255; i++) {
+ if (root_goto[i] != 0)
+ fprintf(f, "%c->S:%d, ", (unsigned char)i, root_goto[i]);
+ }
+ } else {
+ fprintf(f, "full fanout\n");
+ }
+ fprintf(f, "}\n");
+
+ // dump remaining states.
+ AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst);
+ for (uint32 i = 1, e = buf->state_num; i < e; i++) {
+ AC_Ofst ofst = state_ofst_vect[i];
+ ASSERT(ofst == state_ofst[i]);
+ fprintf(f, "S:%d, ofst:%d, goto={", i, ofst);
+
+ AC_State* s = (AC_State*)(buf_base + ofst);
+ State_ID kid = s->first_kid;
+ for (uint32 k = 0, ke = s->goto_num; k < ke; k++, kid++)
+ fprintf(f, "%c->S:%d, ", s->input_vect[k], kid);
+
+ fprintf(f, "}, fail-link = S:%d, %s\n", s->fail_link,
+ s->is_term ? "terminal" : "");
+ }
+}
+#endif
diff --git a/modules/policy/lua-aho-corasick/ac_fast.hpp b/modules/policy/lua-aho-corasick/ac_fast.hpp
new file mode 100644
index 0000000..9ac557c
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_fast.hpp
@@ -0,0 +1,124 @@
+#ifndef AC_FAST_H
+#define AC_FAST_H
+
+#include <vector>
+#include "ac.h"
+#include "ac_slow.hpp"
+
+using namespace std;
+
+class ACS_Constructor;
+
+typedef uint32 AC_Ofst;
+typedef uint32 State_ID;
+
+// The entire "fast" AC graph is converted from its "slow" version, and store
+// in an consecutive trunk of memory or "buffer". Since the pointers in the
+// fast AC graph are represented as offset relative to the base address of
+// the buffer, this fast AC graph is position-independent, meaning cloning
+// the fast graph is just to memcpy the entire buffer.
+//
+// The buffer is laid-out as following:
+//
+// 1. The buffer header. (i.e. the AC_Buffer content)
+// 2. root-node's goto functions. It is represented as an array indiced by
+// root-node's valid inputs, and the element is the ID of the corresponding
+// transition state (aka kid). To save space, we used 8-bit to represent
+// the IDs. ID of root's kids starts with 1.
+//
+// Root may have 255 valid inputs. In this speical case, i-th element
+// stores value i -- i.e the i-th state. So, we don't need such array
+// at all. On the other hand, 8-bit is insufficient to encode kids' ID.
+//
+// 3. An array indiced by state's id, and the element is the offset
+// of correspoding state wrt the base address of the buffer.
+//
+// 4. the contents of states.
+//
+typedef struct {
+ buf_header_t hdr; // The header exposed to the user using this lib.
+#ifdef VERIFY
+ ACS_Constructor* slow_impl;
+#endif
+ uint32 buf_len;
+ AC_Ofst root_goto_ofst; // addr of root node's goto() function.
+ AC_Ofst states_ofst_ofst; // addr of state pointer vector (indiced by id)
+ AC_Ofst first_state_ofst; // addr of the first state in the buffer.
+ uint16 root_goto_num; // fan-out of root-node.
+ uint16 state_num; // number of states
+
+ // Followed by the gut of the buffer:
+ // 1. map: root's-valid-input -> kid's id
+ // 2. map: state's ID -> offset of the state
+ // 3. states' content.
+} AC_Buffer;
+
+// Depict the state of "fast" AC graph.
+typedef struct {
+ // transition are sorted. For instance, state s1, has two transitions :
+ // goto(b) -> S_b, goto(a)->S_a. The inputs are sorted in the ascending
+ // order, and the target states are permuted accordingly. In this case,
+ // the inputs are sorted as : a, b, and the target states are permuted
+ // into S_a, S_b. So, S_a is the 1st kid, the ID of kids are consecutive,
+ // so we don't need to save all the target kids.
+ //
+ State_ID first_kid;
+ AC_Ofst fail_link;
+ short depth; // How far away from root.
+ unsigned short is_term; // Is terminal node. if is_term != 0, it encodes
+ // the value of "1 + pattern-index".
+ unsigned char goto_num; // The number of valid transition.
+ InputTy input_vect[1]; // Vector of valid input. Must be last field!
+} AC_State;
+
+class Buf_Allocator {
+public:
+ Buf_Allocator() : _buf(0) {}
+ virtual ~Buf_Allocator() { free(); }
+
+ virtual AC_Buffer* alloc(int sz) = 0;
+ virtual void free() {};
+protected:
+ AC_Buffer* _buf;
+};
+
+// Convert slow-AC-graph into fast one.
+class AC_Converter {
+public:
+ AC_Converter(ACS_Constructor& acs, Buf_Allocator& ba) :
+ _acs(acs), _buf_alloc(ba) {}
+ AC_Buffer* Convert();
+
+private:
+ // Return the size in byte needed to to save the specified state.
+ uint32 Calc_State_Sz(const ACS_State *) const;
+
+ // In fast-AC-graph, the ID is bit trikcy. Given a state of slow-graph,
+ // this function is to return the ID of its counterpart in the fast-graph.
+ State_ID Get_Renumbered_Id(const ACS_State *s) const {
+ const vector<uint32> &m = _id_map;
+ return m[s->Get_ID()];
+ }
+
+ AC_Buffer* Alloc_Buffer();
+ void Populate_Root_Goto_Func(AC_Buffer *, GotoVect&);
+
+#ifdef DEBUG
+ void dump_buffer(AC_Buffer*, FILE*);
+#endif
+
+private:
+ ACS_Constructor& _acs;
+ Buf_Allocator& _buf_alloc;
+
+ // map: ID of state in slow-graph -> ID of counterpart in fast-graph.
+ vector<uint32> _id_map;
+
+ // map: ID of state in slow-graph -> offset of counterpart in fast-graph.
+ vector<AC_Ofst> _ofst_map;
+};
+
+ac_result_t Match(AC_Buffer* buf, const char* str, uint32 len);
+ac_result_t Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len);
+
+#endif // AC_FAST_H
diff --git a/modules/policy/lua-aho-corasick/ac_lua.cxx b/modules/policy/lua-aho-corasick/ac_lua.cxx
new file mode 100644
index 0000000..ad7307e
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_lua.cxx
@@ -0,0 +1,173 @@
+// Interface functions for libac.so
+//
+#include <vector>
+#include <string>
+#include "ac_slow.hpp"
+#include "ac_fast.hpp"
+#include "ac.h" // for the definition of ac_result_t
+#include "ac_util.hpp"
+
+extern "C" {
+ #include <lua.h>
+ #include <lauxlib.h>
+}
+
+#if defined(USE_SLOW_VER)
+#error "Not going to implement it"
+#endif
+
+using namespace std;
+static const char* tname = "aho-corasick";
+
+class BufAlloc : public Buf_Allocator {
+public:
+ BufAlloc(lua_State* L) : _L(L) {}
+ virtual AC_Buffer* alloc(int sz) {
+ return (AC_Buffer*)lua_newuserdata (_L, sz);
+ }
+
+ // Let GC to take care.
+ virtual void free() {}
+
+private:
+ lua_State* _L;
+};
+
+static bool
+_create_helper(lua_State* L, const vector<const char*>& str_v,
+ const vector<unsigned int>& strlen_v) {
+ ASSERT(str_v.size() == strlen_v.size());
+
+ ACS_Constructor acc;
+ BufAlloc ba(L);
+
+ // Step 1: construt the slow version.
+ unsigned int strnum = str_v.size();
+ const char** str_vect = new const char*[strnum];
+ unsigned int* strlen_vect = new unsigned int[strnum];
+
+ int idx = 0;
+ for (vector<const char*>::const_iterator i = str_v.begin(), e = str_v.end();
+ i != e; i++) {
+ str_vect[idx++] = *i;
+ }
+
+ idx = 0;
+ for (vector<unsigned int>::const_iterator i = strlen_v.begin(),
+ e = strlen_v.end(); i != e; i++) {
+ strlen_vect[idx++] = *i;
+ }
+
+ acc.Construct(str_vect, strlen_vect, idx);
+ delete[] str_vect;
+ delete[] strlen_vect;
+
+ // Step 2: convert to fast version
+ AC_Converter cvt(acc, ba);
+ return cvt.Convert() != 0;
+}
+
+static ac_result_t
+_match_helper(buf_header_t* ac, const char *str, unsigned int len) {
+ AC_Buffer* buf = (AC_Buffer*)(void*)ac;
+ ASSERT(ac->magic_num == AC_MAGIC_NUM);
+
+ ac_result_t r = Match(buf, str, len);
+ return r;
+}
+
+// LUA sematic:
+// input: array of strings
+// output: userdata containing the AC-graph (i.e. the AC_Buffer).
+//
+static int
+lac_create(lua_State* L) {
+ // The table of the array must be the 1st argument.
+ int input_tab = 1;
+
+ luaL_checktype(L, input_tab, LUA_TTABLE);
+
+ // Init the "iteartor".
+ lua_pushnil(L);
+
+ vector<const char*> str_v;
+ vector<unsigned int> strlen_v;
+
+ // Loop over the elements
+ while (lua_next(L, input_tab)) {
+ size_t str_len;
+ const char* s = luaL_checklstring(L, -1, &str_len);
+ str_v.push_back(s);
+ strlen_v.push_back(str_len);
+
+ // remove the value, but keep the key as the iterator.
+ lua_pop(L, 1);
+ }
+
+ // pop the nil value
+ lua_pop(L, 1);
+
+ if (_create_helper(L, str_v, strlen_v)) {
+ // The AC graph, as a userdata is already pushed to the stack, hence 1.
+ return 1;
+ }
+
+ return 0;
+}
+
+// LUA input:
+// arg1: the userdata, representing the AC graph, returned from l_create().
+// arg2: the string to be matched.
+//
+// LUA return:
+// if match, return index range of the match; otherwise nil is returned.
+//
+static int
+lac_match(lua_State* L) {
+ buf_header_t* ac = (buf_header_t*)lua_touserdata(L, 1);
+ if (!ac) {
+ luaL_checkudata(L, 1, tname);
+ return 0;
+ }
+
+ size_t len;
+ const char* str;
+ #if LUA_VERSION_NUM >= 502
+ str = luaL_tolstring(L, 2, &len);
+ #else
+ str = lua_tolstring(L, 2, &len);
+ #endif
+ if (!str) {
+ luaL_checkstring(L, 2);
+ return 0;
+ }
+
+ ac_result_t r = _match_helper(ac, str, len);
+ if (r.match_begin != -1) {
+ lua_pushinteger(L, r.match_begin);
+ lua_pushinteger(L, r.match_end);
+ return 2;
+ }
+
+ return 0;
+}
+
+static const struct luaL_Reg lib_funcs[] = {
+ { "create", lac_create },
+ { "match", lac_match },
+ {0, 0}
+};
+
+extern "C" int AC_EXPORT
+luaopen_ahocorasick(lua_State* L) {
+ luaL_newmetatable(L, tname);
+
+#if LUA_VERSION_NUM == 501
+ luaL_register(L, tname, lib_funcs);
+#elif LUA_VERSION_NUM >= 502
+ luaL_newlib(L, lib_funcs);
+#else
+ #error "Don't know how to do it right"
+#endif
+ return 1;
+}
diff --git a/modules/policy/lua-aho-corasick/ac_slow.cxx b/modules/policy/lua-aho-corasick/ac_slow.cxx
new file mode 100644
index 0000000..cb3957a
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_slow.cxx
@@ -0,0 +1,318 @@
+#include <ctype.h>
+#include <strings.h> // for bzero
+#include <algorithm>
+#include "ac_slow.hpp"
+#include "ac.h"
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Implementation of AhoCorasick_Slow
+//
+//////////////////////////////////////////////////////////////////////////
+//
+ACS_Constructor::ACS_Constructor() : _next_node_id(1) {
+ _root = new_state();
+ _root_char = new InputTy[256];
+ bzero((void*)_root_char, 256);
+
+#ifdef VERIFY
+ _pattern_buf = 0;
+#endif
+}
+
+ACS_Constructor::~ACS_Constructor() {
+ for (std::vector<ACS_State* >::iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ delete *i;
+ }
+ _all_states.clear();
+ delete[] _root_char;
+
+#ifdef VERIFY
+ delete[] _pattern_buf;
+#endif
+}
+
+ACS_State*
+ACS_Constructor::new_state() {
+ ACS_State* t = new ACS_State(_next_node_id++);
+ _all_states.push_back(t);
+ return t;
+}
+
+void
+ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len,
+ int pattern_idx) {
+ ACS_State* state = _root;
+ for (unsigned int i = 0; i < str_len; i++) {
+ const char c = str[i];
+ ACS_State* new_s = state->Get_Goto(c);
+ if (!new_s) {
+ new_s = new_state();
+ new_s->_depth = state->_depth + 1;
+ state->Set_Goto(c, new_s);
+ }
+ state = new_s;
+ }
+ state->_is_terminal = true;
+ state->set_Pattern_Idx(pattern_idx);
+}
+
+void
+ACS_Constructor::Propagate_faillink() {
+ ACS_State* r = _root;
+ std::vector<ACS_State*> wl;
+
+ const ACS_Goto_Map& m = r->Get_Goto_Map();
+ for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) {
+ ACS_State* s = i->second;
+ s->_fail_link = r;
+ wl.push_back(s);
+ }
+
+ // For any input c, make sure "goto(root, c)" is valid, which make the
+ // fail-link propagation lot easier.
+ ACS_Goto_Map goto_save = r->_goto_map;
+ for (uint32 i = 0; i <= 255; i++) {
+ ACS_State* s = r->Get_Goto(i);
+ if (!s) r->Set_Goto(i, r);
+ }
+
+ for (uint32 i = 0; i < wl.size(); i++) {
+ ACS_State* s = wl[i];
+ ACS_State* fl = s->_fail_link;
+
+ const ACS_Goto_Map& tran_map = s->Get_Goto_Map();
+
+ for (ACS_Goto_Map::const_iterator ii = tran_map.begin(),
+ ee = tran_map.end(); ii != ee; ii++) {
+ InputTy c = ii->first;
+ ACS_State *tran = ii->second;
+
+ ACS_State* tran_fl = 0;
+ for (ACS_State* fl_walk = fl; ;) {
+ if (ACS_State* t = fl_walk->Get_Goto(c)) {
+ tran_fl = t;
+ break;
+ } else {
+ fl_walk = fl_walk->Get_FailLink();
+ }
+ }
+
+ tran->_fail_link = tran_fl;
+ wl.push_back(tran);
+ }
+ }
+
+ // Remove "goto(root, c) == root" transitions
+ r->_goto_map = goto_save;
+}
+
+void
+ACS_Constructor::Construct(const char** strv, unsigned int* strlenv,
+ uint32 strnum) {
+ Save_Patterns(strv, strlenv, strnum);
+
+ for (uint32 i = 0; i < strnum; i++) {
+ Add_Pattern(strv[i], strlenv[i], i);
+ }
+
+ Propagate_faillink();
+ unsigned char* p = _root_char;
+
+ const ACS_Goto_Map& m = _root->Get_Goto_Map();
+ for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end();
+ i != e; i++) {
+ p[i->first] = 1;
+ }
+}
+
+Match_Result
+ACS_Constructor::MatchHelper(const char *str, uint32 len) const {
+ const ACS_State* root = _root;
+ const ACS_State* state = root;
+
+ uint32 idx = 0;
+ while (idx < len) {
+ InputTy c = str[idx];
+ idx++;
+ if (_root_char[c]) {
+ state = root->Get_Goto(c);
+ break;
+ }
+ }
+
+ if (unlikely(state->is_Terminal())) {
+ // This could happen if the one of the pattern has only one char!
+ uint32 pos = idx - 1;
+ Match_Result r(pos - state->Get_Depth() + 1, pos,
+ state->get_Pattern_Idx());
+ return r;
+ }
+
+ while (idx < len) {
+ InputTy c = str[idx];
+ ACS_State* gs = state->Get_Goto(c);
+
+ if (!gs) {
+ ACS_State* fl = state->Get_FailLink();
+ if (fl == root) {
+ while (idx < len) {
+ InputTy c = str[idx];
+ idx++;
+ if (_root_char[c]) {
+ state = root->Get_Goto(c);
+ break;
+ }
+ }
+ } else {
+ state = fl;
+ }
+ } else {
+ idx ++;
+ state = gs;
+ }
+
+ if (state->is_Terminal()) {
+ uint32 pos = idx - 1;
+ Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos,
+ state->get_Pattern_Idx());
+ return r;
+ }
+ }
+
+ return Match_Result(-1, -1, -1);
+}
+
+#ifdef DEBUG
+void
+ACS_Constructor::dump_text(const char* txtfile) const {
+ FILE* f = fopen(txtfile, "w+");
+ for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ ACS_State* s = *i;
+
+ fprintf(f, "S%d goto:{", s->Get_ID());
+ const ACS_Goto_Map& goto_func = s->Get_Goto_Map();
+
+ for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end();
+ i != e; i++) {
+ InputTy input = i->first;
+ ACS_State* tran = i->second;
+ if (isprint(input))
+ fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID());
+ else
+ fprintf(f, "%#x -> S:%d,", input, tran->Get_ID());
+ }
+ fprintf(f, "} ");
+
+ if (s->_fail_link) {
+ fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID());
+ }
+
+ if (s->_is_terminal) {
+ fprintf(f, ", terminal");
+ }
+
+ fprintf(f, "\n");
+ }
+ fclose(f);
+}
+
+void
+ACS_Constructor::dump_dot(const char *dotfile) const {
+ FILE* f = fopen(dotfile, "w+");
+ const char* indent = " ";
+
+ fprintf(f, "digraph G {\n");
+
+ // Emit node information
+ fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID());
+ for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ ACS_State *s = *i;
+ if (s->_is_terminal) {
+ fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID());
+ }
+ }
+ fprintf(f, "\n");
+
+ // Emit edge information
+ for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
+ e = _all_states.end(); i != e; i++) {
+ ACS_State* s = *i;
+ uint32 id = s->Get_ID();
+
+ const ACS_Goto_Map& m = s->Get_Goto_Map();
+ for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end();
+ ii != ee; ii++) {
+ InputTy input = ii->first;
+ ACS_State* tran = ii->second;
+ if (isalnum(input))
+ fprintf(f, "%s%d -> %d [label=%c];\n",
+ indent, id, tran->Get_ID(), input);
+ else
+ fprintf(f, "%s%d -> %d [label=\"%#x\"];\n",
+ indent, id, tran->Get_ID(), input);
+
+ }
+
+ // Emit fail-link
+ ACS_State* fl = s->Get_FailLink();
+ if (fl && fl != _root) {
+ fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n",
+ indent, id, fl->Get_ID());
+ }
+ }
+ fprintf(f, "}\n");
+ fclose(f);
+}
+#endif
+
+#ifdef VERIFY
+void
+ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r)
+ const {
+ if (r->begin >= 0) {
+ unsigned len = r->end - r->begin + 1;
+ int ptn_idx = r->pattern_idx;
+
+ ASSERT(ptn_idx >= 0 &&
+ len == get_ith_Pattern_Len(ptn_idx) &&
+ memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0);
+ }
+}
+
+void
+ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv,
+ int pattern_num) {
+ // calculate the total size needed to save all patterns.
+ //
+ int buf_size = 0;
+ for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; }
+
+ // HINT: patterns are delimited by '\0' in order to ease debugging.
+ buf_size += pattern_num;
+ ASSERT(_pattern_buf == 0);
+ _pattern_buf = new char[buf_size + 1];
+ #define MAGIC_NUM 0x5a
+ _pattern_buf[buf_size] = MAGIC_NUM;
+
+ int ofst = 0;
+ _pattern_lens.resize(pattern_num);
+ _pattern_vect.resize(pattern_num);
+ for (int i = 0; i < pattern_num; i++) {
+ int l = strlenv[i];
+ _pattern_lens[i] = l;
+ _pattern_vect[i] = _pattern_buf + ofst;
+
+ memcpy(_pattern_buf + ofst, strv[i], l);
+ ofst += l;
+ _pattern_buf[ofst++] = '\0';
+ }
+
+ ASSERT(_pattern_buf[buf_size] == MAGIC_NUM);
+ #undef MAGIC_NUM
+}
+
+#endif
diff --git a/modules/policy/lua-aho-corasick/ac_slow.hpp b/modules/policy/lua-aho-corasick/ac_slow.hpp
new file mode 100644
index 0000000..030b95d
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_slow.hpp
@@ -0,0 +1,158 @@
+#ifndef MY_AC_H
+#define MY_AC_H
+
+#include <string.h>
+#include <stdio.h>
+#include <map>
+#include <vector>
+#include <algorithm> // for std::sort
+#include "ac_util.hpp"
+
+// Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation"
+class ACS_State;
+class ACS_Constructor;
+class AhoCorasick;
+
+using namespace std;
+
+typedef std::map<InputTy, ACS_State*> ACS_Goto_Map;
+
+class Match_Result {
+public:
+ int begin;
+ int end;
+ int pattern_idx;
+ Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {}
+};
+
+typedef pair<InputTy, ACS_State *> GotoPair;
+typedef vector<GotoPair> GotoVect;
+
+// Sorting functor
+class GotoSort {
+public:
+ bool operator() (const GotoPair& g1, const GotoPair& g2) {
+ return g1.first < g2.first;
+ }
+};
+
+class ACS_State {
+friend class ACS_Constructor;
+
+public:
+ ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0),
+ _is_terminal(false), _fail_link(0){}
+ ~ACS_State() {};
+
+ void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; }
+ ACS_State *Get_Goto(InputTy c) const {
+ ACS_Goto_Map::const_iterator iter = _goto_map.find(c);
+ return iter != _goto_map.end() ? (*iter).second : 0;
+ }
+
+ // Return all transitions sorted in the ascending order of their input.
+ void Get_Sorted_Gotos(GotoVect& Gotos) const {
+ const ACS_Goto_Map& m = _goto_map;
+ Gotos.clear();
+ for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end();
+ i != e; i++) {
+ Gotos.push_back(GotoPair(i->first, i->second));
+ }
+ sort(Gotos.begin(), Gotos.end(), GotoSort());
+ }
+
+ ACS_State* Get_FailLink() const { return _fail_link; }
+ uint32 Get_GotoNum() const { return _goto_map.size(); }
+ uint32 Get_ID() const { return _id; }
+ uint32 Get_Depth() const { return _depth; }
+ const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; }
+ bool is_Terminal() const { return _is_terminal; }
+ int get_Pattern_Idx() const {
+ ASSERT(is_Terminal() && _pattern_idx >= 0);
+ return _pattern_idx;
+ }
+
+private:
+ void set_Pattern_Idx(int idx) {
+ ASSERT(is_Terminal());
+ _pattern_idx = idx;
+ }
+
+private:
+ uint32 _id;
+ int _pattern_idx;
+ short _depth;
+ bool _is_terminal;
+ ACS_Goto_Map _goto_map;
+ ACS_State* _fail_link;
+};
+
+class ACS_Constructor {
+public:
+ ACS_Constructor();
+ ~ACS_Constructor();
+
+ void Construct(const char** strv, unsigned int* strlenv,
+ unsigned int strnum);
+
+ Match_Result Match(const char* s, uint32 len) const {
+ Match_Result r = MatchHelper(s, len);
+ Verify_Result(s, &r);
+ return r;
+ }
+
+ Match_Result Match(const char* s) const { return Match(s, strlen(s)); }
+
+#ifdef DEBUG
+ void dump_text(const char* = "ac.txt") const;
+ void dump_dot(const char* = "ac.dot") const;
+#endif
+ const ACS_State *Get_Root_State() const { return _root; }
+ const vector<ACS_State*>& Get_All_States() const {
+ return _all_states;
+ }
+
+ uint32 Get_Next_Node_Id() const { return _next_node_id; }
+ uint32 Get_State_Num() const { return _next_node_id - 1; }
+
+private:
+ void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx);
+ ACS_State* new_state();
+ void Propagate_faillink();
+
+ Match_Result MatchHelper(const char*, uint32 len) const;
+
+#ifdef VERIFY
+ void Verify_Result(const char* subject, const Match_Result* r) const;
+ void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len);
+ const char* get_ith_Pattern(unsigned i) const {
+ ASSERT(i < _pattern_vect.size());
+ return _pattern_vect.at(i);
+ }
+ unsigned get_ith_Pattern_Len(unsigned i) const {
+ ASSERT(i < _pattern_lens.size());
+ return _pattern_lens.at(i);
+ }
+#else
+ void Verify_Result(const char* subject, const Match_Result* r) const {
+ (void)subject; (void)r;
+ }
+ void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) {
+ (void)strv; (void)strlenv;
+ }
+#endif
+
+private:
+ ACS_State* _root;
+ vector<ACS_State*> _all_states;
+ unsigned char* _root_char;
+ uint32 _next_node_id;
+
+#ifdef VERIFY
+ char* _pattern_buf;
+ vector<int> _pattern_lens;
+ vector<char*> _pattern_vect;
+#endif
+};
+
+#endif
diff --git a/modules/policy/lua-aho-corasick/ac_util.hpp b/modules/policy/lua-aho-corasick/ac_util.hpp
new file mode 100644
index 0000000..56fd46c
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/ac_util.hpp
@@ -0,0 +1,69 @@
+/*
+ Copyright (c) 2014 CloudFlare, Inc. All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions are
+ met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following disclaimer
+ in the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of CloudFlare, Inc. nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+#ifndef AC_UTIL_H
+#define AC_UTIL_H
+
+#ifdef DEBUG
+#include <stdio.h> // for fprintf
+#include <stdlib.h> // for abort
+#endif
+
+typedef unsigned short uint16;
+typedef unsigned int uint32;
+typedef unsigned long uint64;
+typedef unsigned char InputTy;
+
+#ifdef DEBUG
+ // Usage examples: ASSERT(a > b), ASSERT(foo() && "Opps, foo() reutrn 0");
+ #define ASSERT(c) if (!(c))\
+ { fprintf(stderr, "%s:%d Assert: %s\n", __FILE__, __LINE__, #c); abort(); }
+#else
+ #define ASSERT(c) ((void)0)
+#endif
+
+#define likely(x) __builtin_expect((x),1)
+#define unlikely(x) __builtin_expect((x),0)
+
+#ifndef offsetof
+#define offsetof(st, m) ((size_t)(&((st *)0)->m))
+#endif
+
+typedef enum {
+ IMPL_SLOW_VARIANT = 1,
+ IMPL_FAST_VARIANT = 2,
+} impl_var_t;
+
+#define AC_MAGIC_NUM 0x5a
+typedef struct {
+ unsigned char magic_num;
+ unsigned char impl_variant;
+} buf_header_t;
+
+#endif //AC_UTIL_H
diff --git a/modules/policy/lua-aho-corasick/load_ac.lua b/modules/policy/lua-aho-corasick/load_ac.lua
new file mode 100644
index 0000000..eb70446
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/load_ac.lua
@@ -0,0 +1,90 @@
+-- Helper wrappring script for loading shared object libac.so (FFI interface)
+-- from package.cpath instead of LD_LIBRARTY_PATH.
+--
+
+local ffi = require 'ffi'
+ffi.cdef[[
+ void* ac_create(const char** str_v, unsigned int* strlen_v,
+ unsigned int v_len);
+ int ac_match2(void*, const char *str, int len);
+ void ac_free(void*);
+]]
+
+local _M = {}
+
+local string_gmatch = string.gmatch
+local string_match = string.match
+
+local ac_lib = nil
+local ac_create = nil
+local ac_match = nil
+local ac_free = nil
+
+--[[ Find shared object file package.cpath, obviating the need of setting
+ LD_LIBRARY_PATH
+]]
+local function find_shared_obj(cpath, so_name)
+ for k, v in string_gmatch(cpath, "[^;]+") do
+ local so_path = string_match(k, "(.*/)")
+ if so_path then
+ -- "so_path" could be nil. e.g, the dir path component is "."
+ so_path = so_path .. so_name
+
+ -- Don't get me wrong, the only way to know if a file exist is
+ -- trying to open it.
+ local f = io.open(so_path)
+ if f ~= nil then
+ io.close(f)
+ return so_path
+ end
+ end
+ end
+end
+
+function _M.load_ac_lib()
+ if ac_lib ~= nil then
+ return ac_lib
+ else
+ local so_path = find_shared_obj(package.cpath, "libac.so")
+ if so_path ~= nil then
+ ac_lib = ffi.load(so_path)
+ ac_create = ac_lib.ac_create
+ ac_match = ac_lib.ac_match2
+ ac_free = ac_lib.ac_free
+ return ac_lib
+ end
+ end
+end
+
+-- Create an Aho-Corasick instance, and return the instance if it was
+-- successful.
+function _M.create_ac(dict)
+ local strnum = #dict
+ if ac_lib == nil then
+ _M.load_ac_lib()
+ end
+
+ local str_v = ffi.new("const char *[?]", strnum)
+ local strlen_v = ffi.new("unsigned int [?]", strnum)
+
+ for i = 1, strnum do
+ local s = dict[i]
+ str_v[i - 1] = s
+ strlen_v[i - 1] = #s
+ end
+
+ local ac = ac_create(str_v, strlen_v, strnum);
+ if ac ~= nil then
+ return ffi.gc(ac, ac_free)
+ end
+end
+
+-- Return nil if str doesn't match the dictionary, else return non-nil.
+function _M.match(ac, str)
+ local r = ac_match(ac, str, #str);
+ if r >= 0 then
+ return r
+ end
+end
+
+return _M
diff --git a/modules/policy/lua-aho-corasick/mytest.cxx b/modules/policy/lua-aho-corasick/mytest.cxx
new file mode 100644
index 0000000..ef3dc87
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/mytest.cxx
@@ -0,0 +1,200 @@
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include "ac.h"
+
+using namespace std;
+
+/////////////////////////////////////////////////////////////////////////
+//
+// Test using strings from input files
+//
+/////////////////////////////////////////////////////////////////////////
+//
+class BigFileTester {
+public:
+ BigFileTester(const char* filepath);
+
+private:
+ void Genector
+privaete:
+ const char* _msg;
+ int _msg_len;
+ int _key_num; // number of strings in dictionary
+ int _key_len_idx;
+};
+
+/////////////////////////////////////////////////////////////////////////
+//
+// Simple (yet maybe tricky) testings
+//
+/////////////////////////////////////////////////////////////////////////
+//
+typedef struct {
+ const char* str;
+ const char* match;
+} StrPair;
+
+typedef struct {
+ const char* name;
+ const char** dict;
+ StrPair* strpairs;
+ int dict_len;
+ int strpair_num;
+} TestingCase;
+
+class Tests {
+public:
+ Tests(const char* name,
+ const char* dict[], int dict_len,
+ StrPair strpairs[], int strpair_num) {
+ if (!_tests)
+ _tests = new vector<TestingCase>;
+
+ TestingCase tc;
+ tc.name = name;
+ tc.dict = dict;
+ tc.strpairs = strpairs;
+ tc.dict_len = dict_len;
+ tc.strpair_num = strpair_num;
+ _tests->push_back(tc);
+ }
+
+ static vector<TestingCase>* Get_Tests() { return _tests; }
+ static void Erase_Tests() { delete _tests; _tests = 0; }
+
+private:
+ static vector<TestingCase> *_tests;
+};
+
+vector<TestingCase>* Tests::_tests = 0;
+
+static void
+simple_test(void) {
+ int total = 0;
+ int fail = 0;
+
+ vector<TestingCase> *tests = Tests::Get_Tests();
+ if (!tests)
+ return 0;
+
+ for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end();
+ i != e; i++) {
+ TestingCase& t = *i;
+ fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name);
+ for (int i = 0, e = t.dict_len, need_break=0; i < e; i++) {
+ fprintf(stdout, "%s, ", t.dict[i]);
+ if (need_break++ == 16) {
+ fputs("\n ", stdout);
+ need_break = 0;
+ }
+ }
+ fputs("]\n", stdout);
+
+ /* Create the dictionary */
+ int dict_len = t.dict_len;
+ ac_t* ac = ac_create(t.dict, dict_len);
+
+ for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) {
+ const StrPair& sp = t.strpairs[ii];
+ const char *str = sp.str; // the string to be matched
+ const char *match = sp.match;
+
+ fprintf(stdout, "[%3d] Testing '%s' : ", total, str);
+
+ int len = strlen(str);
+ ac_result_t r = ac_match(ac, str, len);
+ int m_b = r.match_begin;
+ int m_e = r.match_end;
+
+ // The return value per se is insane.
+ if (m_b > m_e ||
+ ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) {
+ fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e);
+ fail ++;
+ continue;
+ }
+
+ // If the string is not supposed to match the dictionary.
+ if (!match) {
+ if (m_b != -1 || m_e != -1) {
+ fail ++;
+ fprintf(stdout, "Not Supposed to match (%d, %d) \n",
+ m_b, m_e);
+ } else
+ fputs("Pass\n", stdout);
+ continue;
+ }
+
+ // The string or its substring is match the dict.
+ if (m_b >= len || m_b >= len) {
+ fail ++;
+ fprintf(stdout,
+ "Return value >= the length of the string (%d, %d)\n",
+ m_b, m_e);
+ continue;
+ } else {
+ int mlen = strlen(match);
+ if ((mlen != m_e - m_b + 1) ||
+ strncmp(str + m_b, match, mlen)) {
+ fail ++;
+ fprintf(stdout, "Fail\n");
+ } else
+ fprintf(stdout, "Pass\n");
+ }
+ }
+ fputs("\n", stdout);
+ ac_free(ac);
+ }
+
+ fprintf(stdout, "Total : %d, Fail %d\n", total, fail);
+
+ return fail ? -1 : 0;
+}
+
+int
+main (int argc, char** argv) {
+ int res = simple_test();
+ return res;
+};
+
+/* test 1*/
+const char *dict1[] = {"he", "she", "his", "her"};
+StrPair strpair1[] = {
+ {"he", "he"}, {"she", "she"}, {"his", "his"},
+ {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"},
+ {"shis2", "his"}, {"ahhe", "he"}
+};
+Tests test1("test 1",
+ dict1, sizeof(dict1)/sizeof(dict1[0]),
+ strpair1, sizeof(strpair1)/sizeof(strpair1[0]));
+
+/* test 2*/
+const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/
+StrPair strpair2[] = {{"The pot had a handle", 0}};
+Tests test2("test 2", dict2, 2, strpair2, 1);
+
+/* test 3*/
+const char *dict3[] = {"The"};
+StrPair strpair3[] = {{"The pot had a handle", "The"}};
+Tests test3("test 3", dict3, 1, strpair3, 1);
+
+/* test 4*/
+const char *dict4[] = {"pot"};
+StrPair strpair4[] = {{"The pot had a handle", "pot"}};
+Tests test4("test 4", dict4, 1, strpair4, 1);
+
+/* test 5*/
+const char *dict5[] = {"pot "};
+StrPair strpair5[] = {{"The pot had a handle", "pot "}};
+Tests test5("test 5", dict5, 1, strpair5, 1);
+
+/* test 6*/
+const char *dict6[] = {"ot h"};
+StrPair strpair6[] = {{"The pot had a handle", "ot h"}};
+Tests test6("test 6", dict6, 1, strpair6, 1);
+
+/* test 7*/
+const char *dict7[] = {"andle"};
+StrPair strpair7[] = {{"The pot had a handle", "andle"}};
+Tests test7("test 7", dict7, 1, strpair7, 1);
diff --git a/modules/policy/lua-aho-corasick/tests/Makefile b/modules/policy/lua-aho-corasick/tests/Makefile
new file mode 100644
index 0000000..54fd90f
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/Makefile
@@ -0,0 +1,65 @@
+OS := $(shell uname)
+ifeq ($(OS), Darwin)
+ SO_EXT := dylib
+else
+ SO_EXT := so
+endif
+
+.PHONY = all clean test runtest benchmark
+
+PROGRAM = ac_test
+BENCHMARK = ac_bench
+all: runtest
+
+CXXFLAGS = -O3 -g -march=native -Wall -DDEBUG
+MYCXXFLAGS = -MMD -I.. $(CXXFLAGS)
+%.o : %.cxx
+ $(CXX) $< -c $(MYCXXFLAGS)
+
+-include dep.cxx
+SRC = test_main.cxx ac_test_simple.cxx ac_test_aggr.cxx test_bigfile.cxx
+
+OBJ = ${SRC:.cxx=.o}
+
+-include test_dep.txt
+-include bench_dep.txt
+
+$(PROGRAM) $(BENCHMARK) : testinput/text.tar testinput/image.bin
+$(PROGRAM) : $(OBJ) ../libac.$(SO_EXT)
+ $(CXX) $(OBJ) -L.. -lac -o $@
+ -cat *.d > test_dep.txt
+
+$(BENCHMARK) : ac_bench.o ../libac.$(SO_EXT)
+ $(CXX) ac_bench.o -L.. -lac -o $@
+ -cat *.d > bench_dep.txt
+
+ifneq ($(OS), Darwin)
+runtest:$(PROGRAM)
+ LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/*
+
+benchmark:$(BENCHMARK)
+ LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./ac_bench
+
+else
+runtest:$(PROGRAM)
+ DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/*
+
+benchmark:$(BENCHMARK)
+ DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./ac_bench
+
+endif
+
+testinput/text.tar:
+ echo "download testing files (gcc tarball)..."
+ if [ ! -d testinput ] ; then mkdir testinput; fi
+ cd testinput && \
+ curl ftp://ftp.gnu.org/gnu/gcc/gcc-1.42.tar.gz -o text.tar.gz 2>/dev/null \
+ && gzip -d text.tar.gz
+
+testinput/image.bin:
+ echo "download testing files.."
+ if [ ! -d testinput ] ; then mkdir testinput; fi
+ curl http://www.3dvisionlive.com/sites/default/files/Curiosity_render_hiresb.jpg -o $@ 2>/dev/null
+
+clean:
+ -rm -f *.o *.d dep.txt $(PROGRAM) $(BENCHMARK)
diff --git a/modules/policy/lua-aho-corasick/tests/ac_bench.cxx b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx
new file mode 100644
index 0000000..421322c
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx
@@ -0,0 +1,519 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <sys/time.h>
+#include <time.h>
+#include <fcntl.h>
+#include <unistd.h>
+#include <dirent.h>
+#include <libgen.h>
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <getopt.h>
+
+#include <string>
+#include <vector>
+#include "ac.h"
+#include "ac_util.hpp"
+
+using namespace std;
+
+static bool SomethingWrong = false;
+
+static int iteration = 300;
+static string dict_dir;
+static string obj_file_dir;
+static bool print_help = false;
+static int piece_size = 1024;
+
+class PatternSet {
+public:
+ PatternSet(const char* filepath);
+ ~PatternSet() { Cleanup(); }
+
+ int getPatternNum() const { return _pat_num; }
+ const char** getPatternVector() const { return _patterns; }
+ unsigned int* getPatternLenVector() const { return _pat_len; }
+
+ const char* getErrMessage() const { return _errmsg; }
+ static bool isDictFile(const char* filepath) {
+ if (strncmp(basename(const_cast<char*>(filepath)), "dict", 4))
+ return false;
+ return true;
+ }
+
+private:
+ bool ExtractPattern(const char* filepath);
+ void Cleanup();
+
+ const char** _patterns;
+ unsigned int* _pat_len;
+ char* _mmap;
+ int _fd;
+ size_t _mmap_size;
+ int _pat_num;
+
+ const char* _errmsg;
+};
+
+bool
+PatternSet::ExtractPattern(const char* filepath) {
+ if (!isDictFile(filepath))
+ return false;
+
+ struct stat filestat;
+ if (stat(filepath, &filestat)) {
+ _errmsg = "fail to call stat()";
+ return false;
+ }
+
+ if (filestat.st_size > 4096 * 1024) {
+ /* It dosen't seem to be a dictionary file*/
+ _errmsg = "file too big?";
+ return false;
+ }
+
+ _fd = open(filepath, 0);
+ if (_fd == -1) {
+ _errmsg = "fail to open dictionary file";
+ return false;
+ }
+
+ _mmap_size = filestat.st_size;
+ _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE,
+ MAP_PRIVATE, _fd, 0);
+ if (_mmap == MAP_FAILED) {
+ _errmsg = "fail to call mmap";
+ return false;
+ }
+
+ const char* pat = _mmap;
+ vector<const char*> pat_vect;
+ vector<unsigned> pat_len_vect;
+
+ for (size_t i = 0, e = filestat.st_size; i < e; i++) {
+ if (_mmap[i] == '\r' || _mmap[i] == '\n') {
+ _mmap[i] = '\0';
+ int len = _mmap + i - pat;
+ if (len > 0) {
+ pat_vect.push_back(pat);
+ pat_len_vect.push_back(len);
+ }
+ pat = _mmap + i + 1;
+ }
+ }
+
+ ASSERT(pat_vect.size() == pat_len_vect.size());
+
+ int pat_num = pat_vect.size();
+ if (pat_num > 0) {
+ const char** p = _patterns = new const char*[pat_num];
+ int i = 0;
+ for (vector<const char*>::iterator iter = pat_vect.begin(),
+ iter_e = pat_vect.end(); iter != iter_e; ++iter) {
+ p[i++] = *iter;
+ }
+
+ i = 0;
+ unsigned int* q = _pat_len = new unsigned int[pat_num];
+ for (vector<unsigned>::iterator iter = pat_len_vect.begin(),
+ iter_e = pat_len_vect.end(); iter != iter_e; ++iter) {
+ q[i++] = *iter;
+ }
+ }
+
+ _pat_num = pat_num;
+ if (pat_num <= 0) {
+ _errmsg = "no pattern at all";
+ return false;
+ }
+
+ return true;
+}
+
+void
+PatternSet::Cleanup() {
+ if (_mmap != MAP_FAILED) {
+ munmap(_mmap, _mmap_size);
+ _mmap = (char*)MAP_FAILED;
+ _mmap_size = 0;
+ }
+
+ delete[] _patterns;
+ delete[] _pat_len;
+ if (_fd != -1)
+ close(_fd);
+ _pat_num = -1;
+}
+
+PatternSet::PatternSet(const char* filepath) {
+ _patterns = 0;
+ _pat_len = 0;
+ _mmap = (char*)MAP_FAILED;
+ _mmap_size = 0;
+ _pat_num = -1;
+ _errmsg = "";
+
+ if (!ExtractPattern(filepath))
+ Cleanup();
+}
+
+bool
+getFilesUnderDir(vector<string>& files, const char* path) {
+ files.clear();
+
+ DIR* dir = opendir(path);
+ if (!dir)
+ return false;
+
+ string path_dir = path;
+ path_dir += "/";
+
+ for (;;) {
+ struct dirent* entry = readdir(dir);
+ if (entry) {
+ string filepath = path_dir + entry->d_name;
+ struct stat file_stat;
+ if (stat(filepath.c_str(), &file_stat)) {
+ closedir(dir);
+ return false;
+ }
+
+ if (S_ISREG(file_stat.st_mode))
+ files.push_back(filepath);
+
+ continue;
+ }
+
+ if (errno) {
+ return false;
+ }
+ break;
+ }
+ closedir(dir);
+ return true;
+}
+
+class Timer {
+public:
+ Timer() {
+ my_clock_gettime(&_start);
+ _stop = _start;
+ _acc.tv_sec = 0;
+ _acc.tv_nsec = 0;
+ }
+
+ const Timer& operator += (const Timer& that) {
+ time_t sec = _acc.tv_sec + that._acc.tv_sec;
+ long nsec = _acc.tv_nsec + that._acc.tv_nsec;
+ if (nsec > 1000000000) {
+ nsec -= 1000000000;
+ sec += 1;
+ }
+ _acc.tv_sec = sec;
+ _acc.tv_nsec = nsec;
+ return *this;
+ }
+
+ // return duration in us
+ size_t getDuration() const {
+ return _acc.tv_sec * (size_t)1000000 + _acc.tv_nsec/1000;
+ }
+
+ void Start(bool acc=true) {
+ my_clock_gettime(&_start);
+ }
+
+ void Stop() {
+ my_clock_gettime(&_stop);
+ struct timespec t = CalcDuration();
+ _acc = add_duration(_acc, t);
+ }
+
+private:
+ int my_clock_gettime(struct timespec* t) {
+#ifdef __linux
+ return clock_gettime(CLOCK_PROCESS_CPUTIME_ID, t);
+#else
+ struct timeval tv;
+ int rc = gettimeofday(&tv, 0);
+ t->tv_sec = tv.tv_sec;
+ t->tv_nsec = tv.tv_usec * 1000;
+ return rc;
+#endif
+ }
+
+ struct timespec add_duration(const struct timespec& dur1,
+ const struct timespec& dur2) {
+ time_t sec = dur1.tv_sec + dur2.tv_sec;
+ long nsec = dur1.tv_nsec + dur2.tv_nsec;
+ if (nsec > 1000000000) {
+ nsec -= 1000000000;
+ sec += 1;
+ }
+ timespec t;
+ t.tv_sec = sec;
+ t.tv_nsec = nsec;
+
+ return t;
+ }
+
+ struct timespec CalcDuration() const {
+ timespec diff;
+ if ((_stop.tv_nsec - _start.tv_nsec)<0) {
+ diff.tv_sec = _stop.tv_sec - _start.tv_sec - 1;
+ diff.tv_nsec = 1000000000 + _stop.tv_nsec - _start.tv_nsec;
+ } else {
+ diff.tv_sec = _stop.tv_sec - _start.tv_sec;
+ diff.tv_nsec = _stop.tv_nsec - _start.tv_nsec;
+ }
+ return diff;
+ }
+
+ struct timespec _start;
+ struct timespec _stop;
+ struct timespec _acc;
+};
+
+class Benchmark {
+public:
+ Benchmark(const PatternSet& pat_set, const char* infile):
+ _pat_set(pat_set), _infile(infile) {
+ _mmap = (char*)MAP_FAILED;
+ _file_sz = 0;
+ _fd = -1;
+ }
+
+ ~Benchmark() {
+ if (_mmap != MAP_FAILED)
+ munmap(_mmap, _file_sz);
+ if (_fd != -1)
+ close(_fd);
+ }
+
+ bool Run(int iteration);
+ const Timer& getTimer() const { return _timer; }
+
+private:
+ const PatternSet& _pat_set;
+ const char* _infile;
+ char* _mmap;
+ int _fd;
+ size_t _file_sz; // input file size
+ Timer _timer;
+};
+
+bool
+Benchmark::Run(int iteration) {
+ if (_pat_set.getPatternNum() <= 0) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ if (_mmap == MAP_FAILED) {
+ struct stat filestat;
+ if (stat(_infile, &filestat)) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ if (!S_ISREG(filestat.st_mode)) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ _fd = open(_infile, 0);
+ if (_fd == -1)
+ return false;
+
+ _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE,
+ MAP_PRIVATE, _fd, 0);
+
+ if (_mmap == MAP_FAILED) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ _file_sz = filestat.st_size;
+ }
+
+ ac_t* ac = ac_create(_pat_set.getPatternVector(),
+ _pat_set.getPatternLenVector(),
+ _pat_set.getPatternNum());
+ if (!ac) {
+ SomethingWrong = true;
+ return false;
+ }
+
+ int piece_num = _file_sz/piece_size;
+
+ _timer.Start(false);
+
+ /* Stupid compiler may not be able to promote piece_size into register.
+ * Do it manually.
+ */
+ int piece_sz = piece_size;
+ for (int i = 0; i < iteration; i++) {
+ size_t match_ofst = 0;
+ for (int piece_idx = 0; piece_idx < piece_num; piece_idx ++) {
+ ac_match2(ac, _mmap + match_ofst, piece_sz);
+ match_ofst += piece_sz;
+ }
+ if (match_ofst != _file_sz)
+ ac_match2(ac, _mmap + match_ofst, _file_sz - match_ofst);
+ }
+ _timer.Stop();
+ return true;
+}
+
+const char* short_opt = "hd:f:i:p:";
+const struct option long_opts[] = {
+ {"help", no_argument, 0, 'h'},
+ {"iteration", required_argument, 0, 'i'},
+ {"dictionary-dir", required_argument, 0, 'd'},
+ {"obj-file-dir", required_argument, 0, 'f'},
+ {"piece-size", required_argument, 0, 'p'},
+};
+
+static void
+PrintHelp(const char* prog_name) {
+ const char* msg =
+"Usage %s [OPTIONS]\n"
+" -d, --dictionary-dir : specify the dictionary directory (./dict by default)\n"
+" -f, --obj-file-dir : specify the object file directory\n"
+" (./testinput by default)\n"
+" -i, --iteration : Run this many iteration for each pattern match\n"
+" -p, --piece-size : The size of 'piece' in byte. The input file is\n"
+" divided into pieces, and match function is working\n"
+" on one piece at a time. The default size of piece\n"
+" is 1k byte.\n";
+
+ fprintf(stdout, msg, prog_name);
+}
+
+static bool
+getOptions(int argc, char** argv) {
+ bool dict_dir_set = false;
+ bool objfile_dir_set = false;
+ int opt_index;
+
+ while (1) {
+ if (print_help) break;
+
+ int c = getopt_long(argc, argv, short_opt, long_opts, &opt_index);
+
+ if (c == -1) break;
+ if (c == 0) { c = long_opts[opt_index].val; }
+
+ switch(c) {
+ case 'h':
+ print_help = true;
+ break;
+
+ case 'i':
+ iteration = atol(optarg);
+ break;
+
+ case 'd':
+ dict_dir = optarg;
+ dict_dir_set = true;
+ break;
+
+ case 'f':
+ obj_file_dir = optarg;
+ objfile_dir_set = true;
+ break;
+
+ case 'p':
+ piece_size = atol(optarg);
+ break;
+
+ case '?':
+ default:
+ return false;
+ }
+ }
+
+ if (print_help)
+ return true;
+
+ string basedir(dirname(argv[0]));
+ if (!dict_dir_set)
+ dict_dir = basedir + "/dict";
+
+ if (!objfile_dir_set)
+ obj_file_dir = basedir + "/testinput";
+
+ return true;
+}
+
+int
+main(int argc, char** argv) {
+ if (!getOptions(argc, argv))
+ return -1;
+
+ if (print_help) {
+ PrintHelp(argv[0]);
+ return 0;
+ }
+
+#ifndef __linux
+ fprintf(stdout, "\n!!!WARNING: On this OS, the execution time is measured"
+ " by gettimeofday(2) which is imprecise!!!\n\n");
+#endif
+
+ fprintf(stdout, "Test with iteration = %d, piece size = %d, and",
+ iteration, piece_size);
+ fprintf(stdout, "\n dictionary dir = %s\n object file dir = %s\n\n",
+ dict_dir.c_str(), obj_file_dir.c_str());
+
+ vector<string> dict_files;
+ vector<string> input_files;
+
+ if (!getFilesUnderDir(dict_files, dict_dir.c_str())) {
+ fprintf(stdout, "fail to find dictionary files\n");
+ return -1;
+ }
+
+ if (!getFilesUnderDir(input_files, obj_file_dir.c_str())) {
+ fprintf(stdout, "fail to find test input files\n");
+ return -1;
+ }
+
+ for (vector<string>::iterator diter = dict_files.begin(),
+ diter_e = dict_files.end(); diter != diter_e; ++diter) {
+
+ const char* dict_name = diter->c_str();
+ if (!PatternSet::isDictFile(dict_name))
+ continue;
+
+ PatternSet ps(dict_name);
+ if (ps.getPatternNum() <= 0) {
+ fprintf(stdout, "fail to open dictionary file %s : %s\n",
+ dict_name, ps.getErrMessage());
+ SomethingWrong = true;
+ continue;
+ }
+
+ fprintf(stdout, "Using dictionary %s\n", dict_name);
+ Timer timer;
+ for (vector<string>::iterator iter = input_files.begin(),
+ iter_e = input_files.end(); iter != iter_e; ++iter) {
+ fprintf(stdout, " testing %s ... ", iter->c_str());
+ fflush(stdout);
+ Benchmark bm(ps, iter->c_str());
+ bm.Run(iteration);
+ const Timer& t = bm.getTimer();
+ timer += bm.getTimer();
+ fprintf(stdout, "elapsed %.3f\n", t.getDuration() / 1000000.0);
+ }
+
+ fprintf(stdout,
+ "\n==========================================================\n"
+ " Total Elapse %.3f\n\n", timer.getDuration() / 1000000.0);
+ }
+
+ return SomethingWrong ? -1 : 0;
+}
diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx
new file mode 100644
index 0000000..4ea02bc
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx
@@ -0,0 +1,135 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+using namespace std;
+
+namespace {
+class ACBigFileTester : public BigFileTester {
+public:
+ ACBigFileTester(const char* filepath) : BigFileTester(filepath){};
+
+private:
+ virtual buf_header_t* PM_Create(const char** strv, uint32* strlenv,
+ uint32 vect_len) {
+ return (buf_header_t*)ac_create(strv, strlenv, vect_len);
+ }
+
+ virtual void PM_Free(buf_header_t* PM) { ac_free(PM); }
+ virtual bool Run_Helper(buf_header_t* PM);
+};
+
+class ACTestAggressive: public ACTestBase {
+public:
+ ACTestAggressive(const vector<const char*>& files, const char* banner)
+ : ACTestBase(banner), _files(files) {}
+ virtual bool Run();
+
+private:
+ void PrintSummary(int total, int fail) {
+ fprintf(stdout, "Test count : %d, fail: %d\n", total, fail);
+ fflush(stdout);
+ }
+ vector<const char*> _files;
+};
+
+} // end of anonymous namespace
+
+bool
+ACBigFileTester::Run_Helper(buf_header_t* PM) {
+ int fail = 0;
+ // advance one chunk at a time.
+ int len = _msg_len;
+ int chunk_sz = _chunk_sz;
+
+ vector<const char*> c_style_keys;
+ for (int i = 0, e = _keys.size(); i != e; i++) {
+ const char* key = _keys[i].first;
+ int len = _keys[i].second;
+ char *t = new char[len+1];
+ memcpy(t, key, len);
+ t[len] = '\0';
+ c_style_keys.push_back(t);
+ }
+
+ for (int ofst = 0, chunk_idx = 0, chunk_num = _chunk_num;
+ chunk_idx < chunk_num; ofst += chunk_sz, chunk_idx++) {
+ const char* substring = _msg + ofst;
+ ac_result_t r = ac_match((ac_t*)(void*)PM, substring , len - ofst);
+ int m_b = r.match_begin;
+ int m_e = r.match_end;
+
+ if (m_b < 0 || m_e < 0 || m_e <= m_b || m_e >= len) {
+ fprintf(stdout, "fail to find match substring[%d:%d])\n",
+ ofst, len - 1);
+ fail ++;
+ continue;
+ }
+
+ const char* match_str = _msg + len;
+ int strstr_len = 0;
+ int key_idx = -1;
+
+ for (int i = 0, e = c_style_keys.size(); i != e; i++) {
+ const char* key = c_style_keys[i];
+ if (const char *m = strstr(substring, key)) {
+ if (m < match_str) {
+ match_str = m;
+ strstr_len = _keys[i].second;
+ key_idx = i;
+ }
+ }
+ }
+ ASSERT(key_idx != -1);
+ if ((match_str - substring != m_b)) {
+ fprintf(stdout,
+ "Fail to find match substring[%d:%d]),"
+ " expected to find match at offset %d instead of %d\n",
+ ofst, len - 1,
+ (int)(match_str - _msg), ofst + m_b);
+ fprintf(stdout, "%d vs %d (key idx %d)\n", strstr_len, m_e - m_b + 1, key_idx);
+ PrintStr(stdout, match_str, strstr_len);
+ fprintf(stdout, "\n");
+ PrintStr(stdout, _msg + ofst + m_b,
+ m_e - m_b + 1);
+ fprintf(stdout, "\n");
+ fail ++;
+ }
+ }
+ for (vector<const char*>::iterator i = c_style_keys.begin(),
+ e = c_style_keys.end(); i != e; i++) {
+ delete[] *i;
+ }
+
+ return fail == 0;
+}
+
+bool
+ACTestAggressive::Run() {
+ int fail = 0;
+ for (vector<const char*>::iterator i = _files.begin(), e = _files.end();
+ i != e; i++) {
+ ACBigFileTester bft(*i);
+ if (!bft.Run())
+ fail ++;
+ }
+ return fail == 0;
+}
+
+bool
+Run_AC_Aggressive_Test(const vector<const char*>& files) {
+ ACTestAggressive t(files, "AC Aggressive test");
+ t.PrintBanner();
+ return t.Run();
+}
diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx
new file mode 100644
index 0000000..fa2d7fd
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx
@@ -0,0 +1,275 @@
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+using namespace std;
+
+namespace {
+typedef struct {
+ const char* str;
+ const char* match;
+} StrPair;
+
+typedef enum {
+ MV_FIRST_MATCH = 0,
+ MV_LEFT_LONGEST = 1,
+} MatchVariant;
+
+typedef struct {
+ const char* name;
+ const char** dict;
+ StrPair* strpairs;
+ int dict_len;
+ int strpair_num;
+ MatchVariant match_variant;
+} TestingCase;
+
+class Tests {
+public:
+ Tests(const char* name,
+ const char* dict[], int dict_len,
+ StrPair strpairs[], int strpair_num,
+ MatchVariant mv = MV_FIRST_MATCH) {
+ if (!_tests)
+ _tests = new vector<TestingCase>;
+
+ TestingCase tc;
+ tc.name = name;
+ tc.dict = dict;
+ tc.strpairs = strpairs;
+ tc.dict_len = dict_len;
+ tc.strpair_num = strpair_num;
+ tc.match_variant = mv;
+ _tests->push_back(tc);
+ }
+
+ static vector<TestingCase>* Get_Tests() { return _tests; }
+ static void Erase_Tests() { delete _tests; _tests = 0; }
+
+private:
+ static vector<TestingCase> *_tests;
+};
+
+class LeftLongestTests : public Tests {
+public:
+ LeftLongestTests (const char* name, const char* dict[], int dict_len,
+ StrPair strpairs[], int strpair_num):
+ Tests(name, dict, dict_len, strpairs, strpair_num, MV_LEFT_LONGEST) {
+ }
+};
+
+vector<TestingCase>* Tests::_tests = 0;
+
+class ACTestSimple: public ACTestBase {
+public:
+ ACTestSimple(const char* banner) : ACTestBase(banner) {}
+ virtual bool Run();
+
+private:
+ void PrintSummary(int total, int fail) {
+ fprintf(stdout, "Test count : %d, fail: %d\n", total, fail);
+ fflush(stdout);
+ }
+};
+}
+
+bool
+ACTestSimple::Run() {
+ int total = 0;
+ int fail = 0;
+
+ vector<TestingCase> *tests = Tests::Get_Tests();
+ if (!tests) {
+ PrintSummary(0, 0);
+ return true;
+ }
+
+ for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end();
+ i != e; i++) {
+ TestingCase& t = *i;
+ int dict_len = t.dict_len;
+ unsigned int* strlen_v = new unsigned int[dict_len];
+
+ fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name);
+ for (int i = 0, need_break=0; i < dict_len; i++) {
+ const char* s = t.dict[i];
+ fprintf(stdout, "%s, ", s);
+ strlen_v[i] = strlen(s);
+ if (need_break++ == 16) {
+ fputs("\n ", stdout);
+ need_break = 0;
+ }
+ }
+ fputs("]\n", stdout);
+
+ /* Create the dictionary */
+ ac_t* ac = ac_create(t.dict, strlen_v, dict_len);
+ delete[] strlen_v;
+
+ for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) {
+ const StrPair& sp = t.strpairs[ii];
+ const char *str = sp.str; // the string to be matched
+ const char *match = sp.match;
+
+ fprintf(stdout, "[%3d] Testing '%s' : ", total, str);
+
+ int len = strlen(str);
+ ac_result_t r;
+ if (t.match_variant == MV_FIRST_MATCH)
+ r = ac_match(ac, str, len);
+ else if (t.match_variant == MV_LEFT_LONGEST)
+ r = ac_match_longest_l(ac, str, len);
+ else {
+ ASSERT(false && "Unknown variant");
+ }
+
+ int m_b = r.match_begin;
+ int m_e = r.match_end;
+
+ // The return value per se is insane.
+ if (m_b > m_e ||
+ ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) {
+ fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e);
+ fail ++;
+ continue;
+ }
+
+ // If the string is not supposed to match the dictionary.
+ if (!match) {
+ if (m_b != -1 || m_e != -1) {
+ fail ++;
+ fprintf(stdout, "Not Supposed to match (%d, %d) \n",
+ m_b, m_e);
+ } else
+ fputs("Pass\n", stdout);
+ continue;
+ }
+
+ // The string or its substring is match the dict.
+ if (m_b >= len || m_b >= len) {
+ fail ++;
+ fprintf(stdout,
+ "Return value >= the length of the string (%d, %d)\n",
+ m_b, m_e);
+ continue;
+ } else {
+ int mlen = strlen(match);
+ if ((mlen != m_e - m_b + 1) ||
+ strncmp(str + m_b, match, mlen)) {
+ fail ++;
+ fprintf(stdout, "Fail\n");
+ } else
+ fprintf(stdout, "Pass\n");
+ }
+ }
+ fputs("\n", stdout);
+ ac_free(ac);
+ }
+
+ PrintSummary(total, fail);
+ return fail == 0;
+}
+
+bool
+Run_AC_Simple_Test() {
+ ACTestSimple t("AC Simple test");
+ t.PrintBanner();
+ return t.Run();
+}
+
+//////////////////////////////////////////////////////////////////////////////
+//
+// Testing cases for first-match variant (i.e. test ac_match())
+//
+//////////////////////////////////////////////////////////////////////////////
+//
+
+/* test 1*/
+const char *dict1[] = {"he", "she", "his", "her"};
+StrPair strpair1[] = {
+ {"he", "he"}, {"she", "she"}, {"his", "his"},
+ {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"},
+ {"shis2", "his"}, {"ahhe", "he"}
+};
+Tests test1("test 1",
+ dict1, sizeof(dict1)/sizeof(dict1[0]),
+ strpair1, sizeof(strpair1)/sizeof(strpair1[0]));
+
+/* test 2*/
+const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/
+StrPair strpair2[] = {{"The pot had a handle", 0}};
+Tests test2("test 2", dict2, 2, strpair2, 1);
+
+/* test 3*/
+const char *dict3[] = {"The"};
+StrPair strpair3[] = {{"The pot had a handle", "The"}};
+Tests test3("test 3", dict3, 1, strpair3, 1);
+
+/* test 4*/
+const char *dict4[] = {"pot"};
+StrPair strpair4[] = {{"The pot had a handle", "pot"}};
+Tests test4("test 4", dict4, 1, strpair4, 1);
+
+/* test 5*/
+const char *dict5[] = {"pot "};
+StrPair strpair5[] = {{"The pot had a handle", "pot "}};
+Tests test5("test 5", dict5, 1, strpair5, 1);
+
+/* test 6*/
+const char *dict6[] = {"ot h"};
+StrPair strpair6[] = {{"The pot had a handle", "ot h"}};
+Tests test6("test 6", dict6, 1, strpair6, 1);
+
+/* test 7*/
+const char *dict7[] = {"andle"};
+StrPair strpair7[] = {{"The pot had a handle", "andle"}};
+Tests test7("test 7", dict7, 1, strpair7, 1);
+
+const char *dict8[] = {"aaab"};
+StrPair strpair8[] = {{"aaaaaaab", "aaab"}};
+Tests test8("test 8", dict8, 1, strpair8, 1);
+
+const char *dict9[] = {"haha", "z"};
+StrPair strpair9[] = {{"aaaaz", "z"}, {"z", "z"}};
+Tests test9("test 9", dict9, 2, strpair9, 2);
+
+/* test the case when input string dosen't contain even a single char
+ * of the pattern in dictionary.
+ */
+const char *dict10[] = {"abc"};
+StrPair strpair10[] = {{"cde", 0}};
+Tests test10("test 10", dict10, 1, strpair10, 1);
+
+
+//////////////////////////////////////////////////////////////////////////////
+//
+// Testing cases for first longest match variant (i.e.
+// test ac_match_longest_l())
+//
+//////////////////////////////////////////////////////////////////////////////
+//
+
+// This was actually first motivation for left-longest-match
+const char *dict100[] = {"Mozilla", "Mozilla Mobile"};
+StrPair strpair100[] = {{"User Agent containing string Mozilla Mobile", "Mozilla Mobile"}};
+LeftLongestTests test100("l_test 100", dict100, 2, strpair100, 1);
+
+// Dict with single char is tricky
+const char *dict101[] = {"a", "abc"};
+StrPair strpair101[] = {{"abcdef", "abc"}};
+LeftLongestTests test101("l_test 101", dict101, 2, strpair101, 1);
+
+// Testing case with partially overlapping patterns. The purpose is to
+// check if the fail-link leading from terminal state is correct.
+//
+// The fail-link leading from terminal-state does not matter in
+// match-first-occurrence variant, as it stop when a terminal is hit.
+//
+const char *dict102[] = {"abc", "bcdef"};
+StrPair strpair102[] = {{"abcdef", "bcdef"}};
+LeftLongestTests test102("l_test 102", dict102, 2, strpair102, 1);
diff --git a/modules/policy/lua-aho-corasick/tests/dict/README.txt b/modules/policy/lua-aho-corasick/tests/dict/README.txt
new file mode 100644
index 0000000..cd50b41
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/dict/README.txt
@@ -0,0 +1 @@
+This directory contains pattern set of benchmark purpose.
diff --git a/modules/policy/lua-aho-corasick/tests/dict/dict1.txt b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt
new file mode 100644
index 0000000..94085a9
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt
@@ -0,0 +1,11 @@
+false_return@
+forloop#haha
+wtfprogram
+mmaporunmap
+ThIs?Module!IsEssential
+struct rtlwtf
+gettIMEOfdayWrong
+edistribution_and_use_in_@source
+Copyright~#@
+while {!
+!%SQLinje
diff --git a/modules/policy/lua-aho-corasick/tests/load_ac_test.lua b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua
new file mode 100644
index 0000000..7fb7db9
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua
@@ -0,0 +1,82 @@
+-- This script is to test load_ac.lua
+--
+-- Some notes:
+-- 1. The purpose of this script is not to check if the libac.so work
+-- properly, it is to check if there are something stupid in load_ac.lua
+--
+-- 2. There are bunch of collectgarbage() calls, the purpose is to make
+-- sure the shared lib is not unloaded after GC.
+
+-- load_ac.lua looks up libac.so via package.cpath rather than LD_LIBRARY_PATH,
+-- prepend (instead of appending) some insane paths here to see if it quit
+-- prematurely.
+--
+package.cpath = ".;./?.so;" .. package.cpath
+
+local ac = require "load_ac"
+
+local ac_create = ac.create_ac
+local ac_match = ac.match
+local string_fmt = string.format
+local string_sub = string.sub
+
+local err_cnt = 0
+local function mytest(testname, dict, match, notmatch)
+ print(">Testing ", testname)
+
+ io.write(string_fmt("Dictionary: "));
+ for i=1, #dict do
+ io.write(string_fmt("%s, ", dict[i]))
+ end
+ print ""
+
+ local ac_inst = ac_create(dict);
+ collectgarbage()
+ for i=1, #match do
+ local str = match[i]
+ io.write(string_fmt("Matching %s, ", str))
+ local b = ac_match(ac_inst, str)
+ if b then
+ print "pass"
+ else
+ err_cnt = err_cnt + 1
+ print "fail"
+ end
+ collectgarbage()
+ end
+
+ if notmatch == nil then
+ return
+ end
+
+ collectgarbage()
+
+ for i = 1, #notmatch do
+ local str = notmatch[i]
+ io.write(string_fmt("*Matching %s, ", str))
+ local r = ac_match(ac_inst, str)
+ if r then
+ err_cnt = err_cnt + 1
+ print("fail")
+ else
+ print("succ")
+ end
+ collectgarbage()
+ end
+ ac_inst = nil
+ collectgarbage()
+end
+
+print("")
+print("====== Test to see if load_ac.lua works properly ========")
+
+mytest("test1",
+ {"he", "she", "his", "her", "str\0ing"},
+ -- matching cases
+ { "he", "she", "his", "hers", "ahe", "shhe", "shis2", "ahhe", "str\0ing" },
+
+ -- not matching case
+ {"str\0", "str"}
+ )
+
+os.exit((err_cnt == 0) and 0 or 1)
diff --git a/modules/policy/lua-aho-corasick/tests/lua_test.lua b/modules/policy/lua-aho-corasick/tests/lua_test.lua
new file mode 100644
index 0000000..cfe178f
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/lua_test.lua
@@ -0,0 +1,67 @@
+-- This script is to test ahocorasick.so not libac.so
+--
+local ac = require "ahocorasick"
+
+local ac_create = ac.create
+local ac_match = ac.match
+local string_fmt = string.format
+local string_sub = string.sub
+
+local err_cnt = 0
+local function mytest(testname, dict, match, notmatch)
+ print(">Testing ", testname)
+
+ io.write(string_fmt("Dictionary: "));
+ for i=1, #dict do
+ io.write(string_fmt("%s, ", dict[i]))
+ end
+ print ""
+
+ local ac_inst = ac_create(dict);
+ for i=1, #match do
+ local str = match[i][1]
+ local substr = match[i][2]
+ io.write(string_fmt("Matching %s, ", str))
+ local b, e = ac_match(ac_inst, str)
+ if b and e and (string_sub(str, b+1, e+1) == substr) then
+ print "pass"
+ else
+ err_cnt = err_cnt + 1
+ print "fail"
+ end
+ --print("gc is called")
+ collectgarbage()
+ end
+
+ if notmatch == nil then
+ return
+ end
+
+ for i = 1, #notmatch do
+ local str = notmatch[i]
+ io.write(string_fmt("*Matching %s, ", str))
+ local r = ac_match(ac_inst, str)
+ if r then
+ err_cnt = err_cnt + 1
+ print("fail")
+ else
+ print("succ")
+ end
+ collectgarbage()
+ end
+end
+
+mytest("test1",
+ {"he", "she", "his", "her", "str\0ing"},
+ -- matching cases
+ { {"he", "he"}, {"she", "she"}, {"his", "his"}, {"hers", "he"},
+ {"ahe", "he"}, {"shhe", "he"}, {"shis2", "his"}, {"ahhe", "he"},
+ {"str\0ing", "str\0ing"}
+ },
+
+ -- not matching case
+ {"str\0", "str"}
+
+ )
+
+os.exit((err_cnt == 0) and 0 or 1)
diff --git a/modules/policy/lua-aho-corasick/tests/test_base.hpp b/modules/policy/lua-aho-corasick/tests/test_base.hpp
new file mode 100644
index 0000000..7758371
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/test_base.hpp
@@ -0,0 +1,60 @@
+#ifndef TEST_BASE_H
+#define TEST_BASE_H
+
+#include <stdio.h>
+#include <string>
+#include <stdint.h>
+
+using namespace std;
+class ACTestBase {
+public:
+ ACTestBase(const char* name) :_banner(name) {}
+ virtual void PrintBanner() {
+ fprintf(stdout, "\n===== %s ====\n", _banner.c_str());
+ }
+
+ virtual bool Run() = 0;
+private:
+ string _banner;
+};
+
+typedef std::pair<const char*, int> StrInfo;
+class BigFileTester {
+public:
+ BigFileTester(const char* filepath);
+ virtual ~BigFileTester() { Cleanup(); }
+
+ bool Run();
+
+protected:
+ virtual buf_header_t* PM_Create(const char** strv, uint32_t* strlenv,
+ uint32_t vect_len) = 0;
+ virtual void PM_Free(buf_header_t*) = 0;
+ virtual bool Run_Helper(buf_header_t* PM) = 0;
+
+ // Return true if the '\0' is valid char of a string.
+ virtual bool Str_C_Style() { return true; }
+
+ bool GenerateKeys();
+ void Cleanup();
+ void PrintStr(FILE*, const char* str, int len);
+
+protected:
+ const char* _filepath;
+ int _fd;
+ vector<StrInfo> _keys;
+ char* _msg;
+ int _msg_len;
+ int _key_num; // number of strings in dictionary
+ int _chunk_sz;
+ int _chunk_num;
+
+ int _max_key_num;
+ int _key_min_len;
+ int _key_max_len;
+};
+
+extern bool Run_AC_Simple_Test();
+extern bool Run_AC_Aggressive_Test(const vector<const char*>& files);
+
+#endif
diff --git a/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx
new file mode 100644
index 0000000..f189d8d
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx
@@ -0,0 +1,167 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+///////////////////////////////////////////////////////////////////////////
+//
+// Implementation of BigFileTester
+//
+///////////////////////////////////////////////////////////////////////////
+//
+BigFileTester::BigFileTester(const char* filepath) {
+ _filepath = filepath;
+ _fd = -1;
+ _msg = (char*)MAP_FAILED;
+ _msg_len = 0;
+ _key_num = 0;
+ _chunk_sz = 0;
+ _chunk_num = 0;
+
+ _max_key_num = 100;
+ _key_min_len = 20;
+ _key_max_len = 80;
+}
+
+void
+BigFileTester::Cleanup() {
+ if (_msg != MAP_FAILED) {
+ munmap((void*)_msg, _msg_len);
+ _msg = (char*)MAP_FAILED;
+ _msg_len = 0;
+ }
+
+ if (_fd != -1) {
+ close(_fd);
+ _fd = -1;
+ }
+}
+
+bool
+BigFileTester::GenerateKeys() {
+ int chunk_sz = 4096;
+ int max_key_num = _max_key_num;
+ int key_min_len = _key_min_len;
+ int key_max_len = _key_max_len;
+
+ int t = _msg_len / chunk_sz;
+ int keynum = t > max_key_num ? max_key_num : t;
+
+ if (keynum <= 4) {
+ // file is too small
+ return false;
+ }
+ chunk_sz = _msg_len / keynum;
+ _chunk_sz = chunk_sz;
+
+ // For each chunck, "randomly" grab a sub-string searving
+ // as key.
+ int random_ofst[] = { 12, 30, 23, 15 };
+ int rofstsz = sizeof(random_ofst)/sizeof(random_ofst[0]);
+ int ofst = 0;
+ const char* msg = _msg;
+ _chunk_num = keynum - 1;
+ for (int idx = 0, e = _chunk_num; idx < e; idx++) {
+ const char* key = msg + ofst + idx % rofstsz;
+ int key_len = key_min_len + idx % (key_max_len - key_min_len);
+ _keys.push_back(StrInfo(key, key_len));
+ ofst += chunk_sz;
+ }
+ return true;
+}
+
+bool
+BigFileTester::Run() {
+ // Step 1: Bring the file into memory
+ fprintf(stdout, "Testing using file '%s'...\n", _filepath);
+
+ int fd = _fd = ::open(_filepath, O_RDONLY);
+ if (fd == -1) {
+ perror("open");
+ return false;
+ }
+
+ struct stat sb;
+ if (fstat(fd, &sb) == -1) {
+ perror("fstat");
+ return false;
+ }
+
+ if (!S_ISREG (sb.st_mode)) {
+ fprintf(stderr, "%s is not regular file\n", _filepath);
+ return false;
+ }
+
+ int ten_M = 1024 * 1024 * 10;
+ int map_sz = _msg_len = sb.st_size > ten_M ? ten_M : sb.st_size;
+ char* p = _msg =
+ (char*)mmap (0, map_sz, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0);
+ if (p == MAP_FAILED) {
+ perror("mmap");
+ return false;
+ }
+
+ // Get rid of '\0' if we are picky at it.
+ if (Str_C_Style()) {
+ for (int i = 0; i < map_sz; i++) { if (!p[i]) p[i] = 'a'; }
+ p[map_sz - 1] = 0;
+ }
+
+ // Step 2: "Fabricate" some keys from the file.
+ if (!GenerateKeys()) {
+ close(fd);
+ return false;
+ }
+
+ // Step 3: Create PM instance
+ const char** keys = new const char*[_keys.size()];
+ unsigned int* keylens = new unsigned int[_keys.size()];
+
+ int i = 0;
+ for (vector<StrInfo>::iterator si = _keys.begin(), se = _keys.end();
+ si != se; si++, i++) {
+ const StrInfo& strinfo = *si;
+ keys[i] = strinfo.first;
+ keylens[i] = strinfo.second;
+ }
+
+ buf_header_t* PM = PM_Create(keys, keylens, i);
+ delete[] keys;
+ delete[] keylens;
+
+ // Step 4: Run testing
+ bool res = Run_Helper(PM);
+ PM_Free(PM);
+
+ // Step 5: Clanup
+ munmap(p, map_sz);
+ _msg = (char*)MAP_FAILED;
+ close(fd);
+ _fd = -1;
+
+ fprintf(stdout, "%s\n", res ? "succ" : "fail");
+ return res;
+}
+
+void
+BigFileTester::PrintStr(FILE* f, const char* str, int len) {
+ fprintf(f, "{");
+ for (int i = 0; i < len; i++) {
+ unsigned char c = str[i];
+ if (isprint(c))
+ fprintf(f, "'%c', ", c);
+ else
+ fprintf(f, "%#x, ", c);
+ }
+ fprintf(f, "}");
+};
diff --git a/modules/policy/lua-aho-corasick/tests/test_main.cxx b/modules/policy/lua-aho-corasick/tests/test_main.cxx
new file mode 100644
index 0000000..b4f5225
--- /dev/null
+++ b/modules/policy/lua-aho-corasick/tests/test_main.cxx
@@ -0,0 +1,33 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <sys/mman.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+#include <stdio.h>
+#include <string.h>
+#include <vector>
+#include <string>
+#include "ac.h"
+#include "ac_util.hpp"
+#include "test_base.hpp"
+
+using namespace std;
+
+
+/////////////////////////////////////////////////////////////////////////
+//
+// Simple (yet maybe tricky) testings
+//
+/////////////////////////////////////////////////////////////////////////
+//
+int
+main (int argc, char** argv) {
+ bool succ = Run_AC_Simple_Test();
+
+ vector<const char*> files;
+ for (int i = 1; i < argc; i++) { files.push_back(argv[i]); }
+ succ = Run_AC_Aggressive_Test(files) && succ;
+
+ return succ ? 0 : -1;
+};
diff --git a/modules/policy/meson.build b/modules/policy/meson.build
new file mode 100644
index 0000000..fd51b0f
--- /dev/null
+++ b/modules/policy/meson.build
@@ -0,0 +1,49 @@
+# LUA module: policy
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+lua_mod_src += [
+ files('policy.lua'),
+]
+
+config_tests += [
+ ['policy', files('policy.test.lua')],
+ ['policy.slice', files('policy.slice.test.lua')],
+ ['policy.rpz', files('policy.rpz.test.lua')],
+]
+
+integr_tests += [
+ ['policy', meson.current_source_dir() / 'test.integr'],
+ ['policy.noipv6', meson.current_source_dir() / 'noipv6.test.integr'],
+ ['policy.noipvx', meson.current_source_dir() / 'noipvx.test.integr'],
+]
+
+# check git submodules were initialized
+lua_ac_submodule = run_command(['test', '-r',
+ '@0@/lua-aho-corasick/ac_fast.cxx'.format(meson.current_source_dir())])
+if lua_ac_submodule.returncode() != 0
+ error('run "git submodule update --init --recursive" to initialize git submodules')
+endif
+
+# compile bundled lua-aho-corasick as shared module
+lua_ac_src = files([
+ 'lua-aho-corasick/ac_fast.cxx',
+ 'lua-aho-corasick/ac_lua.cxx',
+ 'lua-aho-corasick/ac_slow.cxx',
+])
+
+lua_ac_lib = shared_module(
+ 'ahocorasick',
+ lua_ac_src,
+ cpp_args: [
+ '-fvisibility=hidden',
+ '-Wall',
+ '-fPIC',
+ ],
+ dependencies: [
+ luajit_inc,
+ ],
+ include_directories: mod_inc_dir,
+ name_prefix: '',
+ install: true,
+ install_dir: lib_dir,
+)
diff --git a/modules/policy/noipv6.test.integr/broken-ipv6.rpl b/modules/policy/noipv6.test.integr/broken-ipv6.rpl
new file mode 100644
index 0000000..cb0738d
--- /dev/null
+++ b/modules/policy/noipv6.test.integr/broken-ipv6.rpl
@@ -0,0 +1,47 @@
+; config options
+; SPDX-License-Identifier: GPL-3.0-or-later
+ stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET.
+CONFIG_END
+
+SCENARIO_BEGIN Test that IPv6 is not used by kresd.
+
+RANGE_BEGIN 0 100
+ ADDRESS ::1:2:3:4
+RANGE_END
+
+RANGE_BEGIN 0 100
+ ADDRESS 1.2.3.4
+
+ENTRY_BEGIN
+MATCH opcode qtype qname
+ADJUST copy_id
+REPLY QR NOERROR
+SECTION QUESTION
+www.test.org A
+SECTION ANSWER
+www.test.org 3600 A 4.3.2.1
+ENTRY_END
+
+RANGE_END
+
+
+STEP 10 QUERY
+ENTRY_BEGIN
+REPLY RD AD
+SECTION QUESTION
+www.test.org A
+ENTRY_END
+
+STEP 20 CHECK_ANSWER
+ENTRY_BEGIN
+MATCH all answer
+REPLY QR RD RA NOERROR
+SECTION QUESTION
+www.test.org A
+SECTION ANSWER
+www.test.org 3600 A 4.3.2.1
+SECTION AUTHORITY
+SECTION ADDITIONAL
+ENTRY_END
+
+SCENARIO_END
diff --git a/modules/policy/noipv6.test.integr/deckard.yaml b/modules/policy/noipv6.test.integr/deckard.yaml
new file mode 100644
index 0000000..4c1b6f8
--- /dev/null
+++ b/modules/policy/noipv6.test.integr/deckard.yaml
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+programs:
+- name: kresd
+ binary: kresd
+ additional:
+ - --noninteractive
+ templates:
+ - modules/policy/noipv6.test.integr/kresd_config.j2
+ - tests/integration/hints_zone.j2
+ configs:
+ - config
+ - hints
diff --git a/modules/policy/noipv6.test.integr/kresd_config.j2 b/modules/policy/noipv6.test.integr/kresd_config.j2
new file mode 100644
index 0000000..8e8c263
--- /dev/null
+++ b/modules/policy/noipv6.test.integr/kresd_config.j2
@@ -0,0 +1,59 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+{% raw %}
+net.ipv6 = false
+policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' })))
+
+-- make sure DNSSEC is turned off for tests
+trust_anchors.remove('.')
+
+-- Disable RFC5011 TA update
+if ta_update then
+ modules.unload('ta_update')
+end
+
+-- Disable RFC8145 signaling, scenario doesn't provide expected answers
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+
+-- Disable RFC8109 priming, scenario doesn't provide expected answers
+if priming then
+ modules.unload('priming')
+end
+
+-- Disable this module because it make one priming query
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+_hint_root_file('hints')
+cache.size = 2*MB
+verbose(true)
+{% endraw %}
+
+net = { '{{SELF_ADDR}}' }
+
+
+{% if QMIN == "false" %}
+option('NO_MINIMIZE', true)
+{% else %}
+option('NO_MINIMIZE', false)
+{% endif %}
+
+
+-- Self-checks on globals
+assert(help() ~= nil)
+assert(worker.id ~= nil)
+-- Self-checks on facilities
+assert(cache.count() == 0)
+assert(cache.stats() ~= nil)
+assert(cache.backends() ~= nil)
+assert(worker.stats() ~= nil)
+assert(net.interfaces() ~= nil)
+-- Self-checks on loaded stuff
+assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
+assert(#modules.list() > 0)
+-- Self-check timers
+ev = event.recurrent(1 * sec, function (ev) return 1 end)
+event.cancel(ev)
+ev = event.after(0, function (ev) return 1 end)
diff --git a/modules/policy/noipvx.test.integr/broken-ipvx.rpl b/modules/policy/noipvx.test.integr/broken-ipvx.rpl
new file mode 100644
index 0000000..60ed618
--- /dev/null
+++ b/modules/policy/noipvx.test.integr/broken-ipvx.rpl
@@ -0,0 +1,35 @@
+; config options
+; SPDX-License-Identifier: GPL-3.0-or-later
+ stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET.
+CONFIG_END
+
+SCENARIO_BEGIN Test that neither IPv6 nor IPv4 is used by kresd :-)
+
+RANGE_BEGIN 0 100
+ ADDRESS ::1:2:3:4
+RANGE_END
+
+RANGE_BEGIN 0 100
+ ADDRESS 1.2.3.4
+RANGE_END
+
+
+STEP 10 QUERY
+ENTRY_BEGIN
+REPLY RD AD
+SECTION QUESTION
+www.test.org A
+ENTRY_END
+
+STEP 20 CHECK_ANSWER
+ENTRY_BEGIN
+MATCH all answer
+REPLY QR RD RA SERVFAIL
+SECTION QUESTION
+www.test.org A
+SECTION ANSWER
+SECTION AUTHORITY
+SECTION ADDITIONAL
+ENTRY_END
+
+SCENARIO_END
diff --git a/modules/policy/noipvx.test.integr/deckard.yaml b/modules/policy/noipvx.test.integr/deckard.yaml
new file mode 100644
index 0000000..8178759
--- /dev/null
+++ b/modules/policy/noipvx.test.integr/deckard.yaml
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+programs:
+- name: kresd
+ binary: kresd
+ additional:
+ - --noninteractive
+ templates:
+ - modules/policy/noipvx.test.integr/kresd_config.j2
+ - tests/integration/hints_zone.j2
+ configs:
+ - config
+ - hints
diff --git a/modules/policy/noipvx.test.integr/kresd_config.j2 b/modules/policy/noipvx.test.integr/kresd_config.j2
new file mode 100644
index 0000000..771927b
--- /dev/null
+++ b/modules/policy/noipvx.test.integr/kresd_config.j2
@@ -0,0 +1,60 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+{% raw %}
+net.ipv4 = false
+net.ipv6 = false
+policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' })))
+
+-- make sure DNSSEC is turned off for tests
+trust_anchors.remove('.')
+
+-- Disable RFC5011 TA update
+if ta_update then
+ modules.unload('ta_update')
+end
+
+-- Disable RFC8145 signaling, scenario doesn't provide expected answers
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+
+-- Disable RFC8109 priming, scenario doesn't provide expected answers
+if priming then
+ modules.unload('priming')
+end
+
+-- Disable this module because it make one priming query
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+_hint_root_file('hints')
+cache.size = 2*MB
+verbose(true)
+{% endraw %}
+
+net = { '{{SELF_ADDR}}' }
+
+
+{% if QMIN == "false" %}
+option('NO_MINIMIZE', true)
+{% else %}
+option('NO_MINIMIZE', false)
+{% endif %}
+
+
+-- Self-checks on globals
+assert(help() ~= nil)
+assert(worker.id ~= nil)
+-- Self-checks on facilities
+assert(cache.count() == 0)
+assert(cache.stats() ~= nil)
+assert(cache.backends() ~= nil)
+assert(worker.stats() ~= nil)
+assert(net.interfaces() ~= nil)
+-- Self-checks on loaded stuff
+assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
+assert(#modules.list() > 0)
+-- Self-check timers
+ev = event.recurrent(1 * sec, function (ev) return 1 end)
+event.cancel(ev)
+ev = event.after(0, function (ev) return 1 end)
diff --git a/modules/policy/policy.lua b/modules/policy/policy.lua
new file mode 100644
index 0000000..e7c0cdb
--- /dev/null
+++ b/modules/policy/policy.lua
@@ -0,0 +1,998 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+local kres = require('kres')
+local ffi = require('ffi')
+
+local todname = kres.str2dname -- not available during module load otherwise
+
+-- Counter of unique rules
+local nextid = 0
+local function getruleid()
+ local newid = nextid
+ nextid = nextid + 1
+ return newid
+end
+
+-- Support for client sockets from inside policy actions
+local socket_client = function ()
+ return error("missing lua-cqueues library, can't create socket client")
+end
+local has_socket, socket = pcall(require, 'cqueues.socket')
+if has_socket then
+ socket_client = function (host, port)
+ local s, err, status
+
+ s = socket.connect({ host = host, port = port, type = socket.SOCK_DGRAM })
+ s:setmode('bn', 'bn')
+ status, err = pcall(s.connect, s)
+
+ if not status then
+ return status, err
+ end
+ return s
+ end
+end
+
+-- Split address and port from a combined string.
+local function addr_split_port(target, default_port)
+ assert(default_port and type(default_port) == 'number')
+ local port = ffi.new('uint16_t[1]', default_port)
+ local addr = ffi.new('char[47]') -- INET6_ADDRSTRLEN + 1
+ local ret = ffi.C.kr_straddr_split(target, addr, port)
+ if ret ~= 0 then
+ error('failed to parse address ' .. target)
+ end
+ return addr, tonumber(port[0])
+end
+
+-- String address@port -> sockaddr.
+local function addr2sock(target, default_port)
+ local addr, port = addr_split_port(target, default_port)
+ local sock = ffi.gc(ffi.C.kr_straddr_socket(addr, port, nil), ffi.C.free);
+ if sock == nil then
+ error("target '"..target..'" is not a valid IP address')
+ end
+ return sock
+end
+
+-- policy functions are defined below
+local policy = {}
+
+function policy.PASS(state, _)
+ return state
+end
+
+-- Mirror request elsewhere, and continue solving
+function policy.MIRROR(target)
+ local addr, port = addr_split_port(target, 53)
+ local sink, err = socket_client(ffi.string(addr), port)
+ if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end
+ return function(state, req)
+ if state == kres.FAIL then return state end
+ local query = req.qsource.packet
+ if query ~= nil then
+ sink:send(ffi.string(query.wire, query.size), 1, tonumber(query.size))
+ end
+ return -- Chain action to next
+ end
+end
+
+-- Override the list of nameservers (forwarders)
+local function set_nslist(req, list)
+ local ns_i = 0
+ for _, ns in ipairs(list) do
+ if ffi.C.kr_forward_add_target(req, ns) == 0 then
+ ns_i = ns_i + 1
+ end
+ end
+ if ns_i == 0 then
+ -- would use assert() but don't want to compose the message if not triggered
+ error('no usable address in NS set (check net.ipv4 and '
+ .. 'net.ipv6 config):\n' .. table_print(list, 2))
+ end
+end
+
+-- Forward request, and solve as stub query
+function policy.STUB(target)
+ local list = {}
+ if type(target) == 'table' then
+ for _, v in pairs(target) do
+ table.insert(list, addr2sock(v, 53))
+ end
+ else
+ table.insert(list, addr2sock(target, 53))
+ end
+ return function(state, req)
+ local qry = req:current()
+ -- Switch mode to stub resolver, do not track origin zone cut since it's not real authority NS
+ qry.flags.STUB = true
+ qry.flags.ALWAYS_CUT = false
+ set_nslist(req, list)
+ return state
+ end
+end
+
+-- Forward request and all subrequests to upstream; validate answers
+function policy.FORWARD(target)
+ local list = {}
+ if type(target) == 'table' then
+ for _, v in pairs(target) do
+ table.insert(list, addr2sock(v, 53))
+ end
+ else
+ table.insert(list, addr2sock(target, 53))
+ end
+ return function(state, req)
+ local qry = req:current()
+ req.options.FORWARD = true
+ req.options.NO_MINIMIZE = true
+ qry.flags.FORWARD = true
+ qry.flags.ALWAYS_CUT = false
+ qry.flags.NO_MINIMIZE = true
+ qry.flags.AWAIT_CUT = true
+ set_nslist(req, list)
+ return state
+ end
+end
+
+-- Forward request and all subrequests to upstream over TLS; validate answers
+function policy.TLS_FORWARD(targets)
+ if type(targets) ~= 'table' or #targets < 1 then
+ error('TLS_FORWARD argument must be a non-empty table')
+ end
+
+ local sockaddr_c_set = {}
+ local nslist = {} -- to persist in closure of the returned function
+ for idx, target in pairs(targets) do
+ if type(target) ~= 'table' or type(target[1]) ~= 'string' then
+ error(string.format('TLS_FORWARD configuration at position ' ..
+ '%d must be a table starting with an IP address', idx))
+ end
+ -- Note: some functions have checks with error() calls inside.
+ local sockaddr_c = addr2sock(target[1], 853)
+
+ -- Refuse repeated addresses in the same set.
+ local sockaddr_lua = ffi.string(sockaddr_c, ffi.C.kr_sockaddr_len(sockaddr_c))
+ if sockaddr_c_set[sockaddr_lua] then
+ error('TLS_FORWARD configuration cannot declare two configs for IP address '
+ .. target[1])
+ else
+ sockaddr_c_set[sockaddr_lua] = true;
+ end
+
+ table.insert(nslist, sockaddr_c)
+ net.tls_client(target)
+ end
+
+ return function(state, req)
+ local qry = req:current()
+ req.options.FORWARD = true
+ req.options.NO_MINIMIZE = true
+ qry.flags.FORWARD = true
+ qry.flags.ALWAYS_CUT = false
+ qry.flags.NO_MINIMIZE = true
+ qry.flags.AWAIT_CUT = true
+ req.options.TCP = true
+ qry.flags.TCP = true
+ set_nslist(req, nslist)
+ return state
+ end
+end
+
+-- Rewrite records in packet
+function policy.REROUTE(tbl, names)
+ -- Import renumbering rules
+ local ren = require('kres_modules.renumber')
+ local prefixes = {}
+ for from, to in pairs(tbl) do
+ table.insert(prefixes, names and ren.name(from, to) or ren.prefix(from, to))
+ end
+ -- Return rule closure
+ return ren.rule(prefixes)
+end
+
+-- Set and clear some query flags
+function policy.FLAGS(opts_set, opts_clear)
+ return function(_, req)
+ local qry = req:current()
+ ffi.C.kr_qflags_set (qry.flags, kres.mk_qflags(opts_set or {}))
+ ffi.C.kr_qflags_clear(qry.flags, kres.mk_qflags(opts_clear or {}))
+ return nil -- chain rule
+ end
+end
+
+local function mkauth_soa(answer, dname, mname, ttl)
+ if mname == nil then
+ mname = dname
+ end
+ return answer:put(dname, ttl or 10800, answer:qclass(), kres.type.SOA,
+ mname .. '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48')
+end
+
+-- Create answer with passed arguments
+function policy.ANSWER(rtable, nodata)
+ return function(_, req)
+ local qry = req:current()
+ local data = rtable[qry.stype]
+ if data == nil and nodata ~= true then
+ return nil
+ end
+ -- now we're certain we want to generate an answer
+
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ ffi.C.kr_pkt_make_auth_header(answer)
+ local ttl = (data or {}).ttl or 1
+ answer:rcode(kres.rcode.NOERROR)
+
+ if data == nil then -- want NODATA, i.e. just a SOA
+ answer:begin(kres.section.AUTHORITY)
+ local soa = rtable[kres.type.SOA]
+ if soa ~= nil then
+ answer:put(qry.sname, soa.ttl or ttl, qry.sclass, kres.type.SOA,
+ soa.rdata[1] or soa.rdata)
+ else
+ mkauth_soa(answer, kres.dname2wire(qry.sname), nil, ttl)
+ end
+ else
+ answer:begin(kres.section.ANSWER)
+ if type(data.rdata) == 'table' then
+ for _, rdato in ipairs(data.rdata) do
+ answer:put(qry.sname, ttl, qry.sclass, qry.stype, rdato)
+ end
+ else
+ answer:put(qry.sname, ttl, qry.sclass, qry.stype, data.rdata)
+ end
+ end
+ return kres.DONE
+ end
+end
+
+local dname_localhost = todname('localhost.')
+
+-- Rule for localhost. zone; see RFC6303, sec. 3
+local function localhost(_, req)
+ local qry = req:current()
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ ffi.C.kr_pkt_make_auth_header(answer)
+
+ local is_exact = ffi.C.knot_dname_is_equal(qry.sname, dname_localhost)
+
+ answer:rcode(kres.rcode.NOERROR)
+ answer:begin(kres.section.ANSWER)
+ if qry.stype == kres.type.AAAA then
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.AAAA,
+ '\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1')
+ elseif qry.stype == kres.type.A then
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\127\0\0\1')
+ elseif is_exact and qry.stype == kres.type.SOA then
+ mkauth_soa(answer, dname_localhost)
+ elseif is_exact and qry.stype == kres.type.NS then
+ answer:put(dname_localhost, 900, answer:qclass(), kres.type.NS, dname_localhost)
+ else
+ answer:begin(kres.section.AUTHORITY)
+ mkauth_soa(answer, dname_localhost)
+ end
+ return kres.DONE
+end
+
+local dname_rev4_localhost = todname('1.0.0.127.in-addr.arpa');
+local dname_rev4_localhost_apex = todname('127.in-addr.arpa');
+
+-- Rule for reverse localhost.
+-- Answer with locally served minimal 127.in-addr.arpa domain, only having
+-- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone.
+-- TODO: much of this would better be left to the hints module (or coordinated).
+local function localhost_reversed(_, req)
+ local qry = req:current()
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+
+ -- classify qry.sname:
+ local is_exact -- exact dname for localhost
+ local is_apex -- apex of a locally-served localhost zone
+ local is_nonterm -- empty non-terminal name
+ if ffi.C.knot_dname_in_bailiwick(qry.sname, todname('ip6.arpa.')) > 0 then
+ -- exact ::1 query (relying on the calling rule)
+ is_exact = true
+ is_apex = true
+ else
+ -- within 127.in-addr.arpa.
+ local labels = ffi.C.knot_dname_labels(qry.sname, nil)
+ if labels == 3 then
+ is_exact = false
+ is_apex = true
+ elseif labels == 4+2 and ffi.C.knot_dname_is_equal(
+ qry.sname, dname_rev4_localhost) then
+ is_exact = true
+ else
+ is_exact = false
+ is_apex = false
+ is_nonterm = ffi.C.knot_dname_in_bailiwick(dname_rev4_localhost, qry.sname) > 0
+ end
+ end
+
+ ffi.C.kr_pkt_make_auth_header(answer)
+ answer:rcode(kres.rcode.NOERROR)
+ answer:begin(kres.section.ANSWER)
+ if is_exact and qry.stype == kres.type.PTR then
+ answer:put(qry.sname, 900, answer:qclass(), kres.type.PTR, dname_localhost)
+ elseif is_apex and qry.stype == kres.type.SOA then
+ mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost)
+ elseif is_apex and qry.stype == kres.type.NS then
+ answer:put(dname_rev4_localhost_apex, 900, answer:qclass(), kres.type.NS,
+ dname_localhost)
+ else
+ if not is_nonterm then
+ answer:rcode(kres.rcode.NXDOMAIN)
+ end
+ answer:begin(kres.section.AUTHORITY)
+ mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost)
+ end
+ return kres.DONE
+end
+
+-- All requests
+function policy.all(action)
+ return function(_, _) return action end
+end
+
+-- Requests which QNAME matches given zone list (i.e. suffix match)
+function policy.suffix(action, zone_list)
+ local AC = require('ahocorasick')
+ local tree = AC.create(zone_list)
+ return function(_, query)
+ local match = AC.match(tree, query:name(), false)
+ if match ~= nil then
+ return action
+ end
+ return nil
+ end
+end
+
+-- Check for common suffix first, then suffix match (specialized version of suffix match)
+function policy.suffix_common(action, suffix_list, common_suffix)
+ local common_len = string.len(common_suffix)
+ local suffix_count = #suffix_list
+ return function(_, query)
+ -- Preliminary check
+ local qname = query:name()
+ if not string.find(qname, common_suffix, -common_len, true) then
+ return nil
+ end
+ -- String match
+ for i = 1, suffix_count do
+ local zone = suffix_list[i]
+ if string.find(qname, zone, -string.len(zone), true) then
+ return action
+ end
+ end
+ return nil
+ end
+end
+
+-- Filter QNAME pattern
+function policy.pattern(action, pattern)
+ return function(_, query)
+ if string.find(query:name(), pattern) then
+ return action
+ end
+ return nil
+ end
+end
+
+local function rpz_parse(action, path)
+ local rules = {}
+ local new_actions = {}
+ local action_map = {
+ -- RPZ Policy Actions
+ ['\0'] = action,
+ ['\1*\0'] = policy.ANSWER({}, true),
+ ['\012rpz-passthru\0'] = policy.PASS, -- the grammar...
+ ['\008rpz-drop\0'] = policy.DROP,
+ ['\012rpz-tcp-only\0'] = policy.TC,
+ -- Policy triggers @NYI@
+ }
+ -- RR types to be skipped; boolean denoting whether to throw a warning even for RPZ apex.
+ local rrtype_bad = {
+ [kres.type.DNAME] = true,
+ [kres.type.NS] = false,
+ [kres.type.SOA] = false,
+ [kres.type.DNSKEY] = true,
+ [kres.type.DS] = true,
+ [kres.type.RRSIG] = true,
+ [kres.type.NSEC] = true,
+ [kres.type.NSEC3] = true,
+ }
+ local parser = require('zonefile').new()
+ local ok, errstr = parser:open(path)
+ if not ok then
+ error(string.format('failed to parse "%s": %s', path, errstr or "unknown error"))
+ end
+ while true do
+ ok, errstr = parser:parse()
+ if errstr then
+ log('[poli] RPZ %s:%d: %s', path, tonumber(parser.line_counter), errstr)
+ end
+ if not ok then break end
+
+ local full_name = ffi.gc(ffi.C.knot_dname_copy(parser.r_owner, nil), ffi.C.free)
+ local rdata = ffi.string(parser.r_data, parser.r_data_length)
+ ffi.C.knot_dname_to_lower(full_name)
+
+ local prefix_labels = ffi.C.knot_dname_in_bailiwick(full_name, parser.zone_origin)
+ if prefix_labels < 0 then
+ log('[poli] RPZ %s:%d: RR owner "%s" outside the zone (ignored)',
+ path, tonumber(parser.line_counter), kres.dname2str(full_name))
+ goto continue
+ end
+
+ local bytes = ffi.C.knot_dname_size(full_name) - ffi.C.knot_dname_size(parser.zone_origin)
+ local name = ffi.string(full_name, bytes) .. '\0'
+
+ if parser.r_type == kres.type.CNAME then
+ if action_map[rdata] then
+ rules[name] = action_map[rdata]
+ else
+ log('[poli] RPZ %s:%d: CNAME with custom target in RPZ is not supported yet (ignored)',
+ path, tonumber(parser.line_counter))
+ end
+ else
+ if #name then
+ local is_bad = rrtype_bad[parser.r_type]
+ if is_bad == true or (is_bad == false and prefix_labels ~= 0) then
+ log('[poli] RPZ %s:%d warning: RR type %s is not allowed in RPZ (ignored)',
+ path, tonumber(parser.line_counter), kres.tostring.type[parser.r_type])
+ elseif is_bad == nil then
+ if new_actions[name] == nil then new_actions[name] = {} end
+ local act = new_actions[name][parser.r_type]
+ if act == nil then
+ new_actions[name][parser.r_type] = { ttl=parser.r_ttl, rdata=rdata }
+ else -- mutiple RRs: no reordering or deduplication
+ if type(act.rdata) ~= 'table' then
+ act.rdata = { act.rdata }
+ end
+ table.insert(act.rdata, rdata)
+ if parser.r_ttl ~= act.ttl then -- be conservative
+ log('[poli] RPZ %s:%d warning: different TTLs in a set (minimum taken)',
+ path, tonumber(parser.line_counter))
+ act.ttl = math.min(act.ttl, parser.r_ttl)
+ end
+ end
+ else
+ assert(is_bad == false and prefix_labels == 0)
+ end
+ end
+ end
+
+ ::continue::
+ end
+ collectgarbage()
+ for qname, rrsets in pairs(new_actions) do
+ rules[qname] = policy.ANSWER(rrsets, true)
+ end
+ return rules
+end
+
+-- Split path into dirname and basename (like the shell utilities)
+local function get_dir_and_file(path)
+ local dir, file = string.match(path, "(.*)/([^/]+)")
+
+ -- If regex doesn't match then path must be the file directly (i.e. doesn't contain '/')
+ -- This assumes that the file exists (rpz_parse() would fail if it doesn't)
+ if not dir and not file then
+ dir = '.'
+ file = path
+ end
+
+ return dir, file
+end
+
+-- RPZ policy set
+-- Create RPZ from zone file and optionally watch the file for changes
+function policy.rpz(action, path, watch)
+ local rules = rpz_parse(action, path)
+
+ if watch ~= false then
+ local has_notify, notify = pcall(require, 'cqueues.notify')
+ if has_notify then
+ local bit = require('bit')
+
+ local dir, file = get_dir_and_file(path)
+ local watcher = notify.opendir(dir)
+ watcher:add(file, bit.bxor(notify.CREATE, notify.MODIFY))
+
+ worker.coroutine(function ()
+ for _, name in watcher:changes() do
+ -- Limit to changes on file we're interested in
+ -- Watcher will also fire for changes to the directory itself
+ if name == file then
+ -- If the file changes then reparse and replace the existing ruleset
+ if verbose() then
+ log('[poli] RPZ reloading: ' .. name)
+ end
+ rules = rpz_parse(action, path)
+ end
+ end
+ end)
+ elseif watch then -- explicitly requested and failed
+ error('[poli] lua-cqueues required to watch and reload RPZ file')
+ elseif verbose() then
+ log('[poli] lua-cqueues required to watch and reload RPZ file, continuing without watching')
+ end
+ end
+
+ return function(_, query)
+ local label = query:name()
+ local rule = rules[label]
+ while rule == nil and string.len(label) > 0 do
+ label = string.sub(label, string.byte(label) + 2)
+ rule = rules['\1*'..label]
+ end
+ return rule
+ end
+end
+
+-- Apply an action when query belongs to a slice (determined by slice_func())
+function policy.slice(slice_func, ...)
+ local actions = {...}
+ if #actions <= 0 then
+ error('[poli] at least one action must be provided to policy.slice()')
+ end
+
+ return function(_, query)
+ local index = slice_func(query, #actions)
+ return actions[index]
+ end
+end
+
+-- Initializes slicing function that randomly assigns queries to a slice based on their registrable domain
+function policy.slice_randomize_psl(seed)
+ local has_psl, psl_lib = pcall(require, 'psl')
+ if not has_psl then
+ error('[poli] lua-psl is required for policy.slice_randomize_psl()')
+ end
+ -- load psl
+ local has_latest, psl = pcall(psl_lib.latest)
+ if not has_latest then -- compatiblity with lua-psl < 0.15
+ psl = psl_lib.builtin()
+ end
+
+ if seed == nil then
+ seed = os.time() / (3600 * 24 * 7)
+ end
+ seed = math.floor(seed) -- convert to int
+
+ return function(query, length)
+ assert(length > 0)
+
+ local domain = kres.dname2str(query:name())
+ if domain == nil then -- invalid data: auto-select first action
+ return 1
+ end
+ if domain:len() > 1 then --remove trailing dot
+ domain = domain:sub(0, -2)
+ end
+
+ -- do psl lookup for registrable domain
+ local reg_domain = psl:registrable_domain(domain)
+ if reg_domain == nil then -- fallback to unreg. domain
+ reg_domain = psl:unregistrable_domain(domain)
+ if reg_domain == nil then -- shouldn't happen: safe fallback
+ return 1
+ end
+ end
+
+ local rand_seed = seed
+ -- create deterministic seed for pseudo-random slice assignment
+ for i = 1, #reg_domain do
+ rand_seed = rand_seed + reg_domain:byte(i)
+ end
+
+ -- use lineral congruential generator with values from ANSI C
+ rand_seed = rand_seed % 0x80000000 -- ensure seed is positive 32b int
+ local rand = (1103515245 * rand_seed + 12345) % 0x10000
+ return 1 + rand % length
+ end
+end
+
+-- Prepare for making an answer from scratch. (Return the packet for convenience.)
+local function answer_clear(req)
+ -- If we're in postrules, previous resolving might have chosen some RRs
+ -- for inclusion in the answer, so we need to avoid those.
+ -- *_selected arrays are in mempool, so explicit deallocation is not necessary.
+ req.answ_selected.len = 0
+ req.auth_selected.len = 0
+ req.add_selected.len = 0
+
+ -- Let's be defensive and clear the answer, too.
+ local pkt = req:ensure_answer()
+ if pkt == nil then return nil end
+ pkt:clear_payload()
+ return pkt
+end
+
+function policy.DENY_MSG(msg)
+ if msg and (type(msg) ~= 'string' or #msg >= 255) then
+ error('DENY_MSG: optional msg must be string shorter than 256 characters')
+ end
+
+ return function (_, req)
+ -- Write authority information
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ ffi.C.kr_pkt_make_auth_header(answer)
+ answer:rcode(kres.rcode.NXDOMAIN)
+ answer:begin(kres.section.AUTHORITY)
+ mkauth_soa(answer, answer:qname())
+ if msg then
+ answer:begin(kres.section.ADDITIONAL)
+ answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT,
+ string.char(#msg) .. msg)
+
+ end
+ return kres.DONE
+ end
+end
+
+local function free_cb(func)
+ func:free()
+end
+
+local debug_logline_cb = ffi.cast('trace_log_f', function (_, msg)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+ -- msg typically ends with newline
+ io.write(ffi.string(msg))
+end)
+ffi.gc(debug_logline_cb, free_cb)
+
+local debug_logfinish_cb = ffi.cast('trace_callback_f', function (req)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+ ffi.C.kr_log_req(req, 0, 0, 'dbg',
+ 'following rrsets were marked as interesting:\n' ..
+ req:selected_tostring())
+ if req.answer ~= nil then
+ ffi.C.kr_log_req(req, 0, 0, 'dbg',
+ 'answer packet:\n' ..
+ tostring(req.answer))
+ else
+ ffi.C.kr_log_req(req, 0, 0, 'dbg',
+ 'answer packet DROPPED\n')
+ end
+end)
+ffi.gc(debug_logfinish_cb, free_cb)
+
+-- log request packet
+function policy.REQTRACE(_, req)
+ ffi.C.kr_log_req(req, 0, 0, 'dbg', 'request packet:\n%s',
+ tostring(req.qsource.packet))
+end
+
+function policy.DEBUG_ALWAYS(state, req)
+ policy.QTRACE(state, req)
+ req:trace_chain_callbacks(debug_logline_cb, debug_logfinish_cb)
+ policy.REQTRACE(state, req)
+end
+
+local debug_stashlog_cb = ffi.cast('trace_log_f', function (req, msg)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+
+ -- stash messages for conditional logging in trace_finish
+ local stash = req:vars()['policy_debug_stash']
+ table.insert(stash, ffi.string(msg))
+end)
+ffi.gc(debug_stashlog_cb, free_cb)
+
+-- buffer verbose logs and print then only if test() returns a truthy value
+function policy.DEBUG_IF(test)
+ local debug_finish_cb = ffi.cast('trace_callback_f', function (cbreq)
+ jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed
+ if test(cbreq) then
+ debug_logfinish_cb(cbreq) -- unconditional version
+ local stash = cbreq:vars()['policy_debug_stash']
+ io.write(table.concat(stash, ''))
+ end
+ end)
+ ffi.gc(debug_finish_cb, function (func) func:free() end)
+
+ return function (state, req)
+ req:vars()['policy_debug_stash'] = {}
+ policy.QTRACE(state, req)
+ req:trace_chain_callbacks(debug_stashlog_cb, debug_finish_cb)
+ policy.REQTRACE(state, req)
+ return
+ end
+end
+
+policy.DEBUG_CACHE_MISS = policy.DEBUG_IF(
+ function(req)
+ return not req:all_from_cache()
+ end
+)
+
+policy.DENY = policy.DENY_MSG() -- compatibility with < 2.0
+
+function policy.DROP(_, req)
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ return kres.FAIL
+end
+
+function policy.REFUSE(_, req)
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ answer:rcode(kres.rcode.REFUSED)
+ answer:ad(false)
+ return kres.DONE
+end
+
+function policy.TC(state, req)
+ -- Avoid non-UDP queries
+ if req.qsource.addr == nil or req.qsource.flags.tcp then
+ return state
+ end
+
+ local answer = answer_clear(req)
+ if answer == nil then return nil end
+ answer:tc(1)
+ answer:ad(false)
+ return kres.DONE
+end
+
+function policy.QTRACE(_, req)
+ local qry = req:current()
+ req.options.TRACE = true
+ qry.flags.TRACE = true
+ return -- this allows to continue iterating over policy list
+end
+
+-- Evaluate packet in given rules to determine policy action
+function policy.evaluate(rules, req, query, state)
+ for i = 1, #rules do
+ local rule = rules[i]
+ if not rule.suspended then
+ local action = rule.cb(req, query)
+ if action ~= nil then
+ rule.count = rule.count + 1
+ local next_state = action(state, req)
+ if next_state then -- Not a chain rule,
+ return next_state -- stop on first match
+ end
+ end
+ end
+ end
+ return
+end
+
+-- Add rule to policy list
+function policy.add(rule, postrule)
+ -- Compatibility with 1.0.0 API
+ -- it will be dropped in 1.2.0
+ if rule == policy then
+ rule = postrule
+ postrule = nil
+ end
+ -- End of compatibility shim
+ local desc = {id=getruleid(), cb=rule, count=0}
+ table.insert(postrule and policy.postrules or policy.rules, desc)
+ return desc
+end
+
+-- Remove rule from a list
+local function delrule(rules, id)
+ for i, r in ipairs(rules) do
+ if r.id == id then
+ table.remove(rules, i)
+ return true
+ end
+ end
+ return false
+end
+
+-- Delete rule from policy list
+function policy.del(id)
+ if not delrule(policy.rules, id) then
+ if not delrule(policy.postrules, id) then
+ return false
+ end
+ end
+ return true
+end
+
+-- Convert list of string names to domain names
+function policy.todnames(names)
+ for i, v in ipairs(names) do
+ names[i] = kres.str2dname(v)
+ end
+ return names
+end
+
+-- RFC1918 Private, local, broadcast, test and special zones
+-- Considerations: RFC6761, sec 6.1.
+-- https://www.iana.org/assignments/locally-served-dns-zones
+local private_zones = {
+ -- RFC6303
+ '10.in-addr.arpa.',
+ '16.172.in-addr.arpa.',
+ '17.172.in-addr.arpa.',
+ '18.172.in-addr.arpa.',
+ '19.172.in-addr.arpa.',
+ '20.172.in-addr.arpa.',
+ '21.172.in-addr.arpa.',
+ '22.172.in-addr.arpa.',
+ '23.172.in-addr.arpa.',
+ '24.172.in-addr.arpa.',
+ '25.172.in-addr.arpa.',
+ '26.172.in-addr.arpa.',
+ '27.172.in-addr.arpa.',
+ '28.172.in-addr.arpa.',
+ '29.172.in-addr.arpa.',
+ '30.172.in-addr.arpa.',
+ '31.172.in-addr.arpa.',
+ '168.192.in-addr.arpa.',
+ '0.in-addr.arpa.',
+ '254.169.in-addr.arpa.',
+ '2.0.192.in-addr.arpa.',
+ '100.51.198.in-addr.arpa.',
+ '113.0.203.in-addr.arpa.',
+ '255.255.255.255.in-addr.arpa.',
+ -- RFC7793
+ '64.100.in-addr.arpa.',
+ '65.100.in-addr.arpa.',
+ '66.100.in-addr.arpa.',
+ '67.100.in-addr.arpa.',
+ '68.100.in-addr.arpa.',
+ '69.100.in-addr.arpa.',
+ '70.100.in-addr.arpa.',
+ '71.100.in-addr.arpa.',
+ '72.100.in-addr.arpa.',
+ '73.100.in-addr.arpa.',
+ '74.100.in-addr.arpa.',
+ '75.100.in-addr.arpa.',
+ '76.100.in-addr.arpa.',
+ '77.100.in-addr.arpa.',
+ '78.100.in-addr.arpa.',
+ '79.100.in-addr.arpa.',
+ '80.100.in-addr.arpa.',
+ '81.100.in-addr.arpa.',
+ '82.100.in-addr.arpa.',
+ '83.100.in-addr.arpa.',
+ '84.100.in-addr.arpa.',
+ '85.100.in-addr.arpa.',
+ '86.100.in-addr.arpa.',
+ '87.100.in-addr.arpa.',
+ '88.100.in-addr.arpa.',
+ '89.100.in-addr.arpa.',
+ '90.100.in-addr.arpa.',
+ '91.100.in-addr.arpa.',
+ '92.100.in-addr.arpa.',
+ '93.100.in-addr.arpa.',
+ '94.100.in-addr.arpa.',
+ '95.100.in-addr.arpa.',
+ '96.100.in-addr.arpa.',
+ '97.100.in-addr.arpa.',
+ '98.100.in-addr.arpa.',
+ '99.100.in-addr.arpa.',
+ '100.100.in-addr.arpa.',
+ '101.100.in-addr.arpa.',
+ '102.100.in-addr.arpa.',
+ '103.100.in-addr.arpa.',
+ '104.100.in-addr.arpa.',
+ '105.100.in-addr.arpa.',
+ '106.100.in-addr.arpa.',
+ '107.100.in-addr.arpa.',
+ '108.100.in-addr.arpa.',
+ '109.100.in-addr.arpa.',
+ '110.100.in-addr.arpa.',
+ '111.100.in-addr.arpa.',
+ '112.100.in-addr.arpa.',
+ '113.100.in-addr.arpa.',
+ '114.100.in-addr.arpa.',
+ '115.100.in-addr.arpa.',
+ '116.100.in-addr.arpa.',
+ '117.100.in-addr.arpa.',
+ '118.100.in-addr.arpa.',
+ '119.100.in-addr.arpa.',
+ '120.100.in-addr.arpa.',
+ '121.100.in-addr.arpa.',
+ '122.100.in-addr.arpa.',
+ '123.100.in-addr.arpa.',
+ '124.100.in-addr.arpa.',
+ '125.100.in-addr.arpa.',
+ '126.100.in-addr.arpa.',
+ '127.100.in-addr.arpa.',
+
+ -- RFC6303
+ -- localhost_reversed handles ::1
+ '0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.',
+ 'd.f.ip6.arpa.',
+ '8.e.f.ip6.arpa.',
+ '9.e.f.ip6.arpa.',
+ 'a.e.f.ip6.arpa.',
+ 'b.e.f.ip6.arpa.',
+ '8.b.d.0.1.0.0.2.ip6.arpa.',
+ -- RFC8375
+ 'home.arpa.',
+}
+policy.todnames(private_zones)
+
+-- @var Default rules
+policy.rules = {}
+policy.postrules = {}
+policy.special_names = {
+ -- XXX: beware of special_names_optim() when modifying these filters
+ {
+ cb=policy.suffix_common(policy.DENY_MSG(
+ 'Blocking is mandated by standards, see references on '
+ .. 'https://www.iana.org/assignments/'
+ .. 'locally-served-dns-zones/locally-served-dns-zones.xhtml'),
+ private_zones, todname('arpa.')),
+ count=0
+ },
+ {
+ cb=policy.suffix(policy.DENY_MSG(
+ 'Blocking is mandated by standards, see references on '
+ .. 'https://www.iana.org/assignments/'
+ .. 'special-use-domain-names/special-use-domain-names.xhtml'),
+ {
+ todname('test.'),
+ todname('onion.'),
+ todname('invalid.'),
+ todname('local.'), -- RFC 8375.4
+ }),
+ count=0
+ },
+ {
+ cb=policy.suffix(localhost, {dname_localhost}),
+ count=0
+ },
+ {
+ cb=policy.suffix_common(localhost_reversed, {
+ todname('127.in-addr.arpa.'),
+ todname('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.')},
+ todname('arpa.')),
+ count=0
+ },
+}
+
+-- Return boolean; false = no special name may apply, true = some might apply.
+-- The point is to *efficiently* filter almost all QNAMEs that do not apply.
+local function special_names_optim(req, sname)
+ local qname_size = req.qsource.packet.qname_size
+ if qname_size < 9 then return true end -- don't want to special-case bad array access
+ local root = sname + qname_size - 1
+ return
+ -- .a???. or .t???.
+ (root[-5] == 4 and (root[-4] == 97 or root[-4] == 116))
+ -- .on???. or .in?????. or lo???. or *ost.
+ or (root[-6] == 5 and root[-5] == 111 and root[-4] == 110)
+ or (root[-8] == 7 and root[-7] == 105 and root[-6] == 110)
+ or (root[-6] == 5 and root[-5] == 108 and root[-4] == 111)
+ or (root[-3] == 111 and root[-2] == 115 and root[-1] == 116)
+end
+
+-- Top-down policy list walk until we hit a match
+-- the caller is responsible for reordering policy list
+-- from most specific to least specific.
+-- Some rules may be chained, in this case they are evaluated
+-- as a dependency chain, e.g. r1,r2,r3 -> r3(r2(r1(state)))
+policy.layer = {
+ begin = function(state, req)
+ -- Don't act on "finished" cases.
+ if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end
+ local qry = req:initial() -- same as :current() but more descriptive
+ return policy.evaluate(policy.rules, req, qry, state)
+ or (special_names_optim(req, qry.sname)
+ and policy.evaluate(policy.special_names, req, qry, state))
+ or state
+ end,
+ finish = function(state, req)
+ -- Optimization for the typical case
+ if #policy.postrules == 0 then return state end
+ -- Don't act on failed cases.
+ if bit.band(state, kres.FAIL) ~= 0 then return state end
+ return policy.evaluate(policy.postrules, req, req:initial(), state) or state
+ end
+}
+
+return policy
diff --git a/modules/policy/policy.rpz.test.lua b/modules/policy/policy.rpz.test.lua
new file mode 100644
index 0000000..047b27f
--- /dev/null
+++ b/modules/policy/policy.rpz.test.lua
@@ -0,0 +1,56 @@
+
+local function prepare_cache()
+ cache.open(100*MB)
+ cache.clear()
+
+ local ffi = require('ffi')
+ local c = kres.context().cache
+
+ local passthru_addr = '\127\0\0\9'
+ rr_passthru = kres.rrset(todname('rpzpassthru.'), kres.type.A, kres.class.IN, 3600999999)
+ assert(rr_passthru:add_rdata(passthru_addr, #passthru_addr))
+ assert(c:insert(rr_passthru, nil, ffi.C.KR_RANK_SECURE + ffi.C.KR_RANK_AUTH))
+
+ c:commit()
+end
+
+local check_answer = require('test_utils').check_answer
+
+local function test_rpz()
+ check_answer('"CNAME ." return NXDOMAIN',
+ 'nxdomain.', kres.type.A, kres.rcode.NXDOMAIN)
+ check_answer('"CNAME *." return NODATA',
+ 'nodata.', kres.type.A, kres.rcode.NOERROR, {})
+ check_answer('"CNAME *. on wildcard" return NODATA',
+ 'nodata.nxdomain.', kres.type.A, kres.rcode.NOERROR, {})
+ check_answer('"CNAME rpz-drop." be dropped',
+ 'rpzdrop.', kres.type.A, kres.rcode.SERVFAIL)
+ check_answer('"CNAME rpz-passthru" return A rrset',
+ 'rpzpassthru.', kres.type.A, kres.rcode.NOERROR, '127.0.0.9')
+ check_answer('"A 192.168.5.5" return local A rrset',
+ 'rra.', kres.type.A, kres.rcode.NOERROR, '192.168.5.5')
+ check_answer('"A 192.168.6.6" with suffixed zone name in owner return local A rrset',
+ 'rra-zonename-suffix.', kres.type.A, kres.rcode.NOERROR, '192.168.6.6')
+ check_answer('"A 192.168.7.7" with suffixed zone name in owner return local A rrset',
+ 'testdomain.rra.', kres.type.A, kres.rcode.NOERROR, '192.168.7.7')
+ check_answer('non existing AAAA on rra domain return NODATA',
+ 'rra.', kres.type.AAAA, kres.rcode.NOERROR, {})
+ check_answer('"A 192.168.8.8" and domain with uppercase and lowercase letters',
+ 'case.sensitive.', kres.type.A, kres.rcode.NOERROR, '192.168.8.8')
+ check_answer('"A 192.168.8.8" and domain with uppercase and lowercase letters',
+ 'CASe.SENSItivE.', kres.type.A, kres.rcode.NOERROR, '192.168.8.8')
+ check_answer('two AAAA records',
+ 'two.records.', kres.type.AAAA, kres.rcode.NOERROR,
+ {'2001:db8::2', '2001:db8::1'})
+end
+
+net.ipv4 = false
+net.ipv6 = false
+
+prepare_cache()
+
+policy.add(policy.rpz(policy.DENY, 'policy.test.rpz'))
+
+return {
+ test_rpz,
+}
diff --git a/modules/policy/policy.slice.test.lua b/modules/policy/policy.slice.test.lua
new file mode 100644
index 0000000..89c1b05
--- /dev/null
+++ b/modules/policy/policy.slice.test.lua
@@ -0,0 +1,109 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+-- check lua-psl is available
+local has_psl = pcall(require, 'psl')
+if not has_psl then
+ os.exit(77) -- SKIP policy.slice
+end
+
+-- unload modules which are not related to this test
+if ta_update then
+ modules.unload('ta_update')
+end
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+if priming then
+ modules.unload('priming')
+end
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+local kres = require('kres')
+
+local slice_queries = {
+ {},
+ {},
+ {},
+}
+
+local function sliceaction(index)
+ return function(_, req)
+ -- log query
+ local qry = req:current()
+ local name = kres.dname2str(qry:name())
+ local count = slice_queries[index][name]
+ if not count then
+ count = 0
+ end
+ slice_queries[index][name] = count + 1
+
+ -- refuse query
+ local answer = req:ensure_answer()
+ if answer == nil then return nil end
+ answer:rcode(kres.rcode.REFUSED)
+ answer:ad(false)
+ return kres.DONE
+ end
+end
+
+-- configure slicing
+policy.add(policy.slice(
+ policy.slice_randomize_psl(0),
+ sliceaction(1),
+ sliceaction(2),
+ sliceaction(3)
+))
+
+local function check_slice(desc, qname, qtype, expected_slice, expected_count)
+ callback = function()
+ count = slice_queries[expected_slice][qname]
+ qtype_str = kres.tostring.type[qtype]
+ same(count, expected_count, desc .. qname .. ' ' .. qtype_str)
+ end
+ resolve(qname, qtype, kres.class.IN, {}, callback)
+end
+
+local function test_randomize_psl()
+ local desc = 'randomize_psl() same qname, different qtype (same slice): '
+ check_slice(desc, 'example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'example.com.', kres.type.AAAA, 2, 2)
+ check_slice(desc, 'example.com.', kres.type.MX, 2, 3)
+ check_slice(desc, 'example.com.', kres.type.NS, 2, 4)
+
+ desc = 'randomize_psl() subdomain in same slice: '
+ check_slice(desc, 'a.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'b.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'c.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'a.a.example.com.', kres.type.A, 2, 1)
+ check_slice(desc, 'a.a.a.example.com.', kres.type.A, 2, 1)
+
+ desc = 'randomize_psl() different qnames in different slices: '
+ check_slice(desc, 'example2.com.', kres.type.A, 1, 1)
+ check_slice(desc, 'example5.com.', kres.type.A, 3, 1)
+
+ desc = 'randomize_psl() check unregistrable domains: '
+ check_slice(desc, '.', kres.type.A, 3, 1)
+ check_slice(desc, 'com.', kres.type.A, 1, 1)
+ check_slice(desc, 'cz.', kres.type.A, 2, 1)
+ check_slice(desc, 'co.uk.', kres.type.A, 1, 1)
+
+ desc = 'randomize_psl() check multi-level reg. domains: '
+ check_slice(desc, 'example.co.uk.', kres.type.A, 3, 1)
+ check_slice(desc, 'a.example.co.uk.', kres.type.A, 3, 1)
+ check_slice(desc, 'b.example.co.uk.', kres.type.MX, 3, 1)
+ check_slice(desc, 'example2.co.uk.', kres.type.A, 2, 1)
+
+ desc = 'randomize_psl() reg. domain - always ends up in slice: '
+ check_slice(desc, 'fdsnnsdfvkdn.com.', kres.type.A, 3, 1)
+ check_slice(desc, 'bdfbd.cz.', kres.type.A, 1, 1)
+ check_slice(desc, 'nrojgvn.net.', kres.type.A, 1, 1)
+ check_slice(desc, 'jnojtnbv.engineer.', kres.type.A, 2, 1)
+ check_slice(desc, 'dfnjonfdsjg.gov.', kres.type.A, 1, 1)
+ check_slice(desc, 'okfjnosdfgjn.mil.', kres.type.A, 1, 1)
+ check_slice(desc, 'josdhnojn.test.', kres.type.A, 2, 1)
+end
+
+return {
+ test_randomize_psl,
+}
diff --git a/modules/policy/policy.test.lua b/modules/policy/policy.test.lua
new file mode 100644
index 0000000..69dda1f
--- /dev/null
+++ b/modules/policy/policy.test.lua
@@ -0,0 +1,145 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+-- setup resolver
+-- policy module should be loaded by default, do not load it explicitly
+
+-- do not attempt to contact outside world, operate only on cache
+net.ipv4 = false
+net.ipv6 = false
+-- do not listen, test is driven by config code
+env.KRESD_NO_LISTEN = true
+
+-- test for default configuration
+local function test_tls_forward()
+ boom(policy.TLS_FORWARD, {}, 'TLS_FORWARD without arguments')
+ boom(policy.TLS_FORWARD, {'1'}, 'TLS_FORWARD with non-table argument')
+ boom(policy.TLS_FORWARD, {{}}, 'TLS_FORWARD with empty table')
+ boom(policy.TLS_FORWARD, {{{}}}, 'TLS_FORWARD with empty target table')
+ boom(policy.TLS_FORWARD, {{{bleble=''}}}, 'TLS_FORWARD with invalid parameters in table')
+
+ boom(policy.TLS_FORWARD, {{'1'}}, 'TLS_FORWARD with invalid IP address')
+ boom(policy.TLS_FORWARD, {{{'::1', bleble=''}}}, 'TLS_FORWARD with valid IP and invalid parameters')
+ boom(policy.TLS_FORWARD, {{{'127.0.0.1'}}}, 'TLS_FORWARD with missing auth parameters')
+
+ ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true}}), 'TLS_FORWARD with no authentication')
+ boom(policy.TLS_FORWARD, {{{'100:dead::', insecure=true},
+ {'100:DEAD:0::', insecure=true}
+ }}, 'TLS_FORWARD with duplicate IP addresses is not allowed')
+ ok(policy.TLS_FORWARD({{'100:dead::2', insecure=true},
+ {'100:dead::2@443', insecure=true}
+ }), 'TLS_FORWARD with duplicate IP addresses but different ports is allowed')
+ ok(policy.TLS_FORWARD({{'100:dead::3', insecure=true},
+ {'100:beef::3', insecure=true}
+ }), 'TLS_FORWARD with different IPv6 addresses is allowed')
+ ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true},
+ {'127.0.0.2', insecure=true}
+ }), 'TLS_FORWARD with different IPv4 addresses is allowed')
+
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256=''}}}, 'TLS_FORWARD with empty pin_sha256')
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='č'}}}, 'TLS_FORWARD with bad pin_sha256')
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='d161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}},
+ 'TLS_FORWARD with bad pin_sha256 (short base64)')
+ boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='bbd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}},
+ 'TLS_FORWARD with bad pin_sha256 (long base64)')
+ ok(policy.TLS_FORWARD({
+ {'::1', pin_sha256='g1PpXsxqPchz2tH6w9kcvVXqzQ0QclhInFP2+VWOqic='}
+ }), 'TLS_FORWARD with base64 pin_sha256')
+ ok(policy.TLS_FORWARD({
+ {'::1', pin_sha256={
+ 'ev1xcdU++dY9BlcX0QoKeaUftvXQvNIz/PCss1Z/3ek=',
+ 'SgnqTFcvYduWX7+VUnlNFT1gwSNvQdZakH7blChIRbM=',
+ 'bd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg=',
+ }}}), 'TLS_FORWARD with a table of pins')
+
+ -- ok(policy.TLS_FORWARD({{'::1', hostname='test.', ca_file='/tmp/ca.crt'}}), 'TLS_FORWARD with hostname + CA cert')
+ ok(policy.TLS_FORWARD({{'::1', hostname='test.'}}),
+ 'TLS_FORWARD with just hostname (use system CA store)')
+ boom(policy.TLS_FORWARD, {{{'::1', ca_file='/tmp/ca.crt'}}},
+ 'TLS_FORWARD with just CA cert')
+ boom(policy.TLS_FORWARD, {{{'::1', hostname='', ca_file='/tmp/ca.crt'}}},
+ 'TLS_FORWARD with empty hostname + CA cert')
+ boom(policy.TLS_FORWARD, {
+ {{'::1', hostname='test.', ca_file='/dev/a_file_which_surely_does_NOT_exist!'}}
+ }, 'TLS_FORWARD with hostname + unreadable CA cert')
+
+end
+
+local function test_slice()
+ boom(policy.slice, {function() end}, 'policy.slice() without any action')
+ ok(policy.slice, {function() end, policy.FORWARD, policy.FORWARD})
+end
+
+local function mirror_parser(srv, cv, nqueries)
+ local ffi = require('ffi')
+ local test_end = 0
+ local TIMEOUT = 5 -- seconds
+
+ while true do
+ local input = srv:xread('*a', 'bn', TIMEOUT)
+ if not input then
+ cv:signal()
+ return false, 'mirror: timeout'
+ end
+ --print(#input, input)
+ -- convert query to knot_pkt_t
+ local wire = ffi.cast("void *", input)
+ local pkt = ffi.gc(ffi.C.knot_pkt_new(wire, #input, nil), ffi.C.knot_pkt_free)
+ if not pkt then
+ cv:signal()
+ return false, 'mirror: packet allocation error'
+ end
+
+ local result = ffi.C.knot_pkt_parse(pkt, 0)
+ if result ~= 0 then
+ cv:signal()
+ return false, 'mirror: packet parse error'
+ end
+ --print(pkt)
+ test_end = test_end + 1
+
+ if test_end == nqueries then
+ cv:signal()
+ return true, 'packet mirror pass'
+ end
+
+ end
+end
+
+local function test_mirror()
+ local kluautil = require('kluautil')
+ local socket = require('cqueues.socket')
+ local cond = require('cqueues.condition')
+ local cv = cond.new()
+ local queries = {}
+ local srv = socket.listen({
+ host = "127.0.0.1",
+ port = 36659,
+ type = socket.SOCK_DGRAM,
+ })
+ -- binary mode, no buffering
+ srv:setmode('bn', 'bn')
+
+ queries["bla.mujtest.cz."] = kres.type.AAAA
+ queries["bla.mujtest2.cz."] = kres.type.AAAA
+
+ -- UDP server for test
+ worker.bg_worker.cq:wrap(function()
+ local err, msg = mirror_parser(srv, cv, kluautil.kr_table_len(queries))
+
+ ok(err, msg)
+ end)
+
+ policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest.cz.'})))
+ policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest2.cz.'})))
+
+ for name, rtype in pairs(queries) do
+ resolve(name, rtype)
+ end
+
+ cv:wait()
+end
+
+return {
+ test_tls_forward,
+ test_mirror,
+ test_slice,
+}
diff --git a/modules/policy/policy.test.rpz b/modules/policy/policy.test.rpz
new file mode 100644
index 0000000..80b7106
--- /dev/null
+++ b/modules/policy/policy.test.rpz
@@ -0,0 +1,18 @@
+$ORIGIN testdomain.
+$TTL 30
+testdomain. SOA nonexistent.testdomain. testdomain. 1 12h 15m 3w 2h
+ NS nonexistent.testdomain.
+
+nxdomain CNAME .
+nodata CNAME *.
+*.nxdomain CNAME *.
+rpzdrop CNAME rpz-drop.
+rpzpassthru CNAME rpz-passthru.
+rra A 192.168.5.5
+rra-zonename-suffix A 192.168.6.6
+testdomain.rra.testdomain. A 192.168.7.7
+CaSe.SeNSiTiVe A 192.168.8.8
+
+two.records AAAA 2001:db8::2
+two.records AAAA 2001:db8::1
+
diff --git a/modules/policy/test.integr/deckard.yaml b/modules/policy/test.integr/deckard.yaml
new file mode 100644
index 0000000..9c6cb70
--- /dev/null
+++ b/modules/policy/test.integr/deckard.yaml
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+programs:
+- name: kresd
+ binary: kresd
+ additional:
+ - --noninteractive
+ templates:
+ - modules/policy/test.integr/kresd_config.j2
+ - tests/integration/hints_zone.j2
+ configs:
+ - config
+ - hints
diff --git a/modules/policy/test.integr/kresd_config.j2 b/modules/policy/test.integr/kresd_config.j2
new file mode 100644
index 0000000..25cf0ee
--- /dev/null
+++ b/modules/policy/test.integr/kresd_config.j2
@@ -0,0 +1,58 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+{% raw %}
+policy.add(policy.suffix(policy.REFUSE, {todname('refuse.example.com')}))
+
+-- make sure DNSSEC is turned off for tests
+trust_anchors.remove('.')
+
+-- Disable RFC5011 TA update
+if ta_update then
+ modules.unload('ta_update')
+end
+
+-- Disable RFC8145 signaling, scenario doesn't provide expected answers
+if ta_signal_query then
+ modules.unload('ta_signal_query')
+end
+
+-- Disable RFC8109 priming, scenario doesn't provide expected answers
+if priming then
+ modules.unload('priming')
+end
+
+-- Disable this module because it make one priming query
+if detect_time_skew then
+ modules.unload('detect_time_skew')
+end
+
+_hint_root_file('hints')
+cache.size = 2*MB
+verbose(true)
+{% endraw %}
+
+net = { '{{SELF_ADDR}}' }
+
+
+{% if QMIN == "false" %}
+option('NO_MINIMIZE', true)
+{% else %}
+option('NO_MINIMIZE', false)
+{% endif %}
+
+
+-- Self-checks on globals
+assert(help() ~= nil)
+assert(worker.id ~= nil)
+-- Self-checks on facilities
+assert(cache.count() == 0)
+assert(cache.stats() ~= nil)
+assert(cache.backends() ~= nil)
+assert(worker.stats() ~= nil)
+assert(net.interfaces() ~= nil)
+-- Self-checks on loaded stuff
+assert(net.list()[1].transport.ip == '{{SELF_ADDR}}')
+assert(#modules.list() > 0)
+-- Self-check timers
+ev = event.recurrent(1 * sec, function (ev) return 1 end)
+event.cancel(ev)
+ev = event.after(0, function (ev) return 1 end)
diff --git a/modules/policy/test.integr/refuse.rpl b/modules/policy/test.integr/refuse.rpl
new file mode 100644
index 0000000..ee8daa4
--- /dev/null
+++ b/modules/policy/test.integr/refuse.rpl
@@ -0,0 +1,25 @@
+; SPDX-License-Identifier: GPL-3.0-or-later
+; config options
+ stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET.
+CONFIG_END
+
+SCENARIO_BEGIN Test refuse policy
+
+STEP 10 QUERY
+ENTRY_BEGIN
+REPLY RD AD
+SECTION QUESTION
+www.refuse.example.com. IN A
+ENTRY_END
+
+STEP 20 CHECK_ANSWER
+ENTRY_BEGIN
+MATCH all answer
+; AD must not be set in the answer
+REPLY QR RD RA REFUSED
+SECTION QUESTION
+www.refuse.example.com. IN A
+SECTION ANSWER
+ENTRY_END
+
+SCENARIO_END