From 19fcec84d8d7d21e796c7624e521b60d28ee21ed Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 20:45:59 +0200 Subject: Adding upstream version 16.2.11+ds. Signed-off-by: Daniel Baumann --- .../mgr/dashboard/controllers/_base_controller.py | 314 +++++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 src/pybind/mgr/dashboard/controllers/_base_controller.py (limited to 'src/pybind/mgr/dashboard/controllers/_base_controller.py') diff --git a/src/pybind/mgr/dashboard/controllers/_base_controller.py b/src/pybind/mgr/dashboard/controllers/_base_controller.py new file mode 100644 index 000000000..4fcc8442f --- /dev/null +++ b/src/pybind/mgr/dashboard/controllers/_base_controller.py @@ -0,0 +1,314 @@ +import inspect +import json +import logging +from functools import wraps +from typing import ClassVar, List, Optional, Type +from urllib.parse import unquote + +import cherrypy + +from ..plugins import PLUGIN_MANAGER +from ..services.auth import AuthManager, JwtManager +from ..tools import get_request_body_params +from ._helpers import _get_function_params +from ._version import APIVersion + +logger = logging.getLogger(__name__) + + +class BaseController: + """ + Base class for all controllers providing API endpoints. + """ + + _registry: ClassVar[List[Type['BaseController']]] = [] + _routed = False + + def __init_subclass__(cls, skip_registry: bool = False, **kwargs) -> None: + super().__init_subclass__(**kwargs) # type: ignore + if not skip_registry: + BaseController._registry.append(cls) + + @classmethod + def load_controllers(cls): + import importlib + from pathlib import Path + + path = Path(__file__).parent + logger.debug('Controller import path: %s', path) + modules = [ + f.stem for f in path.glob('*.py') if + not f.name.startswith('_') and f.is_file() and not f.is_symlink()] + logger.debug('Controller files found: %r', modules) + + for module in modules: + importlib.import_module(f'{__package__}.{module}') + + # pylint: disable=protected-access + controllers = [ + controller for controller in BaseController._registry if + controller._routed + ] + + for clist in PLUGIN_MANAGER.hook.get_controllers() or []: + controllers.extend(clist) + + return controllers + + class Endpoint: + """ + An instance of this class represents an endpoint. + """ + + def __init__(self, ctrl, func): + self.ctrl = ctrl + self.inst = None + self.func = func + + if not self.config['proxy']: + setattr(self.ctrl, func.__name__, self.function) + + @property + def config(self): + func = self.func + while not hasattr(func, '_endpoint'): + if hasattr(func, "__wrapped__"): + func = func.__wrapped__ + else: + return None + return func._endpoint # pylint: disable=protected-access + + @property + def function(self): + # pylint: disable=protected-access + return self.ctrl._request_wrapper(self.func, self.method, + self.config['json_response'], + self.config['xml'], + self.config['version']) + + @property + def method(self): + return self.config['method'] + + @property + def proxy(self): + return self.config['proxy'] + + @property + def url(self): + ctrl_path = self.ctrl.get_path() + if ctrl_path == "/": + ctrl_path = "" + if self.config['path'] is not None: + url = "{}{}".format(ctrl_path, self.config['path']) + else: + url = "{}/{}".format(ctrl_path, self.func.__name__) + + ctrl_path_params = self.ctrl.get_path_param_names( + self.config['path']) + path_params = [p['name'] for p in self.path_params + if p['name'] not in ctrl_path_params] + path_params = ["{{{}}}".format(p) for p in path_params] + if path_params: + url += "/{}".format("/".join(path_params)) + + return url + + @property + def action(self): + return self.func.__name__ + + @property + def path_params(self): + ctrl_path_params = self.ctrl.get_path_param_names( + self.config['path']) + func_params = _get_function_params(self.func) + + if self.method in ['GET', 'DELETE']: + assert self.config['path_params'] is None + + return [p for p in func_params if p['name'] in ctrl_path_params + or (p['name'] not in self.config['query_params'] + and p['required'])] + + # elif self.method in ['POST', 'PUT']: + return [p for p in func_params if p['name'] in ctrl_path_params + or p['name'] in self.config['path_params']] + + @property + def query_params(self): + if self.method in ['GET', 'DELETE']: + func_params = _get_function_params(self.func) + path_params = [p['name'] for p in self.path_params] + return [p for p in func_params if p['name'] not in path_params] + + # elif self.method in ['POST', 'PUT']: + func_params = _get_function_params(self.func) + return [p for p in func_params + if p['name'] in self.config['query_params']] + + @property + def body_params(self): + func_params = _get_function_params(self.func) + path_params = [p['name'] for p in self.path_params] + query_params = [p['name'] for p in self.query_params] + return [p for p in func_params + if p['name'] not in path_params + and p['name'] not in query_params] + + @property + def group(self): + return self.ctrl.__name__ + + @property + def is_api(self): + # changed from hasattr to getattr: some ui-based api inherit _api_endpoint + return getattr(self.ctrl, '_api_endpoint', False) + + @property + def is_secure(self): + return self.ctrl._cp_config['tools.authenticate.on'] # pylint: disable=protected-access + + def __repr__(self): + return "Endpoint({}, {}, {})".format(self.url, self.method, + self.action) + + def __init__(self): + logger.info('Initializing controller: %s -> %s', + self.__class__.__name__, self._cp_path_) # type: ignore + super().__init__() + + def _has_permissions(self, permissions, scope=None): + if not self._cp_config['tools.authenticate.on']: # type: ignore + raise Exception("Cannot verify permission in non secured " + "controllers") + + if not isinstance(permissions, list): + permissions = [permissions] + + if scope is None: + scope = getattr(self, '_security_scope', None) + if scope is None: + raise Exception("Cannot verify permissions without scope security" + " defined") + username = JwtManager.LOCAL_USER.username + return AuthManager.authorize(username, scope, permissions) + + @classmethod + def get_path_param_names(cls, path_extension=None): + if path_extension is None: + path_extension = "" + full_path = cls._cp_path_[1:] + path_extension # type: ignore + path_params = [] + for step in full_path.split('/'): + param = None + if not step: + continue + if step[0] == ':': + param = step[1:] + elif step[0] == '{' and step[-1] == '}': + param, _, _ = step[1:-1].partition(':') + if param: + path_params.append(param) + return path_params + + @classmethod + def get_path(cls): + return cls._cp_path_ # type: ignore + + @classmethod + def endpoints(cls): + """ + This method iterates over all the methods decorated with ``@endpoint`` + and creates an Endpoint object for each one of the methods. + + :return: A list of endpoint objects + :rtype: list[BaseController.Endpoint] + """ + result = [] + for _, func in inspect.getmembers(cls, predicate=callable): + if hasattr(func, '_endpoint'): + result.append(cls.Endpoint(cls, func)) + return result + + @staticmethod + def _request_wrapper(func, method, json_response, xml, # pylint: disable=unused-argument + version: Optional[APIVersion]): + # pylint: disable=too-many-branches + @wraps(func) + def inner(*args, **kwargs): + client_version = None + for key, value in kwargs.items(): + if isinstance(value, str): + kwargs[key] = unquote(value) + + # Process method arguments. + params = get_request_body_params(cherrypy.request) + kwargs.update(params) + + if version is not None: + try: + client_version = APIVersion.from_mime_type( + cherrypy.request.headers['Accept']) + except Exception: + raise cherrypy.HTTPError( + 415, "Unable to find version in request header") + + if version.supports(client_version): + ret = func(*args, **kwargs) + else: + raise cherrypy.HTTPError( + 415, + f"Incorrect version: endpoint is '{version!s}', " + f"client requested '{client_version!s}'" + ) + + else: + ret = func(*args, **kwargs) + if isinstance(ret, bytes): + ret = ret.decode('utf-8') + if xml: + if version: + cherrypy.response.headers['Content-Type'] = \ + 'application/vnd.ceph.api.v{}+xml'.format(version) + else: + cherrypy.response.headers['Content-Type'] = 'application/xml' + return ret.encode('utf8') + if json_response: + if version: + cherrypy.response.headers['Content-Type'] = \ + 'application/vnd.ceph.api.v{}+json'.format(version) + else: + cherrypy.response.headers['Content-Type'] = 'application/json' + ret = json.dumps(ret).encode('utf8') + return ret + return inner + + @property + def _request(self): + return self.Request(cherrypy.request) + + class Request(object): + def __init__(self, cherrypy_req): + self._creq = cherrypy_req + + @property + def scheme(self): + return self._creq.scheme + + @property + def host(self): + base = self._creq.base + base = base[len(self.scheme)+3:] + return base[:base.find(":")] if ":" in base else base + + @property + def port(self): + base = self._creq.base + base = base[len(self.scheme)+3:] + default_port = 443 if self.scheme == 'https' else 80 + return int(base[base.find(":")+1:]) if ":" in base else default_port + + @property + def path_info(self): + return self._creq.path_info -- cgit v1.2.3