import logging import os import asyncio from tempfile import NamedTemporaryFile from threading import Thread from contextlib import contextmanager from io import StringIO from shlex import quote from typing import TYPE_CHECKING, Optional, List, Tuple, Dict, Iterator, TypeVar, Awaitable, Union from orchestrator import OrchestratorError try: import asyncssh except ImportError: asyncssh = None # type: ignore if TYPE_CHECKING: from cephadm.module import CephadmOrchestrator from asyncssh.connection import SSHClientConnection T = TypeVar('T') logger = logging.getLogger(__name__) asyncssh_logger = logging.getLogger('asyncssh') asyncssh_logger.propagate = False class HostConnectionError(OrchestratorError): def __init__(self, message: str, hostname: str, addr: str) -> None: super().__init__(message) self.hostname = hostname self.addr = addr DEFAULT_SSH_CONFIG = """ Host * User root StrictHostKeyChecking no UserKnownHostsFile /dev/null ConnectTimeout=30 """ class EventLoopThread(Thread): def __init__(self) -> None: self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) super().__init__(target=self._loop.run_forever) self.start() def get_result(self, coro: Awaitable[T], timeout: Optional[int] = None) -> T: # useful to note: This "run_coroutine_threadsafe" returns a # concurrent.futures.Future, rather than an asyncio.Future. They are # fairly similar but have a few differences, notably in our case # that the result function of a concurrent.futures.Future accepts # a timeout argument future = asyncio.run_coroutine_threadsafe(coro, self._loop) try: return future.result(timeout) except asyncio.TimeoutError: # try to cancel the task before raising the exception further up future.cancel() raise class SSHManager: def __init__(self, mgr: "CephadmOrchestrator"): self.mgr: "CephadmOrchestrator" = mgr self.cons: Dict[str, "SSHClientConnection"] = {} async def _remote_connection(self, host: str, addr: Optional[str] = None, ) -> "SSHClientConnection": if not self.cons.get(host) or host not in self.mgr.inventory: if not addr and host in self.mgr.inventory: addr = self.mgr.inventory.get_addr(host) if not addr: raise OrchestratorError("host address is empty") assert self.mgr.ssh_user n = self.mgr.ssh_user + '@' + addr logger.debug("Opening connection to {} with ssh options '{}'".format( n, self.mgr._ssh_options)) asyncssh.set_log_level('DEBUG') asyncssh.set_debug_level(3) with self.redirect_log(host, addr): try: ssh_options = asyncssh.SSHClientConnectionOptions( keepalive_interval=7, keepalive_count_max=3) conn = await asyncssh.connect(addr, username=self.mgr.ssh_user, client_keys=[self.mgr.tkey.name], known_hosts=None, config=[self.mgr.ssh_config_fname], preferred_auth=['publickey'], options=ssh_options) except OSError: raise except asyncssh.Error: raise except Exception: raise self.cons[host] = conn self.mgr.offline_hosts_remove(host) return self.cons[host] @contextmanager def redirect_log(self, host: str, addr: str) -> Iterator[None]: log_string = StringIO() ch = logging.StreamHandler(log_string) ch.setLevel(logging.INFO) asyncssh_logger.addHandler(ch) try: yield except OSError as e: self.mgr.offline_hosts.add(host) log_content = log_string.getvalue() msg = f"Can't communicate with remote host `{addr}`, possibly because the host is not reachable or python3 is not installed on the host. {str(e)}" logger.exception(msg) raise HostConnectionError(msg, host, addr) except asyncssh.Error as e: self.mgr.offline_hosts.add(host) log_content = log_string.getvalue() msg = f'Failed to connect to {host} ({addr}). {str(e)}' + '\n' + f'Log: {log_content}' logger.debug(msg) raise HostConnectionError(msg, host, addr) except Exception as e: self.mgr.offline_hosts.add(host) log_content = log_string.getvalue() logger.exception(str(e)) raise HostConnectionError( f'Failed to connect to {host} ({addr}): {repr(e)}' + '\n' f'Log: {log_content}', host, addr) finally: log_string.flush() asyncssh_logger.removeHandler(ch) def remote_connection(self, host: str, addr: Optional[str] = None, ) -> "SSHClientConnection": with self.mgr.async_timeout_handler(host, f'ssh {host} (addr {addr})'): return self.mgr.wait_async(self._remote_connection(host, addr)) async def _execute_command(self, host: str, cmd_components: List[str], stdin: Optional[str] = None, addr: Optional[str] = None, log_command: Optional[bool] = True, ) -> Tuple[str, str, int]: conn = await self._remote_connection(host, addr) sudo_prefix = "sudo " if self.mgr.ssh_user != 'root' else "" cmd = sudo_prefix + " ".join(quote(x) for x in cmd_components) try: address = addr or self.mgr.inventory.get_addr(host) except Exception: address = host if log_command: logger.debug(f'Running command: {cmd}') try: r = await conn.run(f'{sudo_prefix}true', check=True, timeout=5) # host quick check r = await conn.run(cmd, input=stdin) # handle these Exceptions otherwise you might get a weird error like # TypeError: __init__() missing 1 required positional argument: 'reason' (due to the asyncssh error interacting with raise_if_exception) except asyncssh.ChannelOpenError as e: # SSH connection closed or broken, will create new connection next call logger.debug(f'Connection to {host} failed. {str(e)}') await self._reset_con(host) self.mgr.offline_hosts.add(host) raise HostConnectionError(f'Unable to reach remote host {host}. {str(e)}', host, address) except asyncssh.ProcessError as e: msg = f"Cannot execute the command '{cmd}' on the {host}. {str(e.stderr)}." logger.debug(msg) await self._reset_con(host) self.mgr.offline_hosts.add(host) raise HostConnectionError(msg, host, address) except Exception as e: msg = f"Generic error while executing command '{cmd}' on the host {host}. {str(e)}." logger.debug(msg) await self._reset_con(host) self.mgr.offline_hosts.add(host) raise HostConnectionError(msg, host, address) def _rstrip(v: Union[bytes, str, None]) -> str: if not v: return '' if isinstance(v, str): return v.rstrip('\n') if isinstance(v, bytes): return v.decode().rstrip('\n') raise OrchestratorError( f'Unable to parse ssh output with type {type(v)} from remote host {host}') out = _rstrip(r.stdout) err = _rstrip(r.stderr) rc = r.returncode if r.returncode else 0 return out, err, rc def execute_command(self, host: str, cmd: List[str], stdin: Optional[str] = None, addr: Optional[str] = None, log_command: Optional[bool] = True ) -> Tuple[str, str, int]: with self.mgr.async_timeout_handler(host, " ".join(cmd)): return self.mgr.wait_async(self._execute_command(host, cmd, stdin, addr, log_command)) async def _check_execute_command(self, host: str, cmd: List[str], stdin: Optional[str] = None, addr: Optional[str] = None, log_command: Optional[bool] = True ) -> str: out, err, code = await self._execute_command(host, cmd, stdin, addr, log_command) if code != 0: msg = f'Command {cmd} failed. {err}' logger.debug(msg) raise OrchestratorError(msg) return out def check_execute_command(self, host: str, cmd: List[str], stdin: Optional[str] = None, addr: Optional[str] = None, log_command: Optional[bool] = True, ) -> str: with self.mgr.async_timeout_handler(host, " ".join(cmd)): return self.mgr.wait_async(self._check_execute_command(host, cmd, stdin, addr, log_command)) async def _write_remote_file(self, host: str, path: str, content: bytes, mode: Optional[int] = None, uid: Optional[int] = None, gid: Optional[int] = None, addr: Optional[str] = None, ) -> None: try: cephadm_tmp_dir = f"/tmp/cephadm-{self.mgr._cluster_fsid}" dirname = os.path.dirname(path) await self._check_execute_command(host, ['mkdir', '-p', dirname], addr=addr) await self._check_execute_command(host, ['mkdir', '-p', cephadm_tmp_dir + dirname], addr=addr) tmp_path = cephadm_tmp_dir + path + '.new' await self._check_execute_command(host, ['touch', tmp_path], addr=addr) if self.mgr.ssh_user != 'root': assert self.mgr.ssh_user await self._check_execute_command(host, ['chown', '-R', self.mgr.ssh_user, cephadm_tmp_dir], addr=addr) await self._check_execute_command(host, ['chmod', str(644), tmp_path], addr=addr) with NamedTemporaryFile(prefix='cephadm-write-remote-file-') as f: os.fchmod(f.fileno(), 0o600) f.write(content) f.flush() conn = await self._remote_connection(host, addr) async with conn.start_sftp_client() as sftp: await sftp.put(f.name, tmp_path) if uid is not None and gid is not None and mode is not None: # shlex quote takes str or byte object, not int await self._check_execute_command(host, ['chown', '-R', str(uid) + ':' + str(gid), tmp_path], addr=addr) await self._check_execute_command(host, ['chmod', oct(mode)[2:], tmp_path], addr=addr) await self._check_execute_command(host, ['mv', tmp_path, path], addr=addr) except Exception as e: msg = f"Unable to write {host}:{path}: {e}" logger.exception(msg) raise OrchestratorError(msg) def write_remote_file(self, host: str, path: str, content: bytes, mode: Optional[int] = None, uid: Optional[int] = None, gid: Optional[int] = None, addr: Optional[str] = None, ) -> None: with self.mgr.async_timeout_handler(host, f'writing file {path}'): self.mgr.wait_async(self._write_remote_file( host, path, content, mode, uid, gid, addr)) async def _reset_con(self, host: str) -> None: conn = self.cons.get(host) if conn: logger.debug(f'_reset_con close {host}') conn.close() del self.cons[host] def reset_con(self, host: str) -> None: with self.mgr.async_timeout_handler(cmd=f'resetting ssh connection to {host}'): self.mgr.wait_async(self._reset_con(host)) def _reset_cons(self) -> None: for host, conn in self.cons.items(): logger.debug(f'_reset_cons close {host}') conn.close() self.cons = {} def _reconfig_ssh(self) -> None: temp_files = [] # type: list ssh_options = [] # type: List[str] # ssh_config self.mgr.ssh_config_fname = self.mgr.ssh_config_file ssh_config = self.mgr.get_store("ssh_config") if ssh_config is not None or self.mgr.ssh_config_fname is None: if not ssh_config: ssh_config = DEFAULT_SSH_CONFIG f = NamedTemporaryFile(prefix='cephadm-conf-') os.fchmod(f.fileno(), 0o600) f.write(ssh_config.encode('utf-8')) f.flush() # make visible to other processes temp_files += [f] self.mgr.ssh_config_fname = f.name if self.mgr.ssh_config_fname: self.mgr.validate_ssh_config_fname(self.mgr.ssh_config_fname) ssh_options += ['-F', self.mgr.ssh_config_fname] self.mgr.ssh_config = ssh_config # identity ssh_key = self.mgr.get_store("ssh_identity_key") ssh_pub = self.mgr.get_store("ssh_identity_pub") ssh_cert = self.mgr.get_store("ssh_identity_cert") self.mgr.ssh_pub = ssh_pub self.mgr.ssh_key = ssh_key self.mgr.ssh_cert = ssh_cert if ssh_key: self.mgr.tkey = NamedTemporaryFile(prefix='cephadm-identity-') self.mgr.tkey.write(ssh_key.encode('utf-8')) os.fchmod(self.mgr.tkey.fileno(), 0o600) self.mgr.tkey.flush() # make visible to other processes temp_files += [self.mgr.tkey] if ssh_pub: tpub = open(self.mgr.tkey.name + '.pub', 'w') os.fchmod(tpub.fileno(), 0o600) tpub.write(ssh_pub) tpub.flush() # make visible to other processes temp_files += [tpub] if ssh_cert: tcert = open(self.mgr.tkey.name + '-cert.pub', 'w') os.fchmod(tcert.fileno(), 0o600) tcert.write(ssh_cert) tcert.flush() # make visible to other processes temp_files += [tcert] ssh_options += ['-i', self.mgr.tkey.name] self.mgr._temp_files = temp_files if ssh_options: self.mgr._ssh_options = ' '.join(ssh_options) else: self.mgr._ssh_options = None if self.mgr.mode == 'root': self.mgr.ssh_user = self.mgr.get_store('ssh_user', default='root') elif self.mgr.mode == 'cephadm-package': self.mgr.ssh_user = 'cephadm' self._reset_cons()