summaryrefslogtreecommitdiffstats
path: root/third_party/python/mohawk/mohawk/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/python/mohawk/mohawk/util.py')
-rw-r--r--third_party/python/mohawk/mohawk/util.py267
1 files changed, 267 insertions, 0 deletions
diff --git a/third_party/python/mohawk/mohawk/util.py b/third_party/python/mohawk/mohawk/util.py
new file mode 100644
index 0000000000..46a28e94ce
--- /dev/null
+++ b/third_party/python/mohawk/mohawk/util.py
@@ -0,0 +1,267 @@
+from base64 import b64encode, urlsafe_b64encode
+import calendar
+import hashlib
+import hmac
+import logging
+import math
+import os
+import pprint
+import re
+import sys
+import time
+
+import six
+
+from .exc import (
+ BadHeaderValue,
+ HawkFail,
+ InvalidCredentials)
+
+
+HAWK_VER = 1
+log = logging.getLogger(__name__)
+allowable_header_keys = set(['id', 'ts', 'tsm', 'nonce', 'hash',
+ 'error', 'ext', 'mac', 'app', 'dlg'])
+
+
+def validate_credentials(creds):
+ if not hasattr(creds, '__getitem__'):
+ raise InvalidCredentials('credentials must be a dict-like object')
+ try:
+ creds['id']
+ creds['key']
+ creds['algorithm']
+ except KeyError:
+ etype, val, tb = sys.exc_info()
+ raise InvalidCredentials('{etype}: {val}'
+ .format(etype=etype, val=val))
+
+
+def random_string(length):
+ """Generates a random string for a given length."""
+ # this conservatively gets 8*length bits and then returns 6*length of
+ # them. Grabbing (6/8)*length bits could lose some entropy off the ends.
+ return urlsafe_b64encode(os.urandom(length))[:length]
+
+
+def calculate_payload_hash(payload, algorithm, content_type):
+ """Calculates a hash for a given payload."""
+ p_hash = hashlib.new(algorithm)
+
+ parts = []
+ parts.append('hawk.' + str(HAWK_VER) + '.payload\n')
+ parts.append(parse_content_type(content_type) + '\n')
+ parts.append(payload or '')
+ parts.append('\n')
+
+ for i, p in enumerate(parts):
+ # Make sure we are about to hash binary strings.
+ if not isinstance(p, six.binary_type):
+ p = p.encode('utf8')
+ p_hash.update(p)
+ parts[i] = p
+
+ log.debug('calculating payload hash from:\n{parts}'
+ .format(parts=pprint.pformat(parts)))
+
+ return b64encode(p_hash.digest())
+
+
+def calculate_mac(mac_type, resource, content_hash):
+ """Calculates a message authorization code (MAC)."""
+ normalized = normalize_string(mac_type, resource, content_hash)
+ log.debug(u'normalized resource for mac calc: {norm}'
+ .format(norm=normalized))
+ digestmod = getattr(hashlib, resource.credentials['algorithm'])
+
+ # Make sure we are about to hash binary strings.
+
+ if not isinstance(normalized, six.binary_type):
+ normalized = normalized.encode('utf8')
+ key = resource.credentials['key']
+ if not isinstance(key, six.binary_type):
+ key = key.encode('ascii')
+
+ result = hmac.new(key, normalized, digestmod)
+ return b64encode(result.digest())
+
+
+def calculate_ts_mac(ts, credentials):
+ """Calculates a message authorization code (MAC) for a timestamp."""
+ normalized = ('hawk.{hawk_ver}.ts\n{ts}\n'
+ .format(hawk_ver=HAWK_VER, ts=ts))
+ log.debug(u'normalized resource for ts mac calc: {norm}'
+ .format(norm=normalized))
+ digestmod = getattr(hashlib, credentials['algorithm'])
+
+ if not isinstance(normalized, six.binary_type):
+ normalized = normalized.encode('utf8')
+ key = credentials['key']
+ if not isinstance(key, six.binary_type):
+ key = key.encode('ascii')
+
+ result = hmac.new(key, normalized, digestmod)
+ return b64encode(result.digest())
+
+
+def normalize_string(mac_type, resource, content_hash):
+ """Serializes mac_type and resource into a HAWK string."""
+
+ normalized = [
+ 'hawk.' + str(HAWK_VER) + '.' + mac_type,
+ normalize_header_attr(resource.timestamp),
+ normalize_header_attr(resource.nonce),
+ normalize_header_attr(resource.method or ''),
+ normalize_header_attr(resource.name or ''),
+ normalize_header_attr(resource.host),
+ normalize_header_attr(resource.port),
+ normalize_header_attr(content_hash or '')
+ ]
+
+ # The blank lines are important. They follow what the Node Hawk lib does.
+
+ normalized.append(normalize_header_attr(resource.ext or ''))
+
+ if resource.app:
+ normalized.append(normalize_header_attr(resource.app))
+ normalized.append(normalize_header_attr(resource.dlg or ''))
+
+ # Add trailing new line.
+ normalized.append('')
+
+ normalized = '\n'.join(normalized)
+
+ return normalized
+
+
+def parse_content_type(content_type):
+ """Cleans up content_type."""
+ if content_type:
+ return content_type.split(';')[0].strip().lower()
+ else:
+ return ''
+
+
+def parse_authorization_header(auth_header):
+ """
+ Example Authorization header:
+
+ 'Hawk id="dh37fgj492je", ts="1367076201", nonce="NPHgnG", ext="and
+ welcome!", mac="CeWHy4d9kbLGhDlkyw2Nh3PJ7SDOdZDa267KH4ZaNMY="'
+ """
+ attributes = {}
+
+ # Make sure we have a unicode object for consistency.
+ if isinstance(auth_header, six.binary_type):
+ auth_header = auth_header.decode('utf8')
+
+ parts = auth_header.split(',')
+ auth_scheme_parts = parts[0].split(' ')
+ if 'hawk' != auth_scheme_parts[0].lower():
+ raise HawkFail("Unknown scheme '{scheme}' when parsing header"
+ .format(scheme=auth_scheme_parts[0].lower()))
+
+ # Replace 'Hawk key: value' with 'key: value'
+ # which matches the rest of parts
+ parts[0] = auth_scheme_parts[1]
+
+ for part in parts:
+ attr_parts = part.split('=')
+ key = attr_parts[0].strip()
+ if key not in allowable_header_keys:
+ raise HawkFail("Unknown Hawk key '{key}' when parsing header"
+ .format(key=key))
+
+ if len(attr_parts) > 2:
+ attr_parts[1] = '='.join(attr_parts[1:])
+
+ # Chop of quotation marks
+ value = attr_parts[1]
+
+ if attr_parts[1].find('"') == 0:
+ value = attr_parts[1][1:]
+
+ if value.find('"') > -1:
+ value = value[0:-1]
+
+ validate_header_attr(value, name=key)
+ value = unescape_header_attr(value)
+ attributes[key] = value
+
+ log.debug('parsed Hawk header: {header} into: \n{parsed}'
+ .format(header=auth_header, parsed=pprint.pformat(attributes)))
+ return attributes
+
+
+def strings_match(a, b):
+ # Constant time string comparision, mitigates side channel attacks.
+ if len(a) != len(b):
+ return False
+ result = 0
+
+ def byte_ints(buf):
+ for ch in buf:
+ # In Python 3, if we have a bytes object, iterating it will
+ # already get the integer value. In older pythons, we need
+ # to use ord().
+ if not isinstance(ch, int):
+ ch = ord(ch)
+ yield ch
+
+ for x, y in zip(byte_ints(a), byte_ints(b)):
+ result |= x ^ y
+ return result == 0
+
+
+def utc_now(offset_in_seconds=0.0):
+ # TODO: add support for SNTP server? See ntplib module.
+ return int(math.floor(calendar.timegm(time.gmtime()) +
+ float(offset_in_seconds)))
+
+
+# Allowed value characters:
+# !#$%&'()*+,-./:;<=>?@[]^_`{|}~ and space, a-z, A-Z, 0-9, \, "
+_header_attribute_chars = re.compile(
+ r"^[ a-zA-Z0-9_\!#\$%&'\(\)\*\+,\-\./\:;<\=>\?@\[\]\^`\{\|\}~\"\\]*$")
+
+
+def validate_header_attr(val, name=None):
+ if not _header_attribute_chars.match(val):
+ raise BadHeaderValue('header value name={name} value={val} '
+ 'contained an illegal character'
+ .format(name=name or '?', val=repr(val)))
+
+
+def escape_header_attr(val):
+
+ # Ensure we are working with Unicode for consistency.
+ if isinstance(val, six.binary_type):
+ val = val.decode('utf8')
+
+ # Escape quotes and slash like the hawk reference code.
+ val = val.replace('\\', '\\\\')
+ val = val.replace('"', '\\"')
+ val = val.replace('\n', '\\n')
+ return val
+
+
+def unescape_header_attr(val):
+ # Un-do the hawk escaping.
+ val = val.replace('\\n', '\n')
+ val = val.replace('\\\\', '\\').replace('\\"', '"')
+ return val
+
+
+def prepare_header_val(val):
+ val = escape_header_attr(val)
+ validate_header_attr(val)
+ return val
+
+
+def normalize_header_attr(val):
+ if not val:
+ val = ''
+
+ # Normalize like the hawk reference code.
+ val = escape_header_attr(val)
+ return val