summaryrefslogtreecommitdiffstats
path: root/test/lib/ansible_test/_internal/connections.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/lib/ansible_test/_internal/connections.py')
-rw-r--r--test/lib/ansible_test/_internal/connections.py258
1 files changed, 258 insertions, 0 deletions
diff --git a/test/lib/ansible_test/_internal/connections.py b/test/lib/ansible_test/_internal/connections.py
new file mode 100644
index 0000000..4823b1a
--- /dev/null
+++ b/test/lib/ansible_test/_internal/connections.py
@@ -0,0 +1,258 @@
+"""Connection abstraction for interacting with test hosts."""
+from __future__ import annotations
+
+import abc
+import shlex
+import tempfile
+import typing as t
+
+from .io import (
+ read_text_file,
+)
+
+from .config import (
+ EnvironmentConfig,
+)
+
+from .util import (
+ Display,
+ OutputStream,
+ SubprocessError,
+ retry,
+)
+
+from .util_common import (
+ run_command,
+)
+
+from .docker_util import (
+ DockerInspect,
+ docker_exec,
+ docker_inspect,
+ docker_network_disconnect,
+)
+
+from .ssh import (
+ SshConnectionDetail,
+ ssh_options_to_list,
+)
+
+from .become import (
+ Become,
+)
+
+
+class Connection(metaclass=abc.ABCMeta):
+ """Base class for connecting to a host."""
+ @abc.abstractmethod
+ def run(self,
+ command: list[str],
+ capture: bool,
+ interactive: bool = False,
+ data: t.Optional[str] = None,
+ stdin: t.Optional[t.IO[bytes]] = None,
+ stdout: t.Optional[t.IO[bytes]] = None,
+ output_stream: t.Optional[OutputStream] = None,
+ ) -> tuple[t.Optional[str], t.Optional[str]]:
+ """Run the specified command and return the result."""
+
+ def extract_archive(self,
+ chdir: str,
+ src: t.IO[bytes],
+ ):
+ """Extract the given archive file stream in the specified directory."""
+ tar_cmd = ['tar', 'oxzf', '-', '-C', chdir]
+
+ retry(lambda: self.run(tar_cmd, stdin=src, capture=True))
+
+ def create_archive(self,
+ chdir: str,
+ name: str,
+ dst: t.IO[bytes],
+ exclude: t.Optional[str] = None,
+ ):
+ """Create the specified archive file stream from the specified directory, including the given name and optionally excluding the given name."""
+ tar_cmd = ['tar', 'cf', '-', '-C', chdir]
+ gzip_cmd = ['gzip']
+
+ if exclude:
+ tar_cmd += ['--exclude', exclude]
+
+ tar_cmd.append(name)
+
+ # Using gzip to compress the archive allows this to work on all POSIX systems we support.
+ commands = [tar_cmd, gzip_cmd]
+
+ sh_cmd = ['sh', '-c', ' | '.join(shlex.join(command) for command in commands)]
+
+ retry(lambda: self.run(sh_cmd, stdout=dst, capture=True))
+
+
+class LocalConnection(Connection):
+ """Connect to localhost."""
+ def __init__(self, args: EnvironmentConfig) -> None:
+ self.args = args
+
+ def run(self,
+ command: list[str],
+ capture: bool,
+ interactive: bool = False,
+ data: t.Optional[str] = None,
+ stdin: t.Optional[t.IO[bytes]] = None,
+ stdout: t.Optional[t.IO[bytes]] = None,
+ output_stream: t.Optional[OutputStream] = None,
+ ) -> tuple[t.Optional[str], t.Optional[str]]:
+ """Run the specified command and return the result."""
+ return run_command(
+ args=self.args,
+ cmd=command,
+ capture=capture,
+ data=data,
+ stdin=stdin,
+ stdout=stdout,
+ interactive=interactive,
+ output_stream=output_stream,
+ )
+
+
+class SshConnection(Connection):
+ """Connect to a host using SSH."""
+ def __init__(self, args: EnvironmentConfig, settings: SshConnectionDetail, become: t.Optional[Become] = None) -> None:
+ self.args = args
+ self.settings = settings
+ self.become = become
+
+ self.options = ['-i', settings.identity_file]
+
+ ssh_options: dict[str, t.Union[int, str]] = dict(
+ BatchMode='yes',
+ StrictHostKeyChecking='no',
+ UserKnownHostsFile='/dev/null',
+ ServerAliveInterval=15,
+ ServerAliveCountMax=4,
+ )
+
+ ssh_options.update(settings.options)
+
+ self.options.extend(ssh_options_to_list(ssh_options))
+
+ def run(self,
+ command: list[str],
+ capture: bool,
+ interactive: bool = False,
+ data: t.Optional[str] = None,
+ stdin: t.Optional[t.IO[bytes]] = None,
+ stdout: t.Optional[t.IO[bytes]] = None,
+ output_stream: t.Optional[OutputStream] = None,
+ ) -> tuple[t.Optional[str], t.Optional[str]]:
+ """Run the specified command and return the result."""
+ options = list(self.options)
+
+ if self.become:
+ command = self.become.prepare_command(command)
+
+ options.append('-q')
+
+ if interactive:
+ options.append('-tt')
+
+ with tempfile.NamedTemporaryFile(prefix='ansible-test-ssh-debug-', suffix='.log') as ssh_logfile:
+ options.extend(['-vvv', '-E', ssh_logfile.name])
+
+ if self.settings.port:
+ options.extend(['-p', str(self.settings.port)])
+
+ options.append(f'{self.settings.user}@{self.settings.host}')
+ options.append(shlex.join(command))
+
+ def error_callback(ex: SubprocessError) -> None:
+ """Error handler."""
+ self.capture_log_details(ssh_logfile.name, ex)
+
+ return run_command(
+ args=self.args,
+ cmd=['ssh'] + options,
+ capture=capture,
+ data=data,
+ stdin=stdin,
+ stdout=stdout,
+ interactive=interactive,
+ output_stream=output_stream,
+ error_callback=error_callback,
+ )
+
+ @staticmethod
+ def capture_log_details(path: str, ex: SubprocessError) -> None:
+ """Read the specified SSH debug log and add relevant details to the provided exception."""
+ if ex.status != 255:
+ return
+
+ markers = [
+ 'debug1: Connection Established',
+ 'debug1: Authentication successful',
+ 'debug1: Entering interactive session',
+ 'debug1: Sending command',
+ 'debug2: PTY allocation request accepted',
+ 'debug2: exec request accepted',
+ ]
+
+ file_contents = read_text_file(path)
+ messages = []
+
+ for line in reversed(file_contents.splitlines()):
+ messages.append(line)
+
+ if any(line.startswith(marker) for marker in markers):
+ break
+
+ message = '\n'.join(reversed(messages))
+
+ ex.message += '>>> SSH Debug Output\n'
+ ex.message += '%s%s\n' % (message.strip(), Display.clear)
+
+
+class DockerConnection(Connection):
+ """Connect to a host using Docker."""
+ def __init__(self, args: EnvironmentConfig, container_id: str, user: t.Optional[str] = None) -> None:
+ self.args = args
+ self.container_id = container_id
+ self.user: t.Optional[str] = user
+
+ def run(self,
+ command: list[str],
+ capture: bool,
+ interactive: bool = False,
+ data: t.Optional[str] = None,
+ stdin: t.Optional[t.IO[bytes]] = None,
+ stdout: t.Optional[t.IO[bytes]] = None,
+ output_stream: t.Optional[OutputStream] = None,
+ ) -> tuple[t.Optional[str], t.Optional[str]]:
+ """Run the specified command and return the result."""
+ options = []
+
+ if self.user:
+ options.extend(['--user', self.user])
+
+ if interactive:
+ options.append('-it')
+
+ return docker_exec(
+ args=self.args,
+ container_id=self.container_id,
+ cmd=command,
+ options=options,
+ capture=capture,
+ data=data,
+ stdin=stdin,
+ stdout=stdout,
+ interactive=interactive,
+ output_stream=output_stream,
+ )
+
+ def inspect(self) -> DockerInspect:
+ """Inspect the container and return a DockerInspect instance with the results."""
+ return docker_inspect(self.args, self.container_id)
+
+ def disconnect_network(self, network: str) -> None:
+ """Disconnect the container from the specified network."""
+ docker_network_disconnect(self.args, self.container_id, network)