diff options
Diffstat (limited to 'test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py')
-rw-r--r-- | test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py new file mode 100644 index 0000000..955329d --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py @@ -0,0 +1,235 @@ +# (c) 2018, Ansible Inc, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import uuid +import hashlib + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible.plugins.action import ActionBase +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + socket_path = None + network_os = self._get_network_os(task_vars).split(".")[-1] + persistent_connection = self._play_context.connection.split(".")[-1] + + result = super(ActionModule, self).run(task_vars=task_vars) + + if persistent_connection != "network_cli": + # It is supported only with network_cli + result["failed"] = True + result["msg"] = ( + "connection type %s is not valid for net_put module," + " please use fully qualified name of network_cli connection type" + % self._play_context.connection + ) + return result + + try: + src = self._task.args["src"] + except KeyError as exc: + return { + "failed": True, + "msg": "missing required argument: %s" % exc, + } + + src_file_path_name = src + + # Get destination file if specified + dest = self._task.args.get("dest") + + # Get proto + proto = self._task.args.get("protocol") + if proto is None: + proto = "scp" + + # Get mode if set + mode = self._task.args.get("mode") + if mode is None: + mode = "binary" + + if mode == "text": + try: + self._handle_template(convert_data=False) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + # Now src has resolved file write to disk in current diectory for scp + src = self._task.args.get("src") + filename = str(uuid.uuid4()) + cwd = self._loader.get_basedir() + output_file = os.path.join(cwd, filename) + try: + with open(output_file, "wb") as f: + f.write(to_bytes(src, encoding="utf-8")) + except Exception: + os.remove(output_file) + raise + else: + try: + output_file = self._get_binary_src_file(src) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + if socket_path is None: + socket_path = self._connection.socket_path + + conn = Connection(socket_path) + sock_timeout = conn.get_option("persistent_command_timeout") + + if dest is None: + dest = src_file_path_name + + try: + changed = self._handle_existing_file( + conn, output_file, dest, proto, sock_timeout + ) + if changed is False: + result["changed"] = changed + result["destination"] = dest + return result + except Exception as exc: + result["msg"] = ( + "Warning: %s idempotency check failed. Check dest" % exc + ) + + try: + conn.copy_file( + source=output_file, + destination=dest, + proto=proto, + timeout=sock_timeout, + ) + except Exception as exc: + if to_text(exc) == "No response from server": + if network_os == "iosxr": + # IOSXR sometimes closes socket prematurely after completion + # of file transfer + result[ + "msg" + ] = "Warning: iosxr scp server pre close issue. Please check dest" + else: + result["failed"] = True + result["msg"] = "Exception received: %s" % exc + + if mode == "text": + # Cleanup tmp file expanded wih ansible vars + os.remove(output_file) + + result["changed"] = changed + result["destination"] = dest + return result + + def _handle_existing_file(self, conn, source, dest, proto, timeout): + """ + Determines whether the source and destination file match. + + :return: False if source and dest both exist and have matching sha1 sums, True otherwise. + """ + cwd = self._loader.get_basedir() + filename = str(uuid.uuid4()) + tmp_source_file = os.path.join(cwd, filename) + try: + conn.get_file( + source=dest, + destination=tmp_source_file, + proto=proto, + timeout=timeout, + ) + except ConnectionError as exc: + error = to_text(exc) + if error.endswith("No such file or directory"): + if os.path.exists(tmp_source_file): + os.remove(tmp_source_file) + return True + + try: + with open(source, "r") as f: + new_content = f.read() + with open(tmp_source_file, "r") as f: + old_content = f.read() + except (IOError, OSError): + os.remove(tmp_source_file) + raise + + sha1 = hashlib.sha1() + old_content_b = to_bytes(old_content, errors="surrogate_or_strict") + sha1.update(old_content_b) + checksum_old = sha1.digest() + + sha1 = hashlib.sha1() + new_content_b = to_bytes(new_content, errors="surrogate_or_strict") + sha1.update(new_content_b) + checksum_new = sha1.digest() + os.remove(tmp_source_file) + if checksum_old == checksum_new: + return False + return True + + def _get_binary_src_file(self, src): + working_path = self._get_working_path() + + if os.path.isabs(src) or urlsplit("src").scheme: + source = src + else: + source = self._loader.path_dwim_relative( + working_path, "templates", src + ) + if not source: + source = self._loader.path_dwim_relative(working_path, src) + + if not os.path.exists(source): + raise ValueError("path specified in src not found") + + return source + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os |