summaryrefslogtreecommitdiffstats
path: root/python/samba/gkdi.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 17:20:00 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 17:20:00 +0000
commit8daa83a594a2e98f39d764422bfbdbc62c9efd44 (patch)
tree4099e8021376c7d8c05bdf8503093d80e9c7bad0 /python/samba/gkdi.py
parentInitial commit. (diff)
downloadsamba-8daa83a594a2e98f39d764422bfbdbc62c9efd44.tar.xz
samba-8daa83a594a2e98f39d764422bfbdbc62c9efd44.zip
Adding upstream version 2:4.20.0+dfsg.upstream/2%4.20.0+dfsg
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'python/samba/gkdi.py')
-rw-r--r--python/samba/gkdi.py397
1 files changed, 397 insertions, 0 deletions
diff --git a/python/samba/gkdi.py b/python/samba/gkdi.py
new file mode 100644
index 0000000..4179263
--- /dev/null
+++ b/python/samba/gkdi.py
@@ -0,0 +1,397 @@
+# Unix SMB/CIFS implementation.
+# Copyright (C) Catalyst.Net Ltd 2023
+#
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <https://www.gnu.org/licenses/>.
+#
+
+"""Group Key Distribution Service module"""
+
+from enum import Enum
+from functools import total_ordering
+from typing import Optional, Tuple
+
+from cryptography.hazmat.primitives import hashes
+
+from samba import _glue
+from samba.dcerpc import gkdi, misc
+from samba.ndr import ndr_pack, ndr_unpack
+from samba.nt_time import NtTime, NtTimeDelta
+
+
+uint64_max: int = 2**64 - 1
+
+L1_KEY_ITERATION: int = _glue.GKDI_L1_KEY_ITERATION
+L2_KEY_ITERATION: int = _glue.GKDI_L2_KEY_ITERATION
+KEY_CYCLE_DURATION: NtTimeDelta = _glue.GKDI_KEY_CYCLE_DURATION
+MAX_CLOCK_SKEW: NtTimeDelta = _glue.GKDI_MAX_CLOCK_SKEW
+
+KEY_LEN_BYTES = 64
+
+
+class Algorithm(Enum):
+ SHA1 = "SHA1"
+ SHA256 = "SHA256"
+ SHA384 = "SHA384"
+ SHA512 = "SHA512"
+
+ def algorithm(self) -> hashes.HashAlgorithm:
+ if self is Algorithm.SHA1:
+ return hashes.SHA1()
+
+ if self is Algorithm.SHA256:
+ return hashes.SHA256()
+
+ if self is Algorithm.SHA384:
+ return hashes.SHA384()
+
+ if self is Algorithm.SHA512:
+ return hashes.SHA512()
+
+ raise RuntimeError("unknown hash algorithm {self}")
+
+ def __repr__(self) -> str:
+ return str(self)
+
+ @staticmethod
+ def from_kdf_parameters(kdf_param: Optional[bytes]) -> "Algorithm":
+ if not kdf_param:
+ return Algorithm.SHA256 # the default used by Windows.
+
+ kdf_parameters = ndr_unpack(gkdi.KdfParameters, kdf_param)
+ return Algorithm(kdf_parameters.hash_algorithm)
+
+
+class GkidType(Enum):
+ DEFAULT = object()
+ L0_SEED_KEY = object()
+ L1_SEED_KEY = object()
+ L2_SEED_KEY = object()
+
+ def description(self) -> str:
+ if self is GkidType.DEFAULT:
+ return "a default GKID"
+
+ if self is GkidType.L0_SEED_KEY:
+ return "an L0 seed key"
+
+ if self is GkidType.L1_SEED_KEY:
+ return "an L1 seed key"
+
+ if self is GkidType.L2_SEED_KEY:
+ return "an L2 seed key"
+
+ raise RuntimeError("unknown GKID type {self}")
+
+
+class InvalidDerivation(Exception):
+ pass
+
+
+class UndefinedStartTime(Exception):
+ pass
+
+
+@total_ordering
+class Gkid:
+ __slots__ = ["_l0_idx", "_l1_idx", "_l2_idx"]
+
+ max_l0_idx = 0x7FFF_FFFF
+
+ def __init__(self, l0_idx: int, l1_idx: int, l2_idx: int) -> None:
+ if not -1 <= l0_idx <= Gkid.max_l0_idx:
+ raise ValueError(f"L0 index {l0_idx} out of range")
+
+ if not -1 <= l1_idx < L1_KEY_ITERATION:
+ raise ValueError(f"L1 index {l1_idx} out of range")
+
+ if not -1 <= l2_idx < L2_KEY_ITERATION:
+ raise ValueError(f"L2 index {l2_idx} out of range")
+
+ if l0_idx == -1 and l1_idx != -1:
+ raise ValueError("invalid combination of negative and non‐negative indices")
+
+ if l1_idx == -1 and l2_idx != -1:
+ raise ValueError("invalid combination of negative and non‐negative indices")
+
+ self._l0_idx = l0_idx
+ self._l1_idx = l1_idx
+ self._l2_idx = l2_idx
+
+ @property
+ def l0_idx(self) -> int:
+ return self._l0_idx
+
+ @property
+ def l1_idx(self) -> int:
+ return self._l1_idx
+
+ @property
+ def l2_idx(self) -> int:
+ return self._l2_idx
+
+ def gkid_type(self) -> GkidType:
+ if self.l0_idx == -1:
+ return GkidType.DEFAULT
+
+ if self.l1_idx == -1:
+ return GkidType.L0_SEED_KEY
+
+ if self.l2_idx == -1:
+ return GkidType.L1_SEED_KEY
+
+ return GkidType.L2_SEED_KEY
+
+ def wrapped_l1_idx(self) -> int:
+ if self.l1_idx == -1:
+ return L1_KEY_ITERATION
+
+ return self.l1_idx
+
+ def wrapped_l2_idx(self) -> int:
+ if self.l2_idx == -1:
+ return L2_KEY_ITERATION
+
+ return self.l2_idx
+
+ def derive_l1_seed_key(self) -> "Gkid":
+ gkid_type = self.gkid_type()
+ if (
+ gkid_type is not GkidType.L0_SEED_KEY
+ and gkid_type is not GkidType.L1_SEED_KEY
+ ):
+ raise InvalidDerivation(
+ "Invalid attempt to derive an L1 seed key from"
+ f" {gkid_type.description()}"
+ )
+
+ if self.l1_idx == 0:
+ raise InvalidDerivation("No further derivation of L1 seed keys is possible")
+
+ return Gkid(self.l0_idx, self.wrapped_l1_idx() - 1, self.l2_idx)
+
+ def derive_l2_seed_key(self) -> "Gkid":
+ gkid_type = self.gkid_type()
+ if (
+ gkid_type is not GkidType.L1_SEED_KEY
+ and gkid_type is not GkidType.L2_SEED_KEY
+ ):
+ raise InvalidDerivation(
+ f"Attempt to derive an L2 seed key from {gkid_type.description()}"
+ )
+
+ if self.l2_idx == 0:
+ raise InvalidDerivation("No further derivation of L2 seed keys is possible")
+
+ return Gkid(self.l0_idx, self.l1_idx, self.wrapped_l2_idx() - 1)
+
+ def __str__(self) -> str:
+ return f"Gkid({self.l0_idx}, {self.l1_idx}, {self.l2_idx})"
+
+ def __repr__(self) -> str:
+ cls = type(self)
+ return (
+ f"{cls.__qualname__}({repr(self.l0_idx)}, {repr(self.l1_idx)},"
+ f" {repr(self.l2_idx)})"
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Gkid):
+ return NotImplemented
+
+ return (self.l0_idx, self.l1_idx, self.l2_idx) == (
+ other.l0_idx,
+ other.l1_idx,
+ other.l2_idx,
+ )
+
+ def __lt__(self, other: object) -> bool:
+ if not isinstance(other, Gkid):
+ return NotImplemented
+
+ def as_tuple(gkid: Gkid) -> Tuple[int, int, int]:
+ l0_idx, l1_idx, l2_idx = gkid.l0_idx, gkid.l1_idx, gkid.l2_idx
+
+ # DEFAULT is considered less than everything else, so that the
+ # lexical ordering requirement in [MS-GKDI] 3.1.4.1.3 (GetKey) makes
+ # sense.
+ if gkid.gkid_type() is not GkidType.DEFAULT:
+ # Use the wrapped indices so that L1 seed keys are considered
+ # greater than their children L2 seed keys, and L0 seed keys are
+ # considered greater than their children L1 seed keys.
+ l1_idx = gkid.wrapped_l1_idx()
+ l2_idx = gkid.wrapped_l2_idx()
+
+ return l0_idx, l1_idx, l2_idx
+
+ return as_tuple(self) < as_tuple(other)
+
+ def __hash__(self) -> int:
+ return hash((self.l0_idx, self.l1_idx, self.l2_idx))
+
+ @staticmethod
+ def default() -> "Gkid":
+ return Gkid(-1, -1, -1)
+
+ @staticmethod
+ def l0_seed_key(l0_idx: int) -> "Gkid":
+ return Gkid(l0_idx, -1, -1)
+
+ @staticmethod
+ def l1_seed_key(l0_idx: int, l1_idx: int) -> "Gkid":
+ return Gkid(l0_idx, l1_idx, -1)
+
+ @staticmethod
+ def from_nt_time(nt_time: NtTime) -> "Gkid":
+ l0 = nt_time // (L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION)
+ l1 = (
+ nt_time
+ % (L1_KEY_ITERATION * L2_KEY_ITERATION * KEY_CYCLE_DURATION)
+ // (L2_KEY_ITERATION * KEY_CYCLE_DURATION)
+ )
+ l2 = nt_time % (L2_KEY_ITERATION * KEY_CYCLE_DURATION) // KEY_CYCLE_DURATION
+
+ return Gkid(l0, l1, l2)
+
+ def start_nt_time(self) -> NtTime:
+ gkid_type = self.gkid_type()
+ if gkid_type is not GkidType.L2_SEED_KEY:
+ raise UndefinedStartTime(
+ f"{gkid_type.description()} has no defined start time"
+ )
+
+ start_time = NtTime(
+ (
+ self.l0_idx * L1_KEY_ITERATION * L2_KEY_ITERATION
+ + self.l1_idx * L2_KEY_ITERATION
+ + self.l2_idx
+ )
+ * KEY_CYCLE_DURATION
+ )
+
+ if not 0 <= start_time <= uint64_max:
+ raise OverflowError(f"start time {start_time} out of range")
+
+ return start_time
+
+
+class SeedKeyPair:
+ __slots__ = ["l1_key", "l2_key", "gkid", "hash_algorithm", "root_key_id"]
+
+ def __init__(
+ self,
+ l1_key: Optional[bytes],
+ l2_key: Optional[bytes],
+ gkid: Gkid,
+ hash_algorithm: Algorithm,
+ root_key_id: misc.GUID,
+ ) -> None:
+ if l1_key is not None and len(l1_key) != KEY_LEN_BYTES:
+ raise ValueError(f"L1 key ({repr(l1_key)}) must be {KEY_LEN_BYTES} bytes")
+ if l2_key is not None and len(l2_key) != KEY_LEN_BYTES:
+ raise ValueError(f"L2 key ({repr(l2_key)}) must be {KEY_LEN_BYTES} bytes")
+
+ self.l1_key = l1_key
+ self.l2_key = l2_key
+ self.gkid = gkid
+ self.hash_algorithm = hash_algorithm
+ self.root_key_id = root_key_id
+
+ def __str__(self) -> str:
+ l1_key_hex = None if self.l1_key is None else self.l1_key.hex()
+ l2_key_hex = None if self.l2_key is None else self.l2_key.hex()
+
+ return (
+ f"SeedKeyPair(L1Key({l1_key_hex}), L2Key({l2_key_hex}), {self.gkid},"
+ f" {self.root_key_id}, {self.hash_algorithm})"
+ )
+
+ def __repr__(self) -> str:
+ cls = type(self)
+ return (
+ f"{cls.__qualname__}({repr(self.l1_key)}, {repr(self.l2_key)},"
+ f" {repr(self.gkid)}, {repr(self.hash_algorithm)},"
+ f" {repr(self.root_key_id)})"
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, SeedKeyPair):
+ return NotImplemented
+
+ return (
+ self.l1_key,
+ self.l2_key,
+ self.gkid,
+ self.hash_algorithm,
+ self.root_key_id,
+ ) == (
+ other.l1_key,
+ other.l2_key,
+ other.gkid,
+ other.hash_algorithm,
+ other.root_key_id,
+ )
+
+ def __hash__(self) -> int:
+ return hash((
+ self.l1_key,
+ self.l2_key,
+ self.gkid,
+ self.hash_algorithm,
+ ndr_pack(self.root_key_id),
+ ))
+
+
+class GroupKey:
+ __slots__ = ["gkid", "key", "hash_algorithm", "root_key_id"]
+
+ def __init__(
+ self, key: bytes, gkid: Gkid, hash_algorithm: Algorithm, root_key_id: misc.GUID
+ ) -> None:
+ if key is not None and len(key) != KEY_LEN_BYTES:
+ raise ValueError(f"Key ({repr(key)}) must be {KEY_LEN_BYTES} bytes")
+
+ self.key = key
+ self.gkid = gkid
+ self.hash_algorithm = hash_algorithm
+ self.root_key_id = root_key_id
+
+ def __str__(self) -> str:
+ return (
+ f"GroupKey(Key({self.key.hex()}), {self.gkid}, {self.hash_algorithm},"
+ f" {self.root_key_id})"
+ )
+
+ def __repr__(self) -> str:
+ cls = type(self)
+ return (
+ f"{cls.__qualname__}({repr(self.key)}, {repr(self.gkid)},"
+ f" {repr(self.hash_algorithm)}, {repr(self.root_key_id)})"
+ )
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, GroupKey):
+ return NotImplemented
+
+ return (self.key, self.gkid, self.hash_algorithm, self.root_key_id) == (
+ other.key,
+ other.gkid,
+ other.hash_algorithm,
+ other.root_key_id,
+ )
+
+ def __hash__(self) -> int:
+ return hash(
+ (self.key, self.gkid, self.hash_algorithm, ndr_pack(self.root_key_id))
+ )