diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-14 08:55:51 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-14 08:55:51 +0000 |
commit | b4ea5a3722487b0d46f197158c46229405b1048f (patch) | |
tree | 301b9af97ef5b1b5f72d6e5ef32aba3b93b73c39 /pydantic_extra_types | |
parent | Initial commit. (diff) | |
download | pydantic-extra-types-b4ea5a3722487b0d46f197158c46229405b1048f.tar.xz pydantic-extra-types-b4ea5a3722487b0d46f197158c46229405b1048f.zip |
Adding upstream version 2.6.0.upstream/2.6.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pydantic_extra_types')
-rw-r--r-- | pydantic_extra_types/__init__.py | 1 | ||||
-rw-r--r-- | pydantic_extra_types/color.py | 636 | ||||
-rw-r--r-- | pydantic_extra_types/coordinate.py | 145 | ||||
-rw-r--r-- | pydantic_extra_types/country.py | 281 | ||||
-rw-r--r-- | pydantic_extra_types/currency_code.py | 179 | ||||
-rw-r--r-- | pydantic_extra_types/isbn.py | 152 | ||||
-rw-r--r-- | pydantic_extra_types/language_code.py | 182 | ||||
-rw-r--r-- | pydantic_extra_types/mac_address.py | 125 | ||||
-rw-r--r-- | pydantic_extra_types/payment.py | 199 | ||||
-rw-r--r-- | pydantic_extra_types/pendulum_dt.py | 74 | ||||
-rw-r--r-- | pydantic_extra_types/phone_numbers.py | 68 | ||||
-rw-r--r-- | pydantic_extra_types/py.typed | 0 | ||||
-rw-r--r-- | pydantic_extra_types/routing_number.py | 89 | ||||
-rw-r--r-- | pydantic_extra_types/ulid.py | 62 |
14 files changed, 2193 insertions, 0 deletions
diff --git a/pydantic_extra_types/__init__.py b/pydantic_extra_types/__init__.py new file mode 100644 index 0000000..f0e5e1e --- /dev/null +++ b/pydantic_extra_types/__init__.py @@ -0,0 +1 @@ +__version__ = '2.6.0' diff --git a/pydantic_extra_types/color.py b/pydantic_extra_types/color.py new file mode 100644 index 0000000..34aa441 --- /dev/null +++ b/pydantic_extra_types/color.py @@ -0,0 +1,636 @@ +""" +Color definitions are used as per the CSS3 +[CSS Color Module Level 3](http://www.w3.org/TR/css3-color/#svg-color) specification. + +A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`. + +In these cases the _last_ color when sorted alphabetically takes preferences, +eg. `Color((0, 255, 255)).as_named() == 'cyan'` because "cyan" comes after "aqua". +""" +from __future__ import annotations + +import math +import re +from colorsys import hls_to_rgb, rgb_to_hls +from typing import Any, Callable, Literal, Tuple, Union, cast + +from pydantic import GetJsonSchemaHandler +from pydantic._internal import _repr +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema, PydanticCustomError, core_schema + +ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] +ColorType = Union[ColorTuple, str, 'Color'] +HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]] + + +class RGBA: + """ + Internal use only as a representation of a color. + """ + + __slots__ = 'r', 'g', 'b', 'alpha', '_tuple' + + def __init__(self, r: float, g: float, b: float, alpha: float | None): + self.r = r + self.g = g + self.b = b + self.alpha = alpha + + self._tuple: tuple[float, float, float, float | None] = (r, g, b, alpha) + + def __getitem__(self, item: Any) -> Any: + return self._tuple[item] + + +# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached +_r_255 = r'(\d{1,3}(?:\.\d+)?)' +_r_comma = r'\s*,\s*' +_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)' +_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?' +_r_sl = r'(\d{1,3}(?:\.\d+)?)%' +r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*' +r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*' +# CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%) +r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*' +# CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%) +r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*' +# CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%) +r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*' +# CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%) +r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*' + +# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used +repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'} +rads = 2 * math.pi + + +class Color(_repr.Representation): + """ + Represents a color. + """ + + __slots__ = '_original', '_rgba' + + def __init__(self, value: ColorType) -> None: + self._rgba: RGBA + self._original: ColorType + if isinstance(value, (tuple, list)): + self._rgba = parse_tuple(value) + elif isinstance(value, str): + self._rgba = parse_str(value) + elif isinstance(value, Color): + self._rgba = value._rgba + value = value._original + else: + raise PydanticCustomError( + 'color_error', + 'value is not a valid color: value must be a tuple, list or string', + ) + + # if we've got here value must be a valid color + self._original = value + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema: dict[str, Any] = {} + field_schema.update(type='string', format='color') + return field_schema + + def original(self) -> ColorType: + """ + Original value passed to `Color`. + """ + return self._original + + def as_named(self, *, fallback: bool = False) -> str: + """ + Returns the name of the color if it can be found in `COLORS_BY_VALUE` dictionary, + otherwise returns the hexadecimal representation of the color or raises `ValueError`. + + Args: + fallback: If True, falls back to returning the hexadecimal representation of + the color instead of raising a ValueError when no named color is found. + + Returns: + The name of the color, or the hexadecimal representation of the color. + + Raises: + ValueError: When no named color is found and fallback is `False`. + """ + if self._rgba.alpha is None: + rgb = cast(Tuple[int, int, int], self.as_rgb_tuple()) + try: + return COLORS_BY_VALUE[rgb] + except KeyError as e: + if fallback: + return self.as_hex() + else: + raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e + else: + return self.as_hex() + + def as_hex(self, format: Literal['short', 'long'] = 'short') -> str: + """Returns the hexadecimal representation of the color. + + Hex string representing the color can be 3, 4, 6, or 8 characters depending on whether the string + a "short" representation of the color is possible and whether there's an alpha channel. + + Returns: + The hexadecimal representation of the color. + """ + values = [float_to_255(c) for c in self._rgba[:3]] + if self._rgba.alpha is not None: + values.append(float_to_255(self._rgba.alpha)) + + as_hex = ''.join(f'{v:02x}' for v in values) + if format == 'short' and all(c in repeat_colors for c in values): + as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2)) + return '#' + as_hex + + def as_rgb(self) -> str: + """ + Color as an `rgb(<r>, <g>, <b>)` or `rgba(<r>, <g>, <b>, <a>)` string. + """ + if self._rgba.alpha is None: + return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})' + else: + return ( + f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, ' + f'{round(self._alpha_float(), 2)})' + ) + + def as_rgb_tuple(self, *, alpha: bool | None = None) -> ColorTuple: + """ + Returns the color as an RGB or RGBA tuple. + + Args: + alpha: Whether to include the alpha channel. There are three options for this input: + + - `None` (default): Include alpha only if it's set. (e.g. not `None`) + - `True`: Always include alpha. + - `False`: Always omit alpha. + + Returns: + A tuple that contains the values of the red, green, and blue channels in the range 0 to 255. + If alpha is included, it is in the range 0 to 1. + """ + r, g, b = (float_to_255(c) for c in self._rgba[:3]) + if alpha is None: + if self._rgba.alpha is None: + return r, g, b + else: + return r, g, b, self._alpha_float() + elif alpha: + return r, g, b, self._alpha_float() + else: + # alpha is False + return r, g, b + + def as_hsl(self) -> str: + """ + Color as an `hsl(<h>, <s>, <l>)` or `hsl(<h>, <s>, <l>, <a>)` string. + """ + if self._rgba.alpha is None: + h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore + return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})' + else: + h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore + return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})' + + def as_hsl_tuple(self, *, alpha: bool | None = None) -> HslColorTuple: + """ + Returns the color as an HSL or HSLA tuple. + + Args: + alpha: Whether to include the alpha channel. + + - `None` (default): Include the alpha channel only if it's set (e.g. not `None`). + - `True`: Always include alpha. + - `False`: Always omit alpha. + + Returns: + The color as a tuple of hue, saturation, lightness, and alpha (if included). + All elements are in the range 0 to 1. + + Note: + This is HSL as used in HTML and most other places, not HLS as used in Python's `colorsys`. + """ + h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) + if alpha is None: + if self._rgba.alpha is None: + return h, s, l + else: + return h, s, l, self._alpha_float() + if alpha: + return h, s, l, self._alpha_float() + else: + # alpha is False + return h, s, l + + def _alpha_float(self) -> float: + return 1 if self._rgba.alpha is None else self._rgba.alpha + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: Callable[[Any], CoreSchema] + ) -> core_schema.CoreSchema: + return core_schema.with_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) + + @classmethod + def _validate(cls, __input_value: Any, _: Any) -> Color: + return cls(__input_value) + + def __str__(self) -> str: + return self.as_named(fallback=True) + + def __repr_args__(self) -> _repr.ReprArgs: + return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple() + + def __hash__(self) -> int: + return hash(self.as_rgb_tuple()) + + +def parse_tuple(value: tuple[Any, ...]) -> RGBA: + """Parse a tuple or list to get RGBA values. + + Args: + value: A tuple or list. + + Returns: + An `RGBA` tuple parsed from the input tuple. + + Raises: + PydanticCustomError: If tuple is not valid. + """ + if len(value) == 3: + r, g, b = (parse_color_value(v) for v in value) + return RGBA(r, g, b, None) + elif len(value) == 4: + r, g, b = (parse_color_value(v) for v in value[:3]) + return RGBA(r, g, b, parse_float_alpha(value[3])) + else: + raise PydanticCustomError('color_error', 'value is not a valid color: tuples must have length 3 or 4') + + +def parse_str(value: str) -> RGBA: + """ + Parse a string representing a color to an RGBA tuple. + + Possible formats for the input string include: + + * named color, see `COLORS_BY_NAME` + * hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing) + * hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing) + * `rgb(<r>, <g>, <b>)` + * `rgba(<r>, <g>, <b>, <a>)` + * `transparent` + + Args: + value: A string representing a color. + + Returns: + An `RGBA` tuple parsed from the input string. + + Raises: + ValueError: If the input string cannot be parsed to an RGBA tuple. + """ + value_lower = value.lower() + try: + r, g, b = COLORS_BY_NAME[value_lower] + except KeyError: + pass + else: + return ints_to_rgba(r, g, b, None) + + m = re.fullmatch(r_hex_short, value_lower) + if m: + *rgb, a = m.groups() + r, g, b = (int(v * 2, 16) for v in rgb) + if a: + alpha: float | None = int(a * 2, 16) / 255 + else: + alpha = None + return ints_to_rgba(r, g, b, alpha) + + m = re.fullmatch(r_hex_long, value_lower) + if m: + *rgb, a = m.groups() + r, g, b = (int(v, 16) for v in rgb) + if a: + alpha = int(a, 16) / 255 + else: + alpha = None + return ints_to_rgba(r, g, b, alpha) + + m = re.fullmatch(r_rgb, value_lower) or re.fullmatch(r_rgb_v4_style, value_lower) + if m: + return ints_to_rgba(*m.groups()) # type: ignore + + m = re.fullmatch(r_hsl, value_lower) or re.fullmatch(r_hsl_v4_style, value_lower) + if m: + return parse_hsl(*m.groups()) # type: ignore + + if value_lower == 'transparent': + return RGBA(0, 0, 0, 0) + + raise PydanticCustomError( + 'color_error', + 'value is not a valid color: string not recognised as a valid color', + ) + + +def ints_to_rgba( + r: int | str, + g: int | str, + b: int | str, + alpha: float | None = None, +) -> RGBA: + """ + Converts integer or string values for RGB color and an optional alpha value to an `RGBA` object. + + Args: + r: An integer or string representing the red color value. + g: An integer or string representing the green color value. + b: An integer or string representing the blue color value. + alpha: A float representing the alpha value. Defaults to None. + + Returns: + An instance of the `RGBA` class with the corresponding color and alpha values. + """ + return RGBA( + parse_color_value(r), + parse_color_value(g), + parse_color_value(b), + parse_float_alpha(alpha), + ) + + +def parse_color_value(value: int | str, max_val: int = 255) -> float: + """ + Parse the color value provided and return a number between 0 and 1. + + Args: + value: An integer or string color value. + max_val: Maximum range value. Defaults to 255. + + Raises: + PydanticCustomError: If the value is not a valid color. + + Returns: + A number between 0 and 1. + """ + try: + color = float(value) + except ValueError: + raise PydanticCustomError( + 'color_error', + 'value is not a valid color: color values must be a valid number', + ) + if 0 <= color <= max_val: + return color / max_val + else: + raise PydanticCustomError( + 'color_error', + 'value is not a valid color: color values must be in the range 0 to {max_val}', + {'max_val': max_val}, + ) + + +def parse_float_alpha(value: None | str | float | int) -> float | None: + """ + Parse an alpha value checking it's a valid float in the range 0 to 1. + + Args: + value: The input value to parse. + + Returns: + The parsed value as a float, or `None` if the value was None or equal 1. + + Raises: + PydanticCustomError: If the input value cannot be successfully parsed as a float in the expected range. + """ + if value is None: + return None + try: + if isinstance(value, str) and value.endswith('%'): + alpha = float(value[:-1]) / 100 + else: + alpha = float(value) + except ValueError: + raise PydanticCustomError( + 'color_error', + 'value is not a valid color: alpha values must be a valid float', + ) + + if math.isclose(alpha, 1): + return None + elif 0 <= alpha <= 1: + return alpha + else: + raise PydanticCustomError( + 'color_error', + 'value is not a valid color: alpha values must be in the range 0 to 1', + ) + + +def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: float | None = None) -> RGBA: + """ + Parse raw hue, saturation, lightness, and alpha values and convert to RGBA. + + Args: + h: The hue value. + h_units: The unit for hue value. + sat: The saturation value. + light: The lightness value. + alpha: Alpha value. + + Returns: + An instance of `RGBA`. + """ + s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100) + + h_value = float(h) + if h_units in {None, 'deg'}: + h_value = h_value % 360 / 360 + elif h_units == 'rad': + h_value = h_value % rads / rads + else: + # turns + h_value = h_value % 1 + + r, g, b = hls_to_rgb(h_value, l_value, s_value) + return RGBA(r, g, b, parse_float_alpha(alpha)) + + +def float_to_255(c: float) -> int: + """ + Converts a float value between 0 and 1 (inclusive) to an integer between 0 and 255 (inclusive). + + Args: + c: The float value to be converted. Must be between 0 and 1 (inclusive). + + Returns: + The integer equivalent of the given float value rounded to the nearest whole number. + """ + return round(c * 255) + + +COLORS_BY_NAME = { + 'aliceblue': (240, 248, 255), + 'antiquewhite': (250, 235, 215), + 'aqua': (0, 255, 255), + 'aquamarine': (127, 255, 212), + 'azure': (240, 255, 255), + 'beige': (245, 245, 220), + 'bisque': (255, 228, 196), + 'black': (0, 0, 0), + 'blanchedalmond': (255, 235, 205), + 'blue': (0, 0, 255), + 'blueviolet': (138, 43, 226), + 'brown': (165, 42, 42), + 'burlywood': (222, 184, 135), + 'cadetblue': (95, 158, 160), + 'chartreuse': (127, 255, 0), + 'chocolate': (210, 105, 30), + 'coral': (255, 127, 80), + 'cornflowerblue': (100, 149, 237), + 'cornsilk': (255, 248, 220), + 'crimson': (220, 20, 60), + 'cyan': (0, 255, 255), + 'darkblue': (0, 0, 139), + 'darkcyan': (0, 139, 139), + 'darkgoldenrod': (184, 134, 11), + 'darkgray': (169, 169, 169), + 'darkgreen': (0, 100, 0), + 'darkgrey': (169, 169, 169), + 'darkkhaki': (189, 183, 107), + 'darkmagenta': (139, 0, 139), + 'darkolivegreen': (85, 107, 47), + 'darkorange': (255, 140, 0), + 'darkorchid': (153, 50, 204), + 'darkred': (139, 0, 0), + 'darksalmon': (233, 150, 122), + 'darkseagreen': (143, 188, 143), + 'darkslateblue': (72, 61, 139), + 'darkslategray': (47, 79, 79), + 'darkslategrey': (47, 79, 79), + 'darkturquoise': (0, 206, 209), + 'darkviolet': (148, 0, 211), + 'deeppink': (255, 20, 147), + 'deepskyblue': (0, 191, 255), + 'dimgray': (105, 105, 105), + 'dimgrey': (105, 105, 105), + 'dodgerblue': (30, 144, 255), + 'firebrick': (178, 34, 34), + 'floralwhite': (255, 250, 240), + 'forestgreen': (34, 139, 34), + 'fuchsia': (255, 0, 255), + 'gainsboro': (220, 220, 220), + 'ghostwhite': (248, 248, 255), + 'gold': (255, 215, 0), + 'goldenrod': (218, 165, 32), + 'gray': (128, 128, 128), + 'green': (0, 128, 0), + 'greenyellow': (173, 255, 47), + 'grey': (128, 128, 128), + 'honeydew': (240, 255, 240), + 'hotpink': (255, 105, 180), + 'indianred': (205, 92, 92), + 'indigo': (75, 0, 130), + 'ivory': (255, 255, 240), + 'khaki': (240, 230, 140), + 'lavender': (230, 230, 250), + 'lavenderblush': (255, 240, 245), + 'lawngreen': (124, 252, 0), + 'lemonchiffon': (255, 250, 205), + 'lightblue': (173, 216, 230), + 'lightcoral': (240, 128, 128), + 'lightcyan': (224, 255, 255), + 'lightgoldenrodyellow': (250, 250, 210), + 'lightgray': (211, 211, 211), + 'lightgreen': (144, 238, 144), + 'lightgrey': (211, 211, 211), + 'lightpink': (255, 182, 193), + 'lightsalmon': (255, 160, 122), + 'lightseagreen': (32, 178, 170), + 'lightskyblue': (135, 206, 250), + 'lightslategray': (119, 136, 153), + 'lightslategrey': (119, 136, 153), + 'lightsteelblue': (176, 196, 222), + 'lightyellow': (255, 255, 224), + 'lime': (0, 255, 0), + 'limegreen': (50, 205, 50), + 'linen': (250, 240, 230), + 'magenta': (255, 0, 255), + 'maroon': (128, 0, 0), + 'mediumaquamarine': (102, 205, 170), + 'mediumblue': (0, 0, 205), + 'mediumorchid': (186, 85, 211), + 'mediumpurple': (147, 112, 219), + 'mediumseagreen': (60, 179, 113), + 'mediumslateblue': (123, 104, 238), + 'mediumspringgreen': (0, 250, 154), + 'mediumturquoise': (72, 209, 204), + 'mediumvioletred': (199, 21, 133), + 'midnightblue': (25, 25, 112), + 'mintcream': (245, 255, 250), + 'mistyrose': (255, 228, 225), + 'moccasin': (255, 228, 181), + 'navajowhite': (255, 222, 173), + 'navy': (0, 0, 128), + 'oldlace': (253, 245, 230), + 'olive': (128, 128, 0), + 'olivedrab': (107, 142, 35), + 'orange': (255, 165, 0), + 'orangered': (255, 69, 0), + 'orchid': (218, 112, 214), + 'palegoldenrod': (238, 232, 170), + 'palegreen': (152, 251, 152), + 'paleturquoise': (175, 238, 238), + 'palevioletred': (219, 112, 147), + 'papayawhip': (255, 239, 213), + 'peachpuff': (255, 218, 185), + 'peru': (205, 133, 63), + 'pink': (255, 192, 203), + 'plum': (221, 160, 221), + 'powderblue': (176, 224, 230), + 'purple': (128, 0, 128), + 'red': (255, 0, 0), + 'rosybrown': (188, 143, 143), + 'royalblue': (65, 105, 225), + 'saddlebrown': (139, 69, 19), + 'salmon': (250, 128, 114), + 'sandybrown': (244, 164, 96), + 'seagreen': (46, 139, 87), + 'seashell': (255, 245, 238), + 'sienna': (160, 82, 45), + 'silver': (192, 192, 192), + 'skyblue': (135, 206, 235), + 'slateblue': (106, 90, 205), + 'slategray': (112, 128, 144), + 'slategrey': (112, 128, 144), + 'snow': (255, 250, 250), + 'springgreen': (0, 255, 127), + 'steelblue': (70, 130, 180), + 'tan': (210, 180, 140), + 'teal': (0, 128, 128), + 'thistle': (216, 191, 216), + 'tomato': (255, 99, 71), + 'turquoise': (64, 224, 208), + 'violet': (238, 130, 238), + 'wheat': (245, 222, 179), + 'white': (255, 255, 255), + 'whitesmoke': (245, 245, 245), + 'yellow': (255, 255, 0), + 'yellowgreen': (154, 205, 50), +} + +COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()} diff --git a/pydantic_extra_types/coordinate.py b/pydantic_extra_types/coordinate.py new file mode 100644 index 0000000..df470d5 --- /dev/null +++ b/pydantic_extra_types/coordinate.py @@ -0,0 +1,145 @@ +""" +The `pydantic_extra_types.coordinate` module provides the [`Latitude`][pydantic_extra_types.coordinate.Latitude], +[`Longitude`][pydantic_extra_types.coordinate.Longitude], and +[`Coordinate`][pydantic_extra_types.coordinate.Coordinate] data types. +""" +from dataclasses import dataclass +from typing import Any, ClassVar, Tuple, Type + +from pydantic import GetCoreSchemaHandler +from pydantic._internal import _repr +from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema + + +class Latitude(float): + """Latitude value should be between -90 and 90, inclusive. + + ```py + from pydantic import BaseModel + from pydantic_extra_types.coordinate import Latitude + + class Location(BaseModel): + latitude: Latitude + + location = Location(latitude=41.40338) + print(location) + #> latitude=41.40338 + ``` + """ + + min: ClassVar[float] = -90.00 + max: ClassVar[float] = 90.00 + + @classmethod + def __get_pydantic_core_schema__(cls, source: Type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.float_schema(ge=cls.min, le=cls.max) + + +class Longitude(float): + """Longitude value should be between -180 and 180, inclusive. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.coordinate import Longitude + + class Location(BaseModel): + longitude: Longitude + + location = Location(longitude=2.17403) + print(location) + #> longitude=2.17403 + ``` + """ + + min: ClassVar[float] = -180.00 + max: ClassVar[float] = 180.00 + + @classmethod + def __get_pydantic_core_schema__(cls, source: Type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.float_schema(ge=cls.min, le=cls.max) + + +@dataclass +class Coordinate(_repr.Representation): + """Coordinate parses Latitude and Longitude. + + You can use the `Coordinate` data type for storing coordinates. Coordinates can be + defined using one of the following formats: + + 1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)`. + 2. `Coordinate` instance: `Coordinate(latitude=Latitude, longitude=Longitude)`. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.coordinate import Coordinate + + class Location(BaseModel): + coordinate: Coordinate + + location = Location(coordinate=(41.40338, 2.17403)) + #> coordinate=Coordinate(latitude=41.40338, longitude=2.17403) + ``` + """ + + _NULL_ISLAND: ClassVar[Tuple[float, float]] = (0.0, 0.0) + + latitude: Latitude + longitude: Longitude + + @classmethod + def __get_pydantic_core_schema__(cls, source: Type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + schema_chain = [ + core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()), + core_schema.no_info_wrap_validator_function( + cls._parse_tuple, + handler.generate_schema(Tuple[float, float]), + ), + handler(source), + ] + + chain_length = len(schema_chain) + chain_schemas = [core_schema.chain_schema(schema_chain[x:]) for x in range(chain_length - 1, -1, -1)] + return core_schema.no_info_wrap_validator_function( + cls._parse_args, + core_schema.union_schema(chain_schemas), # type: ignore[arg-type] + ) + + @classmethod + def _parse_args(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: + if isinstance(value, ArgsKwargs) and not value.kwargs: + n_args = len(value.args) + if n_args == 0: + value = cls._NULL_ISLAND + elif n_args == 1: + value = value.args[0] + return handler(value) + + @classmethod + def _parse_str(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: + if not isinstance(value, str): + return value + try: + value = tuple(float(x) for x in value.split(',')) + except ValueError: + raise PydanticCustomError( + 'coordinate_error', + 'value is not a valid coordinate: string is not recognized as a valid coordinate', + ) + return ArgsKwargs(args=value) + + @classmethod + def _parse_tuple(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: + if not isinstance(value, tuple): + return value + return ArgsKwargs(args=handler(value)) + + def __str__(self) -> str: + return f'{self.latitude},{self.longitude}' + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Coordinate) and self.latitude == other.latitude and self.longitude == other.longitude + + def __hash__(self) -> int: + return hash((self.latitude, self.longitude)) diff --git a/pydantic_extra_types/country.py b/pydantic_extra_types/country.py new file mode 100644 index 0000000..a6d26e2 --- /dev/null +++ b/pydantic_extra_types/country.py @@ -0,0 +1,281 @@ +""" +Country definitions that are based on the [ISO 3166](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). +""" +from __future__ import annotations + +from dataclasses import dataclass +from functools import lru_cache +from typing import Any + +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + +try: + import pycountry +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + 'The `country` module requires "pycountry" to be installed. You can install it with "pip install pycountry".' + ) + + +@dataclass +class CountryInfo: + alpha2: str + alpha3: str + numeric_code: str + short_name: str + + +@lru_cache +def _countries() -> list[CountryInfo]: + return [ + CountryInfo( + alpha2=country.alpha_2, + alpha3=country.alpha_3, + numeric_code=country.numeric, + short_name=country.name, + ) + for country in pycountry.countries + ] + + +@lru_cache +def _index_by_alpha2() -> dict[str, CountryInfo]: + return {country.alpha2: country for country in _countries()} + + +@lru_cache +def _index_by_alpha3() -> dict[str, CountryInfo]: + return {country.alpha3: country for country in _countries()} + + +@lru_cache +def _index_by_numeric_code() -> dict[str, CountryInfo]: + return {country.numeric_code: country for country in _countries()} + + +@lru_cache +def _index_by_short_name() -> dict[str, CountryInfo]: + return {country.short_name: country for country in _countries()} + + +class CountryAlpha2(str): + """CountryAlpha2 parses country codes in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) + format. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.country import CountryAlpha2 + + class Product(BaseModel): + made_in: CountryAlpha2 + + product = Product(made_in='ES') + print(product) + #> made_in='ES' + ``` + """ + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryAlpha2: + if __input_value not in _index_by_alpha2(): + raise PydanticCustomError('country_alpha2', 'Invalid country alpha2 code') + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(to_upper=True), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + json_schema = handler(schema) + json_schema.update({'pattern': r'^\w{2}$'}) + return json_schema + + @property + def alpha3(self) -> str: + """The country code in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) format.""" + return _index_by_alpha2()[self].alpha3 + + @property + def numeric_code(self) -> str: + """The country code in the [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format.""" + return _index_by_alpha2()[self].numeric_code + + @property + def short_name(self) -> str: + """The country short name.""" + return _index_by_alpha2()[self].short_name + + +class CountryAlpha3(str): + """CountryAlpha3 parses country codes in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) + format. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.country import CountryAlpha3 + + class Product(BaseModel): + made_in: CountryAlpha3 + + product = Product(made_in="USA") + print(product) + #> made_in='USA' + ``` + """ + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryAlpha3: + if __input_value not in _index_by_alpha3(): + raise PydanticCustomError('country_alpha3', 'Invalid country alpha3 code') + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(to_upper=True), + serialization=core_schema.to_string_ser_schema(), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + json_schema = handler(schema) + json_schema.update({'pattern': r'^\w{3}$'}) + return json_schema + + @property + def alpha2(self) -> str: + """The country code in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) format.""" + return _index_by_alpha3()[self].alpha2 + + @property + def numeric_code(self) -> str: + """The country code in the [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format.""" + return _index_by_alpha3()[self].numeric_code + + @property + def short_name(self) -> str: + """The country short name.""" + return _index_by_alpha3()[self].short_name + + +class CountryNumericCode(str): + """CountryNumericCode parses country codes in the + [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.country import CountryNumericCode + + class Product(BaseModel): + made_in: CountryNumericCode + + product = Product(made_in="840") + print(product) + #> made_in='840' + ``` + """ + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryNumericCode: + if __input_value not in _index_by_numeric_code(): + raise PydanticCustomError('country_numeric_code', 'Invalid country numeric code') + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(to_upper=True), + serialization=core_schema.to_string_ser_schema(), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + json_schema = handler(schema) + json_schema.update({'pattern': r'^[0-9]{3}$'}) + return json_schema + + @property + def alpha2(self) -> str: + """The country code in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) format.""" + return _index_by_numeric_code()[self].alpha2 + + @property + def alpha3(self) -> str: + """The country code in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) format.""" + return _index_by_numeric_code()[self].alpha3 + + @property + def short_name(self) -> str: + """The country short name.""" + return _index_by_numeric_code()[self].short_name + + +class CountryShortName(str): + """CountryShortName parses country codes in the short name format. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.country import CountryShortName + + class Product(BaseModel): + made_in: CountryShortName + + product = Product(made_in="United States") + print(product) + #> made_in='United States' + ``` + """ + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> CountryShortName: + if __input_value not in _index_by_short_name(): + raise PydanticCustomError('country_short_name', 'Invalid country short name') + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(), + serialization=core_schema.to_string_ser_schema(), + ) + + @property + def alpha2(self) -> str: + """The country code in the [ISO 3166-1 alpha-2](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) format.""" + return _index_by_short_name()[self].alpha2 + + @property + def alpha3(self) -> str: + """The country code in the [ISO 3166-1 alpha-3](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-3) format.""" + return _index_by_short_name()[self].alpha3 + + @property + def numeric_code(self) -> str: + """The country code in the [ISO 3166-1 numeric](https://en.wikipedia.org/wiki/ISO_3166-1_numeric) format.""" + return _index_by_short_name()[self].numeric_code diff --git a/pydantic_extra_types/currency_code.py b/pydantic_extra_types/currency_code.py new file mode 100644 index 0000000..c19d9bf --- /dev/null +++ b/pydantic_extra_types/currency_code.py @@ -0,0 +1,179 @@ +""" +Currency definitions that are based on the [ISO4217](https://en.wikipedia.org/wiki/ISO_4217). +""" +from __future__ import annotations + +from typing import Any + +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + +try: + import pycountry +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + 'The `currency_code` module requires "pycountry" to be installed. You can install it with "pip install ' + 'pycountry".' + ) + +# List of codes that should not be usually used within regular transactions +_CODES_FOR_BONDS_METAL_TESTING = { + 'XTS', # testing + 'XAU', # gold + 'XAG', # silver + 'XPD', # palladium + 'XPT', # platinum + 'XBA', # Bond Markets Unit European Composite Unit (EURCO) + 'XBB', # Bond Markets Unit European Monetary Unit (E.M.U.-6) + 'XBC', # Bond Markets Unit European Unit of Account 9 (E.U.A.-9) + 'XBD', # Bond Markets Unit European Unit of Account 17 (E.U.A.-17) + 'XXX', # no currency + 'XDR', # SDR (Special Drawing Right) +} + + +class ISO4217(str): + """ISO4217 parses Currency in the [ISO 4217](https://en.wikipedia.org/wiki/ISO_4217) format. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.currency_code import ISO4217 + + class Currency(BaseModel): + alpha_3: ISO4217 + + currency = Currency(alpha_3='AED') + print(currency) + # > alpha_3='AED' + ``` + """ + + allowed_countries_list = [country.alpha_3 for country in pycountry.currencies] + allowed_currencies = set(allowed_countries_list) + + @classmethod + def _validate(cls, currency_code: str, _: core_schema.ValidationInfo) -> str: + """ + Validate a ISO 4217 language code from the provided str value. + + Args: + currency_code: The str value to be validated. + _: The Pydantic ValidationInfo. + + Returns: + The validated ISO 4217 currency code. + + Raises: + PydanticCustomError: If the ISO 4217 currency code is not valid. + """ + if currency_code not in cls.allowed_currencies: + raise PydanticCustomError( + 'ISO4217', 'Invalid ISO 4217 currency code. See https://en.wikipedia.org/wiki/ISO_4217' + ) + return currency_code + + @classmethod + def __get_pydantic_core_schema__(cls, _: type[Any], __: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=3, max_length=3), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + json_schema = handler(schema) + json_schema.update({'enum': cls.allowed_countries_list}) + return json_schema + + +class Currency(str): + """Currency parses currency subset of the [ISO 4217](https://en.wikipedia.org/wiki/ISO_4217) format. + It excludes bonds testing codes and precious metals. + ```py + from pydantic import BaseModel + + from pydantic_extra_types.currency_code import Currency + + class currency(BaseModel): + alpha_3: Currency + + cur = currency(alpha_3='AED') + print(cur) + # > alpha_3='AED' + ``` + """ + + allowed_countries_list = list( + filter(lambda x: x not in _CODES_FOR_BONDS_METAL_TESTING, ISO4217.allowed_countries_list) + ) + allowed_currencies = set(allowed_countries_list) + + @classmethod + def _validate(cls, currency_symbol: str, _: core_schema.ValidationInfo) -> str: + """ + Validate a subset of the [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. + It excludes bonds testing codes and precious metals. + + Args: + currency_symbol: The str value to be validated. + _: The Pydantic ValidationInfo. + + Returns: + The validated ISO 4217 currency code. + + Raises: + PydanticCustomError: If the ISO 4217 currency code is not valid or is bond, precious metal or testing code. + """ + if currency_symbol not in cls.allowed_currencies: + raise PydanticCustomError( + 'InvalidCurrency', + 'Invalid currency code.' + ' See https://en.wikipedia.org/wiki/ISO_4217. ' + 'Bonds, testing and precious metals codes are not allowed.', + ) + return currency_symbol + + @classmethod + def __get_pydantic_core_schema__(cls, _: type[Any], __: GetCoreSchemaHandler) -> core_schema.CoreSchema: + """ + Return a Pydantic CoreSchema with the currency subset of the + [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. + It excludes bonds testing codes and precious metals. + + Args: + _: The source type. + __: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the subset of the currency subset of the + [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. + It excludes bonds testing codes and precious metals. + """ + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=3, max_length=3), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + """ + Return a Pydantic JSON Schema with subset of the [ISO4217](https://en.wikipedia.org/wiki/ISO_4217) format. + Excluding bonds testing codes and precious metals. + + Args: + schema: The Pydantic CoreSchema. + handler: The handler to get the JSON Schema. + + Returns: + A Pydantic JSON Schema with the subset of the ISO4217 currency code validation. without bonds testing codes + and precious metals. + + """ + json_schema = handler(schema) + json_schema.update({'enum': cls.allowed_countries_list}) + return json_schema diff --git a/pydantic_extra_types/isbn.py b/pydantic_extra_types/isbn.py new file mode 100644 index 0000000..df573c6 --- /dev/null +++ b/pydantic_extra_types/isbn.py @@ -0,0 +1,152 @@ +""" +The `pydantic_extra_types.isbn` module provides functionality to recieve and validate ISBN. + +ISBN (International Standard Book Number) is a numeric commercial book identifier which is intended to be unique. This module provides a ISBN type for Pydantic models. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import GetCoreSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +def isbn10_digit_calc(isbn: str) -> str: + """Calc a ISBN-10 last digit from the provided str value. More information of validation algorithm on [Wikipedia](https://en.wikipedia.org/wiki/ISBN#Check_digits) + + Args: + isbn: The str value representing the ISBN in 10 digits. + + Returns: + The calculated last digit of the ISBN-10 value. + """ + total = sum(int(digit) * (10 - idx) for idx, digit in enumerate(isbn[:9])) + + for check_digit in range(1, 11): + if (total + check_digit) % 11 == 0: + valid_check_digit = 'X' if check_digit == 10 else str(check_digit) + + return valid_check_digit + + +def isbn13_digit_calc(isbn: str) -> str: + """Calc a ISBN-13 last digit from the provided str value. More information of validation algorithm on [Wikipedia](https://en.wikipedia.org/wiki/ISBN#Check_digits) + + Args: + isbn: The str value representing the ISBN in 13 digits. + + Returns: + The calculated last digit of the ISBN-13 value. + """ + total = sum(int(digit) * (1 if idx % 2 == 0 else 3) for idx, digit in enumerate(isbn[:12])) + + check_digit = (10 - (total % 10)) % 10 + + return str(check_digit) + + +class ISBN(str): + """Represents a ISBN and provides methods for conversion, validation, and serialization. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.isbn import ISBN + + + class Book(BaseModel): + isbn: ISBN + + book = Book(isbn="8537809667") + print(book) + #> isbn='9788537809662' + ``` + """ + + @classmethod + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + """ + Return a Pydantic CoreSchema with the ISBN validation. + + Args: + source: The source type to be converted. + handler: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the ISBN validation. + + """ + return core_schema.with_info_before_validator_function( + cls._validate, + core_schema.str_schema(), + ) + + @classmethod + def _validate(cls, __input_value: str, _: Any) -> str: + """ + Validate a ISBN from the provided str value. + + Args: + __input_value: The str value to be validated. + _: The source type to be converted. + + Returns: + The validated ISBN. + + Raises: + PydanticCustomError: If the ISBN is not valid. + """ + cls.validate_isbn_format(__input_value) + + return cls.convert_isbn10_to_isbn13(__input_value) + + @staticmethod + def validate_isbn_format(value: str) -> None: + """Validate a ISBN format from the provided str value. + + Args: + value: The str value representing the ISBN in 10 or 13 digits. + + Raises: + PydanticCustomError: If the ISBN is not valid. + """ + + isbn_length = len(value) + + if isbn_length not in (10, 13): + raise PydanticCustomError('isbn_length', f'Length for ISBN must be 10 or 13 digits, not {isbn_length}') + + if isbn_length == 10: + if not value[:-1].isdigit() or ((value[-1] != 'X') and (not value[-1].isdigit())): + raise PydanticCustomError('isbn10_invalid_characters', 'First 9 digits of ISBN-10 must be integers') + if isbn10_digit_calc(value) != value[-1]: + raise PydanticCustomError('isbn_invalid_digit_check_isbn10', 'Provided digit is invalid for given ISBN') + + if isbn_length == 13: + if not value.isdigit(): + raise PydanticCustomError('isbn13_invalid_characters', 'All digits of ISBN-13 must be integers') + if value[:3] not in ('978', '979'): + raise PydanticCustomError( + 'isbn_invalid_early_characters', 'The first 3 digits of ISBN-13 must be 978 or 979' + ) + if isbn13_digit_calc(value) != value[-1]: + raise PydanticCustomError('isbn_invalid_digit_check_isbn13', 'Provided digit is invalid for given ISBN') + + @staticmethod + def convert_isbn10_to_isbn13(value: str) -> str: + """Convert an ISBN-10 to ISBN-13. + + Args: + value: The ISBN-10 value to be converted. + + Returns: + The converted ISBN or the original value if no conversion is necessary. + """ + + if len(value) == 10: + base_isbn = f'978{value[:-1]}' + isbn13_digit = isbn13_digit_calc(base_isbn) + return ISBN(f'{base_isbn}{isbn13_digit}') + + return ISBN(value) diff --git a/pydantic_extra_types/language_code.py b/pydantic_extra_types/language_code.py new file mode 100644 index 0000000..117e877 --- /dev/null +++ b/pydantic_extra_types/language_code.py @@ -0,0 +1,182 @@ +""" +Language definitions that are based on the [ISO 639-3](https://en.wikipedia.org/wiki/ISO_639-3) & [ISO 639-5](https://en.wikipedia.org/wiki/ISO_639-5). +""" +from __future__ import annotations + +from typing import Any + +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + +try: + import pycountry +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + 'The `language_code` module requires "pycountry" to be installed.' + ' You can install it with "pip install pycountry".' + ) + + +class ISO639_3(str): + """ISO639_3 parses Language in the [ISO 639-3 alpha-3](https://en.wikipedia.org/wiki/ISO_639-3_alpha-3) + format. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.language_code import ISO639_3 + + class Language(BaseModel): + alpha_3: ISO639_3 + + lang = Language(alpha_3='ssr') + print(lang) + # > alpha_3='ssr' + ``` + """ + + allowed_values_list = [lang.alpha_3 for lang in pycountry.languages] + allowed_values = set(allowed_values_list) + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> ISO639_3: + """ + Validate a ISO 639-3 language code from the provided str value. + + Args: + __input_value: The str value to be validated. + _: The Pydantic ValidationInfo. + + Returns: + The validated ISO 639-3 language code. + + Raises: + PydanticCustomError: If the ISO 639-3 language code is not valid. + """ + if __input_value not in cls.allowed_values: + raise PydanticCustomError( + 'ISO649_3', 'Invalid ISO 639-3 language code. See https://en.wikipedia.org/wiki/ISO_639-3' + ) + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, _: type[Any], __: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + """ + Return a Pydantic CoreSchema with the ISO 639-3 language code validation. + + Args: + _: The source type. + __: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the ISO 639-3 language code validation. + + """ + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=3, max_length=3), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + """ + Return a Pydantic JSON Schema with the ISO 639-3 language code validation. + + Args: + schema: The Pydantic CoreSchema. + handler: The handler to get the JSON Schema. + + Returns: + A Pydantic JSON Schema with the ISO 639-3 language code validation. + + """ + json_schema = handler(schema) + json_schema.update({'enum': cls.allowed_values_list}) + return json_schema + + +class ISO639_5(str): + """ISO639_5 parses Language in the [ISO 639-5 alpha-3](https://en.wikipedia.org/wiki/ISO_639-5_alpha-3) + format. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.language_code import ISO639_5 + + class Language(BaseModel): + alpha_3: ISO639_5 + + lang = Language(alpha_3='gem') + print(lang) + # > alpha_3='gem' + ``` + """ + + allowed_values_list = [lang.alpha_3 for lang in pycountry.language_families] + allowed_values_list.sort() + allowed_values = set(allowed_values_list) + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> ISO639_5: + """ + Validate a ISO 639-5 language code from the provided str value. + + Args: + __input_value: The str value to be validated. + _: The Pydantic ValidationInfo. + + Returns: + The validated ISO 639-3 language code. + + Raises: + PydanticCustomError: If the ISO 639-5 language code is not valid. + """ + if __input_value not in cls.allowed_values: + raise PydanticCustomError( + 'ISO649_5', 'Invalid ISO 639-5 language code. See https://en.wikipedia.org/wiki/ISO_639-5' + ) + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, _: type[Any], __: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + """ + Return a Pydantic CoreSchema with the ISO 639-5 language code validation. + + Args: + _: The source type. + __: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the ISO 639-5 language code validation. + + """ + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=3, max_length=3), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + """ + Return a Pydantic JSON Schema with the ISO 639-5 language code validation. + + Args: + schema: The Pydantic CoreSchema. + handler: The handler to get the JSON Schema. + + Returns: + A Pydantic JSON Schema with the ISO 639-5 language code validation. + + """ + json_schema = handler(schema) + json_schema.update({'enum': cls.allowed_values_list}) + return json_schema diff --git a/pydantic_extra_types/mac_address.py b/pydantic_extra_types/mac_address.py new file mode 100644 index 0000000..9be1557 --- /dev/null +++ b/pydantic_extra_types/mac_address.py @@ -0,0 +1,125 @@ +""" +The MAC address module provides functionality to parse and validate MAC addresses in different +formats, such as IEEE 802 MAC-48, EUI-48, EUI-64, or a 20-octet format. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import GetCoreSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +class MacAddress(str): + """Represents a MAC address and provides methods for conversion, validation, and serialization. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.mac_address import MacAddress + + + class Network(BaseModel): + mac_address: MacAddress + + + network = Network(mac_address="00:00:5e:00:53:01") + print(network) + #> mac_address='00:00:5e:00:53:01' + ``` + """ + + @classmethod + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + """ + Return a Pydantic CoreSchema with the MAC address validation. + + Args: + source: The source type to be converted. + handler: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the MAC address validation. + + """ + return core_schema.with_info_before_validator_function( + cls._validate, + core_schema.str_schema(), + ) + + @classmethod + def _validate(cls, __input_value: str, _: Any) -> str: + """ + Validate a MAC Address from the provided str value. + + Args: + __input_value: The str value to be validated. + _: The source type to be converted. + + Returns: + str: The parsed MAC address. + + """ + return cls.validate_mac_address(__input_value.encode()) + + @staticmethod + def validate_mac_address(value: bytes) -> str: + """ + Validate a MAC Address from the provided byte value. + """ + if len(value) < 14: + raise PydanticCustomError( + 'mac_address_len', + 'Length for a {mac_address} MAC address must be {required_length}', + {'mac_address': value.decode(), 'required_length': 14}, + ) + + if value[2] in [ord(':'), ord('-')]: + if (len(value) + 1) % 3 != 0: + raise PydanticCustomError( + 'mac_address_format', 'Must have the format xx:xx:xx:xx:xx:xx or xx-xx-xx-xx-xx-xx' + ) + n = (len(value) + 1) // 3 + if n not in (6, 8, 20): + raise PydanticCustomError( + 'mac_address_format', + 'Length for a {mac_address} MAC address must be {required_length}', + {'mac_address': value.decode(), 'required_length': (6, 8, 20)}, + ) + mac_address = bytearray(n) + x = 0 + for i in range(n): + try: + byte_value = int(value[x : x + 2], 16) + mac_address[i] = byte_value + x += 3 + except ValueError as e: + raise PydanticCustomError('mac_address_format', 'Unrecognized format') from e + + elif value[4] == ord('.'): + if (len(value) + 1) % 5 != 0: + raise PydanticCustomError('mac_address_format', 'Must have the format xx.xx.xx.xx.xx.xx') + n = 2 * (len(value) + 1) // 5 + if n not in (6, 8, 20): + raise PydanticCustomError( + 'mac_address_format', + 'Length for a {mac_address} MAC address must be {required_length}', + {'mac_address': value.decode(), 'required_length': (6, 8, 20)}, + ) + mac_address = bytearray(n) + x = 0 + for i in range(0, n, 2): + try: + byte_value = int(value[x : x + 2], 16) + mac_address[i] = byte_value + byte_value = int(value[x + 2 : x + 4], 16) + mac_address[i + 1] = byte_value + x += 5 + except ValueError as e: + raise PydanticCustomError('mac_address_format', 'Unrecognized format') from e + + else: + raise PydanticCustomError('mac_address_format', 'Unrecognized format') + + return ':'.join(f'{b:02x}' for b in mac_address) diff --git a/pydantic_extra_types/payment.py b/pydantic_extra_types/payment.py new file mode 100644 index 0000000..e3c040b --- /dev/null +++ b/pydantic_extra_types/payment.py @@ -0,0 +1,199 @@ +""" +The `pydantic_extra_types.payment` module provides the +[`PaymentCardNumber`][pydantic_extra_types.payment.PaymentCardNumber] data type. +""" + +from __future__ import annotations + +from enum import Enum +from typing import Any, ClassVar + +from pydantic import GetCoreSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +class PaymentCardBrand(str, Enum): + """Payment card brands supported by the [`PaymentCardNumber`][pydantic_extra_types.payment.PaymentCardNumber].""" + + amex = 'American Express' + mastercard = 'Mastercard' + visa = 'Visa' + mir = 'Mir' + maestro = 'Maestro' + discover = 'Discover' + verve = 'Verve' + dankort = 'Dankort' + troy = 'Troy' + unionpay = 'UnionPay' + jcb = 'JCB' + other = 'other' + + def __str__(self) -> str: + return self.value + + +class PaymentCardNumber(str): + """A [payment card number](https://en.wikipedia.org/wiki/Payment_card_number).""" + + strip_whitespace: ClassVar[bool] = True + """Whether to strip whitespace from the input value.""" + min_length: ClassVar[int] = 12 + """The minimum length of the card number.""" + max_length: ClassVar[int] = 19 + """The maximum length of the card number.""" + bin: str + """The first 6 digits of the card number.""" + last4: str + """The last 4 digits of the card number.""" + brand: PaymentCardBrand + """The brand of the card.""" + + def __init__(self, card_number: str): + self.validate_digits(card_number) + + card_number = self.validate_luhn_check_digit(card_number) + + self.bin = card_number[:6] + self.last4 = card_number[-4:] + self.brand = self.validate_brand(card_number) + + @classmethod + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + cls.validate, + core_schema.str_schema( + min_length=cls.min_length, max_length=cls.max_length, strip_whitespace=cls.strip_whitespace + ), + ) + + @classmethod + def validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> PaymentCardNumber: + """Validate the `PaymentCardNumber` instance. + + Args: + __input_value: The input value to validate. + _: The validation info. + + Returns: + The validated `PaymentCardNumber` instance. + """ + return cls(__input_value) + + @property + def masked(self) -> str: + """The masked card number.""" + num_masked = len(self) - 10 # len(bin) + len(last4) == 10 + return f'{self.bin}{"*" * num_masked}{self.last4}' + + @classmethod + def validate_digits(cls, card_number: str) -> None: + """Validate that the card number is all digits. + + Args: + card_number: The card number to validate. + + Raises: + PydanticCustomError: If the card number is not all digits. + """ + if not card_number or not all('0' <= c <= '9' for c in card_number): + raise PydanticCustomError('payment_card_number_digits', 'Card number is not all digits') + + @classmethod + def validate_luhn_check_digit(cls, card_number: str) -> str: + """Validate the payment card number. + Based on the [Luhn algorithm](https://en.wikipedia.org/wiki/Luhn_algorithm). + + Args: + card_number: The card number to validate. + + Returns: + The validated card number. + + Raises: + PydanticCustomError: If the card number is not valid. + """ + sum_ = int(card_number[-1]) + length = len(card_number) + parity = length % 2 + for i in range(length - 1): + digit = int(card_number[i]) + if i % 2 == parity: + digit *= 2 + if digit > 9: + digit -= 9 + sum_ += digit + valid = sum_ % 10 == 0 + if not valid: + raise PydanticCustomError('payment_card_number_luhn', 'Card number is not luhn valid') + return card_number + + @staticmethod + def validate_brand(card_number: str) -> PaymentCardBrand: + """Validate length based on + [BIN](https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN)) + for major brands. + + Args: + card_number: The card number to validate. + + Returns: + The validated card brand. + + Raises: + PydanticCustomError: If the card number is not valid. + """ + brand = PaymentCardBrand.other + + if card_number[0] == '4': + brand = PaymentCardBrand.visa + required_length = [13, 16, 19] + elif 51 <= int(card_number[:2]) <= 55: + brand = PaymentCardBrand.mastercard + required_length = [16] + elif card_number[:2] in {'34', '37'}: + brand = PaymentCardBrand.amex + required_length = [15] + elif 2200 <= int(card_number[:4]) <= 2204: + brand = PaymentCardBrand.mir + required_length = list(range(16, 20)) + elif card_number[:4] in {'5018', '5020', '5038', '5893', '6304', '6759', '6761', '6762', '6763'} or card_number[ + :6 + ] in ( + '676770', + '676774', + ): + brand = PaymentCardBrand.maestro + required_length = list(range(12, 20)) + elif card_number.startswith('65') or 644 <= int(card_number[:3]) <= 649 or card_number.startswith('6011'): + brand = PaymentCardBrand.discover + required_length = list(range(16, 20)) + elif ( + 506099 <= int(card_number[:6]) <= 506198 + or 650002 <= int(card_number[:6]) <= 650027 + or 507865 <= int(card_number[:6]) <= 507964 + ): + brand = PaymentCardBrand.verve + required_length = [16, 18, 19] + elif card_number[:4] in {'5019', '4571'}: + brand = PaymentCardBrand.dankort + required_length = [16] + elif card_number.startswith('9792'): + brand = PaymentCardBrand.troy + required_length = [16] + elif card_number[:2] in {'62', '81'}: + brand = PaymentCardBrand.unionpay + required_length = [16, 19] + elif 3528 <= int(card_number[:4]) <= 3589: + brand = PaymentCardBrand.jcb + required_length = [16, 19] + + valid = len(card_number) in required_length if brand != PaymentCardBrand.other else True + + if not valid: + raise PydanticCustomError( + 'payment_card_number_brand', + f'Length for a {brand} card must be {" or ".join(map(str, required_length))}', + {'brand': brand, 'required_length': required_length}, + ) + + return brand diff --git a/pydantic_extra_types/pendulum_dt.py b/pydantic_extra_types/pendulum_dt.py new file mode 100644 index 0000000..f507779 --- /dev/null +++ b/pydantic_extra_types/pendulum_dt.py @@ -0,0 +1,74 @@ +""" +Native Pendulum DateTime object implementation. This is a copy of the Pendulum DateTime object, but with a Pydantic +CoreSchema implementation. This allows Pydantic to validate the DateTime object. +""" + +try: + from pendulum import DateTime as _DateTime + from pendulum import parse +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + 'The `pendulum_dt` module requires "pendulum" to be installed. You can install it with "pip install pendulum".' + ) +from typing import Any, List, Type + +from pydantic import GetCoreSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +class DateTime(_DateTime): + """ + A `pendulum.DateTime` object. At runtime, this type decomposes into pendulum.DateTime automatically. + This type exists because Pydantic throws a fit on unknown types. + + ```python + from pydantic import BaseModel + from pydantic_extra_types.pendulum_dt import DateTime + + class test_model(BaseModel): + dt: DateTime + + print(test_model(dt='2021-01-01T00:00:00+00:00')) + + #> test_model(dt=DateTime(2021, 1, 1, 0, 0, 0, tzinfo=FixedTimezone(0, name="+00:00"))) + ``` + """ + + __slots__: List[str] = [] + + @classmethod + def __get_pydantic_core_schema__(cls, source: Type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + """ + Return a Pydantic CoreSchema with the Datetime validation + + Args: + source: The source type to be converted. + handler: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the Datetime validation. + """ + return core_schema.no_info_wrap_validator_function(cls._validate, core_schema.datetime_schema()) + + @classmethod + def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: + """ + Validate the datetime object and return it. + + Args: + value: The value to validate. + handler: The handler to get the CoreSchema. + + Returns: + The validated value or raises a PydanticCustomError. + """ + # if we are passed an existing instance, pass it straight through. + if isinstance(value, _DateTime): + return handler(value) + + # otherwise, parse it. + try: + data = parse(value) + except Exception as exc: + raise PydanticCustomError('value_error', 'value is not a valid timestamp') from exc + return handler(data) diff --git a/pydantic_extra_types/phone_numbers.py b/pydantic_extra_types/phone_numbers.py new file mode 100644 index 0000000..7acaa89 --- /dev/null +++ b/pydantic_extra_types/phone_numbers.py @@ -0,0 +1,68 @@ +""" +The `pydantic_extra_types.phone_numbers` module provides the +[`PhoneNumber`][pydantic_extra_types.phone_numbers.PhoneNumber] data type. + +This class depends on the [phonenumbers] package, which is a Python port of Google's [libphonenumber]. +""" +from __future__ import annotations + +from typing import Any, Callable, ClassVar, Generator + +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + +try: + import phonenumbers +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + '`PhoneNumber` requires "phonenumbers" to be installed. You can install it with "pip install phonenumbers"' + ) + +GeneratorCallableStr = Generator[Callable[..., str], None, None] + + +class PhoneNumber(str): + """ + A wrapper around [phonenumbers](https://pypi.org/project/phonenumbers/) package, which + is a Python port of Google's [libphonenumber](https://github.com/google/libphonenumber/). + """ + + supported_regions: list[str] = sorted(phonenumbers.SUPPORTED_REGIONS) + """The supported regions.""" + supported_formats: list[str] = sorted([f for f in phonenumbers.PhoneNumberFormat.__dict__.keys() if f.isupper()]) + """The supported phone number formats.""" + + default_region_code: ClassVar[str | None] = None + """The default region code to use when parsing phone numbers without an international prefix.""" + phone_format: str = 'RFC3966' + """The format of the phone number.""" + min_length: int = 7 + """The minimum length of the phone number.""" + max_length: int = 64 + """The maximum length of the phone number.""" + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + json_schema = handler(schema) + json_schema.update({'format': 'phone'}) + return json_schema + + @classmethod + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=cls.min_length, max_length=cls.max_length), + ) + + @classmethod + def _validate(cls, phone_number: str, _: core_schema.ValidationInfo) -> str: + try: + parsed_number = phonenumbers.parse(phone_number, cls.default_region_code) + except phonenumbers.phonenumberutil.NumberParseException as exc: + raise PydanticCustomError('value_error', 'value is not a valid phone number') from exc + if not phonenumbers.is_valid_number(parsed_number): + raise PydanticCustomError('value_error', 'value is not a valid phone number') + + return phonenumbers.format_number(parsed_number, getattr(phonenumbers.PhoneNumberFormat, cls.phone_format)) diff --git a/pydantic_extra_types/py.typed b/pydantic_extra_types/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/pydantic_extra_types/py.typed diff --git a/pydantic_extra_types/routing_number.py b/pydantic_extra_types/routing_number.py new file mode 100644 index 0000000..22ea6e8 --- /dev/null +++ b/pydantic_extra_types/routing_number.py @@ -0,0 +1,89 @@ +""" +The `pydantic_extra_types.routing_number` module provides the +[`ABARoutingNumber`][pydantic_extra_types.routing_number.ABARoutingNumber] data type. +""" +from typing import Any, ClassVar, Type + +from pydantic import GetCoreSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +class ABARoutingNumber(str): + """The `ABARoutingNumber` data type is a string of 9 digits representing an ABA routing transit number. + + The algorithm used to validate the routing number is described in the + [ABA routing transit number](https://en.wikipedia.org/wiki/ABA_routing_transit_number#Check_digit) + Wikipedia article. + + ```py + from pydantic import BaseModel + + from pydantic_extra_types.routing_number import ABARoutingNumber + + class BankAccount(BaseModel): + routing_number: ABARoutingNumber + + account = BankAccount(routing_number='122105155') + print(account) + #> routing_number='122105155' + ``` + """ + + strip_whitespace: ClassVar[bool] = True + min_length: ClassVar[int] = 9 + max_length: ClassVar[int] = 9 + + def __init__(self, routing_number: str): + self._validate_digits(routing_number) + self._routing_number = self._validate_routing_number(routing_number) + + @classmethod + def __get_pydantic_core_schema__( + cls, source: Type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema( + min_length=cls.min_length, + max_length=cls.max_length, + strip_whitespace=cls.strip_whitespace, + strict=False, + ), + ) + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> 'ABARoutingNumber': + return cls(__input_value) + + @classmethod + def _validate_digits(cls, routing_number: str) -> None: + """Check that the routing number is all digits. + + Args: + routing_number: The routing number to validate. + + Raises: + PydanticCustomError: If the routing number is not all digits. + """ + if not routing_number.isdigit(): + raise PydanticCustomError('aba_routing_number', 'routing number is not all digits') + + @classmethod + def _validate_routing_number(cls, routing_number: str) -> str: + """Check [digit algorithm](https://en.wikipedia.org/wiki/ABA_routing_transit_number#Check_digit) for + [ABA routing transit number](https://www.routingnumber.com/). + + Args: + routing_number: The routing number to validate. + + Raises: + PydanticCustomError: If the routing number is incorrect. + """ + checksum = ( + 3 * (sum(map(int, [routing_number[0], routing_number[3], routing_number[6]]))) + + 7 * (sum(map(int, [routing_number[1], routing_number[4], routing_number[7]]))) + + sum(map(int, [routing_number[2], routing_number[5], routing_number[8]])) + ) + if checksum % 10 != 0: + raise PydanticCustomError('aba_routing_number', 'Incorrect ABA routing transit number') + return routing_number diff --git a/pydantic_extra_types/ulid.py b/pydantic_extra_types/ulid.py new file mode 100644 index 0000000..d2bf650 --- /dev/null +++ b/pydantic_extra_types/ulid.py @@ -0,0 +1,62 @@ +""" +The `pydantic_extra_types.ULID` module provides the [`ULID`] data type. + +This class depends on the [python-ulid] package, which is a validate by the [ULID-spec](https://github.com/ulid/spec#implementations-in-other-languages). +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Union + +from pydantic import GetCoreSchemaHandler +from pydantic._internal import _repr +from pydantic_core import PydanticCustomError, core_schema + +try: + from ulid import ULID as _ULID +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + 'The `ulid` module requires "python-ulid" to be installed. You can install it with "pip install python-ulid".' + ) + +UlidType = Union[str, bytes, int] + + +@dataclass +class ULID(_repr.Representation): + """ + A wrapper around [python-ulid](https://pypi.org/project/python-ulid/) package, which + is a validate by the [ULID-spec](https://github.com/ulid/spec#implementations-in-other-languages). + """ + + ulid: _ULID + + @classmethod + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.no_info_wrap_validator_function( + cls._validate_ulid, + core_schema.union_schema( + [ + core_schema.is_instance_schema(_ULID), + core_schema.int_schema(), + core_schema.bytes_schema(), + core_schema.str_schema(), + ] + ), + ) + + @classmethod + def _validate_ulid(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: + ulid: _ULID + try: + if isinstance(value, int): + ulid = _ULID.from_int(value) + elif isinstance(value, str): + ulid = _ULID.from_str(value) + elif isinstance(value, _ULID): + ulid = value + else: + ulid = _ULID.from_bytes(value) + except ValueError: + raise PydanticCustomError('ulid_format', 'Unrecognized format') + return handler(ulid) |