summaryrefslogtreecommitdiffstats
path: root/deluge/decorators.py
blob: 92e3ecf59bb9d2e70f9d4297943754f111cdac41 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#
# Copyright (C) 2010 John Garland <johnnybg+deluge@gmail.com>
#
# This file is part of Deluge and is licensed under GNU General Public License 3.0, or later, with
# the additional special exception to link portions of this program with the OpenSSL library.
# See LICENSE for more details.
#

import inspect
import re
import warnings
from functools import wraps
from typing import Any, Callable, Coroutine, TypeVar

from twisted.internet import defer


def proxy(proxy_func):
    """
    Factory class which returns a decorator that passes
    the decorated function to a proxy function

    :param proxy_func: the proxy function
    :type proxy_func: function
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            return proxy_func(func, *args, **kwargs)

        return wrapper

    return decorator


def overrides(*args):
    """
    Decorater function to specify when class methods override
    super class methods.

    When used as
    @overrides
    def funcname

    the argument will be the funcname function.

    When used as
    @overrides(BaseClass)
    def funcname

    the argument will be the BaseClass

    """
    stack = inspect.stack()
    if inspect.isfunction(args[0]):
        return _overrides(stack, args[0])
    else:
        # One or more classes are specified, so return a function that will be
        # called with the real function as argument
        def ret_func(func, **kwargs):
            return _overrides(stack, func, explicit_base_classes=args)

        return ret_func


def _overrides(stack, method, explicit_base_classes=None):
    # stack[0]=overrides, stack[1]=inside class def'n, stack[2]=outside class def'n
    classes = {}
    derived_class_locals = stack[2][0].f_locals

    # Find all super classes
    m = re.search(r'class\s(.+)\((.+)\)\s*\:', stack[2][4][0])
    class_name = m.group(1)
    base_classes = m.group(2)

    # Handle multiple inheritance
    base_classes = [s.strip() for s in base_classes.split(',')]
    check_classes = base_classes

    if not base_classes:
        raise ValueError(
            'overrides decorator: unable to determine base class of class "%s"'
            % class_name
        )

    def get_class(cls_name):
        if '.' not in cls_name:
            return derived_class_locals[cls_name]
        else:
            components = cls_name.split('.')
            # obj is either a module or a class
            obj = derived_class_locals[components[0]]
            for c in components[1:]:
                assert inspect.ismodule(obj) or inspect.isclass(obj)
                obj = getattr(obj, c)
            return obj

    if explicit_base_classes:
        # One or more base classes are explicitly given, check only those classes
        override_classes = re.search(r'\s*@overrides\((.+)\)\s*', stack[1][4][0]).group(
            1
        )
        override_classes = [c.strip() for c in override_classes.split(',')]
        check_classes = override_classes

    for c in base_classes + check_classes:
        classes[c] = get_class(c)

    # Verify that the explicit override class is one of base classes
    if explicit_base_classes:
        from itertools import product

        for bc, cc in product(base_classes, check_classes):
            if issubclass(classes[bc], classes[cc]):
                break
        else:
            raise Exception(
                'Excplicit override class "%s" is not a super class of: %s'
                % (explicit_base_classes, class_name)
            )
        if not all(hasattr(classes[cls], method.__name__) for cls in check_classes):
            for cls in check_classes:
                if not hasattr(classes[cls], method.__name__):
                    raise Exception(
                        'Function override "%s" not found in superclass: %s\n%s'
                        % (
                            method.__name__,
                            cls,
                            f'File: {stack[1][1]}:{stack[1][2]}',
                        )
                    )

    if not any(hasattr(classes[cls], method.__name__) for cls in check_classes):
        raise Exception(
            'Function override "%s" not found in any superclass: %s\n%s'
            % (
                method.__name__,
                check_classes,
                f'File: {stack[1][1]}:{stack[1][2]}',
            )
        )
    return method


def deprecated(func):
    """This is a decorator which can be used to mark function as deprecated.

    It will result in a warning being emitted when the function is used.

    """

    @wraps(func)
    def depr_func(*args, **kwargs):
        warnings.simplefilter('always', DeprecationWarning)  # Turn off filter
        warnings.warn(
            f'Call to deprecated function {func.__name__}.',
            category=DeprecationWarning,
            stacklevel=2,
        )
        warnings.simplefilter('default', DeprecationWarning)  # Reset filter
        return func(*args, **kwargs)

    return depr_func


class CoroutineDeferred(defer.Deferred):
    """Wraps a coroutine in a Deferred.
    It will dynamically pass through the underlying coroutine without wrapping where apporpriate.
    """

    def __init__(self, coro: Coroutine):
        # Delay this import to make sure a reactor was installed first
        from twisted.internet import reactor

        super().__init__()
        self.coro = coro
        self.awaited = None
        self.activate_deferred = reactor.callLater(0, self.activate)

    def __await__(self):
        if self.awaited in [None, True]:
            self.awaited = True
            return self.coro.__await__()
        # Already in deferred mode
        return super().__await__()

    def activate(self):
        """If the result wasn't awaited before the next context switch, we turn it into a deferred."""
        if self.awaited is None:
            self.awaited = False
            try:
                d = defer.Deferred.fromCoroutine(self.coro)
            except AttributeError:
                # Fallback for Twisted <= 21.2 without fromCoroutine
                d = defer.ensureDeferred(self.coro)
            d.chainDeferred(self)

    def addCallbacks(self, *args, **kwargs):  # noqa: N802
        assert not self.awaited, 'Cannot add callbacks to an already awaited coroutine.'
        self.activate()
        return super().addCallbacks(*args, **kwargs)


_RetT = TypeVar('_RetT')


def maybe_coroutine(
    f: Callable[..., Coroutine[Any, Any, _RetT]]
) -> 'Callable[..., defer.Deferred[_RetT]]':
    """Wraps a coroutine function to make it usable as a normal function that returns a Deferred."""

    @wraps(f)
    def wrapper(*args, **kwargs):
        # Uncomment for quick testing to make sure CoroutineDeferred magic isn't at fault
        # return defer.ensureDeferred(f(*args, **kwargs))
        return CoroutineDeferred(f(*args, **kwargs))

    return wrapper