From 8a754e0858d922e955e71b253c139e071ecec432 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 28 Apr 2024 18:04:21 +0200 Subject: Adding upstream version 2.14.3. Signed-off-by: Daniel Baumann --- .../ansible/netcommon/plugins/action/cli_config.py | 40 + .../ansible/netcommon/plugins/action/net_base.py | 90 + .../ansible/netcommon/plugins/action/net_get.py | 199 ++ .../ansible/netcommon/plugins/action/net_put.py | 235 ++ .../ansible/netcommon/plugins/action/network.py | 209 ++ .../ansible/netcommon/plugins/become/enable.py | 42 + .../netcommon/plugins/connection/httpapi.py | 324 +++ .../netcommon/plugins/connection/netconf.py | 404 +++ .../netcommon/plugins/connection/network_cli.py | 1386 +++++++++++ .../netcommon/plugins/connection/persistent.py | 97 + .../plugins/doc_fragments/connection_persistent.py | 77 + .../netcommon/plugins/doc_fragments/netconf.py | 66 + .../plugins/doc_fragments/network_agnostic.py | 14 + .../ansible/netcommon/plugins/filter/ipaddr.py | 1186 +++++++++ .../ansible/netcommon/plugins/filter/network.py | 531 ++++ .../ansible/netcommon/plugins/httpapi/restconf.py | 91 + .../plugins/module_utils/compat/ipaddress.py | 2578 ++++++++++++++++++++ .../module_utils/network/common/cfg/base.py | 27 + .../plugins/module_utils/network/common/config.py | 473 ++++ .../module_utils/network/common/facts/facts.py | 162 ++ .../plugins/module_utils/network/common/netconf.py | 179 ++ .../plugins/module_utils/network/common/network.py | 275 +++ .../plugins/module_utils/network/common/parsing.py | 316 +++ .../plugins/module_utils/network/common/utils.py | 686 ++++++ .../module_utils/network/netconf/netconf.py | 147 ++ .../module_utils/network/restconf/restconf.py | 61 + .../netcommon/plugins/modules/cli_config.py | 444 ++++ .../ansible/netcommon/plugins/modules/net_get.py | 71 + .../ansible/netcommon/plugins/modules/net_put.py | 82 + .../ansible/netcommon/plugins/netconf/default.py | 70 + .../plugins/plugin_utils/connection_base.py | 185 ++ 31 files changed, 10747 insertions(+) create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py create mode 100644 test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py (limited to 'test/support/network-integration/collections/ansible_collections/ansible/netcommon') diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py new file mode 100644 index 0000000..089b339 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py @@ -0,0 +1,40 @@ +# +# Copyright 2018 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from ansible_collections.ansible.netcommon.plugins.action.network import ( + ActionModule as ActionNetworkModule, +) + + +class ActionModule(ActionNetworkModule): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + self._config_module = True + if self._play_context.connection.split(".")[-1] != "network_cli": + return { + "failed": True, + "msg": "Connection type %s is not valid for cli_config module" + % self._play_context.connection, + } + + return super(ActionModule, self).run(task_vars=task_vars) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py new file mode 100644 index 0000000..542dcfe --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py @@ -0,0 +1,90 @@ +# Copyright: (c) 2015, Ansible Inc, +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import copy + +from ansible.errors import AnsibleError +from ansible.plugins.action import ActionBase +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + result = {} + play_context = copy.deepcopy(self._play_context) + play_context.network_os = self._get_network_os(task_vars) + new_task = self._task.copy() + + module = self._get_implementation_module( + play_context.network_os, self._task.action + ) + if not module: + if self._task.args["fail_on_missing_module"]: + result["failed"] = True + else: + result["failed"] = False + + result["msg"] = ( + "Could not find implementation module %s for %s" + % (self._task.action, play_context.network_os) + ) + return result + + new_task.action = module + + action = self._shared_loader_obj.action_loader.get( + play_context.network_os, + task=new_task, + connection=self._connection, + play_context=play_context, + loader=self._loader, + templar=self._templar, + shared_loader_obj=self._shared_loader_obj, + ) + display.vvvv("Running implementation module %s" % module) + return action.run(task_vars=task_vars) + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host to use platform agnostic modules" + ) + + return network_os + + def _get_implementation_module(self, network_os, platform_agnostic_module): + module_name = ( + network_os.split(".")[-1] + + "_" + + platform_agnostic_module.partition("_")[2] + ) + if "." in network_os: + fqcn_module = ".".join(network_os.split(".")[0:-1]) + implementation_module = fqcn_module + "." + module_name + else: + implementation_module = module_name + + if implementation_module not in self._shared_loader_obj.module_loader: + implementation_module = None + + return implementation_module diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py new file mode 100644 index 0000000..40205a4 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py @@ -0,0 +1,199 @@ +# (c) 2018, Ansible Inc, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import re +import uuid +import hashlib + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible.plugins.action import ActionBase +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + socket_path = None + self._get_network_os(task_vars) + persistent_connection = self._play_context.connection.split(".")[-1] + + result = super(ActionModule, self).run(task_vars=task_vars) + + if persistent_connection != "network_cli": + # It is supported only with network_cli + result["failed"] = True + result["msg"] = ( + "connection type %s is not valid for net_get module," + " please use fully qualified name of network_cli connection type" + % self._play_context.connection + ) + return result + + try: + src = self._task.args["src"] + except KeyError as exc: + return { + "failed": True, + "msg": "missing required argument: %s" % exc, + } + + # Get destination file if specified + dest = self._task.args.get("dest") + + if dest is None: + dest = self._get_default_dest(src) + else: + dest = self._handle_dest_path(dest) + + # Get proto + proto = self._task.args.get("protocol") + if proto is None: + proto = "scp" + + if socket_path is None: + socket_path = self._connection.socket_path + + conn = Connection(socket_path) + sock_timeout = conn.get_option("persistent_command_timeout") + + try: + changed = self._handle_existing_file( + conn, src, dest, proto, sock_timeout + ) + if changed is False: + result["changed"] = changed + result["destination"] = dest + return result + except Exception as exc: + result["msg"] = ( + "Warning: %s idempotency check failed. Check dest" % exc + ) + + try: + conn.get_file( + source=src, destination=dest, proto=proto, timeout=sock_timeout + ) + except Exception as exc: + result["failed"] = True + result["msg"] = "Exception received: %s" % exc + + result["changed"] = changed + result["destination"] = dest + return result + + def _handle_dest_path(self, dest): + working_path = self._get_working_path() + + if os.path.isabs(dest) or urlsplit("dest").scheme: + dst = dest + else: + dst = self._loader.path_dwim_relative(working_path, "", dest) + + return dst + + def _get_src_filename_from_path(self, src_path): + filename_list = re.split("/|:", src_path) + return filename_list[-1] + + def _get_default_dest(self, src_path): + dest_path = self._get_working_path() + src_fname = self._get_src_filename_from_path(src_path) + filename = "%s/%s" % (dest_path, src_fname) + return filename + + def _handle_existing_file(self, conn, source, dest, proto, timeout): + """ + Determines whether the source and destination file match. + + :return: False if source and dest both exist and have matching sha1 sums, True otherwise. + """ + if not os.path.exists(dest): + return True + + cwd = self._loader.get_basedir() + filename = str(uuid.uuid4()) + tmp_dest_file = os.path.join(cwd, filename) + try: + conn.get_file( + source=source, + destination=tmp_dest_file, + proto=proto, + timeout=timeout, + ) + except ConnectionError as exc: + error = to_text(exc) + if error.endswith("No such file or directory"): + if os.path.exists(tmp_dest_file): + os.remove(tmp_dest_file) + return True + + try: + with open(tmp_dest_file, "r") as f: + new_content = f.read() + with open(dest, "r") as f: + old_content = f.read() + except (IOError, OSError): + os.remove(tmp_dest_file) + raise + + sha1 = hashlib.sha1() + old_content_b = to_bytes(old_content, errors="surrogate_or_strict") + sha1.update(old_content_b) + checksum_old = sha1.digest() + + sha1 = hashlib.sha1() + new_content_b = to_bytes(new_content, errors="surrogate_or_strict") + sha1.update(new_content_b) + checksum_new = sha1.digest() + os.remove(tmp_dest_file) + if checksum_old == checksum_new: + return False + return True + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py new file mode 100644 index 0000000..955329d --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py @@ -0,0 +1,235 @@ +# (c) 2018, Ansible Inc, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import uuid +import hashlib + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible.plugins.action import ActionBase +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + socket_path = None + network_os = self._get_network_os(task_vars).split(".")[-1] + persistent_connection = self._play_context.connection.split(".")[-1] + + result = super(ActionModule, self).run(task_vars=task_vars) + + if persistent_connection != "network_cli": + # It is supported only with network_cli + result["failed"] = True + result["msg"] = ( + "connection type %s is not valid for net_put module," + " please use fully qualified name of network_cli connection type" + % self._play_context.connection + ) + return result + + try: + src = self._task.args["src"] + except KeyError as exc: + return { + "failed": True, + "msg": "missing required argument: %s" % exc, + } + + src_file_path_name = src + + # Get destination file if specified + dest = self._task.args.get("dest") + + # Get proto + proto = self._task.args.get("protocol") + if proto is None: + proto = "scp" + + # Get mode if set + mode = self._task.args.get("mode") + if mode is None: + mode = "binary" + + if mode == "text": + try: + self._handle_template(convert_data=False) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + # Now src has resolved file write to disk in current diectory for scp + src = self._task.args.get("src") + filename = str(uuid.uuid4()) + cwd = self._loader.get_basedir() + output_file = os.path.join(cwd, filename) + try: + with open(output_file, "wb") as f: + f.write(to_bytes(src, encoding="utf-8")) + except Exception: + os.remove(output_file) + raise + else: + try: + output_file = self._get_binary_src_file(src) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + if socket_path is None: + socket_path = self._connection.socket_path + + conn = Connection(socket_path) + sock_timeout = conn.get_option("persistent_command_timeout") + + if dest is None: + dest = src_file_path_name + + try: + changed = self._handle_existing_file( + conn, output_file, dest, proto, sock_timeout + ) + if changed is False: + result["changed"] = changed + result["destination"] = dest + return result + except Exception as exc: + result["msg"] = ( + "Warning: %s idempotency check failed. Check dest" % exc + ) + + try: + conn.copy_file( + source=output_file, + destination=dest, + proto=proto, + timeout=sock_timeout, + ) + except Exception as exc: + if to_text(exc) == "No response from server": + if network_os == "iosxr": + # IOSXR sometimes closes socket prematurely after completion + # of file transfer + result[ + "msg" + ] = "Warning: iosxr scp server pre close issue. Please check dest" + else: + result["failed"] = True + result["msg"] = "Exception received: %s" % exc + + if mode == "text": + # Cleanup tmp file expanded wih ansible vars + os.remove(output_file) + + result["changed"] = changed + result["destination"] = dest + return result + + def _handle_existing_file(self, conn, source, dest, proto, timeout): + """ + Determines whether the source and destination file match. + + :return: False if source and dest both exist and have matching sha1 sums, True otherwise. + """ + cwd = self._loader.get_basedir() + filename = str(uuid.uuid4()) + tmp_source_file = os.path.join(cwd, filename) + try: + conn.get_file( + source=dest, + destination=tmp_source_file, + proto=proto, + timeout=timeout, + ) + except ConnectionError as exc: + error = to_text(exc) + if error.endswith("No such file or directory"): + if os.path.exists(tmp_source_file): + os.remove(tmp_source_file) + return True + + try: + with open(source, "r") as f: + new_content = f.read() + with open(tmp_source_file, "r") as f: + old_content = f.read() + except (IOError, OSError): + os.remove(tmp_source_file) + raise + + sha1 = hashlib.sha1() + old_content_b = to_bytes(old_content, errors="surrogate_or_strict") + sha1.update(old_content_b) + checksum_old = sha1.digest() + + sha1 = hashlib.sha1() + new_content_b = to_bytes(new_content, errors="surrogate_or_strict") + sha1.update(new_content_b) + checksum_new = sha1.digest() + os.remove(tmp_source_file) + if checksum_old == checksum_new: + return False + return True + + def _get_binary_src_file(self, src): + working_path = self._get_working_path() + + if os.path.isabs(src) or urlsplit("src").scheme: + source = src + else: + source = self._loader.path_dwim_relative( + working_path, "templates", src + ) + if not source: + source = self._loader.path_dwim_relative(working_path, src) + + if not os.path.exists(source): + raise ValueError("path specified in src not found") + + return source + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py new file mode 100644 index 0000000..5d05d33 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py @@ -0,0 +1,209 @@ +# +# (c) 2018 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import time +import re + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.plugins.action.normal import ActionModule as _ActionModule +from ansible.utils.display import Display + +display = Display() + +PRIVATE_KEYS_RE = re.compile("__.+__") + + +class ActionModule(_ActionModule): + def run(self, task_vars=None): + config_module = hasattr(self, "_config_module") and self._config_module + if config_module and self._task.args.get("src"): + try: + self._handle_src_option() + except AnsibleError as e: + return {"failed": True, "msg": e.message, "changed": False} + + result = super(ActionModule, self).run(task_vars=task_vars) + + if ( + config_module + and self._task.args.get("backup") + and not result.get("failed") + ): + self._handle_backup_option(result, task_vars) + + return result + + def _handle_backup_option(self, result, task_vars): + + filename = None + backup_path = None + try: + content = result["__backup__"] + except KeyError: + raise AnsibleError("Failed while reading configuration backup") + + backup_options = self._task.args.get("backup_options") + if backup_options: + filename = backup_options.get("filename") + backup_path = backup_options.get("dir_path") + + if not backup_path: + cwd = self._get_working_path() + backup_path = os.path.join(cwd, "backup") + if not filename: + tstamp = time.strftime( + "%Y-%m-%d@%H:%M:%S", time.localtime(time.time()) + ) + filename = "%s_config.%s" % ( + task_vars["inventory_hostname"], + tstamp, + ) + + dest = os.path.join(backup_path, filename) + backup_path = os.path.expanduser( + os.path.expandvars( + to_bytes(backup_path, errors="surrogate_or_strict") + ) + ) + + if not os.path.exists(backup_path): + os.makedirs(backup_path) + + new_task = self._task.copy() + for item in self._task.args: + if not item.startswith("_"): + new_task.args.pop(item, None) + + new_task.args.update(dict(content=content, dest=dest)) + copy_action = self._shared_loader_obj.action_loader.get( + "copy", + task=new_task, + connection=self._connection, + play_context=self._play_context, + loader=self._loader, + templar=self._templar, + shared_loader_obj=self._shared_loader_obj, + ) + copy_result = copy_action.run(task_vars=task_vars) + if copy_result.get("failed"): + result["failed"] = copy_result["failed"] + result["msg"] = copy_result.get("msg") + return + + result["backup_path"] = dest + if copy_result.get("changed", False): + result["changed"] = copy_result["changed"] + + if backup_options and backup_options.get("filename"): + result["date"] = time.strftime( + "%Y-%m-%d", + time.gmtime(os.stat(result["backup_path"]).st_ctime), + ) + result["time"] = time.strftime( + "%H:%M:%S", + time.gmtime(os.stat(result["backup_path"]).st_ctime), + ) + + else: + result["date"] = tstamp.split("@")[0] + result["time"] = tstamp.split("@")[1] + result["shortname"] = result["backup_path"][::-1].split(".", 1)[1][ + ::-1 + ] + result["filename"] = result["backup_path"].split("/")[-1] + + # strip out any keys that have two leading and two trailing + # underscore characters + for key in list(result.keys()): + if PRIVATE_KEYS_RE.match(key): + del result[key] + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _handle_src_option(self, convert_data=True): + src = self._task.args.get("src") + working_path = self._get_working_path() + + if os.path.isabs(src) or urlsplit("src").scheme: + source = src + else: + source = self._loader.path_dwim_relative( + working_path, "templates", src + ) + if not source: + source = self._loader.path_dwim_relative(working_path, src) + + if not os.path.exists(source): + raise AnsibleError("path specified in src not found") + + try: + with open(source, "r") as f: + template_data = to_text(f.read()) + except IOError as e: + raise AnsibleError( + "unable to load src file {0}, I/O error({1}): {2}".format( + source, e.errno, e.strerror + ) + ) + + # Create a template search path in the following order: + # [working_path, self_role_path, dependent_role_paths, dirname(source)] + searchpath = [working_path] + if self._task._role is not None: + searchpath.append(self._task._role._role_path) + if hasattr(self._task, "_block:"): + dep_chain = self._task._block.get_dep_chain() + if dep_chain is not None: + for role in dep_chain: + searchpath.append(role._role_path) + searchpath.append(os.path.dirname(source)) + with self._templar.set_temporary_context(searchpath=searchpath): + self._task.args["src"] = self._templar.template( + template_data, convert_data=convert_data + ) + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py new file mode 100644 index 0000000..33938fd --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """become: enable +short_description: Switch to elevated permissions on a network device +description: +- This become plugins allows elevated permissions on a remote network device. +author: ansible (@core) +options: + become_pass: + description: password + ini: + - section: enable_become_plugin + key: password + vars: + - name: ansible_become_password + - name: ansible_become_pass + - name: ansible_enable_pass + env: + - name: ANSIBLE_BECOME_PASS + - name: ANSIBLE_ENABLE_PASS +notes: +- enable is really implemented in the network connection handler and as such can only + be used with network connections. +- This plugin ignores the 'become_exe' and 'become_user' settings as it uses an API + and not an executable. +""" + +from ansible.plugins.become import BecomeBase + + +class BecomeModule(BecomeBase): + + name = "ansible.netcommon.enable" + + def build_become_command(self, cmd, shell): + # enable is implemented inside the network connection plugins + return cmd diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py new file mode 100644 index 0000000..b063ef0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py @@ -0,0 +1,324 @@ +# (c) 2018 Red Hat Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +connection: httpapi +short_description: Use httpapi to run command on network appliances +description: +- This connection plugin provides a connection to remote devices over a HTTP(S)-based + api. +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the HTTP(S) connection + to. + default: inventory_hostname + vars: + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the HTTP(S) connection. + - When unspecified, will pick 80 or 443 based on the value of use_ssl. + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_httpapi_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load the correct httpapi plugin to communicate with the remote device + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the API connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + needed for the device API. + vars: + - name: ansible_password + - name: ansible_httpapi_pass + - name: ansible_httpapi_password + use_ssl: + type: boolean + description: + - Whether to connect using SSL (HTTPS) or not (HTTP). + default: false + vars: + - name: ansible_httpapi_use_ssl + validate_certs: + type: boolean + description: + - Whether to validate SSL certificates + default: true + vars: + - name: ansible_httpapi_validate_certs + use_proxy: + type: boolean + description: + - Whether to use https_proxy for requests. + default: true + vars: + - name: ansible_httpapi_use_proxy + become: + type: boolean + description: + - The become option will instruct the CLI session to attempt privilege escalation + on platforms that support it. Normally this means transitioning from user mode + to C(enable) mode in the CLI session. If become is set to True and the remote + device does not support privilege escalation or the privilege has already been + elevated, then this option is silently ignored. + - Can be configured from the CLI via the C(--become) or C(-b) options. + default: false + ini: + - section: privilege_escalation + key: become + env: + - name: ANSIBLE_BECOME + vars: + - name: ansible_become + become_method: + description: + - This option allows the become method to be specified in for handling privilege + escalation. Typically the become_method value is set to C(enable) but could + be defined as other values. + default: sudo + ini: + - section: privilege_escalation + key: become_method + env: + - name: ANSIBLE_BECOME_METHOD + vars: + - name: ansible_become_method + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this option + as it could create a security vulnerability by logging sensitive information + in log file. + default: false + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" + +from io import BytesIO + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_bytes +from ansible.module_utils.six import PY3 +from ansible.module_utils.six.moves import cPickle +from ansible.module_utils.six.moves.urllib.error import HTTPError, URLError +from ansible.module_utils.urls import open_url +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import httpapi_loader +from ansible.plugins.connection import NetworkConnectionBase, ensure_connect + + +class Connection(NetworkConnectionBase): + """Network API connection""" + + transport = "ansible.netcommon.httpapi" + has_pipelining = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + + self._url = None + self._auth = None + + if self._network_os: + + self.httpapi = httpapi_loader.get(self._network_os, self) + if self.httpapi: + self._sub_plugin = { + "type": "httpapi", + "name": self.httpapi._load_name, + "obj": self.httpapi, + } + self.queue_message( + "vvvv", + "loaded API plugin %s from path %s for network_os %s" + % ( + self.httpapi._load_name, + self.httpapi._original_path, + self._network_os, + ), + ) + else: + raise AnsibleConnectionFailure( + "unable to load API plugin for network_os %s" + % self._network_os + ) + + else: + raise AnsibleConnectionFailure( + "Unable to automatically determine host network os. Please " + "manually configure ansible_network_os value for this host" + ) + self.queue_message("log", "network_os is set to %s" % self._network_os) + + def update_play_context(self, pc_data): + """Updates the play context information for the connection""" + pc_data = to_bytes(pc_data) + if PY3: + pc_data = cPickle.loads(pc_data, encoding="bytes") + else: + pc_data = cPickle.loads(pc_data) + play_context = PlayContext() + play_context.deserialize(pc_data) + + self.queue_message("vvvv", "updating play_context for connection") + if self._play_context.become ^ play_context.become: + self.set_become(play_context) + if play_context.become is True: + self.queue_message("vvvv", "authorizing connection") + else: + self.queue_message("vvvv", "deauthorizing connection") + + self._play_context = play_context + + def _connect(self): + if not self.connected: + protocol = "https" if self.get_option("use_ssl") else "http" + host = self.get_option("host") + port = self.get_option("port") or ( + 443 if protocol == "https" else 80 + ) + self._url = "%s://%s:%s" % (protocol, host, port) + + self.queue_message( + "vvv", + "ESTABLISH HTTP(S) CONNECTFOR USER: %s TO %s" + % (self._play_context.remote_user, self._url), + ) + self.httpapi.set_become(self._play_context) + self._connected = True + + self.httpapi.login( + self.get_option("remote_user"), self.get_option("password") + ) + + def close(self): + """ + Close the active session to the device + """ + # only close the connection if its connected. + if self._connected: + self.queue_message("vvvv", "closing http(s) connection to device") + self.logout() + + super(Connection, self).close() + + @ensure_connect + def send(self, path, data, **kwargs): + """ + Sends the command to the device over api + """ + url_kwargs = dict( + timeout=self.get_option("persistent_command_timeout"), + validate_certs=self.get_option("validate_certs"), + use_proxy=self.get_option("use_proxy"), + headers={}, + ) + url_kwargs.update(kwargs) + if self._auth: + # Avoid modifying passed-in headers + headers = dict(kwargs.get("headers", {})) + headers.update(self._auth) + url_kwargs["headers"] = headers + else: + url_kwargs["force_basic_auth"] = True + url_kwargs["url_username"] = self.get_option("remote_user") + url_kwargs["url_password"] = self.get_option("password") + + try: + url = self._url + path + self._log_messages( + "send url '%s' with data '%s' and kwargs '%s'" + % (url, data, url_kwargs) + ) + response = open_url(url, data=data, **url_kwargs) + except HTTPError as exc: + is_handled = self.handle_httperror(exc) + if is_handled is True: + return self.send(path, data, **kwargs) + elif is_handled is False: + raise + else: + response = is_handled + except URLError as exc: + raise AnsibleConnectionFailure( + "Could not connect to {0}: {1}".format( + self._url + path, exc.reason + ) + ) + + response_buffer = BytesIO() + resp_data = response.read() + self._log_messages("received response: '%s'" % resp_data) + response_buffer.write(resp_data) + + # Try to assign a new auth token if one is given + self._auth = self.update_auth(response, response_buffer) or self._auth + + response_buffer.seek(0) + + return response, response_buffer diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py new file mode 100644 index 0000000..1e2d3ca --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py @@ -0,0 +1,404 @@ +# (c) 2016 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +connection: netconf +short_description: Provides a persistent connection using the netconf protocol +description: +- This connection plugin provides a connection to remote devices over the SSH NETCONF + subsystem. This connection plugin is typically used by network devices for sending + and receiving RPC calls over NETCONF. +- Note this connection plugin requires ncclient to be installed on the local Ansible + controller. +requirements: +- ncclient +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the SSH connection + to. + default: inventory_hostname + vars: + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the SSH connection. + default: 830 + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load a device specific netconf plugin. If this option is not configured + (or set to C(auto)), then Ansible will attempt to guess the correct network_os + to use. If it can not guess a network_os correctly it will use C(default). + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the SSH connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + first establishing the SSH connection. + vars: + - name: ansible_password + - name: ansible_ssh_pass + - name: ansible_ssh_password + - name: ansible_netconf_password + private_key_file: + description: + - The private SSH key or certificate file used to authenticate to the remote device + when first establishing the SSH connection. + ini: + - section: defaults + key: private_key_file + env: + - name: ANSIBLE_PRIVATE_KEY_FILE + vars: + - name: ansible_private_key_file + look_for_keys: + default: true + description: + - Enables looking for ssh keys in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`). + env: + - name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS + ini: + - section: paramiko_connection + key: look_for_keys + type: boolean + host_key_checking: + description: Set this to "False" if you want to avoid host key checking by the + underlying tools Ansible uses to connect to the host + type: boolean + default: true + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + - name: ANSIBLE_NETCONF_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: paramiko_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + - name: ansible_netconf_host_key_checking + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + netconf_ssh_config: + description: + - This variable is used to enable bastion/jump host with netconf connection. If + set to True the bastion/jump host ssh settings should be present in ~/.ssh/config + file, alternatively it can be set to custom ssh configuration file path to read + the bastion/jump host settings. + ini: + - section: netconf_connection + key: ssh_config + version_added: '2.7' + env: + - name: ANSIBLE_NETCONF_SSH_CONFIG + vars: + - name: ansible_netconf_ssh_config + version_added: '2.7' + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this option + as it could create a security vulnerability by logging sensitive information + in log file. + default: false + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" + +import os +import logging +import json + +from ansible.errors import AnsibleConnectionFailure, AnsibleError +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.parsing.convert_bool import ( + BOOLEANS_TRUE, + BOOLEANS_FALSE, +) +from ansible.plugins.loader import netconf_loader +from ansible.plugins.connection import NetworkConnectionBase, ensure_connect + +try: + from ncclient import manager + from ncclient.operations import RPCError + from ncclient.transport.errors import SSHUnknownHostError + from ncclient.xml_ import to_ele, to_xml + + HAS_NCCLIENT = True + NCCLIENT_IMP_ERR = None +except ( + ImportError, + AttributeError, +) as err: # paramiko and gssapi are incompatible and raise AttributeError not ImportError + HAS_NCCLIENT = False + NCCLIENT_IMP_ERR = err + +logging.getLogger("ncclient").setLevel(logging.INFO) + + +class Connection(NetworkConnectionBase): + """NetConf connections""" + + transport = "ansible.netcommon.netconf" + has_pipelining = False + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + + # If network_os is not specified then set the network os to auto + # This will be used to trigger the use of guess_network_os when connecting. + self._network_os = self._network_os or "auto" + + self.netconf = netconf_loader.get(self._network_os, self) + if self.netconf: + self._sub_plugin = { + "type": "netconf", + "name": self.netconf._load_name, + "obj": self.netconf, + } + self.queue_message( + "vvvv", + "loaded netconf plugin %s from path %s for network_os %s" + % ( + self.netconf._load_name, + self.netconf._original_path, + self._network_os, + ), + ) + else: + self.netconf = netconf_loader.get("default", self) + self._sub_plugin = { + "type": "netconf", + "name": "default", + "obj": self.netconf, + } + self.queue_message( + "display", + "unable to load netconf plugin for network_os %s, falling back to default plugin" + % self._network_os, + ) + + self.queue_message("log", "network_os is set to %s" % self._network_os) + self._manager = None + self.key_filename = None + self._ssh_config = None + + def exec_command(self, cmd, in_data=None, sudoable=True): + """Sends the request to the node and returns the reply + The method accepts two forms of request. The first form is as a byte + string that represents xml string be send over netconf session. + The second form is a json-rpc (2.0) byte string. + """ + if self._manager: + # to_ele operates on native strings + request = to_ele(to_native(cmd, errors="surrogate_or_strict")) + + if request is None: + return "unable to parse request" + + try: + reply = self._manager.rpc(request) + except RPCError as exc: + error = self.internal_error( + data=to_text(to_xml(exc.xml), errors="surrogate_or_strict") + ) + return json.dumps(error) + + return reply.data_xml + else: + return super(Connection, self).exec_command(cmd, in_data, sudoable) + + @property + @ensure_connect + def manager(self): + return self._manager + + def _connect(self): + if not HAS_NCCLIENT: + raise AnsibleError( + "%s: %s" + % ( + missing_required_lib("ncclient"), + to_native(NCCLIENT_IMP_ERR), + ) + ) + + self.queue_message("log", "ssh connection done, starting ncclient") + + allow_agent = True + if self._play_context.password is not None: + allow_agent = False + setattr(self._play_context, "allow_agent", allow_agent) + + self.key_filename = ( + self._play_context.private_key_file + or self.get_option("private_key_file") + ) + if self.key_filename: + self.key_filename = str(os.path.expanduser(self.key_filename)) + + self._ssh_config = self.get_option("netconf_ssh_config") + if self._ssh_config in BOOLEANS_TRUE: + self._ssh_config = True + elif self._ssh_config in BOOLEANS_FALSE: + self._ssh_config = None + + # Try to guess the network_os if the network_os is set to auto + if self._network_os == "auto": + for cls in netconf_loader.all(class_only=True): + network_os = cls.guess_network_os(self) + if network_os: + self.queue_message( + "vvv", "discovered network_os %s" % network_os + ) + self._network_os = network_os + + # If we have tried to detect the network_os but were unable to i.e. network_os is still 'auto' + # then use default as the network_os + + if self._network_os == "auto": + # Network os not discovered. Set it to default + self.queue_message( + "vvv", + "Unable to discover network_os. Falling back to default.", + ) + self._network_os = "default" + try: + ncclient_device_handler = self.netconf.get_option( + "ncclient_device_handler" + ) + except KeyError: + ncclient_device_handler = "default" + self.queue_message( + "vvv", + "identified ncclient device handler: %s." + % ncclient_device_handler, + ) + device_params = {"name": ncclient_device_handler} + + try: + port = self._play_context.port or 830 + self.queue_message( + "vvv", + "ESTABLISH NETCONF SSH CONNECTION FOR USER: %s on PORT %s TO %s WITH SSH_CONFIG = %s" + % ( + self._play_context.remote_user, + port, + self._play_context.remote_addr, + self._ssh_config, + ), + ) + self._manager = manager.connect( + host=self._play_context.remote_addr, + port=port, + username=self._play_context.remote_user, + password=self._play_context.password, + key_filename=self.key_filename, + hostkey_verify=self.get_option("host_key_checking"), + look_for_keys=self.get_option("look_for_keys"), + device_params=device_params, + allow_agent=self._play_context.allow_agent, + timeout=self.get_option("persistent_connect_timeout"), + ssh_config=self._ssh_config, + ) + + self._manager._timeout = self.get_option( + "persistent_command_timeout" + ) + except SSHUnknownHostError as exc: + raise AnsibleConnectionFailure(to_native(exc)) + except ImportError: + raise AnsibleError( + "connection=netconf is not supported on {0}".format( + self._network_os + ) + ) + + if not self._manager.connected: + return 1, b"", b"not connected" + + self.queue_message( + "log", "ncclient manager object created successfully" + ) + + self._connected = True + + super(Connection, self)._connect() + + return ( + 0, + to_bytes(self._manager.session_id, errors="surrogate_or_strict"), + b"", + ) + + def close(self): + if self._manager: + self._manager.close_session() + super(Connection, self).close() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py new file mode 100644 index 0000000..fef4081 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py @@ -0,0 +1,1386 @@ +# (c) 2016 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """ +author: + - Ansible Networking Team (@ansible-network) +name: network_cli +short_description: Use network_cli to run command on network appliances +description: +- This connection plugin provides a connection to remote devices over the SSH and + implements a CLI shell. This connection plugin is typically used by network devices + for sending and receiving CLi commands to network devices. +version_added: 1.0.0 +requirements: +- ansible-pylibssh if using I(ssh_type=libssh) +extends_documentation_fragment: +- ansible.netcommon.connection_persistent +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the SSH connection + to. + default: inventory_hostname + vars: + - name: inventory_hostname + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the SSH connection. + default: 22 + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load the correct terminal and cliconf plugins to communicate with the remote + device. + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the SSH connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + first establishing the SSH connection. + vars: + - name: ansible_password + - name: ansible_ssh_pass + - name: ansible_ssh_password + private_key_file: + description: + - The private SSH key or certificate file used to authenticate to the remote device + when first establishing the SSH connection. + ini: + - section: defaults + key: private_key_file + env: + - name: ANSIBLE_PRIVATE_KEY_FILE + vars: + - name: ansible_private_key_file + become: + type: boolean + description: + - The become option will instruct the CLI session to attempt privilege escalation + on platforms that support it. Normally this means transitioning from user mode + to C(enable) mode in the CLI session. If become is set to True and the remote + device does not support privilege escalation or the privilege has already been + elevated, then this option is silently ignored. + - Can be configured from the CLI via the C(--become) or C(-b) options. + default: false + ini: + - section: privilege_escalation + key: become + env: + - name: ANSIBLE_BECOME + vars: + - name: ansible_become + become_errors: + type: str + description: + - This option determines how privilege escalation failures are handled when + I(become) is enabled. + - When set to C(ignore), the errors are silently ignored. + When set to C(warn), a warning message is displayed. + The default option C(fail), triggers a failure and halts execution. + vars: + - name: ansible_network_become_errors + default: fail + choices: ["ignore", "warn", "fail"] + terminal_errors: + type: str + description: + - This option determines how failures while setting terminal parameters + are handled. + - When set to C(ignore), the errors are silently ignored. + When set to C(warn), a warning message is displayed. + The default option C(fail), triggers a failure and halts execution. + vars: + - name: ansible_network_terminal_errors + default: fail + choices: ["ignore", "warn", "fail"] + version_added: 3.1.0 + become_method: + description: + - This option allows the become method to be specified in for handling privilege + escalation. Typically the become_method value is set to C(enable) but could + be defined as other values. + default: sudo + ini: + - section: privilege_escalation + key: become_method + env: + - name: ANSIBLE_BECOME_METHOD + vars: + - name: ansible_become_method + host_key_auto_add: + type: boolean + description: + - By default, Ansible will prompt the user before adding SSH keys to the known + hosts file. Since persistent connections such as network_cli run in background + processes, the user will never be prompted. By enabling this option, unknown + host keys will automatically be added to the known hosts file. + - Be sure to fully understand the security implications of enabling this option + on production systems as it could create a security vulnerability. + default: false + ini: + - section: paramiko_connection + key: host_key_auto_add + env: + - name: ANSIBLE_HOST_KEY_AUTO_ADD + persistent_buffer_read_timeout: + type: float + description: + - Configures, in seconds, the amount of time to wait for the data to be read from + Paramiko channel after the command prompt is matched. This timeout value ensures + that command prompt matched is correct and there is no more data left to be + received from remote host. + default: 0.1 + ini: + - section: persistent_connection + key: buffer_read_timeout + env: + - name: ANSIBLE_PERSISTENT_BUFFER_READ_TIMEOUT + vars: + - name: ansible_buffer_read_timeout + terminal_stdout_re: + type: list + elements: dict + description: + - A single regex pattern or a sequence of patterns along with optional flags to + match the command prompt from the received response chunk. This option accepts + C(pattern) and C(flags) keys. The value of C(pattern) is a python regex pattern + to match the response and the value of C(flags) is the value accepted by I(flags) + argument of I(re.compile) python method to control the way regex is matched + with the response, for example I('re.I'). + vars: + - name: ansible_terminal_stdout_re + terminal_stderr_re: + type: list + elements: dict + description: + - This option provides the regex pattern and optional flags to match the error + string from the received response chunk. This option accepts C(pattern) and + C(flags) keys. The value of C(pattern) is a python regex pattern to match the + response and the value of C(flags) is the value accepted by I(flags) argument + of I(re.compile) python method to control the way regex is matched with the + response, for example I('re.I'). + vars: + - name: ansible_terminal_stderr_re + terminal_initial_prompt: + type: list + elements: string + description: + - A single regex pattern or a sequence of patterns to evaluate the expected prompt + at the time of initial login to the remote host. + vars: + - name: ansible_terminal_initial_prompt + terminal_initial_answer: + type: list + elements: string + description: + - The answer to reply with if the C(terminal_initial_prompt) is matched. The value + can be a single answer or a list of answers for multiple terminal_initial_prompt. + In case the login menu has multiple prompts the sequence of the prompt and excepted + answer should be in same order and the value of I(terminal_prompt_checkall) + should be set to I(True) if all the values in C(terminal_initial_prompt) are + expected to be matched and set to I(False) if any one login prompt is to be + matched. + vars: + - name: ansible_terminal_initial_answer + terminal_initial_prompt_checkall: + type: boolean + description: + - By default the value is set to I(False) and any one of the prompts mentioned + in C(terminal_initial_prompt) option is matched it won't check for other prompts. + When set to I(True) it will check for all the prompts mentioned in C(terminal_initial_prompt) + option in the given order and all the prompts should be received from remote + host if not it will result in timeout. + default: false + vars: + - name: ansible_terminal_initial_prompt_checkall + terminal_inital_prompt_newline: + type: boolean + description: + - This boolean flag, that when set to I(True) will send newline in the response + if any of values in I(terminal_initial_prompt) is matched. + default: true + vars: + - name: ansible_terminal_initial_prompt_newline + network_cli_retries: + description: + - Number of attempts to connect to remote host. The delay time between the retires + increases after every attempt by power of 2 in seconds till either the maximum + attempts are exhausted or any of the C(persistent_command_timeout) or C(persistent_connect_timeout) + timers are triggered. + default: 3 + type: integer + env: + - name: ANSIBLE_NETWORK_CLI_RETRIES + ini: + - section: persistent_connection + key: network_cli_retries + vars: + - name: ansible_network_cli_retries + ssh_type: + description: + - The python package that will be used by the C(network_cli) connection plugin to create a SSH connection to remote host. + - I(libssh) will use the ansible-pylibssh package, which needs to be installed in order to work. + - I(paramiko) will instead use the paramiko package to manage the SSH connection. + - I(auto) will use ansible-pylibssh if that package is installed, otherwise will fallback to paramiko. + default: auto + choices: ["libssh", "paramiko", "auto"] + env: + - name: ANSIBLE_NETWORK_CLI_SSH_TYPE + ini: + - section: persistent_connection + key: ssh_type + vars: + - name: ansible_network_cli_ssh_type + host_key_checking: + description: 'Set this to "False" if you want to avoid host key checking by the underlying tools Ansible uses to connect to the host' + type: boolean + default: True + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: persistent_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + single_user_mode: + type: boolean + default: false + version_added: 2.0.0 + description: + - This option enables caching of data fetched from the target for re-use. + The cache is invalidated when the target device enters configuration mode. + - Applicable only for platforms where this has been implemented. + env: + - name: ANSIBLE_NETWORK_SINGLE_USER_MODE + vars: + - name: ansible_network_single_user_mode +""" + +import getpass +import json +import logging +import os +import re +import signal +import socket +import time +import traceback +from functools import wraps +from io import BytesIO + +from ansible.errors import AnsibleConnectionFailure, AnsibleError +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.six import PY3 +from ansible.module_utils.six.moves import cPickle +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import ( + cache_loader, + cliconf_loader, + connection_loader, + terminal_loader, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible_collections.ansible.netcommon.plugins.plugin_utils.connection_base import ( + NetworkConnectionBase, +) + +try: + from scp import SCPClient + + HAS_SCP = True +except ImportError: + HAS_SCP = False + +HAS_PYLIBSSH = False + + +def ensure_connect(func): + @wraps(func) + def wrapped(self, *args, **kwargs): + if not self._connected: + self._connect() + self.update_cli_prompt_context() + return func(self, *args, **kwargs) + + return wrapped + + +class AnsibleCmdRespRecv(Exception): + pass + + +class Connection(NetworkConnectionBase): + """CLI (shell) SSH connections on Paramiko""" + + transport = "ansible.netcommon.network_cli" + has_pipelining = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._ssh_shell = None + + self._matched_prompt = None + self._matched_cmd_prompt = None + self._matched_pattern = None + self._last_response = None + self._history = list() + self._command_response = None + self._last_recv_window = None + self._cache = None + + self._terminal = None + self.cliconf = None + + # Managing prompt context + self._check_prompt = False + + self._task_uuid = to_text(kwargs.get("task_uuid", "")) + self._ssh_type_conn = None + self._ssh_type = None + + self._single_user_mode = False + + if self._network_os: + self._terminal = terminal_loader.get(self._network_os, self) + if not self._terminal: + raise AnsibleConnectionFailure( + "network os %s is not supported" % self._network_os + ) + + self.cliconf = cliconf_loader.get(self._network_os, self) + if self.cliconf: + self._sub_plugin = { + "type": "cliconf", + "name": self.cliconf._load_name, + "obj": self.cliconf, + } + self.queue_message( + "vvvv", + "loaded cliconf plugin %s from path %s for network_os %s" + % ( + self.cliconf._load_name, + self.cliconf._original_path, + self._network_os, + ), + ) + else: + self.queue_message( + "vvvv", + "unable to load cliconf for network_os %s" + % self._network_os, + ) + else: + raise AnsibleConnectionFailure( + "Unable to automatically determine host network os. Please " + "manually configure ansible_network_os value for this host" + ) + self.queue_message("log", "network_os is set to %s" % self._network_os) + + @property + def ssh_type(self): + if self._ssh_type is None: + self._ssh_type = self.get_option("ssh_type") + self.queue_message( + "vvvv", "ssh type is set to %s" % self._ssh_type + ) + # Support autodetection of supported library + if self._ssh_type == "auto": + self.queue_message("vvvv", "autodetecting ssh_type") + if HAS_PYLIBSSH: + self._ssh_type = "libssh" + else: + self.queue_message( + "warning", + "ansible-pylibssh not installed, falling back to paramiko", + ) + self._ssh_type = "paramiko" + self.queue_message( + "vvvv", "ssh type is now set to %s" % self._ssh_type + ) + + if self._ssh_type not in ["paramiko", "libssh"]: + raise AnsibleConnectionFailure( + "Invalid value '%s' set for ssh_type option." + " Expected value is either 'libssh' or 'paramiko'" + % self._ssh_type + ) + + return self._ssh_type + + @property + def ssh_type_conn(self): + if self._ssh_type_conn is None: + if self.ssh_type == "libssh": + connection_plugin = "ansible.netcommon.libssh" + elif self.ssh_type == "paramiko": + # NOTE: This MUST be paramiko or things will break + connection_plugin = "paramiko" + else: + raise AnsibleConnectionFailure( + "Invalid value '%s' set for ssh_type option." + " Expected value is either 'libssh' or 'paramiko'" + % self._ssh_type + ) + + self._ssh_type_conn = connection_loader.get( + connection_plugin, self._play_context, "/dev/null" + ) + + return self._ssh_type_conn + + # To maintain backward compatibility + @property + def paramiko_conn(self): + return self.ssh_type_conn + + def _get_log_channel(self): + name = "p=%s u=%s | " % (os.getpid(), getpass.getuser()) + name += "%s [%s]" % (self.ssh_type, self._play_context.remote_addr) + return name + + @ensure_connect + def get_prompt(self): + """Returns the current prompt from the device""" + return self._matched_prompt + + def exec_command(self, cmd, in_data=None, sudoable=True): + # this try..except block is just to handle the transition to supporting + # network_cli as a toplevel connection. Once connection=local is gone, + # this block can be removed as well and all calls passed directly to + # the local connection + if self._ssh_shell: + try: + cmd = json.loads(to_text(cmd, errors="surrogate_or_strict")) + kwargs = { + "command": to_bytes( + cmd["command"], errors="surrogate_or_strict" + ) + } + for key in ( + "prompt", + "answer", + "sendonly", + "newline", + "prompt_retry_check", + ): + if cmd.get(key) is True or cmd.get(key) is False: + kwargs[key] = cmd[key] + elif cmd.get(key) is not None: + kwargs[key] = to_bytes( + cmd[key], errors="surrogate_or_strict" + ) + return self.send(**kwargs) + except ValueError: + cmd = to_bytes(cmd, errors="surrogate_or_strict") + return self.send(command=cmd) + + else: + return super(Connection, self).exec_command(cmd, in_data, sudoable) + + def get_options(self, hostvars=None): + options = super(Connection, self).get_options(hostvars=hostvars) + options.update(self.ssh_type_conn.get_options(hostvars=hostvars)) + return options + + def set_options(self, task_keys=None, var_options=None, direct=None): + super(Connection, self).set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + self.ssh_type_conn.set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + # Retain old look_for_keys behaviour, but only if not set + if not any( + [ + task_keys and ("look_for_keys" in task_keys), + var_options and ("look_for_keys" in var_options), + direct and ("look_for_keys" in direct), + ] + ): + look_for_keys = not bool( + self.get_option("password") + and not self.get_option("private_key_file") + ) + if not look_for_keys: + # This actually can't be overridden yet without changes in ansible-core + # TODO: Uncomment when appropriate + # self.queue_message( + # "warning", + # "Option look_for_keys has been implicitly set to {0} because " + # "it was not set explicitly. This is retained to maintain " + # "backwards compatibility with the old behavior. This behavior " + # "will be removed in some release after 2024-01-01".format( + # look_for_keys + # ), + # ) + self.ssh_type_conn.set_option("look_for_keys", look_for_keys) + + def update_play_context(self, pc_data): + """Updates the play context information for the connection""" + pc_data = to_bytes(pc_data) + if PY3: + pc_data = cPickle.loads(pc_data, encoding="bytes") + else: + pc_data = cPickle.loads(pc_data) + play_context = PlayContext() + play_context.deserialize(pc_data) + + self.queue_message("vvvv", "updating play_context for connection") + if self._play_context.become ^ play_context.become: + if play_context.become is True: + auth_pass = play_context.become_pass + self._on_become(become_pass=auth_pass) + self.queue_message("vvvv", "authorizing connection") + else: + self._terminal.on_unbecome() + self.queue_message("vvvv", "deauthorizing connection") + + self._play_context = play_context + if self._ssh_type_conn is not None: + # TODO: This works, but is not really ideal. We would rather use + # set_options, but then we need more custom handling in that + # method. + self._ssh_type_conn._play_context = play_context + + if hasattr(self, "reset_history"): + self.reset_history() + if hasattr(self, "disable_response_logging"): + self.disable_response_logging() + + self._single_user_mode = self.get_option("single_user_mode") + + def set_check_prompt(self, task_uuid): + self._check_prompt = task_uuid + + def update_cli_prompt_context(self): + # set cli prompt context at the start of new task run only + if self._check_prompt and self._task_uuid != self._check_prompt: + self._task_uuid, self._check_prompt = self._check_prompt, False + self.set_cli_prompt_context() + + def _connect(self): + """ + Connects to the remote device and starts the terminal + """ + if self._play_context.verbosity > 3: + logging.getLogger(self.ssh_type).setLevel(logging.DEBUG) + + self.queue_message( + "vvvv", "invoked shell using ssh_type: %s" % self.ssh_type + ) + + self._single_user_mode = self.get_option("single_user_mode") + + if not self.connected: + self.ssh_type_conn._set_log_channel(self._get_log_channel()) + self.ssh_type_conn.force_persistence = self.force_persistence + + command_timeout = self.get_option("persistent_command_timeout") + max_pause = min( + [ + self.get_option("persistent_connect_timeout"), + command_timeout, + ] + ) + retries = self.get_option("network_cli_retries") + total_pause = 0 + + for attempt in range(retries + 1): + try: + ssh = self.ssh_type_conn._connect() + break + except AnsibleError: + raise + except Exception as e: + pause = 2 ** (attempt + 1) + if attempt == retries or total_pause >= max_pause: + raise AnsibleConnectionFailure( + to_text(e, errors="surrogate_or_strict") + ) + else: + msg = ( + "network_cli_retry: attempt: %d, caught exception(%s), " + "pausing for %d seconds" + % ( + attempt + 1, + to_text(e, errors="surrogate_or_strict"), + pause, + ) + ) + + self.queue_message("vv", msg) + time.sleep(pause) + total_pause += pause + continue + + self.queue_message("vvvv", "ssh connection done, setting terminal") + self._connected = True + + self._ssh_shell = ssh.ssh.invoke_shell() + if self.ssh_type == "paramiko": + self._ssh_shell.settimeout(command_timeout) + + self.queue_message( + "vvvv", + "loaded terminal plugin for network_os %s" % self._network_os, + ) + + terminal_initial_prompt = ( + self.get_option("terminal_initial_prompt") + or self._terminal.terminal_initial_prompt + ) + terminal_initial_answer = ( + self.get_option("terminal_initial_answer") + or self._terminal.terminal_initial_answer + ) + newline = ( + self.get_option("terminal_inital_prompt_newline") + or self._terminal.terminal_inital_prompt_newline + ) + check_all = ( + self.get_option("terminal_initial_prompt_checkall") or False + ) + + self.receive( + prompts=terminal_initial_prompt, + answer=terminal_initial_answer, + newline=newline, + check_all=check_all, + ) + + if self._play_context.become: + self.queue_message("vvvv", "firing event: on_become") + auth_pass = self._play_context.become_pass + self._on_become(become_pass=auth_pass) + + self.queue_message("vvvv", "firing event: on_open_shell()") + self._on_open_shell() + + self.queue_message( + "vvvv", "ssh connection has completed successfully" + ) + + return self + + def _on_become(self, become_pass=None): + """ + Wraps terminal.on_become() to handle + privilege escalation failures based on user preference + """ + on_become_error = self.get_option("become_errors") + try: + self._terminal.on_become(passwd=become_pass) + except AnsibleConnectionFailure: + if on_become_error == "ignore": + pass + elif on_become_error == "warn": + self.queue_message( + "warning", "on_become: privilege escalation failed" + ) + else: + raise + + def _on_open_shell(self): + """ + Wraps terminal.on_open_shell() to handle + terminal setting failures based on user preference + """ + on_terminal_error = self.get_option("terminal_errors") + try: + self._terminal.on_open_shell() + except AnsibleConnectionFailure: + if on_terminal_error == "ignore": + pass + elif on_terminal_error == "warn": + self.queue_message( + "warning", + "on_open_shell: failed to set terminal parameters", + ) + else: + raise + + def close(self): + """ + Close the active connection to the device + """ + # only close the connection if its connected. + if self._connected: + self.queue_message("debug", "closing ssh connection to device") + if self._ssh_shell: + self.queue_message("debug", "firing event: on_close_shell()") + self._terminal.on_close_shell() + self._ssh_shell.close() + self._ssh_shell = None + self.queue_message("debug", "cli session is now closed") + + self.ssh_type_conn.close() + self._ssh_type_conn = None + self.queue_message( + "debug", "ssh connection has been closed successfully" + ) + super(Connection, self).close() + + def _read_post_command_prompt_match(self): + time.sleep(self.get_option("persistent_buffer_read_timeout")) + data = self._ssh_shell.read_bulk_response() + return data if data else None + + def receive_paramiko( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + + recv = BytesIO() + cache_socket_timeout = self.get_option("persistent_command_timeout") + self._ssh_shell.settimeout(cache_socket_timeout) + command_prompt_matched = False + handled = False + errored_response = None + + while True: + if command_prompt_matched: + try: + signal.signal( + signal.SIGALRM, self._handle_buffer_read_timeout + ) + signal.setitimer( + signal.ITIMER_REAL, self._buffer_read_timeout + ) + data = self._ssh_shell.recv(256) + signal.alarm(0) + self._log_messages( + "response-%s: %s" % (self._window_count + 1, data) + ) + # if data is still received on channel it indicates the prompt string + # is wrongly matched in between response chunks, continue to read + # remaining response. + command_prompt_matched = False + + # restart command_timeout timer + signal.signal(signal.SIGALRM, self._handle_command_timeout) + signal.alarm(self._command_timeout) + + except AnsibleCmdRespRecv: + # reset socket timeout to global timeout + return self._command_response + else: + data = self._ssh_shell.recv(256) + self._log_messages( + "response-%s: %s" % (self._window_count + 1, data) + ) + # when a channel stream is closed, received data will be empty + if not data: + break + + recv.write(data) + offset = recv.tell() - 256 if recv.tell() > 256 else 0 + recv.seek(offset) + + window = self._strip(recv.read()) + self._last_recv_window = window + self._window_count += 1 + + if prompts and not handled: + handled = self._handle_prompt( + window, prompts, answer, newline, False, check_all + ) + self._matched_prompt_window = self._window_count + elif ( + prompts + and handled + and prompt_retry_check + and self._matched_prompt_window + 1 == self._window_count + ): + # check again even when handled, if same prompt repeats in next window + # (like in the case of a wrong enable password, etc) indicates + # value of answer is wrong, report this as error. + if self._handle_prompt( + window, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + ): + raise AnsibleConnectionFailure( + "For matched prompt '%s', answer is not valid" + % self._matched_cmd_prompt + ) + + if self._find_error(window): + # We can't exit here, as we need to drain the buffer in case + # the error isn't fatal, and will be using the buffer again + errored_response = window + + if self._find_prompt(window): + if errored_response: + raise AnsibleConnectionFailure(errored_response) + self._last_response = recv.getvalue() + resp = self._strip(self._last_response) + self._command_response = self._sanitize( + resp, command, strip_prompt + ) + if self._buffer_read_timeout == 0.0: + # reset socket timeout to global timeout + return self._command_response + else: + command_prompt_matched = True + + def receive_libssh( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + self._command_response = resp = b"" + command_prompt_matched = False + handled = False + errored_response = None + + while True: + + if command_prompt_matched: + data = self._read_post_command_prompt_match() + if data: + command_prompt_matched = False + else: + return self._command_response + else: + try: + data = self._ssh_shell.read_bulk_response() + # TODO: Should be ConnectionError when pylibssh drops Python 2 support + except OSError: + # Socket has closed + break + + if not data: + continue + self._last_recv_window = self._strip(data) + resp += self._last_recv_window + self._window_count += 1 + + self._log_messages("response-%s: %s" % (self._window_count, data)) + + if prompts and not handled: + handled = self._handle_prompt( + resp, prompts, answer, newline, False, check_all + ) + self._matched_prompt_window = self._window_count + elif ( + prompts + and handled + and prompt_retry_check + and self._matched_prompt_window + 1 == self._window_count + ): + # check again even when handled, if same prompt repeats in next window + # (like in the case of a wrong enable password, etc) indicates + # value of answer is wrong, report this as error. + if self._handle_prompt( + resp, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + ): + raise AnsibleConnectionFailure( + "For matched prompt '%s', answer is not valid" + % self._matched_cmd_prompt + ) + + if self._find_error(resp): + # We can't exit here, as we need to drain the buffer in case + # the error isn't fatal, and will be using the buffer again + errored_response = resp + + if self._find_prompt(resp): + if errored_response: + raise AnsibleConnectionFailure(errored_response) + self._last_response = data + self._command_response += self._sanitize( + resp, command, strip_prompt + ) + command_prompt_matched = True + + def receive( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + """ + Handles receiving of output from command + """ + self._matched_prompt = None + self._matched_cmd_prompt = None + self._matched_prompt_window = 0 + self._window_count = 0 + + # set terminal regex values for command prompt and errors in response + self._terminal_stderr_re = self._get_terminal_std_re( + "terminal_stderr_re" + ) + self._terminal_stdout_re = self._get_terminal_std_re( + "terminal_stdout_re" + ) + + self._command_timeout = self.get_option("persistent_command_timeout") + self._validate_timeout_value( + self._command_timeout, "persistent_command_timeout" + ) + + self._buffer_read_timeout = self.get_option( + "persistent_buffer_read_timeout" + ) + self._validate_timeout_value( + self._buffer_read_timeout, "persistent_buffer_read_timeout" + ) + + self._log_messages("command: %s" % command) + if self.ssh_type == "libssh": + response = self.receive_libssh( + command, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + elif self.ssh_type == "paramiko": + response = self.receive_paramiko( + command, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + + return response + + @ensure_connect + def send( + self, + command, + prompt=None, + answer=None, + newline=True, + sendonly=False, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + """ + Sends the command to the device in the opened shell + """ + # try cache first + if (not prompt) and (self._single_user_mode): + out = self.get_cache().lookup(command) + if out: + self.queue_message( + "vvvv", "cache hit for command: %s" % command + ) + return out + + if check_all: + prompt_len = len(to_list(prompt)) + answer_len = len(to_list(answer)) + if prompt_len != answer_len: + raise AnsibleConnectionFailure( + "Number of prompts (%s) is not same as that of answers (%s)" + % (prompt_len, answer_len) + ) + try: + cmd = b"%s\r" % command + self._history.append(cmd) + self._ssh_shell.sendall(cmd) + self._log_messages("send command: %s" % cmd) + if sendonly: + return + response = self.receive( + command, + prompt, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + response = to_text(response, errors="surrogate_then_replace") + + if (not prompt) and (self._single_user_mode): + if self._needs_cache_invalidation(command): + # invalidate the existing cache + if self.get_cache().keys(): + self.queue_message( + "vvvv", "invalidating existing cache" + ) + self.get_cache().invalidate() + else: + # populate cache + self.queue_message( + "vvvv", "populating cache for command: %s" % command + ) + self.get_cache().populate(command, response) + + return response + except (socket.timeout, AttributeError): + self.queue_message("error", traceback.format_exc()) + raise AnsibleConnectionFailure( + "timeout value %s seconds reached while trying to send command: %s" + % (self._ssh_shell.gettimeout(), command.strip()) + ) + + def _handle_buffer_read_timeout(self, signum, frame): + self.queue_message( + "vvvv", + "Response received, triggered 'persistent_buffer_read_timeout' timer of %s seconds" + % self.get_option("persistent_buffer_read_timeout"), + ) + raise AnsibleCmdRespRecv() + + def _handle_command_timeout(self, signum, frame): + msg = ( + "command timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and Troubleshooting Guide." + % self.get_option("persistent_command_timeout") + ) + self.queue_message("log", msg) + raise AnsibleConnectionFailure(msg) + + def _strip(self, data): + """ + Removes ANSI codes from device response + """ + for regex in self._terminal.ansi_re: + data = regex.sub(b"", data) + return data + + def _handle_prompt( + self, + resp, + prompts, + answer, + newline, + prompt_retry_check=False, + check_all=False, + ): + """ + Matches the command prompt and responds + + :arg resp: Byte string containing the raw response from the remote + :arg prompts: Sequence of byte strings that we consider prompts for input + :arg answer: Sequence of Byte string to send back to the remote if we find a prompt. + A carriage return is automatically appended to this string. + :param prompt_retry_check: Bool value for trying to detect more prompts + :param check_all: Bool value to indicate if all the values in prompt sequence should be matched or any one of + given prompt. + :returns: True if a prompt was found in ``resp``. If check_all is True + will True only after all the prompt in the prompts list are matched. False otherwise. + """ + single_prompt = False + if not isinstance(prompts, list): + prompts = [prompts] + single_prompt = True + if not isinstance(answer, list): + answer = [answer] + try: + prompts_regex = [re.compile(to_bytes(r), re.I) for r in prompts] + except re.error as exc: + raise ConnectionError( + "Failed to compile one or more terminal prompt regexes: %s.\n" + "Prompts provided: %s" % (to_text(exc), prompts) + ) + for index, regex in enumerate(prompts_regex): + match = regex.search(resp) + if match: + self._matched_cmd_prompt = match.group() + self._log_messages( + "matched command prompt: %s" % self._matched_cmd_prompt + ) + + # if prompt_retry_check is enabled to check if same prompt is + # repeated don't send answer again. + if not prompt_retry_check: + prompt_answer = to_bytes( + answer[index] if len(answer) > index else answer[0] + ) + if newline: + prompt_answer += b"\r" + self._ssh_shell.sendall(prompt_answer) + self._log_messages( + "matched command prompt answer: %s" % prompt_answer + ) + if check_all and prompts and not single_prompt: + prompts.pop(0) + answer.pop(0) + return False + return True + return False + + def _sanitize(self, resp, command=None, strip_prompt=True): + """ + Removes elements from the response before returning to the caller + """ + cleaned = [] + for line in resp.splitlines(): + if command and line.strip() == command.strip(): + continue + + for prompt in self._matched_prompt.strip().splitlines(): + if prompt.strip() in line and strip_prompt: + break + else: + cleaned.append(line) + + return b"\n".join(cleaned).strip() + + def _find_error(self, response): + """Searches the buffered response for a matching error condition""" + for stderr_regex in self._terminal_stderr_re: + if stderr_regex.search(response): + self._log_messages( + "matched error regex (terminal_stderr_re) '%s' from response '%s'" + % (stderr_regex.pattern, response) + ) + + self._log_messages( + "matched stdout regex (terminal_stdout_re) '%s' from error response '%s'" + % (self._matched_pattern, response) + ) + return True + + return False + + def _find_prompt(self, response): + """Searches the buffered response for a matching command prompt""" + for stdout_regex in self._terminal_stdout_re: + match = stdout_regex.search(response) + if match: + self._matched_pattern = stdout_regex.pattern + self._matched_prompt = match.group() + self._log_messages( + "matched cli prompt '%s' with regex '%s' from response '%s'" + % (self._matched_prompt, self._matched_pattern, response) + ) + return True + + return False + + def _validate_timeout_value(self, timeout, timer_name): + if timeout < 0: + raise AnsibleConnectionFailure( + "'%s' timer value '%s' is invalid, value should be greater than or equal to zero." + % (timer_name, timeout) + ) + + def transport_test(self, connect_timeout): + """This method enables wait_for_connection to work. + + As it is used by wait_for_connection, it is called by that module's action plugin, + which is on the controller process, which means that nothing done on this instance + should impact the actual persistent connection... this check is for informational + purposes only and should be properly cleaned up. + """ + + # Force a fresh connect if for some reason we have connected before. + self.close() + self._connect() + self.close() + + def _get_terminal_std_re(self, option): + terminal_std_option = self.get_option(option) + terminal_std_re = [] + + if terminal_std_option: + for item in terminal_std_option: + if "pattern" not in item: + raise AnsibleConnectionFailure( + "'pattern' is a required key for option '%s'," + " received option value is %s" % (option, item) + ) + pattern = rb"%s" % to_bytes(item["pattern"]) + flag = item.get("flags", 0) + if flag: + flag = getattr(re, flag.split(".")[1]) + terminal_std_re.append(re.compile(pattern, flag)) + else: + # To maintain backward compatibility + terminal_std_re = getattr(self._terminal, option) + + return terminal_std_re + + def copy_file( + self, source=None, destination=None, proto="scp", timeout=30 + ): + """Copies file over scp/sftp to remote device + + :param source: Source file path + :param destination: Destination file path on remote device + :param proto: Protocol to be used for file transfer, + supported protocol: scp and sftp + :param timeout: Specifies the wait time to receive response from + remote host before triggering timeout exception + :return: None + """ + ssh = self.ssh_type_conn._connect_uncached() + if self.ssh_type == "libssh": + self.ssh_type_conn.put_file(source, destination, proto=proto) + elif self.ssh_type == "paramiko": + if proto == "scp": + if not HAS_SCP: + raise AnsibleError(missing_required_lib("scp")) + with SCPClient( + ssh.get_transport(), socket_timeout=timeout + ) as scp: + scp.put(source, destination) + elif proto == "sftp": + with ssh.open_sftp() as sftp: + sftp.put(source, destination) + else: + raise AnsibleError( + "Do not know how to do transfer file over protocol %s" + % proto + ) + else: + raise AnsibleError( + "Do not know how to do SCP with ssh_type %s" % self.ssh_type + ) + + def get_file(self, source=None, destination=None, proto="scp", timeout=30): + """Fetch file over scp/sftp from remote device + :param source: Source file path + :param destination: Destination file path + :param proto: Protocol to be used for file transfer, + supported protocol: scp and sftp + :param timeout: Specifies the wait time to receive response from + remote host before triggering timeout exception + :return: None + """ + """Fetch file over scp/sftp from remote device""" + ssh = self.ssh_type_conn._connect_uncached() + if self.ssh_type == "libssh": + self.ssh_type_conn.fetch_file(source, destination, proto=proto) + elif self.ssh_type == "paramiko": + if proto == "scp": + if not HAS_SCP: + raise AnsibleError(missing_required_lib("scp")) + try: + with SCPClient( + ssh.get_transport(), socket_timeout=timeout + ) as scp: + scp.get(source, destination) + except EOFError: + # This appears to be benign. + pass + elif proto == "sftp": + with ssh.open_sftp() as sftp: + sftp.get(source, destination) + else: + raise AnsibleError( + "Do not know how to do transfer file over protocol %s" + % proto + ) + else: + raise AnsibleError( + "Do not know how to do SCP with ssh_type %s" % self.ssh_type + ) + + def get_cache(self): + if not self._cache: + # TO-DO: support jsonfile or other modes of caching with + # a configurable option + self._cache = cache_loader.get("ansible.netcommon.memory") + return self._cache + + def _is_in_config_mode(self): + """ + Check if the target device is in config mode by comparing + the current prompt with the platform's `terminal_config_prompt`. + Returns False if `terminal_config_prompt` is not defined. + + :returns: A boolean indicating if the device is in config mode or not. + """ + cfg_mode = False + cur_prompt = to_text( + self.get_prompt(), errors="surrogate_then_replace" + ).strip() + cfg_prompt = getattr(self._terminal, "terminal_config_prompt", None) + if cfg_prompt and cfg_prompt.match(cur_prompt): + cfg_mode = True + return cfg_mode + + def _needs_cache_invalidation(self, command): + """ + This method determines if it is necessary to invalidate + the existing cache based on whether the device has entered + configuration mode or if the last command sent to the device + is potentially capable of making configuration changes. + + :param command: The last command sent to the target device. + :returns: A boolean indicating if cache invalidation is required or not. + """ + invalidate = False + cfg_cmds = [] + try: + # AnsiblePlugin base class in Ansible 2.9 does not have has_option() method. + # TO-DO: use has_option() when we drop 2.9 support. + cfg_cmds = self.cliconf.get_option("config_commands") + except AttributeError: + cfg_cmds = [] + if (self._is_in_config_mode()) or (to_text(command) in cfg_cmds): + invalidate = True + return invalidate diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py new file mode 100644 index 0000000..b29b487 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py @@ -0,0 +1,97 @@ +# 2017 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Core Team +connection: persistent +short_description: Use a persistent unix socket for connection +description: +- This is a helper plugin to allow making other connections persistent. +options: + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close + default: 10 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout +""" +from ansible.executor.task_executor import start_connection +from ansible.plugins.connection import ConnectionBase +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import Connection as SocketConnection +from ansible.utils.display import Display + +display = Display() + + +class Connection(ConnectionBase): + """ Local based connections """ + + transport = "ansible.netcommon.persistent" + has_pipelining = False + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._task_uuid = to_text(kwargs.get("task_uuid", "")) + + def _connect(self): + self._connected = True + return self + + def exec_command(self, cmd, in_data=None, sudoable=True): + display.vvvv( + "exec_command(), socket_path=%s" % self.socket_path, + host=self._play_context.remote_addr, + ) + connection = SocketConnection(self.socket_path) + out = connection.exec_command(cmd, in_data=in_data, sudoable=sudoable) + return 0, out, "" + + def put_file(self, in_path, out_path): + pass + + def fetch_file(self, in_path, out_path): + pass + + def close(self): + self._connected = False + + def run(self): + """Returns the path of the persistent connection socket. + + Attempts to ensure (within playcontext.timeout seconds) that the + socket path exists. If the path exists (or the timeout has expired), + returns the socket path. + """ + display.vvvv( + "starting connection from persistent connection plugin", + host=self._play_context.remote_addr, + ) + variables = { + "ansible_command_timeout": self.get_option( + "persistent_command_timeout" + ) + } + socket_path = start_connection( + self._play_context, variables, self._task_uuid + ) + display.vvvv( + "local domain socket path is %s" % socket_path, + host=self._play_context.remote_addr, + ) + setattr(self, "_socket_path", socket_path) + return socket_path diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py new file mode 100644 index 0000000..d572c30 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r""" +options: + import_modules: + type: boolean + description: + - Reduce CPU usage and network module execution time + by enabling direct execution. Instead of the module being packaged + and executed by the shell, it will be directly executed by the Ansible + control node using the same python interpreter as the Ansible process. + Note- Incompatible with C(asynchronous mode). + Note- Python 3 and Ansible 2.9.16 or greater required. + Note- With Ansible 2.9.x fully qualified modules names are required in tasks. + default: true + ini: + - section: ansible_network + key: import_modules + env: + - name: ANSIBLE_NETWORK_IMPORT_MODULES + vars: + - name: ansible_network_import_modules + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to + return from the remote device. If this timer is exceeded before the + command returns, the connection plugin will raise an exception and + close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this + option as it could create a security vulnerability by logging sensitive information in log file. + default: False + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py new file mode 100644 index 0000000..8789075 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: + host: + description: + - Specifies the DNS host name or address for connecting to the remote device over + the specified transport. The value of host is used as the destination address + for the transport. + type: str + required: true + port: + description: + - Specifies the port to use when building the connection to the remote device. The + port value will default to port 830. + type: int + default: 830 + username: + description: + - Configures the username to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value is + not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME) + will be used instead. + type: str + password: + description: + - Specifies the password to use to authenticate the connection to the remote device. This + value is used to authenticate the SSH session. If the value is not specified + in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) will + be used instead. + type: str + timeout: + description: + - Specifies the timeout in seconds for communicating with the network device for + either connecting or sending commands. If the timeout is exceeded before the + operation is completed, the module will error. + type: int + default: 10 + ssh_keyfile: + description: + - Specifies the SSH key to use to authenticate the connection to the remote device. This + value is the path to the key used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_SSH_KEYFILE) + will be used instead. + type: path + hostkey_verify: + description: + - If set to C(yes), the ssh host key of the device must match a ssh key present + on the host if set to C(no), the ssh host key of the device is not checked. + type: bool + default: true + look_for_keys: + description: + - Enables looking in the usual locations for the ssh keys (e.g. :file:`~/.ssh/id_*`) + type: bool + default: true +notes: +- For information on using netconf see the :ref:`Platform Options guide using Netconf` +- For more information on using Ansible to manage network devices see the :ref:`Ansible + Network Guide ` +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py new file mode 100644 index 0000000..ad65f6e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019 Ansible, Inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: {} +notes: +- This module is supported on C(ansible_network_os) network platforms. See the :ref:`Network + Platform Options ` for details. +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py new file mode 100644 index 0000000..6ae47a7 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py @@ -0,0 +1,1186 @@ +# (c) 2014, Maciej Delmanowski +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from functools import partial +import types + +try: + import netaddr +except ImportError: + # in this case, we'll make the filters return error messages (see bottom) + netaddr = None +else: + + class mac_linux(netaddr.mac_unix): + pass + + mac_linux.word_fmt = "%.2x" + +from ansible import errors + + +# ---- IP address and network query helpers ---- +def _empty_ipaddr_query(v, vtype): + # We don't have any query to process, so just check what type the user + # expects, and return the IP address in a correct format + if v: + if vtype == "address": + return str(v.ip) + elif vtype == "network": + return str(v) + + +def _first_last(v): + if v.size == 2: + first_usable = int(netaddr.IPAddress(v.first)) + last_usable = int(netaddr.IPAddress(v.last)) + return first_usable, last_usable + elif v.size > 1: + first_usable = int(netaddr.IPAddress(v.first + 1)) + last_usable = int(netaddr.IPAddress(v.last - 1)) + return first_usable, last_usable + + +def _6to4_query(v, vtype, value): + if v.version == 4: + + if v.size == 1: + ipconv = str(v.ip) + elif v.size > 1: + if v.ip != v.network: + ipconv = str(v.ip) + else: + ipconv = False + + if ipaddr(ipconv, "public"): + numbers = list(map(int, ipconv.split("."))) + + try: + return "2002:{:02x}{:02x}:{:02x}{:02x}::1/48".format(*numbers) + except Exception: + return False + + elif v.version == 6: + if vtype == "address": + if ipaddr(str(v), "2002::/16"): + return value + elif vtype == "network": + if v.ip != v.network: + if ipaddr(str(v.ip), "2002::/16"): + return value + else: + return False + + +def _ip_query(v): + if v.size == 1: + return str(v.ip) + if v.size > 1: + # /31 networks in netaddr have no broadcast address + if v.ip != v.network or not v.broadcast: + return str(v.ip) + + +def _gateway_query(v): + if v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _address_prefix_query(v): + if v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _bool_ipaddr_query(v): + if v: + return True + + +def _broadcast_query(v): + if v.size > 2: + return str(v.broadcast) + + +def _cidr_query(v): + return str(v) + + +def _cidr_lookup_query(v, iplist, value): + try: + if v in iplist: + return value + except Exception: + return False + + +def _first_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size == 2: + return str(netaddr.IPAddress(int(v.network))) + elif v.size > 1: + return str(netaddr.IPAddress(int(v.network) + 1)) + + +def _host_query(v): + if v.size == 1: + return str(v) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _hostmask_query(v): + return str(v.hostmask) + + +def _int_query(v, vtype): + if vtype == "address": + return int(v.ip) + elif vtype == "network": + return str(int(v.ip)) + "/" + str(int(v.prefixlen)) + + +def _ip_prefix_query(v): + if v.size == 2: + return str(v.ip) + "/" + str(v.prefixlen) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _ip_netmask_query(v): + if v.size == 2: + return str(v.ip) + " " + str(v.netmask) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + " " + str(v.netmask) + + +""" +def _ip_wildcard_query(v): + if v.size == 2: + return str(v.ip) + ' ' + str(v.hostmask) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + ' ' + str(v.hostmask) +""" + + +def _ipv4_query(v, value): + if v.version == 6: + try: + return str(v.ipv4()) + except Exception: + return False + else: + return value + + +def _ipv6_query(v, value): + if v.version == 4: + return str(v.ipv6()) + else: + return value + + +def _last_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + return str(netaddr.IPAddress(last_usable)) + + +def _link_local_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if v.version == 4: + if ipaddr(str(v_ip), "169.254.0.0/24"): + return value + + elif v.version == 6: + if ipaddr(str(v_ip), "fe80::/10"): + return value + + +def _loopback_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if v_ip.is_loopback(): + return value + + +def _multicast_query(v, value): + if v.is_multicast(): + return value + + +def _net_query(v): + if v.size > 1: + if v.ip == v.network: + return str(v.network) + "/" + str(v.prefixlen) + + +def _netmask_query(v): + return str(v.netmask) + + +def _network_query(v): + """Return the network of a given IP or subnet""" + return str(v.network) + + +def _network_id_query(v): + """Return the network of a given IP or subnet""" + return str(v.network) + + +def _network_netmask_query(v): + return str(v.network) + " " + str(v.netmask) + + +def _network_wildcard_query(v): + return str(v.network) + " " + str(v.hostmask) + + +def _next_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + next_ip = int(netaddr.IPAddress(int(v.ip) + 1)) + if next_ip >= first_usable and next_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) + 1)) + + +def _peer_query(v, vtype): + if vtype == "address": + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size == 2: + return str(netaddr.IPAddress(int(v.ip) ^ 1)) + if v.size == 4: + if int(v.ip) % 4 == 0: + raise errors.AnsibleFilterError( + "Network address of /30 has no peer" + ) + if int(v.ip) % 4 == 3: + raise errors.AnsibleFilterError( + "Broadcast address of /30 has no peer" + ) + return str(netaddr.IPAddress(int(v.ip) ^ 3)) + raise errors.AnsibleFilterError("Not a point-to-point network") + + +def _prefix_query(v): + return int(v.prefixlen) + + +def _previous_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + previous_ip = int(netaddr.IPAddress(int(v.ip) - 1)) + if previous_ip >= first_usable and previous_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) - 1)) + + +def _private_query(v, value): + if v.is_private(): + return value + + +def _public_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if ( + v_ip.is_unicast() + and not v_ip.is_private() + and not v_ip.is_loopback() + and not v_ip.is_netmask() + and not v_ip.is_hostmask() + ): + return value + + +def _range_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + first_usable = str(netaddr.IPAddress(first_usable)) + last_usable = str(netaddr.IPAddress(last_usable)) + return "{0}-{1}".format(first_usable, last_usable) + + +def _revdns_query(v): + v_ip = netaddr.IPAddress(str(v.ip)) + return v_ip.reverse_dns + + +def _size_query(v): + return v.size + + +def _size_usable_query(v): + if v.size == 1: + return 0 + elif v.size == 2: + return 2 + return v.size - 2 + + +def _subnet_query(v): + return str(v.cidr) + + +def _type_query(v): + if v.size == 1: + return "address" + if v.size > 1: + if v.ip != v.network: + return "address" + else: + return "network" + + +def _unicast_query(v, value): + if v.is_unicast(): + return value + + +def _version_query(v): + return v.version + + +def _wrap_query(v, vtype, value): + if v.version == 6: + if vtype == "address": + return "[" + str(v.ip) + "]" + elif vtype == "network": + return "[" + str(v.ip) + "]/" + str(v.prefixlen) + else: + return value + + +# ---- HWaddr query helpers ---- +def _bare_query(v): + v.dialect = netaddr.mac_bare + return str(v) + + +def _bool_hwaddr_query(v): + if v: + return True + + +def _int_hwaddr_query(v): + return int(v) + + +def _cisco_query(v): + v.dialect = netaddr.mac_cisco + return str(v) + + +def _empty_hwaddr_query(v, value): + if v: + return value + + +def _linux_query(v): + v.dialect = mac_linux + return str(v) + + +def _postgresql_query(v): + v.dialect = netaddr.mac_pgsql + return str(v) + + +def _unix_query(v): + v.dialect = netaddr.mac_unix + return str(v) + + +def _win_query(v): + v.dialect = netaddr.mac_eui48 + return str(v) + + +# ---- IP address and network filters ---- + +# Returns a minified list of subnets or a single subnet that spans all of +# the inputs. +def cidr_merge(value, action="merge"): + if not hasattr(value, "__iter__"): + raise errors.AnsibleFilterError( + "cidr_merge: expected iterable, got " + repr(value) + ) + + if action == "merge": + try: + return [str(ip) for ip in netaddr.cidr_merge(value)] + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + + elif action == "span": + # spanning_cidr needs at least two values + if len(value) == 0: + return None + elif len(value) == 1: + try: + return str(netaddr.IPNetwork(value[0])) + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + else: + try: + return str(netaddr.spanning_cidr(value)) + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + + else: + raise errors.AnsibleFilterError( + "cidr_merge: invalid action '%s'" % action + ) + + +def ipaddr(value, query="", version=False, alias="ipaddr"): + """ Check if string is an IP address or network and filter it """ + + query_func_extra_args = { + "": ("vtype",), + "6to4": ("vtype", "value"), + "cidr_lookup": ("iplist", "value"), + "first_usable": ("vtype",), + "int": ("vtype",), + "ipv4": ("value",), + "ipv6": ("value",), + "last_usable": ("vtype",), + "link-local": ("value",), + "loopback": ("value",), + "lo": ("value",), + "multicast": ("value",), + "next_usable": ("vtype",), + "peer": ("vtype",), + "previous_usable": ("vtype",), + "private": ("value",), + "public": ("value",), + "unicast": ("value",), + "range_usable": ("vtype",), + "wrap": ("vtype", "value"), + } + + query_func_map = { + "": _empty_ipaddr_query, + "6to4": _6to4_query, + "address": _ip_query, + "address/prefix": _address_prefix_query, # deprecate + "bool": _bool_ipaddr_query, + "broadcast": _broadcast_query, + "cidr": _cidr_query, + "cidr_lookup": _cidr_lookup_query, + "first_usable": _first_usable_query, + "gateway": _gateway_query, # deprecate + "gw": _gateway_query, # deprecate + "host": _host_query, + "host/prefix": _address_prefix_query, # deprecate + "hostmask": _hostmask_query, + "hostnet": _gateway_query, # deprecate + "int": _int_query, + "ip": _ip_query, + "ip/prefix": _ip_prefix_query, + "ip_netmask": _ip_netmask_query, + # 'ip_wildcard': _ip_wildcard_query, built then could not think of use case + "ipv4": _ipv4_query, + "ipv6": _ipv6_query, + "last_usable": _last_usable_query, + "link-local": _link_local_query, + "lo": _loopback_query, + "loopback": _loopback_query, + "multicast": _multicast_query, + "net": _net_query, + "next_usable": _next_usable_query, + "netmask": _netmask_query, + "network": _network_query, + "network_id": _network_id_query, + "network/prefix": _subnet_query, + "network_netmask": _network_netmask_query, + "network_wildcard": _network_wildcard_query, + "peer": _peer_query, + "prefix": _prefix_query, + "previous_usable": _previous_usable_query, + "private": _private_query, + "public": _public_query, + "range_usable": _range_usable_query, + "revdns": _revdns_query, + "router": _gateway_query, # deprecate + "size": _size_query, + "size_usable": _size_usable_query, + "subnet": _subnet_query, + "type": _type_query, + "unicast": _unicast_query, + "v4": _ipv4_query, + "v6": _ipv6_query, + "version": _version_query, + "wildcard": _hostmask_query, + "wrap": _wrap_query, + } + + vtype = None + + if not value: + return False + + elif value is True: + return False + + # Check if value is a list and parse each element + elif isinstance(value, (list, tuple, types.GeneratorType)): + + _ret = [] + for element in value: + if ipaddr(element, str(query), version): + _ret.append(ipaddr(element, str(query), version)) + + if _ret: + return _ret + else: + return list() + + # Check if value is a number and convert it to an IP address + elif str(value).isdigit(): + + # We don't know what IP version to assume, so let's check IPv4 first, + # then IPv6 + try: + if (not version) or (version and version == 4): + v = netaddr.IPNetwork("0.0.0.0/0") + v.value = int(value) + v.prefixlen = 32 + elif version and version == 6: + v = netaddr.IPNetwork("::/0") + v.value = int(value) + v.prefixlen = 128 + + # IPv4 didn't work the first time, so it definitely has to be IPv6 + except Exception: + try: + v = netaddr.IPNetwork("::/0") + v.value = int(value) + v.prefixlen = 128 + + # The value is too big for IPv6. Are you a nanobot? + except Exception: + return False + + # We got an IP address, let's mark it as such + value = str(v) + vtype = "address" + + # value has not been recognized, check if it's a valid IP string + else: + try: + v = netaddr.IPNetwork(value) + + # value is a valid IP string, check if user specified + # CIDR prefix or just an IP address, this will indicate default + # output format + try: + address, prefix = value.split("/") + vtype = "network" + except Exception: + vtype = "address" + + # value hasn't been recognized, maybe it's a numerical CIDR? + except Exception: + try: + address, prefix = value.split("/") + address.isdigit() + address = int(address) + prefix.isdigit() + prefix = int(prefix) + + # It's not numerical CIDR, give up + except Exception: + return False + + # It is something, so let's try and build a CIDR from the parts + try: + v = netaddr.IPNetwork("0.0.0.0/0") + v.value = address + v.prefixlen = prefix + + # It's not a valid IPv4 CIDR + except Exception: + try: + v = netaddr.IPNetwork("::/0") + v.value = address + v.prefixlen = prefix + + # It's not a valid IPv6 CIDR. Give up. + except Exception: + return False + + # We have a valid CIDR, so let's write it in correct format + value = str(v) + vtype = "network" + + # We have a query string but it's not in the known query types. Check if + # that string is a valid subnet, if so, we can check later if given IP + # address/network is inside that specific subnet + try: + # ?? 6to4 and link-local were True here before. Should they still? + if ( + query + and (query not in query_func_map or query == "cidr_lookup") + and not str(query).isdigit() + and ipaddr(query, "network") + ): + iplist = netaddr.IPSet([netaddr.IPNetwork(query)]) + query = "cidr_lookup" + except Exception: + pass + + # This code checks if value maches the IP version the user wants, ie. if + # it's any version ("ipaddr()"), IPv4 ("ipv4()") or IPv6 ("ipv6()") + # If version does not match, return False + if version and v.version != version: + return False + + extras = [] + for arg in query_func_extra_args.get(query, tuple()): + extras.append(locals()[arg]) + try: + return query_func_map[query](v, *extras) + except KeyError: + try: + float(query) + if v.size == 1: + if vtype == "address": + return str(v.ip) + elif vtype == "network": + return str(v) + + elif v.size > 1: + try: + return str(v[query]) + "/" + str(v.prefixlen) + except Exception: + return False + + else: + return value + + except Exception: + raise errors.AnsibleFilterError( + alias + ": unknown filter type: %s" % query + ) + + return False + + +def ipmath(value, amount): + try: + if "/" in value: + ip = netaddr.IPNetwork(value).ip + else: + ip = netaddr.IPAddress(value) + except (netaddr.AddrFormatError, ValueError): + msg = "You must pass a valid IP address; {0} is invalid".format(value) + raise errors.AnsibleFilterError(msg) + + if not isinstance(amount, int): + msg = ( + "You must pass an integer for arithmetic; " + "{0} is not a valid integer" + ).format(amount) + raise errors.AnsibleFilterError(msg) + + return str(ip + amount) + + +def ipwrap(value, query=""): + try: + if isinstance(value, (list, tuple, types.GeneratorType)): + _ret = [] + for element in value: + if ipaddr(element, query, version=False, alias="ipwrap"): + _ret.append(ipaddr(element, "wrap")) + else: + _ret.append(element) + + return _ret + else: + _ret = ipaddr(value, query, version=False, alias="ipwrap") + if _ret: + return ipaddr(_ret, "wrap") + else: + return value + + except Exception: + return value + + +def ipv4(value, query=""): + return ipaddr(value, query, version=4, alias="ipv4") + + +def ipv6(value, query=""): + return ipaddr(value, query, version=6, alias="ipv6") + + +# Split given subnet into smaller subnets or find out the biggest subnet of +# a given IP address with given CIDR prefix +# Usage: +# +# - address or address/prefix | ipsubnet +# returns CIDR subnet of a given input +# +# - address/prefix | ipsubnet(cidr) +# returns number of possible subnets for given CIDR prefix +# +# - address/prefix | ipsubnet(cidr, index) +# returns new subnet with given CIDR prefix +# +# - address | ipsubnet(cidr) +# returns biggest subnet with given CIDR prefix that address belongs to +# +# - address | ipsubnet(cidr, index) +# returns next indexed subnet which contains given address +# +# - address/prefix | ipsubnet(subnet/prefix) +# return the index of the subnet in the subnet +def ipsubnet(value, query="", index="x"): + """ Manipulate IPv4/IPv6 subnets """ + + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + value = netaddr.IPNetwork(v) + except Exception: + return False + query_string = str(query) + if not query: + return str(value) + + elif query_string.isdigit(): + vsize = ipaddr(v, "size") + query = int(query) + + try: + float(index) + index = int(index) + + if vsize > 1: + try: + return str(list(value.subnet(query))[index]) + except Exception: + return False + + elif vsize == 1: + try: + return str(value.supernet(query)[index]) + except Exception: + return False + + except Exception: + if vsize > 1: + try: + return str(len(list(value.subnet(query)))) + except Exception: + return False + + elif vsize == 1: + try: + return str(value.supernet(query)[0]) + except Exception: + return False + + elif query_string: + vtype = ipaddr(query, "type") + if vtype == "address": + v = ipaddr(query, "cidr") + elif vtype == "network": + v = ipaddr(query, "subnet") + else: + msg = "You must pass a valid subnet or IP address; {0} is invalid".format( + query_string + ) + raise errors.AnsibleFilterError(msg) + query = netaddr.IPNetwork(v) + for i, subnet in enumerate(query.subnet(value.prefixlen), 1): + if subnet == value: + return str(i) + msg = "{0} is not in the subnet {1}".format(value.cidr, query.cidr) + raise errors.AnsibleFilterError(msg) + return False + + +# Returns the nth host within a network described by value. +# Usage: +# +# - address or address/prefix | nthhost(nth) +# returns the nth host within the given network +def nthhost(value, query=""): + """ Get the nth host within a given network """ + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + value = netaddr.IPNetwork(v) + except Exception: + return False + + if not query: + return False + + try: + nth = int(query) + if value.size > nth: + return value[nth] + + except ValueError: + return False + + return False + + +# Returns the next nth usable ip within a network described by value. +def next_nth_usable(value, offset): + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + v = netaddr.IPNetwork(v) + except Exception: + return False + + if type(offset) != int: + raise errors.AnsibleFilterError("Must pass in an integer") + if v.size > 1: + first_usable, last_usable = _first_last(v) + nth_ip = int(netaddr.IPAddress(int(v.ip) + offset)) + if nth_ip >= first_usable and nth_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) + offset)) + + +# Returns the previous nth usable ip within a network described by value. +def previous_nth_usable(value, offset): + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + v = netaddr.IPNetwork(v) + except Exception: + return False + + if type(offset) != int: + raise errors.AnsibleFilterError("Must pass in an integer") + if v.size > 1: + first_usable, last_usable = _first_last(v) + nth_ip = int(netaddr.IPAddress(int(v.ip) - offset)) + if nth_ip >= first_usable and nth_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) - offset)) + + +def _range_checker(ip_check, first, last): + """ + Tests whether an ip address is within the bounds of the first and last address. + + :param ip_check: The ip to test if it is within first and last. + :param first: The first IP in the range to test against. + :param last: The last IP in the range to test against. + + :return: bool + """ + if ip_check >= first and ip_check <= last: + return True + else: + return False + + +def _address_normalizer(value): + """ + Used to validate an address or network type and return it in a consistent format. + This is being used for future use cases not currently available such as an address range. + + :param value: The string representation of an address or network. + + :return: The address or network in the normalized form. + """ + try: + vtype = ipaddr(value, "type") + if vtype == "address" or vtype == "network": + v = ipaddr(value, "subnet") + except Exception: + return False + + return v + + +def network_in_usable(value, test): + """ + Checks whether 'test' is a useable address or addresses in 'value' + + :param: value: The string representation of an address or network to test against. + :param test: The string representation of an address or network to validate if it is within the range of 'value'. + + :return: bool + """ + # normalize value and test variables into an ipaddr + v = _address_normalizer(value) + w = _address_normalizer(test) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + v_first = ipaddr(ipaddr(v, "first_usable") or ipaddr(v, "address"), "int") + v_last = ipaddr(ipaddr(v, "last_usable") or ipaddr(v, "address"), "int") + w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int") + w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int") + + if _range_checker(w_first, v_first, v_last) and _range_checker( + w_last, v_first, v_last + ): + return True + else: + return False + + +def network_in_network(value, test): + """ + Checks whether the 'test' address or addresses are in 'value', including broadcast and network + + :param: value: The network address or range to test against. + :param test: The address or network to validate if it is within the range of 'value'. + + :return: bool + """ + # normalize value and test variables into an ipaddr + v = _address_normalizer(value) + w = _address_normalizer(test) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + v_first = ipaddr(ipaddr(v, "network") or ipaddr(v, "address"), "int") + v_last = ipaddr(ipaddr(v, "broadcast") or ipaddr(v, "address"), "int") + w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int") + w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int") + + if _range_checker(w_first, v_first, v_last) and _range_checker( + w_last, v_first, v_last + ): + return True + else: + return False + + +def reduce_on_network(value, network): + """ + Reduces a list of addresses to only the addresses that match a given network. + + :param: value: The list of addresses to filter on. + :param: network: The network to validate against. + + :return: The reduced list of addresses. + """ + # normalize network variable into an ipaddr + n = _address_normalizer(network) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + n_first = ipaddr(ipaddr(n, "network") or ipaddr(n, "address"), "int") + n_last = ipaddr(ipaddr(n, "broadcast") or ipaddr(n, "address"), "int") + + # create an empty list to fill and return + r = [] + + for address in value: + # normalize address variables into an ipaddr + a = _address_normalizer(address) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + a_first = ipaddr(ipaddr(a, "network") or ipaddr(a, "address"), "int") + a_last = ipaddr(ipaddr(a, "broadcast") or ipaddr(a, "address"), "int") + + if _range_checker(a_first, n_first, n_last) and _range_checker( + a_last, n_first, n_last + ): + r.append(address) + + return r + + +# Returns the SLAAC address within a network for a given HW/MAC address. +# Usage: +# +# - prefix | slaac(mac) +def slaac(value, query=""): + """ Get the SLAAC address within given network """ + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + if ipaddr(value, "version") != 6: + return False + + value = netaddr.IPNetwork(v) + except Exception: + return False + + if not query: + return False + + try: + mac = hwaddr(query, alias="slaac") + + eui = netaddr.EUI(mac) + except Exception: + return False + + return eui.ipv6(value.network) + + +# ---- HWaddr / MAC address filters ---- +def hwaddr(value, query="", alias="hwaddr"): + """ Check if string is a HW/MAC address and filter it """ + + query_func_extra_args = {"": ("value",)} + + query_func_map = { + "": _empty_hwaddr_query, + "bare": _bare_query, + "bool": _bool_hwaddr_query, + "int": _int_hwaddr_query, + "cisco": _cisco_query, + "eui48": _win_query, + "linux": _linux_query, + "pgsql": _postgresql_query, + "postgresql": _postgresql_query, + "psql": _postgresql_query, + "unix": _unix_query, + "win": _win_query, + } + + try: + v = netaddr.EUI(value) + except Exception: + if query and query != "bool": + raise errors.AnsibleFilterError( + alias + ": not a hardware address: %s" % value + ) + + extras = [] + for arg in query_func_extra_args.get(query, tuple()): + extras.append(locals()[arg]) + try: + return query_func_map[query](v, *extras) + except KeyError: + raise errors.AnsibleFilterError( + alias + ": unknown filter type: %s" % query + ) + + return False + + +def macaddr(value, query=""): + return hwaddr(value, query, alias="macaddr") + + +def _need_netaddr(f_name, *args, **kwargs): + raise errors.AnsibleFilterError( + "The %s filter requires python's netaddr be " + "installed on the ansible controller" % f_name + ) + + +def ip4_hex(arg, delimiter=""): + """ Convert an IPv4 address to Hexadecimal notation """ + numbers = list(map(int, arg.split("."))) + return "{0:02x}{sep}{1:02x}{sep}{2:02x}{sep}{3:02x}".format( + *numbers, sep=delimiter + ) + + +# ---- Ansible filters ---- +class FilterModule(object): + """ IP address and network manipulation filters """ + + filter_map = { + # IP addresses and networks + "cidr_merge": cidr_merge, + "ipaddr": ipaddr, + "ipmath": ipmath, + "ipwrap": ipwrap, + "ip4_hex": ip4_hex, + "ipv4": ipv4, + "ipv6": ipv6, + "ipsubnet": ipsubnet, + "next_nth_usable": next_nth_usable, + "network_in_network": network_in_network, + "network_in_usable": network_in_usable, + "reduce_on_network": reduce_on_network, + "nthhost": nthhost, + "previous_nth_usable": previous_nth_usable, + "slaac": slaac, + # MAC / HW addresses + "hwaddr": hwaddr, + "macaddr": macaddr, + } + + def filters(self): + if netaddr: + return self.filter_map + else: + # Need to install python's netaddr for these filters to work + return dict( + (f, partial(_need_netaddr, f)) for f in self.filter_map + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py new file mode 100644 index 0000000..72d6c86 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py @@ -0,0 +1,531 @@ +# +# {c) 2017 Red Hat, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +# Make coding more python3-ish +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import re +import os +import traceback +import string + +from collections.abc import Mapping +from xml.etree.ElementTree import fromstring + +from ansible.module_utils._text import to_native, to_text +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + Template, +) +from ansible.module_utils.six import iteritems, string_types +from ansible.errors import AnsibleError, AnsibleFilterError +from ansible.utils.display import Display +from ansible.utils.encrypt import passlib_or_crypt, random_password + +try: + import yaml + + HAS_YAML = True +except ImportError: + HAS_YAML = False + +try: + import textfsm + + HAS_TEXTFSM = True +except ImportError: + HAS_TEXTFSM = False + +display = Display() + + +def re_matchall(regex, value): + objects = list() + for match in re.findall(regex.pattern, value, re.M): + obj = {} + if regex.groupindex: + for name, index in iteritems(regex.groupindex): + if len(regex.groupindex) == 1: + obj[name] = match + else: + obj[name] = match[index - 1] + objects.append(obj) + return objects + + +def re_search(regex, value): + obj = {} + match = regex.search(value, re.M) + if match: + items = list(match.groups()) + if regex.groupindex: + for name, index in iteritems(regex.groupindex): + obj[name] = items[index - 1] + return obj + + +def parse_cli(output, tmpl): + if not isinstance(output, string_types): + raise AnsibleError( + "parse_cli input should be a string, but was given a input of %s" + % (type(output)) + ) + + if not os.path.exists(tmpl): + raise AnsibleError("unable to locate parse_cli template: %s" % tmpl) + + try: + template = Template() + except ImportError as exc: + raise AnsibleError(to_native(exc)) + + with open(tmpl) as tmpl_fh: + tmpl_content = tmpl_fh.read() + + spec = yaml.safe_load(tmpl_content) + obj = {} + + for name, attrs in iteritems(spec["keys"]): + value = attrs["value"] + + try: + variables = spec.get("vars", {}) + value = template(value, variables) + except Exception: + pass + + if "start_block" in attrs and "end_block" in attrs: + start_block = re.compile(attrs["start_block"]) + end_block = re.compile(attrs["end_block"]) + + blocks = list() + lines = None + block_started = False + + for line in output.split("\n"): + match_start = start_block.match(line) + match_end = end_block.match(line) + + if match_start: + lines = list() + lines.append(line) + block_started = True + + elif match_end: + if lines: + lines.append(line) + blocks.append("\n".join(lines)) + block_started = False + + elif block_started: + if lines: + lines.append(line) + + regex_items = [re.compile(r) for r in attrs["items"]] + objects = list() + + for block in blocks: + if isinstance(value, Mapping) and "key" not in value: + items = list() + for regex in regex_items: + match = regex.search(block) + if match: + item_values = match.groupdict() + item_values["match"] = list(match.groups()) + items.append(item_values) + else: + items.append(None) + + obj = {} + for k, v in iteritems(value): + try: + obj[k] = template( + v, {"item": items}, fail_on_undefined=False + ) + except Exception: + obj[k] = None + objects.append(obj) + + elif isinstance(value, Mapping): + items = list() + for regex in regex_items: + match = regex.search(block) + if match: + item_values = match.groupdict() + item_values["match"] = list(match.groups()) + items.append(item_values) + else: + items.append(None) + + key = template(value["key"], {"item": items}) + values = dict( + [ + (k, template(v, {"item": items})) + for k, v in iteritems(value["values"]) + ] + ) + objects.append({key: values}) + + return objects + + elif "items" in attrs: + regexp = re.compile(attrs["items"]) + when = attrs.get("when") + conditional = ( + "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when + ) + + if isinstance(value, Mapping) and "key" not in value: + values = list() + + for item in re_matchall(regexp, output): + entry = {} + + for item_key, item_value in iteritems(value): + entry[item_key] = template(item_value, {"item": item}) + + if when: + if template(conditional, {"item": entry}): + values.append(entry) + else: + values.append(entry) + + obj[name] = values + + elif isinstance(value, Mapping): + values = dict() + + for item in re_matchall(regexp, output): + entry = {} + + for item_key, item_value in iteritems(value["values"]): + entry[item_key] = template(item_value, {"item": item}) + + key = template(value["key"], {"item": item}) + + if when: + if template( + conditional, {"item": {"key": key, "value": entry}} + ): + values[key] = entry + else: + values[key] = entry + + obj[name] = values + + else: + item = re_search(regexp, output) + obj[name] = template(value, {"item": item}) + + else: + obj[name] = value + + return obj + + +def parse_cli_textfsm(value, template): + if not HAS_TEXTFSM: + raise AnsibleError( + "parse_cli_textfsm filter requires TextFSM library to be installed" + ) + + if not isinstance(value, string_types): + raise AnsibleError( + "parse_cli_textfsm input should be a string, but was given a input of %s" + % (type(value)) + ) + + if not os.path.exists(template): + raise AnsibleError( + "unable to locate parse_cli_textfsm template: %s" % template + ) + + try: + template = open(template) + except IOError as exc: + raise AnsibleError(to_native(exc)) + + re_table = textfsm.TextFSM(template) + fsm_results = re_table.ParseText(value) + + results = list() + for item in fsm_results: + results.append(dict(zip(re_table.header, item))) + + return results + + +def _extract_param(template, root, attrs, value): + + key = None + when = attrs.get("when") + conditional = "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when + param_to_xpath_map = attrs["items"] + + if isinstance(value, Mapping): + key = value.get("key", None) + if key: + value = value["values"] + + entries = dict() if key else list() + + for element in root.findall(attrs["top"]): + entry = dict() + item_dict = dict() + for param, param_xpath in iteritems(param_to_xpath_map): + fields = None + try: + fields = element.findall(param_xpath) + except Exception: + display.warning( + "Failed to evaluate value of '%s' with XPath '%s'.\nUnexpected error: %s." + % (param, param_xpath, traceback.format_exc()) + ) + + tags = param_xpath.split("/") + + # check if xpath ends with attribute. + # If yes set attribute key/value dict to param value in case attribute matches + # else if it is a normal xpath assign matched element text value. + if len(tags) and tags[-1].endswith("]"): + if fields: + if len(fields) > 1: + item_dict[param] = [field.attrib for field in fields] + else: + item_dict[param] = fields[0].attrib + else: + item_dict[param] = {} + else: + if fields: + if len(fields) > 1: + item_dict[param] = [field.text for field in fields] + else: + item_dict[param] = fields[0].text + else: + item_dict[param] = None + + if isinstance(value, Mapping): + for item_key, item_value in iteritems(value): + entry[item_key] = template(item_value, {"item": item_dict}) + else: + entry = template(value, {"item": item_dict}) + + if key: + expanded_key = template(key, {"item": item_dict}) + if when: + if template( + conditional, + {"item": {"key": expanded_key, "value": entry}}, + ): + entries[expanded_key] = entry + else: + entries[expanded_key] = entry + else: + if when: + if template(conditional, {"item": entry}): + entries.append(entry) + else: + entries.append(entry) + + return entries + + +def parse_xml(output, tmpl): + if not os.path.exists(tmpl): + raise AnsibleError("unable to locate parse_xml template: %s" % tmpl) + + if not isinstance(output, string_types): + raise AnsibleError( + "parse_xml works on string input, but given input of : %s" + % type(output) + ) + + root = fromstring(output) + try: + template = Template() + except ImportError as exc: + raise AnsibleError(to_native(exc)) + + with open(tmpl) as tmpl_fh: + tmpl_content = tmpl_fh.read() + + spec = yaml.safe_load(tmpl_content) + obj = {} + + for name, attrs in iteritems(spec["keys"]): + value = attrs["value"] + + try: + variables = spec.get("vars", {}) + value = template(value, variables) + except Exception: + pass + + if "items" in attrs: + obj[name] = _extract_param(template, root, attrs, value) + else: + obj[name] = value + + return obj + + +def type5_pw(password, salt=None): + if not isinstance(password, string_types): + raise AnsibleFilterError( + "type5_pw password input should be a string, but was given a input of %s" + % (type(password).__name__) + ) + + salt_chars = u"".join( + (to_text(string.ascii_letters), to_text(string.digits), u"./") + ) + if salt is not None and not isinstance(salt, string_types): + raise AnsibleFilterError( + "type5_pw salt input should be a string, but was given a input of %s" + % (type(salt).__name__) + ) + elif not salt: + salt = random_password(length=4, chars=salt_chars) + elif not set(salt) <= set(salt_chars): + raise AnsibleFilterError( + "type5_pw salt used inproper characters, must be one of %s" + % (salt_chars) + ) + + encrypted_password = passlib_or_crypt(password, "md5_crypt", salt=salt) + + return encrypted_password + + +def hash_salt(password): + + split_password = password.split("$") + if len(split_password) != 4: + raise AnsibleFilterError( + "Could not parse salt out password correctly from {0}".format( + password + ) + ) + else: + return split_password[2] + + +def comp_type5( + unencrypted_password, encrypted_password, return_original=False +): + + salt = hash_salt(encrypted_password) + if type5_pw(unencrypted_password, salt) == encrypted_password: + if return_original is True: + return encrypted_password + else: + return True + return False + + +def vlan_parser(vlan_list, first_line_len=48, other_line_len=44): + + """ + Input: Unsorted list of vlan integers + Output: Sorted string list of integers according to IOS-like vlan list rules + + 1. Vlans are listed in ascending order + 2. Runs of 3 or more consecutive vlans are listed with a dash + 3. The first line of the list can be first_line_len characters long + 4. Subsequent list lines can be other_line_len characters + """ + + # Sort and remove duplicates + sorted_list = sorted(set(vlan_list)) + + if sorted_list[0] < 1 or sorted_list[-1] > 4094: + raise AnsibleFilterError("Valid VLAN range is 1-4094") + + parse_list = [] + idx = 0 + while idx < len(sorted_list): + start = idx + end = start + while end < len(sorted_list) - 1: + if sorted_list[end + 1] - sorted_list[end] == 1: + end += 1 + else: + break + + if start == end: + # Single VLAN + parse_list.append(str(sorted_list[idx])) + elif start + 1 == end: + # Run of 2 VLANs + parse_list.append(str(sorted_list[start])) + parse_list.append(str(sorted_list[end])) + else: + # Run of 3 or more VLANs + parse_list.append( + str(sorted_list[start]) + "-" + str(sorted_list[end]) + ) + idx = end + 1 + + line_count = 0 + result = [""] + for vlans in parse_list: + # First line (" switchport trunk allowed vlan ") + if line_count == 0: + if len(result[line_count] + vlans) > first_line_len: + result.append("") + line_count += 1 + result[line_count] += vlans + "," + else: + result[line_count] += vlans + "," + + # Subsequent lines (" switchport trunk allowed vlan add ") + else: + if len(result[line_count] + vlans) > other_line_len: + result.append("") + line_count += 1 + result[line_count] += vlans + "," + else: + result[line_count] += vlans + "," + + # Remove trailing orphan commas + for idx in range(0, len(result)): + result[idx] = result[idx].rstrip(",") + + # Sometimes text wraps to next line, but there are no remaining VLANs + if "" in result: + result.remove("") + + return result + + +class FilterModule(object): + """Filters for working with output from network devices""" + + filter_map = { + "parse_cli": parse_cli, + "parse_cli_textfsm": parse_cli_textfsm, + "parse_xml": parse_xml, + "type5_pw": type5_pw, + "hash_salt": hash_salt, + "comp_type5": comp_type5, + "vlan_parser": vlan_parser, + } + + def filters(self): + return self.filter_map diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py new file mode 100644 index 0000000..8afb3e5 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py @@ -0,0 +1,91 @@ +# Copyright (c) 2018 Cisco and/or its affiliates. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +httpapi: restconf +short_description: HttpApi Plugin for devices supporting Restconf API +description: +- This HttpApi plugin provides methods to connect to Restconf API endpoints. +options: + root_path: + type: str + description: + - Specifies the location of the Restconf root. + default: /restconf + vars: + - name: ansible_httpapi_restconf_root +""" + +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import ConnectionError +from ansible.module_utils.six.moves.urllib.error import HTTPError +from ansible.plugins.httpapi import HttpApiBase + + +CONTENT_TYPE = "application/yang-data+json" + + +class HttpApi(HttpApiBase): + def send_request(self, data, **message_kwargs): + if data: + data = json.dumps(data) + + path = "/".join( + [ + self.get_option("root_path").rstrip("/"), + message_kwargs.get("path", "").lstrip("/"), + ] + ) + + headers = { + "Content-Type": message_kwargs.get("content_type") or CONTENT_TYPE, + "Accept": message_kwargs.get("accept") or CONTENT_TYPE, + } + response, response_data = self.connection.send( + path, data, headers=headers, method=message_kwargs.get("method") + ) + + return handle_response(response, response_data) + + +def handle_response(response, response_data): + try: + response_data = json.loads(response_data.read()) + except ValueError: + response_data = response_data.read() + + if isinstance(response, HTTPError): + if response_data: + if "errors" in response_data: + errors = response_data["errors"]["error"] + error_text = "\n".join( + (error["error-message"] for error in errors) + ) + else: + error_text = response_data + + raise ConnectionError(error_text, code=response.code) + raise ConnectionError(to_text(response), code=response.code) + + return response_data diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py new file mode 100644 index 0000000..dc0a19f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py @@ -0,0 +1,2578 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file, and this file only, is based on +# Lib/ipaddress.py of cpython +# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved" +# are retained in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. + +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +from __future__ import unicode_literals + + +import itertools +import struct + + +# The following makes it easier for us to script updates of the bundled code and is not part of +# upstream +_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"} + +__version__ = "1.0.22" + +# Compatibility functions +_compat_int_types = (int,) +try: + _compat_int_types = (int, long) +except NameError: + pass +try: + _compat_str = unicode +except NameError: + _compat_str = str + assert bytes != str +if b"\0"[0] == 0: # Python 3 semantics + + def _compat_bytes_to_byte_vals(byt): + return byt + + +else: + + def _compat_bytes_to_byte_vals(byt): + return [struct.unpack(b"!B", b)[0] for b in byt] + + +try: + _compat_int_from_byte_vals = int.from_bytes +except AttributeError: + + def _compat_int_from_byte_vals(bytvals, endianess): + assert endianess == "big" + res = 0 + for bv in bytvals: + assert isinstance(bv, _compat_int_types) + res = (res << 8) + bv + return res + + +def _compat_to_bytes(intval, length, endianess): + assert isinstance(intval, _compat_int_types) + assert endianess == "big" + if length == 4: + if intval < 0 or intval >= 2 ** 32: + raise struct.error("integer out of range for 'I' format code") + return struct.pack(b"!I", intval) + elif length == 16: + if intval < 0 or intval >= 2 ** 128: + raise struct.error("integer out of range for 'QQ' format code") + return struct.pack(b"!QQ", intval >> 64, intval & 0xFFFFFFFFFFFFFFFF) + else: + raise NotImplementedError() + + +if hasattr(int, "bit_length"): + # Not int.bit_length , since that won't work in 2.7 where long exists + def _compat_bit_length(i): + return i.bit_length() + + +else: + + def _compat_bit_length(i): + for res in itertools.count(): + if i >> res == 0: + return res + + +def _compat_range(start, end, step=1): + assert step > 0 + i = start + while i < end: + yield i + i += step + + +class _TotalOrderingMixin(object): + __slots__ = () + + # Helper that derives the other comparison operations from + # __lt__ and __eq__ + # We avoid functools.total_ordering because it doesn't handle + # NotImplemented correctly yet (http://bugs.python.org/issue10042) + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + + def __lt__(self, other): + raise NotImplementedError + + def __le__(self, other): + less = self.__lt__(other) + if less is NotImplemented or not less: + return self.__eq__(other) + return less + + def __gt__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not (less or equal) + + def __ge__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + return not less + + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + "%r does not appear to be an IPv4 or IPv6 address. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" % address + ) + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 address" % address + ) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + "%r does not appear to be an IPv4 or IPv6 network. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" % address + ) + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 network" % address + ) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 interface" % address + ) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return _compat_to_bytes(address, 4, "big") + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return _compat_to_bytes(address, 16, "big") + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = _compat_str(address).split("/") + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) # pylint: disable=stop-iteration-return + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, _compat_bit_length(~number & (number - 1))) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if not ( + isinstance(first, _BaseAddress) and isinstance(last, _BaseAddress) + ): + raise TypeError("first and last must be IP addresses, not networks") + if first.version != last.version: + raise TypeError( + "%s and %s are not of the same version" % (first, last) + ) + if first > last: + raise ValueError("last IP address must be greater than first") + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError("unknown IP version") + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min( + _count_righthand_zero_bits(first_int, ip_bits), + _compat_bit_length(last_int - first_int + 1) - 1, + ) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, + # last.network_address <= net.network_address is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, ips[-1]) + ) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, ips[-1]) + ) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, nets[-1]) + ) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase(_TotalOrderingMixin): + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return _compat_str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = "%200s has no version specified" % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError( + msg % (address, self._max_prefixlen, self._version) + ) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = ( + "%r (len %d != %d) is not permitted as an IPv%d address. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" + ) + raise AddressValueError( + msg % (address, address_len, expected_len, self._version) + ) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits( + ip_int, cls._max_prefixlen + ) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = _compat_to_bytes(ip_int, byteslen, "big") + msg = "Netmask pattern %r mixes zeroes & ones" + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = "%r is not a valid netmask" % netmask_str + raise NetmaskValueError(msg) + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + def __reduce__(self): + return self.__class__, (_compat_str(self),) + + +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + try: + return self._ip == other._ip and self._version == other._version + except AttributeError: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseAddress): + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + if self._ip != other._ip: + return self._ip < other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return _compat_str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + +class _BaseNetwork(_IPAddressBase): + + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + + """ + + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return "%s/%d" % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError("address out of range") + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError("address out of range") + return self._address_class(broadcast + n) + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseNetwork): + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __eq__(self, other): + try: + return ( + self._version == other._version + and self.network_address == other.network_address + and int(self.netmask) == int(other.netmask) + ) + except AttributeError: + return NotImplemented + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return ( + int(self.network_address) + <= int(other._ip) + <= int(self.broadcast_address) + ) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other + or ( + other.network_address in self + or (other.broadcast_address in self) + ) + ) + + @property + def broadcast_address(self): + x = self._cache.get("broadcast_address") + if x is None: + x = self._address_class( + int(self.network_address) | int(self.hostmask) + ) + self._cache["broadcast_address"] = x + return x + + @property + def hostmask(self): + x = self._cache.get("hostmask") + if x is None: + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) + self._cache["hostmask"] = x + return x + + @property + def with_prefixlen(self): + return "%s/%d" % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = "%200s has no associated address class" % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not other.subnet_of(self): + raise ValueError("%s not contained in %s" % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__( + "%s/%s" % (other.network_address, other.prefixlen) + ) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if other.subnet_of(s1): + yield s2 + s1, s2 = s1.subnets() + elif other.subnet_of(s2): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError( + "Error performing exclusion: " + "s1: %s s2: %s other: %s" % (s1, s2, other) + ) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError( + "Error performing exclusion: " + "s1: %s s2: %s other: %s" % (s1, s2, other) + ) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError("new prefix must be longer") + if prefixlen_diff != 1: + raise ValueError("cannot set prefixlen_diff and new_prefix") + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError("prefix length diff must be > 0") + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + "prefix length diff %d is invalid for netblock %s" + % (new_prefixlen, self) + ) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in _compat_range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError("new prefix must be shorter") + if prefixlen_diff != 1: + raise ValueError("cannot set prefixlen_diff and new_prefix") + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + "current prefixlen is %d, cannot have a prefixlen_diff of %d" + % (self.prefixlen, prefixlen_diff) + ) + return self.__class__( + ( + int(self.network_address) + & (int(self.netmask) << prefixlen_diff), + new_prefixlen, + ) + ) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return ( + self.network_address.is_multicast + and self.broadcast_address.is_multicast + ) + + @staticmethod + def _is_subnet_of(a, b): + try: + # Always false if one is v4 and the other is v6. + if a._version != b._version: + raise TypeError( + "%s and %s are not of the same version" % (a, b) + ) + return ( + b.network_address <= a.network_address + and b.broadcast_address >= a.broadcast_address + ) + except AttributeError: + raise TypeError( + "Unable to test subnet containment " + "between %s and %s" % (a, b) + ) + + def subnet_of(self, other): + """Return True if this network is a subnet of other.""" + return self._is_subnet_of(self, other) + + def supernet_of(self, other): + """Return True if this network is a supernet of other.""" + return self._is_subnet_of(other, self) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return ( + self.network_address.is_reserved + and self.broadcast_address.is_reserved + ) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return ( + self.network_address.is_link_local + and self.broadcast_address.is_link_local + ) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return ( + self.network_address.is_private + and self.broadcast_address.is_private + ) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return ( + self.network_address.is_unspecified + and self.broadcast_address.is_unspecified + ) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return ( + self.network_address.is_loopback + and self.broadcast_address.is_loopback + ) + + +class _BaseV4(object): + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2 ** IPV4LENGTH) - 1 + _DECIMAL_DIGITS = frozenset("0123456789") + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0]) + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return _compat_str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError("Address cannot be empty") + + octets = ip_str.split(".") + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return _compat_int_from_byte_vals( + map(cls._parse_octet, octets), "big" + ) + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._DECIMAL_DIGITS.issuperset(octet_str): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + # Any octets that look like they *might* be written in octal, + # and which don't look exactly the same in both octal and + # decimal are rejected as ambiguous + if octet_int > 7 and octet_str[0] == "0": + msg = "Ambiguous (octal/decimal) value in %r not permitted" + raise ValueError(msg % octet_str) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return ".".join( + _compat_str( + struct.unpack(b"!B", b)[0] if isinstance(b, bytes) else b + ) + for b in _compat_to_bytes(ip_int, 4, "big") + ) + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split(".") + try: + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = _compat_str(self).split(".")[::-1] + return ".".join(reverse_octets) + ".in-addr.arpa" + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ("_ip", "__weakref__") + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, "big") + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if "/" in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + return ( + self not in self._constants._public_network and not self.is_private + ) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv4Address.__init__(self, address) + self.network = IPv4Network(self._ip) + self._prefixlen = self._max_prefixlen + return + + if isinstance(address, tuple): + IPv4Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + + self.network = IPv4Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv4Address.__init__(self, addr[0]) + + self.network = IPv4Network(address, strict=False) + self._prefixlen = self.network._prefixlen + + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + + def __str__(self): + return "%s/%d" % ( + self._string_from_ip_int(self._ip), + self.network.prefixlen, + ) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return ( + self.network < other.network + or self.network == other.network + and address_less + ) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.0.2/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (_compat_int_types, bytes)): + self.network_address = IPv4Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen + ) + # fixme: address/network test here. + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + # We weren't given an address[1] + arg = self._max_prefixlen + self.network_address = IPv4Address(address[0]) + self.netmask, self._prefixlen = self._make_netmask(arg) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError("%s has host bits set" % self) + else: + self.network_address = IPv4Address( + packed & int(self.netmask) + ) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if ( + IPv4Address(int(self.network_address) & int(self.netmask)) + != self.network_address + ): + raise ValueError("%s has host bits set" % self) + self.network_address = IPv4Address( + int(self.network_address) & int(self.netmask) + ) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return ( + not ( + self.network_address in IPv4Network("100.64.0.0/10") + and self.broadcast_address in IPv4Network("100.64.0.0/10") + ) + and not self.is_private + ) + + +class _IPv4Constants(object): + + _linklocal_network = IPv4Network("169.254.0.0/16") + + _loopback_network = IPv4Network("127.0.0.0/8") + + _multicast_network = IPv4Network("224.0.0.0/4") + + _public_network = IPv4Network("100.64.0.0/10") + + _private_networks = [ + IPv4Network("0.0.0.0/8"), + IPv4Network("10.0.0.0/8"), + IPv4Network("127.0.0.0/8"), + IPv4Network("169.254.0.0/16"), + IPv4Network("172.16.0.0/12"), + IPv4Network("192.0.0.0/29"), + IPv4Network("192.0.0.170/31"), + IPv4Network("192.0.2.0/24"), + IPv4Network("192.168.0.0/16"), + IPv4Network("198.18.0.0/15"), + IPv4Network("198.51.100.0/24"), + IPv4Network("203.0.113.0/24"), + IPv4Network("240.0.0.0/4"), + IPv4Network("255.255.255.255/32"), + ] + + _reserved_network = IPv4Network("240.0.0.0/4") + + _unspecified_address = IPv4Address("0.0.0.0") + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6(object): + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2 ** IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset("0123456789ABCDEFabcdef") + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError("Address cannot be empty") + + parts = ip_str.split(":") + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if "." in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + parts.append("%x" % ((ipv4_int >> 16) & 0xFFFF)) + parts.append("%x" % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % ( + _max_parts - 1, + ip_str, + ) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in _compat_range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == "0": + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = ( + best_doublecolon_start + best_doublecolon_len + ) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [""] + hextets[best_doublecolon_start:best_doublecolon_end] = [""] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [""] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError("IPv6 address is too large") + + hex_str = "%032x" % ip_int + hextets = ["%x" % int(hex_str[x : x + 4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ":".join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = _compat_str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = _compat_str(self.ip) + else: + ip_str = _compat_str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = "%032x" % ip_int + parts = [hex_str[x : x + 4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return "%s/%d" % (":".join(parts), self._prefixlen) + return ":".join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(":", "") + return ".".join(reverse_chars) + ".ip6.arpa" + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ("_ip", "__weakref__") + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, "big") + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if "/" in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return ( + IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF), + ) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv6Address.__init__(self, address) + self.network = IPv6Network(self._ip) + self._prefixlen = self._max_prefixlen + return + if isinstance(address, tuple): + IPv6Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv6Address.__init__(self, addr[0]) + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + self.hostmask = self.network.hostmask + + def __str__(self): + return "%s/%d" % ( + self._string_from_ip_int(self._ip), + self.network.prefixlen, + ) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return ( + self.network < other.network + or self.network == other.network + and address_less + ) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Efficient constructor from integer or packed address + if isinstance(address, (bytes, _compat_int_types)): + self.network_address = IPv6Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen + ) + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + self.network_address = IPv6Address(address[0]) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError("%s has host bits set" % self) + else: + self.network_address = IPv6Address( + packed & int(self.netmask) + ) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + + self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if ( + IPv6Address(int(self.network_address) & int(self.netmask)) + != self.network_address + ): + raise ValueError("%s has host bits set" % self) + self.network_address = IPv6Address( + int(self.network_address) & int(self.netmask) + ) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return ( + self.network_address.is_site_local + and self.broadcast_address.is_site_local + ) + + +class _IPv6Constants(object): + + _linklocal_network = IPv6Network("fe80::/10") + + _multicast_network = IPv6Network("ff00::/8") + + _private_networks = [ + IPv6Network("::1/128"), + IPv6Network("::/128"), + IPv6Network("::ffff:0:0/96"), + IPv6Network("100::/64"), + IPv6Network("2001::/23"), + IPv6Network("2001:2::/48"), + IPv6Network("2001:db8::/32"), + IPv6Network("2001:10::/28"), + IPv6Network("fc00::/7"), + IPv6Network("fe80::/10"), + ] + + _reserved_networks = [ + IPv6Network("::/8"), + IPv6Network("100::/8"), + IPv6Network("200::/7"), + IPv6Network("400::/6"), + IPv6Network("800::/5"), + IPv6Network("1000::/4"), + IPv6Network("4000::/3"), + IPv6Network("6000::/3"), + IPv6Network("8000::/3"), + IPv6Network("A000::/3"), + IPv6Network("C000::/3"), + IPv6Network("E000::/4"), + IPv6Network("F000::/5"), + IPv6Network("F800::/6"), + IPv6Network("FE00::/9"), + ] + + _sitelocal_network = IPv6Network("fec0::/10") + + +IPv6Address._constants = _IPv6Constants diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py new file mode 100644 index 0000000..68608d1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py @@ -0,0 +1,27 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The base class for all resource modules +""" + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import ( + get_resource_connection, +) + + +class ConfigBase(object): + """ The base class for all resource modules + """ + + ACTION_STATES = ["merged", "replaced", "overridden", "deleted"] + + def __init__(self, module): + self._module = module + self.state = module.params["state"] + self._connection = None + + if self.state not in ["rendered", "parsed"]: + self._connection = get_resource_connection(module) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py new file mode 100644 index 0000000..bc458eb --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py @@ -0,0 +1,473 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import re +import hashlib + +from ansible.module_utils.six.moves import zip +from ansible.module_utils._text import to_bytes, to_native + +DEFAULT_COMMENT_TOKENS = ["#", "!", "/*", "*/", "echo"] + +DEFAULT_IGNORE_LINES_RE = set( + [ + re.compile(r"Using \d+ out of \d+ bytes"), + re.compile(r"Building configuration"), + re.compile(r"Current configuration : \d+ bytes"), + ] +) + + +try: + Pattern = re._pattern_type +except AttributeError: + Pattern = re.Pattern + + +class ConfigLine(object): + def __init__(self, raw): + self.text = str(raw).strip() + self.raw = raw + self._children = list() + self._parents = list() + + def __str__(self): + return self.raw + + def __eq__(self, other): + return self.line == other.line + + def __ne__(self, other): + return not self.__eq__(other) + + def __getitem__(self, key): + for item in self._children: + if item.text == key: + return item + raise KeyError(key) + + @property + def line(self): + line = self.parents + line.append(self.text) + return " ".join(line) + + @property + def children(self): + return _obj_to_text(self._children) + + @property + def child_objs(self): + return self._children + + @property + def parents(self): + return _obj_to_text(self._parents) + + @property + def path(self): + config = _obj_to_raw(self._parents) + config.append(self.raw) + return "\n".join(config) + + @property + def has_children(self): + return len(self._children) > 0 + + @property + def has_parents(self): + return len(self._parents) > 0 + + def add_child(self, obj): + if not isinstance(obj, ConfigLine): + raise AssertionError("child must be of type `ConfigLine`") + self._children.append(obj) + + +def ignore_line(text, tokens=None): + for item in tokens or DEFAULT_COMMENT_TOKENS: + if text.startswith(item): + return True + for regex in DEFAULT_IGNORE_LINES_RE: + if regex.match(text): + return True + + +def _obj_to_text(x): + return [o.text for o in x] + + +def _obj_to_raw(x): + return [o.raw for o in x] + + +def _obj_to_block(objects, visited=None): + items = list() + for o in objects: + if o not in items: + items.append(o) + for child in o._children: + if child not in items: + items.append(child) + return _obj_to_raw(items) + + +def dumps(objects, output="block", comments=False): + if output == "block": + items = _obj_to_block(objects) + elif output == "commands": + items = _obj_to_text(objects) + elif output == "raw": + items = _obj_to_raw(objects) + else: + raise TypeError("unknown value supplied for keyword output") + + if output == "block": + if comments: + for index, item in enumerate(items): + nextitem = index + 1 + if ( + nextitem < len(items) + and not item.startswith(" ") + and items[nextitem].startswith(" ") + ): + item = "!\n%s" % item + items[index] = item + items.append("!") + items.append("end") + + return "\n".join(items) + + +class NetworkConfig(object): + def __init__(self, indent=1, contents=None, ignore_lines=None): + self._indent = indent + self._items = list() + self._config_text = None + + if ignore_lines: + for item in ignore_lines: + if not isinstance(item, Pattern): + item = re.compile(item) + DEFAULT_IGNORE_LINES_RE.add(item) + + if contents: + self.load(contents) + + @property + def items(self): + return self._items + + @property + def config_text(self): + return self._config_text + + @property + def sha1(self): + sha1 = hashlib.sha1() + sha1.update(to_bytes(str(self), errors="surrogate_or_strict")) + return sha1.digest() + + def __getitem__(self, key): + for line in self: + if line.text == key: + return line + raise KeyError(key) + + def __iter__(self): + return iter(self._items) + + def __str__(self): + return "\n".join([c.raw for c in self.items]) + + def __len__(self): + return len(self._items) + + def load(self, s): + self._config_text = s + self._items = self.parse(s) + + def loadfp(self, fp): + with open(fp) as f: + return self.load(f.read()) + + def parse(self, lines, comment_tokens=None): + toplevel = re.compile(r"\S") + childline = re.compile(r"^\s*(.+)$") + entry_reg = re.compile(r"([{};])") + + ancestors = list() + config = list() + + indents = [0] + + for linenum, line in enumerate( + to_native(lines, errors="surrogate_or_strict").split("\n") + ): + text = entry_reg.sub("", line).strip() + + cfg = ConfigLine(line) + + if not text or ignore_line(text, comment_tokens): + continue + + # handle top level commands + if toplevel.match(line): + ancestors = [cfg] + indents = [0] + + # handle sub level commands + else: + match = childline.match(line) + line_indent = match.start(1) + + if line_indent < indents[-1]: + while indents[-1] > line_indent: + indents.pop() + + if line_indent > indents[-1]: + indents.append(line_indent) + + curlevel = len(indents) - 1 + parent_level = curlevel - 1 + + cfg._parents = ancestors[:curlevel] + + if curlevel > len(ancestors): + config.append(cfg) + continue + + for i in range(curlevel, len(ancestors)): + ancestors.pop() + + ancestors.append(cfg) + ancestors[parent_level].add_child(cfg) + + config.append(cfg) + + return config + + def get_object(self, path): + for item in self.items: + if item.text == path[-1]: + if item.parents == path[:-1]: + return item + + def get_block(self, path): + if not isinstance(path, list): + raise AssertionError("path argument must be a list object") + obj = self.get_object(path) + if not obj: + raise ValueError("path does not exist in config") + return self._expand_block(obj) + + def get_block_config(self, path): + block = self.get_block(path) + return dumps(block, "block") + + def _expand_block(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj._children: + if child in S: + continue + self._expand_block(child, S) + return S + + def _diff_line(self, other): + updates = list() + for item in self.items: + if item not in other: + updates.append(item) + return updates + + def _diff_strict(self, other): + updates = list() + # block extracted from other does not have all parents + # but the last one. In case of multiple parents we need + # to add additional parents. + if other and isinstance(other, list) and len(other) > 0: + start_other = other[0] + if start_other.parents: + for parent in start_other.parents: + other.insert(0, ConfigLine(parent)) + for index, line in enumerate(self.items): + try: + if str(line).strip() != str(other[index]).strip(): + updates.append(line) + except (AttributeError, IndexError): + updates.append(line) + return updates + + def _diff_exact(self, other): + updates = list() + if len(other) != len(self.items): + updates.extend(self.items) + else: + for ours, theirs in zip(self.items, other): + if ours != theirs: + updates.extend(self.items) + break + return updates + + def difference(self, other, match="line", path=None, replace=None): + """Perform a config diff against the another network config + + :param other: instance of NetworkConfig to diff against + :param match: type of diff to perform. valid values are 'line', + 'strict', 'exact' + :param path: context in the network config to filter the diff + :param replace: the method used to generate the replacement lines. + valid values are 'block', 'line' + + :returns: a string of lines that are different + """ + if path and match != "line": + try: + other = other.get_block(path) + except ValueError: + other = list() + else: + other = other.items + + # generate a list of ConfigLines that aren't in other + meth = getattr(self, "_diff_%s" % match) + updates = meth(other) + + if replace == "block": + parents = list() + for item in updates: + if not item.has_parents: + parents.append(item) + else: + for p in item._parents: + if p not in parents: + parents.append(p) + + updates = list() + for item in parents: + updates.extend(self._expand_block(item)) + + visited = set() + expanded = list() + + for item in updates: + for p in item._parents: + if p.line not in visited: + visited.add(p.line) + expanded.append(p) + expanded.append(item) + visited.add(item.line) + + return expanded + + def add(self, lines, parents=None): + ancestors = list() + offset = 0 + obj = None + + # global config command + if not parents: + for line in lines: + # handle ignore lines + if ignore_line(line): + continue + + item = ConfigLine(line) + item.raw = line + if item not in self.items: + self.items.append(item) + + else: + for index, p in enumerate(parents): + try: + i = index + 1 + obj = self.get_block(parents[:i])[0] + ancestors.append(obj) + + except ValueError: + # add parent to config + offset = index * self._indent + obj = ConfigLine(p) + obj.raw = p.rjust(len(p) + offset) + if ancestors: + obj._parents = list(ancestors) + ancestors[-1]._children.append(obj) + self.items.append(obj) + ancestors.append(obj) + + # add child objects + for line in lines: + # handle ignore lines + if ignore_line(line): + continue + + # check if child already exists + for child in ancestors[-1]._children: + if child.text == line: + break + else: + offset = len(parents) * self._indent + item = ConfigLine(line) + item.raw = line.rjust(len(line) + offset) + item._parents = ancestors + ancestors[-1]._children.append(item) + self.items.append(item) + + +class CustomNetworkConfig(NetworkConfig): + def items_text(self): + return [item.text for item in self.items] + + def expand_section(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj.child_objs: + if child in S: + continue + self.expand_section(child, S) + return S + + def to_block(self, section): + return "\n".join([item.raw for item in section]) + + def get_section(self, path): + try: + section = self.get_section_objects(path) + return self.to_block(section) + except ValueError: + return list() + + def get_section_objects(self, path): + if not isinstance(path, list): + path = [path] + obj = self.get_object(path) + if not obj: + raise ValueError("path does not exist in config") + return self.expand_section(obj) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py new file mode 100644 index 0000000..477d318 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py @@ -0,0 +1,162 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The facts base class +this contains methods common to all facts subsets +""" +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import ( + get_resource_connection, +) +from ansible.module_utils.six import iteritems + + +class FactsBase(object): + """ + The facts base class + """ + + def __init__(self, module): + self._module = module + self._warnings = [] + self._gather_subset = module.params.get("gather_subset") + self._gather_network_resources = module.params.get( + "gather_network_resources" + ) + self._connection = None + if module.params.get("state") not in ["rendered", "parsed"]: + self._connection = get_resource_connection(module) + + self.ansible_facts = {"ansible_network_resources": {}} + self.ansible_facts["ansible_net_gather_network_resources"] = list() + self.ansible_facts["ansible_net_gather_subset"] = list() + + if not self._gather_subset: + self._gather_subset = ["!config"] + if not self._gather_network_resources: + self._gather_network_resources = ["!all"] + + def gen_runable(self, subsets, valid_subsets, resource_facts=False): + """ Generate the runable subset + + :param module: The module instance + :param subsets: The provided subsets + :param valid_subsets: The valid subsets + :param resource_facts: A boolean flag + :rtype: list + :returns: The runable subsets + """ + runable_subsets = set() + exclude_subsets = set() + minimal_gather_subset = set() + if not resource_facts: + minimal_gather_subset = frozenset(["default"]) + + for subset in subsets: + if subset == "all": + runable_subsets.update(valid_subsets) + continue + if subset == "min" and minimal_gather_subset: + runable_subsets.update(minimal_gather_subset) + continue + if subset.startswith("!"): + subset = subset[1:] + if subset == "min": + exclude_subsets.update(minimal_gather_subset) + continue + if subset == "all": + exclude_subsets.update( + valid_subsets - minimal_gather_subset + ) + continue + exclude = True + else: + exclude = False + + if subset not in valid_subsets: + self._module.fail_json( + msg="Subset must be one of [%s], got %s" + % ( + ", ".join(sorted([item for item in valid_subsets])), + subset, + ) + ) + + if exclude: + exclude_subsets.add(subset) + else: + runable_subsets.add(subset) + + if not runable_subsets: + runable_subsets.update(valid_subsets) + runable_subsets.difference_update(exclude_subsets) + return runable_subsets + + def get_network_resources_facts( + self, facts_resource_obj_map, resource_facts_type=None, data=None + ): + """ + :param fact_resource_subsets: + :param data: previously collected configuration + :return: + """ + if not resource_facts_type: + resource_facts_type = self._gather_network_resources + + restorun_subsets = self.gen_runable( + resource_facts_type, + frozenset(facts_resource_obj_map.keys()), + resource_facts=True, + ) + if restorun_subsets: + self.ansible_facts["ansible_net_gather_network_resources"] = list( + restorun_subsets + ) + instances = list() + for key in restorun_subsets: + fact_cls_obj = facts_resource_obj_map.get(key) + if fact_cls_obj: + instances.append(fact_cls_obj(self._module)) + else: + self._warnings.extend( + [ + "network resource fact gathering for '%s' is not supported" + % key + ] + ) + + for inst in instances: + inst.populate_facts(self._connection, self.ansible_facts, data) + + def get_network_legacy_facts( + self, fact_legacy_obj_map, legacy_facts_type=None + ): + if not legacy_facts_type: + legacy_facts_type = self._gather_subset + + runable_subsets = self.gen_runable( + legacy_facts_type, frozenset(fact_legacy_obj_map.keys()) + ) + if runable_subsets: + facts = dict() + # default subset should always returned be with legacy facts subsets + if "default" not in runable_subsets: + runable_subsets.add("default") + self.ansible_facts["ansible_net_gather_subset"] = list( + runable_subsets + ) + + instances = list() + for key in runable_subsets: + instances.append(fact_legacy_obj_map[key](self._module)) + + for inst in instances: + inst.populate() + facts.update(inst.facts) + self._warnings.extend(inst.warnings) + + for key, value in iteritems(facts): + key = "ansible_net_%s" % key + self.ansible_facts[key] = value diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py new file mode 100644 index 0000000..53a91e8 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py @@ -0,0 +1,179 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import sys + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError + +try: + from ncclient.xml_ import NCElement, new_ele, sub_ele + + HAS_NCCLIENT = True +except (ImportError, AttributeError): + HAS_NCCLIENT = False + +try: + from lxml.etree import Element, fromstring, XMLSyntaxError +except ImportError: + from xml.etree.ElementTree import Element, fromstring + + if sys.version_info < (2, 7): + from xml.parsers.expat import ExpatError as XMLSyntaxError + else: + from xml.etree.ElementTree import ParseError as XMLSyntaxError + +NS_MAP = {"nc": "urn:ietf:params:xml:ns:netconf:base:1.0"} + + +def exec_rpc(module, *args, **kwargs): + connection = NetconfConnection(module._socket_path) + return connection.execute_rpc(*args, **kwargs) + + +class NetconfConnection(Connection): + def __init__(self, socket_path): + super(NetconfConnection, self).__init__(socket_path) + + def __rpc__(self, name, *args, **kwargs): + """Executes the json-rpc and returns the output received + from remote device. + :name: rpc method to be executed over connection plugin that implements jsonrpc 2.0 + :args: Ordered list of params passed as arguments to rpc method + :kwargs: Dict of valid key, value pairs passed as arguments to rpc method + + For usage refer the respective connection plugin docs. + """ + self.check_rc = kwargs.pop("check_rc", True) + self.ignore_warning = kwargs.pop("ignore_warning", True) + + response = self._exec_jsonrpc(name, *args, **kwargs) + if "error" in response: + rpc_error = response["error"].get("data") + return self.parse_rpc_error( + to_bytes(rpc_error, errors="surrogate_then_replace") + ) + + return fromstring( + to_bytes(response["result"], errors="surrogate_then_replace") + ) + + def parse_rpc_error(self, rpc_error): + if self.check_rc: + try: + error_root = fromstring(rpc_error) + root = Element("root") + root.append(error_root) + + error_list = root.findall(".//nc:rpc-error", NS_MAP) + if not error_list: + raise ConnectionError( + to_text(rpc_error, errors="surrogate_then_replace") + ) + + warnings = [] + for error in error_list: + message_ele = error.find("./nc:error-message", NS_MAP) + + if message_ele is None: + message_ele = error.find("./nc:error-info", NS_MAP) + + message = ( + message_ele.text if message_ele is not None else None + ) + + severity = error.find("./nc:error-severity", NS_MAP).text + + if ( + severity == "warning" + and self.ignore_warning + and message is not None + ): + warnings.append(message) + else: + raise ConnectionError( + to_text(rpc_error, errors="surrogate_then_replace") + ) + return warnings + except XMLSyntaxError: + raise ConnectionError(rpc_error) + + +def transform_reply(): + return b""" + + + + + + + + + + + + + + + + + + + + + """ + + +# Note: Workaround for ncclient 0.5.3 +def remove_namespaces(data): + if not HAS_NCCLIENT: + raise ImportError( + "ncclient is required but does not appear to be installed. " + "It can be installed using `pip install ncclient`" + ) + return NCElement(data, transform_reply()).data_xml + + +def build_root_xml_node(tag): + return new_ele(tag) + + +def build_child_xml_node(parent, tag, text=None, attrib=None): + element = sub_ele(parent, tag) + if text: + element.text = to_text(text) + if attrib: + element.attrib.update(attrib) + return element + + +def build_subtree(parent, path): + element = parent + for field in path.split("/"): + sub_element = build_child_xml_node(element, field) + element = sub_element + return element diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py new file mode 100644 index 0000000..555fc71 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py @@ -0,0 +1,275 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2015 Peter Sprygada, +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import traceback +import json + +from ansible.module_utils._text import to_text, to_native +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.basic import env_fallback +from ansible.module_utils.connection import Connection, ConnectionError +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import ( + NetconfConnection, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import ( + Cli, +) +from ansible.module_utils.six import iteritems + + +NET_TRANSPORT_ARGS = dict( + host=dict(required=True), + port=dict(type="int"), + username=dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])), + password=dict( + no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]) + ), + ssh_keyfile=dict( + fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path" + ), + authorize=dict( + default=False, + fallback=(env_fallback, ["ANSIBLE_NET_AUTHORIZE"]), + type="bool", + ), + auth_pass=dict( + no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_AUTH_PASS"]) + ), + provider=dict(type="dict", no_log=True), + transport=dict(choices=list()), + timeout=dict(default=10, type="int"), +) + +NET_CONNECTION_ARGS = dict() + +NET_CONNECTIONS = dict() + + +def _transitional_argument_spec(): + argument_spec = {} + for key, value in iteritems(NET_TRANSPORT_ARGS): + value["required"] = False + argument_spec[key] = value + return argument_spec + + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +class ModuleStub(object): + def __init__(self, argument_spec, fail_json): + self.params = dict() + for key, value in argument_spec.items(): + self.params[key] = value.get("default") + self.fail_json = fail_json + + +class NetworkError(Exception): + def __init__(self, msg, **kwargs): + super(NetworkError, self).__init__(msg) + self.kwargs = kwargs + + +class Config(object): + def __init__(self, connection): + self.connection = connection + + def __call__(self, commands, **kwargs): + lines = to_list(commands) + return self.connection.configure(lines, **kwargs) + + def load_config(self, commands, **kwargs): + commands = to_list(commands) + return self.connection.load_config(commands, **kwargs) + + def get_config(self, **kwargs): + return self.connection.get_config(**kwargs) + + def save_config(self): + return self.connection.save_config() + + +class NetworkModule(AnsibleModule): + def __init__(self, *args, **kwargs): + connect_on_load = kwargs.pop("connect_on_load", True) + + argument_spec = NET_TRANSPORT_ARGS.copy() + argument_spec["transport"]["choices"] = NET_CONNECTIONS.keys() + argument_spec.update(NET_CONNECTION_ARGS.copy()) + + if kwargs.get("argument_spec"): + argument_spec.update(kwargs["argument_spec"]) + kwargs["argument_spec"] = argument_spec + + super(NetworkModule, self).__init__(*args, **kwargs) + + self.connection = None + self._cli = None + self._config = None + + try: + transport = self.params["transport"] or "__default__" + cls = NET_CONNECTIONS[transport] + self.connection = cls() + except KeyError: + self.fail_json( + msg="Unknown transport or no default transport specified" + ) + except (TypeError, NetworkError) as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + if connect_on_load: + self.connect() + + @property + def cli(self): + if not self.connected: + self.connect() + if self._cli: + return self._cli + self._cli = Cli(self.connection) + return self._cli + + @property + def config(self): + if not self.connected: + self.connect() + if self._config: + return self._config + self._config = Config(self.connection) + return self._config + + @property + def connected(self): + return self.connection._connected + + def _load_params(self): + super(NetworkModule, self)._load_params() + provider = self.params.get("provider") or dict() + for key, value in provider.items(): + for args in [NET_TRANSPORT_ARGS, NET_CONNECTION_ARGS]: + if key in args: + if self.params.get(key) is None and value is not None: + self.params[key] = value + + def connect(self): + try: + if not self.connected: + self.connection.connect(self.params) + if self.params["authorize"]: + self.connection.authorize(self.params) + self.log( + "connected to %s:%s using %s" + % ( + self.params["host"], + self.params["port"], + self.params["transport"], + ) + ) + except NetworkError as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + def disconnect(self): + try: + if self.connected: + self.connection.disconnect() + self.log("disconnected from %s" % self.params["host"]) + except NetworkError as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + +def register_transport(transport, default=False): + def register(cls): + NET_CONNECTIONS[transport] = cls + if default: + NET_CONNECTIONS["__default__"] = cls + return cls + + return register + + +def add_argument(key, value): + NET_CONNECTION_ARGS[key] = value + + +def get_resource_connection(module): + if hasattr(module, "_connection"): + return module._connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api in ("cliconf", "nxapi", "eapi", "exosapi"): + module._connection = Connection(module._socket_path) + elif network_api == "netconf": + module._connection = NetconfConnection(module._socket_path) + elif network_api == "local": + # This isn't supported, but we shouldn't fail here. + # Set the connection to a fake connection so it fails sensibly. + module._connection = LocalResourceConnection(module) + else: + module.fail_json( + msg="Invalid connection type {0!s}".format(network_api) + ) + + return module._connection + + +def get_capabilities(module): + if hasattr(module, "capabilities"): + return module._capabilities + try: + capabilities = Connection(module._socket_path).get_capabilities() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + except AssertionError: + # No socket_path, connection most likely local. + return dict(network_api="local") + module._capabilities = json.loads(capabilities) + + return module._capabilities + + +class LocalResourceConnection: + def __init__(self, module): + self.module = module + + def get(self, *args, **kwargs): + self.module.fail_json( + msg="Network resource modules not supported over local connection." + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py new file mode 100644 index 0000000..2dd1de9 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py @@ -0,0 +1,316 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2015 Peter Sprygada, +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +import shlex +import time + +from ansible.module_utils.parsing.convert_bool import ( + BOOLEANS_TRUE, + BOOLEANS_FALSE, +) +from ansible.module_utils.six import string_types, text_type +from ansible.module_utils.six.moves import zip + + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +class FailedConditionsError(Exception): + def __init__(self, msg, failed_conditions): + super(FailedConditionsError, self).__init__(msg) + self.failed_conditions = failed_conditions + + +class FailedConditionalError(Exception): + def __init__(self, msg, failed_conditional): + super(FailedConditionalError, self).__init__(msg) + self.failed_conditional = failed_conditional + + +class AddCommandError(Exception): + def __init__(self, msg, command): + super(AddCommandError, self).__init__(msg) + self.command = command + + +class AddConditionError(Exception): + def __init__(self, msg, condition): + super(AddConditionError, self).__init__(msg) + self.condition = condition + + +class Cli(object): + def __init__(self, connection): + self.connection = connection + self.default_output = connection.default_output or "text" + self._commands = list() + + @property + def commands(self): + return [str(c) for c in self._commands] + + def __call__(self, commands, output=None): + objects = list() + for cmd in to_list(commands): + objects.append(self.to_command(cmd, output)) + return self.connection.run_commands(objects) + + def to_command( + self, command, output=None, prompt=None, response=None, **kwargs + ): + output = output or self.default_output + if isinstance(command, Command): + return command + if isinstance(prompt, string_types): + prompt = re.compile(re.escape(prompt)) + return Command( + command, output, prompt=prompt, response=response, **kwargs + ) + + def add_commands(self, commands, output=None, **kwargs): + for cmd in commands: + self._commands.append(self.to_command(cmd, output, **kwargs)) + + def run_commands(self): + responses = self.connection.run_commands(self._commands) + for resp, cmd in zip(responses, self._commands): + cmd.response = resp + + # wipe out the commands list to avoid issues if additional + # commands are executed later + self._commands = list() + + return responses + + +class Command(object): + def __init__( + self, command, output=None, prompt=None, response=None, **kwargs + ): + + self.command = command + self.output = output + self.command_string = command + + self.prompt = prompt + self.response = response + + self.args = kwargs + + def __str__(self): + return self.command_string + + +class CommandRunner(object): + def __init__(self, module): + self.module = module + + self.items = list() + self.conditionals = set() + + self.commands = list() + + self.retries = 10 + self.interval = 1 + + self.match = "all" + + self._default_output = module.connection.default_output + + def add_command( + self, command, output=None, prompt=None, response=None, **kwargs + ): + if command in [str(c) for c in self.commands]: + raise AddCommandError( + "duplicated command detected", command=command + ) + cmd = self.module.cli.to_command( + command, output=output, prompt=prompt, response=response, **kwargs + ) + self.commands.append(cmd) + + def get_command(self, command, output=None): + for cmd in self.commands: + if cmd.command == command: + return cmd.response + raise ValueError("command '%s' not found" % command) + + def get_responses(self): + return [cmd.response for cmd in self.commands] + + def add_conditional(self, condition): + try: + self.conditionals.add(Conditional(condition)) + except AttributeError as exc: + raise AddConditionError(msg=str(exc), condition=condition) + + def run(self): + while self.retries > 0: + self.module.cli.add_commands(self.commands) + responses = self.module.cli.run_commands() + + for item in list(self.conditionals): + if item(responses): + if self.match == "any": + return item + self.conditionals.remove(item) + + if not self.conditionals: + break + + time.sleep(self.interval) + self.retries -= 1 + else: + failed_conditions = [item.raw for item in self.conditionals] + errmsg = ( + "One or more conditional statements have not been satisfied" + ) + raise FailedConditionsError(errmsg, failed_conditions) + + +class Conditional(object): + """Used in command modules to evaluate waitfor conditions + """ + + OPERATORS = { + "eq": ["eq", "=="], + "neq": ["neq", "ne", "!="], + "gt": ["gt", ">"], + "ge": ["ge", ">="], + "lt": ["lt", "<"], + "le": ["le", "<="], + "contains": ["contains"], + "matches": ["matches"], + } + + def __init__(self, conditional, encoding=None): + self.raw = conditional + self.negate = False + try: + components = shlex.split(conditional) + key, val = components[0], components[-1] + op_components = components[1:-1] + if "not" in op_components: + self.negate = True + op_components.pop(op_components.index("not")) + op = op_components[0] + + except ValueError: + raise ValueError("failed to parse conditional") + + self.key = key + self.func = self._func(op) + self.value = self._cast_value(val) + + def __call__(self, data): + value = self.get_value(dict(result=data)) + if not self.negate: + return self.func(value) + else: + return not self.func(value) + + def _cast_value(self, value): + if value in BOOLEANS_TRUE: + return True + elif value in BOOLEANS_FALSE: + return False + elif re.match(r"^\d+\.d+$", value): + return float(value) + elif re.match(r"^\d+$", value): + return int(value) + else: + return text_type(value) + + def _func(self, oper): + for func, operators in self.OPERATORS.items(): + if oper in operators: + return getattr(self, func) + raise AttributeError("unknown operator: %s" % oper) + + def get_value(self, result): + try: + return self.get_json(result) + except (IndexError, TypeError, AttributeError): + msg = "unable to apply conditional to result" + raise FailedConditionalError(msg, self.raw) + + def get_json(self, result): + string = re.sub(r"\[[\'|\"]", ".", self.key) + string = re.sub(r"[\'|\"]\]", ".", string) + parts = re.split(r"\.(?=[^\]]*(?:\[|$))", string) + for part in parts: + match = re.findall(r"\[(\S+?)\]", part) + if match: + key = part[: part.find("[")] + result = result[key] + for m in match: + try: + m = int(m) + except ValueError: + m = str(m) + result = result[m] + else: + result = result.get(part) + return result + + def number(self, value): + if "." in str(value): + return float(value) + else: + return int(value) + + def eq(self, value): + return value == self.value + + def neq(self, value): + return value != self.value + + def gt(self, value): + return self.number(value) > self.value + + def ge(self, value): + return self.number(value) >= self.value + + def lt(self, value): + return self.number(value) < self.value + + def le(self, value): + return self.number(value) <= self.value + + def contains(self, value): + return str(self.value) in value + + def matches(self, value): + match = re.search(self.value, value, re.M) + return match is not None diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py new file mode 100644 index 0000000..64eca15 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py @@ -0,0 +1,686 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Networking tools for network modules only + +import re +import ast +import operator +import socket +import json + +from itertools import chain + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils import basic +from ansible.module_utils.parsing.convert_bool import boolean + +# Backwards compatibility for 3rd party modules +# TODO(pabelanger): With move to ansible.netcommon, we should clean this code +# up and have modules import directly themself. +from ansible.module_utils.common.network import ( # noqa: F401 + to_bits, + is_netmask, + is_masklen, + to_netmask, + to_masklen, + to_subnet, + to_ipv6_network, + VALID_MASKS, +) + +try: + from jinja2 import Environment, StrictUndefined + from jinja2.exceptions import UndefinedError + + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False + + +OPERATORS = frozenset(["ge", "gt", "eq", "neq", "lt", "le"]) +ALIASES = frozenset( + [("min", "ge"), ("max", "le"), ("exactly", "eq"), ("neq", "ne")] +) + + +def to_list(val): + if isinstance(val, (list, tuple, set)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +def to_lines(stdout): + for item in stdout: + if isinstance(item, string_types): + item = to_text(item).split("\n") + yield item + + +def transform_commands(module): + transform = ComplexList( + dict( + command=dict(key=True), + output=dict(), + prompt=dict(type="list"), + answer=dict(type="list"), + newline=dict(type="bool", default=True), + sendonly=dict(type="bool", default=False), + check_all=dict(type="bool", default=False), + ), + module, + ) + + return transform(module.params["commands"]) + + +def sort_list(val): + if isinstance(val, list): + return sorted(val) + return val + + +class Entity(object): + """Transforms a dict to with an argument spec + + This class will take a dict and apply an Ansible argument spec to the + values. The resulting dict will contain all of the keys in the param + with appropriate values set. + + Example:: + + argument_spec = dict( + command=dict(key=True), + display=dict(default='text', choices=['text', 'json']), + validate=dict(type='bool') + ) + transform = Entity(module, argument_spec) + value = dict(command='foo') + result = transform(value) + print result + {'command': 'foo', 'display': 'text', 'validate': None} + + Supported argument spec: + * key - specifies how to map a single value to a dict + * read_from - read and apply the argument_spec from the module + * required - a value is required + * type - type of value (uses AnsibleModule type checker) + * fallback - implements fallback function + * choices - set of valid options + * default - default value + """ + + def __init__( + self, module, attrs=None, args=None, keys=None, from_argspec=False + ): + args = [] if args is None else args + + self._attributes = attrs or {} + self._module = module + + for arg in args: + self._attributes[arg] = dict() + if from_argspec: + self._attributes[arg]["read_from"] = arg + if keys and arg in keys: + self._attributes[arg]["key"] = True + + self.attr_names = frozenset(self._attributes.keys()) + + _has_key = False + + for name, attr in iteritems(self._attributes): + if attr.get("read_from"): + if attr["read_from"] not in self._module.argument_spec: + module.fail_json( + msg="argument %s does not exist" % attr["read_from"] + ) + spec = self._module.argument_spec.get(attr["read_from"]) + for key, value in iteritems(spec): + if key not in attr: + attr[key] = value + + if attr.get("key"): + if _has_key: + module.fail_json(msg="only one key value can be specified") + _has_key = True + attr["required"] = True + + def serialize(self): + return self._attributes + + def to_dict(self, value): + obj = {} + for name, attr in iteritems(self._attributes): + if attr.get("key"): + obj[name] = value + else: + obj[name] = attr.get("default") + return obj + + def __call__(self, value, strict=True): + if not isinstance(value, dict): + value = self.to_dict(value) + + if strict: + unknown = set(value).difference(self.attr_names) + if unknown: + self._module.fail_json( + msg="invalid keys: %s" % ",".join(unknown) + ) + + for name, attr in iteritems(self._attributes): + if value.get(name) is None: + value[name] = attr.get("default") + + if attr.get("fallback") and not value.get(name): + fallback = attr.get("fallback", (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + value[name] = fallback_strategy( + *fallback_args, **fallback_kwargs + ) + except basic.AnsibleFallbackNotFound: + continue + + if attr.get("required") and value.get(name) is None: + self._module.fail_json( + msg="missing required attribute %s" % name + ) + + if "choices" in attr: + if value[name] not in attr["choices"]: + self._module.fail_json( + msg="%s must be one of %s, got %s" + % (name, ", ".join(attr["choices"]), value[name]) + ) + + if value[name] is not None: + value_type = attr.get("type", "str") + type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[ + value_type + ] + type_checker(value[name]) + elif value.get(name): + value[name] = self._module.params[name] + + return value + + +class EntityCollection(Entity): + """Extends ```Entity``` to handle a list of dicts """ + + def __call__(self, iterable, strict=True): + if iterable is None: + iterable = [ + super(EntityCollection, self).__call__( + self._module.params, strict + ) + ] + + if not isinstance(iterable, (list, tuple)): + self._module.fail_json(msg="value must be an iterable") + + return [ + (super(EntityCollection, self).__call__(i, strict)) + for i in iterable + ] + + +# these two are for backwards compatibility and can be removed once all of the +# modules that use them are updated +class ComplexDict(Entity): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexDict, self).__init__(module, attrs, *args, **kwargs) + + +class ComplexList(EntityCollection): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexList, self).__init__(module, attrs, *args, **kwargs) + + +def dict_diff(base, comparable): + """ Generate a dict object of differences + + This function will compare two dict objects and return the difference + between them as a dict object. For scalar values, the key will reflect + the updated value. If the key does not exist in `comparable`, then then no + key will be returned. For lists, the value in comparable will wholly replace + the value in base for the key. For dicts, the returned value will only + return keys that are different. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + + :returns: new dict object with differences + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type ") + if not isinstance(comparable, dict): + if comparable is None: + comparable = dict() + else: + raise AssertionError("`comparable` must be of type ") + + updates = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + item = comparable.get(key) + if item is not None: + sub_diff = dict_diff(value, comparable[key]) + if sub_diff: + updates[key] = sub_diff + else: + comparable_value = comparable.get(key) + if comparable_value is not None: + if sort_list(base[key]) != sort_list(comparable_value): + updates[key] = comparable_value + + for key in set(comparable.keys()).difference(base.keys()): + updates[key] = comparable.get(key) + + return updates + + +def dict_merge(base, other): + """ Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. If the value is a list object, the two lists will + be combined and duplicate entries removed. + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type ") + if not isinstance(other, dict): + raise AssertionError("`other` must be of type ") + + combined = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + if isinstance(other[key], Mapping): + combined[key] = dict_merge(value, other[key]) + else: + combined[key] = other[key] + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if item is not None: + try: + combined[key] = list(set(chain(value, item))) + except TypeError: + value.extend([i for i in item if i not in value]) + combined[key] = value + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) + + return combined + + +def param_list_to_dict(param_list, unique_key="name", remove_key=True): + """Rotates a list of dictionaries to be a dictionary of dictionaries. + + :param param_list: The aforementioned list of dictionaries + :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value + behind this key will be the key each dictionary can be found at in the new root dictionary + :param remove_key: If True, remove unique_key from the individual dictionaries before returning. + """ + param_dict = {} + for params in param_list: + params = params.copy() + if remove_key: + name = params.pop(unique_key) + else: + name = params.get(unique_key) + param_dict[name] = params + + return param_dict + + +def conditional(expr, val, cast=None): + match = re.match(r"^(.+)\((.+)\)$", str(expr), re.I) + if match: + op, arg = match.groups() + else: + op = "eq" + if " " in str(expr): + raise AssertionError("invalid expression: cannot contain spaces") + arg = expr + + if cast is None and val is not None: + arg = type(val)(arg) + elif callable(cast): + arg = cast(arg) + val = cast(val) + + op = next((oper for alias, oper in ALIASES if op == alias), op) + + if not hasattr(operator, op) and op not in OPERATORS: + raise ValueError("unknown operator: %s" % op) + + func = getattr(operator, op) + return func(val, arg) + + +def ternary(value, true_val, false_val): + """ value ? true_val : false_val """ + if value: + return true_val + else: + return false_val + + +def remove_default_spec(spec): + for item in spec: + if "default" in spec[item]: + del spec[item]["default"] + + +def validate_ip_address(address): + try: + socket.inet_aton(address) + except socket.error: + return False + return address.count(".") == 3 + + +def validate_ip_v6_address(address): + try: + socket.inet_pton(socket.AF_INET6, address) + except socket.error: + return False + return True + + +def validate_prefix(prefix): + if prefix and not 0 <= int(prefix) <= 32: + return False + return True + + +def load_provider(spec, args): + provider = args.get("provider") or {} + for key, value in iteritems(spec): + if key not in provider: + if "fallback" in value: + provider[key] = _fallback(value["fallback"]) + elif "default" in value: + provider[key] = value["default"] + else: + provider[key] = None + if "authorize" in provider: + # Coerce authorize to provider if a string has somehow snuck in. + provider["authorize"] = boolean(provider["authorize"] or False) + args["provider"] = provider + return provider + + +def _fallback(fallback): + strategy = fallback[0] + args = [] + kwargs = {} + + for item in fallback[1:]: + if isinstance(item, dict): + kwargs = item + else: + args = item + try: + return strategy(*args, **kwargs) + except basic.AnsibleFallbackNotFound: + pass + + +def generate_dict(spec): + """ + Generate dictionary which is in sync with argspec + + :param spec: A dictionary that is the argspec of the module + :rtype: A dictionary + :returns: A dictionary in sync with argspec with default value + """ + obj = {} + if not spec: + return obj + + for key, val in iteritems(spec): + if "default" in val: + dct = {key: val["default"]} + elif "type" in val and val["type"] == "dict": + dct = {key: generate_dict(val["options"])} + else: + dct = {key: None} + obj.update(dct) + return obj + + +def parse_conf_arg(cfg, arg): + """ + Parse config based on argument + + :param cfg: A text string which is a line of configuration. + :param arg: A text string which is to be matched. + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r"%s (.+)(\n|$)" % arg, cfg, re.M) + if match: + result = match.group(1).strip() + else: + result = None + return result + + +def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str="no"): + """ + Parse config based on command + + :param cfg: A text string which is a line of configuration. + :param cmd: A text string which is the command to be matched + :param res1: A text string to be returned if the command is present + :param res2: A text string to be returned if the negate command + is present + :param delete_str: A text string to identify the start of the + negate command + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r"\n\s+%s(\n|$)" % cmd, cfg) + if match: + return res1 + if res2 is not None: + match = re.search(r"\n\s+%s %s(\n|$)" % (delete_str, cmd), cfg) + if match: + return res2 + return None + + +def get_xml_conf_arg(cfg, path, data="text"): + """ + :param cfg: The top level configuration lxml Element tree object + :param path: The relative xpath w.r.t to top level element (cfg) + to be searched in the xml hierarchy + :param data: The type of data to be returned for the matched xml node. + Valid values are text, tag, attrib, with default as text. + :return: Returns the required type for the matched xml node or else None + """ + match = cfg.xpath(path) + if len(match): + if data == "tag": + result = getattr(match[0], "tag") + elif data == "attrib": + result = getattr(match[0], "attrib") + else: + result = getattr(match[0], "text") + else: + result = None + return result + + +def remove_empties(cfg_dict): + """ + Generate final config dictionary + + :param cfg_dict: A dictionary parsed in the facts system + :rtype: A dictionary + :returns: A dictionary by eliminating keys that have null values + """ + final_cfg = {} + if not cfg_dict: + return final_cfg + + for key, val in iteritems(cfg_dict): + dct = None + if isinstance(val, dict): + child_val = remove_empties(val) + if child_val: + dct = {key: child_val} + elif ( + isinstance(val, list) + and val + and all([isinstance(x, dict) for x in val]) + ): + child_val = [remove_empties(x) for x in val] + if child_val: + dct = {key: child_val} + elif val not in [None, [], {}, (), ""]: + dct = {key: val} + if dct: + final_cfg.update(dct) + return final_cfg + + +def validate_config(spec, data): + """ + Validate if the input data against the AnsibleModule spec format + :param spec: Ansible argument spec + :param data: Data to be validated + :return: + """ + params = basic._ANSIBLE_ARGS + basic._ANSIBLE_ARGS = to_bytes(json.dumps({"ANSIBLE_MODULE_ARGS": data})) + validated_data = basic.AnsibleModule(spec).params + basic._ANSIBLE_ARGS = params + return validated_data + + +def search_obj_in_list(name, lst, key="name"): + if not lst: + return None + else: + for item in lst: + if item.get(key) == name: + return item + + +class Template: + def __init__(self): + if not HAS_JINJA2: + raise ImportError( + "jinja2 is required but does not appear to be installed. " + "It can be installed using `pip install jinja2`" + ) + + self.env = Environment(undefined=StrictUndefined) + self.env.filters.update({"ternary": ternary}) + + def __call__(self, value, variables=None, fail_on_undefined=True): + variables = variables or {} + + if not self.contains_vars(value): + return value + + try: + value = self.env.from_string(value).render(variables) + except UndefinedError: + if not fail_on_undefined: + return None + raise + + if value: + try: + return ast.literal_eval(value) + except Exception: + return str(value) + else: + return None + + def contains_vars(self, data): + if isinstance(data, string_types): + for marker in ( + self.env.block_start_string, + self.env.variable_start_string, + self.env.comment_start_string, + ): + if marker in data: + return True + return False diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py new file mode 100644 index 0000000..1f03299 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py @@ -0,0 +1,147 @@ +# +# (c) 2018 Red Hat, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +import json + +from copy import deepcopy +from contextlib import contextmanager + +try: + from lxml.etree import fromstring, tostring +except ImportError: + from xml.etree.ElementTree import fromstring, tostring + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import ( + NetconfConnection, +) + + +IGNORE_XML_ATTRIBUTE = () + + +def get_connection(module): + if hasattr(module, "_netconf_connection"): + return module._netconf_connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api == "netconf": + module._netconf_connection = NetconfConnection(module._socket_path) + else: + module.fail_json(msg="Invalid connection type %s" % network_api) + + return module._netconf_connection + + +def get_capabilities(module): + if hasattr(module, "_netconf_capabilities"): + return module._netconf_capabilities + + capabilities = Connection(module._socket_path).get_capabilities() + module._netconf_capabilities = json.loads(capabilities) + return module._netconf_capabilities + + +def lock_configuration(module, target=None): + conn = get_connection(module) + return conn.lock(target=target) + + +def unlock_configuration(module, target=None): + conn = get_connection(module) + return conn.unlock(target=target) + + +@contextmanager +def locked_config(module, target=None): + try: + lock_configuration(module, target=target) + yield + finally: + unlock_configuration(module, target=target) + + +def get_config(module, source, filter=None, lock=False): + conn = get_connection(module) + try: + locked = False + if lock: + conn.lock(target=source) + locked = True + response = conn.get_config(source=source, filter=filter) + + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + finally: + if locked: + conn.unlock(target=source) + + return response + + +def get(module, filter, lock=False): + conn = get_connection(module) + try: + locked = False + if lock: + conn.lock(target="running") + locked = True + + response = conn.get(filter=filter) + + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + finally: + if locked: + conn.unlock(target="running") + + return response + + +def dispatch(module, request): + conn = get_connection(module) + try: + response = conn.dispatch(request) + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + return response + + +def sanitize_xml(data): + tree = fromstring( + to_bytes(deepcopy(data), errors="surrogate_then_replace") + ) + for element in tree.getiterator(): + # remove attributes + attribute = element.attrib + if attribute: + for key in list(attribute): + if key not in IGNORE_XML_ATTRIBUTE: + attribute.pop(key) + return to_text(tostring(tree), errors="surrogate_then_replace").strip() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py new file mode 100644 index 0000000..fba46be --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py @@ -0,0 +1,61 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2018 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from ansible.module_utils.connection import Connection + + +def get(module, path=None, content=None, fields=None, output="json"): + if path is None: + raise ValueError("path value must be provided") + if content: + path += "?" + "content=%s" % content + if fields: + path += "?" + "field=%s" % fields + + accept = None + if output == "xml": + accept = "application/yang-data+xml" + + connection = Connection(module._socket_path) + return connection.send_request( + None, path=path, method="GET", accept=accept + ) + + +def edit_config(module, path=None, content=None, method="GET", format="json"): + if path is None: + raise ValueError("path value must be provided") + + content_type = None + if format == "xml": + content_type = "application/yang-data+xml" + + connection = Connection(module._socket_path) + return connection.send_request( + content, path=path, method=method, content_type=content_type + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py new file mode 100644 index 0000000..c1384c1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py @@ -0,0 +1,444 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: cli_config +author: Trishna Guha (@trishnaguha) +notes: +- The commands will be returned only for platforms that do not support onbox diff. + The C(--diff) option with the playbook will return the difference in configuration + for devices that has support for onbox diff +short_description: Push text based configuration to network devices over network_cli +description: +- This module provides platform agnostic way of pushing text based configuration to + network devices over network_cli connection plugin. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + config: + description: + - The config to be pushed to the network device. This argument is mutually exclusive + with C(rollback) and either one of the option should be given as input. The + config should have indentation that the device uses. + type: str + commit: + description: + - The C(commit) argument instructs the module to push the configuration to the + device. This is mapped to module check mode. + type: bool + replace: + description: + - If the C(replace) argument is set to C(yes), it will replace the entire running-config + of the device with the C(config) argument value. For devices that support replacing + running configuration from file on device like NXOS/JUNOS, the C(replace) argument + takes path to the file on the device that will be used for replacing the entire + running-config. The value of C(config) option should be I(None) for such devices. + Nexus 9K devices only support replace. Use I(net_put) or I(nxos_file_copy) in + case of NXOS module to copy the flat file to remote device and then use set + the fullpath to this argument. + type: str + backup: + description: + - This argument will cause the module to create a full backup of the current running + config from the remote device before any changes are made. If the C(backup_options) + value is not given, the backup file is written to the C(backup) folder in the + playbook root directory or role root directory, if playbook is part of an ansible + role. If the directory does not exist, it is created. + type: bool + default: 'no' + rollback: + description: + - The C(rollback) argument instructs the module to rollback the current configuration + to the identifier specified in the argument. If the specified rollback identifier + does not exist on the remote device, the module will fail. To rollback to the + most recent commit, set the C(rollback) argument to 0. This option is mutually + exclusive with C(config). + commit_comment: + description: + - The C(commit_comment) argument specifies a text string to be used when committing + the configuration. If the C(commit) argument is set to False, this argument + is silently ignored. This argument is only valid for the platforms that support + commit operation with comment. + type: str + defaults: + description: + - The I(defaults) argument will influence how the running-config is collected + from the device. When the value is set to true, the command used to collect + the running-config is append with the all keyword. When the value is set to + false, the command is issued without the all keyword. + default: 'no' + type: bool + multiline_delimiter: + description: + - This argument is used when pushing a multiline configuration element to the + device. It specifies the character to use as the delimiting character. This + only applies to the configuration action. + type: str + diff_replace: + description: + - Instructs the module on the way to perform the configuration on the device. + If the C(diff_replace) argument is set to I(line) then the modified lines are + pushed to the device in configuration mode. If the argument is set to I(block) + then the entire command block is pushed to the device in configuration mode + if any line is not correct. Note that this parameter will be ignored if the + platform has onbox diff support. + choices: + - line + - block + - config + diff_match: + description: + - Instructs the module on the way to perform the matching of the set of commands + against the current device config. If C(diff_match) is set to I(line), commands + are matched line by line. If C(diff_match) is set to I(strict), command lines + are matched with respect to position. If C(diff_match) is set to I(exact), command + lines must be an equal match. Finally, if C(diff_match) is set to I(none), the + module will not attempt to compare the source configuration with the running + configuration on the remote device. Note that this parameter will be ignored + if the platform has onbox diff support. + choices: + - line + - strict + - exact + - none + diff_ignore_lines: + description: + - Use this argument to specify one or more lines that should be ignored during + the diff. This is used for lines in the configuration that are automatically + updated by the system. This argument takes a list of regular expressions or + exact line matches. Note that this parameter will be ignored if the platform + has onbox diff support. + backup_options: + description: + - This is a dict object containing configurable options related to backup file + path. The value of this option is read only when C(backup) is set to I(yes), + if C(backup) is set to I(no) this option will be silently ignored. + suboptions: + filename: + description: + - The filename to be used to store the backup configuration. If the filename + is not given it will be generated based on the hostname, current time and + date in format defined by _config.@ + dir_path: + description: + - This option provides the path ending with directory name in which the backup + configuration file will be stored. If the directory does not exist it will + be first created and the filename is either the value of C(filename) or + default filename as described in C(filename) options description. If the + path value is not given in that case a I(backup) directory will be created + in the current working directory and backup configuration will be copied + in C(filename) within I(backup) directory. + type: path + type: dict +""" + +EXAMPLES = """ +- name: configure device with config + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + +- name: multiline config + cli_config: + config: | + hostname foo + feature nxapi + +- name: configure device with config with defaults enabled + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + defaults: yes + +- name: Use diff_match + cli_config: + config: "{{ lookup('file', 'interface_config') }}" + diff_match: none + +- name: nxos replace config + cli_config: + replace: 'bootflash:nxoscfg' + +- name: junos replace config + cli_config: + replace: '/var/home/ansible/junos01.cfg' + +- name: commit with comment + cli_config: + config: set system host-name foo + commit_comment: this is a test + +- name: configurable backup path + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + backup: yes + backup_options: + filename: backup.cfg + dir_path: /home/user +""" + +RETURN = """ +commands: + description: The set of commands that will be pushed to the remote device + returned: always + type: list + sample: ['interface Loopback999', 'no shutdown'] +backup_path: + description: The full path to the backup file + returned: when backup is yes + type: str + sample: /playbooks/ansible/backup/hostname_config.2016-07-16@22:28:34 +""" + +import json + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.connection import Connection +from ansible.module_utils._text import to_text + + +def validate_args(module, device_operations): + """validate param if it is supported on the platform + """ + feature_list = [ + "replace", + "rollback", + "commit_comment", + "defaults", + "multiline_delimiter", + "diff_replace", + "diff_match", + "diff_ignore_lines", + ] + + for feature in feature_list: + if module.params[feature]: + supports_feature = device_operations.get("supports_%s" % feature) + if supports_feature is None: + module.fail_json( + "This platform does not specify whether %s is supported or not. " + "Please report an issue against this platform's cliconf plugin." + % feature + ) + elif not supports_feature: + module.fail_json( + msg="Option %s is not supported on this platform" % feature + ) + + +def run( + module, device_operations, connection, candidate, running, rollback_id +): + result = {} + resp = {} + config_diff = [] + banner_diff = {} + + replace = module.params["replace"] + commit_comment = module.params["commit_comment"] + multiline_delimiter = module.params["multiline_delimiter"] + diff_replace = module.params["diff_replace"] + diff_match = module.params["diff_match"] + diff_ignore_lines = module.params["diff_ignore_lines"] + + commit = not module.check_mode + + if replace in ("yes", "true", "True"): + replace = True + elif replace in ("no", "false", "False"): + replace = False + + if ( + replace is not None + and replace not in [True, False] + and candidate is not None + ): + module.fail_json( + msg="Replace value '%s' is a configuration file path already" + " present on the device. Hence 'replace' and 'config' options" + " are mutually exclusive" % replace + ) + + if rollback_id is not None: + resp = connection.rollback(rollback_id, commit) + if "diff" in resp: + result["changed"] = True + + elif device_operations.get("supports_onbox_diff"): + if diff_replace: + module.warn( + "diff_replace is ignored as the device supports onbox diff" + ) + if diff_match: + module.warn( + "diff_mattch is ignored as the device supports onbox diff" + ) + if diff_ignore_lines: + module.warn( + "diff_ignore_lines is ignored as the device supports onbox diff" + ) + + if candidate and not isinstance(candidate, list): + candidate = candidate.strip("\n").splitlines() + + kwargs = { + "candidate": candidate, + "commit": commit, + "replace": replace, + "comment": commit_comment, + } + resp = connection.edit_config(**kwargs) + + if "diff" in resp: + result["changed"] = True + + elif device_operations.get("supports_generate_diff"): + kwargs = {"candidate": candidate, "running": running} + if diff_match: + kwargs.update({"diff_match": diff_match}) + if diff_replace: + kwargs.update({"diff_replace": diff_replace}) + if diff_ignore_lines: + kwargs.update({"diff_ignore_lines": diff_ignore_lines}) + + diff_response = connection.get_diff(**kwargs) + + config_diff = diff_response.get("config_diff") + banner_diff = diff_response.get("banner_diff") + + if config_diff: + if isinstance(config_diff, list): + candidate = config_diff + else: + candidate = config_diff.splitlines() + + kwargs = { + "candidate": candidate, + "commit": commit, + "replace": replace, + "comment": commit_comment, + } + if commit: + connection.edit_config(**kwargs) + result["changed"] = True + result["commands"] = config_diff.split("\n") + + if banner_diff: + candidate = json.dumps(banner_diff) + + kwargs = {"candidate": candidate, "commit": commit} + if multiline_delimiter: + kwargs.update({"multiline_delimiter": multiline_delimiter}) + if commit: + connection.edit_banner(**kwargs) + result["changed"] = True + + if module._diff: + if "diff" in resp: + result["diff"] = {"prepared": resp["diff"]} + else: + diff = "" + if config_diff: + if isinstance(config_diff, list): + diff += "\n".join(config_diff) + else: + diff += config_diff + if banner_diff: + diff += json.dumps(banner_diff) + result["diff"] = {"prepared": diff} + + return result + + +def main(): + """main entry point for execution + """ + backup_spec = dict(filename=dict(), dir_path=dict(type="path")) + argument_spec = dict( + backup=dict(default=False, type="bool"), + backup_options=dict(type="dict", options=backup_spec), + config=dict(type="str"), + commit=dict(type="bool"), + replace=dict(type="str"), + rollback=dict(type="int"), + commit_comment=dict(type="str"), + defaults=dict(default=False, type="bool"), + multiline_delimiter=dict(type="str"), + diff_replace=dict(choices=["line", "block", "config"]), + diff_match=dict(choices=["line", "strict", "exact", "none"]), + diff_ignore_lines=dict(type="list"), + ) + + mutually_exclusive = [("config", "rollback")] + required_one_of = [["backup", "config", "rollback"]] + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=mutually_exclusive, + required_one_of=required_one_of, + supports_check_mode=True, + ) + + result = {"changed": False} + + connection = Connection(module._socket_path) + capabilities = module.from_json(connection.get_capabilities()) + + if capabilities: + device_operations = capabilities.get("device_operations", dict()) + validate_args(module, device_operations) + else: + device_operations = dict() + + if module.params["defaults"]: + if "get_default_flag" in capabilities.get("rpc"): + flags = connection.get_default_flag() + else: + flags = "all" + else: + flags = [] + + candidate = module.params["config"] + candidate = ( + to_text(candidate, errors="surrogate_then_replace") + if candidate + else None + ) + running = connection.get_config(flags=flags) + rollback_id = module.params["rollback"] + + if module.params["backup"]: + result["__backup__"] = running + + if candidate or rollback_id or module.params["replace"]: + try: + result.update( + run( + module, + device_operations, + connection, + candidate, + running, + rollback_id, + ) + ) + except Exception as exc: + module.fail_json(msg=to_text(exc)) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py new file mode 100644 index 0000000..f0910f5 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py @@ -0,0 +1,71 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: net_get +author: Deepak Agrawal (@dagrawal) +short_description: Copy a file from a network device to Ansible Controller +description: +- This module provides functionality to copy file from network device to ansible controller. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + src: + description: + - Specifies the source file. The path to the source file can either be the full + path on the network device or a relative path as per path supported by destination + network device. + required: true + protocol: + description: + - Protocol used to transfer file. + default: scp + choices: + - scp + - sftp + dest: + description: + - Specifies the destination file. The path to the destination file can either + be the full path on the Ansible control host or a relative path from the playbook + or role root directory. + default: + - Same filename as specified in I(src). The path will be playbook root or role + root directory if playbook is part of a role. +requirements: +- scp +notes: +- Some devices need specific configurations to be enabled before scp can work These + configuration should be pre-configured before using this module e.g ios - C(ip scp + server enable). +- User privilege to do scp on network device should be pre-configured e.g. ios - need + user privilege 15 by default for allowing scp. +- Default destination of source file. +""" + +EXAMPLES = """ +- name: copy file from the network device to Ansible controller + net_get: + src: running_cfg_ios1.txt + +- name: copy file from ios to common location at /tmp + net_get: + src: running_cfg_sw1.txt + dest : /tmp/ios1.txt +""" + +RETURN = """ +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py new file mode 100644 index 0000000..2fc4a98 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py @@ -0,0 +1,82 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: net_put +author: Deepak Agrawal (@dagrawal) +short_description: Copy a file from Ansible Controller to a network device +description: +- This module provides functionality to copy file from Ansible controller to network + devices. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + src: + description: + - Specifies the source file. The path to the source file can either be the full + path on the Ansible control host or a relative path from the playbook or role + root directory. + required: true + protocol: + description: + - Protocol used to transfer file. + default: scp + choices: + - scp + - sftp + dest: + description: + - Specifies the destination file. The path to destination file can either be the + full path or relative path as supported by network_os. + default: + - Filename from src and at default directory of user shell on network_os. + required: false + mode: + description: + - Set the file transfer mode. If mode is set to I(text) then I(src) file will + go through Jinja2 template engine to replace any vars if present in the src + file. If mode is set to I(binary) then file will be copied as it is to destination + device. + default: binary + choices: + - binary + - text +requirements: +- scp +notes: +- Some devices need specific configurations to be enabled before scp can work These + configuration should be pre-configured before using this module e.g ios - C(ip scp + server enable). +- User privilege to do scp on network device should be pre-configured e.g. ios - need + user privilege 15 by default for allowing scp. +- Default destination of source file. +""" + +EXAMPLES = """ +- name: copy file from ansible controller to a network device + net_put: + src: running_cfg_ios1.txt + +- name: copy file at root dir of flash in slot 3 of sw1(ios) + net_put: + src: running_cfg_sw1.txt + protocol: sftp + dest : flash3:/running_cfg_sw1.txt +""" + +RETURN = """ +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py new file mode 100644 index 0000000..e9332f2 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py @@ -0,0 +1,70 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +netconf: default +short_description: Use default netconf plugin to run standard netconf commands as + per RFC +description: +- This default plugin provides low level abstraction apis for sending and receiving + netconf commands as per Netconf RFC specification. +options: + ncclient_device_handler: + type: str + default: default + description: + - Specifies the ncclient device handler name for network os that support default + netconf implementation as per Netconf RFC specification. To identify the ncclient + device handler name refer ncclient library documentation. +""" +import json + +from ansible.module_utils._text import to_text +from ansible.plugins.netconf import NetconfBase + + +class Netconf(NetconfBase): + def get_text(self, ele, tag): + try: + return to_text( + ele.find(tag).text, errors="surrogate_then_replace" + ).strip() + except AttributeError: + pass + + def get_device_info(self): + device_info = dict() + device_info["network_os"] = "default" + return device_info + + def get_capabilities(self): + result = dict() + result["rpc"] = self.get_base_rpc() + result["network_api"] = "netconf" + result["device_info"] = self.get_device_info() + result["server_capabilities"] = [c for c in self.m.server_capabilities] + result["client_capabilities"] = [c for c in self.m.client_capabilities] + result["session_id"] = self.m.session_id + result["device_operations"] = self.get_device_operations( + result["server_capabilities"] + ) + return json.dumps(result) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py new file mode 100644 index 0000000..a38a775 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py @@ -0,0 +1,185 @@ +# (c) 2012-2014, Michael DeHaan +# (c) 2015 Toshio Kuratomi +# (c) 2017, Peter Sprygada +# (c) 2017 Ansible Project +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os + +from ansible import constants as C +from ansible.plugins.connection import ConnectionBase +from ansible.plugins.loader import connection_loader +from ansible.utils.display import Display +from ansible.utils.path import unfrackpath + +display = Display() + + +__all__ = ["NetworkConnectionBase"] + +BUFSIZE = 65536 + + +class NetworkConnectionBase(ConnectionBase): + """ + A base class for network-style connections. + """ + + force_persistence = True + # Do not use _remote_is_local in other connections + _remote_is_local = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(NetworkConnectionBase, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._messages = [] + self._conn_closed = False + + self._network_os = self._play_context.network_os + + self._local = connection_loader.get("local", play_context, "/dev/null") + self._local.set_options() + + self._sub_plugin = {} + self._cached_variables = (None, None, None) + + # reconstruct the socket_path and set instance values accordingly + self._ansible_playbook_pid = kwargs.get("ansible_playbook_pid") + self._update_connection_state() + + def __getattr__(self, name): + try: + return self.__dict__[name] + except KeyError: + if not name.startswith("_"): + plugin = self._sub_plugin.get("obj") + if plugin: + method = getattr(plugin, name, None) + if method is not None: + return method + raise AttributeError( + "'%s' object has no attribute '%s'" + % (self.__class__.__name__, name) + ) + + def exec_command(self, cmd, in_data=None, sudoable=True): + return self._local.exec_command(cmd, in_data, sudoable) + + def queue_message(self, level, message): + """ + Adds a message to the queue of messages waiting to be pushed back to the controller process. + + :arg level: A string which can either be the name of a method in display, or 'log'. When + the messages are returned to task_executor, a value of log will correspond to + ``display.display(message, log_only=True)``, while another value will call ``display.[level](message)`` + """ + self._messages.append((level, message)) + + def pop_messages(self): + messages, self._messages = self._messages, [] + return messages + + def put_file(self, in_path, out_path): + """Transfer a file from local to remote""" + return self._local.put_file(in_path, out_path) + + def fetch_file(self, in_path, out_path): + """Fetch a file from remote to local""" + return self._local.fetch_file(in_path, out_path) + + def reset(self): + """ + Reset the connection + """ + if self._socket_path: + self.queue_message( + "vvvv", + "resetting persistent connection for socket_path %s" + % self._socket_path, + ) + self.close() + self.queue_message("vvvv", "reset call on connection instance") + + def close(self): + self._conn_closed = True + if self._connected: + self._connected = False + + def get_options(self, hostvars=None): + options = super(NetworkConnectionBase, self).get_options( + hostvars=hostvars + ) + + if ( + self._sub_plugin.get("obj") + and self._sub_plugin.get("type") != "external" + ): + try: + options.update( + self._sub_plugin["obj"].get_options(hostvars=hostvars) + ) + except AttributeError: + pass + + return options + + def set_options(self, task_keys=None, var_options=None, direct=None): + super(NetworkConnectionBase, self).set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + if self.get_option("persistent_log_messages"): + warning = ( + "Persistent connection logging is enabled for %s. This will log ALL interactions" + % self._play_context.remote_addr + ) + logpath = getattr(C, "DEFAULT_LOG_PATH") + if logpath is not None: + warning += " to %s" % logpath + self.queue_message( + "warning", + "%s and WILL NOT redact sensitive configuration like passwords. USE WITH CAUTION!" + % warning, + ) + + if ( + self._sub_plugin.get("obj") + and self._sub_plugin.get("type") != "external" + ): + try: + self._sub_plugin["obj"].set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + except AttributeError: + pass + + def _update_connection_state(self): + """ + Reconstruct the connection socket_path and check if it exists + + If the socket path exists then the connection is active and set + both the _socket_path value to the path and the _connected value + to True. If the socket path doesn't exist, leave the socket path + value to None and the _connected value to False + """ + ssh = connection_loader.get("ssh", class_only=True) + control_path = ssh._create_control_path( + self._play_context.remote_addr, + self._play_context.port, + self._play_context.remote_user, + self._play_context.connection, + self._ansible_playbook_pid, + ) + + tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) + socket_path = unfrackpath(control_path % dict(directory=tmp_path)) + + if os.path.exists(socket_path): + self._connected = True + self._socket_path = socket_path + + def _log_messages(self, message): + if self.get_option("persistent_log_messages"): + self.queue_message("log", message) -- cgit v1.2.3