summaryrefslogtreecommitdiffstats
path: root/src/pybind/mgr/dashboard/plugins/ttl_cache.py
blob: 78221547acc3a065cadb2c8df220dbf9c0198fef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
This is a minimal implementation of TTL-ed lru_cache function.

Based on Python 3 functools and backports.functools_lru_cache.
"""

import os
from collections import OrderedDict
from functools import wraps
from threading import RLock
from time import time
from typing import Any, Dict

try:
    from typing import Tuple
except ImportError:
    pass  # For typing only


class TTLCache:
    class CachedValue:
        def __init__(self, value, timestamp):
            self.value = value
            self.timestamp = timestamp

    def __init__(self, reference, ttl, maxsize=128):
        self.reference = reference
        self.ttl: int = ttl
        self.maxsize = maxsize
        self.cache: OrderedDict[Tuple[Any], TTLCache.CachedValue] = OrderedDict()
        self.hits = 0
        self.misses = 0
        self.expired = 0
        self.rlock = RLock()

    def __getitem__(self, key):
        with self.rlock:
            if key not in self.cache:
                self.misses += 1
                raise KeyError(f'"{key}" is not set')

            cached_value = self.cache[key]
            if time() - cached_value.timestamp >= self.ttl:
                del self.cache[key]
                self.expired += 1
                self.misses += 1
                raise KeyError(f'"{key}" is not set')

            self.hits += 1
            return cached_value.value

    def __setitem__(self, key, value):
        with self.rlock:
            if key in self.cache:
                cached_value = self.cache[key]
                if time() - cached_value.timestamp >= self.ttl:
                    self.expired += 1
            if len(self.cache) == self.maxsize:
                self.cache.popitem(last=False)

            self.cache[key] = TTLCache.CachedValue(value, time())

    def clear(self):
        with self.rlock:
            self.cache.clear()

    def info(self) -> str:
        return (f'cache={self.reference} hits={self.hits}, misses={self.misses},'
                f'expired={self.expired}, maxsize={self.maxsize}, currsize={len(self.cache)}')


class CacheManager:
    caches: Dict[str, TTLCache] = {}

    @classmethod
    def get(cls, reference: str, ttl=30, maxsize=128):
        if reference in cls.caches:
            return cls.caches[reference]
        cls.caches[reference] = TTLCache(reference, ttl, maxsize)
        return cls.caches[reference]


def ttl_cache(ttl, maxsize=128, typed=False, label: str = ''):
    if typed is not False:
        raise NotImplementedError("typed caching not supported")

    # disable caching while running unit tests
    if 'UNITTEST' in os.environ:
        ttl = 0

    def decorating_function(function):
        cache_name = label
        if not cache_name:
            cache_name = function.__name__
        cache = CacheManager.get(cache_name, ttl, maxsize)

        @wraps(function)
        def wrapper(*args, **kwargs):
            key = args + tuple(kwargs.items())
            try:
                return cache[key]
            except KeyError:
                ret = function(*args, **kwargs)
                cache[key] = ret
                return ret

        return wrapper
    return decorating_function


def ttl_cache_invalidator(label: str):
    def decorating_function(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            ret = function(*args, **kwargs)
            CacheManager.get(label).clear()
            return ret
        return wrapper
    return decorating_function