summaryrefslogtreecommitdiffstats
path: root/debian/kasp_json2lmdb
diff options
context:
space:
mode:
Diffstat (limited to 'debian/kasp_json2lmdb')
-rwxr-xr-xdebian/kasp_json2lmdb458
1 files changed, 458 insertions, 0 deletions
diff --git a/debian/kasp_json2lmdb b/debian/kasp_json2lmdb
new file mode 100755
index 0000000..f6aa785
--- /dev/null
+++ b/debian/kasp_json2lmdb
@@ -0,0 +1,458 @@
+#!/usr/bin/env python3
+# vim: et ts=4 sw=4 sts=4
+#
+# import from obsolete JSON KASP to LMDB-beckended KASP database.
+#
+
+from __future__ import print_function
+
+import datetime
+import time
+import json
+import sys
+import re
+import glob
+import argparse
+import time
+import traceback
+import os
+import hashlib
+import importlib
+import codecs
+
+opt_force = False
+lmdb = None
+
+def lmdb_requirement():
+ global lmdb
+
+ try:
+ lmdb = importlib.import_module('lmdb')
+ except ImportError:
+ print("Error: unable to import module LMDB.")
+ print("Probably you need to 'apt install python3-lmdb'.")
+ sys.exit(10)
+
+# workarounding that python 2 doesn't have int.to_bytes()
+def to_bytes(n, length, endianness='big'):
+ h = '%x' % n
+ assert len(h) <= length * 2
+ s = ('0'*(len(h) % 2) + h).zfill(length * 2)
+ if sys.version_info >= (3,0):
+ sb = codecs.decode(s, 'hex')
+ else:
+ sb = s.decode('hex')
+ return bytearray(sb) if endianness == 'big' else bytearray(sb[::-1])
+
+def from_bytes(ba, endianness='big'):
+ x = ba if endianness == 'big' else bytearray(s[::-1])
+ if sys.version_info >= (3,0):
+ hx = codecs.encode(x, 'hex')
+ else:
+ hx = str(x).encode('hex')
+ return int(hx, 16)
+
+# aka knot_dname_from_str_alloc()
+def str2dname(s):
+ if s.endswith('.') is False:
+ s += '.'
+ res = bytearray(b"")
+ nodes = s.lower().split('.')
+ if nodes[-1] != "":
+ nodes.append("")
+
+ for node in nodes:
+ res.append(len(node))
+ res.extend(bytearray(node.lower(), 'ascii'))
+
+ return res
+
+def dname2str(dn):
+ res = ""
+ beg = 0
+ end = ord(dn[0]) + 1
+ while ord(dn[beg]) > 0:
+ res += str(dn[beg+1:end]) + "."
+ beg = end
+ end = beg + ord(dn[beg]) + 1
+
+ return res
+
+# this is just helper for shuffling time
+def shuffle_unixtime(base_time, shuffle_years, shuffle_months):
+ rsm = shuffle_months + 12 * shuffle_years
+ dt = datetime.datetime.fromtimestamp(base_time)
+ newmonth = (dt.month - 1 + rsm) % 12 + 1 # in python, % always returns [0, 11]
+ sameyear = dt.month + rsm % 12
+ newyear = dt.year + rsm // 12 + (0 if sameyear in range(1, 13) else 1) # in python, (-1)//12 = -1
+ dt2 = dt.replace(month=newmonth, year=newyear)
+ print(dt2.month, "/", dt2.year)
+ ttuple = dt2.timetuple()
+ return int(time.mktime(ttuple))
+
+def timespec2unix(spec):
+ if re.match(r"^\d+$", spec):
+ return int(spec)
+
+ now = int(time.time())
+ s = re.sub(r"^now", "t", spec)
+ if s == "t":
+ return now
+
+ unitmap = { "" : 1, "mi" : 60, "h" : 3600, "d" : 86400 }
+ unitmap_mo = { "mo" : 1, "y" : 12 }
+
+ if re.match(r"^t[-+]\d+", s):
+ unit = re.sub(r"^t[-+]\d+", "", s)
+ cutend = len(s) if unit == "" else -len(unit)
+ if unit in list(unitmap.keys()):
+ return now + int(s[1:cutend]) * unitmap[unit]
+ elif unit in list(unitmap_mo.keys()):
+ return shuffle_unixtime(now, 0, int(s[1:cutend]) * unitmap_mo[unit])
+ else:
+ print("Error in time unit specification")
+
+ print("Error in time specification")
+
+class Keykey:
+ '''Kasp DB key serialized (type, zone_name, key_id)'''
+
+ def __init__(self, raw_bytearray):
+ self.raw = bytearray(raw_bytearray)
+
+ @classmethod
+ def from_params(self, valtype, zone_name, key_id):
+ selfraw = to_bytes(valtype, 1)
+ if zone_name is not None:
+ selfraw.extend(zone_name)
+ if key_id is not None:
+ selfraw.extend(bytearray(key_id.encode("ascii")))
+ selfraw.append(0)
+ return Keykey(selfraw)
+
+ def getRaw(self):
+ return bytearray(self.raw)
+
+ def getType(self):
+ return self.raw[0]
+
+ def __getSplit(self):
+ x = self.raw.find(to_bytes(0, 1))
+ assert x > 0
+ return x + 1
+
+ def getZone(self):
+ if self.getType() == 2:
+ return None
+ return str(self.raw[1:self.__getSplit()])
+
+ def getKeyid(self):
+ if self.getType() != 1:
+ return None
+ return str(self.raw[self.__getSplit():])
+
+class Keyparams:
+ '''Serialized key parameters for kasp-db.'''
+
+ def __init__(self, raw_bytearray):
+ self.raw = bytearray(raw_bytearray)
+ self.timers_dict = { "created" : [ 0, 20, 28 ],
+ "publish" : [ 1, 28, 36 ],
+ "ready" : [ 2, 36, 44 ],
+ "active" : [ 3, 44, 52 ],
+ "retire" : [ 4, 52, 60 ],
+ "remove" : [ 5, 60, 68 ] }
+
+ @classmethod
+ def from_params(self, pubkey, keytag, algorithm, isksk, timers):
+ assert len(timers) == 6
+ if sys.version_info >= (3,0):
+ pk = codecs.decode(bytearray(pubkey, 'ascii'), "base64")
+ else:
+ pk = pubkey.decode("base64")
+ selfraw = to_bytes(len(pk), 8)
+ selfraw.extend(to_bytes(0, 8)) # zero length of unused-future
+ selfraw.extend(to_bytes(int(keytag), 2))
+ selfraw.extend(to_bytes(int(algorithm), 1))
+ selfraw.extend(to_bytes((1 if isksk else 0), 1))
+ for t in timers:
+ if t < 0:
+ print("keytag=%i timers=(%i, %i, %i, %i, %i, %i)" % (keytag,
+ timers[0], timers[1], timers[2], timers[3], timers[4], timers[5]))
+ assert False
+ selfraw.extend(to_bytes(t, 8))
+ selfraw.extend(pk)
+ return Keyparams(selfraw)
+
+ def _check(self):
+ assert len(self.raw) >= 16
+ pkl = from_bytes(self.raw[0:8])
+ ufl = from_bytes(self.raw[8:16])
+ assert len(self.raw) == 68 + pkl + ufl
+ assert self.raw[19] < 2
+
+ def getRaw(self):
+ self._check()
+ return bytearray(self.raw)
+
+ def getAlgorithm(self):
+ self._check()
+ return int(self.raw[18])
+
+ def setAlgorithm(self, algorithm):
+ self._check()
+ self.raw[18] = to_bytes(algorithm, 1)[0]
+
+ def isKSK(self):
+ self._check()
+ return 1 if self.raw[19] != 0 else 0
+
+ def setKSK(self, isksk):
+ self._check()
+ self.raw[11] = (b"\01" if isksk else b"\00")[0]
+
+ def getKeytag(self):
+ self._check()
+ return from_bytes(self.raw[16:18])
+
+ def setKeytag(self, keytag):
+ self._check()
+ self.raw[16:18] = to_bytes(keytag, 2)
+
+ def getTimers(self):
+ self._check()
+ res = [ 0, 0, 0, 0, 0, 0 ]
+ for i, x, y in list(self.timers_dict.values()):
+ res[i] = from_bytes(self.raw[x:y])
+ return res
+
+ def getTimersString(self):
+ self._check()
+ res = "["
+ for ti in list(self.timers_dict.keys()):
+ _, x, y = self.timers_dict[ti];
+ res += (" " if res == "[" else ", ") + ti + ": " + str(from_bytes(self.raw[x:y]))
+ return res + " ]"
+
+ def setTimers(self, timers):
+ self._check()
+ assert len(timers) == 5
+ for i, x, y in list(self.timers_dict.values()):
+ self.raw[x:y] = to_bytes(timers[i], 8)
+
+ def getPubKey(self):
+ self._check()
+ pkl = from_bytes(self.raw[0:8])
+ return self.raw[68:68+pkl].encode("base64")
+
+ def getParams(self):
+ return [ self.getPubKey(), self.getKeytag(), self.getAlgorithm(),
+ self.isKSK(), self.getTimers() ];
+
+ def setByParamName(self, param_name, new_val):
+ if param_name == "algorithm":
+ self.setAlgorithm(int(new_val))
+ elif param_name == "isksk":
+ if new_val in ("1", "True", "true", "on", "yes", "Yes"):
+ self.setKSK(True)
+ elif new_val in ("0", "False", "false", "off", "no", "No"):
+ self.setKSK(False)
+ else:
+ print("Error: bad true/false value", new_val)
+ elif param_name == "keytag":
+ self.setKeytag(int(new_val))
+ elif param_name in list(self.timers_dict.keys()):
+ _, x, y = self.timers_dict[param_name]
+ self.raw[x:y] = to_bytes(timespec2unix(new_val), 8)
+ else:
+ print("Error: bad parameter", param_name)
+
+ def computeDS(self, zone_str, digestalg):
+ ds_raw = bytearray(str2dname(zone_str))
+ ds_raw.extend(to_bytes(257 if self.isKSK() else 256, 2))
+ ds_raw.extend(b"\x03") # protocol is always == 3
+ ds_raw.extend(self.raw[18:19]) # algorithm
+ pkl = from_bytes(self.raw[0:8])
+ ds_raw.extend(self.raw[68:68+pkl]) # pubkey
+ if digestalg == "sha1":
+ ds_hash = hashlib.sha1(ds_raw).hexdigest()
+ algno = " 1 "
+ elif digestalg == "sha256":
+ ds_hash = hashlib.sha256(ds_raw).hexdigest()
+ algno = " 2 "
+ elif digestalg == "sha384":
+ ds_hash = hashlib.sha384(ds_raw).hexdigest()
+ algno = " 4 "
+ else:
+ print("Error: bad DS digest algorith", ds_hash)
+ return
+ return zone_str + ' DS ' + str(self.getKeytag()) + ' ' + str(self.getAlgorithm()) + algno + ds_hash
+
+ def isPublished(self, moment):
+ tmrs = self.getTimers()
+ if tmrs[self.timers_dict["publish"][0]] <= moment:
+ if moment < tmrs[self.timers_dict["remove"][0]]:
+ return True
+ return False
+
+ def isReady(self, moment):
+ tmrs = self.getTimers()
+ if tmrs[self.timers_dict["ready"][0]] <= moment:
+ if moment < tmrs[self.timers_dict["ready"][0]]:
+ return True
+ return False
+
+ def isActive(self, moment):
+ tmrs = self.getTimers()
+ if tmrs[self.timers_dict["active"][0]] <= moment:
+ if moment < tmrs[self.timers_dict["retire"][0]]:
+ return True
+ return False
+
+ def isRetired(self, moment):
+ tmrs = self.getTimers()
+ if tmrs[self.timers_dict["retire"][0]] <= moment:
+ return True
+ return False
+
+ def isRemoved(self, moment):
+ tmrs = self.getTimers()
+ if tmrs[self.timers_dict["remove"][0]] <= moment:
+ return True
+ return False
+
+# static: just for use in following method
+def arr_ind2unix(arr, ind, defaul):
+ try:
+ ttuple = datetime.datetime.strptime(arr[ind], "%Y-%m-%dT%H:%M:%S+0000").timetuple()
+ res = int(time.mktime(ttuple))
+ return res if res >= 0 else 0
+ except KeyError:
+ return defaul
+
+def import_nsec3salt(keys, env, db_keys, zname):
+ try:
+ with lmdb.Transaction(env, db_keys, write=True) as txn_keys:
+ dbk1 = Keykey.from_params(3, zname, None).getRaw()
+ dbv1 = keys["nsec3_salt"]
+ if dbv1 is None:
+ return
+ if sys.version_info >= (3,0):
+ dbv1d = codecs.decode(bytearray(dbv1, 'ascii'), "base64")
+ else:
+ dbv1d = dbv1.decode("base64")
+ txn_keys.put(dbk1, dbv1d, dupdata=False, overwrite=True)
+
+ dbk2 = Keykey.from_params(4, zname, None).getRaw()
+ dbv2 = to_bytes(arr_ind2unix(keys, "nsec3_salt_created", 0), 8)
+ txn_keys.put(dbk2, dbv2, dupdata=False, overwrite=True)
+ except (KeyError, AttributeError):
+ pass # nsec3salt not configured or set to null, no problem
+
+# import single JSON zone config into open LMDB env
+def import_file(fname, env, db_keys):
+ try:
+ with open(fname) as f:
+ keys = json.load(f)
+
+ except ValueError:
+ print("Warning: not imported ", fname)
+ return False
+
+ try:
+ zname_str = re.sub(r'^zone_', '', re.sub(r'\.json$', '', re.sub(r'.*/', '', fname)))
+ print("Importing zone", zname_str)
+ zname = str2dname(zname_str)
+ import_nsec3salt(keys, env, db_keys, zname)
+
+ import_now = int(time.time())
+
+ for key in keys["keys"]:
+ dbk3 = Keykey.from_params(1, zname, key["id"]).getRaw()
+
+ infty = 0x00ffffffffffff00 # time infinity, this is year 142'715'360
+
+ dbv3 = Keyparams.from_params(key["public_key"], key["keytag"],
+ key["algorithm"], key["ksk"], [
+ arr_ind2unix(key, "created", 0),
+ arr_ind2unix(key, "publish", 0),
+ arr_ind2unix(key, "active", 0), # taking active for ready
+ arr_ind2unix(key, "active", 0),
+ arr_ind2unix(key, "retire", infty),
+ arr_ind2unix(key, "remove", infty)
+ ])
+
+ if dbv3.isRemoved(import_now):
+ continue
+
+ with lmdb.Transaction(env, db_keys, write=True) as txn_keys:
+ txn_keys.put(dbk3, dbv3.getRaw(), dupdata=False, overwrite=True)
+
+ except (KeyError, KeyboardInterrupt, TypeError):
+ print("Warning: not imported ", fname)
+ return False
+
+ return True
+
+def import_dir(dirname):
+ print("Importing json key config in", dirname)
+ if os.path.isfile(dirname + "/data.mdb"):
+ print("Warning: LMDB key configuration in", dirname, "already exists.")
+ if opt_force:
+ print("...deleting it.")
+ os.remove(dirname + "/data.mdb")
+ os.remove(dirname + "/lock.mdb")
+ else:
+ print("If you want to delete it and import again, use 'force' option.")
+ return False
+
+ env = lmdb.open(dirname, max_dbs=2, map_size=500*1024*1024)
+ db_keys = env.open_db(b"keys_db")
+ something_imported = False
+ for json_file in glob.glob(dirname + "/*.json"):
+ something_imported = import_file(json_file, env, db_keys) or something_imported
+
+ if not something_imported:
+ print("Warning: nothing imported in", dirname)
+
+class VersionAction(argparse.Action):
+ def __init__(self, option_strings, version=None, dest=argparse.SUPPRESS,
+ default=argparse.SUPPRESS, help="show program's version number and exit"):
+ super(VersionAction, self).__init__(option_strings=option_strings, dest=dest,
+ default=default, nargs=0, help=help)
+ self.version = version
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ version = self.version
+ if version is None:
+ version = parser.version
+ formatter = parser._get_formatter()
+ formatter.add_text(version)
+ sys.stdout.write(formatter.format_help())
+ sys.exit(0)
+
+def main():
+ global opt_force
+ parser = argparse.ArgumentParser(description="Knot DNSSEC KASP converter (JSON to LMDB)",
+ formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument("-i", "--import", action="append", nargs="?", dest="importdir",
+ help='''Import zone-key configuration from JSON.
+Syntax: -i <key_dir>
+(You can import multiple key_dirs at once by repeating this option.)''')
+ parser.add_argument("-f", "--force", action="store_true", dest="force", help="Do stuff even if dangerous.")
+ parser.add_argument("-V", "--version", action=VersionAction, version="knot KASP legacy JSON importer (debian support for Knot DNS), version 2.7.1")
+ args = parser.parse_args()
+ opt_force = args.force
+
+ if args.importdir is not None:
+ lmdb_requirement()
+ if isinstance(args.importdir, (list, tuple)):
+ importdir = args.importdir
+ else:
+ importdir = [args.importdir]
+
+ for dirn in importdir:
+ import_dir(dirn)
+
+if __name__ == "__main__":
+ main()