summaryrefslogtreecommitdiffstats
path: root/src/prompt_toolkit/cache.py
blob: 01dd1f79d6582442b0932a6a252585a5d85aa0a5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

from collections import deque
from functools import wraps
from typing import Any, Callable, Dict, Generic, Hashable, Tuple, TypeVar, cast

__all__ = [
    "SimpleCache",
    "FastDictCache",
    "memoized",
]

_T = TypeVar("_T", bound=Hashable)
_U = TypeVar("_U")


class SimpleCache(Generic[_T, _U]):
    """
    Very simple cache that discards the oldest item when the cache size is
    exceeded.

    :param maxsize: Maximum size of the cache. (Don't make it too big.)
    """

    def __init__(self, maxsize: int = 8) -> None:
        assert maxsize > 0

        self._data: dict[_T, _U] = {}
        self._keys: deque[_T] = deque()
        self.maxsize: int = maxsize

    def get(self, key: _T, getter_func: Callable[[], _U]) -> _U:
        """
        Get object from the cache.
        If not found, call `getter_func` to resolve it, and put that on the top
        of the cache instead.
        """
        # Look in cache first.
        try:
            return self._data[key]
        except KeyError:
            # Not found? Get it.
            value = getter_func()
            self._data[key] = value
            self._keys.append(key)

            # Remove the oldest key when the size is exceeded.
            if len(self._data) > self.maxsize:
                key_to_remove = self._keys.popleft()
                if key_to_remove in self._data:
                    del self._data[key_to_remove]

            return value

    def clear(self) -> None:
        "Clear cache."
        self._data = {}
        self._keys = deque()


_K = TypeVar("_K", bound=Tuple[Hashable, ...])
_V = TypeVar("_V")


class FastDictCache(Dict[_K, _V]):
    """
    Fast, lightweight cache which keeps at most `size` items.
    It will discard the oldest items in the cache first.

    The cache is a dictionary, which doesn't keep track of access counts.
    It is perfect to cache little immutable objects which are not expensive to
    create, but where a dictionary lookup is still much faster than an object
    instantiation.

    :param get_value: Callable that's called in case of a missing key.
    """

    # NOTE: This cache is used to cache `prompt_toolkit.layout.screen.Char` and
    #       `prompt_toolkit.Document`. Make sure to keep this really lightweight.
    #       Accessing the cache should stay faster than instantiating new
    #       objects.
    #       (Dictionary lookups are really fast.)
    #       SimpleCache is still required for cases where the cache key is not
    #       the same as the arguments given to the function that creates the
    #       value.)
    def __init__(self, get_value: Callable[..., _V], size: int = 1000000) -> None:
        assert size > 0

        self._keys: deque[_K] = deque()
        self.get_value = get_value
        self.size = size

    def __missing__(self, key: _K) -> _V:
        # Remove the oldest key when the size is exceeded.
        if len(self) > self.size:
            key_to_remove = self._keys.popleft()
            if key_to_remove in self:
                del self[key_to_remove]

        result = self.get_value(*key)
        self[key] = result
        self._keys.append(key)
        return result


_F = TypeVar("_F", bound=Callable[..., object])


def memoized(maxsize: int = 1024) -> Callable[[_F], _F]:
    """
    Memoization decorator for immutable classes and pure functions.
    """

    def decorator(obj: _F) -> _F:
        cache: SimpleCache[Hashable, Any] = SimpleCache(maxsize=maxsize)

        @wraps(obj)
        def new_callable(*a: Any, **kw: Any) -> Any:
            def create_new() -> Any:
                return obj(*a, **kw)

            key = (a, tuple(sorted(kw.items())))
            return cache.get(key, create_new)

        return cast(_F, new_callable)

    return decorator