#!/usr/bin/env python3 # vim: et ts=4 sw=4 sts=4 # # import from obsolete JSON KASP to LMDB-beckended KASP database. # from __future__ import print_function import datetime import time import json import sys import re import glob import argparse import time import traceback import os import hashlib import importlib import codecs opt_force = False lmdb = None def lmdb_requirement(): global lmdb try: lmdb = importlib.import_module('lmdb') except ImportError: print("Error: unable to import module LMDB.") print("Probably you need to 'apt install python3-lmdb'.") sys.exit(10) # workarounding that python 2 doesn't have int.to_bytes() def to_bytes(n, length, endianness='big'): h = '%x' % n assert len(h) <= length * 2 s = ('0'*(len(h) % 2) + h).zfill(length * 2) if sys.version_info >= (3,0): sb = codecs.decode(s, 'hex') else: sb = s.decode('hex') return bytearray(sb) if endianness == 'big' else bytearray(sb[::-1]) def from_bytes(ba, endianness='big'): x = ba if endianness == 'big' else bytearray(s[::-1]) if sys.version_info >= (3,0): hx = codecs.encode(x, 'hex') else: hx = str(x).encode('hex') return int(hx, 16) # aka knot_dname_from_str_alloc() def str2dname(s): if s.endswith('.') is False: s += '.' res = bytearray(b"") nodes = s.lower().split('.') if nodes[-1] != "": nodes.append("") for node in nodes: res.append(len(node)) res.extend(bytearray(node.lower(), 'ascii')) return res def dname2str(dn): res = "" beg = 0 end = ord(dn[0]) + 1 while ord(dn[beg]) > 0: res += str(dn[beg+1:end]) + "." beg = end end = beg + ord(dn[beg]) + 1 return res # this is just helper for shuffling time def shuffle_unixtime(base_time, shuffle_years, shuffle_months): rsm = shuffle_months + 12 * shuffle_years dt = datetime.datetime.fromtimestamp(base_time) newmonth = (dt.month - 1 + rsm) % 12 + 1 # in python, % always returns [0, 11] sameyear = dt.month + rsm % 12 newyear = dt.year + rsm // 12 + (0 if sameyear in range(1, 13) else 1) # in python, (-1)//12 = -1 dt2 = dt.replace(month=newmonth, year=newyear) print(dt2.month, "/", dt2.year) ttuple = dt2.timetuple() return int(time.mktime(ttuple)) def timespec2unix(spec): if re.match(r"^\d+$", spec): return int(spec) now = int(time.time()) s = re.sub(r"^now", "t", spec) if s == "t": return now unitmap = { "" : 1, "mi" : 60, "h" : 3600, "d" : 86400 } unitmap_mo = { "mo" : 1, "y" : 12 } if re.match(r"^t[-+]\d+", s): unit = re.sub(r"^t[-+]\d+", "", s) cutend = len(s) if unit == "" else -len(unit) if unit in list(unitmap.keys()): return now + int(s[1:cutend]) * unitmap[unit] elif unit in list(unitmap_mo.keys()): return shuffle_unixtime(now, 0, int(s[1:cutend]) * unitmap_mo[unit]) else: print("Error in time unit specification") print("Error in time specification") class Keykey: '''Kasp DB key serialized (type, zone_name, key_id)''' def __init__(self, raw_bytearray): self.raw = bytearray(raw_bytearray) @classmethod def from_params(self, valtype, zone_name, key_id): selfraw = to_bytes(valtype, 1) if zone_name is not None: selfraw.extend(zone_name) if key_id is not None: selfraw.extend(bytearray(key_id.encode("ascii"))) selfraw.append(0) return Keykey(selfraw) def getRaw(self): return bytearray(self.raw) def getType(self): return self.raw[0] def __getSplit(self): x = self.raw.find(to_bytes(0, 1)) assert x > 0 return x + 1 def getZone(self): if self.getType() == 2: return None return str(self.raw[1:self.__getSplit()]) def getKeyid(self): if self.getType() != 1: return None return str(self.raw[self.__getSplit():]) class Keyparams: '''Serialized key parameters for kasp-db.''' def __init__(self, raw_bytearray): self.raw = bytearray(raw_bytearray) self.timers_dict = { "created" : [ 0, 20, 28 ], "publish" : [ 1, 28, 36 ], "ready" : [ 2, 36, 44 ], "active" : [ 3, 44, 52 ], "retire" : [ 4, 52, 60 ], "remove" : [ 5, 60, 68 ] } @classmethod def from_params(self, pubkey, keytag, algorithm, isksk, timers): assert len(timers) == 6 if sys.version_info >= (3,0): pk = codecs.decode(bytearray(pubkey, 'ascii'), "base64") else: pk = pubkey.decode("base64") selfraw = to_bytes(len(pk), 8) selfraw.extend(to_bytes(0, 8)) # zero length of unused-future selfraw.extend(to_bytes(int(keytag), 2)) selfraw.extend(to_bytes(int(algorithm), 1)) selfraw.extend(to_bytes((1 if isksk else 0), 1)) for t in timers: if t < 0: print("keytag=%i timers=(%i, %i, %i, %i, %i, %i)" % (keytag, timers[0], timers[1], timers[2], timers[3], timers[4], timers[5])) assert False selfraw.extend(to_bytes(t, 8)) selfraw.extend(pk) return Keyparams(selfraw) def _check(self): assert len(self.raw) >= 16 pkl = from_bytes(self.raw[0:8]) ufl = from_bytes(self.raw[8:16]) assert len(self.raw) == 68 + pkl + ufl assert self.raw[19] < 2 def getRaw(self): self._check() return bytearray(self.raw) def getAlgorithm(self): self._check() return int(self.raw[18]) def setAlgorithm(self, algorithm): self._check() self.raw[18] = to_bytes(algorithm, 1)[0] def isKSK(self): self._check() return 1 if self.raw[19] != 0 else 0 def setKSK(self, isksk): self._check() self.raw[11] = (b"\01" if isksk else b"\00")[0] def getKeytag(self): self._check() return from_bytes(self.raw[16:18]) def setKeytag(self, keytag): self._check() self.raw[16:18] = to_bytes(keytag, 2) def getTimers(self): self._check() res = [ 0, 0, 0, 0, 0, 0 ] for i, x, y in list(self.timers_dict.values()): res[i] = from_bytes(self.raw[x:y]) return res def getTimersString(self): self._check() res = "[" for ti in list(self.timers_dict.keys()): _, x, y = self.timers_dict[ti]; res += (" " if res == "[" else ", ") + ti + ": " + str(from_bytes(self.raw[x:y])) return res + " ]" def setTimers(self, timers): self._check() assert len(timers) == 5 for i, x, y in list(self.timers_dict.values()): self.raw[x:y] = to_bytes(timers[i], 8) def getPubKey(self): self._check() pkl = from_bytes(self.raw[0:8]) return self.raw[68:68+pkl].encode("base64") def getParams(self): return [ self.getPubKey(), self.getKeytag(), self.getAlgorithm(), self.isKSK(), self.getTimers() ]; def setByParamName(self, param_name, new_val): if param_name == "algorithm": self.setAlgorithm(int(new_val)) elif param_name == "isksk": if new_val in ("1", "True", "true", "on", "yes", "Yes"): self.setKSK(True) elif new_val in ("0", "False", "false", "off", "no", "No"): self.setKSK(False) else: print("Error: bad true/false value", new_val) elif param_name == "keytag": self.setKeytag(int(new_val)) elif param_name in list(self.timers_dict.keys()): _, x, y = self.timers_dict[param_name] self.raw[x:y] = to_bytes(timespec2unix(new_val), 8) else: print("Error: bad parameter", param_name) def computeDS(self, zone_str, digestalg): ds_raw = bytearray(str2dname(zone_str)) ds_raw.extend(to_bytes(257 if self.isKSK() else 256, 2)) ds_raw.extend(b"\x03") # protocol is always == 3 ds_raw.extend(self.raw[18:19]) # algorithm pkl = from_bytes(self.raw[0:8]) ds_raw.extend(self.raw[68:68+pkl]) # pubkey if digestalg == "sha1": ds_hash = hashlib.sha1(ds_raw).hexdigest() algno = " 1 " elif digestalg == "sha256": ds_hash = hashlib.sha256(ds_raw).hexdigest() algno = " 2 " elif digestalg == "sha384": ds_hash = hashlib.sha384(ds_raw).hexdigest() algno = " 4 " else: print("Error: bad DS digest algorith", ds_hash) return return zone_str + ' DS ' + str(self.getKeytag()) + ' ' + str(self.getAlgorithm()) + algno + ds_hash def isPublished(self, moment): tmrs = self.getTimers() if tmrs[self.timers_dict["publish"][0]] <= moment: if moment < tmrs[self.timers_dict["remove"][0]]: return True return False def isReady(self, moment): tmrs = self.getTimers() if tmrs[self.timers_dict["ready"][0]] <= moment: if moment < tmrs[self.timers_dict["ready"][0]]: return True return False def isActive(self, moment): tmrs = self.getTimers() if tmrs[self.timers_dict["active"][0]] <= moment: if moment < tmrs[self.timers_dict["retire"][0]]: return True return False def isRetired(self, moment): tmrs = self.getTimers() if tmrs[self.timers_dict["retire"][0]] <= moment: return True return False def isRemoved(self, moment): tmrs = self.getTimers() if tmrs[self.timers_dict["remove"][0]] <= moment: return True return False # static: just for use in following method def arr_ind2unix(arr, ind, defaul): try: ttuple = datetime.datetime.strptime(arr[ind], "%Y-%m-%dT%H:%M:%S+0000").timetuple() res = int(time.mktime(ttuple)) return res if res >= 0 else 0 except KeyError: return defaul def import_nsec3salt(keys, env, db_keys, zname): try: with lmdb.Transaction(env, db_keys, write=True) as txn_keys: dbk1 = Keykey.from_params(3, zname, None).getRaw() dbv1 = keys["nsec3_salt"] if dbv1 is None: return if sys.version_info >= (3,0): dbv1d = codecs.decode(bytearray(dbv1, 'ascii'), "base64") else: dbv1d = dbv1.decode("base64") txn_keys.put(dbk1, dbv1d, dupdata=False, overwrite=True) dbk2 = Keykey.from_params(4, zname, None).getRaw() dbv2 = to_bytes(arr_ind2unix(keys, "nsec3_salt_created", 0), 8) txn_keys.put(dbk2, dbv2, dupdata=False, overwrite=True) except (KeyError, AttributeError): pass # nsec3salt not configured or set to null, no problem # import single JSON zone config into open LMDB env def import_file(fname, env, db_keys): try: with open(fname) as f: keys = json.load(f) except ValueError: print("Warning: not imported ", fname) return False try: zname_str = re.sub(r'^zone_', '', re.sub(r'\.json$', '', re.sub(r'.*/', '', fname))) print("Importing zone", zname_str) zname = str2dname(zname_str) import_nsec3salt(keys, env, db_keys, zname) import_now = int(time.time()) for key in keys["keys"]: dbk3 = Keykey.from_params(1, zname, key["id"]).getRaw() infty = 0x00ffffffffffff00 # time infinity, this is year 142'715'360 dbv3 = Keyparams.from_params(key["public_key"], key["keytag"], key["algorithm"], key["ksk"], [ arr_ind2unix(key, "created", 0), arr_ind2unix(key, "publish", 0), arr_ind2unix(key, "active", 0), # taking active for ready arr_ind2unix(key, "active", 0), arr_ind2unix(key, "retire", infty), arr_ind2unix(key, "remove", infty) ]) if dbv3.isRemoved(import_now): continue with lmdb.Transaction(env, db_keys, write=True) as txn_keys: txn_keys.put(dbk3, dbv3.getRaw(), dupdata=False, overwrite=True) except (KeyError, KeyboardInterrupt, TypeError): print("Warning: not imported ", fname) return False return True def import_dir(dirname): print("Importing json key config in", dirname) if os.path.isfile(dirname + "/data.mdb"): print("Warning: LMDB key configuration in", dirname, "already exists.") if opt_force: print("...deleting it.") os.remove(dirname + "/data.mdb") os.remove(dirname + "/lock.mdb") else: print("If you want to delete it and import again, use 'force' option.") return False env = lmdb.open(dirname, max_dbs=2, map_size=500*1024*1024) db_keys = env.open_db(b"keys_db") something_imported = False for json_file in glob.glob(dirname + "/*.json"): something_imported = import_file(json_file, env, db_keys) or something_imported if not something_imported: print("Warning: nothing imported in", dirname) class VersionAction(argparse.Action): def __init__(self, option_strings, version=None, dest=argparse.SUPPRESS, default=argparse.SUPPRESS, help="show program's version number and exit"): super(VersionAction, self).__init__(option_strings=option_strings, dest=dest, default=default, nargs=0, help=help) self.version = version def __call__(self, parser, namespace, values, option_string=None): version = self.version if version is None: version = parser.version formatter = parser._get_formatter() formatter.add_text(version) sys.stdout.write(formatter.format_help()) sys.exit(0) def main(): global opt_force parser = argparse.ArgumentParser(description="Knot DNSSEC KASP converter (JSON to LMDB)", formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("-i", "--import", action="append", nargs="?", dest="importdir", help='''Import zone-key configuration from JSON. Syntax: -i (You can import multiple key_dirs at once by repeating this option.)''') parser.add_argument("-f", "--force", action="store_true", dest="force", help="Do stuff even if dangerous.") parser.add_argument("-V", "--version", action=VersionAction, version="knot KASP legacy JSON importer (debian support for Knot DNS), version 2.7.1") args = parser.parse_args() opt_force = args.force if args.importdir is not None: lmdb_requirement() if isinstance(args.importdir, (list, tuple)): importdir = args.importdir else: importdir = [args.importdir] for dirn in importdir: import_dir(dirn) if __name__ == "__main__": main()