diff options
Diffstat (limited to '')
-rwxr-xr-x | ndiff/ndifftest.py | 807 |
1 files changed, 807 insertions, 0 deletions
diff --git a/ndiff/ndifftest.py b/ndiff/ndifftest.py new file mode 100755 index 0000000..27fc525 --- /dev/null +++ b/ndiff/ndifftest.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python3 + +# Unit tests for Ndiff. + +import subprocess +import sys +import unittest + +# Prevent loading PyXML +import xml +xml.__path__ = [x for x in xml.__path__ if "_xmlplus" not in x] + +import xml.dom.minidom + +import imp +dont_write_bytecode = sys.dont_write_bytecode +sys.dont_write_bytecode = True +ndiff = imp.load_source("ndiff", "ndiff.py") +for x in dir(ndiff): + if not x.startswith("_"): + globals()[x] = getattr(ndiff, x) +sys.dont_write_bytecode = dont_write_bytecode +del dont_write_bytecode + +import io + + +class scan_test(unittest.TestCase): + """Test the Scan class.""" + def test_empty(self): + scan = Scan() + scan.load_from_file("test-scans/empty.xml") + self.assertEqual(len(scan.hosts), 0) + self.assertNotEqual(scan.start_date, None) + self.assertNotEqual(scan.end_date, None) + + def test_single(self): + scan = Scan() + scan.load_from_file("test-scans/single.xml") + self.assertEqual(len(scan.hosts), 1) + + def test_simple(self): + """Test that the correct number of known ports is returned when there + are no extraports.""" + scan = Scan() + scan.load_from_file("test-scans/simple.xml") + host = scan.hosts[0] + self.assertEqual(len(host.ports), 2) + + def test_extraports(self): + scan = Scan() + scan.load_from_file("test-scans/single.xml") + host = scan.hosts[0] + self.assertEqual(len(host.ports), 5) + self.assertEqual(list(host.extraports.items()), [("filtered", 95)]) + + def test_extraports_multi(self): + """Test that the correct number of known ports is returned when there + are extraports in more than one state.""" + scan = Scan() + scan.load_from_file("test-scans/complex.xml") + host = scan.hosts[0] + self.assertEqual(len(host.ports), 6) + self.assertEqual(set(host.extraports.items()), + set([("filtered", 95), ("open|filtered", 99)])) + + def test_nmaprun(self): + """Test that nmaprun information is recorded.""" + scan = Scan() + scan.load_from_file("test-scans/empty.xml") + self.assertEqual(scan.scanner, "nmap") + self.assertEqual(scan.version, "4.90RC2") + self.assertEqual(scan.args, "nmap -oX empty.xml -p 1-100") + + def test_addresses(self): + """Test that addresses are recorded.""" + scan = Scan() + scan.load_from_file("test-scans/simple.xml") + host = scan.hosts[0] + self.assertEqual(host.addresses, [IPv4Address("64.13.134.52")]) + + def test_hostname(self): + """Test that hostnames are recorded.""" + scan = Scan() + scan.load_from_file("test-scans/simple.xml") + host = scan.hosts[0] + self.assertEqual(host.hostnames, ["scanme.nmap.org"]) + + def test_os(self): + """Test that OS information is recorded.""" + scan = Scan() + scan.load_from_file("test-scans/complex.xml") + host = scan.hosts[0] + self.assertTrue(len(host.os) > 0) + + def test_script(self): + """Test that script results are recorded.""" + scan = Scan() + scan.load_from_file("test-scans/complex.xml") + host = scan.hosts[0] + self.assertTrue(len(host.script_results) > 0) + self.assertTrue(len(host.ports[(22, "tcp")].script_results) > 0) + +# This test is commented out because Nmap XML doesn't store any information +# about down hosts, not even the fact that they are down. Recovering the list +# of scanned hosts to infer which ones are down would involve parsing the +# targets out of the /nmaprun/@args attribute (which is non-trivial) and +# possibly looking up their addresses. +# def test_down_state(self): +# """Test that hosts that are not marked "up" are in the "down" +# state.""" +# scan = Scan() +# scan.load_from_file("test-scans/down.xml") +# self.assertTrue(len(scan.hosts) == 1) +# host = scan.hosts[0] +# self.assertTrue(host.state == "down") + + +class host_test(unittest.TestCase): + """Test the Host class.""" + def test_empty(self): + h = Host() + self.assertEqual(len(h.addresses), 0) + self.assertEqual(len(h.hostnames), 0) + self.assertEqual(len(h.ports), 0) + self.assertEqual(len(h.extraports), 0) + self.assertEqual(len(h.os), 0) + + def test_format_name(self): + h = Host() + self.assertTrue(isinstance(h.format_name(), str)) + h.add_address(IPv4Address("127.0.0.1")) + self.assertTrue("127.0.0.1" in h.format_name()) + h.add_address(IPv6Address("::1")) + self.assertTrue("127.0.0.1" in h.format_name()) + self.assertTrue("::1" in h.format_name()) + h.add_hostname("localhost") + self.assertTrue("127.0.0.1" in h.format_name()) + self.assertTrue("::1" in h.format_name()) + self.assertTrue("localhost" in h.format_name()) + + def test_empty_get_port(self): + h = Host() + for num in 10, 100, 1000, 10000: + for proto in ("tcp", "udp", "ip"): + port = h.ports.get((num, proto)) + self.assertEqual(port, None) + + def test_add_port(self): + h = Host() + spec = (10, "tcp") + port = h.ports.get(spec) + self.assertEqual(port, None) + h.add_port(Port(spec, "open")) + self.assertEqual(len(h.ports), 1) + port = h.ports[spec] + self.assertEqual(port.state, "open") + h.add_port(Port(spec, "closed")) + self.assertEqual(len(h.ports), 1) + port = h.ports[spec] + self.assertEqual(port.state, "closed") + + spec = (22, "tcp") + port = h.ports.get(spec) + self.assertEqual(port, None) + port = Port(spec) + port.state = "open" + port.service.name = "ssh" + h.add_port(port) + self.assertEqual(len(h.ports), 2) + port = h.ports[spec] + self.assertEqual(port.state, "open") + self.assertEqual(port.service.name, "ssh") + + def test_extraports(self): + h = Host() + self.assertFalse(h.is_extraports("open")) + self.assertFalse(h.is_extraports("closed")) + self.assertFalse(h.is_extraports("filtered")) + h.extraports["closed"] = 10 + self.assertFalse(h.is_extraports("open")) + self.assertTrue(h.is_extraports("closed")) + self.assertFalse(h.is_extraports("filtered")) + h.extraports["filtered"] = 10 + self.assertFalse(h.is_extraports("open")) + self.assertTrue(h.is_extraports("closed")) + self.assertTrue(h.is_extraports("filtered")) + del h.extraports["closed"] + del h.extraports["filtered"] + self.assertFalse(h.is_extraports("open")) + self.assertFalse(h.is_extraports("closed")) + self.assertFalse(h.is_extraports("filtered")) + + def test_parse(self): + s = Scan() + s.load_from_file("test-scans/single.xml") + h = s.hosts[0] + self.assertEqual(len(h.ports), 5) + self.assertEqual(len(h.extraports), 1) + self.assertEqual(list(h.extraports.keys())[0], "filtered") + self.assertEqual(list(h.extraports.values())[0], 95) + self.assertEqual(h.state, "up") + + +class address_test(unittest.TestCase): + """Test the Address class.""" + def test_ipv4_new(self): + a = Address.new("ipv4", "127.0.0.1") + self.assertEqual(a.type, "ipv4") + + def test_ipv6_new(self): + a = Address.new("ipv6", "::1") + self.assertEqual(a.type, "ipv6") + + def test_mac_new(self): + a = Address.new("mac", "00:00:00:00:00:00") + self.assertEqual(a.type, "mac") + + def test_unknown_new(self): + self.assertRaises(ValueError, Address.new, "aaa", "") + + def test_compare(self): + """Test that addresses with the same contents compare equal.""" + a = IPv4Address("127.0.0.1") + self.assertEqual(a, a) + b = IPv4Address("127.0.0.1") + self.assertEqual(a, b) + c = Address.new("ipv4", "127.0.0.1") + self.assertEqual(a, c) + self.assertEqual(b, c) + + d = IPv4Address("1.1.1.1") + self.assertNotEqual(a, d) + + e = IPv6Address("::1") + self.assertEqual(e, e) + self.assertNotEqual(a, e) + + +class port_test(unittest.TestCase): + """Test the Port class.""" + def test_spec_string(self): + p = Port((10, "tcp")) + self.assertEqual(p.spec_string(), "10/tcp") + p = Port((100, "ip")) + self.assertEqual(p.spec_string(), "100/ip") + + def test_state_string(self): + p = Port((10, "tcp")) + self.assertEqual(p.state_string(), "unknown") + + +class service_test(unittest.TestCase): + """Test the Service class.""" + def test_compare(self): + """Test that services with the same contents compare equal.""" + a = Service() + a.name = "ftp" + a.product = "FooBar FTP" + a.version = "1.1.1" + a.tunnel = "ssl" + self.assertEqual(a, a) + b = Service() + b.name = "ftp" + b.product = "FooBar FTP" + b.version = "1.1.1" + b.tunnel = "ssl" + self.assertEqual(a, b) + b.name = "http" + self.assertNotEqual(a, b) + c = Service() + self.assertNotEqual(a, c) + + def test_tunnel(self): + serv = Service() + serv.name = "http" + serv.tunnel = "ssl" + self.assertEqual(serv.name_string(), "ssl/http") + + def test_version_string(self): + serv = Service() + serv.product = "FooBar" + self.assertTrue(len(serv.version_string()) > 0) + serv = Service() + serv.version = "1.2.3" + self.assertTrue(len(serv.version_string()) > 0) + serv = Service() + serv.extrainfo = "misconfigured" + self.assertTrue(len(serv.version_string()) > 0) + serv = Service() + serv.product = "FooBar" + serv.version = "1.2.3" + # Must match Nmap output. + self.assertEqual(serv.version_string(), + "%s %s" % (serv.product, serv.version)) + serv.extrainfo = "misconfigured" + self.assertEqual(serv.version_string(), + "%s %s (%s)" % (serv.product, serv.version, serv.extrainfo)) + + +class ScanDiffSub(ScanDiff): + """A subclass of ScanDiff that counts diffs for testing.""" + def __init__(self, scan_a, scan_b, f=sys.stdout): + ScanDiff.__init__(self, scan_a, scan_b, f) + self.pre_script_result_diffs = [] + self.post_script_result_diffs = [] + self.host_diffs = [] + + def output_beginning(self): + pass + + def output_pre_scripts(self, pre_script_result_diffs): + self.pre_script_result_diffs = pre_script_result_diffs + + def output_post_scripts(self, post_script_result_diffs): + self.post_script_result_diffs = post_script_result_diffs + + def output_host_diff(self, h_diff): + self.host_diffs.append(h_diff) + + def output_ending(self): + pass + + +class scan_diff_test(unittest.TestCase): + """Test the ScanDiff class.""" + def setUp(self): + self.blackhole = open("/dev/null", "w") + + def tearDown(self): + self.blackhole.close() + + def test_self(self): + scan = Scan() + scan.load_from_file("test-scans/complex.xml") + diff = ScanDiffText(scan, scan, self.blackhole) + cost = diff.output() + self.assertEqual(cost, 0) + diff = ScanDiffXML(scan, scan, self.blackhole) + cost = diff.output() + self.assertEqual(cost, 0) + + def test_unknown_up(self): + a = Scan() + a.load_from_file("test-scans/empty.xml") + b = Scan() + b.load_from_file("test-scans/simple.xml") + diff = ScanDiffSub(a, b, self.blackhole) + diff.output() + self.assertEqual(len(diff.pre_script_result_diffs), 0) + self.assertEqual(len(diff.post_script_result_diffs), 0) + self.assertEqual(len(diff.host_diffs), 1) + h_diff = diff.host_diffs[0] + self.assertEqual(h_diff.host_a.state, None) + self.assertEqual(h_diff.host_b.state, "up") + + def test_up_unknown(self): + a = Scan() + a.load_from_file("test-scans/simple.xml") + b = Scan() + b.load_from_file("test-scans/empty.xml") + diff = ScanDiffSub(a, b, self.blackhole) + diff.output() + self.assertEqual(len(diff.pre_script_result_diffs), 0) + self.assertEqual(len(diff.post_script_result_diffs), 0) + self.assertEqual(len(diff.host_diffs), 1) + h_diff = diff.host_diffs[0] + self.assertEqual(h_diff.host_a.state, "up") + self.assertEqual(h_diff.host_b.state, None) + + def test_diff_is_effective(self): + """Test that a scan diff is effective. This means that if the + recommended changes are applied to the first scan the scans become the + same.""" + PAIRS = ( + ("empty", "empty"), + ("simple", "complex"), + ("complex", "simple"), + ("single", "os"), + ("os", "single"), + ("random-1", "simple"), + ("simple", "random-1"), + ) + for pair in PAIRS: + a = Scan() + a.load_from_file("test-scans/%s.xml" % pair[0]) + b = Scan() + b.load_from_file("test-scans/%s.xml" % pair[1]) + diff = ScanDiffSub(a, b) + scan_apply_diff(a, diff) + diff = ScanDiffSub(a, b) + self.assertEqual(diff.host_diffs, []) + + +class host_diff_test(unittest.TestCase): + """Test the HostDiff class.""" + def test_empty(self): + a = Host() + b = Host() + diff = HostDiff(a, b) + self.assertFalse(diff.id_changed) + self.assertFalse(diff.state_changed) + self.assertFalse(diff.os_changed) + self.assertFalse(diff.extraports_changed) + self.assertEqual(diff.cost, 0) + + def test_self(self): + h = Host() + h.add_port(Port((10, "tcp"), "open")) + h.add_port(Port((22, "tcp"), "closed")) + diff = HostDiff(h, h) + self.assertFalse(diff.id_changed) + self.assertFalse(diff.state_changed) + self.assertFalse(diff.os_changed) + self.assertFalse(diff.extraports_changed) + self.assertEqual(diff.cost, 0) + + def test_state_change(self): + a = Host() + b = Host() + a.state = "up" + b.state = "down" + diff = HostDiff(a, b) + self.assertTrue(diff.state_changed) + self.assertTrue(diff.cost > 0) + + def test_state_change_unknown(self): + a = Host() + b = Host() + a.state = "up" + diff = HostDiff(a, b) + self.assertTrue(diff.state_changed) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.state_changed) + self.assertTrue(diff.cost > 0) + + def test_address_change(self): + a = Host() + b = Host() + b.add_address(Address.new("ipv4", "127.0.0.1")) + diff = HostDiff(a, b) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + a.add_address(Address.new("ipv4", "1.1.1.1")) + diff = HostDiff(a, b) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + + def test_hostname_change(self): + a = Host() + b = Host() + b.add_hostname("host-1") + diff = HostDiff(a, b) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + a.add_address("host-2") + diff = HostDiff(a, b) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.id_changed) + self.assertTrue(diff.cost > 0) + + def test_port_state_change(self): + a = Host() + b = Host() + spec = (10, "tcp") + a.add_port(Port(spec, "open")) + b.add_port(Port(spec, "closed")) + diff = HostDiff(a, b) + self.assertTrue(len(diff.ports) > 0) + self.assertEqual(set(diff.ports), set(diff.port_diffs.keys())) + self.assertTrue(diff.cost > 0) + + def test_port_state_change_unknown(self): + a = Host() + b = Host() + b.add_port(Port((10, "tcp"), "open")) + diff = HostDiff(a, b) + self.assertTrue(len(diff.ports) > 0) + self.assertEqual(set(diff.ports), set(diff.port_diffs.keys())) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(len(diff.ports) > 0) + self.assertEqual(set(diff.ports), set(diff.port_diffs.keys())) + self.assertTrue(diff.cost > 0) + + def test_port_state_change_multi(self): + a = Host() + b = Host() + a.add_port(Port((10, "tcp"), "open")) + a.add_port(Port((20, "tcp"), "closed")) + a.add_port(Port((30, "tcp"), "open")) + b.add_port(Port((10, "tcp"), "open")) + b.add_port(Port((20, "tcp"), "open")) + b.add_port(Port((30, "tcp"), "open")) + diff = HostDiff(a, b) + self.assertTrue(diff.cost > 0) + + def test_os_change(self): + a = Host() + b = Host() + a.os.append("os-1") + diff = HostDiff(a, b) + self.assertTrue(diff.os_changed) + self.assertTrue(len(diff.os_diffs) > 0) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.os_changed) + self.assertTrue(len(diff.os_diffs) > 0) + self.assertTrue(diff.cost > 0) + b.os.append("os-2") + diff = HostDiff(a, b) + self.assertTrue(diff.os_changed) + self.assertTrue(len(diff.os_diffs) > 0) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.os_changed) + self.assertTrue(len(diff.os_diffs) > 0) + self.assertTrue(diff.cost > 0) + + def test_extraports_change(self): + a = Host() + b = Host() + a.extraports = {"open": 100} + diff = HostDiff(a, b) + self.assertTrue(diff.extraports_changed) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.extraports_changed) + self.assertTrue(diff.cost > 0) + b.extraports = {"closed": 100} + diff = HostDiff(a, b) + self.assertTrue(diff.extraports_changed) + self.assertTrue(diff.cost > 0) + diff = HostDiff(b, a) + self.assertTrue(diff.extraports_changed) + self.assertTrue(diff.cost > 0) + + def test_diff_is_effective(self): + """Test that a host diff is effective. + This means that if the recommended changes are applied to the first + host the hosts become the same.""" + a = Host() + b = Host() + + a.state = "up" + b.state = "down" + + a.add_port(Port((10, "tcp"), "open")) + a.add_port(Port((20, "tcp"), "closed")) + a.add_port(Port((40, "udp"), "open|filtered")) + b.add_port(Port((10, "tcp"), "open")) + b.add_port(Port((30, "tcp"), "open")) + a.add_port(Port((40, "udp"), "open")) + + a.add_hostname("a") + a.add_hostname("localhost") + b.add_hostname("b") + b.add_hostname("localhost") + b.add_hostname("b.example.com") + + b.add_address(Address.new("ipv4", "1.2.3.4")) + + a.os = ["os-1", "os-2"] + b.os = ["os-2", "os-3"] + + a.extraports = {"filtered": 99} + + diff = HostDiff(a, b) + host_apply_diff(a, diff) + diff = HostDiff(a, b) + + self.assertFalse(diff.id_changed) + self.assertFalse(diff.state_changed) + self.assertFalse(diff.os_changed) + self.assertFalse(diff.extraports_changed) + self.assertEqual(diff.cost, 0) + + +class port_diff_test(unittest.TestCase): + """Test the PortDiff class.""" + def test_equal(self): + spec = (10, "tcp") + a = Port(spec) + b = Port(spec) + diff = PortDiff(a, b) + self.assertEqual(diff.cost, 0) + + def test_self(self): + p = Port((10, "tcp")) + diff = PortDiff(p, p) + self.assertEqual(diff.cost, 0) + + def test_state_change(self): + spec = (10, "tcp") + a = Port(spec) + a.state = "open" + b = Port(spec) + b.state = "closed" + diff = PortDiff(a, b) + self.assertTrue(diff.cost > 0) + self.assertEqual(PortDiff(a, diff.port_a).cost, 0) + self.assertEqual(PortDiff(b, diff.port_b).cost, 0) + + def test_id_change(self): + a = Port((10, "tcp")) + a.state = "open" + b = Port((20, "tcp")) + b.state = "open" + diff = PortDiff(a, b) + self.assertTrue(diff.cost > 0) + self.assertEqual(PortDiff(a, diff.port_a).cost, 0) + self.assertEqual(PortDiff(b, diff.port_b).cost, 0) + + +class table_test(unittest.TestCase): + """Test the table class.""" + def test_empty(self): + t = Table("") + self.assertEqual(str(t), "") + t = Table("***") + self.assertEqual(str(t), "") + t = Table("* * *") + self.assertEqual(str(t), "") + + def test_none(self): + """Test that None is treated like an empty string when it is not at the + end of a row.""" + t = Table("* * *") + t.append((None, "a", "b")) + self.assertEqual(str(t), " a b") + t = Table("* * *") + t.append(("a", None, "b")) + self.assertEqual(str(t), "a b") + t = Table("* * *") + t.append((None, None, "a")) + self.assertEqual(str(t), " a") + + def test_prefix(self): + t = Table("<<<") + t.append(("a", "b", "c")) + self.assertEqual(str(t), "<<<abc") + + def test_padding(self): + t = Table("<<<*>>>*!!!") + t.append(()) + self.assertEqual(str(t), "<<<") + t = Table("<<<*>>>*!!!") + t.append(("a")) + self.assertEqual(str(t), "<<<a>>>") + t = Table("<<<*>>>*!!!") + t.append(("a", "b", "c", "d")) + self.assertEqual(str(t), "<<<a>>>b!!!cd") + + def test_append_raw(self): + """Test the append_raw method that inserts an unformatted row.""" + t = Table("<* * *>") + t.append(("1", "2", "3")) + t.append_raw(" row ") + self.assertEqual(str(t), "<1 2 3>\n row ") + t.append(("4", "5", "6")) + self.assertEqual(str(t), "<1 2 3>\n row \n<4 5 6>") + + def test_strip(self): + """Test that trailing whitespace is stripped.""" + t = Table("* * * ") + t.append(("a", "b", None)) + self.assertEqual(str(t), "a b") + t = Table("* * *") + t.append(("a", None, None)) + self.assertEqual(str(t), "a") + t = Table("* * *") + t.append(("a", "b", "")) + self.assertEqual(str(t), "a b") + t = Table("* * *") + t.append(("a", "", "")) + self.assertEqual(str(t), "a") + + def test_newline(self): + """Test that there is no trailing newline in a table.""" + t = Table("*") + self.assertFalse(str(t).endswith("\n")) + t.append(("a")) + self.assertFalse(str(t).endswith("\n")) + t.append(("b")) + self.assertFalse(str(t).endswith("\n")) + + +class scan_diff_xml_test(unittest.TestCase): + def setUp(self): + a = Scan() + a.load_from_file("test-scans/empty.xml") + b = Scan() + b.load_from_file("test-scans/simple.xml") + f = io.StringIO() + self.scan_diff = ScanDiffXML(a, b, f) + self.scan_diff.output() + self.xml = f.getvalue() + f.close() + + def test_well_formed(self): + try: + document = xml.dom.minidom.parseString(self.xml) + except Exception as e: + self.fail("Parsing XML diff output caused the exception: %s" + % str(e)) + + +def scan_apply_diff(scan, diff): + """Apply a scan diff to the given scan.""" + for h_diff in diff.host_diffs: + host = h_diff.host_a or h_diff.host_b + if host not in scan.hosts: + scan.hosts.append(host) + host_apply_diff(host, h_diff) + + +def host_apply_diff(host, diff): + """Apply a host diff to the given host.""" + if diff.state_changed: + host.state = diff.host_b.state + + if diff.id_changed: + host.addresses = diff.host_b.addresses[:] + host.hostnames = diff.host_b.hostnames[:] + + if diff.os_changed: + host.os = diff.host_b.os[:] + + if diff.extraports_changed: + for state in list(host.extraports.keys()): + for port in list(host.ports.values()): + if port.state == state: + del host.ports[port.spec] + host.extraports = diff.host_b.extraports.copy() + + for port in diff.ports: + port_b = diff.port_diffs[port].port_b + if port_b.state is None: + del host.ports[port.spec] + else: + host.ports[port.spec] = diff.port_diffs[port].port_b + + for sr_diff in diff.script_result_diffs: + sr_a = sr_diff.sr_a + sr_b = sr_diff.sr_b + if sr_a is None: + host.script_results.append(sr_b) + elif sr_b is None: + host.script_results.remove(sr_a) + else: + host.script_results[host.script_results.index(sr_a)] = sr_b + host.script_results.sort() + + +def call_quiet(args, **kwargs): + """Run a command with subprocess.call and hide its output.""" + return subprocess.call(args, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, env={'PYTHONPATH': "."}, **kwargs) + + +class exit_code_test(unittest.TestCase): + NDIFF = "./scripts/ndiff" + + def test_exit_equal(self): + """Test that the exit code is 0 when the diff is empty.""" + for format in ("--text", "--xml"): + code = call_quiet([self.NDIFF, format, + "test-scans/simple.xml", "test-scans/simple.xml"]) + self.assertEqual(code, 0) + # Should be independent of verbosity. + for format in ("--text", "--xml"): + code = call_quiet([self.NDIFF, "-v", format, + "test-scans/simple.xml", "test-scans/simple.xml"]) + self.assertEqual(code, 0) + + def test_exit_different(self): + """Test that the exit code is 1 when the diff is not empty.""" + for format in ("--text", "--xml"): + code = call_quiet([self.NDIFF, format, + "test-scans/simple.xml", "test-scans/complex.xml"]) + self.assertEqual(code, 1) + + def test_exit_error(self): + """Test that the exit code is 2 when there is an error.""" + code = call_quiet([self.NDIFF]) + self.assertEqual(code, 2) + code = call_quiet([self.NDIFF, "test-scans/simple.xml"]) + self.assertEqual(code, 2) + code = call_quiet([self.NDIFF, "test-scans/simple.xml", + "test-scans/nonexistent.xml"]) + self.assertEqual(code, 2) + code = call_quiet([self.NDIFF, "--nothing"]) + self.assertEqual(code, 2) + +unittest.main() |