summaryrefslogtreecommitdiffstats
path: root/crmsh/user_of_host.py
blob: 4551dce9787d8f1c41213f531e2502bb150f8767 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import socket
import subprocess
import typing

from . import config
from . import constants
from . import userdir
from .pyshim import cache


logger = logging.getLogger(__name__)


class UserNotFoundError(ValueError):
    pass


class UserOfHost:
    @staticmethod
    def instance():
        return _user_of_host_instance

    @staticmethod
    @cache
    def this_node():
        return socket.gethostname()

    def __init__(self):
        self._user_cache = dict()
        self._user_pair_cache = dict()

    def user_of(self, host):
        cached = self._user_cache.get(host)
        if cached is None:
            ret = self._get_user_of_host_from_config(host)
            if ret is None:
                raise UserNotFoundError()
            else:
                self._user_cache[host] = ret
                return ret
        else:
            return cached

    def user_pair_for_ssh(self, host: str) -> typing.Tuple[str, str]:
        """Return (local_user, remote_user) pair for ssh connection"""
        local_user = None
        remote_user = None
        try:
            local_user = 'root' if self.use_ssh_agent() else self.user_of(self.this_node())
            remote_user = self.user_of(host)
            return local_user, remote_user
        except UserNotFoundError:
            cached = self._user_pair_cache.get(host)
            if cached is None:
                if local_user is not None:
                    ret = local_user, local_user
                    self._user_pair_cache[host] = ret
                    return ret
                else:
                    ret = self._guess_user_for_ssh(host)
                    if ret is None:
                        raise UserNotFoundError
                    else:
                        self._user_pair_cache[host] = ret
                        return ret
            else:
                return cached

    @staticmethod
    def use_ssh_agent() -> bool:
        return config.get_option('core', 'no_generating_ssh_key')

    @staticmethod
    def _get_user_of_host_from_config(host):
        try:
            canonical, aliases, _ = socket.gethostbyaddr(host)
            aliases = set(aliases)
            aliases.add(canonical)
            aliases.add(host)
        except (socket.herror, socket.gaierror):
            aliases = {host}
        hosts = config.get_option('core', 'hosts')
        if hosts == ['']:
            return None
        for item in hosts:
            if item.find('@') != -1:
                user, node = item.split('@')
            else:
                user = userdir.getuser()
                node = item
            if node in aliases:
                return user
        logger.debug('Failed to get the user of host %s (aliases: %s). Known hosts are %s', host, aliases, hosts)
        return None

    @staticmethod
    def _guess_user_for_ssh(host: str) -> typing.Tuple[str, str]:
        args = ['ssh']
        args.extend(constants.SSH_OPTION_ARGS)
        if userdir.get_sudoer():
            args.extend(['-o', 'BatchMode=yes', host, 'sudo', 'true'])
        else:
            args.extend(['-o', 'BatchMode=yes', host, 'true'])
        rc = subprocess.call(
            args,
            stdin=subprocess.DEVNULL,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        if rc == 0:
            user = userdir.getuser()
            return user, user
        else:
            return None


_user_of_host_instance = UserOfHost()


def instance():
    return _user_of_host_instance