#!/usr/bin/env python3 # Unit tests for Ndiff. import subprocess import sys import unittest # Prevent loading PyXML import xml xml.__path__ = [x for x in xml.__path__ if "_xmlplus" not in x] import xml.dom.minidom import imp dont_write_bytecode = sys.dont_write_bytecode sys.dont_write_bytecode = True ndiff = imp.load_source("ndiff", "ndiff.py") for x in dir(ndiff): if not x.startswith("_"): globals()[x] = getattr(ndiff, x) sys.dont_write_bytecode = dont_write_bytecode del dont_write_bytecode import io class scan_test(unittest.TestCase): """Test the Scan class.""" def test_empty(self): scan = Scan() scan.load_from_file("test-scans/empty.xml") self.assertEqual(len(scan.hosts), 0) self.assertNotEqual(scan.start_date, None) self.assertNotEqual(scan.end_date, None) def test_single(self): scan = Scan() scan.load_from_file("test-scans/single.xml") self.assertEqual(len(scan.hosts), 1) def test_simple(self): """Test that the correct number of known ports is returned when there are no extraports.""" scan = Scan() scan.load_from_file("test-scans/simple.xml") host = scan.hosts[0] self.assertEqual(len(host.ports), 2) def test_extraports(self): scan = Scan() scan.load_from_file("test-scans/single.xml") host = scan.hosts[0] self.assertEqual(len(host.ports), 5) self.assertEqual(list(host.extraports.items()), [("filtered", 95)]) def test_extraports_multi(self): """Test that the correct number of known ports is returned when there are extraports in more than one state.""" scan = Scan() scan.load_from_file("test-scans/complex.xml") host = scan.hosts[0] self.assertEqual(len(host.ports), 6) self.assertEqual(set(host.extraports.items()), set([("filtered", 95), ("open|filtered", 99)])) def test_nmaprun(self): """Test that nmaprun information is recorded.""" scan = Scan() scan.load_from_file("test-scans/empty.xml") self.assertEqual(scan.scanner, "nmap") self.assertEqual(scan.version, "4.90RC2") self.assertEqual(scan.args, "nmap -oX empty.xml -p 1-100") def test_addresses(self): """Test that addresses are recorded.""" scan = Scan() scan.load_from_file("test-scans/simple.xml") host = scan.hosts[0] self.assertEqual(host.addresses, [IPv4Address("64.13.134.52")]) def test_hostname(self): """Test that hostnames are recorded.""" scan = Scan() scan.load_from_file("test-scans/simple.xml") host = scan.hosts[0] self.assertEqual(host.hostnames, ["scanme.nmap.org"]) def test_os(self): """Test that OS information is recorded.""" scan = Scan() scan.load_from_file("test-scans/complex.xml") host = scan.hosts[0] self.assertTrue(len(host.os) > 0) def test_script(self): """Test that script results are recorded.""" scan = Scan() scan.load_from_file("test-scans/complex.xml") host = scan.hosts[0] self.assertTrue(len(host.script_results) > 0) self.assertTrue(len(host.ports[(22, "tcp")].script_results) > 0) # This test is commented out because Nmap XML doesn't store any information # about down hosts, not even the fact that they are down. Recovering the list # of scanned hosts to infer which ones are down would involve parsing the # targets out of the /nmaprun/@args attribute (which is non-trivial) and # possibly looking up their addresses. # def test_down_state(self): # """Test that hosts that are not marked "up" are in the "down" # state.""" # scan = Scan() # scan.load_from_file("test-scans/down.xml") # self.assertTrue(len(scan.hosts) == 1) # host = scan.hosts[0] # self.assertTrue(host.state == "down") class host_test(unittest.TestCase): """Test the Host class.""" def test_empty(self): h = Host() self.assertEqual(len(h.addresses), 0) self.assertEqual(len(h.hostnames), 0) self.assertEqual(len(h.ports), 0) self.assertEqual(len(h.extraports), 0) self.assertEqual(len(h.os), 0) def test_format_name(self): h = Host() self.assertTrue(isinstance(h.format_name(), str)) h.add_address(IPv4Address("127.0.0.1")) self.assertTrue("127.0.0.1" in h.format_name()) h.add_address(IPv6Address("::1")) self.assertTrue("127.0.0.1" in h.format_name()) self.assertTrue("::1" in h.format_name()) h.add_hostname("localhost") self.assertTrue("127.0.0.1" in h.format_name()) self.assertTrue("::1" in h.format_name()) self.assertTrue("localhost" in h.format_name()) def test_empty_get_port(self): h = Host() for num in 10, 100, 1000, 10000: for proto in ("tcp", "udp", "ip"): port = h.ports.get((num, proto)) self.assertEqual(port, None) def test_add_port(self): h = Host() spec = (10, "tcp") port = h.ports.get(spec) self.assertEqual(port, None) h.add_port(Port(spec, "open")) self.assertEqual(len(h.ports), 1) port = h.ports[spec] self.assertEqual(port.state, "open") h.add_port(Port(spec, "closed")) self.assertEqual(len(h.ports), 1) port = h.ports[spec] self.assertEqual(port.state, "closed") spec = (22, "tcp") port = h.ports.get(spec) self.assertEqual(port, None) port = Port(spec) port.state = "open" port.service.name = "ssh" h.add_port(port) self.assertEqual(len(h.ports), 2) port = h.ports[spec] self.assertEqual(port.state, "open") self.assertEqual(port.service.name, "ssh") def test_extraports(self): h = Host() self.assertFalse(h.is_extraports("open")) self.assertFalse(h.is_extraports("closed")) self.assertFalse(h.is_extraports("filtered")) h.extraports["closed"] = 10 self.assertFalse(h.is_extraports("open")) self.assertTrue(h.is_extraports("closed")) self.assertFalse(h.is_extraports("filtered")) h.extraports["filtered"] = 10 self.assertFalse(h.is_extraports("open")) self.assertTrue(h.is_extraports("closed")) self.assertTrue(h.is_extraports("filtered")) del h.extraports["closed"] del h.extraports["filtered"] self.assertFalse(h.is_extraports("open")) self.assertFalse(h.is_extraports("closed")) self.assertFalse(h.is_extraports("filtered")) def test_parse(self): s = Scan() s.load_from_file("test-scans/single.xml") h = s.hosts[0] self.assertEqual(len(h.ports), 5) self.assertEqual(len(h.extraports), 1) self.assertEqual(list(h.extraports.keys())[0], "filtered") self.assertEqual(list(h.extraports.values())[0], 95) self.assertEqual(h.state, "up") class address_test(unittest.TestCase): """Test the Address class.""" def test_ipv4_new(self): a = Address.new("ipv4", "127.0.0.1") self.assertEqual(a.type, "ipv4") def test_ipv6_new(self): a = Address.new("ipv6", "::1") self.assertEqual(a.type, "ipv6") def test_mac_new(self): a = Address.new("mac", "00:00:00:00:00:00") self.assertEqual(a.type, "mac") def test_unknown_new(self): self.assertRaises(ValueError, Address.new, "aaa", "") def test_compare(self): """Test that addresses with the same contents compare equal.""" a = IPv4Address("127.0.0.1") self.assertEqual(a, a) b = IPv4Address("127.0.0.1") self.assertEqual(a, b) c = Address.new("ipv4", "127.0.0.1") self.assertEqual(a, c) self.assertEqual(b, c) d = IPv4Address("1.1.1.1") self.assertNotEqual(a, d) e = IPv6Address("::1") self.assertEqual(e, e) self.assertNotEqual(a, e) class port_test(unittest.TestCase): """Test the Port class.""" def test_spec_string(self): p = Port((10, "tcp")) self.assertEqual(p.spec_string(), "10/tcp") p = Port((100, "ip")) self.assertEqual(p.spec_string(), "100/ip") def test_state_string(self): p = Port((10, "tcp")) self.assertEqual(p.state_string(), "unknown") class service_test(unittest.TestCase): """Test the Service class.""" def test_compare(self): """Test that services with the same contents compare equal.""" a = Service() a.name = "ftp" a.product = "FooBar FTP" a.version = "1.1.1" a.tunnel = "ssl" self.assertEqual(a, a) b = Service() b.name = "ftp" b.product = "FooBar FTP" b.version = "1.1.1" b.tunnel = "ssl" self.assertEqual(a, b) b.name = "http" self.assertNotEqual(a, b) c = Service() self.assertNotEqual(a, c) def test_tunnel(self): serv = Service() serv.name = "http" serv.tunnel = "ssl" self.assertEqual(serv.name_string(), "ssl/http") def test_version_string(self): serv = Service() serv.product = "FooBar" self.assertTrue(len(serv.version_string()) > 0) serv = Service() serv.version = "1.2.3" self.assertTrue(len(serv.version_string()) > 0) serv = Service() serv.extrainfo = "misconfigured" self.assertTrue(len(serv.version_string()) > 0) serv = Service() serv.product = "FooBar" serv.version = "1.2.3" # Must match Nmap output. self.assertEqual(serv.version_string(), "%s %s" % (serv.product, serv.version)) serv.extrainfo = "misconfigured" self.assertEqual(serv.version_string(), "%s %s (%s)" % (serv.product, serv.version, serv.extrainfo)) class ScanDiffSub(ScanDiff): """A subclass of ScanDiff that counts diffs for testing.""" def __init__(self, scan_a, scan_b, f=sys.stdout): ScanDiff.__init__(self, scan_a, scan_b, f) self.pre_script_result_diffs = [] self.post_script_result_diffs = [] self.host_diffs = [] def output_beginning(self): pass def output_pre_scripts(self, pre_script_result_diffs): self.pre_script_result_diffs = pre_script_result_diffs def output_post_scripts(self, post_script_result_diffs): self.post_script_result_diffs = post_script_result_diffs def output_host_diff(self, h_diff): self.host_diffs.append(h_diff) def output_ending(self): pass class scan_diff_test(unittest.TestCase): """Test the ScanDiff class.""" def setUp(self): self.blackhole = open("/dev/null", "w") def tearDown(self): self.blackhole.close() def test_self(self): scan = Scan() scan.load_from_file("test-scans/complex.xml") diff = ScanDiffText(scan, scan, self.blackhole) cost = diff.output() self.assertEqual(cost, 0) diff = ScanDiffXML(scan, scan, self.blackhole) cost = diff.output() self.assertEqual(cost, 0) def test_unknown_up(self): a = Scan() a.load_from_file("test-scans/empty.xml") b = Scan() b.load_from_file("test-scans/simple.xml") diff = ScanDiffSub(a, b, self.blackhole) diff.output() self.assertEqual(len(diff.pre_script_result_diffs), 0) self.assertEqual(len(diff.post_script_result_diffs), 0) self.assertEqual(len(diff.host_diffs), 1) h_diff = diff.host_diffs[0] self.assertEqual(h_diff.host_a.state, None) self.assertEqual(h_diff.host_b.state, "up") def test_up_unknown(self): a = Scan() a.load_from_file("test-scans/simple.xml") b = Scan() b.load_from_file("test-scans/empty.xml") diff = ScanDiffSub(a, b, self.blackhole) diff.output() self.assertEqual(len(diff.pre_script_result_diffs), 0) self.assertEqual(len(diff.post_script_result_diffs), 0) self.assertEqual(len(diff.host_diffs), 1) h_diff = diff.host_diffs[0] self.assertEqual(h_diff.host_a.state, "up") self.assertEqual(h_diff.host_b.state, None) def test_diff_is_effective(self): """Test that a scan diff is effective. This means that if the recommended changes are applied to the first scan the scans become the same.""" PAIRS = ( ("empty", "empty"), ("simple", "complex"), ("complex", "simple"), ("single", "os"), ("os", "single"), ("random-1", "simple"), ("simple", "random-1"), ) for pair in PAIRS: a = Scan() a.load_from_file("test-scans/%s.xml" % pair[0]) b = Scan() b.load_from_file("test-scans/%s.xml" % pair[1]) diff = ScanDiffSub(a, b) scan_apply_diff(a, diff) diff = ScanDiffSub(a, b) self.assertEqual(diff.host_diffs, []) class host_diff_test(unittest.TestCase): """Test the HostDiff class.""" def test_empty(self): a = Host() b = Host() diff = HostDiff(a, b) self.assertFalse(diff.id_changed) self.assertFalse(diff.state_changed) self.assertFalse(diff.os_changed) self.assertFalse(diff.extraports_changed) self.assertEqual(diff.cost, 0) def test_self(self): h = Host() h.add_port(Port((10, "tcp"), "open")) h.add_port(Port((22, "tcp"), "closed")) diff = HostDiff(h, h) self.assertFalse(diff.id_changed) self.assertFalse(diff.state_changed) self.assertFalse(diff.os_changed) self.assertFalse(diff.extraports_changed) self.assertEqual(diff.cost, 0) def test_state_change(self): a = Host() b = Host() a.state = "up" b.state = "down" diff = HostDiff(a, b) self.assertTrue(diff.state_changed) self.assertTrue(diff.cost > 0) def test_state_change_unknown(self): a = Host() b = Host() a.state = "up" diff = HostDiff(a, b) self.assertTrue(diff.state_changed) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.state_changed) self.assertTrue(diff.cost > 0) def test_address_change(self): a = Host() b = Host() b.add_address(Address.new("ipv4", "127.0.0.1")) diff = HostDiff(a, b) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) a.add_address(Address.new("ipv4", "1.1.1.1")) diff = HostDiff(a, b) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) def test_hostname_change(self): a = Host() b = Host() b.add_hostname("host-1") diff = HostDiff(a, b) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) a.add_address("host-2") diff = HostDiff(a, b) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.id_changed) self.assertTrue(diff.cost > 0) def test_port_state_change(self): a = Host() b = Host() spec = (10, "tcp") a.add_port(Port(spec, "open")) b.add_port(Port(spec, "closed")) diff = HostDiff(a, b) self.assertTrue(len(diff.ports) > 0) self.assertEqual(set(diff.ports), set(diff.port_diffs.keys())) self.assertTrue(diff.cost > 0) def test_port_state_change_unknown(self): a = Host() b = Host() b.add_port(Port((10, "tcp"), "open")) diff = HostDiff(a, b) self.assertTrue(len(diff.ports) > 0) self.assertEqual(set(diff.ports), set(diff.port_diffs.keys())) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(len(diff.ports) > 0) self.assertEqual(set(diff.ports), set(diff.port_diffs.keys())) self.assertTrue(diff.cost > 0) def test_port_state_change_multi(self): a = Host() b = Host() a.add_port(Port((10, "tcp"), "open")) a.add_port(Port((20, "tcp"), "closed")) a.add_port(Port((30, "tcp"), "open")) b.add_port(Port((10, "tcp"), "open")) b.add_port(Port((20, "tcp"), "open")) b.add_port(Port((30, "tcp"), "open")) diff = HostDiff(a, b) self.assertTrue(diff.cost > 0) def test_os_change(self): a = Host() b = Host() a.os.append("os-1") diff = HostDiff(a, b) self.assertTrue(diff.os_changed) self.assertTrue(len(diff.os_diffs) > 0) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.os_changed) self.assertTrue(len(diff.os_diffs) > 0) self.assertTrue(diff.cost > 0) b.os.append("os-2") diff = HostDiff(a, b) self.assertTrue(diff.os_changed) self.assertTrue(len(diff.os_diffs) > 0) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.os_changed) self.assertTrue(len(diff.os_diffs) > 0) self.assertTrue(diff.cost > 0) def test_extraports_change(self): a = Host() b = Host() a.extraports = {"open": 100} diff = HostDiff(a, b) self.assertTrue(diff.extraports_changed) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.extraports_changed) self.assertTrue(diff.cost > 0) b.extraports = {"closed": 100} diff = HostDiff(a, b) self.assertTrue(diff.extraports_changed) self.assertTrue(diff.cost > 0) diff = HostDiff(b, a) self.assertTrue(diff.extraports_changed) self.assertTrue(diff.cost > 0) def test_diff_is_effective(self): """Test that a host diff is effective. This means that if the recommended changes are applied to the first host the hosts become the same.""" a = Host() b = Host() a.state = "up" b.state = "down" a.add_port(Port((10, "tcp"), "open")) a.add_port(Port((20, "tcp"), "closed")) a.add_port(Port((40, "udp"), "open|filtered")) b.add_port(Port((10, "tcp"), "open")) b.add_port(Port((30, "tcp"), "open")) a.add_port(Port((40, "udp"), "open")) a.add_hostname("a") a.add_hostname("localhost") b.add_hostname("b") b.add_hostname("localhost") b.add_hostname("b.example.com") b.add_address(Address.new("ipv4", "1.2.3.4")) a.os = ["os-1", "os-2"] b.os = ["os-2", "os-3"] a.extraports = {"filtered": 99} diff = HostDiff(a, b) host_apply_diff(a, diff) diff = HostDiff(a, b) self.assertFalse(diff.id_changed) self.assertFalse(diff.state_changed) self.assertFalse(diff.os_changed) self.assertFalse(diff.extraports_changed) self.assertEqual(diff.cost, 0) class port_diff_test(unittest.TestCase): """Test the PortDiff class.""" def test_equal(self): spec = (10, "tcp") a = Port(spec) b = Port(spec) diff = PortDiff(a, b) self.assertEqual(diff.cost, 0) def test_self(self): p = Port((10, "tcp")) diff = PortDiff(p, p) self.assertEqual(diff.cost, 0) def test_state_change(self): spec = (10, "tcp") a = Port(spec) a.state = "open" b = Port(spec) b.state = "closed" diff = PortDiff(a, b) self.assertTrue(diff.cost > 0) self.assertEqual(PortDiff(a, diff.port_a).cost, 0) self.assertEqual(PortDiff(b, diff.port_b).cost, 0) def test_id_change(self): a = Port((10, "tcp")) a.state = "open" b = Port((20, "tcp")) b.state = "open" diff = PortDiff(a, b) self.assertTrue(diff.cost > 0) self.assertEqual(PortDiff(a, diff.port_a).cost, 0) self.assertEqual(PortDiff(b, diff.port_b).cost, 0) class table_test(unittest.TestCase): """Test the table class.""" def test_empty(self): t = Table("") self.assertEqual(str(t), "") t = Table("***") self.assertEqual(str(t), "") t = Table("* * *") self.assertEqual(str(t), "") def test_none(self): """Test that None is treated like an empty string when it is not at the end of a row.""" t = Table("* * *") t.append((None, "a", "b")) self.assertEqual(str(t), " a b") t = Table("* * *") t.append(("a", None, "b")) self.assertEqual(str(t), "a b") t = Table("* * *") t.append((None, None, "a")) self.assertEqual(str(t), " a") def test_prefix(self): t = Table("<<<") t.append(("a", "b", "c")) self.assertEqual(str(t), "<<>>*!!!") t.append(()) self.assertEqual(str(t), "<<<") t = Table("<<<*>>>*!!!") t.append(("a")) self.assertEqual(str(t), "<<>>") t = Table("<<<*>>>*!!!") t.append(("a", "b", "c", "d")) self.assertEqual(str(t), "<<>>b!!!cd") def test_append_raw(self): """Test the append_raw method that inserts an unformatted row.""" t = Table("<* * *>") t.append(("1", "2", "3")) t.append_raw(" row ") self.assertEqual(str(t), "<1 2 3>\n row ") t.append(("4", "5", "6")) self.assertEqual(str(t), "<1 2 3>\n row \n<4 5 6>") def test_strip(self): """Test that trailing whitespace is stripped.""" t = Table("* * * ") t.append(("a", "b", None)) self.assertEqual(str(t), "a b") t = Table("* * *") t.append(("a", None, None)) self.assertEqual(str(t), "a") t = Table("* * *") t.append(("a", "b", "")) self.assertEqual(str(t), "a b") t = Table("* * *") t.append(("a", "", "")) self.assertEqual(str(t), "a") def test_newline(self): """Test that there is no trailing newline in a table.""" t = Table("*") self.assertFalse(str(t).endswith("\n")) t.append(("a")) self.assertFalse(str(t).endswith("\n")) t.append(("b")) self.assertFalse(str(t).endswith("\n")) class scan_diff_xml_test(unittest.TestCase): def setUp(self): a = Scan() a.load_from_file("test-scans/empty.xml") b = Scan() b.load_from_file("test-scans/simple.xml") f = io.StringIO() self.scan_diff = ScanDiffXML(a, b, f) self.scan_diff.output() self.xml = f.getvalue() f.close() def test_well_formed(self): try: document = xml.dom.minidom.parseString(self.xml) except Exception as e: self.fail("Parsing XML diff output caused the exception: %s" % str(e)) def scan_apply_diff(scan, diff): """Apply a scan diff to the given scan.""" for h_diff in diff.host_diffs: host = h_diff.host_a or h_diff.host_b if host not in scan.hosts: scan.hosts.append(host) host_apply_diff(host, h_diff) def host_apply_diff(host, diff): """Apply a host diff to the given host.""" if diff.state_changed: host.state = diff.host_b.state if diff.id_changed: host.addresses = diff.host_b.addresses[:] host.hostnames = diff.host_b.hostnames[:] if diff.os_changed: host.os = diff.host_b.os[:] if diff.extraports_changed: for state in list(host.extraports.keys()): for port in list(host.ports.values()): if port.state == state: del host.ports[port.spec] host.extraports = diff.host_b.extraports.copy() for port in diff.ports: port_b = diff.port_diffs[port].port_b if port_b.state is None: del host.ports[port.spec] else: host.ports[port.spec] = diff.port_diffs[port].port_b for sr_diff in diff.script_result_diffs: sr_a = sr_diff.sr_a sr_b = sr_diff.sr_b if sr_a is None: host.script_results.append(sr_b) elif sr_b is None: host.script_results.remove(sr_a) else: host.script_results[host.script_results.index(sr_a)] = sr_b host.script_results.sort() def call_quiet(args, **kwargs): """Run a command with subprocess.call and hide its output.""" return subprocess.call(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env={'PYTHONPATH': "."}, **kwargs) class exit_code_test(unittest.TestCase): NDIFF = "./scripts/ndiff" def test_exit_equal(self): """Test that the exit code is 0 when the diff is empty.""" for format in ("--text", "--xml"): code = call_quiet([self.NDIFF, format, "test-scans/simple.xml", "test-scans/simple.xml"]) self.assertEqual(code, 0) # Should be independent of verbosity. for format in ("--text", "--xml"): code = call_quiet([self.NDIFF, "-v", format, "test-scans/simple.xml", "test-scans/simple.xml"]) self.assertEqual(code, 0) def test_exit_different(self): """Test that the exit code is 1 when the diff is not empty.""" for format in ("--text", "--xml"): code = call_quiet([self.NDIFF, format, "test-scans/simple.xml", "test-scans/complex.xml"]) self.assertEqual(code, 1) def test_exit_error(self): """Test that the exit code is 2 when there is an error.""" code = call_quiet([self.NDIFF]) self.assertEqual(code, 2) code = call_quiet([self.NDIFF, "test-scans/simple.xml"]) self.assertEqual(code, 2) code = call_quiet([self.NDIFF, "test-scans/simple.xml", "test-scans/nonexistent.xml"]) self.assertEqual(code, 2) code = call_quiet([self.NDIFF, "--nothing"]) self.assertEqual(code, 2) unittest.main()