summaryrefslogtreecommitdiffstats
path: root/test/lib/ansible_test/_internal/core_ci.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib/ansible_test/_internal/core_ci.py')
-rw-r--r--test/lib/ansible_test/_internal/core_ci.py547
1 files changed, 547 insertions, 0 deletions
diff --git a/test/lib/ansible_test/_internal/core_ci.py b/test/lib/ansible_test/_internal/core_ci.py
new file mode 100644
index 0000000..d62b903
--- /dev/null
+++ b/test/lib/ansible_test/_internal/core_ci.py
@@ -0,0 +1,547 @@
+"""Access Ansible Core CI remote services."""
+from __future__ import annotations
+
+import abc
+import dataclasses
+import json
+import os
+import re
+import traceback
+import uuid
+import time
+import typing as t
+
+from .http import (
+ HttpClient,
+ HttpResponse,
+ HttpError,
+)
+
+from .io import (
+ make_dirs,
+ read_text_file,
+ write_json_file,
+ write_text_file,
+)
+
+from .util import (
+ ApplicationError,
+ display,
+ ANSIBLE_TEST_TARGET_ROOT,
+ mutex,
+)
+
+from .util_common import (
+ run_command,
+ ResultType,
+)
+
+from .config import (
+ EnvironmentConfig,
+)
+
+from .ci import (
+ get_ci_provider,
+)
+
+from .data import (
+ data_context,
+)
+
+
+@dataclasses.dataclass(frozen=True)
+class Resource(metaclass=abc.ABCMeta):
+ """Base class for Ansible Core CI resources."""
+ @abc.abstractmethod
+ def as_tuple(self) -> tuple[str, str, str, str]:
+ """Return the resource as a tuple of platform, version, architecture and provider."""
+
+ @abc.abstractmethod
+ def get_label(self) -> str:
+ """Return a user-friendly label for this resource."""
+
+ @property
+ @abc.abstractmethod
+ def persist(self) -> bool:
+ """True if the resource is persistent, otherwise false."""
+
+
+@dataclasses.dataclass(frozen=True)
+class VmResource(Resource):
+ """Details needed to request a VM from Ansible Core CI."""
+ platform: str
+ version: str
+ architecture: str
+ provider: str
+ tag: str
+
+ def as_tuple(self) -> tuple[str, str, str, str]:
+ """Return the resource as a tuple of platform, version, architecture and provider."""
+ return self.platform, self.version, self.architecture, self.provider
+
+ def get_label(self) -> str:
+ """Return a user-friendly label for this resource."""
+ return f'{self.platform} {self.version} ({self.architecture}) [{self.tag}] @{self.provider}'
+
+ @property
+ def persist(self) -> bool:
+ """True if the resource is persistent, otherwise false."""
+ return True
+
+
+@dataclasses.dataclass(frozen=True)
+class CloudResource(Resource):
+ """Details needed to request cloud credentials from Ansible Core CI."""
+ platform: str
+
+ def as_tuple(self) -> tuple[str, str, str, str]:
+ """Return the resource as a tuple of platform, version, architecture and provider."""
+ return self.platform, '', '', self.platform
+
+ def get_label(self) -> str:
+ """Return a user-friendly label for this resource."""
+ return self.platform
+
+ @property
+ def persist(self) -> bool:
+ """True if the resource is persistent, otherwise false."""
+ return False
+
+
+class AnsibleCoreCI:
+ """Client for Ansible Core CI services."""
+ DEFAULT_ENDPOINT = 'https://ansible-core-ci.testing.ansible.com'
+
+ def __init__(
+ self,
+ args: EnvironmentConfig,
+ resource: Resource,
+ load: bool = True,
+ ) -> None:
+ self.args = args
+ self.resource = resource
+ self.platform, self.version, self.arch, self.provider = self.resource.as_tuple()
+ self.stage = args.remote_stage
+ self.client = HttpClient(args)
+ self.connection = None
+ self.instance_id = None
+ self.endpoint = None
+ self.default_endpoint = args.remote_endpoint or self.DEFAULT_ENDPOINT
+ self.retries = 3
+ self.ci_provider = get_ci_provider()
+ self.label = self.resource.get_label()
+
+ stripped_label = re.sub('[^A-Za-z0-9_.]+', '-', self.label).strip('-')
+
+ self.name = f"{stripped_label}-{self.stage}" # turn the label into something suitable for use as a filename
+
+ self.path = os.path.expanduser(f'~/.ansible/test/instances/{self.name}')
+ self.ssh_key = SshKey(args)
+
+ if self.resource.persist and load and self._load():
+ try:
+ display.info(f'Checking existing {self.label} instance using: {self._uri}', verbosity=1)
+
+ self.connection = self.get(always_raise_on=[404])
+
+ display.info(f'Loaded existing {self.label} instance.', verbosity=1)
+ except HttpError as ex:
+ if ex.status != 404:
+ raise
+
+ self._clear()
+
+ display.info(f'Cleared stale {self.label} instance.', verbosity=1)
+
+ self.instance_id = None
+ self.endpoint = None
+ elif not self.resource.persist:
+ self.instance_id = None
+ self.endpoint = None
+ self._clear()
+
+ if self.instance_id:
+ self.started: bool = True
+ else:
+ self.started = False
+ self.instance_id = str(uuid.uuid4())
+ self.endpoint = None
+
+ display.sensitive.add(self.instance_id)
+
+ if not self.endpoint:
+ self.endpoint = self.default_endpoint
+
+ @property
+ def available(self) -> bool:
+ """Return True if Ansible Core CI is supported."""
+ return self.ci_provider.supports_core_ci_auth()
+
+ def start(self) -> t.Optional[dict[str, t.Any]]:
+ """Start instance."""
+ if self.started:
+ display.info(f'Skipping started {self.label} instance.', verbosity=1)
+ return None
+
+ return self._start(self.ci_provider.prepare_core_ci_auth())
+
+ def stop(self) -> None:
+ """Stop instance."""
+ if not self.started:
+ display.info(f'Skipping invalid {self.label} instance.', verbosity=1)
+ return
+
+ response = self.client.delete(self._uri)
+
+ if response.status_code == 404:
+ self._clear()
+ display.info(f'Cleared invalid {self.label} instance.', verbosity=1)
+ return
+
+ if response.status_code == 200:
+ self._clear()
+ display.info(f'Stopped running {self.label} instance.', verbosity=1)
+ return
+
+ raise self._create_http_error(response)
+
+ def get(self, tries: int = 3, sleep: int = 15, always_raise_on: t.Optional[list[int]] = None) -> t.Optional[InstanceConnection]:
+ """Get instance connection information."""
+ if not self.started:
+ display.info(f'Skipping invalid {self.label} instance.', verbosity=1)
+ return None
+
+ if not always_raise_on:
+ always_raise_on = []
+
+ if self.connection and self.connection.running:
+ return self.connection
+
+ while True:
+ tries -= 1
+ response = self.client.get(self._uri)
+
+ if response.status_code == 200:
+ break
+
+ error = self._create_http_error(response)
+
+ if not tries or response.status_code in always_raise_on:
+ raise error
+
+ display.warning(f'{error}. Trying again after {sleep} seconds.')
+ time.sleep(sleep)
+
+ if self.args.explain:
+ self.connection = InstanceConnection(
+ running=True,
+ hostname='cloud.example.com',
+ port=12345,
+ username='root',
+ password='password' if self.platform == 'windows' else None,
+ )
+ else:
+ response_json = response.json()
+ status = response_json['status']
+ con = response_json.get('connection')
+
+ if con:
+ self.connection = InstanceConnection(
+ running=status == 'running',
+ hostname=con['hostname'],
+ port=int(con['port']),
+ username=con['username'],
+ password=con.get('password'),
+ response_json=response_json,
+ )
+ else:
+ self.connection = InstanceConnection(
+ running=status == 'running',
+ response_json=response_json,
+ )
+
+ if self.connection.password:
+ display.sensitive.add(str(self.connection.password))
+
+ status = 'running' if self.connection.running else 'starting'
+
+ display.info(f'The {self.label} instance is {status}.', verbosity=1)
+
+ return self.connection
+
+ def wait(self, iterations: t.Optional[int] = 90) -> None:
+ """Wait for the instance to become ready."""
+ for _iteration in range(1, iterations):
+ if self.get().running:
+ return
+ time.sleep(10)
+
+ raise ApplicationError(f'Timeout waiting for {self.label} instance.')
+
+ @property
+ def _uri(self) -> str:
+ return f'{self.endpoint}/{self.stage}/{self.provider}/{self.instance_id}'
+
+ def _start(self, auth) -> dict[str, t.Any]:
+ """Start instance."""
+ display.info(f'Initializing new {self.label} instance using: {self._uri}', verbosity=1)
+
+ if self.platform == 'windows':
+ winrm_config = read_text_file(os.path.join(ANSIBLE_TEST_TARGET_ROOT, 'setup', 'ConfigureRemotingForAnsible.ps1'))
+ else:
+ winrm_config = None
+
+ data = dict(
+ config=dict(
+ platform=self.platform,
+ version=self.version,
+ architecture=self.arch,
+ public_key=self.ssh_key.pub_contents,
+ winrm_config=winrm_config,
+ )
+ )
+
+ data.update(dict(auth=auth))
+
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+
+ response = self._start_endpoint(data, headers)
+
+ self.started = True
+ self._save()
+
+ display.info(f'Started {self.label} instance.', verbosity=1)
+
+ if self.args.explain:
+ return {}
+
+ return response.json()
+
+ def _start_endpoint(self, data: dict[str, t.Any], headers: dict[str, str]) -> HttpResponse:
+ tries = self.retries
+ sleep = 15
+
+ while True:
+ tries -= 1
+ response = self.client.put(self._uri, data=json.dumps(data), headers=headers)
+
+ if response.status_code == 200:
+ return response
+
+ error = self._create_http_error(response)
+
+ if response.status_code == 503:
+ raise error
+
+ if not tries:
+ raise error
+
+ display.warning(f'{error}. Trying again after {sleep} seconds.')
+ time.sleep(sleep)
+
+ def _clear(self) -> None:
+ """Clear instance information."""
+ try:
+ self.connection = None
+ os.remove(self.path)
+ except FileNotFoundError:
+ pass
+
+ def _load(self) -> bool:
+ """Load instance information."""
+ try:
+ data = read_text_file(self.path)
+ except FileNotFoundError:
+ return False
+
+ if not data.startswith('{'):
+ return False # legacy format
+
+ config = json.loads(data)
+
+ return self.load(config)
+
+ def load(self, config: dict[str, str]) -> bool:
+ """Load the instance from the provided dictionary."""
+ self.instance_id = str(config['instance_id'])
+ self.endpoint = config['endpoint']
+ self.started = True
+
+ display.sensitive.add(self.instance_id)
+
+ return True
+
+ def _save(self) -> None:
+ """Save instance information."""
+ if self.args.explain:
+ return
+
+ config = self.save()
+
+ write_json_file(self.path, config, create_directories=True)
+
+ def save(self) -> dict[str, str]:
+ """Save instance details and return as a dictionary."""
+ return dict(
+ label=self.resource.get_label(),
+ instance_id=self.instance_id,
+ endpoint=self.endpoint,
+ )
+
+ @staticmethod
+ def _create_http_error(response: HttpResponse) -> ApplicationError:
+ """Return an exception created from the given HTTP response."""
+ response_json = response.json()
+ stack_trace = ''
+
+ if 'message' in response_json:
+ message = response_json['message']
+ elif 'errorMessage' in response_json:
+ message = response_json['errorMessage'].strip()
+ if 'stackTrace' in response_json:
+ traceback_lines = response_json['stackTrace']
+
+ # AWS Lambda on Python 2.7 returns a list of tuples
+ # AWS Lambda on Python 3.7 returns a list of strings
+ if traceback_lines and isinstance(traceback_lines[0], list):
+ traceback_lines = traceback.format_list(traceback_lines)
+
+ trace = '\n'.join([x.rstrip() for x in traceback_lines])
+ stack_trace = f'\nTraceback (from remote server):\n{trace}'
+ else:
+ message = str(response_json)
+
+ return CoreHttpError(response.status_code, message, stack_trace)
+
+
+class CoreHttpError(HttpError):
+ """HTTP response as an error."""
+ def __init__(self, status: int, remote_message: str, remote_stack_trace: str) -> None:
+ super().__init__(status, f'{remote_message}{remote_stack_trace}')
+
+ self.remote_message = remote_message
+ self.remote_stack_trace = remote_stack_trace
+
+
+class SshKey:
+ """Container for SSH key used to connect to remote instances."""
+ KEY_TYPE = 'rsa' # RSA is used to maintain compatibility with paramiko and EC2
+ KEY_NAME = f'id_{KEY_TYPE}'
+ PUB_NAME = f'{KEY_NAME}.pub'
+
+ @mutex
+ def __init__(self, args: EnvironmentConfig) -> None:
+ key_pair = self.get_key_pair()
+
+ if not key_pair:
+ key_pair = self.generate_key_pair(args)
+
+ key, pub = key_pair
+ key_dst, pub_dst = self.get_in_tree_key_pair_paths()
+
+ def ssh_key_callback(files: list[tuple[str, str]]) -> None:
+ """
+ Add the SSH keys to the payload file list.
+ They are either outside the source tree or in the cache dir which is ignored by default.
+ """
+ files.append((key, os.path.relpath(key_dst, data_context().content.root)))
+ files.append((pub, os.path.relpath(pub_dst, data_context().content.root)))
+
+ data_context().register_payload_callback(ssh_key_callback)
+
+ self.key, self.pub = key, pub
+
+ if args.explain:
+ self.pub_contents = None
+ self.key_contents = None
+ else:
+ self.pub_contents = read_text_file(self.pub).strip()
+ self.key_contents = read_text_file(self.key).strip()
+
+ @staticmethod
+ def get_relative_in_tree_private_key_path() -> str:
+ """Return the ansible-test SSH private key path relative to the content tree."""
+ temp_dir = ResultType.TMP.relative_path
+
+ key = os.path.join(temp_dir, SshKey.KEY_NAME)
+
+ return key
+
+ def get_in_tree_key_pair_paths(self) -> t.Optional[tuple[str, str]]:
+ """Return the ansible-test SSH key pair paths from the content tree."""
+ temp_dir = ResultType.TMP.path
+
+ key = os.path.join(temp_dir, self.KEY_NAME)
+ pub = os.path.join(temp_dir, self.PUB_NAME)
+
+ return key, pub
+
+ def get_source_key_pair_paths(self) -> t.Optional[tuple[str, str]]:
+ """Return the ansible-test SSH key pair paths for the current user."""
+ base_dir = os.path.expanduser('~/.ansible/test/')
+
+ key = os.path.join(base_dir, self.KEY_NAME)
+ pub = os.path.join(base_dir, self.PUB_NAME)
+
+ return key, pub
+
+ def get_key_pair(self) -> t.Optional[tuple[str, str]]:
+ """Return the ansible-test SSH key pair paths if present, otherwise return None."""
+ key, pub = self.get_in_tree_key_pair_paths()
+
+ if os.path.isfile(key) and os.path.isfile(pub):
+ return key, pub
+
+ key, pub = self.get_source_key_pair_paths()
+
+ if os.path.isfile(key) and os.path.isfile(pub):
+ return key, pub
+
+ return None
+
+ def generate_key_pair(self, args: EnvironmentConfig) -> tuple[str, str]:
+ """Generate an SSH key pair for use by all ansible-test invocations for the current user."""
+ key, pub = self.get_source_key_pair_paths()
+
+ if not args.explain:
+ make_dirs(os.path.dirname(key))
+
+ if not os.path.isfile(key) or not os.path.isfile(pub):
+ run_command(args, ['ssh-keygen', '-m', 'PEM', '-q', '-t', self.KEY_TYPE, '-N', '', '-f', key], capture=True)
+
+ if args.explain:
+ return key, pub
+
+ # newer ssh-keygen PEM output (such as on RHEL 8.1) is not recognized by paramiko
+ key_contents = read_text_file(key)
+ key_contents = re.sub(r'(BEGIN|END) PRIVATE KEY', r'\1 RSA PRIVATE KEY', key_contents)
+
+ write_text_file(key, key_contents)
+
+ return key, pub
+
+
+class InstanceConnection:
+ """Container for remote instance status and connection details."""
+ def __init__(self,
+ running: bool,
+ hostname: t.Optional[str] = None,
+ port: t.Optional[int] = None,
+ username: t.Optional[str] = None,
+ password: t.Optional[str] = None,
+ response_json: t.Optional[dict[str, t.Any]] = None,
+ ) -> None:
+ self.running = running
+ self.hostname = hostname
+ self.port = port
+ self.username = username
+ self.password = password
+ self.response_json = response_json or {}
+
+ def __str__(self):
+ if self.password:
+ return f'{self.hostname}:{self.port} [{self.username}:{self.password}]'
+
+ return f'{self.hostname}:{self.port} [{self.username}]'