diff options
Diffstat (limited to '')
-rw-r--r-- | tests/deckard/pydnstest/scenario.py | 1058 |
1 files changed, 1058 insertions, 0 deletions
diff --git a/tests/deckard/pydnstest/scenario.py b/tests/deckard/pydnstest/scenario.py new file mode 100644 index 0000000..5e0661b --- /dev/null +++ b/tests/deckard/pydnstest/scenario.py @@ -0,0 +1,1058 @@ +# FIXME pylint: disable=too-many-lines +from abc import ABC +import binascii +import calendar +from datetime import datetime +import errno +import logging +import os +import posixpath +import random +import socket +import string +import struct +import time +from typing import Optional + +import dns.dnssec +import dns.message +import dns.name +import dns.rcode +import dns.rrset +import dns.tsigkeyring + +import pydnstest.augwrap +import pydnstest.matchpart + + +def str2bool(v): + """ Return conversion of JSON-ish string value to boolean. """ + return v.lower() in ('yes', 'true', 'on', '1') + + +# Global statistics +g_rtt = 0.0 +g_nqueries = 0 + + +def recvfrom_msg(stream, raw=False): + """ + Receive DNS message from TCP/UDP socket. + + Returns: + if raw == False: (DNS message object, peer address) + if raw == True: (blob, peer address) + """ + if stream.type & socket.SOCK_DGRAM: + data, addr = stream.recvfrom(4096) + elif stream.type & socket.SOCK_STREAM: + data = stream.recv(2) + if not data: + return None, None + msg_len = struct.unpack_from("!H", data)[0] + data = b"" + received = 0 + while received < msg_len: + next_chunk = stream.recv(4096) + if not next_chunk: + return None, None + data += next_chunk + received += len(next_chunk) + addr = stream.getpeername()[0] + else: + raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type) + if raw: + return data, addr + else: + msg = dns.message.from_wire(data, one_rr_per_rrset=True) + return msg, addr + + +def sendto_msg(stream, message, addr=None): + """ Send DNS/UDP/TCP message. """ + try: + if stream.type & socket.SOCK_DGRAM: + if addr is None: + stream.send(message) + else: + stream.sendto(message, addr) + elif stream.type & socket.SOCK_STREAM: + data = struct.pack("!H", len(message)) + message + stream.send(data) + else: + raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type) + except socket.error as ex: + if ex.errno != errno.ECONNREFUSED: # TODO Investigate how this can happen + raise + + +def replay_rrs(rrs, nqueries, destination, args=None): + """ Replay list of queries and report statistics. """ + if args is None: + args = [] + navail, queries = len(rrs), [] + chunksize = 16 + for i in range(nqueries if 'RAND' in args else navail): + rr = rrs[i % navail] + name = rr.name + if 'RAND' in args: + prefix = ''.join([random.choice(string.ascii_letters + string.digits) + for _ in range(8)]) + name = prefix + '.' + rr.name.to_text() + msg = dns.message.make_query(name, rr.rdtype, rr.rdclass) + if 'DO' in args: + msg.want_dnssec(True) + queries.append(msg.to_wire()) + # Make a UDP connected socket to the destination + family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET + sock = socket.socket(family, socket.SOCK_DGRAM) + sock.connect(destination) + sock.setblocking(False) + # Play the query set + # @NOTE: this is only good for relative low-speed replay + rcvbuf = bytearray('\x00' * 512) + nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries) + fdset = [sock] + import select + while nsent - nwait < nqueries: + to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5) + if to_write: + try: + while nsent < nqueries and nwait < chunksize: + sock.send(queries[nsent % navail]) + nwait += 1 + nsent += 1 + except socket.error: + pass # EINVAL + if to_read: + try: + while nwait > 0: + sock.recv_into(rcvbuf) + nwait -= 1 + nrcvd += 1 + except socket.error: + pass + if not to_write and not to_read: + nwait = 0 # Timeout, started dropping packets + break + return nsent, nrcvd + + +class DNSBlob(ABC): + def to_wire(self) -> bytes: + raise NotImplementedError + + def __str__(self) -> str: + return '<DNSBlob>' + + +class DNSMessage(DNSBlob): + def __init__(self, message: dns.message.Message) -> None: + assert message is not None + self.message = message + + def to_wire(self) -> bytes: + return self.message.to_wire(max_size=65535) + + def __str__(self) -> str: + return str(self.message) + + +class DNSReply(DNSMessage): + def __init__( + self, + message: dns.message.Message, + query: Optional[dns.message.Message] = None, + copy_id: bool = False, + copy_query: bool = False + ) -> None: + super().__init__(message) + if copy_id or copy_query: + if query is None: + raise ValueError("query must be provided to adjust copy_id/copy_query") + self.adjust_reply(query, copy_id, copy_query) + + def adjust_reply( + self, + query: dns.message.Message, + copy_id: bool = True, + copy_query: bool = True + ) -> None: + answer = dns.message.from_wire(self.message.to_wire(), + xfr=self.message.xfr, + one_rr_per_rrset=True) + answer.use_edns(query.edns, query.ednsflags, options=self.message.options) + if copy_id: + answer.id = query.id + # Copy letter-case if the template has QD + if answer.question: + answer.question[0].name = query.question[0].name + if copy_query: + answer.question = query.question + # Re-set, as the EDNS might have reset the ext-rcode + answer.set_rcode(self.message.rcode()) + + # sanity check: adjusted answer should be almost the same + assert len(answer.answer) == len(self.message.answer) + assert len(answer.authority) == len(self.message.authority) + assert len(answer.additional) == len(self.message.additional) + self.message = answer + + +class DNSReplyRaw(DNSBlob): + def __init__( + self, + wire: bytes, + query: Optional[dns.message.Message] = None, + copy_id: bool = False + ) -> None: + assert wire is not None + self.wire = wire + if copy_id: + self.adjust_reply(query, copy_id) + + def adjust_reply( + self, + query: dns.message.Message, + copy_id: bool = True + ) -> None: + if copy_id: + if len(self.wire) < 2: + raise ValueError( + 'wire data must contain at least 2 bytes to adjust query id') + raw_answer = bytearray(self.wire) + struct.pack_into('!H', raw_answer, 0, query.id) + self.wire = bytes(raw_answer) + + def to_wire(self) -> bytes: + return self.wire + + def __str__(self) -> str: + return '<DNSReplyRaw>' + + +class DNSReplyServfail(DNSMessage): + def __init__(self, query: dns.message.Message) -> None: + message = dns.message.make_response(query) + message.set_rcode(dns.rcode.SERVFAIL) + super().__init__(message) + + +class Entry: + """ + Data entry represents scripted message and extra metadata, + notably match criteria and reply adjustments. + """ + + # Globals + default_ttl = 3600 + default_cls = 'IN' + default_rc = 'NOERROR' + + def __init__(self, node): + """ Initialize data entry. """ + self.node = node + self.origin = '.' + self.message = dns.message.Message() + self.message.use_edns(edns=0, payload=4096) + self.fired = 0 + + # RAW + self.raw_data = None # type: Optional[bytes] + self.is_raw_data_entry = self.process_raw() + + # MATCH + self.match_fields = self.process_match() + + # FLAGS + self.process_reply_line() + + # ADJUST + self.adjust_fields = {m.value for m in node.match("/adjust")} + + # MANDATORY + try: + self.mandatory = list(node.match("/mandatory"))[0] + except (KeyError, IndexError): + self.mandatory = None + + # TSIG + self.process_tsig() + + # SECTIONS & RECORDS + self.sections = self.process_sections() + + def process_raw(self): + try: + self.raw_data = binascii.unhexlify(self.node["/raw"].value) + return True + except KeyError: + return False + + def process_match(self): + try: + self.node["/match_present"] + except KeyError: + return None + + fields = set(m.value for m in self.node.match("/match")) + + if 'all' in fields: + fields.remove("all") + fields |= set(["opcode", "qtype", "qname", "flags", + "rcode", "answer", "authority", "additional"]) + + if 'question' in fields: + fields.remove("question") + fields |= set(["qtype", "qname"]) + + return fields + + def process_reply_line(self): + """Extracts flags, rcode and opcode from given node and adjust dns message accordingly""" + self.fields = [f.value for f in self.node.match("/reply")] + if 'DO' in self.fields: + self.message.want_dnssec(True) + opcode = self.get_opcode(fields=self.fields) + rcode = self.get_rcode(fields=self.fields) + self.message.flags = self.get_flags(fields=self.fields) + if rcode is not None: + self.message.set_rcode(rcode) + if opcode is not None: + self.message.set_opcode(opcode) + + def process_tsig(self): + try: + tsig = list(self.node.match("/tsig"))[0] + tsig_keyname = tsig["/keyname"].value + tsig_secret = tsig["/secret"].value + keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret}) + self.message.use_tsig(keyring=keyring, keyname=tsig_keyname) + except (KeyError, IndexError): + pass + + def process_sections(self): + sections = set() + for section in self.node.match("/section/*"): + section_name = posixpath.basename(section.path) + sections.add(section_name) + for record in section.match("/record"): + owner = record['/domain'].value + if not owner.endswith("."): + owner += self.origin + try: + ttl = dns.ttl.from_text(record['/ttl'].value) + except KeyError: + ttl = self.default_ttl + try: + rdclass = dns.rdataclass.from_text(record['/class'].value) + except KeyError: + rdclass = dns.rdataclass.from_text(self.default_cls) + rdtype = dns.rdatatype.from_text(record['/type'].value) + rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype) + if section_name != "question": + rd = record['/data'].value.split() + if rd: + if rdtype == dns.rdatatype.DS: + rd[1] = str(dns.dnssec.algorithm_from_text(rd[1])) + rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join( + rd), origin=dns.name.from_text(self.origin), relativize=False) + rr.add(rd) + if section_name == 'question': + if rr.rdtype == dns.rdatatype.AXFR: + self.message.xfr = True + self.message.question.append(rr) + elif section_name == 'answer': + self.message.answer.append(rr) + elif section_name == 'authority': + self.message.authority.append(rr) + elif section_name == 'additional': + self.message.additional.append(rr) + return sections + + def __str__(self): + txt = 'ENTRY_BEGIN\n' + if not self.is_raw_data_entry: + txt += 'MATCH {0}\n'.format(' '.join(self.match_fields)) + txt += 'ADJUST {0}\n'.format(' '.join(self.adjust_fields)) + txt += 'REPLY {rcode} {flags}\n'.format( + rcode=dns.rcode.to_text(self.message.rcode()), + flags=' '.join([dns.flags.to_text(self.message.flags), + dns.flags.edns_to_text(self.message.ednsflags)]) + ) + for sect_name in ['question', 'answer', 'authority', 'additional']: + sect = getattr(self.message, sect_name) + if not sect: + continue + txt += 'SECTION {n}\n'.format(n=sect_name.upper()) + for rr in sect: + txt += str(rr) + txt += '\n' + if self.is_raw_data_entry: + txt += 'RAW\n' + if self.raw_data: + txt += binascii.hexlify(self.raw_data) + else: + txt += 'NULL' + txt += '\n' + txt += 'ENTRY_END\n' + return txt + + @classmethod + def get_flags(cls, fields): + """From `fields` extracts and returns flags""" + flags = [] + for code in fields: + try: + dns.flags.from_text(code) # throws KeyError on failure + flags.append(code) + except KeyError: + pass + return dns.flags.from_text(' '.join(flags)) + + @classmethod + def get_rcode(cls, fields): + """ + From `fields` extracts and returns rcode. + Throws `ValueError` if there are more then one rcodes + """ + rcodes = [] + for code in fields: + try: + rcodes.append(dns.rcode.from_text(code)) + except dns.rcode.UnknownRcode: + pass + if len(rcodes) > 1: + raise ValueError("Parse failed, too many rcode values.", rcodes) + if not rcodes: + return None + return rcodes[0] + + @classmethod + def get_opcode(cls, fields): + """ + From `fields` extracts and returns opcode. + Throws `ValueError` if there are more then one opcodes + """ + opcodes = [] + for code in fields: + try: + opcodes.append(dns.opcode.from_text(code)) + except dns.opcode.UnknownOpcode: + pass + if len(opcodes) > 1: + raise ValueError("Parse failed, too many opcode values.") + if not opcodes: + return None + return opcodes[0] + + def match(self, msg): + """ Compare scripted reply to given message based on match criteria. """ + for code in self.match_fields: + try: + pydnstest.matchpart.match_part(self.message, msg, code) + except pydnstest.matchpart.DataMismatch as ex: + errstr = '%s in the response:\n%s' % (str(ex), msg.to_text()) + # TODO: cisla radku + raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr)) + + def cmp_raw(self, raw_value): + assert self.is_raw_data_entry + expected = None + if self.raw_data is not None: + expected = binascii.hexlify(self.raw_data) + got = None + if raw_value is not None: + got = binascii.hexlify(raw_value) + if expected != got: + raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got)) + + def reply(self, query) -> Optional[DNSBlob]: + if 'do_not_answer' in self.adjust_fields: + return None + if self.is_raw_data_entry: + copy_id = 'raw_data' in self.adjust_fields + assert self.raw_data is not None + return DNSReplyRaw(self.raw_data, query, copy_id) + copy_id = 'copy_id' in self.adjust_fields + copy_query = 'copy_query' in self.adjust_fields + return DNSReply(self.message, query, copy_id, copy_query) + + def set_edns(self, fields): + """ Set EDNS version and bufsize. """ + version = 0 + bufsize = 4096 + if fields and fields[0].isdigit(): + version = int(fields.pop(0)) + if fields and fields[0].isdigit(): + bufsize = int(fields.pop(0)) + if bufsize == 0: + self.message.use_edns(False) + return + opts = [] + for v in fields: + k, v = tuple(v.split('=')) if '=' in v else (v, True) + if k.lower() == 'nsid': + opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v)) + if k.lower() == 'subnet': + net = v.split('/') + subnet_addr = net[0] + family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET + addr = socket.inet_pton(family, subnet_addr) + prefix = len(addr) * 8 + if len(net) > 1: + prefix = int(net[1]) + addr = addr[0: (prefix + 7) / 8] + if prefix % 8 != 0: # Mask the last byte + addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8)) + opts.append(dns.edns.GenericOption(8, struct.pack( + "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr)) + self.message.use_edns(edns=version, payload=bufsize, options=opts) + + +class Range: + """ + Range represents a set of scripted queries valid for given step range. + """ + log = logging.getLogger('pydnstest.scenario.Range') + + def __init__(self, node): + """ Initialize reply range. """ + self.node = node + self.a = int(node['/from'].value) + self.b = int(node['/to'].value) + assert self.a <= self.b + + address = node["/address"].value + self.addresses = {address} if address is not None else set() + self.addresses |= {a.value for a in node.match("/address/*")} + self.stored = [Entry(n) for n in node.match("/entry")] + self.args = {} + self.received = 0 + self.sent = 0 + + def __del__(self): + self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d', + self.a, self.b, self.addresses, self.received, self.sent) + + def __str__(self): + txt = '\nRANGE_BEGIN {a} {b}\n'.format(a=self.a, b=self.b) + for addr in self.addresses: + txt += ' ADDRESS {0}\n'.format(addr) + + for entry in self.stored: + txt += '\n' + txt += str(entry) + txt += 'RANGE_END\n\n' + return txt + + def eligible(self, ident, address): + """ Return true if this range is eligible for fetching reply. """ + if self.a <= ident <= self.b: + return (None is address + or set() == self.addresses + or address in self.addresses) + return False + + def reply(self, query: dns.message.Message) -> Optional[DNSBlob]: + """Get answer for given query (adjusted if needed).""" + self.received += 1 + for candidate in self.stored: + try: + candidate.match(query) + resp = candidate.reply(query) + # Probabilistic loss + if 'LOSS' in self.args: + if random.random() < float(self.args['LOSS']): + return DNSReplyServfail(query) + self.sent += 1 + candidate.fired += 1 + return resp + except ValueError: + pass + return DNSReplyServfail(query) + + +class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods + """ + Prepent Step identification before each log message. + """ + def process(self, msg, kwargs): + return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs + + +class Step: + """ + Step represents one scripted action in a given moment, + each step has an order identifier, type and optionally data entry. + """ + require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY'] + + def __init__(self, node): + """ Initialize single scenario step. """ + self.node = node + self.id = int(node.value) + self.type = node["/type"].value + self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'), + {'id': self.id, 'type': self.type}) + try: + self.delay = int(node["/timestamp"].value) + except KeyError: + pass + self.data = [Entry(n) for n in node.match("/entry")] + self.queries = [] + self.has_data = self.type in Step.require_data + self.answer = None + self.raw_answer = None + self.repeat_if_fail = 0 + self.pause_if_fail = 0 + self.next_if_fail = -1 + + # TODO Parser currently can't parse CHECK_ANSWER args, player doesn't understand them anyway + # if type == 'CHECK_ANSWER': + # for arg in extra_args: + # param = arg.split('=') + # try: + # if param[0] == 'REPEAT': + # self.repeat_if_fail = int(param[1]) + # elif param[0] == 'PAUSE': + # self.pause_if_fail = float(param[1]) + # elif param[0] == 'NEXT': + # self.next_if_fail = int(param[1]) + # except Exception as e: + # raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e))) + + def __str__(self): + txt = '\nSTEP {i} {t}'.format(i=self.id, t=self.type) + if self.repeat_if_fail: + txt += ' REPEAT {v}'.format(v=self.repeat_if_fail) + elif self.pause_if_fail: + txt += ' PAUSE {v}'.format(v=self.pause_if_fail) + elif self.next_if_fail != -1: + txt += ' NEXT {v}'.format(v=self.next_if_fail) + # if self.args: + # txt += ' ' + # txt += ' '.join(self.args) + txt += '\n' + + for data in self.data: + # from IPython.core.debugger import Tracer + # Tracer()() + txt += str(data) + return txt + + def play(self, ctx): + """ Play one step from a scenario. """ + if self.type == 'QUERY': + self.log.info('') + self.log.debug(self.data[0].message.to_text()) + # Parse QUERY-specific parameters + choice, tcp, source = None, False, None + return self.__query(ctx, tcp=tcp, choice=choice, source=source) + elif self.type == 'CHECK_OUT_QUERY': # ignore + self.log.info('') + return None + elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER': + self.log.info('') + return self.__check_answer(ctx) + elif self.type == 'TIME_PASSES ELAPSE': + self.log.info('') + return self.__time_passes() + elif self.type == 'REPLY' or self.type == 'MOCK': + self.log.info('') + return None + # Parser currently doesn't support step types LOG, REPLAY and ASSERT. + # No test uses them. + # elif self.type == 'LOG': + # if not ctx.log: + # raise Exception('scenario has no log interface') + # return ctx.log.match(self.args) + # elif self.type == 'REPLAY': + # self.__replay(ctx) + # elif self.type == 'ASSERT': + # self.__assert(ctx) + else: + raise NotImplementedError('step %03d type %s unsupported' % (self.id, self.type)) + + def __check_answer(self, ctx): + """ Compare answer from previously resolved query. """ + if not self.data: + raise ValueError("response definition required") + expected = self.data[0] + if expected.is_raw_data_entry is True: + self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text()) + expected.cmp_raw(ctx.last_raw_answer) + else: + if ctx.last_answer is None: + raise ValueError("no answer from preceding query") + self.log.debug("answer: %s", ctx.last_answer.to_text()) + expected.match(ctx.last_answer) + + # def __replay(self, ctx, chunksize=8): + # nqueries = len(self.queries) + # if len(self.args) > 0 and self.args[0].isdigit(): + # nqueries = int(self.args.pop(0)) + # destination = ctx.client[ctx.client.keys()[0]] + # self.log.info('replaying %d queries to %s@%d (%s)', + # nqueries, destination[0], destination[1], ' '.join(self.args)) + # if 'INTENSIFY' in os.environ: + # nqueries *= int(os.environ['INTENSIFY']) + # tstart = datetime.now() + # nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args) + # # Keep/print the statistics + # rtt = (datetime.now() - tstart).total_seconds() * 1000 + # pps = 1000 * nrcvd / rtt + # self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps) + # tag = None + # for arg in self.args: + # if arg.upper().startswith('PRINT'): + # _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay') + # if tag: + # self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d', + # tag.ljust(11), pps, rtt, nsent, nrcvd) + + def __query(self, ctx, tcp=False, choice=None, source=None): + """ + Send query and wait for an answer (if the query is not RAW). + + The received answer is stored in self.answer and ctx.last_answer. + """ + if not self.data: + raise ValueError("query definition required") + if self.data[0].is_raw_data_entry is True: + data_to_wire = self.data[0].raw_data + else: + # Don't use a message copy as the EDNS data portion is not copied. + data_to_wire = self.data[0].message.to_wire() + if choice is None or not choice: + choice = list(ctx.client.keys())[0] + if choice not in ctx.client: + raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice)) + # Create socket to test subject + sock = None + destination = ctx.client[choice] + family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET + sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if tcp: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + sock.settimeout(3) + if source: + sock.bind((source, 0)) + sock.connect(destination) + # Send query to client and wait for response + tstart = datetime.now() + while True: + try: + sendto_msg(sock, data_to_wire) + break + except OSError as ex: + # ENOBUFS, throttle sending + if ex.errno == errno.ENOBUFS: + time.sleep(0.1) + # Wait for a response for a reasonable time + answer = None + if not self.data[0].is_raw_data_entry: + while True: + if (datetime.now() - tstart).total_seconds() > 5: + raise RuntimeError("Server took too long to respond") + try: + answer, _ = recvfrom_msg(sock, True) + break + except OSError as ex: + if ex.errno == errno.ENOBUFS: + time.sleep(0.1) + # Track RTT + rtt = (datetime.now() - tstart).total_seconds() * 1000 + global g_rtt, g_nqueries + g_nqueries += 1 + g_rtt += rtt + # Remember last answer for checking later + self.raw_answer = answer + ctx.last_raw_answer = answer + if self.raw_answer is not None: + self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True) + else: + self.answer = None + ctx.last_answer = self.answer + + def __time_passes(self): + """ Modify system time. """ + file_old = os.environ["FAKETIME_TIMESTAMP_FILE"] + file_next = os.environ["FAKETIME_TIMESTAMP_FILE"] + ".next" + with open(file_old, 'r') as time_file: + line = time_file.readline().strip() + t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple()) + t += self.delay + with open(file_next, 'w') as time_file: + time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n") + time_file.flush() + os.replace(file_next, file_old) + + # def __assert(self, ctx): + # """ Assert that a passed expression evaluates to True. """ + # result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges}) + # # Evaluate subexpressions for clarity + # subexpr = [] + # for expr in self.args: + # try: + # ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges}) + # subexpr.append(str(ee)) + # except: + # subexpr.append(expr) + # assert result is True, '"%s" assertion fails (%s)' % ( + # ' '.join(self.args), ' '.join(subexpr)) + + +class Scenario: + log = logging.getLogger('pydnstest.scenatio.Scenario') + + def __init__(self, node, filename): + """ Initialize scenario with description. """ + self.node = node + self.info = node.value + self.file = filename + self.ranges = [Range(n) for n in node.match("/range")] + self.current_range = None + self.steps = [Step(n) for n in node.match("/step")] + self.current_step = None + self.client = {} + + def __str__(self): + txt = 'SCENARIO_BEGIN' + if self.info: + txt += ' {0}'.format(self.info) + txt += '\n' + for range_ in self.ranges: + txt += str(range_) + for step in self.steps: + txt += str(step) + txt += "\nSCENARIO_END" + return txt + + def reply(self, query: dns.message.Message, address=None) -> Optional[DNSBlob]: + """Generate answer packet for given query.""" + current_step_id = self.current_step.id + # Unknown address, select any match + # TODO: workaround until the server supports stub zones + all_addresses = set() # type: ignore + for rng in self.ranges: + all_addresses.update(rng.addresses) + if address not in all_addresses: + address = None + # Find current valid query response range + for rng in self.ranges: + if rng.eligible(current_step_id, address): + self.current_range = rng + return rng.reply(query) + # Find any prescripted one-shot replies + for step in self.steps: + if step.id < current_step_id or step.type != 'REPLY': + continue + try: + candidate = step.data[0] + candidate.match(query) + step.data.remove(candidate) + return candidate.reply(query) + except (IndexError, ValueError): + pass + return DNSReplyServfail(query) + + def play(self, paddr): + """ Play given scenario. """ + # Store test subject => address mapping + self.client = paddr + + step = None + i = 0 + while i < len(self.steps): + step = self.steps[i] + self.current_step = step + try: + step.play(self) + except ValueError as ex: + if step.repeat_if_fail > 0: + self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)", + step.id, ex, step.next_if_fail, step.repeat_if_fail) + step.repeat_if_fail -= 1 + if step.pause_if_fail > 0: + time.sleep(step.pause_if_fail) + if step.next_if_fail != -1: + next_steps = [j for j in range(len(self.steps)) if self.steps[ + j].id == step.next_if_fail] + if not next_steps: + raise ValueError('step %d: wrong NEXT value "%d"' % + (step.id, step.next_if_fail)) + next_step = next_steps[0] + if next_step < len(self.steps): + i = next_step + else: + raise ValueError('step %d: Can''t branch to NEXT value "%d"' % + (step.id, step.next_if_fail)) + continue + else: + raise ValueError('%s step %d %s' % (self.file, step.id, str(ex))) + i += 1 + + for r in self.ranges: + for e in r.stored: + if e.mandatory and e.fired == 0: + # TODO: cisla radku + raise ValueError('Mandatory section at %s not fired' % e.mandatory.span) + + +def get_next(file_in, skip_empty=True): + """ Return next token from the input stream. """ + while True: + line = file_in.readline() + if not line: + return False + quoted, escaped = False, False + for i, char in enumerate(line): + if char == '\\': + escaped = not escaped + if not escaped and char == '"': + quoted = not quoted + if char == ';' and not quoted: + line = line[0:i] + break + if char != '\\': + escaped = False + tokens = ' '.join(line.strip().split()).split() + if not tokens: + if skip_empty: + continue + else: + return '', [] + op = tokens.pop(0) + return op, tokens + + +def parse_config(scn_cfg, qmin, installdir): # FIXME: pylint: disable=too-many-statements + """ + Transform scene config (key, value) pairs into dict filled with defaults. + Returns tuple: + context dict: {Jinja2 variable: value} + trust anchor dict: {domain: [TA lines for particular domain]} + """ + # defaults + do_not_query_localhost = True + harden_glue = True + sockfamily = 0 # auto-select value for socket.getaddrinfo + trust_anchor_list = [] + trust_anchor_files = {} + negative_ta_list = [] + stub_addr = None + override_timestamp = None + + features = {} + feature_list_delimiter = ';' + feature_pair_delimiter = '=' + + for k, v in scn_cfg: + # Enable selectively for some tests + if k == 'do-not-query-localhost': + do_not_query_localhost = str2bool(v) + elif k == 'domain-insecure': + negative_ta_list.append(v) + elif k == 'harden-glue': + harden_glue = str2bool(v) + elif k == 'query-minimization': + qmin = str2bool(v) + elif k == 'trust-anchor': + trust_anchor = v.strip('"\'') + trust_anchor_list.append(trust_anchor) + domain = dns.name.from_text(trust_anchor.split()[0]).canonicalize() + if domain not in trust_anchor_files: + trust_anchor_files[domain] = [] + trust_anchor_files[domain].append(trust_anchor) + elif k == 'val-override-timestamp': + override_timestamp_str = v.strip('"\'') + override_timestamp = int(override_timestamp_str) + elif k == 'val-override-date': + override_date_str = v.strip('"\'') + ovr_yr = override_date_str[0:4] + ovr_mnt = override_date_str[4:6] + ovr_day = override_date_str[6:8] + ovr_hr = override_date_str[8:10] + ovr_min = override_date_str[10:12] + ovr_sec = override_date_str[12:] + override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format( + ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec) + override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S") + override_timestamp = calendar.timegm(override_date) + elif k == 'stub-addr': + stub_addr = v.strip('"\'') + elif k == 'features': + feature_list = v.split(feature_list_delimiter) + try: + for f_item in feature_list: + if f_item.find(feature_pair_delimiter) != -1: + f_key, f_value = [x.strip() + for x + in f_item.split(feature_pair_delimiter, 1)] + else: + f_key = f_item.strip() + f_value = "" + features[f_key] = f_value + except KeyError as ex: + raise KeyError("can't parse features (%s) in config section (%s)" % (v, str(ex))) + elif k == 'feature-list': + try: + f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)] + if f_key not in features: + features[f_key] = [] + f_value = f_value.replace("{{INSTALL_DIR}}", installdir) + features[f_key].append(f_value) + except KeyError as ex: + raise KeyError("can't parse feature-list (%s) in config section (%s)" + % (v, str(ex))) + elif k == 'force-ipv6' and v.upper() == 'TRUE': + sockfamily = socket.AF_INET6 + else: + raise NotImplementedError('unsupported CONFIG key "%s"' % k) + + ctx = { + "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(), + "NEGATIVE_TRUST_ANCHORS": negative_ta_list, + "FEATURES": features, + "HARDEN_GLUE": str(harden_glue).lower(), + "INSTALL_DIR": installdir, + "QMIN": str(qmin).lower(), + "TRUST_ANCHORS": trust_anchor_list, + "TRUST_ANCHOR_FILES": trust_anchor_files.keys() + } + if stub_addr: + ctx['ROOT_ADDR'] = stub_addr + # determine and verify socket family for specified root address + gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0, + socket.IPPROTO_UDP, socket.AI_NUMERICHOST) + assert len(gai) == 1 + sockfamily = gai[0][0] + if not sockfamily: + sockfamily = socket.AF_INET # default to IPv4 + ctx['_SOCKET_FAMILY'] = sockfamily + if override_timestamp: + ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp + return (ctx, trust_anchor_files) + + +def parse_file(path): + """ Parse scenario from a file. """ + + aug = pydnstest.augwrap.AugeasWrapper( + confpath=path, lens='Deckard', loadpath=os.path.dirname(__file__)) + node = aug.tree + config = [] + for line in [c.value for c in node.match("/config/*")]: + if line: + if not line.startswith(';'): + if '#' in line: + line = line[0:line.index('#')] + # Break to key-value pairs + # e.g.: ['minimization', 'on'] + kv = [x.strip() for x in line.split(':', 1)] + if len(kv) >= 2: + config.append(kv) + scenario = Scenario(node["/scenario"], posixpath.basename(node.path)) + return scenario, config |