summaryrefslogtreecommitdiffstats
path: root/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py
blob: a38a775b93f6f6842f8b82228c675121b9d44ae3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com>
# (c) 2017, Peter Sprygada <psprygad@redhat.com>
# (c) 2017 Ansible Project
from __future__ import absolute_import, division, print_function

__metaclass__ = type

import os

from ansible import constants as C
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.loader import connection_loader
from ansible.utils.display import Display
from ansible.utils.path import unfrackpath

display = Display()


__all__ = ["NetworkConnectionBase"]

BUFSIZE = 65536


class NetworkConnectionBase(ConnectionBase):
    """
    A base class for network-style connections.
    """

    force_persistence = True
    # Do not use _remote_is_local in other connections
    _remote_is_local = True

    def __init__(self, play_context, new_stdin, *args, **kwargs):
        super(NetworkConnectionBase, self).__init__(
            play_context, new_stdin, *args, **kwargs
        )
        self._messages = []
        self._conn_closed = False

        self._network_os = self._play_context.network_os

        self._local = connection_loader.get("local", play_context, "/dev/null")
        self._local.set_options()

        self._sub_plugin = {}
        self._cached_variables = (None, None, None)

        # reconstruct the socket_path and set instance values accordingly
        self._ansible_playbook_pid = kwargs.get("ansible_playbook_pid")
        self._update_connection_state()

    def __getattr__(self, name):
        try:
            return self.__dict__[name]
        except KeyError:
            if not name.startswith("_"):
                plugin = self._sub_plugin.get("obj")
                if plugin:
                    method = getattr(plugin, name, None)
                    if method is not None:
                        return method
            raise AttributeError(
                "'%s' object has no attribute '%s'"
                % (self.__class__.__name__, name)
            )

    def exec_command(self, cmd, in_data=None, sudoable=True):
        return self._local.exec_command(cmd, in_data, sudoable)

    def queue_message(self, level, message):
        """
        Adds a message to the queue of messages waiting to be pushed back to the controller process.

        :arg level: A string which can either be the name of a method in display, or 'log'. When
            the messages are returned to task_executor, a value of log will correspond to
            ``display.display(message, log_only=True)``, while another value will call ``display.[level](message)``
        """
        self._messages.append((level, message))

    def pop_messages(self):
        messages, self._messages = self._messages, []
        return messages

    def put_file(self, in_path, out_path):
        """Transfer a file from local to remote"""
        return self._local.put_file(in_path, out_path)

    def fetch_file(self, in_path, out_path):
        """Fetch a file from remote to local"""
        return self._local.fetch_file(in_path, out_path)

    def reset(self):
        """
        Reset the connection
        """
        if self._socket_path:
            self.queue_message(
                "vvvv",
                "resetting persistent connection for socket_path %s"
                % self._socket_path,
            )
            self.close()
        self.queue_message("vvvv", "reset call on connection instance")

    def close(self):
        self._conn_closed = True
        if self._connected:
            self._connected = False

    def get_options(self, hostvars=None):
        options = super(NetworkConnectionBase, self).get_options(
            hostvars=hostvars
        )

        if (
            self._sub_plugin.get("obj")
            and self._sub_plugin.get("type") != "external"
        ):
            try:
                options.update(
                    self._sub_plugin["obj"].get_options(hostvars=hostvars)
                )
            except AttributeError:
                pass

        return options

    def set_options(self, task_keys=None, var_options=None, direct=None):
        super(NetworkConnectionBase, self).set_options(
            task_keys=task_keys, var_options=var_options, direct=direct
        )
        if self.get_option("persistent_log_messages"):
            warning = (
                "Persistent connection logging is enabled for %s. This will log ALL interactions"
                % self._play_context.remote_addr
            )
            logpath = getattr(C, "DEFAULT_LOG_PATH")
            if logpath is not None:
                warning += " to %s" % logpath
            self.queue_message(
                "warning",
                "%s and WILL NOT redact sensitive configuration like passwords. USE WITH CAUTION!"
                % warning,
            )

        if (
            self._sub_plugin.get("obj")
            and self._sub_plugin.get("type") != "external"
        ):
            try:
                self._sub_plugin["obj"].set_options(
                    task_keys=task_keys, var_options=var_options, direct=direct
                )
            except AttributeError:
                pass

    def _update_connection_state(self):
        """
        Reconstruct the connection socket_path and check if it exists

        If the socket path exists then the connection is active and set
        both the _socket_path value to the path and the _connected value
        to True.  If the socket path doesn't exist, leave the socket path
        value to None and the _connected value to False
        """
        ssh = connection_loader.get("ssh", class_only=True)
        control_path = ssh._create_control_path(
            self._play_context.remote_addr,
            self._play_context.port,
            self._play_context.remote_user,
            self._play_context.connection,
            self._ansible_playbook_pid,
        )

        tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR)
        socket_path = unfrackpath(control_path % dict(directory=tmp_path))

        if os.path.exists(socket_path):
            self._connected = True
            self._socket_path = socket_path

    def _log_messages(self, message):
        if self.get_option("persistent_log_messages"):
            self.queue_message("log", message)