summaryrefslogtreecommitdiffstats
path: root/python/samba/emulate
diff options
context:
space:
mode:
Diffstat (limited to 'python/samba/emulate')
-rw-r--r--python/samba/emulate/__init__.py16
-rw-r--r--python/samba/emulate/traffic.py2415
-rw-r--r--python/samba/emulate/traffic_packets.py973
3 files changed, 3404 insertions, 0 deletions
diff --git a/python/samba/emulate/__init__.py b/python/samba/emulate/__init__.py
new file mode 100644
index 0000000..110e19d
--- /dev/null
+++ b/python/samba/emulate/__init__.py
@@ -0,0 +1,16 @@
+# Package level initialisation
+#
+# Copyright (C) Catalyst IT Ltd. 2017
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
diff --git a/python/samba/emulate/traffic.py b/python/samba/emulate/traffic.py
new file mode 100644
index 0000000..4811fe8
--- /dev/null
+++ b/python/samba/emulate/traffic.py
@@ -0,0 +1,2415 @@
+# -*- encoding: utf-8 -*-
+# Samba traffic replay and learning
+#
+# Copyright (C) Catalyst IT Ltd. 2017
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+
+import time
+import os
+import random
+import json
+import math
+import sys
+import signal
+from errno import ECHILD, ESRCH
+
+from collections import OrderedDict, Counter, defaultdict, namedtuple
+from dns.resolver import query as dns_query
+
+from samba.emulate import traffic_packets
+from samba.samdb import SamDB
+import ldb
+from ldb import LdbError
+from samba.dcerpc import ClientConnection
+from samba.dcerpc import security, drsuapi, lsa
+from samba.dcerpc import netlogon
+from samba.dcerpc.netlogon import netr_Authenticator
+from samba.dcerpc import srvsvc
+from samba.dcerpc import samr
+from samba.drs_utils import drs_DsBind
+import traceback
+from samba.credentials import Credentials, DONT_USE_KERBEROS, MUST_USE_KERBEROS
+from samba.auth import system_session
+from samba.dsdb import (
+ UF_NORMAL_ACCOUNT,
+ UF_SERVER_TRUST_ACCOUNT,
+ UF_TRUSTED_FOR_DELEGATION,
+ UF_WORKSTATION_TRUST_ACCOUNT
+)
+from samba.dcerpc.misc import SEC_CHAN_BDC
+from samba import gensec
+from samba import sd_utils
+from samba.common import get_string
+from samba.logger import get_samba_logger
+import bisect
+
+CURRENT_MODEL_VERSION = 2 # save as this
+REQUIRED_MODEL_VERSION = 2 # load accepts this or greater
+SLEEP_OVERHEAD = 3e-4
+
+# we don't use None, because it complicates [de]serialisation
+NON_PACKET = '-'
+
+CLIENT_CLUES = {
+ ('dns', '0'): 1.0, # query
+ ('smb', '0x72'): 1.0, # Negotiate protocol
+ ('ldap', '0'): 1.0, # bind
+ ('ldap', '3'): 1.0, # searchRequest
+ ('ldap', '2'): 1.0, # unbindRequest
+ ('cldap', '3'): 1.0,
+ ('dcerpc', '11'): 1.0, # bind
+ ('dcerpc', '14'): 1.0, # Alter_context
+ ('nbns', '0'): 1.0, # query
+}
+
+SERVER_CLUES = {
+ ('dns', '1'): 1.0, # response
+ ('ldap', '1'): 1.0, # bind response
+ ('ldap', '4'): 1.0, # search result
+ ('ldap', '5'): 1.0, # search done
+ ('cldap', '5'): 1.0,
+ ('dcerpc', '12'): 1.0, # bind_ack
+ ('dcerpc', '13'): 1.0, # bind_nak
+ ('dcerpc', '15'): 1.0, # Alter_context response
+}
+
+SKIPPED_PROTOCOLS = {"smb", "smb2", "browser", "smb_netlogon"}
+
+WAIT_SCALE = 10.0
+WAIT_THRESHOLD = (1.0 / WAIT_SCALE)
+NO_WAIT_LOG_TIME_RANGE = (-10, -3)
+
+# DEBUG_LEVEL can be changed by scripts with -d
+DEBUG_LEVEL = 0
+
+LOGGER = get_samba_logger(name=__name__)
+
+
+def debug(level, msg, *args):
+ """Print a formatted debug message to standard error.
+
+
+ :param level: The debug level, message will be printed if it is <= the
+ currently set debug level. The debug level can be set with
+ the -d option.
+ :param msg: The message to be logged, can contain C-Style format
+ specifiers
+ :param args: The parameters required by the format specifiers
+ """
+ if level <= DEBUG_LEVEL:
+ if not args:
+ print(msg, file=sys.stderr)
+ else:
+ print(msg % tuple(args), file=sys.stderr)
+
+
+def debug_lineno(*args):
+ """ Print an unformatted log message to stderr, containing the line number
+ """
+ tb = traceback.extract_stack(limit=2)
+ print((" %s:" "\033[01;33m"
+ "%s " "\033[00m" % (tb[0][2], tb[0][1])), end=' ',
+ file=sys.stderr)
+ for a in args:
+ print(a, file=sys.stderr)
+ print(file=sys.stderr)
+ sys.stderr.flush()
+
+
+def random_colour_print(seeds):
+ """Return a function that prints a coloured line to stderr. The colour
+ of the line depends on a sort of hash of the integer arguments."""
+ if seeds:
+ s = 214
+ for x in seeds:
+ s += 17
+ s *= x
+ s %= 214
+ prefix = "\033[38;5;%dm" % (18 + s)
+
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print("%s%s\033[00m" % (prefix, a), file=sys.stderr)
+ else:
+ def p(*args):
+ if DEBUG_LEVEL > 0:
+ for a in args:
+ print(a, file=sys.stderr)
+
+ return p
+
+
+class FakePacketError(Exception):
+ pass
+
+
+class Packet(object):
+ """Details of a network packet"""
+ __slots__ = ('timestamp',
+ 'ip_protocol',
+ 'stream_number',
+ 'src',
+ 'dest',
+ 'protocol',
+ 'opcode',
+ 'desc',
+ 'extra',
+ 'endpoints')
+ def __init__(self, timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra):
+ self.timestamp = timestamp
+ self.ip_protocol = ip_protocol
+ self.stream_number = stream_number
+ self.src = src
+ self.dest = dest
+ self.protocol = protocol
+ self.opcode = opcode
+ self.desc = desc
+ self.extra = extra
+ if self.src < self.dest:
+ self.endpoints = (self.src, self.dest)
+ else:
+ self.endpoints = (self.dest, self.src)
+
+ @classmethod
+ def from_line(cls, line):
+ fields = line.rstrip('\n').split('\t')
+ (timestamp,
+ ip_protocol,
+ stream_number,
+ src,
+ dest,
+ protocol,
+ opcode,
+ desc) = fields[:8]
+ extra = fields[8:]
+
+ timestamp = float(timestamp)
+ src = int(src)
+ dest = int(dest)
+
+ return cls(timestamp, ip_protocol, stream_number, src, dest,
+ protocol, opcode, desc, extra)
+
+ def as_summary(self, time_offset=0.0):
+ """Format the packet as a traffic_summary line.
+ """
+ extra = '\t'.join(self.extra)
+ t = self.timestamp + time_offset
+ return (t, '%f\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' %
+ (t,
+ self.ip_protocol,
+ self.stream_number or '',
+ self.src,
+ self.dest,
+ self.protocol,
+ self.opcode,
+ self.desc,
+ extra))
+
+ def __str__(self):
+ return ("%.3f: %d -> %d; ip %s; strm %s; prot %s; op %s; desc %s %s" %
+ (self.timestamp, self.src, self.dest, self.ip_protocol or '-',
+ self.stream_number, self.protocol, self.opcode, self.desc,
+ ('«' + ' '.join(self.extra) + '»' if self.extra else '')))
+
+ def __repr__(self):
+ return "<Packet @%s>" % self
+
+ def copy(self):
+ return self.__class__(self.timestamp,
+ self.ip_protocol,
+ self.stream_number,
+ self.src,
+ self.dest,
+ self.protocol,
+ self.opcode,
+ self.desc,
+ self.extra)
+
+ def as_packet_type(self):
+ t = '%s:%s' % (self.protocol, self.opcode)
+ return t
+
+ def client_score(self):
+ """A positive number means we think it is a client; a negative number
+ means we think it is a server. Zero means no idea. range: -1 to 1.
+ """
+ key = (self.protocol, self.opcode)
+ if key in CLIENT_CLUES:
+ return CLIENT_CLUES[key]
+ if key in SERVER_CLUES:
+ return -SERVER_CLUES[key]
+ return 0.0
+
+ def play(self, conversation, context):
+ """Send the packet over the network, if required.
+
+ Some packets are ignored, i.e. for protocols not handled,
+ server response messages, or messages that are generated by the
+ protocol layer associated with other packets.
+ """
+ fn_name = 'packet_%s_%s' % (self.protocol, self.opcode)
+ try:
+ fn = getattr(traffic_packets, fn_name)
+
+ except AttributeError as e:
+ print("Conversation(%s) Missing handler %s" %
+ (conversation.conversation_id, fn_name),
+ file=sys.stderr)
+ return
+
+ # Don't display a message for kerberos packets, they're not directly
+ # generated they're used to indicate kerberos should be used
+ if self.protocol != "kerberos":
+ debug(2, "Conversation(%s) Calling handler %s" %
+ (conversation.conversation_id, fn_name))
+
+ start = time.time()
+ try:
+ if fn(self, conversation, context):
+ # Only collect timing data for functions that generate
+ # network traffic, or fail
+ end = time.time()
+ duration = end - start
+ print("%f\t%s\t%s\t%s\t%f\tTrue\t" %
+ (end, conversation.conversation_id, self.protocol,
+ self.opcode, duration))
+ except Exception as e:
+ end = time.time()
+ duration = end - start
+ print("%f\t%s\t%s\t%s\t%f\tFalse\t%s" %
+ (end, conversation.conversation_id, self.protocol,
+ self.opcode, duration, e))
+
+ def __cmp__(self, other):
+ return self.timestamp - other.timestamp
+
+ def is_really_a_packet(self, missing_packet_stats=None):
+ return is_a_real_packet(self.protocol, self.opcode)
+
+
+def is_a_real_packet(protocol, opcode):
+ """Is the packet one that can be ignored?
+
+ If so removing it will have no effect on the replay
+ """
+ if protocol in SKIPPED_PROTOCOLS:
+ # Ignore any packets for the protocols we're not interested in.
+ return False
+ if protocol == "ldap" and opcode == '':
+ # skip ldap continuation packets
+ return False
+
+ fn_name = 'packet_%s_%s' % (protocol, opcode)
+ fn = getattr(traffic_packets, fn_name, None)
+ if fn is None:
+ LOGGER.debug("missing packet %s" % fn_name, file=sys.stderr)
+ return False
+ if fn is traffic_packets.null_packet:
+ return False
+ return True
+
+
+def is_a_traffic_generating_packet(protocol, opcode):
+ """Return true if a packet generates traffic in its own right. Some of
+ these will generate traffic in certain contexts (e.g. ldap unbind
+ after a bind) but not if the conversation consists only of these packets.
+ """
+ if protocol == 'wait':
+ return False
+
+ if (protocol, opcode) in (
+ ('kerberos', ''),
+ ('ldap', '2'),
+ ('dcerpc', '15'),
+ ('dcerpc', '16')):
+ return False
+
+ return is_a_real_packet(protocol, opcode)
+
+
+class ReplayContext(object):
+ """State/Context for a conversation between an simulated client and a
+ server. Some of the context is shared amongst all conversations
+ and should be generated before the fork, while other context is
+ specific to a particular conversation and should be generated
+ *after* the fork, in generate_process_local_config().
+ """
+ def __init__(self,
+ server=None,
+ lp=None,
+ creds=None,
+ total_conversations=None,
+ badpassword_frequency=None,
+ prefer_kerberos=None,
+ tempdir=None,
+ statsdir=None,
+ ou=None,
+ base_dn=None,
+ domain=os.environ.get("DOMAIN"),
+ domain_sid=None,
+ instance_id=None):
+ self.server = server
+ self.netlogon_connection = None
+ self.creds = creds
+ self.lp = lp
+ if prefer_kerberos:
+ self.kerberos_state = MUST_USE_KERBEROS
+ else:
+ self.kerberos_state = DONT_USE_KERBEROS
+ self.ou = ou
+ self.base_dn = base_dn
+ self.domain = domain
+ self.statsdir = statsdir
+ self.global_tempdir = tempdir
+ self.domain_sid = domain_sid
+ self.realm = lp.get('realm')
+ self.instance_id = instance_id
+
+ # Bad password attempt controls
+ self.badpassword_frequency = badpassword_frequency
+ self.last_lsarpc_bad = False
+ self.last_lsarpc_named_bad = False
+ self.last_simple_bind_bad = False
+ self.last_bind_bad = False
+ self.last_srvsvc_bad = False
+ self.last_drsuapi_bad = False
+ self.last_netlogon_bad = False
+ self.last_samlogon_bad = False
+ self.total_conversations = total_conversations
+ self.generate_ldap_search_tables()
+
+ def generate_ldap_search_tables(self):
+ session = system_session()
+
+ db = SamDB(url="ldap://%s" % self.server,
+ session_info=session,
+ credentials=self.creds,
+ lp=self.lp)
+
+ res = db.search(db.domain_dn(),
+ scope=ldb.SCOPE_SUBTREE,
+ controls=["paged_results:1:1000"],
+ attrs=['dn'])
+
+ # find a list of dns for each pattern
+ # e.g. CN,CN,CN,DC,DC
+ dn_map = {}
+ attribute_clue_map = {
+ 'invocationId': []
+ }
+
+ for r in res:
+ dn = str(r.dn)
+ pattern = ','.join(x.lstrip()[:2] for x in dn.split(',')).upper()
+ dns = dn_map.setdefault(pattern, [])
+ dns.append(dn)
+ if dn.startswith('CN=NTDS Settings,'):
+ attribute_clue_map['invocationId'].append(dn)
+
+ # extend the map in case we are working with a different
+ # number of DC components.
+ # for k, v in self.dn_map.items():
+ # print >>sys.stderr, k, len(v)
+
+ for k in list(dn_map.keys()):
+ if k[-3:] != ',DC':
+ continue
+ p = k[:-3]
+ while p[-3:] == ',DC':
+ p = p[:-3]
+ for i in range(5):
+ p += ',DC'
+ if p != k and p in dn_map:
+ print('dn_map collision %s %s' % (k, p),
+ file=sys.stderr)
+ continue
+ dn_map[p] = dn_map[k]
+
+ self.dn_map = dn_map
+ self.attribute_clue_map = attribute_clue_map
+
+ # pre-populate DN-based search filters (it's simplest to generate them
+ # once, when the test starts). These are used by guess_search_filter()
+ # to avoid full-scans
+ self.search_filters = {}
+
+ # lookup all the GPO DNs
+ res = db.search(db.domain_dn(), scope=ldb.SCOPE_SUBTREE, attrs=['dn'],
+ expression='(objectclass=groupPolicyContainer)')
+ gpos_by_dn = "".join("(distinguishedName={0})".format(msg['dn']) for msg in res)
+
+ # a search for the 'gPCFileSysPath' attribute is probably a GPO search
+ # (as per the MS-GPOL spec) which searches for GPOs by DN
+ self.search_filters['gPCFileSysPath'] = "(|{0})".format(gpos_by_dn)
+
+ # likewise, a search for gpLink is probably the Domain SOM search part
+ # of the MS-GPOL, in which case it's looking up a few OUs by DN
+ ou_str = ""
+ for ou in ["Domain Controllers,", "traffic_replay,", ""]:
+ ou_str += "(distinguishedName={0}{1})".format(ou, db.domain_dn())
+ self.search_filters['gpLink'] = "(|{0})".format(ou_str)
+
+ # The CEP Web Service can query the AD DC to get pKICertificateTemplate
+ # objects (as per MS-WCCE)
+ self.search_filters['pKIExtendedKeyUsage'] = \
+ '(objectCategory=pKICertificateTemplate)'
+
+ # assume that anything querying the usnChanged is some kind of
+ # synchronization tool, e.g. AD Change Detection Connector
+ res = db.search('', scope=ldb.SCOPE_BASE, attrs=['highestCommittedUSN'])
+ self.search_filters['usnChanged'] = \
+ '(usnChanged>={0})'.format(res[0]['highestCommittedUSN'])
+
+ # The traffic_learner script doesn't preserve the LDAP search filter, and
+ # having no filter can result in a full DB scan. This is costly for a large
+ # DB, and not necessarily representative of real world traffic. As there
+ # several standard LDAP queries that get used by AD tools, we can apply
+ # some logic and guess what the search filter might have been originally.
+ def guess_search_filter(self, attrs, dn_sig, dn):
+
+ # there are some standard spec-based searches that query fairly unique
+ # attributes. Check if the search is likely one of these
+ for key in self.search_filters.keys():
+ if key in attrs:
+ return self.search_filters[key]
+
+ # if it's the top-level domain, assume we're looking up a single user,
+ # e.g. like powershell Get-ADUser or a similar tool
+ if dn_sig == 'DC,DC':
+ random_user_id = random.random() % self.total_conversations
+ account_name = user_name(self.instance_id, random_user_id)
+ return '(&(sAMAccountName=%s)(objectClass=user))' % account_name
+
+ # otherwise just return everything in the sub-tree
+ return '(objectClass=*)'
+
+ def generate_process_local_config(self, account, conversation):
+ self.ldap_connections = []
+ self.dcerpc_connections = []
+ self.lsarpc_connections = []
+ self.lsarpc_connections_named = []
+ self.drsuapi_connections = []
+ self.srvsvc_connections = []
+ self.samr_contexts = []
+ self.netbios_name = account.netbios_name
+ self.machinepass = account.machinepass
+ self.username = account.username
+ self.userpass = account.userpass
+
+ self.tempdir = mk_masked_dir(self.global_tempdir,
+ 'conversation-%d' %
+ conversation.conversation_id)
+
+ self.lp.set("private dir", self.tempdir)
+ self.lp.set("lock dir", self.tempdir)
+ self.lp.set("state directory", self.tempdir)
+ self.lp.set("tls verify peer", "no_check")
+
+ self.remoteAddress = "/root/ncalrpc_as_system"
+ self.samlogon_dn = ("cn=%s,%s" %
+ (self.netbios_name, self.ou))
+ self.user_dn = ("cn=%s,%s" %
+ (self.username, self.ou))
+
+ self.generate_machine_creds()
+ self.generate_user_creds()
+
+ def with_random_bad_credentials(self, f, good, bad, failed_last_time):
+ """Execute the supplied logon function, randomly choosing the
+ bad credentials.
+
+ Based on the frequency in badpassword_frequency randomly perform the
+ function with the supplied bad credentials.
+ If run with bad credentials, the function is re-run with the good
+ credentials.
+ failed_last_time is used to prevent consecutive bad credential
+ attempts. So the over all bad credential frequency will be lower
+ than that requested, but not significantly.
+ """
+ if not failed_last_time:
+ if (self.badpassword_frequency and
+ random.random() < self.badpassword_frequency):
+ try:
+ f(bad)
+ except Exception:
+ # Ignore any exceptions as the operation may fail
+ # as it's being performed with bad credentials
+ pass
+ failed_last_time = True
+ else:
+ failed_last_time = False
+
+ result = f(good)
+ return (result, failed_last_time)
+
+ def generate_user_creds(self):
+ """Generate the conversation specific user Credentials.
+
+ Each Conversation has an associated user account used to simulate
+ any non Administrative user traffic.
+
+ Generates user credentials with good and bad passwords and ldap
+ simple bind credentials with good and bad passwords.
+ """
+ self.user_creds = Credentials()
+ self.user_creds.guess(self.lp)
+ self.user_creds.set_workstation(self.netbios_name)
+ self.user_creds.set_password(self.userpass)
+ self.user_creds.set_username(self.username)
+ self.user_creds.set_domain(self.domain)
+ self.user_creds.set_kerberos_state(self.kerberos_state)
+
+ self.user_creds_bad = Credentials()
+ self.user_creds_bad.guess(self.lp)
+ self.user_creds_bad.set_workstation(self.netbios_name)
+ self.user_creds_bad.set_password(self.userpass[:-4])
+ self.user_creds_bad.set_username(self.username)
+ self.user_creds_bad.set_kerberos_state(self.kerberos_state)
+
+ # Credentials for ldap simple bind.
+ self.simple_bind_creds = Credentials()
+ self.simple_bind_creds.guess(self.lp)
+ self.simple_bind_creds.set_workstation(self.netbios_name)
+ self.simple_bind_creds.set_password(self.userpass)
+ self.simple_bind_creds.set_username(self.username)
+ self.simple_bind_creds.set_gensec_features(
+ self.simple_bind_creds.get_gensec_features() | gensec.FEATURE_SEAL)
+ self.simple_bind_creds.set_kerberos_state(self.kerberos_state)
+ self.simple_bind_creds.set_bind_dn(self.user_dn)
+
+ self.simple_bind_creds_bad = Credentials()
+ self.simple_bind_creds_bad.guess(self.lp)
+ self.simple_bind_creds_bad.set_workstation(self.netbios_name)
+ self.simple_bind_creds_bad.set_password(self.userpass[:-4])
+ self.simple_bind_creds_bad.set_username(self.username)
+ self.simple_bind_creds_bad.set_gensec_features(
+ self.simple_bind_creds_bad.get_gensec_features() |
+ gensec.FEATURE_SEAL)
+ self.simple_bind_creds_bad.set_kerberos_state(self.kerberos_state)
+ self.simple_bind_creds_bad.set_bind_dn(self.user_dn)
+
+ def generate_machine_creds(self):
+ """Generate the conversation specific machine Credentials.
+
+ Each Conversation has an associated machine account.
+
+ Generates machine credentials with good and bad passwords.
+ """
+
+ self.machine_creds = Credentials()
+ self.machine_creds.guess(self.lp)
+ self.machine_creds.set_workstation(self.netbios_name)
+ self.machine_creds.set_secure_channel_type(SEC_CHAN_BDC)
+ self.machine_creds.set_password(self.machinepass)
+ self.machine_creds.set_username(self.netbios_name + "$")
+ self.machine_creds.set_domain(self.domain)
+ self.machine_creds.set_kerberos_state(self.kerberos_state)
+
+ self.machine_creds_bad = Credentials()
+ self.machine_creds_bad.guess(self.lp)
+ self.machine_creds_bad.set_workstation(self.netbios_name)
+ self.machine_creds_bad.set_secure_channel_type(SEC_CHAN_BDC)
+ self.machine_creds_bad.set_password(self.machinepass[:-4])
+ self.machine_creds_bad.set_username(self.netbios_name + "$")
+ self.machine_creds_bad.set_kerberos_state(self.kerberos_state)
+
+ def get_matching_dn(self, pattern, attributes=None):
+ # If the pattern is an empty string, we assume ROOTDSE,
+ # Otherwise we try adding or removing DC suffixes, then
+ # shorter leading patterns until we hit one.
+ # e.g if there is no CN,CN,CN,CN,DC,DC
+ # we first try CN,CN,CN,CN,DC
+ # and CN,CN,CN,CN,DC,DC,DC
+ # then change to CN,CN,CN,DC,DC
+ # and as last resort we use the base_dn
+ attr_clue = self.attribute_clue_map.get(attributes)
+ if attr_clue:
+ return random.choice(attr_clue)
+
+ pattern = pattern.upper()
+ while pattern:
+ if pattern in self.dn_map:
+ return random.choice(self.dn_map[pattern])
+ # chop one off the front and try it all again.
+ pattern = pattern[3:]
+
+ return self.base_dn
+
+ def get_dcerpc_connection(self, new=False):
+ guid = '12345678-1234-abcd-ef00-01234567cffb' # RPC_NETLOGON UUID
+ if self.dcerpc_connections and not new:
+ return self.dcerpc_connections[-1]
+ c = ClientConnection("ncacn_ip_tcp:%s" % self.server,
+ (guid, 1), self.lp)
+ self.dcerpc_connections.append(c)
+ return c
+
+ def get_srvsvc_connection(self, new=False):
+ if self.srvsvc_connections and not new:
+ return self.srvsvc_connections[-1]
+
+ def connect(creds):
+ return srvsvc.srvsvc("ncacn_np:%s" % (self.server),
+ self.lp,
+ creds)
+
+ (c, self.last_srvsvc_bad) = \
+ self.with_random_bad_credentials(connect,
+ self.user_creds,
+ self.user_creds_bad,
+ self.last_srvsvc_bad)
+
+ self.srvsvc_connections.append(c)
+ return c
+
+ def get_lsarpc_connection(self, new=False):
+ if self.lsarpc_connections and not new:
+ return self.lsarpc_connections[-1]
+
+ def connect(creds):
+ binding_options = 'schannel,seal,sign'
+ return lsa.lsarpc("ncacn_ip_tcp:%s[%s]" %
+ (self.server, binding_options),
+ self.lp,
+ creds)
+
+ (c, self.last_lsarpc_bad) = \
+ self.with_random_bad_credentials(connect,
+ self.machine_creds,
+ self.machine_creds_bad,
+ self.last_lsarpc_bad)
+
+ self.lsarpc_connections.append(c)
+ return c
+
+ def get_lsarpc_named_pipe_connection(self, new=False):
+ if self.lsarpc_connections_named and not new:
+ return self.lsarpc_connections_named[-1]
+
+ def connect(creds):
+ return lsa.lsarpc("ncacn_np:%s" % (self.server),
+ self.lp,
+ creds)
+
+ (c, self.last_lsarpc_named_bad) = \
+ self.with_random_bad_credentials(connect,
+ self.machine_creds,
+ self.machine_creds_bad,
+ self.last_lsarpc_named_bad)
+
+ self.lsarpc_connections_named.append(c)
+ return c
+
+ def get_drsuapi_connection_pair(self, new=False, unbind=False):
+ """get a (drs, drs_handle) tuple"""
+ if self.drsuapi_connections and not new:
+ c = self.drsuapi_connections[-1]
+ return c
+
+ def connect(creds):
+ binding_options = 'seal'
+ binding_string = "ncacn_ip_tcp:%s[%s]" %\
+ (self.server, binding_options)
+ return drsuapi.drsuapi(binding_string, self.lp, creds)
+
+ (drs, self.last_drsuapi_bad) = \
+ self.with_random_bad_credentials(connect,
+ self.user_creds,
+ self.user_creds_bad,
+ self.last_drsuapi_bad)
+
+ (drs_handle, supported_extensions) = drs_DsBind(drs)
+ c = (drs, drs_handle)
+ self.drsuapi_connections.append(c)
+ return c
+
+ def get_ldap_connection(self, new=False, simple=False):
+ if self.ldap_connections and not new:
+ return self.ldap_connections[-1]
+
+ def simple_bind(creds):
+ """
+ To run simple bind against Windows, we need to run
+ following commands in PowerShell:
+
+ Install-windowsfeature ADCS-Cert-Authority
+ Install-AdcsCertificationAuthority -CAType EnterpriseRootCA
+ Restart-Computer
+
+ """
+ return SamDB('ldaps://%s' % self.server,
+ credentials=creds,
+ lp=self.lp)
+
+ def sasl_bind(creds):
+ return SamDB('ldap://%s' % self.server,
+ credentials=creds,
+ lp=self.lp)
+ if simple:
+ (samdb, self.last_simple_bind_bad) = \
+ self.with_random_bad_credentials(simple_bind,
+ self.simple_bind_creds,
+ self.simple_bind_creds_bad,
+ self.last_simple_bind_bad)
+ else:
+ (samdb, self.last_bind_bad) = \
+ self.with_random_bad_credentials(sasl_bind,
+ self.user_creds,
+ self.user_creds_bad,
+ self.last_bind_bad)
+
+ self.ldap_connections.append(samdb)
+ return samdb
+
+ def get_samr_context(self, new=False):
+ if not self.samr_contexts or new:
+ self.samr_contexts.append(
+ SamrContext(self.server, lp=self.lp, creds=self.creds))
+ return self.samr_contexts[-1]
+
+ def get_netlogon_connection(self):
+
+ if self.netlogon_connection:
+ return self.netlogon_connection
+
+ def connect(creds):
+ return netlogon.netlogon("ncacn_ip_tcp:%s[schannel,seal]" %
+ (self.server),
+ self.lp,
+ creds)
+ (c, self.last_netlogon_bad) = \
+ self.with_random_bad_credentials(connect,
+ self.machine_creds,
+ self.machine_creds_bad,
+ self.last_netlogon_bad)
+ self.netlogon_connection = c
+ return c
+
+ def guess_a_dns_lookup(self):
+ return (self.realm, 'A')
+
+ def get_authenticator(self):
+ auth = self.machine_creds.new_client_authenticator()
+ current = netr_Authenticator()
+ current.cred.data = [x if isinstance(x, int) else ord(x)
+ for x in auth["credential"]]
+ current.timestamp = auth["timestamp"]
+
+ subsequent = netr_Authenticator()
+ return (current, subsequent)
+
+ def write_stats(self, filename, **kwargs):
+ """Write arbitrary key/value pairs to a file in our stats directory in
+ order for them to be picked up later by another process working out
+ statistics."""
+ filename = os.path.join(self.statsdir, filename)
+ f = open(filename, 'w')
+ for k, v in kwargs.items():
+ print("%s: %s" % (k, v), file=f)
+ f.close()
+
+
+class SamrContext(object):
+ """State/Context associated with a samr connection.
+ """
+ def __init__(self, server, lp=None, creds=None):
+ self.connection = None
+ self.handle = None
+ self.domain_handle = None
+ self.domain_sid = None
+ self.group_handle = None
+ self.user_handle = None
+ self.rids = None
+ self.server = server
+ self.lp = lp
+ self.creds = creds
+
+ def get_connection(self):
+ if not self.connection:
+ self.connection = samr.samr(
+ "ncacn_ip_tcp:%s[seal]" % (self.server),
+ lp_ctx=self.lp,
+ credentials=self.creds)
+
+ return self.connection
+
+ def get_handle(self):
+ if not self.handle:
+ c = self.get_connection()
+ self.handle = c.Connect2(None, security.SEC_FLAG_MAXIMUM_ALLOWED)
+ return self.handle
+
+
+class Conversation(object):
+ """Details of a converation between a simulated client and a server."""
+ def __init__(self, start_time=None, endpoints=None, seq=(),
+ conversation_id=None):
+ self.start_time = start_time
+ self.endpoints = endpoints
+ self.packets = []
+ self.msg = random_colour_print(endpoints)
+ self.client_balance = 0.0
+ self.conversation_id = conversation_id
+ for p in seq:
+ self.add_short_packet(*p)
+
+ def __cmp__(self, other):
+ if self.start_time is None:
+ if other.start_time is None:
+ return 0
+ return -1
+ if other.start_time is None:
+ return 1
+ return self.start_time - other.start_time
+
+ def add_packet(self, packet):
+ """Add a packet object to this conversation, making a local copy with
+ a conversation-relative timestamp."""
+ p = packet.copy()
+
+ if self.start_time is None:
+ self.start_time = p.timestamp
+
+ if self.endpoints is None:
+ self.endpoints = p.endpoints
+
+ if p.endpoints != self.endpoints:
+ raise FakePacketError("Conversation endpoints %s don't match"
+ "packet endpoints %s" %
+ (self.endpoints, p.endpoints))
+
+ p.timestamp -= self.start_time
+
+ if p.src == p.endpoints[0]:
+ self.client_balance -= p.client_score()
+ else:
+ self.client_balance += p.client_score()
+
+ if p.is_really_a_packet():
+ self.packets.append(p)
+
+ def add_short_packet(self, timestamp, protocol, opcode, extra,
+ client=True, skip_unused_packets=True):
+ """Create a packet from a timestamp, and 'protocol:opcode' pair, and a
+ (possibly empty) list of extra data. If client is True, assume
+ this packet is from the client to the server.
+ """
+ if skip_unused_packets and not is_a_real_packet(protocol, opcode):
+ return
+
+ src, dest = self.guess_client_server()
+ if not client:
+ src, dest = dest, src
+ key = (protocol, opcode)
+ desc = OP_DESCRIPTIONS.get(key, '')
+ ip_protocol = IP_PROTOCOLS.get(protocol, '06')
+ packet = Packet(timestamp - self.start_time, ip_protocol,
+ '', src, dest,
+ protocol, opcode, desc, extra)
+ # XXX we're assuming the timestamp is already adjusted for
+ # this conversation?
+ # XXX should we adjust client balance for guessed packets?
+ if packet.src == packet.endpoints[0]:
+ self.client_balance -= packet.client_score()
+ else:
+ self.client_balance += packet.client_score()
+ if packet.is_really_a_packet():
+ self.packets.append(packet)
+
+ def __str__(self):
+ return ("<Conversation %s %s starting %.3f %d packets>" %
+ (self.conversation_id, self.endpoints, self.start_time,
+ len(self.packets)))
+
+ __repr__ = __str__
+
+ def __iter__(self):
+ return iter(self.packets)
+
+ def __len__(self):
+ return len(self.packets)
+
+ def get_duration(self):
+ if len(self.packets) < 2:
+ return 0
+ return self.packets[-1].timestamp - self.packets[0].timestamp
+
+ def replay_as_summary_lines(self):
+ return [p.as_summary(self.start_time) for p in self.packets]
+
+ def replay_with_delay(self, start, context=None, account=None):
+ """Replay the conversation at the right time.
+ (We're already in a fork)."""
+ # first we sleep until the first packet
+ t = self.start_time
+ now = time.time() - start
+ gap = t - now
+ sleep_time = gap - SLEEP_OVERHEAD
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+
+ miss = (time.time() - start) - t
+ self.msg("starting %s [miss %.3f]" % (self, miss))
+
+ max_gap = 0.0
+ max_sleep_miss = 0.0
+ # packet times are relative to conversation start
+ p_start = time.time()
+ for p in self.packets:
+ now = time.time() - p_start
+ gap = now - p.timestamp
+ if gap > max_gap:
+ max_gap = gap
+ if gap < 0:
+ sleep_time = -gap - SLEEP_OVERHEAD
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+ t = time.time() - p_start
+ if t - p.timestamp > max_sleep_miss:
+ max_sleep_miss = t - p.timestamp
+
+ p.play(self, context)
+
+ return max_gap, miss, max_sleep_miss
+
+ def guess_client_server(self, server_clue=None):
+ """Have a go at deciding who is the server and who is the client.
+ returns (client, server)
+ """
+ a, b = self.endpoints
+
+ if self.client_balance < 0:
+ return (a, b)
+
+ # in the absence of a clue, we will fall through to assuming
+ # the lowest number is the server (which is usually true).
+
+ if self.client_balance == 0 and server_clue == b:
+ return (a, b)
+
+ return (b, a)
+
+ def forget_packets_outside_window(self, s, e):
+ """Prune any packets outside the time window we're interested in
+
+ :param s: start of the window
+ :param e: end of the window
+ """
+ self.packets = [p for p in self.packets if s <= p.timestamp <= e]
+ self.start_time = self.packets[0].timestamp if self.packets else None
+
+ def renormalise_times(self, start_time):
+ """Adjust the packet start times relative to the new start time."""
+ for p in self.packets:
+ p.timestamp -= start_time
+
+ if self.start_time is not None:
+ self.start_time -= start_time
+
+
+class DnsHammer(Conversation):
+ """A lightweight conversation that generates a lot of dns:0 packets on
+ the fly"""
+
+ def __init__(self, dns_rate, duration, query_file=None):
+ n = int(dns_rate * duration)
+ self.times = [random.uniform(0, duration) for i in range(n)]
+ self.times.sort()
+ self.rate = dns_rate
+ self.duration = duration
+ self.start_time = 0
+ self.query_choices = self._get_query_choices(query_file=query_file)
+
+ def __str__(self):
+ return ("<DnsHammer %d packets over %.1fs (rate %.2f)>" %
+ (len(self.times), self.duration, self.rate))
+
+ def _get_query_choices(self, query_file=None):
+ """
+ Read dns query choices from a file, or return default
+
+ rname may contain format string like `{realm}`
+ realm can be fetched from context.realm
+ """
+
+ if query_file:
+ with open(query_file, 'r') as f:
+ text = f.read()
+ choices = []
+ for line in text.splitlines():
+ line = line.strip()
+ if line and not line.startswith('#'):
+ args = line.split(',')
+ assert len(args) == 4
+ choices.append(args)
+ return choices
+ else:
+ return [
+ (0, '{realm}', 'A', 'yes'),
+ (1, '{realm}', 'NS', 'yes'),
+ (2, '*.{realm}', 'A', 'no'),
+ (3, '*.{realm}', 'NS', 'no'),
+ (10, '_msdcs.{realm}', 'A', 'yes'),
+ (11, '_msdcs.{realm}', 'NS', 'yes'),
+ (20, 'nx.realm.com', 'A', 'no'),
+ (21, 'nx.realm.com', 'NS', 'no'),
+ (22, '*.nx.realm.com', 'A', 'no'),
+ (23, '*.nx.realm.com', 'NS', 'no'),
+ ]
+
+ def replay(self, context=None):
+ assert context
+ assert context.realm
+ start = time.time()
+ for t in self.times:
+ now = time.time() - start
+ gap = t - now
+ sleep_time = gap - SLEEP_OVERHEAD
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+
+ opcode, rname, rtype, exist = random.choice(self.query_choices)
+ rname = rname.format(realm=context.realm)
+ success = True
+ packet_start = time.time()
+ try:
+ answers = dns_query(rname, rtype)
+ if exist == 'yes' and not len(answers):
+ # expect answers but didn't get, fail
+ success = False
+ except Exception:
+ success = False
+ finally:
+ end = time.time()
+ duration = end - packet_start
+ print("%f\tDNS\tdns\t%s\t%f\t%s\t" % (end, opcode, duration, success))
+
+
+def ingest_summaries(files, dns_mode='count'):
+ """Load a summary traffic summary file and generated Converations from it.
+ """
+
+ dns_counts = defaultdict(int)
+ packets = []
+ for f in files:
+ if isinstance(f, str):
+ f = open(f)
+ print("Ingesting %s" % (f.name,), file=sys.stderr)
+ for line in f:
+ p = Packet.from_line(line)
+ if p.protocol == 'dns' and dns_mode != 'include':
+ dns_counts[p.opcode] += 1
+ else:
+ packets.append(p)
+
+ f.close()
+
+ if not packets:
+ return [], 0
+
+ start_time = min(p.timestamp for p in packets)
+ last_packet = max(p.timestamp for p in packets)
+
+ print("gathering packets into conversations", file=sys.stderr)
+ conversations = OrderedDict()
+ for i, p in enumerate(packets):
+ p.timestamp -= start_time
+ c = conversations.get(p.endpoints)
+ if c is None:
+ c = Conversation(conversation_id=(i + 2))
+ conversations[p.endpoints] = c
+ c.add_packet(p)
+
+ # We only care about conversations with actual traffic, so we
+ # filter out conversations with nothing to say. We do that here,
+ # rather than earlier, because those empty packets contain useful
+ # hints as to which end of the conversation was the client.
+ conversation_list = []
+ for c in conversations.values():
+ if len(c) != 0:
+ conversation_list.append(c)
+
+ # This is obviously not correct, as many conversations will appear
+ # to start roughly simultaneously at the beginning of the snapshot.
+ # To which we say: oh well, so be it.
+ duration = float(last_packet - start_time)
+ mean_interval = len(conversations) / duration
+
+ return conversation_list, mean_interval, duration, dns_counts
+
+
+def guess_server_address(conversations):
+ # we guess the most common address.
+ addresses = Counter()
+ for c in conversations:
+ addresses.update(c.endpoints)
+ if addresses:
+ return addresses.most_common(1)[0]
+
+
+def stringify_keys(x):
+ y = {}
+ for k, v in x.items():
+ k2 = '\t'.join(k)
+ y[k2] = v
+ return y
+
+
+def unstringify_keys(x):
+ y = {}
+ for k, v in x.items():
+ t = tuple(str(k).split('\t'))
+ y[t] = v
+ return y
+
+
+class TrafficModel(object):
+ def __init__(self, n=3):
+ self.ngrams = {}
+ self.query_details = {}
+ self.n = n
+ self.dns_opcounts = defaultdict(int)
+ self.cumulative_duration = 0.0
+ self.packet_rate = [0, 1]
+
+ def learn(self, conversations, dns_opcounts=None):
+ if dns_opcounts is None:
+ dns_opcounts = {}
+ prev = 0.0
+ cum_duration = 0.0
+ key = (NON_PACKET,) * (self.n - 1)
+
+ server = guess_server_address(conversations)
+
+ for k, v in dns_opcounts.items():
+ self.dns_opcounts[k] += v
+
+ if len(conversations) > 1:
+ first = conversations[0].start_time
+ total = 0
+ last = first + 0.1
+ for c in conversations:
+ total += len(c)
+ last = max(last, c.packets[-1].timestamp)
+
+ self.packet_rate[0] = total
+ self.packet_rate[1] = last - first
+
+ for c in conversations:
+ client, server = c.guess_client_server(server)
+ cum_duration += c.get_duration()
+ key = (NON_PACKET,) * (self.n - 1)
+ for p in c:
+ if p.src != client:
+ continue
+
+ elapsed = p.timestamp - prev
+ prev = p.timestamp
+ if elapsed > WAIT_THRESHOLD:
+ # add the wait as an extra state
+ wait = 'wait:%d' % (math.log(max(1.0,
+ elapsed * WAIT_SCALE)))
+ self.ngrams.setdefault(key, []).append(wait)
+ key = key[1:] + (wait,)
+
+ short_p = p.as_packet_type()
+ self.query_details.setdefault(short_p,
+ []).append(tuple(p.extra))
+ self.ngrams.setdefault(key, []).append(short_p)
+ key = key[1:] + (short_p,)
+
+ self.cumulative_duration += cum_duration
+ # add in the end
+ self.ngrams.setdefault(key, []).append(NON_PACKET)
+
+ def save(self, f):
+ ngrams = {}
+ for k, v in self.ngrams.items():
+ k = '\t'.join(k)
+ ngrams[k] = dict(Counter(v))
+
+ query_details = {}
+ for k, v in self.query_details.items():
+ query_details[k] = dict(Counter('\t'.join(x) if x else '-'
+ for x in v))
+
+ d = {
+ 'ngrams': ngrams,
+ 'query_details': query_details,
+ 'cumulative_duration': self.cumulative_duration,
+ 'packet_rate': self.packet_rate,
+ 'version': CURRENT_MODEL_VERSION
+ }
+ d['dns'] = self.dns_opcounts
+
+ if isinstance(f, str):
+ f = open(f, 'w')
+
+ json.dump(d, f, indent=2)
+
+ def load(self, f):
+ if isinstance(f, str):
+ f = open(f)
+
+ d = json.load(f)
+
+ try:
+ version = d["version"]
+ if version < REQUIRED_MODEL_VERSION:
+ raise ValueError("the model file is version %d; "
+ "version %d is required" %
+ (version, REQUIRED_MODEL_VERSION))
+ except KeyError:
+ raise ValueError("the model file lacks a version number; "
+ "version %d is required" %
+ (REQUIRED_MODEL_VERSION))
+
+ for k, v in d['ngrams'].items():
+ k = tuple(str(k).split('\t'))
+ values = self.ngrams.setdefault(k, [])
+ for p, count in v.items():
+ values.extend([str(p)] * count)
+ values.sort()
+
+ for k, v in d['query_details'].items():
+ values = self.query_details.setdefault(str(k), [])
+ for p, count in v.items():
+ if p == '-':
+ values.extend([()] * count)
+ else:
+ values.extend([tuple(str(p).split('\t'))] * count)
+ values.sort()
+
+ if 'dns' in d:
+ for k, v in d['dns'].items():
+ self.dns_opcounts[k] += v
+
+ self.cumulative_duration = d['cumulative_duration']
+ self.packet_rate = d['packet_rate']
+
+ def construct_conversation_sequence(self, timestamp=0.0,
+ hard_stop=None,
+ replay_speed=1,
+ ignore_before=0,
+ persistence=0):
+ """Construct an individual conversation packet sequence from the
+ model.
+ """
+ c = []
+ key = (NON_PACKET,) * (self.n - 1)
+ if ignore_before is None:
+ ignore_before = timestamp - 1
+
+ while True:
+ p = random.choice(self.ngrams.get(key, (NON_PACKET,)))
+ if p == NON_PACKET:
+ if timestamp < ignore_before:
+ break
+ if random.random() > persistence:
+ print("ending after %s (persistence %.1f)" % (key, persistence),
+ file=sys.stderr)
+ break
+
+ p = 'wait:%d' % random.randrange(5, 12)
+ print("trying %s instead of end" % p, file=sys.stderr)
+
+ if p in self.query_details:
+ extra = random.choice(self.query_details[p])
+ else:
+ extra = []
+
+ protocol, opcode = p.split(':', 1)
+ if protocol == 'wait':
+ log_wait_time = int(opcode) + random.random()
+ wait = math.exp(log_wait_time) / (WAIT_SCALE * replay_speed)
+ timestamp += wait
+ else:
+ log_wait = random.uniform(*NO_WAIT_LOG_TIME_RANGE)
+ wait = math.exp(log_wait) / replay_speed
+ timestamp += wait
+ if hard_stop is not None and timestamp > hard_stop:
+ break
+ if timestamp >= ignore_before:
+ c.append((timestamp, protocol, opcode, extra))
+
+ key = key[1:] + (p,)
+ if key[-2][:5] == 'wait:' and key[-1][:5] == 'wait:':
+ # two waits in a row can only be caused by "persistence"
+ # tricks, and will not result in any packets being found.
+ # Instead we pretend this is a fresh start.
+ key = (NON_PACKET,) * (self.n - 1)
+
+ return c
+
+ def scale_to_packet_rate(self, scale):
+ rate_n, rate_t = self.packet_rate
+ return scale * rate_n / rate_t
+
+ def packet_rate_to_scale(self, pps):
+ rate_n, rate_t = self.packet_rate
+ return pps * rate_t / rate_n
+
+ def generate_conversation_sequences(self, packet_rate, duration, replay_speed=1,
+ persistence=0):
+ """Generate a list of conversation descriptions from the model."""
+
+ # We run the simulation for ten times as long as our desired
+ # duration, and take the section at the end.
+ lead_in = 9 * duration
+ target_packets = int(packet_rate * duration)
+ conversations = []
+ n_packets = 0
+
+ while n_packets < target_packets:
+ start = random.uniform(-lead_in, duration)
+ c = self.construct_conversation_sequence(start,
+ hard_stop=duration,
+ replay_speed=replay_speed,
+ ignore_before=0,
+ persistence=persistence)
+ # will these "packets" generate actual traffic?
+ # some (e.g. ldap unbind) will not generate anything
+ # if the previous packets are not there, and if the
+ # conversation only has those it wastes a process doing nothing.
+ for timestamp, protocol, opcode, extra in c:
+ if is_a_traffic_generating_packet(protocol, opcode):
+ break
+ else:
+ continue
+
+ conversations.append(c)
+ n_packets += len(c)
+
+ scale = self.packet_rate_to_scale(packet_rate)
+ print(("we have %d packets (target %d) in %d conversations at %.1f/s "
+ "(scale %f)" % (n_packets, target_packets, len(conversations),
+ packet_rate, scale)),
+ file=sys.stderr)
+ conversations.sort() # sorts by first element == start time
+ return conversations
+
+
+def seq_to_conversations(seq, server=1, client=2):
+ conversations = []
+ for s in seq:
+ if s:
+ c = Conversation(s[0][0], (server, client), s)
+ client += 1
+ conversations.append(c)
+ return conversations
+
+
+IP_PROTOCOLS = {
+ 'dns': '11',
+ 'rpc_netlogon': '06',
+ 'kerberos': '06', # ratio 16248:258
+ 'smb': '06',
+ 'smb2': '06',
+ 'ldap': '06',
+ 'cldap': '11',
+ 'lsarpc': '06',
+ 'samr': '06',
+ 'dcerpc': '06',
+ 'epm': '06',
+ 'drsuapi': '06',
+ 'browser': '11',
+ 'smb_netlogon': '11',
+ 'srvsvc': '06',
+ 'nbns': '11',
+}
+
+OP_DESCRIPTIONS = {
+ ('browser', '0x01'): 'Host Announcement (0x01)',
+ ('browser', '0x02'): 'Request Announcement (0x02)',
+ ('browser', '0x08'): 'Browser Election Request (0x08)',
+ ('browser', '0x09'): 'Get Backup List Request (0x09)',
+ ('browser', '0x0c'): 'Domain/Workgroup Announcement (0x0c)',
+ ('browser', '0x0f'): 'Local Master Announcement (0x0f)',
+ ('cldap', '3'): 'searchRequest',
+ ('cldap', '5'): 'searchResDone',
+ ('dcerpc', '0'): 'Request',
+ ('dcerpc', '11'): 'Bind',
+ ('dcerpc', '12'): 'Bind_ack',
+ ('dcerpc', '13'): 'Bind_nak',
+ ('dcerpc', '14'): 'Alter_context',
+ ('dcerpc', '15'): 'Alter_context_resp',
+ ('dcerpc', '16'): 'AUTH3',
+ ('dcerpc', '2'): 'Response',
+ ('dns', '0'): 'query',
+ ('dns', '1'): 'response',
+ ('drsuapi', '0'): 'DsBind',
+ ('drsuapi', '12'): 'DsCrackNames',
+ ('drsuapi', '13'): 'DsWriteAccountSpn',
+ ('drsuapi', '1'): 'DsUnbind',
+ ('drsuapi', '2'): 'DsReplicaSync',
+ ('drsuapi', '3'): 'DsGetNCChanges',
+ ('drsuapi', '4'): 'DsReplicaUpdateRefs',
+ ('epm', '3'): 'Map',
+ ('kerberos', ''): '',
+ ('ldap', '0'): 'bindRequest',
+ ('ldap', '1'): 'bindResponse',
+ ('ldap', '2'): 'unbindRequest',
+ ('ldap', '3'): 'searchRequest',
+ ('ldap', '4'): 'searchResEntry',
+ ('ldap', '5'): 'searchResDone',
+ ('ldap', ''): '*** Unknown ***',
+ ('lsarpc', '14'): 'lsa_LookupNames',
+ ('lsarpc', '15'): 'lsa_LookupSids',
+ ('lsarpc', '39'): 'lsa_QueryTrustedDomainInfoBySid',
+ ('lsarpc', '40'): 'lsa_SetTrustedDomainInfo',
+ ('lsarpc', '6'): 'lsa_OpenPolicy',
+ ('lsarpc', '76'): 'lsa_LookupSids3',
+ ('lsarpc', '77'): 'lsa_LookupNames4',
+ ('nbns', '0'): 'query',
+ ('nbns', '1'): 'response',
+ ('rpc_netlogon', '21'): 'NetrLogonDummyRoutine1',
+ ('rpc_netlogon', '26'): 'NetrServerAuthenticate3',
+ ('rpc_netlogon', '29'): 'NetrLogonGetDomainInfo',
+ ('rpc_netlogon', '30'): 'NetrServerPasswordSet2',
+ ('rpc_netlogon', '39'): 'NetrLogonSamLogonEx',
+ ('rpc_netlogon', '40'): 'DsrEnumerateDomainTrusts',
+ ('rpc_netlogon', '45'): 'NetrLogonSamLogonWithFlags',
+ ('rpc_netlogon', '4'): 'NetrServerReqChallenge',
+ ('samr', '0',): 'Connect',
+ ('samr', '16'): 'GetAliasMembership',
+ ('samr', '17'): 'LookupNames',
+ ('samr', '18'): 'LookupRids',
+ ('samr', '19'): 'OpenGroup',
+ ('samr', '1'): 'Close',
+ ('samr', '25'): 'QueryGroupMember',
+ ('samr', '34'): 'OpenUser',
+ ('samr', '36'): 'QueryUserInfo',
+ ('samr', '39'): 'GetGroupsForUser',
+ ('samr', '3'): 'QuerySecurity',
+ ('samr', '5'): 'LookupDomain',
+ ('samr', '64'): 'Connect5',
+ ('samr', '6'): 'EnumDomains',
+ ('samr', '7'): 'OpenDomain',
+ ('samr', '8'): 'QueryDomainInfo',
+ ('smb', '0x04'): 'Close (0x04)',
+ ('smb', '0x24'): 'Locking AndX (0x24)',
+ ('smb', '0x2e'): 'Read AndX (0x2e)',
+ ('smb', '0x32'): 'Trans2 (0x32)',
+ ('smb', '0x71'): 'Tree Disconnect (0x71)',
+ ('smb', '0x72'): 'Negotiate Protocol (0x72)',
+ ('smb', '0x73'): 'Session Setup AndX (0x73)',
+ ('smb', '0x74'): 'Logoff AndX (0x74)',
+ ('smb', '0x75'): 'Tree Connect AndX (0x75)',
+ ('smb', '0xa2'): 'NT Create AndX (0xa2)',
+ ('smb2', '0'): 'NegotiateProtocol',
+ ('smb2', '11'): 'Ioctl',
+ ('smb2', '14'): 'Find',
+ ('smb2', '16'): 'GetInfo',
+ ('smb2', '18'): 'Break',
+ ('smb2', '1'): 'SessionSetup',
+ ('smb2', '2'): 'SessionLogoff',
+ ('smb2', '3'): 'TreeConnect',
+ ('smb2', '4'): 'TreeDisconnect',
+ ('smb2', '5'): 'Create',
+ ('smb2', '6'): 'Close',
+ ('smb2', '8'): 'Read',
+ ('smb_netlogon', '0x12'): 'SAM LOGON request from client (0x12)',
+ ('smb_netlogon', '0x17'): ('SAM Active Directory Response - '
+ 'user unknown (0x17)'),
+ ('srvsvc', '16'): 'NetShareGetInfo',
+ ('srvsvc', '21'): 'NetSrvGetInfo',
+}
+
+
+def expand_short_packet(p, timestamp, src, dest, extra):
+ protocol, opcode = p.split(':', 1)
+ desc = OP_DESCRIPTIONS.get((protocol, opcode), '')
+ ip_protocol = IP_PROTOCOLS.get(protocol, '06')
+
+ line = [timestamp, ip_protocol, '', src, dest, protocol, opcode, desc]
+ line.extend(extra)
+ return '\t'.join(line)
+
+
+def flushing_signal_handler(signal, frame):
+ """Signal handler closes standard out and error.
+
+ Triggered by a sigterm, ensures that the log messages are flushed
+ to disk and not lost.
+ """
+ sys.stderr.close()
+ sys.stdout.close()
+ os._exit(0)
+
+
+def replay_seq_in_fork(cs, start, context, account, client_id, server_id=1):
+ """Fork a new process and replay the conversation sequence."""
+ # We will need to reseed the random number generator or all the
+ # clients will end up using the same sequence of random
+ # numbers. random.randint() is mixed in so the initial seed will
+ # have an effect here.
+ seed = client_id * 1000 + random.randint(0, 999)
+
+ # flush our buffers so messages won't be written by both sides
+ sys.stdout.flush()
+ sys.stderr.flush()
+ pid = os.fork()
+ if pid != 0:
+ return pid
+
+ # we must never return, or we'll end up running parts of the
+ # parent's clean-up code. So we work in a try...finally, and
+ # try to print any exceptions.
+ try:
+ random.seed(seed)
+ endpoints = (server_id, client_id)
+ status = 0
+ t = cs[0][0]
+ c = Conversation(t, endpoints, seq=cs, conversation_id=client_id)
+ signal.signal(signal.SIGTERM, flushing_signal_handler)
+
+ context.generate_process_local_config(account, c)
+ sys.stdin.close()
+ os.close(0)
+ filename = os.path.join(context.statsdir, 'stats-conversation-%d' %
+ c.conversation_id)
+ f = open(filename, 'w')
+ try:
+ sys.stdout.close()
+ os.close(1)
+ except IOError as e:
+ LOGGER.info("stdout closing failed with %s" % e)
+
+ sys.stdout = f
+ now = time.time() - start
+ gap = t - now
+ sleep_time = gap - SLEEP_OVERHEAD
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+
+ max_lag, start_lag, max_sleep_miss = c.replay_with_delay(start=start,
+ context=context)
+ print("Maximum lag: %f" % max_lag)
+ print("Start lag: %f" % start_lag)
+ print("Max sleep miss: %f" % max_sleep_miss)
+
+ except Exception:
+ status = 1
+ print(("EXCEPTION in child PID %d, conversation %s" % (os.getpid(), c)),
+ file=sys.stderr)
+ traceback.print_exc(sys.stderr)
+ sys.stderr.flush()
+ finally:
+ sys.stderr.close()
+ sys.stdout.close()
+ os._exit(status)
+
+
+def dnshammer_in_fork(dns_rate, duration, context, query_file=None):
+ sys.stdout.flush()
+ sys.stderr.flush()
+ pid = os.fork()
+ if pid != 0:
+ return pid
+
+ sys.stdin.close()
+ os.close(0)
+
+ try:
+ sys.stdout.close()
+ os.close(1)
+ except IOError as e:
+ LOGGER.warn("stdout closing failed with %s" % e)
+ filename = os.path.join(context.statsdir, 'stats-dns')
+ sys.stdout = open(filename, 'w')
+
+ try:
+ status = 0
+ signal.signal(signal.SIGTERM, flushing_signal_handler)
+ hammer = DnsHammer(dns_rate, duration, query_file=query_file)
+ hammer.replay(context=context)
+ except Exception:
+ status = 1
+ print(("EXCEPTION in child PID %d, the DNS hammer" % (os.getpid())),
+ file=sys.stderr)
+ traceback.print_exc(sys.stderr)
+ finally:
+ sys.stderr.close()
+ sys.stdout.close()
+ os._exit(status)
+
+
+def replay(conversation_seq,
+ host=None,
+ creds=None,
+ lp=None,
+ accounts=None,
+ dns_rate=0,
+ dns_query_file=None,
+ duration=None,
+ latency_timeout=1.0,
+ stop_on_any_error=False,
+ **kwargs):
+
+ context = ReplayContext(server=host,
+ creds=creds,
+ lp=lp,
+ total_conversations=len(conversation_seq),
+ **kwargs)
+
+ if len(accounts) < len(conversation_seq):
+ raise ValueError(("we have %d accounts but %d conversations" %
+ (len(accounts), len(conversation_seq))))
+
+ # Set the process group so that the calling scripts are not killed
+ # when the forked child processes are killed.
+ os.setpgrp()
+
+ # we delay the start by a bit to allow all the forks to get up and
+ # running.
+ delay = len(conversation_seq) * 0.02
+ start = time.time() + delay
+
+ if duration is None:
+ # end slightly after the last packet of the last conversation
+ # to start. Conversations other than the last could still be
+ # going, but we don't care.
+ duration = conversation_seq[-1][-1][0] + latency_timeout
+
+ print("We will start in %.1f seconds" % delay,
+ file=sys.stderr)
+ print("We will stop after %.1f seconds" % (duration + delay),
+ file=sys.stderr)
+ print("runtime %.1f seconds" % duration,
+ file=sys.stderr)
+
+ # give one second grace for packets to finish before killing begins
+ end = start + duration + 1.0
+
+ LOGGER.info("Replaying traffic for %u conversations over %d seconds"
+ % (len(conversation_seq), duration))
+
+ context.write_stats('intentions',
+ Planned_conversations=len(conversation_seq),
+ Planned_packets=sum(len(x) for x in conversation_seq))
+
+ children = {}
+ try:
+ if dns_rate:
+ pid = dnshammer_in_fork(dns_rate, duration, context,
+ query_file=dns_query_file)
+ children[pid] = 1
+
+ for i, cs in enumerate(conversation_seq):
+ account = accounts[i]
+ client_id = i + 2
+ pid = replay_seq_in_fork(cs, start, context, account, client_id)
+ children[pid] = client_id
+
+ # HERE, we are past all the forks
+ t = time.time()
+ print("all forks done in %.1f seconds, waiting %.1f" %
+ (t - start + delay, t - start),
+ file=sys.stderr)
+
+ while time.time() < end and children:
+ time.sleep(0.003)
+ try:
+ pid, status = os.waitpid(-1, os.WNOHANG)
+ except OSError as e:
+ if e.errno != ECHILD: # no child processes
+ raise
+ break
+ if pid:
+ c = children.pop(pid, None)
+ if DEBUG_LEVEL > 0:
+ print(("process %d finished conversation %d;"
+ " %d to go" %
+ (pid, c, len(children))), file=sys.stderr)
+ if stop_on_any_error and status != 0:
+ break
+
+ except Exception:
+ print("EXCEPTION in parent", file=sys.stderr)
+ traceback.print_exc()
+ finally:
+ context.write_stats('unfinished',
+ Unfinished_conversations=len(children))
+
+ for s in (15, 15, 9):
+ print(("killing %d children with -%d" %
+ (len(children), s)), file=sys.stderr)
+ for pid in children:
+ try:
+ os.kill(pid, s)
+ except OSError as e:
+ if e.errno != ESRCH: # don't fail if it has already died
+ raise
+ time.sleep(0.5)
+ end = time.time() + 1
+ while children:
+ try:
+ pid, status = os.waitpid(-1, os.WNOHANG)
+ except OSError as e:
+ if e.errno != ECHILD:
+ raise
+ if pid != 0:
+ c = children.pop(pid, None)
+ if c is None:
+ print("children is %s, no pid found" % children)
+ sys.stderr.flush()
+ sys.stdout.flush()
+ os._exit(1)
+ print(("kill -%d %d KILLED conversation; "
+ "%d to go" %
+ (s, pid, len(children))),
+ file=sys.stderr)
+ if time.time() >= end:
+ break
+
+ if not children:
+ break
+ time.sleep(1)
+
+ if children:
+ print("%d children are missing" % len(children),
+ file=sys.stderr)
+
+ # there may be stragglers that were forked just as ^C was hit
+ # and don't appear in the list of children. We can get them
+ # with killpg, but that will also kill us, so this is^H^H would be
+ # goodbye, except we cheat and pretend to use ^C (SIG_INTERRUPT),
+ # so as not to have to fuss around writing signal handlers.
+ try:
+ os.killpg(0, 2)
+ except KeyboardInterrupt:
+ print("ignoring fake ^C", file=sys.stderr)
+
+
+def openLdb(host, creds, lp):
+ session = system_session()
+ ldb = SamDB(url="ldap://%s" % host,
+ session_info=session,
+ options=['modules:paged_searches'],
+ credentials=creds,
+ lp=lp)
+ return ldb
+
+
+def ou_name(ldb, instance_id):
+ """Generate an ou name from the instance id"""
+ return "ou=instance-%d,ou=traffic_replay,%s" % (instance_id,
+ ldb.domain_dn())
+
+
+def create_ou(ldb, instance_id):
+ """Create an ou, all created user and machine accounts will belong to it.
+
+ This allows all the created resources to be cleaned up easily.
+ """
+ ou = ou_name(ldb, instance_id)
+ try:
+ ldb.add({"dn": ou.split(',', 1)[1],
+ "objectclass": "organizationalunit"})
+ except LdbError as e:
+ (status, _) = e.args
+ # ignore already exists
+ if status != 68:
+ raise
+ try:
+ ldb.add({"dn": ou,
+ "objectclass": "organizationalunit"})
+ except LdbError as e:
+ (status, _) = e.args
+ # ignore already exists
+ if status != 68:
+ raise
+ return ou
+
+
+# ConversationAccounts holds details of the machine and user accounts
+# associated with a conversation.
+#
+# We use a named tuple to reduce shared memory usage.
+ConversationAccounts = namedtuple('ConversationAccounts',
+ ('netbios_name',
+ 'machinepass',
+ 'username',
+ 'userpass'))
+
+
+def generate_replay_accounts(ldb, instance_id, number, password):
+ """Generate a series of unique machine and user account names."""
+
+ accounts = []
+ for i in range(1, number + 1):
+ netbios_name = machine_name(instance_id, i)
+ username = user_name(instance_id, i)
+
+ account = ConversationAccounts(netbios_name, password, username,
+ password)
+ accounts.append(account)
+ return accounts
+
+
+def create_machine_account(ldb, instance_id, netbios_name, machinepass,
+ traffic_account=True):
+ """Create a machine account via ldap."""
+
+ ou = ou_name(ldb, instance_id)
+ dn = "cn=%s,%s" % (netbios_name, ou)
+ utf16pw = ('"%s"' % get_string(machinepass)).encode('utf-16-le')
+
+ if traffic_account:
+ # we set these bits for the machine account otherwise the replayed
+ # traffic throws up NT_STATUS_NO_TRUST_SAM_ACCOUNT errors
+ account_controls = str(UF_TRUSTED_FOR_DELEGATION |
+ UF_SERVER_TRUST_ACCOUNT)
+
+ else:
+ account_controls = str(UF_WORKSTATION_TRUST_ACCOUNT)
+
+ ldb.add({
+ "dn": dn,
+ "objectclass": "computer",
+ "sAMAccountName": "%s$" % netbios_name,
+ "userAccountControl": account_controls,
+ "unicodePwd": utf16pw})
+
+
+def create_user_account(ldb, instance_id, username, userpass):
+ """Create a user account via ldap."""
+ ou = ou_name(ldb, instance_id)
+ user_dn = "cn=%s,%s" % (username, ou)
+ utf16pw = ('"%s"' % get_string(userpass)).encode('utf-16-le')
+ ldb.add({
+ "dn": user_dn,
+ "objectclass": "user",
+ "sAMAccountName": username,
+ "userAccountControl": str(UF_NORMAL_ACCOUNT),
+ "unicodePwd": utf16pw
+ })
+
+ # grant user write permission to do things like write account SPN
+ sdutils = sd_utils.SDUtils(ldb)
+ sdutils.dacl_add_ace(user_dn, "(A;;WP;;;PS)")
+
+
+def create_group(ldb, instance_id, name):
+ """Create a group via ldap."""
+
+ ou = ou_name(ldb, instance_id)
+ dn = "cn=%s,%s" % (name, ou)
+ ldb.add({
+ "dn": dn,
+ "objectclass": "group",
+ "sAMAccountName": name,
+ })
+
+
+def user_name(instance_id, i):
+ """Generate a user name based in the instance id"""
+ return "STGU-%d-%d" % (instance_id, i)
+
+
+def search_objectclass(ldb, objectclass='user', attr='sAMAccountName'):
+ """Search objectclass, return attr in a set"""
+ objs = ldb.search(
+ expression="(objectClass={})".format(objectclass),
+ attrs=[attr]
+ )
+ return {str(obj[attr]) for obj in objs}
+
+
+def generate_users(ldb, instance_id, number, password):
+ """Add users to the server"""
+ existing_objects = search_objectclass(ldb, objectclass='user')
+ users = 0
+ for i in range(number, 0, -1):
+ name = user_name(instance_id, i)
+ if name not in existing_objects:
+ create_user_account(ldb, instance_id, name, password)
+ users += 1
+ if users % 50 == 0:
+ LOGGER.info("Created %u/%u users" % (users, number))
+
+ return users
+
+
+def machine_name(instance_id, i, traffic_account=True):
+ """Generate a machine account name from instance id."""
+ if traffic_account:
+ # traffic accounts correspond to a given user, and use different
+ # userAccountControl flags to ensure packets get processed correctly
+ # by the DC
+ return "STGM-%d-%d" % (instance_id, i)
+ else:
+ # Otherwise we're just generating computer accounts to simulate a
+ # semi-realistic network. These use the default computer
+ # userAccountControl flags, so we use a different account name so that
+ # we don't try to use them when generating packets
+ return "PC-%d-%d" % (instance_id, i)
+
+
+def generate_machine_accounts(ldb, instance_id, number, password,
+ traffic_account=True):
+ """Add machine accounts to the server"""
+ existing_objects = search_objectclass(ldb, objectclass='computer')
+ added = 0
+ for i in range(number, 0, -1):
+ name = machine_name(instance_id, i, traffic_account)
+ if name + "$" not in existing_objects:
+ create_machine_account(ldb, instance_id, name, password,
+ traffic_account)
+ added += 1
+ if added % 50 == 0:
+ LOGGER.info("Created %u/%u machine accounts" % (added, number))
+
+ return added
+
+
+def group_name(instance_id, i):
+ """Generate a group name from instance id."""
+ return "STGG-%d-%d" % (instance_id, i)
+
+
+def generate_groups(ldb, instance_id, number):
+ """Create the required number of groups on the server."""
+ existing_objects = search_objectclass(ldb, objectclass='group')
+ groups = 0
+ for i in range(number, 0, -1):
+ name = group_name(instance_id, i)
+ if name not in existing_objects:
+ create_group(ldb, instance_id, name)
+ groups += 1
+ if groups % 1000 == 0:
+ LOGGER.info("Created %u/%u groups" % (groups, number))
+
+ return groups
+
+
+def clean_up_accounts(ldb, instance_id):
+ """Remove the created accounts and groups from the server."""
+ ou = ou_name(ldb, instance_id)
+ try:
+ ldb.delete(ou, ["tree_delete:1"])
+ except LdbError as e:
+ (status, _) = e.args
+ # ignore does not exist
+ if status != 32:
+ raise
+
+
+def generate_users_and_groups(ldb, instance_id, password,
+ number_of_users, number_of_groups,
+ group_memberships, max_members,
+ machine_accounts, traffic_accounts=True):
+ """Generate the required users and groups, allocating the users to
+ those groups."""
+ memberships_added = 0
+ groups_added = 0
+ computers_added = 0
+
+ create_ou(ldb, instance_id)
+
+ LOGGER.info("Generating dummy user accounts")
+ users_added = generate_users(ldb, instance_id, number_of_users, password)
+
+ LOGGER.info("Generating dummy machine accounts")
+ computers_added = generate_machine_accounts(ldb, instance_id,
+ machine_accounts, password,
+ traffic_accounts)
+
+ if number_of_groups > 0:
+ LOGGER.info("Generating dummy groups")
+ groups_added = generate_groups(ldb, instance_id, number_of_groups)
+
+ if group_memberships > 0:
+ LOGGER.info("Assigning users to groups")
+ assignments = GroupAssignments(number_of_groups,
+ groups_added,
+ number_of_users,
+ users_added,
+ group_memberships,
+ max_members)
+ LOGGER.info("Adding users to groups")
+ add_users_to_groups(ldb, instance_id, assignments)
+ memberships_added = assignments.total()
+
+ if (groups_added > 0 and users_added == 0 and
+ number_of_groups != groups_added):
+ LOGGER.warning("The added groups will contain no members")
+
+ LOGGER.info("Added %d users (%d machines), %d groups and %d memberships" %
+ (users_added, computers_added, groups_added,
+ memberships_added))
+
+
+class GroupAssignments(object):
+ def __init__(self, number_of_groups, groups_added, number_of_users,
+ users_added, group_memberships, max_members):
+
+ self.count = 0
+ self.generate_group_distribution(number_of_groups)
+ self.generate_user_distribution(number_of_users, group_memberships)
+ self.max_members = max_members
+ self.assignments = defaultdict(list)
+ self.assign_groups(number_of_groups, groups_added, number_of_users,
+ users_added, group_memberships)
+
+ def cumulative_distribution(self, weights):
+ # make sure the probabilities conform to a cumulative distribution
+ # spread between 0.0 and 1.0. Dividing by the weighted total gives each
+ # probability a proportional share of 1.0. Higher probabilities get a
+ # bigger share, so are more likely to be picked. We use the cumulative
+ # value, so we can use random.random() as a simple index into the list
+ dist = []
+ total = sum(weights)
+ if total == 0:
+ return None
+
+ cumulative = 0.0
+ for probability in weights:
+ cumulative += probability
+ dist.append(cumulative / total)
+ return dist
+
+ def generate_user_distribution(self, num_users, num_memberships):
+ """Probability distribution of a user belonging to a group.
+ """
+ # Assign a weighted probability to each user. Use the Pareto
+ # Distribution so that some users are in a lot of groups, and the
+ # bulk of users are in only a few groups. If we're assigning a large
+ # number of group memberships, use a higher shape. This means slightly
+ # fewer outlying users that are in large numbers of groups. The aim is
+ # to have no users belonging to more than ~500 groups.
+ if num_memberships > 5000000:
+ shape = 3.0
+ elif num_memberships > 2000000:
+ shape = 2.5
+ elif num_memberships > 300000:
+ shape = 2.25
+ else:
+ shape = 1.75
+
+ weights = []
+ for x in range(1, num_users + 1):
+ p = random.paretovariate(shape)
+ weights.append(p)
+
+ # convert the weights to a cumulative distribution between 0.0 and 1.0
+ self.user_dist = self.cumulative_distribution(weights)
+
+ def generate_group_distribution(self, n):
+ """Probability distribution of a group containing a user."""
+
+ # Assign a weighted probability to each user. Probability decreases
+ # as the group-ID increases
+ weights = []
+ for x in range(1, n + 1):
+ p = 1 / (x**1.3)
+ weights.append(p)
+
+ # convert the weights to a cumulative distribution between 0.0 and 1.0
+ self.group_weights = weights
+ self.group_dist = self.cumulative_distribution(weights)
+
+ def generate_random_membership(self):
+ """Returns a randomly generated user-group membership"""
+
+ # the list items are cumulative distribution values between 0.0 and
+ # 1.0, which makes random() a handy way to index the list to get a
+ # weighted random user/group. (Here the user/group returned are
+ # zero-based array indexes)
+ user = bisect.bisect(self.user_dist, random.random())
+ group = bisect.bisect(self.group_dist, random.random())
+
+ return user, group
+
+ def users_in_group(self, group):
+ return self.assignments[group]
+
+ def get_groups(self):
+ return self.assignments.keys()
+
+ def cap_group_membership(self, group, max_members):
+ """Prevent the group's membership from exceeding the max specified"""
+ num_members = len(self.assignments[group])
+ if num_members >= max_members:
+ LOGGER.info("Group {0} has {1} members".format(group, num_members))
+
+ # remove this group and then recalculate the cumulative
+ # distribution, so this group is no longer selected
+ self.group_weights[group - 1] = 0
+ new_dist = self.cumulative_distribution(self.group_weights)
+ self.group_dist = new_dist
+
+ def add_assignment(self, user, group):
+ # the assignments are stored in a dictionary where key=group,
+ # value=list-of-users-in-group (indexing by group-ID allows us to
+ # optimize for DB membership writes)
+ if user not in self.assignments[group]:
+ self.assignments[group].append(user)
+ self.count += 1
+
+ # check if there'a cap on how big the groups can grow
+ if self.max_members:
+ self.cap_group_membership(group, self.max_members)
+
+ def assign_groups(self, number_of_groups, groups_added,
+ number_of_users, users_added, group_memberships):
+ """Allocate users to groups.
+
+ The intention is to have a few users that belong to most groups, while
+ the majority of users belong to a few groups.
+
+ A few groups will contain most users, with the remaining only having a
+ few users.
+ """
+
+ if group_memberships <= 0:
+ return
+
+ # Calculate the number of group menberships required
+ group_memberships = math.ceil(
+ float(group_memberships) *
+ (float(users_added) / float(number_of_users)))
+
+ if self.max_members:
+ group_memberships = min(group_memberships,
+ self.max_members * number_of_groups)
+
+ existing_users = number_of_users - users_added - 1
+ existing_groups = number_of_groups - groups_added - 1
+ while self.total() < group_memberships:
+ user, group = self.generate_random_membership()
+
+ if group > existing_groups or user > existing_users:
+ # the + 1 converts the array index to the corresponding
+ # group or user number
+ self.add_assignment(user + 1, group + 1)
+
+ def total(self):
+ return self.count
+
+
+def add_users_to_groups(db, instance_id, assignments):
+ """Takes the assignments of users to groups and applies them to the DB."""
+
+ total = assignments.total()
+ count = 0
+ added = 0
+
+ for group in assignments.get_groups():
+ users_in_group = assignments.users_in_group(group)
+ if len(users_in_group) == 0:
+ continue
+
+ # Split up the users into chunks, so we write no more than 1K at a
+ # time. (Minimizing the DB modifies is more efficient, but writing
+ # 10K+ users to a single group becomes inefficient memory-wise)
+ for chunk in range(0, len(users_in_group), 1000):
+ chunk_of_users = users_in_group[chunk:chunk + 1000]
+ add_group_members(db, instance_id, group, chunk_of_users)
+
+ added += len(chunk_of_users)
+ count += 1
+ if count % 50 == 0:
+ LOGGER.info("Added %u/%u memberships" % (added, total))
+
+def add_group_members(db, instance_id, group, users_in_group):
+ """Adds the given users to group specified."""
+
+ ou = ou_name(db, instance_id)
+
+ def build_dn(name):
+ return("cn=%s,%s" % (name, ou))
+
+ group_dn = build_dn(group_name(instance_id, group))
+ m = ldb.Message()
+ m.dn = ldb.Dn(db, group_dn)
+
+ for user in users_in_group:
+ user_dn = build_dn(user_name(instance_id, user))
+ idx = "member-" + str(user)
+ m[idx] = ldb.MessageElement(user_dn, ldb.FLAG_MOD_ADD, "member")
+
+ db.modify(m)
+
+
+def generate_stats(statsdir, timing_file):
+ """Generate and print the summary stats for a run."""
+ first = sys.float_info.max
+ last = 0
+ successful = 0
+ failed = 0
+ latencies = {}
+ failures = Counter()
+ unique_conversations = set()
+ if timing_file is not None:
+ tw = timing_file.write
+ else:
+ def tw(x):
+ pass
+
+ tw("time\tconv\tprotocol\ttype\tduration\tsuccessful\terror\n")
+
+ float_values = {
+ 'Maximum lag': 0,
+ 'Start lag': 0,
+ 'Max sleep miss': 0,
+ }
+ int_values = {
+ 'Planned_conversations': 0,
+ 'Planned_packets': 0,
+ 'Unfinished_conversations': 0,
+ }
+
+ for filename in os.listdir(statsdir):
+ path = os.path.join(statsdir, filename)
+ with open(path, 'r') as f:
+ for line in f:
+ try:
+ fields = line.rstrip('\n').split('\t')
+ conversation = fields[1]
+ protocol = fields[2]
+ packet_type = fields[3]
+ latency = float(fields[4])
+ t = float(fields[0])
+ first = min(t - latency, first)
+ last = max(t, last)
+
+ op = (protocol, packet_type)
+ latencies.setdefault(op, []).append(latency)
+ if fields[5] == 'True':
+ successful += 1
+ else:
+ failed += 1
+ failures[op] += 1
+
+ unique_conversations.add(conversation)
+
+ tw(line)
+ except (ValueError, IndexError):
+ if ':' in line:
+ k, v = line.split(':', 1)
+ if k in float_values:
+ float_values[k] = max(float(v),
+ float_values[k])
+ elif k in int_values:
+ int_values[k] = max(int(v),
+ int_values[k])
+ else:
+ print(line, file=sys.stderr)
+ else:
+ # not a valid line print and ignore
+ print(line, file=sys.stderr)
+
+ duration = last - first
+ if successful == 0:
+ success_rate = 0
+ else:
+ success_rate = successful / duration
+ if failed == 0:
+ failure_rate = 0
+ else:
+ failure_rate = failed / duration
+
+ conversations = len(unique_conversations)
+
+ print("Total conversations: %10d" % conversations)
+ print("Successful operations: %10d (%.3f per second)"
+ % (successful, success_rate))
+ print("Failed operations: %10d (%.3f per second)"
+ % (failed, failure_rate))
+
+ for k, v in sorted(float_values.items()):
+ print("%-28s %f" % (k.replace('_', ' ') + ':', v))
+ for k, v in sorted(int_values.items()):
+ print("%-28s %d" % (k.replace('_', ' ') + ':', v))
+
+ print("Protocol Op Code Description "
+ " Count Failed Mean Median "
+ "95% Range Max")
+
+ ops = {}
+ for proto, packet in latencies:
+ if proto not in ops:
+ ops[proto] = set()
+ ops[proto].add(packet)
+ protocols = sorted(ops.keys())
+
+ for protocol in protocols:
+ packet_types = sorted(ops[protocol], key=opcode_key)
+ for packet_type in packet_types:
+ op = (protocol, packet_type)
+ values = latencies[op]
+ values = sorted(values)
+ count = len(values)
+ failed = failures[op]
+ mean = sum(values) / count
+ median = calc_percentile(values, 0.50)
+ percentile = calc_percentile(values, 0.95)
+ rng = values[-1] - values[0]
+ maxv = values[-1]
+ desc = OP_DESCRIPTIONS.get(op, '')
+ print("%-12s %4s %-35s %12d %12d %12.6f "
+ "%12.6f %12.6f %12.6f %12.6f"
+ % (protocol,
+ packet_type,
+ desc,
+ count,
+ failed,
+ mean,
+ median,
+ percentile,
+ rng,
+ maxv))
+
+
+def opcode_key(v):
+ """Sort key for the operation code to ensure that it sorts numerically"""
+ try:
+ return "%03d" % int(v)
+ except ValueError:
+ return v
+
+
+def calc_percentile(values, percentile):
+ """Calculate the specified percentile from the list of values.
+
+ Assumes the list is sorted in ascending order.
+ """
+
+ if not values:
+ return 0
+ k = (len(values) - 1) * percentile
+ f = math.floor(k)
+ c = math.ceil(k)
+ if f == c:
+ return values[int(k)]
+ d0 = values[int(f)] * (c - k)
+ d1 = values[int(c)] * (k - f)
+ return d0 + d1
+
+
+def mk_masked_dir(*path):
+ """In a testenv we end up with 0777 directories that look an alarming
+ green colour with ls. Use umask to avoid that."""
+ # py3 os.mkdir can do this
+ d = os.path.join(*path)
+ mask = os.umask(0o077)
+ os.mkdir(d)
+ os.umask(mask)
+ return d
diff --git a/python/samba/emulate/traffic_packets.py b/python/samba/emulate/traffic_packets.py
new file mode 100644
index 0000000..95c7465
--- /dev/null
+++ b/python/samba/emulate/traffic_packets.py
@@ -0,0 +1,973 @@
+# Dispatch for various request types.
+#
+# Copyright (C) Catalyst IT Ltd. 2017
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+import os
+import ctypes
+import random
+
+from samba.net import Net
+from samba.dcerpc import security, drsuapi, nbt, lsa, netlogon, ntlmssp
+from samba.dcerpc.netlogon import netr_WorkstationInformation
+from samba.dcerpc.security import dom_sid
+from samba.netbios import Node
+from samba.ndr import ndr_pack
+from samba.credentials import (
+ CLI_CRED_NTLMv2_AUTH,
+ MUST_USE_KERBEROS,
+ DONT_USE_KERBEROS
+)
+from samba import NTSTATUSError
+from samba.ntstatus import (
+ NT_STATUS_OBJECT_NAME_NOT_FOUND,
+ NT_STATUS_NO_SUCH_DOMAIN
+)
+import samba
+import dns.resolver
+from ldb import SCOPE_BASE
+
+def uint32(v):
+ return ctypes.c_uint32(v).value
+
+
+def check_runtime_error(runtime, val):
+ if runtime is None:
+ return False
+
+ err32 = uint32(runtime.args[0])
+ if err32 == val:
+ return True
+
+ return False
+
+
+name_formats = [
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_FQDN_1779,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_NT4_ACCOUNT,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_DISPLAY,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_GUID,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_CANONICAL,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_USER_PRINCIPAL,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_CANONICAL_EX,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_SERVICE_PRINCIPAL,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_SID_OR_SID_HISTORY,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_DNS_DOMAIN,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_UPN_AND_ALTSECID,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_NT4_ACCOUNT_NAME_SANS_DOMAIN_EX,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_GLOBAL_CATALOG_SERVERS,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_UPN_FOR_LOGON,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_SERVERS_WITH_DCS_IN_SITE,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_STRING_SID_NAME,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_ALT_SECURITY_IDENTITIES_NAME,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_NCS,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_DOMAINS,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_MAP_SCHEMA_GUID,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_NT4_ACCOUNT_NAME_SANS_DOMAIN,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_ROLES,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_INFO_FOR_SERVER,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_SERVERS_FOR_DOMAIN_IN_SITE,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_DOMAINS_IN_SITE,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_SERVERS_IN_SITE,
+ drsuapi.DRSUAPI_DS_NAME_FORMAT_LIST_SITES,
+]
+
+
+def warning(message):
+ print("\033[37;41;1m" "Warning: %s" "\033[00m" % (message))
+
+###############################################################################
+#
+# Packet generation functions:
+#
+# All the packet generation functions have the following form:
+# packet_${protocol}_${opcode}(packet, conversation, context)
+#
+# The functions return true, if statistics should be collected for the packet
+# false, the packet has been ignored.
+#
+# Where:
+# protocol is the protocol, i.e. cldap, dcerpc, ...
+# opcode is the protocol op code i.e. type of the packet to be
+# generated.
+#
+# packet contains data about the captured/generated packet
+# provides any extra data needed to generate the packet
+#
+# conversation Details of the current client/server interaction
+#
+# context state data for the current interaction
+#
+#
+#
+# The following protocols are not currently handled:
+# smb
+# smb2
+# browser
+# smb_netlogon
+#
+# The following drsuapi replication packets are currently ignored:
+# DsReplicaSync
+# DsGetNCChanges
+# DsReplicaUpdateRefs
+
+
+# Packet generators that do NOTHING are assigned to the null_packet
+# function which allows the conversation generators to notice this and
+# avoid a whole lot of pointless work.
+def null_packet(packet, conversation, context):
+ return False
+
+
+def packet_cldap_3(packet, conversation, context):
+ # searchRequest
+ net = Net(creds=context.creds, lp=context.lp)
+ net.finddc(domain=context.lp.get('realm'),
+ flags=(nbt.NBT_SERVER_LDAP |
+ nbt.NBT_SERVER_DS |
+ nbt.NBT_SERVER_WRITABLE))
+ return True
+
+
+packet_cldap_5 = null_packet
+# searchResDone
+
+packet_dcerpc_0 = null_packet
+# Request
+# Can be ignored, it's the continuation of an existing conversation
+
+packet_dcerpc_2 = null_packet
+# Request
+# Server response, so should be ignored
+
+packet_dcerpc_3 = null_packet
+
+packet_dcerpc_11 = null_packet
+# Bind
+# creation of the rpc dcerpc connection is managed by the higher level
+# protocol drivers. So we ignore it when generating traffic
+
+
+packet_dcerpc_12 = null_packet
+# Bind_ack
+# Server response, so should be ignored
+
+
+packet_dcerpc_13 = null_packet
+# Bind_nak
+# Server response, so should be ignored
+
+
+packet_dcerpc_14 = null_packet
+# Alter_context
+# Generated as part of the connect process
+
+
+def packet_dcerpc_15(packet, conversation, context):
+ # Alter_context_resp
+ # This means it was GSSAPI/krb5 (probably)
+ # Check the kerberos_state and issue a diagnostic if kerberos not enabled
+ if context.user_creds.get_kerberos_state() == DONT_USE_KERBEROS:
+ warning("Kerberos disabled but have dcerpc Alter_context_resp "
+ "indicating Kerberos was used")
+ return False
+
+
+def packet_dcerpc_16(packet, conversation, context):
+ # AUTH3
+ # This means it was NTLMSSP
+ # Check the kerberos_state and issue a diagnostic if kerberos enabled
+ if context.user_creds.get_kerberos_state() == MUST_USE_KERBEROS:
+ warning("Kerberos enabled but have dcerpc AUTH3 "
+ "indicating NTLMSSP was used")
+ return False
+
+
+def packet_dns_0(packet, conversation, context):
+ # query
+ name, rtype = context.guess_a_dns_lookup()
+ dns.resolver.query(name, rtype)
+ return True
+
+
+packet_dns_1 = null_packet
+# response
+# Server response, so should be ignored
+
+
+def packet_drsuapi_0(packet, conversation, context):
+ # DsBind
+ context.get_drsuapi_connection_pair(True)
+ return True
+
+
+NAME_FORMATS = [getattr(drsuapi, _x) for _x in dir(drsuapi)
+ if 'NAME_FORMAT' in _x]
+
+
+def packet_drsuapi_12(packet, conversation, context):
+ # DsCrackNames
+ drs, handle = context.get_drsuapi_connection_pair()
+
+ names = drsuapi.DsNameString()
+ names.str = context.server
+
+ req = drsuapi.DsNameRequest1()
+ req.format_flags = 0
+ req.format_offered = 7
+ req.format_desired = random.choice(name_formats)
+ req.codepage = 1252
+ req.language = 1033 # German, I think
+ req.format_flags = 0
+ req.count = 1
+ req.names = [names]
+
+ (result, ctr) = drs.DsCrackNames(handle, 1, req)
+ return True
+
+
+def packet_drsuapi_13(packet, conversation, context):
+ # DsWriteAccountSpn
+ req = drsuapi.DsWriteAccountSpnRequest1()
+ req.operation = drsuapi.DRSUAPI_DS_SPN_OPERATION_REPLACE
+ req.unknown1 = 0 # Unused, must be 0
+ req.object_dn = context.user_dn
+ req.count = 1 # only 1 name
+ spn_name = drsuapi.DsNameString()
+ spn_name.str = 'foo/{}'.format(context.username)
+ req.spn_names = [spn_name]
+ (drs, handle) = context.get_drsuapi_connection_pair()
+ (level, res) = drs.DsWriteAccountSpn(handle, 1, req)
+ return True
+
+
+def packet_drsuapi_1(packet, conversation, context):
+ # DsUnbind
+ (drs, handle) = context.get_drsuapi_connection_pair()
+ drs.DsUnbind(handle)
+ del context.drsuapi_connections[-1]
+ return True
+
+
+packet_drsuapi_2 = null_packet
+# DsReplicaSync
+# This is between DCs, triggered on a DB change
+# Ignoring for now
+
+
+packet_drsuapi_3 = null_packet
+# DsGetNCChanges
+# This is between DCs, trigger with DB operation,
+# or DsReplicaSync between DCs.
+# Ignoring for now
+
+
+packet_drsuapi_4 = null_packet
+# DsReplicaUpdateRefs
+# Ignoring for now
+
+
+packet_epm_3 = null_packet
+# Map
+# Will be generated by higher level protocol calls
+
+
+def packet_kerberos_(packet, conversation, context):
+ # Use the presence of kerberos packets as a hint to enable kerberos
+ # for the rest of the conversation.
+ # i.e. kerberos packets are not explicitly generated.
+ context.user_creds.set_kerberos_state(MUST_USE_KERBEROS)
+ context.user_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
+ context.machine_creds.set_kerberos_state(MUST_USE_KERBEROS)
+ context.machine_creds_bad.set_kerberos_state(MUST_USE_KERBEROS)
+ context.creds.set_kerberos_state(MUST_USE_KERBEROS)
+ return False
+
+
+packet_ldap_ = null_packet
+# Unknown
+# The ldap payload was probably encrypted so just ignore it.
+
+
+def packet_ldap_0(packet, conversation, context):
+ # bindRequest
+ if packet.extra[5] == "simple":
+ # Perform a simple bind.
+ context.get_ldap_connection(new=True, simple=True)
+ else:
+ # Perform a sasl bind.
+ context.get_ldap_connection(new=True, simple=False)
+ return True
+
+
+packet_ldap_1 = null_packet
+# bindResponse
+# Server response ignored for traffic generation
+
+
+def packet_ldap_2(packet, conversation, context):
+ # unbindRequest
+ # pop the last one off -- most likely we're in a bind/unbind ping.
+ del context.ldap_connections[-1:]
+ return False
+
+
+def packet_ldap_3(packet, conversation, context):
+ # searchRequest
+
+ (scope, dn_sig, filter, attrs, extra, desc, oid) = packet.extra
+ if not scope:
+ scope = SCOPE_BASE
+
+ samdb = context.get_ldap_connection()
+ dn = context.get_matching_dn(dn_sig)
+
+ # try to guess the search expression (don't bother for base searches, as
+ # they're only looking up a single object)
+ if (filter is None or filter == '') and scope != SCOPE_BASE:
+ filter = context.guess_search_filter(attrs, dn_sig, dn)
+
+ samdb.search(dn,
+ expression=filter,
+ scope=int(scope),
+ attrs=attrs.split(','),
+ controls=["paged_results:1:1000"])
+ return True
+
+
+packet_ldap_4 = null_packet
+# searchResEntry
+# Server response ignored for traffic generation
+
+
+packet_ldap_5 = null_packet
+# Server response ignored for traffic generation
+
+packet_ldap_6 = null_packet
+
+packet_ldap_7 = null_packet
+
+packet_ldap_8 = null_packet
+
+packet_ldap_9 = null_packet
+
+packet_ldap_16 = null_packet
+
+packet_lsarpc_0 = null_packet
+# lsarClose
+
+packet_lsarpc_1 = null_packet
+# lsarDelete
+
+packet_lsarpc_2 = null_packet
+# lsarEnumeratePrivileges
+
+packet_lsarpc_3 = null_packet
+# LsarQuerySecurityObject
+
+packet_lsarpc_4 = null_packet
+# LsarSetSecurityObject
+
+packet_lsarpc_5 = null_packet
+# LsarChangePassword
+
+packet_lsarpc_6 = null_packet
+# lsa_OpenPolicy
+# We ignore this, but take it as a hint that the lsarpc handle should
+# be over a named pipe.
+#
+
+
+def packet_lsarpc_14(packet, conversation, context):
+ # lsa_LookupNames
+ c = context.get_lsarpc_named_pipe_connection()
+
+ objectAttr = lsa.ObjectAttribute()
+ pol_handle = c.OpenPolicy2(u'', objectAttr,
+ security.SEC_FLAG_MAXIMUM_ALLOWED)
+
+ sids = lsa.TransSidArray()
+ names = [lsa.String("This Organization"),
+ lsa.String("Digest Authentication")]
+ level = lsa.LSA_LOOKUP_NAMES_ALL
+ count = 0
+ c.LookupNames(pol_handle, names, sids, level, count)
+ return True
+
+
+def packet_lsarpc_15(packet, conversation, context):
+ # lsa_LookupSids
+ c = context.get_lsarpc_named_pipe_connection()
+
+ objectAttr = lsa.ObjectAttribute()
+ pol_handle = c.OpenPolicy2(u'', objectAttr,
+ security.SEC_FLAG_MAXIMUM_ALLOWED)
+
+ sids = lsa.SidArray()
+ sid = lsa.SidPtr()
+
+ x = dom_sid("S-1-5-7")
+ sid.sid = x
+ sids.sids = [sid]
+ sids.num_sids = 1
+ names = lsa.TransNameArray()
+ level = lsa.LSA_LOOKUP_NAMES_ALL
+ count = 0
+
+ c.LookupSids(pol_handle, sids, names, level, count)
+ return True
+
+
+def packet_lsarpc_39(packet, conversation, context):
+ # lsa_QueryTrustedDomainInfoBySid
+ # Samba does not support trusted domains, so this call is expected to fail
+ #
+ c = context.get_lsarpc_named_pipe_connection()
+
+ objectAttr = lsa.ObjectAttribute()
+
+ pol_handle = c.OpenPolicy2(u'', objectAttr,
+ security.SEC_FLAG_MAXIMUM_ALLOWED)
+
+ domsid = security.dom_sid(context.domain_sid)
+ level = 1
+ try:
+ c.QueryTrustedDomainInfoBySid(pol_handle, domsid, level)
+ except NTSTATUSError as error:
+ # Object Not found is the expected result from samba,
+ # while No Such Domain is the expected result from windows,
+ # anything else is a failure.
+ if not check_runtime_error(error, NT_STATUS_OBJECT_NAME_NOT_FOUND) \
+ and not check_runtime_error(error, NT_STATUS_NO_SUCH_DOMAIN):
+ raise
+ return True
+
+
+packet_lsarpc_40 = null_packet
+# lsa_SetTrustedDomainInfo
+# Not currently supported
+
+
+packet_lsarpc_43 = null_packet
+# LsaStorePrivateData
+
+
+packet_lsarpc_44 = null_packet
+# LsaRetrievePrivateData
+
+
+packet_lsarpc_68 = null_packet
+# LsarLookupNames3
+
+
+def packet_lsarpc_76(packet, conversation, context):
+ # lsa_LookupSids3
+ c = context.get_lsarpc_connection()
+ sids = lsa.SidArray()
+ sid = lsa.SidPtr()
+ # Need a set
+ x = dom_sid("S-1-5-7")
+ sid.sid = x
+ sids.sids = [sid]
+ sids.num_sids = 1
+ names = lsa.TransNameArray2()
+ level = lsa.LSA_LOOKUP_NAMES_ALL
+ count = 0
+ lookup_options = lsa.LSA_LOOKUP_OPTION_SEARCH_ISOLATED_NAMES
+ client_revision = lsa.LSA_CLIENT_REVISION_2
+ c.LookupSids3(sids, names, level, count, lookup_options, client_revision)
+ return True
+
+
+def packet_lsarpc_77(packet, conversation, context):
+ # lsa_LookupNames4
+ c = context.get_lsarpc_connection()
+ sids = lsa.TransSidArray3()
+ names = [lsa.String("This Organization"),
+ lsa.String("Digest Authentication")]
+ level = lsa.LSA_LOOKUP_NAMES_ALL
+ count = 0
+ lookup_options = lsa.LSA_LOOKUP_OPTION_SEARCH_ISOLATED_NAMES
+ client_revision = lsa.LSA_CLIENT_REVISION_2
+ c.LookupNames4(names, sids, level, count, lookup_options, client_revision)
+ return True
+
+
+def packet_nbns_0(packet, conversation, context):
+ # query
+ n = Node()
+ try:
+ n.query_name("ANAME", context.server, timeout=4, broadcast=False)
+ except:
+ pass
+ return True
+
+
+packet_nbns_1 = null_packet
+# response
+# Server response, not generated by the client
+
+
+packet_rpc_netlogon_0 = null_packet
+
+packet_rpc_netlogon_1 = null_packet
+
+packet_rpc_netlogon_4 = null_packet
+# NetrServerReqChallenge
+# generated by higher level protocol drivers
+# ignored for traffic generation
+
+packet_rpc_netlogon_14 = null_packet
+
+packet_rpc_netlogon_15 = null_packet
+
+packet_rpc_netlogon_21 = null_packet
+# NetrLogonDummyRoutine1
+# Used to determine security settings. Triggered from schannel setup
+# So no need for an explicit generator
+
+
+packet_rpc_netlogon_26 = null_packet
+# NetrServerAuthenticate3
+# Triggered from schannel set up, no need for an explicit generator
+
+
+def packet_rpc_netlogon_29(packet, conversation, context):
+ # NetrLogonGetDomainInfo [531]
+ c = context.get_netlogon_connection()
+ (auth, succ) = context.get_authenticator()
+ query = netr_WorkstationInformation()
+
+ c.netr_LogonGetDomainInfo(context.server,
+ context.netbios_name,
+ auth,
+ succ,
+ 2, # TODO are there other values?
+ query)
+ return True
+
+
+def packet_rpc_netlogon_30(packet, conversation, context):
+ # NetrServerPasswordSet2
+ c = context.get_netlogon_connection()
+ (auth, succ) = context.get_authenticator()
+ DATA_LEN = 512
+ # Set the new password to the existing password, this generates the same
+ # work load as a new value, and leaves the account password intact for
+ # subsequent runs
+ newpass = context.machine_creds.get_password().encode('utf-16-le')
+ pwd_len = len(newpass)
+ filler = [x if isinstance(x, int) else ord(x) for x in os.urandom(DATA_LEN - pwd_len)]
+ pwd = netlogon.netr_CryptPassword()
+ pwd.length = pwd_len
+ pwd.data = filler + [x if isinstance(x, int) else ord(x) for x in newpass]
+ context.machine_creds.encrypt_netr_crypt_password(pwd)
+ c.netr_ServerPasswordSet2(context.server,
+ # must ends with $, so use get_username instead
+ # of get_workstation here
+ context.machine_creds.get_username(),
+ context.machine_creds.get_secure_channel_type(),
+ context.netbios_name,
+ auth,
+ pwd)
+ return True
+
+
+packet_rpc_netlogon_34 = null_packet
+
+
+def packet_rpc_netlogon_39(packet, conversation, context):
+ # NetrLogonSamLogonEx [4331]
+ def connect(creds):
+ c = context.get_netlogon_connection()
+
+ # Disable Kerberos in cli creds to extract NTLM response
+ old_state = creds.get_kerberos_state()
+ creds.set_kerberos_state(DONT_USE_KERBEROS)
+
+ logon = samlogon_logon_info(context.domain,
+ context.netbios_name,
+ creds)
+ logon_level = netlogon.NetlogonNetworkTransitiveInformation
+ validation_level = netlogon.NetlogonValidationSamInfo4
+ netr_flags = 0
+ c.netr_LogonSamLogonEx(context.server,
+ context.machine_creds.get_workstation(),
+ logon_level,
+ logon,
+ validation_level,
+ netr_flags)
+
+ creds.set_kerberos_state(old_state)
+
+ context.last_samlogon_bad =\
+ context.with_random_bad_credentials(connect,
+ context.user_creds,
+ context.user_creds_bad,
+ context.last_samlogon_bad)
+ return True
+
+
+def samlogon_target(domain_name, computer_name):
+ target_info = ntlmssp.AV_PAIR_LIST()
+ target_info.count = 3
+ computername = ntlmssp.AV_PAIR()
+ computername.AvId = ntlmssp.MsvAvNbComputerName
+ computername.Value = computer_name
+
+ domainname = ntlmssp.AV_PAIR()
+ domainname.AvId = ntlmssp.MsvAvNbDomainName
+ domainname.Value = domain_name
+
+ eol = ntlmssp.AV_PAIR()
+ eol.AvId = ntlmssp.MsvAvEOL
+ target_info.pair = [domainname, computername, eol]
+
+ return ndr_pack(target_info)
+
+
+def samlogon_logon_info(domain_name, computer_name, creds):
+
+ target_info_blob = samlogon_target(domain_name, computer_name)
+
+ challenge = b"abcdefgh"
+ # User account under test
+ response = creds.get_ntlm_response(flags=CLI_CRED_NTLMv2_AUTH,
+ challenge=challenge,
+ target_info=target_info_blob)
+
+ logon = netlogon.netr_NetworkInfo()
+
+ logon.challenge = [x if isinstance(x, int) else ord(x) for x in challenge]
+ logon.nt = netlogon.netr_ChallengeResponse()
+ logon.nt.length = len(response["nt_response"])
+ logon.nt.data = [x if isinstance(x, int) else ord(x) for x in response["nt_response"]]
+
+ logon.identity_info = netlogon.netr_IdentityInfo()
+
+ (username, domain) = creds.get_ntlm_username_domain()
+ logon.identity_info.domain_name.string = domain
+ logon.identity_info.account_name.string = username
+ logon.identity_info.workstation.string = creds.get_workstation()
+
+ return logon
+
+
+def packet_rpc_netlogon_40(packet, conversation, context):
+ # DsrEnumerateDomainTrusts
+ c = context.get_netlogon_connection()
+ c.netr_DsrEnumerateDomainTrusts(
+ context.server,
+ netlogon.NETR_TRUST_FLAG_IN_FOREST |
+ netlogon.NETR_TRUST_FLAG_OUTBOUND |
+ netlogon.NETR_TRUST_FLAG_INBOUND)
+ return True
+
+
+def packet_rpc_netlogon_45(packet, conversation, context):
+ # NetrLogonSamLogonWithFlags [7]
+ def connect(creds):
+ c = context.get_netlogon_connection()
+ (auth, succ) = context.get_authenticator()
+
+ # Disable Kerberos in cli creds to extract NTLM response
+ old_state = creds.get_kerberos_state()
+ creds.set_kerberos_state(DONT_USE_KERBEROS)
+
+ logon = samlogon_logon_info(context.domain,
+ context.netbios_name,
+ creds)
+ logon_level = netlogon.NetlogonNetworkTransitiveInformation
+ validation_level = netlogon.NetlogonValidationSamInfo4
+ netr_flags = 0
+ c.netr_LogonSamLogonWithFlags(context.server,
+ context.machine_creds.get_workstation(),
+ auth,
+ succ,
+ logon_level,
+ logon,
+ validation_level,
+ netr_flags)
+
+ creds.set_kerberos_state(old_state)
+
+ context.last_samlogon_bad =\
+ context.with_random_bad_credentials(connect,
+ context.user_creds,
+ context.user_creds_bad,
+ context.last_samlogon_bad)
+ return True
+
+
+def packet_samr_0(packet, conversation, context):
+ # Open
+ c = context.get_samr_context()
+ c.get_handle()
+ return True
+
+
+def packet_samr_1(packet, conversation, context):
+ # Close
+ c = context.get_samr_context()
+ s = c.get_connection()
+ # close the last opened handle, may not always be accurate
+ # but will do for load simulation
+ if c.user_handle is not None:
+ s.Close(c.user_handle)
+ c.user_handle = None
+ elif c.group_handle is not None:
+ s.Close(c.group_handle)
+ c.group_handle = None
+ elif c.domain_handle is not None:
+ s.Close(c.domain_handle)
+ c.domain_handle = None
+ c.rids = None
+ elif c.handle is not None:
+ s.Close(c.handle)
+ c.handle = None
+ c.domain_sid = None
+ return True
+
+
+def packet_samr_3(packet, conversation, context):
+ # QuerySecurity
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.user_handle is None:
+ packet_samr_34(packet, conversation, context)
+ s.QuerySecurity(c.user_handle, 1)
+ return True
+
+
+def packet_samr_5(packet, conversation, context):
+ # LookupDomain
+ c = context.get_samr_context()
+ s = c.get_connection()
+ h = c.get_handle()
+ d = lsa.String()
+ d.string = context.domain
+ c.domain_sid = s.LookupDomain(h, d)
+ return True
+
+
+def packet_samr_6(packet, conversation, context):
+ # EnumDomains
+ c = context.get_samr_context()
+ s = c.get_connection()
+ h = c.get_handle()
+ s.EnumDomains(h, 0, 0)
+ return True
+
+
+def packet_samr_7(packet, conversation, context):
+ # OpenDomain
+ c = context.get_samr_context()
+ s = c.get_connection()
+ h = c.get_handle()
+ if c.domain_sid is None:
+ packet_samr_5(packet, conversation, context)
+
+ c.domain_handle = s.OpenDomain(h,
+ security.SEC_FLAG_MAXIMUM_ALLOWED,
+ c.domain_sid)
+ return True
+
+
+SAMR_QUERY_DOMAIN_INFO_LEVELS = [8, 12]
+
+
+def packet_samr_8(packet, conversation, context):
+ # QueryDomainInfo [228]
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.domain_handle is None:
+ packet_samr_7(packet, conversation, context)
+ level = random.choice(SAMR_QUERY_DOMAIN_INFO_LEVELS)
+ s.QueryDomainInfo(c.domain_handle, level)
+ return True
+
+
+packet_samr_14 = null_packet
+# CreateDomainAlias
+# Ignore these for now.
+
+
+def packet_samr_15(packet, conversation, context):
+ # EnumDomainAliases
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.domain_handle is None:
+ packet_samr_7(packet, conversation, context)
+
+ s.EnumDomainAliases(c.domain_handle, 100, 0)
+ return True
+
+
+def packet_samr_16(packet, conversation, context):
+ # GetAliasMembership
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.domain_handle is None:
+ packet_samr_7(packet, conversation, context)
+
+ sids = lsa.SidArray()
+ sid = lsa.SidPtr()
+ sid.sid = c.domain_sid
+ sids.sids = [sid]
+ s.GetAliasMembership(c.domain_handle, sids)
+ return True
+
+
+def packet_samr_17(packet, conversation, context):
+ # LookupNames
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.domain_handle is None:
+ packet_samr_7(packet, conversation, context)
+
+ name = lsa.String(context.username)
+ c.rids = s.LookupNames(c.domain_handle, [name])
+ return True
+
+
+def packet_samr_18(packet, conversation, context):
+ # LookupRids
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.rids is None:
+ packet_samr_17(packet, conversation, context)
+ rids = []
+ for r in c.rids:
+ for i in r.ids:
+ rids.append(i)
+ s.LookupRids(c.domain_handle, rids)
+ return True
+
+
+def packet_samr_19(packet, conversation, context):
+ # OpenGroup
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.domain_handle is None:
+ packet_samr_7(packet, conversation, context)
+
+ rid = 0x202 # Users I think.
+ c.group_handle = s.OpenGroup(c.domain_handle,
+ security.SEC_FLAG_MAXIMUM_ALLOWED,
+ rid)
+ return True
+
+
+def packet_samr_25(packet, conversation, context):
+ # QueryGroupMember
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.group_handle is None:
+ packet_samr_19(packet, conversation, context)
+ s.QueryGroupMember(c.group_handle)
+ return True
+
+
+def packet_samr_34(packet, conversation, context):
+ # OpenUser
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.rids is None:
+ packet_samr_17(packet, conversation, context)
+ c.user_handle = s.OpenUser(c.domain_handle,
+ security.SEC_FLAG_MAXIMUM_ALLOWED,
+ c.rids[0].ids[0])
+ return True
+
+
+def packet_samr_36(packet, conversation, context):
+ # QueryUserInfo
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.user_handle is None:
+ packet_samr_34(packet, conversation, context)
+ level = 1
+ s.QueryUserInfo(c.user_handle, level)
+ return True
+
+
+packet_samr_37 = null_packet
+
+
+def packet_samr_39(packet, conversation, context):
+ # GetGroupsForUser
+ c = context.get_samr_context()
+ s = c.get_connection()
+ if c.user_handle is None:
+ packet_samr_34(packet, conversation, context)
+ s.GetGroupsForUser(c.user_handle)
+ return True
+
+
+packet_samr_40 = null_packet
+
+packet_samr_44 = null_packet
+
+
+def packet_samr_57(packet, conversation, context):
+ # Connect2
+ c = context.get_samr_context()
+ c.get_handle()
+ return True
+
+
+def packet_samr_64(packet, conversation, context):
+ # Connect5
+ c = context.get_samr_context()
+ c.get_handle()
+ return True
+
+
+packet_samr_68 = null_packet
+
+
+def packet_srvsvc_16(packet, conversation, context):
+ # NetShareGetInfo
+ s = context.get_srvsvc_connection()
+ server_unc = "\\\\" + context.server
+ share_name = "IPC$"
+ level = 1
+ s.NetShareGetInfo(server_unc, share_name, level)
+ return True
+
+
+def packet_srvsvc_21(packet, conversation, context):
+ """NetSrvGetInfo
+
+ FIXME: Level changed from 102 to 101 here, to bypass Windows error.
+
+ Level 102 will cause WERR_ACCESS_DENIED error against Windows, because:
+
+ > If the level is 102 or 502, the Windows implementation checks whether
+ > the caller is a member of one of the groups previously mentioned or
+ > is a member of the Power Users local group.
+
+ It passed against Samba since this check is not implemented by Samba yet.
+
+ refer to:
+
+ https://msdn.microsoft.com/en-us/library/cc247297.aspx#Appendix_A_80
+
+ """
+ srvsvc = context.get_srvsvc_connection()
+ server_unc = "\\\\" + context.server
+ level = 101
+ srvsvc.NetSrvGetInfo(server_unc, level)
+ return True