diff options
Diffstat (limited to 'selftest/target/dns_hub.py')
-rwxr-xr-x | selftest/target/dns_hub.py | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/selftest/target/dns_hub.py b/selftest/target/dns_hub.py new file mode 100755 index 0000000..9f5e3dc --- /dev/null +++ b/selftest/target/dns_hub.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# +# Unix SMB/CIFS implementation. +# Copyright (C) Volker Lendecke 2017 +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Used by selftest to proxy DNS queries to the correct testenv DC. +# See selftest/target/README for more details. +# Based on the EchoServer example from python docs + +import threading +import sys +import select +import socket +import collections +import time +from samba.dcerpc import dns +import samba.ndr as ndr + +if sys.version_info[0] < 3: + import SocketServer + sserver = SocketServer +else: + import socketserver + sserver = socketserver + +DNS_REQUEST_TIMEOUT = 10 + +# make sure the script dies immediately when hitting control-C, +# rather than raising KeyboardInterrupt. As we do all database +# operations using transactions, this is safe. +import signal +signal.signal(signal.SIGINT, signal.SIG_DFL) + +class DnsHandler(sserver.BaseRequestHandler): + dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_')) + def dns_qtype_string(self, qtype): + "Return a readable qtype code" + return self.dns_qtype_strings[qtype] + + dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_')) + def dns_rcode_string(self, rcode): + "Return a readable error code" + return self.dns_rcode_strings[rcode] + + def dns_transaction_udp(self, packet, host): + "send a DNS query and read the reply" + s = None + flags = socket.AddressInfo.AI_NUMERICHOST + flags |= socket.AddressInfo.AI_NUMERICSERV + flags |= socket.AddressInfo.AI_PASSIVE + addr_info = socket.getaddrinfo(host, int(53), + type=socket.SocketKind.SOCK_DGRAM, + flags=flags) + assert len(addr_info) == 1 + try: + send_packet = ndr.ndr_pack(packet) + s = socket.socket(addr_info[0][0], addr_info[0][1], 0) + s.settimeout(DNS_REQUEST_TIMEOUT) + s.connect(addr_info[0][4]) + s.sendall(send_packet, 0) + recv_packet = s.recv(2048, 0) + return ndr.ndr_unpack(dns.name_packet, recv_packet) + except socket.error as err: + print("Error sending to host %s for name %s: %s\n" % + (host, packet.questions[0].name, err.errno)) + raise + finally: + if s is not None: + s.close() + + def get_pdc_ipv4_addr(self, lookup_name): + """Maps a DNS realm to the IPv4 address of the PDC for that testenv""" + + realm_to_ip_mappings = self.server.realm_to_ip_mappings + + # sort the realms so we find the longest-match first + testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len) + testenv_realms.reverse() + + for realm in testenv_realms: + if lookup_name.endswith(realm): + # return the corresponding IP address for this realm's PDC + return realm_to_ip_mappings[realm] + + return None + + def forwarder(self, name): + lname = name.lower() + + # check for special cases used by tests (e.g. dns_forwarder.py) + if lname.endswith('an-address-that-will-not-resolve'): + return 'ignore' + if lname.endswith('dsfsdfs'): + return 'fail' + if lname.endswith("torture1", 0, len(lname)-2): + # CATCH TORTURE100, TORTURE101, ... + return 'torture' + if lname.endswith('_none_.example.com'): + return 'torture' + if lname.endswith('torturedom.samba.example.com'): + return 'torture' + + # return the testenv PDC matching the realm being requested + return self.get_pdc_ipv4_addr(lname) + + def handle(self): + start = time.monotonic() + data, sock = self.request + query = ndr.ndr_unpack(dns.name_packet, data) + name = query.questions[0].name + forwarder = self.forwarder(name) + response = None + + if forwarder == 'ignore': + return + elif forwarder == 'fail': + pass + elif forwarder in ['torture', None]: + response = query + response.operation |= dns.DNS_FLAG_REPLY + response.operation |= dns.DNS_FLAG_RECURSION_AVAIL + response.operation |= dns.DNS_RCODE_NXDOMAIN + else: + try: + response = self.dns_transaction_udp(query, forwarder) + except OSError as err: + print("dns_hub: Error sending dns query to forwarder[%s] for name[%s]: %s" % + (forwarder, name, err)) + + if response is None: + response = query + response.operation |= dns.DNS_FLAG_REPLY + response.operation |= dns.DNS_FLAG_RECURSION_AVAIL + response.operation |= dns.DNS_RCODE_SERVFAIL + + send_packet = ndr.ndr_pack(response) + + end = time.monotonic() + tdiff = end - start + errcode = response.operation & dns.DNS_RCODE + if tdiff > (DNS_REQUEST_TIMEOUT/5): + debug = True + else: + debug = False + if debug: + print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" % + (forwarder, self.client_address, name, + self.dns_qtype_string(query.questions[0].question_type), + self.dns_rcode_string(errcode), response.operation, tdiff)) + + try: + sock.sendto(send_packet, self.client_address) + except socket.error as err: + print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" % + (self.client_address, name, tdiff, err)) + + +class server_thread(threading.Thread): + def __init__(self, server, name): + threading.Thread.__init__(self, name=name) + self.server = server + + def run(self): + print("dns_hub[%s]: before serve_forever()" % self.name) + self.server.serve_forever() + print("dns_hub[%s]: after serve_forever()" % self.name) + + def stop(self): + print("dns_hub[%s]: before shutdown()" % self.name) + self.server.shutdown() + print("dns_hub[%s]: after shutdown()" % self.name) + self.server.server_close() + +class UDPV4Server(sserver.UDPServer): + address_family = socket.AF_INET + +class UDPV6Server(sserver.UDPServer): + address_family = socket.AF_INET6 + +def main(): + if len(sys.argv) < 4: + print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]") + sys.exit(1) + + timeout = int(sys.argv[1]) * 1000 + timeout = min(timeout, 2**31 - 1) # poll with 32-bit int can't take more + # we pass in the listen addresses as a comma-separated string. + listenaddresses = sys.argv[2].split(',') + # we pass in the realm-to-IP mappings as a comma-separated key=value + # string. Convert this back into a dictionary that the DnsHandler can use + realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(',')) + + def prepare_server_thread(listenaddress, realm_mappings): + + flags = socket.AddressInfo.AI_NUMERICHOST + flags |= socket.AddressInfo.AI_NUMERICSERV + flags |= socket.AddressInfo.AI_PASSIVE + addr_info = socket.getaddrinfo(listenaddress, int(53), + type=socket.SocketKind.SOCK_DGRAM, + flags=flags) + assert len(addr_info) == 1 + if addr_info[0][0] == socket.AddressFamily.AF_INET6: + server = UDPV6Server(addr_info[0][4], DnsHandler) + else: + server = UDPV4Server(addr_info[0][4], DnsHandler) + + # we pass in the realm-to-IP mappings as a comma-separated key=value + # string. Convert this back into a dictionary that the DnsHandler can use + server.realm_to_ip_mappings = realm_mappings + t = server_thread(server, name="UDP[%s]" % listenaddress) + return t + + print("dns_hub will proxy DNS requests for the following realms:") + for realm, ip in realm_mappings.items(): + print(" {0} ==> {1}".format(realm, ip)) + + print("dns_hub will listen on the following UDP addresses:") + threads = [] + for listenaddress in listenaddresses: + print(" %s" % listenaddress) + t = prepare_server_thread(listenaddress, realm_mappings) + threads.append(t) + + for t in threads: + t.start() + p = select.poll() + stdin = sys.stdin.fileno() + p.register(stdin, select.POLLIN) + p.poll(timeout) + print("dns_hub: after poll()") + for t in threads: + t.stop() + for t in threads: + t.join() + print("dns_hub: before exit()") + sys.exit(0) + +main() |