summaryrefslogtreecommitdiffstats
path: root/python/samba/tests/samba_tool/domain_models.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 17:20:00 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 17:20:00 +0000
commit8daa83a594a2e98f39d764422bfbdbc62c9efd44 (patch)
tree4099e8021376c7d8c05bdf8503093d80e9c7bad0 /python/samba/tests/samba_tool/domain_models.py
parentInitial commit. (diff)
downloadsamba-8daa83a594a2e98f39d764422bfbdbc62c9efd44.tar.xz
samba-8daa83a594a2e98f39d764422bfbdbc62c9efd44.zip
Adding upstream version 2:4.20.0+dfsg.upstream/2%4.20.0+dfsg
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'python/samba/tests/samba_tool/domain_models.py')
-rw-r--r--python/samba/tests/samba_tool/domain_models.py416
1 files changed, 416 insertions, 0 deletions
diff --git a/python/samba/tests/samba_tool/domain_models.py b/python/samba/tests/samba_tool/domain_models.py
new file mode 100644
index 0000000..e0f21fe
--- /dev/null
+++ b/python/samba/tests/samba_tool/domain_models.py
@@ -0,0 +1,416 @@
+# Unix SMB/CIFS implementation.
+#
+# Tests for domain models and fields
+#
+# Copyright (C) Catalyst.Net Ltd. 2023
+#
+# Written by Rob van der Linde <rob@catalyst.net.nz>
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+
+import os
+from datetime import datetime
+from xml.etree import ElementTree
+
+from ldb import FLAG_MOD_ADD, MessageElement, SCOPE_ONELEVEL
+from samba.dcerpc import security
+from samba.dcerpc.misc import GUID
+from samba.netcmd.domain.models import Group, User, fields
+from samba.netcmd.domain.models.auth_policy import StrongNTLMPolicy
+from samba.ndr import ndr_pack, ndr_unpack
+
+from .base import SambaToolCmdTest
+
+HOST = "ldap://{DC_SERVER}".format(**os.environ)
+CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
+
+
+class FieldTestMixin:
+ """Tests a model field to ensure it behaves correctly in both directions.
+
+ Use a mixin since TestCase can't be marked as abstract.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.samdb = cls.getSamDB("-H", HOST, CREDS)
+ super().setUpClass()
+
+ def get_users_dn(self):
+ """Returns Users DN."""
+ users_dn = self.samdb.get_root_basedn()
+ users_dn.add_child("CN=Users")
+ return users_dn
+
+ def test_to_db_value(self):
+ # Loop through each value and expected value combination.
+ # If the expected value is callable, treat it as a validation callback.
+ # NOTE: perhaps we should be using subtests for this.
+ for (value, expected) in self.to_db_value:
+ db_value = self.field.to_db_value(self.samdb, value, FLAG_MOD_ADD)
+ if callable(expected):
+ self.assertTrue(expected(db_value))
+ else:
+ self.assertEqual(db_value, expected)
+
+ def test_from_db_value(self):
+ # Loop through each value and expected value combination.
+ # NOTE: perhaps we should be using subtests for this.
+ for (db_value, expected) in self.from_db_value:
+ value = self.field.from_db_value(self.samdb, db_value)
+ self.assertEqual(value, expected)
+
+
+class IntegerFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.IntegerField("FieldName")
+
+ to_db_value = [
+ (10, MessageElement(b"10")),
+ ([1, 5, 10], MessageElement([b"1", b"5", b"10"])),
+ (None, None),
+ ]
+
+ from_db_value = [
+ (MessageElement(b"10"), 10),
+ (MessageElement([b"1", b"5", b"10"]), [1, 5, 10]),
+ (None, None),
+ ]
+
+
+class BinaryFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.BinaryField("FieldName")
+
+ to_db_value = [
+ (b"SAMBA", MessageElement(b"SAMBA")),
+ ([b"SAMBA", b"Developer"], MessageElement([b"SAMBA", b"Developer"])),
+ (None, None),
+ ]
+
+ from_db_value = [
+ (MessageElement(b"SAMBA"), b"SAMBA"),
+ (MessageElement([b"SAMBA", b"Developer"]), [b"SAMBA", b"Developer"]),
+ (None, None),
+ ]
+
+
+class StringFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.StringField("FieldName")
+
+ to_db_value = [
+ ("SAMBA", MessageElement(b"SAMBA")),
+ (["SAMBA", "Developer"], MessageElement([b"SAMBA", b"Developer"])),
+ (None, None),
+ ]
+
+ from_db_value = [
+ (MessageElement(b"SAMBA"), "SAMBA"),
+ (MessageElement([b"SAMBA", b"Developer"]), ["SAMBA", "Developer"]),
+ (None, None),
+ ]
+
+
+class BooleanFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.BooleanField("FieldName")
+
+ to_db_value = [
+ (True, MessageElement(b"TRUE")),
+ ([False, True], MessageElement([b"FALSE", b"TRUE"])),
+ (None, None),
+ ]
+
+ from_db_value = [
+ (MessageElement(b"TRUE"), True),
+ (MessageElement([b"FALSE", b"TRUE"]), [False, True]),
+ (None, None),
+ ]
+
+
+class EnumFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.EnumField("FieldName", StrongNTLMPolicy)
+
+ to_db_value = [
+ (StrongNTLMPolicy.OPTIONAL, MessageElement("1")),
+ ([StrongNTLMPolicy.REQUIRED, StrongNTLMPolicy.OPTIONAL],
+ MessageElement(["2", "1"])),
+ (None, None),
+ ]
+
+ from_db_value = [
+ (MessageElement("1"), StrongNTLMPolicy.OPTIONAL),
+ (MessageElement(["2", "1"]),
+ [StrongNTLMPolicy.REQUIRED, StrongNTLMPolicy.OPTIONAL]),
+ (None, None),
+ ]
+
+
+class DateTimeFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.DateTimeField("FieldName")
+
+ to_db_value = [
+ (datetime(2023, 1, 27, 22, 36, 41), MessageElement("20230127223641.0Z")),
+ ([datetime(2023, 1, 27, 22, 36, 41), datetime(2023, 1, 27, 22, 47, 50)],
+ MessageElement(["20230127223641.0Z", "20230127224750.0Z"])),
+ (None, None),
+ ]
+
+ from_db_value = [
+ (MessageElement("20230127223641.0Z"), datetime(2023, 1, 27, 22, 36, 41)),
+ (MessageElement(["20230127223641.0Z", "20230127224750.0Z"]),
+ [datetime(2023, 1, 27, 22, 36, 41), datetime(2023, 1, 27, 22, 47, 50)]),
+ (None, None),
+ ]
+
+
+class RelatedFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.RelatedField("FieldName", User)
+
+ @property
+ def to_db_value(self):
+ alice = User.get(self.samdb, username="alice")
+ joe = User.get(self.samdb, username="joe")
+ return [
+ (alice, MessageElement(str(alice.dn))),
+ ([joe, alice], MessageElement([str(joe.dn), str(alice.dn)])),
+ (None, None),
+ ]
+
+ @property
+ def from_db_value(self):
+ alice = User.get(self.samdb, username="alice")
+ joe = User.get(self.samdb, username="joe")
+ return [
+ (MessageElement(str(alice.dn)), alice),
+ (MessageElement([str(joe.dn), str(alice.dn)]), [joe, alice]),
+ (None, None),
+ ]
+
+
+class DnFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.DnField("FieldName")
+
+ @property
+ def to_db_value(self):
+ alice = User.get(self.samdb, username="alice")
+ joe = User.get(self.samdb, username="joe")
+ return [
+ (alice.dn, MessageElement(str(alice.dn))),
+ ([joe.dn, alice.dn], MessageElement([str(joe.dn), str(alice.dn)])),
+ (None, None),
+ ]
+
+ @property
+ def from_db_value(self):
+ alice = User.get(self.samdb, username="alice")
+ joe = User.get(self.samdb, username="joe")
+ return [
+ (MessageElement(str(alice.dn)), alice.dn),
+ (MessageElement([str(joe.dn), str(alice.dn)]), [joe.dn, alice.dn]),
+ (None, None),
+ ]
+
+
+class SIDFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.SIDField("FieldName")
+
+ @property
+ def to_db_value(self):
+ # Create a group for testing
+ group = Group(name="group1")
+ group.save(self.samdb)
+ self.addCleanup(group.delete, self.samdb)
+
+ # Get raw value to compare against
+ group_rec = self.samdb.search(Group.get_base_dn(self.samdb),
+ scope=SCOPE_ONELEVEL,
+ expression="(name=group1)",
+ attrs=["objectSid"])[0]
+ raw_sid = group_rec["objectSid"]
+
+ return [
+ (group.object_sid, raw_sid),
+ (None, None),
+ ]
+
+ @property
+ def from_db_value(self):
+ # Create a group for testing
+ group = Group(name="group1")
+ group.save(self.samdb)
+ self.addCleanup(group.delete, self.samdb)
+
+ # Get raw value to compare against
+ group_rec = self.samdb.search(Group.get_base_dn(self.samdb),
+ scope=SCOPE_ONELEVEL,
+ expression="(name=group1)",
+ attrs=["objectSid"])[0]
+ raw_sid = group_rec["objectSid"]
+
+ return [
+ (raw_sid, group.object_sid),
+ (None, None),
+ ]
+
+
+class GUIDFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.GUIDField("FieldName")
+
+ @property
+ def to_db_value(self):
+ users_dn = self.get_users_dn()
+
+ alice = self.samdb.search(users_dn,
+ scope=SCOPE_ONELEVEL,
+ expression="(sAMAccountName=alice)",
+ attrs=["objectGUID"])[0]
+
+ joe = self.samdb.search(users_dn,
+ scope=SCOPE_ONELEVEL,
+ expression="(sAMAccountName=joe)",
+ attrs=["objectGUID"])[0]
+
+ alice_guid = str(ndr_unpack(GUID, alice["objectGUID"][0]))
+ joe_guid = str(ndr_unpack(GUID, joe["objectGUID"][0]))
+
+ return [
+ (alice_guid, alice["objectGUID"]),
+ (
+ [joe_guid, alice_guid],
+ MessageElement([joe["objectGUID"][0], alice["objectGUID"][0]]),
+ ),
+ (None, None),
+ ]
+
+ @property
+ def from_db_value(self):
+ users_dn = self.get_users_dn()
+
+ alice = self.samdb.search(users_dn,
+ scope=SCOPE_ONELEVEL,
+ expression="(sAMAccountName=alice)",
+ attrs=["objectGUID"])[0]
+
+ joe = self.samdb.search(users_dn,
+ scope=SCOPE_ONELEVEL,
+ expression="(sAMAccountName=joe)",
+ attrs=["objectGUID"])[0]
+
+ alice_guid = str(ndr_unpack(GUID, alice["objectGUID"][0]))
+ joe_guid = str(ndr_unpack(GUID, joe["objectGUID"][0]))
+
+ return [
+ (alice["objectGUID"], alice_guid),
+ (
+ MessageElement([joe["objectGUID"][0], alice["objectGUID"][0]]),
+ [joe_guid, alice_guid],
+ ),
+ (None, None),
+ ]
+
+
+class SDDLFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.SDDLField("FieldName")
+
+ def setUp(self):
+ super().setUp()
+ self.domain_sid = security.dom_sid(self.samdb.get_domain_sid())
+
+ def encode(self, value):
+ return ndr_pack(security.descriptor.from_sddl(value, self.domain_sid))
+
+ @property
+ def to_db_value(self):
+ values = [
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AU)}))",
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AO)}))",
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;((Member_of {SID(AO)}) || (Member_of {SID(BO)})))",
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(%s)}))" % self.domain_sid,
+ ]
+ expected = [
+ (value, MessageElement(self.encode(value))) for value in values
+ ]
+ expected.append((None, None))
+ return expected
+
+ @property
+ def from_db_value(self):
+ values = [
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AU)}))",
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(AO)}))",
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;((Member_of {SID(AO)}) || (Member_of {SID(BO)})))",
+ "O:SYG:SYD:(XA;OICI;CR;;;WD;(Member_of {SID(%s)}))" % self.domain_sid,
+ ]
+ expected = [
+ (MessageElement(self.encode(value)), value) for value in values
+ ]
+ expected.append((None, None))
+ return expected
+
+
+class PossibleClaimValuesFieldTest(FieldTestMixin, SambaToolCmdTest):
+ field = fields.PossibleClaimValuesField("FieldName")
+
+ json_data = [{
+ "ValueGUID": "1c39ed4f-0b26-4536-b963-5959c8b1b676",
+ "ValueDisplayName": "Alice",
+ "ValueDescription": "Alice Description",
+ "Value": "alice",
+ }]
+
+ xml_data = "<?xml version='1.0' encoding='utf-16'?>" \
+ "<PossibleClaimValues xmlns:xsd='http://www.w3.org/2001/XMLSchema'" \
+ " xmlns:xsi='http://www.w3.org/2001/XMLSchema-instance'" \
+ " xmlns='http://schemas.microsoft.com/2010/08/ActiveDirectory/PossibleValues'>" \
+ "<StringList>" \
+ "<Item>" \
+ "<ValueGUID>1c39ed4f-0b26-4536-b963-5959c8b1b676</ValueGUID>" \
+ "<ValueDisplayName>Alice</ValueDisplayName>" \
+ "<ValueDescription>Alice Description</ValueDescription>" \
+ "<Value>alice</Value>" \
+ "</Item>" \
+ "</StringList>" \
+ "</PossibleClaimValues>"
+
+ def validate_xml(self, db_field):
+ """Callback that compares XML strings.
+
+ Tidying the HTMl output and adding consistent indentation was only
+ added to ETree in Python 3.9+ so generate a single line XML string.
+
+ This is just based on comparing the parsed XML, converted back
+ to a string, then comparing those strings.
+
+ So the expected xml_data string must have no spacing or indentation.
+
+ :param db_field: MessageElement value returned by field.to_db_field()
+ """
+ expected = ElementTree.fromstring(self.xml_data)
+ parsed = ElementTree.fromstring(str(db_field))
+ return ElementTree.tostring(parsed) == ElementTree.tostring(expected)
+
+ @property
+ def to_db_value(self):
+ return [
+ (self.json_data, self.validate_xml), # callback to validate XML
+ (self.json_data[0], self.validate_xml), # one item wrapped as list
+ ([], None), # empty list clears field
+ (None, None),
+ ]
+
+ @property
+ def from_db_value(self):
+ return [
+ (MessageElement(self.xml_data), self.json_data),
+ (None, None),
+ ]