summaryrefslogtreecommitdiffstats
path: root/myst_parser/config/dc_validators.py
blob: 765cfb93f515c0f75774ea1975e22398e4f3dc61 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Validators for dataclasses, mirroring those of https://github.com/python-attrs/attrs."""
from __future__ import annotations

import dataclasses as dc
from typing import Any, Sequence

from typing_extensions import Protocol


def validate_field(inst: Any, field: dc.Field, value: Any) -> None:
    """Validate the field of a dataclass,
    according to a `validator` function set in the field.metadata.

    The validator function should take as input (inst, field, value) and
    raise an exception if the value is invalid.
    """
    if "validator" not in field.metadata:
        return
    if isinstance(field.metadata["validator"], list):
        for validator in field.metadata["validator"]:
            validator(inst, field, value)
    else:
        field.metadata["validator"](inst, field, value)


def validate_fields(inst: Any) -> None:
    """Validate the fields of a dataclass,
    according to `validator` functions set in the field metadata.

    This function should be called in the `__post_init__` of the dataclass.

    The validator function should take as input (inst, field, value) and
    raise an exception if the value is invalid.
    """
    for field in dc.fields(inst):
        validate_field(inst, field, getattr(inst, field.name))


class ValidatorType(Protocol):
    def __call__(
        self, inst: bytes, field: dc.Field, value: Any, suffix: str = ""
    ) -> None:
        ...


def instance_of(type: type[Any] | tuple[type[Any], ...]) -> ValidatorType:
    """
    A validator that raises a `TypeError` if the initializer is called
    with a wrong type for this particular attribute (checks are performed using
    `isinstance` therefore it's also valid to pass a tuple of types).

    :param type: The type to check for.
    """

    def _validator(inst, field, value, suffix=""):
        """
        We use a callable class to be able to change the ``__repr__``.
        """
        if not isinstance(value, type):
            raise TypeError(
                f"'{field.name}{suffix}' must be of type {type!r} "
                f"(got {value!r} that is a {value.__class__!r})."
            )

    return _validator


def optional(validator: ValidatorType) -> ValidatorType:
    """
    A validator that makes an attribute optional.  An optional attribute is one
    which can be set to ``None`` in addition to satisfying the requirements of
    the sub-validator.
    """

    def _validator(inst, field, value, suffix=""):
        if value is None:
            return

        validator(inst, field, value, suffix=suffix)

    return _validator


def is_callable(inst, field, value, suffix=""):
    """
    A validator that raises a `TypeError` if the
    initializer is called with a value for this particular attribute
    that is not callable.
    """
    if not callable(value):
        raise TypeError(
            f"'{field.name}{suffix}' must be callable "
            f"(got {value!r} that is a {value.__class__!r})."
        )


def in_(options: Sequence) -> ValidatorType:
    """
    A validator that raises a `ValueError` if the initializer is called
    with a value that does not belong in the options provided.  The check is
    performed using ``value in options``.

    :param options: Allowed options.
    """

    def _validator(inst, field, value, suffix=""):
        try:
            in_options = value in options
        except TypeError:  # e.g. `1 in "abc"`
            in_options = False

        if not in_options:
            raise ValueError(
                f"'{field.name}{suffix}' must be in {options!r} (got {value!r})"
            )

    return _validator


def deep_iterable(
    member_validator: ValidatorType, iterable_validator: ValidatorType | None = None
) -> ValidatorType:
    """
    A validator that performs deep validation of an iterable.

    :param member_validator: Validator to apply to iterable members
    :param iterable_validator: Validator to apply to iterable itself
    """

    def _validator(inst, field, value, suffix=""):
        if iterable_validator is not None:
            iterable_validator(inst, field, value, suffix=suffix)

        for idx, member in enumerate(value):
            member_validator(inst, field, member, suffix=f"{suffix}[{idx}]")

    return _validator


def deep_mapping(
    key_validator: ValidatorType,
    value_validator: ValidatorType,
    mapping_validator: ValidatorType | None = None,
) -> ValidatorType:
    """
    A validator that performs deep validation of a dictionary.

    :param key_validator: Validator to apply to dictionary keys
    :param value_validator: Validator to apply to dictionary values
    :param mapping_validator: Validator to apply to top-level mapping attribute (optional)
    """

    def _validator(inst, field: dc.Field, value, suffix=""):
        if mapping_validator is not None:
            mapping_validator(inst, field, value)

        for key in value:
            key_validator(inst, field, key, suffix=f"{suffix}[{key!r}]")
            value_validator(inst, field, value[key], suffix=f"{suffix}[{key!r}]")

    return _validator