diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 16:04:21 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 16:04:21 +0000 |
commit | 8a754e0858d922e955e71b253c139e071ecec432 (patch) | |
tree | 527d16e74bfd1840c85efd675fdecad056c54107 /test/lib/ansible_test/_internal/core_ci.py | |
parent | Initial commit. (diff) | |
download | ansible-core-8a754e0858d922e955e71b253c139e071ecec432.tar.xz ansible-core-8a754e0858d922e955e71b253c139e071ecec432.zip |
Adding upstream version 2.14.3.upstream/2.14.3upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'test/lib/ansible_test/_internal/core_ci.py')
-rw-r--r-- | test/lib/ansible_test/_internal/core_ci.py | 547 |
1 files changed, 547 insertions, 0 deletions
diff --git a/test/lib/ansible_test/_internal/core_ci.py b/test/lib/ansible_test/_internal/core_ci.py new file mode 100644 index 0000000..d62b903 --- /dev/null +++ b/test/lib/ansible_test/_internal/core_ci.py @@ -0,0 +1,547 @@ +"""Access Ansible Core CI remote services.""" +from __future__ import annotations + +import abc +import dataclasses +import json +import os +import re +import traceback +import uuid +import time +import typing as t + +from .http import ( + HttpClient, + HttpResponse, + HttpError, +) + +from .io import ( + make_dirs, + read_text_file, + write_json_file, + write_text_file, +) + +from .util import ( + ApplicationError, + display, + ANSIBLE_TEST_TARGET_ROOT, + mutex, +) + +from .util_common import ( + run_command, + ResultType, +) + +from .config import ( + EnvironmentConfig, +) + +from .ci import ( + get_ci_provider, +) + +from .data import ( + data_context, +) + + +@dataclasses.dataclass(frozen=True) +class Resource(metaclass=abc.ABCMeta): + """Base class for Ansible Core CI resources.""" + @abc.abstractmethod + def as_tuple(self) -> tuple[str, str, str, str]: + """Return the resource as a tuple of platform, version, architecture and provider.""" + + @abc.abstractmethod + def get_label(self) -> str: + """Return a user-friendly label for this resource.""" + + @property + @abc.abstractmethod + def persist(self) -> bool: + """True if the resource is persistent, otherwise false.""" + + +@dataclasses.dataclass(frozen=True) +class VmResource(Resource): + """Details needed to request a VM from Ansible Core CI.""" + platform: str + version: str + architecture: str + provider: str + tag: str + + def as_tuple(self) -> tuple[str, str, str, str]: + """Return the resource as a tuple of platform, version, architecture and provider.""" + return self.platform, self.version, self.architecture, self.provider + + def get_label(self) -> str: + """Return a user-friendly label for this resource.""" + return f'{self.platform} {self.version} ({self.architecture}) [{self.tag}] @{self.provider}' + + @property + def persist(self) -> bool: + """True if the resource is persistent, otherwise false.""" + return True + + +@dataclasses.dataclass(frozen=True) +class CloudResource(Resource): + """Details needed to request cloud credentials from Ansible Core CI.""" + platform: str + + def as_tuple(self) -> tuple[str, str, str, str]: + """Return the resource as a tuple of platform, version, architecture and provider.""" + return self.platform, '', '', self.platform + + def get_label(self) -> str: + """Return a user-friendly label for this resource.""" + return self.platform + + @property + def persist(self) -> bool: + """True if the resource is persistent, otherwise false.""" + return False + + +class AnsibleCoreCI: + """Client for Ansible Core CI services.""" + DEFAULT_ENDPOINT = 'https://ansible-core-ci.testing.ansible.com' + + def __init__( + self, + args: EnvironmentConfig, + resource: Resource, + load: bool = True, + ) -> None: + self.args = args + self.resource = resource + self.platform, self.version, self.arch, self.provider = self.resource.as_tuple() + self.stage = args.remote_stage + self.client = HttpClient(args) + self.connection = None + self.instance_id = None + self.endpoint = None + self.default_endpoint = args.remote_endpoint or self.DEFAULT_ENDPOINT + self.retries = 3 + self.ci_provider = get_ci_provider() + self.label = self.resource.get_label() + + stripped_label = re.sub('[^A-Za-z0-9_.]+', '-', self.label).strip('-') + + self.name = f"{stripped_label}-{self.stage}" # turn the label into something suitable for use as a filename + + self.path = os.path.expanduser(f'~/.ansible/test/instances/{self.name}') + self.ssh_key = SshKey(args) + + if self.resource.persist and load and self._load(): + try: + display.info(f'Checking existing {self.label} instance using: {self._uri}', verbosity=1) + + self.connection = self.get(always_raise_on=[404]) + + display.info(f'Loaded existing {self.label} instance.', verbosity=1) + except HttpError as ex: + if ex.status != 404: + raise + + self._clear() + + display.info(f'Cleared stale {self.label} instance.', verbosity=1) + + self.instance_id = None + self.endpoint = None + elif not self.resource.persist: + self.instance_id = None + self.endpoint = None + self._clear() + + if self.instance_id: + self.started: bool = True + else: + self.started = False + self.instance_id = str(uuid.uuid4()) + self.endpoint = None + + display.sensitive.add(self.instance_id) + + if not self.endpoint: + self.endpoint = self.default_endpoint + + @property + def available(self) -> bool: + """Return True if Ansible Core CI is supported.""" + return self.ci_provider.supports_core_ci_auth() + + def start(self) -> t.Optional[dict[str, t.Any]]: + """Start instance.""" + if self.started: + display.info(f'Skipping started {self.label} instance.', verbosity=1) + return None + + return self._start(self.ci_provider.prepare_core_ci_auth()) + + def stop(self) -> None: + """Stop instance.""" + if not self.started: + display.info(f'Skipping invalid {self.label} instance.', verbosity=1) + return + + response = self.client.delete(self._uri) + + if response.status_code == 404: + self._clear() + display.info(f'Cleared invalid {self.label} instance.', verbosity=1) + return + + if response.status_code == 200: + self._clear() + display.info(f'Stopped running {self.label} instance.', verbosity=1) + return + + raise self._create_http_error(response) + + def get(self, tries: int = 3, sleep: int = 15, always_raise_on: t.Optional[list[int]] = None) -> t.Optional[InstanceConnection]: + """Get instance connection information.""" + if not self.started: + display.info(f'Skipping invalid {self.label} instance.', verbosity=1) + return None + + if not always_raise_on: + always_raise_on = [] + + if self.connection and self.connection.running: + return self.connection + + while True: + tries -= 1 + response = self.client.get(self._uri) + + if response.status_code == 200: + break + + error = self._create_http_error(response) + + if not tries or response.status_code in always_raise_on: + raise error + + display.warning(f'{error}. Trying again after {sleep} seconds.') + time.sleep(sleep) + + if self.args.explain: + self.connection = InstanceConnection( + running=True, + hostname='cloud.example.com', + port=12345, + username='root', + password='password' if self.platform == 'windows' else None, + ) + else: + response_json = response.json() + status = response_json['status'] + con = response_json.get('connection') + + if con: + self.connection = InstanceConnection( + running=status == 'running', + hostname=con['hostname'], + port=int(con['port']), + username=con['username'], + password=con.get('password'), + response_json=response_json, + ) + else: + self.connection = InstanceConnection( + running=status == 'running', + response_json=response_json, + ) + + if self.connection.password: + display.sensitive.add(str(self.connection.password)) + + status = 'running' if self.connection.running else 'starting' + + display.info(f'The {self.label} instance is {status}.', verbosity=1) + + return self.connection + + def wait(self, iterations: t.Optional[int] = 90) -> None: + """Wait for the instance to become ready.""" + for _iteration in range(1, iterations): + if self.get().running: + return + time.sleep(10) + + raise ApplicationError(f'Timeout waiting for {self.label} instance.') + + @property + def _uri(self) -> str: + return f'{self.endpoint}/{self.stage}/{self.provider}/{self.instance_id}' + + def _start(self, auth) -> dict[str, t.Any]: + """Start instance.""" + display.info(f'Initializing new {self.label} instance using: {self._uri}', verbosity=1) + + if self.platform == 'windows': + winrm_config = read_text_file(os.path.join(ANSIBLE_TEST_TARGET_ROOT, 'setup', 'ConfigureRemotingForAnsible.ps1')) + else: + winrm_config = None + + data = dict( + config=dict( + platform=self.platform, + version=self.version, + architecture=self.arch, + public_key=self.ssh_key.pub_contents, + winrm_config=winrm_config, + ) + ) + + data.update(dict(auth=auth)) + + headers = { + 'Content-Type': 'application/json', + } + + response = self._start_endpoint(data, headers) + + self.started = True + self._save() + + display.info(f'Started {self.label} instance.', verbosity=1) + + if self.args.explain: + return {} + + return response.json() + + def _start_endpoint(self, data: dict[str, t.Any], headers: dict[str, str]) -> HttpResponse: + tries = self.retries + sleep = 15 + + while True: + tries -= 1 + response = self.client.put(self._uri, data=json.dumps(data), headers=headers) + + if response.status_code == 200: + return response + + error = self._create_http_error(response) + + if response.status_code == 503: + raise error + + if not tries: + raise error + + display.warning(f'{error}. Trying again after {sleep} seconds.') + time.sleep(sleep) + + def _clear(self) -> None: + """Clear instance information.""" + try: + self.connection = None + os.remove(self.path) + except FileNotFoundError: + pass + + def _load(self) -> bool: + """Load instance information.""" + try: + data = read_text_file(self.path) + except FileNotFoundError: + return False + + if not data.startswith('{'): + return False # legacy format + + config = json.loads(data) + + return self.load(config) + + def load(self, config: dict[str, str]) -> bool: + """Load the instance from the provided dictionary.""" + self.instance_id = str(config['instance_id']) + self.endpoint = config['endpoint'] + self.started = True + + display.sensitive.add(self.instance_id) + + return True + + def _save(self) -> None: + """Save instance information.""" + if self.args.explain: + return + + config = self.save() + + write_json_file(self.path, config, create_directories=True) + + def save(self) -> dict[str, str]: + """Save instance details and return as a dictionary.""" + return dict( + label=self.resource.get_label(), + instance_id=self.instance_id, + endpoint=self.endpoint, + ) + + @staticmethod + def _create_http_error(response: HttpResponse) -> ApplicationError: + """Return an exception created from the given HTTP response.""" + response_json = response.json() + stack_trace = '' + + if 'message' in response_json: + message = response_json['message'] + elif 'errorMessage' in response_json: + message = response_json['errorMessage'].strip() + if 'stackTrace' in response_json: + traceback_lines = response_json['stackTrace'] + + # AWS Lambda on Python 2.7 returns a list of tuples + # AWS Lambda on Python 3.7 returns a list of strings + if traceback_lines and isinstance(traceback_lines[0], list): + traceback_lines = traceback.format_list(traceback_lines) + + trace = '\n'.join([x.rstrip() for x in traceback_lines]) + stack_trace = f'\nTraceback (from remote server):\n{trace}' + else: + message = str(response_json) + + return CoreHttpError(response.status_code, message, stack_trace) + + +class CoreHttpError(HttpError): + """HTTP response as an error.""" + def __init__(self, status: int, remote_message: str, remote_stack_trace: str) -> None: + super().__init__(status, f'{remote_message}{remote_stack_trace}') + + self.remote_message = remote_message + self.remote_stack_trace = remote_stack_trace + + +class SshKey: + """Container for SSH key used to connect to remote instances.""" + KEY_TYPE = 'rsa' # RSA is used to maintain compatibility with paramiko and EC2 + KEY_NAME = f'id_{KEY_TYPE}' + PUB_NAME = f'{KEY_NAME}.pub' + + @mutex + def __init__(self, args: EnvironmentConfig) -> None: + key_pair = self.get_key_pair() + + if not key_pair: + key_pair = self.generate_key_pair(args) + + key, pub = key_pair + key_dst, pub_dst = self.get_in_tree_key_pair_paths() + + def ssh_key_callback(files: list[tuple[str, str]]) -> None: + """ + Add the SSH keys to the payload file list. + They are either outside the source tree or in the cache dir which is ignored by default. + """ + files.append((key, os.path.relpath(key_dst, data_context().content.root))) + files.append((pub, os.path.relpath(pub_dst, data_context().content.root))) + + data_context().register_payload_callback(ssh_key_callback) + + self.key, self.pub = key, pub + + if args.explain: + self.pub_contents = None + self.key_contents = None + else: + self.pub_contents = read_text_file(self.pub).strip() + self.key_contents = read_text_file(self.key).strip() + + @staticmethod + def get_relative_in_tree_private_key_path() -> str: + """Return the ansible-test SSH private key path relative to the content tree.""" + temp_dir = ResultType.TMP.relative_path + + key = os.path.join(temp_dir, SshKey.KEY_NAME) + + return key + + def get_in_tree_key_pair_paths(self) -> t.Optional[tuple[str, str]]: + """Return the ansible-test SSH key pair paths from the content tree.""" + temp_dir = ResultType.TMP.path + + key = os.path.join(temp_dir, self.KEY_NAME) + pub = os.path.join(temp_dir, self.PUB_NAME) + + return key, pub + + def get_source_key_pair_paths(self) -> t.Optional[tuple[str, str]]: + """Return the ansible-test SSH key pair paths for the current user.""" + base_dir = os.path.expanduser('~/.ansible/test/') + + key = os.path.join(base_dir, self.KEY_NAME) + pub = os.path.join(base_dir, self.PUB_NAME) + + return key, pub + + def get_key_pair(self) -> t.Optional[tuple[str, str]]: + """Return the ansible-test SSH key pair paths if present, otherwise return None.""" + key, pub = self.get_in_tree_key_pair_paths() + + if os.path.isfile(key) and os.path.isfile(pub): + return key, pub + + key, pub = self.get_source_key_pair_paths() + + if os.path.isfile(key) and os.path.isfile(pub): + return key, pub + + return None + + def generate_key_pair(self, args: EnvironmentConfig) -> tuple[str, str]: + """Generate an SSH key pair for use by all ansible-test invocations for the current user.""" + key, pub = self.get_source_key_pair_paths() + + if not args.explain: + make_dirs(os.path.dirname(key)) + + if not os.path.isfile(key) or not os.path.isfile(pub): + run_command(args, ['ssh-keygen', '-m', 'PEM', '-q', '-t', self.KEY_TYPE, '-N', '', '-f', key], capture=True) + + if args.explain: + return key, pub + + # newer ssh-keygen PEM output (such as on RHEL 8.1) is not recognized by paramiko + key_contents = read_text_file(key) + key_contents = re.sub(r'(BEGIN|END) PRIVATE KEY', r'\1 RSA PRIVATE KEY', key_contents) + + write_text_file(key, key_contents) + + return key, pub + + +class InstanceConnection: + """Container for remote instance status and connection details.""" + def __init__(self, + running: bool, + hostname: t.Optional[str] = None, + port: t.Optional[int] = None, + username: t.Optional[str] = None, + password: t.Optional[str] = None, + response_json: t.Optional[dict[str, t.Any]] = None, + ) -> None: + self.running = running + self.hostname = hostname + self.port = port + self.username = username + self.password = password + self.response_json = response_json or {} + + def __str__(self): + if self.password: + return f'{self.hostname}:{self.port} [{self.username}:{self.password}]' + + return f'{self.hostname}:{self.port} [{self.username}]' |