diff options
Diffstat (limited to 'modules/policy')
42 files changed, 5986 insertions, 0 deletions
diff --git a/modules/policy/.packaging/test.config b/modules/policy/.packaging/test.config new file mode 100644 index 0000000..60c9ddc --- /dev/null +++ b/modules/policy/.packaging/test.config @@ -0,0 +1,4 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +modules.load('policy') +assert(policy) +quit() diff --git a/modules/policy/README.rst b/modules/policy/README.rst new file mode 100644 index 0000000..202aaba --- /dev/null +++ b/modules/policy/README.rst @@ -0,0 +1,774 @@ +.. SPDX-License-Identifier: GPL-3.0-or-later + +.. default-domain:: py +.. module:: policy + +.. _mod-policy: + + +Query policies +============== + +This module can block, rewrite, or alter inbound queries based on user-defined policies. It does not affect queries generated by the resolver itself, e.g. when following CNAME chains etc. + +Each policy *rule* has two parts: a *filter* and an *action*. A *filter* selects which queries will be affected by the policy, and *action* which modifies queries matching the associated filter. + +Typically a rule is defined as follows: ``filter(action(action parameters), filter parameters)``. For example, a filter can be ``suffix`` which matches queries whose suffix part is in specified set, and one of possible actions is :any:`policy.DENY`, which denies resolution. These are combined together into ``policy.suffix(policy.DENY, {todname('badguy.example.')})``. The rule is effective when it is added into rule table using ``policy.add()``, please see examples below. + +This module is enabled by default because it implements mandatory :rfc:`6761` logic. +When no rule applies to a query, built-in rules for `special-use <https://www.iana.org/assignments/special-use-domain-names/special-use-domain-names.xhtml>`_ and `locally-served <http://www.iana.org/assignments/locally-served-dns-zones>`_ domain names are applied. +These rules can be overridden by action :any:`policy.PASS`. For debugging purposes you can also add ``modules.unload('policy')`` to your config to unload the module. + + +Filters +------- +A *filter* selects which queries will be affected by specified Actions_. There are several policy filters available in the ``policy.`` table: + +.. function:: all(action) + + Always applies the action. + +.. function:: pattern(action, pattern) + + Applies the action if query name matches a `Lua regular expression <http://lua-users.org/wiki/PatternsTutorial>`_. + +.. function:: suffix(action, suffix_table) + + Applies the action if query name suffix matches one of suffixes in the table (useful for "is domain in zone" rules). + + .. code-block:: lua + + policy.add(policy.suffix(policy.DENY, policy.todnames({'example.com', 'example.net'}))) + +.. note:: For speed this filter requires domain names in DNS wire format, not textual representation, so each label in the name must be prefixed with its length. Always use convenience function :func:`policy.todnames` for automatic conversion from strings! For example: + +.. _IDN: + +.. note:: Non-ASCII is not supported. + + Knot Resolver does not provide any convenience support for IDN. + Therefore everywhere (all configuration, logs, RPZ files) you need to deal with the + `xn\-\- forms <https://en.wikipedia.org/wiki/Internationalized_domain_name#Example_of_IDNA_encoding>`_ + of domain name labels, instead of directly using unicode characters. + +.. function:: domains(action, domain_table) + + Like :func:`policy.suffix` match, but the queried name must match exactly, not just its suffix. + +.. function:: suffix_common(action, suffix_table[, common_suffix]) + + :param action: action if the pattern matches query name + :param suffix_table: table of valid suffixes + :param common_suffix: common suffix of entries in suffix_table + + Like :func:`policy.suffix` match, but you can also provide a common suffix of all matches for faster processing (nil otherwise). + This function is faster for small suffix tables (in the order of "hundreds"). + +.. :noindex: function:: rpz(default_action, path, [watch]) + + Implements a subset of `Response Policy Zone` (RPZ_) stored in zonefile format. See below for details: :func:`policy.rpz`. + +It is also possible to define custom filter function with any name. + +.. function:: custom_filter(state, query) + + :param state: Request processing state :c:type:`kr_layer_state`, typically not used by filter function. + :param query: Incoming DNS query as :c:type:`kr_query` structure. + :return: An `action <#actions>`_ function or ``nil`` if filter did not match. + + Typically filter function is generated by another function, which allows easy parametrization - this technique is called `closure <https://www.lua.org/pil/6.1.html>`_. An practical example of such filter generator is: + +.. code-block:: lua + + function match_query_type(action, target_qtype) + return function (state, query) + if query.stype == target_qtype then + -- filter matched the query, return action function + return action + else + -- filter did not match, continue with next filter + return nil + end + end + end + +This custom filter can be used as any other built-in filter. +For example this applies our custom filter and executes action :any:`policy.DENY` on all queries of type `HINFO`: + +.. code-block:: lua + + -- custom filter which matches HINFO queries, action is policy.DENY + policy.add(match_query_type(policy.DENY, kres.type.HINFO)) + + +.. _mod-policy-actions: + +Actions +------- +An *action* is a function which modifies DNS request, and is either of type *chain* or *non-chain*: + + * `Non-chain actions`_ modify state of the request and stop rule processing. An example of such action is :ref:`forwarding`. + * `Chain actions`_ modify state of the request and allow other rules to evaluate and act on the same request. One such example is :func:`policy.MIRROR`. + +Non-chain actions +^^^^^^^^^^^^^^^^^ + +Following actions stop the policy matching on the query, i.e. other rules are not evaluated once rule with following actions matches: + +.. py:attribute:: PASS + + Let the query pass through; it's useful to make exceptions before wider rules. For example: + + More specific whitelist rule must precede generic blacklist rule: + + .. code-block:: lua + + -- Whitelist 'good.example.com' + policy.add(policy.pattern(policy.PASS, todname('good.example.com.'))) + -- Block all names below example.com + policy.add(policy.suffix(policy.DENY, {todname('example.com.')})) + +.. py:attribute:: DENY + + Deny existence of names matching filter, i.e. reply NXDOMAIN authoritatively. + +.. function:: DENY_MSG(message, [extended_error=kres.extended_error.BLOCKED]) + + Deny existence of a given domain and add explanatory message. NXDOMAIN reply + contains an additional explanatory message as TXT record in the additional + section. + + You may override the extended DNS error to provide the user with more + information. By default, ``BLOCKED`` is returned to indicate the domain is + blocked due to the internal policy of the operator. Other suitable error + codes are ``CENSORED`` (for externally imposed policy reasons) or + ``FILTERED`` (for blocking requested by the client). For more information, + please refer to :rfc:`8914`. + +.. py:attribute:: DROP + + Terminate query resolution and return SERVFAIL to the requestor. + +.. py:attribute:: REFUSE + + Terminate query resolution and return REFUSED to the requestor. + +.. py:attribute:: NO_ANSWER + + Terminate query resolution and do not return any answer to the requestor. + + .. warning:: During normal operation, an answer should always be returned. + Deliberate query drops are indistinguishable from packet loss and may + cause problems as described in :rfc:`8906`. Only use :any:`NO_ANSWER` + on very specific occasions, e.g. as a defense mechanism during an attack, + and prefer other actions (e.g. :any:`DROP` or :any:`REFUSE`) for normal + operation. + +.. py:attribute:: TC + + Force requestor to use TCP. It sets truncated bit (*TC*) in response to true if the request came through UDP, which will force standard-compliant clients to retry the request over TCP. + +.. function:: REROUTE({subnet = target, ...}) + + Reroute IP addresses in response matching given subnet to given target, e.g. ``{['192.0.2.0/24'] = '127.0.0.0'}`` will rewrite '192.0.2.55' to '127.0.0.55', see :ref:`renumber module <mod-renumber>` for more information. See :func:`policy.add` and do not forget to specify that this is *postrule*. Quick example: + + .. code-block:: lua + + -- this policy is enforced on answers + -- therefore we have to use 'postrule' + -- (the "true" at the end of policy.add) + policy.add(policy.all(policy.REROUTE({['192.0.2.0/24'] = '127.0.0.0'})), true) + +.. function:: ANSWER({ type = { rdata=data, [ttl=1] } }, [nodata=false]) + + Overwrite Resource Records in responses with specified values. + + * type + - RR type to be replaced, e.g. ``[kres.type.A]`` or `numeric value <https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4>`_. + * rdata + - RR data in DNS wire format, i.e. binary form specific for given RR type. Set of multiple RRs can be specified as table ``{ rdata1, rdata2, ... }``. Use helper function :func:`kres.str2ip` to generate wire format for A and AAAA records. Wire format for other record types can be generated with :func:`kres.parse_rdata`. + * ttl + - TTL in seconds. Default: 1 second. + * nodata + - If type requested by client is not configured in this policy: + + - ``true``: Return empty answer (`NODATA`). + - ``false``: Ignore this policy and continue processing other rules. + + Default: ``false``. + + .. code-block:: lua + + -- policy to change IPv4 address and TTL for example.com + policy.add( + policy.domains( + policy.ANSWER( + { [kres.type.A] = { rdata=kres.str2ip('192.0.2.7'), ttl=300 } } + ), { todname('example.com') })) + -- policy to generate two TXT records (specified in binary format) for example.net + policy.add( + policy.domains( + policy.ANSWER( + { [kres.type.TXT] = { rdata={'\005first', '\006second'}, ttl=5 } } + ), { todname('example.net') })) + + + .. function:: kres.parse_rdata({str, ...}) + + Parse string representation of RTYPE and RDATA into RDATA wire format. Expects + a table of string(s) and returns a table of wire data. + + .. code-block:: lua + + -- create wire format RDATA that can be passed to policy.ANSWER + kres.parse_rdata({'SVCB 1 resolver.example. alpn=dot'}) + kres.parse_rdata({ + 'SVCB 1 resolver.example. alpn=dot ipv4hint=192.0.2.1 ipv6hint=2001:db8::1', + 'SVCB 2 resolver.example. mandatory=key65380 alpn=h2 key65380=/dns-query{?dns}', + }) + +More complex non-chain actions are described in their own chapters, namely: + + * :ref:`forwarding` + * `Response Policy Zones`_ + +Chain actions +^^^^^^^^^^^^^ + +Following actions act on request and then processing continue until first non-chain action (specified in the previous section) is triggered: + +.. function:: MIRROR(ip_address) + + Send copy of incoming DNS queries to a given IP address using DNS-over-UDP and continue resolving them as usual. This is useful for sanity testing new versions of DNS resolvers. + + .. code-block:: lua + + policy.add(policy.all(policy.MIRROR('127.0.0.2'))) + +.. function:: FLAGS(set, clear) + + Set and/or clear some flags for the query. There can be multiple flags to set/clear. You can just pass a single flag name (string) or a set of names. Flag names correspond to :c:type:`kr_qflags` structure. Use only if you know what you are doing. + + +.. _mod-policy-logging: + +Actions for extra logging +^^^^^^^^^^^^^^^^^^^^^^^^^ + +These are also "chain" actions, i.e. they don't stop processing the policy rule list. +Similarly to other actions, they apply during whole processing of the client's request, +i.e. including any sub-queries. + +The log lines from these policy actions are tagged by extra ``[reqdbg]`` prefix, +and they are produced regardless of your :func:`log_level()` setting. +They are marked as ``debug`` level, so e.g. with journalctl command you can use ``-p info`` to skip them. + +.. warning:: Beware of producing too much logs. + + These actions are not suitable for use on a large fraction of resolver's requests. + The extra logs have significant performance impact and might also overload your logging system + (or get rate-limited by it). + You can use `Filters`_ to further limit on which requests this happens. + +.. py:attribute:: DEBUG_ALWAYS + + Print debug-level logging for this request. + That also includes messages from client (:any:`REQTRACE`), upstream servers (:any:`QTRACE`), and stats about interesting records at the end. + + .. code-block:: lua + + -- debug requests that ask for flaky.example.net or below + policy.add(policy.suffix(policy.DEBUG_ALWAYS, + policy.todnames({'flaky.example.net'}))) + +.. py:attribute:: DEBUG_CACHE_MISS + + Same as :any:`DEBUG_ALWAYS` but only if the request required information which was not available locally, i.e. requests which forced resolver to ask upstream server(s). + Intended usage is for debugging problems with remote servers. + +.. py:function:: DEBUG_IF(test_function) + + :param test_function: Function with single argument of type :c:type:`kr_request` which returns ``true`` if debug logs for that request should be generated and ``false`` otherwise. + + Same as :any:`DEBUG_ALWAYS` but only logs if the test_function says so. + + .. note:: ``test_function`` is evaluated only when request is finished. + As a result all debug logs this request must be collected, + and at the end they get either printed or thrown away. + + Example usage which gathers verbose logs for all requests in subtree ``dnssec-failed.org.`` and prints debug logs for those finishing in a different state than ``kres.DONE`` (most importantly ``kres.FAIL``, see :c:type:`kr_layer_state`). + + .. code-block:: lua + + policy.add(policy.suffix( + policy.DEBUG_IF(function(req) + return (req.state ~= kres.DONE) + end), + policy.todnames({'dnssec-failed.org.'}))) + +.. py:attribute:: QTRACE + + Pretty-print DNS responses from upstream servers (or cache) into logs. + It's useful for debugging weird DNS servers. + + If you do not use ``QTRACE`` in combination with ``DEBUG*``, + you additionally need either ``log_groups({'iterat'})`` (possibly with other groups) + or ``log_level('debug')`` to see the output in logs. + +.. py:attribute:: REQTRACE + + Pretty-print DNS requests from clients into the verbose log. It's useful for debugging weird DNS clients. + It makes most sense together with :ref:`mod-view` (enabling per-client) + and probably with verbose logging those request (e.g. use :any:`DEBUG_ALWAYS` instead). + +.. py:attribute:: IPTRACE + + Log how the request arrived. + Most notably, this includes the client's IP address, so beware of privacy implications. + + .. code-block:: lua + + -- example usage in configuration + policy.add(policy.all(policy.IPTRACE)) + -- you might want to combine it with some other logs, e.g. + policy.add(policy.all(policy.DEBUG_ALWAYS)) + + .. code-block:: text + + -- example log lines from IPTRACE: + [reqdbg][policy][57517.00] request packet arrived from ::1#37931 to ::1#00853 (TCP + TLS) + [reqdbg][policy][65538.00] request packet arrived internally + + +Custom actions +^^^^^^^^^^^^^^ + +.. function:: custom_action(state, request) + + :param state: Request processing state :c:type:`kr_layer_state`. + :param request: Current DNS request as :c:type:`kr_request` structure. + :return: Returning a new :c:type:`kr_layer_state` prevents evaluating other policy rules. Returning ``nil`` creates a `chain action <#actions>`_ and allows to continue evaluating other rules. + + This is real example of an action function: + +.. code-block:: lua + + -- Custom action which generates fake A record + local ffi = require('ffi') + local function fake_A_record(state, req) + local answer = req:ensure_answer() + if answer == nil then return nil end + local qry = req:current() + if qry.stype ~= kres.type.A then + return state + end + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\192\168\1\3') + return kres.DONE + end + +This custom action can be used as any other built-in action. +For example this applies our *fake A record action* and executes it on all queries in subtree ``example.net``: + +.. code-block:: lua + + policy.add(policy.suffix(fake_A_record, policy.todnames({'example.net'}))) + +The action function can implement arbitrary logic so it is possible to implement complex heuristics, e.g. to deflect `Slow drip DNS attacks <https://secure64.com/water-torture-slow-drip-dns-ddos-attack>`_ or gray-list resolution of misbehaving zones. + +.. warning:: The policy module currently only looks at whole DNS requests. The rules won't be re-applied e.g. when following CNAMEs. + +.. _forwarding: + +Forwarding +---------- + +Forwarding action alters behavior for cache-miss events. If an information is missing in the local cache the resolver will *forward* the query to *another DNS resolver* for resolution (instead of contacting authoritative servers directly). DNS answers from the remote resolver are then processed locally and sent back to the original client. + +Actions :func:`policy.FORWARD`, :func:`policy.TLS_FORWARD` and :func:`policy.STUB` accept up to four IP addresses at once and the resolver will automatically select IP address which statistically responds the fastest. + +.. function:: FORWARD(ip_address) + FORWARD({ ip_address, [ip_address, ...] }) + + Forward cache-miss queries to specified IP addresses (without encryption), DNSSEC validate received answers and cache them. Target IP addresses are expected to be DNS resolvers. + + .. code-block:: lua + + -- Forward all queries to public resolvers https://www.nic.cz/odvr + policy.add(policy.all( + policy.FORWARD( + {'2001:148f:fffe::1', '2001:148f:ffff::1', + '185.43.135.1', '193.14.47.1'}))) + + A variant which uses encrypted DNS-over-TLS transport is called :func:`policy.TLS_FORWARD`, please see section :ref:`tls-forwarding`. + +.. function:: STUB(ip_address) + STUB({ ip_address, [ip_address, ...] }) + + Similar to :func:`policy.FORWARD` but *without* attempting DNSSEC validation. + Each request may be either answered from cache or simply sent to one of the IPs with proxying back the answer. + + This mode does not support encryption and should be used only for `Replacing part of the DNS tree`_. + Use :func:`policy.FORWARD` mode if possible. + + .. code-block:: lua + + -- Answers for reverse queries about the 192.168.1.0/24 subnet + -- are to be obtained from IP address 192.0.2.1 port 5353 + -- This disables DNSSEC validation! + policy.add(policy.suffix( + policy.STUB('192.0.2.1@5353'), + {todname('1.168.192.in-addr.arpa')})) + +.. note:: By default, forwarding targets must support + `EDNS <https://en.wikipedia.org/wiki/Extension_mechanisms_for_DNS>`_ and + `0x20 randomization <https://tools.ietf.org/html/draft-vixie-dnsext-dns0x20-00>`_. + See example in `Replacing part of the DNS tree`_. + +.. warning:: + Limiting forwarding actions by filters (e.g. :func:`policy.suffix`) may have unexpected consequences. + Notably, forwarders can inject *any* records into your cache + even if you "restrict" them to an insignificant DNS subtree -- + except in cases where DNSSEC validation applies, of course. + + The behavior is probably best understood through the fact + that filters and actions are completely decoupled. + The forwarding actions have no clue about why they were executed, + e.g. that the user wanted to restrict the forwarder only to some subtree. + The action just selects some set of forwarders to process this whole request from the client, + and during that processing it might need some other "sub-queries" (e.g. for validation). + Some of those might not've passed the intended filter, + but policy rule-set only applies once per client's request. + +.. _tls-forwarding: + +Forwarding over TLS protocol (DNS-over-TLS) +------------------------------------------- +.. function:: TLS_FORWARD( { {ip_address, authentication}, [...] } ) + + Same as :func:`policy.FORWARD` but send query over DNS-over-TLS protocol (encrypted). + Each target IP address needs explicit configuration how to validate + TLS certificate so each IP address is configured by pair: + ``{ip_address, authentication}``. See sections below for more details. + + +Policy :func:`policy.TLS_FORWARD` allows you to forward queries using `Transport Layer Security`_ protocol, which hides the content of your queries from an attacker observing the network traffic. Further details about this protocol can be found in :rfc:`7858` and `IETF draft dprive-dtls-and-tls-profiles`_. + +Queries affected by :func:`policy.TLS_FORWARD` will always be resolved over TLS connection. Knot Resolver does not implement fallback to non-TLS connection, so if TLS connection cannot be established or authenticated according to the configuration, the resolution will fail. + +To test this feature you need to either :ref:`configure Knot Resolver as DNS-over-TLS server <tls-server-config>`, or pick some public DNS-over-TLS server. Please see `DNS Privacy Project`_ homepage for list of public servers. + +.. note:: Some public DNS-over-TLS providers may apply rate-limiting which + makes their service incompatible with Knot Resolver's TLS forwarding. + Notably, `Google Public DNS + <https://developers.google.com/speed/public-dns/docs/dns-over-tls>`_ doesn't + work as of 2019-07-10. + +When multiple servers are specified, the one with the lowest round-trip time is used. + +CA+hostname authentication +^^^^^^^^^^^^^^^^^^^^^^^^^^ +Traditional PKI authentication requires server to present certificate with specified hostname, which is issued by one of trusted CAs. Example policy is: + +.. code-block:: lua + + policy.TLS_FORWARD({ + {'2001:DB8::d0c', hostname='res.example.com'}}) + +- ``hostname`` must be a valid domain name matching server's certificate. It will also be sent to the server as SNI_. +- ``ca_file`` optionally contains a path to a CA certificate (or certificate bundle) in `PEM format`_. + If you omit that, the system CA certificate store will be used instead (usually sufficient). + A list of paths is also accepted, but all of them must be valid PEMs. + +Key-pinned authentication +^^^^^^^^^^^^^^^^^^^^^^^^^ +Instead of CAs, you can specify hashes of accepted certificates in ``pin_sha256``. +They are in the usual format -- base64 from sha256. +You may still specify ``hostname`` if you want SNI_ to be sent. + +.. _tls-examples: + +TLS Examples +^^^^^^^^^^^^ + +.. code-block:: lua + + modules = { 'policy' } + -- forward all queries over TLS to the specified server + policy.add(policy.all(policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}}))) + -- for brevity, other TLS examples omit policy.add(policy.all()) + -- single server authenticated using its certificate pin_sha256 + policy.TLS_FORWARD({{'192.0.2.1', pin_sha256='YQ=='}}) -- pin_sha256 is base64-encoded + -- single server authenticated using hostname and system-wide CA certificates + policy.TLS_FORWARD({{'192.0.2.1', hostname='res.example.com'}}) + -- single server using non-standard port + policy.TLS_FORWARD({{'192.0.2.1@443', pin_sha256='YQ=='}}) -- use @ or # to specify port + -- single server with multiple valid pins (e.g. anycast) + policy.TLS_FORWARD({{'192.0.2.1', pin_sha256={'YQ==', 'Wg=='}}) + -- multiple servers, each with own authenticator + policy.TLS_FORWARD({ -- please note that { here starts list of servers + {'192.0.2.1', pin_sha256='Wg=='}, + -- server must present certificate issued by specified CA and hostname must match + {'2001:DB8::d0c', hostname='res.example.com', ca_file='/etc/knot-resolver/tlsca.crt'} + }) + +Forwarding to multiple targets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With the use of :func:`policy.slice` function, it is possible to split the +entire DNS namespace into distinct slices. When used in conjunction with +:func:`policy.TLS_FORWARD`, it's possible to forward different queries to +different targets. + +.. function:: slice(slice_func, action[, action[, ...]) + + :param slice_func: slicing function that returns index based on query + :param action: action to be performed for the slice + + This function splits the entire domain space into multiple slices (determined + by the number of provided ``actions``). A ``slice_func`` is called to determine + which slice a query belongs to. The corresponding ``action`` is then executed. + + +.. function:: slice_randomize_psl(seed = os.time() / (3600 * 24 * 7)) + + :param seed: seed for random assignment + + The function initializes and returns a slicing function, which + deterministically assigns ``query`` to a slice based on the query name. + + It utilizes the `Public Suffix List`_ to ensure domains under the same + registrable domain end up in a single slice. (see example below) + + ``seed`` can be used to re-shuffle the slicing algorithm when the slicing + function is initialized. By default, the assignment is re-shuffled after one + week (when resolver restart / reloads config). To force a stable + distribution, pass a fixed value. To re-shuffle on every resolver restart, + use ``os.time()``. + + The following example demonstrates a distribution among 3 slices:: + + slice 1/3: + example.com + a.example.com + b.example.com + x.b.example.com + example3.com + + slice 2/3: + example2.co.uk + + slice 3/3: + example.co.uk + a.example.co.uk + +These two functions can be used together to forward queries for names +in different parts of DNS name space to different target servers: + +.. code-block:: lua + + policy.add(policy.slice( + policy.slice_randomize_psl(), + policy.TLS_FORWARD({{'192.0.2.1', hostname='res.example.com'}}), + policy.TLS_FORWARD({ + -- multiple servers can be specified for a single slice + -- the one with lowest round-trip time will be used + {'193.17.47.1', hostname='odvr.nic.cz'}, + {'185.43.135.1', hostname='odvr.nic.cz'}, + }) + )) + +.. note:: The privacy implications of using this feature aren't clear. Since + websites often make requests to multiple domains, these might be forwarded + to different targets. This could result in decreased privacy (e.g. when the + remote targets are both logging or otherwise processing your DNS traffic). + The intended use-case is to use this feature with semi-trusted resolvers + which claim to do no logging (such as those listed on `dnsprivacy.org + <https://dnsprivacy.org/wiki/display/DP/DNS+Privacy+Test+Servers>`_), to + decrease the potential exposure of your DNS data to a malicious resolver + operator. + +.. _dns-graft: + +Replacing part of the DNS tree +------------------------------ + +Following procedure applies only to domains which have different content +publicly and internally. For example this applies to "your own" top-level domain +``example.`` which does not exist in the public (global) DNS namespace. + +Dealing with these internal-only domains requires extra configuration because +DNS was designed as "single namespace" and local modifications like adding +your own TLD break this assumption. + +.. warning:: Use of internal names which are not delegated from the public DNS + *is causing technical problems* with caching and DNSSEC validation + and generally makes DNS operation more costly. + We recommend **against** using these non-delegated names. + +To make such internal domain available in your resolver it is necessary to +*graft* your domain onto the public DNS namespace, +but *grafting* creates new issues: + +These *grafted* domains will be rejected by DNSSEC validation +because such domains are technically indistinguishable from an spoofing attack +against the public DNS. +Therefore, if you trust the remote resolver which hosts the internal-only domain, +and you trust your link to it, you need to use the :func:`policy.STUB` policy +instead of :func:`policy.FORWARD` to disable DNSSEC validation for those +*grafted* domains. + +.. code-block:: lua + :caption: Example configuration grafting domains onto public DNS namespace + + extraTrees = policy.todnames( + {'faketldtest.', + 'sld.example.', + 'internal.example.com.', + '2.0.192.in-addr.arpa.' -- this applies to reverse DNS tree as well + }) + -- Beware: the rule order is important, as policy.STUB is not a chain action. + -- Flags: for "dumb" targets disabling EDNS can help (below) as DNSSEC isn't + -- validated anyway; in some of those cases adding 'NO_0X20' can also help, + -- though it also lowers defenses against off-path attacks on communication + -- between the two servers. + -- With kresd <= 5.5.3 you also needed 'NO_CACHE' flag to avoid unintentional + -- NXDOMAINs that could sometimes happen due to aggressive DNSSEC caching. + policy.add(policy.suffix(policy.FLAGS({'NO_EDNS'}), extraTrees)) + policy.add(policy.suffix(policy.STUB({'2001:db8::1'}), extraTrees)) + +Response policy zones +--------------------- + .. warning:: + + There is no published Internet Standard for RPZ_ and implementations vary. + At the moment Knot Resolver supports limited subset of RPZ format and deviates + from implementation in BIND. Nevertheless it is good enough + for blocking large lists of spam or advertising domains. + + + + The RPZ file format is basically a DNS zone file with *very special* semantics. + For example: + + .. code-block:: none + + ; left hand side ; TTL and class ; right hand side + ; encodes RPZ trigger ; ignored ; encodes action + ; (i.e. filter) + blocked.domain.example 600 IN CNAME . ; block main domain + *.blocked.domain.example 600 IN CNAME . ; block subdomains + + The only "trigger" supported in Knot Resolver is query name, + i.e. left hand side must be a domain name which triggers the action specified + on the right hand side. + + Subset of possible RPZ actions is supported, namely: + + .. csv-table:: + :header: "RPZ Right Hand Side", "Knot Resolver Action", "BIND Compatibility" + + "``.``", "``action`` is used", "compatible if ``action`` is :any:`policy.DENY`" + "``*.``", ":func:`policy.ANSWER`", "yes" + "``rpz-passthru.``", ":any:`policy.PASS`", "yes" + "``rpz-tcp-only.``", ":any:`policy.TC`", "yes" + "``rpz-drop.``", ":any:`policy.DROP`", "no [#]_" + "fake A/AAAA", ":func:`policy.ANSWER`", "yes" + "fake CNAME", "not supported", "no" + + .. [#] Our :any:`policy.DROP` returns *SERVFAIL* answer (for historical reasons). + + + .. note:: + + To debug which domains are affected by RPZ (or other policy actions), you can enable the ``policy`` log group: + + .. code-block:: lua + + log_groups({'policy'}) + + See also :ref:`non-ASCII support note <IDN>`. + + +.. function:: rpz(action, path, [watch = true]) + + :param action: the default action for match in the zone; typically you want :any:`policy.DENY` + :param path: path to zone file + :param watch: boolean, if true, the file will be reloaded on file change + + Enforce RPZ_ rules. This can be used in conjunction with published blocklist feeds. + The RPZ_ operation is well described in this `Jan-Piet Mens's post`_, + or the `Pro DNS and BIND`_ book. + + For example, we can store the example snippet with domain ``blocked.domain.example`` + (above) into file ``/etc/knot-resolver/blocklist.rpz`` and configure resolver to + answer with *NXDOMAIN* plus the specified additional text to queries for this domain: + + .. code-block:: lua + + policy.add( + policy.rpz(policy.DENY_MSG('domain blocked by your resolver operator'), + '/etc/knot-resolver/blocklist.rpz', + true)) + + Resolver will reload RPZ file at run-time if the RPZ file changes. + Recommended RPZ update procedure is to store new blocklist in a new file + (*newblocklist.rpz*) and then rename the new file to the original file name + (*blocklist.rpz*). This avoids problems where resolver might attempt + to re-read an incomplete file. + + + +Additional properties +--------------------- + +Most properties (actions, filters) are described above. + +.. function:: add(rule, postrule) + + :param rule: added rule, i.e. ``policy.pattern(policy.DENY, '[0-9]+\2cz')`` + :param postrule: boolean, if true the rule will be evaluated on answer instead of query + :return: rule description + + Add a new policy rule that is executed either or queries or answers, depending on the ``postrule`` parameter. You can then use the returned rule description to get information and unique identifier for the rule, as well as match count. + + .. code-block:: lua + + -- mirror all queries, keep handle so we can retrieve information later + local rule = policy.add(policy.all(policy.MIRROR('127.0.0.2'))) + -- we can print statistics about this rule any time later + print(string.format('id: %d, matched queries: %d', rule.id, rule.count) + +.. function:: del(id) + + :param id: identifier of a given rule returned by :func:`policy.add` + :return: boolean ``true`` if rule was deleted, ``false`` otherwise + + Remove a rule from policy list. + +.. function:: todnames({name, ...}) + + :param: names table of domain names in textual format + + Returns table of domain names in wire format converted from strings. + + .. code-block:: lua + + -- Convert single name + assert(todname('example.com') == '\7example\3com\0') + -- Convert table of names + policy.todnames({'example.com', 'me.cz'}) + { '\7example\3com\0', '\2me\2cz\0' } + + +.. _RPZ: https://dnsrpz.info/ +.. _`PEM format`: https://en.wikipedia.org/wiki/Privacy-enhanced_Electronic_Mail +.. _`Pro DNS and BIND`: http://www.zytrax.com/books/dns/ch7/rpz.html +.. _`Jan-Piet Mens's post`: http://jpmens.net/2011/04/26/how-to-configure-your-bind-resolvers-to-lie-using-response-policy-zones-rpz/ +.. _`Transport Layer Security`: https://en.wikipedia.org/wiki/Transport_Layer_Security +.. _`DNS Privacy Project`: https://dnsprivacy.org/ +.. _`IETF draft dprive-dtls-and-tls-profiles`: https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles +.. _SNI: https://en.wikipedia.org/wiki/Server_Name_Indication +.. _`Public Suffix List`: https://publicsuffix.org diff --git a/modules/policy/lua-aho-corasick/LICENSE b/modules/policy/lua-aho-corasick/LICENSE new file mode 100644 index 0000000..dd65f72 --- /dev/null +++ b/modules/policy/lua-aho-corasick/LICENSE @@ -0,0 +1,28 @@ + Copyright (c) 2014 CloudFlare, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of CloudFlare, Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/modules/policy/lua-aho-corasick/Makefile b/modules/policy/lua-aho-corasick/Makefile new file mode 100644 index 0000000..6471664 --- /dev/null +++ b/modules/policy/lua-aho-corasick/Makefile @@ -0,0 +1,134 @@ +OS := $(shell uname) + +ifeq ($(OS), Darwin) + SO_EXT := dylib +else + SO_EXT := so +endif + +############################################################################# +# +# Binaries we are going to build +# +############################################################################# +# +C_SO_NAME = libac.$(SO_EXT) +LUA_SO_NAME = ahocorasick.$(SO_EXT) +AR_NAME = libac.a + +############################################################################# +# +# Compile and link flags +# +############################################################################# +PREFIX ?= /usr/local +LUA_VERSION := 5.1 +LUA_INCLUDE_DIR := $(PREFIX)/include/lua$(LUA_VERSION) +SO_TARGET_DIR := $(PREFIX)/lib/lua/$(LUA_VERSION) +LUA_TARGET_DIR := $(PREFIX)/share/lua/$(LUA_VERSION) + +# Available directives: +# -DDEBUG : Turn on debugging support +# -DVERIFY : To verify if the slow-version and fast-version implementations +# get exactly the same result. Note -DVERIFY implies -DDEBUG. +# +COMMON_FLAGS = -O3 #-g -DVERIFY -msse2 -msse3 -msse4.1 +COMMON_FLAGS += -fvisibility=hidden -Wall $(CXXFLAGS) $(MY_CXXFLAGS) $(CPPFLAGS) + +SO_CXXFLAGS = $(COMMON_FLAGS) -fPIC +SO_LFLAGS = $(COMMON_FLAGS) $(LDFLAGS) +AR_CXXFLAGS = $(COMMON_FLAGS) + +# -DVERIFY implies -DDEBUG +ifneq ($(findstring -DVERIFY, $(COMMON_FLAGS)), ) +ifeq ($(findstring -DDEBUG, $(COMMON_FLAGS)), ) + COMMON_FLAGS += -DDEBUG +endif +endif + +AR = ar +AR_FLAGS = cru + +############################################################################# +# +# Divide source codes and objects into several categories +# +############################################################################# +# +SRC_COMMON := ac_fast.cxx ac_slow.cxx +LIBAC_SO_SRC := $(SRC_COMMON) ac.cxx # source for libac.so +LUA_SO_SRC := $(SRC_COMMON) ac_lua.cxx # source for ahocorasick.so +LIBAC_A_SRC := $(LIBAC_SO_SRC) # source for libac.a + +############################################################################# +# +# Make rules +# +############################################################################# +# +.PHONY = all clean test benchmark prepare +all : $(C_SO_NAME) $(LUA_SO_NAME) $(AR_NAME) + +-include c_so_dep.txt +-include lua_so_dep.txt +-include ar_dep.txt + +BUILD_SO_DIR := build_so +BUILD_AR_DIR := build_ar + +$(BUILD_SO_DIR) :; mkdir $@ +$(BUILD_AR_DIR) :; mkdir $@ + +$(BUILD_SO_DIR)/%.o : %.cxx | $(BUILD_SO_DIR) + $(CXX) $< -c $(SO_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ + +$(BUILD_AR_DIR)/%.o : %.cxx | $(BUILD_AR_DIR) + $(CXX) $< -c $(AR_CXXFLAGS) -I$(LUA_INCLUDE_DIR) -MMD -o $@ + +ifneq ($(OS), Darwin) +$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared -Wl,-soname=$(C_SO_NAME) $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt + +$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared -Wl,-soname=$(LUA_SO_NAME) $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt + +else +$(C_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared $(SO_LFLAGS) -o $@ + cat $(addprefix $(BUILD_SO_DIR)/, ${LIBAC_SO_SRC:.cxx=.d}) > c_so_dep.txt + +$(LUA_SO_NAME) : $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.o}) + $(CXX) $+ -shared $(SO_LFLAGS) -o $@ -Wl,-undefined,dynamic_lookup + cat $(addprefix $(BUILD_SO_DIR)/, ${LUA_SO_SRC:.cxx=.d}) > lua_so_dep.txt +endif + +$(AR_NAME) : $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.o}) + $(AR) $(AR_FLAGS) $@ $+ + cat $(addprefix $(BUILD_AR_DIR)/, ${LIBAC_A_SRC:.cxx=.d}) > lua_so_dep.txt + +############################################################################# +# +# Misc +# +############################################################################# +# +test : $(C_SO_NAME) + $(MAKE) -C tests && \ + luajit tests/lua_test.lua && \ + luajit tests/load_ac_test.lua + +benchmark: $(C_SO_NAME) + $(MAKE) benchmark -C tests + +clean : + -rm -rf *.o *.d c_so_dep.txt lua_so_dep.txt ar_dep.txt $(TEST) \ + $(C_SO_NAME) $(LUA_SO_NAME) $(TEST) $(BUILD_SO_DIR) $(BUILD_AR_DIR) \ + $(AR_NAME) + make clean -C tests + +install: + install -D -m 755 $(C_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(C_SO_NAME) + install -D -m 755 $(LUA_SO_NAME) $(DESTDIR)/$(SO_TARGET_DIR)/$(LUA_SO_NAME) + install -D -m 664 load_ac.lua $(DESTDIR)/$(LUA_TARGET_DIR)/load_ac.lua diff --git a/modules/policy/lua-aho-corasick/README.md b/modules/policy/lua-aho-corasick/README.md new file mode 100644 index 0000000..b5cc406 --- /dev/null +++ b/modules/policy/lua-aho-corasick/README.md @@ -0,0 +1,40 @@ +aho-corasick-lua +================ + +C++ and Lua Implementation of the Aho-Corasick (AC) string matching algorithm +(http://dl.acm.org/citation.cfm?id=360855). + +We began with pure Lua implementation and realize the performance is not +satisfactory. So we switch to C/C++ implementation. + +There are two shared objects provied by this package: libac.so and ahocorasick.so +The former is a regular shared object which can be directly used by C/C++ +application, or by Lua via FFI; and the later is a Lua module. An example usage +is shown bellow: + +```lua +local ac = require "ahocorasick" +local dict = {"string1", "string", "etc"} +local acinst = ac.create(dict) +local r = ac.match(acinst, "mystring") +``` + +For efficiency reasons, the implementation is slightly different from the +standard AC algorithm in that it doesn't return a set of strings in the dictionary +that match the given string, instead it only returns one of them in case the string +matches. The functionality of our implementation can be (precisely) described by +following pseudo-c snippet. + +```C +string foo(input-string, dictionary) { + string ret = the-end-of-input-string; + for each string s in dictionary { + // find the first occurrence match sub-string. + ret = min(ret, strstr(input-string, s); + } + return ret; +} +``` + +It's pretty easy to get rid of this limitation, just to associate each state with +a spare bit-vector dipicting the set of strings recognized by that state. diff --git a/modules/policy/lua-aho-corasick/ac.cxx b/modules/policy/lua-aho-corasick/ac.cxx new file mode 100644 index 0000000..23fb3b5 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac.cxx @@ -0,0 +1,101 @@ +// Interface functions for libac.so +// +#include "ac_slow.hpp" +#include "ac_fast.hpp" +#include "ac.h" + +static inline ac_result_t +_match(buf_header_t* ac, const char* str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(ac->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match(buf, str, len); + + #ifdef VERIFY + { + Match_Result r2 = buf->slow_impl->Match(str, len); + if (r.match_begin != r2.begin) { + ASSERT(0); + } else { + ASSERT((r.match_begin < 0) || + (r.match_end == r2.end && + r.pattern_idx == r2.pattern_idx)); + } + } + #endif + return r; +} + +extern "C" int +ac_match2(ac_t* ac, const char* str, unsigned int len) { + ac_result_t r = _match((buf_header_t*)(void*)ac, str, len); + return r.match_begin; +} + +extern "C" ac_result_t +ac_match(ac_t* ac, const char* str, unsigned int len) { + return _match((buf_header_t*)(void*)ac, str, len); +} + +extern "C" ac_result_t +ac_match_longest_l(ac_t* ac, const char* str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(((buf_header_t*)ac)->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match_Longest_L(buf, str, len); + return r; +} + +class BufAlloc : public Buf_Allocator { +public: + virtual AC_Buffer* alloc(int sz) { + return (AC_Buffer*)(new unsigned char[sz]); + } + + // Do not de-allocate the buffer when the BufAlloc die. + virtual void free() {} + + static void myfree(AC_Buffer* buf) { + ASSERT(buf->hdr.magic_num == AC_MAGIC_NUM); + const char* b = (const char*)buf; + delete[] b; + } +}; + +extern "C" ac_t* +ac_create(const char** strv, unsigned int* strlenv, unsigned int v_len) { + if (v_len >= 65535) { + // TODO: Currently we use 16-bit to encode pattern-index (see the + // comment to AC_State::is_term), therefore we are not able to + // handle pattern set with more than 65535 entries. + return 0; + } + + ACS_Constructor *acc; +#ifdef VERIFY + acc = new ACS_Constructor; +#else + ACS_Constructor tmp; + acc = &tmp; +#endif + acc->Construct(strv, strlenv, v_len); + + BufAlloc ba; + AC_Converter cvt(*acc, ba); + AC_Buffer* buf = cvt.Convert(); + +#ifdef VERIFY + buf->slow_impl = acc; +#endif + return (ac_t*)(void*)buf; +} + +extern "C" void +ac_free(void* ac) { + AC_Buffer* buf = (AC_Buffer*)ac; +#ifdef VERIFY + delete buf->slow_impl; +#endif + + BufAlloc::myfree(buf); +} diff --git a/modules/policy/lua-aho-corasick/ac.h b/modules/policy/lua-aho-corasick/ac.h new file mode 100644 index 0000000..30bf447 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac.h @@ -0,0 +1,49 @@ +#ifndef AC_H +#define AC_H +#ifdef __cplusplus +extern "C" { +#endif + +#define AC_EXPORT __attribute__ ((visibility ("default"))) + +/* If the subject-string dosen't match any of the given patterns, "match_begin" + * should be a negative; otherwise the substring of the subject-string, + * starting from offset "match_begin" to "match_end" incusively, + * should exactly match the pattern specified by the 'pattern_idx' (i.e. + * the pattern is "pattern_v[pattern_idx]" where the "pattern_v" is the + * first acutal argument passing to ac_create()) + */ +typedef struct { + int match_begin; + int match_end; + int pattern_idx; +} ac_result_t; + +struct ac_t; + +/* Create an AC instance. "pattern_v" is a vector of patterns, the length of + * i-th pattern is specified by "pattern_len_v[i]"; the number of patterns + * is specified by "vect_len". + * + * Return the instance on success, or NUL otherwise. + */ +ac_t* ac_create(const char** pattern_v, unsigned int* pattern_len_v, + unsigned int vect_len) AC_EXPORT; + +ac_result_t ac_match(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +ac_result_t ac_match_longest_l(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +/* Similar to ac_match() except that it only returns match-begin. The rationale + * for this interface is that luajit has hard time in dealing with strcture- + * return-value. + */ +int ac_match2(ac_t*, const char *str, unsigned int len) AC_EXPORT; + +void ac_free(void*) AC_EXPORT; + +#ifdef __cplusplus +} +#endif + +#endif /* AC_H */ diff --git a/modules/policy/lua-aho-corasick/ac_fast.cxx b/modules/policy/lua-aho-corasick/ac_fast.cxx new file mode 100644 index 0000000..9dbc2e6 --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_fast.cxx @@ -0,0 +1,468 @@ +#include <algorithm> // for std::sort +#include "ac_slow.hpp" +#include "ac_fast.hpp" + +uint32 +AC_Converter::Calc_State_Sz(const ACS_State* s) const { + AC_State dummy; + uint32 sz = offsetof(AC_State, input_vect); + sz += s->Get_GotoNum() * sizeof(dummy.input_vect[0]); + + if (sz < sizeof(AC_State)) + sz = sizeof(AC_State); + + uint32 align = __alignof__(dummy); + sz = (sz + align - 1) & ~(align - 1); + return sz; +} + +AC_Buffer* +AC_Converter::Alloc_Buffer() { + const vector<ACS_State*>& all_states = _acs.Get_All_States(); + const ACS_State* root_state = _acs.Get_Root_State(); + uint32 root_fanout = root_state->Get_GotoNum(); + + // Step 1: Calculate the buffer size + AC_Ofst root_goto_ofst, states_ofst_ofst, first_state_ofst; + + // part 1 : buffer header + uint32 sz = root_goto_ofst = sizeof(AC_Buffer); + + // part 2: Root-node's goto function + if (likely(root_fanout != 255)) + sz += 256; + else + root_goto_ofst = 0; + + // part 3: mapping of state's relative position. + unsigned align = __alignof__(AC_Ofst); + sz = (sz + align - 1) & ~(align - 1); + states_ofst_ofst = sz; + + sz += sizeof(AC_Ofst) * all_states.size(); + + // part 4: state's contents + align = __alignof__(AC_State); + sz = (sz + align - 1) & ~(align - 1); + first_state_ofst = sz; + + uint32 state_sz = 0; + for (vector<ACS_State*>::const_iterator i = all_states.begin(), + e = all_states.end(); i != e; i++) { + state_sz += Calc_State_Sz(*i); + } + state_sz -= Calc_State_Sz(root_state); + + sz += state_sz; + + // Step 2: Allocate buffer, and populate header. + AC_Buffer* buf = _buf_alloc.alloc(sz); + + buf->hdr.magic_num = AC_MAGIC_NUM; + buf->hdr.impl_variant = IMPL_FAST_VARIANT; + buf->buf_len = sz; + buf->root_goto_ofst = root_goto_ofst; + buf->states_ofst_ofst = states_ofst_ofst; + buf->first_state_ofst = first_state_ofst; + buf->root_goto_num = root_fanout; + buf->state_num = _acs.Get_State_Num(); + return buf; +} + +void +AC_Converter::Populate_Root_Goto_Func(AC_Buffer* buf, + GotoVect& goto_vect) { + unsigned char *buf_base = (unsigned char*)(buf); + InputTy* root_gotos = (InputTy*)(buf_base + buf->root_goto_ofst); + const ACS_State* root_state = _acs.Get_Root_State(); + + root_state->Get_Sorted_Gotos(goto_vect); + + // Renumber the ID of root-node's immediate kids. + uint32 new_id = 1; + bool full_fantout = (goto_vect.size() == 255); + if (likely(!full_fantout)) + bzero(root_gotos, 256*sizeof(InputTy)); + + for (GotoVect::iterator i = goto_vect.begin(), e = goto_vect.end(); + i != e; i++, new_id++) { + InputTy c = i->first; + ACS_State* s = i->second; + _id_map[s->Get_ID()] = new_id; + + if (likely(!full_fantout)) + root_gotos[c] = new_id; + } +} + +AC_Buffer* +AC_Converter::Convert() { + // Step 1: Some preparation stuff. + GotoVect gotovect; + + _id_map.clear(); + _ofst_map.clear(); + _id_map.resize(_acs.Get_Next_Node_Id()); + _ofst_map.resize(_acs.Get_Next_Node_Id()); + + // Step 2: allocate buffer to accommodate the entire AC graph. + AC_Buffer* buf = Alloc_Buffer(); + unsigned char* buf_base = (unsigned char*)buf; + + // Step 3: Root node need special care. + Populate_Root_Goto_Func(buf, gotovect); + buf->root_goto_num = gotovect.size(); + _id_map[_acs.Get_Root_State()->Get_ID()] = 0; + + // Step 4: Converting the remaining states by BFSing the graph. + // First of all, enter root's immediate kids to the working list. + vector<const ACS_State*> wl; + State_ID id = 1; + for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); + i != e; i++, id++) { + ACS_State* s = i->second; + wl.push_back(s); + _id_map[s->Get_ID()] = id; + } + + AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); + AC_Ofst ofst = buf->first_state_ofst; + for (uint32 idx = 0; idx < wl.size(); idx++) { + const ACS_State* old_s = wl[idx]; + AC_State* new_s = (AC_State*)(buf_base + ofst); + + // This property should hold as we: + // - States are appended to worklist in the BFS order. + // - sibiling states are appended to worklist in the order of their + // corresponding input. + // + State_ID state_id = idx + 1; + ASSERT(_id_map[old_s->Get_ID()] == state_id); + + state_ofst_vect[state_id] = ofst; + + new_s->first_kid = wl.size() + 1; + new_s->depth = old_s->Get_Depth(); + new_s->is_term = old_s->is_Terminal() ? + old_s->get_Pattern_Idx() + 1 : 0; + + uint32 gotonum = old_s->Get_GotoNum(); + new_s->goto_num = gotonum; + + // Populate the "input" field + old_s->Get_Sorted_Gotos(gotovect); + uint32 input_idx = 0; + uint32 id = wl.size() + 1; + InputTy* input_vect = new_s->input_vect; + for (GotoVect::iterator i = gotovect.begin(), e = gotovect.end(); + i != e; i++, id++, input_idx++) { + input_vect[input_idx] = i->first; + + ACS_State* kid = i->second; + _id_map[kid->Get_ID()] = id; + wl.push_back(kid); + } + + _ofst_map[old_s->Get_ID()] = ofst; + ofst += Calc_State_Sz(old_s); + } + + // This assertion might be useful to catch buffer overflow + ASSERT(ofst == buf->buf_len); + + // Populate the fail-link field. + for (vector<const ACS_State*>::iterator i = wl.begin(), e = wl.end(); + i != e; i++) { + const ACS_State* slow_s = *i; + State_ID fast_s_id = _id_map[slow_s->Get_ID()]; + AC_State* fast_s = (AC_State*)(buf_base + state_ofst_vect[fast_s_id]); + if (const ACS_State* fl = slow_s->Get_FailLink()) { + State_ID id = _id_map[fl->Get_ID()]; + fast_s->fail_link = id; + } else + fast_s->fail_link = 0; + } +#ifdef DEBUG + //dump_buffer(buf, stderr); +#endif + return buf; +} + +static inline AC_State* +Get_State_Addr(unsigned char* buf_base, AC_Ofst* StateOfstVect, uint32 state_id) { + ASSERT(state_id != 0 && "root node is handled in speical way"); + ASSERT(state_id < ((AC_Buffer*)buf_base)->state_num); + return (AC_State*)(buf_base + StateOfstVect[state_id]); +} + +// The performance of the binary search is critical to this work. +// +// Here we provide two versions of binary-search functions. +// The non-pristine version seems to consistently out-perform "pristine" one on +// bunch of benchmarks we tested. With the benchmark under tests/testinput/ +// +// The speedup is following on my laptop (core i7, ubuntu): +// +// benchmark was is +// ---------------------------------------- +// image.bin 2.3s 2.0s +// test.tar 6.7s 5.7s +// +// NOTE: As of I write this comment, we only measure the performance on about +// 10+ benchmarks. It's still too early to say which one works better. +// +#if !defined(BS_MULTI_VER) +static bool __attribute__((always_inline)) inline +Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { + if (vect_len <= 8) { + for (int i = 0; i < vect_len; i++) { + if (input_vect[i] == input) { + idx = i; + return true; + } + } + return false; + } + + // The "low" and "high" must be signed integers, as they could become -1. + // Also since they are signed integer, "(low + high)/2" is sightly more + // expensive than (low+high)>>1 or ((unsigned)(low + high))/2. + // + int low = 0, high = vect_len - 1; + while (low <= high) { + int mid = (low + high) >> 1; + InputTy mid_c = input_vect[mid]; + + if (input < mid_c) + high = mid - 1; + else if (input > mid_c) + low = mid + 1; + else { + idx = mid; + return true; + } + } + return false; +} + +#else + +/* Let us call this version "pristine" version. */ +static inline bool +Binary_Search_Input(InputTy* input_vect, int vect_len, InputTy input, int& idx) { + int low = 0, high = vect_len - 1; + while (low <= high) { + int mid = (low + high) >> 1; + InputTy mid_c = input_vect[mid]; + + if (input < mid_c) + high = mid - 1; + else if (input > mid_c) + low = mid + 1; + else { + idx = mid; + return true; + } + } + return false; +} +#endif + +typedef enum { + // Look for the first match. e.g. pattern set = {"ab", "abc", "def"}, + // subject string "ababcdef". The first match would be "ab" at the + // beginning of the subject string. + MV_FIRST_MATCH, + + // Look for the left-most longest match. Follow above example; there are + // two longest matches, "abc" and "def", and the left-most longest match + // is "abc". + MV_LEFT_LONGEST, + + // Similar to the left-most longest match, except that it returns the + // *right* most longest match. Follow above example, the match would + // be "def". NYI. + MV_RIGHT_LONGEST, + + // Return all patterns that match that given subject string. NYI. + MV_ALL_MATCHES, +} MATCH_VARIANT; + +/* The Match_Tmpl is the template for vairants MV_FIRST_MATCH, MV_LEFT_LONGEST, + * MV_RIGHT_LONGEST (If we really really need MV_RIGHT_LONGEST variant, we are + * better off implementing it in a seprate function). + * + * The Match_Tmpl supports three variants at once "symbolically", once it's + * instanced to a particular variants, all the code irrelevant to the variants + * will be statically removed. So don't worry about the code like + * "if (variant == MV_XXXX)"; they will not incur any penalty. + * + * The drawback of using template is increased code size. Unfortunately, there + * is no silver bullet. + */ +template<MATCH_VARIANT variant> static ac_result_t +Match_Tmpl(AC_Buffer* buf, const char* str, uint32 len) { + unsigned char* buf_base = (unsigned char*)(buf); + unsigned char* root_goto = buf_base + buf->root_goto_ofst; + AC_Ofst* states_ofst_vect = (AC_Ofst* )(buf_base + buf->states_ofst_ofst); + + AC_State* state = 0; + uint32 idx = 0; + + // Skip leading chars that are not valid input of root-nodes. + if (likely(buf->root_goto_num != 255)) { + while(idx < len) { + unsigned char c = str[idx++]; + if (unsigned char kid_id = root_goto[c]) { + state = Get_State_Addr(buf_base, states_ofst_vect, kid_id); + break; + } + } + } else { + idx = 1; + state = Get_State_Addr(buf_base, states_ofst_vect, *str); + } + + ac_result_t r = {-1, -1}; + if (likely(state != 0)) { + if (unlikely(state->is_term)) { + /* Dictionary may have string of length 1 */ + r.match_begin = idx - state->depth; + r.match_end = idx - 1; + r.pattern_idx = state->is_term - 1; + + if (variant == MV_FIRST_MATCH) { + return r; + } + } + } + + while (idx < len) { + unsigned char c = str[idx]; + int res; + bool found; + found = Binary_Search_Input(state->input_vect, state->goto_num, c, res); + if (found) { + // The "t = goto(c, current_state)" is valid, advance to state "t". + uint32 kid = state->first_kid + res; + state = Get_State_Addr(buf_base, states_ofst_vect, kid); + idx++; + } else { + // Follow the fail-link. + State_ID fl = state->fail_link; + if (fl == 0) { + // fail-link is root-node, which implies the root-node dosen't + // have 255 valid transitions (otherwise, the fail-link should + // points to "goto(root, c)"), so we don't need speical handling + // as we did before this while-loop is entered. + // + while(idx < len) { + InputTy c = str[idx++]; + if (unsigned char kid_id = root_goto[c]) { + state = + Get_State_Addr(buf_base, states_ofst_vect, kid_id); + break; + } + } + } else { + state = Get_State_Addr(buf_base, states_ofst_vect, fl); + } + } + + // Check to see if the state is terminal state? + if (state->is_term) { + if (variant == MV_FIRST_MATCH) { + ac_result_t r; + r.match_begin = idx - state->depth; + r.match_end = idx - 1; + r.pattern_idx = state->is_term - 1; + return r; + } + + if (variant == MV_LEFT_LONGEST) { + int match_begin = idx - state->depth; + int match_end = idx - 1; + + if (r.match_begin == -1 || + match_end - match_begin > r.match_end - r.match_begin) { + r.match_begin = match_begin; + r.match_end = match_end; + r.pattern_idx = state->is_term - 1; + } + continue; + } + + ASSERT(false && "NYI"); + } + } + + return r; +} + +ac_result_t +Match(AC_Buffer* buf, const char* str, uint32 len) { + return Match_Tmpl<MV_FIRST_MATCH>(buf, str, len); +} + +ac_result_t +Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len) { + return Match_Tmpl<MV_LEFT_LONGEST>(buf, str, len); +} + +#ifdef DEBUG +void +AC_Converter::dump_buffer(AC_Buffer* buf, FILE* f) { + vector<AC_Ofst> state_ofst; + state_ofst.resize(_id_map.size()); + + fprintf(f, "Id maps between old/slow and new/fast graphs\n"); + int old_id = 0; + for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end(); + i != e; i++, old_id++) { + State_ID new_id = *i; + if (new_id != 0) { + fprintf(f, "%d -> %d, ", old_id, new_id); + } + } + fprintf(f, "\n"); + + int idx = 0; + for (vector<uint32>::iterator i = _id_map.begin(), e = _id_map.end(); + i != e; i++, idx++) { + uint32 id = *i; + if (id == 0) continue; + state_ofst[id] = _ofst_map[idx]; + } + + unsigned char* buf_base = (unsigned char*)buf; + + // dump root goto-function. + fprintf(f, "root, fanout:%d goto {", buf->root_goto_num); + if (buf->root_goto_num != 255) { + unsigned char* root_goto = buf_base + buf->root_goto_ofst; + for (uint32 i = 0; i < 255; i++) { + if (root_goto[i] != 0) + fprintf(f, "%c->S:%d, ", (unsigned char)i, root_goto[i]); + } + } else { + fprintf(f, "full fanout\n"); + } + fprintf(f, "}\n"); + + // dump remaining states. + AC_Ofst* state_ofst_vect = (AC_Ofst*)(buf_base + buf->states_ofst_ofst); + for (uint32 i = 1, e = buf->state_num; i < e; i++) { + AC_Ofst ofst = state_ofst_vect[i]; + ASSERT(ofst == state_ofst[i]); + fprintf(f, "S:%d, ofst:%d, goto={", i, ofst); + + AC_State* s = (AC_State*)(buf_base + ofst); + State_ID kid = s->first_kid; + for (uint32 k = 0, ke = s->goto_num; k < ke; k++, kid++) + fprintf(f, "%c->S:%d, ", s->input_vect[k], kid); + + fprintf(f, "}, fail-link = S:%d, %s\n", s->fail_link, + s->is_term ? "terminal" : ""); + } +} +#endif diff --git a/modules/policy/lua-aho-corasick/ac_fast.hpp b/modules/policy/lua-aho-corasick/ac_fast.hpp new file mode 100644 index 0000000..9ac557c --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_fast.hpp @@ -0,0 +1,124 @@ +#ifndef AC_FAST_H +#define AC_FAST_H + +#include <vector> +#include "ac.h" +#include "ac_slow.hpp" + +using namespace std; + +class ACS_Constructor; + +typedef uint32 AC_Ofst; +typedef uint32 State_ID; + +// The entire "fast" AC graph is converted from its "slow" version, and store +// in an consecutive trunk of memory or "buffer". Since the pointers in the +// fast AC graph are represented as offset relative to the base address of +// the buffer, this fast AC graph is position-independent, meaning cloning +// the fast graph is just to memcpy the entire buffer. +// +// The buffer is laid-out as following: +// +// 1. The buffer header. (i.e. the AC_Buffer content) +// 2. root-node's goto functions. It is represented as an array indiced by +// root-node's valid inputs, and the element is the ID of the corresponding +// transition state (aka kid). To save space, we used 8-bit to represent +// the IDs. ID of root's kids starts with 1. +// +// Root may have 255 valid inputs. In this speical case, i-th element +// stores value i -- i.e the i-th state. So, we don't need such array +// at all. On the other hand, 8-bit is insufficient to encode kids' ID. +// +// 3. An array indiced by state's id, and the element is the offset +// of correspoding state wrt the base address of the buffer. +// +// 4. the contents of states. +// +typedef struct { + buf_header_t hdr; // The header exposed to the user using this lib. +#ifdef VERIFY + ACS_Constructor* slow_impl; +#endif + uint32 buf_len; + AC_Ofst root_goto_ofst; // addr of root node's goto() function. + AC_Ofst states_ofst_ofst; // addr of state pointer vector (indiced by id) + AC_Ofst first_state_ofst; // addr of the first state in the buffer. + uint16 root_goto_num; // fan-out of root-node. + uint16 state_num; // number of states + + // Followed by the gut of the buffer: + // 1. map: root's-valid-input -> kid's id + // 2. map: state's ID -> offset of the state + // 3. states' content. +} AC_Buffer; + +// Depict the state of "fast" AC graph. +typedef struct { + // transition are sorted. For instance, state s1, has two transitions : + // goto(b) -> S_b, goto(a)->S_a. The inputs are sorted in the ascending + // order, and the target states are permuted accordingly. In this case, + // the inputs are sorted as : a, b, and the target states are permuted + // into S_a, S_b. So, S_a is the 1st kid, the ID of kids are consecutive, + // so we don't need to save all the target kids. + // + State_ID first_kid; + AC_Ofst fail_link; + short depth; // How far away from root. + unsigned short is_term; // Is terminal node. if is_term != 0, it encodes + // the value of "1 + pattern-index". + unsigned char goto_num; // The number of valid transition. + InputTy input_vect[1]; // Vector of valid input. Must be last field! +} AC_State; + +class Buf_Allocator { +public: + Buf_Allocator() : _buf(0) {} + virtual ~Buf_Allocator() { free(); } + + virtual AC_Buffer* alloc(int sz) = 0; + virtual void free() {}; +protected: + AC_Buffer* _buf; +}; + +// Convert slow-AC-graph into fast one. +class AC_Converter { +public: + AC_Converter(ACS_Constructor& acs, Buf_Allocator& ba) : + _acs(acs), _buf_alloc(ba) {} + AC_Buffer* Convert(); + +private: + // Return the size in byte needed to to save the specified state. + uint32 Calc_State_Sz(const ACS_State *) const; + + // In fast-AC-graph, the ID is bit trikcy. Given a state of slow-graph, + // this function is to return the ID of its counterpart in the fast-graph. + State_ID Get_Renumbered_Id(const ACS_State *s) const { + const vector<uint32> &m = _id_map; + return m[s->Get_ID()]; + } + + AC_Buffer* Alloc_Buffer(); + void Populate_Root_Goto_Func(AC_Buffer *, GotoVect&); + +#ifdef DEBUG + void dump_buffer(AC_Buffer*, FILE*); +#endif + +private: + ACS_Constructor& _acs; + Buf_Allocator& _buf_alloc; + + // map: ID of state in slow-graph -> ID of counterpart in fast-graph. + vector<uint32> _id_map; + + // map: ID of state in slow-graph -> offset of counterpart in fast-graph. + vector<AC_Ofst> _ofst_map; +}; + +ac_result_t Match(AC_Buffer* buf, const char* str, uint32 len); +ac_result_t Match_Longest_L(AC_Buffer* buf, const char* str, uint32 len); + +#endif // AC_FAST_H diff --git a/modules/policy/lua-aho-corasick/ac_lua.cxx b/modules/policy/lua-aho-corasick/ac_lua.cxx new file mode 100644 index 0000000..ad7307e --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_lua.cxx @@ -0,0 +1,173 @@ +// Interface functions for libac.so +// +#include <vector> +#include <string> +#include "ac_slow.hpp" +#include "ac_fast.hpp" +#include "ac.h" // for the definition of ac_result_t +#include "ac_util.hpp" + +extern "C" { + #include <lua.h> + #include <lauxlib.h> +} + +#if defined(USE_SLOW_VER) +#error "Not going to implement it" +#endif + +using namespace std; +static const char* tname = "aho-corasick"; + +class BufAlloc : public Buf_Allocator { +public: + BufAlloc(lua_State* L) : _L(L) {} + virtual AC_Buffer* alloc(int sz) { + return (AC_Buffer*)lua_newuserdata (_L, sz); + } + + // Let GC to take care. + virtual void free() {} + +private: + lua_State* _L; +}; + +static bool +_create_helper(lua_State* L, const vector<const char*>& str_v, + const vector<unsigned int>& strlen_v) { + ASSERT(str_v.size() == strlen_v.size()); + + ACS_Constructor acc; + BufAlloc ba(L); + + // Step 1: construt the slow version. + unsigned int strnum = str_v.size(); + const char** str_vect = new const char*[strnum]; + unsigned int* strlen_vect = new unsigned int[strnum]; + + int idx = 0; + for (vector<const char*>::const_iterator i = str_v.begin(), e = str_v.end(); + i != e; i++) { + str_vect[idx++] = *i; + } + + idx = 0; + for (vector<unsigned int>::const_iterator i = strlen_v.begin(), + e = strlen_v.end(); i != e; i++) { + strlen_vect[idx++] = *i; + } + + acc.Construct(str_vect, strlen_vect, idx); + delete[] str_vect; + delete[] strlen_vect; + + // Step 2: convert to fast version + AC_Converter cvt(acc, ba); + return cvt.Convert() != 0; +} + +static ac_result_t +_match_helper(buf_header_t* ac, const char *str, unsigned int len) { + AC_Buffer* buf = (AC_Buffer*)(void*)ac; + ASSERT(ac->magic_num == AC_MAGIC_NUM); + + ac_result_t r = Match(buf, str, len); + return r; +} + +// LUA sematic: +// input: array of strings +// output: userdata containing the AC-graph (i.e. the AC_Buffer). +// +static int +lac_create(lua_State* L) { + // The table of the array must be the 1st argument. + int input_tab = 1; + + luaL_checktype(L, input_tab, LUA_TTABLE); + + // Init the "iteartor". + lua_pushnil(L); + + vector<const char*> str_v; + vector<unsigned int> strlen_v; + + // Loop over the elements + while (lua_next(L, input_tab)) { + size_t str_len; + const char* s = luaL_checklstring(L, -1, &str_len); + str_v.push_back(s); + strlen_v.push_back(str_len); + + // remove the value, but keep the key as the iterator. + lua_pop(L, 1); + } + + // pop the nil value + lua_pop(L, 1); + + if (_create_helper(L, str_v, strlen_v)) { + // The AC graph, as a userdata is already pushed to the stack, hence 1. + return 1; + } + + return 0; +} + +// LUA input: +// arg1: the userdata, representing the AC graph, returned from l_create(). +// arg2: the string to be matched. +// +// LUA return: +// if match, return index range of the match; otherwise nil is returned. +// +static int +lac_match(lua_State* L) { + buf_header_t* ac = (buf_header_t*)lua_touserdata(L, 1); + if (!ac) { + luaL_checkudata(L, 1, tname); + return 0; + } + + size_t len; + const char* str; + #if LUA_VERSION_NUM >= 502 + str = luaL_tolstring(L, 2, &len); + #else + str = lua_tolstring(L, 2, &len); + #endif + if (!str) { + luaL_checkstring(L, 2); + return 0; + } + + ac_result_t r = _match_helper(ac, str, len); + if (r.match_begin != -1) { + lua_pushinteger(L, r.match_begin); + lua_pushinteger(L, r.match_end); + return 2; + } + + return 0; +} + +static const struct luaL_Reg lib_funcs[] = { + { "create", lac_create }, + { "match", lac_match }, + {0, 0} +}; + +extern "C" int AC_EXPORT +luaopen_ahocorasick(lua_State* L) { + luaL_newmetatable(L, tname); + +#if LUA_VERSION_NUM == 501 + luaL_register(L, tname, lib_funcs); +#elif LUA_VERSION_NUM >= 502 + luaL_newlib(L, lib_funcs); +#else + #error "Don't know how to do it right" +#endif + return 1; +} diff --git a/modules/policy/lua-aho-corasick/ac_slow.cxx b/modules/policy/lua-aho-corasick/ac_slow.cxx new file mode 100644 index 0000000..cb3957a --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_slow.cxx @@ -0,0 +1,318 @@ +#include <ctype.h> +#include <strings.h> // for bzero +#include <algorithm> +#include "ac_slow.hpp" +#include "ac.h" + +////////////////////////////////////////////////////////////////////////// +// +// Implementation of AhoCorasick_Slow +// +////////////////////////////////////////////////////////////////////////// +// +ACS_Constructor::ACS_Constructor() : _next_node_id(1) { + _root = new_state(); + _root_char = new InputTy[256]; + bzero((void*)_root_char, 256); + +#ifdef VERIFY + _pattern_buf = 0; +#endif +} + +ACS_Constructor::~ACS_Constructor() { + for (std::vector<ACS_State* >::iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + delete *i; + } + _all_states.clear(); + delete[] _root_char; + +#ifdef VERIFY + delete[] _pattern_buf; +#endif +} + +ACS_State* +ACS_Constructor::new_state() { + ACS_State* t = new ACS_State(_next_node_id++); + _all_states.push_back(t); + return t; +} + +void +ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len, + int pattern_idx) { + ACS_State* state = _root; + for (unsigned int i = 0; i < str_len; i++) { + const char c = str[i]; + ACS_State* new_s = state->Get_Goto(c); + if (!new_s) { + new_s = new_state(); + new_s->_depth = state->_depth + 1; + state->Set_Goto(c, new_s); + } + state = new_s; + } + state->_is_terminal = true; + state->set_Pattern_Idx(pattern_idx); +} + +void +ACS_Constructor::Propagate_faillink() { + ACS_State* r = _root; + std::vector<ACS_State*> wl; + + const ACS_Goto_Map& m = r->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) { + ACS_State* s = i->second; + s->_fail_link = r; + wl.push_back(s); + } + + // For any input c, make sure "goto(root, c)" is valid, which make the + // fail-link propagation lot easier. + ACS_Goto_Map goto_save = r->_goto_map; + for (uint32 i = 0; i <= 255; i++) { + ACS_State* s = r->Get_Goto(i); + if (!s) r->Set_Goto(i, r); + } + + for (uint32 i = 0; i < wl.size(); i++) { + ACS_State* s = wl[i]; + ACS_State* fl = s->_fail_link; + + const ACS_Goto_Map& tran_map = s->Get_Goto_Map(); + + for (ACS_Goto_Map::const_iterator ii = tran_map.begin(), + ee = tran_map.end(); ii != ee; ii++) { + InputTy c = ii->first; + ACS_State *tran = ii->second; + + ACS_State* tran_fl = 0; + for (ACS_State* fl_walk = fl; ;) { + if (ACS_State* t = fl_walk->Get_Goto(c)) { + tran_fl = t; + break; + } else { + fl_walk = fl_walk->Get_FailLink(); + } + } + + tran->_fail_link = tran_fl; + wl.push_back(tran); + } + } + + // Remove "goto(root, c) == root" transitions + r->_goto_map = goto_save; +} + +void +ACS_Constructor::Construct(const char** strv, unsigned int* strlenv, + uint32 strnum) { + Save_Patterns(strv, strlenv, strnum); + + for (uint32 i = 0; i < strnum; i++) { + Add_Pattern(strv[i], strlenv[i], i); + } + + Propagate_faillink(); + unsigned char* p = _root_char; + + const ACS_Goto_Map& m = _root->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); + i != e; i++) { + p[i->first] = 1; + } +} + +Match_Result +ACS_Constructor::MatchHelper(const char *str, uint32 len) const { + const ACS_State* root = _root; + const ACS_State* state = root; + + uint32 idx = 0; + while (idx < len) { + InputTy c = str[idx]; + idx++; + if (_root_char[c]) { + state = root->Get_Goto(c); + break; + } + } + + if (unlikely(state->is_Terminal())) { + // This could happen if the one of the pattern has only one char! + uint32 pos = idx - 1; + Match_Result r(pos - state->Get_Depth() + 1, pos, + state->get_Pattern_Idx()); + return r; + } + + while (idx < len) { + InputTy c = str[idx]; + ACS_State* gs = state->Get_Goto(c); + + if (!gs) { + ACS_State* fl = state->Get_FailLink(); + if (fl == root) { + while (idx < len) { + InputTy c = str[idx]; + idx++; + if (_root_char[c]) { + state = root->Get_Goto(c); + break; + } + } + } else { + state = fl; + } + } else { + idx ++; + state = gs; + } + + if (state->is_Terminal()) { + uint32 pos = idx - 1; + Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos, + state->get_Pattern_Idx()); + return r; + } + } + + return Match_Result(-1, -1, -1); +} + +#ifdef DEBUG +void +ACS_Constructor::dump_text(const char* txtfile) const { + FILE* f = fopen(txtfile, "w+"); + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State* s = *i; + + fprintf(f, "S%d goto:{", s->Get_ID()); + const ACS_Goto_Map& goto_func = s->Get_Goto_Map(); + + for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end(); + i != e; i++) { + InputTy input = i->first; + ACS_State* tran = i->second; + if (isprint(input)) + fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID()); + else + fprintf(f, "%#x -> S:%d,", input, tran->Get_ID()); + } + fprintf(f, "} "); + + if (s->_fail_link) { + fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID()); + } + + if (s->_is_terminal) { + fprintf(f, ", terminal"); + } + + fprintf(f, "\n"); + } + fclose(f); +} + +void +ACS_Constructor::dump_dot(const char *dotfile) const { + FILE* f = fopen(dotfile, "w+"); + const char* indent = " "; + + fprintf(f, "digraph G {\n"); + + // Emit node information + fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID()); + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State *s = *i; + if (s->_is_terminal) { + fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID()); + } + } + fprintf(f, "\n"); + + // Emit edge information + for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(), + e = _all_states.end(); i != e; i++) { + ACS_State* s = *i; + uint32 id = s->Get_ID(); + + const ACS_Goto_Map& m = s->Get_Goto_Map(); + for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end(); + ii != ee; ii++) { + InputTy input = ii->first; + ACS_State* tran = ii->second; + if (isalnum(input)) + fprintf(f, "%s%d -> %d [label=%c];\n", + indent, id, tran->Get_ID(), input); + else + fprintf(f, "%s%d -> %d [label=\"%#x\"];\n", + indent, id, tran->Get_ID(), input); + + } + + // Emit fail-link + ACS_State* fl = s->Get_FailLink(); + if (fl && fl != _root) { + fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n", + indent, id, fl->Get_ID()); + } + } + fprintf(f, "}\n"); + fclose(f); +} +#endif + +#ifdef VERIFY +void +ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r) + const { + if (r->begin >= 0) { + unsigned len = r->end - r->begin + 1; + int ptn_idx = r->pattern_idx; + + ASSERT(ptn_idx >= 0 && + len == get_ith_Pattern_Len(ptn_idx) && + memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0); + } +} + +void +ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv, + int pattern_num) { + // calculate the total size needed to save all patterns. + // + int buf_size = 0; + for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; } + + // HINT: patterns are delimited by '\0' in order to ease debugging. + buf_size += pattern_num; + ASSERT(_pattern_buf == 0); + _pattern_buf = new char[buf_size + 1]; + #define MAGIC_NUM 0x5a + _pattern_buf[buf_size] = MAGIC_NUM; + + int ofst = 0; + _pattern_lens.resize(pattern_num); + _pattern_vect.resize(pattern_num); + for (int i = 0; i < pattern_num; i++) { + int l = strlenv[i]; + _pattern_lens[i] = l; + _pattern_vect[i] = _pattern_buf + ofst; + + memcpy(_pattern_buf + ofst, strv[i], l); + ofst += l; + _pattern_buf[ofst++] = '\0'; + } + + ASSERT(_pattern_buf[buf_size] == MAGIC_NUM); + #undef MAGIC_NUM +} + +#endif diff --git a/modules/policy/lua-aho-corasick/ac_slow.hpp b/modules/policy/lua-aho-corasick/ac_slow.hpp new file mode 100644 index 0000000..030b95d --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_slow.hpp @@ -0,0 +1,158 @@ +#ifndef MY_AC_H +#define MY_AC_H + +#include <string.h> +#include <stdio.h> +#include <map> +#include <vector> +#include <algorithm> // for std::sort +#include "ac_util.hpp" + +// Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation" +class ACS_State; +class ACS_Constructor; +class AhoCorasick; + +using namespace std; + +typedef std::map<InputTy, ACS_State*> ACS_Goto_Map; + +class Match_Result { +public: + int begin; + int end; + int pattern_idx; + Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {} +}; + +typedef pair<InputTy, ACS_State *> GotoPair; +typedef vector<GotoPair> GotoVect; + +// Sorting functor +class GotoSort { +public: + bool operator() (const GotoPair& g1, const GotoPair& g2) { + return g1.first < g2.first; + } +}; + +class ACS_State { +friend class ACS_Constructor; + +public: + ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0), + _is_terminal(false), _fail_link(0){} + ~ACS_State() {}; + + void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; } + ACS_State *Get_Goto(InputTy c) const { + ACS_Goto_Map::const_iterator iter = _goto_map.find(c); + return iter != _goto_map.end() ? (*iter).second : 0; + } + + // Return all transitions sorted in the ascending order of their input. + void Get_Sorted_Gotos(GotoVect& Gotos) const { + const ACS_Goto_Map& m = _goto_map; + Gotos.clear(); + for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); + i != e; i++) { + Gotos.push_back(GotoPair(i->first, i->second)); + } + sort(Gotos.begin(), Gotos.end(), GotoSort()); + } + + ACS_State* Get_FailLink() const { return _fail_link; } + uint32 Get_GotoNum() const { return _goto_map.size(); } + uint32 Get_ID() const { return _id; } + uint32 Get_Depth() const { return _depth; } + const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; } + bool is_Terminal() const { return _is_terminal; } + int get_Pattern_Idx() const { + ASSERT(is_Terminal() && _pattern_idx >= 0); + return _pattern_idx; + } + +private: + void set_Pattern_Idx(int idx) { + ASSERT(is_Terminal()); + _pattern_idx = idx; + } + +private: + uint32 _id; + int _pattern_idx; + short _depth; + bool _is_terminal; + ACS_Goto_Map _goto_map; + ACS_State* _fail_link; +}; + +class ACS_Constructor { +public: + ACS_Constructor(); + ~ACS_Constructor(); + + void Construct(const char** strv, unsigned int* strlenv, + unsigned int strnum); + + Match_Result Match(const char* s, uint32 len) const { + Match_Result r = MatchHelper(s, len); + Verify_Result(s, &r); + return r; + } + + Match_Result Match(const char* s) const { return Match(s, strlen(s)); } + +#ifdef DEBUG + void dump_text(const char* = "ac.txt") const; + void dump_dot(const char* = "ac.dot") const; +#endif + const ACS_State *Get_Root_State() const { return _root; } + const vector<ACS_State*>& Get_All_States() const { + return _all_states; + } + + uint32 Get_Next_Node_Id() const { return _next_node_id; } + uint32 Get_State_Num() const { return _next_node_id - 1; } + +private: + void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx); + ACS_State* new_state(); + void Propagate_faillink(); + + Match_Result MatchHelper(const char*, uint32 len) const; + +#ifdef VERIFY + void Verify_Result(const char* subject, const Match_Result* r) const; + void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len); + const char* get_ith_Pattern(unsigned i) const { + ASSERT(i < _pattern_vect.size()); + return _pattern_vect.at(i); + } + unsigned get_ith_Pattern_Len(unsigned i) const { + ASSERT(i < _pattern_lens.size()); + return _pattern_lens.at(i); + } +#else + void Verify_Result(const char* subject, const Match_Result* r) const { + (void)subject; (void)r; + } + void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) { + (void)strv; (void)strlenv; + } +#endif + +private: + ACS_State* _root; + vector<ACS_State*> _all_states; + unsigned char* _root_char; + uint32 _next_node_id; + +#ifdef VERIFY + char* _pattern_buf; + vector<int> _pattern_lens; + vector<char*> _pattern_vect; +#endif +}; + +#endif diff --git a/modules/policy/lua-aho-corasick/ac_util.hpp b/modules/policy/lua-aho-corasick/ac_util.hpp new file mode 100644 index 0000000..56fd46c --- /dev/null +++ b/modules/policy/lua-aho-corasick/ac_util.hpp @@ -0,0 +1,69 @@ +/* + Copyright (c) 2014 CloudFlare, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following disclaimer + in the documentation and/or other materials provided with the + distribution. + * Neither the name of CloudFlare, Inc. nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#ifndef AC_UTIL_H +#define AC_UTIL_H + +#ifdef DEBUG +#include <stdio.h> // for fprintf +#include <stdlib.h> // for abort +#endif + +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long uint64; +typedef unsigned char InputTy; + +#ifdef DEBUG + // Usage examples: ASSERT(a > b), ASSERT(foo() && "Opps, foo() reutrn 0"); + #define ASSERT(c) if (!(c))\ + { fprintf(stderr, "%s:%d Assert: %s\n", __FILE__, __LINE__, #c); abort(); } +#else + #define ASSERT(c) ((void)0) +#endif + +#define likely(x) __builtin_expect((x),1) +#define unlikely(x) __builtin_expect((x),0) + +#ifndef offsetof +#define offsetof(st, m) ((size_t)(&((st *)0)->m)) +#endif + +typedef enum { + IMPL_SLOW_VARIANT = 1, + IMPL_FAST_VARIANT = 2, +} impl_var_t; + +#define AC_MAGIC_NUM 0x5a +typedef struct { + unsigned char magic_num; + unsigned char impl_variant; +} buf_header_t; + +#endif //AC_UTIL_H diff --git a/modules/policy/lua-aho-corasick/load_ac.lua b/modules/policy/lua-aho-corasick/load_ac.lua new file mode 100644 index 0000000..eb70446 --- /dev/null +++ b/modules/policy/lua-aho-corasick/load_ac.lua @@ -0,0 +1,90 @@ +-- Helper wrappring script for loading shared object libac.so (FFI interface) +-- from package.cpath instead of LD_LIBRARTY_PATH. +-- + +local ffi = require 'ffi' +ffi.cdef[[ + void* ac_create(const char** str_v, unsigned int* strlen_v, + unsigned int v_len); + int ac_match2(void*, const char *str, int len); + void ac_free(void*); +]] + +local _M = {} + +local string_gmatch = string.gmatch +local string_match = string.match + +local ac_lib = nil +local ac_create = nil +local ac_match = nil +local ac_free = nil + +--[[ Find shared object file package.cpath, obviating the need of setting + LD_LIBRARY_PATH +]] +local function find_shared_obj(cpath, so_name) + for k, v in string_gmatch(cpath, "[^;]+") do + local so_path = string_match(k, "(.*/)") + if so_path then + -- "so_path" could be nil. e.g, the dir path component is "." + so_path = so_path .. so_name + + -- Don't get me wrong, the only way to know if a file exist is + -- trying to open it. + local f = io.open(so_path) + if f ~= nil then + io.close(f) + return so_path + end + end + end +end + +function _M.load_ac_lib() + if ac_lib ~= nil then + return ac_lib + else + local so_path = find_shared_obj(package.cpath, "libac.so") + if so_path ~= nil then + ac_lib = ffi.load(so_path) + ac_create = ac_lib.ac_create + ac_match = ac_lib.ac_match2 + ac_free = ac_lib.ac_free + return ac_lib + end + end +end + +-- Create an Aho-Corasick instance, and return the instance if it was +-- successful. +function _M.create_ac(dict) + local strnum = #dict + if ac_lib == nil then + _M.load_ac_lib() + end + + local str_v = ffi.new("const char *[?]", strnum) + local strlen_v = ffi.new("unsigned int [?]", strnum) + + for i = 1, strnum do + local s = dict[i] + str_v[i - 1] = s + strlen_v[i - 1] = #s + end + + local ac = ac_create(str_v, strlen_v, strnum); + if ac ~= nil then + return ffi.gc(ac, ac_free) + end +end + +-- Return nil if str doesn't match the dictionary, else return non-nil. +function _M.match(ac, str) + local r = ac_match(ac, str, #str); + if r >= 0 then + return r + end +end + +return _M diff --git a/modules/policy/lua-aho-corasick/mytest.cxx b/modules/policy/lua-aho-corasick/mytest.cxx new file mode 100644 index 0000000..ef3dc87 --- /dev/null +++ b/modules/policy/lua-aho-corasick/mytest.cxx @@ -0,0 +1,200 @@ +#include <stdio.h> +#include <string.h> +#include <vector> +#include "ac.h" + +using namespace std; + +///////////////////////////////////////////////////////////////////////// +// +// Test using strings from input files +// +///////////////////////////////////////////////////////////////////////// +// +class BigFileTester { +public: + BigFileTester(const char* filepath); + +private: + void Genector +privaete: + const char* _msg; + int _msg_len; + int _key_num; // number of strings in dictionary + int _key_len_idx; +}; + +///////////////////////////////////////////////////////////////////////// +// +// Simple (yet maybe tricky) testings +// +///////////////////////////////////////////////////////////////////////// +// +typedef struct { + const char* str; + const char* match; +} StrPair; + +typedef struct { + const char* name; + const char** dict; + StrPair* strpairs; + int dict_len; + int strpair_num; +} TestingCase; + +class Tests { +public: + Tests(const char* name, + const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num) { + if (!_tests) + _tests = new vector<TestingCase>; + + TestingCase tc; + tc.name = name; + tc.dict = dict; + tc.strpairs = strpairs; + tc.dict_len = dict_len; + tc.strpair_num = strpair_num; + _tests->push_back(tc); + } + + static vector<TestingCase>* Get_Tests() { return _tests; } + static void Erase_Tests() { delete _tests; _tests = 0; } + +private: + static vector<TestingCase> *_tests; +}; + +vector<TestingCase>* Tests::_tests = 0; + +static void +simple_test(void) { + int total = 0; + int fail = 0; + + vector<TestingCase> *tests = Tests::Get_Tests(); + if (!tests) + return 0; + + for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end(); + i != e; i++) { + TestingCase& t = *i; + fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); + for (int i = 0, e = t.dict_len, need_break=0; i < e; i++) { + fprintf(stdout, "%s, ", t.dict[i]); + if (need_break++ == 16) { + fputs("\n ", stdout); + need_break = 0; + } + } + fputs("]\n", stdout); + + /* Create the dictionary */ + int dict_len = t.dict_len; + ac_t* ac = ac_create(t.dict, dict_len); + + for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { + const StrPair& sp = t.strpairs[ii]; + const char *str = sp.str; // the string to be matched + const char *match = sp.match; + + fprintf(stdout, "[%3d] Testing '%s' : ", total, str); + + int len = strlen(str); + ac_result_t r = ac_match(ac, str, len); + int m_b = r.match_begin; + int m_e = r.match_end; + + // The return value per se is insane. + if (m_b > m_e || + ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { + fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); + fail ++; + continue; + } + + // If the string is not supposed to match the dictionary. + if (!match) { + if (m_b != -1 || m_e != -1) { + fail ++; + fprintf(stdout, "Not Supposed to match (%d, %d) \n", + m_b, m_e); + } else + fputs("Pass\n", stdout); + continue; + } + + // The string or its substring is match the dict. + if (m_b >= len || m_b >= len) { + fail ++; + fprintf(stdout, + "Return value >= the length of the string (%d, %d)\n", + m_b, m_e); + continue; + } else { + int mlen = strlen(match); + if ((mlen != m_e - m_b + 1) || + strncmp(str + m_b, match, mlen)) { + fail ++; + fprintf(stdout, "Fail\n"); + } else + fprintf(stdout, "Pass\n"); + } + } + fputs("\n", stdout); + ac_free(ac); + } + + fprintf(stdout, "Total : %d, Fail %d\n", total, fail); + + return fail ? -1 : 0; +} + +int +main (int argc, char** argv) { + int res = simple_test(); + return res; +}; + +/* test 1*/ +const char *dict1[] = {"he", "she", "his", "her"}; +StrPair strpair1[] = { + {"he", "he"}, {"she", "she"}, {"his", "his"}, + {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, + {"shis2", "his"}, {"ahhe", "he"} +}; +Tests test1("test 1", + dict1, sizeof(dict1)/sizeof(dict1[0]), + strpair1, sizeof(strpair1)/sizeof(strpair1[0])); + +/* test 2*/ +const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ +StrPair strpair2[] = {{"The pot had a handle", 0}}; +Tests test2("test 2", dict2, 2, strpair2, 1); + +/* test 3*/ +const char *dict3[] = {"The"}; +StrPair strpair3[] = {{"The pot had a handle", "The"}}; +Tests test3("test 3", dict3, 1, strpair3, 1); + +/* test 4*/ +const char *dict4[] = {"pot"}; +StrPair strpair4[] = {{"The pot had a handle", "pot"}}; +Tests test4("test 4", dict4, 1, strpair4, 1); + +/* test 5*/ +const char *dict5[] = {"pot "}; +StrPair strpair5[] = {{"The pot had a handle", "pot "}}; +Tests test5("test 5", dict5, 1, strpair5, 1); + +/* test 6*/ +const char *dict6[] = {"ot h"}; +StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; +Tests test6("test 6", dict6, 1, strpair6, 1); + +/* test 7*/ +const char *dict7[] = {"andle"}; +StrPair strpair7[] = {{"The pot had a handle", "andle"}}; +Tests test7("test 7", dict7, 1, strpair7, 1); diff --git a/modules/policy/lua-aho-corasick/tests/Makefile b/modules/policy/lua-aho-corasick/tests/Makefile new file mode 100644 index 0000000..54fd90f --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/Makefile @@ -0,0 +1,65 @@ +OS := $(shell uname) +ifeq ($(OS), Darwin) + SO_EXT := dylib +else + SO_EXT := so +endif + +.PHONY = all clean test runtest benchmark + +PROGRAM = ac_test +BENCHMARK = ac_bench +all: runtest + +CXXFLAGS = -O3 -g -march=native -Wall -DDEBUG +MYCXXFLAGS = -MMD -I.. $(CXXFLAGS) +%.o : %.cxx + $(CXX) $< -c $(MYCXXFLAGS) + +-include dep.cxx +SRC = test_main.cxx ac_test_simple.cxx ac_test_aggr.cxx test_bigfile.cxx + +OBJ = ${SRC:.cxx=.o} + +-include test_dep.txt +-include bench_dep.txt + +$(PROGRAM) $(BENCHMARK) : testinput/text.tar testinput/image.bin +$(PROGRAM) : $(OBJ) ../libac.$(SO_EXT) + $(CXX) $(OBJ) -L.. -lac -o $@ + -cat *.d > test_dep.txt + +$(BENCHMARK) : ac_bench.o ../libac.$(SO_EXT) + $(CXX) ac_bench.o -L.. -lac -o $@ + -cat *.d > bench_dep.txt + +ifneq ($(OS), Darwin) +runtest:$(PROGRAM) + LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* + +benchmark:$(BENCHMARK) + LD_LIBRARY_PATH=$(LD_LIBRARY_PATH):.. ./ac_bench + +else +runtest:$(PROGRAM) + DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./$(PROGRAM) testinput/* + +benchmark:$(BENCHMARK) + DYLD_LIBRARY_PATH=$(DYLD_LIBRARY_PATH):.. ./ac_bench + +endif + +testinput/text.tar: + echo "download testing files (gcc tarball)..." + if [ ! -d testinput ] ; then mkdir testinput; fi + cd testinput && \ + curl ftp://ftp.gnu.org/gnu/gcc/gcc-1.42.tar.gz -o text.tar.gz 2>/dev/null \ + && gzip -d text.tar.gz + +testinput/image.bin: + echo "download testing files.." + if [ ! -d testinput ] ; then mkdir testinput; fi + curl http://www.3dvisionlive.com/sites/default/files/Curiosity_render_hiresb.jpg -o $@ 2>/dev/null + +clean: + -rm -f *.o *.d dep.txt $(PROGRAM) $(BENCHMARK) diff --git a/modules/policy/lua-aho-corasick/tests/ac_bench.cxx b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx new file mode 100644 index 0000000..421322c --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_bench.cxx @@ -0,0 +1,519 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <sys/time.h> +#include <time.h> +#include <fcntl.h> +#include <unistd.h> +#include <dirent.h> +#include <libgen.h> +#include <errno.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <getopt.h> + +#include <string> +#include <vector> +#include "ac.h" +#include "ac_util.hpp" + +using namespace std; + +static bool SomethingWrong = false; + +static int iteration = 300; +static string dict_dir; +static string obj_file_dir; +static bool print_help = false; +static int piece_size = 1024; + +class PatternSet { +public: + PatternSet(const char* filepath); + ~PatternSet() { Cleanup(); } + + int getPatternNum() const { return _pat_num; } + const char** getPatternVector() const { return _patterns; } + unsigned int* getPatternLenVector() const { return _pat_len; } + + const char* getErrMessage() const { return _errmsg; } + static bool isDictFile(const char* filepath) { + if (strncmp(basename(const_cast<char*>(filepath)), "dict", 4)) + return false; + return true; + } + +private: + bool ExtractPattern(const char* filepath); + void Cleanup(); + + const char** _patterns; + unsigned int* _pat_len; + char* _mmap; + int _fd; + size_t _mmap_size; + int _pat_num; + + const char* _errmsg; +}; + +bool +PatternSet::ExtractPattern(const char* filepath) { + if (!isDictFile(filepath)) + return false; + + struct stat filestat; + if (stat(filepath, &filestat)) { + _errmsg = "fail to call stat()"; + return false; + } + + if (filestat.st_size > 4096 * 1024) { + /* It dosen't seem to be a dictionary file*/ + _errmsg = "file too big?"; + return false; + } + + _fd = open(filepath, 0); + if (_fd == -1) { + _errmsg = "fail to open dictionary file"; + return false; + } + + _mmap_size = filestat.st_size; + _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, + MAP_PRIVATE, _fd, 0); + if (_mmap == MAP_FAILED) { + _errmsg = "fail to call mmap"; + return false; + } + + const char* pat = _mmap; + vector<const char*> pat_vect; + vector<unsigned> pat_len_vect; + + for (size_t i = 0, e = filestat.st_size; i < e; i++) { + if (_mmap[i] == '\r' || _mmap[i] == '\n') { + _mmap[i] = '\0'; + int len = _mmap + i - pat; + if (len > 0) { + pat_vect.push_back(pat); + pat_len_vect.push_back(len); + } + pat = _mmap + i + 1; + } + } + + ASSERT(pat_vect.size() == pat_len_vect.size()); + + int pat_num = pat_vect.size(); + if (pat_num > 0) { + const char** p = _patterns = new const char*[pat_num]; + int i = 0; + for (vector<const char*>::iterator iter = pat_vect.begin(), + iter_e = pat_vect.end(); iter != iter_e; ++iter) { + p[i++] = *iter; + } + + i = 0; + unsigned int* q = _pat_len = new unsigned int[pat_num]; + for (vector<unsigned>::iterator iter = pat_len_vect.begin(), + iter_e = pat_len_vect.end(); iter != iter_e; ++iter) { + q[i++] = *iter; + } + } + + _pat_num = pat_num; + if (pat_num <= 0) { + _errmsg = "no pattern at all"; + return false; + } + + return true; +} + +void +PatternSet::Cleanup() { + if (_mmap != MAP_FAILED) { + munmap(_mmap, _mmap_size); + _mmap = (char*)MAP_FAILED; + _mmap_size = 0; + } + + delete[] _patterns; + delete[] _pat_len; + if (_fd != -1) + close(_fd); + _pat_num = -1; +} + +PatternSet::PatternSet(const char* filepath) { + _patterns = 0; + _pat_len = 0; + _mmap = (char*)MAP_FAILED; + _mmap_size = 0; + _pat_num = -1; + _errmsg = ""; + + if (!ExtractPattern(filepath)) + Cleanup(); +} + +bool +getFilesUnderDir(vector<string>& files, const char* path) { + files.clear(); + + DIR* dir = opendir(path); + if (!dir) + return false; + + string path_dir = path; + path_dir += "/"; + + for (;;) { + struct dirent* entry = readdir(dir); + if (entry) { + string filepath = path_dir + entry->d_name; + struct stat file_stat; + if (stat(filepath.c_str(), &file_stat)) { + closedir(dir); + return false; + } + + if (S_ISREG(file_stat.st_mode)) + files.push_back(filepath); + + continue; + } + + if (errno) { + return false; + } + break; + } + closedir(dir); + return true; +} + +class Timer { +public: + Timer() { + my_clock_gettime(&_start); + _stop = _start; + _acc.tv_sec = 0; + _acc.tv_nsec = 0; + } + + const Timer& operator += (const Timer& that) { + time_t sec = _acc.tv_sec + that._acc.tv_sec; + long nsec = _acc.tv_nsec + that._acc.tv_nsec; + if (nsec > 1000000000) { + nsec -= 1000000000; + sec += 1; + } + _acc.tv_sec = sec; + _acc.tv_nsec = nsec; + return *this; + } + + // return duration in us + size_t getDuration() const { + return _acc.tv_sec * (size_t)1000000 + _acc.tv_nsec/1000; + } + + void Start(bool acc=true) { + my_clock_gettime(&_start); + } + + void Stop() { + my_clock_gettime(&_stop); + struct timespec t = CalcDuration(); + _acc = add_duration(_acc, t); + } + +private: + int my_clock_gettime(struct timespec* t) { +#ifdef __linux + return clock_gettime(CLOCK_PROCESS_CPUTIME_ID, t); +#else + struct timeval tv; + int rc = gettimeofday(&tv, 0); + t->tv_sec = tv.tv_sec; + t->tv_nsec = tv.tv_usec * 1000; + return rc; +#endif + } + + struct timespec add_duration(const struct timespec& dur1, + const struct timespec& dur2) { + time_t sec = dur1.tv_sec + dur2.tv_sec; + long nsec = dur1.tv_nsec + dur2.tv_nsec; + if (nsec > 1000000000) { + nsec -= 1000000000; + sec += 1; + } + timespec t; + t.tv_sec = sec; + t.tv_nsec = nsec; + + return t; + } + + struct timespec CalcDuration() const { + timespec diff; + if ((_stop.tv_nsec - _start.tv_nsec)<0) { + diff.tv_sec = _stop.tv_sec - _start.tv_sec - 1; + diff.tv_nsec = 1000000000 + _stop.tv_nsec - _start.tv_nsec; + } else { + diff.tv_sec = _stop.tv_sec - _start.tv_sec; + diff.tv_nsec = _stop.tv_nsec - _start.tv_nsec; + } + return diff; + } + + struct timespec _start; + struct timespec _stop; + struct timespec _acc; +}; + +class Benchmark { +public: + Benchmark(const PatternSet& pat_set, const char* infile): + _pat_set(pat_set), _infile(infile) { + _mmap = (char*)MAP_FAILED; + _file_sz = 0; + _fd = -1; + } + + ~Benchmark() { + if (_mmap != MAP_FAILED) + munmap(_mmap, _file_sz); + if (_fd != -1) + close(_fd); + } + + bool Run(int iteration); + const Timer& getTimer() const { return _timer; } + +private: + const PatternSet& _pat_set; + const char* _infile; + char* _mmap; + int _fd; + size_t _file_sz; // input file size + Timer _timer; +}; + +bool +Benchmark::Run(int iteration) { + if (_pat_set.getPatternNum() <= 0) { + SomethingWrong = true; + return false; + } + + if (_mmap == MAP_FAILED) { + struct stat filestat; + if (stat(_infile, &filestat)) { + SomethingWrong = true; + return false; + } + + if (!S_ISREG(filestat.st_mode)) { + SomethingWrong = true; + return false; + } + + _fd = open(_infile, 0); + if (_fd == -1) + return false; + + _mmap = (char*)mmap(0, filestat.st_size, PROT_READ|PROT_WRITE, + MAP_PRIVATE, _fd, 0); + + if (_mmap == MAP_FAILED) { + SomethingWrong = true; + return false; + } + + _file_sz = filestat.st_size; + } + + ac_t* ac = ac_create(_pat_set.getPatternVector(), + _pat_set.getPatternLenVector(), + _pat_set.getPatternNum()); + if (!ac) { + SomethingWrong = true; + return false; + } + + int piece_num = _file_sz/piece_size; + + _timer.Start(false); + + /* Stupid compiler may not be able to promote piece_size into register. + * Do it manually. + */ + int piece_sz = piece_size; + for (int i = 0; i < iteration; i++) { + size_t match_ofst = 0; + for (int piece_idx = 0; piece_idx < piece_num; piece_idx ++) { + ac_match2(ac, _mmap + match_ofst, piece_sz); + match_ofst += piece_sz; + } + if (match_ofst != _file_sz) + ac_match2(ac, _mmap + match_ofst, _file_sz - match_ofst); + } + _timer.Stop(); + return true; +} + +const char* short_opt = "hd:f:i:p:"; +const struct option long_opts[] = { + {"help", no_argument, 0, 'h'}, + {"iteration", required_argument, 0, 'i'}, + {"dictionary-dir", required_argument, 0, 'd'}, + {"obj-file-dir", required_argument, 0, 'f'}, + {"piece-size", required_argument, 0, 'p'}, +}; + +static void +PrintHelp(const char* prog_name) { + const char* msg = +"Usage %s [OPTIONS]\n" +" -d, --dictionary-dir : specify the dictionary directory (./dict by default)\n" +" -f, --obj-file-dir : specify the object file directory\n" +" (./testinput by default)\n" +" -i, --iteration : Run this many iteration for each pattern match\n" +" -p, --piece-size : The size of 'piece' in byte. The input file is\n" +" divided into pieces, and match function is working\n" +" on one piece at a time. The default size of piece\n" +" is 1k byte.\n"; + + fprintf(stdout, msg, prog_name); +} + +static bool +getOptions(int argc, char** argv) { + bool dict_dir_set = false; + bool objfile_dir_set = false; + int opt_index; + + while (1) { + if (print_help) break; + + int c = getopt_long(argc, argv, short_opt, long_opts, &opt_index); + + if (c == -1) break; + if (c == 0) { c = long_opts[opt_index].val; } + + switch(c) { + case 'h': + print_help = true; + break; + + case 'i': + iteration = atol(optarg); + break; + + case 'd': + dict_dir = optarg; + dict_dir_set = true; + break; + + case 'f': + obj_file_dir = optarg; + objfile_dir_set = true; + break; + + case 'p': + piece_size = atol(optarg); + break; + + case '?': + default: + return false; + } + } + + if (print_help) + return true; + + string basedir(dirname(argv[0])); + if (!dict_dir_set) + dict_dir = basedir + "/dict"; + + if (!objfile_dir_set) + obj_file_dir = basedir + "/testinput"; + + return true; +} + +int +main(int argc, char** argv) { + if (!getOptions(argc, argv)) + return -1; + + if (print_help) { + PrintHelp(argv[0]); + return 0; + } + +#ifndef __linux + fprintf(stdout, "\n!!!WARNING: On this OS, the execution time is measured" + " by gettimeofday(2) which is imprecise!!!\n\n"); +#endif + + fprintf(stdout, "Test with iteration = %d, piece size = %d, and", + iteration, piece_size); + fprintf(stdout, "\n dictionary dir = %s\n object file dir = %s\n\n", + dict_dir.c_str(), obj_file_dir.c_str()); + + vector<string> dict_files; + vector<string> input_files; + + if (!getFilesUnderDir(dict_files, dict_dir.c_str())) { + fprintf(stdout, "fail to find dictionary files\n"); + return -1; + } + + if (!getFilesUnderDir(input_files, obj_file_dir.c_str())) { + fprintf(stdout, "fail to find test input files\n"); + return -1; + } + + for (vector<string>::iterator diter = dict_files.begin(), + diter_e = dict_files.end(); diter != diter_e; ++diter) { + + const char* dict_name = diter->c_str(); + if (!PatternSet::isDictFile(dict_name)) + continue; + + PatternSet ps(dict_name); + if (ps.getPatternNum() <= 0) { + fprintf(stdout, "fail to open dictionary file %s : %s\n", + dict_name, ps.getErrMessage()); + SomethingWrong = true; + continue; + } + + fprintf(stdout, "Using dictionary %s\n", dict_name); + Timer timer; + for (vector<string>::iterator iter = input_files.begin(), + iter_e = input_files.end(); iter != iter_e; ++iter) { + fprintf(stdout, " testing %s ... ", iter->c_str()); + fflush(stdout); + Benchmark bm(ps, iter->c_str()); + bm.Run(iteration); + const Timer& t = bm.getTimer(); + timer += bm.getTimer(); + fprintf(stdout, "elapsed %.3f\n", t.getDuration() / 1000000.0); + } + + fprintf(stdout, + "\n==========================================================\n" + " Total Elapse %.3f\n\n", timer.getDuration() / 1000000.0); + } + + return SomethingWrong ? -1 : 0; +} diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx new file mode 100644 index 0000000..4ea02bc --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_test_aggr.cxx @@ -0,0 +1,135 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + +namespace { +class ACBigFileTester : public BigFileTester { +public: + ACBigFileTester(const char* filepath) : BigFileTester(filepath){}; + +private: + virtual buf_header_t* PM_Create(const char** strv, uint32* strlenv, + uint32 vect_len) { + return (buf_header_t*)ac_create(strv, strlenv, vect_len); + } + + virtual void PM_Free(buf_header_t* PM) { ac_free(PM); } + virtual bool Run_Helper(buf_header_t* PM); +}; + +class ACTestAggressive: public ACTestBase { +public: + ACTestAggressive(const vector<const char*>& files, const char* banner) + : ACTestBase(banner), _files(files) {} + virtual bool Run(); + +private: + void PrintSummary(int total, int fail) { + fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); + fflush(stdout); + } + vector<const char*> _files; +}; + +} // end of anonymous namespace + +bool +ACBigFileTester::Run_Helper(buf_header_t* PM) { + int fail = 0; + // advance one chunk at a time. + int len = _msg_len; + int chunk_sz = _chunk_sz; + + vector<const char*> c_style_keys; + for (int i = 0, e = _keys.size(); i != e; i++) { + const char* key = _keys[i].first; + int len = _keys[i].second; + char *t = new char[len+1]; + memcpy(t, key, len); + t[len] = '\0'; + c_style_keys.push_back(t); + } + + for (int ofst = 0, chunk_idx = 0, chunk_num = _chunk_num; + chunk_idx < chunk_num; ofst += chunk_sz, chunk_idx++) { + const char* substring = _msg + ofst; + ac_result_t r = ac_match((ac_t*)(void*)PM, substring , len - ofst); + int m_b = r.match_begin; + int m_e = r.match_end; + + if (m_b < 0 || m_e < 0 || m_e <= m_b || m_e >= len) { + fprintf(stdout, "fail to find match substring[%d:%d])\n", + ofst, len - 1); + fail ++; + continue; + } + + const char* match_str = _msg + len; + int strstr_len = 0; + int key_idx = -1; + + for (int i = 0, e = c_style_keys.size(); i != e; i++) { + const char* key = c_style_keys[i]; + if (const char *m = strstr(substring, key)) { + if (m < match_str) { + match_str = m; + strstr_len = _keys[i].second; + key_idx = i; + } + } + } + ASSERT(key_idx != -1); + if ((match_str - substring != m_b)) { + fprintf(stdout, + "Fail to find match substring[%d:%d])," + " expected to find match at offset %d instead of %d\n", + ofst, len - 1, + (int)(match_str - _msg), ofst + m_b); + fprintf(stdout, "%d vs %d (key idx %d)\n", strstr_len, m_e - m_b + 1, key_idx); + PrintStr(stdout, match_str, strstr_len); + fprintf(stdout, "\n"); + PrintStr(stdout, _msg + ofst + m_b, + m_e - m_b + 1); + fprintf(stdout, "\n"); + fail ++; + } + } + for (vector<const char*>::iterator i = c_style_keys.begin(), + e = c_style_keys.end(); i != e; i++) { + delete[] *i; + } + + return fail == 0; +} + +bool +ACTestAggressive::Run() { + int fail = 0; + for (vector<const char*>::iterator i = _files.begin(), e = _files.end(); + i != e; i++) { + ACBigFileTester bft(*i); + if (!bft.Run()) + fail ++; + } + return fail == 0; +} + +bool +Run_AC_Aggressive_Test(const vector<const char*>& files) { + ACTestAggressive t(files, "AC Aggressive test"); + t.PrintBanner(); + return t.Run(); +} diff --git a/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx new file mode 100644 index 0000000..fa2d7fd --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/ac_test_simple.cxx @@ -0,0 +1,275 @@ +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + +namespace { +typedef struct { + const char* str; + const char* match; +} StrPair; + +typedef enum { + MV_FIRST_MATCH = 0, + MV_LEFT_LONGEST = 1, +} MatchVariant; + +typedef struct { + const char* name; + const char** dict; + StrPair* strpairs; + int dict_len; + int strpair_num; + MatchVariant match_variant; +} TestingCase; + +class Tests { +public: + Tests(const char* name, + const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num, + MatchVariant mv = MV_FIRST_MATCH) { + if (!_tests) + _tests = new vector<TestingCase>; + + TestingCase tc; + tc.name = name; + tc.dict = dict; + tc.strpairs = strpairs; + tc.dict_len = dict_len; + tc.strpair_num = strpair_num; + tc.match_variant = mv; + _tests->push_back(tc); + } + + static vector<TestingCase>* Get_Tests() { return _tests; } + static void Erase_Tests() { delete _tests; _tests = 0; } + +private: + static vector<TestingCase> *_tests; +}; + +class LeftLongestTests : public Tests { +public: + LeftLongestTests (const char* name, const char* dict[], int dict_len, + StrPair strpairs[], int strpair_num): + Tests(name, dict, dict_len, strpairs, strpair_num, MV_LEFT_LONGEST) { + } +}; + +vector<TestingCase>* Tests::_tests = 0; + +class ACTestSimple: public ACTestBase { +public: + ACTestSimple(const char* banner) : ACTestBase(banner) {} + virtual bool Run(); + +private: + void PrintSummary(int total, int fail) { + fprintf(stdout, "Test count : %d, fail: %d\n", total, fail); + fflush(stdout); + } +}; +} + +bool +ACTestSimple::Run() { + int total = 0; + int fail = 0; + + vector<TestingCase> *tests = Tests::Get_Tests(); + if (!tests) { + PrintSummary(0, 0); + return true; + } + + for (vector<TestingCase>::iterator i = tests->begin(), e = tests->end(); + i != e; i++) { + TestingCase& t = *i; + int dict_len = t.dict_len; + unsigned int* strlen_v = new unsigned int[dict_len]; + + fprintf(stdout, ">Testing %s\nDictionary:[ ", t.name); + for (int i = 0, need_break=0; i < dict_len; i++) { + const char* s = t.dict[i]; + fprintf(stdout, "%s, ", s); + strlen_v[i] = strlen(s); + if (need_break++ == 16) { + fputs("\n ", stdout); + need_break = 0; + } + } + fputs("]\n", stdout); + + /* Create the dictionary */ + ac_t* ac = ac_create(t.dict, strlen_v, dict_len); + delete[] strlen_v; + + for (int ii = 0, ee = t.strpair_num; ii < ee; ii++, total++) { + const StrPair& sp = t.strpairs[ii]; + const char *str = sp.str; // the string to be matched + const char *match = sp.match; + + fprintf(stdout, "[%3d] Testing '%s' : ", total, str); + + int len = strlen(str); + ac_result_t r; + if (t.match_variant == MV_FIRST_MATCH) + r = ac_match(ac, str, len); + else if (t.match_variant == MV_LEFT_LONGEST) + r = ac_match_longest_l(ac, str, len); + else { + ASSERT(false && "Unknown variant"); + } + + int m_b = r.match_begin; + int m_e = r.match_end; + + // The return value per se is insane. + if (m_b > m_e || + ((m_b < 0 || m_e < 0) && (m_b != -1 || m_e != -1))) { + fprintf(stdout, "Insane return value (%d, %d)\n", m_b, m_e); + fail ++; + continue; + } + + // If the string is not supposed to match the dictionary. + if (!match) { + if (m_b != -1 || m_e != -1) { + fail ++; + fprintf(stdout, "Not Supposed to match (%d, %d) \n", + m_b, m_e); + } else + fputs("Pass\n", stdout); + continue; + } + + // The string or its substring is match the dict. + if (m_b >= len || m_b >= len) { + fail ++; + fprintf(stdout, + "Return value >= the length of the string (%d, %d)\n", + m_b, m_e); + continue; + } else { + int mlen = strlen(match); + if ((mlen != m_e - m_b + 1) || + strncmp(str + m_b, match, mlen)) { + fail ++; + fprintf(stdout, "Fail\n"); + } else + fprintf(stdout, "Pass\n"); + } + } + fputs("\n", stdout); + ac_free(ac); + } + + PrintSummary(total, fail); + return fail == 0; +} + +bool +Run_AC_Simple_Test() { + ACTestSimple t("AC Simple test"); + t.PrintBanner(); + return t.Run(); +} + +////////////////////////////////////////////////////////////////////////////// +// +// Testing cases for first-match variant (i.e. test ac_match()) +// +////////////////////////////////////////////////////////////////////////////// +// + +/* test 1*/ +const char *dict1[] = {"he", "she", "his", "her"}; +StrPair strpair1[] = { + {"he", "he"}, {"she", "she"}, {"his", "his"}, + {"hers", "he"}, {"ahe", "he"}, {"shhe", "he"}, + {"shis2", "his"}, {"ahhe", "he"} +}; +Tests test1("test 1", + dict1, sizeof(dict1)/sizeof(dict1[0]), + strpair1, sizeof(strpair1)/sizeof(strpair1[0])); + +/* test 2*/ +const char *dict2[] = {"poto", "poto"}; /* duplicated strings*/ +StrPair strpair2[] = {{"The pot had a handle", 0}}; +Tests test2("test 2", dict2, 2, strpair2, 1); + +/* test 3*/ +const char *dict3[] = {"The"}; +StrPair strpair3[] = {{"The pot had a handle", "The"}}; +Tests test3("test 3", dict3, 1, strpair3, 1); + +/* test 4*/ +const char *dict4[] = {"pot"}; +StrPair strpair4[] = {{"The pot had a handle", "pot"}}; +Tests test4("test 4", dict4, 1, strpair4, 1); + +/* test 5*/ +const char *dict5[] = {"pot "}; +StrPair strpair5[] = {{"The pot had a handle", "pot "}}; +Tests test5("test 5", dict5, 1, strpair5, 1); + +/* test 6*/ +const char *dict6[] = {"ot h"}; +StrPair strpair6[] = {{"The pot had a handle", "ot h"}}; +Tests test6("test 6", dict6, 1, strpair6, 1); + +/* test 7*/ +const char *dict7[] = {"andle"}; +StrPair strpair7[] = {{"The pot had a handle", "andle"}}; +Tests test7("test 7", dict7, 1, strpair7, 1); + +const char *dict8[] = {"aaab"}; +StrPair strpair8[] = {{"aaaaaaab", "aaab"}}; +Tests test8("test 8", dict8, 1, strpair8, 1); + +const char *dict9[] = {"haha", "z"}; +StrPair strpair9[] = {{"aaaaz", "z"}, {"z", "z"}}; +Tests test9("test 9", dict9, 2, strpair9, 2); + +/* test the case when input string dosen't contain even a single char + * of the pattern in dictionary. + */ +const char *dict10[] = {"abc"}; +StrPair strpair10[] = {{"cde", 0}}; +Tests test10("test 10", dict10, 1, strpair10, 1); + + +////////////////////////////////////////////////////////////////////////////// +// +// Testing cases for first longest match variant (i.e. +// test ac_match_longest_l()) +// +////////////////////////////////////////////////////////////////////////////// +// + +// This was actually first motivation for left-longest-match +const char *dict100[] = {"Mozilla", "Mozilla Mobile"}; +StrPair strpair100[] = {{"User Agent containing string Mozilla Mobile", "Mozilla Mobile"}}; +LeftLongestTests test100("l_test 100", dict100, 2, strpair100, 1); + +// Dict with single char is tricky +const char *dict101[] = {"a", "abc"}; +StrPair strpair101[] = {{"abcdef", "abc"}}; +LeftLongestTests test101("l_test 101", dict101, 2, strpair101, 1); + +// Testing case with partially overlapping patterns. The purpose is to +// check if the fail-link leading from terminal state is correct. +// +// The fail-link leading from terminal-state does not matter in +// match-first-occurrence variant, as it stop when a terminal is hit. +// +const char *dict102[] = {"abc", "bcdef"}; +StrPair strpair102[] = {{"abcdef", "bcdef"}}; +LeftLongestTests test102("l_test 102", dict102, 2, strpair102, 1); diff --git a/modules/policy/lua-aho-corasick/tests/dict/README.txt b/modules/policy/lua-aho-corasick/tests/dict/README.txt new file mode 100644 index 0000000..cd50b41 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/dict/README.txt @@ -0,0 +1 @@ +This directory contains pattern set of benchmark purpose. diff --git a/modules/policy/lua-aho-corasick/tests/dict/dict1.txt b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt new file mode 100644 index 0000000..94085a9 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/dict/dict1.txt @@ -0,0 +1,11 @@ +false_return@ +forloop#haha +wtfprogram +mmaporunmap +ThIs?Module!IsEssential +struct rtlwtf +gettIMEOfdayWrong +edistribution_and_use_in_@source +Copyright~#@ +while {! +!%SQLinje diff --git a/modules/policy/lua-aho-corasick/tests/load_ac_test.lua b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua new file mode 100644 index 0000000..7fb7db9 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/load_ac_test.lua @@ -0,0 +1,82 @@ +-- This script is to test load_ac.lua +-- +-- Some notes: +-- 1. The purpose of this script is not to check if the libac.so work +-- properly, it is to check if there are something stupid in load_ac.lua +-- +-- 2. There are bunch of collectgarbage() calls, the purpose is to make +-- sure the shared lib is not unloaded after GC. + +-- load_ac.lua looks up libac.so via package.cpath rather than LD_LIBRARY_PATH, +-- prepend (instead of appending) some insane paths here to see if it quit +-- prematurely. +-- +package.cpath = ".;./?.so;" .. package.cpath + +local ac = require "load_ac" + +local ac_create = ac.create_ac +local ac_match = ac.match +local string_fmt = string.format +local string_sub = string.sub + +local err_cnt = 0 +local function mytest(testname, dict, match, notmatch) + print(">Testing ", testname) + + io.write(string_fmt("Dictionary: ")); + for i=1, #dict do + io.write(string_fmt("%s, ", dict[i])) + end + print "" + + local ac_inst = ac_create(dict); + collectgarbage() + for i=1, #match do + local str = match[i] + io.write(string_fmt("Matching %s, ", str)) + local b = ac_match(ac_inst, str) + if b then + print "pass" + else + err_cnt = err_cnt + 1 + print "fail" + end + collectgarbage() + end + + if notmatch == nil then + return + end + + collectgarbage() + + for i = 1, #notmatch do + local str = notmatch[i] + io.write(string_fmt("*Matching %s, ", str)) + local r = ac_match(ac_inst, str) + if r then + err_cnt = err_cnt + 1 + print("fail") + else + print("succ") + end + collectgarbage() + end + ac_inst = nil + collectgarbage() +end + +print("") +print("====== Test to see if load_ac.lua works properly ========") + +mytest("test1", + {"he", "she", "his", "her", "str\0ing"}, + -- matching cases + { "he", "she", "his", "hers", "ahe", "shhe", "shis2", "ahhe", "str\0ing" }, + + -- not matching case + {"str\0", "str"} + ) + +os.exit((err_cnt == 0) and 0 or 1) diff --git a/modules/policy/lua-aho-corasick/tests/lua_test.lua b/modules/policy/lua-aho-corasick/tests/lua_test.lua new file mode 100644 index 0000000..cfe178f --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/lua_test.lua @@ -0,0 +1,67 @@ +-- This script is to test ahocorasick.so not libac.so +-- +local ac = require "ahocorasick" + +local ac_create = ac.create +local ac_match = ac.match +local string_fmt = string.format +local string_sub = string.sub + +local err_cnt = 0 +local function mytest(testname, dict, match, notmatch) + print(">Testing ", testname) + + io.write(string_fmt("Dictionary: ")); + for i=1, #dict do + io.write(string_fmt("%s, ", dict[i])) + end + print "" + + local ac_inst = ac_create(dict); + for i=1, #match do + local str = match[i][1] + local substr = match[i][2] + io.write(string_fmt("Matching %s, ", str)) + local b, e = ac_match(ac_inst, str) + if b and e and (string_sub(str, b+1, e+1) == substr) then + print "pass" + else + err_cnt = err_cnt + 1 + print "fail" + end + --print("gc is called") + collectgarbage() + end + + if notmatch == nil then + return + end + + for i = 1, #notmatch do + local str = notmatch[i] + io.write(string_fmt("*Matching %s, ", str)) + local r = ac_match(ac_inst, str) + if r then + err_cnt = err_cnt + 1 + print("fail") + else + print("succ") + end + collectgarbage() + end +end + +mytest("test1", + {"he", "she", "his", "her", "str\0ing"}, + -- matching cases + { {"he", "he"}, {"she", "she"}, {"his", "his"}, {"hers", "he"}, + {"ahe", "he"}, {"shhe", "he"}, {"shis2", "his"}, {"ahhe", "he"}, + {"str\0ing", "str\0ing"} + }, + + -- not matching case + {"str\0", "str"} + + ) + +os.exit((err_cnt == 0) and 0 or 1) diff --git a/modules/policy/lua-aho-corasick/tests/test_base.hpp b/modules/policy/lua-aho-corasick/tests/test_base.hpp new file mode 100644 index 0000000..7758371 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_base.hpp @@ -0,0 +1,60 @@ +#ifndef TEST_BASE_H +#define TEST_BASE_H + +#include <stdio.h> +#include <string> +#include <stdint.h> + +using namespace std; +class ACTestBase { +public: + ACTestBase(const char* name) :_banner(name) {} + virtual void PrintBanner() { + fprintf(stdout, "\n===== %s ====\n", _banner.c_str()); + } + + virtual bool Run() = 0; +private: + string _banner; +}; + +typedef std::pair<const char*, int> StrInfo; +class BigFileTester { +public: + BigFileTester(const char* filepath); + virtual ~BigFileTester() { Cleanup(); } + + bool Run(); + +protected: + virtual buf_header_t* PM_Create(const char** strv, uint32_t* strlenv, + uint32_t vect_len) = 0; + virtual void PM_Free(buf_header_t*) = 0; + virtual bool Run_Helper(buf_header_t* PM) = 0; + + // Return true if the '\0' is valid char of a string. + virtual bool Str_C_Style() { return true; } + + bool GenerateKeys(); + void Cleanup(); + void PrintStr(FILE*, const char* str, int len); + +protected: + const char* _filepath; + int _fd; + vector<StrInfo> _keys; + char* _msg; + int _msg_len; + int _key_num; // number of strings in dictionary + int _chunk_sz; + int _chunk_num; + + int _max_key_num; + int _key_min_len; + int _key_max_len; +}; + +extern bool Run_AC_Simple_Test(); +extern bool Run_AC_Aggressive_Test(const vector<const char*>& files); + +#endif diff --git a/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx new file mode 100644 index 0000000..f189d8d --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_bigfile.cxx @@ -0,0 +1,167 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> + +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +/////////////////////////////////////////////////////////////////////////// +// +// Implementation of BigFileTester +// +/////////////////////////////////////////////////////////////////////////// +// +BigFileTester::BigFileTester(const char* filepath) { + _filepath = filepath; + _fd = -1; + _msg = (char*)MAP_FAILED; + _msg_len = 0; + _key_num = 0; + _chunk_sz = 0; + _chunk_num = 0; + + _max_key_num = 100; + _key_min_len = 20; + _key_max_len = 80; +} + +void +BigFileTester::Cleanup() { + if (_msg != MAP_FAILED) { + munmap((void*)_msg, _msg_len); + _msg = (char*)MAP_FAILED; + _msg_len = 0; + } + + if (_fd != -1) { + close(_fd); + _fd = -1; + } +} + +bool +BigFileTester::GenerateKeys() { + int chunk_sz = 4096; + int max_key_num = _max_key_num; + int key_min_len = _key_min_len; + int key_max_len = _key_max_len; + + int t = _msg_len / chunk_sz; + int keynum = t > max_key_num ? max_key_num : t; + + if (keynum <= 4) { + // file is too small + return false; + } + chunk_sz = _msg_len / keynum; + _chunk_sz = chunk_sz; + + // For each chunck, "randomly" grab a sub-string searving + // as key. + int random_ofst[] = { 12, 30, 23, 15 }; + int rofstsz = sizeof(random_ofst)/sizeof(random_ofst[0]); + int ofst = 0; + const char* msg = _msg; + _chunk_num = keynum - 1; + for (int idx = 0, e = _chunk_num; idx < e; idx++) { + const char* key = msg + ofst + idx % rofstsz; + int key_len = key_min_len + idx % (key_max_len - key_min_len); + _keys.push_back(StrInfo(key, key_len)); + ofst += chunk_sz; + } + return true; +} + +bool +BigFileTester::Run() { + // Step 1: Bring the file into memory + fprintf(stdout, "Testing using file '%s'...\n", _filepath); + + int fd = _fd = ::open(_filepath, O_RDONLY); + if (fd == -1) { + perror("open"); + return false; + } + + struct stat sb; + if (fstat(fd, &sb) == -1) { + perror("fstat"); + return false; + } + + if (!S_ISREG (sb.st_mode)) { + fprintf(stderr, "%s is not regular file\n", _filepath); + return false; + } + + int ten_M = 1024 * 1024 * 10; + int map_sz = _msg_len = sb.st_size > ten_M ? ten_M : sb.st_size; + char* p = _msg = + (char*)mmap (0, map_sz, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); + if (p == MAP_FAILED) { + perror("mmap"); + return false; + } + + // Get rid of '\0' if we are picky at it. + if (Str_C_Style()) { + for (int i = 0; i < map_sz; i++) { if (!p[i]) p[i] = 'a'; } + p[map_sz - 1] = 0; + } + + // Step 2: "Fabricate" some keys from the file. + if (!GenerateKeys()) { + close(fd); + return false; + } + + // Step 3: Create PM instance + const char** keys = new const char*[_keys.size()]; + unsigned int* keylens = new unsigned int[_keys.size()]; + + int i = 0; + for (vector<StrInfo>::iterator si = _keys.begin(), se = _keys.end(); + si != se; si++, i++) { + const StrInfo& strinfo = *si; + keys[i] = strinfo.first; + keylens[i] = strinfo.second; + } + + buf_header_t* PM = PM_Create(keys, keylens, i); + delete[] keys; + delete[] keylens; + + // Step 4: Run testing + bool res = Run_Helper(PM); + PM_Free(PM); + + // Step 5: Clanup + munmap(p, map_sz); + _msg = (char*)MAP_FAILED; + close(fd); + _fd = -1; + + fprintf(stdout, "%s\n", res ? "succ" : "fail"); + return res; +} + +void +BigFileTester::PrintStr(FILE* f, const char* str, int len) { + fprintf(f, "{"); + for (int i = 0; i < len; i++) { + unsigned char c = str[i]; + if (isprint(c)) + fprintf(f, "'%c', ", c); + else + fprintf(f, "%#x, ", c); + } + fprintf(f, "}"); +}; diff --git a/modules/policy/lua-aho-corasick/tests/test_main.cxx b/modules/policy/lua-aho-corasick/tests/test_main.cxx new file mode 100644 index 0000000..b4f5225 --- /dev/null +++ b/modules/policy/lua-aho-corasick/tests/test_main.cxx @@ -0,0 +1,33 @@ +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/mman.h> +#include <fcntl.h> +#include <unistd.h> + +#include <stdio.h> +#include <string.h> +#include <vector> +#include <string> +#include "ac.h" +#include "ac_util.hpp" +#include "test_base.hpp" + +using namespace std; + + +///////////////////////////////////////////////////////////////////////// +// +// Simple (yet maybe tricky) testings +// +///////////////////////////////////////////////////////////////////////// +// +int +main (int argc, char** argv) { + bool succ = Run_AC_Simple_Test(); + + vector<const char*> files; + for (int i = 1; i < argc; i++) { files.push_back(argv[i]); } + succ = Run_AC_Aggressive_Test(files) && succ; + + return succ ? 0 : -1; +}; diff --git a/modules/policy/meson.build b/modules/policy/meson.build new file mode 100644 index 0000000..37f1683 --- /dev/null +++ b/modules/policy/meson.build @@ -0,0 +1,50 @@ +# LUA module: policy +# SPDX-License-Identifier: GPL-3.0-or-later + +lua_mod_src += [ + files('policy.lua'), +] + +config_tests += [ + ['policy', files('policy.test.lua')], + ['policy.slice', files('policy.slice.test.lua')], + ['policy.rpz', files('policy.rpz.test.lua')], +] + +integr_tests += [ + ['policy', meson.current_source_dir() / 'test.integr'], + ['policy.noipv6', meson.current_source_dir() / 'noipv6.test.integr'], + ['policy.noipvx', meson.current_source_dir() / 'noipvx.test.integr'], +] + +# check git submodules were initialized +lua_ac_submodule = run_command(['test', '-r', + '@0@/lua-aho-corasick/ac_fast.cxx'.format(meson.current_source_dir())], + check: false) +if lua_ac_submodule.returncode() != 0 + error('run "git submodule update --init --recursive" to initialize git submodules') +endif + +# compile bundled lua-aho-corasick as shared module +lua_ac_src = files([ + 'lua-aho-corasick/ac_fast.cxx', + 'lua-aho-corasick/ac_lua.cxx', + 'lua-aho-corasick/ac_slow.cxx', +]) + +lua_ac_lib = shared_module( + 'ahocorasick', + lua_ac_src, + cpp_args: [ + '-fvisibility=hidden', + '-Wall', + '-fPIC', + ], + dependencies: [ + luajit_inc, + ], + include_directories: mod_inc_dir, + name_prefix: '', + install: true, + install_dir: lib_dir, +) diff --git a/modules/policy/noipv6.test.integr/broken-ipv6.rpl b/modules/policy/noipv6.test.integr/broken-ipv6.rpl new file mode 100644 index 0000000..cb0738d --- /dev/null +++ b/modules/policy/noipv6.test.integr/broken-ipv6.rpl @@ -0,0 +1,47 @@ +; config options +; SPDX-License-Identifier: GPL-3.0-or-later + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test that IPv6 is not used by kresd. + +RANGE_BEGIN 0 100 + ADDRESS ::1:2:3:4 +RANGE_END + +RANGE_BEGIN 0 100 + ADDRESS 1.2.3.4 + +ENTRY_BEGIN +MATCH opcode qtype qname +ADJUST copy_id +REPLY QR NOERROR +SECTION QUESTION +www.test.org A +SECTION ANSWER +www.test.org 3600 A 4.3.2.1 +ENTRY_END + +RANGE_END + + +STEP 10 QUERY +ENTRY_BEGIN +REPLY RD AD +SECTION QUESTION +www.test.org A +ENTRY_END + +STEP 20 CHECK_ANSWER +ENTRY_BEGIN +MATCH all answer +REPLY QR RD RA NOERROR +SECTION QUESTION +www.test.org A +SECTION ANSWER +www.test.org 3600 A 4.3.2.1 +SECTION AUTHORITY +SECTION ADDITIONAL +ENTRY_END + +SCENARIO_END diff --git a/modules/policy/noipv6.test.integr/deckard.yaml b/modules/policy/noipv6.test.integr/deckard.yaml new file mode 100644 index 0000000..4c1b6f8 --- /dev/null +++ b/modules/policy/noipv6.test.integr/deckard.yaml @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +programs: +- name: kresd + binary: kresd + additional: + - --noninteractive + templates: + - modules/policy/noipv6.test.integr/kresd_config.j2 + - tests/integration/hints_zone.j2 + configs: + - config + - hints diff --git a/modules/policy/noipv6.test.integr/kresd_config.j2 b/modules/policy/noipv6.test.integr/kresd_config.j2 new file mode 100644 index 0000000..a897d17 --- /dev/null +++ b/modules/policy/noipv6.test.integr/kresd_config.j2 @@ -0,0 +1,59 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +{% raw %} +net.ipv6 = false +policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' }))) + +-- make sure DNSSEC is turned off for tests +trust_anchors.remove('.') + +-- Disable RFC5011 TA update +if ta_update then + modules.unload('ta_update') +end + +-- Disable RFC8145 signaling, scenario doesn't provide expected answers +if ta_signal_query then + modules.unload('ta_signal_query') +end + +-- Disable RFC8109 priming, scenario doesn't provide expected answers +if priming then + modules.unload('priming') +end + +-- Disable this module because it make one priming query +if detect_time_skew then + modules.unload('detect_time_skew') +end + +_hint_root_file('hints') +cache.size = 2*MB +log_level('debug') +{% endraw %} + +net = { '{{SELF_ADDR}}' } + + +{% if QMIN == "false" %} +option('NO_MINIMIZE', true) +{% else %} +option('NO_MINIMIZE', false) +{% endif %} + + +-- Self-checks on globals +assert(help() ~= nil) +assert(worker.id ~= nil) +-- Self-checks on facilities +assert(cache.count() == 0) +assert(cache.stats() ~= nil) +assert(cache.backends() ~= nil) +assert(worker.stats() ~= nil) +assert(net.interfaces() ~= nil) +-- Self-checks on loaded stuff +assert(net.list()[1].transport.ip == '{{SELF_ADDR}}') +assert(#modules.list() > 0) +-- Self-check timers +ev = event.recurrent(1 * sec, function (ev) return 1 end) +event.cancel(ev) +ev = event.after(0, function (ev) return 1 end) diff --git a/modules/policy/noipvx.test.integr/broken-ipvx.rpl b/modules/policy/noipvx.test.integr/broken-ipvx.rpl new file mode 100644 index 0000000..60ed618 --- /dev/null +++ b/modules/policy/noipvx.test.integr/broken-ipvx.rpl @@ -0,0 +1,35 @@ +; config options +; SPDX-License-Identifier: GPL-3.0-or-later + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test that neither IPv6 nor IPv4 is used by kresd :-) + +RANGE_BEGIN 0 100 + ADDRESS ::1:2:3:4 +RANGE_END + +RANGE_BEGIN 0 100 + ADDRESS 1.2.3.4 +RANGE_END + + +STEP 10 QUERY +ENTRY_BEGIN +REPLY RD AD +SECTION QUESTION +www.test.org A +ENTRY_END + +STEP 20 CHECK_ANSWER +ENTRY_BEGIN +MATCH all answer +REPLY QR RD RA SERVFAIL +SECTION QUESTION +www.test.org A +SECTION ANSWER +SECTION AUTHORITY +SECTION ADDITIONAL +ENTRY_END + +SCENARIO_END diff --git a/modules/policy/noipvx.test.integr/deckard.yaml b/modules/policy/noipvx.test.integr/deckard.yaml new file mode 100644 index 0000000..8178759 --- /dev/null +++ b/modules/policy/noipvx.test.integr/deckard.yaml @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +programs: +- name: kresd + binary: kresd + additional: + - --noninteractive + templates: + - modules/policy/noipvx.test.integr/kresd_config.j2 + - tests/integration/hints_zone.j2 + configs: + - config + - hints diff --git a/modules/policy/noipvx.test.integr/kresd_config.j2 b/modules/policy/noipvx.test.integr/kresd_config.j2 new file mode 100644 index 0000000..87873e8 --- /dev/null +++ b/modules/policy/noipvx.test.integr/kresd_config.j2 @@ -0,0 +1,60 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +{% raw %} +net.ipv4 = false +net.ipv6 = false +policy.add(policy.all(policy.STUB({ '::1:2:3:4', '1.2.3.4' }))) + +-- make sure DNSSEC is turned off for tests +trust_anchors.remove('.') + +-- Disable RFC5011 TA update +if ta_update then + modules.unload('ta_update') +end + +-- Disable RFC8145 signaling, scenario doesn't provide expected answers +if ta_signal_query then + modules.unload('ta_signal_query') +end + +-- Disable RFC8109 priming, scenario doesn't provide expected answers +if priming then + modules.unload('priming') +end + +-- Disable this module because it make one priming query +if detect_time_skew then + modules.unload('detect_time_skew') +end + +_hint_root_file('hints') +cache.size = 2*MB +log_level('debug') +{% endraw %} + +net = { '{{SELF_ADDR}}' } + + +{% if QMIN == "false" %} +option('NO_MINIMIZE', true) +{% else %} +option('NO_MINIMIZE', false) +{% endif %} + + +-- Self-checks on globals +assert(help() ~= nil) +assert(worker.id ~= nil) +-- Self-checks on facilities +assert(cache.count() == 0) +assert(cache.stats() ~= nil) +assert(cache.backends() ~= nil) +assert(worker.stats() ~= nil) +assert(net.interfaces() ~= nil) +-- Self-checks on loaded stuff +assert(net.list()[1].transport.ip == '{{SELF_ADDR}}') +assert(#modules.list() > 0) +-- Self-check timers +ev = event.recurrent(1 * sec, function (ev) return 1 end) +event.cancel(ev) +ev = event.after(0, function (ev) return 1 end) diff --git a/modules/policy/policy.lua b/modules/policy/policy.lua new file mode 100644 index 0000000..47e436f --- /dev/null +++ b/modules/policy/policy.lua @@ -0,0 +1,1109 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +local kres = require('kres') +local ffi = require('ffi') + +local LOG_GRP_POLICY_TAG = ffi.string(ffi.C.kr_log_grp2name(ffi.C.LOG_GRP_POLICY)) +local LOG_GRP_REQDBG_TAG = ffi.string(ffi.C.kr_log_grp2name(ffi.C.LOG_GRP_REQDBG)) + +local todname = kres.str2dname -- not available during module load otherwise + +-- Counter of unique rules +local nextid = 0 +local function getruleid() + local newid = nextid + nextid = nextid + 1 + return newid +end + +-- Support for client sockets from inside policy actions +local socket_client = function () + return error("missing lua-cqueues library, can't create socket client") +end +local has_socket, socket = pcall(require, 'cqueues.socket') +if has_socket then + socket_client = function (host, port) + local s, err, status + + s = socket.connect({ host = host, port = port, type = socket.SOCK_DGRAM }) + s:setmode('bn', 'bn') + status, err = pcall(s.connect, s) + + if not status then + return status, err + end + return s + end +end + +-- Split address and port from a combined string. +local function addr_split_port(target, default_port) + assert(default_port and type(default_port) == 'number') + local port = ffi.new('uint16_t[1]', default_port) + local addr = ffi.new('char[47]') -- INET6_ADDRSTRLEN + 1 + local ret = ffi.C.kr_straddr_split(target, addr, port) + if ret ~= 0 then + error('failed to parse address ' .. target) + end + return addr, tonumber(port[0]) +end + +-- String address@port -> sockaddr. +local function addr2sock(target, default_port) + local addr, port = addr_split_port(target, default_port) + local sock = ffi.gc(ffi.C.kr_straddr_socket(addr, port, nil), ffi.C.free); + if sock == nil then + error("target '"..target..'" is not a valid IP address') + end + return sock +end + +-- Debug logging for taken policy actions +local function log_policy_action(req, name) + if ffi.C.kr_log_is_debug_fun(ffi.C.LOG_GRP_POLICY, req) then + local qry = req:current() + ffi.C.kr_log_req1( + req, qry.uid, 2, ffi.C.LOG_GRP_POLICY, LOG_GRP_POLICY_TAG, + "%s applied for %s %s\n", + name, kres.dname2str(qry.sname), kres.tostring.type[qry.stype]) + end +end + +-- policy functions are defined below +local policy = {} + +function policy.PASS(state, _) + return state +end + +-- Mirror request elsewhere, and continue solving +function policy.MIRROR(target) + local addr, port = addr_split_port(target, 53) + local sink, err = socket_client(ffi.string(addr), port) + if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end + return function(state, req) + if state == kres.FAIL then return state end + local query = req.qsource.packet + if query ~= nil then + sink:send(ffi.string(query.wire, query.size), 1, tonumber(query.size)) + end + return -- Chain action to next + end +end + +-- Override the list of nameservers (forwarders) +local function set_nslist(req, list) + local ns_i = 0 + for _, ns in ipairs(list) do + if ffi.C.kr_forward_add_target(req, ns) == 0 then + ns_i = ns_i + 1 + end + end + if ns_i == 0 then + -- would use assert() but don't want to compose the message if not triggered + error('no usable address in NS set (check net.ipv4 and ' + .. 'net.ipv6 config):\n' .. table_print(list, 2)) + end +end + +-- Forward request, and solve as stub query +function policy.STUB(target) + local list = {} + if type(target) == 'table' then + for _, v in pairs(target) do + table.insert(list, addr2sock(v, 53)) + end + else + table.insert(list, addr2sock(target, 53)) + end + return function(state, req) + local qry = req:current() + -- Switch mode to stub resolver, do not track origin zone cut since it's not real authority NS + qry.flags.STUB = true + qry.flags.ALWAYS_CUT = false + set_nslist(req, list) + return state + end +end + +-- Forward request and all subrequests to upstream; validate answers +function policy.FORWARD(target) + local list = {} + if type(target) == 'table' then + for _, v in pairs(target) do + table.insert(list, addr2sock(v, 53)) + end + else + table.insert(list, addr2sock(target, 53)) + end + return function(state, req) + local qry = req:current() + req.options.FORWARD = true + req.options.NO_MINIMIZE = true + qry.flags.FORWARD = true + qry.flags.ALWAYS_CUT = false + qry.flags.NO_MINIMIZE = true + qry.flags.AWAIT_CUT = true + set_nslist(req, list) + return state + end +end + +-- Forward request and all subrequests to upstream over TLS; validate answers +function policy.TLS_FORWARD(targets) + if type(targets) ~= 'table' or #targets < 1 then + error('TLS_FORWARD argument must be a non-empty table') + end + + local sockaddr_c_set = {} + local nslist = {} -- to persist in closure of the returned function + for idx, target in pairs(targets) do + if type(target) ~= 'table' or type(target[1]) ~= 'string' then + error(string.format('TLS_FORWARD configuration at position ' .. + '%d must be a table starting with an IP address', idx)) + end + -- Note: some functions have checks with error() calls inside. + local sockaddr_c = addr2sock(target[1], 853) + + -- Refuse repeated addresses in the same set. + local sockaddr_lua = ffi.string(sockaddr_c, ffi.C.kr_sockaddr_len(sockaddr_c)) + if sockaddr_c_set[sockaddr_lua] then + error('TLS_FORWARD configuration cannot declare two configs for IP address ' + .. target[1]) + else + sockaddr_c_set[sockaddr_lua] = true; + end + + table.insert(nslist, sockaddr_c) + net.tls_client(target) + end + + return function(state, req) + local qry = req:current() + req.options.FORWARD = true + req.options.NO_MINIMIZE = true + qry.flags.FORWARD = true + qry.flags.ALWAYS_CUT = false + qry.flags.NO_MINIMIZE = true + qry.flags.AWAIT_CUT = true + req.options.TCP = true + qry.flags.TCP = true + set_nslist(req, nslist) + return state + end +end + +-- Rewrite records in packet +function policy.REROUTE(tbl, names) + -- Import renumbering rules + local ren = require('kres_modules.renumber') + local prefixes = {} + for from, to in pairs(tbl) do + local prefix = names and ren.name(from, to) or ren.prefix(from, to) + table.insert(prefixes, prefix) + end + -- Return rule closure + return ren.rule(prefixes) +end + +-- Set and clear some query flags +function policy.FLAGS(opts_set, opts_clear) + return function(_, req) + -- We assume to be running in the begin phase, so to truly apply + -- to the whole request we need to change both kr_request and kr_query. + local qry = req:current() + for _, flags in pairs({qry.flags, req.options}) do + ffi.C.kr_qflags_set (flags, kres.mk_qflags(opts_set or {})) + ffi.C.kr_qflags_clear(flags, kres.mk_qflags(opts_clear or {})) + end + return nil -- chain rule + end +end + +local function mkauth_soa(answer, dname, mname, ttl) + if mname == nil then + mname = dname + end + return answer:put(dname, ttl or 10800, answer:qclass(), kres.type.SOA, + mname .. '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48') +end + +-- Create answer with passed arguments +function policy.ANSWER(rtable, nodata) + return function(_, req) + local qry = req:current() + local data = rtable[qry.stype] + if data == nil and nodata ~= true then + return nil + end + -- now we're certain we want to generate an answer + + local answer = req:ensure_answer() + if answer == nil then return nil end + ffi.C.kr_pkt_make_auth_header(answer) + local ttl = (data or {}).ttl or 1 + answer:rcode(kres.rcode.NOERROR) + req:set_extended_error(kres.extended_error.FORGED, "5DO5") + + if data == nil then -- want NODATA, i.e. just a SOA + answer:begin(kres.section.AUTHORITY) + local soa = rtable[kres.type.SOA] + if soa ~= nil then + answer:put(qry.sname, soa.ttl or ttl, qry.sclass, kres.type.SOA, + soa.rdata[1] or soa.rdata) + else + mkauth_soa(answer, kres.dname2wire(qry.sname), nil, ttl) + end + log_policy_action(req, 'ANSWER (nodata)') + else + answer:begin(kres.section.ANSWER) + if type(data.rdata) == 'table' then + for _, entry in ipairs(data.rdata) do + answer:put(qry.sname, ttl, qry.sclass, qry.stype, entry) + end + else + answer:put(qry.sname, ttl, qry.sclass, qry.stype, data.rdata) + end + log_policy_action(req, 'ANSWER (forged)') + end + return kres.DONE + end +end + +local dname_localhost = todname('localhost.') + +-- Rule for localhost. zone; see RFC6303, sec. 3 +local function localhost(_, req) + local qry = req:current() + local answer = req:ensure_answer() + if answer == nil then return nil end + ffi.C.kr_pkt_make_auth_header(answer) + + local is_exact = ffi.C.knot_dname_is_equal(qry.sname, dname_localhost) + + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + if qry.stype == kres.type.AAAA then + answer:put(qry.sname, 900, answer:qclass(), kres.type.AAAA, + '\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1') + elseif qry.stype == kres.type.A then + answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\127\0\0\1') + elseif is_exact and qry.stype == kres.type.SOA then + mkauth_soa(answer, dname_localhost) + elseif is_exact and qry.stype == kres.type.NS then + answer:put(dname_localhost, 900, answer:qclass(), kres.type.NS, dname_localhost) + else + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, dname_localhost) + end + return kres.DONE +end + +local dname_rev4_localhost = todname('1.0.0.127.in-addr.arpa'); +local dname_rev4_localhost_apex = todname('127.in-addr.arpa'); + +-- Rule for reverse localhost. +-- Answer with locally served minimal 127.in-addr.arpa domain, only having +-- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone. +-- TODO: much of this would better be left to the hints module (or coordinated). +local function localhost_reversed(_, req) + local qry = req:current() + local answer = req:ensure_answer() + if answer == nil then return nil end + + -- classify qry.sname: + local is_exact -- exact dname for localhost + local is_apex -- apex of a locally-served localhost zone + local is_nonterm -- empty non-terminal name + if ffi.C.knot_dname_in_bailiwick(qry.sname, todname('ip6.arpa.')) > 0 then + -- exact ::1 query (relying on the calling rule) + is_exact = true + is_apex = true + else + -- within 127.in-addr.arpa. + local labels = ffi.C.knot_dname_labels(qry.sname, nil) + if labels == 3 then + is_exact = false + is_apex = true + elseif labels == 4+2 and ffi.C.knot_dname_is_equal( + qry.sname, dname_rev4_localhost) then + is_exact = true + else + is_exact = false + is_apex = false + is_nonterm = ffi.C.knot_dname_in_bailiwick(dname_rev4_localhost, qry.sname) > 0 + end + end + + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + if is_exact and qry.stype == kres.type.PTR then + answer:put(qry.sname, 900, answer:qclass(), kres.type.PTR, dname_localhost) + elseif is_apex and qry.stype == kres.type.SOA then + mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost) + elseif is_apex and qry.stype == kres.type.NS then + answer:put(dname_rev4_localhost_apex, 900, answer:qclass(), kres.type.NS, + dname_localhost) + else + if not is_nonterm then + answer:rcode(kres.rcode.NXDOMAIN) + end + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost) + end + return kres.DONE +end + +-- All requests +function policy.all(action) + return function(_, _) return action end +end + +-- Requests whose QNAME is exactly the provided domain +function policy.domains(action, dname_list) + return function(_, query) + local qname = query:name() + for _, dname in ipairs(dname_list) do + if ffi.C.knot_dname_is_equal(qname, dname) then + return action + end + end + return nil + end +end + +-- Requests whose QNAME matches given zone list (i.e. suffix match) +function policy.suffix(action, zone_list) + local AC = require('ahocorasick') + local tree = AC.create(zone_list) + return function(_, query) + local match = AC.match(tree, query:name(), false) + if match ~= nil then + return action + end + return nil + end +end + +-- Check for common suffix first, then suffix match (specialized version of suffix match) +function policy.suffix_common(action, suffix_list, common_suffix) + local common_len = string.len(common_suffix) + local suffix_count = #suffix_list + return function(_, query) + -- Preliminary check + local qname = query:name() + if not string.find(qname, common_suffix, -common_len, true) then + return nil + end + -- String match + for i = 1, suffix_count do + local zone = suffix_list[i] + if string.find(qname, zone, -string.len(zone), true) then + return action + end + end + return nil + end +end + +-- Filter QNAME pattern +function policy.pattern(action, pattern) + return function(_, query) + if string.find(query:name(), pattern) then + return action + end + return nil + end +end + +local function rpz_parse(action, path) + local rules = {} + local new_actions = {} + local action_map = { + -- RPZ Policy Actions + ['\0'] = action, + ['\1*\0'] = policy.ANSWER({}, true), + ['\012rpz-passthru\0'] = policy.PASS, -- the grammar... + ['\008rpz-drop\0'] = policy.DROP, + ['\012rpz-tcp-only\0'] = policy.TC, + -- Policy triggers @NYI@ + } + -- RR types to be skipped; boolean denoting whether to throw a warning even for RPZ apex. + local rrtype_bad = { + [kres.type.DNAME] = true, + [kres.type.NS] = false, + [kres.type.DNSKEY] = true, + [kres.type.DS] = true, + [kres.type.RRSIG] = true, + [kres.type.NSEC] = true, + [kres.type.NSEC3] = true, + } + + -- We generally don't know what zone should be in the file; we try to detect it. + -- Fortunately, it's typical that SOA is the first record, even required for AXFR. + local origin_soa = nil + local warned_soa, warned_bailiwick + + local parser = require('zonefile').new() + local ok, errstr = parser:open(path) + if not ok then + error(string.format('failed to parse "%s": %s', path, errstr or "unknown error")) + end + while true do + ok, errstr = parser:parse() + if errstr then + log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d: %s', + path, tonumber(parser.line_counter), errstr) + end + if not ok then break end + + local full_name = ffi.gc(ffi.C.knot_dname_copy(parser.r_owner, nil), ffi.C.free) + local rdata = ffi.string(parser.r_data, parser.r_data_length) + ffi.C.knot_dname_to_lower(full_name) + + local origin = origin_soa or parser.zone_origin + local prefix_labels = ffi.C.knot_dname_in_bailiwick(full_name, origin) + if prefix_labels < 0 then + if not warned_bailiwick then + warned_bailiwick = true + log_warn(ffi.C.LOG_GRP_POLICY, + 'RPZ %s:%d: RR owner "%s" outside the zone (ignored; reported once per file)', + path, tonumber(parser.line_counter), kres.dname2str(full_name)) + end + goto continue + end + + local bytes = ffi.C.knot_dname_size(full_name) - ffi.C.knot_dname_size(origin) + local name = ffi.string(full_name, bytes) .. '\0' + + if parser.r_type == kres.type.CNAME then + if action_map[rdata] then + rules[name] = action_map[rdata] + else + log_warn(ffi.C.LOG_GRP_POLICY, + 'RPZ %s:%d: CNAME with custom target in RPZ is not supported yet (ignored)', + path, tonumber(parser.line_counter)) + end + else + if #name then + local is_bad = rrtype_bad[parser.r_type] + + if parser.r_type == kres.type.SOA then + if origin_soa == nil then + origin_soa = ffi.gc(ffi.C.knot_dname_copy(parser.r_owner, nil), ffi.C.free) + goto continue -- we don't want to modify `new_actions` + else + is_bad = true -- maybe provide more info, but it seems rare + end + elseif origin_soa == nil and not warned_soa then + warned_soa = true + log_warn(ffi.C.LOG_GRP_POLICY, + 'RPZ %s:%d warning: SOA missing as the first record', + path, tonumber(parser.line_counter)) + end + + if is_bad == true or (is_bad == false and prefix_labels ~= 0) then + log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d warning: RR type %s is not allowed in RPZ (ignored)', + path, tonumber(parser.line_counter), kres.tostring.type[parser.r_type]) + elseif is_bad == nil then + if new_actions[name] == nil then new_actions[name] = {} end + local act = new_actions[name][parser.r_type] + if act == nil then + new_actions[name][parser.r_type] = { ttl=parser.r_ttl, rdata=rdata } + else -- multiple RRs: no reordering or deduplication + if type(act.rdata) ~= 'table' then + act.rdata = { act.rdata } + end + table.insert(act.rdata, rdata) + if parser.r_ttl ~= act.ttl then -- be conservative + log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d warning: different TTLs in a set (minimum taken)', + path, tonumber(parser.line_counter)) + act.ttl = math.min(act.ttl, parser.r_ttl) + end + end + else + assert(is_bad == false and prefix_labels == 0) + end + end + end + + ::continue:: + end + collectgarbage() + for qname, rrsets in pairs(new_actions) do + rules[qname] = policy.ANSWER(rrsets, true) + end + return rules +end + +-- Split path into dirname and basename (like the shell utilities) +local function get_dir_and_file(path) + local dir, file = string.match(path, "(.*)/([^/]+)") + + -- If regex doesn't match then path must be the file directly (i.e. doesn't contain '/') + -- This assumes that the file exists (rpz_parse() would fail if it doesn't) + if not dir and not file then + dir = '.' + file = path + end + + return dir, file +end + +-- RPZ policy set +-- Create RPZ from zone file and optionally watch the file for changes +function policy.rpz(action, path, watch) + local rules = rpz_parse(action, path) + + if watch ~= false then + local has_notify, notify = pcall(require, 'cqueues.notify') + if has_notify then + local bit = require('bit') + + local dir, file = get_dir_and_file(path) + local watcher = notify.opendir(dir) + watcher:add(file, bit.bxor(notify.CREATE, notify.MODIFY)) + + worker.coroutine(function () + for _, name in watcher:changes() do + -- Limit to changes on file we're interested in + -- Watcher will also fire for changes to the directory itself + if name == file then + -- If the file changes then reparse and replace the existing ruleset + log_info(ffi.C.LOG_GRP_POLICY, 'RPZ reloading: ' .. name) + rules = rpz_parse(action, path) + end + end + end) + elseif watch then -- explicitly requested and failed + error('[poli] lua-cqueues required to watch and reload RPZ file') + else + log_info(ffi.C.LOG_GRP_POLICY, 'lua-cqueues required to watch and reload RPZ file, continuing without watching') + end + end + + return function(_, query) + local label = query:name() + local rule = rules[label] + while rule == nil and string.len(label) > 0 do + label = string.sub(label, string.byte(label) + 2) + rule = rules['\1*'..label] + end + return rule + end +end + +-- Apply an action when query belongs to a slice (determined by slice_func()) +function policy.slice(slice_func, ...) + local actions = {...} + if #actions <= 0 then + error('[poli] at least one action must be provided to policy.slice()') + end + + return function(_, query) + local index = slice_func(query, #actions) + return actions[index] + end +end + +-- Initializes slicing function that randomly assigns queries to a slice based on their registrable domain +function policy.slice_randomize_psl(seed) + local has_psl, psl_lib = pcall(require, 'psl') + if not has_psl then + error('[poli] lua-psl is required for policy.slice_randomize_psl()') + end + -- load psl + local has_latest, psl = pcall(psl_lib.latest) + if not has_latest then -- compatibility with lua-psl < 0.15 + psl = psl_lib.builtin() + end + + if seed == nil then + seed = os.time() / (3600 * 24 * 7) + end + seed = math.floor(seed) -- convert to int + + return function(query, length) + assert(length > 0) + + local domain = kres.dname2str(query:name()) + if domain == nil then -- invalid data: auto-select first action + return 1 + end + if domain:len() > 1 then --remove trailing dot + domain = domain:sub(0, -2) + end + + -- do psl lookup for registrable domain + local reg_domain = psl:registrable_domain(domain) + if reg_domain == nil then -- fallback to unreg. domain + reg_domain = psl:unregistrable_domain(domain) + if reg_domain == nil then -- shouldn't happen: safe fallback + return 1 + end + end + + local rand_seed = seed + -- create deterministic seed for pseudo-random slice assignment + for i = 1, #reg_domain do + rand_seed = rand_seed + reg_domain:byte(i) + end + + -- use linear congruential generator with values from ANSI C + rand_seed = rand_seed % 0x80000000 -- ensure seed is positive 32b int + local rand = (1103515245 * rand_seed + 12345) % 0x10000 + return 1 + rand % length + end +end + +-- Prepare for making an answer from scratch. (Return the packet for convenience.) +local function answer_clear(req) + -- If we're in postrules, previous resolving might have chosen some RRs + -- for inclusion in the answer, so we need to avoid those. + -- *_selected arrays are in mempool, so explicit deallocation is not necessary. + req.answ_selected.len = 0 + req.auth_selected.len = 0 + req.add_selected.len = 0 + + -- Let's be defensive and clear the answer, too. + local pkt = req:ensure_answer() + if pkt == nil then return nil end + pkt:clear_payload() + req:ensure_edns() + return pkt +end + +function policy.DENY_MSG(msg, extended_error) + if msg and (type(msg) ~= 'string' or #msg >= 255) then + error('DENY_MSG: optional msg must be string shorter than 256 characters') + end + if extended_error == nil then + extended_error = kres.extended_error.BLOCKED + end + local action_name = msg and 'DENY_MSG' or 'DENY' + + return function (_, req) + -- Write authority information + local answer = answer_clear(req) + if answer == nil then return nil end + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NXDOMAIN) + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, answer:qname()) + if msg then + answer:begin(kres.section.ADDITIONAL) + answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT, + string.char(#msg) .. msg) + + end + req:set_extended_error(extended_error, "CR36") + log_policy_action(req, action_name) + return kres.DONE + end +end + +local function free_cb(func) + func:free() +end + +local debug_logline_cb = ffi.cast('trace_log_f', function (_, msg) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + ffi.C.kr_log_fmt( + ffi.C.LOG_GRP_REQDBG, -- but the original [group] tag also remains in the string + LOG_DEBUG, + 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=policy.DEBUG_ALWAYS', -- no meaningful locations + '[%-6s]%s', LOG_GRP_REQDBG_TAG, msg) -- msg should end with newline already +end) +ffi.gc(debug_logline_cb, free_cb) + +-- LOG_DEBUG without log_trace and without code locations +local function log_notrace(req, fmt, ...) + ffi.C.kr_log_fmt(ffi.C.LOG_GRP_REQDBG, LOG_DEBUG, + 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=', -- no meaningful locations + '%s', string.format( -- convert in lua, as integers in C varargs would pass as double + '[%-6s][%-6s][%05u.00] ' .. fmt, + LOG_GRP_REQDBG_TAG, LOG_GRP_POLICY_TAG, req.uid, ...) + ) +end + +local debug_logfinish_cb = ffi.cast('trace_callback_f', function (req) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + log_notrace(req, 'following rrsets were marked as interesting:\n%s\n', + req:selected_tostring()) + if req.answer ~= nil then + log_notrace(req, 'answer packet:\n%s\n', req.answer) + else + log_notrace(req, 'answer packet DROPPED\n') + end +end) +ffi.gc(debug_logfinish_cb, free_cb) + +-- log request packet +function policy.REQTRACE(_, req) + log_notrace(req, 'request packet:\n%s', req.qsource.packet) +end + +-- log how the request arrived, notably the client's IP +function policy.IPTRACE(_, req) + if req.qsource.addr == nil then + log_notrace(req, 'request packet arrived internally\n') + else + -- stringify transport flags: struct kr_request_qsource_flags + local qf = req.qsource.flags + local qf_str = qf.tcp and 'TCP' or 'UDP' + if qf.tls then qf_str = qf_str .. ' + TLS' end + if qf.http then qf_str = qf_str .. ' + HTTP' end + if qf.xdp then qf_str = qf_str .. ' + XDP' end + + log_notrace(req, 'request packet arrived from %s to %s (%s)\n', + req.qsource.addr, req.qsource.dst_addr, qf_str) + end + return nil -- chain rule +end + +function policy.DEBUG_ALWAYS(state, req) + policy.QTRACE(state, req) + req:trace_chain_callbacks(debug_logline_cb, debug_logfinish_cb) + policy.REQTRACE(state, req) +end + +local debug_stashlog_cb = ffi.cast('trace_log_f', function (req, msg) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + + -- stash messages for conditional logging in trace_finish + local stash = req:vars()['policy_debug_stash'] + table.insert(stash, ffi.string(msg)) +end) +ffi.gc(debug_stashlog_cb, free_cb) + +-- buffer debug logs and print then only if test() returns a truthy value +function policy.DEBUG_IF(test) + local debug_finish_cb = ffi.cast('trace_callback_f', function (cbreq) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + if test(cbreq) then + policy.REQTRACE(nil, cbreq) + debug_logfinish_cb(cbreq) -- unconditional version + + local stash = cbreq:vars()['policy_debug_stash'] + for _, line in ipairs(stash) do -- don't want one huge entry + ffi.C.kr_log_fmt(ffi.C.LOG_GRP_REQDBG, LOG_DEBUG, + 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=', -- no meaningful locations + '[%-6s]%s', LOG_GRP_REQDBG_TAG, line) + end + end + end) + ffi.gc(debug_finish_cb, function (func) func:free() end) + + return function (state, req) + req:vars()['policy_debug_stash'] = {} + policy.QTRACE(state, req) + req:trace_chain_callbacks(debug_stashlog_cb, debug_finish_cb) + return + end +end + +policy.DEBUG_CACHE_MISS = policy.DEBUG_IF( + function(req) + return not req:all_from_cache() + end +) + +policy.DENY = policy.DENY_MSG() -- compatibility with < 2.0 + +function policy.DROP(_, req) + local answer = answer_clear(req) + if answer == nil then return nil end + req:set_extended_error(kres.extended_error.PROHIBITED, "U5KL") + log_policy_action(req, 'DROP') + return kres.FAIL +end + +function policy.NO_ANSWER(_, req) + req.options.NO_ANSWER = true + log_policy_action(req, 'NO_ANSWER') + return kres.FAIL +end + +function policy.REFUSE(_, req) + local answer = answer_clear(req) + if answer == nil then return nil end + answer:rcode(kres.rcode.REFUSED) + answer:ad(false) + req:set_extended_error(kres.extended_error.PROHIBITED, "EIM4") + log_policy_action(req, 'REFUSE') + return kres.DONE +end + +function policy.TC(state, req) + -- Avoid non-UDP queries + if req.qsource.addr == nil or req.qsource.flags.tcp then + return state + end + + local answer = answer_clear(req) + if answer == nil then return nil end + answer:tc(1) + answer:ad(false) + log_policy_action(req, 'TC') + return kres.DONE +end + +function policy.QTRACE(_, req) + local qry = req:current() + req.options.TRACE = true + qry.flags.TRACE = true + return -- this allows to continue iterating over policy list +end + +-- Evaluate packet in given rules to determine policy action +function policy.evaluate(rules, req, query, state) + for i = 1, #rules do + local rule = rules[i] + if not rule.suspended then + local action = rule.cb(req, query) + if action ~= nil then + rule.count = rule.count + 1 + local next_state = action(state, req) + if next_state then -- Not a chain rule, + return next_state -- stop on first match + end + end + end + end + return +end + +-- Add rule to policy list +function policy.add(rule, postrule) + -- Compatibility with 1.0.0 API + -- it will be dropped in 1.2.0 + if rule == policy then + rule = postrule + postrule = nil + end + -- End of compatibility shim + local desc = {id=getruleid(), cb=rule, count=0} + table.insert(postrule and policy.postrules or policy.rules, desc) + return desc +end + +-- Remove rule from a list +local function delrule(rules, id) + for i, r in ipairs(rules) do + if r.id == id then + table.remove(rules, i) + return true + end + end + return false +end + +-- Delete rule from policy list +function policy.del(id) + if not delrule(policy.rules, id) then + if not delrule(policy.postrules, id) then + return false + end + end + return true +end + +-- Convert list of string names to domain names +function policy.todnames(names) + for i, v in ipairs(names) do + names[i] = kres.str2dname(v) + end + return names +end + +-- RFC1918 Private, local, broadcast, test and special zones +-- Considerations: RFC6761, sec 6.1. +-- https://www.iana.org/assignments/locally-served-dns-zones +local private_zones = { + -- RFC6303 + '10.in-addr.arpa.', + '16.172.in-addr.arpa.', + '17.172.in-addr.arpa.', + '18.172.in-addr.arpa.', + '19.172.in-addr.arpa.', + '20.172.in-addr.arpa.', + '21.172.in-addr.arpa.', + '22.172.in-addr.arpa.', + '23.172.in-addr.arpa.', + '24.172.in-addr.arpa.', + '25.172.in-addr.arpa.', + '26.172.in-addr.arpa.', + '27.172.in-addr.arpa.', + '28.172.in-addr.arpa.', + '29.172.in-addr.arpa.', + '30.172.in-addr.arpa.', + '31.172.in-addr.arpa.', + '168.192.in-addr.arpa.', + '0.in-addr.arpa.', + '254.169.in-addr.arpa.', + '2.0.192.in-addr.arpa.', + '100.51.198.in-addr.arpa.', + '113.0.203.in-addr.arpa.', + '255.255.255.255.in-addr.arpa.', + -- RFC7793 + '64.100.in-addr.arpa.', + '65.100.in-addr.arpa.', + '66.100.in-addr.arpa.', + '67.100.in-addr.arpa.', + '68.100.in-addr.arpa.', + '69.100.in-addr.arpa.', + '70.100.in-addr.arpa.', + '71.100.in-addr.arpa.', + '72.100.in-addr.arpa.', + '73.100.in-addr.arpa.', + '74.100.in-addr.arpa.', + '75.100.in-addr.arpa.', + '76.100.in-addr.arpa.', + '77.100.in-addr.arpa.', + '78.100.in-addr.arpa.', + '79.100.in-addr.arpa.', + '80.100.in-addr.arpa.', + '81.100.in-addr.arpa.', + '82.100.in-addr.arpa.', + '83.100.in-addr.arpa.', + '84.100.in-addr.arpa.', + '85.100.in-addr.arpa.', + '86.100.in-addr.arpa.', + '87.100.in-addr.arpa.', + '88.100.in-addr.arpa.', + '89.100.in-addr.arpa.', + '90.100.in-addr.arpa.', + '91.100.in-addr.arpa.', + '92.100.in-addr.arpa.', + '93.100.in-addr.arpa.', + '94.100.in-addr.arpa.', + '95.100.in-addr.arpa.', + '96.100.in-addr.arpa.', + '97.100.in-addr.arpa.', + '98.100.in-addr.arpa.', + '99.100.in-addr.arpa.', + '100.100.in-addr.arpa.', + '101.100.in-addr.arpa.', + '102.100.in-addr.arpa.', + '103.100.in-addr.arpa.', + '104.100.in-addr.arpa.', + '105.100.in-addr.arpa.', + '106.100.in-addr.arpa.', + '107.100.in-addr.arpa.', + '108.100.in-addr.arpa.', + '109.100.in-addr.arpa.', + '110.100.in-addr.arpa.', + '111.100.in-addr.arpa.', + '112.100.in-addr.arpa.', + '113.100.in-addr.arpa.', + '114.100.in-addr.arpa.', + '115.100.in-addr.arpa.', + '116.100.in-addr.arpa.', + '117.100.in-addr.arpa.', + '118.100.in-addr.arpa.', + '119.100.in-addr.arpa.', + '120.100.in-addr.arpa.', + '121.100.in-addr.arpa.', + '122.100.in-addr.arpa.', + '123.100.in-addr.arpa.', + '124.100.in-addr.arpa.', + '125.100.in-addr.arpa.', + '126.100.in-addr.arpa.', + '127.100.in-addr.arpa.', + + -- RFC6303 + -- localhost_reversed handles ::1 + '0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.', + 'd.f.ip6.arpa.', + '8.e.f.ip6.arpa.', + '9.e.f.ip6.arpa.', + 'a.e.f.ip6.arpa.', + 'b.e.f.ip6.arpa.', + '8.b.d.0.1.0.0.2.ip6.arpa.', + -- RFC8375 + 'home.arpa.', +} +policy.todnames(private_zones) + +-- @var Default rules +policy.rules = {} +policy.postrules = {} +policy.special_names = { + -- XXX: beware of special_names_optim() when modifying these filters + { + cb=policy.suffix_common(policy.DENY_MSG( + 'Blocking is mandated by standards, see references on ' + .. 'https://www.iana.org/assignments/' + .. 'locally-served-dns-zones/locally-served-dns-zones.xhtml', + kres.extended_error.NOTSUP), + private_zones, todname('arpa.')), + count=0 + }, + { + cb=policy.suffix(policy.DENY_MSG( + 'Blocking is mandated by standards, see references on ' + .. 'https://www.iana.org/assignments/' + .. 'special-use-domain-names/special-use-domain-names.xhtml', + kres.extended_error.NOTSUP), + { + todname('test.'), + todname('onion.'), + todname('invalid.'), + todname('local.'), -- RFC 8375.4 + }), + count=0 + }, + { + cb=policy.suffix(localhost, {dname_localhost}), + count=0 + }, + { + cb=policy.suffix_common(localhost_reversed, { + todname('127.in-addr.arpa.'), + todname('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.')}, + todname('arpa.')), + count=0 + }, +} + +-- Return boolean; false = no special name may apply, true = some might apply. +-- The point is to *efficiently* filter almost all QNAMEs that do not apply. +local function special_names_optim(req, sname) + local qname_size = req.qsource.packet.qname_size + if qname_size < 9 then return true end -- don't want to special-case bad array access + local root = sname + qname_size - 1 + return + -- .a???. or .t???. + (root[-5] == 4 and (root[-4] == 97 or root[-4] == 116)) + -- .on???. or .in?????. or lo???. or *ost. + or (root[-6] == 5 and root[-5] == 111 and root[-4] == 110) + or (root[-8] == 7 and root[-7] == 105 and root[-6] == 110) + or (root[-6] == 5 and root[-5] == 108 and root[-4] == 111) + or (root[-3] == 111 and root[-2] == 115 and root[-1] == 116) +end + +-- Top-down policy list walk until we hit a match +-- the caller is responsible for reordering policy list +-- from most specific to least specific. +-- Some rules may be chained, in this case they are evaluated +-- as a dependency chain, e.g. r1,r2,r3 -> r3(r2(r1(state))) +policy.layer = { + begin = function(state, req) + -- Don't act on "finished" cases. + if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end + local qry = req:initial() -- same as :current() but more descriptive + return policy.evaluate(policy.rules, req, qry, state) + or (special_names_optim(req, qry.sname) + and policy.evaluate(policy.special_names, req, qry, state)) + or state + end, + finish = function(state, req) + -- Optimization for the typical case + if #policy.postrules == 0 then return state end + -- Don't act on failed cases. + if bit.band(state, kres.FAIL) ~= 0 then return state end + return policy.evaluate(policy.postrules, req, req:initial(), state) or state + end +} + +return policy diff --git a/modules/policy/policy.rpz.test.lua b/modules/policy/policy.rpz.test.lua new file mode 100644 index 0000000..94fb9ce --- /dev/null +++ b/modules/policy/policy.rpz.test.lua @@ -0,0 +1,65 @@ + +local function prepare_cache() + cache.open(100*MB) + cache.clear() + + local ffi = require('ffi') + local c = kres.context().cache + + local passthru_addr = '\127\0\0\9' + rr_passthru = kres.rrset(todname('rpzpassthru.'), kres.type.A, kres.class.IN, 2147483647) + assert(rr_passthru:add_rdata(passthru_addr, #passthru_addr)) + assert(c:insert(rr_passthru, nil, ffi.C.KR_RANK_SECURE + ffi.C.KR_RANK_AUTH)) + + c:commit() +end + +local check_answer = require('test_utils').check_answer + +local function test_rpz() + check_answer('"CNAME ." return NXDOMAIN', + 'nxdomain.', kres.type.A, kres.rcode.NXDOMAIN) + check_answer('"CNAME *." return NODATA', + 'nodata.', kres.type.A, kres.rcode.NOERROR, {}) + check_answer('"CNAME *. on wildcard" return NODATA', + 'nodata.nxdomain.', kres.type.A, kres.rcode.NOERROR, {}) + check_answer('"CNAME rpz-drop." be dropped', + 'rpzdrop.', kres.type.A, kres.rcode.SERVFAIL) + check_answer('"CNAME rpz-passthru" return A rrset', + 'rpzpassthru.', kres.type.A, kres.rcode.NOERROR, '127.0.0.9') + check_answer('"A 192.168.5.5" return local A rrset', + 'rra.', kres.type.A, kres.rcode.NOERROR, '192.168.5.5') + check_answer('"A 192.168.6.6" with suffixed zone name in owner return local A rrset', + 'rra-zonename-suffix.', kres.type.A, kres.rcode.NOERROR, '192.168.6.6') + check_answer('"A 192.168.7.7" with suffixed zone name in owner return local A rrset', + 'testdomain.rra.', kres.type.A, kres.rcode.NOERROR, '192.168.7.7') + check_answer('non existing AAAA on rra domain return NODATA', + 'rra.', kres.type.AAAA, kres.rcode.NOERROR, {}) + check_answer('"A 192.168.8.8" and domain with uppercase and lowercase letters', + 'case.sensitive.', kres.type.A, kres.rcode.NOERROR, '192.168.8.8') + check_answer('"A 192.168.8.8" and domain with uppercase and lowercase letters', + 'CASe.SENSItivE.', kres.type.A, kres.rcode.NOERROR, '192.168.8.8') + check_answer('two AAAA records', + 'two.records.', kres.type.AAAA, kres.rcode.NOERROR, + {'2001:db8::2', '2001:db8::1'}) +end + +local function test_rpz_soa() + check_answer('"CNAME ." return NXDOMAIN (SOA origin)', + 'nxdomain-fqdn.', kres.type.A, kres.rcode.NXDOMAIN) + check_answer('"CNAME *." return NODATA (SOA origin)', + 'nodata-fqdn.', kres.type.A, kres.rcode.NOERROR, {}) +end + +net.ipv4 = false +net.ipv6 = false + +prepare_cache() + +policy.add(policy.rpz(policy.DENY, 'policy.test.rpz')) +policy.add(policy.rpz(policy.DENY, 'policy.test.rpz.soa')) + +return { + test_rpz, + test_rpz_soa, +} diff --git a/modules/policy/policy.slice.test.lua b/modules/policy/policy.slice.test.lua new file mode 100644 index 0000000..89c1b05 --- /dev/null +++ b/modules/policy/policy.slice.test.lua @@ -0,0 +1,109 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +-- check lua-psl is available +local has_psl = pcall(require, 'psl') +if not has_psl then + os.exit(77) -- SKIP policy.slice +end + +-- unload modules which are not related to this test +if ta_update then + modules.unload('ta_update') +end +if ta_signal_query then + modules.unload('ta_signal_query') +end +if priming then + modules.unload('priming') +end +if detect_time_skew then + modules.unload('detect_time_skew') +end + +local kres = require('kres') + +local slice_queries = { + {}, + {}, + {}, +} + +local function sliceaction(index) + return function(_, req) + -- log query + local qry = req:current() + local name = kres.dname2str(qry:name()) + local count = slice_queries[index][name] + if not count then + count = 0 + end + slice_queries[index][name] = count + 1 + + -- refuse query + local answer = req:ensure_answer() + if answer == nil then return nil end + answer:rcode(kres.rcode.REFUSED) + answer:ad(false) + return kres.DONE + end +end + +-- configure slicing +policy.add(policy.slice( + policy.slice_randomize_psl(0), + sliceaction(1), + sliceaction(2), + sliceaction(3) +)) + +local function check_slice(desc, qname, qtype, expected_slice, expected_count) + callback = function() + count = slice_queries[expected_slice][qname] + qtype_str = kres.tostring.type[qtype] + same(count, expected_count, desc .. qname .. ' ' .. qtype_str) + end + resolve(qname, qtype, kres.class.IN, {}, callback) +end + +local function test_randomize_psl() + local desc = 'randomize_psl() same qname, different qtype (same slice): ' + check_slice(desc, 'example.com.', kres.type.A, 2, 1) + check_slice(desc, 'example.com.', kres.type.AAAA, 2, 2) + check_slice(desc, 'example.com.', kres.type.MX, 2, 3) + check_slice(desc, 'example.com.', kres.type.NS, 2, 4) + + desc = 'randomize_psl() subdomain in same slice: ' + check_slice(desc, 'a.example.com.', kres.type.A, 2, 1) + check_slice(desc, 'b.example.com.', kres.type.A, 2, 1) + check_slice(desc, 'c.example.com.', kres.type.A, 2, 1) + check_slice(desc, 'a.a.example.com.', kres.type.A, 2, 1) + check_slice(desc, 'a.a.a.example.com.', kres.type.A, 2, 1) + + desc = 'randomize_psl() different qnames in different slices: ' + check_slice(desc, 'example2.com.', kres.type.A, 1, 1) + check_slice(desc, 'example5.com.', kres.type.A, 3, 1) + + desc = 'randomize_psl() check unregistrable domains: ' + check_slice(desc, '.', kres.type.A, 3, 1) + check_slice(desc, 'com.', kres.type.A, 1, 1) + check_slice(desc, 'cz.', kres.type.A, 2, 1) + check_slice(desc, 'co.uk.', kres.type.A, 1, 1) + + desc = 'randomize_psl() check multi-level reg. domains: ' + check_slice(desc, 'example.co.uk.', kres.type.A, 3, 1) + check_slice(desc, 'a.example.co.uk.', kres.type.A, 3, 1) + check_slice(desc, 'b.example.co.uk.', kres.type.MX, 3, 1) + check_slice(desc, 'example2.co.uk.', kres.type.A, 2, 1) + + desc = 'randomize_psl() reg. domain - always ends up in slice: ' + check_slice(desc, 'fdsnnsdfvkdn.com.', kres.type.A, 3, 1) + check_slice(desc, 'bdfbd.cz.', kres.type.A, 1, 1) + check_slice(desc, 'nrojgvn.net.', kres.type.A, 1, 1) + check_slice(desc, 'jnojtnbv.engineer.', kres.type.A, 2, 1) + check_slice(desc, 'dfnjonfdsjg.gov.', kres.type.A, 1, 1) + check_slice(desc, 'okfjnosdfgjn.mil.', kres.type.A, 1, 1) + check_slice(desc, 'josdhnojn.test.', kres.type.A, 2, 1) +end + +return { + test_randomize_psl, +} diff --git a/modules/policy/policy.test.lua b/modules/policy/policy.test.lua new file mode 100644 index 0000000..69dda1f --- /dev/null +++ b/modules/policy/policy.test.lua @@ -0,0 +1,145 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +-- setup resolver +-- policy module should be loaded by default, do not load it explicitly + +-- do not attempt to contact outside world, operate only on cache +net.ipv4 = false +net.ipv6 = false +-- do not listen, test is driven by config code +env.KRESD_NO_LISTEN = true + +-- test for default configuration +local function test_tls_forward() + boom(policy.TLS_FORWARD, {}, 'TLS_FORWARD without arguments') + boom(policy.TLS_FORWARD, {'1'}, 'TLS_FORWARD with non-table argument') + boom(policy.TLS_FORWARD, {{}}, 'TLS_FORWARD with empty table') + boom(policy.TLS_FORWARD, {{{}}}, 'TLS_FORWARD with empty target table') + boom(policy.TLS_FORWARD, {{{bleble=''}}}, 'TLS_FORWARD with invalid parameters in table') + + boom(policy.TLS_FORWARD, {{'1'}}, 'TLS_FORWARD with invalid IP address') + boom(policy.TLS_FORWARD, {{{'::1', bleble=''}}}, 'TLS_FORWARD with valid IP and invalid parameters') + boom(policy.TLS_FORWARD, {{{'127.0.0.1'}}}, 'TLS_FORWARD with missing auth parameters') + + ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true}}), 'TLS_FORWARD with no authentication') + boom(policy.TLS_FORWARD, {{{'100:dead::', insecure=true}, + {'100:DEAD:0::', insecure=true} + }}, 'TLS_FORWARD with duplicate IP addresses is not allowed') + ok(policy.TLS_FORWARD({{'100:dead::2', insecure=true}, + {'100:dead::2@443', insecure=true} + }), 'TLS_FORWARD with duplicate IP addresses but different ports is allowed') + ok(policy.TLS_FORWARD({{'100:dead::3', insecure=true}, + {'100:beef::3', insecure=true} + }), 'TLS_FORWARD with different IPv6 addresses is allowed') + ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true}, + {'127.0.0.2', insecure=true} + }), 'TLS_FORWARD with different IPv4 addresses is allowed') + + boom(policy.TLS_FORWARD, {{{'::1', pin_sha256=''}}}, 'TLS_FORWARD with empty pin_sha256') + boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='č'}}}, 'TLS_FORWARD with bad pin_sha256') + boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='d161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}}, + 'TLS_FORWARD with bad pin_sha256 (short base64)') + boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='bbd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}}, + 'TLS_FORWARD with bad pin_sha256 (long base64)') + ok(policy.TLS_FORWARD({ + {'::1', pin_sha256='g1PpXsxqPchz2tH6w9kcvVXqzQ0QclhInFP2+VWOqic='} + }), 'TLS_FORWARD with base64 pin_sha256') + ok(policy.TLS_FORWARD({ + {'::1', pin_sha256={ + 'ev1xcdU++dY9BlcX0QoKeaUftvXQvNIz/PCss1Z/3ek=', + 'SgnqTFcvYduWX7+VUnlNFT1gwSNvQdZakH7blChIRbM=', + 'bd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg=', + }}}), 'TLS_FORWARD with a table of pins') + + -- ok(policy.TLS_FORWARD({{'::1', hostname='test.', ca_file='/tmp/ca.crt'}}), 'TLS_FORWARD with hostname + CA cert') + ok(policy.TLS_FORWARD({{'::1', hostname='test.'}}), + 'TLS_FORWARD with just hostname (use system CA store)') + boom(policy.TLS_FORWARD, {{{'::1', ca_file='/tmp/ca.crt'}}}, + 'TLS_FORWARD with just CA cert') + boom(policy.TLS_FORWARD, {{{'::1', hostname='', ca_file='/tmp/ca.crt'}}}, + 'TLS_FORWARD with empty hostname + CA cert') + boom(policy.TLS_FORWARD, { + {{'::1', hostname='test.', ca_file='/dev/a_file_which_surely_does_NOT_exist!'}} + }, 'TLS_FORWARD with hostname + unreadable CA cert') + +end + +local function test_slice() + boom(policy.slice, {function() end}, 'policy.slice() without any action') + ok(policy.slice, {function() end, policy.FORWARD, policy.FORWARD}) +end + +local function mirror_parser(srv, cv, nqueries) + local ffi = require('ffi') + local test_end = 0 + local TIMEOUT = 5 -- seconds + + while true do + local input = srv:xread('*a', 'bn', TIMEOUT) + if not input then + cv:signal() + return false, 'mirror: timeout' + end + --print(#input, input) + -- convert query to knot_pkt_t + local wire = ffi.cast("void *", input) + local pkt = ffi.gc(ffi.C.knot_pkt_new(wire, #input, nil), ffi.C.knot_pkt_free) + if not pkt then + cv:signal() + return false, 'mirror: packet allocation error' + end + + local result = ffi.C.knot_pkt_parse(pkt, 0) + if result ~= 0 then + cv:signal() + return false, 'mirror: packet parse error' + end + --print(pkt) + test_end = test_end + 1 + + if test_end == nqueries then + cv:signal() + return true, 'packet mirror pass' + end + + end +end + +local function test_mirror() + local kluautil = require('kluautil') + local socket = require('cqueues.socket') + local cond = require('cqueues.condition') + local cv = cond.new() + local queries = {} + local srv = socket.listen({ + host = "127.0.0.1", + port = 36659, + type = socket.SOCK_DGRAM, + }) + -- binary mode, no buffering + srv:setmode('bn', 'bn') + + queries["bla.mujtest.cz."] = kres.type.AAAA + queries["bla.mujtest2.cz."] = kres.type.AAAA + + -- UDP server for test + worker.bg_worker.cq:wrap(function() + local err, msg = mirror_parser(srv, cv, kluautil.kr_table_len(queries)) + + ok(err, msg) + end) + + policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest.cz.'}))) + policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest2.cz.'}))) + + for name, rtype in pairs(queries) do + resolve(name, rtype) + end + + cv:wait() +end + +return { + test_tls_forward, + test_mirror, + test_slice, +} diff --git a/modules/policy/policy.test.rpz b/modules/policy/policy.test.rpz new file mode 100644 index 0000000..d962e9f --- /dev/null +++ b/modules/policy/policy.test.rpz @@ -0,0 +1,18 @@ +$ORIGIN testdomain. +$TTL 30 +testdomain. SOA nonexistent.testdomain. testdomain. 1 12h 15m 3w 2h + NS nonexistent.testdomain. + +nxdomain CNAME . +nodata CNAME *. +*.nxdomain CNAME *. +rpzdrop CNAME rpz-drop. +rpzpassthru CNAME rpz-passthru. +rra A 192.168.5.5 +rra-zonename-suffix A 192.168.6.6 +testdomain.rra.testdomain. A 192.168.7.7 +CaSe.SeNSiTiVe A 192.168.8.8 + +two.records AAAA 2001:db8::2 +two.records AAAA 2001:db8::1 + diff --git a/modules/policy/policy.test.rpz.soa b/modules/policy/policy.test.rpz.soa new file mode 100644 index 0000000..ad18aa4 --- /dev/null +++ b/modules/policy/policy.test.rpz.soa @@ -0,0 +1,5 @@ +test2domain. SOA nonexistent.test2domain. test2domain. 1 12h 15m 3w 2h + NS nonexistent.test2domain. + +nxdomain-fqdn.test2domain. CNAME . +nodata-fqdn.test2domain. CNAME *. diff --git a/modules/policy/test.integr/deckard.yaml b/modules/policy/test.integr/deckard.yaml new file mode 100644 index 0000000..9c6cb70 --- /dev/null +++ b/modules/policy/test.integr/deckard.yaml @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +programs: +- name: kresd + binary: kresd + additional: + - --noninteractive + templates: + - modules/policy/test.integr/kresd_config.j2 + - tests/integration/hints_zone.j2 + configs: + - config + - hints diff --git a/modules/policy/test.integr/kresd_config.j2 b/modules/policy/test.integr/kresd_config.j2 new file mode 100644 index 0000000..668c792 --- /dev/null +++ b/modules/policy/test.integr/kresd_config.j2 @@ -0,0 +1,59 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +{% raw %} +policy.add(policy.domains(policy.DENY, {todname('example.com')})) +policy.add(policy.suffix(policy.REFUSE, {todname('refuse.example.com')})) + +-- make sure DNSSEC is turned off for tests +trust_anchors.remove('.') + +-- Disable RFC5011 TA update +if ta_update then + modules.unload('ta_update') +end + +-- Disable RFC8145 signaling, scenario doesn't provide expected answers +if ta_signal_query then + modules.unload('ta_signal_query') +end + +-- Disable RFC8109 priming, scenario doesn't provide expected answers +if priming then + modules.unload('priming') +end + +-- Disable this module because it make one priming query +if detect_time_skew then + modules.unload('detect_time_skew') +end + +_hint_root_file('hints') +cache.size = 2*MB +log_level('debug') +{% endraw %} + +net = { '{{SELF_ADDR}}' } + + +{% if QMIN == "false" %} +option('NO_MINIMIZE', true) +{% else %} +option('NO_MINIMIZE', false) +{% endif %} + + +-- Self-checks on globals +assert(help() ~= nil) +assert(worker.id ~= nil) +-- Self-checks on facilities +assert(cache.count() == 0) +assert(cache.stats() ~= nil) +assert(cache.backends() ~= nil) +assert(worker.stats() ~= nil) +assert(net.interfaces() ~= nil) +-- Self-checks on loaded stuff +assert(net.list()[1].transport.ip == '{{SELF_ADDR}}') +assert(#modules.list() > 0) +-- Self-check timers +ev = event.recurrent(1 * sec, function (ev) return 1 end) +event.cancel(ev) +ev = event.after(0, function (ev) return 1 end) diff --git a/modules/policy/test.integr/refuse.rpl b/modules/policy/test.integr/refuse.rpl new file mode 100644 index 0000000..08f9942 --- /dev/null +++ b/modules/policy/test.integr/refuse.rpl @@ -0,0 +1,44 @@ +; SPDX-License-Identifier: GPL-3.0-or-later +; config options + stub-addr: 193.0.14.129 # K.ROOT-SERVERS.NET. +CONFIG_END + +SCENARIO_BEGIN Test refuse policy + +STEP 10 QUERY +ENTRY_BEGIN +REPLY RD AD +SECTION QUESTION +www.refuse.example.com. IN A +ENTRY_END + +STEP 20 CHECK_ANSWER +ENTRY_BEGIN +MATCH all answer +; AD must not be set in the answer +REPLY QR RD RA REFUSED +SECTION QUESTION +www.refuse.example.com. IN A +SECTION ANSWER +ENTRY_END + +STEP 30 QUERY +ENTRY_BEGIN +REPLY RD AD +SECTION QUESTION +example.com. IN A +ENTRY_END + +STEP 40 CHECK_ANSWER +ENTRY_BEGIN +MATCH all answer +REPLY QR RD AA RA NXDOMAIN +SECTION QUESTION +example.com. IN A +SECTION ANSWER +SECTION AUTHORITY +example.com. 10800 IN SOA example.com. nobody.invalid. 1 3600 1200 604800 10800 +ENTRY_END + + +SCENARIO_END |