summaryrefslogtreecommitdiffstats
path: root/tests/fix_faker.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/fix_faker.py')
-rw-r--r--tests/fix_faker.py868
1 files changed, 868 insertions, 0 deletions
diff --git a/tests/fix_faker.py b/tests/fix_faker.py
new file mode 100644
index 0000000..5289d8f
--- /dev/null
+++ b/tests/fix_faker.py
@@ -0,0 +1,868 @@
+import datetime as dt
+import importlib
+import ipaddress
+from math import isnan
+from uuid import UUID
+from random import choice, random, randrange
+from typing import Any, List, Set, Tuple, Union
+from decimal import Decimal
+from contextlib import contextmanager, asynccontextmanager
+
+import pytest
+
+import psycopg
+from psycopg import sql
+from psycopg.adapt import PyFormat
+from psycopg._compat import Deque
+from psycopg.types.range import Range
+from psycopg.types.json import Json, Jsonb
+from psycopg.types.numeric import Int4, Int8
+from psycopg.types.multirange import Multirange
+
+
+@pytest.fixture
+def faker(conn):
+ return Faker(conn)
+
+
+class Faker:
+ """
+ An object to generate random records.
+ """
+
+ json_max_level = 3
+ json_max_length = 10
+ str_max_length = 100
+ list_max_length = 20
+ tuple_max_length = 15
+
+ def __init__(self, connection):
+ self.conn = connection
+ self.format = PyFormat.BINARY
+ self.records = []
+
+ self._schema = None
+ self._types = None
+ self._types_names = None
+ self._makers = {}
+ self.table_name = sql.Identifier("fake_table")
+
+ @property
+ def schema(self):
+ if not self._schema:
+ self.schema = self.choose_schema()
+ return self._schema
+
+ @schema.setter
+ def schema(self, schema):
+ self._schema = schema
+ self._types_names = None
+
+ @property
+ def fields_names(self):
+ return [sql.Identifier(f"fld_{i}") for i in range(len(self.schema))]
+
+ @property
+ def types(self):
+ if not self._types:
+
+ def key(cls: type) -> str:
+ return cls.__name__
+
+ self._types = sorted(self.get_supported_types(), key=key)
+ return self._types
+
+ @property
+ def types_names_sql(self):
+ if self._types_names:
+ return self._types_names
+
+ record = self.make_record(nulls=0)
+ tx = psycopg.adapt.Transformer(self.conn)
+ types = [
+ self._get_type_name(tx, schema, value)
+ for schema, value in zip(self.schema, record)
+ ]
+ self._types_names = types
+ return types
+
+ @property
+ def types_names(self):
+ types = [t.as_string(self.conn).replace('"', "") for t in self.types_names_sql]
+ return types
+
+ def _get_type_name(self, tx, schema, value):
+ # Special case it as it is passed as unknown so is returned as text
+ if schema == (list, str):
+ return sql.SQL("text[]")
+
+ registry = self.conn.adapters.types
+ dumper = tx.get_dumper(value, self.format)
+ dumper.dump(value) # load the oid if it's dynamic (e.g. array)
+ info = registry.get(dumper.oid) or registry.get("text")
+ if dumper.oid == info.array_oid:
+ return sql.SQL("{}[]").format(sql.Identifier(info.name))
+ else:
+ return sql.Identifier(info.name)
+
+ @property
+ def drop_stmt(self):
+ return sql.SQL("drop table if exists {}").format(self.table_name)
+
+ @property
+ def create_stmt(self):
+ field_values = []
+ for name, type in zip(self.fields_names, self.types_names_sql):
+ field_values.append(sql.SQL("{} {}").format(name, type))
+
+ fields = sql.SQL(", ").join(field_values)
+ return sql.SQL("create table {table} (id serial primary key, {fields})").format(
+ table=self.table_name, fields=fields
+ )
+
+ @property
+ def insert_stmt(self):
+ phs = [sql.Placeholder(format=self.format) for i in range(len(self.schema))]
+ return sql.SQL("insert into {} ({}) values ({})").format(
+ self.table_name,
+ sql.SQL(", ").join(self.fields_names),
+ sql.SQL(", ").join(phs),
+ )
+
+ @property
+ def select_stmt(self):
+ fields = sql.SQL(", ").join(self.fields_names)
+ return sql.SQL("select {} from {} order by id").format(fields, self.table_name)
+
+ @contextmanager
+ def find_insert_problem(self, conn):
+ """Context manager to help finding a problematic value."""
+ try:
+ with conn.transaction():
+ yield
+ except psycopg.DatabaseError:
+ cur = conn.cursor()
+ # Repeat insert one field at time, until finding the wrong one
+ cur.execute(self.drop_stmt)
+ cur.execute(self.create_stmt)
+ for i, rec in enumerate(self.records):
+ for j, val in enumerate(rec):
+ try:
+ cur.execute(self._insert_field_stmt(j), (val,))
+ except psycopg.DatabaseError as e:
+ r = repr(val)
+ if len(r) > 200:
+ r = f"{r[:200]}... ({len(r)} chars)"
+ raise Exception(
+ f"value {r!r} at record {i} column0 {j} failed insert: {e}"
+ ) from None
+
+ # just in case, but hopefully we should have triggered the problem
+ raise
+
+ @asynccontextmanager
+ async def find_insert_problem_async(self, aconn):
+ try:
+ async with aconn.transaction():
+ yield
+ except psycopg.DatabaseError:
+ acur = aconn.cursor()
+ # Repeat insert one field at time, until finding the wrong one
+ await acur.execute(self.drop_stmt)
+ await acur.execute(self.create_stmt)
+ for i, rec in enumerate(self.records):
+ for j, val in enumerate(rec):
+ try:
+ await acur.execute(self._insert_field_stmt(j), (val,))
+ except psycopg.DatabaseError as e:
+ r = repr(val)
+ if len(r) > 200:
+ r = f"{r[:200]}... ({len(r)} chars)"
+ raise Exception(
+ f"value {r!r} at record {i} column0 {j} failed insert: {e}"
+ ) from None
+
+ # just in case, but hopefully we should have triggered the problem
+ raise
+
+ def _insert_field_stmt(self, i):
+ ph = sql.Placeholder(format=self.format)
+ return sql.SQL("insert into {} ({}) values ({})").format(
+ self.table_name, self.fields_names[i], ph
+ )
+
+ def choose_schema(self, ncols=20):
+ schema: List[Union[Tuple[type, ...], type]] = []
+ while len(schema) < ncols:
+ s = self.make_schema(choice(self.types))
+ if s is not None:
+ schema.append(s)
+ self.schema = schema
+ return schema
+
+ def make_records(self, nrecords):
+ self.records = [self.make_record(nulls=0.05) for i in range(nrecords)]
+
+ def make_record(self, nulls=0):
+ if not nulls:
+ return tuple(self.example(spec) for spec in self.schema)
+ else:
+ return tuple(
+ self.make(spec) if random() > nulls else None for spec in self.schema
+ )
+
+ def assert_record(self, got, want):
+ for spec, g, w in zip(self.schema, got, want):
+ if g is None and w is None:
+ continue
+ m = self.get_matcher(spec)
+ m(spec, g, w)
+
+ def get_supported_types(self) -> Set[type]:
+ dumpers = self.conn.adapters._dumpers[self.format]
+ rv = set()
+ for cls in dumpers.keys():
+ if isinstance(cls, str):
+ cls = deep_import(cls)
+ if issubclass(cls, Multirange) and self.conn.info.server_version < 140000:
+ continue
+
+ rv.add(cls)
+
+ # check all the types are handled
+ for cls in rv:
+ self.get_maker(cls)
+
+ return rv
+
+ def make_schema(self, cls: type) -> Union[Tuple[type, ...], type, None]:
+ """Create a schema spec from a Python type.
+
+ A schema specifies what Postgres type to generate when a Python type
+ maps to more than one (e.g. tuple -> composite, list -> array[],
+ datetime -> timestamp[tz]).
+
+ A schema for a type is represented by a tuple (type, ...) which the
+ matching make_*() method can interpret, or just type if the type
+ doesn't require further specification.
+
+ A `None` means that the type is not supported.
+ """
+ meth = self._get_method("schema", cls)
+ return meth(cls) if meth else cls
+
+ def get_maker(self, spec):
+ cls = spec if isinstance(spec, type) else spec[0]
+
+ try:
+ return self._makers[cls]
+ except KeyError:
+ pass
+
+ meth = self._get_method("make", cls)
+ if meth:
+ self._makers[cls] = meth
+ return meth
+ else:
+ raise NotImplementedError(f"cannot make fake objects of class {cls}")
+
+ def get_matcher(self, spec):
+ cls = spec if isinstance(spec, type) else spec[0]
+ meth = self._get_method("match", cls)
+ return meth if meth else self.match_any
+
+ def _get_method(self, prefix, cls):
+ name = cls.__name__
+ if cls.__module__ != "builtins":
+ name = f"{cls.__module__}.{name}"
+
+ parts = name.split(".")
+ for i in range(len(parts)):
+ mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}"
+ meth = getattr(self, mname, None)
+ if meth:
+ return meth
+
+ return None
+
+ def make(self, spec):
+ # spec can be a type or a tuple (type, options)
+ return self.get_maker(spec)(spec)
+
+ def example(self, spec):
+ # A good representative of the object - no degenerate case
+ cls = spec if isinstance(spec, type) else spec[0]
+ meth = self._get_method("example", cls)
+ if meth:
+ return meth(spec)
+ else:
+ return self.make(spec)
+
+ def match_any(self, spec, got, want):
+ assert got == want
+
+ # methods to generate samples of specific types
+
+ def make_Binary(self, spec):
+ return self.make_bytes(spec)
+
+ def match_Binary(self, spec, got, want):
+ return want.obj == got
+
+ def make_bool(self, spec):
+ return choice((True, False))
+
+ def make_bytearray(self, spec):
+ return self.make_bytes(spec)
+
+ def make_bytes(self, spec):
+ length = randrange(self.str_max_length)
+ return spec(bytes([randrange(256) for i in range(length)]))
+
+ def make_date(self, spec):
+ day = randrange(dt.date.max.toordinal())
+ return dt.date.fromordinal(day + 1)
+
+ def schema_datetime(self, cls):
+ return self.schema_time(cls)
+
+ def make_datetime(self, spec):
+ # Add a day because with timezone we might go BC
+ dtmin = dt.datetime.min + dt.timedelta(days=1)
+ delta = dt.datetime.max - dtmin
+ micros = randrange((delta.days + 1) * 24 * 60 * 60 * 1_000_000)
+ rv = dtmin + dt.timedelta(microseconds=micros)
+ if spec[1]:
+ rv = rv.replace(tzinfo=self._make_tz(spec))
+ return rv
+
+ def match_datetime(self, spec, got, want):
+ # Comparisons with different timezones is unreliable: certain pairs
+ # are reported different but their delta is 0
+ # https://bugs.python.org/issue45347
+ assert not (got - want)
+
+ def make_Decimal(self, spec):
+ if random() >= 0.99:
+ return Decimal(choice(self._decimal_special_values()))
+
+ sign = choice("+-")
+ num = choice(["0.zd", "d", "d.d"])
+ while "z" in num:
+ ndigits = randrange(1, 20)
+ num = num.replace("z", "0" * ndigits, 1)
+ while "d" in num:
+ ndigits = randrange(1, 20)
+ num = num.replace(
+ "d", "".join([str(randrange(10)) for i in range(ndigits)]), 1
+ )
+ expsign = choice(["e+", "e-", ""])
+ exp = randrange(20) if expsign else ""
+ rv = Decimal(f"{sign}{num}{expsign}{exp}")
+ return rv
+
+ def match_Decimal(self, spec, got, want):
+ if got is not None and got.is_nan():
+ assert want.is_nan()
+ else:
+ assert got == want
+
+ def _decimal_special_values(self):
+ values = ["NaN", "sNaN"]
+
+ if self.conn.info.vendor == "PostgreSQL":
+ if self.conn.info.server_version >= 140000:
+ values.extend(["Inf", "-Inf"])
+ elif self.conn.info.vendor == "CockroachDB":
+ if self.conn.info.server_version >= 220100:
+ values.extend(["Inf", "-Inf"])
+ else:
+ pytest.fail(f"unexpected vendor: {self.conn.info.vendor}")
+
+ return values
+
+ def schema_Enum(self, cls):
+ # TODO: can't fake those as we would need to create temporary types
+ return None
+
+ def make_Enum(self, spec):
+ return None
+
+ def make_float(self, spec, double=True):
+ if random() <= 0.99:
+ # These exponents should generate no inf
+ return float(
+ f"{choice('-+')}0.{randrange(1 << 53)}e{randrange(-310,309)}"
+ if double
+ else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
+ )
+ else:
+ return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan")))
+
+ def match_float(self, spec, got, want, approx=False, rel=None):
+ if got is not None and isnan(got):
+ assert isnan(want)
+ else:
+ if approx or self._server_rounds():
+ assert got == pytest.approx(want, rel=rel)
+ else:
+ assert got == want
+
+ def _server_rounds(self):
+ """Return True if the connected server perform float rounding"""
+ if self.conn.info.vendor == "CockroachDB":
+ return True
+ else:
+ # Versions older than 12 make some rounding. e.g. in Postgres 10.4
+ # select '-1.409006204063909e+112'::float8
+ # -> -1.40900620406391e+112
+ return self.conn.info.server_version < 120000
+
+ def make_Float4(self, spec):
+ return spec(self.make_float(spec, double=False))
+
+ def match_Float4(self, spec, got, want):
+ self.match_float(spec, got, want, approx=True, rel=1e-5)
+
+ def make_Float8(self, spec):
+ return spec(self.make_float(spec))
+
+ match_Float8 = match_float
+
+ def make_int(self, spec):
+ return randrange(-(1 << 90), 1 << 90)
+
+ def make_Int2(self, spec):
+ return spec(randrange(-(1 << 15), 1 << 15))
+
+ def make_Int4(self, spec):
+ return spec(randrange(-(1 << 31), 1 << 31))
+
+ def make_Int8(self, spec):
+ return spec(randrange(-(1 << 63), 1 << 63))
+
+ def make_IntNumeric(self, spec):
+ return spec(randrange(-(1 << 100), 1 << 100))
+
+ def make_IPv4Address(self, spec):
+ return ipaddress.IPv4Address(bytes(randrange(256) for _ in range(4)))
+
+ def make_IPv4Interface(self, spec):
+ prefix = randrange(32)
+ return ipaddress.IPv4Interface(
+ (bytes(randrange(256) for _ in range(4)), prefix)
+ )
+
+ def make_IPv4Network(self, spec):
+ return self.make_IPv4Interface(spec).network
+
+ def make_IPv6Address(self, spec):
+ return ipaddress.IPv6Address(bytes(randrange(256) for _ in range(16)))
+
+ def make_IPv6Interface(self, spec):
+ prefix = randrange(128)
+ return ipaddress.IPv6Interface(
+ (bytes(randrange(256) for _ in range(16)), prefix)
+ )
+
+ def make_IPv6Network(self, spec):
+ return self.make_IPv6Interface(spec).network
+
+ def make_Json(self, spec):
+ return spec(self._make_json())
+
+ def match_Json(self, spec, got, want):
+ if want is not None:
+ want = want.obj
+ assert got == want
+
+ def make_Jsonb(self, spec):
+ return spec(self._make_json())
+
+ def match_Jsonb(self, spec, got, want):
+ self.match_Json(spec, got, want)
+
+ def make_JsonFloat(self, spec):
+ # A float limited to what json accepts
+ # this exponent should generate no inf
+ return float(f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}")
+
+ def schema_list(self, cls):
+ while True:
+ scls = choice(self.types)
+ if scls is cls:
+ continue
+ if scls is float:
+ # TODO: float lists are currently adapted as decimal.
+ # There may be rounding errors or problems with inf.
+ continue
+
+ # CRDB doesn't support arrays of json
+ # https://github.com/cockroachdb/cockroach/issues/23468
+ if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb):
+ continue
+
+ schema = self.make_schema(scls)
+ if schema is not None:
+ break
+
+ return (cls, schema)
+
+ def make_list(self, spec):
+ # don't make empty lists because they regularly fail cast
+ length = randrange(1, self.list_max_length)
+ spec = spec[1]
+ while True:
+ rv = [self.make(spec) for i in range(length)]
+
+ # TODO multirange lists fail binary dump if the last element is
+ # empty and there is no type annotation. See xfail in
+ # test_multirange::test_dump_builtin_array
+ if rv and isinstance(rv[-1], Multirange) and not rv[-1]:
+ continue
+
+ return rv
+
+ def example_list(self, spec):
+ return [self.example(spec[1])]
+
+ def match_list(self, spec, got, want):
+ assert len(got) == len(want)
+ m = self.get_matcher(spec[1])
+ for g, w in zip(got, want):
+ m(spec[1], g, w)
+
+ def make_memoryview(self, spec):
+ return self.make_bytes(spec)
+
+ def schema_Multirange(self, cls):
+ return self.schema_Range(cls)
+
+ def make_Multirange(self, spec, length=None, **kwargs):
+ if length is None:
+ length = randrange(0, self.list_max_length)
+
+ def overlap(r1, r2):
+ l1, u1 = r1.lower, r1.upper
+ l2, u2 = r2.lower, r2.upper
+ if l1 is None and l2 is None:
+ return True
+ elif l1 is None:
+ l1 = l2
+ elif l2 is None:
+ l2 = l1
+
+ if u1 is None and u2 is None:
+ return True
+ elif u1 is None:
+ u1 = u2
+ elif u2 is None:
+ u2 = u1
+
+ return l1 <= u2 and l2 <= u1
+
+ out: List[Range[Any]] = []
+ for i in range(length):
+ r = self.make_Range((Range, spec[1]), **kwargs)
+ if r.isempty:
+ continue
+ for r2 in out:
+ if overlap(r, r2):
+ insert = False
+ break
+ else:
+ insert = True
+ if insert:
+ out.append(r) # alternatively, we could merge
+
+ return spec[0](sorted(out))
+
+ def example_Multirange(self, spec):
+ return self.make_Multirange(spec, length=1, empty_chance=0, no_bound_chance=0)
+
+ def make_Int4Multirange(self, spec):
+ return self.make_Multirange((spec, Int4))
+
+ def make_Int8Multirange(self, spec):
+ return self.make_Multirange((spec, Int8))
+
+ def make_NumericMultirange(self, spec):
+ return self.make_Multirange((spec, Decimal))
+
+ def make_DateMultirange(self, spec):
+ return self.make_Multirange((spec, dt.date))
+
+ def make_TimestampMultirange(self, spec):
+ return self.make_Multirange((spec, (dt.datetime, False)))
+
+ def make_TimestamptzMultirange(self, spec):
+ return self.make_Multirange((spec, (dt.datetime, True)))
+
+ def match_Multirange(self, spec, got, want):
+ assert len(got) == len(want)
+ for ig, iw in zip(got, want):
+ self.match_Range(spec, ig, iw)
+
+ def match_Int4Multirange(self, spec, got, want):
+ return self.match_Multirange((spec, Int4), got, want)
+
+ def match_Int8Multirange(self, spec, got, want):
+ return self.match_Multirange((spec, Int8), got, want)
+
+ def match_NumericMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, Decimal), got, want)
+
+ def match_DateMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, dt.date), got, want)
+
+ def match_TimestampMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, (dt.datetime, False)), got, want)
+
+ def match_TimestamptzMultirange(self, spec, got, want):
+ return self.match_Multirange((spec, (dt.datetime, True)), got, want)
+
+ def schema_NoneType(self, cls):
+ return None
+
+ def make_NoneType(self, spec):
+ return None
+
+ def make_Oid(self, spec):
+ return spec(randrange(1 << 32))
+
+ def schema_Range(self, cls):
+ subtypes = [
+ Decimal,
+ Int4,
+ Int8,
+ dt.date,
+ (dt.datetime, True),
+ (dt.datetime, False),
+ ]
+
+ return (cls, choice(subtypes))
+
+ def make_Range(self, spec, empty_chance=0.02, no_bound_chance=0.05):
+ # TODO: drop format check after fixing binary dumping of empty ranges
+ # (an array starting with an empty range will get the wrong type currently)
+ if (
+ random() < empty_chance
+ and spec[0] is Range
+ and self.format == PyFormat.TEXT
+ ):
+ return spec[0](empty=True)
+
+ while True:
+ bounds: List[Union[Any, None]] = []
+ while len(bounds) < 2:
+ if random() < no_bound_chance:
+ bounds.append(None)
+ continue
+
+ val = self.make(spec[1])
+ # NaN are allowed in a range, but comparison in Python get tricky.
+ if spec[1] is Decimal and val.is_nan():
+ continue
+
+ bounds.append(val)
+
+ if bounds[0] is not None and bounds[1] is not None:
+ if bounds[0] == bounds[1]:
+ # It would come out empty
+ continue
+
+ if bounds[0] > bounds[1]:
+ bounds.reverse()
+
+ # avoid generating ranges with no type info if dumping in binary
+ # TODO: lift this limitation after test_copy_in_empty xfail is fixed
+ if spec[0] is Range and self.format == PyFormat.BINARY:
+ if bounds[0] is bounds[1] is None:
+ continue
+
+ break
+
+ r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])"))
+ return r
+
+ def example_Range(self, spec):
+ return self.make_Range(spec, empty_chance=0, no_bound_chance=0)
+
+ def make_Int4Range(self, spec):
+ return self.make_Range((spec, Int4))
+
+ def make_Int8Range(self, spec):
+ return self.make_Range((spec, Int8))
+
+ def make_NumericRange(self, spec):
+ return self.make_Range((spec, Decimal))
+
+ def make_DateRange(self, spec):
+ return self.make_Range((spec, dt.date))
+
+ def make_TimestampRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, False)))
+
+ def make_TimestamptzRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, True)))
+
+ def match_Range(self, spec, got, want):
+ # normalise the bounds of unbounded ranges
+ if want.lower is None and want.lower_inc:
+ want = type(want)(want.lower, want.upper, "(" + want.bounds[1])
+ if want.upper is None and want.upper_inc:
+ want = type(want)(want.lower, want.upper, want.bounds[0] + ")")
+
+ # Normalise discrete ranges
+ unit: Union[dt.timedelta, int, None]
+ if spec[1] is dt.date:
+ unit = dt.timedelta(days=1)
+ elif type(spec[1]) is type and issubclass(spec[1], int):
+ unit = 1
+ else:
+ unit = None
+
+ if unit is not None:
+ if want.lower is not None and not want.lower_inc:
+ want = type(want)(want.lower + unit, want.upper, "[" + want.bounds[1])
+ if want.upper_inc:
+ want = type(want)(want.lower, want.upper + unit, want.bounds[0] + ")")
+
+ if spec[1] == (dt.datetime, True) and not want.isempty:
+ # work around https://bugs.python.org/issue45347
+ def fix_dt(x):
+ return x.astimezone(dt.timezone.utc) if x is not None else None
+
+ def fix_range(r):
+ return type(r)(fix_dt(r.lower), fix_dt(r.upper), r.bounds)
+
+ want = fix_range(want)
+ got = fix_range(got)
+
+ assert got == want
+
+ def match_Int4Range(self, spec, got, want):
+ return self.match_Range((spec, Int4), got, want)
+
+ def match_Int8Range(self, spec, got, want):
+ return self.match_Range((spec, Int8), got, want)
+
+ def match_NumericRange(self, spec, got, want):
+ return self.match_Range((spec, Decimal), got, want)
+
+ def match_DateRange(self, spec, got, want):
+ return self.match_Range((spec, dt.date), got, want)
+
+ def match_TimestampRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, False)), got, want)
+
+ def match_TimestamptzRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, True)), got, want)
+
+ def make_str(self, spec, length=0):
+ if not length:
+ length = randrange(self.str_max_length)
+
+ rv: List[int] = []
+ while len(rv) < length:
+ c = randrange(1, 128) if random() < 0.5 else randrange(1, 0x110000)
+ if not (0xD800 <= c <= 0xDBFF or 0xDC00 <= c <= 0xDFFF):
+ rv.append(c)
+
+ return "".join(map(chr, rv))
+
+ def schema_time(self, cls):
+ # Choose timezone yes/no
+ return (cls, choice([True, False]))
+
+ def make_time(self, spec):
+ val = randrange(24 * 60 * 60 * 1_000_000)
+ val, ms = divmod(val, 1_000_000)
+ val, s = divmod(val, 60)
+ h, m = divmod(val, 60)
+ tz = self._make_tz(spec) if spec[1] else None
+ return dt.time(h, m, s, ms, tz)
+
+ CRDB_TIMEDELTA_MAX = dt.timedelta(days=1281239)
+
+ def make_timedelta(self, spec):
+ if self.conn.info.vendor == "CockroachDB":
+ rng = [-self.CRDB_TIMEDELTA_MAX, self.CRDB_TIMEDELTA_MAX]
+ else:
+ rng = [dt.timedelta.min, dt.timedelta.max]
+
+ return choice(rng) * random()
+
+ def schema_tuple(self, cls):
+ # TODO: this is a complicated matter as it would involve creating
+ # temporary composite types.
+ # length = randrange(1, self.tuple_max_length)
+ # return (cls, self.make_random_schema(ncols=length))
+ return None
+
+ def make_tuple(self, spec):
+ return tuple(self.make(s) for s in spec[1])
+
+ def match_tuple(self, spec, got, want):
+ assert len(got) == len(want) == len(spec[1])
+ for g, w, s in zip(got, want, spec):
+ if g is None or w is None:
+ assert g is w
+ else:
+ m = self.get_matcher(s)
+ m(s, g, w)
+
+ def make_UUID(self, spec):
+ return UUID(bytes=bytes([randrange(256) for i in range(16)]))
+
+ def _make_json(self, container_chance=0.66):
+ rec_types = [list, dict]
+ scal_types = [type(None), int, JsonFloat, bool, str]
+ if random() < container_chance:
+ cls = choice(rec_types)
+ if cls is list:
+ return [
+ self._make_json(container_chance=container_chance / 2.0)
+ for i in range(randrange(self.json_max_length))
+ ]
+ elif cls is dict:
+ return {
+ self.make_str(str, 15): self._make_json(
+ container_chance=container_chance / 2.0
+ )
+ for i in range(randrange(self.json_max_length))
+ }
+ else:
+ assert False, f"unknown rec type: {cls}"
+
+ else:
+ cls = choice(scal_types) # type: ignore[assignment]
+ return self.make(cls)
+
+ def _make_tz(self, spec):
+ minutes = randrange(-12 * 60, 12 * 60 + 1)
+ return dt.timezone(dt.timedelta(minutes=minutes))
+
+
+class JsonFloat:
+ pass
+
+
+def deep_import(name):
+ parts = Deque(name.split("."))
+ seen = []
+ if not parts:
+ raise ValueError("name must be a dot-separated name")
+
+ seen.append(parts.popleft())
+ thing = importlib.import_module(seen[-1])
+ while parts:
+ attr = parts.popleft()
+ seen.append(attr)
+
+ if hasattr(thing, attr):
+ thing = getattr(thing, attr)
+ else:
+ thing = importlib.import_module(".".join(seen))
+
+ return thing