summaryrefslogtreecommitdiffstats
path: root/src/pybind/mgr/cli_api/module.py
blob: 79b042eb0e9d6e2d8d458af8abfbf5142a72a46c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import concurrent.futures
import functools
import inspect
import logging
import time
import errno
from typing import Any, Callable, Dict, List

from mgr_module import MgrModule, HandleCommandResult, CLICommand, API

logger = logging.getLogger()
get_time = time.perf_counter


def pretty_json(obj: Any) -> Any:
    import json
    return json.dumps(obj, sort_keys=True, indent=2)


class CephCommander:
    """
    Utility class to inspect Python functions and generate corresponding
    CephCommand signatures (see src/mon/MonCommand.h for details)
    """

    def __init__(self, func: Callable):
        self.func = func
        self.signature = inspect.signature(func)
        self.params = self.signature.parameters

    def to_ceph_signature(self) -> Dict[str, str]:
        """
        Generate CephCommand signature (dict-like)
        """
        return {
            'prefix': f'mgr cli {self.func.__name__}',
            'perm': API.perm.get(self.func)
        }


class MgrAPIReflector(type):
    """
    Metaclass to register COMMANDS and Command Handlers via CLICommand
    decorator
    """

    def __new__(cls, name, bases, dct):  # type: ignore
        klass = super().__new__(cls, name, bases, dct)
        cls.threaded_benchmark_runner = None
        for base in bases:
            for name, func in inspect.getmembers(base, cls.is_public):
                # However not necessary (CLICommand uses a registry)
                # save functions to klass._cli_{n}() methods. This
                # can help on unit testing
                wrapper = cls.func_wrapper(func)
                command = CLICommand(**CephCommander(func).to_ceph_signature())(  # type: ignore
                    wrapper)
                setattr(
                    klass,
                    f'_cli_{name}',
                    command)
        return klass

    @staticmethod
    def is_public(func: Callable) -> bool:
        return (
            inspect.isfunction(func)
            and not func.__name__.startswith('_')
            and API.expose.get(func)
        )

    @staticmethod
    def func_wrapper(func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs) -> HandleCommandResult:  # type: ignore
            return HandleCommandResult(stdout=pretty_json(
                func(self, *args, **kwargs)))

        # functools doesn't change the signature when wrapping a function
        # so we do it manually
        signature = inspect.signature(func)
        wrapper.__signature__ = signature  # type: ignore
        return wrapper


class CLI(MgrModule, metaclass=MgrAPIReflector):
    @CLICommand('mgr cli_benchmark')
    def benchmark(self, iterations: int, threads: int, func_name: str,
                  func_args: List[str] = None) -> HandleCommandResult:  # type: ignore
        func_args = () if func_args is None else func_args
        if iterations and threads:
            try:
                func = getattr(self, func_name)
            except AttributeError:
                return HandleCommandResult(errno.EINVAL,
                                           stderr="Could not find the public "
                                           "function you are requesting")
        else:
            raise BenchmarkException("Number of calls and number "
                                     "of parallel calls must be greater than 0")

        def timer(*args: Any) -> float:
            time_start = get_time()
            func(*func_args)
            return get_time() - time_start

        with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
            results_iter = executor.map(timer, range(iterations))
        results = list(results_iter)

        stats = {
            "avg": sum(results) / len(results),
            "max": max(results),
            "min": min(results),
        }
        return HandleCommandResult(stdout=pretty_json(stats))


class BenchmarkException(Exception):
    pass