summaryrefslogtreecommitdiffstats
path: root/debian/lib/python/debian_linux/dataclasses_extra.py
blob: 49f24100386373b0f4b5c2d1597df2e291a63fa2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from __future__ import annotations

from dataclasses import (
    fields,
    is_dataclass,
    replace,
)
from typing import (
    Protocol,
    TypeVar,
    TYPE_CHECKING,
)

if TYPE_CHECKING:
    from _typeshed import DataclassInstance as _DataclassInstance

    class _HasName(Protocol, _DataclassInstance):
        name: str

    _DataclassT = TypeVar('_DataclassT', bound=_DataclassInstance)
    _HasNameT = TypeVar('_HasNameT', bound=_HasName)


def default(
    cls: type[_DataclassT],
    /,
) -> _DataclassT:
    f = {}

    for field in fields(cls):
        if 'default' in field.metadata:
            f[field.name] = field.metadata['default']

    return cls(**f)


def merge(
    self: _DataclassT,
    other: _DataclassT | None, /,
) -> _DataclassT:
    if other is None:
        return self

    f = {}

    for field in fields(self):
        if not field.init:
            continue

        field_default_type = object
        if isinstance(field.default_factory, type):
            field_default_type = field.default_factory

        self_field = getattr(self, field.name)
        other_field = getattr(other, field.name)

        if field.name == 'name':
            assert self_field == other_field
        elif field.type == 'bool':
            f[field.name] = other_field
        elif field.metadata.get('merge') == 'assoclist':
            f[field.name] = _merge_assoclist(self_field, other_field)
        elif is_dataclass(field_default_type):
            f[field.name] = merge(self_field, other_field)
        elif issubclass(field_default_type, list):
            f[field.name] = self_field + other_field
        elif issubclass(field_default_type, dict):
            f[field.name] = self_field | other_field
        elif field.default is None:
            if other_field is not None:
                f[field.name] = other_field
        else:
            raise RuntimeError(f'Unable to merge for type {field.type}')

    return replace(self, **f)


def merge_default(
    cls: type[_DataclassT],
    /,
    *others: _DataclassT,
) -> _DataclassT:
    ret: _DataclassT = default(cls)
    for o in others:
        ret = merge(ret, o)
    return ret


def _merge_assoclist(
    self_list: list[_HasNameT],
    other_list: list[_HasNameT],
    /,
) -> list[_HasNameT]:
    '''
    Merge lists where each item got a "name" attribute
    '''
    if not self_list:
        return other_list
    if not other_list:
        return self_list

    ret: list[_HasNameT] = []
    other_dict = {
        i.name: i
        for i in other_list
    }
    for i in self_list:
        if i.name in other_dict:
            ret.append(merge(i, other_dict.pop(i.name)))
        else:
            ret.append(i)
    ret.extend(other_dict.values())
    return ret