diff options
Diffstat (limited to 'tests/py/nft-test.py')
-rwxr-xr-x | tests/py/nft-test.py | 1586 |
1 files changed, 1586 insertions, 0 deletions
diff --git a/tests/py/nft-test.py b/tests/py/nft-test.py new file mode 100755 index 0000000..9a25503 --- /dev/null +++ b/tests/py/nft-test.py @@ -0,0 +1,1586 @@ +#!/usr/bin/env python +# +# (C) 2014 by Ana Rey Botello <anarey@gmail.com> +# +# Based on iptables-test.py: +# (C) 2012 by Pablo Neira Ayuso <pablo@netfilter.org>" +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# Thanks to the Outreach Program for Women (OPW) for sponsoring this test +# infrastructure. + +from __future__ import print_function +import sys +import os +import argparse +import signal +import json +import traceback +import tempfile + +TESTS_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(TESTS_PATH, '../../py/')) +os.environ['TZ'] = 'UTC-2' + +from nftables import Nftables + +TESTS_DIRECTORY = ["any", "arp", "bridge", "inet", "ip", "ip6", "netdev"] +LOGFILE = "/tmp/nftables-test.log" +log_file = None +table_list = [] +chain_list = [] +all_set = dict() +obj_list = [] +signal_received = 0 + + +class Colors: + if sys.stdout.isatty() and sys.stderr.isatty(): + HEADER = '\033[95m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + RED = '\033[91m' + ENDC = '\033[0m' + else: + HEADER = '' + GREEN = '' + YELLOW = '' + RED = '' + ENDC = '' + + +class Chain: + """Class that represents a chain""" + + def __init__(self, name, config, lineno): + self.name = name + self.config = config + self.lineno = lineno + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __str__(self): + return "%s" % self.name + + +class Table: + """Class that represents a table""" + + def __init__(self, family, name, chains): + self.family = family + self.name = name + self.chains = chains + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __str__(self): + return "%s %s" % (self.family, self.name) + + +class Set: + """Class that represents a set""" + + def __init__(self, family, table, name, type, data, timeout, flags): + self.family = family + self.table = table + self.name = name + self.type = type + self.data = data + self.timeout = timeout + self.flags = flags + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + +class Obj: + """Class that represents an object""" + + def __init__(self, table, family, name, type, spcf): + self.table = table + self.family = family + self.name = name + self.type = type + self.spcf = spcf + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + +def print_msg(reason, errstr, filename=None, lineno=None, color=None): + ''' + Prints a message with nice colors, indicating file and line number. + ''' + color_errstr = "%s%s%s" % (color, errstr, Colors.ENDC) + if filename and lineno: + sys.stderr.write("%s: %s line %d: %s\n" % + (filename, color_errstr, lineno + 1, reason)) + else: + sys.stderr.write("%s %s\n" % (color_errstr, reason)) + sys.stderr.flush() # So that the message stay in the right place. + + +def print_error(reason, filename=None, lineno=None): + print_msg(reason, "ERROR:", filename, lineno, Colors.RED) + + +def print_warning(reason, filename=None, lineno=None): + print_msg(reason, "WARNING:", filename, lineno, Colors.YELLOW) + +def print_info(reason, filename=None, lineno=None): + print_msg(reason, "INFO:", filename, lineno, Colors.GREEN) + +def color_differences(rule, other, color): + rlen = len(rule) + olen = len(other) + out = "" + i = 0 + + # find equal part at start + for i in range(rlen): + if i >= olen or rule[i] != other[i]: + break + if i > 0: + out += rule[:i] + rule = rule[i:] + other = other[i:] + rlen = len(rule) + olen = len(other) + + # find equal part at end + for i in range(1, rlen + 1): + if i > olen or rule[rlen - i] != other[olen - i]: + i -= 1 + break + if rlen > i: + out += color + rule[:rlen - i] + Colors.ENDC + rule = rule[rlen - i:] + + out += rule + return out + +def print_differences_warning(filename, lineno, rule1, rule2, cmd): + colored_rule1 = color_differences(rule1, rule2, Colors.YELLOW) + colored_rule2 = color_differences(rule2, rule1, Colors.YELLOW) + reason = "'%s': '%s' mismatches '%s'" % (cmd, colored_rule1, colored_rule2) + print_warning(reason, filename, lineno) + + +def print_differences_error(filename, lineno, cmd): + reason = "'%s': Listing is broken." % cmd + print_error(reason, filename, lineno) + + +def table_exist(table, filename, lineno): + ''' + Exists a table. + ''' + cmd = "list table %s" % table + ret = execute_cmd(cmd, filename, lineno) + + return True if (ret == 0) else False + + +def table_flush(table, filename, lineno): + ''' + Flush a table. + ''' + cmd = "flush table %s" % table + execute_cmd(cmd, filename, lineno) + + return cmd + + +def table_create(table, filename, lineno): + ''' + Adds a table. + ''' + # We check if table exists. + if table_exist(table, filename, lineno): + reason = "Table %s already exists" % table + print_error(reason, filename, lineno) + return -1 + + table_list.append(table) + + # We add a new table + cmd = "add table %s" % table + ret = execute_cmd(cmd, filename, lineno) + + if ret != 0: + reason = "Cannot " + cmd + print_error(reason, filename, lineno) + table_list.remove(table) + return -1 + + # We check if table was added correctly. + if not table_exist(table, filename, lineno): + table_list.remove(table) + reason = "I have just added the table %s " \ + "but it does not exist. Giving up!" % table + print_error(reason, filename, lineno) + return -1 + + for table_chain in table.chains: + chain = chain_get_by_name(table_chain) + if chain is None: + reason = "The chain %s requested by table %s " \ + "does not exist." % (table_chain, table) + print_error(reason, filename, lineno) + else: + chain_create(chain, table, filename) + + return 0 + + +def table_delete(table, filename=None, lineno=None): + ''' + Deletes a table. + ''' + if not table_exist(table, filename, lineno): + reason = "Table %s does not exist but I added it before." % table + print_error(reason, filename, lineno) + return -1 + + cmd = "delete table %s" % table + ret = execute_cmd(cmd, filename, lineno) + if ret != 0: + reason = "%s: I cannot delete table %s. Giving up!" % (cmd, table) + print_error(reason, filename, lineno) + return -1 + + if table_exist(table, filename, lineno): + reason = "I have just deleted the table %s " \ + "but it still exists." % table + print_error(reason, filename, lineno) + return -1 + + return 0 + + +def chain_exist(chain, table, filename): + ''' + Checks a chain + ''' + cmd = "list chain %s %s" % (table, chain) + ret = execute_cmd(cmd, filename, chain.lineno) + + return True if (ret == 0) else False + + +def chain_create(chain, table, filename): + ''' + Adds a chain + ''' + if chain_exist(chain, table, filename): + reason = "This chain '%s' exists in %s. I cannot create " \ + "two chains with same name." % (chain, table) + print_error(reason, filename, chain.lineno) + return -1 + + cmd = "add chain %s %s" % (table, chain) + if chain.config: + cmd += " { %s; }" % chain.config + + ret = execute_cmd(cmd, filename, chain.lineno) + if ret != 0: + reason = "I cannot create the chain '%s'" % chain + print_error(reason, filename, chain.lineno) + return -1 + + if not chain_exist(chain, table, filename): + reason = "I have added the chain '%s' " \ + "but it does not exist in %s" % (chain, table) + print_error(reason, filename, chain.lineno) + return -1 + + return 0 + + +def chain_delete(chain, table, filename=None, lineno=None): + ''' + Flushes and deletes a chain. + ''' + if not chain_exist(chain, table, filename): + reason = "The chain %s does not exist in %s. " \ + "I cannot delete it." % (chain, table) + print_error(reason, filename, lineno) + return -1 + + cmd = "flush chain %s %s" % (table, chain) + ret = execute_cmd(cmd, filename, lineno) + if ret != 0: + reason = "I cannot " + cmd + print_error(reason, filename, lineno) + return -1 + + cmd = "delete chain %s %s" % (table, chain) + ret = execute_cmd(cmd, filename, lineno) + if ret != 0: + reason = "I cannot " + cmd + print_error(reason, filename, lineno) + return -1 + + if chain_exist(chain, table, filename): + reason = "The chain %s exists in %s. " \ + "I cannot delete this chain" % (chain, table) + print_error(reason, filename, lineno) + return -1 + + return 0 + + +def chain_get_by_name(name): + for chain in chain_list: + if chain.name == name: + break + else: + chain = None + + return chain + + +def set_add(s, test_result, filename, lineno): + ''' + Adds a set. + ''' + if not table_list: + reason = "Missing table to add rule" + print_error(reason, filename, lineno) + return -1 + + for table in table_list: + s.table = table.name + s.family = table.family + if _set_exist(s, filename, lineno): + reason = "Set %s already exists in %s" % (s.name, table) + print_error(reason, filename, lineno) + return -1 + + flags = s.flags + if flags != "": + flags = "flags %s; " % flags + + if s.data == "": + cmd = "add set %s %s { type %s;%s %s}" % (table, s.name, s.type, s.timeout, flags) + else: + cmd = "add map %s %s { type %s : %s;%s %s}" % (table, s.name, s.type, s.data, s.timeout, flags) + + ret = execute_cmd(cmd, filename, lineno) + + if (ret == 0 and test_result == "fail") or \ + (ret != 0 and test_result == "ok"): + reason = "%s: I cannot add the set %s" % (cmd, s.name) + print_error(reason, filename, lineno) + return -1 + + if not _set_exist(s, filename, lineno): + reason = "I have just added the set %s to " \ + "the table %s but it does not exist" % (s.name, table) + print_error(reason, filename, lineno) + return -1 + + return 0 + + +def map_add(s, test_result, filename, lineno): + ''' + Adds a map + ''' + if not table_list: + reason = "Missing table to add rule" + print_error(reason, filename, lineno) + return -1 + + for table in table_list: + s.table = table.name + s.family = table.family + if _map_exist(s, filename, lineno): + reason = "Map %s already exists in %s" % (s.name, table) + print_error(reason, filename, lineno) + return -1 + + flags = s.flags + if flags != "": + flags = "flags %s; " % flags + + cmd = "add map %s %s { type %s : %s;%s %s}" % (table, s.name, s.type, s.data, s.timeout, flags) + + ret = execute_cmd(cmd, filename, lineno) + + if (ret == 0 and test_result == "fail") or \ + (ret != 0 and test_result == "ok"): + reason = "%s: I cannot add the set %s" % (cmd, s.name) + print_error(reason, filename, lineno) + return -1 + + if not _map_exist(s, filename, lineno): + reason = "I have just added the set %s to " \ + "the table %s but it does not exist" % (s.name, table) + print_error(reason, filename, lineno) + return -1 + + +def set_add_elements(set_element, set_name, state, filename, lineno): + ''' + Adds elements to the set. + ''' + if not table_list: + reason = "Missing table to add rules" + print_error(reason, filename, lineno) + return -1 + + for table in table_list: + # Check if set exists. + if (not set_exist(set_name, table, filename, lineno) or + set_name not in all_set) and state == "ok": + reason = "I cannot add an element to the set %s " \ + "since it does not exist." % set_name + print_error(reason, filename, lineno) + return -1 + + element = ", ".join(set_element) + cmd = "add element %s %s { %s }" % (table, set_name, element) + ret = execute_cmd(cmd, filename, lineno) + + if (state == "fail" and ret == 0) or (state == "ok" and ret != 0): + if state == "fail": + test_state = "This rule should have failed." + else: + test_state = "This rule should not have failed." + + reason = cmd + ": " + test_state + print_error(reason, filename, lineno) + return -1 + + # Add element into all_set. + if ret == 0 and state == "ok": + for e in set_element: + all_set[set_name].add(e) + + return 0 + + +def set_delete_elements(set_element, set_name, table, filename=None, + lineno=None): + ''' + Deletes elements in a set. + ''' + for element in set_element: + cmd = "delete element %s %s { %s }" % (table, set_name, element) + ret = execute_cmd(cmd, filename, lineno) + if ret != 0: + reason = "I cannot delete element %s " \ + "from the set %s" % (element, set_name) + print_error(reason, filename, lineno) + return -1 + + return 0 + + +def set_delete(table, filename=None, lineno=None): + ''' + Deletes set and its content. + ''' + for set_name in all_set.keys(): + # Check if exists the set + if not set_exist(set_name, table, filename, lineno): + reason = "The set %s does not exist, " \ + "I cannot delete it" % set_name + print_error(reason, filename, lineno) + return -1 + + # We delete all elements in the set + set_delete_elements(all_set[set_name], set_name, table, filename, + lineno) + + # We delete the set. + cmd = "delete set %s %s" % (table, set_name) + ret = execute_cmd(cmd, filename, lineno) + + # Check if the set still exists after I deleted it. + if ret != 0 or set_exist(set_name, table, filename, lineno): + reason = "Cannot remove the set " + set_name + print_error(reason, filename, lineno) + return -1 + + return 0 + + +def set_exist(set_name, table, filename, lineno): + ''' + Check if the set exists. + ''' + cmd = "list set %s %s" % (table, set_name) + ret = execute_cmd(cmd, filename, lineno) + + return True if (ret == 0) else False + + +def _set_exist(s, filename, lineno): + ''' + Check if the set exists. + ''' + cmd = "list set %s %s %s" % (s.family, s.table, s.name) + ret = execute_cmd(cmd, filename, lineno) + + return True if (ret == 0) else False + + +def _map_exist(s, filename, lineno): + ''' + Check if the map exists. + ''' + cmd = "list map %s %s %s" % (s.family, s.table, s.name) + ret = execute_cmd(cmd, filename, lineno) + + return True if (ret == 0) else False + + +def set_check_element(rule1, rule2): + ''' + Check if element exists in anonymous sets. + ''' + pos1 = rule1.find("{") + pos2 = rule2.find("{") + + if (rule1[:pos1] != rule2[:pos2]): + return False + + end1 = rule1.find("}") + end2 = rule2.find("}") + + if (pos1 != -1) and (pos2 != -1) and (end1 != -1) and (end2 != -1): + list1 = (rule1[pos1 + 1:end1].replace(" ", "")).split(",") + list2 = (rule2[pos2 + 1:end2].replace(" ", "")).split(",") + list1.sort() + list2.sort() + if list1 != list2: + return False + + return rule1[end1:] == rule2[end2:] + + return False + + +def obj_add(o, test_result, filename, lineno): + ''' + Adds an object. + ''' + if not table_list: + reason = "Missing table to add rule" + print_error(reason, filename, lineno) + return -1 + + for table in table_list: + o.table = table.name + o.family = table.family + obj_handle = o.type + " " + o.name + if _obj_exist(o, filename, lineno): + reason = "The %s already exists in %s" % (obj_handle, table) + print_error(reason, filename, lineno) + return -1 + + cmd = "add %s %s %s %s" % (o.type, table, o.name, o.spcf) + ret = execute_cmd(cmd, filename, lineno) + + if (ret == 0 and test_result == "fail") or \ + (ret != 0 and test_result == "ok"): + reason = "%s: I cannot add the %s" % (cmd, obj_handle) + print_error(reason, filename, lineno) + return -1 + + exist = _obj_exist(o, filename, lineno) + + if exist: + if test_result == "ok": + return 0 + reason = "I added the %s to the table %s " \ + "but it should have failed" % (obj_handle, table) + print_error(reason, filename, lineno) + return -1 + + if test_result == "fail": + return 0 + + reason = "I have just added the %s to " \ + "the table %s but it does not exist" % (obj_handle, table) + print_error(reason, filename, lineno) + return -1 + +def obj_delete(table, filename=None, lineno=None): + ''' + Deletes object. + ''' + for o in obj_list: + obj_handle = o.type + " " + o.name + # Check if exists the obj + if not obj_exist(o, table, filename, lineno): + reason = "The %s does not exist, I cannot delete it" % obj_handle + print_error(reason, filename, lineno) + return -1 + + # We delete the object. + cmd = "delete %s %s %s" % (o.type, table, o.name) + ret = execute_cmd(cmd, filename, lineno) + + # Check if the object still exists after I deleted it. + if ret != 0 or obj_exist(o, table, filename, lineno): + reason = "Cannot remove the " + obj_handle + print_error(reason, filename, lineno) + return -1 + + return 0 + + +def obj_exist(o, table, filename, lineno): + ''' + Check if the object exists. + ''' + cmd = "list %s %s %s" % (o.type, table, o.name) + ret = execute_cmd(cmd, filename, lineno) + + return True if (ret == 0) else False + + +def _obj_exist(o, filename, lineno): + ''' + Check if the object exists. + ''' + cmd = "list %s %s %s %s" % (o.type, o.family, o.table, o.name) + ret = execute_cmd(cmd, filename, lineno) + + return True if (ret == 0) else False + + +def output_clean(pre_output, chain): + pos_chain = pre_output.find(chain.name) + if pos_chain == -1: + return "" + output_intermediate = pre_output[pos_chain:] + brace_start = output_intermediate.find("{") + brace_end = output_intermediate.find("}") + pre_rule = output_intermediate[brace_start:brace_end] + if pre_rule[1:].find("{") > -1: # this rule has a set. + set = pre_rule[1:].replace("\t", "").replace("\n", "").strip() + set = set.split(";")[2].strip() + "}" + remainder = output_clean(chain.name + " {;;" + output_intermediate[brace_end+1:], chain) + if len(remainder) <= 0: + return set + return set + " " + remainder + else: + rule = pre_rule.split(";")[2].replace("\t", "").replace("\n", "").\ + strip() + if len(rule) < 0: + return "" + return rule + + +def payload_check_elems_to_set(elems): + newset = set() + + for n, line in enumerate(elems.split('[end]')): + e = line.strip() + if e in newset: + print_error("duplicate", e, n) + return newset + + newset.add(e) + + return newset + + +def payload_check_set_elems(want, got): + if want.find('element') < 0 or want.find('[end]') < 0: + return 0 + + if got.find('element') < 0 or got.find('[end]') < 0: + return 0 + + set_want = payload_check_elems_to_set(want) + set_got = payload_check_elems_to_set(got) + + return set_want == set_got + + +def payload_check(payload_buffer, file, cmd): + file.seek(0, 0) + i = 0 + + if not payload_buffer: + return False + + for lineno, want_line in enumerate(payload_buffer): + line = file.readline() + + if want_line == line: + i += 1 + continue + + if want_line.find('[') < 0 and line.find('[') < 0: + continue + if want_line.find(']') < 0 and line.find(']') < 0: + continue + + if payload_check_set_elems(want_line, line): + continue + + print_differences_warning(file.name, lineno, want_line.strip(), + line.strip(), cmd) + return 0 + + return i > 0 + + +def json_dump_normalize(json_string, human_readable = False): + json_obj = json.loads(json_string) + + if human_readable: + return json.dumps(json_obj, sort_keys = True, + indent = 4, separators = (',', ': ')) + else: + return json.dumps(json_obj, sort_keys = True) + +def json_validate(json_string): + json_obj = json.loads(json_string) + try: + nftables.json_validate(json_obj) + except Exception: + print_error("schema validation failed for input '%s'" % json_string) + print_error(traceback.format_exc()) + +def rule_add(rule, filename, lineno, force_all_family_option, filename_path): + ''' + Adds a rule + ''' + # TODO Check if a rule is added correctly. + ret = warning = error = unit_tests = 0 + + if not table_list or not chain_list: + reason = "Missing table or chain to add rule." + print_error(reason, filename, lineno) + return [-1, warning, error, unit_tests] + + if rule[1].strip() == "ok": + payload_expected = None + payload_path = None + try: + payload_log = open("%s.payload" % filename_path) + payload_path = payload_log.name + payload_expected = payload_find_expected(payload_log, rule[0]) + except: + payload_log = None + + if enable_json_option: + try: + json_log = open("%s.json" % filename_path) + json_input = json_find_expected(json_log, rule[0]) + except: + json_input = None + + if not json_input: + print_error("did not find JSON equivalent for rule '%s'" + % rule[0]) + else: + try: + json_input = json_dump_normalize(json_input) + except ValueError: + reason = "Invalid JSON syntax in rule: %s" % json_input + print_error(reason) + return [-1, warning, error, unit_tests] + + try: + json_log = open("%s.json.output" % filename_path) + json_expected = json_find_expected(json_log, rule[0]) + except: + # will use json_input for comparison + json_expected = None + + if json_expected: + try: + json_expected = json_dump_normalize(json_expected) + except ValueError: + reason = "Invalid JSON syntax in expected output: %s" % json_expected + print_error(reason) + return [-1, warning, error, unit_tests] + + for table in table_list: + if rule[1].strip() == "ok": + table_payload_expected = None + try: + payload_log = open("%s.payload.%s" % (filename_path, table.family)) + payload_path = payload_log.name + table_payload_expected = payload_find_expected(payload_log, rule[0]) + except: + if not payload_log: + print_error("did not find any payload information", + filename_path) + elif not payload_expected: + print_error("did not find payload information for " + "rule '%s'" % rule[0], payload_log.name, 1) + if not table_payload_expected: + table_payload_expected = payload_expected + + for table_chain in table.chains: + chain = chain_get_by_name(table_chain) + unit_tests += 1 + table_flush(table, filename, lineno) + + payload_log = tempfile.TemporaryFile(mode="w+") + + # Add rule and check return code + cmd = "add rule %s %s %s" % (table, chain, rule[0]) + ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink") + + state = rule[1].rstrip() + if (ret in [0,134] and state == "fail") or (ret != 0 and state == "ok"): + if state == "fail": + test_state = "This rule should have failed." + else: + test_state = "This rule should not have failed." + reason = cmd + ": " + test_state + print_error(reason, filename, lineno) + ret = -1 + error += 1 + if not force_all_family_option: + return [ret, warning, error, unit_tests] + + if state == "fail" and ret != 0: + ret = 0 + continue + + if ret != 0: + continue + + # Check for matching payload + if state == "ok" and not payload_check(table_payload_expected, + payload_log, cmd): + error += 1 + + try: + gotf = open("%s.got" % payload_path) + gotf_payload_expected = payload_find_expected(gotf, rule[0]) + gotf.close() + except: + gotf_payload_expected = None + payload_log.seek(0, 0) + if not payload_check(gotf_payload_expected, payload_log, cmd): + gotf = open("%s.got" % payload_path, 'a') + payload_log.seek(0, 0) + gotf.write("# %s\n" % rule[0]) + while True: + line = payload_log.readline() + if line == "": + break + gotf.write(line) + gotf.close() + print_warning("Wrote payload for rule %s" % rule[0], + gotf.name, 1) + + # Check for matching ruleset listing + numeric_proto_old = nftables.set_numeric_proto_output(True) + stateless_old = nftables.set_stateless_output(True) + list_cmd = 'list table %s' % table + rc, pre_output, err = nftables.cmd(list_cmd) + nftables.set_numeric_proto_output(numeric_proto_old) + nftables.set_stateless_output(stateless_old) + + output = pre_output.split(";") + if len(output) < 2: + reason = cmd + ": Listing is broken." + print_error(reason, filename, lineno) + ret = -1 + error += 1 + if not force_all_family_option: + return [ret, warning, error, unit_tests] + continue + + rule_output = output_clean(pre_output, chain) + retest_output = False + if len(rule) == 3: + teoric_exit = rule[2] + retest_output = True + else: + teoric_exit = rule[0] + + if rule_output.rstrip() != teoric_exit.rstrip(): + if rule[0].find("{") != -1: # anonymous sets + if not set_check_element(teoric_exit.rstrip(), + rule_output.rstrip()): + warning += 1 + retest_output = True + print_differences_warning(filename, lineno, + teoric_exit.rstrip(), + rule_output, cmd) + if not force_all_family_option: + return [ret, warning, error, unit_tests] + else: + if len(rule_output) <= 0: + error += 1 + print_differences_error(filename, lineno, cmd) + if not force_all_family_option: + return [ret, warning, error, unit_tests] + + warning += 1 + retest_output = True + print_differences_warning(filename, lineno, + teoric_exit.rstrip(), + rule_output, cmd) + + if not force_all_family_option: + return [ret, warning, error, unit_tests] + + if retest_output: + table_flush(table, filename, lineno) + + # Add rule and check return code + cmd = "add rule %s %s %s" % (table, chain, rule_output.rstrip()) + ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink") + + if ret != 0: + test_state = "Replaying rule failed." + reason = cmd + ": " + test_state + print_warning(reason, filename, lineno) + ret = -1 + error += 1 + if not force_all_family_option: + return [ret, warning, error, unit_tests] + # Check for matching payload + elif not payload_check(table_payload_expected, + payload_log, cmd): + error += 1 + + if not enable_json_option: + continue + + # Generate JSON equivalent for rule if not found + if not json_input: + json_old = nftables.set_json_output(True) + rc, json_output, err = nftables.cmd(list_cmd) + nftables.set_json_output(json_old) + + json_output = json.loads(json_output) + for item in json_output["nftables"]: + if "rule" in item: + del(item["rule"]["handle"]) + json_output = item["rule"] + break + json_input = json.dumps(json_output["expr"], sort_keys = True) + + gotf = open("%s.json.got" % filename_path, 'a') + jdump = json_dump_normalize(json_input, True) + gotf.write("# %s\n%s\n\n" % (rule[0], jdump)) + gotf.close() + print_warning("Wrote JSON equivalent for rule %s" % rule[0], + gotf.name, 1) + + table_flush(table, filename, lineno) + payload_log = tempfile.TemporaryFile(mode="w+") + + # Add rule in JSON format + cmd = json.dumps({ "nftables": [{ "add": { "rule": { + "family": table.family, + "table": table.name, + "chain": chain.name, + "expr": json.loads(json_input), + }}}]}) + + if enable_json_schema: + json_validate(cmd) + + json_old = nftables.set_json_output(True) + ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink") + nftables.set_json_output(json_old) + + if ret != 0: + reason = "Failed to add JSON equivalent rule" + print_error(reason, filename, lineno) + continue + + # Check for matching payload + if not payload_check(table_payload_expected, payload_log, cmd): + error += 1 + gotf = open("%s.json.payload.got" % filename_path, 'a') + payload_log.seek(0, 0) + gotf.write("# %s\n" % rule[0]) + while True: + line = payload_log.readline() + if line == "": + break + gotf.write(line) + gotf.close() + print_warning("Wrote JSON payload for rule %s" % rule[0], + gotf.name, 1) + + # Check for matching ruleset listing + numeric_proto_old = nftables.set_numeric_proto_output(True) + stateless_old = nftables.set_stateless_output(True) + json_old = nftables.set_json_output(True) + rc, json_output, err = nftables.cmd(list_cmd) + nftables.set_json_output(json_old) + nftables.set_numeric_proto_output(numeric_proto_old) + nftables.set_stateless_output(stateless_old) + + if enable_json_schema: + json_validate(json_output) + + json_output = json.loads(json_output) + for item in json_output["nftables"]: + if "rule" in item: + del(item["rule"]["handle"]) + json_output = item["rule"] + break + json_output = json.dumps(json_output["expr"], sort_keys = True) + + if not json_expected and json_output != json_input: + print_differences_warning(filename, lineno, + json_input, json_output, cmd) + error += 1 + gotf = open("%s.json.output.got" % filename_path, 'a') + jdump = json_dump_normalize(json_output, True) + gotf.write("# %s\n%s\n\n" % (rule[0], jdump)) + gotf.close() + print_warning("Wrote JSON output for rule %s" % rule[0], + gotf.name, 1) + # prevent further warnings and .got file updates + json_expected = json_output + elif json_expected and json_output != json_expected: + print_differences_warning(filename, lineno, + json_expected, json_output, cmd) + error += 1 + + return [ret, warning, error, unit_tests] + + +def cleanup_on_exit(): + for table in table_list: + for table_chain in table.chains: + chain = chain_get_by_name(table_chain) + chain_delete(chain, table, "", "") + if all_set: + set_delete(table) + if obj_list: + obj_delete(table) + table_delete(table) + + +def signal_handler(signal, frame): + global signal_received + signal_received = 1 + + +def execute_cmd(cmd, filename, lineno, stdout_log=False, debug=False): + ''' + Executes a command, checks for segfaults and returns the command exit + code. + + :param cmd: string with the command to be executed + :param filename: name of the file tested (used for print_error purposes) + :param lineno: line number being tested (used for print_error purposes) + :param stdout_log: redirect stdout to this file instead of global log_file + :param debug: temporarily set these debug flags + ''' + global log_file + print("command: {}".format(cmd), file=log_file) + if debug_option: + print(cmd) + + log_file.flush() + + if debug: + debug_old = nftables.get_debug() + nftables.set_debug(debug) + + ret, out, err = nftables.cmd(cmd) + + if not stdout_log: + stdout_log = log_file + + stdout_log.write(out) + stdout_log.flush() + log_file.write(err) + log_file.flush() + + if debug: + nftables.set_debug(debug_old) + + return ret + + +def print_result(filename, tests, warning, error): + return str(filename) + ": " + str(tests) + " unit tests, " + str(error) + \ + " error, " + str(warning) + " warning" + + +def print_result_all(filename, tests, warning, error, unit_tests): + return str(filename) + ": " + str(tests) + " unit tests, " + \ + str(unit_tests) + " total test executed, " + str(error) + \ + " error, " + str(warning) + " warning" + + +def table_process(table_line, filename, lineno): + table_info = table_line.split(";") + table = Table(table_info[0], table_info[1], table_info[2].split(",")) + + return table_create(table, filename, lineno) + + +def chain_process(chain_line, lineno): + chain_info = chain_line.split(";") + chain_list.append(Chain(chain_info[0], chain_info[1], lineno)) + + return 0 + + +def set_process(set_line, filename, lineno): + test_result = set_line[1] + timeout="" + + tokens = set_line[0].split(" ") + set_name = tokens[0] + set_type = tokens[2] + set_data = "" + set_flags = "" + + i = 3 + while len(tokens) > i and tokens[i] == ".": + set_type += " . " + tokens[i+1] + i += 2 + + while len(tokens) > i and tokens[i] == ":": + set_data = tokens[i+1] + i += 2 + + if len(tokens) == i+2 and tokens[i] == "timeout": + timeout = "timeout " + tokens[i+1] + ";" + i += 2 + + if len(tokens) == i+2 and tokens[i] == "flags": + set_flags = tokens[i+1] + elif len(tokens) != i: + print_error(set_name + " bad flag: " + tokens[i], filename, lineno) + + s = Set("", "", set_name, set_type, set_data, timeout, set_flags) + + if set_data == "": + ret = set_add(s, test_result, filename, lineno) + else: + ret = map_add(s, test_result, filename, lineno) + + if ret == 0: + all_set[set_name] = set() + + return ret + + +def set_element_process(element_line, filename, lineno): + rule_state = element_line[1] + element_line = element_line[0] + space = element_line.find(" ") + set_name = element_line[:space] + set_element = element_line[space:].split(",") + + return set_add_elements(set_element, set_name, rule_state, filename, lineno) + + +def obj_process(obj_line, filename, lineno): + test_result = obj_line[1] + + tokens = obj_line[0].split(" ") + obj_name = tokens[0] + obj_type = tokens[2] + obj_spcf = "" + + if obj_type == "ct" and tokens[3] == "helper": + obj_type = "ct helper" + tokens[3] = "" + + if obj_type == "ct" and tokens[3] == "timeout": + obj_type = "ct timeout" + tokens[3] = "" + + if obj_type == "ct" and tokens[3] == "expectation": + obj_type = "ct expectation" + tokens[3] = "" + + if len(tokens) > 3: + obj_spcf = " ".join(tokens[3:]) + + o = Obj("", "", obj_name, obj_type, obj_spcf) + + ret = obj_add(o, test_result, filename, lineno) + if ret == 0: + obj_list.append(o) + + return ret + + +def payload_find_expected(payload_log, rule): + ''' + Find the netlink payload that should be generated by given rule in + payload_log + + :param payload_log: open file handle of the payload data + :param rule: nft rule we are going to add + ''' + found = 0 + payload_buffer = [] + + while True: + line = payload_log.readline() + if not line: + break + + if line[0] == "#": # rule start + rule_line = line.strip()[2:] + + if rule_line == rule.strip(): + found = 1 + continue + + if found == 1: + payload_buffer.append(line) + if line.isspace(): + return payload_buffer + + payload_log.seek(0, 0) + return payload_buffer + + +def json_find_expected(json_log, rule): + ''' + Find the corresponding JSON for given rule + + :param json_log: open file handle of the json data + :param rule: nft rule we are going to add + ''' + found = 0 + json_buffer = "" + + while True: + line = json_log.readline() + if not line: + break + + if line[0] == "#": # rule start + rule_line = line.strip()[2:] + + if rule_line == rule.strip(): + found = 1 + continue + + if found == 1: + json_buffer += line.rstrip("\n").strip() + if line.isspace(): + return json_buffer + + json_log.seek(0, 0) + return json_buffer + + +def run_test_file(filename, force_all_family_option, specific_file): + ''' + Runs a test file + + :param filename: name of the file with the test rules + ''' + filename_path = os.path.join(TESTS_PATH, filename) + f = open(filename_path) + tests = passed = total_unit_run = total_warning = total_error = 0 + + for lineno, line in enumerate(f): + sys.stdout.flush() + + if signal_received == 1: + print("\nSignal received. Cleaning up and Exitting...") + cleanup_on_exit() + sys.exit(0) + + if line.isspace(): + continue + + if line[0] == "#": # Command-line + continue + + if line[0] == '*': # Table + table_line = line.rstrip()[1:] + ret = table_process(table_line, filename, lineno) + if ret != 0: + break + continue + + if line[0] == ":": # Chain + chain_line = line.rstrip()[1:] + ret = chain_process(chain_line, lineno) + if ret != 0: + break + continue + + if line[0] == "!": # Adds this set + set_line = line.rstrip()[1:].split(";") + ret = set_process(set_line, filename, lineno) + tests += 1 + if ret == -1: + continue + passed += 1 + continue + + if line[0] == "?": # Adds elements in a set + element_line = line.rstrip()[1:].split(";") + ret = set_element_process(element_line, filename, lineno) + tests += 1 + if ret == -1: + continue + + passed += 1 + continue + + if line[0] == "%": # Adds this object + brace = line.rfind("}") + if brace < 0: + obj_line = line.rstrip()[1:].split(";") + else: + obj_line = (line[1:brace+1], line[brace+2:].rstrip()) + + ret = obj_process(obj_line, filename, lineno) + tests += 1 + if ret == -1: + continue + passed += 1 + continue + + # Rule + rule = line.split(';') # rule[1] Ok or FAIL + if len(rule) == 1 or len(rule) > 3 or rule[1].rstrip() \ + not in {"ok", "fail"}: + reason = "Skipping malformed rule test. (%s)" % line.rstrip('\n') + print_warning(reason, filename, lineno) + continue + + if line[0] == "-": # Run omitted lines + if need_fix_option: + rule[0] = rule[0].rstrip()[1:].strip() + else: + continue + elif need_fix_option: + continue + + result = rule_add(rule, filename, lineno, force_all_family_option, + filename_path) + tests += 1 + ret = result[0] + warning = result[1] + total_warning += warning + total_error += result[2] + total_unit_run += result[3] + + if ret != 0: + continue + + if warning == 0: # All ok. + passed += 1 + + # Delete rules, sets, chains and tables + for table in table_list: + # We delete chains + for table_chain in table.chains: + chain = chain_get_by_name(table_chain) + chain_delete(chain, table, filename, lineno) + + # We delete sets. + if all_set: + ret = set_delete(table, filename, lineno) + if ret != 0: + reason = "There is a problem when we delete a set" + print_error(reason, filename, lineno) + + # We delete tables. + table_delete(table, filename, lineno) + + if specific_file: + if force_all_family_option: + print(print_result_all(filename, tests, total_warning, total_error, + total_unit_run)) + else: + print(print_result(filename, tests, total_warning, total_error)) + else: + if tests == passed and tests > 0: + print(filename + ": " + Colors.GREEN + "OK" + Colors.ENDC) + + f.close() + del table_list[:] + del chain_list[:] + all_set.clear() + + return [tests, passed, total_warning, total_error, total_unit_run] + +def spawn_netns(): + # prefer unshare module + try: + import unshare + unshare.unshare(unshare.CLONE_NEWNET) + return True + except: + pass + + # sledgehammer style: + # - call ourselves prefixed by 'unshare -n' if found + # - pass extra --no-netns parameter to avoid another recursion + try: + import shutil + + unshare = shutil.which("unshare") + if unshare is None: + return False + + sys.argv.append("--no-netns") + if debug_option: + print("calling: ", [unshare, "-n", sys.executable] + sys.argv) + os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv) + except: + pass + + return False + +def main(): + parser = argparse.ArgumentParser(description='Run nft tests') + + parser.add_argument('filenames', nargs='*', metavar='path/to/file.t', + help='Run only these tests') + + parser.add_argument('-d', '--debug', action='store_true', dest='debug', + help='enable debugging mode') + + parser.add_argument('-e', '--need-fix', action='store_true', + dest='need_fix_line', help='run rules that need a fix') + + parser.add_argument('-f', '--force-family', action='store_true', + dest='force_all_family', + help='keep testing all families on error') + + parser.add_argument('-H', '--host', action='store_true', + help='run tests against installed libnftables.so.1') + + parser.add_argument('-j', '--enable-json', action='store_true', + dest='enable_json', + help='test JSON functionality as well') + + parser.add_argument('-l', '--library', default=None, + help='path to libntables.so.1, overrides --host') + + parser.add_argument('-N', '--no-netns', action='store_true', + dest='no_netns', + help='Do not run in own network namespace') + + parser.add_argument('-s', '--schema', action='store_true', + dest='enable_schema', + help='verify json input/output against schema') + + parser.add_argument('-v', '--version', action='version', + version='1.0', + help='Print the version information') + + args = parser.parse_args() + global debug_option, need_fix_option, enable_json_option, enable_json_schema + debug_option = args.debug + need_fix_option = args.need_fix_line + force_all_family_option = args.force_all_family + enable_json_option = args.enable_json + enable_json_schema = args.enable_schema + specific_file = False + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + if os.getuid() != 0: + print("You need to be root to run this, sorry") + return + + if not args.no_netns and not spawn_netns(): + print_warning("cannot run in own namespace, connectivity might break") + + # Change working directory to repository root + os.chdir(TESTS_PATH + "/../..") + + check_lib_path = True + if args.library is None: + if args.host: + args.library = 'libnftables.so.1' + check_lib_path = False + else: + args.library = 'src/.libs/libnftables.so.1' + + if check_lib_path and not os.path.exists(args.library): + print("The nftables library at '%s' does not exist. " + "You need to build the project." % args.library) + return + + if args.enable_schema and not args.enable_json: + print_error("Option --schema requires option --json") + return + + global nftables + nftables = Nftables(sofile = args.library) + + test_files = files_ok = run_total = 0 + tests = passed = warnings = errors = 0 + global log_file + try: + log_file = open(LOGFILE, 'w') + print_info("Log will be available at %s" % LOGFILE) + except IOError: + print_error("Cannot open log file %s" % LOGFILE) + return + + file_list = [] + if args.filenames: + file_list = args.filenames + if len(args.filenames) == 1: + specific_file = True + else: + for directory in TESTS_DIRECTORY: + path = os.path.join(TESTS_PATH, directory) + for root, dirs, files in os.walk(path): + for f in files: + if f.endswith(".t"): + file_list.append(os.path.join(directory, f)) + + for filename in file_list: + result = run_test_file(filename, force_all_family_option, specific_file) + file_tests = result[0] + file_passed = result[1] + file_warnings = result[2] + file_errors = result[3] + file_unit_run = result[4] + + test_files += 1 + + if file_warnings == 0 and file_tests == file_passed: + files_ok += 1 + if file_tests: + tests += file_tests + passed += file_passed + errors += file_errors + warnings += file_warnings + if force_all_family_option: + run_total += file_unit_run + + if test_files == 0: + print("No test files to run") + else: + if not specific_file: + if force_all_family_option: + print("%d test files, %d files passed, %d unit tests, " % (test_files, files_ok, tests)) + print("%d total executed, %d error, %d warning" % (run_total, errors,warnings)) + else: + print("%d test files, %d files passed, %d unit tests, " % (test_files, files_ok, tests)) + print("%d error, %d warning" % (errors, warnings)) + +if __name__ == '__main__': + main() |