diff options
Diffstat (limited to 'lib/ansible/module_utils')
136 files changed, 27414 insertions, 0 deletions
diff --git a/lib/ansible/module_utils/__init__.py b/lib/ansible/module_utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/__init__.py diff --git a/lib/ansible/module_utils/_text.py b/lib/ansible/module_utils/_text.py new file mode 100644 index 0000000..6cd7721 --- /dev/null +++ b/lib/ansible/module_utils/_text.py @@ -0,0 +1,15 @@ +# Copyright (c), Toshio Kuratomi <tkuratomi@ansible.com> 2016 +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +""" +.. warn:: Use ansible.module_utils.common.text.converters instead. +""" + +# Backwards compat for people still calling it from this package +import codecs + +from ansible.module_utils.six import PY3, text_type, binary_type + +from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text diff --git a/lib/ansible/module_utils/ansible_release.py b/lib/ansible/module_utils/ansible_release.py new file mode 100644 index 0000000..66a04b9 --- /dev/null +++ b/lib/ansible/module_utils/ansible_release.py @@ -0,0 +1,24 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +__version__ = '2.14.3' +__author__ = 'Ansible, Inc.' +__codename__ = "C'mon Everybody" diff --git a/lib/ansible/module_utils/api.py b/lib/ansible/module_utils/api.py new file mode 100644 index 0000000..e780ec6 --- /dev/null +++ b/lib/ansible/module_utils/api.py @@ -0,0 +1,166 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright: (c) 2015, Brian Coca, <bcoca@ansible.com> +# +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) +""" +This module adds shared support for generic api modules + +In order to use this module, include it as part of a custom +module as shown below. + +The 'api' module provides the following common argument specs: + + * rate limit spec + - rate: number of requests per time unit (int) + - rate_limit: time window in which the limit is applied in seconds + + * retry spec + - retries: number of attempts + - retry_pause: delay between attempts in seconds +""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import functools +import random +import sys +import time + + +def rate_limit_argument_spec(spec=None): + """Creates an argument spec for working with rate limiting""" + arg_spec = (dict( + rate=dict(type='int'), + rate_limit=dict(type='int'), + )) + if spec: + arg_spec.update(spec) + return arg_spec + + +def retry_argument_spec(spec=None): + """Creates an argument spec for working with retrying""" + arg_spec = (dict( + retries=dict(type='int'), + retry_pause=dict(type='float', default=1), + )) + if spec: + arg_spec.update(spec) + return arg_spec + + +def basic_auth_argument_spec(spec=None): + arg_spec = (dict( + api_username=dict(type='str'), + api_password=dict(type='str', no_log=True), + api_url=dict(type='str'), + validate_certs=dict(type='bool', default=True) + )) + if spec: + arg_spec.update(spec) + return arg_spec + + +def rate_limit(rate=None, rate_limit=None): + """rate limiting decorator""" + minrate = None + if rate is not None and rate_limit is not None: + minrate = float(rate_limit) / float(rate) + + def wrapper(f): + last = [0.0] + + def ratelimited(*args, **kwargs): + if sys.version_info >= (3, 8): + real_time = time.process_time + else: + real_time = time.clock + if minrate is not None: + elapsed = real_time() - last[0] + left = minrate - elapsed + if left > 0: + time.sleep(left) + last[0] = real_time() + ret = f(*args, **kwargs) + return ret + + return ratelimited + return wrapper + + +def retry(retries=None, retry_pause=1): + """Retry decorator""" + def wrapper(f): + + def retried(*args, **kwargs): + retry_count = 0 + if retries is not None: + ret = None + while True: + retry_count += 1 + if retry_count >= retries: + raise Exception("Retry limit exceeded: %d" % retries) + try: + ret = f(*args, **kwargs) + except Exception: + pass + if ret: + break + time.sleep(retry_pause) + return ret + + return retried + return wrapper + + +def generate_jittered_backoff(retries=10, delay_base=3, delay_threshold=60): + """The "Full Jitter" backoff strategy. + + Ref: https://www.awsarchitectureblog.com/2015/03/backoff.html + + :param retries: The number of delays to generate. + :param delay_base: The base time in seconds used to calculate the exponential backoff. + :param delay_threshold: The maximum time in seconds for any delay. + """ + for retry in range(0, retries): + yield random.randint(0, min(delay_threshold, delay_base * 2 ** retry)) + + +def retry_never(exception_or_result): + return False + + +def retry_with_delays_and_condition(backoff_iterator, should_retry_error=None): + """Generic retry decorator. + + :param backoff_iterator: An iterable of delays in seconds. + :param should_retry_error: A callable that takes an exception of the decorated function and decides whether to retry or not (returns a bool). + """ + if should_retry_error is None: + should_retry_error = retry_never + + def function_wrapper(function): + @functools.wraps(function) + def run_function(*args, **kwargs): + """This assumes the function has not already been called. + If backoff_iterator is empty, we should still run the function a single time with no delay. + """ + call_retryable_function = functools.partial(function, *args, **kwargs) + + for delay in backoff_iterator: + try: + return call_retryable_function() + except Exception as e: + if not should_retry_error(e): + raise + time.sleep(delay) + + # Only or final attempt + return call_retryable_function() + return run_function + return function_wrapper diff --git a/lib/ansible/module_utils/basic.py b/lib/ansible/module_utils/basic.py new file mode 100644 index 0000000..67be924 --- /dev/null +++ b/lib/ansible/module_utils/basic.py @@ -0,0 +1,2148 @@ +# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013 +# Copyright (c), Toshio Kuratomi <tkuratomi@ansible.com> 2016 +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +FILE_ATTRIBUTES = { + 'A': 'noatime', + 'a': 'append', + 'c': 'compressed', + 'C': 'nocow', + 'd': 'nodump', + 'D': 'dirsync', + 'e': 'extents', + 'E': 'encrypted', + 'h': 'blocksize', + 'i': 'immutable', + 'I': 'indexed', + 'j': 'journalled', + 'N': 'inline', + 's': 'zero', + 'S': 'synchronous', + 't': 'notail', + 'T': 'blockroot', + 'u': 'undelete', + 'X': 'compressedraw', + 'Z': 'compresseddirty', +} + +# Ansible modules can be written in any language. +# The functions available here can be used to do many common tasks, +# to simplify development of Python modules. + +import __main__ +import atexit +import errno +import datetime +import grp +import fcntl +import locale +import os +import pwd +import platform +import re +import select +import shlex +import shutil +import signal +import stat +import subprocess +import sys +import tempfile +import time +import traceback +import types + +from itertools import chain, repeat + +try: + import syslog + HAS_SYSLOG = True +except ImportError: + HAS_SYSLOG = False + +try: + from systemd import journal, daemon as systemd_daemon + # Makes sure that systemd.journal has method sendv() + # Double check that journal has method sendv (some packages don't) + # check if the system is running under systemd + has_journal = hasattr(journal, 'sendv') and systemd_daemon.booted() +except (ImportError, AttributeError): + # AttributeError would be caused from use of .booted() if wrong systemd + has_journal = False + +HAVE_SELINUX = False +try: + from ansible.module_utils.compat import selinux + HAVE_SELINUX = True +except ImportError: + pass + +# Python2 & 3 way to get NoneType +NoneType = type(None) + +from ansible.module_utils.compat import selectors + +from ._text import to_native, to_bytes, to_text +from ansible.module_utils.common.text.converters import ( + jsonify, + container_to_bytes as json_dict_unicode_to_bytes, + container_to_text as json_dict_bytes_to_unicode, +) + +from ansible.module_utils.common.arg_spec import ModuleArgumentSpecValidator + +from ansible.module_utils.common.text.formatters import ( + lenient_lowercase, + bytes_to_human, + human_to_bytes, + SIZE_RANGES, +) + +try: + from ansible.module_utils.common._json_compat import json +except ImportError as e: + print('\n{{"msg": "Error: ansible requires the stdlib json: {0}", "failed": true}}'.format(to_native(e))) + sys.exit(1) + + +AVAILABLE_HASH_ALGORITHMS = dict() +try: + import hashlib + + # python 2.7.9+ and 2.7.0+ + for attribute in ('available_algorithms', 'algorithms'): + algorithms = getattr(hashlib, attribute, None) + if algorithms: + break + if algorithms is None: + # python 2.5+ + algorithms = ('md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512') + for algorithm in algorithms: + AVAILABLE_HASH_ALGORITHMS[algorithm] = getattr(hashlib, algorithm) + + # we may have been able to import md5 but it could still not be available + try: + hashlib.md5() + except ValueError: + AVAILABLE_HASH_ALGORITHMS.pop('md5', None) +except Exception: + import sha + AVAILABLE_HASH_ALGORITHMS = {'sha1': sha.sha} + try: + import md5 + AVAILABLE_HASH_ALGORITHMS['md5'] = md5.md5 + except Exception: + pass + +from ansible.module_utils.common._collections_compat import ( + KeysView, + Mapping, MutableMapping, + Sequence, MutableSequence, + Set, MutableSet, +) +from ansible.module_utils.common.locale import get_best_parsable_locale +from ansible.module_utils.common.process import get_bin_path +from ansible.module_utils.common.file import ( + _PERM_BITS as PERM_BITS, + _EXEC_PERM_BITS as EXEC_PERM_BITS, + _DEFAULT_PERM as DEFAULT_PERM, + is_executable, + format_attributes, + get_flags_from_attributes, +) +from ansible.module_utils.common.sys_info import ( + get_distribution, + get_distribution_version, + get_platform_subclass, +) +from ansible.module_utils.pycompat24 import get_exception, literal_eval +from ansible.module_utils.common.parameters import ( + env_fallback, + remove_values, + sanitize_keys, + DEFAULT_TYPE_VALIDATORS, + PASS_VARS, + PASS_BOOLS, +) + +from ansible.module_utils.errors import AnsibleFallbackNotFound, AnsibleValidationErrorMultiple, UnsupportedError +from ansible.module_utils.six import ( + PY2, + PY3, + b, + binary_type, + integer_types, + iteritems, + string_types, + text_type, +) +from ansible.module_utils.six.moves import map, reduce, shlex_quote +from ansible.module_utils.common.validation import ( + check_missing_parameters, + safe_eval, +) +from ansible.module_utils.common._utils import get_all_subclasses as _get_all_subclasses +from ansible.module_utils.parsing.convert_bool import BOOLEANS, BOOLEANS_FALSE, BOOLEANS_TRUE, boolean +from ansible.module_utils.common.warnings import ( + deprecate, + get_deprecation_messages, + get_warning_messages, + warn, +) + +# Note: When getting Sequence from collections, it matches with strings. If +# this matters, make sure to check for strings before checking for sequencetype +SEQUENCETYPE = frozenset, KeysView, Sequence + +PASSWORD_MATCH = re.compile(r'^(?:.+[-_\s])?pass(?:[-_\s]?(?:word|phrase|wrd|wd)?)(?:[-_\s].+)?$', re.I) + +imap = map + +try: + # Python 2 + unicode # type: ignore[has-type] # pylint: disable=used-before-assignment +except NameError: + # Python 3 + unicode = text_type + +try: + # Python 2 + basestring # type: ignore[has-type] # pylint: disable=used-before-assignment +except NameError: + # Python 3 + basestring = string_types + +_literal_eval = literal_eval + +# End of deprecated names + +# Internal global holding passed in params. This is consulted in case +# multiple AnsibleModules are created. Otherwise each AnsibleModule would +# attempt to read from stdin. Other code should not use this directly as it +# is an internal implementation detail +_ANSIBLE_ARGS = None + + +FILE_COMMON_ARGUMENTS = dict( + # These are things we want. About setting metadata (mode, ownership, permissions in general) on + # created files (these are used by set_fs_attributes_if_different and included in + # load_file_common_arguments) + mode=dict(type='raw'), + owner=dict(type='str'), + group=dict(type='str'), + seuser=dict(type='str'), + serole=dict(type='str'), + selevel=dict(type='str'), + setype=dict(type='str'), + attributes=dict(type='str', aliases=['attr']), + unsafe_writes=dict(type='bool', default=False, fallback=(env_fallback, ['ANSIBLE_UNSAFE_WRITES'])), # should be available to any module using atomic_move +) + +PASSWD_ARG_RE = re.compile(r'^[-]{0,2}pass[-]?(word|wd)?') + +# Used for parsing symbolic file perms +MODE_OPERATOR_RE = re.compile(r'[+=-]') +USERS_RE = re.compile(r'[^ugo]') +PERMS_RE = re.compile(r'[^rwxXstugo]') + +# Used for determining if the system is running a new enough python version +# and should only restrict on our documented minimum versions +_PY3_MIN = sys.version_info >= (3, 5) +_PY2_MIN = (2, 7) <= sys.version_info < (3,) +_PY_MIN = _PY3_MIN or _PY2_MIN +if not _PY_MIN: + print( + '\n{"failed": true, ' + '"msg": "ansible-core requires a minimum of Python2 version 2.7 or Python3 version 3.5. Current version: %s"}' % ''.join(sys.version.splitlines()) + ) + sys.exit(1) + + +# +# Deprecated functions +# + +def get_platform(): + ''' + **Deprecated** Use :py:func:`platform.system` directly. + + :returns: Name of the platform the module is running on in a native string + + Returns a native string that labels the platform ("Linux", "Solaris", etc). Currently, this is + the result of calling :py:func:`platform.system`. + ''' + return platform.system() + +# End deprecated functions + + +# +# Compat shims +# + +def load_platform_subclass(cls, *args, **kwargs): + """**Deprecated**: Use ansible.module_utils.common.sys_info.get_platform_subclass instead""" + platform_cls = get_platform_subclass(cls) + return super(cls, platform_cls).__new__(platform_cls) + + +def get_all_subclasses(cls): + """**Deprecated**: Use ansible.module_utils.common._utils.get_all_subclasses instead""" + return list(_get_all_subclasses(cls)) + + +# End compat shims + + +def heuristic_log_sanitize(data, no_log_values=None): + ''' Remove strings that look like passwords from log messages ''' + # Currently filters: + # user:pass@foo/whatever and http://username:pass@wherever/foo + # This code has false positives and consumes parts of logs that are + # not passwds + + # begin: start of a passwd containing string + # end: end of a passwd containing string + # sep: char between user and passwd + # prev_begin: where in the overall string to start a search for + # a passwd + # sep_search_end: where in the string to end a search for the sep + data = to_native(data) + + output = [] + begin = len(data) + prev_begin = begin + sep = 1 + while sep: + # Find the potential end of a passwd + try: + end = data.rindex('@', 0, begin) + except ValueError: + # No passwd in the rest of the data + output.insert(0, data[0:begin]) + break + + # Search for the beginning of a passwd + sep = None + sep_search_end = end + while not sep: + # URL-style username+password + try: + begin = data.rindex('://', 0, sep_search_end) + except ValueError: + # No url style in the data, check for ssh style in the + # rest of the string + begin = 0 + # Search for separator + try: + sep = data.index(':', begin + 3, end) + except ValueError: + # No separator; choices: + if begin == 0: + # Searched the whole string so there's no password + # here. Return the remaining data + output.insert(0, data[0:prev_begin]) + break + # Search for a different beginning of the password field. + sep_search_end = begin + continue + if sep: + # Password was found; remove it. + output.insert(0, data[end:prev_begin]) + output.insert(0, '********') + output.insert(0, data[begin:sep + 1]) + prev_begin = begin + + output = ''.join(output) + if no_log_values: + output = remove_values(output, no_log_values) + return output + + +def _load_params(): + ''' read the modules parameters and store them globally. + + This function may be needed for certain very dynamic custom modules which + want to process the parameters that are being handed the module. Since + this is so closely tied to the implementation of modules we cannot + guarantee API stability for it (it may change between versions) however we + will try not to break it gratuitously. It is certainly more future-proof + to call this function and consume its outputs than to implement the logic + inside it as a copy in your own code. + ''' + global _ANSIBLE_ARGS + if _ANSIBLE_ARGS is not None: + buffer = _ANSIBLE_ARGS + else: + # debug overrides to read args from file or cmdline + + # Avoid tracebacks when locale is non-utf8 + # We control the args and we pass them as utf8 + if len(sys.argv) > 1: + if os.path.isfile(sys.argv[1]): + fd = open(sys.argv[1], 'rb') + buffer = fd.read() + fd.close() + else: + buffer = sys.argv[1] + if PY3: + buffer = buffer.encode('utf-8', errors='surrogateescape') + # default case, read from stdin + else: + if PY2: + buffer = sys.stdin.read() + else: + buffer = sys.stdin.buffer.read() + _ANSIBLE_ARGS = buffer + + try: + params = json.loads(buffer.decode('utf-8')) + except ValueError: + # This helper used too early for fail_json to work. + print('\n{"msg": "Error: Module unable to decode valid JSON on stdin. Unable to figure out what parameters were passed", "failed": true}') + sys.exit(1) + + if PY2: + params = json_dict_unicode_to_bytes(params) + + try: + return params['ANSIBLE_MODULE_ARGS'] + except KeyError: + # This helper does not have access to fail_json so we have to print + # json output on our own. + print('\n{"msg": "Error: Module unable to locate ANSIBLE_MODULE_ARGS in json data from stdin. Unable to figure out what parameters were passed", ' + '"failed": true}') + sys.exit(1) + + +def missing_required_lib(library, reason=None, url=None): + hostname = platform.node() + msg = "Failed to import the required Python library (%s) on %s's Python %s." % (library, hostname, sys.executable) + if reason: + msg += " This is required %s." % reason + if url: + msg += " See %s for more info." % url + + msg += (" Please read the module documentation and install it in the appropriate location." + " If the required library is installed, but Ansible is using the wrong Python interpreter," + " please consult the documentation on ansible_python_interpreter") + return msg + + +class AnsibleModule(object): + def __init__(self, argument_spec, bypass_checks=False, no_log=False, + mutually_exclusive=None, required_together=None, + required_one_of=None, add_file_common_args=False, + supports_check_mode=False, required_if=None, required_by=None): + + ''' + Common code for quickly building an ansible module in Python + (although you can write modules with anything that can return JSON). + + See :ref:`developing_modules_general` for a general introduction + and :ref:`developing_program_flow_modules` for more detailed explanation. + ''' + + self._name = os.path.basename(__file__) # initialize name until we can parse from options + self.argument_spec = argument_spec + self.supports_check_mode = supports_check_mode + self.check_mode = False + self.bypass_checks = bypass_checks + self.no_log = no_log + + self.mutually_exclusive = mutually_exclusive + self.required_together = required_together + self.required_one_of = required_one_of + self.required_if = required_if + self.required_by = required_by + self.cleanup_files = [] + self._debug = False + self._diff = False + self._socket_path = None + self._shell = None + self._syslog_facility = 'LOG_USER' + self._verbosity = 0 + # May be used to set modifications to the environment for any + # run_command invocation + self.run_command_environ_update = {} + self._clean = {} + self._string_conversion_action = '' + + self.aliases = {} + self._legal_inputs = [] + self._options_context = list() + self._tmpdir = None + + if add_file_common_args: + for k, v in FILE_COMMON_ARGUMENTS.items(): + if k not in self.argument_spec: + self.argument_spec[k] = v + + # Save parameter values that should never be logged + self.no_log_values = set() + + # check the locale as set by the current environment, and reset to + # a known valid (LANG=C) if it's an invalid/unavailable locale + self._check_locale() + + self._load_params() + self._set_internal_properties() + + self.validator = ModuleArgumentSpecValidator(self.argument_spec, + self.mutually_exclusive, + self.required_together, + self.required_one_of, + self.required_if, + self.required_by, + ) + + self.validation_result = self.validator.validate(self.params) + self.params.update(self.validation_result.validated_parameters) + self.no_log_values.update(self.validation_result._no_log_values) + self.aliases.update(self.validation_result._aliases) + + try: + error = self.validation_result.errors[0] + except IndexError: + error = None + + # Fail for validation errors, even in check mode + if error: + msg = self.validation_result.errors.msg + if isinstance(error, UnsupportedError): + msg = "Unsupported parameters for ({name}) {kind}: {msg}".format(name=self._name, kind='module', msg=msg) + + self.fail_json(msg=msg) + + if self.check_mode and not self.supports_check_mode: + self.exit_json(skipped=True, msg="remote module (%s) does not support check mode" % self._name) + + # This is for backwards compatibility only. + self._CHECK_ARGUMENT_TYPES_DISPATCHER = DEFAULT_TYPE_VALIDATORS + + if not self.no_log: + self._log_invocation() + + # selinux state caching + self._selinux_enabled = None + self._selinux_mls_enabled = None + self._selinux_initial_context = None + + # finally, make sure we're in a sane working dir + self._set_cwd() + + @property + def tmpdir(self): + # if _ansible_tmpdir was not set and we have a remote_tmp, + # the module needs to create it and clean it up once finished. + # otherwise we create our own module tmp dir from the system defaults + if self._tmpdir is None: + basedir = None + + if self._remote_tmp is not None: + basedir = os.path.expanduser(os.path.expandvars(self._remote_tmp)) + + if basedir is not None and not os.path.exists(basedir): + try: + os.makedirs(basedir, mode=0o700) + except (OSError, IOError) as e: + self.warn("Unable to use %s as temporary directory, " + "failing back to system: %s" % (basedir, to_native(e))) + basedir = None + else: + self.warn("Module remote_tmp %s did not exist and was " + "created with a mode of 0700, this may cause" + " issues when running as another user. To " + "avoid this, create the remote_tmp dir with " + "the correct permissions manually" % basedir) + + basefile = "ansible-moduletmp-%s-" % time.time() + try: + tmpdir = tempfile.mkdtemp(prefix=basefile, dir=basedir) + except (OSError, IOError) as e: + self.fail_json( + msg="Failed to create remote module tmp path at dir %s " + "with prefix %s: %s" % (basedir, basefile, to_native(e)) + ) + if not self._keep_remote_files: + atexit.register(shutil.rmtree, tmpdir) + self._tmpdir = tmpdir + + return self._tmpdir + + def warn(self, warning): + warn(warning) + self.log('[WARNING] %s' % warning) + + def deprecate(self, msg, version=None, date=None, collection_name=None): + if version is not None and date is not None: + raise AssertionError("implementation error -- version and date must not both be set") + deprecate(msg, version=version, date=date, collection_name=collection_name) + # For compatibility, we accept that neither version nor date is set, + # and treat that the same as if version would haven been set + if date is not None: + self.log('[DEPRECATION WARNING] %s %s' % (msg, date)) + else: + self.log('[DEPRECATION WARNING] %s %s' % (msg, version)) + + def load_file_common_arguments(self, params, path=None): + ''' + many modules deal with files, this encapsulates common + options that the file module accepts such that it is directly + available to all modules and they can share code. + + Allows to overwrite the path/dest module argument by providing path. + ''' + + if path is None: + path = params.get('path', params.get('dest', None)) + if path is None: + return {} + else: + path = os.path.expanduser(os.path.expandvars(path)) + + b_path = to_bytes(path, errors='surrogate_or_strict') + # if the path is a symlink, and we're following links, get + # the target of the link instead for testing + if params.get('follow', False) and os.path.islink(b_path): + b_path = os.path.realpath(b_path) + path = to_native(b_path) + + mode = params.get('mode', None) + owner = params.get('owner', None) + group = params.get('group', None) + + # selinux related options + seuser = params.get('seuser', None) + serole = params.get('serole', None) + setype = params.get('setype', None) + selevel = params.get('selevel', None) + secontext = [seuser, serole, setype] + + if self.selinux_mls_enabled(): + secontext.append(selevel) + + default_secontext = self.selinux_default_context(path) + for i in range(len(default_secontext)): + if i is not None and secontext[i] == '_default': + secontext[i] = default_secontext[i] + + attributes = params.get('attributes', None) + return dict( + path=path, mode=mode, owner=owner, group=group, + seuser=seuser, serole=serole, setype=setype, + selevel=selevel, secontext=secontext, attributes=attributes, + ) + + # Detect whether using selinux that is MLS-aware. + # While this means you can set the level/range with + # selinux.lsetfilecon(), it may or may not mean that you + # will get the selevel as part of the context returned + # by selinux.lgetfilecon(). + + def selinux_mls_enabled(self): + if self._selinux_mls_enabled is None: + self._selinux_mls_enabled = HAVE_SELINUX and selinux.is_selinux_mls_enabled() == 1 + + return self._selinux_mls_enabled + + def selinux_enabled(self): + if self._selinux_enabled is None: + self._selinux_enabled = HAVE_SELINUX and selinux.is_selinux_enabled() == 1 + + return self._selinux_enabled + + # Determine whether we need a placeholder for selevel/mls + def selinux_initial_context(self): + if self._selinux_initial_context is None: + self._selinux_initial_context = [None, None, None] + if self.selinux_mls_enabled(): + self._selinux_initial_context.append(None) + + return self._selinux_initial_context + + # If selinux fails to find a default, return an array of None + def selinux_default_context(self, path, mode=0): + context = self.selinux_initial_context() + if not self.selinux_enabled(): + return context + try: + ret = selinux.matchpathcon(to_native(path, errors='surrogate_or_strict'), mode) + except OSError: + return context + if ret[0] == -1: + return context + # Limit split to 4 because the selevel, the last in the list, + # may contain ':' characters + context = ret[1].split(':', 3) + return context + + def selinux_context(self, path): + context = self.selinux_initial_context() + if not self.selinux_enabled(): + return context + try: + ret = selinux.lgetfilecon_raw(to_native(path, errors='surrogate_or_strict')) + except OSError as e: + if e.errno == errno.ENOENT: + self.fail_json(path=path, msg='path %s does not exist' % path) + else: + self.fail_json(path=path, msg='failed to retrieve selinux context') + if ret[0] == -1: + return context + # Limit split to 4 because the selevel, the last in the list, + # may contain ':' characters + context = ret[1].split(':', 3) + return context + + def user_and_group(self, path, expand=True): + b_path = to_bytes(path, errors='surrogate_or_strict') + if expand: + b_path = os.path.expanduser(os.path.expandvars(b_path)) + st = os.lstat(b_path) + uid = st.st_uid + gid = st.st_gid + return (uid, gid) + + def find_mount_point(self, path): + ''' + Takes a path and returns it's mount point + + :param path: a string type with a filesystem path + :returns: the path to the mount point as a text type + ''' + + b_path = os.path.realpath(to_bytes(os.path.expanduser(os.path.expandvars(path)), errors='surrogate_or_strict')) + while not os.path.ismount(b_path): + b_path = os.path.dirname(b_path) + + return to_text(b_path, errors='surrogate_or_strict') + + def is_special_selinux_path(self, path): + """ + Returns a tuple containing (True, selinux_context) if the given path is on a + NFS or other 'special' fs mount point, otherwise the return will be (False, None). + """ + try: + f = open('/proc/mounts', 'r') + mount_data = f.readlines() + f.close() + except Exception: + return (False, None) + + path_mount_point = self.find_mount_point(path) + + for line in mount_data: + (device, mount_point, fstype, options, rest) = line.split(' ', 4) + if to_bytes(path_mount_point) == to_bytes(mount_point): + for fs in self._selinux_special_fs: + if fs in fstype: + special_context = self.selinux_context(path_mount_point) + return (True, special_context) + + return (False, None) + + def set_default_selinux_context(self, path, changed): + if not self.selinux_enabled(): + return changed + context = self.selinux_default_context(path) + return self.set_context_if_different(path, context, False) + + def set_context_if_different(self, path, context, changed, diff=None): + + if not self.selinux_enabled(): + return changed + + if self.check_file_absent_if_check_mode(path): + return True + + cur_context = self.selinux_context(path) + new_context = list(cur_context) + # Iterate over the current context instead of the + # argument context, which may have selevel. + + (is_special_se, sp_context) = self.is_special_selinux_path(path) + if is_special_se: + new_context = sp_context + else: + for i in range(len(cur_context)): + if len(context) > i: + if context[i] is not None and context[i] != cur_context[i]: + new_context[i] = context[i] + elif context[i] is None: + new_context[i] = cur_context[i] + + if cur_context != new_context: + if diff is not None: + if 'before' not in diff: + diff['before'] = {} + diff['before']['secontext'] = cur_context + if 'after' not in diff: + diff['after'] = {} + diff['after']['secontext'] = new_context + + try: + if self.check_mode: + return True + rc = selinux.lsetfilecon(to_native(path), ':'.join(new_context)) + except OSError as e: + self.fail_json(path=path, msg='invalid selinux context: %s' % to_native(e), + new_context=new_context, cur_context=cur_context, input_was=context) + if rc != 0: + self.fail_json(path=path, msg='set selinux context failed') + changed = True + return changed + + def set_owner_if_different(self, path, owner, changed, diff=None, expand=True): + + if owner is None: + return changed + + b_path = to_bytes(path, errors='surrogate_or_strict') + if expand: + b_path = os.path.expanduser(os.path.expandvars(b_path)) + + if self.check_file_absent_if_check_mode(b_path): + return True + + orig_uid, orig_gid = self.user_and_group(b_path, expand) + try: + uid = int(owner) + except ValueError: + try: + uid = pwd.getpwnam(owner).pw_uid + except KeyError: + path = to_text(b_path) + self.fail_json(path=path, msg='chown failed: failed to look up user %s' % owner) + + if orig_uid != uid: + if diff is not None: + if 'before' not in diff: + diff['before'] = {} + diff['before']['owner'] = orig_uid + if 'after' not in diff: + diff['after'] = {} + diff['after']['owner'] = uid + + if self.check_mode: + return True + try: + os.lchown(b_path, uid, -1) + except (IOError, OSError) as e: + path = to_text(b_path) + self.fail_json(path=path, msg='chown failed: %s' % (to_text(e))) + changed = True + return changed + + def set_group_if_different(self, path, group, changed, diff=None, expand=True): + + if group is None: + return changed + + b_path = to_bytes(path, errors='surrogate_or_strict') + if expand: + b_path = os.path.expanduser(os.path.expandvars(b_path)) + + if self.check_file_absent_if_check_mode(b_path): + return True + + orig_uid, orig_gid = self.user_and_group(b_path, expand) + try: + gid = int(group) + except ValueError: + try: + gid = grp.getgrnam(group).gr_gid + except KeyError: + path = to_text(b_path) + self.fail_json(path=path, msg='chgrp failed: failed to look up group %s' % group) + + if orig_gid != gid: + if diff is not None: + if 'before' not in diff: + diff['before'] = {} + diff['before']['group'] = orig_gid + if 'after' not in diff: + diff['after'] = {} + diff['after']['group'] = gid + + if self.check_mode: + return True + try: + os.lchown(b_path, -1, gid) + except OSError: + path = to_text(b_path) + self.fail_json(path=path, msg='chgrp failed') + changed = True + return changed + + def set_mode_if_different(self, path, mode, changed, diff=None, expand=True): + + if mode is None: + return changed + + b_path = to_bytes(path, errors='surrogate_or_strict') + if expand: + b_path = os.path.expanduser(os.path.expandvars(b_path)) + + if self.check_file_absent_if_check_mode(b_path): + return True + + path_stat = os.lstat(b_path) + + if not isinstance(mode, int): + try: + mode = int(mode, 8) + except Exception: + try: + mode = self._symbolic_mode_to_octal(path_stat, mode) + except Exception as e: + path = to_text(b_path) + self.fail_json(path=path, + msg="mode must be in octal or symbolic form", + details=to_native(e)) + + if mode != stat.S_IMODE(mode): + # prevent mode from having extra info orbeing invalid long number + path = to_text(b_path) + self.fail_json(path=path, msg="Invalid mode supplied, only permission info is allowed", details=mode) + + prev_mode = stat.S_IMODE(path_stat.st_mode) + + if prev_mode != mode: + + if diff is not None: + if 'before' not in diff: + diff['before'] = {} + diff['before']['mode'] = '0%03o' % prev_mode + if 'after' not in diff: + diff['after'] = {} + diff['after']['mode'] = '0%03o' % mode + + if self.check_mode: + return True + # FIXME: comparison against string above will cause this to be executed + # every time + try: + if hasattr(os, 'lchmod'): + os.lchmod(b_path, mode) + else: + if not os.path.islink(b_path): + os.chmod(b_path, mode) + else: + # Attempt to set the perms of the symlink but be + # careful not to change the perms of the underlying + # file while trying + underlying_stat = os.stat(b_path) + os.chmod(b_path, mode) + new_underlying_stat = os.stat(b_path) + if underlying_stat.st_mode != new_underlying_stat.st_mode: + os.chmod(b_path, stat.S_IMODE(underlying_stat.st_mode)) + except OSError as e: + if os.path.islink(b_path) and e.errno in ( + errno.EACCES, # can't access symlink in sticky directory (stat) + errno.EPERM, # can't set mode on symbolic links (chmod) + errno.EROFS, # can't set mode on read-only filesystem + ): + pass + elif e.errno in (errno.ENOENT, errno.ELOOP): # Can't set mode on broken symbolic links + pass + else: + raise + except Exception as e: + path = to_text(b_path) + self.fail_json(path=path, msg='chmod failed', details=to_native(e), + exception=traceback.format_exc()) + + path_stat = os.lstat(b_path) + new_mode = stat.S_IMODE(path_stat.st_mode) + + if new_mode != prev_mode: + changed = True + return changed + + def set_attributes_if_different(self, path, attributes, changed, diff=None, expand=True): + + if attributes is None: + return changed + + b_path = to_bytes(path, errors='surrogate_or_strict') + if expand: + b_path = os.path.expanduser(os.path.expandvars(b_path)) + + if self.check_file_absent_if_check_mode(b_path): + return True + + existing = self.get_file_attributes(b_path, include_version=False) + + attr_mod = '=' + if attributes.startswith(('-', '+')): + attr_mod = attributes[0] + attributes = attributes[1:] + + if existing.get('attr_flags', '') != attributes or attr_mod == '-': + attrcmd = self.get_bin_path('chattr') + if attrcmd: + attrcmd = [attrcmd, '%s%s' % (attr_mod, attributes), b_path] + changed = True + + if diff is not None: + if 'before' not in diff: + diff['before'] = {} + diff['before']['attributes'] = existing.get('attr_flags') + if 'after' not in diff: + diff['after'] = {} + diff['after']['attributes'] = '%s%s' % (attr_mod, attributes) + + if not self.check_mode: + try: + rc, out, err = self.run_command(attrcmd) + if rc != 0 or err: + raise Exception("Error while setting attributes: %s" % (out + err)) + except Exception as e: + self.fail_json(path=to_text(b_path), msg='chattr failed', + details=to_native(e), exception=traceback.format_exc()) + return changed + + def get_file_attributes(self, path, include_version=True): + output = {} + attrcmd = self.get_bin_path('lsattr', False) + if attrcmd: + flags = '-vd' if include_version else '-d' + attrcmd = [attrcmd, flags, path] + try: + rc, out, err = self.run_command(attrcmd) + if rc == 0: + res = out.split() + attr_flags_idx = 0 + if include_version: + attr_flags_idx = 1 + output['version'] = res[0].strip() + output['attr_flags'] = res[attr_flags_idx].replace('-', '').strip() + output['attributes'] = format_attributes(output['attr_flags']) + except Exception: + pass + return output + + @classmethod + def _symbolic_mode_to_octal(cls, path_stat, symbolic_mode): + """ + This enables symbolic chmod string parsing as stated in the chmod man-page + + This includes things like: "u=rw-x+X,g=r-x+X,o=r-x+X" + """ + + new_mode = stat.S_IMODE(path_stat.st_mode) + + # Now parse all symbolic modes + for mode in symbolic_mode.split(','): + # Per single mode. This always contains a '+', '-' or '=' + # Split it on that + permlist = MODE_OPERATOR_RE.split(mode) + + # And find all the operators + opers = MODE_OPERATOR_RE.findall(mode) + + # The user(s) where it's all about is the first element in the + # 'permlist' list. Take that and remove it from the list. + # An empty user or 'a' means 'all'. + users = permlist.pop(0) + use_umask = (users == '') + if users == 'a' or users == '': + users = 'ugo' + + # Check if there are illegal characters in the user list + # They can end up in 'users' because they are not split + if USERS_RE.match(users): + raise ValueError("bad symbolic permission for mode: %s" % mode) + + # Now we have two list of equal length, one contains the requested + # permissions and one with the corresponding operators. + for idx, perms in enumerate(permlist): + # Check if there are illegal characters in the permissions + if PERMS_RE.match(perms): + raise ValueError("bad symbolic permission for mode: %s" % mode) + + for user in users: + mode_to_apply = cls._get_octal_mode_from_symbolic_perms(path_stat, user, perms, use_umask) + new_mode = cls._apply_operation_to_mode(user, opers[idx], mode_to_apply, new_mode) + + return new_mode + + @staticmethod + def _apply_operation_to_mode(user, operator, mode_to_apply, current_mode): + if operator == '=': + if user == 'u': + mask = stat.S_IRWXU | stat.S_ISUID + elif user == 'g': + mask = stat.S_IRWXG | stat.S_ISGID + elif user == 'o': + mask = stat.S_IRWXO | stat.S_ISVTX + + # mask out u, g, or o permissions from current_mode and apply new permissions + inverse_mask = mask ^ PERM_BITS + new_mode = (current_mode & inverse_mask) | mode_to_apply + elif operator == '+': + new_mode = current_mode | mode_to_apply + elif operator == '-': + new_mode = current_mode - (current_mode & mode_to_apply) + return new_mode + + @staticmethod + def _get_octal_mode_from_symbolic_perms(path_stat, user, perms, use_umask): + prev_mode = stat.S_IMODE(path_stat.st_mode) + + is_directory = stat.S_ISDIR(path_stat.st_mode) + has_x_permissions = (prev_mode & EXEC_PERM_BITS) > 0 + apply_X_permission = is_directory or has_x_permissions + + # Get the umask, if the 'user' part is empty, the effect is as if (a) were + # given, but bits that are set in the umask are not affected. + # We also need the "reversed umask" for masking + umask = os.umask(0) + os.umask(umask) + rev_umask = umask ^ PERM_BITS + + # Permission bits constants documented at: + # https://docs.python.org/3/library/stat.html#stat.S_ISUID + if apply_X_permission: + X_perms = { + 'u': {'X': stat.S_IXUSR}, + 'g': {'X': stat.S_IXGRP}, + 'o': {'X': stat.S_IXOTH}, + } + else: + X_perms = { + 'u': {'X': 0}, + 'g': {'X': 0}, + 'o': {'X': 0}, + } + + user_perms_to_modes = { + 'u': { + 'r': rev_umask & stat.S_IRUSR if use_umask else stat.S_IRUSR, + 'w': rev_umask & stat.S_IWUSR if use_umask else stat.S_IWUSR, + 'x': rev_umask & stat.S_IXUSR if use_umask else stat.S_IXUSR, + 's': stat.S_ISUID, + 't': 0, + 'u': prev_mode & stat.S_IRWXU, + 'g': (prev_mode & stat.S_IRWXG) << 3, + 'o': (prev_mode & stat.S_IRWXO) << 6}, + 'g': { + 'r': rev_umask & stat.S_IRGRP if use_umask else stat.S_IRGRP, + 'w': rev_umask & stat.S_IWGRP if use_umask else stat.S_IWGRP, + 'x': rev_umask & stat.S_IXGRP if use_umask else stat.S_IXGRP, + 's': stat.S_ISGID, + 't': 0, + 'u': (prev_mode & stat.S_IRWXU) >> 3, + 'g': prev_mode & stat.S_IRWXG, + 'o': (prev_mode & stat.S_IRWXO) << 3}, + 'o': { + 'r': rev_umask & stat.S_IROTH if use_umask else stat.S_IROTH, + 'w': rev_umask & stat.S_IWOTH if use_umask else stat.S_IWOTH, + 'x': rev_umask & stat.S_IXOTH if use_umask else stat.S_IXOTH, + 's': 0, + 't': stat.S_ISVTX, + 'u': (prev_mode & stat.S_IRWXU) >> 6, + 'g': (prev_mode & stat.S_IRWXG) >> 3, + 'o': prev_mode & stat.S_IRWXO}, + } + + # Insert X_perms into user_perms_to_modes + for key, value in X_perms.items(): + user_perms_to_modes[key].update(value) + + def or_reduce(mode, perm): + return mode | user_perms_to_modes[user][perm] + + return reduce(or_reduce, perms, 0) + + def set_fs_attributes_if_different(self, file_args, changed, diff=None, expand=True): + # set modes owners and context as needed + changed = self.set_context_if_different( + file_args['path'], file_args['secontext'], changed, diff + ) + changed = self.set_owner_if_different( + file_args['path'], file_args['owner'], changed, diff, expand + ) + changed = self.set_group_if_different( + file_args['path'], file_args['group'], changed, diff, expand + ) + changed = self.set_mode_if_different( + file_args['path'], file_args['mode'], changed, diff, expand + ) + changed = self.set_attributes_if_different( + file_args['path'], file_args['attributes'], changed, diff, expand + ) + return changed + + def check_file_absent_if_check_mode(self, file_path): + return self.check_mode and not os.path.exists(file_path) + + def set_directory_attributes_if_different(self, file_args, changed, diff=None, expand=True): + return self.set_fs_attributes_if_different(file_args, changed, diff, expand) + + def set_file_attributes_if_different(self, file_args, changed, diff=None, expand=True): + return self.set_fs_attributes_if_different(file_args, changed, diff, expand) + + def add_path_info(self, kwargs): + ''' + for results that are files, supplement the info about the file + in the return path with stats about the file path. + ''' + + path = kwargs.get('path', kwargs.get('dest', None)) + if path is None: + return kwargs + b_path = to_bytes(path, errors='surrogate_or_strict') + if os.path.exists(b_path): + (uid, gid) = self.user_and_group(path) + kwargs['uid'] = uid + kwargs['gid'] = gid + try: + user = pwd.getpwuid(uid)[0] + except KeyError: + user = str(uid) + try: + group = grp.getgrgid(gid)[0] + except KeyError: + group = str(gid) + kwargs['owner'] = user + kwargs['group'] = group + st = os.lstat(b_path) + kwargs['mode'] = '0%03o' % stat.S_IMODE(st[stat.ST_MODE]) + # secontext not yet supported + if os.path.islink(b_path): + kwargs['state'] = 'link' + elif os.path.isdir(b_path): + kwargs['state'] = 'directory' + elif os.stat(b_path).st_nlink > 1: + kwargs['state'] = 'hard' + else: + kwargs['state'] = 'file' + if self.selinux_enabled(): + kwargs['secontext'] = ':'.join(self.selinux_context(path)) + kwargs['size'] = st[stat.ST_SIZE] + return kwargs + + def _check_locale(self): + ''' + Uses the locale module to test the currently set locale + (per the LANG and LC_CTYPE environment settings) + ''' + try: + # setting the locale to '' uses the default locale + # as it would be returned by locale.getdefaultlocale() + locale.setlocale(locale.LC_ALL, '') + except locale.Error: + # fallback to the 'best' locale, per the function + # final fallback is 'C', which may cause unicode issues + # but is preferable to simply failing on unknown locale + best_locale = get_best_parsable_locale(self) + + # need to set several since many tools choose to ignore documented precedence and scope + locale.setlocale(locale.LC_ALL, best_locale) + os.environ['LANG'] = best_locale + os.environ['LC_ALL'] = best_locale + os.environ['LC_MESSAGES'] = best_locale + except Exception as e: + self.fail_json(msg="An unknown error was encountered while attempting to validate the locale: %s" % + to_native(e), exception=traceback.format_exc()) + + def _set_internal_properties(self, argument_spec=None, module_parameters=None): + if argument_spec is None: + argument_spec = self.argument_spec + if module_parameters is None: + module_parameters = self.params + + for k in PASS_VARS: + # handle setting internal properties from internal ansible vars + param_key = '_ansible_%s' % k + if param_key in module_parameters: + if k in PASS_BOOLS: + setattr(self, PASS_VARS[k][0], self.boolean(module_parameters[param_key])) + else: + setattr(self, PASS_VARS[k][0], module_parameters[param_key]) + + # clean up internal top level params: + if param_key in self.params: + del self.params[param_key] + else: + # use defaults if not already set + if not hasattr(self, PASS_VARS[k][0]): + setattr(self, PASS_VARS[k][0], PASS_VARS[k][1]) + + def safe_eval(self, value, locals=None, include_exceptions=False): + return safe_eval(value, locals, include_exceptions) + + def _load_params(self): + ''' read the input and set the params attribute. + + This method is for backwards compatibility. The guts of the function + were moved out in 2.1 so that custom modules could read the parameters. + ''' + # debug overrides to read args from file or cmdline + self.params = _load_params() + + def _log_to_syslog(self, msg): + if HAS_SYSLOG: + try: + module = 'ansible-%s' % self._name + facility = getattr(syslog, self._syslog_facility, syslog.LOG_USER) + syslog.openlog(str(module), 0, facility) + syslog.syslog(syslog.LOG_INFO, msg) + except TypeError as e: + self.fail_json( + msg='Failed to log to syslog (%s). To proceed anyway, ' + 'disable syslog logging by setting no_target_syslog ' + 'to True in your Ansible config.' % to_native(e), + exception=traceback.format_exc(), + msg_to_log=msg, + ) + + def debug(self, msg): + if self._debug: + self.log('[debug] %s' % msg) + + def log(self, msg, log_args=None): + + if not self.no_log: + + if log_args is None: + log_args = dict() + + module = 'ansible-%s' % self._name + if isinstance(module, binary_type): + module = module.decode('utf-8', 'replace') + + # 6655 - allow for accented characters + if not isinstance(msg, (binary_type, text_type)): + raise TypeError("msg should be a string (got %s)" % type(msg)) + + # We want journal to always take text type + # syslog takes bytes on py2, text type on py3 + if isinstance(msg, binary_type): + journal_msg = remove_values(msg.decode('utf-8', 'replace'), self.no_log_values) + else: + # TODO: surrogateescape is a danger here on Py3 + journal_msg = remove_values(msg, self.no_log_values) + + if PY3: + syslog_msg = journal_msg + else: + syslog_msg = journal_msg.encode('utf-8', 'replace') + + if has_journal: + journal_args = [("MODULE", os.path.basename(__file__))] + for arg in log_args: + name, value = (arg.upper(), str(log_args[arg])) + if name in ( + 'PRIORITY', 'MESSAGE', 'MESSAGE_ID', + 'CODE_FILE', 'CODE_LINE', 'CODE_FUNC', + 'SYSLOG_FACILITY', 'SYSLOG_IDENTIFIER', + 'SYSLOG_PID', + ): + name = "_%s" % name + journal_args.append((name, value)) + + try: + if HAS_SYSLOG: + # If syslog_facility specified, it needs to convert + # from the facility name to the facility code, and + # set it as SYSLOG_FACILITY argument of journal.send() + facility = getattr(syslog, + self._syslog_facility, + syslog.LOG_USER) >> 3 + journal.send(MESSAGE=u"%s %s" % (module, journal_msg), + SYSLOG_FACILITY=facility, + **dict(journal_args)) + else: + journal.send(MESSAGE=u"%s %s" % (module, journal_msg), + **dict(journal_args)) + except IOError: + # fall back to syslog since logging to journal failed + self._log_to_syslog(syslog_msg) + else: + self._log_to_syslog(syslog_msg) + + def _log_invocation(self): + ''' log that ansible ran the module ''' + # TODO: generalize a separate log function and make log_invocation use it + # Sanitize possible password argument when logging. + log_args = dict() + + for param in self.params: + canon = self.aliases.get(param, param) + arg_opts = self.argument_spec.get(canon, {}) + no_log = arg_opts.get('no_log', None) + + # try to proactively capture password/passphrase fields + if no_log is None and PASSWORD_MATCH.search(param): + log_args[param] = 'NOT_LOGGING_PASSWORD' + self.warn('Module did not set no_log for %s' % param) + elif self.boolean(no_log): + log_args[param] = 'NOT_LOGGING_PARAMETER' + else: + param_val = self.params[param] + if not isinstance(param_val, (text_type, binary_type)): + param_val = str(param_val) + elif isinstance(param_val, text_type): + param_val = param_val.encode('utf-8') + log_args[param] = heuristic_log_sanitize(param_val, self.no_log_values) + + msg = ['%s=%s' % (to_native(arg), to_native(val)) for arg, val in log_args.items()] + if msg: + msg = 'Invoked with %s' % ' '.join(msg) + else: + msg = 'Invoked' + + self.log(msg, log_args=log_args) + + def _set_cwd(self): + try: + cwd = os.getcwd() + if not os.access(cwd, os.F_OK | os.R_OK): + raise Exception() + return cwd + except Exception: + # we don't have access to the cwd, probably because of sudo. + # Try and move to a neutral location to prevent errors + for cwd in [self.tmpdir, os.path.expandvars('$HOME'), tempfile.gettempdir()]: + try: + if os.access(cwd, os.F_OK | os.R_OK): + os.chdir(cwd) + return cwd + except Exception: + pass + # we won't error here, as it may *not* be a problem, + # and we don't want to break modules unnecessarily + return None + + def get_bin_path(self, arg, required=False, opt_dirs=None): + ''' + Find system executable in PATH. + + :param arg: The executable to find. + :param required: if executable is not found and required is ``True``, fail_json + :param opt_dirs: optional list of directories to search in addition to ``PATH`` + :returns: if found return full path; otherwise return None + ''' + + bin_path = None + try: + bin_path = get_bin_path(arg=arg, opt_dirs=opt_dirs) + except ValueError as e: + if required: + self.fail_json(msg=to_text(e)) + else: + return bin_path + + return bin_path + + def boolean(self, arg): + '''Convert the argument to a boolean''' + if arg is None: + return arg + + try: + return boolean(arg) + except TypeError as e: + self.fail_json(msg=to_native(e)) + + def jsonify(self, data): + try: + return jsonify(data) + except UnicodeError as e: + self.fail_json(msg=to_text(e)) + + def from_json(self, data): + return json.loads(data) + + def add_cleanup_file(self, path): + if path not in self.cleanup_files: + self.cleanup_files.append(path) + + def do_cleanup_files(self): + for path in self.cleanup_files: + self.cleanup(path) + + def _return_formatted(self, kwargs): + + self.add_path_info(kwargs) + + if 'invocation' not in kwargs: + kwargs['invocation'] = {'module_args': self.params} + + if 'warnings' in kwargs: + if isinstance(kwargs['warnings'], list): + for w in kwargs['warnings']: + self.warn(w) + else: + self.warn(kwargs['warnings']) + + warnings = get_warning_messages() + if warnings: + kwargs['warnings'] = warnings + + if 'deprecations' in kwargs: + if isinstance(kwargs['deprecations'], list): + for d in kwargs['deprecations']: + if isinstance(d, SEQUENCETYPE) and len(d) == 2: + self.deprecate(d[0], version=d[1]) + elif isinstance(d, Mapping): + self.deprecate(d['msg'], version=d.get('version'), date=d.get('date'), + collection_name=d.get('collection_name')) + else: + self.deprecate(d) # pylint: disable=ansible-deprecated-no-version + else: + self.deprecate(kwargs['deprecations']) # pylint: disable=ansible-deprecated-no-version + + deprecations = get_deprecation_messages() + if deprecations: + kwargs['deprecations'] = deprecations + + kwargs = remove_values(kwargs, self.no_log_values) + print('\n%s' % self.jsonify(kwargs)) + + def exit_json(self, **kwargs): + ''' return from the module, without error ''' + + self.do_cleanup_files() + self._return_formatted(kwargs) + sys.exit(0) + + def fail_json(self, msg, **kwargs): + ''' return from the module, with an error message ''' + + kwargs['failed'] = True + kwargs['msg'] = msg + + # Add traceback if debug or high verbosity and it is missing + # NOTE: Badly named as exception, it really always has been a traceback + if 'exception' not in kwargs and sys.exc_info()[2] and (self._debug or self._verbosity >= 3): + if PY2: + # On Python 2 this is the last (stack frame) exception and as such may be unrelated to the failure + kwargs['exception'] = 'WARNING: The below traceback may *not* be related to the actual failure.\n' +\ + ''.join(traceback.format_tb(sys.exc_info()[2])) + else: + kwargs['exception'] = ''.join(traceback.format_tb(sys.exc_info()[2])) + + self.do_cleanup_files() + self._return_formatted(kwargs) + sys.exit(1) + + def fail_on_missing_params(self, required_params=None): + if not required_params: + return + try: + check_missing_parameters(self.params, required_params) + except TypeError as e: + self.fail_json(msg=to_native(e)) + + def digest_from_file(self, filename, algorithm): + ''' Return hex digest of local file for a digest_method specified by name, or None if file is not present. ''' + b_filename = to_bytes(filename, errors='surrogate_or_strict') + + if not os.path.exists(b_filename): + return None + if os.path.isdir(b_filename): + self.fail_json(msg="attempted to take checksum of directory: %s" % filename) + + # preserve old behaviour where the third parameter was a hash algorithm object + if hasattr(algorithm, 'hexdigest'): + digest_method = algorithm + else: + try: + digest_method = AVAILABLE_HASH_ALGORITHMS[algorithm]() + except KeyError: + self.fail_json(msg="Could not hash file '%s' with algorithm '%s'. Available algorithms: %s" % + (filename, algorithm, ', '.join(AVAILABLE_HASH_ALGORITHMS))) + + blocksize = 64 * 1024 + infile = open(os.path.realpath(b_filename), 'rb') + block = infile.read(blocksize) + while block: + digest_method.update(block) + block = infile.read(blocksize) + infile.close() + return digest_method.hexdigest() + + def md5(self, filename): + ''' Return MD5 hex digest of local file using digest_from_file(). + + Do not use this function unless you have no other choice for: + 1) Optional backwards compatibility + 2) Compatibility with a third party protocol + + This function will not work on systems complying with FIPS-140-2. + + Most uses of this function can use the module.sha1 function instead. + ''' + if 'md5' not in AVAILABLE_HASH_ALGORITHMS: + raise ValueError('MD5 not available. Possibly running in FIPS mode') + return self.digest_from_file(filename, 'md5') + + def sha1(self, filename): + ''' Return SHA1 hex digest of local file using digest_from_file(). ''' + return self.digest_from_file(filename, 'sha1') + + def sha256(self, filename): + ''' Return SHA-256 hex digest of local file using digest_from_file(). ''' + return self.digest_from_file(filename, 'sha256') + + def backup_local(self, fn): + '''make a date-marked backup of the specified file, return True or False on success or failure''' + + backupdest = '' + if os.path.exists(fn): + # backups named basename.PID.YYYY-MM-DD@HH:MM:SS~ + ext = time.strftime("%Y-%m-%d@%H:%M:%S~", time.localtime(time.time())) + backupdest = '%s.%s.%s' % (fn, os.getpid(), ext) + + try: + self.preserved_copy(fn, backupdest) + except (shutil.Error, IOError) as e: + self.fail_json(msg='Could not make backup of %s to %s: %s' % (fn, backupdest, to_native(e))) + + return backupdest + + def cleanup(self, tmpfile): + if os.path.exists(tmpfile): + try: + os.unlink(tmpfile) + except OSError as e: + sys.stderr.write("could not cleanup %s: %s" % (tmpfile, to_native(e))) + + def preserved_copy(self, src, dest): + """Copy a file with preserved ownership, permissions and context""" + + # shutil.copy2(src, dst) + # Similar to shutil.copy(), but metadata is copied as well - in fact, + # this is just shutil.copy() followed by copystat(). This is similar + # to the Unix command cp -p. + # + # shutil.copystat(src, dst) + # Copy the permission bits, last access time, last modification time, + # and flags from src to dst. The file contents, owner, and group are + # unaffected. src and dst are path names given as strings. + + shutil.copy2(src, dest) + + # Set the context + if self.selinux_enabled(): + context = self.selinux_context(src) + self.set_context_if_different(dest, context, False) + + # chown it + try: + dest_stat = os.stat(src) + tmp_stat = os.stat(dest) + if dest_stat and (tmp_stat.st_uid != dest_stat.st_uid or tmp_stat.st_gid != dest_stat.st_gid): + os.chown(dest, dest_stat.st_uid, dest_stat.st_gid) + except OSError as e: + if e.errno != errno.EPERM: + raise + + # Set the attributes + current_attribs = self.get_file_attributes(src, include_version=False) + current_attribs = current_attribs.get('attr_flags', '') + self.set_attributes_if_different(dest, current_attribs, True) + + def atomic_move(self, src, dest, unsafe_writes=False): + '''atomically move src to dest, copying attributes from dest, returns true on success + it uses os.rename to ensure this as it is an atomic operation, rest of the function is + to work around limitations, corner cases and ensure selinux context is saved if possible''' + context = None + dest_stat = None + b_src = to_bytes(src, errors='surrogate_or_strict') + b_dest = to_bytes(dest, errors='surrogate_or_strict') + if os.path.exists(b_dest): + try: + dest_stat = os.stat(b_dest) + + # copy mode and ownership + os.chmod(b_src, dest_stat.st_mode & PERM_BITS) + os.chown(b_src, dest_stat.st_uid, dest_stat.st_gid) + + # try to copy flags if possible + if hasattr(os, 'chflags') and hasattr(dest_stat, 'st_flags'): + try: + os.chflags(b_src, dest_stat.st_flags) + except OSError as e: + for err in 'EOPNOTSUPP', 'ENOTSUP': + if hasattr(errno, err) and e.errno == getattr(errno, err): + break + else: + raise + except OSError as e: + if e.errno != errno.EPERM: + raise + if self.selinux_enabled(): + context = self.selinux_context(dest) + else: + if self.selinux_enabled(): + context = self.selinux_default_context(dest) + + creating = not os.path.exists(b_dest) + + try: + # Optimistically try a rename, solves some corner cases and can avoid useless work, throws exception if not atomic. + os.rename(b_src, b_dest) + except (IOError, OSError) as e: + if e.errno not in [errno.EPERM, errno.EXDEV, errno.EACCES, errno.ETXTBSY, errno.EBUSY]: + # only try workarounds for errno 18 (cross device), 1 (not permitted), 13 (permission denied) + # and 26 (text file busy) which happens on vagrant synced folders and other 'exotic' non posix file systems + self.fail_json(msg='Could not replace file: %s to %s: %s' % (src, dest, to_native(e)), exception=traceback.format_exc()) + else: + # Use bytes here. In the shippable CI, this fails with + # a UnicodeError with surrogateescape'd strings for an unknown + # reason (doesn't happen in a local Ubuntu16.04 VM) + b_dest_dir = os.path.dirname(b_dest) + b_suffix = os.path.basename(b_dest) + error_msg = None + tmp_dest_name = None + try: + tmp_dest_fd, tmp_dest_name = tempfile.mkstemp(prefix=b'.ansible_tmp', dir=b_dest_dir, suffix=b_suffix) + except (OSError, IOError) as e: + error_msg = 'The destination directory (%s) is not writable by the current user. Error was: %s' % (os.path.dirname(dest), to_native(e)) + except TypeError: + # We expect that this is happening because python3.4.x and + # below can't handle byte strings in mkstemp(). + # Traceback would end in something like: + # file = _os.path.join(dir, pre + name + suf) + # TypeError: can't concat bytes to str + error_msg = ('Failed creating tmp file for atomic move. This usually happens when using Python3 less than Python3.5. ' + 'Please use Python2.x or Python3.5 or greater.') + finally: + if error_msg: + if unsafe_writes: + self._unsafe_writes(b_src, b_dest) + else: + self.fail_json(msg=error_msg, exception=traceback.format_exc()) + + if tmp_dest_name: + b_tmp_dest_name = to_bytes(tmp_dest_name, errors='surrogate_or_strict') + + try: + try: + # close tmp file handle before file operations to prevent text file busy errors on vboxfs synced folders (windows host) + os.close(tmp_dest_fd) + # leaves tmp file behind when sudo and not root + try: + shutil.move(b_src, b_tmp_dest_name) + except OSError: + # cleanup will happen by 'rm' of tmpdir + # copy2 will preserve some metadata + shutil.copy2(b_src, b_tmp_dest_name) + + if self.selinux_enabled(): + self.set_context_if_different( + b_tmp_dest_name, context, False) + try: + tmp_stat = os.stat(b_tmp_dest_name) + if dest_stat and (tmp_stat.st_uid != dest_stat.st_uid or tmp_stat.st_gid != dest_stat.st_gid): + os.chown(b_tmp_dest_name, dest_stat.st_uid, dest_stat.st_gid) + except OSError as e: + if e.errno != errno.EPERM: + raise + try: + os.rename(b_tmp_dest_name, b_dest) + except (shutil.Error, OSError, IOError) as e: + if unsafe_writes and e.errno == errno.EBUSY: + self._unsafe_writes(b_tmp_dest_name, b_dest) + else: + self.fail_json(msg='Unable to make %s into to %s, failed final rename from %s: %s' % + (src, dest, b_tmp_dest_name, to_native(e)), exception=traceback.format_exc()) + except (shutil.Error, OSError, IOError) as e: + if unsafe_writes: + self._unsafe_writes(b_src, b_dest) + else: + self.fail_json(msg='Failed to replace file: %s to %s: %s' % (src, dest, to_native(e)), exception=traceback.format_exc()) + finally: + self.cleanup(b_tmp_dest_name) + + if creating: + # make sure the file has the correct permissions + # based on the current value of umask + umask = os.umask(0) + os.umask(umask) + os.chmod(b_dest, DEFAULT_PERM & ~umask) + try: + os.chown(b_dest, os.geteuid(), os.getegid()) + except OSError: + # We're okay with trying our best here. If the user is not + # root (or old Unices) they won't be able to chown. + pass + + if self.selinux_enabled(): + # rename might not preserve context + self.set_context_if_different(dest, context, False) + + def _unsafe_writes(self, src, dest): + # sadly there are some situations where we cannot ensure atomicity, but only if + # the user insists and we get the appropriate error we update the file unsafely + try: + out_dest = in_src = None + try: + out_dest = open(dest, 'wb') + in_src = open(src, 'rb') + shutil.copyfileobj(in_src, out_dest) + finally: # assuring closed files in 2.4 compatible way + if out_dest: + out_dest.close() + if in_src: + in_src.close() + except (shutil.Error, OSError, IOError) as e: + self.fail_json(msg='Could not write data to file (%s) from (%s): %s' % (dest, src, to_native(e)), + exception=traceback.format_exc()) + + def _clean_args(self, args): + + if not self._clean: + # create a printable version of the command for use in reporting later, + # which strips out things like passwords from the args list + to_clean_args = args + if PY2: + if isinstance(args, text_type): + to_clean_args = to_bytes(args) + else: + if isinstance(args, binary_type): + to_clean_args = to_text(args) + if isinstance(args, (text_type, binary_type)): + to_clean_args = shlex.split(to_clean_args) + + clean_args = [] + is_passwd = False + for arg in (to_native(a) for a in to_clean_args): + if is_passwd: + is_passwd = False + clean_args.append('********') + continue + if PASSWD_ARG_RE.match(arg): + sep_idx = arg.find('=') + if sep_idx > -1: + clean_args.append('%s=********' % arg[:sep_idx]) + continue + else: + is_passwd = True + arg = heuristic_log_sanitize(arg, self.no_log_values) + clean_args.append(arg) + self._clean = ' '.join(shlex_quote(arg) for arg in clean_args) + + return self._clean + + def _restore_signal_handlers(self): + # Reset SIGPIPE to SIG_DFL, otherwise in Python2.7 it gets ignored in subprocesses. + if PY2 and sys.platform != 'win32': + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + def run_command(self, args, check_rc=False, close_fds=True, executable=None, data=None, binary_data=False, path_prefix=None, cwd=None, + use_unsafe_shell=False, prompt_regex=None, environ_update=None, umask=None, encoding='utf-8', errors='surrogate_or_strict', + expand_user_and_vars=True, pass_fds=None, before_communicate_callback=None, ignore_invalid_cwd=True, handle_exceptions=True): + ''' + Execute a command, returns rc, stdout, and stderr. + + :arg args: is the command to run + * If args is a list, the command will be run with shell=False. + * If args is a string and use_unsafe_shell=False it will split args to a list and run with shell=False + * If args is a string and use_unsafe_shell=True it runs with shell=True. + :kw check_rc: Whether to call fail_json in case of non zero RC. + Default False + :kw close_fds: See documentation for subprocess.Popen(). Default True + :kw executable: See documentation for subprocess.Popen(). Default None + :kw data: If given, information to write to the stdin of the command + :kw binary_data: If False, append a newline to the data. Default False + :kw path_prefix: If given, additional path to find the command in. + This adds to the PATH environment variable so helper commands in + the same directory can also be found + :kw cwd: If given, working directory to run the command inside + :kw use_unsafe_shell: See `args` parameter. Default False + :kw prompt_regex: Regex string (not a compiled regex) which can be + used to detect prompts in the stdout which would otherwise cause + the execution to hang (especially if no input data is specified) + :kw environ_update: dictionary to *update* environ variables with + :kw umask: Umask to be used when running the command. Default None + :kw encoding: Since we return native strings, on python3 we need to + know the encoding to use to transform from bytes to text. If you + want to always get bytes back, use encoding=None. The default is + "utf-8". This does not affect transformation of strings given as + args. + :kw errors: Since we return native strings, on python3 we need to + transform stdout and stderr from bytes to text. If the bytes are + undecodable in the ``encoding`` specified, then use this error + handler to deal with them. The default is ``surrogate_or_strict`` + which means that the bytes will be decoded using the + surrogateescape error handler if available (available on all + python3 versions we support) otherwise a UnicodeError traceback + will be raised. This does not affect transformations of strings + given as args. + :kw expand_user_and_vars: When ``use_unsafe_shell=False`` this argument + dictates whether ``~`` is expanded in paths and environment variables + are expanded before running the command. When ``True`` a string such as + ``$SHELL`` will be expanded regardless of escaping. When ``False`` and + ``use_unsafe_shell=False`` no path or variable expansion will be done. + :kw pass_fds: When running on Python 3 this argument + dictates which file descriptors should be passed + to an underlying ``Popen`` constructor. On Python 2, this will + set ``close_fds`` to False. + :kw before_communicate_callback: This function will be called + after ``Popen`` object will be created + but before communicating to the process. + (``Popen`` object will be passed to callback as a first argument) + :kw ignore_invalid_cwd: This flag indicates whether an invalid ``cwd`` + (non-existent or not a directory) should be ignored or should raise + an exception. + :kw handle_exceptions: This flag indicates whether an exception will + be handled inline and issue a failed_json or if the caller should + handle it. + :returns: A 3-tuple of return code (integer), stdout (native string), + and stderr (native string). On python2, stdout and stderr are both + byte strings. On python3, stdout and stderr are text strings converted + according to the encoding and errors parameters. If you want byte + strings on python3, use encoding=None to turn decoding to text off. + ''' + # used by clean args later on + self._clean = None + + if not isinstance(args, (list, binary_type, text_type)): + msg = "Argument 'args' to run_command must be list or string" + self.fail_json(rc=257, cmd=args, msg=msg) + + shell = False + if use_unsafe_shell: + + # stringify args for unsafe/direct shell usage + if isinstance(args, list): + args = b" ".join([to_bytes(shlex_quote(x), errors='surrogate_or_strict') for x in args]) + else: + args = to_bytes(args, errors='surrogate_or_strict') + + # not set explicitly, check if set by controller + if executable: + executable = to_bytes(executable, errors='surrogate_or_strict') + args = [executable, b'-c', args] + elif self._shell not in (None, '/bin/sh'): + args = [to_bytes(self._shell, errors='surrogate_or_strict'), b'-c', args] + else: + shell = True + else: + # ensure args are a list + if isinstance(args, (binary_type, text_type)): + # On python2.6 and below, shlex has problems with text type + # On python3, shlex needs a text type. + if PY2: + args = to_bytes(args, errors='surrogate_or_strict') + elif PY3: + args = to_text(args, errors='surrogateescape') + args = shlex.split(args) + + # expand ``~`` in paths, and all environment vars + if expand_user_and_vars: + args = [to_bytes(os.path.expanduser(os.path.expandvars(x)), errors='surrogate_or_strict') for x in args if x is not None] + else: + args = [to_bytes(x, errors='surrogate_or_strict') for x in args if x is not None] + + prompt_re = None + if prompt_regex: + if isinstance(prompt_regex, text_type): + if PY3: + prompt_regex = to_bytes(prompt_regex, errors='surrogateescape') + elif PY2: + prompt_regex = to_bytes(prompt_regex, errors='surrogate_or_strict') + try: + prompt_re = re.compile(prompt_regex, re.MULTILINE) + except re.error: + self.fail_json(msg="invalid prompt regular expression given to run_command") + + rc = 0 + msg = None + st_in = None + + env = os.environ.copy() + # We can set this from both an attribute and per call + env.update(self.run_command_environ_update or {}) + env.update(environ_update or {}) + if path_prefix: + path = env.get('PATH', '') + if path: + env['PATH'] = "%s:%s" % (path_prefix, path) + else: + env['PATH'] = path_prefix + + # If using test-module.py and explode, the remote lib path will resemble: + # /tmp/test_module_scratch/debug_dir/ansible/module_utils/basic.py + # If using ansible or ansible-playbook with a remote system: + # /tmp/ansible_vmweLQ/ansible_modlib.zip/ansible/module_utils/basic.py + + # Clean out python paths set by ansiballz + if 'PYTHONPATH' in env: + pypaths = [x for x in env['PYTHONPATH'].split(':') + if x and + not x.endswith('/ansible_modlib.zip') and + not x.endswith('/debug_dir')] + if pypaths and any(pypaths): + env['PYTHONPATH'] = ':'.join(pypaths) + + if data: + st_in = subprocess.PIPE + + def preexec(): + self._restore_signal_handlers() + if umask: + os.umask(umask) + + kwargs = dict( + executable=executable, + shell=shell, + close_fds=close_fds, + stdin=st_in, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=preexec, + env=env, + ) + if PY3 and pass_fds: + kwargs["pass_fds"] = pass_fds + elif PY2 and pass_fds: + kwargs['close_fds'] = False + + # make sure we're in the right working directory + if cwd: + cwd = to_bytes(os.path.abspath(os.path.expanduser(cwd)), errors='surrogate_or_strict') + if os.path.isdir(cwd): + kwargs['cwd'] = cwd + elif not ignore_invalid_cwd: + self.fail_json(msg="Provided cwd is not a valid directory: %s" % cwd) + + try: + if self._debug: + self.log('Executing: ' + self._clean_args(args)) + cmd = subprocess.Popen(args, **kwargs) + if before_communicate_callback: + before_communicate_callback(cmd) + + # the communication logic here is essentially taken from that + # of the _communicate() function in ssh.py + + stdout = b'' + stderr = b'' + try: + selector = selectors.DefaultSelector() + except (IOError, OSError): + # Failed to detect default selector for the given platform + # Select PollSelector which is supported by major platforms + selector = selectors.PollSelector() + + selector.register(cmd.stdout, selectors.EVENT_READ) + selector.register(cmd.stderr, selectors.EVENT_READ) + if os.name == 'posix': + fcntl.fcntl(cmd.stdout.fileno(), fcntl.F_SETFL, fcntl.fcntl(cmd.stdout.fileno(), fcntl.F_GETFL) | os.O_NONBLOCK) + fcntl.fcntl(cmd.stderr.fileno(), fcntl.F_SETFL, fcntl.fcntl(cmd.stderr.fileno(), fcntl.F_GETFL) | os.O_NONBLOCK) + + if data: + if not binary_data: + data += '\n' + if isinstance(data, text_type): + data = to_bytes(data) + cmd.stdin.write(data) + cmd.stdin.close() + + while True: + events = selector.select(1) + for key, event in events: + b_chunk = key.fileobj.read() + if b_chunk == b(''): + selector.unregister(key.fileobj) + if key.fileobj == cmd.stdout: + stdout += b_chunk + elif key.fileobj == cmd.stderr: + stderr += b_chunk + # if we're checking for prompts, do it now + if prompt_re: + if prompt_re.search(stdout) and not data: + if encoding: + stdout = to_native(stdout, encoding=encoding, errors=errors) + return (257, stdout, "A prompt was encountered while running a command, but no input data was specified") + # only break out if no pipes are left to read or + # the pipes are completely read and + # the process is terminated + if (not events or not selector.get_map()) and cmd.poll() is not None: + break + # No pipes are left to read but process is not yet terminated + # Only then it is safe to wait for the process to be finished + # NOTE: Actually cmd.poll() is always None here if no selectors are left + elif not selector.get_map() and cmd.poll() is None: + cmd.wait() + # The process is terminated. Since no pipes to read from are + # left, there is no need to call select() again. + break + + cmd.stdout.close() + cmd.stderr.close() + selector.close() + + rc = cmd.returncode + except (OSError, IOError) as e: + self.log("Error Executing CMD:%s Exception:%s" % (self._clean_args(args), to_native(e))) + if handle_exceptions: + self.fail_json(rc=e.errno, stdout=b'', stderr=b'', msg=to_native(e), cmd=self._clean_args(args)) + else: + raise e + except Exception as e: + self.log("Error Executing CMD:%s Exception:%s" % (self._clean_args(args), to_native(traceback.format_exc()))) + if handle_exceptions: + self.fail_json(rc=257, stdout=b'', stderr=b'', msg=to_native(e), exception=traceback.format_exc(), cmd=self._clean_args(args)) + else: + raise e + + if rc != 0 and check_rc: + msg = heuristic_log_sanitize(stderr.rstrip(), self.no_log_values) + self.fail_json(cmd=self._clean_args(args), rc=rc, stdout=stdout, stderr=stderr, msg=msg) + + if encoding is not None: + return (rc, to_native(stdout, encoding=encoding, errors=errors), + to_native(stderr, encoding=encoding, errors=errors)) + + return (rc, stdout, stderr) + + def append_to_file(self, filename, str): + filename = os.path.expandvars(os.path.expanduser(filename)) + fh = open(filename, 'a') + fh.write(str) + fh.close() + + def bytes_to_human(self, size): + return bytes_to_human(size) + + # for backwards compatibility + pretty_bytes = bytes_to_human + + def human_to_bytes(self, number, isbits=False): + return human_to_bytes(number, isbits) + + # + # Backwards compat + # + + # In 2.0, moved from inside the module to the toplevel + is_executable = is_executable + + @staticmethod + def get_buffer_size(fd): + try: + # 1032 == FZ_GETPIPE_SZ + buffer_size = fcntl.fcntl(fd, 1032) + except Exception: + try: + # not as exact as above, but should be good enough for most platforms that fail the previous call + buffer_size = select.PIPE_BUF + except Exception: + buffer_size = 9000 # use sane default JIC + + return buffer_size + + +def get_module_path(): + return os.path.dirname(os.path.realpath(__file__)) diff --git a/lib/ansible/module_utils/common/__init__.py b/lib/ansible/module_utils/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/common/__init__.py diff --git a/lib/ansible/module_utils/common/_collections_compat.py b/lib/ansible/module_utils/common/_collections_compat.py new file mode 100644 index 0000000..3412408 --- /dev/null +++ b/lib/ansible/module_utils/common/_collections_compat.py @@ -0,0 +1,46 @@ +# Copyright (c), Sviatoslav Sydorenko <ssydoren@redhat.com> 2018 +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) +"""Collections ABC import shim. + +This module is intended only for internal use. +It will go away once the bundled copy of six includes equivalent functionality. +Third parties should not use this. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +try: + """Python 3.3+ branch.""" + from collections.abc import ( + MappingView, + ItemsView, + KeysView, + ValuesView, + Mapping, MutableMapping, + Sequence, MutableSequence, + Set, MutableSet, + Container, + Hashable, + Sized, + Callable, + Iterable, + Iterator, + ) +except ImportError: + """Use old lib location under 2.6-3.2.""" + from collections import ( # type: ignore[no-redef,attr-defined] # pylint: disable=deprecated-class + MappingView, + ItemsView, + KeysView, + ValuesView, + Mapping, MutableMapping, + Sequence, MutableSequence, + Set, MutableSet, + Container, + Hashable, + Sized, + Callable, + Iterable, + Iterator, + ) diff --git a/lib/ansible/module_utils/common/_json_compat.py b/lib/ansible/module_utils/common/_json_compat.py new file mode 100644 index 0000000..787af0f --- /dev/null +++ b/lib/ansible/module_utils/common/_json_compat.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import types +import json + +# Detect the python-json library which is incompatible +try: + if not isinstance(json.loads, types.FunctionType) or not isinstance(json.dumps, types.FunctionType): + raise ImportError('json.loads or json.dumps were not found in the imported json library.') +except AttributeError: + raise ImportError('python-json was detected, which is incompatible.') diff --git a/lib/ansible/module_utils/common/_utils.py b/lib/ansible/module_utils/common/_utils.py new file mode 100644 index 0000000..66df316 --- /dev/null +++ b/lib/ansible/module_utils/common/_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018, Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +""" +Modules in _utils are waiting to find a better home. If you need to use them, be prepared for them +to move to a different location in the future. +""" + + +def get_all_subclasses(cls): + ''' + Recursively search and find all subclasses of a given class + + :arg cls: A python class + :rtype: set + :returns: The set of python classes which are the subclasses of `cls`. + + In python, you can use a class's :py:meth:`__subclasses__` method to determine what subclasses + of a class exist. However, `__subclasses__` only goes one level deep. This function searches + each child class's `__subclasses__` method to find all of the descendent classes. It then + returns an iterable of the descendent classes. + ''' + # Retrieve direct subclasses + subclasses = set(cls.__subclasses__()) + to_visit = list(subclasses) + # Then visit all subclasses + while to_visit: + for sc in to_visit: + # The current class is now visited, so remove it from list + to_visit.remove(sc) + # Appending all subclasses to visit and keep a reference of available class + for ssc in sc.__subclasses__(): + if ssc not in subclasses: + to_visit.append(ssc) + subclasses.add(ssc) + return subclasses diff --git a/lib/ansible/module_utils/common/arg_spec.py b/lib/ansible/module_utils/common/arg_spec.py new file mode 100644 index 0000000..d9f716e --- /dev/null +++ b/lib/ansible/module_utils/common/arg_spec.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from copy import deepcopy + +from ansible.module_utils.common.parameters import ( + _ADDITIONAL_CHECKS, + _get_legal_inputs, + _get_unsupported_parameters, + _handle_aliases, + _list_deprecations, + _list_no_log_values, + _set_defaults, + _validate_argument_types, + _validate_argument_values, + _validate_sub_spec, + set_fallbacks, +) + +from ansible.module_utils.common.text.converters import to_native +from ansible.module_utils.common.warnings import deprecate, warn + +from ansible.module_utils.common.validation import ( + check_mutually_exclusive, + check_required_arguments, +) + +from ansible.module_utils.errors import ( + AliasError, + AnsibleValidationErrorMultiple, + DeprecationError, + MutuallyExclusiveError, + NoLogError, + RequiredDefaultError, + RequiredError, + UnsupportedError, +) + + +class ValidationResult: + """Result of argument spec validation. + + This is the object returned by :func:`ArgumentSpecValidator.validate() + <ansible.module_utils.common.arg_spec.ArgumentSpecValidator.validate()>` + containing the validated parameters and any errors. + """ + + def __init__(self, parameters): + """ + :arg parameters: Terms to be validated and coerced to the correct type. + :type parameters: dict + """ + self._no_log_values = set() + """:class:`set` of values marked as ``no_log`` in the argument spec. This + is a temporary holding place for these values and may move in the future. + """ + + self._unsupported_parameters = set() + self._supported_parameters = dict() + self._validated_parameters = deepcopy(parameters) + self._deprecations = [] + self._warnings = [] + self._aliases = {} + self.errors = AnsibleValidationErrorMultiple() + """ + :class:`~ansible.module_utils.errors.AnsibleValidationErrorMultiple` containing all + :class:`~ansible.module_utils.errors.AnsibleValidationError` objects if there were + any failures during validation. + """ + + @property + def validated_parameters(self): + """Validated and coerced parameters.""" + return self._validated_parameters + + @property + def unsupported_parameters(self): + """:class:`set` of unsupported parameter names.""" + return self._unsupported_parameters + + @property + def error_messages(self): + """:class:`list` of all error messages from each exception in :attr:`errors`.""" + return self.errors.messages + + +class ArgumentSpecValidator: + """Argument spec validation class + + Creates a validator based on the ``argument_spec`` that can be used to + validate a number of parameters using the :meth:`validate` method. + """ + + def __init__(self, argument_spec, + mutually_exclusive=None, + required_together=None, + required_one_of=None, + required_if=None, + required_by=None, + ): + + """ + :arg argument_spec: Specification of valid parameters and their type. May + include nested argument specs. + :type argument_spec: dict[str, dict] + + :kwarg mutually_exclusive: List or list of lists of terms that should not + be provided together. + :type mutually_exclusive: list[str] or list[list[str]] + + :kwarg required_together: List of lists of terms that are required together. + :type required_together: list[list[str]] + + :kwarg required_one_of: List of lists of terms, one of which in each list + is required. + :type required_one_of: list[list[str]] + + :kwarg required_if: List of lists of ``[parameter, value, [parameters]]`` where + one of ``[parameters]`` is required if ``parameter == value``. + :type required_if: list + + :kwarg required_by: Dictionary of parameter names that contain a list of + parameters required by each key in the dictionary. + :type required_by: dict[str, list[str]] + """ + + self._mutually_exclusive = mutually_exclusive + self._required_together = required_together + self._required_one_of = required_one_of + self._required_if = required_if + self._required_by = required_by + self._valid_parameter_names = set() + self.argument_spec = argument_spec + + for key in sorted(self.argument_spec.keys()): + aliases = self.argument_spec[key].get('aliases') + if aliases: + self._valid_parameter_names.update(["{key} ({aliases})".format(key=key, aliases=", ".join(sorted(aliases)))]) + else: + self._valid_parameter_names.update([key]) + + def validate(self, parameters, *args, **kwargs): + """Validate ``parameters`` against argument spec. + + Error messages in the :class:`ValidationResult` may contain no_log values and should be + sanitized with :func:`~ansible.module_utils.common.parameters.sanitize_keys` before logging or displaying. + + :arg parameters: Parameters to validate against the argument spec + :type parameters: dict[str, dict] + + :return: :class:`ValidationResult` containing validated parameters. + + :Simple Example: + + .. code-block:: text + + argument_spec = { + 'name': {'type': 'str'}, + 'age': {'type': 'int'}, + } + + parameters = { + 'name': 'bo', + 'age': '42', + } + + validator = ArgumentSpecValidator(argument_spec) + result = validator.validate(parameters) + + if result.error_messages: + sys.exit("Validation failed: {0}".format(", ".join(result.error_messages)) + + valid_params = result.validated_parameters + """ + + result = ValidationResult(parameters) + + result._no_log_values.update(set_fallbacks(self.argument_spec, result._validated_parameters)) + + alias_warnings = [] + alias_deprecations = [] + try: + result._aliases.update(_handle_aliases(self.argument_spec, result._validated_parameters, alias_warnings, alias_deprecations)) + except (TypeError, ValueError) as e: + result.errors.append(AliasError(to_native(e))) + + legal_inputs = _get_legal_inputs(self.argument_spec, result._validated_parameters, result._aliases) + + for option, alias in alias_warnings: + result._warnings.append({'option': option, 'alias': alias}) + + for deprecation in alias_deprecations: + result._deprecations.append({ + 'msg': "Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], + 'version': deprecation.get('version'), + 'date': deprecation.get('date'), + 'collection_name': deprecation.get('collection_name'), + }) + + try: + result._no_log_values.update(_list_no_log_values(self.argument_spec, result._validated_parameters)) + except TypeError as te: + result.errors.append(NoLogError(to_native(te))) + + try: + result._deprecations.extend(_list_deprecations(self.argument_spec, result._validated_parameters)) + except TypeError as te: + result.errors.append(DeprecationError(to_native(te))) + + try: + result._unsupported_parameters.update( + _get_unsupported_parameters( + self.argument_spec, + result._validated_parameters, + legal_inputs, + store_supported=result._supported_parameters, + ) + ) + except TypeError as te: + result.errors.append(RequiredDefaultError(to_native(te))) + except ValueError as ve: + result.errors.append(AliasError(to_native(ve))) + + try: + check_mutually_exclusive(self._mutually_exclusive, result._validated_parameters) + except TypeError as te: + result.errors.append(MutuallyExclusiveError(to_native(te))) + + result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters, False)) + + try: + check_required_arguments(self.argument_spec, result._validated_parameters) + except TypeError as e: + result.errors.append(RequiredError(to_native(e))) + + _validate_argument_types(self.argument_spec, result._validated_parameters, errors=result.errors) + _validate_argument_values(self.argument_spec, result._validated_parameters, errors=result.errors) + + for check in _ADDITIONAL_CHECKS: + try: + check['func'](getattr(self, "_{attr}".format(attr=check['attr'])), result._validated_parameters) + except TypeError as te: + result.errors.append(check['err'](to_native(te))) + + result._no_log_values.update(_set_defaults(self.argument_spec, result._validated_parameters)) + + alias_deprecations = [] + _validate_sub_spec(self.argument_spec, result._validated_parameters, + errors=result.errors, + no_log_values=result._no_log_values, + unsupported_parameters=result._unsupported_parameters, + supported_parameters=result._supported_parameters, + alias_deprecations=alias_deprecations,) + for deprecation in alias_deprecations: + result._deprecations.append({ + 'msg': "Alias '%s' is deprecated. See the module docs for more information" % deprecation['name'], + 'version': deprecation.get('version'), + 'date': deprecation.get('date'), + 'collection_name': deprecation.get('collection_name'), + }) + + if result._unsupported_parameters: + flattened_names = [] + for item in result._unsupported_parameters: + if isinstance(item, tuple): + flattened_names.append(".".join(item)) + else: + flattened_names.append(item) + + unsupported_string = ", ".join(sorted(list(flattened_names))) + supported_params = supported_aliases = [] + if result._supported_parameters.get(item): + supported_params = sorted(list(result._supported_parameters[item][0])) + supported_aliases = sorted(list(result._supported_parameters[item][1])) + supported_string = ", ".join(supported_params) + if supported_aliases: + aliases_string = ", ".join(supported_aliases) + supported_string += " (%s)" % aliases_string + + msg = "{0}. Supported parameters include: {1}.".format(unsupported_string, supported_string) + result.errors.append(UnsupportedError(msg)) + + return result + + +class ModuleArgumentSpecValidator(ArgumentSpecValidator): + """Argument spec validation class used by :class:`AnsibleModule`. + + This is not meant to be used outside of :class:`AnsibleModule`. Use + :class:`ArgumentSpecValidator` instead. + """ + + def __init__(self, *args, **kwargs): + super(ModuleArgumentSpecValidator, self).__init__(*args, **kwargs) + + def validate(self, parameters): + result = super(ModuleArgumentSpecValidator, self).validate(parameters) + + for d in result._deprecations: + deprecate(d['msg'], + version=d.get('version'), date=d.get('date'), + collection_name=d.get('collection_name')) + + for w in result._warnings: + warn('Both option {option} and its alias {alias} are set.'.format(option=w['option'], alias=w['alias'])) + + return result diff --git a/lib/ansible/module_utils/common/collections.py b/lib/ansible/module_utils/common/collections.py new file mode 100644 index 0000000..fdb9108 --- /dev/null +++ b/lib/ansible/module_utils/common/collections.py @@ -0,0 +1,112 @@ +# Copyright: (c) 2018, Sviatoslav Sydorenko <ssydoren@redhat.com> +# Copyright: (c) 2018, Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) +"""Collection of low-level utility functions.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +from ansible.module_utils.six import binary_type, text_type +from ansible.module_utils.common._collections_compat import Hashable, Mapping, MutableMapping, Sequence + + +class ImmutableDict(Hashable, Mapping): + """Dictionary that cannot be updated""" + def __init__(self, *args, **kwargs): + self._store = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._store[key] + + def __iter__(self): + return self._store.__iter__() + + def __len__(self): + return self._store.__len__() + + def __hash__(self): + return hash(frozenset(self.items())) + + def __eq__(self, other): + try: + if self.__hash__() == hash(other): + return True + except TypeError: + pass + + return False + + def __repr__(self): + return 'ImmutableDict({0})'.format(repr(self._store)) + + def union(self, overriding_mapping): + """ + Create an ImmutableDict as a combination of the original and overriding_mapping + + :arg overriding_mapping: A Mapping of replacement and additional items + :return: A copy of the ImmutableDict with key-value pairs from the overriding_mapping added + + If any of the keys in overriding_mapping are already present in the original ImmutableDict, + the overriding_mapping item replaces the one in the original ImmutableDict. + """ + return ImmutableDict(self._store, **overriding_mapping) + + def difference(self, subtractive_iterable): + """ + Create an ImmutableDict as a combination of the original minus keys in subtractive_iterable + + :arg subtractive_iterable: Any iterable containing keys that should not be present in the + new ImmutableDict + :return: A copy of the ImmutableDict with keys from the subtractive_iterable removed + """ + remove_keys = frozenset(subtractive_iterable) + keys = (k for k in self._store.keys() if k not in remove_keys) + return ImmutableDict((k, self._store[k]) for k in keys) + + +def is_string(seq): + """Identify whether the input has a string-like type (inclding bytes).""" + # AnsibleVaultEncryptedUnicode inherits from Sequence, but is expected to be a string like object + return isinstance(seq, (text_type, binary_type)) or getattr(seq, '__ENCRYPTED__', False) + + +def is_iterable(seq, include_strings=False): + """Identify whether the input is an iterable.""" + if not include_strings and is_string(seq): + return False + + try: + iter(seq) + return True + except TypeError: + return False + + +def is_sequence(seq, include_strings=False): + """Identify whether the input is a sequence. + + Strings and bytes are not sequences here, + unless ``include_string`` is ``True``. + + Non-indexable things are never of a sequence type. + """ + if not include_strings and is_string(seq): + return False + + return isinstance(seq, Sequence) + + +def count(seq): + """Returns a dictionary with the number of appearances of each element of the iterable. + + Resembles the collections.Counter class functionality. It is meant to be used when the + code is run on Python 2.6.* where collections.Counter is not available. It should be + deprecated and replaced when support for Python < 2.7 is dropped. + """ + if not is_iterable(seq): + raise Exception('Argument provided is not an iterable') + counters = dict() + for elem in seq: + counters[elem] = counters.get(elem, 0) + 1 + return counters diff --git a/lib/ansible/module_utils/common/dict_transformations.py b/lib/ansible/module_utils/common/dict_transformations.py new file mode 100644 index 0000000..ffd0645 --- /dev/null +++ b/lib/ansible/module_utils/common/dict_transformations.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +import re +from copy import deepcopy + +from ansible.module_utils.common._collections_compat import MutableMapping + + +def camel_dict_to_snake_dict(camel_dict, reversible=False, ignore_list=()): + """ + reversible allows two way conversion of a camelized dict + such that snake_dict_to_camel_dict(camel_dict_to_snake_dict(x)) == x + + This is achieved through mapping e.g. HTTPEndpoint to h_t_t_p_endpoint + where the default would be simply http_endpoint, which gets turned into + HttpEndpoint if recamelized. + + ignore_list is used to avoid converting a sub-tree of a dict. This is + particularly important for tags, where keys are case-sensitive. We convert + the 'Tags' key but nothing below. + """ + + def value_is_list(camel_list): + + checked_list = [] + for item in camel_list: + if isinstance(item, dict): + checked_list.append(camel_dict_to_snake_dict(item, reversible)) + elif isinstance(item, list): + checked_list.append(value_is_list(item)) + else: + checked_list.append(item) + + return checked_list + + snake_dict = {} + for k, v in camel_dict.items(): + if isinstance(v, dict) and k not in ignore_list: + snake_dict[_camel_to_snake(k, reversible=reversible)] = camel_dict_to_snake_dict(v, reversible) + elif isinstance(v, list) and k not in ignore_list: + snake_dict[_camel_to_snake(k, reversible=reversible)] = value_is_list(v) + else: + snake_dict[_camel_to_snake(k, reversible=reversible)] = v + + return snake_dict + + +def snake_dict_to_camel_dict(snake_dict, capitalize_first=False): + """ + Perhaps unexpectedly, snake_dict_to_camel_dict returns dromedaryCase + rather than true CamelCase. Passing capitalize_first=True returns + CamelCase. The default remains False as that was the original implementation + """ + + def camelize(complex_type, capitalize_first=False): + if complex_type is None: + return + new_type = type(complex_type)() + if isinstance(complex_type, dict): + for key in complex_type: + new_type[_snake_to_camel(key, capitalize_first)] = camelize(complex_type[key], capitalize_first) + elif isinstance(complex_type, list): + for i in range(len(complex_type)): + new_type.append(camelize(complex_type[i], capitalize_first)) + else: + return complex_type + return new_type + + return camelize(snake_dict, capitalize_first) + + +def _snake_to_camel(snake, capitalize_first=False): + if capitalize_first: + return ''.join(x.capitalize() or '_' for x in snake.split('_')) + else: + return snake.split('_')[0] + ''.join(x.capitalize() or '_' for x in snake.split('_')[1:]) + + +def _camel_to_snake(name, reversible=False): + + def prepend_underscore_and_lower(m): + return '_' + m.group(0).lower() + + if reversible: + upper_pattern = r'[A-Z]' + else: + # Cope with pluralized abbreviations such as TargetGroupARNs + # that would otherwise be rendered target_group_ar_ns + upper_pattern = r'[A-Z]{3,}s$' + + s1 = re.sub(upper_pattern, prepend_underscore_and_lower, name) + # Handle when there was nothing before the plural_pattern + if s1.startswith("_") and not name.startswith("_"): + s1 = s1[1:] + if reversible: + return s1 + + # Remainder of solution seems to be https://stackoverflow.com/a/1176023 + first_cap_pattern = r'(.)([A-Z][a-z]+)' + all_cap_pattern = r'([a-z0-9])([A-Z]+)' + s2 = re.sub(first_cap_pattern, r'\1_\2', s1) + return re.sub(all_cap_pattern, r'\1_\2', s2).lower() + + +def dict_merge(a, b): + '''recursively merges dicts. not just simple a['key'] = b['key'], if + both a and b have a key whose value is a dict then dict_merge is called + on both values and the result stored in the returned dictionary.''' + if not isinstance(b, dict): + return b + result = deepcopy(a) + for k, v in b.items(): + if k in result and isinstance(result[k], dict): + result[k] = dict_merge(result[k], v) + else: + result[k] = deepcopy(v) + return result + + +def recursive_diff(dict1, dict2): + """Recursively diff two dictionaries + + Raises ``TypeError`` for incorrect argument type. + + :arg dict1: Dictionary to compare against. + :arg dict2: Dictionary to compare with ``dict1``. + :return: Tuple of dictionaries of differences or ``None`` if there are no differences. + """ + + if not all((isinstance(item, MutableMapping) for item in (dict1, dict2))): + raise TypeError("Unable to diff 'dict1' %s and 'dict2' %s. " + "Both must be a dictionary." % (type(dict1), type(dict2))) + + left = dict((k, v) for (k, v) in dict1.items() if k not in dict2) + right = dict((k, v) for (k, v) in dict2.items() if k not in dict1) + for k in (set(dict1.keys()) & set(dict2.keys())): + if isinstance(dict1[k], dict) and isinstance(dict2[k], dict): + result = recursive_diff(dict1[k], dict2[k]) + if result: + left[k] = result[0] + right[k] = result[1] + elif dict1[k] != dict2[k]: + left[k] = dict1[k] + right[k] = dict2[k] + if left or right: + return left, right + return None diff --git a/lib/ansible/module_utils/common/file.py b/lib/ansible/module_utils/common/file.py new file mode 100644 index 0000000..1e83660 --- /dev/null +++ b/lib/ansible/module_utils/common/file.py @@ -0,0 +1,205 @@ +# Copyright (c) 2018, Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import errno +import os +import stat +import re +import pwd +import grp +import time +import shutil +import traceback +import fcntl +import sys + +from contextlib import contextmanager +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.six import b, binary_type +from ansible.module_utils.common.warnings import deprecate + +try: + import selinux + HAVE_SELINUX = True +except ImportError: + HAVE_SELINUX = False + + +FILE_ATTRIBUTES = { + 'A': 'noatime', + 'a': 'append', + 'c': 'compressed', + 'C': 'nocow', + 'd': 'nodump', + 'D': 'dirsync', + 'e': 'extents', + 'E': 'encrypted', + 'h': 'blocksize', + 'i': 'immutable', + 'I': 'indexed', + 'j': 'journalled', + 'N': 'inline', + 's': 'zero', + 'S': 'synchronous', + 't': 'notail', + 'T': 'blockroot', + 'u': 'undelete', + 'X': 'compressedraw', + 'Z': 'compresseddirty', +} + + +# Used for parsing symbolic file perms +MODE_OPERATOR_RE = re.compile(r'[+=-]') +USERS_RE = re.compile(r'[^ugo]') +PERMS_RE = re.compile(r'[^rwxXstugo]') + + +_PERM_BITS = 0o7777 # file mode permission bits +_EXEC_PERM_BITS = 0o0111 # execute permission bits +_DEFAULT_PERM = 0o0666 # default file permission bits + + +def is_executable(path): + # This function's signature needs to be repeated + # as the first line of its docstring. + # This method is reused by the basic module, + # the repetition helps the basic module's html documentation come out right. + # http://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_docstring_signature + '''is_executable(path) + + is the given path executable? + + :arg path: The path of the file to check. + + Limitations: + + * Does not account for FSACLs. + * Most times we really want to know "Can the current user execute this + file". This function does not tell us that, only if any execute bit is set. + ''' + # These are all bitfields so first bitwise-or all the permissions we're + # looking for, then bitwise-and with the file's mode to determine if any + # execute bits are set. + return ((stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) & os.stat(path)[stat.ST_MODE]) + + +def format_attributes(attributes): + attribute_list = [FILE_ATTRIBUTES.get(attr) for attr in attributes if attr in FILE_ATTRIBUTES] + return attribute_list + + +def get_flags_from_attributes(attributes): + flags = [key for key, attr in FILE_ATTRIBUTES.items() if attr in attributes] + return ''.join(flags) + + +def get_file_arg_spec(): + arg_spec = dict( + mode=dict(type='raw'), + owner=dict(), + group=dict(), + seuser=dict(), + serole=dict(), + selevel=dict(), + setype=dict(), + attributes=dict(aliases=['attr']), + ) + return arg_spec + + +class LockTimeout(Exception): + pass + + +class FileLock: + ''' + Currently FileLock is implemented via fcntl.flock on a lock file, however this + behaviour may change in the future. Avoid mixing lock types fcntl.flock, + fcntl.lockf and module_utils.common.file.FileLock as it will certainly cause + unwanted and/or unexpected behaviour + ''' + def __init__(self): + deprecate("FileLock is not reliable and has never been used in core for that reason. There is no current alternative that works across POSIX targets", + version='2.16') + self.lockfd = None + + @contextmanager + def lock_file(self, path, tmpdir, lock_timeout=None): + ''' + Context for lock acquisition + ''' + try: + self.set_lock(path, tmpdir, lock_timeout) + yield + finally: + self.unlock() + + def set_lock(self, path, tmpdir, lock_timeout=None): + ''' + Create a lock file based on path with flock to prevent other processes + using given path. + Please note that currently file locking only works when it's executed by + the same user, I.E single user scenarios + + :kw path: Path (file) to lock + :kw tmpdir: Path where to place the temporary .lock file + :kw lock_timeout: + Wait n seconds for lock acquisition, fail if timeout is reached. + 0 = Do not wait, fail if lock cannot be acquired immediately, + Default is None, wait indefinitely until lock is released. + :returns: True + ''' + lock_path = os.path.join(tmpdir, 'ansible-{0}.lock'.format(os.path.basename(path))) + l_wait = 0.1 + r_exception = IOError + if sys.version_info[0] == 3: + r_exception = BlockingIOError + + self.lockfd = open(lock_path, 'w') + + if lock_timeout <= 0: + fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB) + os.chmod(lock_path, stat.S_IWRITE | stat.S_IREAD) + return True + + if lock_timeout: + e_secs = 0 + while e_secs < lock_timeout: + try: + fcntl.flock(self.lockfd, fcntl.LOCK_EX | fcntl.LOCK_NB) + os.chmod(lock_path, stat.S_IWRITE | stat.S_IREAD) + return True + except r_exception: + time.sleep(l_wait) + e_secs += l_wait + continue + + self.lockfd.close() + raise LockTimeout('{0} sec'.format(lock_timeout)) + + fcntl.flock(self.lockfd, fcntl.LOCK_EX) + os.chmod(lock_path, stat.S_IWRITE | stat.S_IREAD) + + return True + + def unlock(self): + ''' + Make sure lock file is available for everyone and Unlock the file descriptor + locked by set_lock + + :returns: True + ''' + if not self.lockfd: + return True + + try: + fcntl.flock(self.lockfd, fcntl.LOCK_UN) + self.lockfd.close() + except ValueError: # file wasn't opened, let context manager fail gracefully + pass + + return True diff --git a/lib/ansible/module_utils/common/json.py b/lib/ansible/module_utils/common/json.py new file mode 100644 index 0000000..727083c --- /dev/null +++ b/lib/ansible/module_utils/common/json.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +import datetime + +from ansible.module_utils._text import to_text +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.common.collections import is_sequence + + +def _is_unsafe(value): + return getattr(value, '__UNSAFE__', False) and not getattr(value, '__ENCRYPTED__', False) + + +def _is_vault(value): + return getattr(value, '__ENCRYPTED__', False) + + +def _preprocess_unsafe_encode(value): + """Recursively preprocess a data structure converting instances of ``AnsibleUnsafe`` + into their JSON dict representations + + Used in ``AnsibleJSONEncoder.iterencode`` + """ + if _is_unsafe(value): + value = {'__ansible_unsafe': to_text(value, errors='surrogate_or_strict', nonstring='strict')} + elif is_sequence(value): + value = [_preprocess_unsafe_encode(v) for v in value] + elif isinstance(value, Mapping): + value = dict((k, _preprocess_unsafe_encode(v)) for k, v in value.items()) + + return value + + +def json_dump(structure): + return json.dumps(structure, cls=AnsibleJSONEncoder, sort_keys=True, indent=4) + + +class AnsibleJSONEncoder(json.JSONEncoder): + ''' + Simple encoder class to deal with JSON encoding of Ansible internal types + ''' + + def __init__(self, preprocess_unsafe=False, vault_to_text=False, **kwargs): + self._preprocess_unsafe = preprocess_unsafe + self._vault_to_text = vault_to_text + super(AnsibleJSONEncoder, self).__init__(**kwargs) + + # NOTE: ALWAYS inform AWS/Tower when new items get added as they consume them downstream via a callback + def default(self, o): + if getattr(o, '__ENCRYPTED__', False): + # vault object + if self._vault_to_text: + value = to_text(o, errors='surrogate_or_strict') + else: + value = {'__ansible_vault': to_text(o._ciphertext, errors='surrogate_or_strict', nonstring='strict')} + elif getattr(o, '__UNSAFE__', False): + # unsafe object, this will never be triggered, see ``AnsibleJSONEncoder.iterencode`` + value = {'__ansible_unsafe': to_text(o, errors='surrogate_or_strict', nonstring='strict')} + elif isinstance(o, Mapping): + # hostvars and other objects + value = dict(o) + elif isinstance(o, (datetime.date, datetime.datetime)): + # date object + value = o.isoformat() + else: + # use default encoder + value = super(AnsibleJSONEncoder, self).default(o) + return value + + def iterencode(self, o, **kwargs): + """Custom iterencode, primarily design to handle encoding ``AnsibleUnsafe`` + as the ``AnsibleUnsafe`` subclasses inherit from string types and + ``json.JSONEncoder`` does not support custom encoders for string types + """ + if self._preprocess_unsafe: + o = _preprocess_unsafe_encode(o) + + return super(AnsibleJSONEncoder, self).iterencode(o, **kwargs) diff --git a/lib/ansible/module_utils/common/locale.py b/lib/ansible/module_utils/common/locale.py new file mode 100644 index 0000000..a6068c8 --- /dev/null +++ b/lib/ansible/module_utils/common/locale.py @@ -0,0 +1,61 @@ +# Copyright (c), Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils._text import to_native + + +def get_best_parsable_locale(module, preferences=None, raise_on_locale=False): + ''' + Attempts to return the best possible locale for parsing output in English + useful for scraping output with i18n tools. When this raises an exception + and the caller wants to continue, it should use the 'C' locale. + + :param module: an AnsibleModule instance + :param preferences: A list of preferred locales, in order of preference + :param raise_on_locale: boolean that determines if we raise exception or not + due to locale CLI issues + :returns: The first matched preferred locale or 'C' which is the default + ''' + + found = 'C' # default posix, its ascii but always there + try: + locale = module.get_bin_path("locale") + if not locale: + # not using required=true as that forces fail_json + raise RuntimeWarning("Could not find 'locale' tool") + + available = [] + + if preferences is None: + # new POSIX standard or English cause those are messages core team expects + # yes, the last 2 are the same but some systems are weird + preferences = ['C.utf8', 'C.UTF-8', 'en_US.utf8', 'en_US.UTF-8', 'C', 'POSIX'] + + rc, out, err = module.run_command([locale, '-a']) + + if rc == 0: + if out: + available = out.strip().splitlines() + else: + raise RuntimeWarning("No output from locale, rc=%s: %s" % (rc, to_native(err))) + else: + raise RuntimeWarning("Unable to get locale information, rc=%s: %s" % (rc, to_native(err))) + + if available: + for pref in preferences: + if pref in available: + found = pref + break + + except RuntimeWarning as e: + if raise_on_locale: + raise + else: + module.debug('Failed to get locale information: %s' % to_native(e)) + + module.debug('Matched preferred locale to: %s' % found) + + return found diff --git a/lib/ansible/module_utils/common/network.py b/lib/ansible/module_utils/common/network.py new file mode 100644 index 0000000..c3874f8 --- /dev/null +++ b/lib/ansible/module_utils/common/network.py @@ -0,0 +1,161 @@ +# Copyright (c) 2016 Red Hat Inc +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +# General networking tools that may be used by all modules + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +from struct import pack +from socket import inet_ntoa + +from ansible.module_utils.six.moves import zip + + +VALID_MASKS = [2**8 - 2**i for i in range(0, 9)] + + +def is_netmask(val): + parts = str(val).split('.') + if not len(parts) == 4: + return False + for part in parts: + try: + if int(part) not in VALID_MASKS: + raise ValueError + except ValueError: + return False + return True + + +def is_masklen(val): + try: + return 0 <= int(val) <= 32 + except ValueError: + return False + + +def to_netmask(val): + """ converts a masklen to a netmask """ + if not is_masklen(val): + raise ValueError('invalid value for masklen') + + bits = 0 + for i in range(32 - int(val), 32): + bits |= (1 << i) + + return inet_ntoa(pack('>I', bits)) + + +def to_masklen(val): + """ converts a netmask to a masklen """ + if not is_netmask(val): + raise ValueError('invalid value for netmask: %s' % val) + + bits = list() + for x in val.split('.'): + octet = bin(int(x)).count('1') + bits.append(octet) + + return sum(bits) + + +def to_subnet(addr, mask, dotted_notation=False): + """ coverts an addr / mask pair to a subnet in cidr notation """ + try: + if not is_masklen(mask): + raise ValueError + cidr = int(mask) + mask = to_netmask(mask) + except ValueError: + cidr = to_masklen(mask) + + addr = addr.split('.') + mask = mask.split('.') + + network = list() + for s_addr, s_mask in zip(addr, mask): + network.append(str(int(s_addr) & int(s_mask))) + + if dotted_notation: + return '%s %s' % ('.'.join(network), to_netmask(cidr)) + return '%s/%s' % ('.'.join(network), cidr) + + +def to_ipv6_subnet(addr): + """ IPv6 addresses are eight groupings. The first four groupings (64 bits) comprise the subnet address. """ + + # https://tools.ietf.org/rfc/rfc2374.txt + + # Split by :: to identify omitted zeros + ipv6_prefix = addr.split('::')[0] + + # Get the first four groups, or as many as are found + :: + found_groups = [] + for group in ipv6_prefix.split(':'): + found_groups.append(group) + if len(found_groups) == 4: + break + if len(found_groups) < 4: + found_groups.append('::') + + # Concatenate network address parts + network_addr = '' + for group in found_groups: + if group != '::': + network_addr += str(group) + network_addr += str(':') + + # Ensure network address ends with :: + if not network_addr.endswith('::'): + network_addr += str(':') + return network_addr + + +def to_ipv6_network(addr): + """ IPv6 addresses are eight groupings. The first three groupings (48 bits) comprise the network address. """ + + # Split by :: to identify omitted zeros + ipv6_prefix = addr.split('::')[0] + + # Get the first three groups, or as many as are found + :: + found_groups = [] + for group in ipv6_prefix.split(':'): + found_groups.append(group) + if len(found_groups) == 3: + break + if len(found_groups) < 3: + found_groups.append('::') + + # Concatenate network address parts + network_addr = '' + for group in found_groups: + if group != '::': + network_addr += str(group) + network_addr += str(':') + + # Ensure network address ends with :: + if not network_addr.endswith('::'): + network_addr += str(':') + return network_addr + + +def to_bits(val): + """ converts a netmask to bits """ + bits = '' + for octet in val.split('.'): + bits += bin(int(octet))[2:].zfill(8) + return bits + + +def is_mac(mac_address): + """ + Validate MAC address for given string + Args: + mac_address: string to validate as MAC address + + Returns: (Boolean) True if string is valid MAC address, otherwise False + """ + mac_addr_regex = re.compile('[0-9a-f]{2}([-:])[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$') + return bool(mac_addr_regex.match(mac_address.lower())) diff --git a/lib/ansible/module_utils/common/parameters.py b/lib/ansible/module_utils/common/parameters.py new file mode 100644 index 0000000..059ca0a --- /dev/null +++ b/lib/ansible/module_utils/common/parameters.py @@ -0,0 +1,940 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import datetime +import os + +from collections import deque +from itertools import chain + +from ansible.module_utils.common.collections import is_iterable +from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text +from ansible.module_utils.common.text.formatters import lenient_lowercase +from ansible.module_utils.common.warnings import warn +from ansible.module_utils.errors import ( + AliasError, + AnsibleFallbackNotFound, + AnsibleValidationErrorMultiple, + ArgumentTypeError, + ArgumentValueError, + ElementError, + MutuallyExclusiveError, + NoLogError, + RequiredByError, + RequiredError, + RequiredIfError, + RequiredOneOfError, + RequiredTogetherError, + SubParameterTypeError, +) +from ansible.module_utils.parsing.convert_bool import BOOLEANS_FALSE, BOOLEANS_TRUE + +from ansible.module_utils.common._collections_compat import ( + KeysView, + Set, + Sequence, + Mapping, + MutableMapping, + MutableSet, + MutableSequence, +) + +from ansible.module_utils.six import ( + binary_type, + integer_types, + string_types, + text_type, + PY2, + PY3, +) + +from ansible.module_utils.common.validation import ( + check_mutually_exclusive, + check_required_arguments, + check_required_together, + check_required_one_of, + check_required_if, + check_required_by, + check_type_bits, + check_type_bool, + check_type_bytes, + check_type_dict, + check_type_float, + check_type_int, + check_type_jsonarg, + check_type_list, + check_type_path, + check_type_raw, + check_type_str, +) + +# Python2 & 3 way to get NoneType +NoneType = type(None) + +_ADDITIONAL_CHECKS = ( + {'func': check_required_together, 'attr': 'required_together', 'err': RequiredTogetherError}, + {'func': check_required_one_of, 'attr': 'required_one_of', 'err': RequiredOneOfError}, + {'func': check_required_if, 'attr': 'required_if', 'err': RequiredIfError}, + {'func': check_required_by, 'attr': 'required_by', 'err': RequiredByError}, +) + +# if adding boolean attribute, also add to PASS_BOOL +# some of this dupes defaults from controller config +PASS_VARS = { + 'check_mode': ('check_mode', False), + 'debug': ('_debug', False), + 'diff': ('_diff', False), + 'keep_remote_files': ('_keep_remote_files', False), + 'module_name': ('_name', None), + 'no_log': ('no_log', False), + 'remote_tmp': ('_remote_tmp', None), + 'selinux_special_fs': ('_selinux_special_fs', ['fuse', 'nfs', 'vboxsf', 'ramfs', '9p', 'vfat']), + 'shell_executable': ('_shell', '/bin/sh'), + 'socket': ('_socket_path', None), + 'string_conversion_action': ('_string_conversion_action', 'warn'), + 'syslog_facility': ('_syslog_facility', 'INFO'), + 'tmpdir': ('_tmpdir', None), + 'verbosity': ('_verbosity', 0), + 'version': ('ansible_version', '0.0'), +} + +PASS_BOOLS = ('check_mode', 'debug', 'diff', 'keep_remote_files', 'no_log') + +DEFAULT_TYPE_VALIDATORS = { + 'str': check_type_str, + 'list': check_type_list, + 'dict': check_type_dict, + 'bool': check_type_bool, + 'int': check_type_int, + 'float': check_type_float, + 'path': check_type_path, + 'raw': check_type_raw, + 'jsonarg': check_type_jsonarg, + 'json': check_type_jsonarg, + 'bytes': check_type_bytes, + 'bits': check_type_bits, +} + + +def _get_type_validator(wanted): + """Returns the callable used to validate a wanted type and the type name. + + :arg wanted: String or callable. If a string, get the corresponding + validation function from DEFAULT_TYPE_VALIDATORS. If callable, + get the name of the custom callable and return that for the type_checker. + + :returns: Tuple of callable function or None, and a string that is the name + of the wanted type. + """ + + # Use one of our builtin validators. + if not callable(wanted): + if wanted is None: + # Default type for parameters + wanted = 'str' + + type_checker = DEFAULT_TYPE_VALIDATORS.get(wanted) + + # Use the custom callable for validation. + else: + type_checker = wanted + wanted = getattr(wanted, '__name__', to_native(type(wanted))) + + return type_checker, wanted + + +def _get_legal_inputs(argument_spec, parameters, aliases=None): + if aliases is None: + aliases = _handle_aliases(argument_spec, parameters) + + return list(aliases.keys()) + list(argument_spec.keys()) + + +def _get_unsupported_parameters(argument_spec, parameters, legal_inputs=None, options_context=None, store_supported=None): + """Check keys in parameters against those provided in legal_inputs + to ensure they contain legal values. If legal_inputs are not supplied, + they will be generated using the argument_spec. + + :arg argument_spec: Dictionary of parameters, their type, and valid values. + :arg parameters: Dictionary of parameters. + :arg legal_inputs: List of valid key names property names. Overrides values + in argument_spec. + :arg options_context: List of parent keys for tracking the context of where + a parameter is defined. + + :returns: Set of unsupported parameters. Empty set if no unsupported parameters + are found. + """ + + if legal_inputs is None: + legal_inputs = _get_legal_inputs(argument_spec, parameters) + + unsupported_parameters = set() + for k in parameters.keys(): + if k not in legal_inputs: + context = k + if options_context: + context = tuple(options_context + [k]) + + unsupported_parameters.add(context) + + if store_supported is not None: + supported_aliases = _handle_aliases(argument_spec, parameters) + supported_params = [] + for option in legal_inputs: + if option in supported_aliases: + continue + supported_params.append(option) + + store_supported.update({context: (supported_params, supported_aliases)}) + + return unsupported_parameters + + +def _handle_aliases(argument_spec, parameters, alias_warnings=None, alias_deprecations=None): + """Process aliases from an argument_spec including warnings and deprecations. + + Modify ``parameters`` by adding a new key for each alias with the supplied + value from ``parameters``. + + If a list is provided to the alias_warnings parameter, it will be filled with tuples + (option, alias) in every case where both an option and its alias are specified. + + If a list is provided to alias_deprecations, it will be populated with dictionaries, + each containing deprecation information for each alias found in argument_spec. + + :param argument_spec: Dictionary of parameters, their type, and valid values. + :type argument_spec: dict + + :param parameters: Dictionary of parameters. + :type parameters: dict + + :param alias_warnings: + :type alias_warnings: list + + :param alias_deprecations: + :type alias_deprecations: list + """ + + aliases_results = {} # alias:canon + + for (k, v) in argument_spec.items(): + aliases = v.get('aliases', None) + default = v.get('default', None) + required = v.get('required', False) + + if alias_deprecations is not None: + for alias in argument_spec[k].get('deprecated_aliases', []): + if alias.get('name') in parameters: + alias_deprecations.append(alias) + + if default is not None and required: + # not alias specific but this is a good place to check this + raise ValueError("internal error: required and default are mutually exclusive for %s" % k) + + if aliases is None: + continue + + if not is_iterable(aliases) or isinstance(aliases, (binary_type, text_type)): + raise TypeError('internal error: aliases must be a list or tuple') + + for alias in aliases: + aliases_results[alias] = k + if alias in parameters: + if k in parameters and alias_warnings is not None: + alias_warnings.append((k, alias)) + parameters[k] = parameters[alias] + + return aliases_results + + +def _list_deprecations(argument_spec, parameters, prefix=''): + """Return a list of deprecations + + :arg argument_spec: An argument spec dictionary + :arg parameters: Dictionary of parameters + + :returns: List of dictionaries containing a message and version in which + the deprecated parameter will be removed, or an empty list. + + :Example return: + + .. code-block:: python + + [ + { + 'msg': "Param 'deptest' is deprecated. See the module docs for more information", + 'version': '2.9' + } + ] + """ + + deprecations = [] + for arg_name, arg_opts in argument_spec.items(): + if arg_name in parameters: + if prefix: + sub_prefix = '%s["%s"]' % (prefix, arg_name) + else: + sub_prefix = arg_name + if arg_opts.get('removed_at_date') is not None: + deprecations.append({ + 'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix, + 'date': arg_opts.get('removed_at_date'), + 'collection_name': arg_opts.get('removed_from_collection'), + }) + elif arg_opts.get('removed_in_version') is not None: + deprecations.append({ + 'msg': "Param '%s' is deprecated. See the module docs for more information" % sub_prefix, + 'version': arg_opts.get('removed_in_version'), + 'collection_name': arg_opts.get('removed_from_collection'), + }) + # Check sub-argument spec + sub_argument_spec = arg_opts.get('options') + if sub_argument_spec is not None: + sub_arguments = parameters[arg_name] + if isinstance(sub_arguments, Mapping): + sub_arguments = [sub_arguments] + if isinstance(sub_arguments, list): + for sub_params in sub_arguments: + if isinstance(sub_params, Mapping): + deprecations.extend(_list_deprecations(sub_argument_spec, sub_params, prefix=sub_prefix)) + + return deprecations + + +def _list_no_log_values(argument_spec, params): + """Return set of no log values + + :arg argument_spec: An argument spec dictionary + :arg params: Dictionary of all parameters + + :returns: :class:`set` of strings that should be hidden from output: + """ + + no_log_values = set() + for arg_name, arg_opts in argument_spec.items(): + if arg_opts.get('no_log', False): + # Find the value for the no_log'd param + no_log_object = params.get(arg_name, None) + + if no_log_object: + try: + no_log_values.update(_return_datastructure_name(no_log_object)) + except TypeError as e: + raise TypeError('Failed to convert "%s": %s' % (arg_name, to_native(e))) + + # Get no_log values from suboptions + sub_argument_spec = arg_opts.get('options') + if sub_argument_spec is not None: + wanted_type = arg_opts.get('type') + sub_parameters = params.get(arg_name) + + if sub_parameters is not None: + if wanted_type == 'dict' or (wanted_type == 'list' and arg_opts.get('elements', '') == 'dict'): + # Sub parameters can be a dict or list of dicts. Ensure parameters are always a list. + if not isinstance(sub_parameters, list): + sub_parameters = [sub_parameters] + + for sub_param in sub_parameters: + # Validate dict fields in case they came in as strings + + if isinstance(sub_param, string_types): + sub_param = check_type_dict(sub_param) + + if not isinstance(sub_param, Mapping): + raise TypeError("Value '{1}' in the sub parameter field '{0}' must by a {2}, " + "not '{1.__class__.__name__}'".format(arg_name, sub_param, wanted_type)) + + no_log_values.update(_list_no_log_values(sub_argument_spec, sub_param)) + + return no_log_values + + +def _return_datastructure_name(obj): + """ Return native stringified values from datastructures. + + For use with removing sensitive values pre-jsonification.""" + if isinstance(obj, (text_type, binary_type)): + if obj: + yield to_native(obj, errors='surrogate_or_strict') + return + elif isinstance(obj, Mapping): + for element in obj.items(): + for subelement in _return_datastructure_name(element[1]): + yield subelement + elif is_iterable(obj): + for element in obj: + for subelement in _return_datastructure_name(element): + yield subelement + elif obj is None or isinstance(obj, bool): + # This must come before int because bools are also ints + return + elif isinstance(obj, tuple(list(integer_types) + [float])): + yield to_native(obj, nonstring='simplerepr') + else: + raise TypeError('Unknown parameter type: %s' % (type(obj))) + + +def _remove_values_conditions(value, no_log_strings, deferred_removals): + """ + Helper function for :meth:`remove_values`. + + :arg value: The value to check for strings that need to be stripped + :arg no_log_strings: set of strings which must be stripped out of any values + :arg deferred_removals: List which holds information about nested + containers that have to be iterated for removals. It is passed into + this function so that more entries can be added to it if value is + a container type. The format of each entry is a 2-tuple where the first + element is the ``value`` parameter and the second value is a new + container to copy the elements of ``value`` into once iterated. + + :returns: if ``value`` is a scalar, returns ``value`` with two exceptions: + + 1. :class:`~datetime.datetime` objects which are changed into a string representation. + 2. objects which are in ``no_log_strings`` are replaced with a placeholder + so that no sensitive data is leaked. + + If ``value`` is a container type, returns a new empty container. + + ``deferred_removals`` is added to as a side-effect of this function. + + .. warning:: It is up to the caller to make sure the order in which value + is passed in is correct. For instance, higher level containers need + to be passed in before lower level containers. For example, given + ``{'level1': {'level2': 'level3': [True]} }`` first pass in the + dictionary for ``level1``, then the dict for ``level2``, and finally + the list for ``level3``. + """ + if isinstance(value, (text_type, binary_type)): + # Need native str type + native_str_value = value + if isinstance(value, text_type): + value_is_text = True + if PY2: + native_str_value = to_bytes(value, errors='surrogate_or_strict') + elif isinstance(value, binary_type): + value_is_text = False + if PY3: + native_str_value = to_text(value, errors='surrogate_or_strict') + + if native_str_value in no_log_strings: + return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' + for omit_me in no_log_strings: + native_str_value = native_str_value.replace(omit_me, '*' * 8) + + if value_is_text and isinstance(native_str_value, binary_type): + value = to_text(native_str_value, encoding='utf-8', errors='surrogate_then_replace') + elif not value_is_text and isinstance(native_str_value, text_type): + value = to_bytes(native_str_value, encoding='utf-8', errors='surrogate_then_replace') + else: + value = native_str_value + + elif isinstance(value, Sequence): + if isinstance(value, MutableSequence): + new_value = type(value)() + else: + new_value = [] # Need a mutable value + deferred_removals.append((value, new_value)) + value = new_value + + elif isinstance(value, Set): + if isinstance(value, MutableSet): + new_value = type(value)() + else: + new_value = set() # Need a mutable value + deferred_removals.append((value, new_value)) + value = new_value + + elif isinstance(value, Mapping): + if isinstance(value, MutableMapping): + new_value = type(value)() + else: + new_value = {} # Need a mutable value + deferred_removals.append((value, new_value)) + value = new_value + + elif isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): + stringy_value = to_native(value, encoding='utf-8', errors='surrogate_or_strict') + if stringy_value in no_log_strings: + return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' + for omit_me in no_log_strings: + if omit_me in stringy_value: + return 'VALUE_SPECIFIED_IN_NO_LOG_PARAMETER' + + elif isinstance(value, (datetime.datetime, datetime.date)): + value = value.isoformat() + else: + raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) + + return value + + +def _set_defaults(argument_spec, parameters, set_default=True): + """Set default values for parameters when no value is supplied. + + Modifies parameters directly. + + :arg argument_spec: Argument spec + :type argument_spec: dict + + :arg parameters: Parameters to evaluate + :type parameters: dict + + :kwarg set_default: Whether or not to set the default values + :type set_default: bool + + :returns: Set of strings that should not be logged. + :rtype: set + """ + + no_log_values = set() + for param, value in argument_spec.items(): + + # TODO: Change the default value from None to Sentinel to differentiate between + # user supplied None and a default value set by this function. + default = value.get('default', None) + + # This prevents setting defaults on required items on the 1st run, + # otherwise will set things without a default to None on the 2nd. + if param not in parameters and (default is not None or set_default): + # Make sure any default value for no_log fields are masked. + if value.get('no_log', False) and default: + no_log_values.add(default) + + parameters[param] = default + + return no_log_values + + +def _sanitize_keys_conditions(value, no_log_strings, ignore_keys, deferred_removals): + """ Helper method to :func:`sanitize_keys` to build ``deferred_removals`` and avoid deep recursion. """ + if isinstance(value, (text_type, binary_type)): + return value + + if isinstance(value, Sequence): + if isinstance(value, MutableSequence): + new_value = type(value)() + else: + new_value = [] # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value + + if isinstance(value, Set): + if isinstance(value, MutableSet): + new_value = type(value)() + else: + new_value = set() # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value + + if isinstance(value, Mapping): + if isinstance(value, MutableMapping): + new_value = type(value)() + else: + new_value = {} # Need a mutable value + deferred_removals.append((value, new_value)) + return new_value + + if isinstance(value, tuple(chain(integer_types, (float, bool, NoneType)))): + return value + + if isinstance(value, (datetime.datetime, datetime.date)): + return value + + raise TypeError('Value of unknown type: %s, %s' % (type(value), value)) + + +def _validate_elements(wanted_type, parameter, values, options_context=None, errors=None): + + if errors is None: + errors = AnsibleValidationErrorMultiple() + + type_checker, wanted_element_type = _get_type_validator(wanted_type) + validated_parameters = [] + # Get param name for strings so we can later display this value in a useful error message if needed + # Only pass 'kwargs' to our checkers and ignore custom callable checkers + kwargs = {} + if wanted_element_type == 'str' and isinstance(wanted_type, string_types): + if isinstance(parameter, string_types): + kwargs['param'] = parameter + elif isinstance(parameter, dict): + kwargs['param'] = list(parameter.keys())[0] + + for value in values: + try: + validated_parameters.append(type_checker(value, **kwargs)) + except (TypeError, ValueError) as e: + msg = "Elements value for option '%s'" % parameter + if options_context: + msg += " found in '%s'" % " -> ".join(options_context) + msg += " is of type %s and we were unable to convert to %s: %s" % (type(value), wanted_element_type, to_native(e)) + errors.append(ElementError(msg)) + return validated_parameters + + +def _validate_argument_types(argument_spec, parameters, prefix='', options_context=None, errors=None): + """Validate that parameter types match the type in the argument spec. + + Determine the appropriate type checker function and run each + parameter value through that function. All error messages from type checker + functions are returned. If any parameter fails to validate, it will not + be in the returned parameters. + + :arg argument_spec: Argument spec + :type argument_spec: dict + + :arg parameters: Parameters + :type parameters: dict + + :kwarg prefix: Name of the parent key that contains the spec. Used in the error message + :type prefix: str + + :kwarg options_context: List of contexts? + :type options_context: list + + :returns: Two item tuple containing validated and coerced parameters + and a list of any errors that were encountered. + :rtype: tuple + + """ + + if errors is None: + errors = AnsibleValidationErrorMultiple() + + for param, spec in argument_spec.items(): + if param not in parameters: + continue + + value = parameters[param] + if value is None: + continue + + wanted_type = spec.get('type') + type_checker, wanted_name = _get_type_validator(wanted_type) + # Get param name for strings so we can later display this value in a useful error message if needed + # Only pass 'kwargs' to our checkers and ignore custom callable checkers + kwargs = {} + if wanted_name == 'str' and isinstance(wanted_type, string_types): + kwargs['param'] = list(parameters.keys())[0] + + # Get the name of the parent key if this is a nested option + if prefix: + kwargs['prefix'] = prefix + + try: + parameters[param] = type_checker(value, **kwargs) + elements_wanted_type = spec.get('elements', None) + if elements_wanted_type: + elements = parameters[param] + if wanted_type != 'list' or not isinstance(elements, list): + msg = "Invalid type %s for option '%s'" % (wanted_name, elements) + if options_context: + msg += " found in '%s'." % " -> ".join(options_context) + msg += ", elements value check is supported only with 'list' type" + errors.append(ArgumentTypeError(msg)) + parameters[param] = _validate_elements(elements_wanted_type, param, elements, options_context, errors) + + except (TypeError, ValueError) as e: + msg = "argument '%s' is of type %s" % (param, type(value)) + if options_context: + msg += " found in '%s'." % " -> ".join(options_context) + msg += " and we were unable to convert to %s: %s" % (wanted_name, to_native(e)) + errors.append(ArgumentTypeError(msg)) + + +def _validate_argument_values(argument_spec, parameters, options_context=None, errors=None): + """Ensure all arguments have the requested values, and there are no stray arguments""" + + if errors is None: + errors = AnsibleValidationErrorMultiple() + + for param, spec in argument_spec.items(): + choices = spec.get('choices') + if choices is None: + continue + + if isinstance(choices, (frozenset, KeysView, Sequence)) and not isinstance(choices, (binary_type, text_type)): + if param in parameters: + # Allow one or more when type='list' param with choices + if isinstance(parameters[param], list): + diff_list = [item for item in parameters[param] if item not in choices] + if diff_list: + choices_str = ", ".join([to_native(c) for c in choices]) + diff_str = ", ".join(diff_list) + msg = "value of %s must be one or more of: %s. Got no match for: %s" % (param, choices_str, diff_str) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + errors.append(ArgumentValueError(msg)) + elif parameters[param] not in choices: + # PyYaml converts certain strings to bools. If we can unambiguously convert back, do so before checking + # the value. If we can't figure this out, module author is responsible. + if parameters[param] == 'False': + overlap = BOOLEANS_FALSE.intersection(choices) + if len(overlap) == 1: + # Extract from a set + (parameters[param],) = overlap + + if parameters[param] == 'True': + overlap = BOOLEANS_TRUE.intersection(choices) + if len(overlap) == 1: + (parameters[param],) = overlap + + if parameters[param] not in choices: + choices_str = ", ".join([to_native(c) for c in choices]) + msg = "value of %s must be one of: %s, got: %s" % (param, choices_str, parameters[param]) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + errors.append(ArgumentValueError(msg)) + else: + msg = "internal error: choices for argument %s are not iterable: %s" % (param, choices) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + errors.append(ArgumentTypeError(msg)) + + +def _validate_sub_spec( + argument_spec, + parameters, + prefix="", + options_context=None, + errors=None, + no_log_values=None, + unsupported_parameters=None, + supported_parameters=None, + alias_deprecations=None, +): + """Validate sub argument spec. + + This function is recursive. + """ + + if options_context is None: + options_context = [] + + if errors is None: + errors = AnsibleValidationErrorMultiple() + + if no_log_values is None: + no_log_values = set() + + if unsupported_parameters is None: + unsupported_parameters = set() + if supported_parameters is None: + supported_parameters = dict() + + for param, value in argument_spec.items(): + wanted = value.get('type') + if wanted == 'dict' or (wanted == 'list' and value.get('elements', '') == 'dict'): + sub_spec = value.get('options') + if value.get('apply_defaults', False): + if sub_spec is not None: + if parameters.get(param) is None: + parameters[param] = {} + else: + continue + elif sub_spec is None or param not in parameters or parameters[param] is None: + continue + + # Keep track of context for warning messages + options_context.append(param) + + # Make sure we can iterate over the elements + if not isinstance(parameters[param], Sequence) or isinstance(parameters[param], string_types): + elements = [parameters[param]] + else: + elements = parameters[param] + + for idx, sub_parameters in enumerate(elements): + no_log_values.update(set_fallbacks(sub_spec, sub_parameters)) + + if not isinstance(sub_parameters, dict): + errors.append(SubParameterTypeError("value of '%s' must be of type dict or list of dicts" % param)) + continue + + # Set prefix for warning messages + new_prefix = prefix + param + if wanted == 'list': + new_prefix += '[%d]' % idx + new_prefix += '.' + + alias_warnings = [] + alias_deprecations_sub = [] + try: + options_aliases = _handle_aliases(sub_spec, sub_parameters, alias_warnings, alias_deprecations_sub) + except (TypeError, ValueError) as e: + options_aliases = {} + errors.append(AliasError(to_native(e))) + + for option, alias in alias_warnings: + warn('Both option %s%s and its alias %s%s are set.' % (new_prefix, option, new_prefix, alias)) + + if alias_deprecations is not None: + for deprecation in alias_deprecations_sub: + alias_deprecations.append({ + 'name': '%s%s' % (new_prefix, deprecation['name']), + 'version': deprecation.get('version'), + 'date': deprecation.get('date'), + 'collection_name': deprecation.get('collection_name'), + }) + + try: + no_log_values.update(_list_no_log_values(sub_spec, sub_parameters)) + except TypeError as te: + errors.append(NoLogError(to_native(te))) + + legal_inputs = _get_legal_inputs(sub_spec, sub_parameters, options_aliases) + unsupported_parameters.update( + _get_unsupported_parameters( + sub_spec, + sub_parameters, + legal_inputs, + options_context, + store_supported=supported_parameters, + ) + ) + + try: + check_mutually_exclusive(value.get('mutually_exclusive'), sub_parameters, options_context) + except TypeError as e: + errors.append(MutuallyExclusiveError(to_native(e))) + + no_log_values.update(_set_defaults(sub_spec, sub_parameters, False)) + + try: + check_required_arguments(sub_spec, sub_parameters, options_context) + except TypeError as e: + errors.append(RequiredError(to_native(e))) + + _validate_argument_types(sub_spec, sub_parameters, new_prefix, options_context, errors=errors) + _validate_argument_values(sub_spec, sub_parameters, options_context, errors=errors) + + for check in _ADDITIONAL_CHECKS: + try: + check['func'](value.get(check['attr']), sub_parameters, options_context) + except TypeError as e: + errors.append(check['err'](to_native(e))) + + no_log_values.update(_set_defaults(sub_spec, sub_parameters)) + + # Handle nested specs + _validate_sub_spec( + sub_spec, sub_parameters, new_prefix, options_context, errors, no_log_values, + unsupported_parameters, supported_parameters, alias_deprecations) + + options_context.pop() + + +def env_fallback(*args, **kwargs): + """Load value from environment variable""" + + for arg in args: + if arg in os.environ: + return os.environ[arg] + raise AnsibleFallbackNotFound + + +def set_fallbacks(argument_spec, parameters): + no_log_values = set() + for param, value in argument_spec.items(): + fallback = value.get('fallback', (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if param not in parameters and fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + fallback_value = fallback_strategy(*fallback_args, **fallback_kwargs) + except AnsibleFallbackNotFound: + continue + else: + if value.get('no_log', False) and fallback_value: + no_log_values.add(fallback_value) + parameters[param] = fallback_value + + return no_log_values + + +def sanitize_keys(obj, no_log_strings, ignore_keys=frozenset()): + """Sanitize the keys in a container object by removing ``no_log`` values from key names. + + This is a companion function to the :func:`remove_values` function. Similar to that function, + we make use of ``deferred_removals`` to avoid hitting maximum recursion depth in cases of + large data structures. + + :arg obj: The container object to sanitize. Non-container objects are returned unmodified. + :arg no_log_strings: A set of string values we do not want logged. + :kwarg ignore_keys: A set of string values of keys to not sanitize. + + :returns: An object with sanitized keys. + """ + + deferred_removals = deque() + + no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] + new_value = _sanitize_keys_conditions(obj, no_log_strings, ignore_keys, deferred_removals) + + while deferred_removals: + old_data, new_data = deferred_removals.popleft() + + if isinstance(new_data, Mapping): + for old_key, old_elem in old_data.items(): + if old_key in ignore_keys or old_key.startswith('_ansible'): + new_data[old_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) + else: + # Sanitize the old key. We take advantage of the sanitizing code in + # _remove_values_conditions() rather than recreating it here. + new_key = _remove_values_conditions(old_key, no_log_strings, None) + new_data[new_key] = _sanitize_keys_conditions(old_elem, no_log_strings, ignore_keys, deferred_removals) + else: + for elem in old_data: + new_elem = _sanitize_keys_conditions(elem, no_log_strings, ignore_keys, deferred_removals) + if isinstance(new_data, MutableSequence): + new_data.append(new_elem) + elif isinstance(new_data, MutableSet): + new_data.add(new_elem) + else: + raise TypeError('Unknown container type encountered when removing private values from keys') + + return new_value + + +def remove_values(value, no_log_strings): + """Remove strings in ``no_log_strings`` from value. + + If value is a container type, then remove a lot more. + + Use of ``deferred_removals`` exists, rather than a pure recursive solution, + because of the potential to hit the maximum recursion depth when dealing with + large amounts of data (see `issue #24560 <https://github.com/ansible/ansible/issues/24560>`_). + """ + + deferred_removals = deque() + + no_log_strings = [to_native(s, errors='surrogate_or_strict') for s in no_log_strings] + new_value = _remove_values_conditions(value, no_log_strings, deferred_removals) + + while deferred_removals: + old_data, new_data = deferred_removals.popleft() + if isinstance(new_data, Mapping): + for old_key, old_elem in old_data.items(): + new_elem = _remove_values_conditions(old_elem, no_log_strings, deferred_removals) + new_data[old_key] = new_elem + else: + for elem in old_data: + new_elem = _remove_values_conditions(elem, no_log_strings, deferred_removals) + if isinstance(new_data, MutableSequence): + new_data.append(new_elem) + elif isinstance(new_data, MutableSet): + new_data.add(new_elem) + else: + raise TypeError('Unknown container type encountered when removing private values from output') + + return new_value diff --git a/lib/ansible/module_utils/common/process.py b/lib/ansible/module_utils/common/process.py new file mode 100644 index 0000000..97761a4 --- /dev/null +++ b/lib/ansible/module_utils/common/process.py @@ -0,0 +1,46 @@ +# Copyright (c) 2018, Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from ansible.module_utils.common.file import is_executable + + +def get_bin_path(arg, opt_dirs=None, required=None): + ''' + Find system executable in PATH. Raises ValueError if executable is not found. + Optional arguments: + - required: [Deprecated] Prior to 2.10, if executable is not found and required is true it raises an Exception. + In 2.10 and later, an Exception is always raised. This parameter will be removed in 2.14. + - opt_dirs: optional list of directories to search in addition to PATH + In addition to PATH and opt_dirs, this function also looks through /sbin, /usr/sbin and /usr/local/sbin. A lot of + modules, especially for gathering facts, depend on this behaviour. + If found return full path, otherwise raise ValueError. + ''' + opt_dirs = [] if opt_dirs is None else opt_dirs + + sbin_paths = ['/sbin', '/usr/sbin', '/usr/local/sbin'] + paths = [] + for d in opt_dirs: + if d is not None and os.path.exists(d): + paths.append(d) + paths += os.environ.get('PATH', '').split(os.pathsep) + bin_path = None + # mangle PATH to include /sbin dirs + for p in sbin_paths: + if p not in paths and os.path.exists(p): + paths.append(p) + for d in paths: + if not d: + continue + path = os.path.join(d, arg) + if os.path.exists(path) and not os.path.isdir(path) and is_executable(path): + bin_path = path + break + if bin_path is None: + raise ValueError('Failed to find required executable "%s" in paths: %s' % (arg, os.pathsep.join(paths))) + + return bin_path diff --git a/lib/ansible/module_utils/common/respawn.py b/lib/ansible/module_utils/common/respawn.py new file mode 100644 index 0000000..3bc526a --- /dev/null +++ b/lib/ansible/module_utils/common/respawn.py @@ -0,0 +1,98 @@ +# Copyright: (c) 2021, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import subprocess +import sys + +from ansible.module_utils.common.text.converters import to_bytes, to_native + + +def has_respawned(): + return hasattr(sys.modules['__main__'], '_respawned') + + +def respawn_module(interpreter_path): + """ + Respawn the currently-running Ansible Python module under the specified Python interpreter. + + Ansible modules that require libraries that are typically available only under well-known interpreters + (eg, ``yum``, ``apt``, ``dnf``) can use bespoke logic to determine the libraries they need are not + available, then call `respawn_module` to re-execute the current module under a different interpreter + and exit the current process when the new subprocess has completed. The respawned process inherits only + stdout/stderr from the current process. + + Only a single respawn is allowed. ``respawn_module`` will fail on nested respawns. Modules are encouraged + to call `has_respawned()` to defensively guide behavior before calling ``respawn_module``, and to ensure + that the target interpreter exists, as ``respawn_module`` will not fail gracefully. + + :arg interpreter_path: path to a Python interpreter to respawn the current module + """ + + if has_respawned(): + raise Exception('module has already been respawned') + + # FUTURE: we need a safe way to log that a respawn has occurred for forensic/debug purposes + payload = _create_payload() + stdin_read, stdin_write = os.pipe() + os.write(stdin_write, to_bytes(payload)) + os.close(stdin_write) + rc = subprocess.call([interpreter_path, '--'], stdin=stdin_read) + sys.exit(rc) # pylint: disable=ansible-bad-function + + +def probe_interpreters_for_module(interpreter_paths, module_name): + """ + Probes a supplied list of Python interpreters, returning the first one capable of + importing the named module. This is useful when attempting to locate a "system + Python" where OS-packaged utility modules are located. + + :arg interpreter_paths: iterable of paths to Python interpreters. The paths will be probed + in order, and the first path that exists and can successfully import the named module will + be returned (or ``None`` if probing fails for all supplied paths). + :arg module_name: fully-qualified Python module name to probe for (eg, ``selinux``) + """ + for interpreter_path in interpreter_paths: + if not os.path.exists(interpreter_path): + continue + try: + rc = subprocess.call([interpreter_path, '-c', 'import {0}'.format(module_name)]) + if rc == 0: + return interpreter_path + except Exception: + continue + + return None + + +def _create_payload(): + from ansible.module_utils import basic + smuggled_args = getattr(basic, '_ANSIBLE_ARGS') + if not smuggled_args: + raise Exception('unable to access ansible.module_utils.basic._ANSIBLE_ARGS (not launched by AnsiballZ?)') + module_fqn = sys.modules['__main__']._module_fqn + modlib_path = sys.modules['__main__']._modlib_path + respawn_code_template = ''' +import runpy +import sys + +module_fqn = '{module_fqn}' +modlib_path = '{modlib_path}' +smuggled_args = b"""{smuggled_args}""".strip() + + +if __name__ == '__main__': + sys.path.insert(0, modlib_path) + + from ansible.module_utils import basic + basic._ANSIBLE_ARGS = smuggled_args + + runpy.run_module(module_fqn, init_globals=dict(_respawned=True), run_name='__main__', alter_sys=True) + ''' + + respawn_code = respawn_code_template.format(module_fqn=module_fqn, modlib_path=modlib_path, smuggled_args=to_native(smuggled_args)) + + return respawn_code diff --git a/lib/ansible/module_utils/common/sys_info.py b/lib/ansible/module_utils/common/sys_info.py new file mode 100644 index 0000000..206b36c --- /dev/null +++ b/lib/ansible/module_utils/common/sys_info.py @@ -0,0 +1,157 @@ +# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013 +# Copyright (c), Toshio Kuratomi <tkuratomi@ansible.com> 2016 +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import platform + +from ansible.module_utils import distro +from ansible.module_utils.common._utils import get_all_subclasses + + +__all__ = ('get_distribution', 'get_distribution_version', 'get_platform_subclass') + + +def get_distribution(): + ''' + Return the name of the distribution the module is running on. + + :rtype: NativeString or None + :returns: Name of the distribution the module is running on + + This function attempts to determine what distribution the code is running + on and return a string representing that value. If the platform is Linux + and the distribution cannot be determined, it returns ``OtherLinux``. + ''' + distribution = distro.id().capitalize() + + if platform.system() == 'Linux': + if distribution == 'Amzn': + distribution = 'Amazon' + elif distribution == 'Rhel': + distribution = 'Redhat' + elif not distribution: + distribution = 'OtherLinux' + + return distribution + + +def get_distribution_version(): + ''' + Get the version of the distribution the code is running on + + :rtype: NativeString or None + :returns: A string representation of the version of the distribution. If it + cannot determine the version, it returns an empty string. If this is not run on + a Linux machine it returns None. + ''' + version = None + + needs_best_version = frozenset(( + u'centos', + u'debian', + )) + + version = distro.version() + distro_id = distro.id() + + if version is not None: + if distro_id in needs_best_version: + version_best = distro.version(best=True) + + # CentoOS maintainers believe only the major version is appropriate + # but Ansible users desire minor version information, e.g., 7.5. + # https://github.com/ansible/ansible/issues/50141#issuecomment-449452781 + if distro_id == u'centos': + version = u'.'.join(version_best.split(u'.')[:2]) + + # Debian does not include minor version in /etc/os-release. + # Bug report filed upstream requesting this be added to /etc/os-release + # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=931197 + if distro_id == u'debian': + version = version_best + + else: + version = u'' + + return version + + +def get_distribution_codename(): + ''' + Return the code name for this Linux Distribution + + :rtype: NativeString or None + :returns: A string representation of the distribution's codename or None if not a Linux distro + ''' + codename = None + if platform.system() == 'Linux': + # Until this gets merged and we update our bundled copy of distro: + # https://github.com/nir0s/distro/pull/230 + # Fixes Fedora 28+ not having a code name and Ubuntu Xenial Xerus needing to be "xenial" + os_release_info = distro.os_release_info() + codename = os_release_info.get('version_codename') + + if codename is None: + codename = os_release_info.get('ubuntu_codename') + + if codename is None and distro.id() == 'ubuntu': + lsb_release_info = distro.lsb_release_info() + codename = lsb_release_info.get('codename') + + if codename is None: + codename = distro.codename() + if codename == u'': + codename = None + + return codename + + +def get_platform_subclass(cls): + ''' + Finds a subclass implementing desired functionality on the platform the code is running on + + :arg cls: Class to find an appropriate subclass for + :returns: A class that implements the functionality on this platform + + Some Ansible modules have different implementations depending on the platform they run on. This + function is used to select between the various implementations and choose one. You can look at + the implementation of the Ansible :ref:`User module<user_module>` module for an example of how to use this. + + This function replaces ``basic.load_platform_subclass()``. When you port code, you need to + change the callers to be explicit about instantiating the class. For instance, code in the + Ansible User module changed from:: + + .. code-block:: python + + # Old + class User: + def __new__(cls, args, kwargs): + return load_platform_subclass(User, args, kwargs) + + # New + class User: + def __new__(cls, *args, **kwargs): + new_cls = get_platform_subclass(User) + return super(cls, new_cls).__new__(new_cls) + ''' + this_platform = platform.system() + distribution = get_distribution() + + subclass = None + + # get the most specific superclass for this platform + if distribution is not None: + for sc in get_all_subclasses(cls): + if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform: + subclass = sc + if subclass is None: + for sc in get_all_subclasses(cls): + if sc.platform == this_platform and sc.distribution is None: + subclass = sc + if subclass is None: + subclass = cls + + return subclass diff --git a/lib/ansible/module_utils/common/text/__init__.py b/lib/ansible/module_utils/common/text/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/common/text/__init__.py diff --git a/lib/ansible/module_utils/common/text/converters.py b/lib/ansible/module_utils/common/text/converters.py new file mode 100644 index 0000000..5b25df4 --- /dev/null +++ b/lib/ansible/module_utils/common/text/converters.py @@ -0,0 +1,322 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# (c) 2016 Toshio Kuratomi <tkuratomi@ansible.com> +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import codecs +import datetime +import json + +from ansible.module_utils.common._collections_compat import Set +from ansible.module_utils.six import ( + PY3, + binary_type, + iteritems, + text_type, +) + +try: + codecs.lookup_error('surrogateescape') + HAS_SURROGATEESCAPE = True +except LookupError: + HAS_SURROGATEESCAPE = False + + +_COMPOSED_ERROR_HANDLERS = frozenset((None, 'surrogate_or_replace', + 'surrogate_or_strict', + 'surrogate_then_replace')) + + +def to_bytes(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): + """Make sure that a string is a byte string + + :arg obj: An object to make sure is a byte string. In most cases this + will be either a text string or a byte string. However, with + ``nonstring='simplerepr'``, this can be used as a traceback-free + version of ``str(obj)``. + :kwarg encoding: The encoding to use to transform from a text string to + a byte string. Defaults to using 'utf-8'. + :kwarg errors: The error handler to use if the text string is not + encodable using the specified encoding. Any valid `codecs error + handler <https://docs.python.org/3/library/codecs.html#codec-base-classes>`_ + may be specified. There are three additional error strategies + specifically aimed at helping people to port code. The first two are: + + :surrogate_or_strict: Will use ``surrogateescape`` if it is a valid + handler, otherwise it will use ``strict`` + :surrogate_or_replace: Will use ``surrogateescape`` if it is a valid + handler, otherwise it will use ``replace``. + + Because ``surrogateescape`` was added in Python3 this usually means that + Python3 will use ``surrogateescape`` and Python2 will use the fallback + error handler. Note that the code checks for ``surrogateescape`` when the + module is imported. If you have a backport of ``surrogateescape`` for + Python2, be sure to register the error handler prior to importing this + module. + + The last error handler is: + + :surrogate_then_replace: Will use ``surrogateescape`` if it is a valid + handler. If encoding with ``surrogateescape`` would traceback, + surrogates are first replaced with a replacement characters + and then the string is encoded using ``replace`` (which replaces + the rest of the nonencodable bytes). If ``surrogateescape`` is + not present it will simply use ``replace``. (Added in Ansible 2.3) + This strategy is designed to never traceback when it attempts + to encode a string. + + The default until Ansible-2.2 was ``surrogate_or_replace`` + From Ansible-2.3 onwards, the default is ``surrogate_then_replace``. + + :kwarg nonstring: The strategy to use if a nonstring is specified in + ``obj``. Default is 'simplerepr'. Valid values are: + + :simplerepr: The default. This takes the ``str`` of the object and + then returns the bytes version of that string. + :empty: Return an empty byte string + :passthru: Return the object passed in + :strict: Raise a :exc:`TypeError` + + :returns: Typically this returns a byte string. If a nonstring object is + passed in this may be a different type depending on the strategy + specified by nonstring. This will never return a text string. + + .. note:: If passed a byte string, this function does not check that the + string is valid in the specified encoding. If it's important that the + byte string is in the specified encoding do:: + + encoded_string = to_bytes(to_text(input_string, 'latin-1'), 'utf-8') + + .. version_changed:: 2.3 + + Added the ``surrogate_then_replace`` error handler and made it the default error handler. + """ + if isinstance(obj, binary_type): + return obj + + # We're given a text string + # If it has surrogates, we know because it will decode + original_errors = errors + if errors in _COMPOSED_ERROR_HANDLERS: + if HAS_SURROGATEESCAPE: + errors = 'surrogateescape' + elif errors == 'surrogate_or_strict': + errors = 'strict' + else: + errors = 'replace' + + if isinstance(obj, text_type): + try: + # Try this first as it's the fastest + return obj.encode(encoding, errors) + except UnicodeEncodeError: + if original_errors in (None, 'surrogate_then_replace'): + # We should only reach this if encoding was non-utf8 original_errors was + # surrogate_then_escape and errors was surrogateescape + + # Slow but works + return_string = obj.encode('utf-8', 'surrogateescape') + return_string = return_string.decode('utf-8', 'replace') + return return_string.encode(encoding, 'replace') + raise + + # Note: We do these last even though we have to call to_bytes again on the + # value because we're optimizing the common case + if nonstring == 'simplerepr': + try: + value = str(obj) + except UnicodeError: + try: + value = repr(obj) + except UnicodeError: + # Giving up + return to_bytes('') + elif nonstring == 'passthru': + return obj + elif nonstring == 'empty': + # python2.4 doesn't have b'' + return to_bytes('') + elif nonstring == 'strict': + raise TypeError('obj must be a string type') + else: + raise TypeError('Invalid value %s for to_bytes\' nonstring parameter' % nonstring) + + return to_bytes(value, encoding, errors) + + +def to_text(obj, encoding='utf-8', errors=None, nonstring='simplerepr'): + """Make sure that a string is a text string + + :arg obj: An object to make sure is a text string. In most cases this + will be either a text string or a byte string. However, with + ``nonstring='simplerepr'``, this can be used as a traceback-free + version of ``str(obj)``. + :kwarg encoding: The encoding to use to transform from a byte string to + a text string. Defaults to using 'utf-8'. + :kwarg errors: The error handler to use if the byte string is not + decodable using the specified encoding. Any valid `codecs error + handler <https://docs.python.org/3/library/codecs.html#codec-base-classes>`_ + may be specified. We support three additional error strategies + specifically aimed at helping people to port code: + + :surrogate_or_strict: Will use surrogateescape if it is a valid + handler, otherwise it will use strict + :surrogate_or_replace: Will use surrogateescape if it is a valid + handler, otherwise it will use replace. + :surrogate_then_replace: Does the same as surrogate_or_replace but + `was added for symmetry with the error handlers in + :func:`ansible.module_utils._text.to_bytes` (Added in Ansible 2.3) + + Because surrogateescape was added in Python3 this usually means that + Python3 will use `surrogateescape` and Python2 will use the fallback + error handler. Note that the code checks for surrogateescape when the + module is imported. If you have a backport of `surrogateescape` for + python2, be sure to register the error handler prior to importing this + module. + + The default until Ansible-2.2 was `surrogate_or_replace` + In Ansible-2.3 this defaults to `surrogate_then_replace` for symmetry + with :func:`ansible.module_utils._text.to_bytes` . + :kwarg nonstring: The strategy to use if a nonstring is specified in + ``obj``. Default is 'simplerepr'. Valid values are: + + :simplerepr: The default. This takes the ``str`` of the object and + then returns the text version of that string. + :empty: Return an empty text string + :passthru: Return the object passed in + :strict: Raise a :exc:`TypeError` + + :returns: Typically this returns a text string. If a nonstring object is + passed in this may be a different type depending on the strategy + specified by nonstring. This will never return a byte string. + From Ansible-2.3 onwards, the default is `surrogate_then_replace`. + + .. version_changed:: 2.3 + + Added the surrogate_then_replace error handler and made it the default error handler. + """ + if isinstance(obj, text_type): + return obj + + if errors in _COMPOSED_ERROR_HANDLERS: + if HAS_SURROGATEESCAPE: + errors = 'surrogateescape' + elif errors == 'surrogate_or_strict': + errors = 'strict' + else: + errors = 'replace' + + if isinstance(obj, binary_type): + # Note: We don't need special handling for surrogate_then_replace + # because all bytes will either be made into surrogates or are valid + # to decode. + return obj.decode(encoding, errors) + + # Note: We do these last even though we have to call to_text again on the + # value because we're optimizing the common case + if nonstring == 'simplerepr': + try: + value = str(obj) + except UnicodeError: + try: + value = repr(obj) + except UnicodeError: + # Giving up + return u'' + elif nonstring == 'passthru': + return obj + elif nonstring == 'empty': + return u'' + elif nonstring == 'strict': + raise TypeError('obj must be a string type') + else: + raise TypeError('Invalid value %s for to_text\'s nonstring parameter' % nonstring) + + return to_text(value, encoding, errors) + + +#: :py:func:`to_native` +#: Transform a variable into the native str type for the python version +#: +#: On Python2, this is an alias for +#: :func:`~ansible.module_utils.to_bytes`. On Python3 it is an alias for +#: :func:`~ansible.module_utils.to_text`. It makes it easier to +#: transform a variable into the native str type for the python version +#: the code is running on. Use this when constructing the message to +#: send to exceptions or when dealing with an API that needs to take +#: a native string. Example:: +#: +#: try: +#: 1//0 +#: except ZeroDivisionError as e: +#: raise MyException('Encountered and error: %s' % to_native(e)) +if PY3: + to_native = to_text +else: + to_native = to_bytes + + +def _json_encode_fallback(obj): + if isinstance(obj, Set): + return list(obj) + elif isinstance(obj, datetime.datetime): + return obj.isoformat() + raise TypeError("Cannot json serialize %s" % to_native(obj)) + + +def jsonify(data, **kwargs): + for encoding in ("utf-8", "latin-1"): + try: + return json.dumps(data, encoding=encoding, default=_json_encode_fallback, **kwargs) + # Old systems using old simplejson module does not support encoding keyword. + except TypeError: + try: + new_data = container_to_text(data, encoding=encoding) + except UnicodeDecodeError: + continue + return json.dumps(new_data, default=_json_encode_fallback, **kwargs) + except UnicodeDecodeError: + continue + raise UnicodeError('Invalid unicode encoding encountered') + + +def container_to_bytes(d, encoding='utf-8', errors='surrogate_or_strict'): + ''' Recursively convert dict keys and values to byte str + + Specialized for json return because this only handles, lists, tuples, + and dict container types (the containers that the json module returns) + ''' + + if isinstance(d, text_type): + return to_bytes(d, encoding=encoding, errors=errors) + elif isinstance(d, dict): + return dict(container_to_bytes(o, encoding, errors) for o in iteritems(d)) + elif isinstance(d, list): + return [container_to_bytes(o, encoding, errors) for o in d] + elif isinstance(d, tuple): + return tuple(container_to_bytes(o, encoding, errors) for o in d) + else: + return d + + +def container_to_text(d, encoding='utf-8', errors='surrogate_or_strict'): + """Recursively convert dict keys and values to text str + + Specialized for json return because this only handles, lists, tuples, + and dict container types (the containers that the json module returns) + """ + + if isinstance(d, binary_type): + # Warning, can traceback + return to_text(d, encoding=encoding, errors=errors) + elif isinstance(d, dict): + return dict(container_to_text(o, encoding, errors) for o in iteritems(d)) + elif isinstance(d, list): + return [container_to_text(o, encoding, errors) for o in d] + elif isinstance(d, tuple): + return tuple(container_to_text(o, encoding, errors) for o in d) + else: + return d diff --git a/lib/ansible/module_utils/common/text/formatters.py b/lib/ansible/module_utils/common/text/formatters.py new file mode 100644 index 0000000..94ca5a3 --- /dev/null +++ b/lib/ansible/module_utils/common/text/formatters.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import re + +from ansible.module_utils.six import iteritems + +SIZE_RANGES = { + 'Y': 1 << 80, + 'Z': 1 << 70, + 'E': 1 << 60, + 'P': 1 << 50, + 'T': 1 << 40, + 'G': 1 << 30, + 'M': 1 << 20, + 'K': 1 << 10, + 'B': 1, +} + + +def lenient_lowercase(lst): + """Lowercase elements of a list. + + If an element is not a string, pass it through untouched. + """ + lowered = [] + for value in lst: + try: + lowered.append(value.lower()) + except AttributeError: + lowered.append(value) + return lowered + + +def human_to_bytes(number, default_unit=None, isbits=False): + """Convert number in string format into bytes (ex: '2K' => 2048) or using unit argument. + + example: human_to_bytes('10M') <=> human_to_bytes(10, 'M'). + + When isbits is False (default), converts bytes from a human-readable format to integer. + example: human_to_bytes('1MB') returns 1048576 (int). + The function expects 'B' (uppercase) as a byte identifier passed + as a part of 'name' param string or 'unit', e.g. 'MB'/'KB'/etc. + (except when the identifier is single 'b', it is perceived as a byte identifier too). + if 'Mb'/'Kb'/... is passed, the ValueError will be rased. + + When isbits is True, converts bits from a human-readable format to integer. + example: human_to_bytes('1Mb', isbits=True) returns 8388608 (int) - + string bits representation was passed and return as a number or bits. + The function expects 'b' (lowercase) as a bit identifier, e.g. 'Mb'/'Kb'/etc. + if 'MB'/'KB'/... is passed, the ValueError will be rased. + """ + m = re.search(r'^\s*(\d*\.?\d*)\s*([A-Za-z]+)?', str(number), flags=re.IGNORECASE) + if m is None: + raise ValueError("human_to_bytes() can't interpret following string: %s" % str(number)) + try: + num = float(m.group(1)) + except Exception: + raise ValueError("human_to_bytes() can't interpret following number: %s (original input string: %s)" % (m.group(1), number)) + + unit = m.group(2) + if unit is None: + unit = default_unit + + if unit is None: + ''' No unit given, returning raw number ''' + return int(round(num)) + range_key = unit[0].upper() + try: + limit = SIZE_RANGES[range_key] + except Exception: + raise ValueError("human_to_bytes() failed to convert %s (unit = %s). The suffix must be one of %s" % (number, unit, ", ".join(SIZE_RANGES.keys()))) + + # default value + unit_class = 'B' + unit_class_name = 'byte' + # handling bits case + if isbits: + unit_class = 'b' + unit_class_name = 'bit' + # check unit value if more than one character (KB, MB) + if len(unit) > 1: + expect_message = 'expect %s%s or %s' % (range_key, unit_class, range_key) + if range_key == 'B': + expect_message = 'expect %s or %s' % (unit_class, unit_class_name) + + if unit_class_name in unit.lower(): + pass + elif unit[1] != unit_class: + raise ValueError("human_to_bytes() failed to convert %s. Value is not a valid string (%s)" % (number, expect_message)) + + return int(round(num * limit)) + + +def bytes_to_human(size, isbits=False, unit=None): + base = 'Bytes' + if isbits: + base = 'bits' + suffix = '' + + for suffix, limit in sorted(iteritems(SIZE_RANGES), key=lambda item: -item[1]): + if (unit is None and size >= limit) or unit is not None and unit.upper() == suffix[0]: + break + + if limit != 1: + suffix += base[0] + else: + suffix = base + + return '%.2f %s' % (size / limit, suffix) diff --git a/lib/ansible/module_utils/common/validation.py b/lib/ansible/module_utils/common/validation.py new file mode 100644 index 0000000..5a4cebb --- /dev/null +++ b/lib/ansible/module_utils/common/validation.py @@ -0,0 +1,578 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import re + +from ast import literal_eval +from ansible.module_utils._text import to_native +from ansible.module_utils.common._json_compat import json +from ansible.module_utils.common.collections import is_iterable +from ansible.module_utils.common.text.converters import jsonify +from ansible.module_utils.common.text.formatters import human_to_bytes +from ansible.module_utils.parsing.convert_bool import boolean +from ansible.module_utils.six import ( + binary_type, + integer_types, + string_types, + text_type, +) + + +def count_terms(terms, parameters): + """Count the number of occurrences of a key in a given dictionary + + :arg terms: String or iterable of values to check + :arg parameters: Dictionary of parameters + + :returns: An integer that is the number of occurrences of the terms values + in the provided dictionary. + """ + + if not is_iterable(terms): + terms = [terms] + + return len(set(terms).intersection(parameters)) + + +def safe_eval(value, locals=None, include_exceptions=False): + # do not allow method calls to modules + if not isinstance(value, string_types): + # already templated to a datavaluestructure, perhaps? + if include_exceptions: + return (value, None) + return value + if re.search(r'\w\.\w+\(', value): + if include_exceptions: + return (value, None) + return value + # do not allow imports + if re.search(r'import \w+', value): + if include_exceptions: + return (value, None) + return value + try: + result = literal_eval(value) + if include_exceptions: + return (result, None) + else: + return result + except Exception as e: + if include_exceptions: + return (value, e) + return value + + +def check_mutually_exclusive(terms, parameters, options_context=None): + """Check mutually exclusive terms against argument parameters + + Accepts a single list or list of lists that are groups of terms that should be + mutually exclusive with one another + + :arg terms: List of mutually exclusive parameters + :arg parameters: Dictionary of parameters + :kwarg options_context: List of strings of parent key names if ``terms`` are + in a sub spec. + + :returns: Empty list or raises :class:`TypeError` if the check fails. + """ + + results = [] + if terms is None: + return results + + for check in terms: + count = count_terms(check, parameters) + if count > 1: + results.append(check) + + if results: + full_list = ['|'.join(check) for check in results] + msg = "parameters are mutually exclusive: %s" % ', '.join(full_list) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + raise TypeError(to_native(msg)) + + return results + + +def check_required_one_of(terms, parameters, options_context=None): + """Check each list of terms to ensure at least one exists in the given module + parameters + + Accepts a list of lists or tuples + + :arg terms: List of lists of terms to check. For each list of terms, at + least one is required. + :arg parameters: Dictionary of parameters + :kwarg options_context: List of strings of parent key names if ``terms`` are + in a sub spec. + + :returns: Empty list or raises :class:`TypeError` if the check fails. + """ + + results = [] + if terms is None: + return results + + for term in terms: + count = count_terms(term, parameters) + if count == 0: + results.append(term) + + if results: + for term in results: + msg = "one of the following is required: %s" % ', '.join(term) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + raise TypeError(to_native(msg)) + + return results + + +def check_required_together(terms, parameters, options_context=None): + """Check each list of terms to ensure every parameter in each list exists + in the given parameters. + + Accepts a list of lists or tuples. + + :arg terms: List of lists of terms to check. Each list should include + parameters that are all required when at least one is specified + in the parameters. + :arg parameters: Dictionary of parameters + :kwarg options_context: List of strings of parent key names if ``terms`` are + in a sub spec. + + :returns: Empty list or raises :class:`TypeError` if the check fails. + """ + + results = [] + if terms is None: + return results + + for term in terms: + counts = [count_terms(field, parameters) for field in term] + non_zero = [c for c in counts if c > 0] + if len(non_zero) > 0: + if 0 in counts: + results.append(term) + if results: + for term in results: + msg = "parameters are required together: %s" % ', '.join(term) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + raise TypeError(to_native(msg)) + + return results + + +def check_required_by(requirements, parameters, options_context=None): + """For each key in requirements, check the corresponding list to see if they + exist in parameters. + + Accepts a single string or list of values for each key. + + :arg requirements: Dictionary of requirements + :arg parameters: Dictionary of parameters + :kwarg options_context: List of strings of parent key names if ``requirements`` are + in a sub spec. + + :returns: Empty dictionary or raises :class:`TypeError` if the + """ + + result = {} + if requirements is None: + return result + + for (key, value) in requirements.items(): + if key not in parameters or parameters[key] is None: + continue + result[key] = [] + # Support strings (single-item lists) + if isinstance(value, string_types): + value = [value] + for required in value: + if required not in parameters or parameters[required] is None: + result[key].append(required) + + if result: + for key, missing in result.items(): + if len(missing) > 0: + msg = "missing parameter(s) required by '%s': %s" % (key, ', '.join(missing)) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + raise TypeError(to_native(msg)) + + return result + + +def check_required_arguments(argument_spec, parameters, options_context=None): + """Check all parameters in argument_spec and return a list of parameters + that are required but not present in parameters. + + Raises :class:`TypeError` if the check fails + + :arg argument_spec: Argument spec dictionary containing all parameters + and their specification + :arg parameters: Dictionary of parameters + :kwarg options_context: List of strings of parent key names if ``argument_spec`` are + in a sub spec. + + :returns: Empty list or raises :class:`TypeError` if the check fails. + """ + + missing = [] + if argument_spec is None: + return missing + + for (k, v) in argument_spec.items(): + required = v.get('required', False) + if required and k not in parameters: + missing.append(k) + + if missing: + msg = "missing required arguments: %s" % ", ".join(sorted(missing)) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + raise TypeError(to_native(msg)) + + return missing + + +def check_required_if(requirements, parameters, options_context=None): + """Check parameters that are conditionally required + + Raises :class:`TypeError` if the check fails + + :arg requirements: List of lists specifying a parameter, value, parameters + required when the given parameter is the specified value, and optionally + a boolean indicating any or all parameters are required. + + :Example: + + .. code-block:: python + + required_if=[ + ['state', 'present', ('path',), True], + ['someint', 99, ('bool_param', 'string_param')], + ] + + :arg parameters: Dictionary of parameters + + :returns: Empty list or raises :class:`TypeError` if the check fails. + The results attribute of the exception contains a list of dictionaries. + Each dictionary is the result of evaluating each item in requirements. + Each return dictionary contains the following keys: + + :key missing: List of parameters that are required but missing + :key requires: 'any' or 'all' + :key parameter: Parameter name that has the requirement + :key value: Original value of the parameter + :key requirements: Original required parameters + + :Example: + + .. code-block:: python + + [ + { + 'parameter': 'someint', + 'value': 99 + 'requirements': ('bool_param', 'string_param'), + 'missing': ['string_param'], + 'requires': 'all', + } + ] + + :kwarg options_context: List of strings of parent key names if ``requirements`` are + in a sub spec. + """ + results = [] + if requirements is None: + return results + + for req in requirements: + missing = {} + missing['missing'] = [] + max_missing_count = 0 + is_one_of = False + if len(req) == 4: + key, val, requirements, is_one_of = req + else: + key, val, requirements = req + + # is_one_of is True at least one requirement should be + # present, else all requirements should be present. + if is_one_of: + max_missing_count = len(requirements) + missing['requires'] = 'any' + else: + missing['requires'] = 'all' + + if key in parameters and parameters[key] == val: + for check in requirements: + count = count_terms(check, parameters) + if count == 0: + missing['missing'].append(check) + if len(missing['missing']) and len(missing['missing']) >= max_missing_count: + missing['parameter'] = key + missing['value'] = val + missing['requirements'] = requirements + results.append(missing) + + if results: + for missing in results: + msg = "%s is %s but %s of the following are missing: %s" % ( + missing['parameter'], missing['value'], missing['requires'], ', '.join(missing['missing'])) + if options_context: + msg = "{0} found in {1}".format(msg, " -> ".join(options_context)) + raise TypeError(to_native(msg)) + + return results + + +def check_missing_parameters(parameters, required_parameters=None): + """This is for checking for required params when we can not check via + argspec because we need more information than is simply given in the argspec. + + Raises :class:`TypeError` if any required parameters are missing + + :arg parameters: Dictionary of parameters + :arg required_parameters: List of parameters to look for in the given parameters. + + :returns: Empty list or raises :class:`TypeError` if the check fails. + """ + missing_params = [] + if required_parameters is None: + return missing_params + + for param in required_parameters: + if not parameters.get(param): + missing_params.append(param) + + if missing_params: + msg = "missing required arguments: %s" % ', '.join(missing_params) + raise TypeError(to_native(msg)) + + return missing_params + + +# FIXME: The param and prefix parameters here are coming from AnsibleModule._check_type_string() +# which is using those for the warning messaged based on string conversion warning settings. +# Not sure how to deal with that here since we don't have config state to query. +def check_type_str(value, allow_conversion=True, param=None, prefix=''): + """Verify that the value is a string or convert to a string. + + Since unexpected changes can sometimes happen when converting to a string, + ``allow_conversion`` controls whether or not the value will be converted or a + TypeError will be raised if the value is not a string and would be converted + + :arg value: Value to validate or convert to a string + :arg allow_conversion: Whether to convert the string and return it or raise + a TypeError + + :returns: Original value if it is a string, the value converted to a string + if allow_conversion=True, or raises a TypeError if allow_conversion=False. + """ + if isinstance(value, string_types): + return value + + if allow_conversion: + return to_native(value, errors='surrogate_or_strict') + + msg = "'{0!r}' is not a string and conversion is not allowed".format(value) + raise TypeError(to_native(msg)) + + +def check_type_list(value): + """Verify that the value is a list or convert to a list + + A comma separated string will be split into a list. Raises a :class:`TypeError` + if unable to convert to a list. + + :arg value: Value to validate or convert to a list + + :returns: Original value if it is already a list, single item list if a + float, int, or string without commas, or a multi-item list if a + comma-delimited string. + """ + if isinstance(value, list): + return value + + if isinstance(value, string_types): + return value.split(",") + elif isinstance(value, int) or isinstance(value, float): + return [str(value)] + + raise TypeError('%s cannot be converted to a list' % type(value)) + + +def check_type_dict(value): + """Verify that value is a dict or convert it to a dict and return it. + + Raises :class:`TypeError` if unable to convert to a dict + + :arg value: Dict or string to convert to a dict. Accepts ``k1=v2, k2=v2``. + + :returns: value converted to a dictionary + """ + if isinstance(value, dict): + return value + + if isinstance(value, string_types): + if value.startswith("{"): + try: + return json.loads(value) + except Exception: + (result, exc) = safe_eval(value, dict(), include_exceptions=True) + if exc is not None: + raise TypeError('unable to evaluate string as dictionary') + return result + elif '=' in value: + fields = [] + field_buffer = [] + in_quote = False + in_escape = False + for c in value.strip(): + if in_escape: + field_buffer.append(c) + in_escape = False + elif c == '\\': + in_escape = True + elif not in_quote and c in ('\'', '"'): + in_quote = c + elif in_quote and in_quote == c: + in_quote = False + elif not in_quote and c in (',', ' '): + field = ''.join(field_buffer) + if field: + fields.append(field) + field_buffer = [] + else: + field_buffer.append(c) + + field = ''.join(field_buffer) + if field: + fields.append(field) + return dict(x.split("=", 1) for x in fields) + else: + raise TypeError("dictionary requested, could not parse JSON or key=value") + + raise TypeError('%s cannot be converted to a dict' % type(value)) + + +def check_type_bool(value): + """Verify that the value is a bool or convert it to a bool and return it. + + Raises :class:`TypeError` if unable to convert to a bool + + :arg value: String, int, or float to convert to bool. Valid booleans include: + '1', 'on', 1, '0', 0, 'n', 'f', 'false', 'true', 'y', 't', 'yes', 'no', 'off' + + :returns: Boolean True or False + """ + if isinstance(value, bool): + return value + + if isinstance(value, string_types) or isinstance(value, (int, float)): + return boolean(value) + + raise TypeError('%s cannot be converted to a bool' % type(value)) + + +def check_type_int(value): + """Verify that the value is an integer and return it or convert the value + to an integer and return it + + Raises :class:`TypeError` if unable to convert to an int + + :arg value: String or int to convert of verify + + :return: int of given value + """ + if isinstance(value, integer_types): + return value + + if isinstance(value, string_types): + try: + return int(value) + except ValueError: + pass + + raise TypeError('%s cannot be converted to an int' % type(value)) + + +def check_type_float(value): + """Verify that value is a float or convert it to a float and return it + + Raises :class:`TypeError` if unable to convert to a float + + :arg value: float, int, str, or bytes to verify or convert and return. + + :returns: float of given value. + """ + if isinstance(value, float): + return value + + if isinstance(value, (binary_type, text_type, int)): + try: + return float(value) + except ValueError: + pass + + raise TypeError('%s cannot be converted to a float' % type(value)) + + +def check_type_path(value,): + """Verify the provided value is a string or convert it to a string, + then return the expanded path + """ + value = check_type_str(value) + return os.path.expanduser(os.path.expandvars(value)) + + +def check_type_raw(value): + """Returns the raw value""" + return value + + +def check_type_bytes(value): + """Convert a human-readable string value to bytes + + Raises :class:`TypeError` if unable to covert the value + """ + try: + return human_to_bytes(value) + except ValueError: + raise TypeError('%s cannot be converted to a Byte value' % type(value)) + + +def check_type_bits(value): + """Convert a human-readable string bits value to bits in integer. + + Example: ``check_type_bits('1Mb')`` returns integer 1048576. + + Raises :class:`TypeError` if unable to covert the value. + """ + try: + return human_to_bytes(value, isbits=True) + except ValueError: + raise TypeError('%s cannot be converted to a Bit value' % type(value)) + + +def check_type_jsonarg(value): + """Return a jsonified string. Sometimes the controller turns a json string + into a dict/list so transform it back into json here + + Raises :class:`TypeError` if unable to covert the value + + """ + if isinstance(value, (text_type, binary_type)): + return value.strip() + elif isinstance(value, (list, tuple, dict)): + return jsonify(value) + raise TypeError('%s cannot be converted to a json string' % type(value)) diff --git a/lib/ansible/module_utils/common/warnings.py b/lib/ansible/module_utils/common/warnings.py new file mode 100644 index 0000000..9423e6a --- /dev/null +++ b/lib/ansible/module_utils/common/warnings.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils.six import string_types + +_global_warnings = [] +_global_deprecations = [] + + +def warn(warning): + if isinstance(warning, string_types): + _global_warnings.append(warning) + else: + raise TypeError("warn requires a string not a %s" % type(warning)) + + +def deprecate(msg, version=None, date=None, collection_name=None): + if isinstance(msg, string_types): + # For compatibility, we accept that neither version nor date is set, + # and treat that the same as if version would haven been set + if date is not None: + _global_deprecations.append({'msg': msg, 'date': date, 'collection_name': collection_name}) + else: + _global_deprecations.append({'msg': msg, 'version': version, 'collection_name': collection_name}) + else: + raise TypeError("deprecate requires a string not a %s" % type(msg)) + + +def get_warning_messages(): + """Return a tuple of warning messages accumulated over this run""" + return tuple(_global_warnings) + + +def get_deprecation_messages(): + """Return a tuple of deprecations accumulated over this run""" + return tuple(_global_deprecations) diff --git a/lib/ansible/module_utils/common/yaml.py b/lib/ansible/module_utils/common/yaml.py new file mode 100644 index 0000000..e79cc09 --- /dev/null +++ b/lib/ansible/module_utils/common/yaml.py @@ -0,0 +1,48 @@ +# (c) 2020 Matt Martz <matt@sivel.net> +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +""" +This file provides ease of use shortcuts for loading and dumping YAML, +preferring the YAML compiled C extensions to reduce duplicated code. +""" + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from functools import partial as _partial + +HAS_LIBYAML = False + +try: + import yaml as _yaml +except ImportError: + HAS_YAML = False +else: + HAS_YAML = True + +if HAS_YAML: + try: + from yaml import CSafeLoader as SafeLoader + from yaml import CSafeDumper as SafeDumper + from yaml.cyaml import CParser as Parser + + HAS_LIBYAML = True + except (ImportError, AttributeError): + from yaml import SafeLoader # type: ignore[misc] + from yaml import SafeDumper # type: ignore[misc] + from yaml.parser import Parser # type: ignore[misc] + + yaml_load = _partial(_yaml.load, Loader=SafeLoader) + yaml_load_all = _partial(_yaml.load_all, Loader=SafeLoader) + + yaml_dump = _partial(_yaml.dump, Dumper=SafeDumper) + yaml_dump_all = _partial(_yaml.dump_all, Dumper=SafeDumper) +else: + SafeLoader = object # type: ignore[assignment,misc] + SafeDumper = object # type: ignore[assignment,misc] + Parser = object # type: ignore[assignment,misc] + + yaml_load = None # type: ignore[assignment] + yaml_load_all = None # type: ignore[assignment] + yaml_dump = None # type: ignore[assignment] + yaml_dump_all = None # type: ignore[assignment] diff --git a/lib/ansible/module_utils/compat/__init__.py b/lib/ansible/module_utils/compat/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/compat/__init__.py diff --git a/lib/ansible/module_utils/compat/_selectors2.py b/lib/ansible/module_utils/compat/_selectors2.py new file mode 100644 index 0000000..be44b4b --- /dev/null +++ b/lib/ansible/module_utils/compat/_selectors2.py @@ -0,0 +1,655 @@ +# This file is from the selectors2.py package. It backports the PSF Licensed +# selectors module from the Python-3.5 stdlib to older versions of Python. +# The author, Seth Michael Larson, dual licenses his modifications under the +# PSF License and MIT License: +# https://github.com/SethMichaelLarson/selectors2#license +# +# Copyright (c) 2016 Seth Michael Larson +# +# PSF License (see licenses/PSF-license.txt or https://opensource.org/licenses/Python-2.0) +# MIT License (see licenses/MIT-license.txt or https://opensource.org/licenses/MIT) +# + + +# Backport of selectors.py from Python 3.5+ to support Python < 3.4 +# Also has the behavior specified in PEP 475 which is to retry syscalls +# in the case of an EINTR error. This module is required because selectors34 +# does not follow this behavior and instead returns that no file descriptor +# events have occurred rather than retry the syscall. The decision to drop +# support for select.devpoll is made to maintain 100% test coverage. + +import errno +import math +import select +import socket +import sys +import time +from collections import namedtuple +from ansible.module_utils.common._collections_compat import Mapping + +try: + monotonic = time.monotonic +except (AttributeError, ImportError): # Python 3.3< + monotonic = time.time + +__author__ = 'Seth Michael Larson' +__email__ = 'sethmichaellarson@protonmail.com' +__version__ = '1.1.1' +__license__ = 'MIT' + +__all__ = [ + 'EVENT_READ', + 'EVENT_WRITE', + 'SelectorError', + 'SelectorKey', + 'DefaultSelector' +] + +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + +HAS_SELECT = True # Variable that shows whether the platform has a selector. +_SYSCALL_SENTINEL = object() # Sentinel in case a system call returns None. + + +class SelectorError(Exception): + def __init__(self, errcode): + super(SelectorError, self).__init__() + self.errno = errcode + + def __repr__(self): + return "<SelectorError errno={0}>".format(self.errno) + + def __str__(self): + return self.__repr__() + + +def _fileobj_to_fd(fileobj): + """ Return a file descriptor from a file object. If + given an integer will simply return that integer back. """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: {0!r}".format(fileobj)) + if fd < 0: + raise ValueError("Invalid file descriptor: {0}".format(fd)) + return fd + + +# Python 3.5 uses a more direct route to wrap system calls to increase speed. +if sys.version_info >= (3, 5): + def _syscall_wrapper(func, _, *args, **kwargs): + """ This is the short-circuit version of the below logic + because in Python 3.5+ all selectors restart system calls. """ + try: + return func(*args, **kwargs) + except (OSError, IOError, select.error) as e: + errcode = None + if hasattr(e, "errno"): + errcode = e.errno + elif hasattr(e, "args"): + errcode = e.args[0] + raise SelectorError(errcode) +else: + def _syscall_wrapper(func, recalc_timeout, *args, **kwargs): + """ Wrapper function for syscalls that could fail due to EINTR. + All functions should be retried if there is time left in the timeout + in accordance with PEP 475. """ + timeout = kwargs.get("timeout", None) + if timeout is None: + expires = None + recalc_timeout = False + else: + timeout = float(timeout) + if timeout < 0.0: # Timeout less than 0 treated as no timeout. + expires = None + else: + expires = monotonic() + timeout + + args = list(args) + if recalc_timeout and "timeout" not in kwargs: + raise ValueError( + "Timeout must be in args or kwargs to be recalculated") + + result = _SYSCALL_SENTINEL + while result is _SYSCALL_SENTINEL: + try: + result = func(*args, **kwargs) + # OSError is thrown by select.select + # IOError is thrown by select.epoll.poll + # select.error is thrown by select.poll.poll + # Aren't we thankful for Python 3.x rework for exceptions? + except (OSError, IOError, select.error) as e: + # select.error wasn't a subclass of OSError in the past. + errcode = None + if hasattr(e, "errno"): + errcode = e.errno + elif hasattr(e, "args"): + errcode = e.args[0] + + # Also test for the Windows equivalent of EINTR. + is_interrupt = (errcode == errno.EINTR or (hasattr(errno, "WSAEINTR") and + errcode == errno.WSAEINTR)) + + if is_interrupt: + if expires is not None: + current_time = monotonic() + if current_time > expires: + raise OSError(errno.ETIMEDOUT) + if recalc_timeout: + if "timeout" in kwargs: + kwargs["timeout"] = expires - current_time + continue + if errcode: + raise SelectorError(errcode) + else: + raise + return result + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) + + +class _SelectorMapping(Mapping): + """ Mapping of file objects to selector keys """ + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{0!r} is not registered.".format(fileobj)) + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(object): + """ Abstract Selector class + + A selector supports registering file objects to be monitored + for specific I/O events. + + A file object is a file descriptor or any object with a + `fileno()` method. An arbitrary object can be attached to the + file object which can be used for example to store context info, + a callback, etc. + + A selector can use various implementations (select(), poll(), epoll(), + and kqueue()) depending on the platform. The 'DefaultSelector' class uses + the most efficient implementation for the current platform. + """ + def __init__(self): + # Maps file descriptors to keys. + self._fd_to_key = {} + + # Read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """ Return a file descriptor from a file object. + This wraps _fileobj_to_fd() to do an exhaustive + search in case the object is invalid but we still + have it in our map. Used by unregister() so we can + unregister an object that was previously registered + even if it is closed. It is also used by _SelectorMapping + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + + # Search through all our mapped keys. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + """ Register a file object for a set of events to monitor. """ + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {0!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{0!r} (FD {1}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + """ Unregister a file object from being monitored. """ + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + # Getting the fileno of a closed socket on Windows errors with EBADF. + except socket.error as err: + if err.errno != errno.EBADF: + raise + else: + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + self._fd_to_key.pop(key.fd) + break + else: + raise KeyError("{0!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """ Change a registered file object monitored events and data. """ + # NOTE: Some subclasses optimize this operation even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + + return key + + def select(self, timeout=None): + """ Perform the actual selection until some monitored file objects + are ready or the timeout expires. """ + raise NotImplementedError() + + def close(self): + """ Close the selector. This must be called to ensure that all + underlying resources are freed. """ + self._fd_to_key.clear() + self._map = None + + def get_key(self, fileobj): + """ Return the key associated with a registered file object. """ + mapping = self.get_map() + if mapping is None: + raise RuntimeError("Selector is closed") + try: + return mapping[fileobj] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + def get_map(self): + """ Return a mapping of file objects to selector keys """ + return self._map + + def _key_from_fd(self, fd): + """ Return the key associated to a given file descriptor + Return None if it is not found. """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +# Almost all platforms have select.select() +if hasattr(select, "select"): + class SelectSelector(BaseSelector): + """ Select-based selector. """ + def __init__(self): + super(SelectSelector, self).__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super(SelectSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super(SelectSelector, self).unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def _select(self, r, w, timeout=None): + """ Wrapper for select.select because timeout is a positional arg """ + return select.select(r, w, [], timeout) + + def select(self, timeout=None): + # Selecting on empty lists on Windows errors out. + if not len(self._readers) and not len(self._writers): + return [] + + timeout = None if timeout is None else max(timeout, 0.0) + ready = [] + r, w, _ = _syscall_wrapper(self._select, True, self._readers, + self._writers, timeout=timeout) + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + __all__.append('SelectSelector') + + +if hasattr(select, "poll"): + class PollSelector(BaseSelector): + """ Poll-based selector """ + def __init__(self): + super(PollSelector, self).__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super(PollSelector, self).register(fileobj, events, data) + event_mask = 0 + if events & EVENT_READ: + event_mask |= select.POLLIN + if events & EVENT_WRITE: + event_mask |= select.POLLOUT + self._poll.register(key.fd, event_mask) + return key + + def unregister(self, fileobj): + key = super(PollSelector, self).unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def _wrap_poll(self, timeout=None): + """ Wrapper function for select.poll.poll() so that + _syscall_wrapper can work with only seconds. """ + if timeout is not None: + if timeout <= 0: + timeout = 0 + else: + # select.poll.poll() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + + result = self._poll.poll(timeout) + return result + + def select(self, timeout=None): + ready = [] + fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout) + for fd, event_mask in fd_events: + events = 0 + if event_mask & ~select.POLLIN: + events |= EVENT_WRITE + if event_mask & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + + return ready + + __all__.append('PollSelector') + +if hasattr(select, "epoll"): + class EpollSelector(BaseSelector): + """ Epoll-based selector """ + def __init__(self): + super(EpollSelector, self).__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(EpollSelector, self).register(fileobj, events, data) + events_mask = 0 + if events & EVENT_READ: + events_mask |= select.EPOLLIN + if events & EVENT_WRITE: + events_mask |= select.EPOLLOUT + _syscall_wrapper(self._epoll.register, False, key.fd, events_mask) + return key + + def unregister(self, fileobj): + key = super(EpollSelector, self).unregister(fileobj) + try: + _syscall_wrapper(self._epoll.unregister, False, key.fd) + except SelectorError: + # This can occur when the fd was closed since registry. + pass + return key + + def select(self, timeout=None): + if timeout is not None: + if timeout <= 0: + timeout = 0.0 + else: + # select.epoll.poll() has a resolution of 1 millisecond + # but luckily takes seconds so we don't need a wrapper + # like PollSelector. Just for better rounding. + timeout = math.ceil(timeout * 1e3) * 1e-3 + timeout = float(timeout) + else: + timeout = -1.0 # epoll.poll() must have a float. + + # We always want at least 1 to ensure that select can be called + # with no file descriptors registered. Otherwise will fail. + max_events = max(len(self._fd_to_key), 1) + + ready = [] + fd_events = _syscall_wrapper(self._epoll.poll, True, + timeout=timeout, + maxevents=max_events) + for fd, event_mask in fd_events: + events = 0 + if event_mask & ~select.EPOLLIN: + events |= EVENT_WRITE + if event_mask & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._epoll.close() + super(EpollSelector, self).close() + + __all__.append('EpollSelector') + + +if hasattr(select, "devpoll"): + class DevpollSelector(BaseSelector): + """Solaris /dev/poll selector.""" + + def __init__(self): + super(DevpollSelector, self).__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(DevpollSelector, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(DevpollSelector, self).unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def _wrap_poll(self, timeout=None): + """ Wrapper function for select.poll.poll() so that + _syscall_wrapper can work with only seconds. """ + if timeout is not None: + if timeout <= 0: + timeout = 0 + else: + # select.devpoll.poll() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + + result = self._devpoll.poll(timeout) + return result + + def select(self, timeout=None): + ready = [] + fd_events = _syscall_wrapper(self._wrap_poll, True, timeout=timeout) + for fd, event_mask in fd_events: + events = 0 + if event_mask & ~select.POLLIN: + events |= EVENT_WRITE + if event_mask & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + + return ready + + def close(self): + self._devpoll.close() + super(DevpollSelector, self).close() + + __all__.append('DevpollSelector') + + +if hasattr(select, "kqueue"): + class KqueueSelector(BaseSelector): + """ Kqueue / Kevent-based selector """ + def __init__(self): + super(KqueueSelector, self).__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super(KqueueSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + kevent = select.kevent(key.fd, + select.KQ_FILTER_READ, + select.KQ_EV_ADD) + + _syscall_wrapper(self._wrap_control, False, [kevent], 0, 0) + + if events & EVENT_WRITE: + kevent = select.kevent(key.fd, + select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + + _syscall_wrapper(self._wrap_control, False, [kevent], 0, 0) + + return key + + def unregister(self, fileobj): + key = super(KqueueSelector, self).unregister(fileobj) + if key.events & EVENT_READ: + kevent = select.kevent(key.fd, + select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + _syscall_wrapper(self._wrap_control, False, [kevent], 0, 0) + except SelectorError: + pass + if key.events & EVENT_WRITE: + kevent = select.kevent(key.fd, + select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + _syscall_wrapper(self._wrap_control, False, [kevent], 0, 0) + except SelectorError: + pass + + return key + + def select(self, timeout=None): + if timeout is not None: + timeout = max(timeout, 0) + + max_events = len(self._fd_to_key) * 2 + ready_fds = {} + + kevent_list = _syscall_wrapper(self._wrap_control, True, + None, max_events, timeout=timeout) + + for kevent in kevent_list: + fd = kevent.ident + event_mask = kevent.filter + events = 0 + if event_mask == select.KQ_FILTER_READ: + events |= EVENT_READ + if event_mask == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + if key.fd not in ready_fds: + ready_fds[key.fd] = (key, events & key.events) + else: + old_events = ready_fds[key.fd][1] + ready_fds[key.fd] = (key, (events | old_events) & key.events) + + return list(ready_fds.values()) + + def close(self): + self._kqueue.close() + super(KqueueSelector, self).close() + + def _wrap_control(self, changelist, max_events, timeout): + return self._kqueue.control(changelist, max_events, timeout) + + __all__.append('KqueueSelector') + + +# Choose the best implementation, roughly: +# kqueue == epoll == devpoll > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): # Platform-specific: Mac OS and BSD + DefaultSelector = KqueueSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'EpollSelector' in globals(): # Platform-specific: Linux + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): # Platform-specific: Linux + DefaultSelector = PollSelector +elif 'SelectSelector' in globals(): # Platform-specific: Windows + DefaultSelector = SelectSelector +else: # Platform-specific: AppEngine + def no_selector(_): + raise ValueError("Platform does not have a selector") + DefaultSelector = no_selector + HAS_SELECT = False diff --git a/lib/ansible/module_utils/compat/importlib.py b/lib/ansible/module_utils/compat/importlib.py new file mode 100644 index 0000000..0b7fb2c --- /dev/null +++ b/lib/ansible/module_utils/compat/importlib.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 Matt Martz <matt@sivel.net> +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import sys + +try: + from importlib import import_module +except ImportError: + # importlib.import_module returns the tail + # whereas __import__ returns the head + # compat to work like importlib.import_module + def import_module(name): # type: ignore[misc] + __import__(name) + return sys.modules[name] diff --git a/lib/ansible/module_utils/compat/paramiko.py b/lib/ansible/module_utils/compat/paramiko.py new file mode 100644 index 0000000..85478ea --- /dev/null +++ b/lib/ansible/module_utils/compat/paramiko.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import types +import warnings + +PARAMIKO_IMPORT_ERR = None + +try: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='Blowfish has been deprecated', category=UserWarning) + import paramiko +# paramiko and gssapi are incompatible and raise AttributeError not ImportError +# When running in FIPS mode, cryptography raises InternalError +# https://bugzilla.redhat.com/show_bug.cgi?id=1778939 +except Exception as err: + paramiko = None # type: types.ModuleType | None # type: ignore[no-redef] + PARAMIKO_IMPORT_ERR = err diff --git a/lib/ansible/module_utils/compat/selectors.py b/lib/ansible/module_utils/compat/selectors.py new file mode 100644 index 0000000..93ffc62 --- /dev/null +++ b/lib/ansible/module_utils/compat/selectors.py @@ -0,0 +1,57 @@ +# (c) 2014, 2017 Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +''' +Compat selectors library. Python-3.5 has this builtin. The selectors2 +package exists on pypi to backport the functionality as far as python-2.6. +''' +# The following makes it easier for us to script updates of the bundled code +_BUNDLED_METADATA = {"pypi_name": "selectors2", "version": "1.1.1", "version_constraints": ">1.0,<2.0"} + +# Added these bugfix commits from 2.1.0: +# * https://github.com/SethMichaelLarson/selectors2/commit/3bd74f2033363b606e1e849528ccaa76f5067590 +# Wrap kqueue.control so that timeout is a keyword arg +# * https://github.com/SethMichaelLarson/selectors2/commit/6f6a26f42086d8aab273b30be492beecb373646b +# Fix formatting of the kqueue.control patch for pylint +# * https://github.com/SethMichaelLarson/selectors2/commit/f0c2c6c66cfa7662bc52beaf4e2d65adfa25e189 +# Fix use of OSError exception for py3 and use the wrapper of kqueue.control so retries of +# interrupted syscalls work with kqueue + +import os.path +import sys +import types + +try: + # Python 3.4+ + import selectors as _system_selectors +except ImportError: + try: + # backport package installed in the system + import selectors2 as _system_selectors # type: ignore[no-redef] + except ImportError: + _system_selectors = None # type: types.ModuleType | None # type: ignore[no-redef] + +if _system_selectors: + selectors = _system_selectors +else: + # Our bundled copy + from ansible.module_utils.compat import _selectors2 as selectors # type: ignore[no-redef] +sys.modules['ansible.module_utils.compat.selectors'] = selectors diff --git a/lib/ansible/module_utils/compat/selinux.py b/lib/ansible/module_utils/compat/selinux.py new file mode 100644 index 0000000..7191713 --- /dev/null +++ b/lib/ansible/module_utils/compat/selinux.py @@ -0,0 +1,113 @@ +# Copyright: (c) 2021, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import sys + +from ansible.module_utils.common.text.converters import to_native, to_bytes +from ctypes import CDLL, c_char_p, c_int, byref, POINTER, get_errno + +try: + _selinux_lib = CDLL('libselinux.so.1', use_errno=True) +except OSError: + raise ImportError('unable to load libselinux.so') + + +def _module_setup(): + def _check_rc(rc): + if rc < 0: + errno = get_errno() + raise OSError(errno, os.strerror(errno)) + return rc + + binary_char_type = type(b'') + + class _to_char_p: + @classmethod + def from_param(cls, strvalue): + if strvalue is not None and not isinstance(strvalue, binary_char_type): + strvalue = to_bytes(strvalue) + + return strvalue + + # FIXME: swap restype to errcheck + + _funcmap = dict( + is_selinux_enabled={}, + is_selinux_mls_enabled={}, + lgetfilecon_raw=dict(argtypes=[_to_char_p, POINTER(c_char_p)], restype=_check_rc), + # NB: matchpathcon is deprecated and should be rewritten on selabel_lookup (but will be a PITA) + matchpathcon=dict(argtypes=[_to_char_p, c_int, POINTER(c_char_p)], restype=_check_rc), + security_policyvers={}, + selinux_getenforcemode=dict(argtypes=[POINTER(c_int)]), + security_getenforce={}, + lsetfilecon=dict(argtypes=[_to_char_p, _to_char_p], restype=_check_rc), + selinux_getpolicytype=dict(argtypes=[POINTER(c_char_p)], restype=_check_rc), + ) + + _thismod = sys.modules[__name__] + + for fname, cfg in _funcmap.items(): + fn = getattr(_selinux_lib, fname, None) + + if not fn: + raise ImportError('missing selinux function: {0}'.format(fname)) + + # all ctypes pointers share the same base type + base_ptr_type = type(POINTER(c_int)) + fn.argtypes = cfg.get('argtypes', None) + fn.restype = cfg.get('restype', c_int) + + # just patch simple directly callable functions directly onto the module + if not fn.argtypes or not any(argtype for argtype in fn.argtypes if type(argtype) == base_ptr_type): + setattr(_thismod, fname, fn) + continue + + # NB: this validation code must run after all the wrappers have been declared + unimplemented_funcs = set(_funcmap).difference(dir(_thismod)) + if unimplemented_funcs: + raise NotImplementedError('implementation is missing functions: {0}'.format(unimplemented_funcs)) + + +# begin wrapper function impls + +def selinux_getenforcemode(): + enforcemode = c_int() + rc = _selinux_lib.selinux_getenforcemode(byref(enforcemode)) + return [rc, enforcemode.value] + + +def selinux_getpolicytype(): + con = c_char_p() + try: + rc = _selinux_lib.selinux_getpolicytype(byref(con)) + return [rc, to_native(con.value)] + finally: + _selinux_lib.freecon(con) + + +def lgetfilecon_raw(path): + con = c_char_p() + try: + rc = _selinux_lib.lgetfilecon_raw(path, byref(con)) + return [rc, to_native(con.value)] + finally: + _selinux_lib.freecon(con) + + +def matchpathcon(path, mode): + con = c_char_p() + try: + rc = _selinux_lib.matchpathcon(path, mode, byref(con)) + return [rc, to_native(con.value)] + finally: + _selinux_lib.freecon(con) + + +_module_setup() +del _module_setup + +# end wrapper function impls diff --git a/lib/ansible/module_utils/compat/typing.py b/lib/ansible/module_utils/compat/typing.py new file mode 100644 index 0000000..27b25f7 --- /dev/null +++ b/lib/ansible/module_utils/compat/typing.py @@ -0,0 +1,25 @@ +"""Compatibility layer for the `typing` module, providing all Python versions access to the newest type-hinting features.""" +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +# pylint: disable=wildcard-import,unused-wildcard-import + +# catch *all* exceptions to prevent type annotation support module bugs causing runtime failures +# (eg, https://github.com/ansible/ansible/issues/77857) + +try: + from typing_extensions import * +except Exception: # pylint: disable=broad-except + pass + +try: + from typing import * # type: ignore[misc] +except Exception: # pylint: disable=broad-except + pass + + +try: + cast +except NameError: + def cast(typ, val): # type: ignore[no-redef] + return val diff --git a/lib/ansible/module_utils/compat/version.py b/lib/ansible/module_utils/compat/version.py new file mode 100644 index 0000000..f4db1ef --- /dev/null +++ b/lib/ansible/module_utils/compat/version.py @@ -0,0 +1,343 @@ +# Vendored copy of distutils/version.py from CPython 3.9.5 +# +# Implements multiple version numbering conventions for the +# Python Module Distribution Utilities. +# +# PSF License (see licenses/PSF-license.txt or https://opensource.org/licenses/Python-2.0) +# + +"""Provides classes to represent module version numbers (one class for +each style of version numbering). There are currently two such classes +implemented: StrictVersion and LooseVersion. + +Every version number class implements the following interface: + * the 'parse' method takes a string and parses it to some internal + representation; if the string is an invalid version number, + 'parse' raises a ValueError exception + * the class constructor takes an optional string argument which, + if supplied, is passed to 'parse' + * __str__ reconstructs the string that was passed to 'parse' (or + an equivalent string -- ie. one that will generate an equivalent + version number instance) + * __repr__ generates Python code to recreate the version number instance + * _cmp compares the current instance with either another instance + of the same class or a string (which will be parsed to an instance + of the same class, thus must follow the same rules) +""" + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +try: + RE_FLAGS = re.VERBOSE | re.ASCII # type: ignore[attr-defined] +except AttributeError: + RE_FLAGS = re.VERBOSE + + +class Version: + """Abstract base class for version numbering classes. Just provides + constructor (__init__) and reproducer (__repr__), because those + seem to be the same for all version numbering classes; and route + rich comparisons to _cmp. + """ + + def __init__(self, vstring=None): + if vstring: + self.parse(vstring) + + def __repr__(self): + return "%s ('%s')" % (self.__class__.__name__, str(self)) + + def __eq__(self, other): + c = self._cmp(other) + if c is NotImplemented: + return c + return c == 0 + + def __lt__(self, other): + c = self._cmp(other) + if c is NotImplemented: + return c + return c < 0 + + def __le__(self, other): + c = self._cmp(other) + if c is NotImplemented: + return c + return c <= 0 + + def __gt__(self, other): + c = self._cmp(other) + if c is NotImplemented: + return c + return c > 0 + + def __ge__(self, other): + c = self._cmp(other) + if c is NotImplemented: + return c + return c >= 0 + + +# Interface for version-number classes -- must be implemented +# by the following classes (the concrete ones -- Version should +# be treated as an abstract class). +# __init__ (string) - create and take same action as 'parse' +# (string parameter is optional) +# parse (string) - convert a string representation to whatever +# internal representation is appropriate for +# this style of version numbering +# __str__ (self) - convert back to a string; should be very similar +# (if not identical to) the string supplied to parse +# __repr__ (self) - generate Python code to recreate +# the instance +# _cmp (self, other) - compare two version numbers ('other' may +# be an unparsed version string, or another +# instance of your version class) + + +class StrictVersion(Version): + """Version numbering for anal retentives and software idealists. + Implements the standard interface for version number classes as + described above. A version number consists of two or three + dot-separated numeric components, with an optional "pre-release" tag + on the end. The pre-release tag consists of the letter 'a' or 'b' + followed by a number. If the numeric components of two version + numbers are equal, then one with a pre-release tag will always + be deemed earlier (lesser) than one without. + + The following are valid version numbers (shown in the order that + would be obtained by sorting according to the supplied cmp function): + + 0.4 0.4.0 (these two are equivalent) + 0.4.1 + 0.5a1 + 0.5b3 + 0.5 + 0.9.6 + 1.0 + 1.0.4a3 + 1.0.4b1 + 1.0.4 + + The following are examples of invalid version numbers: + + 1 + 2.7.2.2 + 1.3.a4 + 1.3pl1 + 1.3c4 + + The rationale for this version numbering system will be explained + in the distutils documentation. + """ + + version_re = re.compile(r'^(\d+) \. (\d+) (\. (\d+))? ([ab](\d+))?$', + RE_FLAGS) + + def parse(self, vstring): + match = self.version_re.match(vstring) + if not match: + raise ValueError("invalid version number '%s'" % vstring) + + (major, minor, patch, prerelease, prerelease_num) = \ + match.group(1, 2, 4, 5, 6) + + if patch: + self.version = tuple(map(int, [major, minor, patch])) + else: + self.version = tuple(map(int, [major, minor])) + (0,) + + if prerelease: + self.prerelease = (prerelease[0], int(prerelease_num)) + else: + self.prerelease = None + + def __str__(self): + if self.version[2] == 0: + vstring = '.'.join(map(str, self.version[0:2])) + else: + vstring = '.'.join(map(str, self.version)) + + if self.prerelease: + vstring = vstring + self.prerelease[0] + str(self.prerelease[1]) + + return vstring + + def _cmp(self, other): + if isinstance(other, str): + other = StrictVersion(other) + elif not isinstance(other, StrictVersion): + return NotImplemented + + if self.version != other.version: + # numeric versions don't match + # prerelease stuff doesn't matter + if self.version < other.version: + return -1 + else: + return 1 + + # have to compare prerelease + # case 1: neither has prerelease; they're equal + # case 2: self has prerelease, other doesn't; other is greater + # case 3: self doesn't have prerelease, other does: self is greater + # case 4: both have prerelease: must compare them! + + if (not self.prerelease and not other.prerelease): + return 0 + elif (self.prerelease and not other.prerelease): + return -1 + elif (not self.prerelease and other.prerelease): + return 1 + elif (self.prerelease and other.prerelease): + if self.prerelease == other.prerelease: + return 0 + elif self.prerelease < other.prerelease: + return -1 + else: + return 1 + else: + raise AssertionError("never get here") + +# end class StrictVersion + +# The rules according to Greg Stein: +# 1) a version number has 1 or more numbers separated by a period or by +# sequences of letters. If only periods, then these are compared +# left-to-right to determine an ordering. +# 2) sequences of letters are part of the tuple for comparison and are +# compared lexicographically +# 3) recognize the numeric components may have leading zeroes +# +# The LooseVersion class below implements these rules: a version number +# string is split up into a tuple of integer and string components, and +# comparison is a simple tuple comparison. This means that version +# numbers behave in a predictable and obvious way, but a way that might +# not necessarily be how people *want* version numbers to behave. There +# wouldn't be a problem if people could stick to purely numeric version +# numbers: just split on period and compare the numbers as tuples. +# However, people insist on putting letters into their version numbers; +# the most common purpose seems to be: +# - indicating a "pre-release" version +# ('alpha', 'beta', 'a', 'b', 'pre', 'p') +# - indicating a post-release patch ('p', 'pl', 'patch') +# but of course this can't cover all version number schemes, and there's +# no way to know what a programmer means without asking them. +# +# The problem is what to do with letters (and other non-numeric +# characters) in a version number. The current implementation does the +# obvious and predictable thing: keep them as strings and compare +# lexically within a tuple comparison. This has the desired effect if +# an appended letter sequence implies something "post-release": +# eg. "0.99" < "0.99pl14" < "1.0", and "5.001" < "5.001m" < "5.002". +# +# However, if letters in a version number imply a pre-release version, +# the "obvious" thing isn't correct. Eg. you would expect that +# "1.5.1" < "1.5.2a2" < "1.5.2", but under the tuple/lexical comparison +# implemented here, this just isn't so. +# +# Two possible solutions come to mind. The first is to tie the +# comparison algorithm to a particular set of semantic rules, as has +# been done in the StrictVersion class above. This works great as long +# as everyone can go along with bondage and discipline. Hopefully a +# (large) subset of Python module programmers will agree that the +# particular flavour of bondage and discipline provided by StrictVersion +# provides enough benefit to be worth using, and will submit their +# version numbering scheme to its domination. The free-thinking +# anarchists in the lot will never give in, though, and something needs +# to be done to accommodate them. +# +# Perhaps a "moderately strict" version class could be implemented that +# lets almost anything slide (syntactically), and makes some heuristic +# assumptions about non-digits in version number strings. This could +# sink into special-case-hell, though; if I was as talented and +# idiosyncratic as Larry Wall, I'd go ahead and implement a class that +# somehow knows that "1.2.1" < "1.2.2a2" < "1.2.2" < "1.2.2pl3", and is +# just as happy dealing with things like "2g6" and "1.13++". I don't +# think I'm smart enough to do it right though. +# +# In any case, I've coded the test suite for this module (see +# ../test/test_version.py) specifically to fail on things like comparing +# "1.2a2" and "1.2". That's not because the *code* is doing anything +# wrong, it's because the simple, obvious design doesn't match my +# complicated, hairy expectations for real-world version numbers. It +# would be a snap to fix the test suite to say, "Yep, LooseVersion does +# the Right Thing" (ie. the code matches the conception). But I'd rather +# have a conception that matches common notions about version numbers. + + +class LooseVersion(Version): + """Version numbering for anarchists and software realists. + Implements the standard interface for version number classes as + described above. A version number consists of a series of numbers, + separated by either periods or strings of letters. When comparing + version numbers, the numeric components will be compared + numerically, and the alphabetic components lexically. The following + are all valid version numbers, in no particular order: + + 1.5.1 + 1.5.2b2 + 161 + 3.10a + 8.02 + 3.4j + 1996.07.12 + 3.2.pl0 + 3.1.1.6 + 2g6 + 11g + 0.960923 + 2.2beta29 + 1.13++ + 5.5.kw + 2.0b1pl0 + + In fact, there is no such thing as an invalid version number under + this scheme; the rules for comparison are simple and predictable, + but may not always give the results you want (for some definition + of "want"). + """ + + component_re = re.compile(r'(\d+ | [a-z]+ | \.)', re.VERBOSE) + + def __init__(self, vstring=None): + if vstring: + self.parse(vstring) + + def parse(self, vstring): + # I've given up on thinking I can reconstruct the version string + # from the parsed tuple -- so I just store the string here for + # use by __str__ + self.vstring = vstring + components = [x for x in self.component_re.split(vstring) if x and x != '.'] + for i, obj in enumerate(components): + try: + components[i] = int(obj) + except ValueError: + pass + + self.version = components + + def __str__(self): + return self.vstring + + def __repr__(self): + return "LooseVersion ('%s')" % str(self) + + def _cmp(self, other): + if isinstance(other, str): + other = LooseVersion(other) + elif not isinstance(other, LooseVersion): + return NotImplemented + + if self.version == other.version: + return 0 + if self.version < other.version: + return -1 + if self.version > other.version: + return 1 + +# end class LooseVersion diff --git a/lib/ansible/module_utils/connection.py b/lib/ansible/module_utils/connection.py new file mode 100644 index 0000000..1396c1c --- /dev/null +++ b/lib/ansible/module_utils/connection.py @@ -0,0 +1,222 @@ +# +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import hashlib +import json +import socket +import struct +import traceback +import uuid + +from functools import partial +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.common.json import AnsibleJSONEncoder +from ansible.module_utils.six import iteritems +from ansible.module_utils.six.moves import cPickle + + +def write_to_file_descriptor(fd, obj): + """Handles making sure all data is properly written to file descriptor fd. + + In particular, that data is encoded in a character stream-friendly way and + that all data gets written before returning. + """ + # Need to force a protocol that is compatible with both py2 and py3. + # That would be protocol=2 or less. + # Also need to force a protocol that excludes certain control chars as + # stdin in this case is a pty and control chars will cause problems. + # that means only protocol=0 will work. + src = cPickle.dumps(obj, protocol=0) + + # raw \r characters will not survive pty round-trip + # They should be rehydrated on the receiving end + src = src.replace(b'\r', br'\r') + data_hash = to_bytes(hashlib.sha1(src).hexdigest()) + + os.write(fd, b'%d\n' % len(src)) + os.write(fd, src) + os.write(fd, b'%s\n' % data_hash) + + +def send_data(s, data): + packed_len = struct.pack('!Q', len(data)) + return s.sendall(packed_len + data) + + +def recv_data(s): + header_len = 8 # size of a packed unsigned long long + data = to_bytes("") + while len(data) < header_len: + d = s.recv(header_len - len(data)) + if not d: + return None + data += d + data_len = struct.unpack('!Q', data[:header_len])[0] + data = data[header_len:] + while len(data) < data_len: + d = s.recv(data_len - len(data)) + if not d: + return None + data += d + return data + + +def exec_command(module, command): + connection = Connection(module._socket_path) + try: + out = connection.exec_command(command) + except ConnectionError as exc: + code = getattr(exc, 'code', 1) + message = getattr(exc, 'err', exc) + return code, '', to_text(message, errors='surrogate_then_replace') + return 0, out, '' + + +def request_builder(method_, *args, **kwargs): + reqid = str(uuid.uuid4()) + req = {'jsonrpc': '2.0', 'method': method_, 'id': reqid} + req['params'] = (args, kwargs) + + return req + + +class ConnectionError(Exception): + + def __init__(self, message, *args, **kwargs): + super(ConnectionError, self).__init__(message) + for k, v in iteritems(kwargs): + setattr(self, k, v) + + +class Connection(object): + + def __init__(self, socket_path): + if socket_path is None: + raise AssertionError('socket_path must be a value') + self.socket_path = socket_path + + def __getattr__(self, name): + try: + return self.__dict__[name] + except KeyError: + if name.startswith('_'): + raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) + return partial(self.__rpc__, name) + + def _exec_jsonrpc(self, name, *args, **kwargs): + + req = request_builder(name, *args, **kwargs) + reqid = req['id'] + + if not os.path.exists(self.socket_path): + raise ConnectionError( + 'socket path %s does not exist or cannot be found. See Troubleshooting socket ' + 'path issues in the Network Debug and Troubleshooting Guide' % self.socket_path + ) + + try: + data = json.dumps(req, cls=AnsibleJSONEncoder, vault_to_text=True) + except TypeError as exc: + raise ConnectionError( + "Failed to encode some variables as JSON for communication with ansible-connection. " + "The original exception was: %s" % to_text(exc) + ) + + try: + out = self.send(data) + except socket.error as e: + raise ConnectionError( + 'unable to connect to socket %s. See Troubleshooting socket path issues ' + 'in the Network Debug and Troubleshooting Guide' % self.socket_path, + err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc() + ) + + try: + response = json.loads(out) + except ValueError: + # set_option(s) has sensitive info, and the details are unlikely to matter anyway + if name.startswith("set_option"): + raise ConnectionError( + "Unable to decode JSON from response to {0}. Received '{1}'.".format(name, out) + ) + params = [repr(arg) for arg in args] + ['{0}={1!r}'.format(k, v) for k, v in iteritems(kwargs)] + params = ', '.join(params) + raise ConnectionError( + "Unable to decode JSON from response to {0}({1}). Received '{2}'.".format(name, params, out) + ) + + if response['id'] != reqid: + raise ConnectionError('invalid json-rpc id received') + if "result_type" in response: + response["result"] = cPickle.loads(to_bytes(response["result"])) + + return response + + def __rpc__(self, name, *args, **kwargs): + """Executes the json-rpc and returns the output received + from remote device. + :name: rpc method to be executed over connection plugin that implements jsonrpc 2.0 + :args: Ordered list of params passed as arguments to rpc method + :kwargs: Dict of valid key, value pairs passed as arguments to rpc method + + For usage refer the respective connection plugin docs. + """ + + response = self._exec_jsonrpc(name, *args, **kwargs) + + if 'error' in response: + err = response.get('error') + msg = err.get('data') or err['message'] + code = err['code'] + raise ConnectionError(to_text(msg, errors='surrogate_then_replace'), code=code) + + return response['result'] + + def send(self, data): + try: + sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sf.connect(self.socket_path) + + send_data(sf, to_bytes(data)) + response = recv_data(sf) + + except socket.error as e: + sf.close() + raise ConnectionError( + 'unable to connect to socket %s. See the socket path issue category in ' + 'Network Debug and Troubleshooting Guide' % self.socket_path, + err=to_text(e, errors='surrogate_then_replace'), exception=traceback.format_exc() + ) + + sf.close() + + return to_text(response, errors='surrogate_or_strict') diff --git a/lib/ansible/module_utils/csharp/Ansible.AccessToken.cs b/lib/ansible/module_utils/csharp/Ansible.AccessToken.cs new file mode 100644 index 0000000..48c4a19 --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.AccessToken.cs @@ -0,0 +1,460 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; + +namespace Ansible.AccessToken +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential)] + public struct LUID_AND_ATTRIBUTES + { + public Luid Luid; + public UInt32 Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SID_AND_ATTRIBUTES + { + public IntPtr Sid; + public int Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_PRIVILEGES + { + public UInt32 PrivilegeCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public LUID_AND_ATTRIBUTES[] Privileges; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_USER + { + public SID_AND_ATTRIBUTES User; + } + + public enum TokenInformationClass : uint + { + TokenUser = 1, + TokenPrivileges = 3, + TokenStatistics = 10, + TokenElevationType = 18, + TokenLinkedToken = 19, + } + } + + internal class NativeMethods + { + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool CloseHandle( + IntPtr hObject); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool DuplicateTokenEx( + SafeNativeHandle hExistingToken, + TokenAccessLevels dwDesiredAccess, + IntPtr lpTokenAttributes, + SecurityImpersonationLevel ImpersonationLevel, + TokenType TokenType, + out SafeNativeHandle phNewToken); + + [DllImport("kernel32.dll")] + public static extern SafeNativeHandle GetCurrentProcess(); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool GetTokenInformation( + SafeNativeHandle TokenHandle, + NativeHelpers.TokenInformationClass TokenInformationClass, + SafeMemoryBuffer TokenInformation, + UInt32 TokenInformationLength, + out UInt32 ReturnLength); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool ImpersonateLoggedOnUser( + SafeNativeHandle hToken); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LogonUserW( + string lpszUsername, + string lpszDomain, + string lpszPassword, + LogonType dwLogonType, + LogonProvider dwLogonProvider, + out SafeNativeHandle phToken); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LookupPrivilegeNameW( + string lpSystemName, + ref Luid lpLuid, + StringBuilder lpName, + ref UInt32 cchName); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern SafeNativeHandle OpenProcess( + ProcessAccessFlags dwDesiredAccess, + bool bInheritHandle, + UInt32 dwProcessId); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool OpenProcessToken( + SafeNativeHandle ProcessHandle, + TokenAccessLevels DesiredAccess, + out SafeNativeHandle TokenHandle); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool RevertToSelf(); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + public enum LogonProvider + { + Default, + WinNT35, + WinNT40, + WinNT50, + } + + public enum LogonType + { + Interactive = 2, + Network = 3, + Batch = 4, + Service = 5, + Unlock = 7, + NetworkCleartext = 8, + NewCredentials = 9, + } + + [Flags] + public enum PrivilegeAttributes : uint + { + Disabled = 0x00000000, + EnabledByDefault = 0x00000001, + Enabled = 0x00000002, + Removed = 0x00000004, + UsedForAccess = 0x80000000, + } + + [Flags] + public enum ProcessAccessFlags : uint + { + Terminate = 0x00000001, + CreateThread = 0x00000002, + VmOperation = 0x00000008, + VmRead = 0x00000010, + VmWrite = 0x00000020, + DupHandle = 0x00000040, + CreateProcess = 0x00000080, + SetQuota = 0x00000100, + SetInformation = 0x00000200, + QueryInformation = 0x00000400, + SuspendResume = 0x00000800, + QueryLimitedInformation = 0x00001000, + Delete = 0x00010000, + ReadControl = 0x00020000, + WriteDac = 0x00040000, + WriteOwner = 0x00080000, + Synchronize = 0x00100000, + } + + public enum SecurityImpersonationLevel + { + Anonymous, + Identification, + Impersonation, + Delegation, + } + + public enum TokenElevationType + { + Default = 1, + Full, + Limited, + } + + public enum TokenType + { + Primary = 1, + Impersonation, + } + + [StructLayout(LayoutKind.Sequential)] + public struct Luid + { + public UInt32 LowPart; + public Int32 HighPart; + + public static explicit operator UInt64(Luid l) + { + return (UInt64)((UInt64)l.HighPart << 32) | (UInt64)l.LowPart; + } + } + + [StructLayout(LayoutKind.Sequential)] + public struct TokenStatistics + { + public Luid TokenId; + public Luid AuthenticationId; + public Int64 ExpirationTime; + public TokenType TokenType; + public SecurityImpersonationLevel ImpersonationLevel; + public UInt32 DynamicCharged; + public UInt32 DynamicAvailable; + public UInt32 GroupCount; + public UInt32 PrivilegeCount; + public Luid ModifiedId; + } + + public class PrivilegeInfo + { + public string Name; + public PrivilegeAttributes Attributes; + + internal PrivilegeInfo(NativeHelpers.LUID_AND_ATTRIBUTES la) + { + Name = TokenUtil.GetPrivilegeName(la.Luid); + Attributes = (PrivilegeAttributes)la.Attributes; + } + } + + public class SafeNativeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeNativeHandle() : base(true) { } + public SafeNativeHandle(IntPtr handle) : base(true) { this.handle = handle; } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseHandle(handle); + } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2} - 0x{2:X8})", message, base.Message, errorCode); + } + + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + public class TokenUtil + { + public static SafeNativeHandle DuplicateToken(SafeNativeHandle hToken, TokenAccessLevels access, + SecurityImpersonationLevel impersonationLevel, TokenType tokenType) + { + SafeNativeHandle dupToken; + if (!NativeMethods.DuplicateTokenEx(hToken, access, IntPtr.Zero, impersonationLevel, tokenType, out dupToken)) + throw new Win32Exception("Failed to duplicate token"); + return dupToken; + } + + public static SecurityIdentifier GetTokenUser(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenUser)) + { + NativeHelpers.TOKEN_USER tokenUser = (NativeHelpers.TOKEN_USER)Marshal.PtrToStructure( + tokenInfo.DangerousGetHandle(), + typeof(NativeHelpers.TOKEN_USER)); + return new SecurityIdentifier(tokenUser.User.Sid); + } + } + + public static List<PrivilegeInfo> GetTokenPrivileges(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenPrivileges)) + { + NativeHelpers.TOKEN_PRIVILEGES tokenPrivs = (NativeHelpers.TOKEN_PRIVILEGES)Marshal.PtrToStructure( + tokenInfo.DangerousGetHandle(), + typeof(NativeHelpers.TOKEN_PRIVILEGES)); + + NativeHelpers.LUID_AND_ATTRIBUTES[] luidAttrs = + new NativeHelpers.LUID_AND_ATTRIBUTES[tokenPrivs.PrivilegeCount]; + PtrToStructureArray(luidAttrs, IntPtr.Add(tokenInfo.DangerousGetHandle(), + Marshal.SizeOf(tokenPrivs.PrivilegeCount))); + + return luidAttrs.Select(la => new PrivilegeInfo(la)).ToList(); + } + } + + public static TokenStatistics GetTokenStatistics(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenStatistics)) + { + TokenStatistics tokenStats = (TokenStatistics)Marshal.PtrToStructure( + tokenInfo.DangerousGetHandle(), + typeof(TokenStatistics)); + return tokenStats; + } + } + + public static TokenElevationType GetTokenElevationType(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenElevationType)) + { + return (TokenElevationType)Marshal.ReadInt32(tokenInfo.DangerousGetHandle()); + } + } + + public static SafeNativeHandle GetTokenLinkedToken(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenLinkedToken)) + { + return new SafeNativeHandle(Marshal.ReadIntPtr(tokenInfo.DangerousGetHandle())); + } + } + + public static IEnumerable<SafeNativeHandle> EnumerateUserTokens(SecurityIdentifier sid, + TokenAccessLevels access = TokenAccessLevels.Query) + { + foreach (System.Diagnostics.Process process in System.Diagnostics.Process.GetProcesses()) + { + // We always need the Query access level so we can query the TokenUser + using (process) + using (SafeNativeHandle hToken = TryOpenAccessToken(process, access | TokenAccessLevels.Query)) + { + if (hToken == null) + continue; + + if (!sid.Equals(GetTokenUser(hToken))) + continue; + + yield return hToken; + } + } + } + + public static void ImpersonateToken(SafeNativeHandle hToken) + { + if (!NativeMethods.ImpersonateLoggedOnUser(hToken)) + throw new Win32Exception("Failed to impersonate token"); + } + + public static SafeNativeHandle LogonUser(string username, string domain, string password, LogonType logonType, + LogonProvider logonProvider) + { + SafeNativeHandle hToken; + if (!NativeMethods.LogonUserW(username, domain, password, logonType, logonProvider, out hToken)) + throw new Win32Exception(String.Format("Failed to logon {0}", + String.IsNullOrEmpty(domain) ? username : domain + "\\" + username)); + + return hToken; + } + + public static SafeNativeHandle OpenProcess() + { + return NativeMethods.GetCurrentProcess(); + } + + public static SafeNativeHandle OpenProcess(Int32 pid, ProcessAccessFlags access, bool inherit) + { + SafeNativeHandle hProcess = NativeMethods.OpenProcess(access, inherit, (UInt32)pid); + if (hProcess.IsInvalid) + throw new Win32Exception(String.Format("Failed to open process {0} with access {1}", + pid, access.ToString())); + + return hProcess; + } + + public static SafeNativeHandle OpenProcessToken(SafeNativeHandle hProcess, TokenAccessLevels access) + { + SafeNativeHandle hToken; + if (!NativeMethods.OpenProcessToken(hProcess, access, out hToken)) + throw new Win32Exception(String.Format("Failed to open process token with access {0}", + access.ToString())); + + return hToken; + } + + public static void RevertToSelf() + { + if (!NativeMethods.RevertToSelf()) + throw new Win32Exception("Failed to revert thread impersonation"); + } + + internal static string GetPrivilegeName(Luid luid) + { + UInt32 nameLen = 0; + NativeMethods.LookupPrivilegeNameW(null, ref luid, null, ref nameLen); + + StringBuilder name = new StringBuilder((int)(nameLen + 1)); + if (!NativeMethods.LookupPrivilegeNameW(null, ref luid, name, ref nameLen)) + throw new Win32Exception("LookupPrivilegeName() failed"); + + return name.ToString(); + } + + private static SafeMemoryBuffer GetTokenInformation(SafeNativeHandle hToken, + NativeHelpers.TokenInformationClass infoClass) + { + UInt32 tokenLength; + bool res = NativeMethods.GetTokenInformation(hToken, infoClass, new SafeMemoryBuffer(IntPtr.Zero), 0, + out tokenLength); + int errCode = Marshal.GetLastWin32Error(); + if (!res && errCode != 24 && errCode != 122) // ERROR_INSUFFICIENT_BUFFER, ERROR_BAD_LENGTH + throw new Win32Exception(errCode, String.Format("GetTokenInformation({0}) failed to get buffer length", + infoClass.ToString())); + + SafeMemoryBuffer tokenInfo = new SafeMemoryBuffer((int)tokenLength); + if (!NativeMethods.GetTokenInformation(hToken, infoClass, tokenInfo, tokenLength, out tokenLength)) + throw new Win32Exception(String.Format("GetTokenInformation({0}) failed", infoClass.ToString())); + + return tokenInfo; + } + + private static void PtrToStructureArray<T>(T[] array, IntPtr ptr) + { + IntPtr ptrOffset = ptr; + for (int i = 0; i < array.Length; i++, ptrOffset = IntPtr.Add(ptrOffset, Marshal.SizeOf(typeof(T)))) + array[i] = (T)Marshal.PtrToStructure(ptrOffset, typeof(T)); + } + + private static SafeNativeHandle TryOpenAccessToken(System.Diagnostics.Process process, TokenAccessLevels access) + { + try + { + using (SafeNativeHandle hProcess = OpenProcess(process.Id, ProcessAccessFlags.QueryInformation, false)) + return OpenProcessToken(hProcess, access); + } + catch (Win32Exception) + { + return null; + } + } + } +} diff --git a/lib/ansible/module_utils/csharp/Ansible.Basic.cs b/lib/ansible/module_utils/csharp/Ansible.Basic.cs new file mode 100644 index 0000000..c68281e --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Basic.cs @@ -0,0 +1,1489 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Management.Automation; +using System.Management.Automation.Runspaces; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Security.AccessControl; +using System.Security.Principal; +#if CORECLR +using Newtonsoft.Json; +#else +using System.Web.Script.Serialization; +#endif + +// Newtonsoft.Json may reference a different System.Runtime version (6.x) than loaded by PowerShell 7.3 (7.x). +// Ignore CS1701 so the code can be compiled when warnings are reported as errors. +//NoWarn -Name CS1701 -CLR Core + +// System.Diagnostics.EventLog.dll reference different versioned dlls that are +// loaded in PSCore, ignore CS1702 so the code will ignore this warning +//NoWarn -Name CS1702 -CLR Core + +//AssemblyReference -Type Newtonsoft.Json.JsonConvert -CLR Core +//AssemblyReference -Type System.Diagnostics.EventLog -CLR Core +//AssemblyReference -Type System.Security.AccessControl.NativeObjectSecurity -CLR Core +//AssemblyReference -Type System.Security.AccessControl.DirectorySecurity -CLR Core +//AssemblyReference -Type System.Security.Principal.IdentityReference -CLR Core + +//AssemblyReference -Name System.Web.Extensions.dll -CLR Framework + +namespace Ansible.Basic +{ + public class AnsibleModule + { + public delegate void ExitHandler(int rc); + public static ExitHandler Exit = new ExitHandler(ExitModule); + + public delegate void WriteLineHandler(string line); + public static WriteLineHandler WriteLine = new WriteLineHandler(WriteLineModule); + + public static bool _DebugArgSpec = false; + + private static List<string> BOOLEANS_TRUE = new List<string>() { "y", "yes", "on", "1", "true", "t", "1.0" }; + private static List<string> BOOLEANS_FALSE = new List<string>() { "n", "no", "off", "0", "false", "f", "0.0" }; + + private string remoteTmp = Path.GetTempPath(); + private string tmpdir = null; + private HashSet<string> noLogValues = new HashSet<string>(); + private List<string> optionsContext = new List<string>(); + private List<string> warnings = new List<string>(); + private List<Dictionary<string, string>> deprecations = new List<Dictionary<string, string>>(); + private List<string> cleanupFiles = new List<string>(); + + private Dictionary<string, string> passVars = new Dictionary<string, string>() + { + // null values means no mapping, not used in Ansible.Basic.AnsibleModule + { "check_mode", "CheckMode" }, + { "debug", "DebugMode" }, + { "diff", "DiffMode" }, + { "keep_remote_files", "KeepRemoteFiles" }, + { "module_name", "ModuleName" }, + { "no_log", "NoLog" }, + { "remote_tmp", "remoteTmp" }, + { "selinux_special_fs", null }, + { "shell_executable", null }, + { "socket", null }, + { "string_conversion_action", null }, + { "syslog_facility", null }, + { "tmpdir", "tmpdir" }, + { "verbosity", "Verbosity" }, + { "version", "AnsibleVersion" }, + }; + private List<string> passBools = new List<string>() { "check_mode", "debug", "diff", "keep_remote_files", "no_log" }; + private List<string> passInts = new List<string>() { "verbosity" }; + private Dictionary<string, List<object>> specDefaults = new Dictionary<string, List<object>>() + { + // key - (default, type) - null is freeform + { "apply_defaults", new List<object>() { false, typeof(bool) } }, + { "aliases", new List<object>() { typeof(List<string>), typeof(List<string>) } }, + { "choices", new List<object>() { typeof(List<object>), typeof(List<object>) } }, + { "default", new List<object>() { null, null } }, + { "deprecated_aliases", new List<object>() { typeof(List<Hashtable>), typeof(List<Hashtable>) } }, + { "elements", new List<object>() { null, null } }, + { "mutually_exclusive", new List<object>() { typeof(List<List<string>>), typeof(List<object>) } }, + { "no_log", new List<object>() { false, typeof(bool) } }, + { "options", new List<object>() { typeof(Hashtable), typeof(Hashtable) } }, + { "removed_in_version", new List<object>() { null, typeof(string) } }, + { "removed_at_date", new List<object>() { null, typeof(DateTime) } }, + { "removed_from_collection", new List<object>() { null, typeof(string) } }, + { "required", new List<object>() { false, typeof(bool) } }, + { "required_by", new List<object>() { typeof(Hashtable), typeof(Hashtable) } }, + { "required_if", new List<object>() { typeof(List<List<object>>), typeof(List<object>) } }, + { "required_one_of", new List<object>() { typeof(List<List<string>>), typeof(List<object>) } }, + { "required_together", new List<object>() { typeof(List<List<string>>), typeof(List<object>) } }, + { "supports_check_mode", new List<object>() { false, typeof(bool) } }, + { "type", new List<object>() { "str", null } }, + }; + private Dictionary<string, Delegate> optionTypes = new Dictionary<string, Delegate>() + { + { "bool", new Func<object, bool>(ParseBool) }, + { "dict", new Func<object, Dictionary<string, object>>(ParseDict) }, + { "float", new Func<object, float>(ParseFloat) }, + { "int", new Func<object, int>(ParseInt) }, + { "json", new Func<object, string>(ParseJson) }, + { "list", new Func<object, List<object>>(ParseList) }, + { "path", new Func<object, string>(ParsePath) }, + { "raw", new Func<object, object>(ParseRaw) }, + { "sid", new Func<object, SecurityIdentifier>(ParseSid) }, + { "str", new Func<object, string>(ParseStr) }, + }; + + public Dictionary<string, object> Diff = new Dictionary<string, object>(); + public IDictionary Params = null; + public Dictionary<string, object> Result = new Dictionary<string, object>() { { "changed", false } }; + + public bool CheckMode { get; private set; } + public bool DebugMode { get; private set; } + public bool DiffMode { get; private set; } + public bool KeepRemoteFiles { get; private set; } + public string ModuleName { get; private set; } + public bool NoLog { get; private set; } + public int Verbosity { get; private set; } + public string AnsibleVersion { get; private set; } + + public string Tmpdir + { + get + { + if (tmpdir == null) + { +#if WINDOWS + SecurityIdentifier user = WindowsIdentity.GetCurrent().User; + DirectorySecurity dirSecurity = new DirectorySecurity(); + dirSecurity.SetOwner(user); + dirSecurity.SetAccessRuleProtection(true, false); // disable inheritance rules + FileSystemAccessRule ace = new FileSystemAccessRule(user, FileSystemRights.FullControl, + InheritanceFlags.ContainerInherit | InheritanceFlags.ObjectInherit, + PropagationFlags.None, AccessControlType.Allow); + dirSecurity.AddAccessRule(ace); + + string baseDir = Path.GetFullPath(Environment.ExpandEnvironmentVariables(remoteTmp)); + if (!Directory.Exists(baseDir)) + { + string failedMsg = null; + try + { +#if CORECLR + DirectoryInfo createdDir = Directory.CreateDirectory(baseDir); + FileSystemAclExtensions.SetAccessControl(createdDir, dirSecurity); +#else + Directory.CreateDirectory(baseDir, dirSecurity); +#endif + } + catch (Exception e) + { + failedMsg = String.Format("Failed to create base tmpdir '{0}': {1}", baseDir, e.Message); + } + + if (failedMsg != null) + { + string envTmp = Path.GetTempPath(); + Warn(String.Format("Unable to use '{0}' as temporary directory, falling back to system tmp '{1}': {2}", baseDir, envTmp, failedMsg)); + baseDir = envTmp; + } + else + { + NTAccount currentUser = (NTAccount)user.Translate(typeof(NTAccount)); + string warnMsg = String.Format("Module remote_tmp {0} did not exist and was created with FullControl to {1}, ", baseDir, currentUser.ToString()); + warnMsg += "this may cause issues when running as another user. To avoid this, create the remote_tmp dir with the correct permissions manually"; + Warn(warnMsg); + } + } + + string dateTime = DateTime.Now.ToFileTime().ToString(); + string dirName = String.Format("ansible-moduletmp-{0}-{1}", dateTime, new Random().Next(0, int.MaxValue)); + string newTmpdir = Path.Combine(baseDir, dirName); +#if CORECLR + DirectoryInfo tmpdirInfo = Directory.CreateDirectory(newTmpdir); + FileSystemAclExtensions.SetAccessControl(tmpdirInfo, dirSecurity); +#else + Directory.CreateDirectory(newTmpdir, dirSecurity); +#endif + tmpdir = newTmpdir; + + if (!KeepRemoteFiles) + cleanupFiles.Add(tmpdir); +#else + throw new NotImplementedException("Tmpdir is only supported on Windows"); +#endif + } + return tmpdir; + } + } + + public AnsibleModule(string[] args, IDictionary argumentSpec, IDictionary[] fragments = null) + { + // NoLog is not set yet, we cannot rely on FailJson to sanitize the output + // Do the minimum amount to get this running before we actually parse the params + Dictionary<string, string> aliases = new Dictionary<string, string>(); + try + { + ValidateArgumentSpec(argumentSpec); + + // Merge the fragments if present into the main arg spec. + if (fragments != null) + { + foreach (IDictionary fragment in fragments) + { + ValidateArgumentSpec(fragment); + MergeFragmentSpec(argumentSpec, fragment); + } + } + + // Used by ansible-test to retrieve the module argument spec, not designed for public use. + if (_DebugArgSpec) + { + // Cannot call exit here because it will be caught with the catch (Exception e) below. Instead + // just throw a new exception with a specific message and the exception block will handle it. + ScriptBlock.Create("Set-Variable -Name ansibleTestArgSpec -Value $args[0] -Scope Global" + ).Invoke(argumentSpec); + throw new Exception("ansible-test validate-modules check"); + } + + // Now make sure all the metadata keys are set to their defaults, this must be done after we've + // potentially output the arg spec for ansible-test. + SetArgumentSpecDefaults(argumentSpec); + + Params = GetParams(args); + aliases = GetAliases(argumentSpec, Params); + SetNoLogValues(argumentSpec, Params); + } + catch (Exception e) + { + if (e.Message == "ansible-test validate-modules check") + Exit(0); + + Dictionary<string, object> result = new Dictionary<string, object> + { + { "failed", true }, + { "msg", String.Format("internal error: {0}", e.Message) }, + { "exception", e.ToString() } + }; + WriteLine(ToJson(result)); + Exit(1); + } + + // Initialise public properties to the defaults before we parse the actual inputs + CheckMode = false; + DebugMode = false; + DiffMode = false; + KeepRemoteFiles = false; + ModuleName = "undefined win module"; + NoLog = (bool)argumentSpec["no_log"]; + Verbosity = 0; + AppDomain.CurrentDomain.ProcessExit += CleanupFiles; + + List<string> legalInputs = passVars.Keys.Select(v => "_ansible_" + v).ToList(); + legalInputs.AddRange(((IDictionary)argumentSpec["options"]).Keys.Cast<string>().ToList()); + legalInputs.AddRange(aliases.Keys.Cast<string>().ToList()); + CheckArguments(argumentSpec, Params, legalInputs); + + // Set a Ansible friendly invocation value in the result object + Dictionary<string, object> invocation = new Dictionary<string, object>() { { "module_args", Params } }; + Result["invocation"] = RemoveNoLogValues(invocation, noLogValues); + + if (!NoLog) + LogEvent(String.Format("Invoked with:\r\n {0}", FormatLogData(Params, 2)), sanitise: false); + } + + public static AnsibleModule Create(string[] args, IDictionary argumentSpec, IDictionary[] fragments = null) + { + return new AnsibleModule(args, argumentSpec, fragments); + } + + public void Debug(string message) + { + if (DebugMode) + LogEvent(String.Format("[DEBUG] {0}", message)); + } + + public void Deprecate(string message, string version) + { + Deprecate(message, version, null); + } + + public void Deprecate(string message, string version, string collectionName) + { + deprecations.Add(new Dictionary<string, string>() { + { "msg", message }, { "version", version }, { "collection_name", collectionName } }); + LogEvent(String.Format("[DEPRECATION WARNING] {0} {1}", message, version)); + } + + public void Deprecate(string message, DateTime date) + { + Deprecate(message, date, null); + } + + public void Deprecate(string message, DateTime date, string collectionName) + { + string isoDate = date.ToString("yyyy-MM-dd"); + deprecations.Add(new Dictionary<string, string>() { + { "msg", message }, { "date", isoDate }, { "collection_name", collectionName } }); + LogEvent(String.Format("[DEPRECATION WARNING] {0} {1}", message, isoDate)); + } + + public void ExitJson() + { + WriteLine(GetFormattedResults(Result)); + CleanupFiles(null, null); + Exit(0); + } + + public void FailJson(string message) { FailJson(message, null, null); } + public void FailJson(string message, ErrorRecord psErrorRecord) { FailJson(message, psErrorRecord, null); } + public void FailJson(string message, Exception exception) { FailJson(message, null, exception); } + private void FailJson(string message, ErrorRecord psErrorRecord, Exception exception) + { + Result["failed"] = true; + Result["msg"] = RemoveNoLogValues(message, noLogValues); + + + if (!Result.ContainsKey("exception") && (Verbosity > 2 || DebugMode)) + { + if (psErrorRecord != null) + { + string traceback = String.Format("{0}\r\n{1}", psErrorRecord.ToString(), psErrorRecord.InvocationInfo.PositionMessage); + traceback += String.Format("\r\n + CategoryInfo : {0}", psErrorRecord.CategoryInfo.ToString()); + traceback += String.Format("\r\n + FullyQualifiedErrorId : {0}", psErrorRecord.FullyQualifiedErrorId.ToString()); + traceback += String.Format("\r\n\r\nScriptStackTrace:\r\n{0}", psErrorRecord.ScriptStackTrace); + Result["exception"] = traceback; + } + else if (exception != null) + Result["exception"] = exception.ToString(); + } + + WriteLine(GetFormattedResults(Result)); + CleanupFiles(null, null); + Exit(1); + } + + public void LogEvent(string message, EventLogEntryType logEntryType = EventLogEntryType.Information, bool sanitise = true) + { + if (NoLog) + return; + +#if WINDOWS + string logSource = "Ansible"; + bool logSourceExists = false; + try + { + logSourceExists = EventLog.SourceExists(logSource); + } + catch (System.Security.SecurityException) { } // non admin users may not have permission + + if (!logSourceExists) + { + try + { + EventLog.CreateEventSource(logSource, "Application"); + } + catch (System.Security.SecurityException) + { + // Cannot call Warn as that calls LogEvent and we get stuck in a loop + warnings.Add(String.Format("Access error when creating EventLog source {0}, logging to the Application source instead", logSource)); + logSource = "Application"; + } + } + if (sanitise) + message = (string)RemoveNoLogValues(message, noLogValues); + message = String.Format("{0} - {1}", ModuleName, message); + + using (EventLog eventLog = new EventLog("Application")) + { + eventLog.Source = logSource; + try + { + eventLog.WriteEntry(message, logEntryType, 0); + } + catch (System.InvalidOperationException) { } // Ignore permission errors on the Application event log + catch (System.Exception e) + { + // Cannot call Warn as that calls LogEvent and we get stuck in a loop + warnings.Add(String.Format("Unknown error when creating event log entry: {0}", e.Message)); + } + } +#else + // Windows Event Log is only available on Windows + return; +#endif + } + + public void Warn(string message) + { + warnings.Add(message); + LogEvent(String.Format("[WARNING] {0}", message), EventLogEntryType.Warning); + } + + public static object FromJson(string json) { return FromJson<object>(json); } + public static T FromJson<T>(string json) + { +#if CORECLR + return JsonConvert.DeserializeObject<T>(json); +#else + JavaScriptSerializer jss = new JavaScriptSerializer(); + jss.MaxJsonLength = int.MaxValue; + jss.RecursionLimit = int.MaxValue; + return jss.Deserialize<T>(json); +#endif + } + + public static string ToJson(object obj) + { + // Using PowerShell to serialize the JSON is preferable over the native .NET libraries as it handles + // PS Objects a lot better than the alternatives. In case we are debugging in Visual Studio we have a + // fallback to the other libraries as we won't be dealing with PowerShell objects there. + if (Runspace.DefaultRunspace != null) + { + PSObject rawOut = ScriptBlock.Create("ConvertTo-Json -InputObject $args[0] -Depth 99 -Compress").Invoke(obj)[0]; + return rawOut.BaseObject as string; + } + else + { +#if CORECLR + return JsonConvert.SerializeObject(obj); +#else + JavaScriptSerializer jss = new JavaScriptSerializer(); + jss.MaxJsonLength = int.MaxValue; + jss.RecursionLimit = int.MaxValue; + return jss.Serialize(obj); +#endif + } + } + + public static IDictionary GetParams(string[] args) + { + if (args.Length > 0) + { + string inputJson = File.ReadAllText(args[0]); + Dictionary<string, object> rawParams = FromJson<Dictionary<string, object>>(inputJson); + if (!rawParams.ContainsKey("ANSIBLE_MODULE_ARGS")) + throw new ArgumentException("Module was unable to get ANSIBLE_MODULE_ARGS value from the argument path json"); + return (IDictionary)rawParams["ANSIBLE_MODULE_ARGS"]; + } + else + { + // $complex_args is already a Hashtable, no need to waste time converting to a dictionary + PSObject rawArgs = ScriptBlock.Create("$complex_args").Invoke()[0]; + return rawArgs.BaseObject as Hashtable; + } + } + + public static bool ParseBool(object value) + { + if (value.GetType() == typeof(bool)) + return (bool)value; + + List<string> booleans = new List<string>(); + booleans.AddRange(BOOLEANS_TRUE); + booleans.AddRange(BOOLEANS_FALSE); + + string stringValue = ParseStr(value).ToLowerInvariant().Trim(); + if (BOOLEANS_TRUE.Contains(stringValue)) + return true; + else if (BOOLEANS_FALSE.Contains(stringValue)) + return false; + + string msg = String.Format("The value '{0}' is not a valid boolean. Valid booleans include: {1}", + stringValue, String.Join(", ", booleans)); + throw new ArgumentException(msg); + } + + public static Dictionary<string, object> ParseDict(object value) + { + Type valueType = value.GetType(); + if (valueType == typeof(Dictionary<string, object>)) + return (Dictionary<string, object>)value; + else if (value is IDictionary) + return ((IDictionary)value).Cast<DictionaryEntry>().ToDictionary(kvp => (string)kvp.Key, kvp => kvp.Value); + else if (valueType == typeof(string)) + { + string stringValue = (string)value; + if (stringValue.StartsWith("{") && stringValue.EndsWith("}")) + return FromJson<Dictionary<string, object>>((string)value); + else if (stringValue.IndexOfAny(new char[1] { '=' }) != -1) + { + List<string> fields = new List<string>(); + List<char> fieldBuffer = new List<char>(); + char? inQuote = null; + bool inEscape = false; + string field; + + foreach (char c in stringValue.ToCharArray()) + { + if (inEscape) + { + fieldBuffer.Add(c); + inEscape = false; + } + else if (c == '\\') + inEscape = true; + else if (inQuote == null && (c == '\'' || c == '"')) + inQuote = c; + else if (inQuote != null && c == inQuote) + inQuote = null; + else if (inQuote == null && (c == ',' || c == ' ')) + { + field = String.Join("", fieldBuffer); + if (field != "") + fields.Add(field); + fieldBuffer = new List<char>(); + } + else + fieldBuffer.Add(c); + } + + field = String.Join("", fieldBuffer); + if (field != "") + fields.Add(field); + + return fields.Distinct().Select(i => i.Split(new[] { '=' }, 2)).ToDictionary(i => i[0], i => i.Length > 1 ? (object)i[1] : null); + } + else + throw new ArgumentException("string cannot be converted to a dict, must either be a JSON string or in the key=value form"); + } + + throw new ArgumentException(String.Format("{0} cannot be converted to a dict", valueType.FullName)); + } + + public static float ParseFloat(object value) + { + if (value.GetType() == typeof(float)) + return (float)value; + + string valueStr = ParseStr(value); + return float.Parse(valueStr); + } + + public static int ParseInt(object value) + { + Type valueType = value.GetType(); + if (valueType == typeof(int)) + return (int)value; + else + return Int32.Parse(ParseStr(value)); + } + + public static string ParseJson(object value) + { + // mostly used to ensure a dict is a json string as it may + // have been converted on the controller side + Type valueType = value.GetType(); + if (value is IDictionary) + return ToJson(value); + else if (valueType == typeof(string)) + return (string)value; + else + throw new ArgumentException(String.Format("{0} cannot be converted to json", valueType.FullName)); + } + + public static List<object> ParseList(object value) + { + if (value == null) + return null; + + Type valueType = value.GetType(); + if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(List<>)) + return (List<object>)value; + else if (valueType == typeof(ArrayList)) + return ((ArrayList)value).Cast<object>().ToList(); + else if (valueType.IsArray) + return ((object[])value).ToList(); + else if (valueType == typeof(string)) + return ((string)value).Split(',').Select(s => s.Trim()).ToList<object>(); + else if (valueType == typeof(int)) + return new List<object>() { value }; + else + throw new ArgumentException(String.Format("{0} cannot be converted to a list", valueType.FullName)); + } + + public static string ParsePath(object value) + { + string stringValue = ParseStr(value); + + // do not validate, expand the env vars if it starts with \\?\ as + // it is a special path designed for the NT kernel to interpret + if (stringValue.StartsWith(@"\\?\")) + return stringValue; + + stringValue = Environment.ExpandEnvironmentVariables(stringValue); + if (stringValue.IndexOfAny(Path.GetInvalidPathChars()) != -1) + throw new ArgumentException("string value contains invalid path characters, cannot convert to path"); + + // will fire an exception if it contains any invalid chars + Path.GetFullPath(stringValue); + return stringValue; + } + + public static object ParseRaw(object value) { return value; } + + public static SecurityIdentifier ParseSid(object value) + { + string stringValue = ParseStr(value); + + try + { + return new SecurityIdentifier(stringValue); + } + catch (ArgumentException) { } // ignore failures string may not have been a SID + + NTAccount account = new NTAccount(stringValue); + return (SecurityIdentifier)account.Translate(typeof(SecurityIdentifier)); + } + + public static string ParseStr(object value) { return value.ToString(); } + + private void ValidateArgumentSpec(IDictionary argumentSpec) + { + Dictionary<string, object> changedValues = new Dictionary<string, object>(); + foreach (DictionaryEntry entry in argumentSpec) + { + string key = (string)entry.Key; + + // validate the key is a valid argument spec key + if (!specDefaults.ContainsKey(key)) + { + string msg = String.Format("argument spec entry contains an invalid key '{0}', valid keys: {1}", + key, String.Join(", ", specDefaults.Keys)); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + + // ensure the value is casted to the type we expect + Type optionType = null; + if (entry.Value != null) + optionType = (Type)specDefaults[key][1]; + if (optionType != null) + { + Type actualType = entry.Value.GetType(); + bool invalid = false; + if (optionType.IsGenericType && optionType.GetGenericTypeDefinition() == typeof(List<>)) + { + // verify the actual type is not just a single value of the list type + Type entryType = optionType.GetGenericArguments()[0]; + object[] arrayElementTypes = new object[] + { + null, // ArrayList does not have an ElementType + entryType, + typeof(object), // Hope the object is actually entryType or it can at least be casted. + }; + + bool isArray = entry.Value is IList && arrayElementTypes.Contains(actualType.GetElementType()); + if (actualType == entryType || isArray) + { + object rawArray; + if (isArray) + rawArray = entry.Value; + else + rawArray = new object[1] { entry.Value }; + + MethodInfo castMethod = typeof(Enumerable).GetMethod("Cast").MakeGenericMethod(entryType); + MethodInfo toListMethod = typeof(Enumerable).GetMethod("ToList").MakeGenericMethod(entryType); + + var enumerable = castMethod.Invoke(null, new object[1] { rawArray }); + var newList = toListMethod.Invoke(null, new object[1] { enumerable }); + changedValues.Add(key, newList); + } + else if (actualType != optionType && !(actualType == typeof(List<object>))) + invalid = true; + } + else + invalid = actualType != optionType; + + if (invalid) + { + string msg = String.Format("argument spec for '{0}' did not match expected type {1}: actual type {2}", + key, optionType.FullName, actualType.FullName); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + } + + // recursively validate the spec + if (key == "options" && entry.Value != null) + { + IDictionary optionsSpec = (IDictionary)entry.Value; + foreach (DictionaryEntry optionEntry in optionsSpec) + { + optionsContext.Add((string)optionEntry.Key); + IDictionary optionMeta = (IDictionary)optionEntry.Value; + ValidateArgumentSpec(optionMeta); + optionsContext.RemoveAt(optionsContext.Count - 1); + } + } + + // validate the type and elements key type values are known types + if (key == "type" || key == "elements" && entry.Value != null) + { + Type valueType = entry.Value.GetType(); + if (valueType == typeof(string)) + { + string typeValue = (string)entry.Value; + if (!optionTypes.ContainsKey(typeValue)) + { + string msg = String.Format("{0} '{1}' is unsupported", key, typeValue); + msg = String.Format("{0}. Valid types are: {1}", FormatOptionsContext(msg, " - "), String.Join(", ", optionTypes.Keys)); + throw new ArgumentException(msg); + } + } + else if (!(entry.Value is Delegate)) + { + string msg = String.Format("{0} must either be a string or delegate, was: {1}", key, valueType.FullName); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + } + } + + // Outside of the spec iterator, change the values that were casted above + foreach (KeyValuePair<string, object> changedValue in changedValues) + argumentSpec[changedValue.Key] = changedValue.Value; + } + + private void MergeFragmentSpec(IDictionary argumentSpec, IDictionary fragment) + { + foreach (DictionaryEntry fragmentEntry in fragment) + { + string fragmentKey = fragmentEntry.Key.ToString(); + + if (argumentSpec.Contains(fragmentKey)) + { + // We only want to add new list entries and merge dictionary new keys and values. Leave the other + // values as is in the argument spec as that takes priority over the fragment. + if (fragmentEntry.Value is IDictionary) + { + MergeFragmentSpec((IDictionary)argumentSpec[fragmentKey], (IDictionary)fragmentEntry.Value); + } + else if (fragmentEntry.Value is IList) + { + IList specValue = (IList)argumentSpec[fragmentKey]; + foreach (object fragmentValue in (IList)fragmentEntry.Value) + specValue.Add(fragmentValue); + } + } + else + argumentSpec[fragmentKey] = fragmentEntry.Value; + } + } + + private void SetArgumentSpecDefaults(IDictionary argumentSpec) + { + foreach (KeyValuePair<string, List<object>> metadataEntry in specDefaults) + { + List<object> defaults = metadataEntry.Value; + object defaultValue = defaults[0]; + if (defaultValue != null && defaultValue.GetType() == typeof(Type).GetType()) + defaultValue = Activator.CreateInstance((Type)defaultValue); + + if (!argumentSpec.Contains(metadataEntry.Key)) + argumentSpec[metadataEntry.Key] = defaultValue; + } + + // Recursively set the defaults for any inner options. + foreach (DictionaryEntry entry in argumentSpec) + { + if (entry.Value == null || entry.Key.ToString() != "options") + continue; + + IDictionary optionsSpec = (IDictionary)entry.Value; + foreach (DictionaryEntry optionEntry in optionsSpec) + { + optionsContext.Add((string)optionEntry.Key); + IDictionary optionMeta = (IDictionary)optionEntry.Value; + SetArgumentSpecDefaults(optionMeta); + optionsContext.RemoveAt(optionsContext.Count - 1); + } + } + } + + private Dictionary<string, string> GetAliases(IDictionary argumentSpec, IDictionary parameters) + { + Dictionary<string, string> aliasResults = new Dictionary<string, string>(); + + foreach (DictionaryEntry entry in (IDictionary)argumentSpec["options"]) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + List<string> aliases = (List<string>)v["aliases"]; + object defaultValue = v["default"]; + bool required = (bool)v["required"]; + + if (defaultValue != null && required) + throw new ArgumentException(String.Format("required and default are mutually exclusive for {0}", k)); + + foreach (string alias in aliases) + { + aliasResults.Add(alias, k); + if (parameters.Contains(alias)) + parameters[k] = parameters[alias]; + } + + List<Hashtable> deprecatedAliases = (List<Hashtable>)v["deprecated_aliases"]; + foreach (Hashtable depInfo in deprecatedAliases) + { + foreach (string keyName in new List<string> { "name" }) + { + if (!depInfo.ContainsKey(keyName)) + { + string msg = String.Format("{0} is required in a deprecated_aliases entry", keyName); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + } + if (!depInfo.ContainsKey("version") && !depInfo.ContainsKey("date")) + { + string msg = "One of version or date is required in a deprecated_aliases entry"; + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + if (depInfo.ContainsKey("version") && depInfo.ContainsKey("date")) + { + string msg = "Only one of version or date is allowed in a deprecated_aliases entry"; + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + if (depInfo.ContainsKey("date") && depInfo["date"].GetType() != typeof(DateTime)) + { + string msg = "A deprecated_aliases date must be a DateTime object"; + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + string collectionName = null; + if (depInfo.ContainsKey("collection_name")) + { + collectionName = (string)depInfo["collection_name"]; + } + string aliasName = (string)depInfo["name"]; + + if (parameters.Contains(aliasName)) + { + string msg = String.Format("Alias '{0}' is deprecated. See the module docs for more information", aliasName); + if (depInfo.ContainsKey("version")) + { + string depVersion = (string)depInfo["version"]; + Deprecate(FormatOptionsContext(msg, " - "), depVersion, collectionName); + } + if (depInfo.ContainsKey("date")) + { + DateTime depDate = (DateTime)depInfo["date"]; + Deprecate(FormatOptionsContext(msg, " - "), depDate, collectionName); + } + } + } + } + + return aliasResults; + } + + private void SetNoLogValues(IDictionary argumentSpec, IDictionary parameters) + { + foreach (DictionaryEntry entry in (IDictionary)argumentSpec["options"]) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + if ((bool)v["no_log"]) + { + object noLogObject = parameters.Contains(k) ? parameters[k] : null; + string noLogString = noLogObject == null ? "" : noLogObject.ToString(); + if (!String.IsNullOrEmpty(noLogString)) + noLogValues.Add(noLogString); + } + string collectionName = null; + if (v.ContainsKey("removed_from_collection")) + { + collectionName = (string)v["removed_from_collection"]; + } + + object removedInVersion = v["removed_in_version"]; + if (removedInVersion != null && parameters.Contains(k)) + Deprecate(String.Format("Param '{0}' is deprecated. See the module docs for more information", k), + removedInVersion.ToString(), collectionName); + + object removedAtDate = v["removed_at_date"]; + if (removedAtDate != null && parameters.Contains(k)) + Deprecate(String.Format("Param '{0}' is deprecated. See the module docs for more information", k), + (DateTime)removedAtDate, collectionName); + } + } + + private void CheckArguments(IDictionary spec, IDictionary param, List<string> legalInputs) + { + // initially parse the params and check for unsupported ones and set internal vars + CheckUnsupportedArguments(param, legalInputs); + + // Only run this check if we are at the root argument (optionsContext.Count == 0) + if (CheckMode && !(bool)spec["supports_check_mode"] && optionsContext.Count == 0) + { + Result["skipped"] = true; + Result["msg"] = String.Format("remote module ({0}) does not support check mode", ModuleName); + ExitJson(); + } + IDictionary optionSpec = (IDictionary)spec["options"]; + + CheckMutuallyExclusive(param, (IList)spec["mutually_exclusive"]); + CheckRequiredArguments(optionSpec, param); + + // set the parameter types based on the type spec value + foreach (DictionaryEntry entry in optionSpec) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + object value = param.Contains(k) ? param[k] : null; + if (value != null) + { + // convert the current value to the wanted type + Delegate typeConverter; + string type; + if (v["type"].GetType() == typeof(string)) + { + type = (string)v["type"]; + typeConverter = optionTypes[type]; + } + else + { + type = "delegate"; + typeConverter = (Delegate)v["type"]; + } + + try + { + value = typeConverter.DynamicInvoke(value); + param[k] = value; + } + catch (Exception e) + { + string msg = String.Format("argument for {0} is of type {1} and we were unable to convert to {2}: {3}", + k, value.GetType(), type, e.InnerException.Message); + FailJson(FormatOptionsContext(msg)); + } + + // ensure it matches the choices if there are choices set + List<string> choices = ((List<object>)v["choices"]).Select(x => x.ToString()).Cast<string>().ToList(); + if (choices.Count > 0) + { + List<string> values; + string choiceMsg; + if (type == "list") + { + values = ((List<object>)value).Select(x => x.ToString()).Cast<string>().ToList(); + choiceMsg = "one or more of"; + } + else + { + values = new List<string>() { value.ToString() }; + choiceMsg = "one of"; + } + + List<string> diffList = values.Except(choices, StringComparer.OrdinalIgnoreCase).ToList(); + List<string> caseDiffList = values.Except(choices).ToList(); + if (diffList.Count > 0) + { + string msg = String.Format("value of {0} must be {1}: {2}. Got no match for: {3}", + k, choiceMsg, String.Join(", ", choices), String.Join(", ", diffList)); + FailJson(FormatOptionsContext(msg)); + } + /* + For now we will just silently accept case insensitive choices, uncomment this if we want to add it back in + else if (caseDiffList.Count > 0) + { + // For backwards compatibility with Legacy.psm1 we need to be matching choices that are not case sensitive. + // We will warn the user it was case insensitive and tell them this will become case sensitive in the future. + string msg = String.Format( + "value of {0} was a case insensitive match of {1}: {2}. Checking of choices will be case sensitive in a future Ansible release. Case insensitive matches were: {3}", + k, choiceMsg, String.Join(", ", choices), String.Join(", ", caseDiffList.Select(x => RemoveNoLogValues(x, noLogValues))) + ); + Warn(FormatOptionsContext(msg)); + }*/ + } + } + } + + CheckRequiredTogether(param, (IList)spec["required_together"]); + CheckRequiredOneOf(param, (IList)spec["required_one_of"]); + CheckRequiredIf(param, (IList)spec["required_if"]); + CheckRequiredBy(param, (IDictionary)spec["required_by"]); + + // finally ensure all missing parameters are set to null and handle sub options + foreach (DictionaryEntry entry in optionSpec) + { + string k = (string)entry.Key; + IDictionary v = (IDictionary)entry.Value; + + if (!param.Contains(k)) + param[k] = null; + + CheckSubOption(param, k, v); + } + } + + private void CheckUnsupportedArguments(IDictionary param, List<string> legalInputs) + { + HashSet<string> unsupportedParameters = new HashSet<string>(); + HashSet<string> caseUnsupportedParameters = new HashSet<string>(); + List<string> removedParameters = new List<string>(); + + foreach (DictionaryEntry entry in param) + { + string paramKey = (string)entry.Key; + if (!legalInputs.Contains(paramKey, StringComparer.OrdinalIgnoreCase)) + unsupportedParameters.Add(paramKey); + else if (!legalInputs.Contains(paramKey)) + // For backwards compatibility we do not care about the case but we need to warn the users as this will + // change in a future Ansible release. + caseUnsupportedParameters.Add(paramKey); + else if (paramKey.StartsWith("_ansible_")) + { + removedParameters.Add(paramKey); + string key = paramKey.Replace("_ansible_", ""); + // skip setting NoLog if NoLog is already set to true (set by the module) + // or there's no mapping for this key + if ((key == "no_log" && NoLog == true) || (passVars[key] == null)) + continue; + + object value = entry.Value; + if (passBools.Contains(key)) + value = ParseBool(value); + else if (passInts.Contains(key)) + value = ParseInt(value); + + string propertyName = passVars[key]; + PropertyInfo property = typeof(AnsibleModule).GetProperty(propertyName); + FieldInfo field = typeof(AnsibleModule).GetField(propertyName, BindingFlags.NonPublic | BindingFlags.Instance); + if (property != null) + property.SetValue(this, value, null); + else if (field != null) + field.SetValue(this, value); + else + FailJson(String.Format("implementation error: unknown AnsibleModule property {0}", propertyName)); + } + } + foreach (string parameter in removedParameters) + param.Remove(parameter); + + if (unsupportedParameters.Count > 0) + { + legalInputs.RemoveAll(x => passVars.Keys.Contains(x.Replace("_ansible_", ""))); + string msg = String.Format("Unsupported parameters for ({0}) module: {1}", ModuleName, String.Join(", ", unsupportedParameters)); + msg = String.Format("{0}. Supported parameters include: {1}", FormatOptionsContext(msg), String.Join(", ", legalInputs)); + FailJson(msg); + } + + /* + // Uncomment when we want to start warning users around options that are not a case sensitive match to the spec + if (caseUnsupportedParameters.Count > 0) + { + legalInputs.RemoveAll(x => passVars.Keys.Contains(x.Replace("_ansible_", ""))); + string msg = String.Format("Parameters for ({0}) was a case insensitive match: {1}", ModuleName, String.Join(", ", caseUnsupportedParameters)); + msg = String.Format("{0}. Module options will become case sensitive in a future Ansible release. Supported parameters include: {1}", + FormatOptionsContext(msg), String.Join(", ", legalInputs)); + Warn(msg); + }*/ + + // Make sure we convert all the incorrect case params to the ones set by the module spec + foreach (string key in caseUnsupportedParameters) + { + string correctKey = legalInputs[legalInputs.FindIndex(s => s.Equals(key, StringComparison.OrdinalIgnoreCase))]; + object value = param[key]; + param.Remove(key); + param.Add(correctKey, value); + } + } + + private void CheckMutuallyExclusive(IDictionary param, IList mutuallyExclusive) + { + if (mutuallyExclusive == null) + return; + + foreach (object check in mutuallyExclusive) + { + List<string> mutualCheck = ((IList)check).Cast<string>().ToList(); + int count = 0; + foreach (string entry in mutualCheck) + if (param.Contains(entry)) + count++; + + if (count > 1) + { + string msg = String.Format("parameters are mutually exclusive: {0}", String.Join(", ", mutualCheck)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredArguments(IDictionary spec, IDictionary param) + { + List<string> missing = new List<string>(); + foreach (DictionaryEntry entry in spec) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + // set defaults for values not already set + object defaultValue = v["default"]; + if (defaultValue != null && !param.Contains(k)) + param[k] = defaultValue; + + // check required arguments + bool required = (bool)v["required"]; + if (required && !param.Contains(k)) + missing.Add(k); + } + if (missing.Count > 0) + { + string msg = String.Format("missing required arguments: {0}", String.Join(", ", missing)); + FailJson(FormatOptionsContext(msg)); + } + } + + private void CheckRequiredTogether(IDictionary param, IList requiredTogether) + { + if (requiredTogether == null) + return; + + foreach (object check in requiredTogether) + { + List<string> requiredCheck = ((IList)check).Cast<string>().ToList(); + List<bool> found = new List<bool>(); + foreach (string field in requiredCheck) + if (param.Contains(field)) + found.Add(true); + else + found.Add(false); + + if (found.Contains(true) && found.Contains(false)) + { + string msg = String.Format("parameters are required together: {0}", String.Join(", ", requiredCheck)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredOneOf(IDictionary param, IList requiredOneOf) + { + if (requiredOneOf == null) + return; + + foreach (object check in requiredOneOf) + { + List<string> requiredCheck = ((IList)check).Cast<string>().ToList(); + int count = 0; + foreach (string field in requiredCheck) + if (param.Contains(field)) + count++; + + if (count == 0) + { + string msg = String.Format("one of the following is required: {0}", String.Join(", ", requiredCheck)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredIf(IDictionary param, IList requiredIf) + { + if (requiredIf == null) + return; + + foreach (object check in requiredIf) + { + IList requiredCheck = (IList)check; + List<string> missing = new List<string>(); + List<string> missingFields = new List<string>(); + int maxMissingCount = 1; + bool oneRequired = false; + + if (requiredCheck.Count < 3 && requiredCheck.Count < 4) + FailJson(String.Format("internal error: invalid required_if value count of {0}, expecting 3 or 4 entries", requiredCheck.Count)); + else if (requiredCheck.Count == 4) + oneRequired = (bool)requiredCheck[3]; + + string key = (string)requiredCheck[0]; + object val = requiredCheck[1]; + IList requirements = (IList)requiredCheck[2]; + + if (ParseStr(param[key]) != ParseStr(val)) + continue; + + string term = "all"; + if (oneRequired) + { + maxMissingCount = requirements.Count; + term = "any"; + } + + foreach (string required in requirements.Cast<string>()) + if (!param.Contains(required)) + missing.Add(required); + + if (missing.Count >= maxMissingCount) + { + string msg = String.Format("{0} is {1} but {2} of the following are missing: {3}", + key, val.ToString(), term, String.Join(", ", missing)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredBy(IDictionary param, IDictionary requiredBy) + { + foreach (DictionaryEntry entry in requiredBy) + { + string key = (string)entry.Key; + if (!param.Contains(key)) + continue; + + List<string> missing = new List<string>(); + List<string> requires = ParseList(entry.Value).Cast<string>().ToList(); + foreach (string required in requires) + if (!param.Contains(required)) + missing.Add(required); + + if (missing.Count > 0) + { + string msg = String.Format("missing parameter(s) required by '{0}': {1}", key, String.Join(", ", missing)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckSubOption(IDictionary param, string key, IDictionary spec) + { + object value = param[key]; + + string type; + if (spec["type"].GetType() == typeof(string)) + type = (string)spec["type"]; + else + type = "delegate"; + + string elements = null; + Delegate typeConverter = null; + if (spec["elements"] != null && spec["elements"].GetType() == typeof(string)) + { + elements = (string)spec["elements"]; + typeConverter = optionTypes[elements]; + } + else if (spec["elements"] != null) + { + elements = "delegate"; + typeConverter = (Delegate)spec["elements"]; + } + + if (!(type == "dict" || (type == "list" && elements != null))) + // either not a dict, or list with the elements set, so continue + return; + else if (type == "list") + { + // cast each list element to the type specified + if (value == null) + return; + + List<object> newValue = new List<object>(); + foreach (object element in (List<object>)value) + { + if (elements == "dict") + newValue.Add(ParseSubSpec(spec, element, key)); + else + { + try + { + object newElement = typeConverter.DynamicInvoke(element); + newValue.Add(newElement); + } + catch (Exception e) + { + string msg = String.Format("argument for list entry {0} is of type {1} and we were unable to convert to {2}: {3}", + key, element.GetType(), elements, e.Message); + FailJson(FormatOptionsContext(msg)); + } + } + } + + param[key] = newValue; + } + else + param[key] = ParseSubSpec(spec, value, key); + } + + private object ParseSubSpec(IDictionary spec, object value, string context) + { + bool applyDefaults = (bool)spec["apply_defaults"]; + + // set entry to an empty dict if apply_defaults is set + IDictionary optionsSpec = (IDictionary)spec["options"]; + if (applyDefaults && optionsSpec.Keys.Count > 0 && value == null) + value = new Dictionary<string, object>(); + else if (optionsSpec.Keys.Count == 0 || value == null) + return value; + + optionsContext.Add(context); + Dictionary<string, object> newValue = (Dictionary<string, object>)ParseDict(value); + Dictionary<string, string> aliases = GetAliases(spec, newValue); + SetNoLogValues(spec, newValue); + + List<string> subLegalInputs = optionsSpec.Keys.Cast<string>().ToList(); + subLegalInputs.AddRange(aliases.Keys.Cast<string>().ToList()); + + CheckArguments(spec, newValue, subLegalInputs); + optionsContext.RemoveAt(optionsContext.Count - 1); + return newValue; + } + + private string GetFormattedResults(Dictionary<string, object> result) + { + if (!result.ContainsKey("invocation")) + result["invocation"] = new Dictionary<string, object>() { { "module_args", RemoveNoLogValues(Params, noLogValues) } }; + + if (warnings.Count > 0) + result["warnings"] = warnings; + + if (deprecations.Count > 0) + result["deprecations"] = deprecations; + + if (Diff.Count > 0 && DiffMode) + result["diff"] = Diff; + + return ToJson(result); + } + + private string FormatLogData(object data, int indentLevel) + { + if (data == null) + return "$null"; + + string msg = ""; + if (data is IList) + { + string newMsg = ""; + foreach (object value in (IList)data) + { + string entryValue = FormatLogData(value, indentLevel + 2); + newMsg += String.Format("\r\n{0}- {1}", new String(' ', indentLevel), entryValue); + } + msg += newMsg; + } + else if (data is IDictionary) + { + bool start = true; + foreach (DictionaryEntry entry in (IDictionary)data) + { + string newMsg = FormatLogData(entry.Value, indentLevel + 2); + if (!start) + msg += String.Format("\r\n{0}", new String(' ', indentLevel)); + msg += String.Format("{0}: {1}", (string)entry.Key, newMsg); + start = false; + } + } + else + msg = (string)RemoveNoLogValues(ParseStr(data), noLogValues); + + return msg; + } + + private object RemoveNoLogValues(object value, HashSet<string> noLogStrings) + { + Queue<Tuple<object, object>> deferredRemovals = new Queue<Tuple<object, object>>(); + object newValue = RemoveValueConditions(value, noLogStrings, deferredRemovals); + + while (deferredRemovals.Count > 0) + { + Tuple<object, object> data = deferredRemovals.Dequeue(); + object oldData = data.Item1; + object newData = data.Item2; + + if (oldData is IDictionary) + { + foreach (DictionaryEntry entry in (IDictionary)oldData) + { + object newElement = RemoveValueConditions(entry.Value, noLogStrings, deferredRemovals); + ((IDictionary)newData).Add((string)entry.Key, newElement); + } + } + else + { + foreach (object element in (IList)oldData) + { + object newElement = RemoveValueConditions(element, noLogStrings, deferredRemovals); + ((IList)newData).Add(newElement); + } + } + } + + return newValue; + } + + private object RemoveValueConditions(object value, HashSet<string> noLogStrings, Queue<Tuple<object, object>> deferredRemovals) + { + if (value == null) + return value; + + Type valueType = value.GetType(); + HashSet<Type> numericTypes = new HashSet<Type> + { + typeof(byte), typeof(sbyte), typeof(short), typeof(ushort), typeof(int), typeof(uint), + typeof(long), typeof(ulong), typeof(decimal), typeof(double), typeof(float) + }; + + if (numericTypes.Contains(valueType) || valueType == typeof(bool)) + { + string valueString = ParseStr(value); + if (noLogStrings.Contains(valueString)) + return "VALUE_SPECIFIED_IN_NO_LOG_PARAMETER"; + foreach (string omitMe in noLogStrings) + if (valueString.Contains(omitMe)) + return "VALUE_SPECIFIED_IN_NO_LOG_PARAMETER"; + } + else if (valueType == typeof(DateTime)) + value = ((DateTime)value).ToString("o"); + else if (value is IList) + { + List<object> newValue = new List<object>(); + deferredRemovals.Enqueue(new Tuple<object, object>((IList)value, newValue)); + value = newValue; + } + else if (value is IDictionary) + { + Hashtable newValue = new Hashtable(); + deferredRemovals.Enqueue(new Tuple<object, object>((IDictionary)value, newValue)); + value = newValue; + } + else + { + string stringValue = value.ToString(); + if (noLogStrings.Contains(stringValue)) + return "VALUE_SPECIFIED_IN_NO_LOG_PARAMETER"; + foreach (string omitMe in noLogStrings) + if (stringValue.Contains(omitMe)) + return (stringValue).Replace(omitMe, "********"); + value = stringValue; + } + return value; + } + + private void CleanupFiles(object s, EventArgs ev) + { + foreach (string path in cleanupFiles) + { + if (File.Exists(path)) + File.Delete(path); + else if (Directory.Exists(path)) + Directory.Delete(path, true); + } + cleanupFiles = new List<string>(); + } + + private string FormatOptionsContext(string msg, string prefix = " ") + { + if (optionsContext.Count > 0) + msg += String.Format("{0}found in {1}", prefix, String.Join(" -> ", optionsContext)); + return msg; + } + + [DllImport("kernel32.dll")] + private static extern IntPtr GetConsoleWindow(); + + private static void ExitModule(int rc) + { + // When running in a Runspace Environment.Exit will kill the entire + // process which is not what we want, detect if we are in a + // Runspace and call a ScriptBlock with exit instead. + if (Runspace.DefaultRunspace != null) + ScriptBlock.Create("Set-Variable -Name LASTEXITCODE -Value $args[0] -Scope Global; exit $args[0]").Invoke(rc); + else + { + // Used for local debugging in Visual Studio + if (System.Diagnostics.Debugger.IsAttached) + { + Console.WriteLine("Press enter to continue..."); + Console.ReadLine(); + } + Environment.Exit(rc); + } + } + + private static void WriteLineModule(string line) + { + Console.WriteLine(line); + } + } +} diff --git a/lib/ansible/module_utils/csharp/Ansible.Become.cs b/lib/ansible/module_utils/csharp/Ansible.Become.cs new file mode 100644 index 0000000..a6f645c --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Become.cs @@ -0,0 +1,655 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.AccessControl; +using System.Security.Principal; +using System.Text; +using Ansible.AccessToken; +using Ansible.Process; + +namespace Ansible.Become +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct KERB_S4U_LOGON + { + public UInt32 MessageType; + public UInt32 Flags; + public LSA_UNICODE_STRING ClientUpn; + public LSA_UNICODE_STRING ClientRealm; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + public struct LSA_STRING + { + public UInt16 Length; + public UInt16 MaximumLength; + [MarshalAs(UnmanagedType.LPStr)] public string Buffer; + + public static implicit operator string(LSA_STRING s) + { + return s.Buffer; + } + + public static implicit operator LSA_STRING(string s) + { + if (s == null) + s = ""; + + LSA_STRING lsaStr = new LSA_STRING + { + Buffer = s, + Length = (UInt16)s.Length, + MaximumLength = (UInt16)(s.Length + 1), + }; + return lsaStr; + } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct LSA_UNICODE_STRING + { + public UInt16 Length; + public UInt16 MaximumLength; + public IntPtr Buffer; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SECURITY_LOGON_SESSION_DATA + { + public UInt32 Size; + public Luid LogonId; + public LSA_UNICODE_STRING UserName; + public LSA_UNICODE_STRING LogonDomain; + public LSA_UNICODE_STRING AuthenticationPackage; + public SECURITY_LOGON_TYPE LogonType; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_SOURCE + { + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 8)] public char[] SourceName; + public Luid SourceIdentifier; + } + + public enum SECURITY_LOGON_TYPE + { + System = 0, // Used only by the System account + Interactive = 2, + Network, + Batch, + Service, + Proxy, + Unlock, + NetworkCleartext, + NewCredentials, + RemoteInteractive, + CachedInteractive, + CachedRemoteInteractive, + CachedUnlock + } + } + + internal class NativeMethods + { + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool AllocateLocallyUniqueId( + out Luid Luid); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool CreateProcessWithTokenW( + SafeNativeHandle hToken, + LogonFlags dwLogonFlags, + [MarshalAs(UnmanagedType.LPWStr)] string lpApplicationName, + StringBuilder lpCommandLine, + Process.NativeHelpers.ProcessCreationFlags dwCreationFlags, + Process.SafeMemoryBuffer lpEnvironment, + [MarshalAs(UnmanagedType.LPWStr)] string lpCurrentDirectory, + Process.NativeHelpers.STARTUPINFOEX lpStartupInfo, + out Process.NativeHelpers.PROCESS_INFORMATION lpProcessInformation); + + [DllImport("kernel32.dll")] + public static extern UInt32 GetCurrentThreadId(); + + [DllImport("user32.dll", SetLastError = true)] + public static extern NoopSafeHandle GetProcessWindowStation(); + + [DllImport("user32.dll", SetLastError = true)] + public static extern NoopSafeHandle GetThreadDesktop( + UInt32 dwThreadId); + + [DllImport("secur32.dll", SetLastError = true)] + public static extern UInt32 LsaDeregisterLogonProcess( + IntPtr LsaHandle); + + [DllImport("secur32.dll", SetLastError = true)] + public static extern UInt32 LsaFreeReturnBuffer( + IntPtr Buffer); + + [DllImport("secur32.dll", SetLastError = true)] + public static extern UInt32 LsaGetLogonSessionData( + ref Luid LogonId, + out SafeLsaMemoryBuffer ppLogonSessionData); + + [DllImport("secur32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern UInt32 LsaLogonUser( + SafeLsaHandle LsaHandle, + NativeHelpers.LSA_STRING OriginName, + LogonType LogonType, + UInt32 AuthenticationPackage, + IntPtr AuthenticationInformation, + UInt32 AuthenticationInformationLength, + IntPtr LocalGroups, + NativeHelpers.TOKEN_SOURCE SourceContext, + out SafeLsaMemoryBuffer ProfileBuffer, + out UInt32 ProfileBufferLength, + out Luid LogonId, + out SafeNativeHandle Token, + out IntPtr Quotas, + out UInt32 SubStatus); + + [DllImport("secur32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern UInt32 LsaLookupAuthenticationPackage( + SafeLsaHandle LsaHandle, + NativeHelpers.LSA_STRING PackageName, + out UInt32 AuthenticationPackage); + + [DllImport("advapi32.dll")] + public static extern UInt32 LsaNtStatusToWinError( + UInt32 Status); + + [DllImport("secur32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern UInt32 LsaRegisterLogonProcess( + NativeHelpers.LSA_STRING LogonProcessName, + out SafeLsaHandle LsaHandle, + out IntPtr SecurityMode); + } + + internal class SafeLsaHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeLsaHandle() : base(true) { } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + UInt32 res = NativeMethods.LsaDeregisterLogonProcess(handle); + return res == 0; + } + } + + internal class SafeLsaMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeLsaMemoryBuffer() : base(true) { } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + UInt32 res = NativeMethods.LsaFreeReturnBuffer(handle); + return res == 0; + } + } + + internal class NoopSafeHandle : SafeHandle + { + public NoopSafeHandle() : base(IntPtr.Zero, false) { } + public override bool IsInvalid { get { return false; } } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() { return true; } + } + + [Flags] + public enum LogonFlags + { + WithProfile = 0x00000001, + NetcredentialsOnly = 0x00000002 + } + + public class BecomeUtil + { + private static List<string> SERVICE_SIDS = new List<string>() + { + "S-1-5-18", // NT AUTHORITY\SYSTEM + "S-1-5-19", // NT AUTHORITY\LocalService + "S-1-5-20" // NT AUTHORITY\NetworkService + }; + private static int WINDOWS_STATION_ALL_ACCESS = 0x000F037F; + private static int DESKTOP_RIGHTS_ALL_ACCESS = 0x000F01FF; + + public static Result CreateProcessAsUser(string username, string password, string command) + { + return CreateProcessAsUser(username, password, LogonFlags.WithProfile, LogonType.Interactive, + null, command, null, null, ""); + } + + public static Result CreateProcessAsUser(string username, string password, LogonFlags logonFlags, LogonType logonType, + string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, + string stdin) + { + byte[] stdinBytes; + if (String.IsNullOrEmpty(stdin)) + stdinBytes = new byte[0]; + else + { + if (!stdin.EndsWith(Environment.NewLine)) + stdin += Environment.NewLine; + stdinBytes = new UTF8Encoding(false).GetBytes(stdin); + } + return CreateProcessAsUser(username, password, logonFlags, logonType, lpApplicationName, lpCommandLine, + lpCurrentDirectory, environment, stdinBytes); + } + + /// <summary> + /// Creates a process as another user account. This method will attempt to run as another user with the + /// highest possible permissions available. The main privilege required is the SeDebugPrivilege, without + /// this privilege you can only run as a local or domain user if the username and password is specified. + /// </summary> + /// <param name="username">The username of the runas user</param> + /// <param name="password">The password of the runas user</param> + /// <param name="logonFlags">LogonFlags to control how to logon a user when the password is specified</param> + /// <param name="logonType">Controls what type of logon is used, this only applies when the password is specified</param> + /// <param name="lpApplicationName">The name of the executable or batch file to executable</param> + /// <param name="lpCommandLine">The command line to execute, typically this includes lpApplication as the first argument</param> + /// <param name="lpCurrentDirectory">The full path to the current directory for the process, null will have the same cwd as the calling process</param> + /// <param name="environment">A dictionary of key/value pairs to define the new process environment</param> + /// <param name="stdin">Bytes sent to the stdin pipe</param> + /// <returns>Ansible.Process.Result object that contains the command output and return code</returns> + public static Result CreateProcessAsUser(string username, string password, LogonFlags logonFlags, LogonType logonType, + string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, byte[] stdin) + { + // While we use STARTUPINFOEX having EXTENDED_STARTUPINFO_PRESENT causes a parameter validation error + Process.NativeHelpers.ProcessCreationFlags creationFlags = Process.NativeHelpers.ProcessCreationFlags.CREATE_UNICODE_ENVIRONMENT; + Process.NativeHelpers.PROCESS_INFORMATION pi = new Process.NativeHelpers.PROCESS_INFORMATION(); + Process.NativeHelpers.STARTUPINFOEX si = new Process.NativeHelpers.STARTUPINFOEX(); + si.startupInfo.dwFlags = Process.NativeHelpers.StartupInfoFlags.USESTDHANDLES; + + SafeFileHandle stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinRead, stdinWrite; + ProcessUtil.CreateStdioPipes(si, out stdoutRead, out stdoutWrite, out stderrRead, out stderrWrite, + out stdinRead, out stdinWrite); + FileStream stdinStream = new FileStream(stdinWrite, FileAccess.Write); + + // $null from PowerShell ends up as an empty string, we need to convert back as an empty string doesn't + // make sense for these parameters + if (lpApplicationName == "") + lpApplicationName = null; + + if (lpCurrentDirectory == "") + lpCurrentDirectory = null; + + // A user may have 2 tokens, 1 limited and 1 elevated. GetUserTokens will return both token to ensure + // we don't close one of the pairs while the process is still running. If the process tries to retrieve + // one of the pairs and the token handle is closed then it will fail with ERROR_NO_SUCH_LOGON_SESSION. + List<SafeNativeHandle> userTokens = GetUserTokens(username, password, logonType); + try + { + using (Process.SafeMemoryBuffer lpEnvironment = ProcessUtil.CreateEnvironmentPointer(environment)) + { + bool launchSuccess = false; + StringBuilder commandLine = new StringBuilder(lpCommandLine); + foreach (SafeNativeHandle token in userTokens) + { + // GetUserTokens could return null if an elevated token could not be retrieved. + if (token == null) + continue; + + if (NativeMethods.CreateProcessWithTokenW(token, logonFlags, lpApplicationName, + commandLine, creationFlags, lpEnvironment, lpCurrentDirectory, si, out pi)) + { + launchSuccess = true; + break; + } + } + + if (!launchSuccess) + throw new Process.Win32Exception("CreateProcessWithTokenW() failed"); + } + return ProcessUtil.WaitProcess(stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinStream, stdin, + pi.hProcess); + } + finally + { + userTokens.Where(t => t != null).ToList().ForEach(t => t.Dispose()); + } + } + + private static List<SafeNativeHandle> GetUserTokens(string username, string password, LogonType logonType) + { + List<SafeNativeHandle> userTokens = new List<SafeNativeHandle>(); + + SafeNativeHandle systemToken = null; + bool impersonated = false; + string becomeSid = username; + if (logonType != LogonType.NewCredentials) + { + // If prefixed with .\, we are becoming a local account, strip the prefix + if (username.StartsWith(".\\")) + username = username.Substring(2); + + NTAccount account = new NTAccount(username); + becomeSid = ((SecurityIdentifier)account.Translate(typeof(SecurityIdentifier))).Value; + + // Grant access to the current Windows Station and Desktop to the become user + GrantAccessToWindowStationAndDesktop(account); + + // Try and impersonate a SYSTEM token, we need a SYSTEM token to either become a well known service + // account or have administrative rights on the become access token. + // If we ultimately are becoming the SYSTEM account we want the token with the most privileges available. + // https://github.com/ansible/ansible/issues/71453 + bool mostPrivileges = becomeSid == "S-1-5-18"; + systemToken = GetPrimaryTokenForUser(new SecurityIdentifier("S-1-5-18"), + new List<string>() { "SeTcbPrivilege" }, mostPrivileges); + if (systemToken != null) + { + try + { + TokenUtil.ImpersonateToken(systemToken); + impersonated = true; + } + catch (Process.Win32Exception) { } // We tried, just rely on current user's permissions. + } + } + + // We require impersonation if becoming a service sid or becoming a user without a password + if (!impersonated && (SERVICE_SIDS.Contains(becomeSid) || String.IsNullOrEmpty(password))) + throw new Exception("Failed to get token for NT AUTHORITY\\SYSTEM required for become as a service account or an account without a password"); + + try + { + if (becomeSid == "S-1-5-18") + userTokens.Add(systemToken); + // Cannot use String.IsEmptyOrNull() as an empty string is an account that doesn't have a pass. + // We only use S4U if no password was defined or it was null + else if (!SERVICE_SIDS.Contains(becomeSid) && password == null && logonType != LogonType.NewCredentials) + { + // If no password was specified, try and duplicate an existing token for that user or use S4U to + // generate one without network credentials + SecurityIdentifier sid = new SecurityIdentifier(becomeSid); + SafeNativeHandle becomeToken = GetPrimaryTokenForUser(sid); + if (becomeToken != null) + { + userTokens.Add(GetElevatedToken(becomeToken)); + userTokens.Add(becomeToken); + } + else + { + becomeToken = GetS4UTokenForUser(sid, logonType); + userTokens.Add(null); + userTokens.Add(becomeToken); + } + } + else + { + string domain = null; + switch (becomeSid) + { + case "S-1-5-19": + logonType = LogonType.Service; + domain = "NT AUTHORITY"; + username = "LocalService"; + break; + case "S-1-5-20": + logonType = LogonType.Service; + domain = "NT AUTHORITY"; + username = "NetworkService"; + break; + default: + // Trying to become a local or domain account + if (username.Contains(@"\")) + { + string[] userSplit = username.Split(new char[1] { '\\' }, 2); + domain = userSplit[0]; + username = userSplit[1]; + } + else if (!username.Contains("@")) + domain = "."; + break; + } + + SafeNativeHandle hToken = TokenUtil.LogonUser(username, domain, password, logonType, + LogonProvider.Default); + + // Get the elevated token for a local/domain accounts only + if (!SERVICE_SIDS.Contains(becomeSid)) + userTokens.Add(GetElevatedToken(hToken)); + userTokens.Add(hToken); + } + } + finally + { + if (impersonated) + TokenUtil.RevertToSelf(); + } + + return userTokens; + } + + private static SafeNativeHandle GetPrimaryTokenForUser(SecurityIdentifier sid, + List<string> requiredPrivileges = null, bool mostPrivileges = false) + { + // According to CreateProcessWithTokenW we require a token with + // TOKEN_QUERY, TOKEN_DUPLICATE and TOKEN_ASSIGN_PRIMARY + // Also add in TOKEN_IMPERSONATE so we can get an impersonated token + TokenAccessLevels dwAccess = TokenAccessLevels.Query | + TokenAccessLevels.Duplicate | + TokenAccessLevels.AssignPrimary | + TokenAccessLevels.Impersonate; + + SafeNativeHandle userToken = null; + int privilegeCount = 0; + + foreach (SafeNativeHandle hToken in TokenUtil.EnumerateUserTokens(sid, dwAccess)) + { + // Filter out any Network logon tokens, using become with that is useless when S4U + // can give us a Batch logon + NativeHelpers.SECURITY_LOGON_TYPE tokenLogonType = GetTokenLogonType(hToken); + if (tokenLogonType == NativeHelpers.SECURITY_LOGON_TYPE.Network) + continue; + + List<string> actualPrivileges = TokenUtil.GetTokenPrivileges(hToken).Select(x => x.Name).ToList(); + + // If the token has less or the same number of privileges than the current token, skip it. + if (mostPrivileges && privilegeCount >= actualPrivileges.Count) + continue; + + // Check that the required privileges are on the token + if (requiredPrivileges != null) + { + int missing = requiredPrivileges.Where(x => !actualPrivileges.Contains(x)).Count(); + if (missing > 0) + continue; + } + + // Duplicate the token to convert it to a primary token with the access level required. + try + { + userToken = TokenUtil.DuplicateToken(hToken, TokenAccessLevels.MaximumAllowed, + SecurityImpersonationLevel.Anonymous, TokenType.Primary); + privilegeCount = actualPrivileges.Count; + } + catch (Process.Win32Exception) + { + continue; + } + + // If we don't care about getting the token with the most privileges, escape the loop as we already + // have a token. + if (!mostPrivileges) + break; + } + + return userToken; + } + + private static SafeNativeHandle GetS4UTokenForUser(SecurityIdentifier sid, LogonType logonType) + { + NTAccount becomeAccount = (NTAccount)sid.Translate(typeof(NTAccount)); + string[] userSplit = becomeAccount.Value.Split(new char[1] { '\\' }, 2); + string domainName = userSplit[0]; + string username = userSplit[1]; + bool domainUser = domainName.ToLowerInvariant() != Environment.MachineName.ToLowerInvariant(); + + NativeHelpers.LSA_STRING logonProcessName = "ansible"; + SafeLsaHandle lsaHandle; + IntPtr securityMode; + UInt32 res = NativeMethods.LsaRegisterLogonProcess(logonProcessName, out lsaHandle, out securityMode); + if (res != 0) + throw new Process.Win32Exception((int)NativeMethods.LsaNtStatusToWinError(res), "LsaRegisterLogonProcess() failed"); + + using (lsaHandle) + { + NativeHelpers.LSA_STRING packageName = domainUser ? "Kerberos" : "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"; + UInt32 authPackage; + res = NativeMethods.LsaLookupAuthenticationPackage(lsaHandle, packageName, out authPackage); + if (res != 0) + throw new Process.Win32Exception((int)NativeMethods.LsaNtStatusToWinError(res), + String.Format("LsaLookupAuthenticationPackage({0}) failed", (string)packageName)); + + int usernameLength = username.Length * sizeof(char); + int domainLength = domainName.Length * sizeof(char); + int authInfoLength = (Marshal.SizeOf(typeof(NativeHelpers.KERB_S4U_LOGON)) + usernameLength + domainLength); + IntPtr authInfo = Marshal.AllocHGlobal((int)authInfoLength); + try + { + IntPtr usernamePtr = IntPtr.Add(authInfo, Marshal.SizeOf(typeof(NativeHelpers.KERB_S4U_LOGON))); + IntPtr domainPtr = IntPtr.Add(usernamePtr, usernameLength); + + // KERB_S4U_LOGON has the same structure as MSV1_0_S4U_LOGON (local accounts) + NativeHelpers.KERB_S4U_LOGON s4uLogon = new NativeHelpers.KERB_S4U_LOGON + { + MessageType = 12, // KerbS4ULogon + Flags = 0, + ClientUpn = new NativeHelpers.LSA_UNICODE_STRING + { + Length = (UInt16)usernameLength, + MaximumLength = (UInt16)usernameLength, + Buffer = usernamePtr, + }, + ClientRealm = new NativeHelpers.LSA_UNICODE_STRING + { + Length = (UInt16)domainLength, + MaximumLength = (UInt16)domainLength, + Buffer = domainPtr, + }, + }; + Marshal.StructureToPtr(s4uLogon, authInfo, false); + Marshal.Copy(username.ToCharArray(), 0, usernamePtr, username.Length); + Marshal.Copy(domainName.ToCharArray(), 0, domainPtr, domainName.Length); + + Luid sourceLuid; + if (!NativeMethods.AllocateLocallyUniqueId(out sourceLuid)) + throw new Process.Win32Exception("AllocateLocallyUniqueId() failed"); + + NativeHelpers.TOKEN_SOURCE tokenSource = new NativeHelpers.TOKEN_SOURCE + { + SourceName = "ansible\0".ToCharArray(), + SourceIdentifier = sourceLuid, + }; + + // Only Batch or Network will work with S4U, prefer Batch but use Network if asked + LogonType lsaLogonType = logonType == LogonType.Network + ? LogonType.Network + : LogonType.Batch; + SafeLsaMemoryBuffer profileBuffer; + UInt32 profileBufferLength; + Luid logonId; + SafeNativeHandle hToken; + IntPtr quotas; + UInt32 subStatus; + + res = NativeMethods.LsaLogonUser(lsaHandle, logonProcessName, lsaLogonType, authPackage, + authInfo, (UInt32)authInfoLength, IntPtr.Zero, tokenSource, out profileBuffer, out profileBufferLength, + out logonId, out hToken, out quotas, out subStatus); + if (res != 0) + throw new Process.Win32Exception((int)NativeMethods.LsaNtStatusToWinError(res), + String.Format("LsaLogonUser() failed with substatus {0}", subStatus)); + + profileBuffer.Dispose(); + return hToken; + } + finally + { + Marshal.FreeHGlobal(authInfo); + } + } + } + + private static SafeNativeHandle GetElevatedToken(SafeNativeHandle hToken) + { + TokenElevationType tet = TokenUtil.GetTokenElevationType(hToken); + // We already have the best token we can get, no linked token is really available. + if (tet != TokenElevationType.Limited) + return null; + + SafeNativeHandle linkedToken = TokenUtil.GetTokenLinkedToken(hToken); + TokenStatistics tokenStats = TokenUtil.GetTokenStatistics(linkedToken); + + // We can only use a token if it's a primary one (we had the SeTcbPrivilege set) + if (tokenStats.TokenType == TokenType.Primary) + return linkedToken; + else + return null; + } + + private static NativeHelpers.SECURITY_LOGON_TYPE GetTokenLogonType(SafeNativeHandle hToken) + { + TokenStatistics stats = TokenUtil.GetTokenStatistics(hToken); + + SafeLsaMemoryBuffer sessionDataPtr; + UInt32 res = NativeMethods.LsaGetLogonSessionData(ref stats.AuthenticationId, out sessionDataPtr); + if (res != 0) + // Default to Network, if we weren't able to get the actual type treat it as an error and assume + // we don't want to run a process with the token + return NativeHelpers.SECURITY_LOGON_TYPE.Network; + + using (sessionDataPtr) + { + NativeHelpers.SECURITY_LOGON_SESSION_DATA sessionData = (NativeHelpers.SECURITY_LOGON_SESSION_DATA)Marshal.PtrToStructure( + sessionDataPtr.DangerousGetHandle(), typeof(NativeHelpers.SECURITY_LOGON_SESSION_DATA)); + return sessionData.LogonType; + } + } + + private static void GrantAccessToWindowStationAndDesktop(IdentityReference account) + { + GrantAccess(account, NativeMethods.GetProcessWindowStation(), WINDOWS_STATION_ALL_ACCESS); + GrantAccess(account, NativeMethods.GetThreadDesktop(NativeMethods.GetCurrentThreadId()), DESKTOP_RIGHTS_ALL_ACCESS); + } + + private static void GrantAccess(IdentityReference account, NoopSafeHandle handle, int accessMask) + { + GenericSecurity security = new GenericSecurity(false, ResourceType.WindowObject, handle, AccessControlSections.Access); + security.AddAccessRule(new GenericAccessRule(account, accessMask, AccessControlType.Allow)); + security.Persist(handle, AccessControlSections.Access); + } + + private class GenericSecurity : NativeObjectSecurity + { + public GenericSecurity(bool isContainer, ResourceType resType, SafeHandle objectHandle, AccessControlSections sectionsRequested) + : base(isContainer, resType, objectHandle, sectionsRequested) { } + public new void Persist(SafeHandle handle, AccessControlSections includeSections) { base.Persist(handle, includeSections); } + public new void AddAccessRule(AccessRule rule) { base.AddAccessRule(rule); } + public override Type AccessRightType { get { throw new NotImplementedException(); } } + public override AccessRule AccessRuleFactory(System.Security.Principal.IdentityReference identityReference, int accessMask, bool isInherited, + InheritanceFlags inheritanceFlags, PropagationFlags propagationFlags, AccessControlType type) + { throw new NotImplementedException(); } + public override Type AccessRuleType { get { return typeof(AccessRule); } } + public override AuditRule AuditRuleFactory(System.Security.Principal.IdentityReference identityReference, int accessMask, bool isInherited, + InheritanceFlags inheritanceFlags, PropagationFlags propagationFlags, AuditFlags flags) + { throw new NotImplementedException(); } + public override Type AuditRuleType { get { return typeof(AuditRule); } } + } + + private class GenericAccessRule : AccessRule + { + public GenericAccessRule(IdentityReference identity, int accessMask, AccessControlType type) : + base(identity, accessMask, false, InheritanceFlags.None, PropagationFlags.None, type) + { } + } + } +} diff --git a/lib/ansible/module_utils/csharp/Ansible.Privilege.cs b/lib/ansible/module_utils/csharp/Ansible.Privilege.cs new file mode 100644 index 0000000..2c0b266 --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Privilege.cs @@ -0,0 +1,443 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; + +namespace Ansible.Privilege +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential)] + public struct LUID + { + public UInt32 LowPart; + public Int32 HighPart; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LUID_AND_ATTRIBUTES + { + public LUID Luid; + public PrivilegeAttributes Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_PRIVILEGES + { + public UInt32 PrivilegeCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public LUID_AND_ATTRIBUTES[] Privileges; + } + } + + internal class NativeMethods + { + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool AdjustTokenPrivileges( + SafeNativeHandle TokenHandle, + [MarshalAs(UnmanagedType.Bool)] bool DisableAllPrivileges, + SafeMemoryBuffer NewState, + UInt32 BufferLength, + SafeMemoryBuffer PreviousState, + out UInt32 ReturnLength); + + [DllImport("kernel32.dll")] + public static extern bool CloseHandle( + IntPtr hObject); + + [DllImport("kernel32")] + public static extern SafeWaitHandle GetCurrentProcess(); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool GetTokenInformation( + SafeNativeHandle TokenHandle, + UInt32 TokenInformationClass, + SafeMemoryBuffer TokenInformation, + UInt32 TokenInformationLength, + out UInt32 ReturnLength); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LookupPrivilegeName( + string lpSystemName, + ref NativeHelpers.LUID lpLuid, + StringBuilder lpName, + ref UInt32 cchName); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LookupPrivilegeValue( + string lpSystemName, + string lpName, + out NativeHelpers.LUID lpLuid); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool OpenProcessToken( + SafeHandle ProcessHandle, + TokenAccessLevels DesiredAccess, + out SafeNativeHandle TokenHandle); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + internal class SafeNativeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeNativeHandle() : base(true) { } + public SafeNativeHandle(IntPtr handle) : base(true) { this.handle = handle; } + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseHandle(handle); + } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + [Flags] + public enum PrivilegeAttributes : uint + { + Disabled = 0x00000000, + EnabledByDefault = 0x00000001, + Enabled = 0x00000002, + Removed = 0x00000004, + UsedForAccess = 0x80000000, + } + + public class PrivilegeEnabler : IDisposable + { + private SafeHandle process; + private Dictionary<string, bool?> previousState; + + /// <summary> + /// Temporarily enables the privileges specified and reverts once the class is disposed. + /// </summary> + /// <param name="strict">Whether to fail if any privilege failed to be enabled, if false then this will continue silently</param> + /// <param name="privileges">A list of privileges to enable</param> + public PrivilegeEnabler(bool strict, params string[] privileges) + { + if (privileges.Length > 0) + { + process = PrivilegeUtil.GetCurrentProcess(); + Dictionary<string, bool?> newState = new Dictionary<string, bool?>(); + for (int i = 0; i < privileges.Length; i++) + newState.Add(privileges[i], true); + try + { + previousState = PrivilegeUtil.SetTokenPrivileges(process, newState, strict); + } + catch (Win32Exception e) + { + throw new Win32Exception(e.NativeErrorCode, String.Format("Failed to enable privilege(s) {0}", String.Join(", ", privileges))); + } + } + } + + public void Dispose() + { + // disables any privileges that were enabled by this class + if (previousState != null) + PrivilegeUtil.SetTokenPrivileges(process, previousState); + GC.SuppressFinalize(this); + } + ~PrivilegeEnabler() { this.Dispose(); } + } + + public class PrivilegeUtil + { + private static readonly UInt32 TOKEN_PRIVILEGES = 3; + + /// <summary> + /// Checks if the specific privilege constant is a valid privilege name + /// </summary> + /// <param name="name">The privilege constant (Se*Privilege) is valid</param> + /// <returns>true if valid, else false</returns> + public static bool CheckPrivilegeName(string name) + { + NativeHelpers.LUID luid; + if (!NativeMethods.LookupPrivilegeValue(null, name, out luid)) + { + int errCode = Marshal.GetLastWin32Error(); + if (errCode != 1313) // ERROR_NO_SUCH_PRIVILEGE + throw new Win32Exception(errCode, String.Format("LookupPrivilegeValue({0}) failed", name)); + return false; + } + else + { + return true; + } + } + + /// <summary> + /// Disables the privilege specified + /// </summary> + /// <param name="token">The process token to that contains the privilege to disable</param> + /// <param name="privilege">The privilege constant to disable</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> DisablePrivilege(SafeHandle token, string privilege) + { + return SetTokenPrivileges(token, new Dictionary<string, bool?>() { { privilege, false } }); + } + + /// <summary> + /// Disables all the privileges + /// </summary> + /// <param name="token">The process token to that contains the privilege to disable</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> DisableAllPrivileges(SafeHandle token) + { + return AdjustTokenPrivileges(token, null, false); + } + + /// <summary> + /// Enables the privilege specified + /// </summary> + /// <param name="token">The process token to that contains the privilege to enable</param> + /// <param name="privilege">The privilege constant to enable</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> EnablePrivilege(SafeHandle token, string privilege) + { + return SetTokenPrivileges(token, new Dictionary<string, bool?>() { { privilege, true } }); + } + + /// <summary> + /// Get's the status of all the privileges on the token specified + /// </summary> + /// <param name="token">The process token to get the privilege status on</param> + /// <returns>Dictionary where the key is the privilege constant and the value is the PrivilegeAttributes flags</returns> + public static Dictionary<String, PrivilegeAttributes> GetAllPrivilegeInfo(SafeHandle token) + { + SafeNativeHandle hToken = null; + if (!NativeMethods.OpenProcessToken(token, TokenAccessLevels.Query, out hToken)) + throw new Win32Exception("OpenProcessToken() failed"); + + using (hToken) + { + UInt32 tokenLength = 0; + NativeMethods.GetTokenInformation(hToken, TOKEN_PRIVILEGES, new SafeMemoryBuffer(0), 0, out tokenLength); + + NativeHelpers.LUID_AND_ATTRIBUTES[] privileges; + using (SafeMemoryBuffer privilegesPtr = new SafeMemoryBuffer((int)tokenLength)) + { + if (!NativeMethods.GetTokenInformation(hToken, TOKEN_PRIVILEGES, privilegesPtr, tokenLength, out tokenLength)) + throw new Win32Exception("GetTokenInformation() for TOKEN_PRIVILEGES failed"); + + NativeHelpers.TOKEN_PRIVILEGES privilegeInfo = (NativeHelpers.TOKEN_PRIVILEGES)Marshal.PtrToStructure( + privilegesPtr.DangerousGetHandle(), typeof(NativeHelpers.TOKEN_PRIVILEGES)); + privileges = new NativeHelpers.LUID_AND_ATTRIBUTES[privilegeInfo.PrivilegeCount]; + PtrToStructureArray(privileges, IntPtr.Add(privilegesPtr.DangerousGetHandle(), Marshal.SizeOf(privilegeInfo.PrivilegeCount))); + } + + return privileges.ToDictionary(p => GetPrivilegeName(p.Luid), p => p.Attributes); + } + } + + /// <summary> + /// Get a handle to the current process for use with the methods above + /// </summary> + /// <returns>SafeWaitHandle handle of the current process token</returns> + public static SafeWaitHandle GetCurrentProcess() + { + return NativeMethods.GetCurrentProcess(); + } + + /// <summary> + /// Removes a privilege from the token. This operation is irreversible + /// </summary> + /// <param name="token">The process token to that contains the privilege to remove</param> + /// <param name="privilege">The privilege constant to remove</param> + public static void RemovePrivilege(SafeHandle token, string privilege) + { + SetTokenPrivileges(token, new Dictionary<string, bool?>() { { privilege, null } }); + } + + /// <summary> + /// Do a bulk set of multiple privileges + /// </summary> + /// <param name="token">The process token to use when setting the privilege state</param> + /// <param name="state">A dictionary that contains the privileges to set, the key is the constant name and the value can be; + /// true - enable the privilege + /// false - disable the privilege + /// null - remove the privilege (this cannot be reversed) + /// </param> + /// <param name="strict">When true, will fail if one privilege failed to be set, otherwise it will silently continue</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> SetTokenPrivileges(SafeHandle token, IDictionary state, bool strict = true) + { + NativeHelpers.LUID_AND_ATTRIBUTES[] privilegeAttr = new NativeHelpers.LUID_AND_ATTRIBUTES[state.Count]; + int i = 0; + + foreach (DictionaryEntry entry in state) + { + string key = (string)entry.Key; + NativeHelpers.LUID luid; + if (!NativeMethods.LookupPrivilegeValue(null, key, out luid)) + throw new Win32Exception(String.Format("LookupPrivilegeValue({0}) failed", key)); + + PrivilegeAttributes attributes; + switch ((bool?)entry.Value) + { + case true: + attributes = PrivilegeAttributes.Enabled; + break; + case false: + attributes = PrivilegeAttributes.Disabled; + break; + default: + attributes = PrivilegeAttributes.Removed; + break; + } + + privilegeAttr[i].Luid = luid; + privilegeAttr[i].Attributes = attributes; + i++; + } + + return AdjustTokenPrivileges(token, privilegeAttr, strict); + } + + private static Dictionary<string, bool?> AdjustTokenPrivileges(SafeHandle token, NativeHelpers.LUID_AND_ATTRIBUTES[] newState, bool strict) + { + bool disableAllPrivileges; + SafeMemoryBuffer newStatePtr; + NativeHelpers.LUID_AND_ATTRIBUTES[] oldStatePrivileges; + UInt32 returnLength; + + if (newState == null) + { + disableAllPrivileges = true; + newStatePtr = new SafeMemoryBuffer(0); + } + else + { + disableAllPrivileges = false; + + // Need to manually marshal the bytes requires for newState as the constant size + // of LUID_AND_ATTRIBUTES is set to 1 and can't be overridden at runtime, TOKEN_PRIVILEGES + // always contains at least 1 entry so we need to calculate the extra size if there are + // nore than 1 LUID_AND_ATTRIBUTES entry + int tokenPrivilegesSize = Marshal.SizeOf(typeof(NativeHelpers.TOKEN_PRIVILEGES)); + int luidAttrSize = 0; + if (newState.Length > 1) + luidAttrSize = Marshal.SizeOf(typeof(NativeHelpers.LUID_AND_ATTRIBUTES)) * (newState.Length - 1); + int totalSize = tokenPrivilegesSize + luidAttrSize; + byte[] newStateBytes = new byte[totalSize]; + + // get the first entry that includes the struct details + NativeHelpers.TOKEN_PRIVILEGES tokenPrivileges = new NativeHelpers.TOKEN_PRIVILEGES() + { + PrivilegeCount = (UInt32)newState.Length, + Privileges = new NativeHelpers.LUID_AND_ATTRIBUTES[1], + }; + if (newState.Length > 0) + tokenPrivileges.Privileges[0] = newState[0]; + int offset = StructureToBytes(tokenPrivileges, newStateBytes, 0); + + // copy the remaining LUID_AND_ATTRIBUTES (if any) + for (int i = 1; i < newState.Length; i++) + offset += StructureToBytes(newState[i], newStateBytes, offset); + + // finally create the pointer to the byte array we just created + newStatePtr = new SafeMemoryBuffer(newStateBytes.Length); + Marshal.Copy(newStateBytes, 0, newStatePtr.DangerousGetHandle(), newStateBytes.Length); + } + + using (newStatePtr) + { + SafeNativeHandle hToken; + if (!NativeMethods.OpenProcessToken(token, TokenAccessLevels.Query | TokenAccessLevels.AdjustPrivileges, out hToken)) + throw new Win32Exception("OpenProcessToken() failed with Query and AdjustPrivileges"); + + using (hToken) + { + if (!NativeMethods.AdjustTokenPrivileges(hToken, disableAllPrivileges, newStatePtr, 0, new SafeMemoryBuffer(0), out returnLength)) + { + int errCode = Marshal.GetLastWin32Error(); + if (errCode != 122) // ERROR_INSUFFICIENT_BUFFER + throw new Win32Exception(errCode, "AdjustTokenPrivileges() failed to get old state size"); + } + + using (SafeMemoryBuffer oldStatePtr = new SafeMemoryBuffer((int)returnLength)) + { + bool res = NativeMethods.AdjustTokenPrivileges(hToken, disableAllPrivileges, newStatePtr, returnLength, oldStatePtr, out returnLength); + int errCode = Marshal.GetLastWin32Error(); + + // even when res == true, ERROR_NOT_ALL_ASSIGNED may be set as the last error code + // fail if we are running with strict, otherwise ignore those privileges + if (!res || ((strict && errCode != 0) || (!strict && !(errCode == 0 || errCode == 0x00000514)))) + throw new Win32Exception(errCode, "AdjustTokenPrivileges() failed"); + + // Marshal the oldStatePtr to the struct + NativeHelpers.TOKEN_PRIVILEGES oldState = (NativeHelpers.TOKEN_PRIVILEGES)Marshal.PtrToStructure( + oldStatePtr.DangerousGetHandle(), typeof(NativeHelpers.TOKEN_PRIVILEGES)); + oldStatePrivileges = new NativeHelpers.LUID_AND_ATTRIBUTES[oldState.PrivilegeCount]; + PtrToStructureArray(oldStatePrivileges, IntPtr.Add(oldStatePtr.DangerousGetHandle(), Marshal.SizeOf(oldState.PrivilegeCount))); + } + } + } + + return oldStatePrivileges.ToDictionary(p => GetPrivilegeName(p.Luid), p => (bool?)p.Attributes.HasFlag(PrivilegeAttributes.Enabled)); + } + + private static string GetPrivilegeName(NativeHelpers.LUID luid) + { + UInt32 nameLen = 0; + NativeMethods.LookupPrivilegeName(null, ref luid, null, ref nameLen); + + StringBuilder name = new StringBuilder((int)(nameLen + 1)); + if (!NativeMethods.LookupPrivilegeName(null, ref luid, name, ref nameLen)) + throw new Win32Exception("LookupPrivilegeName() failed"); + + return name.ToString(); + } + + private static void PtrToStructureArray<T>(T[] array, IntPtr ptr) + { + IntPtr ptrOffset = ptr; + for (int i = 0; i < array.Length; i++, ptrOffset = IntPtr.Add(ptrOffset, Marshal.SizeOf(typeof(T)))) + array[i] = (T)Marshal.PtrToStructure(ptrOffset, typeof(T)); + } + + private static int StructureToBytes<T>(T structure, byte[] array, int offset) + { + int size = Marshal.SizeOf(structure); + using (SafeMemoryBuffer structPtr = new SafeMemoryBuffer(size)) + { + Marshal.StructureToPtr(structure, structPtr.DangerousGetHandle(), false); + Marshal.Copy(structPtr.DangerousGetHandle(), array, offset, size); + } + + return size; + } + } +} + diff --git a/lib/ansible/module_utils/csharp/Ansible.Process.cs b/lib/ansible/module_utils/csharp/Ansible.Process.cs new file mode 100644 index 0000000..f4c68f0 --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Process.cs @@ -0,0 +1,461 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections; +using System.IO; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; + +namespace Ansible.Process +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential)] + public class SECURITY_ATTRIBUTES + { + public UInt32 nLength; + public IntPtr lpSecurityDescriptor; + public bool bInheritHandle = false; + public SECURITY_ATTRIBUTES() + { + nLength = (UInt32)Marshal.SizeOf(this); + } + } + + [StructLayout(LayoutKind.Sequential)] + public class STARTUPINFO + { + public UInt32 cb; + public IntPtr lpReserved; + [MarshalAs(UnmanagedType.LPWStr)] public string lpDesktop; + [MarshalAs(UnmanagedType.LPWStr)] public string lpTitle; + public UInt32 dwX; + public UInt32 dwY; + public UInt32 dwXSize; + public UInt32 dwYSize; + public UInt32 dwXCountChars; + public UInt32 dwYCountChars; + public UInt32 dwFillAttribute; + public StartupInfoFlags dwFlags; + public UInt16 wShowWindow; + public UInt16 cbReserved2; + public IntPtr lpReserved2; + public SafeFileHandle hStdInput; + public SafeFileHandle hStdOutput; + public SafeFileHandle hStdError; + public STARTUPINFO() + { + cb = (UInt32)Marshal.SizeOf(this); + } + } + + [StructLayout(LayoutKind.Sequential)] + public class STARTUPINFOEX + { + public STARTUPINFO startupInfo; + public IntPtr lpAttributeList; + public STARTUPINFOEX() + { + startupInfo = new STARTUPINFO(); + startupInfo.cb = (UInt32)Marshal.SizeOf(this); + } + } + + [StructLayout(LayoutKind.Sequential)] + public struct PROCESS_INFORMATION + { + public IntPtr hProcess; + public IntPtr hThread; + public int dwProcessId; + public int dwThreadId; + } + + [Flags] + public enum ProcessCreationFlags : uint + { + CREATE_NEW_CONSOLE = 0x00000010, + CREATE_UNICODE_ENVIRONMENT = 0x00000400, + EXTENDED_STARTUPINFO_PRESENT = 0x00080000 + } + + [Flags] + public enum StartupInfoFlags : uint + { + USESTDHANDLES = 0x00000100 + } + + [Flags] + public enum HandleFlags : uint + { + None = 0, + INHERIT = 1 + } + } + + internal class NativeMethods + { + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool AllocConsole(); + + [DllImport("shell32.dll", SetLastError = true)] + public static extern SafeMemoryBuffer CommandLineToArgvW( + [MarshalAs(UnmanagedType.LPWStr)] string lpCmdLine, + out int pNumArgs); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool CreatePipe( + out SafeFileHandle hReadPipe, + out SafeFileHandle hWritePipe, + NativeHelpers.SECURITY_ATTRIBUTES lpPipeAttributes, + UInt32 nSize); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool CreateProcessW( + [MarshalAs(UnmanagedType.LPWStr)] string lpApplicationName, + StringBuilder lpCommandLine, + IntPtr lpProcessAttributes, + IntPtr lpThreadAttributes, + bool bInheritHandles, + NativeHelpers.ProcessCreationFlags dwCreationFlags, + SafeMemoryBuffer lpEnvironment, + [MarshalAs(UnmanagedType.LPWStr)] string lpCurrentDirectory, + NativeHelpers.STARTUPINFOEX lpStartupInfo, + out NativeHelpers.PROCESS_INFORMATION lpProcessInformation); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool FreeConsole(); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern IntPtr GetConsoleWindow(); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool GetExitCodeProcess( + SafeWaitHandle hProcess, + out UInt32 lpExitCode); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern uint SearchPathW( + [MarshalAs(UnmanagedType.LPWStr)] string lpPath, + [MarshalAs(UnmanagedType.LPWStr)] string lpFileName, + [MarshalAs(UnmanagedType.LPWStr)] string lpExtension, + UInt32 nBufferLength, + [MarshalAs(UnmanagedType.LPTStr)] StringBuilder lpBuffer, + out IntPtr lpFilePart); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool SetConsoleCP( + UInt32 wCodePageID); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool SetConsoleOutputCP( + UInt32 wCodePageID); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool SetHandleInformation( + SafeFileHandle hObject, + NativeHelpers.HandleFlags dwMask, + NativeHelpers.HandleFlags dwFlags); + + [DllImport("kernel32.dll")] + public static extern UInt32 WaitForSingleObject( + SafeWaitHandle hHandle, + UInt32 dwMilliseconds); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + public class Result + { + public string StandardOut { get; internal set; } + public string StandardError { get; internal set; } + public uint ExitCode { get; internal set; } + } + + public class ProcessUtil + { + /// <summary> + /// Parses a command line string into an argv array according to the Windows rules + /// </summary> + /// <param name="lpCommandLine">The command line to parse</param> + /// <returns>An array of arguments interpreted by Windows</returns> + public static string[] ParseCommandLine(string lpCommandLine) + { + int numArgs; + using (SafeMemoryBuffer buf = NativeMethods.CommandLineToArgvW(lpCommandLine, out numArgs)) + { + if (buf.IsInvalid) + throw new Win32Exception("Error parsing command line"); + IntPtr[] strptrs = new IntPtr[numArgs]; + Marshal.Copy(buf.DangerousGetHandle(), strptrs, 0, numArgs); + return strptrs.Select(s => Marshal.PtrToStringUni(s)).ToArray(); + } + } + + /// <summary> + /// Searches the path for the executable specified. Will throw a Win32Exception if the file is not found. + /// </summary> + /// <param name="lpFileName">The executable to search for</param> + /// <returns>The full path of the executable to search for</returns> + public static string SearchPath(string lpFileName) + { + StringBuilder sbOut = new StringBuilder(0); + IntPtr filePartOut = IntPtr.Zero; + UInt32 res = NativeMethods.SearchPathW(null, lpFileName, null, (UInt32)sbOut.Capacity, sbOut, out filePartOut); + if (res == 0) + { + int lastErr = Marshal.GetLastWin32Error(); + if (lastErr == 2) // ERROR_FILE_NOT_FOUND + throw new FileNotFoundException(String.Format("Could not find file '{0}'.", lpFileName)); + else + throw new Win32Exception(String.Format("SearchPathW({0}) failed to get buffer length", lpFileName)); + } + + sbOut.EnsureCapacity((int)res); + if (NativeMethods.SearchPathW(null, lpFileName, null, (UInt32)sbOut.Capacity, sbOut, out filePartOut) == 0) + throw new Win32Exception(String.Format("SearchPathW({0}) failed", lpFileName)); + + return sbOut.ToString(); + } + + public static Result CreateProcess(string command) + { + return CreateProcess(null, command, null, null, String.Empty); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment) + { + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, String.Empty); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, string stdin) + { + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdin, null); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, byte[] stdin) + { + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdin, null); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, string stdin, string outputEncoding) + { + byte[] stdinBytes; + if (String.IsNullOrEmpty(stdin)) + stdinBytes = new byte[0]; + else + { + if (!stdin.EndsWith(Environment.NewLine)) + stdin += Environment.NewLine; + stdinBytes = new UTF8Encoding(false).GetBytes(stdin); + } + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdinBytes, outputEncoding); + } + + /// <summary> + /// Creates a process based on the CreateProcess API call. + /// </summary> + /// <param name="lpApplicationName">The name of the executable or batch file to execute</param> + /// <param name="lpCommandLine">The command line to execute, typically this includes lpApplication as the first argument</param> + /// <param name="lpCurrentDirectory">The full path to the current directory for the process, null will have the same cwd as the calling process</param> + /// <param name="environment">A dictionary of key/value pairs to define the new process environment</param> + /// <param name="stdin">A byte array to send over the stdin pipe</param> + /// <param name="outputEncoding">The character encoding for decoding stdout/stderr output of the process.</param> + /// <returns>Result object that contains the command output and return code</returns> + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, byte[] stdin, string outputEncoding) + { + NativeHelpers.ProcessCreationFlags creationFlags = NativeHelpers.ProcessCreationFlags.CREATE_UNICODE_ENVIRONMENT | + NativeHelpers.ProcessCreationFlags.EXTENDED_STARTUPINFO_PRESENT; + NativeHelpers.PROCESS_INFORMATION pi = new NativeHelpers.PROCESS_INFORMATION(); + NativeHelpers.STARTUPINFOEX si = new NativeHelpers.STARTUPINFOEX(); + si.startupInfo.dwFlags = NativeHelpers.StartupInfoFlags.USESTDHANDLES; + + SafeFileHandle stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinRead, stdinWrite; + CreateStdioPipes(si, out stdoutRead, out stdoutWrite, out stderrRead, out stderrWrite, out stdinRead, + out stdinWrite); + FileStream stdinStream = new FileStream(stdinWrite, FileAccess.Write); + + // $null from PowerShell ends up as an empty string, we need to convert back as an empty string doesn't + // make sense for these parameters + if (lpApplicationName == "") + lpApplicationName = null; + + if (lpCurrentDirectory == "") + lpCurrentDirectory = null; + + using (SafeMemoryBuffer lpEnvironment = CreateEnvironmentPointer(environment)) + { + // Create console with utf-8 CP if no existing console is present + bool isConsole = false; + if (NativeMethods.GetConsoleWindow() == IntPtr.Zero) + { + isConsole = NativeMethods.AllocConsole(); + + // Set console input/output codepage to UTF-8 + NativeMethods.SetConsoleCP(65001); + NativeMethods.SetConsoleOutputCP(65001); + } + + try + { + StringBuilder commandLine = new StringBuilder(lpCommandLine); + if (!NativeMethods.CreateProcessW(lpApplicationName, commandLine, IntPtr.Zero, IntPtr.Zero, + true, creationFlags, lpEnvironment, lpCurrentDirectory, si, out pi)) + { + throw new Win32Exception("CreateProcessW() failed"); + } + } + finally + { + if (isConsole) + NativeMethods.FreeConsole(); + } + } + + return WaitProcess(stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinStream, stdin, pi.hProcess, + outputEncoding); + } + + internal static void CreateStdioPipes(NativeHelpers.STARTUPINFOEX si, out SafeFileHandle stdoutRead, + out SafeFileHandle stdoutWrite, out SafeFileHandle stderrRead, out SafeFileHandle stderrWrite, + out SafeFileHandle stdinRead, out SafeFileHandle stdinWrite) + { + NativeHelpers.SECURITY_ATTRIBUTES pipesec = new NativeHelpers.SECURITY_ATTRIBUTES(); + pipesec.bInheritHandle = true; + + if (!NativeMethods.CreatePipe(out stdoutRead, out stdoutWrite, pipesec, 0)) + throw new Win32Exception("STDOUT pipe setup failed"); + if (!NativeMethods.SetHandleInformation(stdoutRead, NativeHelpers.HandleFlags.INHERIT, 0)) + throw new Win32Exception("STDOUT pipe handle setup failed"); + + if (!NativeMethods.CreatePipe(out stderrRead, out stderrWrite, pipesec, 0)) + throw new Win32Exception("STDERR pipe setup failed"); + if (!NativeMethods.SetHandleInformation(stderrRead, NativeHelpers.HandleFlags.INHERIT, 0)) + throw new Win32Exception("STDERR pipe handle setup failed"); + + if (!NativeMethods.CreatePipe(out stdinRead, out stdinWrite, pipesec, 0)) + throw new Win32Exception("STDIN pipe setup failed"); + if (!NativeMethods.SetHandleInformation(stdinWrite, NativeHelpers.HandleFlags.INHERIT, 0)) + throw new Win32Exception("STDIN pipe handle setup failed"); + + si.startupInfo.hStdOutput = stdoutWrite; + si.startupInfo.hStdError = stderrWrite; + si.startupInfo.hStdInput = stdinRead; + } + + internal static SafeMemoryBuffer CreateEnvironmentPointer(IDictionary environment) + { + IntPtr lpEnvironment = IntPtr.Zero; + if (environment != null && environment.Count > 0) + { + StringBuilder environmentString = new StringBuilder(); + foreach (DictionaryEntry kv in environment) + environmentString.AppendFormat("{0}={1}\0", kv.Key, kv.Value); + environmentString.Append('\0'); + + lpEnvironment = Marshal.StringToHGlobalUni(environmentString.ToString()); + } + return new SafeMemoryBuffer(lpEnvironment); + } + + internal static Result WaitProcess(SafeFileHandle stdoutRead, SafeFileHandle stdoutWrite, SafeFileHandle stderrRead, + SafeFileHandle stderrWrite, FileStream stdinStream, byte[] stdin, IntPtr hProcess, string outputEncoding = null) + { + // Default to using UTF-8 as the output encoding, this should be a sane default for most scenarios. + outputEncoding = String.IsNullOrEmpty(outputEncoding) ? "utf-8" : outputEncoding; + Encoding encodingInstance = Encoding.GetEncoding(outputEncoding); + + FileStream stdoutFS = new FileStream(stdoutRead, FileAccess.Read, 4096); + StreamReader stdout = new StreamReader(stdoutFS, encodingInstance, true, 4096); + stdoutWrite.Close(); + + FileStream stderrFS = new FileStream(stderrRead, FileAccess.Read, 4096); + StreamReader stderr = new StreamReader(stderrFS, encodingInstance, true, 4096); + stderrWrite.Close(); + + stdinStream.Write(stdin, 0, stdin.Length); + stdinStream.Close(); + + string stdoutStr, stderrStr = null; + GetProcessOutput(stdout, stderr, out stdoutStr, out stderrStr); + UInt32 rc = GetProcessExitCode(hProcess); + + return new Result + { + StandardOut = stdoutStr, + StandardError = stderrStr, + ExitCode = rc + }; + } + + internal static void GetProcessOutput(StreamReader stdoutStream, StreamReader stderrStream, out string stdout, out string stderr) + { + var sowait = new EventWaitHandle(false, EventResetMode.ManualReset); + var sewait = new EventWaitHandle(false, EventResetMode.ManualReset); + string so = null, se = null; + ThreadPool.QueueUserWorkItem((s) => + { + so = stdoutStream.ReadToEnd(); + sowait.Set(); + }); + ThreadPool.QueueUserWorkItem((s) => + { + se = stderrStream.ReadToEnd(); + sewait.Set(); + }); + foreach (var wh in new WaitHandle[] { sowait, sewait }) + wh.WaitOne(); + stdout = so; + stderr = se; + } + + internal static UInt32 GetProcessExitCode(IntPtr processHandle) + { + SafeWaitHandle hProcess = new SafeWaitHandle(processHandle, true); + NativeMethods.WaitForSingleObject(hProcess, 0xFFFFFFFF); + + UInt32 exitCode; + if (!NativeMethods.GetExitCodeProcess(hProcess, out exitCode)) + throw new Win32Exception("GetExitCodeProcess() failed"); + return exitCode; + } + } +} diff --git a/lib/ansible/module_utils/csharp/__init__.py b/lib/ansible/module_utils/csharp/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/csharp/__init__.py diff --git a/lib/ansible/module_utils/distro/__init__.py b/lib/ansible/module_utils/distro/__init__.py new file mode 100644 index 0000000..b70f29c --- /dev/null +++ b/lib/ansible/module_utils/distro/__init__.py @@ -0,0 +1,56 @@ +# (c) 2018 Toshio Kuratomi <tkuratomi@ansible.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +''' +Compat distro library. +''' +# The following makes it easier for us to script updates of the bundled code +_BUNDLED_METADATA = {"pypi_name": "distro", "version": "1.6.0"} + +# The following additional changes have been made: +# * Remove optparse since it is not needed for our use. +# * A format string including {} has been changed to {0} (py2.6 compat) +# * Port two calls from subprocess.check_output to subprocess.Popen().communicate() (py2.6 compat) + + +import sys +import types + +try: + import distro as _system_distro +except ImportError: + _system_distro = None +else: + # There could be a 'distro' package/module that isn't what we expect, on the + # PYTHONPATH. Rather than erroring out in this case, just fall back to ours. + # We require more functions than distro.id(), but this is probably a decent + # test that we have something we can reasonably use. + if not hasattr(_system_distro, 'id') or \ + not isinstance(_system_distro.id, types.FunctionType): + _system_distro = None + +if _system_distro: + distro = _system_distro +else: + # Our bundled copy + from ansible.module_utils.distro import _distro as distro + +sys.modules['ansible.module_utils.distro'] = distro diff --git a/lib/ansible/module_utils/distro/_distro.py b/lib/ansible/module_utils/distro/_distro.py new file mode 100644 index 0000000..58e41d4 --- /dev/null +++ b/lib/ansible/module_utils/distro/_distro.py @@ -0,0 +1,1416 @@ +# Copyright 2015,2016,2017 Nir Cohen +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# A local copy of the license can be found in licenses/Apache-License.txt +# +# Modifications to this code have been made by Ansible Project + +""" +The ``distro`` package (``distro`` stands for Linux Distribution) provides +information about the Linux distribution it runs on, such as a reliable +machine-readable distro ID, or version information. + +It is the recommended replacement for Python's original +:py:func:`platform.linux_distribution` function, but it provides much more +functionality. An alternative implementation became necessary because Python +3.5 deprecated this function, and Python 3.8 removed it altogether. Its +predecessor function :py:func:`platform.dist` was already deprecated since +Python 2.6 and removed in Python 3.8. Still, there are many cases in which +access to OS distribution information is needed. See `Python issue 1322 +<https://bugs.python.org/issue1322>`_ for more information. +""" + +import logging +import os +import re +import shlex +import subprocess +import sys +import warnings + +__version__ = "1.6.0" + +# Use `if False` to avoid an ImportError on Python 2. After dropping Python 2 +# support, can use typing.TYPE_CHECKING instead. See: +# https://docs.python.org/3/library/typing.html#typing.TYPE_CHECKING +if False: # pragma: nocover + from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Sequence, + TextIO, + Tuple, + Type, + TypedDict, + Union, + ) + + VersionDict = TypedDict( + "VersionDict", {"major": str, "minor": str, "build_number": str} + ) + InfoDict = TypedDict( + "InfoDict", + { + "id": str, + "version": str, + "version_parts": VersionDict, + "like": str, + "codename": str, + }, + ) + + +_UNIXCONFDIR = os.environ.get("UNIXCONFDIR", "/etc") +_UNIXUSRLIBDIR = os.environ.get("UNIXUSRLIBDIR", "/usr/lib") +_OS_RELEASE_BASENAME = "os-release" + +#: Translation table for normalizing the "ID" attribute defined in os-release +#: files, for use by the :func:`distro.id` method. +#: +#: * Key: Value as defined in the os-release file, translated to lower case, +#: with blanks translated to underscores. +#: +#: * Value: Normalized value. +NORMALIZED_OS_ID = { + "ol": "oracle", # Oracle Linux + "opensuse-leap": "opensuse", # Newer versions of OpenSuSE report as opensuse-leap +} + +#: Translation table for normalizing the "Distributor ID" attribute returned by +#: the lsb_release command, for use by the :func:`distro.id` method. +#: +#: * Key: Value as returned by the lsb_release command, translated to lower +#: case, with blanks translated to underscores. +#: +#: * Value: Normalized value. +NORMALIZED_LSB_ID = { + "enterpriseenterpriseas": "oracle", # Oracle Enterprise Linux 4 + "enterpriseenterpriseserver": "oracle", # Oracle Linux 5 + "redhatenterpriseworkstation": "rhel", # RHEL 6, 7 Workstation + "redhatenterpriseserver": "rhel", # RHEL 6, 7 Server + "redhatenterprisecomputenode": "rhel", # RHEL 6 ComputeNode +} + +#: Translation table for normalizing the distro ID derived from the file name +#: of distro release files, for use by the :func:`distro.id` method. +#: +#: * Key: Value as derived from the file name of a distro release file, +#: translated to lower case, with blanks translated to underscores. +#: +#: * Value: Normalized value. +NORMALIZED_DISTRO_ID = { + "redhat": "rhel", # RHEL 6.x, 7.x +} + +# Pattern for content of distro release file (reversed) +_DISTRO_RELEASE_CONTENT_REVERSED_PATTERN = re.compile( + r"(?:[^)]*\)(.*)\()? *(?:STL )?([\d.+\-a-z]*\d) *(?:esaeler *)?(.+)" +) + +# Pattern for base file name of distro release file +_DISTRO_RELEASE_BASENAME_PATTERN = re.compile(r"(\w+)[-_](release|version)$") + +# Base file names to be ignored when searching for distro release file +_DISTRO_RELEASE_IGNORE_BASENAMES = ( + "debian_version", + "lsb-release", + "oem-release", + _OS_RELEASE_BASENAME, + "system-release", + "plesk-release", + "iredmail-release", +) + + +# +# Python 2.6 does not have subprocess.check_output so replicate it here +# +def _my_check_output(*popenargs, **kwargs): + r"""Run command with arguments and return its output as a byte string. + + If the exit code was non-zero it raises a CalledProcessError. The + CalledProcessError object will have the return code in the returncode + attribute and output in the output attribute. + + The arguments are the same as for the Popen constructor. Example: + + >>> check_output(["ls", "-l", "/dev/null"]) + 'crw-rw-rw- 1 root root 1, 3 Oct 18 2007 /dev/null\n' + + The stdout argument is not allowed as it is used internally. + To capture standard error in the result, use stderr=STDOUT. + + >>> check_output(["/bin/sh", "-c", + ... "ls -l non_existent_file ; exit 0"], + ... stderr=STDOUT) + 'ls: non_existent_file: No such file or directory\n' + + This is a backport of Python-2.7's check output to Python-2.6 + """ + if 'stdout' in kwargs: + raise ValueError( + 'stdout argument not allowed, it will be overridden.' + ) + process = subprocess.Popen( + stdout=subprocess.PIPE, *popenargs, **kwargs + ) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + # Deviation from Python-2.7: Python-2.6's CalledProcessError does not + # have an argument for the stdout so simply omit it. + raise subprocess.CalledProcessError(retcode, cmd) + return output + + +try: + _check_output = subprocess.check_output +except AttributeError: + _check_output = _my_check_output + + +def linux_distribution(full_distribution_name=True): + # type: (bool) -> Tuple[str, str, str] + """ + .. deprecated:: 1.6.0 + + :func:`distro.linux_distribution()` is deprecated. It should only be + used as a compatibility shim with Python's + :py:func:`platform.linux_distribution()`. Please use :func:`distro.id`, + :func:`distro.version` and :func:`distro.name` instead. + + Return information about the current OS distribution as a tuple + ``(id_name, version, codename)`` with items as follows: + + * ``id_name``: If *full_distribution_name* is false, the result of + :func:`distro.id`. Otherwise, the result of :func:`distro.name`. + + * ``version``: The result of :func:`distro.version`. + + * ``codename``: The result of :func:`distro.codename`. + + The interface of this function is compatible with the original + :py:func:`platform.linux_distribution` function, supporting a subset of + its parameters. + + The data it returns may not exactly be the same, because it uses more data + sources than the original function, and that may lead to different data if + the OS distribution is not consistent across multiple data sources it + provides (there are indeed such distributions ...). + + Another reason for differences is the fact that the :func:`distro.id` + method normalizes the distro ID string to a reliable machine-readable value + for a number of popular OS distributions. + """ + warnings.warn( + "distro.linux_distribution() is deprecated. It should only be used as a " + "compatibility shim with Python's platform.linux_distribution(). Please use " + "distro.id(), distro.version() and distro.name() instead.", + DeprecationWarning, + stacklevel=2, + ) + return _distro.linux_distribution(full_distribution_name) + + +def id(): + # type: () -> str + """ + Return the distro ID of the current distribution, as a + machine-readable string. + + For a number of OS distributions, the returned distro ID value is + *reliable*, in the sense that it is documented and that it does not change + across releases of the distribution. + + This package maintains the following reliable distro ID values: + + ============== ========================================= + Distro ID Distribution + ============== ========================================= + "ubuntu" Ubuntu + "debian" Debian + "rhel" RedHat Enterprise Linux + "centos" CentOS + "fedora" Fedora + "sles" SUSE Linux Enterprise Server + "opensuse" openSUSE + "amazon" Amazon Linux + "arch" Arch Linux + "cloudlinux" CloudLinux OS + "exherbo" Exherbo Linux + "gentoo" GenToo Linux + "ibm_powerkvm" IBM PowerKVM + "kvmibm" KVM for IBM z Systems + "linuxmint" Linux Mint + "mageia" Mageia + "mandriva" Mandriva Linux + "parallels" Parallels + "pidora" Pidora + "raspbian" Raspbian + "oracle" Oracle Linux (and Oracle Enterprise Linux) + "scientific" Scientific Linux + "slackware" Slackware + "xenserver" XenServer + "openbsd" OpenBSD + "netbsd" NetBSD + "freebsd" FreeBSD + "midnightbsd" MidnightBSD + ============== ========================================= + + If you have a need to get distros for reliable IDs added into this set, + or if you find that the :func:`distro.id` function returns a different + distro ID for one of the listed distros, please create an issue in the + `distro issue tracker`_. + + **Lookup hierarchy and transformations:** + + First, the ID is obtained from the following sources, in the specified + order. The first available and non-empty value is used: + + * the value of the "ID" attribute of the os-release file, + + * the value of the "Distributor ID" attribute returned by the lsb_release + command, + + * the first part of the file name of the distro release file, + + The so determined ID value then passes the following transformations, + before it is returned by this method: + + * it is translated to lower case, + + * blanks (which should not be there anyway) are translated to underscores, + + * a normalization of the ID is performed, based upon + `normalization tables`_. The purpose of this normalization is to ensure + that the ID is as reliable as possible, even across incompatible changes + in the OS distributions. A common reason for an incompatible change is + the addition of an os-release file, or the addition of the lsb_release + command, with ID values that differ from what was previously determined + from the distro release file name. + """ + return _distro.id() + + +def name(pretty=False): + # type: (bool) -> str + """ + Return the name of the current OS distribution, as a human-readable + string. + + If *pretty* is false, the name is returned without version or codename. + (e.g. "CentOS Linux") + + If *pretty* is true, the version and codename are appended. + (e.g. "CentOS Linux 7.1.1503 (Core)") + + **Lookup hierarchy:** + + The name is obtained from the following sources, in the specified order. + The first available and non-empty value is used: + + * If *pretty* is false: + + - the value of the "NAME" attribute of the os-release file, + + - the value of the "Distributor ID" attribute returned by the lsb_release + command, + + - the value of the "<name>" field of the distro release file. + + * If *pretty* is true: + + - the value of the "PRETTY_NAME" attribute of the os-release file, + + - the value of the "Description" attribute returned by the lsb_release + command, + + - the value of the "<name>" field of the distro release file, appended + with the value of the pretty version ("<version_id>" and "<codename>" + fields) of the distro release file, if available. + """ + return _distro.name(pretty) + + +def version(pretty=False, best=False): + # type: (bool, bool) -> str + """ + Return the version of the current OS distribution, as a human-readable + string. + + If *pretty* is false, the version is returned without codename (e.g. + "7.0"). + + If *pretty* is true, the codename in parenthesis is appended, if the + codename is non-empty (e.g. "7.0 (Maipo)"). + + Some distributions provide version numbers with different precisions in + the different sources of distribution information. Examining the different + sources in a fixed priority order does not always yield the most precise + version (e.g. for Debian 8.2, or CentOS 7.1). + + The *best* parameter can be used to control the approach for the returned + version: + + If *best* is false, the first non-empty version number in priority order of + the examined sources is returned. + + If *best* is true, the most precise version number out of all examined + sources is returned. + + **Lookup hierarchy:** + + In all cases, the version number is obtained from the following sources. + If *best* is false, this order represents the priority order: + + * the value of the "VERSION_ID" attribute of the os-release file, + * the value of the "Release" attribute returned by the lsb_release + command, + * the version number parsed from the "<version_id>" field of the first line + of the distro release file, + * the version number parsed from the "PRETTY_NAME" attribute of the + os-release file, if it follows the format of the distro release files. + * the version number parsed from the "Description" attribute returned by + the lsb_release command, if it follows the format of the distro release + files. + """ + return _distro.version(pretty, best) + + +def version_parts(best=False): + # type: (bool) -> Tuple[str, str, str] + """ + Return the version of the current OS distribution as a tuple + ``(major, minor, build_number)`` with items as follows: + + * ``major``: The result of :func:`distro.major_version`. + + * ``minor``: The result of :func:`distro.minor_version`. + + * ``build_number``: The result of :func:`distro.build_number`. + + For a description of the *best* parameter, see the :func:`distro.version` + method. + """ + return _distro.version_parts(best) + + +def major_version(best=False): + # type: (bool) -> str + """ + Return the major version of the current OS distribution, as a string, + if provided. + Otherwise, the empty string is returned. The major version is the first + part of the dot-separated version string. + + For a description of the *best* parameter, see the :func:`distro.version` + method. + """ + return _distro.major_version(best) + + +def minor_version(best=False): + # type: (bool) -> str + """ + Return the minor version of the current OS distribution, as a string, + if provided. + Otherwise, the empty string is returned. The minor version is the second + part of the dot-separated version string. + + For a description of the *best* parameter, see the :func:`distro.version` + method. + """ + return _distro.minor_version(best) + + +def build_number(best=False): + # type: (bool) -> str + """ + Return the build number of the current OS distribution, as a string, + if provided. + Otherwise, the empty string is returned. The build number is the third part + of the dot-separated version string. + + For a description of the *best* parameter, see the :func:`distro.version` + method. + """ + return _distro.build_number(best) + + +def like(): + # type: () -> str + """ + Return a space-separated list of distro IDs of distributions that are + closely related to the current OS distribution in regards to packaging + and programming interfaces, for example distributions the current + distribution is a derivative from. + + **Lookup hierarchy:** + + This information item is only provided by the os-release file. + For details, see the description of the "ID_LIKE" attribute in the + `os-release man page + <http://www.freedesktop.org/software/systemd/man/os-release.html>`_. + """ + return _distro.like() + + +def codename(): + # type: () -> str + """ + Return the codename for the release of the current OS distribution, + as a string. + + If the distribution does not have a codename, an empty string is returned. + + Note that the returned codename is not always really a codename. For + example, openSUSE returns "x86_64". This function does not handle such + cases in any special way and just returns the string it finds, if any. + + **Lookup hierarchy:** + + * the codename within the "VERSION" attribute of the os-release file, if + provided, + + * the value of the "Codename" attribute returned by the lsb_release + command, + + * the value of the "<codename>" field of the distro release file. + """ + return _distro.codename() + + +def info(pretty=False, best=False): + # type: (bool, bool) -> InfoDict + """ + Return certain machine-readable information items about the current OS + distribution in a dictionary, as shown in the following example: + + .. sourcecode:: python + + { + 'id': 'rhel', + 'version': '7.0', + 'version_parts': { + 'major': '7', + 'minor': '0', + 'build_number': '' + }, + 'like': 'fedora', + 'codename': 'Maipo' + } + + The dictionary structure and keys are always the same, regardless of which + information items are available in the underlying data sources. The values + for the various keys are as follows: + + * ``id``: The result of :func:`distro.id`. + + * ``version``: The result of :func:`distro.version`. + + * ``version_parts -> major``: The result of :func:`distro.major_version`. + + * ``version_parts -> minor``: The result of :func:`distro.minor_version`. + + * ``version_parts -> build_number``: The result of + :func:`distro.build_number`. + + * ``like``: The result of :func:`distro.like`. + + * ``codename``: The result of :func:`distro.codename`. + + For a description of the *pretty* and *best* parameters, see the + :func:`distro.version` method. + """ + return _distro.info(pretty, best) + + +def os_release_info(): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information items + from the os-release file data source of the current OS distribution. + + See `os-release file`_ for details about these information items. + """ + return _distro.os_release_info() + + +def lsb_release_info(): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information items + from the lsb_release command data source of the current OS distribution. + + See `lsb_release command output`_ for details about these information + items. + """ + return _distro.lsb_release_info() + + +def distro_release_info(): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information items + from the distro release file data source of the current OS distribution. + + See `distro release file`_ for details about these information items. + """ + return _distro.distro_release_info() + + +def uname_info(): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information items + from the distro release file data source of the current OS distribution. + """ + return _distro.uname_info() + + +def os_release_attr(attribute): + # type: (str) -> str + """ + Return a single named information item from the os-release file data source + of the current OS distribution. + + Parameters: + + * ``attribute`` (string): Key of the information item. + + Returns: + + * (string): Value of the information item, if the item exists. + The empty string, if the item does not exist. + + See `os-release file`_ for details about these information items. + """ + return _distro.os_release_attr(attribute) + + +def lsb_release_attr(attribute): + # type: (str) -> str + """ + Return a single named information item from the lsb_release command output + data source of the current OS distribution. + + Parameters: + + * ``attribute`` (string): Key of the information item. + + Returns: + + * (string): Value of the information item, if the item exists. + The empty string, if the item does not exist. + + See `lsb_release command output`_ for details about these information + items. + """ + return _distro.lsb_release_attr(attribute) + + +def distro_release_attr(attribute): + # type: (str) -> str + """ + Return a single named information item from the distro release file + data source of the current OS distribution. + + Parameters: + + * ``attribute`` (string): Key of the information item. + + Returns: + + * (string): Value of the information item, if the item exists. + The empty string, if the item does not exist. + + See `distro release file`_ for details about these information items. + """ + return _distro.distro_release_attr(attribute) + + +def uname_attr(attribute): + # type: (str) -> str + """ + Return a single named information item from the distro release file + data source of the current OS distribution. + + Parameters: + + * ``attribute`` (string): Key of the information item. + + Returns: + + * (string): Value of the information item, if the item exists. + The empty string, if the item does not exist. + """ + return _distro.uname_attr(attribute) + + +try: + from functools import cached_property +except ImportError: + # Python < 3.8 + class cached_property(object): # type: ignore + """A version of @property which caches the value. On access, it calls the + underlying function and sets the value in `__dict__` so future accesses + will not re-call the property. + """ + + def __init__(self, f): + # type: (Callable[[Any], Any]) -> None + self._fname = f.__name__ + self._f = f + + def __get__(self, obj, owner): + # type: (Any, Type[Any]) -> Any + assert obj is not None, "call {0} on an instance".format(self._fname) + ret = obj.__dict__[self._fname] = self._f(obj) + return ret + + +class LinuxDistribution(object): + """ + Provides information about a OS distribution. + + This package creates a private module-global instance of this class with + default initialization arguments, that is used by the + `consolidated accessor functions`_ and `single source accessor functions`_. + By using default initialization arguments, that module-global instance + returns data about the current OS distribution (i.e. the distro this + package runs on). + + Normally, it is not necessary to create additional instances of this class. + However, in situations where control is needed over the exact data sources + that are used, instances of this class can be created with a specific + distro release file, or a specific os-release file, or without invoking the + lsb_release command. + """ + + def __init__( + self, + include_lsb=True, + os_release_file="", + distro_release_file="", + include_uname=True, + root_dir=None, + ): + # type: (bool, str, str, bool, Optional[str]) -> None + """ + The initialization method of this class gathers information from the + available data sources, and stores that in private instance attributes. + Subsequent access to the information items uses these private instance + attributes, so that the data sources are read only once. + + Parameters: + + * ``include_lsb`` (bool): Controls whether the + `lsb_release command output`_ is included as a data source. + + If the lsb_release command is not available in the program execution + path, the data source for the lsb_release command will be empty. + + * ``os_release_file`` (string): The path name of the + `os-release file`_ that is to be used as a data source. + + An empty string (the default) will cause the default path name to + be used (see `os-release file`_ for details). + + If the specified or defaulted os-release file does not exist, the + data source for the os-release file will be empty. + + * ``distro_release_file`` (string): The path name of the + `distro release file`_ that is to be used as a data source. + + An empty string (the default) will cause a default search algorithm + to be used (see `distro release file`_ for details). + + If the specified distro release file does not exist, or if no default + distro release file can be found, the data source for the distro + release file will be empty. + + * ``include_uname`` (bool): Controls whether uname command output is + included as a data source. If the uname command is not available in + the program execution path the data source for the uname command will + be empty. + + * ``root_dir`` (string): The absolute path to the root directory to use + to find distro-related information files. + + Public instance attributes: + + * ``os_release_file`` (string): The path name of the + `os-release file`_ that is actually used as a data source. The + empty string if no distro release file is used as a data source. + + * ``distro_release_file`` (string): The path name of the + `distro release file`_ that is actually used as a data source. The + empty string if no distro release file is used as a data source. + + * ``include_lsb`` (bool): The result of the ``include_lsb`` parameter. + This controls whether the lsb information will be loaded. + + * ``include_uname`` (bool): The result of the ``include_uname`` + parameter. This controls whether the uname information will + be loaded. + + Raises: + + * :py:exc:`IOError`: Some I/O issue with an os-release file or distro + release file. + + * :py:exc:`subprocess.CalledProcessError`: The lsb_release command had + some issue (other than not being available in the program execution + path). + + * :py:exc:`UnicodeError`: A data source has unexpected characters or + uses an unexpected encoding. + """ + self.root_dir = root_dir + self.etc_dir = os.path.join(root_dir, "etc") if root_dir else _UNIXCONFDIR + self.usr_lib_dir = ( + os.path.join(root_dir, "usr/lib") if root_dir else _UNIXUSRLIBDIR + ) + + if os_release_file: + self.os_release_file = os_release_file + else: + etc_dir_os_release_file = os.path.join(self.etc_dir, _OS_RELEASE_BASENAME) + usr_lib_os_release_file = os.path.join( + self.usr_lib_dir, _OS_RELEASE_BASENAME + ) + + # NOTE: The idea is to respect order **and** have it set + # at all times for API backwards compatibility. + if os.path.isfile(etc_dir_os_release_file) or not os.path.isfile( + usr_lib_os_release_file + ): + self.os_release_file = etc_dir_os_release_file + else: + self.os_release_file = usr_lib_os_release_file + + self.distro_release_file = distro_release_file or "" # updated later + self.include_lsb = include_lsb + self.include_uname = include_uname + + def __repr__(self): + # type: () -> str + """Return repr of all info""" + return ( + "LinuxDistribution(" + "os_release_file={self.os_release_file!r}, " + "distro_release_file={self.distro_release_file!r}, " + "include_lsb={self.include_lsb!r}, " + "include_uname={self.include_uname!r}, " + "_os_release_info={self._os_release_info!r}, " + "_lsb_release_info={self._lsb_release_info!r}, " + "_distro_release_info={self._distro_release_info!r}, " + "_uname_info={self._uname_info!r})".format(self=self) + ) + + def linux_distribution(self, full_distribution_name=True): + # type: (bool) -> Tuple[str, str, str] + """ + Return information about the OS distribution that is compatible + with Python's :func:`platform.linux_distribution`, supporting a subset + of its parameters. + + For details, see :func:`distro.linux_distribution`. + """ + return ( + self.name() if full_distribution_name else self.id(), + self.version(), + self.codename(), + ) + + def id(self): + # type: () -> str + """Return the distro ID of the OS distribution, as a string. + + For details, see :func:`distro.id`. + """ + + def normalize(distro_id, table): + # type: (str, Dict[str, str]) -> str + distro_id = distro_id.lower().replace(" ", "_") + return table.get(distro_id, distro_id) + + distro_id = self.os_release_attr("id") + if distro_id: + return normalize(distro_id, NORMALIZED_OS_ID) + + distro_id = self.lsb_release_attr("distributor_id") + if distro_id: + return normalize(distro_id, NORMALIZED_LSB_ID) + + distro_id = self.distro_release_attr("id") + if distro_id: + return normalize(distro_id, NORMALIZED_DISTRO_ID) + + distro_id = self.uname_attr("id") + if distro_id: + return normalize(distro_id, NORMALIZED_DISTRO_ID) + + return "" + + def name(self, pretty=False): + # type: (bool) -> str + """ + Return the name of the OS distribution, as a string. + + For details, see :func:`distro.name`. + """ + name = ( + self.os_release_attr("name") + or self.lsb_release_attr("distributor_id") + or self.distro_release_attr("name") + or self.uname_attr("name") + ) + if pretty: + name = self.os_release_attr("pretty_name") or self.lsb_release_attr( + "description" + ) + if not name: + name = self.distro_release_attr("name") or self.uname_attr("name") + version = self.version(pretty=True) + if version: + name = name + " " + version + return name or "" + + def version(self, pretty=False, best=False): + # type: (bool, bool) -> str + """ + Return the version of the OS distribution, as a string. + + For details, see :func:`distro.version`. + """ + versions = [ + self.os_release_attr("version_id"), + self.lsb_release_attr("release"), + self.distro_release_attr("version_id"), + self._parse_distro_release_content(self.os_release_attr("pretty_name")).get( + "version_id", "" + ), + self._parse_distro_release_content( + self.lsb_release_attr("description") + ).get("version_id", ""), + self.uname_attr("release"), + ] + version = "" + if best: + # This algorithm uses the last version in priority order that has + # the best precision. If the versions are not in conflict, that + # does not matter; otherwise, using the last one instead of the + # first one might be considered a surprise. + for v in versions: + if v.count(".") > version.count(".") or version == "": + version = v + else: + for v in versions: + if v != "": + version = v + break + if pretty and version and self.codename(): + version = "{0} ({1})".format(version, self.codename()) + return version + + def version_parts(self, best=False): + # type: (bool) -> Tuple[str, str, str] + """ + Return the version of the OS distribution, as a tuple of version + numbers. + + For details, see :func:`distro.version_parts`. + """ + version_str = self.version(best=best) + if version_str: + version_regex = re.compile(r"(\d+)\.?(\d+)?\.?(\d+)?") + matches = version_regex.match(version_str) + if matches: + major, minor, build_number = matches.groups() + return major, minor or "", build_number or "" + return "", "", "" + + def major_version(self, best=False): + # type: (bool) -> str + """ + Return the major version number of the current distribution. + + For details, see :func:`distro.major_version`. + """ + return self.version_parts(best)[0] + + def minor_version(self, best=False): + # type: (bool) -> str + """ + Return the minor version number of the current distribution. + + For details, see :func:`distro.minor_version`. + """ + return self.version_parts(best)[1] + + def build_number(self, best=False): + # type: (bool) -> str + """ + Return the build number of the current distribution. + + For details, see :func:`distro.build_number`. + """ + return self.version_parts(best)[2] + + def like(self): + # type: () -> str + """ + Return the IDs of distributions that are like the OS distribution. + + For details, see :func:`distro.like`. + """ + return self.os_release_attr("id_like") or "" + + def codename(self): + # type: () -> str + """ + Return the codename of the OS distribution. + + For details, see :func:`distro.codename`. + """ + try: + # Handle os_release specially since distros might purposefully set + # this to empty string to have no codename + return self._os_release_info["codename"] + except KeyError: + return ( + self.lsb_release_attr("codename") + or self.distro_release_attr("codename") + or "" + ) + + def info(self, pretty=False, best=False): + # type: (bool, bool) -> InfoDict + """ + Return certain machine-readable information about the OS + distribution. + + For details, see :func:`distro.info`. + """ + return dict( + id=self.id(), + version=self.version(pretty, best), + version_parts=dict( + major=self.major_version(best), + minor=self.minor_version(best), + build_number=self.build_number(best), + ), + like=self.like(), + codename=self.codename(), + ) + + def os_release_info(self): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information + items from the os-release file data source of the OS distribution. + + For details, see :func:`distro.os_release_info`. + """ + return self._os_release_info + + def lsb_release_info(self): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information + items from the lsb_release command data source of the OS + distribution. + + For details, see :func:`distro.lsb_release_info`. + """ + return self._lsb_release_info + + def distro_release_info(self): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information + items from the distro release file data source of the OS + distribution. + + For details, see :func:`distro.distro_release_info`. + """ + return self._distro_release_info + + def uname_info(self): + # type: () -> Dict[str, str] + """ + Return a dictionary containing key-value pairs for the information + items from the uname command data source of the OS distribution. + + For details, see :func:`distro.uname_info`. + """ + return self._uname_info + + def os_release_attr(self, attribute): + # type: (str) -> str + """ + Return a single named information item from the os-release file data + source of the OS distribution. + + For details, see :func:`distro.os_release_attr`. + """ + return self._os_release_info.get(attribute, "") + + def lsb_release_attr(self, attribute): + # type: (str) -> str + """ + Return a single named information item from the lsb_release command + output data source of the OS distribution. + + For details, see :func:`distro.lsb_release_attr`. + """ + return self._lsb_release_info.get(attribute, "") + + def distro_release_attr(self, attribute): + # type: (str) -> str + """ + Return a single named information item from the distro release file + data source of the OS distribution. + + For details, see :func:`distro.distro_release_attr`. + """ + return self._distro_release_info.get(attribute, "") + + def uname_attr(self, attribute): + # type: (str) -> str + """ + Return a single named information item from the uname command + output data source of the OS distribution. + + For details, see :func:`distro.uname_attr`. + """ + return self._uname_info.get(attribute, "") + + @cached_property + def _os_release_info(self): + # type: () -> Dict[str, str] + """ + Get the information items from the specified os-release file. + + Returns: + A dictionary containing all information items. + """ + if os.path.isfile(self.os_release_file): + with open(self.os_release_file) as release_file: + return self._parse_os_release_content(release_file) + return {} + + @staticmethod + def _parse_os_release_content(lines): + # type: (TextIO) -> Dict[str, str] + """ + Parse the lines of an os-release file. + + Parameters: + + * lines: Iterable through the lines in the os-release file. + Each line must be a unicode string or a UTF-8 encoded byte + string. + + Returns: + A dictionary containing all information items. + """ + props = {} + lexer = shlex.shlex(lines, posix=True) + lexer.whitespace_split = True + + # The shlex module defines its `wordchars` variable using literals, + # making it dependent on the encoding of the Python source file. + # In Python 2.6 and 2.7, the shlex source file is encoded in + # 'iso-8859-1', and the `wordchars` variable is defined as a byte + # string. This causes a UnicodeDecodeError to be raised when the + # parsed content is a unicode object. The following fix resolves that + # (... but it should be fixed in shlex...): + if sys.version_info[0] == 2 and isinstance(lexer.wordchars, bytes): + lexer.wordchars = lexer.wordchars.decode("iso-8859-1") + + tokens = list(lexer) + for token in tokens: + # At this point, all shell-like parsing has been done (i.e. + # comments processed, quotes and backslash escape sequences + # processed, multi-line values assembled, trailing newlines + # stripped, etc.), so the tokens are now either: + # * variable assignments: var=value + # * commands or their arguments (not allowed in os-release) + if "=" in token: + k, v = token.split("=", 1) + props[k.lower()] = v + else: + # Ignore any tokens that are not variable assignments + pass + + if "version_codename" in props: + # os-release added a version_codename field. Use that in + # preference to anything else Note that some distros purposefully + # do not have code names. They should be setting + # version_codename="" + props["codename"] = props["version_codename"] + elif "ubuntu_codename" in props: + # Same as above but a non-standard field name used on older Ubuntus + props["codename"] = props["ubuntu_codename"] + elif "version" in props: + # If there is no version_codename, parse it from the version + match = re.search(r"(\(\D+\))|,(\s+)?\D+", props["version"]) + if match: + codename = match.group() + codename = codename.strip("()") + codename = codename.strip(",") + codename = codename.strip() + # codename appears within paranthese. + props["codename"] = codename + + return props + + @cached_property + def _lsb_release_info(self): + # type: () -> Dict[str, str] + """ + Get the information items from the lsb_release command output. + + Returns: + A dictionary containing all information items. + """ + if not self.include_lsb: + return {} + with open(os.devnull, "wb") as devnull: + try: + cmd = ("lsb_release", "-a") + stdout = _check_output(cmd, stderr=devnull) + # Command not found or lsb_release returned error + except (OSError, subprocess.CalledProcessError): + return {} + content = self._to_str(stdout).splitlines() + return self._parse_lsb_release_content(content) + + @staticmethod + def _parse_lsb_release_content(lines): + # type: (Iterable[str]) -> Dict[str, str] + """ + Parse the output of the lsb_release command. + + Parameters: + + * lines: Iterable through the lines of the lsb_release output. + Each line must be a unicode string or a UTF-8 encoded byte + string. + + Returns: + A dictionary containing all information items. + """ + props = {} + for line in lines: + kv = line.strip("\n").split(":", 1) + if len(kv) != 2: + # Ignore lines without colon. + continue + k, v = kv + props.update({k.replace(" ", "_").lower(): v.strip()}) + return props + + @cached_property + def _uname_info(self): + # type: () -> Dict[str, str] + with open(os.devnull, "wb") as devnull: + try: + cmd = ("uname", "-rs") + stdout = _check_output(cmd, stderr=devnull) + except OSError: + return {} + content = self._to_str(stdout).splitlines() + return self._parse_uname_content(content) + + @staticmethod + def _parse_uname_content(lines): + # type: (Sequence[str]) -> Dict[str, str] + props = {} + match = re.search(r"^([^\s]+)\s+([\d\.]+)", lines[0].strip()) + if match: + name, version = match.groups() + + # This is to prevent the Linux kernel version from + # appearing as the 'best' version on otherwise + # identifiable distributions. + if name == "Linux": + return {} + props["id"] = name.lower() + props["name"] = name + props["release"] = version + return props + + @staticmethod + def _to_str(text): + # type: (Union[bytes, str]) -> str + encoding = sys.getfilesystemencoding() + encoding = "utf-8" if encoding == "ascii" else encoding + + if sys.version_info[0] >= 3: + if isinstance(text, bytes): + return text.decode(encoding) + else: + if isinstance(text, unicode): # noqa pylint: disable=undefined-variable + return text.encode(encoding) + + return text + + @cached_property + def _distro_release_info(self): + # type: () -> Dict[str, str] + """ + Get the information items from the specified distro release file. + + Returns: + A dictionary containing all information items. + """ + if self.distro_release_file: + # If it was specified, we use it and parse what we can, even if + # its file name or content does not match the expected pattern. + distro_info = self._parse_distro_release_file(self.distro_release_file) + basename = os.path.basename(self.distro_release_file) + # The file name pattern for user-specified distro release files + # is somewhat more tolerant (compared to when searching for the + # file), because we want to use what was specified as best as + # possible. + match = _DISTRO_RELEASE_BASENAME_PATTERN.match(basename) + if "name" in distro_info and "cloudlinux" in distro_info["name"].lower(): + distro_info["id"] = "cloudlinux" + elif match: + distro_info["id"] = match.group(1) + return distro_info + else: + try: + basenames = os.listdir(self.etc_dir) + # We sort for repeatability in cases where there are multiple + # distro specific files; e.g. CentOS, Oracle, Enterprise all + # containing `redhat-release` on top of their own. + basenames.sort() + except OSError: + # This may occur when /etc is not readable but we can't be + # sure about the *-release files. Check common entries of + # /etc for information. If they turn out to not be there the + # error is handled in `_parse_distro_release_file()`. + basenames = [ + "SuSE-release", + "arch-release", + "base-release", + "centos-release", + "fedora-release", + "gentoo-release", + "mageia-release", + "mandrake-release", + "mandriva-release", + "mandrivalinux-release", + "manjaro-release", + "oracle-release", + "redhat-release", + "sl-release", + "slackware-version", + ] + for basename in basenames: + if basename in _DISTRO_RELEASE_IGNORE_BASENAMES: + continue + match = _DISTRO_RELEASE_BASENAME_PATTERN.match(basename) + if match: + filepath = os.path.join(self.etc_dir, basename) + distro_info = self._parse_distro_release_file(filepath) + if "name" in distro_info: + # The name is always present if the pattern matches + self.distro_release_file = filepath + distro_info["id"] = match.group(1) + if "cloudlinux" in distro_info["name"].lower(): + distro_info["id"] = "cloudlinux" + return distro_info + return {} + + def _parse_distro_release_file(self, filepath): + # type: (str) -> Dict[str, str] + """ + Parse a distro release file. + + Parameters: + + * filepath: Path name of the distro release file. + + Returns: + A dictionary containing all information items. + """ + try: + with open(filepath) as fp: + # Only parse the first line. For instance, on SLES there + # are multiple lines. We don't want them... + return self._parse_distro_release_content(fp.readline()) + except (OSError, IOError): + # Ignore not being able to read a specific, seemingly version + # related file. + # See https://github.com/python-distro/distro/issues/162 + return {} + + @staticmethod + def _parse_distro_release_content(line): + # type: (str) -> Dict[str, str] + """ + Parse a line from a distro release file. + + Parameters: + * line: Line from the distro release file. Must be a unicode string + or a UTF-8 encoded byte string. + + Returns: + A dictionary containing all information items. + """ + matches = _DISTRO_RELEASE_CONTENT_REVERSED_PATTERN.match(line.strip()[::-1]) + distro_info = {} + if matches: + # regexp ensures non-None + distro_info["name"] = matches.group(3)[::-1] + if matches.group(2): + distro_info["version_id"] = matches.group(2)[::-1] + if matches.group(1): + distro_info["codename"] = matches.group(1)[::-1] + elif line: + distro_info["name"] = line.strip() + return distro_info + + +_distro = LinuxDistribution() + + +def main(): + # type: () -> None + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler(sys.stdout)) + + dist = _distro + + logger.info("Name: %s", dist.name(pretty=True)) + distribution_version = dist.version(pretty=True) + logger.info("Version: %s", distribution_version) + distribution_codename = dist.codename() + logger.info("Codename: %s", distribution_codename) + + +if __name__ == "__main__": + main() diff --git a/lib/ansible/module_utils/errors.py b/lib/ansible/module_utils/errors.py new file mode 100644 index 0000000..cbbd86c --- /dev/null +++ b/lib/ansible/module_utils/errors.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2021 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +class AnsibleFallbackNotFound(Exception): + """Fallback validator was not found""" + + +class AnsibleValidationError(Exception): + """Single argument spec validation error""" + + def __init__(self, message): + super(AnsibleValidationError, self).__init__(message) + self.error_message = message + """The error message passed in when the exception was raised.""" + + @property + def msg(self): + """The error message passed in when the exception was raised.""" + return self.args[0] + + +class AnsibleValidationErrorMultiple(AnsibleValidationError): + """Multiple argument spec validation errors""" + + def __init__(self, errors=None): + self.errors = errors[:] if errors else [] + """:class:`list` of :class:`AnsibleValidationError` objects""" + + def __getitem__(self, key): + return self.errors[key] + + def __setitem__(self, key, value): + self.errors[key] = value + + def __delitem__(self, key): + del self.errors[key] + + @property + def msg(self): + """The first message from the first error in ``errors``.""" + return self.errors[0].args[0] + + @property + def messages(self): + """:class:`list` of each error message in ``errors``.""" + return [err.msg for err in self.errors] + + def append(self, error): + """Append a new error to ``self.errors``. + + Only :class:`AnsibleValidationError` should be added. + """ + + self.errors.append(error) + + def extend(self, errors): + """Append each item in ``errors`` to ``self.errors``. Only :class:`AnsibleValidationError` should be added.""" + self.errors.extend(errors) + + +class AliasError(AnsibleValidationError): + """Error handling aliases""" + + +class ArgumentTypeError(AnsibleValidationError): + """Error with parameter type""" + + +class ArgumentValueError(AnsibleValidationError): + """Error with parameter value""" + + +class DeprecationError(AnsibleValidationError): + """Error processing parameter deprecations""" + + +class ElementError(AnsibleValidationError): + """Error when validating elements""" + + +class MutuallyExclusiveError(AnsibleValidationError): + """Mutually exclusive parameters were supplied""" + + +class NoLogError(AnsibleValidationError): + """Error converting no_log values""" + + +class RequiredByError(AnsibleValidationError): + """Error with parameters that are required by other parameters""" + + +class RequiredDefaultError(AnsibleValidationError): + """A required parameter was assigned a default value""" + + +class RequiredError(AnsibleValidationError): + """Missing a required parameter""" + + +class RequiredIfError(AnsibleValidationError): + """Error with conditionally required parameters""" + + +class RequiredOneOfError(AnsibleValidationError): + """Error with parameters where at least one is required""" + + +class RequiredTogetherError(AnsibleValidationError): + """Error with parameters that are required together""" + + +class SubParameterTypeError(AnsibleValidationError): + """Incorrect type for subparameter""" + + +class UnsupportedError(AnsibleValidationError): + """Unsupported parameters were supplied""" diff --git a/lib/ansible/module_utils/facts/__init__.py b/lib/ansible/module_utils/facts/__init__.py new file mode 100644 index 0000000..96ab778 --- /dev/null +++ b/lib/ansible/module_utils/facts/__init__.py @@ -0,0 +1,34 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +# import from the compat api because 2.0-2.3 had a module_utils.facts.ansible_facts +# and get_all_facts in top level namespace +from ansible.module_utils.facts.compat import ansible_facts, get_all_facts # noqa diff --git a/lib/ansible/module_utils/facts/ansible_collector.py b/lib/ansible/module_utils/facts/ansible_collector.py new file mode 100644 index 0000000..e9bafe2 --- /dev/null +++ b/lib/ansible/module_utils/facts/ansible_collector.py @@ -0,0 +1,158 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import fnmatch +import sys + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts import timeout +from ansible.module_utils.facts import collector +from ansible.module_utils.common.collections import is_string + + +class AnsibleFactCollector(collector.BaseFactCollector): + '''A FactCollector that returns results under 'ansible_facts' top level key. + + If a namespace if provided, facts will be collected under that namespace. + For ex, a ansible.module_utils.facts.namespace.PrefixFactNamespace(prefix='ansible_') + + Has a 'from_gather_subset() constructor that populates collectors based on a + gather_subset specifier.''' + + def __init__(self, collectors=None, namespace=None, filter_spec=None): + + super(AnsibleFactCollector, self).__init__(collectors=collectors, + namespace=namespace) + + self.filter_spec = filter_spec + + def _filter(self, facts_dict, filter_spec): + # assume filter_spec='' or filter_spec=[] is equivalent to filter_spec='*' + if not filter_spec or filter_spec == '*': + return facts_dict + + if is_string(filter_spec): + filter_spec = [filter_spec] + + found = [] + for f in filter_spec: + for x, y in facts_dict.items(): + if not f or fnmatch.fnmatch(x, f): + found.append((x, y)) + elif not f.startswith(('ansible_', 'facter', 'ohai')): + # try to match with ansible_ prefix added when non empty + g = 'ansible_%s' % f + if fnmatch.fnmatch(x, g): + found.append((x, y)) + return found + + def collect(self, module=None, collected_facts=None): + collected_facts = collected_facts or {} + + facts_dict = {} + + for collector_obj in self.collectors: + info_dict = {} + + try: + + # Note: this collects with namespaces, so collected_facts also includes namespaces + info_dict = collector_obj.collect_with_namespace(module=module, + collected_facts=collected_facts) + except Exception as e: + sys.stderr.write(repr(e)) + sys.stderr.write('\n') + + # shallow copy of the new facts to pass to each collector in collected_facts so facts + # can reference other facts they depend on. + collected_facts.update(info_dict.copy()) + + # NOTE: If we want complicated fact dict merging, this is where it would hook in + facts_dict.update(self._filter(info_dict, self.filter_spec)) + + return facts_dict + + +class CollectorMetaDataCollector(collector.BaseFactCollector): + '''Collector that provides a facts with the gather_subset metadata.''' + + name = 'gather_subset' + _fact_ids = set() # type: t.Set[str] + + def __init__(self, collectors=None, namespace=None, gather_subset=None, module_setup=None): + super(CollectorMetaDataCollector, self).__init__(collectors, namespace) + self.gather_subset = gather_subset + self.module_setup = module_setup + + def collect(self, module=None, collected_facts=None): + meta_facts = {'gather_subset': self.gather_subset} + if self.module_setup: + meta_facts['module_setup'] = self.module_setup + return meta_facts + + +def get_ansible_collector(all_collector_classes, + namespace=None, + filter_spec=None, + gather_subset=None, + gather_timeout=None, + minimal_gather_subset=None): + + filter_spec = filter_spec or [] + gather_subset = gather_subset or ['all'] + gather_timeout = gather_timeout or timeout.DEFAULT_GATHER_TIMEOUT + minimal_gather_subset = minimal_gather_subset or frozenset() + + collector_classes = \ + collector.collector_classes_from_gather_subset( + all_collector_classes=all_collector_classes, + minimal_gather_subset=minimal_gather_subset, + gather_subset=gather_subset, + gather_timeout=gather_timeout) + + collectors = [] + for collector_class in collector_classes: + collector_obj = collector_class(namespace=namespace) + collectors.append(collector_obj) + + # Add a collector that knows what gather_subset we used so it it can provide a fact + collector_meta_data_collector = \ + CollectorMetaDataCollector(gather_subset=gather_subset, + module_setup=True) + collectors.append(collector_meta_data_collector) + + fact_collector = \ + AnsibleFactCollector(collectors=collectors, + filter_spec=filter_spec, + namespace=namespace) + + return fact_collector diff --git a/lib/ansible/module_utils/facts/collector.py b/lib/ansible/module_utils/facts/collector.py new file mode 100644 index 0000000..ac52fe8 --- /dev/null +++ b/lib/ansible/module_utils/facts/collector.py @@ -0,0 +1,402 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from collections import defaultdict + +import platform + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts import timeout + + +class CycleFoundInFactDeps(Exception): + '''Indicates there is a cycle in fact collector deps + + If collector-B requires collector-A, and collector-A requires + collector-B, that is a cycle. In that case, there is no ordering + that will satisfy B before A and A and before B. That will cause this + error to be raised. + ''' + pass + + +class UnresolvedFactDep(ValueError): + pass + + +class CollectorNotFoundError(KeyError): + pass + + +class BaseFactCollector: + _fact_ids = set() # type: t.Set[str] + + _platform = 'Generic' + name = None # type: str | None + required_facts = set() # type: t.Set[str] + + def __init__(self, collectors=None, namespace=None): + '''Base class for things that collect facts. + + 'collectors' is an optional list of other FactCollectors for composing.''' + self.collectors = collectors or [] + + # self.namespace is a object with a 'transform' method that transforms + # the name to indicate the namespace (ie, adds a prefix or suffix). + self.namespace = namespace + + self.fact_ids = set([self.name]) + self.fact_ids.update(self._fact_ids) + + @classmethod + def platform_match(cls, platform_info): + if platform_info.get('system', None) == cls._platform: + return cls + return None + + def _transform_name(self, key_name): + if self.namespace: + return self.namespace.transform(key_name) + return key_name + + def _transform_dict_keys(self, fact_dict): + '''update a dicts keys to use new names as transformed by self._transform_name''' + + for old_key in list(fact_dict.keys()): + new_key = self._transform_name(old_key) + # pop the item by old_key and replace it using new_key + fact_dict[new_key] = fact_dict.pop(old_key) + return fact_dict + + # TODO/MAYBE: rename to 'collect' and add 'collect_without_namespace' + def collect_with_namespace(self, module=None, collected_facts=None): + # collect, then transform the key names if needed + facts_dict = self.collect(module=module, collected_facts=collected_facts) + if self.namespace: + facts_dict = self._transform_dict_keys(facts_dict) + return facts_dict + + def collect(self, module=None, collected_facts=None): + '''do the fact collection + + 'collected_facts' is a object (a dict, likely) that holds all previously + facts. This is intended to be used if a FactCollector needs to reference + another fact (for ex, the system arch) and should not be modified (usually). + + Returns a dict of facts. + + ''' + facts_dict = {} + return facts_dict + + +def get_collector_names(valid_subsets=None, + minimal_gather_subset=None, + gather_subset=None, + aliases_map=None, + platform_info=None): + '''return a set of FactCollector names based on gather_subset spec. + + gather_subset is a spec describing which facts to gather. + valid_subsets is a frozenset of potential matches for gather_subset ('all', 'network') etc + minimal_gather_subsets is a frozenset of matches to always use, even for gather_subset='!all' + ''' + + # Retrieve module parameters + gather_subset = gather_subset or ['all'] + + # the list of everything that 'all' expands to + valid_subsets = valid_subsets or frozenset() + + # if provided, minimal_gather_subset is always added, even after all negations + minimal_gather_subset = minimal_gather_subset or frozenset() + + aliases_map = aliases_map or defaultdict(set) + + # Retrieve all facts elements + additional_subsets = set() + exclude_subsets = set() + + # total always starts with the min set, then + # adds of the additions in gather_subset, then + # excludes all of the excludes, then add any explicitly + # requested subsets. + gather_subset_with_min = ['min'] + gather_subset_with_min.extend(gather_subset) + + # subsets we mention in gather_subset explicitly, except for 'all'/'min' + explicitly_added = set() + + for subset in gather_subset_with_min: + subset_id = subset + if subset_id == 'min': + additional_subsets.update(minimal_gather_subset) + continue + if subset_id == 'all': + additional_subsets.update(valid_subsets) + continue + if subset_id.startswith('!'): + subset = subset[1:] + if subset == 'min': + exclude_subsets.update(minimal_gather_subset) + continue + if subset == 'all': + exclude_subsets.update(valid_subsets - minimal_gather_subset) + continue + exclude = True + else: + exclude = False + + if exclude: + # include 'devices', 'dmi' etc for '!hardware' + exclude_subsets.update(aliases_map.get(subset, set())) + exclude_subsets.add(subset) + else: + # NOTE: this only considers adding an unknown gather subsetup an error. Asking to + # exclude an unknown gather subset is ignored. + if subset_id not in valid_subsets: + raise TypeError("Bad subset '%s' given to Ansible. gather_subset options allowed: all, %s" % + (subset, ", ".join(sorted(valid_subsets)))) + + explicitly_added.add(subset) + additional_subsets.add(subset) + + if not additional_subsets: + additional_subsets.update(valid_subsets) + + additional_subsets.difference_update(exclude_subsets - explicitly_added) + + return additional_subsets + + +def find_collectors_for_platform(all_collector_classes, compat_platforms): + found_collectors = set() + found_collectors_names = set() + + # start from specific platform, then try generic + for compat_platform in compat_platforms: + platform_match = None + for all_collector_class in all_collector_classes: + + # ask the class if it is compatible with the platform info + platform_match = all_collector_class.platform_match(compat_platform) + + if not platform_match: + continue + + primary_name = all_collector_class.name + + if primary_name not in found_collectors_names: + found_collectors.add(all_collector_class) + found_collectors_names.add(all_collector_class.name) + + return found_collectors + + +def build_fact_id_to_collector_map(collectors_for_platform): + fact_id_to_collector_map = defaultdict(list) + aliases_map = defaultdict(set) + + for collector_class in collectors_for_platform: + primary_name = collector_class.name + + fact_id_to_collector_map[primary_name].append(collector_class) + + for fact_id in collector_class._fact_ids: + fact_id_to_collector_map[fact_id].append(collector_class) + aliases_map[primary_name].add(fact_id) + + return fact_id_to_collector_map, aliases_map + + +def select_collector_classes(collector_names, all_fact_subsets): + seen_collector_classes = set() + + selected_collector_classes = [] + + for collector_name in collector_names: + collector_classes = all_fact_subsets.get(collector_name, []) + for collector_class in collector_classes: + if collector_class not in seen_collector_classes: + selected_collector_classes.append(collector_class) + seen_collector_classes.add(collector_class) + + return selected_collector_classes + + +def _get_requires_by_collector_name(collector_name, all_fact_subsets): + required_facts = set() + + try: + collector_classes = all_fact_subsets[collector_name] + except KeyError: + raise CollectorNotFoundError('Fact collector "%s" not found' % collector_name) + for collector_class in collector_classes: + required_facts.update(collector_class.required_facts) + return required_facts + + +def find_unresolved_requires(collector_names, all_fact_subsets): + '''Find any collector names that have unresolved requires + + Returns a list of collector names that correspond to collector + classes whose .requires_facts() are not in collector_names. + ''' + unresolved = set() + + for collector_name in collector_names: + required_facts = _get_requires_by_collector_name(collector_name, all_fact_subsets) + for required_fact in required_facts: + if required_fact not in collector_names: + unresolved.add(required_fact) + + return unresolved + + +def resolve_requires(unresolved_requires, all_fact_subsets): + new_names = set() + failed = [] + for unresolved in unresolved_requires: + if unresolved in all_fact_subsets: + new_names.add(unresolved) + else: + failed.append(unresolved) + + if failed: + raise UnresolvedFactDep('unresolved fact dep %s' % ','.join(failed)) + return new_names + + +def build_dep_data(collector_names, all_fact_subsets): + dep_map = defaultdict(set) + for collector_name in collector_names: + collector_deps = set() + for collector in all_fact_subsets[collector_name]: + for dep in collector.required_facts: + collector_deps.add(dep) + dep_map[collector_name] = collector_deps + return dep_map + + +def tsort(dep_map): + sorted_list = [] + + unsorted_map = dep_map.copy() + + while unsorted_map: + acyclic = False + for node, edges in list(unsorted_map.items()): + for edge in edges: + if edge in unsorted_map: + break + else: + acyclic = True + del unsorted_map[node] + sorted_list.append((node, edges)) + + if not acyclic: + raise CycleFoundInFactDeps('Unable to tsort deps, there was a cycle in the graph. sorted=%s' % sorted_list) + + return sorted_list + + +def _solve_deps(collector_names, all_fact_subsets): + unresolved = collector_names.copy() + solutions = collector_names.copy() + + while True: + unresolved = find_unresolved_requires(solutions, all_fact_subsets) + if unresolved == set(): + break + + new_names = resolve_requires(unresolved, all_fact_subsets) + solutions.update(new_names) + + return solutions + + +def collector_classes_from_gather_subset(all_collector_classes=None, + valid_subsets=None, + minimal_gather_subset=None, + gather_subset=None, + gather_timeout=None, + platform_info=None): + '''return a list of collector classes that match the args''' + + # use gather_name etc to get the list of collectors + + all_collector_classes = all_collector_classes or [] + + minimal_gather_subset = minimal_gather_subset or frozenset() + + platform_info = platform_info or {'system': platform.system()} + + gather_timeout = gather_timeout or timeout.DEFAULT_GATHER_TIMEOUT + + # tweak the modules GATHER_TIMEOUT + timeout.GATHER_TIMEOUT = gather_timeout + + valid_subsets = valid_subsets or frozenset() + + # maps alias names like 'hardware' to the list of names that are part of hardware + # like 'devices' and 'dmi' + aliases_map = defaultdict(set) + + compat_platforms = [platform_info, {'system': 'Generic'}] + + collectors_for_platform = find_collectors_for_platform(all_collector_classes, compat_platforms) + + # all_facts_subsets maps the subset name ('hardware') to the class that provides it. + + # TODO: name collisions here? are there facts with the same name as a gather_subset (all, network, hardware, virtual, ohai, facter) + all_fact_subsets, aliases_map = build_fact_id_to_collector_map(collectors_for_platform) + + all_valid_subsets = frozenset(all_fact_subsets.keys()) + + # expand any fact_id/collectorname/gather_subset term ('all', 'env', etc) to the list of names that represents + collector_names = get_collector_names(valid_subsets=all_valid_subsets, + minimal_gather_subset=minimal_gather_subset, + gather_subset=gather_subset, + aliases_map=aliases_map, + platform_info=platform_info) + + complete_collector_names = _solve_deps(collector_names, all_fact_subsets) + + dep_map = build_dep_data(complete_collector_names, all_fact_subsets) + + ordered_deps = tsort(dep_map) + ordered_collector_names = [x[0] for x in ordered_deps] + + selected_collector_classes = select_collector_classes(ordered_collector_names, + all_fact_subsets) + + return selected_collector_classes diff --git a/lib/ansible/module_utils/facts/compat.py b/lib/ansible/module_utils/facts/compat.py new file mode 100644 index 0000000..a69fee3 --- /dev/null +++ b/lib/ansible/module_utils/facts/compat.py @@ -0,0 +1,87 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.namespace import PrefixFactNamespace +from ansible.module_utils.facts import default_collectors +from ansible.module_utils.facts import ansible_collector + + +def get_all_facts(module): + '''compat api for ansible 2.2/2.3 module_utils.facts.get_all_facts method + + Expects module to be an instance of AnsibleModule, with a 'gather_subset' param. + + returns a dict mapping the bare fact name ('default_ipv4' with no 'ansible_' namespace) to + the fact value.''' + + gather_subset = module.params['gather_subset'] + return ansible_facts(module, gather_subset=gather_subset) + + +def ansible_facts(module, gather_subset=None): + '''Compat api for ansible 2.0/2.2/2.3 module_utils.facts.ansible_facts method + + 2.3/2.3 expects a gather_subset arg. + 2.0/2.1 does not except a gather_subset arg + + So make gather_subsets an optional arg, defaulting to configured DEFAULT_GATHER_TIMEOUT + + 'module' should be an instance of an AnsibleModule. + + returns a dict mapping the bare fact name ('default_ipv4' with no 'ansible_' namespace) to + the fact value. + ''' + + gather_subset = gather_subset or module.params.get('gather_subset', ['all']) + gather_timeout = module.params.get('gather_timeout', 10) + filter_spec = module.params.get('filter', '*') + + minimal_gather_subset = frozenset(['apparmor', 'caps', 'cmdline', 'date_time', + 'distribution', 'dns', 'env', 'fips', 'local', + 'lsb', 'pkg_mgr', 'platform', 'python', 'selinux', + 'service_mgr', 'ssh_pub_keys', 'user']) + + all_collector_classes = default_collectors.collectors + + # don't add a prefix + namespace = PrefixFactNamespace(namespace_name='ansible', prefix='') + + fact_collector = \ + ansible_collector.get_ansible_collector(all_collector_classes=all_collector_classes, + namespace=namespace, + filter_spec=filter_spec, + gather_subset=gather_subset, + gather_timeout=gather_timeout, + minimal_gather_subset=minimal_gather_subset) + + facts_dict = fact_collector.collect(module=module) + + return facts_dict diff --git a/lib/ansible/module_utils/facts/default_collectors.py b/lib/ansible/module_utils/facts/default_collectors.py new file mode 100644 index 0000000..cf0ef23 --- /dev/null +++ b/lib/ansible/module_utils/facts/default_collectors.py @@ -0,0 +1,177 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + +from ansible.module_utils.facts.other.facter import FacterFactCollector +from ansible.module_utils.facts.other.ohai import OhaiFactCollector + +from ansible.module_utils.facts.system.apparmor import ApparmorFactCollector +from ansible.module_utils.facts.system.caps import SystemCapabilitiesFactCollector +from ansible.module_utils.facts.system.chroot import ChrootFactCollector +from ansible.module_utils.facts.system.cmdline import CmdLineFactCollector +from ansible.module_utils.facts.system.distribution import DistributionFactCollector +from ansible.module_utils.facts.system.date_time import DateTimeFactCollector +from ansible.module_utils.facts.system.env import EnvFactCollector +from ansible.module_utils.facts.system.dns import DnsFactCollector +from ansible.module_utils.facts.system.fips import FipsFactCollector +from ansible.module_utils.facts.system.loadavg import LoadAvgFactCollector +from ansible.module_utils.facts.system.local import LocalFactCollector +from ansible.module_utils.facts.system.lsb import LSBFactCollector +from ansible.module_utils.facts.system.pkg_mgr import PkgMgrFactCollector +from ansible.module_utils.facts.system.pkg_mgr import OpenBSDPkgMgrFactCollector +from ansible.module_utils.facts.system.platform import PlatformFactCollector +from ansible.module_utils.facts.system.python import PythonFactCollector +from ansible.module_utils.facts.system.selinux import SelinuxFactCollector +from ansible.module_utils.facts.system.service_mgr import ServiceMgrFactCollector +from ansible.module_utils.facts.system.ssh_pub_keys import SshPubKeyFactCollector +from ansible.module_utils.facts.system.user import UserFactCollector + +from ansible.module_utils.facts.hardware.base import HardwareCollector +from ansible.module_utils.facts.hardware.aix import AIXHardwareCollector +from ansible.module_utils.facts.hardware.darwin import DarwinHardwareCollector +from ansible.module_utils.facts.hardware.dragonfly import DragonFlyHardwareCollector +from ansible.module_utils.facts.hardware.freebsd import FreeBSDHardwareCollector +from ansible.module_utils.facts.hardware.hpux import HPUXHardwareCollector +from ansible.module_utils.facts.hardware.hurd import HurdHardwareCollector +from ansible.module_utils.facts.hardware.linux import LinuxHardwareCollector +from ansible.module_utils.facts.hardware.netbsd import NetBSDHardwareCollector +from ansible.module_utils.facts.hardware.openbsd import OpenBSDHardwareCollector +from ansible.module_utils.facts.hardware.sunos import SunOSHardwareCollector + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.aix import AIXNetworkCollector +from ansible.module_utils.facts.network.darwin import DarwinNetworkCollector +from ansible.module_utils.facts.network.dragonfly import DragonFlyNetworkCollector +from ansible.module_utils.facts.network.fc_wwn import FcWwnInitiatorFactCollector +from ansible.module_utils.facts.network.freebsd import FreeBSDNetworkCollector +from ansible.module_utils.facts.network.hpux import HPUXNetworkCollector +from ansible.module_utils.facts.network.hurd import HurdNetworkCollector +from ansible.module_utils.facts.network.linux import LinuxNetworkCollector +from ansible.module_utils.facts.network.iscsi import IscsiInitiatorNetworkCollector +from ansible.module_utils.facts.network.nvme import NvmeInitiatorNetworkCollector +from ansible.module_utils.facts.network.netbsd import NetBSDNetworkCollector +from ansible.module_utils.facts.network.openbsd import OpenBSDNetworkCollector +from ansible.module_utils.facts.network.sunos import SunOSNetworkCollector + +from ansible.module_utils.facts.virtual.base import VirtualCollector +from ansible.module_utils.facts.virtual.dragonfly import DragonFlyVirtualCollector +from ansible.module_utils.facts.virtual.freebsd import FreeBSDVirtualCollector +from ansible.module_utils.facts.virtual.hpux import HPUXVirtualCollector +from ansible.module_utils.facts.virtual.linux import LinuxVirtualCollector +from ansible.module_utils.facts.virtual.netbsd import NetBSDVirtualCollector +from ansible.module_utils.facts.virtual.openbsd import OpenBSDVirtualCollector +from ansible.module_utils.facts.virtual.sunos import SunOSVirtualCollector + +# these should always be first due to most other facts depending on them +_base = [ + PlatformFactCollector, + DistributionFactCollector, + LSBFactCollector +] # type: t.List[t.Type[BaseFactCollector]] + +# These restrict what is possible in others +_restrictive = [ + SelinuxFactCollector, + ApparmorFactCollector, + ChrootFactCollector, + FipsFactCollector +] # type: t.List[t.Type[BaseFactCollector]] + +# general info, not required but probably useful for other facts +_general = [ + PythonFactCollector, + SystemCapabilitiesFactCollector, + PkgMgrFactCollector, + OpenBSDPkgMgrFactCollector, + ServiceMgrFactCollector, + CmdLineFactCollector, + DateTimeFactCollector, + EnvFactCollector, + LoadAvgFactCollector, + SshPubKeyFactCollector, + UserFactCollector +] # type: t.List[t.Type[BaseFactCollector]] + +# virtual, this might also limit hardware/networking +_virtual = [ + VirtualCollector, + DragonFlyVirtualCollector, + FreeBSDVirtualCollector, + LinuxVirtualCollector, + OpenBSDVirtualCollector, + NetBSDVirtualCollector, + SunOSVirtualCollector, + HPUXVirtualCollector +] # type: t.List[t.Type[BaseFactCollector]] + +_hardware = [ + HardwareCollector, + AIXHardwareCollector, + DarwinHardwareCollector, + DragonFlyHardwareCollector, + FreeBSDHardwareCollector, + HPUXHardwareCollector, + HurdHardwareCollector, + LinuxHardwareCollector, + NetBSDHardwareCollector, + OpenBSDHardwareCollector, + SunOSHardwareCollector +] # type: t.List[t.Type[BaseFactCollector]] + +_network = [ + DnsFactCollector, + FcWwnInitiatorFactCollector, + NetworkCollector, + AIXNetworkCollector, + DarwinNetworkCollector, + DragonFlyNetworkCollector, + FreeBSDNetworkCollector, + HPUXNetworkCollector, + HurdNetworkCollector, + IscsiInitiatorNetworkCollector, + NvmeInitiatorNetworkCollector, + LinuxNetworkCollector, + NetBSDNetworkCollector, + OpenBSDNetworkCollector, + SunOSNetworkCollector +] # type: t.List[t.Type[BaseFactCollector]] + +# other fact sources +_extra_facts = [ + LocalFactCollector, + FacterFactCollector, + OhaiFactCollector +] # type: t.List[t.Type[BaseFactCollector]] + +# TODO: make config driven +collectors = _base + _restrictive + _general + _virtual + _hardware + _network + _extra_facts diff --git a/lib/ansible/module_utils/facts/hardware/__init__.py b/lib/ansible/module_utils/facts/hardware/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/__init__.py diff --git a/lib/ansible/module_utils/facts/hardware/aix.py b/lib/ansible/module_utils/facts/hardware/aix.py new file mode 100644 index 0000000..dc37394 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/aix.py @@ -0,0 +1,266 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector +from ansible.module_utils.facts.utils import get_mount_size + + +class AIXHardware(Hardware): + """ + AIX-specific subclass of Hardware. Defines memory and CPU facts: + - memfree_mb + - memtotal_mb + - swapfree_mb + - swaptotal_mb + - processor (a list) + - processor_count + - processor_cores + - processor_threads_per_core + - processor_vcpus + """ + platform = 'AIX' + + def populate(self, collected_facts=None): + hardware_facts = {} + + cpu_facts = self.get_cpu_facts() + memory_facts = self.get_memory_facts() + dmi_facts = self.get_dmi_facts() + vgs_facts = self.get_vgs_facts() + mount_facts = self.get_mount_facts() + devices_facts = self.get_device_facts() + + hardware_facts.update(cpu_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(dmi_facts) + hardware_facts.update(vgs_facts) + hardware_facts.update(mount_facts) + hardware_facts.update(devices_facts) + + return hardware_facts + + def get_cpu_facts(self): + cpu_facts = {} + cpu_facts['processor'] = [] + + # FIXME: not clear how to detect multi-sockets + cpu_facts['processor_count'] = 1 + rc, out, err = self.module.run_command( + "/usr/sbin/lsdev -Cc processor" + ) + if out: + i = 0 + for line in out.splitlines(): + + if 'Available' in line: + if i == 0: + data = line.split(' ') + cpudev = data[0] + + i += 1 + cpu_facts['processor_cores'] = int(i) + + rc, out, err = self.module.run_command( + "/usr/sbin/lsattr -El " + cpudev + " -a type" + ) + + data = out.split(' ') + cpu_facts['processor'] = [data[1]] + + cpu_facts['processor_threads_per_core'] = 1 + rc, out, err = self.module.run_command( + "/usr/sbin/lsattr -El " + cpudev + " -a smt_threads" + ) + if out: + data = out.split(' ') + cpu_facts['processor_threads_per_core'] = int(data[1]) + cpu_facts['processor_vcpus'] = ( + cpu_facts['processor_cores'] * cpu_facts['processor_threads_per_core'] + ) + + return cpu_facts + + def get_memory_facts(self): + memory_facts = {} + pagesize = 4096 + rc, out, err = self.module.run_command("/usr/bin/vmstat -v") + for line in out.splitlines(): + data = line.split() + if 'memory pages' in line: + pagecount = int(data[0]) + if 'free pages' in line: + freecount = int(data[0]) + memory_facts['memtotal_mb'] = pagesize * pagecount // 1024 // 1024 + memory_facts['memfree_mb'] = pagesize * freecount // 1024 // 1024 + # Get swapinfo. swapinfo output looks like: + # Device 1M-blocks Used Avail Capacity + # /dev/ada0p3 314368 0 314368 0% + # + rc, out, err = self.module.run_command("/usr/sbin/lsps -s") + if out: + lines = out.splitlines() + data = lines[1].split() + swaptotal_mb = int(data[0].rstrip('MB')) + percused = int(data[1].rstrip('%')) + memory_facts['swaptotal_mb'] = swaptotal_mb + memory_facts['swapfree_mb'] = int(swaptotal_mb * (100 - percused) / 100) + + return memory_facts + + def get_dmi_facts(self): + dmi_facts = {} + + rc, out, err = self.module.run_command("/usr/sbin/lsattr -El sys0 -a fwversion") + data = out.split() + dmi_facts['firmware_version'] = data[1].strip('IBM,') + lsconf_path = self.module.get_bin_path("lsconf") + if lsconf_path: + rc, out, err = self.module.run_command(lsconf_path) + if rc == 0 and out: + for line in out.splitlines(): + data = line.split(':') + if 'Machine Serial Number' in line: + dmi_facts['product_serial'] = data[1].strip() + if 'LPAR Info' in line: + dmi_facts['lpar_info'] = data[1].strip() + if 'System Model' in line: + dmi_facts['product_name'] = data[1].strip() + return dmi_facts + + def get_vgs_facts(self): + """ + Get vg and pv Facts + rootvg: + PV_NAME PV STATE TOTAL PPs FREE PPs FREE DISTRIBUTION + hdisk0 active 546 0 00..00..00..00..00 + hdisk1 active 546 113 00..00..00..21..92 + realsyncvg: + PV_NAME PV STATE TOTAL PPs FREE PPs FREE DISTRIBUTION + hdisk74 active 1999 6 00..00..00..00..06 + testvg: + PV_NAME PV STATE TOTAL PPs FREE PPs FREE DISTRIBUTION + hdisk105 active 999 838 200..39..199..200..200 + hdisk106 active 999 599 200..00..00..199..200 + """ + + vgs_facts = {} + lsvg_path = self.module.get_bin_path("lsvg") + xargs_path = self.module.get_bin_path("xargs") + cmd = "%s -o | %s %s -p" % (lsvg_path, xargs_path, lsvg_path) + if lsvg_path and xargs_path: + rc, out, err = self.module.run_command(cmd, use_unsafe_shell=True) + if rc == 0 and out: + vgs_facts['vgs'] = {} + for m in re.finditer(r'(\S+):\n.*FREE DISTRIBUTION(\n(\S+)\s+(\w+)\s+(\d+)\s+(\d+).*)+', out): + vgs_facts['vgs'][m.group(1)] = [] + pp_size = 0 + cmd = "%s %s" % (lsvg_path, m.group(1)) + rc, out, err = self.module.run_command(cmd) + if rc == 0 and out: + pp_size = re.search(r'PP SIZE:\s+(\d+\s+\S+)', out).group(1) + for n in re.finditer(r'(\S+)\s+(\w+)\s+(\d+)\s+(\d+).*', m.group(0)): + pv_info = {'pv_name': n.group(1), + 'pv_state': n.group(2), + 'total_pps': n.group(3), + 'free_pps': n.group(4), + 'pp_size': pp_size + } + vgs_facts['vgs'][m.group(1)].append(pv_info) + + return vgs_facts + + def get_mount_facts(self): + mount_facts = {} + + mount_facts['mounts'] = [] + + mounts = [] + + # AIX does not have mtab but mount command is only source of info (or to use + # api calls to get same info) + mount_path = self.module.get_bin_path('mount') + rc, mount_out, err = self.module.run_command(mount_path) + if mount_out: + for line in mount_out.split('\n'): + fields = line.split() + if len(fields) != 0 and fields[0] != 'node' and fields[0][0] != '-' and re.match('^/.*|^[a-zA-Z].*|^[0-9].*', fields[0]): + if re.match('^/', fields[0]): + # normal mount + mount = fields[1] + mount_info = {'mount': mount, + 'device': fields[0], + 'fstype': fields[2], + 'options': fields[6], + 'time': '%s %s %s' % (fields[3], fields[4], fields[5])} + mount_info.update(get_mount_size(mount)) + else: + # nfs or cifs based mount + # in case of nfs if no mount options are provided on command line + # add into fields empty string... + if len(fields) < 8: + fields.append("") + + mount_info = {'mount': fields[2], + 'device': '%s:%s' % (fields[0], fields[1]), + 'fstype': fields[3], + 'options': fields[7], + 'time': '%s %s %s' % (fields[4], fields[5], fields[6])} + + mounts.append(mount_info) + + mount_facts['mounts'] = mounts + + return mount_facts + + def get_device_facts(self): + device_facts = {} + device_facts['devices'] = {} + + lsdev_cmd = self.module.get_bin_path('lsdev', True) + lsattr_cmd = self.module.get_bin_path('lsattr', True) + rc, out_lsdev, err = self.module.run_command(lsdev_cmd) + + for line in out_lsdev.splitlines(): + field = line.split() + + device_attrs = {} + device_name = field[0] + device_state = field[1] + device_type = field[2:] + lsattr_cmd_args = [lsattr_cmd, '-E', '-l', device_name] + rc, out_lsattr, err = self.module.run_command(lsattr_cmd_args) + for attr in out_lsattr.splitlines(): + attr_fields = attr.split() + attr_name = attr_fields[0] + attr_parameter = attr_fields[1] + device_attrs[attr_name] = attr_parameter + + device_facts['devices'][device_name] = { + 'state': device_state, + 'type': ' '.join(device_type), + 'attributes': device_attrs + } + + return device_facts + + +class AIXHardwareCollector(HardwareCollector): + _platform = 'AIX' + _fact_class = AIXHardware diff --git a/lib/ansible/module_utils/facts/hardware/base.py b/lib/ansible/module_utils/facts/hardware/base.py new file mode 100644 index 0000000..846bb30 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/base.py @@ -0,0 +1,68 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class Hardware: + platform = 'Generic' + + # FIXME: remove load_on_init when we can + def __init__(self, module, load_on_init=False): + self.module = module + + def populate(self, collected_facts=None): + return {} + + +class HardwareCollector(BaseFactCollector): + name = 'hardware' + _fact_ids = set(['processor', + 'processor_cores', + 'processor_count', + # TODO: mounts isnt exactly hardware + 'mounts', + 'devices']) # type: t.Set[str] + _fact_class = Hardware + + def collect(self, module=None, collected_facts=None): + collected_facts = collected_facts or {} + if not module: + return {} + + # Network munges cached_facts by side effect, so give it a copy + facts_obj = self._fact_class(module) + + facts_dict = facts_obj.populate(collected_facts=collected_facts) + + return facts_dict diff --git a/lib/ansible/module_utils/facts/hardware/darwin.py b/lib/ansible/module_utils/facts/hardware/darwin.py new file mode 100644 index 0000000..d6a8e11 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/darwin.py @@ -0,0 +1,159 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import struct +import time + +from ansible.module_utils.common.process import get_bin_path +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector +from ansible.module_utils.facts.sysctl import get_sysctl + + +class DarwinHardware(Hardware): + """ + Darwin-specific subclass of Hardware. Defines memory and CPU facts: + - processor + - processor_cores + - memtotal_mb + - memfree_mb + - model + - osversion + - osrevision + - uptime_seconds + """ + platform = 'Darwin' + + def populate(self, collected_facts=None): + hardware_facts = {} + + self.sysctl = get_sysctl(self.module, ['hw', 'machdep', 'kern']) + mac_facts = self.get_mac_facts() + cpu_facts = self.get_cpu_facts() + memory_facts = self.get_memory_facts() + uptime_facts = self.get_uptime_facts() + + hardware_facts.update(mac_facts) + hardware_facts.update(cpu_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(uptime_facts) + + return hardware_facts + + def get_system_profile(self): + rc, out, err = self.module.run_command(["/usr/sbin/system_profiler", "SPHardwareDataType"]) + if rc != 0: + return dict() + system_profile = dict() + for line in out.splitlines(): + if ': ' in line: + (key, value) = line.split(': ', 1) + system_profile[key.strip()] = ' '.join(value.strip().split()) + return system_profile + + def get_mac_facts(self): + mac_facts = {} + rc, out, err = self.module.run_command("sysctl hw.model") + if rc == 0: + mac_facts['model'] = mac_facts['product_name'] = out.splitlines()[-1].split()[1] + mac_facts['osversion'] = self.sysctl['kern.osversion'] + mac_facts['osrevision'] = self.sysctl['kern.osrevision'] + + return mac_facts + + def get_cpu_facts(self): + cpu_facts = {} + if 'machdep.cpu.brand_string' in self.sysctl: # Intel + cpu_facts['processor'] = self.sysctl['machdep.cpu.brand_string'] + cpu_facts['processor_cores'] = self.sysctl['machdep.cpu.core_count'] + else: # PowerPC + system_profile = self.get_system_profile() + cpu_facts['processor'] = '%s @ %s' % (system_profile['Processor Name'], system_profile['Processor Speed']) + cpu_facts['processor_cores'] = self.sysctl['hw.physicalcpu'] + cpu_facts['processor_vcpus'] = self.sysctl.get('hw.logicalcpu') or self.sysctl.get('hw.ncpu') or '' + + return cpu_facts + + def get_memory_facts(self): + memory_facts = { + 'memtotal_mb': int(self.sysctl['hw.memsize']) // 1024 // 1024, + 'memfree_mb': 0, + } + + total_used = 0 + page_size = 4096 + try: + vm_stat_command = get_bin_path('vm_stat') + except ValueError: + return memory_facts + + rc, out, err = self.module.run_command(vm_stat_command) + if rc == 0: + # Free = Total - (Wired + active + inactive) + # Get a generator of tuples from the command output so we can later + # turn it into a dictionary + memory_stats = (line.rstrip('.').split(':', 1) for line in out.splitlines()) + + # Strip extra left spaces from the value + memory_stats = dict((k, v.lstrip()) for k, v in memory_stats) + + for k, v in memory_stats.items(): + try: + memory_stats[k] = int(v) + except ValueError: + # Most values convert cleanly to integer values but if the field does + # not convert to an integer, just leave it alone. + pass + + if memory_stats.get('Pages wired down'): + total_used += memory_stats['Pages wired down'] * page_size + if memory_stats.get('Pages active'): + total_used += memory_stats['Pages active'] * page_size + if memory_stats.get('Pages inactive'): + total_used += memory_stats['Pages inactive'] * page_size + + memory_facts['memfree_mb'] = memory_facts['memtotal_mb'] - (total_used // 1024 // 1024) + + return memory_facts + + def get_uptime_facts(self): + # On Darwin, the default format is annoying to parse. + # Use -b to get the raw value and decode it. + sysctl_cmd = self.module.get_bin_path('sysctl') + cmd = [sysctl_cmd, '-b', 'kern.boottime'] + + # We need to get raw bytes, not UTF-8. + rc, out, err = self.module.run_command(cmd, encoding=None) + + # kern.boottime returns seconds and microseconds as two 64-bits + # fields, but we are only interested in the first field. + struct_format = '@L' + struct_size = struct.calcsize(struct_format) + if rc != 0 or len(out) < struct_size: + return {} + + (kern_boottime, ) = struct.unpack(struct_format, out[:struct_size]) + + return { + 'uptime_seconds': int(time.time() - kern_boottime), + } + + +class DarwinHardwareCollector(HardwareCollector): + _fact_class = DarwinHardware + _platform = 'Darwin' diff --git a/lib/ansible/module_utils/facts/hardware/dragonfly.py b/lib/ansible/module_utils/facts/hardware/dragonfly.py new file mode 100644 index 0000000..ea24151 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/dragonfly.py @@ -0,0 +1,26 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.hardware.base import HardwareCollector +from ansible.module_utils.facts.hardware.freebsd import FreeBSDHardware + + +class DragonFlyHardwareCollector(HardwareCollector): + # Note: This uses the freebsd fact class, there is no dragonfly hardware fact class + _fact_class = FreeBSDHardware + _platform = 'DragonFly' diff --git a/lib/ansible/module_utils/facts/hardware/freebsd.py b/lib/ansible/module_utils/facts/hardware/freebsd.py new file mode 100644 index 0000000..cce2ab2 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/freebsd.py @@ -0,0 +1,241 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import json +import re +import struct +import time + +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector +from ansible.module_utils.facts.timeout import TimeoutError, timeout + +from ansible.module_utils.facts.utils import get_file_content, get_mount_size + + +class FreeBSDHardware(Hardware): + """ + FreeBSD-specific subclass of Hardware. Defines memory and CPU facts: + - memfree_mb + - memtotal_mb + - swapfree_mb + - swaptotal_mb + - processor (a list) + - processor_cores + - processor_count + - devices + - uptime_seconds + """ + platform = 'FreeBSD' + DMESG_BOOT = '/var/run/dmesg.boot' + + def populate(self, collected_facts=None): + hardware_facts = {} + + cpu_facts = self.get_cpu_facts() + memory_facts = self.get_memory_facts() + uptime_facts = self.get_uptime_facts() + dmi_facts = self.get_dmi_facts() + device_facts = self.get_device_facts() + + mount_facts = {} + try: + mount_facts = self.get_mount_facts() + except TimeoutError: + pass + + hardware_facts.update(cpu_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(uptime_facts) + hardware_facts.update(dmi_facts) + hardware_facts.update(device_facts) + hardware_facts.update(mount_facts) + + return hardware_facts + + def get_cpu_facts(self): + cpu_facts = {} + cpu_facts['processor'] = [] + sysctl = self.module.get_bin_path('sysctl') + if sysctl: + rc, out, err = self.module.run_command("%s -n hw.ncpu" % sysctl, check_rc=False) + cpu_facts['processor_count'] = out.strip() + + dmesg_boot = get_file_content(FreeBSDHardware.DMESG_BOOT) + if not dmesg_boot: + try: + rc, dmesg_boot, err = self.module.run_command(self.module.get_bin_path("dmesg"), check_rc=False) + except Exception: + dmesg_boot = '' + + for line in dmesg_boot.splitlines(): + if 'CPU:' in line: + cpu = re.sub(r'CPU:\s+', r"", line) + cpu_facts['processor'].append(cpu.strip()) + if 'Logical CPUs per core' in line: + cpu_facts['processor_cores'] = line.split()[4] + + return cpu_facts + + def get_memory_facts(self): + memory_facts = {} + + sysctl = self.module.get_bin_path('sysctl') + if sysctl: + rc, out, err = self.module.run_command("%s vm.stats" % sysctl, check_rc=False) + for line in out.splitlines(): + data = line.split() + if 'vm.stats.vm.v_page_size' in line: + pagesize = int(data[1]) + if 'vm.stats.vm.v_page_count' in line: + pagecount = int(data[1]) + if 'vm.stats.vm.v_free_count' in line: + freecount = int(data[1]) + memory_facts['memtotal_mb'] = pagesize * pagecount // 1024 // 1024 + memory_facts['memfree_mb'] = pagesize * freecount // 1024 // 1024 + + swapinfo = self.module.get_bin_path('swapinfo') + if swapinfo: + # Get swapinfo. swapinfo output looks like: + # Device 1M-blocks Used Avail Capacity + # /dev/ada0p3 314368 0 314368 0% + # + rc, out, err = self.module.run_command("%s -k" % swapinfo) + lines = out.splitlines() + if len(lines[-1]) == 0: + lines.pop() + data = lines[-1].split() + if data[0] != 'Device': + memory_facts['swaptotal_mb'] = int(data[1]) // 1024 + memory_facts['swapfree_mb'] = int(data[3]) // 1024 + + return memory_facts + + def get_uptime_facts(self): + # On FreeBSD, the default format is annoying to parse. + # Use -b to get the raw value and decode it. + sysctl_cmd = self.module.get_bin_path('sysctl') + cmd = [sysctl_cmd, '-b', 'kern.boottime'] + + # We need to get raw bytes, not UTF-8. + rc, out, err = self.module.run_command(cmd, encoding=None) + + # kern.boottime returns seconds and microseconds as two 64-bits + # fields, but we are only interested in the first field. + struct_format = '@L' + struct_size = struct.calcsize(struct_format) + if rc != 0 or len(out) < struct_size: + return {} + + (kern_boottime, ) = struct.unpack(struct_format, out[:struct_size]) + + return { + 'uptime_seconds': int(time.time() - kern_boottime), + } + + @timeout() + def get_mount_facts(self): + mount_facts = {} + + mount_facts['mounts'] = [] + fstab = get_file_content('/etc/fstab') + if fstab: + for line in fstab.splitlines(): + if line.startswith('#') or line.strip() == '': + continue + fields = re.sub(r'\s+', ' ', line).split() + mount_statvfs_info = get_mount_size(fields[1]) + mount_info = {'mount': fields[1], + 'device': fields[0], + 'fstype': fields[2], + 'options': fields[3]} + mount_info.update(mount_statvfs_info) + mount_facts['mounts'].append(mount_info) + + return mount_facts + + def get_device_facts(self): + device_facts = {} + + sysdir = '/dev' + device_facts['devices'] = {} + drives = re.compile(r'(ada?\d+|da\d+|a?cd\d+)') # TODO: rc, disks, err = self.module.run_command("/sbin/sysctl kern.disks") + slices = re.compile(r'(ada?\d+s\d+\w*|da\d+s\d+\w*)') + if os.path.isdir(sysdir): + dirlist = sorted(os.listdir(sysdir)) + for device in dirlist: + d = drives.match(device) + if d: + device_facts['devices'][d.group(1)] = [] + s = slices.match(device) + if s: + device_facts['devices'][d.group(1)].append(s.group(1)) + + return device_facts + + def get_dmi_facts(self): + ''' learn dmi facts from system + + Use dmidecode executable if available''' + + dmi_facts = {} + + # Fall back to using dmidecode, if available + dmi_bin = self.module.get_bin_path('dmidecode') + DMI_DICT = { + 'bios_date': 'bios-release-date', + 'bios_vendor': 'bios-vendor', + 'bios_version': 'bios-version', + 'board_asset_tag': 'baseboard-asset-tag', + 'board_name': 'baseboard-product-name', + 'board_serial': 'baseboard-serial-number', + 'board_vendor': 'baseboard-manufacturer', + 'board_version': 'baseboard-version', + 'chassis_asset_tag': 'chassis-asset-tag', + 'chassis_serial': 'chassis-serial-number', + 'chassis_vendor': 'chassis-manufacturer', + 'chassis_version': 'chassis-version', + 'form_factor': 'chassis-type', + 'product_name': 'system-product-name', + 'product_serial': 'system-serial-number', + 'product_uuid': 'system-uuid', + 'product_version': 'system-version', + 'system_vendor': 'system-manufacturer', + } + for (k, v) in DMI_DICT.items(): + if dmi_bin is not None: + (rc, out, err) = self.module.run_command('%s -s %s' % (dmi_bin, v)) + if rc == 0: + # Strip out commented lines (specific dmidecode output) + # FIXME: why add the fact and then test if it is json? + dmi_facts[k] = ''.join([line for line in out.splitlines() if not line.startswith('#')]) + try: + json.dumps(dmi_facts[k]) + except UnicodeDecodeError: + dmi_facts[k] = 'NA' + else: + dmi_facts[k] = 'NA' + else: + dmi_facts[k] = 'NA' + + return dmi_facts + + +class FreeBSDHardwareCollector(HardwareCollector): + _fact_class = FreeBSDHardware + _platform = 'FreeBSD' diff --git a/lib/ansible/module_utils/facts/hardware/hpux.py b/lib/ansible/module_utils/facts/hardware/hpux.py new file mode 100644 index 0000000..ae72ed8 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/hpux.py @@ -0,0 +1,165 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re + +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector + + +class HPUXHardware(Hardware): + """ + HP-UX-specific subclass of Hardware. Defines memory and CPU facts: + - memfree_mb + - memtotal_mb + - swapfree_mb + - swaptotal_mb + - processor + - processor_cores + - processor_count + - model + - firmware + """ + + platform = 'HP-UX' + + def populate(self, collected_facts=None): + hardware_facts = {} + + cpu_facts = self.get_cpu_facts(collected_facts=collected_facts) + memory_facts = self.get_memory_facts() + hw_facts = self.get_hw_facts() + + hardware_facts.update(cpu_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(hw_facts) + + return hardware_facts + + def get_cpu_facts(self, collected_facts=None): + cpu_facts = {} + collected_facts = collected_facts or {} + + if collected_facts.get('ansible_architecture') in ['9000/800', '9000/785']: + rc, out, err = self.module.run_command("ioscan -FkCprocessor | wc -l", use_unsafe_shell=True) + cpu_facts['processor_count'] = int(out.strip()) + # Working with machinfo mess + elif collected_facts.get('ansible_architecture') == 'ia64': + if collected_facts.get('ansible_distribution_version') == "B.11.23": + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep 'Number of CPUs'", use_unsafe_shell=True) + if out: + cpu_facts['processor_count'] = int(out.strip().split('=')[1]) + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep 'processor family'", use_unsafe_shell=True) + if out: + cpu_facts['processor'] = re.search('.*(Intel.*)', out).groups()[0].strip() + rc, out, err = self.module.run_command("ioscan -FkCprocessor | wc -l", use_unsafe_shell=True) + cpu_facts['processor_cores'] = int(out.strip()) + if collected_facts.get('ansible_distribution_version') == "B.11.31": + # if machinfo return cores strings release B.11.31 > 1204 + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep core | wc -l", use_unsafe_shell=True) + if out.strip() == '0': + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep Intel", use_unsafe_shell=True) + cpu_facts['processor_count'] = int(out.strip().split(" ")[0]) + # If hyperthreading is active divide cores by 2 + rc, out, err = self.module.run_command("/usr/sbin/psrset | grep LCPU", use_unsafe_shell=True) + data = re.sub(' +', ' ', out).strip().split(' ') + if len(data) == 1: + hyperthreading = 'OFF' + else: + hyperthreading = data[1] + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep logical", use_unsafe_shell=True) + data = out.strip().split(" ") + if hyperthreading == 'ON': + cpu_facts['processor_cores'] = int(data[0]) / 2 + else: + if len(data) == 1: + cpu_facts['processor_cores'] = cpu_facts['processor_count'] + else: + cpu_facts['processor_cores'] = int(data[0]) + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep Intel |cut -d' ' -f4-", use_unsafe_shell=True) + cpu_facts['processor'] = out.strip() + else: + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | egrep 'socket[s]?$' | tail -1", use_unsafe_shell=True) + cpu_facts['processor_count'] = int(out.strip().split(" ")[0]) + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep -e '[0-9] core' | tail -1", use_unsafe_shell=True) + cpu_facts['processor_cores'] = int(out.strip().split(" ")[0]) + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep Intel", use_unsafe_shell=True) + cpu_facts['processor'] = out.strip() + + return cpu_facts + + def get_memory_facts(self, collected_facts=None): + memory_facts = {} + collected_facts = collected_facts or {} + + pagesize = 4096 + rc, out, err = self.module.run_command("/usr/bin/vmstat | tail -1", use_unsafe_shell=True) + data = int(re.sub(' +', ' ', out).split(' ')[5].strip()) + memory_facts['memfree_mb'] = pagesize * data // 1024 // 1024 + if collected_facts.get('ansible_architecture') in ['9000/800', '9000/785']: + try: + rc, out, err = self.module.run_command("grep Physical /var/adm/syslog/syslog.log") + data = re.search('.*Physical: ([0-9]*) Kbytes.*', out).groups()[0].strip() + memory_facts['memtotal_mb'] = int(data) // 1024 + except AttributeError: + # For systems where memory details aren't sent to syslog or the log has rotated, use parsed + # adb output. Unfortunately /dev/kmem doesn't have world-read, so this only works as root. + if os.access("/dev/kmem", os.R_OK): + rc, out, err = self.module.run_command("echo 'phys_mem_pages/D' | adb -k /stand/vmunix /dev/kmem | tail -1 | awk '{print $2}'", + use_unsafe_shell=True) + if not err: + data = out + memory_facts['memtotal_mb'] = int(data) / 256 + else: + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo | grep Memory", use_unsafe_shell=True) + data = re.search(r'Memory[\ :=]*([0-9]*).*MB.*', out).groups()[0].strip() + memory_facts['memtotal_mb'] = int(data) + rc, out, err = self.module.run_command("/usr/sbin/swapinfo -m -d -f -q") + memory_facts['swaptotal_mb'] = int(out.strip()) + rc, out, err = self.module.run_command("/usr/sbin/swapinfo -m -d -f | egrep '^dev|^fs'", use_unsafe_shell=True) + swap = 0 + for line in out.strip().splitlines(): + swap += int(re.sub(' +', ' ', line).split(' ')[3].strip()) + memory_facts['swapfree_mb'] = swap + + return memory_facts + + def get_hw_facts(self, collected_facts=None): + hw_facts = {} + collected_facts = collected_facts or {} + + rc, out, err = self.module.run_command("model") + hw_facts['model'] = out.strip() + if collected_facts.get('ansible_architecture') == 'ia64': + separator = ':' + if collected_facts.get('ansible_distribution_version') == "B.11.23": + separator = '=' + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo |grep -i 'Firmware revision' | grep -v BMC", use_unsafe_shell=True) + hw_facts['firmware_version'] = out.split(separator)[1].strip() + rc, out, err = self.module.run_command("/usr/contrib/bin/machinfo |grep -i 'Machine serial number' ", use_unsafe_shell=True) + if rc == 0 and out: + hw_facts['product_serial'] = out.split(separator)[1].strip() + + return hw_facts + + +class HPUXHardwareCollector(HardwareCollector): + _fact_class = HPUXHardware + _platform = 'HP-UX' + + required_facts = set(['platform', 'distribution']) diff --git a/lib/ansible/module_utils/facts/hardware/hurd.py b/lib/ansible/module_utils/facts/hardware/hurd.py new file mode 100644 index 0000000..306e13c --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/hurd.py @@ -0,0 +1,53 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.timeout import TimeoutError +from ansible.module_utils.facts.hardware.base import HardwareCollector +from ansible.module_utils.facts.hardware.linux import LinuxHardware + + +class HurdHardware(LinuxHardware): + """ + GNU Hurd specific subclass of Hardware. Define memory and mount facts + based on procfs compatibility translator mimicking the interface of + the Linux kernel. + """ + + platform = 'GNU' + + def populate(self, collected_facts=None): + hardware_facts = {} + uptime_facts = self.get_uptime_facts() + memory_facts = self.get_memory_facts() + + mount_facts = {} + try: + mount_facts = self.get_mount_facts() + except TimeoutError: + pass + + hardware_facts.update(uptime_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(mount_facts) + + return hardware_facts + + +class HurdHardwareCollector(HardwareCollector): + _fact_class = HurdHardware + _platform = 'GNU' diff --git a/lib/ansible/module_utils/facts/hardware/linux.py b/lib/ansible/module_utils/facts/hardware/linux.py new file mode 100644 index 0000000..c0ca33d --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/linux.py @@ -0,0 +1,869 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import collections +import errno +import glob +import json +import os +import re +import sys +import time + +from multiprocessing import cpu_count +from multiprocessing.pool import ThreadPool + +from ansible.module_utils._text import to_text +from ansible.module_utils.common.locale import get_best_parsable_locale +from ansible.module_utils.common.process import get_bin_path +from ansible.module_utils.common.text.formatters import bytes_to_human +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector +from ansible.module_utils.facts.utils import get_file_content, get_file_lines, get_mount_size +from ansible.module_utils.six import iteritems + +# import this as a module to ensure we get the same module instance +from ansible.module_utils.facts import timeout + + +def get_partition_uuid(partname): + try: + uuids = os.listdir("/dev/disk/by-uuid") + except OSError: + return + + for uuid in uuids: + dev = os.path.realpath("/dev/disk/by-uuid/" + uuid) + if dev == ("/dev/" + partname): + return uuid + + return None + + +class LinuxHardware(Hardware): + """ + Linux-specific subclass of Hardware. Defines memory and CPU facts: + - memfree_mb + - memtotal_mb + - swapfree_mb + - swaptotal_mb + - processor (a list) + - processor_cores + - processor_count + + In addition, it also defines number of DMI facts and device facts. + """ + + platform = 'Linux' + + # Originally only had these four as toplevelfacts + ORIGINAL_MEMORY_FACTS = frozenset(('MemTotal', 'SwapTotal', 'MemFree', 'SwapFree')) + # Now we have all of these in a dict structure + MEMORY_FACTS = ORIGINAL_MEMORY_FACTS.union(('Buffers', 'Cached', 'SwapCached')) + + # regex used against findmnt output to detect bind mounts + BIND_MOUNT_RE = re.compile(r'.*\]') + + # regex used against mtab content to find entries that are bind mounts + MTAB_BIND_MOUNT_RE = re.compile(r'.*bind.*"') + + # regex used for replacing octal escape sequences + OCTAL_ESCAPE_RE = re.compile(r'\\[0-9]{3}') + + def populate(self, collected_facts=None): + hardware_facts = {} + locale = get_best_parsable_locale(self.module) + self.module.run_command_environ_update = {'LANG': locale, 'LC_ALL': locale, 'LC_NUMERIC': locale} + + cpu_facts = self.get_cpu_facts(collected_facts=collected_facts) + memory_facts = self.get_memory_facts() + dmi_facts = self.get_dmi_facts() + device_facts = self.get_device_facts() + uptime_facts = self.get_uptime_facts() + lvm_facts = self.get_lvm_facts() + + mount_facts = {} + try: + mount_facts = self.get_mount_facts() + except timeout.TimeoutError: + self.module.warn("No mount facts were gathered due to timeout.") + + hardware_facts.update(cpu_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(dmi_facts) + hardware_facts.update(device_facts) + hardware_facts.update(uptime_facts) + hardware_facts.update(lvm_facts) + hardware_facts.update(mount_facts) + + return hardware_facts + + def get_memory_facts(self): + memory_facts = {} + if not os.access("/proc/meminfo", os.R_OK): + return memory_facts + + memstats = {} + for line in get_file_lines("/proc/meminfo"): + data = line.split(":", 1) + key = data[0] + if key in self.ORIGINAL_MEMORY_FACTS: + val = data[1].strip().split(' ')[0] + memory_facts["%s_mb" % key.lower()] = int(val) // 1024 + + if key in self.MEMORY_FACTS: + val = data[1].strip().split(' ')[0] + memstats[key.lower()] = int(val) // 1024 + + if None not in (memstats.get('memtotal'), memstats.get('memfree')): + memstats['real:used'] = memstats['memtotal'] - memstats['memfree'] + if None not in (memstats.get('cached'), memstats.get('memfree'), memstats.get('buffers')): + memstats['nocache:free'] = memstats['cached'] + memstats['memfree'] + memstats['buffers'] + if None not in (memstats.get('memtotal'), memstats.get('nocache:free')): + memstats['nocache:used'] = memstats['memtotal'] - memstats['nocache:free'] + if None not in (memstats.get('swaptotal'), memstats.get('swapfree')): + memstats['swap:used'] = memstats['swaptotal'] - memstats['swapfree'] + + memory_facts['memory_mb'] = { + 'real': { + 'total': memstats.get('memtotal'), + 'used': memstats.get('real:used'), + 'free': memstats.get('memfree'), + }, + 'nocache': { + 'free': memstats.get('nocache:free'), + 'used': memstats.get('nocache:used'), + }, + 'swap': { + 'total': memstats.get('swaptotal'), + 'free': memstats.get('swapfree'), + 'used': memstats.get('swap:used'), + 'cached': memstats.get('swapcached'), + }, + } + + return memory_facts + + def get_cpu_facts(self, collected_facts=None): + cpu_facts = {} + collected_facts = collected_facts or {} + + i = 0 + vendor_id_occurrence = 0 + model_name_occurrence = 0 + processor_occurrence = 0 + physid = 0 + coreid = 0 + sockets = {} + cores = {} + + xen = False + xen_paravirt = False + try: + if os.path.exists('/proc/xen'): + xen = True + else: + for line in get_file_lines('/sys/hypervisor/type'): + if line.strip() == 'xen': + xen = True + # Only interested in the first line + break + except IOError: + pass + + if not os.access("/proc/cpuinfo", os.R_OK): + return cpu_facts + + cpu_facts['processor'] = [] + for line in get_file_lines('/proc/cpuinfo'): + data = line.split(":", 1) + key = data[0].strip() + + try: + val = data[1].strip() + except IndexError: + val = "" + + if xen: + if key == 'flags': + # Check for vme cpu flag, Xen paravirt does not expose this. + # Need to detect Xen paravirt because it exposes cpuinfo + # differently than Xen HVM or KVM and causes reporting of + # only a single cpu core. + if 'vme' not in val: + xen_paravirt = True + + # model name is for Intel arch, Processor (mind the uppercase P) + # works for some ARM devices, like the Sheevaplug. + # 'ncpus active' is SPARC attribute + if key in ['model name', 'Processor', 'vendor_id', 'cpu', 'Vendor', 'processor']: + if 'processor' not in cpu_facts: + cpu_facts['processor'] = [] + cpu_facts['processor'].append(val) + if key == 'vendor_id': + vendor_id_occurrence += 1 + if key == 'model name': + model_name_occurrence += 1 + if key == 'processor': + processor_occurrence += 1 + i += 1 + elif key == 'physical id': + physid = val + if physid not in sockets: + sockets[physid] = 1 + elif key == 'core id': + coreid = val + if coreid not in sockets: + cores[coreid] = 1 + elif key == 'cpu cores': + sockets[physid] = int(val) + elif key == 'siblings': + cores[coreid] = int(val) + elif key == '# processors': + cpu_facts['processor_cores'] = int(val) + elif key == 'ncpus active': + i = int(val) + + # Skip for platforms without vendor_id/model_name in cpuinfo (e.g ppc64le) + if vendor_id_occurrence > 0: + if vendor_id_occurrence == model_name_occurrence: + i = vendor_id_occurrence + + # The fields for ARM CPUs do not always include 'vendor_id' or 'model name', + # and sometimes includes both 'processor' and 'Processor'. + # The fields for Power CPUs include 'processor' and 'cpu'. + # Always use 'processor' count for ARM and Power systems + if collected_facts.get('ansible_architecture', '').startswith(('armv', 'aarch', 'ppc')): + i = processor_occurrence + + # FIXME + if collected_facts.get('ansible_architecture') != 's390x': + if xen_paravirt: + cpu_facts['processor_count'] = i + cpu_facts['processor_cores'] = i + cpu_facts['processor_threads_per_core'] = 1 + cpu_facts['processor_vcpus'] = i + else: + if sockets: + cpu_facts['processor_count'] = len(sockets) + else: + cpu_facts['processor_count'] = i + + socket_values = list(sockets.values()) + if socket_values and socket_values[0]: + cpu_facts['processor_cores'] = socket_values[0] + else: + cpu_facts['processor_cores'] = 1 + + core_values = list(cores.values()) + if core_values: + cpu_facts['processor_threads_per_core'] = core_values[0] // cpu_facts['processor_cores'] + else: + cpu_facts['processor_threads_per_core'] = 1 // cpu_facts['processor_cores'] + + cpu_facts['processor_vcpus'] = (cpu_facts['processor_threads_per_core'] * + cpu_facts['processor_count'] * cpu_facts['processor_cores']) + + # if the number of processors available to the module's + # thread cannot be determined, the processor count + # reported by /proc will be the default: + cpu_facts['processor_nproc'] = processor_occurrence + + try: + cpu_facts['processor_nproc'] = len( + os.sched_getaffinity(0) + ) + except AttributeError: + # In Python < 3.3, os.sched_getaffinity() is not available + try: + cmd = get_bin_path('nproc') + except ValueError: + pass + else: + rc, out, _err = self.module.run_command(cmd) + if rc == 0: + cpu_facts['processor_nproc'] = int(out) + + return cpu_facts + + def get_dmi_facts(self): + ''' learn dmi facts from system + + Try /sys first for dmi related facts. + If that is not available, fall back to dmidecode executable ''' + + dmi_facts = {} + + if os.path.exists('/sys/devices/virtual/dmi/id/product_name'): + # Use kernel DMI info, if available + + # DMI SPEC -- https://www.dmtf.org/sites/default/files/standards/documents/DSP0134_3.2.0.pdf + FORM_FACTOR = ["Unknown", "Other", "Unknown", "Desktop", + "Low Profile Desktop", "Pizza Box", "Mini Tower", "Tower", + "Portable", "Laptop", "Notebook", "Hand Held", "Docking Station", + "All In One", "Sub Notebook", "Space-saving", "Lunch Box", + "Main Server Chassis", "Expansion Chassis", "Sub Chassis", + "Bus Expansion Chassis", "Peripheral Chassis", "RAID Chassis", + "Rack Mount Chassis", "Sealed-case PC", "Multi-system", + "CompactPCI", "AdvancedTCA", "Blade", "Blade Enclosure", + "Tablet", "Convertible", "Detachable", "IoT Gateway", + "Embedded PC", "Mini PC", "Stick PC"] + + DMI_DICT = { + 'bios_date': '/sys/devices/virtual/dmi/id/bios_date', + 'bios_vendor': '/sys/devices/virtual/dmi/id/bios_vendor', + 'bios_version': '/sys/devices/virtual/dmi/id/bios_version', + 'board_asset_tag': '/sys/devices/virtual/dmi/id/board_asset_tag', + 'board_name': '/sys/devices/virtual/dmi/id/board_name', + 'board_serial': '/sys/devices/virtual/dmi/id/board_serial', + 'board_vendor': '/sys/devices/virtual/dmi/id/board_vendor', + 'board_version': '/sys/devices/virtual/dmi/id/board_version', + 'chassis_asset_tag': '/sys/devices/virtual/dmi/id/chassis_asset_tag', + 'chassis_serial': '/sys/devices/virtual/dmi/id/chassis_serial', + 'chassis_vendor': '/sys/devices/virtual/dmi/id/chassis_vendor', + 'chassis_version': '/sys/devices/virtual/dmi/id/chassis_version', + 'form_factor': '/sys/devices/virtual/dmi/id/chassis_type', + 'product_name': '/sys/devices/virtual/dmi/id/product_name', + 'product_serial': '/sys/devices/virtual/dmi/id/product_serial', + 'product_uuid': '/sys/devices/virtual/dmi/id/product_uuid', + 'product_version': '/sys/devices/virtual/dmi/id/product_version', + 'system_vendor': '/sys/devices/virtual/dmi/id/sys_vendor', + } + + for (key, path) in DMI_DICT.items(): + data = get_file_content(path) + if data is not None: + if key == 'form_factor': + try: + dmi_facts['form_factor'] = FORM_FACTOR[int(data)] + except IndexError: + dmi_facts['form_factor'] = 'unknown (%s)' % data + else: + dmi_facts[key] = data + else: + dmi_facts[key] = 'NA' + + else: + # Fall back to using dmidecode, if available + dmi_bin = self.module.get_bin_path('dmidecode') + DMI_DICT = { + 'bios_date': 'bios-release-date', + 'bios_vendor': 'bios-vendor', + 'bios_version': 'bios-version', + 'board_asset_tag': 'baseboard-asset-tag', + 'board_name': 'baseboard-product-name', + 'board_serial': 'baseboard-serial-number', + 'board_vendor': 'baseboard-manufacturer', + 'board_version': 'baseboard-version', + 'chassis_asset_tag': 'chassis-asset-tag', + 'chassis_serial': 'chassis-serial-number', + 'chassis_vendor': 'chassis-manufacturer', + 'chassis_version': 'chassis-version', + 'form_factor': 'chassis-type', + 'product_name': 'system-product-name', + 'product_serial': 'system-serial-number', + 'product_uuid': 'system-uuid', + 'product_version': 'system-version', + 'system_vendor': 'system-manufacturer', + } + for (k, v) in DMI_DICT.items(): + if dmi_bin is not None: + (rc, out, err) = self.module.run_command('%s -s %s' % (dmi_bin, v)) + if rc == 0: + # Strip out commented lines (specific dmidecode output) + thisvalue = ''.join([line for line in out.splitlines() if not line.startswith('#')]) + try: + json.dumps(thisvalue) + except UnicodeDecodeError: + thisvalue = "NA" + + dmi_facts[k] = thisvalue + else: + dmi_facts[k] = 'NA' + else: + dmi_facts[k] = 'NA' + + return dmi_facts + + def _run_lsblk(self, lsblk_path): + # call lsblk and collect all uuids + # --exclude 2 makes lsblk ignore floppy disks, which are slower to answer than typical timeouts + # this uses the linux major device number + # for details see https://www.kernel.org/doc/Documentation/devices.txt + args = ['--list', '--noheadings', '--paths', '--output', 'NAME,UUID', '--exclude', '2'] + cmd = [lsblk_path] + args + rc, out, err = self.module.run_command(cmd) + return rc, out, err + + def _lsblk_uuid(self): + uuids = {} + lsblk_path = self.module.get_bin_path("lsblk") + if not lsblk_path: + return uuids + + rc, out, err = self._run_lsblk(lsblk_path) + if rc != 0: + return uuids + + # each line will be in format: + # <devicename><some whitespace><uuid> + # /dev/sda1 32caaec3-ef40-4691-a3b6-438c3f9bc1c0 + for lsblk_line in out.splitlines(): + if not lsblk_line: + continue + + line = lsblk_line.strip() + fields = line.rsplit(None, 1) + + if len(fields) < 2: + continue + + device_name, uuid = fields[0].strip(), fields[1].strip() + if device_name in uuids: + continue + uuids[device_name] = uuid + + return uuids + + def _udevadm_uuid(self, device): + # fallback for versions of lsblk <= 2.23 that don't have --paths, see _run_lsblk() above + uuid = 'N/A' + + udevadm_path = self.module.get_bin_path('udevadm') + if not udevadm_path: + return uuid + + cmd = [udevadm_path, 'info', '--query', 'property', '--name', device] + rc, out, err = self.module.run_command(cmd) + if rc != 0: + return uuid + + # a snippet of the output of the udevadm command below will be: + # ... + # ID_FS_TYPE=ext4 + # ID_FS_USAGE=filesystem + # ID_FS_UUID=57b1a3e7-9019-4747-9809-7ec52bba9179 + # ... + m = re.search('ID_FS_UUID=(.*)\n', out) + if m: + uuid = m.group(1) + + return uuid + + def _run_findmnt(self, findmnt_path): + args = ['--list', '--noheadings', '--notruncate'] + cmd = [findmnt_path] + args + rc, out, err = self.module.run_command(cmd, errors='surrogate_then_replace') + return rc, out, err + + def _find_bind_mounts(self): + bind_mounts = set() + findmnt_path = self.module.get_bin_path("findmnt") + if not findmnt_path: + return bind_mounts + + rc, out, err = self._run_findmnt(findmnt_path) + if rc != 0: + return bind_mounts + + # find bind mounts, in case /etc/mtab is a symlink to /proc/mounts + for line in out.splitlines(): + fields = line.split() + # fields[0] is the TARGET, fields[1] is the SOURCE + if len(fields) < 2: + continue + + # bind mounts will have a [/directory_name] in the SOURCE column + if self.BIND_MOUNT_RE.match(fields[1]): + bind_mounts.add(fields[0]) + + return bind_mounts + + def _mtab_entries(self): + mtab_file = '/etc/mtab' + if not os.path.exists(mtab_file): + mtab_file = '/proc/mounts' + + mtab = get_file_content(mtab_file, '') + mtab_entries = [] + for line in mtab.splitlines(): + fields = line.split() + if len(fields) < 4: + continue + mtab_entries.append(fields) + return mtab_entries + + @staticmethod + def _replace_octal_escapes_helper(match): + # Convert to integer using base8 and then convert to character + return chr(int(match.group()[1:], 8)) + + def _replace_octal_escapes(self, value): + return self.OCTAL_ESCAPE_RE.sub(self._replace_octal_escapes_helper, value) + + def get_mount_info(self, mount, device, uuids): + + mount_size = get_mount_size(mount) + + # _udevadm_uuid is a fallback for versions of lsblk <= 2.23 that don't have --paths + # see _run_lsblk() above + # https://github.com/ansible/ansible/issues/36077 + uuid = uuids.get(device, self._udevadm_uuid(device)) + + return mount_size, uuid + + def get_mount_facts(self): + + mounts = [] + + # gather system lists + bind_mounts = self._find_bind_mounts() + uuids = self._lsblk_uuid() + mtab_entries = self._mtab_entries() + + # start threads to query each mount + results = {} + pool = ThreadPool(processes=min(len(mtab_entries), cpu_count())) + maxtime = globals().get('GATHER_TIMEOUT') or timeout.DEFAULT_GATHER_TIMEOUT + for fields in mtab_entries: + # Transform octal escape sequences + fields = [self._replace_octal_escapes(field) for field in fields] + + device, mount, fstype, options = fields[0], fields[1], fields[2], fields[3] + + if not device.startswith(('/', '\\')) and ':/' not in device or fstype == 'none': + continue + + mount_info = {'mount': mount, + 'device': device, + 'fstype': fstype, + 'options': options} + + if mount in bind_mounts: + # only add if not already there, we might have a plain /etc/mtab + if not self.MTAB_BIND_MOUNT_RE.match(options): + mount_info['options'] += ",bind" + + results[mount] = {'info': mount_info, + 'extra': pool.apply_async(self.get_mount_info, (mount, device, uuids)), + 'timelimit': time.time() + maxtime} + + pool.close() # done with new workers, start gc + + # wait for workers and get results + while results: + for mount in list(results): + done = False + res = results[mount]['extra'] + try: + if res.ready(): + done = True + if res.successful(): + mount_size, uuid = res.get() + if mount_size: + results[mount]['info'].update(mount_size) + results[mount]['info']['uuid'] = uuid or 'N/A' + else: + # failed, try to find out why, if 'res.successful' we know there are no exceptions + results[mount]['info']['note'] = 'Could not get extra information: %s.' % (to_text(res.get())) + + elif time.time() > results[mount]['timelimit']: + done = True + self.module.warn("Timeout exceeded when getting mount info for %s" % mount) + results[mount]['info']['note'] = 'Could not get extra information due to timeout' + except Exception as e: + import traceback + done = True + results[mount]['info'] = 'N/A' + self.module.warn("Error prevented getting extra info for mount %s: [%s] %s." % (mount, type(e), to_text(e))) + self.module.debug(traceback.format_exc()) + + if done: + # move results outside and make loop only handle pending + mounts.append(results[mount]['info']) + del results[mount] + + # avoid cpu churn, sleep between retrying for loop with remaining mounts + time.sleep(0.1) + + return {'mounts': mounts} + + def get_device_links(self, link_dir): + if not os.path.exists(link_dir): + return {} + try: + retval = collections.defaultdict(set) + for entry in os.listdir(link_dir): + try: + target = os.path.basename(os.readlink(os.path.join(link_dir, entry))) + retval[target].add(entry) + except OSError: + continue + return dict((k, list(sorted(v))) for (k, v) in iteritems(retval)) + except OSError: + return {} + + def get_all_device_owners(self): + try: + retval = collections.defaultdict(set) + for path in glob.glob('/sys/block/*/slaves/*'): + elements = path.split('/') + device = elements[3] + target = elements[5] + retval[target].add(device) + return dict((k, list(sorted(v))) for (k, v) in iteritems(retval)) + except OSError: + return {} + + def get_all_device_links(self): + return { + 'ids': self.get_device_links('/dev/disk/by-id'), + 'uuids': self.get_device_links('/dev/disk/by-uuid'), + 'labels': self.get_device_links('/dev/disk/by-label'), + 'masters': self.get_all_device_owners(), + } + + def get_holders(self, block_dev_dict, sysdir): + block_dev_dict['holders'] = [] + if os.path.isdir(sysdir + "/holders"): + for folder in os.listdir(sysdir + "/holders"): + if not folder.startswith("dm-"): + continue + name = get_file_content(sysdir + "/holders/" + folder + "/dm/name") + if name: + block_dev_dict['holders'].append(name) + else: + block_dev_dict['holders'].append(folder) + + def _get_sg_inq_serial(self, sg_inq, block): + device = "/dev/%s" % (block) + rc, drivedata, err = self.module.run_command([sg_inq, device]) + if rc == 0: + serial = re.search(r"(?:Unit serial|Serial) number:\s+(\w+)", drivedata) + if serial: + return serial.group(1) + + def get_device_facts(self): + device_facts = {} + + device_facts['devices'] = {} + lspci = self.module.get_bin_path('lspci') + if lspci: + rc, pcidata, err = self.module.run_command([lspci, '-D'], errors='surrogate_then_replace') + else: + pcidata = None + + try: + block_devs = os.listdir("/sys/block") + except OSError: + return device_facts + + devs_wwn = {} + try: + devs_by_id = os.listdir("/dev/disk/by-id") + except OSError: + pass + else: + for link_name in devs_by_id: + if link_name.startswith("wwn-"): + try: + wwn_link = os.readlink(os.path.join("/dev/disk/by-id", link_name)) + except OSError: + continue + devs_wwn[os.path.basename(wwn_link)] = link_name[4:] + + links = self.get_all_device_links() + device_facts['device_links'] = links + + for block in block_devs: + virtual = 1 + sysfs_no_links = 0 + try: + path = os.readlink(os.path.join("/sys/block/", block)) + except OSError: + e = sys.exc_info()[1] + if e.errno == errno.EINVAL: + path = block + sysfs_no_links = 1 + else: + continue + sysdir = os.path.join("/sys/block", path) + if sysfs_no_links == 1: + for folder in os.listdir(sysdir): + if "device" in folder: + virtual = 0 + break + d = {} + d['virtual'] = virtual + d['links'] = {} + for (link_type, link_values) in iteritems(links): + d['links'][link_type] = link_values.get(block, []) + diskname = os.path.basename(sysdir) + for key in ['vendor', 'model', 'sas_address', 'sas_device_handle']: + d[key] = get_file_content(sysdir + "/device/" + key) + + sg_inq = self.module.get_bin_path('sg_inq') + + # we can get NVMe device's serial number from /sys/block/<name>/device/serial + serial_path = "/sys/block/%s/device/serial" % (block) + + if sg_inq: + serial = self._get_sg_inq_serial(sg_inq, block) + if serial: + d['serial'] = serial + else: + serial = get_file_content(serial_path) + if serial: + d['serial'] = serial + + for key, test in [('removable', '/removable'), + ('support_discard', '/queue/discard_granularity'), + ]: + d[key] = get_file_content(sysdir + test) + + if diskname in devs_wwn: + d['wwn'] = devs_wwn[diskname] + + d['partitions'] = {} + for folder in os.listdir(sysdir): + m = re.search("(" + diskname + r"[p]?\d+)", folder) + if m: + part = {} + partname = m.group(1) + part_sysdir = sysdir + "/" + partname + + part['links'] = {} + for (link_type, link_values) in iteritems(links): + part['links'][link_type] = link_values.get(partname, []) + + part['start'] = get_file_content(part_sysdir + "/start", 0) + part['sectors'] = get_file_content(part_sysdir + "/size", 0) + + part['sectorsize'] = get_file_content(part_sysdir + "/queue/logical_block_size") + if not part['sectorsize']: + part['sectorsize'] = get_file_content(part_sysdir + "/queue/hw_sector_size", 512) + part['size'] = bytes_to_human((float(part['sectors']) * 512.0)) + part['uuid'] = get_partition_uuid(partname) + self.get_holders(part, part_sysdir) + + d['partitions'][partname] = part + + d['rotational'] = get_file_content(sysdir + "/queue/rotational") + d['scheduler_mode'] = "" + scheduler = get_file_content(sysdir + "/queue/scheduler") + if scheduler is not None: + m = re.match(r".*?(\[(.*)\])", scheduler) + if m: + d['scheduler_mode'] = m.group(2) + + d['sectors'] = get_file_content(sysdir + "/size") + if not d['sectors']: + d['sectors'] = 0 + d['sectorsize'] = get_file_content(sysdir + "/queue/logical_block_size") + if not d['sectorsize']: + d['sectorsize'] = get_file_content(sysdir + "/queue/hw_sector_size", 512) + d['size'] = bytes_to_human(float(d['sectors']) * 512.0) + + d['host'] = "" + + # domains are numbered (0 to ffff), bus (0 to ff), slot (0 to 1f), and function (0 to 7). + m = re.match(r".+/([a-f0-9]{4}:[a-f0-9]{2}:[0|1][a-f0-9]\.[0-7])/", sysdir) + if m and pcidata: + pciid = m.group(1) + did = re.escape(pciid) + m = re.search("^" + did + r"\s(.*)$", pcidata, re.MULTILINE) + if m: + d['host'] = m.group(1) + + self.get_holders(d, sysdir) + + device_facts['devices'][diskname] = d + + return device_facts + + def get_uptime_facts(self): + uptime_facts = {} + uptime_file_content = get_file_content('/proc/uptime') + if uptime_file_content: + uptime_seconds_string = uptime_file_content.split(' ')[0] + uptime_facts['uptime_seconds'] = int(float(uptime_seconds_string)) + + return uptime_facts + + def _find_mapper_device_name(self, dm_device): + dm_prefix = '/dev/dm-' + mapper_device = dm_device + if dm_device.startswith(dm_prefix): + dmsetup_cmd = self.module.get_bin_path('dmsetup', True) + mapper_prefix = '/dev/mapper/' + rc, dm_name, err = self.module.run_command("%s info -C --noheadings -o name %s" % (dmsetup_cmd, dm_device)) + if rc == 0: + mapper_device = mapper_prefix + dm_name.rstrip() + return mapper_device + + def get_lvm_facts(self): + """ Get LVM Facts if running as root and lvm utils are available """ + + lvm_facts = {'lvm': 'N/A'} + + if os.getuid() == 0 and self.module.get_bin_path('vgs'): + lvm_util_options = '--noheadings --nosuffix --units g --separator ,' + + vgs_path = self.module.get_bin_path('vgs') + # vgs fields: VG #PV #LV #SN Attr VSize VFree + vgs = {} + if vgs_path: + rc, vg_lines, err = self.module.run_command('%s %s' % (vgs_path, lvm_util_options)) + for vg_line in vg_lines.splitlines(): + items = vg_line.strip().split(',') + vgs[items[0]] = {'size_g': items[-2], + 'free_g': items[-1], + 'num_lvs': items[2], + 'num_pvs': items[1]} + + lvs_path = self.module.get_bin_path('lvs') + # lvs fields: + # LV VG Attr LSize Pool Origin Data% Move Log Copy% Convert + lvs = {} + if lvs_path: + rc, lv_lines, err = self.module.run_command('%s %s' % (lvs_path, lvm_util_options)) + for lv_line in lv_lines.splitlines(): + items = lv_line.strip().split(',') + lvs[items[0]] = {'size_g': items[3], 'vg': items[1]} + + pvs_path = self.module.get_bin_path('pvs') + # pvs fields: PV VG #Fmt #Attr PSize PFree + pvs = {} + if pvs_path: + rc, pv_lines, err = self.module.run_command('%s %s' % (pvs_path, lvm_util_options)) + for pv_line in pv_lines.splitlines(): + items = pv_line.strip().split(',') + pvs[self._find_mapper_device_name(items[0])] = { + 'size_g': items[4], + 'free_g': items[5], + 'vg': items[1]} + + lvm_facts['lvm'] = {'lvs': lvs, 'vgs': vgs, 'pvs': pvs} + + return lvm_facts + + +class LinuxHardwareCollector(HardwareCollector): + _platform = 'Linux' + _fact_class = LinuxHardware + + required_facts = set(['platform']) diff --git a/lib/ansible/module_utils/facts/hardware/netbsd.py b/lib/ansible/module_utils/facts/hardware/netbsd.py new file mode 100644 index 0000000..c6557aa --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/netbsd.py @@ -0,0 +1,184 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re +import time + +from ansible.module_utils.six.moves import reduce + +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector +from ansible.module_utils.facts.timeout import TimeoutError, timeout + +from ansible.module_utils.facts.utils import get_file_content, get_file_lines, get_mount_size +from ansible.module_utils.facts.sysctl import get_sysctl + + +class NetBSDHardware(Hardware): + """ + NetBSD-specific subclass of Hardware. Defines memory and CPU facts: + - memfree_mb + - memtotal_mb + - swapfree_mb + - swaptotal_mb + - processor (a list) + - processor_cores + - processor_count + - devices + - uptime_seconds + """ + platform = 'NetBSD' + MEMORY_FACTS = ['MemTotal', 'SwapTotal', 'MemFree', 'SwapFree'] + + def populate(self, collected_facts=None): + hardware_facts = {} + self.sysctl = get_sysctl(self.module, ['machdep']) + cpu_facts = self.get_cpu_facts() + memory_facts = self.get_memory_facts() + + mount_facts = {} + try: + mount_facts = self.get_mount_facts() + except TimeoutError: + pass + + dmi_facts = self.get_dmi_facts() + uptime_facts = self.get_uptime_facts() + + hardware_facts.update(cpu_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(mount_facts) + hardware_facts.update(dmi_facts) + hardware_facts.update(uptime_facts) + + return hardware_facts + + def get_cpu_facts(self): + cpu_facts = {} + + i = 0 + physid = 0 + sockets = {} + if not os.access("/proc/cpuinfo", os.R_OK): + return cpu_facts + cpu_facts['processor'] = [] + for line in get_file_lines("/proc/cpuinfo"): + data = line.split(":", 1) + key = data[0].strip() + # model name is for Intel arch, Processor (mind the uppercase P) + # works for some ARM devices, like the Sheevaplug. + if key == 'model name' or key == 'Processor': + if 'processor' not in cpu_facts: + cpu_facts['processor'] = [] + cpu_facts['processor'].append(data[1].strip()) + i += 1 + elif key == 'physical id': + physid = data[1].strip() + if physid not in sockets: + sockets[physid] = 1 + elif key == 'cpu cores': + sockets[physid] = int(data[1].strip()) + if len(sockets) > 0: + cpu_facts['processor_count'] = len(sockets) + cpu_facts['processor_cores'] = reduce(lambda x, y: x + y, sockets.values()) + else: + cpu_facts['processor_count'] = i + cpu_facts['processor_cores'] = 'NA' + + return cpu_facts + + def get_memory_facts(self): + memory_facts = {} + if not os.access("/proc/meminfo", os.R_OK): + return memory_facts + for line in get_file_lines("/proc/meminfo"): + data = line.split(":", 1) + key = data[0] + if key in NetBSDHardware.MEMORY_FACTS: + val = data[1].strip().split(' ')[0] + memory_facts["%s_mb" % key.lower()] = int(val) // 1024 + + return memory_facts + + @timeout() + def get_mount_facts(self): + mount_facts = {} + + mount_facts['mounts'] = [] + fstab = get_file_content('/etc/fstab') + + if not fstab: + return mount_facts + + for line in fstab.splitlines(): + if line.startswith('#') or line.strip() == '': + continue + fields = re.sub(r'\s+', ' ', line).split() + mount_statvfs_info = get_mount_size(fields[1]) + mount_info = {'mount': fields[1], + 'device': fields[0], + 'fstype': fields[2], + 'options': fields[3]} + mount_info.update(mount_statvfs_info) + mount_facts['mounts'].append(mount_info) + return mount_facts + + def get_dmi_facts(self): + dmi_facts = {} + # We don't use dmidecode(8) here because: + # - it would add dependency on an external package + # - dmidecode(8) can only be ran as root + # So instead we rely on sysctl(8) to provide us the information on a + # best-effort basis. As a bonus we also get facts on non-amd64/i386 + # platforms this way. + sysctl_to_dmi = { + 'machdep.dmi.system-product': 'product_name', + 'machdep.dmi.system-version': 'product_version', + 'machdep.dmi.system-uuid': 'product_uuid', + 'machdep.dmi.system-serial': 'product_serial', + 'machdep.dmi.system-vendor': 'system_vendor', + } + + for mib in sysctl_to_dmi: + if mib in self.sysctl: + dmi_facts[sysctl_to_dmi[mib]] = self.sysctl[mib] + + return dmi_facts + + def get_uptime_facts(self): + # On NetBSD, we need to call sysctl with -n to get this value as an int. + sysctl_cmd = self.module.get_bin_path('sysctl') + cmd = [sysctl_cmd, '-n', 'kern.boottime'] + + rc, out, err = self.module.run_command(cmd) + + if rc != 0: + return {} + + kern_boottime = out.strip() + if not kern_boottime.isdigit(): + return {} + + return { + 'uptime_seconds': int(time.time() - int(kern_boottime)), + } + + +class NetBSDHardwareCollector(HardwareCollector): + _fact_class = NetBSDHardware + _platform = 'NetBSD' diff --git a/lib/ansible/module_utils/facts/hardware/openbsd.py b/lib/ansible/module_utils/facts/hardware/openbsd.py new file mode 100644 index 0000000..3bcf8ce --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/openbsd.py @@ -0,0 +1,184 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import time + +from ansible.module_utils._text import to_text + +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector +from ansible.module_utils.facts import timeout + +from ansible.module_utils.facts.utils import get_file_content, get_mount_size +from ansible.module_utils.facts.sysctl import get_sysctl + + +class OpenBSDHardware(Hardware): + """ + OpenBSD-specific subclass of Hardware. Defines memory, CPU and device facts: + - memfree_mb + - memtotal_mb + - swapfree_mb + - swaptotal_mb + - processor (a list) + - processor_cores + - processor_count + - processor_speed + - uptime_seconds + + In addition, it also defines number of DMI facts and device facts. + """ + platform = 'OpenBSD' + + def populate(self, collected_facts=None): + hardware_facts = {} + self.sysctl = get_sysctl(self.module, ['hw']) + + hardware_facts.update(self.get_processor_facts()) + hardware_facts.update(self.get_memory_facts()) + hardware_facts.update(self.get_device_facts()) + hardware_facts.update(self.get_dmi_facts()) + hardware_facts.update(self.get_uptime_facts()) + + # storage devices notorioslly prone to hang/block so they are under a timeout + try: + hardware_facts.update(self.get_mount_facts()) + except timeout.TimeoutError: + pass + + return hardware_facts + + @timeout.timeout() + def get_mount_facts(self): + mount_facts = {} + + mount_facts['mounts'] = [] + fstab = get_file_content('/etc/fstab') + if fstab: + for line in fstab.splitlines(): + if line.startswith('#') or line.strip() == '': + continue + fields = re.sub(r'\s+', ' ', line).split() + if fields[1] == 'none' or fields[3] == 'xx': + continue + mount_statvfs_info = get_mount_size(fields[1]) + mount_info = {'mount': fields[1], + 'device': fields[0], + 'fstype': fields[2], + 'options': fields[3]} + mount_info.update(mount_statvfs_info) + mount_facts['mounts'].append(mount_info) + return mount_facts + + def get_memory_facts(self): + memory_facts = {} + # Get free memory. vmstat output looks like: + # procs memory page disks traps cpu + # r b w avm fre flt re pi po fr sr wd0 fd0 int sys cs us sy id + # 0 0 0 47512 28160 51 0 0 0 0 0 1 0 116 89 17 0 1 99 + rc, out, err = self.module.run_command("/usr/bin/vmstat") + if rc == 0: + memory_facts['memfree_mb'] = int(out.splitlines()[-1].split()[4]) // 1024 + memory_facts['memtotal_mb'] = int(self.sysctl['hw.usermem']) // 1024 // 1024 + + # Get swapctl info. swapctl output looks like: + # total: 69268 1K-blocks allocated, 0 used, 69268 available + # And for older OpenBSD: + # total: 69268k bytes allocated = 0k used, 69268k available + rc, out, err = self.module.run_command("/sbin/swapctl -sk") + if rc == 0: + swaptrans = {ord(u'k'): None, + ord(u'm'): None, + ord(u'g'): None} + data = to_text(out, errors='surrogate_or_strict').split() + memory_facts['swapfree_mb'] = int(data[-2].translate(swaptrans)) // 1024 + memory_facts['swaptotal_mb'] = int(data[1].translate(swaptrans)) // 1024 + + return memory_facts + + def get_uptime_facts(self): + # On openbsd, we need to call it with -n to get this value as an int. + sysctl_cmd = self.module.get_bin_path('sysctl') + cmd = [sysctl_cmd, '-n', 'kern.boottime'] + + rc, out, err = self.module.run_command(cmd) + + if rc != 0: + return {} + + kern_boottime = out.strip() + if not kern_boottime.isdigit(): + return {} + + return { + 'uptime_seconds': int(time.time() - int(kern_boottime)), + } + + def get_processor_facts(self): + cpu_facts = {} + processor = [] + for i in range(int(self.sysctl['hw.ncpuonline'])): + processor.append(self.sysctl['hw.model']) + + cpu_facts['processor'] = processor + # The following is partly a lie because there is no reliable way to + # determine the number of physical CPUs in the system. We can only + # query the number of logical CPUs, which hides the number of cores. + # On amd64/i386 we could try to inspect the smt/core/package lines in + # dmesg, however even those have proven to be unreliable. + # So take a shortcut and report the logical number of processors in + # 'processor_count' and 'processor_cores' and leave it at that. + cpu_facts['processor_count'] = self.sysctl['hw.ncpuonline'] + cpu_facts['processor_cores'] = self.sysctl['hw.ncpuonline'] + + return cpu_facts + + def get_device_facts(self): + device_facts = {} + devices = [] + devices.extend(self.sysctl['hw.disknames'].split(',')) + device_facts['devices'] = devices + + return device_facts + + def get_dmi_facts(self): + dmi_facts = {} + # We don't use dmidecode(8) here because: + # - it would add dependency on an external package + # - dmidecode(8) can only be ran as root + # So instead we rely on sysctl(8) to provide us the information on a + # best-effort basis. As a bonus we also get facts on non-amd64/i386 + # platforms this way. + sysctl_to_dmi = { + 'hw.product': 'product_name', + 'hw.version': 'product_version', + 'hw.uuid': 'product_uuid', + 'hw.serialno': 'product_serial', + 'hw.vendor': 'system_vendor', + } + + for mib in sysctl_to_dmi: + if mib in self.sysctl: + dmi_facts[sysctl_to_dmi[mib]] = self.sysctl[mib] + + return dmi_facts + + +class OpenBSDHardwareCollector(HardwareCollector): + _fact_class = OpenBSDHardware + _platform = 'OpenBSD' diff --git a/lib/ansible/module_utils/facts/hardware/sunos.py b/lib/ansible/module_utils/facts/hardware/sunos.py new file mode 100644 index 0000000..0a77db0 --- /dev/null +++ b/lib/ansible/module_utils/facts/hardware/sunos.py @@ -0,0 +1,286 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import time + +from ansible.module_utils.common.locale import get_best_parsable_locale +from ansible.module_utils.common.text.formatters import bytes_to_human +from ansible.module_utils.facts.utils import get_file_content, get_mount_size +from ansible.module_utils.facts.hardware.base import Hardware, HardwareCollector +from ansible.module_utils.facts import timeout +from ansible.module_utils.six.moves import reduce + + +class SunOSHardware(Hardware): + """ + In addition to the generic memory and cpu facts, this also sets + swap_reserved_mb and swap_allocated_mb that is available from *swap -s*. + """ + platform = 'SunOS' + + def populate(self, collected_facts=None): + hardware_facts = {} + + # FIXME: could pass to run_command(environ_update), but it also tweaks the env + # of the parent process instead of altering an env provided to Popen() + # Use C locale for hardware collection helpers to avoid locale specific number formatting (#24542) + locale = get_best_parsable_locale(self.module) + self.module.run_command_environ_update = {'LANG': locale, 'LC_ALL': locale, 'LC_NUMERIC': locale} + + cpu_facts = self.get_cpu_facts() + memory_facts = self.get_memory_facts() + dmi_facts = self.get_dmi_facts() + device_facts = self.get_device_facts() + uptime_facts = self.get_uptime_facts() + + mount_facts = {} + try: + mount_facts = self.get_mount_facts() + except timeout.TimeoutError: + pass + + hardware_facts.update(cpu_facts) + hardware_facts.update(memory_facts) + hardware_facts.update(dmi_facts) + hardware_facts.update(device_facts) + hardware_facts.update(uptime_facts) + hardware_facts.update(mount_facts) + + return hardware_facts + + def get_cpu_facts(self, collected_facts=None): + physid = 0 + sockets = {} + + cpu_facts = {} + collected_facts = collected_facts or {} + + rc, out, err = self.module.run_command("/usr/bin/kstat cpu_info") + + cpu_facts['processor'] = [] + + for line in out.splitlines(): + if len(line) < 1: + continue + + data = line.split(None, 1) + key = data[0].strip() + + # "brand" works on Solaris 10 & 11. "implementation" for Solaris 9. + if key == 'module:': + brand = '' + elif key == 'brand': + brand = data[1].strip() + elif key == 'clock_MHz': + clock_mhz = data[1].strip() + elif key == 'implementation': + processor = brand or data[1].strip() + # Add clock speed to description for SPARC CPU + # FIXME + if collected_facts.get('ansible_machine') != 'i86pc': + processor += " @ " + clock_mhz + "MHz" + if 'ansible_processor' not in collected_facts: + cpu_facts['processor'] = [] + cpu_facts['processor'].append(processor) + elif key == 'chip_id': + physid = data[1].strip() + if physid not in sockets: + sockets[physid] = 1 + else: + sockets[physid] += 1 + + # Counting cores on Solaris can be complicated. + # https://blogs.oracle.com/mandalika/entry/solaris_show_me_the_cpu + # Treat 'processor_count' as physical sockets and 'processor_cores' as + # virtual CPUs visisble to Solaris. Not a true count of cores for modern SPARC as + # these processors have: sockets -> cores -> threads/virtual CPU. + if len(sockets) > 0: + cpu_facts['processor_count'] = len(sockets) + cpu_facts['processor_cores'] = reduce(lambda x, y: x + y, sockets.values()) + else: + cpu_facts['processor_cores'] = 'NA' + cpu_facts['processor_count'] = len(cpu_facts['processor']) + + return cpu_facts + + def get_memory_facts(self): + memory_facts = {} + + rc, out, err = self.module.run_command(["/usr/sbin/prtconf"]) + + for line in out.splitlines(): + if 'Memory size' in line: + memory_facts['memtotal_mb'] = int(line.split()[2]) + + rc, out, err = self.module.run_command("/usr/sbin/swap -s") + + allocated = int(out.split()[1][:-1]) + reserved = int(out.split()[5][:-1]) + used = int(out.split()[8][:-1]) + free = int(out.split()[10][:-1]) + + memory_facts['swapfree_mb'] = free // 1024 + memory_facts['swaptotal_mb'] = (free + used) // 1024 + memory_facts['swap_allocated_mb'] = allocated // 1024 + memory_facts['swap_reserved_mb'] = reserved // 1024 + + return memory_facts + + @timeout.timeout() + def get_mount_facts(self): + mount_facts = {} + mount_facts['mounts'] = [] + + # For a detailed format description see mnttab(4) + # special mount_point fstype options time + fstab = get_file_content('/etc/mnttab') + + if fstab: + for line in fstab.splitlines(): + fields = line.split('\t') + mount_statvfs_info = get_mount_size(fields[1]) + mount_info = {'mount': fields[1], + 'device': fields[0], + 'fstype': fields[2], + 'options': fields[3], + 'time': fields[4]} + mount_info.update(mount_statvfs_info) + mount_facts['mounts'].append(mount_info) + + return mount_facts + + def get_dmi_facts(self): + dmi_facts = {} + + # On Solaris 8 the prtdiag wrapper is absent from /usr/sbin, + # but that's okay, because we know where to find the real thing: + rc, platform, err = self.module.run_command('/usr/bin/uname -i') + platform_sbin = '/usr/platform/' + platform.rstrip() + '/sbin' + + prtdiag_path = self.module.get_bin_path("prtdiag", opt_dirs=[platform_sbin]) + rc, out, err = self.module.run_command(prtdiag_path) + """ + rc returns 1 + """ + if out: + system_conf = out.split('\n')[0] + + # If you know of any other manufacturers whose names appear in + # the first line of prtdiag's output, please add them here: + vendors = [ + "Fujitsu", + "Oracle Corporation", + "QEMU", + "Sun Microsystems", + "VMware, Inc.", + ] + vendor_regexp = "|".join(map(re.escape, vendors)) + system_conf_regexp = (r'System Configuration:\s+' + + r'(' + vendor_regexp + r')\s+' + + r'(?:sun\w+\s+)?' + + r'(.+)') + + found = re.match(system_conf_regexp, system_conf) + if found: + dmi_facts['system_vendor'] = found.group(1) + dmi_facts['product_name'] = found.group(2) + + return dmi_facts + + def get_device_facts(self): + # Device facts are derived for sdderr kstats. This code does not use the + # full output, but rather queries for specific stats. + # Example output: + # sderr:0:sd0,err:Hard Errors 0 + # sderr:0:sd0,err:Illegal Request 6 + # sderr:0:sd0,err:Media Error 0 + # sderr:0:sd0,err:Predictive Failure Analysis 0 + # sderr:0:sd0,err:Product VBOX HARDDISK 9 + # sderr:0:sd0,err:Revision 1.0 + # sderr:0:sd0,err:Serial No VB0ad2ec4d-074a + # sderr:0:sd0,err:Size 53687091200 + # sderr:0:sd0,err:Soft Errors 0 + # sderr:0:sd0,err:Transport Errors 0 + # sderr:0:sd0,err:Vendor ATA + + device_facts = {} + device_facts['devices'] = {} + + disk_stats = { + 'Product': 'product', + 'Revision': 'revision', + 'Serial No': 'serial', + 'Size': 'size', + 'Vendor': 'vendor', + 'Hard Errors': 'hard_errors', + 'Soft Errors': 'soft_errors', + 'Transport Errors': 'transport_errors', + 'Media Error': 'media_errors', + 'Predictive Failure Analysis': 'predictive_failure_analysis', + 'Illegal Request': 'illegal_request', + } + + cmd = ['/usr/bin/kstat', '-p'] + + for ds in disk_stats: + cmd.append('sderr:::%s' % ds) + + d = {} + rc, out, err = self.module.run_command(cmd) + if rc != 0: + return device_facts + + sd_instances = frozenset(line.split(':')[1] for line in out.split('\n') if line.startswith('sderr')) + for instance in sd_instances: + lines = (line for line in out.split('\n') if ':' in line and line.split(':')[1] == instance) + for line in lines: + text, value = line.split('\t') + stat = text.split(':')[3] + + if stat == 'Size': + d[disk_stats.get(stat)] = bytes_to_human(float(value)) + else: + d[disk_stats.get(stat)] = value.rstrip() + + diskname = 'sd' + instance + device_facts['devices'][diskname] = d + d = {} + + return device_facts + + def get_uptime_facts(self): + uptime_facts = {} + # sample kstat output: + # unix:0:system_misc:boot_time 1548249689 + rc, out, err = self.module.run_command('/usr/bin/kstat -p unix:0:system_misc:boot_time') + + if rc != 0: + return + + # uptime = $current_time - $boot_time + uptime_facts['uptime_seconds'] = int(time.time() - int(out.split('\t')[1])) + + return uptime_facts + + +class SunOSHardwareCollector(HardwareCollector): + _fact_class = SunOSHardware + _platform = 'SunOS' + + required_facts = set(['platform']) diff --git a/lib/ansible/module_utils/facts/namespace.py b/lib/ansible/module_utils/facts/namespace.py new file mode 100644 index 0000000..2d6bf8a --- /dev/null +++ b/lib/ansible/module_utils/facts/namespace.py @@ -0,0 +1,51 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +class FactNamespace: + def __init__(self, namespace_name): + self.namespace_name = namespace_name + + def transform(self, name): + '''Take a text name, and transforms it as needed (add a namespace prefix, etc)''' + return name + + def _underscore(self, name): + return name.replace('-', '_') + + +class PrefixFactNamespace(FactNamespace): + def __init__(self, namespace_name, prefix=None): + super(PrefixFactNamespace, self).__init__(namespace_name) + self.prefix = prefix + + def transform(self, name): + new_name = self._underscore(name) + return '%s%s' % (self.prefix, new_name) diff --git a/lib/ansible/module_utils/facts/network/__init__.py b/lib/ansible/module_utils/facts/network/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/__init__.py diff --git a/lib/ansible/module_utils/facts/network/aix.py b/lib/ansible/module_utils/facts/network/aix.py new file mode 100644 index 0000000..e9c90c6 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/aix.py @@ -0,0 +1,145 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.generic_bsd import GenericBsdIfconfigNetwork + + +class AIXNetwork(GenericBsdIfconfigNetwork): + """ + This is the AIX Network Class. + It uses the GenericBsdIfconfigNetwork unchanged. + """ + platform = 'AIX' + + def get_default_interfaces(self, route_path): + interface = dict(v4={}, v6={}) + + netstat_path = self.module.get_bin_path('netstat') + + if netstat_path: + rc, out, err = self.module.run_command([netstat_path, '-nr']) + + lines = out.splitlines() + for line in lines: + words = line.split() + if len(words) > 1 and words[0] == 'default': + if '.' in words[1]: + interface['v4']['gateway'] = words[1] + interface['v4']['interface'] = words[5] + elif ':' in words[1]: + interface['v6']['gateway'] = words[1] + interface['v6']['interface'] = words[5] + + return interface['v4'], interface['v6'] + + # AIX 'ifconfig -a' does not have three words in the interface line + def get_interfaces_info(self, ifconfig_path, ifconfig_options='-a'): + interfaces = {} + current_if = {} + ips = dict( + all_ipv4_addresses=[], + all_ipv6_addresses=[], + ) + + uname_rc = None + uname_out = None + uname_err = None + uname_path = self.module.get_bin_path('uname') + if uname_path: + uname_rc, uname_out, uname_err = self.module.run_command([uname_path, '-W']) + + rc, out, err = self.module.run_command([ifconfig_path, ifconfig_options]) + + for line in out.splitlines(): + + if line: + words = line.split() + + # only this condition differs from GenericBsdIfconfigNetwork + if re.match(r'^\w*\d*:', line): + current_if = self.parse_interface_line(words) + interfaces[current_if['device']] = current_if + elif words[0].startswith('options='): + self.parse_options_line(words, current_if, ips) + elif words[0] == 'nd6': + self.parse_nd6_line(words, current_if, ips) + elif words[0] == 'ether': + self.parse_ether_line(words, current_if, ips) + elif words[0] == 'media:': + self.parse_media_line(words, current_if, ips) + elif words[0] == 'status:': + self.parse_status_line(words, current_if, ips) + elif words[0] == 'lladdr': + self.parse_lladdr_line(words, current_if, ips) + elif words[0] == 'inet': + self.parse_inet_line(words, current_if, ips) + elif words[0] == 'inet6': + self.parse_inet6_line(words, current_if, ips) + else: + self.parse_unknown_line(words, current_if, ips) + + # don't bother with wpars it does not work + # zero means not in wpar + if not uname_rc and uname_out.split()[0] == '0': + + if current_if['macaddress'] == 'unknown' and re.match('^en', current_if['device']): + entstat_path = self.module.get_bin_path('entstat') + if entstat_path: + rc, out, err = self.module.run_command([entstat_path, current_if['device']]) + if rc != 0: + break + for line in out.splitlines(): + if not line: + pass + buff = re.match('^Hardware Address: (.*)', line) + if buff: + current_if['macaddress'] = buff.group(1) + + buff = re.match('^Device Type:', line) + if buff and re.match('.*Ethernet', line): + current_if['type'] = 'ether' + + # device must have mtu attribute in ODM + if 'mtu' not in current_if: + lsattr_path = self.module.get_bin_path('lsattr') + if lsattr_path: + rc, out, err = self.module.run_command([lsattr_path, '-El', current_if['device']]) + if rc != 0: + break + for line in out.splitlines(): + if line: + words = line.split() + if words[0] == 'mtu': + current_if['mtu'] = words[1] + return interfaces, ips + + # AIX 'ifconfig -a' does not inform about MTU, so remove current_if['mtu'] here + def parse_interface_line(self, words): + device = words[0][0:-1] + current_if = {'device': device, 'ipv4': [], 'ipv6': [], 'type': 'unknown'} + current_if['flags'] = self.get_options(words[1]) + current_if['macaddress'] = 'unknown' # will be overwritten later + return current_if + + +class AIXNetworkCollector(NetworkCollector): + _fact_class = AIXNetwork + _platform = 'AIX' diff --git a/lib/ansible/module_utils/facts/network/base.py b/lib/ansible/module_utils/facts/network/base.py new file mode 100644 index 0000000..8243f06 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/base.py @@ -0,0 +1,72 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class Network: + """ + This is a generic Network subclass of Facts. This should be further + subclassed to implement per platform. If you subclass this, + you must define: + - interfaces (a list of interface names) + - interface_<name> dictionary of ipv4, ipv6, and mac address information. + + All subclasses MUST define platform. + """ + platform = 'Generic' + + # FIXME: remove load_on_init when we can + def __init__(self, module, load_on_init=False): + self.module = module + + # TODO: more or less abstract/NotImplemented + def populate(self, collected_facts=None): + return {} + + +class NetworkCollector(BaseFactCollector): + # MAYBE: we could try to build this based on the arch specific implementation of Network() or its kin + name = 'network' + _fact_class = Network + _fact_ids = set(['interfaces', + 'default_ipv4', + 'default_ipv6', + 'all_ipv4_addresses', + 'all_ipv6_addresses']) # type: t.Set[str] + + IPV6_SCOPE = {'0': 'global', + '10': 'host', + '20': 'link', + '40': 'admin', + '50': 'site', + '80': 'organization'} + + def collect(self, module=None, collected_facts=None): + collected_facts = collected_facts or {} + if not module: + return {} + + # Network munges cached_facts by side effect, so give it a copy + facts_obj = self._fact_class(module) + + facts_dict = facts_obj.populate(collected_facts=collected_facts) + + return facts_dict diff --git a/lib/ansible/module_utils/facts/network/darwin.py b/lib/ansible/module_utils/facts/network/darwin.py new file mode 100644 index 0000000..90117e5 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/darwin.py @@ -0,0 +1,49 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.generic_bsd import GenericBsdIfconfigNetwork + + +class DarwinNetwork(GenericBsdIfconfigNetwork): + """ + This is the Mac macOS Darwin Network Class. + It uses the GenericBsdIfconfigNetwork unchanged + """ + platform = 'Darwin' + + # media line is different to the default FreeBSD one + def parse_media_line(self, words, current_if, ips): + # not sure if this is useful - we also drop information + current_if['media'] = 'Unknown' # Mac does not give us this + current_if['media_select'] = words[1] + if len(words) > 2: + # MacOSX sets the media to '<unknown type>' for bridge interface + # and parsing splits this into two words; this if/else helps + if words[1] == '<unknown' and words[2] == 'type>': + current_if['media_select'] = 'Unknown' + current_if['media_type'] = 'unknown type' + else: + current_if['media_type'] = words[2][1:-1] + if len(words) > 3: + current_if['media_options'] = self.get_options(words[3]) + + +class DarwinNetworkCollector(NetworkCollector): + _fact_class = DarwinNetwork + _platform = 'Darwin' diff --git a/lib/ansible/module_utils/facts/network/dragonfly.py b/lib/ansible/module_utils/facts/network/dragonfly.py new file mode 100644 index 0000000..e43bbb2 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/dragonfly.py @@ -0,0 +1,33 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.generic_bsd import GenericBsdIfconfigNetwork + + +class DragonFlyNetwork(GenericBsdIfconfigNetwork): + """ + This is the DragonFly Network Class. + It uses the GenericBsdIfconfigNetwork unchanged. + """ + platform = 'DragonFly' + + +class DragonFlyNetworkCollector(NetworkCollector): + _fact_class = DragonFlyNetwork + _platform = 'DragonFly' diff --git a/lib/ansible/module_utils/facts/network/fc_wwn.py b/lib/ansible/module_utils/facts/network/fc_wwn.py new file mode 100644 index 0000000..86182f8 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/fc_wwn.py @@ -0,0 +1,111 @@ +# Fibre Channel WWN initiator related facts collection for ansible. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import sys +import glob + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_lines +from ansible.module_utils.facts.collector import BaseFactCollector + + +class FcWwnInitiatorFactCollector(BaseFactCollector): + name = 'fibre_channel_wwn' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + """ + Example contents /sys/class/fc_host/*/port_name: + + 0x21000014ff52a9bb + + """ + + fc_facts = {} + fc_facts['fibre_channel_wwn'] = [] + if sys.platform.startswith('linux'): + for fcfile in glob.glob('/sys/class/fc_host/*/port_name'): + for line in get_file_lines(fcfile): + fc_facts['fibre_channel_wwn'].append(line.rstrip()[2:]) + elif sys.platform.startswith('sunos'): + """ + on solaris 10 or solaris 11 should use `fcinfo hba-port` + TBD (not implemented): on solaris 9 use `prtconf -pv` + """ + cmd = module.get_bin_path('fcinfo') + if cmd: + cmd = cmd + " hba-port" + rc, fcinfo_out, err = module.run_command(cmd) + """ + # fcinfo hba-port | grep "Port WWN" + HBA Port WWN: 10000090fa1658de + """ + if rc == 0 and fcinfo_out: + for line in fcinfo_out.splitlines(): + if 'Port WWN' in line: + data = line.split(' ') + fc_facts['fibre_channel_wwn'].append(data[-1].rstrip()) + elif sys.platform.startswith('aix'): + cmd = module.get_bin_path('lsdev') + lscfg_cmd = module.get_bin_path('lscfg') + if cmd and lscfg_cmd: + # get list of available fibre-channel devices (fcs) + cmd = cmd + " -Cc adapter -l fcs*" + rc, lsdev_out, err = module.run_command(cmd) + if rc == 0 and lsdev_out: + for line in lsdev_out.splitlines(): + # if device is available (not in defined state), get its WWN + if 'Available' in line: + data = line.split(' ') + cmd = lscfg_cmd + " -vl %s" % data[0] + rc, lscfg_out, err = module.run_command(cmd) + # example output + # lscfg -vpl fcs3 | grep "Network Address" + # Network Address.............10000090FA551509 + if rc == 0 and lscfg_out: + for line in lscfg_out.splitlines(): + if 'Network Address' in line: + data = line.split('.') + fc_facts['fibre_channel_wwn'].append(data[-1].rstrip()) + elif sys.platform.startswith('hp-ux'): + cmd = module.get_bin_path('ioscan') + fcmsu_cmd = module.get_bin_path('fcmsutil', opt_dirs=['/opt/fcms/bin']) + # go ahead if we have both commands available + if cmd and fcmsu_cmd: + # ioscan / get list of available fibre-channel devices (fcd) + cmd = cmd + " -fnC FC" + rc, ioscan_out, err = module.run_command(cmd) + if rc == 0 and ioscan_out: + for line in ioscan_out.splitlines(): + line = line.strip() + if '/dev/fcd' in line: + dev = line.split(' ') + # get device information + cmd = fcmsu_cmd + " %s" % dev[0] + rc, fcmsutil_out, err = module.run_command(cmd) + # lookup the following line + # N_Port Port World Wide Name = 0x50060b00006975ec + if rc == 0 and fcmsutil_out: + for line in fcmsutil_out.splitlines(): + if 'N_Port Port World Wide Name' in line: + data = line.split('=') + fc_facts['fibre_channel_wwn'].append(data[-1].strip()) + return fc_facts diff --git a/lib/ansible/module_utils/facts/network/freebsd.py b/lib/ansible/module_utils/facts/network/freebsd.py new file mode 100644 index 0000000..36f6eec --- /dev/null +++ b/lib/ansible/module_utils/facts/network/freebsd.py @@ -0,0 +1,33 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.generic_bsd import GenericBsdIfconfigNetwork + + +class FreeBSDNetwork(GenericBsdIfconfigNetwork): + """ + This is the FreeBSD Network Class. + It uses the GenericBsdIfconfigNetwork unchanged. + """ + platform = 'FreeBSD' + + +class FreeBSDNetworkCollector(NetworkCollector): + _fact_class = FreeBSDNetwork + _platform = 'FreeBSD' diff --git a/lib/ansible/module_utils/facts/network/generic_bsd.py b/lib/ansible/module_utils/facts/network/generic_bsd.py new file mode 100644 index 0000000..8d640f2 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/generic_bsd.py @@ -0,0 +1,321 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import socket +import struct + +from ansible.module_utils.facts.network.base import Network + + +class GenericBsdIfconfigNetwork(Network): + """ + This is a generic BSD subclass of Network using the ifconfig command. + It defines + - interfaces (a list of interface names) + - interface_<name> dictionary of ipv4, ipv6, and mac address information. + - all_ipv4_addresses and all_ipv6_addresses: lists of all configured addresses. + """ + platform = 'Generic_BSD_Ifconfig' + + def populate(self, collected_facts=None): + network_facts = {} + ifconfig_path = self.module.get_bin_path('ifconfig') + + if ifconfig_path is None: + return network_facts + + route_path = self.module.get_bin_path('route') + + if route_path is None: + return network_facts + + default_ipv4, default_ipv6 = self.get_default_interfaces(route_path) + interfaces, ips = self.get_interfaces_info(ifconfig_path) + interfaces = self.detect_type_media(interfaces) + + self.merge_default_interface(default_ipv4, interfaces, 'ipv4') + self.merge_default_interface(default_ipv6, interfaces, 'ipv6') + network_facts['interfaces'] = sorted(list(interfaces.keys())) + + for iface in interfaces: + network_facts[iface] = interfaces[iface] + + network_facts['default_ipv4'] = default_ipv4 + network_facts['default_ipv6'] = default_ipv6 + network_facts['all_ipv4_addresses'] = ips['all_ipv4_addresses'] + network_facts['all_ipv6_addresses'] = ips['all_ipv6_addresses'] + + return network_facts + + def detect_type_media(self, interfaces): + for iface in interfaces: + if 'media' in interfaces[iface]: + if 'ether' in interfaces[iface]['media'].lower(): + interfaces[iface]['type'] = 'ether' + return interfaces + + def get_default_interfaces(self, route_path): + + # Use the commands: + # route -n get default + # route -n get -inet6 default + # to find out the default outgoing interface, address, and gateway + + command = dict(v4=[route_path, '-n', 'get', 'default'], + v6=[route_path, '-n', 'get', '-inet6', 'default']) + + interface = dict(v4={}, v6={}) + + for v in 'v4', 'v6': + + if v == 'v6' and not socket.has_ipv6: + continue + rc, out, err = self.module.run_command(command[v]) + if not out: + # v6 routing may result in + # RTNETLINK answers: Invalid argument + continue + for line in out.splitlines(): + words = line.strip().split(': ') + # Collect output from route command + if len(words) > 1: + if words[0] == 'interface': + interface[v]['interface'] = words[1] + if words[0] == 'gateway': + interface[v]['gateway'] = words[1] + # help pick the right interface address on OpenBSD + if words[0] == 'if address': + interface[v]['address'] = words[1] + # help pick the right interface address on NetBSD + if words[0] == 'local addr': + interface[v]['address'] = words[1] + + return interface['v4'], interface['v6'] + + def get_interfaces_info(self, ifconfig_path, ifconfig_options='-a'): + interfaces = {} + current_if = {} + ips = dict( + all_ipv4_addresses=[], + all_ipv6_addresses=[], + ) + # FreeBSD, DragonflyBSD, NetBSD, OpenBSD and macOS all implicitly add '-a' + # when running the command 'ifconfig'. + # Solaris must explicitly run the command 'ifconfig -a'. + rc, out, err = self.module.run_command([ifconfig_path, ifconfig_options]) + + for line in out.splitlines(): + + if line: + words = line.split() + + if words[0] == 'pass': + continue + elif re.match(r'^\S', line) and len(words) > 3: + current_if = self.parse_interface_line(words) + interfaces[current_if['device']] = current_if + elif words[0].startswith('options='): + self.parse_options_line(words, current_if, ips) + elif words[0] == 'nd6': + self.parse_nd6_line(words, current_if, ips) + elif words[0] == 'ether': + self.parse_ether_line(words, current_if, ips) + elif words[0] == 'media:': + self.parse_media_line(words, current_if, ips) + elif words[0] == 'status:': + self.parse_status_line(words, current_if, ips) + elif words[0] == 'lladdr': + self.parse_lladdr_line(words, current_if, ips) + elif words[0] == 'inet': + self.parse_inet_line(words, current_if, ips) + elif words[0] == 'inet6': + self.parse_inet6_line(words, current_if, ips) + elif words[0] == 'tunnel': + self.parse_tunnel_line(words, current_if, ips) + else: + self.parse_unknown_line(words, current_if, ips) + + return interfaces, ips + + def parse_interface_line(self, words): + device = words[0][0:-1] + current_if = {'device': device, 'ipv4': [], 'ipv6': [], 'type': 'unknown'} + current_if['flags'] = self.get_options(words[1]) + if 'LOOPBACK' in current_if['flags']: + current_if['type'] = 'loopback' + current_if['macaddress'] = 'unknown' # will be overwritten later + + if len(words) >= 5: # Newer FreeBSD versions + current_if['metric'] = words[3] + current_if['mtu'] = words[5] + else: + current_if['mtu'] = words[3] + + return current_if + + def parse_options_line(self, words, current_if, ips): + # Mac has options like this... + current_if['options'] = self.get_options(words[0]) + + def parse_nd6_line(self, words, current_if, ips): + # FreeBSD has options like this... + current_if['options'] = self.get_options(words[1]) + + def parse_ether_line(self, words, current_if, ips): + current_if['macaddress'] = words[1] + current_if['type'] = 'ether' + + def parse_media_line(self, words, current_if, ips): + # not sure if this is useful - we also drop information + current_if['media'] = words[1] + if len(words) > 2: + current_if['media_select'] = words[2] + if len(words) > 3: + current_if['media_type'] = words[3][1:] + if len(words) > 4: + current_if['media_options'] = self.get_options(words[4]) + + def parse_status_line(self, words, current_if, ips): + current_if['status'] = words[1] + + def parse_lladdr_line(self, words, current_if, ips): + current_if['lladdr'] = words[1] + + def parse_inet_line(self, words, current_if, ips): + # netbsd show aliases like this + # lo0: flags=8049<UP,LOOPBACK,RUNNING,MULTICAST> mtu 33184 + # inet 127.0.0.1 netmask 0xff000000 + # inet alias 127.1.1.1 netmask 0xff000000 + if words[1] == 'alias': + del words[1] + + address = {'address': words[1]} + # cidr style ip address (eg, 127.0.0.1/24) in inet line + # used in netbsd ifconfig -e output after 7.1 + if '/' in address['address']: + ip_address, cidr_mask = address['address'].split('/') + + address['address'] = ip_address + + netmask_length = int(cidr_mask) + netmask_bin = (1 << 32) - (1 << 32 >> int(netmask_length)) + address['netmask'] = socket.inet_ntoa(struct.pack('!L', netmask_bin)) + + if len(words) > 5: + address['broadcast'] = words[3] + + else: + # Don't just assume columns, use "netmask" as the index for the prior column + try: + netmask_idx = words.index('netmask') + 1 + except ValueError: + netmask_idx = 3 + + # deal with hex netmask + if re.match('([0-9a-f]){8}$', words[netmask_idx]): + netmask = '0x' + words[netmask_idx] + else: + netmask = words[netmask_idx] + + if netmask.startswith('0x'): + address['netmask'] = socket.inet_ntoa(struct.pack('!L', int(netmask, base=16))) + else: + # otherwise assume this is a dotted quad + address['netmask'] = netmask + # calculate the network + address_bin = struct.unpack('!L', socket.inet_aton(address['address']))[0] + netmask_bin = struct.unpack('!L', socket.inet_aton(address['netmask']))[0] + address['network'] = socket.inet_ntoa(struct.pack('!L', address_bin & netmask_bin)) + if 'broadcast' not in address: + # broadcast may be given or we need to calculate + try: + broadcast_idx = words.index('broadcast') + 1 + except ValueError: + address['broadcast'] = socket.inet_ntoa(struct.pack('!L', address_bin | (~netmask_bin & 0xffffffff))) + else: + address['broadcast'] = words[broadcast_idx] + + # add to our list of addresses + if not words[1].startswith('127.'): + ips['all_ipv4_addresses'].append(address['address']) + current_if['ipv4'].append(address) + + def parse_inet6_line(self, words, current_if, ips): + address = {'address': words[1]} + + # using cidr style addresses, ala NetBSD ifconfig post 7.1 + if '/' in address['address']: + ip_address, cidr_mask = address['address'].split('/') + + address['address'] = ip_address + address['prefix'] = cidr_mask + + if len(words) > 5: + address['scope'] = words[5] + else: + if (len(words) >= 4) and (words[2] == 'prefixlen'): + address['prefix'] = words[3] + if (len(words) >= 6) and (words[4] == 'scopeid'): + address['scope'] = words[5] + + localhost6 = ['::1', '::1/128', 'fe80::1%lo0'] + if address['address'] not in localhost6: + ips['all_ipv6_addresses'].append(address['address']) + current_if['ipv6'].append(address) + + def parse_tunnel_line(self, words, current_if, ips): + current_if['type'] = 'tunnel' + + def parse_unknown_line(self, words, current_if, ips): + # we are going to ignore unknown lines here - this may be + # a bad idea - but you can override it in your subclass + pass + + # TODO: these are module scope static function candidates + # (most of the class is really...) + def get_options(self, option_string): + start = option_string.find('<') + 1 + end = option_string.rfind('>') + if (start > 0) and (end > 0) and (end > start + 1): + option_csv = option_string[start:end] + return option_csv.split(',') + else: + return [] + + def merge_default_interface(self, defaults, interfaces, ip_type): + if 'interface' not in defaults: + return + if not defaults['interface'] in interfaces: + return + ifinfo = interfaces[defaults['interface']] + # copy all the interface values across except addresses + for item in ifinfo: + if item != 'ipv4' and item != 'ipv6': + defaults[item] = ifinfo[item] + + ipinfo = [] + if 'address' in defaults: + ipinfo = [x for x in ifinfo[ip_type] if x['address'] == defaults['address']] + + if len(ipinfo) == 0: + ipinfo = ifinfo[ip_type] + + if len(ipinfo) > 0: + for item in ipinfo[0]: + defaults[item] = ipinfo[0][item] diff --git a/lib/ansible/module_utils/facts/network/hpux.py b/lib/ansible/module_utils/facts/network/hpux.py new file mode 100644 index 0000000..add57be --- /dev/null +++ b/lib/ansible/module_utils/facts/network/hpux.py @@ -0,0 +1,82 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.network.base import Network, NetworkCollector + + +class HPUXNetwork(Network): + """ + HP-UX-specifig subclass of Network. Defines networking facts: + - default_interface + - interfaces (a list of interface names) + - interface_<name> dictionary of ipv4 address information. + """ + platform = 'HP-UX' + + def populate(self, collected_facts=None): + network_facts = {} + netstat_path = self.module.get_bin_path('netstat') + + if netstat_path is None: + return network_facts + + default_interfaces_facts = self.get_default_interfaces() + network_facts.update(default_interfaces_facts) + + interfaces = self.get_interfaces_info() + network_facts['interfaces'] = interfaces.keys() + for iface in interfaces: + network_facts[iface] = interfaces[iface] + + return network_facts + + def get_default_interfaces(self): + default_interfaces = {} + rc, out, err = self.module.run_command("/usr/bin/netstat -nr") + lines = out.splitlines() + for line in lines: + words = line.split() + if len(words) > 1: + if words[0] == 'default': + default_interfaces['default_interface'] = words[4] + default_interfaces['default_gateway'] = words[1] + + return default_interfaces + + def get_interfaces_info(self): + interfaces = {} + rc, out, err = self.module.run_command("/usr/bin/netstat -niw") + lines = out.splitlines() + for line in lines: + words = line.split() + for i in range(len(words) - 1): + if words[i][:3] == 'lan': + device = words[i] + interfaces[device] = {'device': device} + address = words[i + 3] + interfaces[device]['ipv4'] = {'address': address} + network = words[i + 2] + interfaces[device]['ipv4'] = {'network': network, + 'interface': device, + 'address': address} + return interfaces + + +class HPUXNetworkCollector(NetworkCollector): + _fact_class = HPUXNetwork + _platform = 'HP-UX' diff --git a/lib/ansible/module_utils/facts/network/hurd.py b/lib/ansible/module_utils/facts/network/hurd.py new file mode 100644 index 0000000..518df39 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/hurd.py @@ -0,0 +1,87 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from ansible.module_utils.facts.network.base import Network, NetworkCollector + + +class HurdPfinetNetwork(Network): + """ + This is a GNU Hurd specific subclass of Network. It use fsysopts to + get the ip address and support only pfinet. + """ + platform = 'GNU' + _socket_dir = '/servers/socket/' + + def assign_network_facts(self, network_facts, fsysopts_path, socket_path): + rc, out, err = self.module.run_command([fsysopts_path, '-L', socket_path]) + # FIXME: build up a interfaces datastructure, then assign into network_facts + network_facts['interfaces'] = [] + for i in out.split(): + if '=' in i and i.startswith('--'): + k, v = i.split('=', 1) + # remove '--' + k = k[2:] + if k == 'interface': + # remove /dev/ from /dev/eth0 + v = v[5:] + network_facts['interfaces'].append(v) + network_facts[v] = { + 'active': True, + 'device': v, + 'ipv4': {}, + 'ipv6': [], + } + current_if = v + elif k == 'address': + network_facts[current_if]['ipv4']['address'] = v + elif k == 'netmask': + network_facts[current_if]['ipv4']['netmask'] = v + elif k == 'address6': + address, prefix = v.split('/') + network_facts[current_if]['ipv6'].append({ + 'address': address, + 'prefix': prefix, + }) + return network_facts + + def populate(self, collected_facts=None): + network_facts = {} + + fsysopts_path = self.module.get_bin_path('fsysopts') + if fsysopts_path is None: + return network_facts + + socket_path = None + + for l in ('inet', 'inet6'): + link = os.path.join(self._socket_dir, l) + if os.path.exists(link): + socket_path = link + break + + if socket_path is None: + return network_facts + + return self.assign_network_facts(network_facts, fsysopts_path, socket_path) + + +class HurdNetworkCollector(NetworkCollector): + _platform = 'GNU' + _fact_class = HurdPfinetNetwork diff --git a/lib/ansible/module_utils/facts/network/iscsi.py b/lib/ansible/module_utils/facts/network/iscsi.py new file mode 100644 index 0000000..2bb9383 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/iscsi.py @@ -0,0 +1,115 @@ +# iSCSI initiator related facts collection for Ansible. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import sys +import subprocess + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.common.process import get_bin_path +from ansible.module_utils.facts.utils import get_file_content +from ansible.module_utils.facts.network.base import NetworkCollector + + +class IscsiInitiatorNetworkCollector(NetworkCollector): + name = 'iscsi' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + """ + Example of contents of /etc/iscsi/initiatorname.iscsi: + + ## DO NOT EDIT OR REMOVE THIS FILE! + ## If you remove this file, the iSCSI daemon will not start. + ## If you change the InitiatorName, existing access control lists + ## may reject this initiator. The InitiatorName must be unique + ## for each iSCSI initiator. Do NOT duplicate iSCSI InitiatorNames. + InitiatorName=iqn.1993-08.org.debian:01:44a42c8ddb8b + + Example of output from the AIX lsattr command: + + # lsattr -E -l iscsi0 + disc_filename /etc/iscsi/targets Configuration file False + disc_policy file Discovery Policy True + initiator_name iqn.localhost.hostid.7f000002 iSCSI Initiator Name True + isns_srvnames auto iSNS Servers IP Addresses True + isns_srvports iSNS Servers Port Numbers True + max_targets 16 Maximum Targets Allowed True + num_cmd_elems 200 Maximum number of commands to queue to driver True + + Example of output from the HP-UX iscsiutil command: + + #iscsiutil -l + Initiator Name : iqn.1986-03.com.hp:mcel_VMhost3.1f355cf6-e2db-11e0-a999-b44c0aef5537 + Initiator Alias : + + Authentication Method : None + CHAP Method : CHAP_UNI + Initiator CHAP Name : + CHAP Secret : + NAS Hostname : + NAS Secret : + Radius Server Hostname : + Header Digest : None, CRC32C (default) + Data Digest : None, CRC32C (default) + SLP Scope list for iSLPD : + """ + + iscsi_facts = {} + iscsi_facts['iscsi_iqn'] = "" + if sys.platform.startswith('linux') or sys.platform.startswith('sunos'): + for line in get_file_content('/etc/iscsi/initiatorname.iscsi', '').splitlines(): + if line.startswith('#') or line.startswith(';') or line.strip() == '': + continue + if line.startswith('InitiatorName='): + iscsi_facts['iscsi_iqn'] = line.split('=', 1)[1] + break + elif sys.platform.startswith('aix'): + try: + cmd = get_bin_path('lsattr') + except ValueError: + return iscsi_facts + + cmd += " -E -l iscsi0" + rc, out, err = module.run_command(cmd) + if rc == 0 and out: + line = self.findstr(out, 'initiator_name') + iscsi_facts['iscsi_iqn'] = line.split()[1].rstrip() + + elif sys.platform.startswith('hp-ux'): + # try to find it in the default PATH and opt_dirs + try: + cmd = get_bin_path('iscsiutil', opt_dirs=['/opt/iscsi/bin']) + except ValueError: + return iscsi_facts + + cmd += " -l" + rc, out, err = module.run_command(cmd) + if out: + line = self.findstr(out, 'Initiator Name') + iscsi_facts['iscsi_iqn'] = line.split(":", 1)[1].rstrip() + + return iscsi_facts + + def findstr(self, text, match): + for line in text.splitlines(): + if match in line: + found = line + return found diff --git a/lib/ansible/module_utils/facts/network/linux.py b/lib/ansible/module_utils/facts/network/linux.py new file mode 100644 index 0000000..b7ae976 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/linux.py @@ -0,0 +1,327 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import glob +import os +import re +import socket +import struct + +from ansible.module_utils.facts.network.base import Network, NetworkCollector + +from ansible.module_utils.facts.utils import get_file_content + + +class LinuxNetwork(Network): + """ + This is a Linux-specific subclass of Network. It defines + - interfaces (a list of interface names) + - interface_<name> dictionary of ipv4, ipv6, and mac address information. + - all_ipv4_addresses and all_ipv6_addresses: lists of all configured addresses. + - ipv4_address and ipv6_address: the first non-local address for each family. + """ + platform = 'Linux' + INTERFACE_TYPE = { + '1': 'ether', + '32': 'infiniband', + '512': 'ppp', + '772': 'loopback', + '65534': 'tunnel', + } + + def populate(self, collected_facts=None): + network_facts = {} + ip_path = self.module.get_bin_path('ip') + if ip_path is None: + return network_facts + default_ipv4, default_ipv6 = self.get_default_interfaces(ip_path, + collected_facts=collected_facts) + interfaces, ips = self.get_interfaces_info(ip_path, default_ipv4, default_ipv6) + network_facts['interfaces'] = interfaces.keys() + for iface in interfaces: + network_facts[iface] = interfaces[iface] + network_facts['default_ipv4'] = default_ipv4 + network_facts['default_ipv6'] = default_ipv6 + network_facts['all_ipv4_addresses'] = ips['all_ipv4_addresses'] + network_facts['all_ipv6_addresses'] = ips['all_ipv6_addresses'] + return network_facts + + def get_default_interfaces(self, ip_path, collected_facts=None): + collected_facts = collected_facts or {} + # Use the commands: + # ip -4 route get 8.8.8.8 -> Google public DNS + # ip -6 route get 2404:6800:400a:800::1012 -> ipv6.google.com + # to find out the default outgoing interface, address, and gateway + command = dict( + v4=[ip_path, '-4', 'route', 'get', '8.8.8.8'], + v6=[ip_path, '-6', 'route', 'get', '2404:6800:400a:800::1012'] + ) + interface = dict(v4={}, v6={}) + + for v in 'v4', 'v6': + if (v == 'v6' and collected_facts.get('ansible_os_family') == 'RedHat' and + collected_facts.get('ansible_distribution_version', '').startswith('4.')): + continue + if v == 'v6' and not socket.has_ipv6: + continue + rc, out, err = self.module.run_command(command[v], errors='surrogate_then_replace') + if not out: + # v6 routing may result in + # RTNETLINK answers: Invalid argument + continue + words = out.splitlines()[0].split() + # A valid output starts with the queried address on the first line + if len(words) > 0 and words[0] == command[v][-1]: + for i in range(len(words) - 1): + if words[i] == 'dev': + interface[v]['interface'] = words[i + 1] + elif words[i] == 'src': + interface[v]['address'] = words[i + 1] + elif words[i] == 'via' and words[i + 1] != command[v][-1]: + interface[v]['gateway'] = words[i + 1] + return interface['v4'], interface['v6'] + + def get_interfaces_info(self, ip_path, default_ipv4, default_ipv6): + interfaces = {} + ips = dict( + all_ipv4_addresses=[], + all_ipv6_addresses=[], + ) + + # FIXME: maybe split into smaller methods? + # FIXME: this is pretty much a constructor + + for path in glob.glob('/sys/class/net/*'): + if not os.path.isdir(path): + continue + device = os.path.basename(path) + interfaces[device] = {'device': device} + if os.path.exists(os.path.join(path, 'address')): + macaddress = get_file_content(os.path.join(path, 'address'), default='') + if macaddress and macaddress != '00:00:00:00:00:00': + interfaces[device]['macaddress'] = macaddress + if os.path.exists(os.path.join(path, 'mtu')): + interfaces[device]['mtu'] = int(get_file_content(os.path.join(path, 'mtu'))) + if os.path.exists(os.path.join(path, 'operstate')): + interfaces[device]['active'] = get_file_content(os.path.join(path, 'operstate')) != 'down' + if os.path.exists(os.path.join(path, 'device', 'driver', 'module')): + interfaces[device]['module'] = os.path.basename(os.path.realpath(os.path.join(path, 'device', 'driver', 'module'))) + if os.path.exists(os.path.join(path, 'type')): + _type = get_file_content(os.path.join(path, 'type')) + interfaces[device]['type'] = self.INTERFACE_TYPE.get(_type, 'unknown') + if os.path.exists(os.path.join(path, 'bridge')): + interfaces[device]['type'] = 'bridge' + interfaces[device]['interfaces'] = [os.path.basename(b) for b in glob.glob(os.path.join(path, 'brif', '*'))] + if os.path.exists(os.path.join(path, 'bridge', 'bridge_id')): + interfaces[device]['id'] = get_file_content(os.path.join(path, 'bridge', 'bridge_id'), default='') + if os.path.exists(os.path.join(path, 'bridge', 'stp_state')): + interfaces[device]['stp'] = get_file_content(os.path.join(path, 'bridge', 'stp_state')) == '1' + if os.path.exists(os.path.join(path, 'bonding')): + interfaces[device]['type'] = 'bonding' + interfaces[device]['slaves'] = get_file_content(os.path.join(path, 'bonding', 'slaves'), default='').split() + interfaces[device]['mode'] = get_file_content(os.path.join(path, 'bonding', 'mode'), default='').split()[0] + interfaces[device]['miimon'] = get_file_content(os.path.join(path, 'bonding', 'miimon'), default='').split()[0] + interfaces[device]['lacp_rate'] = get_file_content(os.path.join(path, 'bonding', 'lacp_rate'), default='').split()[0] + primary = get_file_content(os.path.join(path, 'bonding', 'primary')) + if primary: + interfaces[device]['primary'] = primary + path = os.path.join(path, 'bonding', 'all_slaves_active') + if os.path.exists(path): + interfaces[device]['all_slaves_active'] = get_file_content(path) == '1' + if os.path.exists(os.path.join(path, 'bonding_slave')): + interfaces[device]['perm_macaddress'] = get_file_content(os.path.join(path, 'bonding_slave', 'perm_hwaddr'), default='') + if os.path.exists(os.path.join(path, 'device')): + interfaces[device]['pciid'] = os.path.basename(os.readlink(os.path.join(path, 'device'))) + if os.path.exists(os.path.join(path, 'speed')): + speed = get_file_content(os.path.join(path, 'speed')) + if speed is not None: + interfaces[device]['speed'] = int(speed) + + # Check whether an interface is in promiscuous mode + if os.path.exists(os.path.join(path, 'flags')): + promisc_mode = False + # The second byte indicates whether the interface is in promiscuous mode. + # 1 = promisc + # 0 = no promisc + data = int(get_file_content(os.path.join(path, 'flags')), 16) + promisc_mode = (data & 0x0100 > 0) + interfaces[device]['promisc'] = promisc_mode + + # TODO: determine if this needs to be in a nested scope/closure + def parse_ip_output(output, secondary=False): + for line in output.splitlines(): + if not line: + continue + words = line.split() + broadcast = '' + if words[0] == 'inet': + if '/' in words[1]: + address, netmask_length = words[1].split('/') + if len(words) > 3: + if words[2] == 'brd': + broadcast = words[3] + else: + # pointopoint interfaces do not have a prefix + address = words[1] + netmask_length = "32" + address_bin = struct.unpack('!L', socket.inet_aton(address))[0] + netmask_bin = (1 << 32) - (1 << 32 >> int(netmask_length)) + netmask = socket.inet_ntoa(struct.pack('!L', netmask_bin)) + network = socket.inet_ntoa(struct.pack('!L', address_bin & netmask_bin)) + iface = words[-1] + # NOTE: device is ref to outside scope + # NOTE: interfaces is also ref to outside scope + if iface != device: + interfaces[iface] = {} + if not secondary and "ipv4" not in interfaces[iface]: + interfaces[iface]['ipv4'] = {'address': address, + 'broadcast': broadcast, + 'netmask': netmask, + 'network': network, + 'prefix': netmask_length, + } + else: + if "ipv4_secondaries" not in interfaces[iface]: + interfaces[iface]["ipv4_secondaries"] = [] + interfaces[iface]["ipv4_secondaries"].append({ + 'address': address, + 'broadcast': broadcast, + 'netmask': netmask, + 'network': network, + 'prefix': netmask_length, + }) + + # add this secondary IP to the main device + if secondary: + if "ipv4_secondaries" not in interfaces[device]: + interfaces[device]["ipv4_secondaries"] = [] + if device != iface: + interfaces[device]["ipv4_secondaries"].append({ + 'address': address, + 'broadcast': broadcast, + 'netmask': netmask, + 'network': network, + 'prefix': netmask_length, + }) + + # NOTE: default_ipv4 is ref to outside scope + # If this is the default address, update default_ipv4 + if 'address' in default_ipv4 and default_ipv4['address'] == address: + default_ipv4['broadcast'] = broadcast + default_ipv4['netmask'] = netmask + default_ipv4['network'] = network + default_ipv4['prefix'] = netmask_length + # NOTE: macaddress is ref from outside scope + default_ipv4['macaddress'] = macaddress + default_ipv4['mtu'] = interfaces[device]['mtu'] + default_ipv4['type'] = interfaces[device].get("type", "unknown") + default_ipv4['alias'] = words[-1] + if not address.startswith('127.'): + ips['all_ipv4_addresses'].append(address) + elif words[0] == 'inet6': + if 'peer' == words[2]: + address = words[1] + _, prefix = words[3].split('/') + scope = words[5] + else: + address, prefix = words[1].split('/') + scope = words[3] + if 'ipv6' not in interfaces[device]: + interfaces[device]['ipv6'] = [] + interfaces[device]['ipv6'].append({ + 'address': address, + 'prefix': prefix, + 'scope': scope + }) + # If this is the default address, update default_ipv6 + if 'address' in default_ipv6 and default_ipv6['address'] == address: + default_ipv6['prefix'] = prefix + default_ipv6['scope'] = scope + default_ipv6['macaddress'] = macaddress + default_ipv6['mtu'] = interfaces[device]['mtu'] + default_ipv6['type'] = interfaces[device].get("type", "unknown") + if not address == '::1': + ips['all_ipv6_addresses'].append(address) + + ip_path = self.module.get_bin_path("ip") + + args = [ip_path, 'addr', 'show', 'primary', 'dev', device] + rc, primary_data, stderr = self.module.run_command(args, errors='surrogate_then_replace') + if rc == 0: + parse_ip_output(primary_data) + else: + # possibly busybox, fallback to running without the "primary" arg + # https://github.com/ansible/ansible/issues/50871 + args = [ip_path, 'addr', 'show', 'dev', device] + rc, data, stderr = self.module.run_command(args, errors='surrogate_then_replace') + if rc == 0: + parse_ip_output(data) + + args = [ip_path, 'addr', 'show', 'secondary', 'dev', device] + rc, secondary_data, stderr = self.module.run_command(args, errors='surrogate_then_replace') + if rc == 0: + parse_ip_output(secondary_data, secondary=True) + + interfaces[device].update(self.get_ethtool_data(device)) + + # replace : by _ in interface name since they are hard to use in template + new_interfaces = {} + # i is a dict key (string) not an index int + for i in interfaces: + if ':' in i: + new_interfaces[i.replace(':', '_')] = interfaces[i] + else: + new_interfaces[i] = interfaces[i] + return new_interfaces, ips + + def get_ethtool_data(self, device): + + data = {} + ethtool_path = self.module.get_bin_path("ethtool") + # FIXME: exit early on falsey ethtool_path and un-indent + if ethtool_path: + args = [ethtool_path, '-k', device] + rc, stdout, stderr = self.module.run_command(args, errors='surrogate_then_replace') + # FIXME: exit early on falsey if we can + if rc == 0: + features = {} + for line in stdout.strip().splitlines(): + if not line or line.endswith(":"): + continue + key, value = line.split(": ") + if not value: + continue + features[key.strip().replace('-', '_')] = value.strip() + data['features'] = features + + args = [ethtool_path, '-T', device] + rc, stdout, stderr = self.module.run_command(args, errors='surrogate_then_replace') + if rc == 0: + data['timestamping'] = [m.lower() for m in re.findall(r'SOF_TIMESTAMPING_(\w+)', stdout)] + data['hw_timestamp_filters'] = [m.lower() for m in re.findall(r'HWTSTAMP_FILTER_(\w+)', stdout)] + m = re.search(r'PTP Hardware Clock: (\d+)', stdout) + if m: + data['phc_index'] = int(m.groups()[0]) + + return data + + +class LinuxNetworkCollector(NetworkCollector): + _platform = 'Linux' + _fact_class = LinuxNetwork + required_facts = set(['distribution', 'platform']) diff --git a/lib/ansible/module_utils/facts/network/netbsd.py b/lib/ansible/module_utils/facts/network/netbsd.py new file mode 100644 index 0000000..de8ceff --- /dev/null +++ b/lib/ansible/module_utils/facts/network/netbsd.py @@ -0,0 +1,48 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.generic_bsd import GenericBsdIfconfigNetwork + + +class NetBSDNetwork(GenericBsdIfconfigNetwork): + """ + This is the NetBSD Network Class. + It uses the GenericBsdIfconfigNetwork + """ + platform = 'NetBSD' + + def parse_media_line(self, words, current_if, ips): + # example of line: + # $ ifconfig + # ne0: flags=8863<UP,BROADCAST,NOTRAILERS,RUNNING,SIMPLEX,MULTICAST> mtu 1500 + # ec_capabilities=1<VLAN_MTU> + # ec_enabled=0 + # address: 00:20:91:45:00:78 + # media: Ethernet 10baseT full-duplex + # inet 192.168.156.29 netmask 0xffffff00 broadcast 192.168.156.255 + current_if['media'] = words[1] + if len(words) > 2: + current_if['media_type'] = words[2] + if len(words) > 3: + current_if['media_options'] = words[3].split(',') + + +class NetBSDNetworkCollector(NetworkCollector): + _fact_class = NetBSDNetwork + _platform = 'NetBSD' diff --git a/lib/ansible/module_utils/facts/network/nvme.py b/lib/ansible/module_utils/facts/network/nvme.py new file mode 100644 index 0000000..febd0ab --- /dev/null +++ b/lib/ansible/module_utils/facts/network/nvme.py @@ -0,0 +1,57 @@ +# NVMe initiator related facts collection for Ansible. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import sys +import subprocess + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_content +from ansible.module_utils.facts.network.base import NetworkCollector + + +class NvmeInitiatorNetworkCollector(NetworkCollector): + name = 'nvme' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + """ + Currently NVMe is only supported in some Linux distributions. + If NVMe is configured on the host then a file will have been created + during the NVMe driver installation. This file holds the unique NQN + of the host. + + Example of contents of /etc/nvme/hostnqn: + + # cat /etc/nvme/hostnqn + nqn.2014-08.org.nvmexpress:fc_lif:uuid:2cd61a74-17f9-4c22-b350-3020020c458d + + """ + + nvme_facts = {} + nvme_facts['hostnqn'] = "" + if sys.platform.startswith('linux'): + for line in get_file_content('/etc/nvme/hostnqn', '').splitlines(): + if line.startswith('#') or line.startswith(';') or line.strip() == '': + continue + if line.startswith('nqn.'): + nvme_facts['hostnqn'] = line + break + return nvme_facts diff --git a/lib/ansible/module_utils/facts/network/openbsd.py b/lib/ansible/module_utils/facts/network/openbsd.py new file mode 100644 index 0000000..9e11d82 --- /dev/null +++ b/lib/ansible/module_utils/facts/network/openbsd.py @@ -0,0 +1,42 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.generic_bsd import GenericBsdIfconfigNetwork + + +class OpenBSDNetwork(GenericBsdIfconfigNetwork): + """ + This is the OpenBSD Network Class. + It uses the GenericBsdIfconfigNetwork. + """ + platform = 'OpenBSD' + + # OpenBSD 'ifconfig -a' does not have information about aliases + def get_interfaces_info(self, ifconfig_path, ifconfig_options='-aA'): + return super(OpenBSDNetwork, self).get_interfaces_info(ifconfig_path, ifconfig_options) + + # Return macaddress instead of lladdr + def parse_lladdr_line(self, words, current_if, ips): + current_if['macaddress'] = words[1] + current_if['type'] = 'ether' + + +class OpenBSDNetworkCollector(NetworkCollector): + _fact_class = OpenBSDNetwork + _platform = 'OpenBSD' diff --git a/lib/ansible/module_utils/facts/network/sunos.py b/lib/ansible/module_utils/facts/network/sunos.py new file mode 100644 index 0000000..adba14c --- /dev/null +++ b/lib/ansible/module_utils/facts/network/sunos.py @@ -0,0 +1,116 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +from ansible.module_utils.facts.network.base import NetworkCollector +from ansible.module_utils.facts.network.generic_bsd import GenericBsdIfconfigNetwork + + +class SunOSNetwork(GenericBsdIfconfigNetwork): + """ + This is the SunOS Network Class. + It uses the GenericBsdIfconfigNetwork. + + Solaris can have different FLAGS and MTU for IPv4 and IPv6 on the same interface + so these facts have been moved inside the 'ipv4' and 'ipv6' lists. + """ + platform = 'SunOS' + + # Solaris 'ifconfig -a' will print interfaces twice, once for IPv4 and again for IPv6. + # MTU and FLAGS also may differ between IPv4 and IPv6 on the same interface. + # 'parse_interface_line()' checks for previously seen interfaces before defining + # 'current_if' so that IPv6 facts don't clobber IPv4 facts (or vice versa). + def get_interfaces_info(self, ifconfig_path): + interfaces = {} + current_if = {} + ips = dict( + all_ipv4_addresses=[], + all_ipv6_addresses=[], + ) + rc, out, err = self.module.run_command([ifconfig_path, '-a']) + + for line in out.splitlines(): + + if line: + words = line.split() + + if re.match(r'^\S', line) and len(words) > 3: + current_if = self.parse_interface_line(words, current_if, interfaces) + interfaces[current_if['device']] = current_if + elif words[0].startswith('options='): + self.parse_options_line(words, current_if, ips) + elif words[0] == 'nd6': + self.parse_nd6_line(words, current_if, ips) + elif words[0] == 'ether': + self.parse_ether_line(words, current_if, ips) + elif words[0] == 'media:': + self.parse_media_line(words, current_if, ips) + elif words[0] == 'status:': + self.parse_status_line(words, current_if, ips) + elif words[0] == 'lladdr': + self.parse_lladdr_line(words, current_if, ips) + elif words[0] == 'inet': + self.parse_inet_line(words, current_if, ips) + elif words[0] == 'inet6': + self.parse_inet6_line(words, current_if, ips) + else: + self.parse_unknown_line(words, current_if, ips) + + # 'parse_interface_line' and 'parse_inet*_line' leave two dicts in the + # ipv4/ipv6 lists which is ugly and hard to read. + # This quick hack merges the dictionaries. Purely cosmetic. + for iface in interfaces: + for v in 'ipv4', 'ipv6': + combined_facts = {} + for facts in interfaces[iface][v]: + combined_facts.update(facts) + if len(combined_facts.keys()) > 0: + interfaces[iface][v] = [combined_facts] + + return interfaces, ips + + def parse_interface_line(self, words, current_if, interfaces): + device = words[0][0:-1] + if device not in interfaces: + current_if = {'device': device, 'ipv4': [], 'ipv6': [], 'type': 'unknown'} + else: + current_if = interfaces[device] + flags = self.get_options(words[1]) + v = 'ipv4' + if 'IPv6' in flags: + v = 'ipv6' + if 'LOOPBACK' in flags: + current_if['type'] = 'loopback' + current_if[v].append({'flags': flags, 'mtu': words[3]}) + current_if['macaddress'] = 'unknown' # will be overwritten later + return current_if + + # Solaris displays single digit octets in MAC addresses e.g. 0:1:2:d:e:f + # Add leading zero to each octet where needed. + def parse_ether_line(self, words, current_if, ips): + macaddress = '' + for octet in words[1].split(':'): + octet = ('0' + octet)[-2:None] + macaddress += (octet + ':') + current_if['macaddress'] = macaddress[0:-1] + + +class SunOSNetworkCollector(NetworkCollector): + _fact_class = SunOSNetwork + _platform = 'SunOS' diff --git a/lib/ansible/module_utils/facts/other/__init__.py b/lib/ansible/module_utils/facts/other/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/facts/other/__init__.py diff --git a/lib/ansible/module_utils/facts/other/facter.py b/lib/ansible/module_utils/facts/other/facter.py new file mode 100644 index 0000000..3f83999 --- /dev/null +++ b/lib/ansible/module_utils/facts/other/facter.py @@ -0,0 +1,87 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.namespace import PrefixFactNamespace + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class FacterFactCollector(BaseFactCollector): + name = 'facter' + _fact_ids = set(['facter']) # type: t.Set[str] + + def __init__(self, collectors=None, namespace=None): + namespace = PrefixFactNamespace(namespace_name='facter', + prefix='facter_') + super(FacterFactCollector, self).__init__(collectors=collectors, + namespace=namespace) + + def find_facter(self, module): + facter_path = module.get_bin_path('facter', opt_dirs=['/opt/puppetlabs/bin']) + cfacter_path = module.get_bin_path('cfacter', opt_dirs=['/opt/puppetlabs/bin']) + + # Prefer to use cfacter if available + if cfacter_path is not None: + facter_path = cfacter_path + + return facter_path + + def run_facter(self, module, facter_path): + # if facter is installed, and we can use --json because + # ruby-json is ALSO installed, include facter data in the JSON + rc, out, err = module.run_command(facter_path + " --puppet --json") + return rc, out, err + + def get_facter_output(self, module): + facter_path = self.find_facter(module) + if not facter_path: + return None + + rc, out, err = self.run_facter(module, facter_path) + + if rc != 0: + return None + + return out + + def collect(self, module=None, collected_facts=None): + # Note that this mirrors previous facter behavior, where there isnt + # a 'ansible_facter' key in the main fact dict, but instead, 'facter_whatever' + # items are added to the main dict. + facter_dict = {} + + if not module: + return facter_dict + + facter_output = self.get_facter_output(module) + + # TODO: if we fail, should we add a empty facter key or nothing? + if facter_output is None: + return facter_dict + + try: + facter_dict = json.loads(facter_output) + except Exception: + # FIXME: maybe raise a FactCollectorError with some info attrs? + pass + + return facter_dict diff --git a/lib/ansible/module_utils/facts/other/ohai.py b/lib/ansible/module_utils/facts/other/ohai.py new file mode 100644 index 0000000..90c5539 --- /dev/null +++ b/lib/ansible/module_utils/facts/other/ohai.py @@ -0,0 +1,74 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.namespace import PrefixFactNamespace + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class OhaiFactCollector(BaseFactCollector): + '''This is a subclass of Facts for including information gathered from Ohai.''' + name = 'ohai' + _fact_ids = set() # type: t.Set[str] + + def __init__(self, collectors=None, namespace=None): + namespace = PrefixFactNamespace(namespace_name='ohai', + prefix='ohai_') + super(OhaiFactCollector, self).__init__(collectors=collectors, + namespace=namespace) + + def find_ohai(self, module): + ohai_path = module.get_bin_path('ohai') + return ohai_path + + def run_ohai(self, module, ohai_path,): + rc, out, err = module.run_command(ohai_path) + return rc, out, err + + def get_ohai_output(self, module): + ohai_path = self.find_ohai(module) + if not ohai_path: + return None + + rc, out, err = self.run_ohai(module, ohai_path) + if rc != 0: + return None + + return out + + def collect(self, module=None, collected_facts=None): + ohai_facts = {} + if not module: + return ohai_facts + + ohai_output = self.get_ohai_output(module) + + if ohai_output is None: + return ohai_facts + + try: + ohai_facts = json.loads(ohai_output) + except Exception: + # FIXME: useful error, logging, something... + pass + + return ohai_facts diff --git a/lib/ansible/module_utils/facts/packages.py b/lib/ansible/module_utils/facts/packages.py new file mode 100644 index 0000000..53f74a1 --- /dev/null +++ b/lib/ansible/module_utils/facts/packages.py @@ -0,0 +1,86 @@ +# (c) 2018, Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from abc import ABCMeta, abstractmethod + +from ansible.module_utils.six import with_metaclass +from ansible.module_utils.common.process import get_bin_path +from ansible.module_utils.common._utils import get_all_subclasses + + +def get_all_pkg_managers(): + + return {obj.__name__.lower(): obj for obj in get_all_subclasses(PkgMgr) if obj not in (CLIMgr, LibMgr)} + + +class PkgMgr(with_metaclass(ABCMeta, object)): # type: ignore[misc] + + @abstractmethod + def is_available(self): + # This method is supposed to return True/False if the package manager is currently installed/usable + # It can also 'prep' the required systems in the process of detecting availability + pass + + @abstractmethod + def list_installed(self): + # This method should return a list of installed packages, each list item will be passed to get_package_details + pass + + @abstractmethod + def get_package_details(self, package): + # This takes a 'package' item and returns a dictionary with the package information, name and version are minimal requirements + pass + + def get_packages(self): + # Take all of the above and return a dictionary of lists of dictionaries (package = list of installed versions) + + installed_packages = {} + for package in self.list_installed(): + package_details = self.get_package_details(package) + if 'source' not in package_details: + package_details['source'] = self.__class__.__name__.lower() + name = package_details['name'] + if name not in installed_packages: + installed_packages[name] = [package_details] + else: + installed_packages[name].append(package_details) + return installed_packages + + +class LibMgr(PkgMgr): + + LIB = None # type: str | None + + def __init__(self): + + self._lib = None + super(LibMgr, self).__init__() + + def is_available(self): + found = False + try: + self._lib = __import__(self.LIB) + found = True + except ImportError: + pass + return found + + +class CLIMgr(PkgMgr): + + CLI = None # type: str | None + + def __init__(self): + + self._cli = None + super(CLIMgr, self).__init__() + + def is_available(self): + try: + self._cli = get_bin_path(self.CLI) + except ValueError: + return False + return True diff --git a/lib/ansible/module_utils/facts/sysctl.py b/lib/ansible/module_utils/facts/sysctl.py new file mode 100644 index 0000000..2c55d77 --- /dev/null +++ b/lib/ansible/module_utils/facts/sysctl.py @@ -0,0 +1,62 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +from ansible.module_utils._text import to_text + + +def get_sysctl(module, prefixes): + sysctl_cmd = module.get_bin_path('sysctl') + cmd = [sysctl_cmd] + cmd.extend(prefixes) + + sysctl = dict() + + try: + rc, out, err = module.run_command(cmd) + except (IOError, OSError) as e: + module.warn('Unable to read sysctl: %s' % to_text(e)) + rc = 1 + + if rc == 0: + key = '' + value = '' + for line in out.splitlines(): + if not line.strip(): + continue + + if line.startswith(' '): + # handle multiline values, they will not have a starting key + # Add the newline back in so people can split on it to parse + # lines if they need to. + value += '\n' + line + continue + + if key: + sysctl[key] = value.strip() + + try: + (key, value) = re.split(r'\s?=\s?|: ', line, maxsplit=1) + except Exception as e: + module.warn('Unable to split sysctl line (%s): %s' % (to_text(line), to_text(e))) + + if key: + sysctl[key] = value.strip() + + return sysctl diff --git a/lib/ansible/module_utils/facts/system/__init__.py b/lib/ansible/module_utils/facts/system/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/__init__.py diff --git a/lib/ansible/module_utils/facts/system/apparmor.py b/lib/ansible/module_utils/facts/system/apparmor.py new file mode 100644 index 0000000..3b702f9 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/apparmor.py @@ -0,0 +1,41 @@ +# Collect facts related to apparmor +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class ApparmorFactCollector(BaseFactCollector): + name = 'apparmor' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + facts_dict = {} + apparmor_facts = {} + if os.path.exists('/sys/kernel/security/apparmor'): + apparmor_facts['status'] = 'enabled' + else: + apparmor_facts['status'] = 'disabled' + + facts_dict['apparmor'] = apparmor_facts + return facts_dict diff --git a/lib/ansible/module_utils/facts/system/caps.py b/lib/ansible/module_utils/facts/system/caps.py new file mode 100644 index 0000000..6a1e26d --- /dev/null +++ b/lib/ansible/module_utils/facts/system/caps.py @@ -0,0 +1,62 @@ +# Collect facts related to systems 'capabilities' via capsh +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils._text import to_text +from ansible.module_utils.facts.collector import BaseFactCollector + + +class SystemCapabilitiesFactCollector(BaseFactCollector): + name = 'caps' + _fact_ids = set(['system_capabilities', + 'system_capabilities_enforced']) # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + + rc = -1 + facts_dict = {'system_capabilities_enforced': 'N/A', + 'system_capabilities': 'N/A'} + if module: + capsh_path = module.get_bin_path('capsh') + if capsh_path: + # NOTE: -> get_caps_data()/parse_caps_data() for easier mocking -akl + try: + rc, out, err = module.run_command([capsh_path, "--print"], errors='surrogate_then_replace', handle_exceptions=False) + except (IOError, OSError) as e: + module.warn('Could not query system capabilities: %s' % str(e)) + + if rc == 0: + enforced_caps = [] + enforced = 'NA' + for line in out.splitlines(): + if len(line) < 1: + continue + if line.startswith('Current:'): + if line.split(':')[1].strip() == '=ep': + enforced = 'False' + else: + enforced = 'True' + enforced_caps = [i.strip() for i in line.split('=')[1].split(',')] + + facts_dict['system_capabilities_enforced'] = enforced + facts_dict['system_capabilities'] = enforced_caps + + return facts_dict diff --git a/lib/ansible/module_utils/facts/system/chroot.py b/lib/ansible/module_utils/facts/system/chroot.py new file mode 100644 index 0000000..94138a0 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/chroot.py @@ -0,0 +1,49 @@ +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +def is_chroot(module=None): + + is_chroot = None + + if os.environ.get('debian_chroot', False): + is_chroot = True + else: + my_root = os.stat('/') + try: + # check if my file system is the root one + proc_root = os.stat('/proc/1/root/.') + is_chroot = my_root.st_ino != proc_root.st_ino or my_root.st_dev != proc_root.st_dev + except Exception: + # I'm not root or no proc, fallback to checking it is inode #2 + fs_root_ino = 2 + + if module is not None: + stat_path = module.get_bin_path('stat') + if stat_path: + cmd = [stat_path, '-f', '--format=%T', '/'] + rc, out, err = module.run_command(cmd) + if 'btrfs' in out: + fs_root_ino = 256 + elif 'xfs' in out: + fs_root_ino = 128 + + is_chroot = (my_root.st_ino != fs_root_ino) + + return is_chroot + + +class ChrootFactCollector(BaseFactCollector): + name = 'chroot' + _fact_ids = set(['is_chroot']) # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + return {'is_chroot': is_chroot(module)} diff --git a/lib/ansible/module_utils/facts/system/cmdline.py b/lib/ansible/module_utils/facts/system/cmdline.py new file mode 100644 index 0000000..782186d --- /dev/null +++ b/lib/ansible/module_utils/facts/system/cmdline.py @@ -0,0 +1,81 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import shlex + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_content + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class CmdLineFactCollector(BaseFactCollector): + name = 'cmdline' + _fact_ids = set() # type: t.Set[str] + + def _get_proc_cmdline(self): + return get_file_content('/proc/cmdline') + + def _parse_proc_cmdline(self, data): + cmdline_dict = {} + try: + for piece in shlex.split(data, posix=False): + item = piece.split('=', 1) + if len(item) == 1: + cmdline_dict[item[0]] = True + else: + cmdline_dict[item[0]] = item[1] + except ValueError: + pass + + return cmdline_dict + + def _parse_proc_cmdline_facts(self, data): + cmdline_dict = {} + try: + for piece in shlex.split(data, posix=False): + item = piece.split('=', 1) + if len(item) == 1: + cmdline_dict[item[0]] = True + else: + if item[0] in cmdline_dict: + if isinstance(cmdline_dict[item[0]], list): + cmdline_dict[item[0]].append(item[1]) + else: + new_list = [cmdline_dict[item[0]], item[1]] + cmdline_dict[item[0]] = new_list + else: + cmdline_dict[item[0]] = item[1] + except ValueError: + pass + + return cmdline_dict + + def collect(self, module=None, collected_facts=None): + cmdline_facts = {} + + data = self._get_proc_cmdline() + + if not data: + return cmdline_facts + + cmdline_facts['cmdline'] = self._parse_proc_cmdline(data) + cmdline_facts['proc_cmdline'] = self._parse_proc_cmdline_facts(data) + + return cmdline_facts diff --git a/lib/ansible/module_utils/facts/system/date_time.py b/lib/ansible/module_utils/facts/system/date_time.py new file mode 100644 index 0000000..481bef4 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/date_time.py @@ -0,0 +1,70 @@ +# Data and time related facts collection for ansible. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import datetime +import time + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class DateTimeFactCollector(BaseFactCollector): + name = 'date_time' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + facts_dict = {} + date_time_facts = {} + + # Store the timestamp once, then get local and UTC versions from that + epoch_ts = time.time() + now = datetime.datetime.fromtimestamp(epoch_ts) + utcnow = datetime.datetime.utcfromtimestamp(epoch_ts) + + date_time_facts['year'] = now.strftime('%Y') + date_time_facts['month'] = now.strftime('%m') + date_time_facts['weekday'] = now.strftime('%A') + date_time_facts['weekday_number'] = now.strftime('%w') + date_time_facts['weeknumber'] = now.strftime('%W') + date_time_facts['day'] = now.strftime('%d') + date_time_facts['hour'] = now.strftime('%H') + date_time_facts['minute'] = now.strftime('%M') + date_time_facts['second'] = now.strftime('%S') + date_time_facts['epoch'] = now.strftime('%s') + # epoch returns float or string in some non-linux environments + if date_time_facts['epoch'] == '' or date_time_facts['epoch'][0] == '%': + date_time_facts['epoch'] = str(int(epoch_ts)) + # epoch_int always returns integer format of epoch + date_time_facts['epoch_int'] = str(int(now.strftime('%s'))) + if date_time_facts['epoch_int'] == '' or date_time_facts['epoch_int'][0] == '%': + date_time_facts['epoch_int'] = str(int(epoch_ts)) + date_time_facts['date'] = now.strftime('%Y-%m-%d') + date_time_facts['time'] = now.strftime('%H:%M:%S') + date_time_facts['iso8601_micro'] = utcnow.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + date_time_facts['iso8601'] = utcnow.strftime("%Y-%m-%dT%H:%M:%SZ") + date_time_facts['iso8601_basic'] = now.strftime("%Y%m%dT%H%M%S%f") + date_time_facts['iso8601_basic_short'] = now.strftime("%Y%m%dT%H%M%S") + date_time_facts['tz'] = time.strftime("%Z") + date_time_facts['tz_dst'] = time.tzname[1] + date_time_facts['tz_offset'] = time.strftime("%z") + + facts_dict['date_time'] = date_time_facts + return facts_dict diff --git a/lib/ansible/module_utils/facts/system/distribution.py b/lib/ansible/module_utils/facts/system/distribution.py new file mode 100644 index 0000000..dcb6e5a --- /dev/null +++ b/lib/ansible/module_utils/facts/system/distribution.py @@ -0,0 +1,726 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import platform +import re + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.common.sys_info import get_distribution, get_distribution_version, \ + get_distribution_codename +from ansible.module_utils.facts.utils import get_file_content, get_file_lines +from ansible.module_utils.facts.collector import BaseFactCollector + + +def get_uname(module, flags=('-v')): + if isinstance(flags, str): + flags = flags.split() + command = ['uname'] + command.extend(flags) + rc, out, err = module.run_command(command) + if rc == 0: + return out + return None + + +def _file_exists(path, allow_empty=False): + # not finding the file, exit early + if not os.path.exists(path): + return False + + # if just the path needs to exists (ie, it can be empty) we are done + if allow_empty: + return True + + # file exists but is empty and we dont allow_empty + if os.path.getsize(path) == 0: + return False + + # file exists with some content + return True + + +class DistributionFiles: + '''has-a various distro file parsers (os-release, etc) and logic for finding the right one.''' + # every distribution name mentioned here, must have one of + # - allowempty == True + # - be listed in SEARCH_STRING + # - have a function get_distribution_DISTNAME implemented + # keep names in sync with Conditionals page of docs + OSDIST_LIST = ( + {'path': '/etc/altlinux-release', 'name': 'Altlinux'}, + {'path': '/etc/oracle-release', 'name': 'OracleLinux'}, + {'path': '/etc/slackware-version', 'name': 'Slackware'}, + {'path': '/etc/centos-release', 'name': 'CentOS'}, + {'path': '/etc/redhat-release', 'name': 'RedHat'}, + {'path': '/etc/vmware-release', 'name': 'VMwareESX', 'allowempty': True}, + {'path': '/etc/openwrt_release', 'name': 'OpenWrt'}, + {'path': '/etc/os-release', 'name': 'Amazon'}, + {'path': '/etc/system-release', 'name': 'Amazon'}, + {'path': '/etc/alpine-release', 'name': 'Alpine'}, + {'path': '/etc/arch-release', 'name': 'Archlinux', 'allowempty': True}, + {'path': '/etc/os-release', 'name': 'Archlinux'}, + {'path': '/etc/os-release', 'name': 'SUSE'}, + {'path': '/etc/SuSE-release', 'name': 'SUSE'}, + {'path': '/etc/gentoo-release', 'name': 'Gentoo'}, + {'path': '/etc/os-release', 'name': 'Debian'}, + {'path': '/etc/lsb-release', 'name': 'Debian'}, + {'path': '/etc/lsb-release', 'name': 'Mandriva'}, + {'path': '/etc/sourcemage-release', 'name': 'SMGL'}, + {'path': '/usr/lib/os-release', 'name': 'ClearLinux'}, + {'path': '/etc/coreos/update.conf', 'name': 'Coreos'}, + {'path': '/etc/os-release', 'name': 'Flatcar'}, + {'path': '/etc/os-release', 'name': 'NA'}, + ) + + SEARCH_STRING = { + 'OracleLinux': 'Oracle Linux', + 'RedHat': 'Red Hat', + 'Altlinux': 'ALT', + 'SMGL': 'Source Mage GNU/Linux', + } + + # We can't include this in SEARCH_STRING because a name match on its keys + # causes a fallback to using the first whitespace separated item from the file content + # as the name. For os-release, that is in form 'NAME=Arch' + OS_RELEASE_ALIAS = { + 'Archlinux': 'Arch Linux' + } + + STRIP_QUOTES = r'\'\"\\' + + def __init__(self, module): + self.module = module + + def _get_file_content(self, path): + return get_file_content(path) + + def _get_dist_file_content(self, path, allow_empty=False): + # cant find that dist file or it is incorrectly empty + if not _file_exists(path, allow_empty=allow_empty): + return False, None + + data = self._get_file_content(path) + return True, data + + def _parse_dist_file(self, name, dist_file_content, path, collected_facts): + dist_file_dict = {} + dist_file_content = dist_file_content.strip(DistributionFiles.STRIP_QUOTES) + if name in self.SEARCH_STRING: + # look for the distribution string in the data and replace according to RELEASE_NAME_MAP + # only the distribution name is set, the version is assumed to be correct from distro.linux_distribution() + if self.SEARCH_STRING[name] in dist_file_content: + # this sets distribution=RedHat if 'Red Hat' shows up in data + dist_file_dict['distribution'] = name + dist_file_dict['distribution_file_search_string'] = self.SEARCH_STRING[name] + else: + # this sets distribution to what's in the data, e.g. CentOS, Scientific, ... + dist_file_dict['distribution'] = dist_file_content.split()[0] + + return True, dist_file_dict + + if name in self.OS_RELEASE_ALIAS: + if self.OS_RELEASE_ALIAS[name] in dist_file_content: + dist_file_dict['distribution'] = name + return True, dist_file_dict + return False, dist_file_dict + + # call a dedicated function for parsing the file content + # TODO: replace with a map or a class + try: + # FIXME: most of these dont actually look at the dist file contents, but random other stuff + distfunc_name = 'parse_distribution_file_' + name + distfunc = getattr(self, distfunc_name) + parsed, dist_file_dict = distfunc(name, dist_file_content, path, collected_facts) + return parsed, dist_file_dict + except AttributeError as exc: + self.module.debug('exc: %s' % exc) + # this should never happen, but if it does fail quietly and not with a traceback + return False, dist_file_dict + + return True, dist_file_dict + # to debug multiple matching release files, one can use: + # self.facts['distribution_debug'].append({path + ' ' + name: + # (parsed, + # self.facts['distribution'], + # self.facts['distribution_version'], + # self.facts['distribution_release'], + # )}) + + def _guess_distribution(self): + # try to find out which linux distribution this is + dist = (get_distribution(), get_distribution_version(), get_distribution_codename()) + distribution_guess = { + 'distribution': dist[0] or 'NA', + 'distribution_version': dist[1] or 'NA', + # distribution_release can be the empty string + 'distribution_release': 'NA' if dist[2] is None else dist[2] + } + + distribution_guess['distribution_major_version'] = distribution_guess['distribution_version'].split('.')[0] or 'NA' + return distribution_guess + + def process_dist_files(self): + # Try to handle the exceptions now ... + # self.facts['distribution_debug'] = [] + dist_file_facts = {} + + dist_guess = self._guess_distribution() + dist_file_facts.update(dist_guess) + + for ddict in self.OSDIST_LIST: + name = ddict['name'] + path = ddict['path'] + allow_empty = ddict.get('allowempty', False) + + has_dist_file, dist_file_content = self._get_dist_file_content(path, allow_empty=allow_empty) + + # but we allow_empty. For example, ArchLinux with an empty /etc/arch-release and a + # /etc/os-release with a different name + if has_dist_file and allow_empty: + dist_file_facts['distribution'] = name + dist_file_facts['distribution_file_path'] = path + dist_file_facts['distribution_file_variety'] = name + break + + if not has_dist_file: + # keep looking + continue + + parsed_dist_file, parsed_dist_file_facts = self._parse_dist_file(name, dist_file_content, path, dist_file_facts) + + # finally found the right os dist file and were able to parse it + if parsed_dist_file: + dist_file_facts['distribution'] = name + dist_file_facts['distribution_file_path'] = path + # distribution and file_variety are the same here, but distribution + # will be changed/mapped to a more specific name. + # ie, dist=Fedora, file_variety=RedHat + dist_file_facts['distribution_file_variety'] = name + dist_file_facts['distribution_file_parsed'] = parsed_dist_file + dist_file_facts.update(parsed_dist_file_facts) + break + + return dist_file_facts + + # TODO: FIXME: split distro file parsing into its own module or class + def parse_distribution_file_Slackware(self, name, data, path, collected_facts): + slackware_facts = {} + if 'Slackware' not in data: + return False, slackware_facts # TODO: remove + slackware_facts['distribution'] = name + version = re.findall(r'\w+[.]\w+\+?', data) + if version: + slackware_facts['distribution_version'] = version[0] + return True, slackware_facts + + def parse_distribution_file_Amazon(self, name, data, path, collected_facts): + amazon_facts = {} + if 'Amazon' not in data: + return False, amazon_facts + amazon_facts['distribution'] = 'Amazon' + if path == '/etc/os-release': + version = re.search(r"VERSION_ID=\"(.*)\"", data) + if version: + distribution_version = version.group(1) + amazon_facts['distribution_version'] = distribution_version + version_data = distribution_version.split(".") + if len(version_data) > 1: + major, minor = version_data + else: + major, minor = version_data[0], 'NA' + + amazon_facts['distribution_major_version'] = major + amazon_facts['distribution_minor_version'] = minor + else: + version = [n for n in data.split() if n.isdigit()] + version = version[0] if version else 'NA' + amazon_facts['distribution_version'] = version + + return True, amazon_facts + + def parse_distribution_file_OpenWrt(self, name, data, path, collected_facts): + openwrt_facts = {} + if 'OpenWrt' not in data: + return False, openwrt_facts # TODO: remove + openwrt_facts['distribution'] = name + version = re.search('DISTRIB_RELEASE="(.*)"', data) + if version: + openwrt_facts['distribution_version'] = version.groups()[0] + release = re.search('DISTRIB_CODENAME="(.*)"', data) + if release: + openwrt_facts['distribution_release'] = release.groups()[0] + return True, openwrt_facts + + def parse_distribution_file_Alpine(self, name, data, path, collected_facts): + alpine_facts = {} + alpine_facts['distribution'] = 'Alpine' + alpine_facts['distribution_version'] = data + return True, alpine_facts + + def parse_distribution_file_SUSE(self, name, data, path, collected_facts): + suse_facts = {} + if 'suse' not in data.lower(): + return False, suse_facts # TODO: remove if tested without this + if path == '/etc/os-release': + for line in data.splitlines(): + distribution = re.search("^NAME=(.*)", line) + if distribution: + suse_facts['distribution'] = distribution.group(1).strip('"') + # example pattern are 13.04 13.0 13 + distribution_version = re.search(r'^VERSION_ID="?([0-9]+\.?[0-9]*)"?', line) + if distribution_version: + suse_facts['distribution_version'] = distribution_version.group(1) + suse_facts['distribution_major_version'] = distribution_version.group(1).split('.')[0] + if 'open' in data.lower(): + release = re.search(r'^VERSION_ID="?[0-9]+\.?([0-9]*)"?', line) + if release: + suse_facts['distribution_release'] = release.groups()[0] + elif 'enterprise' in data.lower() and 'VERSION_ID' in line: + # SLES doesn't got funny release names + release = re.search(r'^VERSION_ID="?[0-9]+\.?([0-9]*)"?', line) + if release.group(1): + release = release.group(1) + else: + release = "0" # no minor number, so it is the first release + suse_facts['distribution_release'] = release + elif path == '/etc/SuSE-release': + if 'open' in data.lower(): + data = data.splitlines() + distdata = get_file_content(path).splitlines()[0] + suse_facts['distribution'] = distdata.split()[0] + for line in data: + release = re.search('CODENAME *= *([^\n]+)', line) + if release: + suse_facts['distribution_release'] = release.groups()[0].strip() + elif 'enterprise' in data.lower(): + lines = data.splitlines() + distribution = lines[0].split()[0] + if "Server" in data: + suse_facts['distribution'] = "SLES" + elif "Desktop" in data: + suse_facts['distribution'] = "SLED" + for line in lines: + release = re.search('PATCHLEVEL = ([0-9]+)', line) # SLES doesn't got funny release names + if release: + suse_facts['distribution_release'] = release.group(1) + suse_facts['distribution_version'] = collected_facts['distribution_version'] + '.' + release.group(1) + + # See https://www.suse.com/support/kb/doc/?id=000019341 for SLES for SAP + if os.path.islink('/etc/products.d/baseproduct') and os.path.realpath('/etc/products.d/baseproduct').endswith('SLES_SAP.prod'): + suse_facts['distribution'] = 'SLES_SAP' + + return True, suse_facts + + def parse_distribution_file_Debian(self, name, data, path, collected_facts): + debian_facts = {} + if 'Debian' in data or 'Raspbian' in data: + debian_facts['distribution'] = 'Debian' + release = re.search(r"PRETTY_NAME=[^(]+ \(?([^)]+?)\)", data) + if release: + debian_facts['distribution_release'] = release.groups()[0] + + # Last resort: try to find release from tzdata as either lsb is missing or this is very old debian + if collected_facts['distribution_release'] == 'NA' and 'Debian' in data: + dpkg_cmd = self.module.get_bin_path('dpkg') + if dpkg_cmd: + cmd = "%s --status tzdata|grep Provides|cut -f2 -d'-'" % dpkg_cmd + rc, out, err = self.module.run_command(cmd) + if rc == 0: + debian_facts['distribution_release'] = out.strip() + debian_version_path = '/etc/debian_version' + distdata = get_file_lines(debian_version_path) + for line in distdata: + m = re.search(r'(\d+)\.(\d+)', line.strip()) + if m: + debian_facts['distribution_minor_version'] = m.groups()[1] + elif 'Ubuntu' in data: + debian_facts['distribution'] = 'Ubuntu' + # nothing else to do, Ubuntu gets correct info from python functions + elif 'SteamOS' in data: + debian_facts['distribution'] = 'SteamOS' + # nothing else to do, SteamOS gets correct info from python functions + elif path in ('/etc/lsb-release', '/etc/os-release') and ('Kali' in data or 'Parrot' in data): + if 'Kali' in data: + # Kali does not provide /etc/lsb-release anymore + debian_facts['distribution'] = 'Kali' + elif 'Parrot' in data: + debian_facts['distribution'] = 'Parrot' + release = re.search('DISTRIB_RELEASE=(.*)', data) + if release: + debian_facts['distribution_release'] = release.groups()[0] + elif 'Devuan' in data: + debian_facts['distribution'] = 'Devuan' + release = re.search(r"PRETTY_NAME=\"?[^(\"]+ \(?([^) \"]+)\)?", data) + if release: + debian_facts['distribution_release'] = release.groups()[0] + version = re.search(r"VERSION_ID=\"(.*)\"", data) + if version: + debian_facts['distribution_version'] = version.group(1) + debian_facts['distribution_major_version'] = version.group(1) + elif 'Cumulus' in data: + debian_facts['distribution'] = 'Cumulus Linux' + version = re.search(r"VERSION_ID=(.*)", data) + if version: + major, _minor, _dummy_ver = version.group(1).split(".") + debian_facts['distribution_version'] = version.group(1) + debian_facts['distribution_major_version'] = major + + release = re.search(r'VERSION="(.*)"', data) + if release: + debian_facts['distribution_release'] = release.groups()[0] + elif "Mint" in data: + debian_facts['distribution'] = 'Linux Mint' + version = re.search(r"VERSION_ID=\"(.*)\"", data) + if version: + debian_facts['distribution_version'] = version.group(1) + debian_facts['distribution_major_version'] = version.group(1).split('.')[0] + elif 'UOS' in data or 'Uos' in data or 'uos' in data: + debian_facts['distribution'] = 'Uos' + release = re.search(r"VERSION_CODENAME=\"?([^\"]+)\"?", data) + if release: + debian_facts['distribution_release'] = release.groups()[0] + version = re.search(r"VERSION_ID=\"(.*)\"", data) + if version: + debian_facts['distribution_version'] = version.group(1) + debian_facts['distribution_major_version'] = version.group(1).split('.')[0] + elif 'Deepin' in data or 'deepin' in data: + debian_facts['distribution'] = 'Deepin' + release = re.search(r"VERSION_CODENAME=\"?([^\"]+)\"?", data) + if release: + debian_facts['distribution_release'] = release.groups()[0] + version = re.search(r"VERSION_ID=\"(.*)\"", data) + if version: + debian_facts['distribution_version'] = version.group(1) + debian_facts['distribution_major_version'] = version.group(1).split('.')[0] + else: + return False, debian_facts + + return True, debian_facts + + def parse_distribution_file_Mandriva(self, name, data, path, collected_facts): + mandriva_facts = {} + if 'Mandriva' in data: + mandriva_facts['distribution'] = 'Mandriva' + version = re.search('DISTRIB_RELEASE="(.*)"', data) + if version: + mandriva_facts['distribution_version'] = version.groups()[0] + release = re.search('DISTRIB_CODENAME="(.*)"', data) + if release: + mandriva_facts['distribution_release'] = release.groups()[0] + mandriva_facts['distribution'] = name + else: + return False, mandriva_facts + + return True, mandriva_facts + + def parse_distribution_file_NA(self, name, data, path, collected_facts): + na_facts = {} + for line in data.splitlines(): + distribution = re.search("^NAME=(.*)", line) + if distribution and name == 'NA': + na_facts['distribution'] = distribution.group(1).strip('"') + version = re.search("^VERSION=(.*)", line) + if version and collected_facts['distribution_version'] == 'NA': + na_facts['distribution_version'] = version.group(1).strip('"') + return True, na_facts + + def parse_distribution_file_Coreos(self, name, data, path, collected_facts): + coreos_facts = {} + # FIXME: pass in ro copy of facts for this kind of thing + distro = get_distribution() + + if distro.lower() == 'coreos': + if not data: + # include fix from #15230, #15228 + # TODO: verify this is ok for above bugs + return False, coreos_facts + release = re.search("^GROUP=(.*)", data) + if release: + coreos_facts['distribution_release'] = release.group(1).strip('"') + else: + return False, coreos_facts # TODO: remove if tested without this + + return True, coreos_facts + + def parse_distribution_file_Flatcar(self, name, data, path, collected_facts): + flatcar_facts = {} + distro = get_distribution() + + if distro.lower() != 'flatcar': + return False, flatcar_facts + + if not data: + return False, flatcar_facts + + version = re.search("VERSION=(.*)", data) + if version: + flatcar_facts['distribution_major_version'] = version.group(1).strip('"').split('.')[0] + flatcar_facts['distribution_version'] = version.group(1).strip('"') + + return True, flatcar_facts + + def parse_distribution_file_ClearLinux(self, name, data, path, collected_facts): + clear_facts = {} + if "clearlinux" not in name.lower(): + return False, clear_facts + + pname = re.search('NAME="(.*)"', data) + if pname: + if 'Clear Linux' not in pname.groups()[0]: + return False, clear_facts + clear_facts['distribution'] = pname.groups()[0] + version = re.search('VERSION_ID=(.*)', data) + if version: + clear_facts['distribution_major_version'] = version.groups()[0] + clear_facts['distribution_version'] = version.groups()[0] + release = re.search('ID=(.*)', data) + if release: + clear_facts['distribution_release'] = release.groups()[0] + return True, clear_facts + + def parse_distribution_file_CentOS(self, name, data, path, collected_facts): + centos_facts = {} + + if 'CentOS Stream' in data: + centos_facts['distribution_release'] = 'Stream' + return True, centos_facts + + if "TencentOS Server" in data: + centos_facts['distribution'] = 'TencentOS' + return True, centos_facts + + return False, centos_facts + + +class Distribution(object): + """ + This subclass of Facts fills the distribution, distribution_version and distribution_release variables + + To do so it checks the existence and content of typical files in /etc containing distribution information + + This is unit tested. Please extend the tests to cover all distributions if you have them available. + """ + + # keep keys in sync with Conditionals page of docs + OS_FAMILY_MAP = {'RedHat': ['RedHat', 'RHEL', 'Fedora', 'CentOS', 'Scientific', 'SLC', + 'Ascendos', 'CloudLinux', 'PSBM', 'OracleLinux', 'OVS', + 'OEL', 'Amazon', 'Virtuozzo', 'XenServer', 'Alibaba', + 'EulerOS', 'openEuler', 'AlmaLinux', 'Rocky', 'TencentOS', + 'EuroLinux', 'Kylin Linux Advanced Server'], + 'Debian': ['Debian', 'Ubuntu', 'Raspbian', 'Neon', 'KDE neon', + 'Linux Mint', 'SteamOS', 'Devuan', 'Kali', 'Cumulus Linux', + 'Pop!_OS', 'Parrot', 'Pardus GNU/Linux', 'Uos', 'Deepin', 'OSMC'], + 'Suse': ['SuSE', 'SLES', 'SLED', 'openSUSE', 'openSUSE Tumbleweed', + 'SLES_SAP', 'SUSE_LINUX', 'openSUSE Leap'], + 'Archlinux': ['Archlinux', 'Antergos', 'Manjaro'], + 'Mandrake': ['Mandrake', 'Mandriva'], + 'Solaris': ['Solaris', 'Nexenta', 'OmniOS', 'OpenIndiana', 'SmartOS'], + 'Slackware': ['Slackware'], + 'Altlinux': ['Altlinux'], + 'SGML': ['SGML'], + 'Gentoo': ['Gentoo', 'Funtoo'], + 'Alpine': ['Alpine'], + 'AIX': ['AIX'], + 'HP-UX': ['HPUX'], + 'Darwin': ['MacOSX'], + 'FreeBSD': ['FreeBSD', 'TrueOS'], + 'ClearLinux': ['Clear Linux OS', 'Clear Linux Mix'], + 'DragonFly': ['DragonflyBSD', 'DragonFlyBSD', 'Gentoo/DragonflyBSD', 'Gentoo/DragonFlyBSD'], + 'NetBSD': ['NetBSD'], } + + OS_FAMILY = {} + for family, names in OS_FAMILY_MAP.items(): + for name in names: + OS_FAMILY[name] = family + + def __init__(self, module): + self.module = module + + def get_distribution_facts(self): + distribution_facts = {} + + # The platform module provides information about the running + # system/distribution. Use this as a baseline and fix buggy systems + # afterwards + system = platform.system() + distribution_facts['distribution'] = system + distribution_facts['distribution_release'] = platform.release() + distribution_facts['distribution_version'] = platform.version() + + systems_implemented = ('AIX', 'HP-UX', 'Darwin', 'FreeBSD', 'OpenBSD', 'SunOS', 'DragonFly', 'NetBSD') + + if system in systems_implemented: + cleanedname = system.replace('-', '') + distfunc = getattr(self, 'get_distribution_' + cleanedname) + dist_func_facts = distfunc() + distribution_facts.update(dist_func_facts) + elif system == 'Linux': + + distribution_files = DistributionFiles(module=self.module) + + # linux_distribution_facts = LinuxDistribution(module).get_distribution_facts() + dist_file_facts = distribution_files.process_dist_files() + + distribution_facts.update(dist_file_facts) + + distro = distribution_facts['distribution'] + + # look for a os family alias for the 'distribution', if there isnt one, use 'distribution' + distribution_facts['os_family'] = self.OS_FAMILY.get(distro, None) or distro + + return distribution_facts + + def get_distribution_AIX(self): + aix_facts = {} + rc, out, err = self.module.run_command("/usr/bin/oslevel") + data = out.split('.') + aix_facts['distribution_major_version'] = data[0] + if len(data) > 1: + aix_facts['distribution_version'] = '%s.%s' % (data[0], data[1]) + aix_facts['distribution_release'] = data[1] + else: + aix_facts['distribution_version'] = data[0] + return aix_facts + + def get_distribution_HPUX(self): + hpux_facts = {} + rc, out, err = self.module.run_command(r"/usr/sbin/swlist |egrep 'HPUX.*OE.*[AB].[0-9]+\.[0-9]+'", use_unsafe_shell=True) + data = re.search(r'HPUX.*OE.*([AB].[0-9]+\.[0-9]+)\.([0-9]+).*', out) + if data: + hpux_facts['distribution_version'] = data.groups()[0] + hpux_facts['distribution_release'] = data.groups()[1] + return hpux_facts + + def get_distribution_Darwin(self): + darwin_facts = {} + darwin_facts['distribution'] = 'MacOSX' + rc, out, err = self.module.run_command("/usr/bin/sw_vers -productVersion") + data = out.split()[-1] + if data: + darwin_facts['distribution_major_version'] = data.split('.')[0] + darwin_facts['distribution_version'] = data + return darwin_facts + + def get_distribution_FreeBSD(self): + freebsd_facts = {} + freebsd_facts['distribution_release'] = platform.release() + data = re.search(r'(\d+)\.(\d+)-(RELEASE|STABLE|CURRENT|RC|PRERELEASE).*', freebsd_facts['distribution_release']) + if 'trueos' in platform.version(): + freebsd_facts['distribution'] = 'TrueOS' + if data: + freebsd_facts['distribution_major_version'] = data.group(1) + freebsd_facts['distribution_version'] = '%s.%s' % (data.group(1), data.group(2)) + return freebsd_facts + + def get_distribution_OpenBSD(self): + openbsd_facts = {} + openbsd_facts['distribution_version'] = platform.release() + rc, out, err = self.module.run_command("/sbin/sysctl -n kern.version") + match = re.match(r'OpenBSD\s[0-9]+.[0-9]+-(\S+)\s.*', out) + if match: + openbsd_facts['distribution_release'] = match.groups()[0] + else: + openbsd_facts['distribution_release'] = 'release' + return openbsd_facts + + def get_distribution_DragonFly(self): + dragonfly_facts = { + 'distribution_release': platform.release() + } + rc, out, dummy = self.module.run_command("/sbin/sysctl -n kern.version") + match = re.search(r'v(\d+)\.(\d+)\.(\d+)-(RELEASE|STABLE|CURRENT).*', out) + if match: + dragonfly_facts['distribution_major_version'] = match.group(1) + dragonfly_facts['distribution_version'] = '%s.%s.%s' % match.groups()[:3] + return dragonfly_facts + + def get_distribution_NetBSD(self): + netbsd_facts = {} + platform_release = platform.release() + netbsd_facts['distribution_release'] = platform_release + rc, out, dummy = self.module.run_command("/sbin/sysctl -n kern.version") + match = re.match(r'NetBSD\s(\d+)\.(\d+)\s\((GENERIC)\).*', out) + if match: + netbsd_facts['distribution_major_version'] = match.group(1) + netbsd_facts['distribution_version'] = '%s.%s' % match.groups()[:2] + else: + netbsd_facts['distribution_major_version'] = platform_release.split('.')[0] + netbsd_facts['distribution_version'] = platform_release + return netbsd_facts + + def get_distribution_SMGL(self): + smgl_facts = {} + smgl_facts['distribution'] = 'Source Mage GNU/Linux' + return smgl_facts + + def get_distribution_SunOS(self): + sunos_facts = {} + + data = get_file_content('/etc/release').splitlines()[0] + + if 'Solaris' in data: + # for solaris 10 uname_r will contain 5.10, for solaris 11 it will have 5.11 + uname_r = get_uname(self.module, flags=['-r']) + ora_prefix = '' + if 'Oracle Solaris' in data: + data = data.replace('Oracle ', '') + ora_prefix = 'Oracle ' + sunos_facts['distribution'] = data.split()[0] + sunos_facts['distribution_version'] = data.split()[1] + sunos_facts['distribution_release'] = ora_prefix + data + sunos_facts['distribution_major_version'] = uname_r.split('.')[1].rstrip() + return sunos_facts + + uname_v = get_uname(self.module, flags=['-v']) + distribution_version = None + + if 'SmartOS' in data: + sunos_facts['distribution'] = 'SmartOS' + if _file_exists('/etc/product'): + product_data = dict([l.split(': ', 1) for l in get_file_content('/etc/product').splitlines() if ': ' in l]) + if 'Image' in product_data: + distribution_version = product_data.get('Image').split()[-1] + elif 'OpenIndiana' in data: + sunos_facts['distribution'] = 'OpenIndiana' + elif 'OmniOS' in data: + sunos_facts['distribution'] = 'OmniOS' + distribution_version = data.split()[-1] + elif uname_v is not None and 'NexentaOS_' in uname_v: + sunos_facts['distribution'] = 'Nexenta' + distribution_version = data.split()[-1].lstrip('v') + + if sunos_facts.get('distribution', '') in ('SmartOS', 'OpenIndiana', 'OmniOS', 'Nexenta'): + sunos_facts['distribution_release'] = data.strip() + if distribution_version is not None: + sunos_facts['distribution_version'] = distribution_version + elif uname_v is not None: + sunos_facts['distribution_version'] = uname_v.splitlines()[0].strip() + return sunos_facts + + return sunos_facts + + +class DistributionFactCollector(BaseFactCollector): + name = 'distribution' + _fact_ids = set(['distribution_version', + 'distribution_release', + 'distribution_major_version', + 'os_family']) # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + collected_facts = collected_facts or {} + facts_dict = {} + if not module: + return facts_dict + + distribution = Distribution(module=module) + distro_facts = distribution.get_distribution_facts() + + return distro_facts diff --git a/lib/ansible/module_utils/facts/system/dns.py b/lib/ansible/module_utils/facts/system/dns.py new file mode 100644 index 0000000..d913f4a --- /dev/null +++ b/lib/ansible/module_utils/facts/system/dns.py @@ -0,0 +1,68 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_content + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class DnsFactCollector(BaseFactCollector): + name = 'dns' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + dns_facts = {} + + # TODO: flatten + dns_facts['dns'] = {} + + for line in get_file_content('/etc/resolv.conf', '').splitlines(): + if line.startswith('#') or line.startswith(';') or line.strip() == '': + continue + tokens = line.split() + if len(tokens) == 0: + continue + if tokens[0] == 'nameserver': + if 'nameservers' not in dns_facts['dns']: + dns_facts['dns']['nameservers'] = [] + for nameserver in tokens[1:]: + dns_facts['dns']['nameservers'].append(nameserver) + elif tokens[0] == 'domain': + if len(tokens) > 1: + dns_facts['dns']['domain'] = tokens[1] + elif tokens[0] == 'search': + dns_facts['dns']['search'] = [] + for suffix in tokens[1:]: + dns_facts['dns']['search'].append(suffix) + elif tokens[0] == 'sortlist': + dns_facts['dns']['sortlist'] = [] + for address in tokens[1:]: + dns_facts['dns']['sortlist'].append(address) + elif tokens[0] == 'options': + dns_facts['dns']['options'] = {} + if len(tokens) > 1: + for option in tokens[1:]: + option_tokens = option.split(':', 1) + if len(option_tokens) == 0: + continue + val = len(option_tokens) == 2 and option_tokens[1] or True + dns_facts['dns']['options'][option_tokens[0]] = val + + return dns_facts diff --git a/lib/ansible/module_utils/facts/system/env.py b/lib/ansible/module_utils/facts/system/env.py new file mode 100644 index 0000000..605443f --- /dev/null +++ b/lib/ansible/module_utils/facts/system/env.py @@ -0,0 +1,39 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.six import iteritems + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class EnvFactCollector(BaseFactCollector): + name = 'env' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + env_facts = {} + env_facts['env'] = {} + + for k, v in iteritems(os.environ): + env_facts['env'][k] = v + + return env_facts diff --git a/lib/ansible/module_utils/facts/system/fips.py b/lib/ansible/module_utils/facts/system/fips.py new file mode 100644 index 0000000..7e56610 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/fips.py @@ -0,0 +1,39 @@ +# Determine if a system is in 'fips' mode +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_content + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class FipsFactCollector(BaseFactCollector): + name = 'fips' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + # NOTE: this is populated even if it is not set + fips_facts = {} + fips_facts['fips'] = False + data = get_file_content('/proc/sys/crypto/fips_enabled') + if data and data == '1': + fips_facts['fips'] = True + return fips_facts diff --git a/lib/ansible/module_utils/facts/system/loadavg.py b/lib/ansible/module_utils/facts/system/loadavg.py new file mode 100644 index 0000000..8475f2a --- /dev/null +++ b/lib/ansible/module_utils/facts/system/loadavg.py @@ -0,0 +1,31 @@ +# (c) 2021 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class LoadAvgFactCollector(BaseFactCollector): + name = 'loadavg' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + facts = {} + try: + # (0.58, 0.82, 0.98) + loadavg = os.getloadavg() + facts['loadavg'] = { + '1m': loadavg[0], + '5m': loadavg[1], + '15m': loadavg[2] + } + except OSError: + pass + + return facts diff --git a/lib/ansible/module_utils/facts/system/local.py b/lib/ansible/module_utils/facts/system/local.py new file mode 100644 index 0000000..bacdbe0 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/local.py @@ -0,0 +1,113 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import glob +import json +import os +import stat + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils._text import to_text +from ansible.module_utils.facts.utils import get_file_content +from ansible.module_utils.facts.collector import BaseFactCollector +from ansible.module_utils.six.moves import configparser, StringIO + + +class LocalFactCollector(BaseFactCollector): + name = 'local' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + local_facts = {} + local_facts['local'] = {} + + if not module: + return local_facts + + fact_path = module.params.get('fact_path', None) + + if not fact_path or not os.path.exists(fact_path): + return local_facts + + local = {} + # go over .fact files, run executables, read rest, skip bad with warning and note + for fn in sorted(glob.glob(fact_path + '/*.fact')): + # use filename for key where it will sit under local facts + fact_base = os.path.basename(fn).replace('.fact', '') + failed = None + try: + executable_fact = stat.S_IXUSR & os.stat(fn)[stat.ST_MODE] + except OSError as e: + failed = 'Could not stat fact (%s): %s' % (fn, to_text(e)) + local[fact_base] = failed + module.warn(failed) + continue + if executable_fact: + try: + # run it + rc, out, err = module.run_command(fn) + if rc != 0: + failed = 'Failure executing fact script (%s), rc: %s, err: %s' % (fn, rc, err) + except (IOError, OSError) as e: + failed = 'Could not execute fact script (%s): %s' % (fn, to_text(e)) + + if failed is not None: + local[fact_base] = failed + module.warn(failed) + continue + else: + # ignores exceptions and returns empty + out = get_file_content(fn, default='') + + try: + # ensure we have unicode + out = to_text(out, errors='surrogate_or_strict') + except UnicodeError: + fact = 'error loading fact - output of running "%s" was not utf-8' % fn + local[fact_base] = fact + module.warn(fact) + continue + + # try to read it as json first + try: + fact = json.loads(out) + except ValueError: + # if that fails read it with ConfigParser + cp = configparser.ConfigParser() + try: + cp.readfp(StringIO(out)) + except configparser.Error: + fact = "error loading facts as JSON or ini - please check content: %s" % fn + module.warn(fact) + else: + fact = {} + for sect in cp.sections(): + if sect not in fact: + fact[sect] = {} + for opt in cp.options(sect): + val = cp.get(sect, opt) + fact[sect][opt] = val + except Exception as e: + fact = "Failed to convert (%s) to JSON: %s" % (fn, to_text(e)) + module.warn(fact) + + local[fact_base] = fact + + local_facts['local'] = local + return local_facts diff --git a/lib/ansible/module_utils/facts/system/lsb.py b/lib/ansible/module_utils/facts/system/lsb.py new file mode 100644 index 0000000..2dc1433 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/lsb.py @@ -0,0 +1,108 @@ +# Collect facts related to LSB (Linux Standard Base) +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_lines +from ansible.module_utils.facts.collector import BaseFactCollector + + +class LSBFactCollector(BaseFactCollector): + name = 'lsb' + _fact_ids = set() # type: t.Set[str] + STRIP_QUOTES = r'\'\"\\' + + def _lsb_release_bin(self, lsb_path, module): + lsb_facts = {} + + if not lsb_path: + return lsb_facts + + rc, out, err = module.run_command([lsb_path, "-a"], errors='surrogate_then_replace') + if rc != 0: + return lsb_facts + + for line in out.splitlines(): + if len(line) < 1 or ':' not in line: + continue + value = line.split(':', 1)[1].strip() + + if 'LSB Version:' in line: + lsb_facts['release'] = value + elif 'Distributor ID:' in line: + lsb_facts['id'] = value + elif 'Description:' in line: + lsb_facts['description'] = value + elif 'Release:' in line: + lsb_facts['release'] = value + elif 'Codename:' in line: + lsb_facts['codename'] = value + + return lsb_facts + + def _lsb_release_file(self, etc_lsb_release_location): + lsb_facts = {} + + if not os.path.exists(etc_lsb_release_location): + return lsb_facts + + for line in get_file_lines(etc_lsb_release_location): + value = line.split('=', 1)[1].strip() + + if 'DISTRIB_ID' in line: + lsb_facts['id'] = value + elif 'DISTRIB_RELEASE' in line: + lsb_facts['release'] = value + elif 'DISTRIB_DESCRIPTION' in line: + lsb_facts['description'] = value + elif 'DISTRIB_CODENAME' in line: + lsb_facts['codename'] = value + + return lsb_facts + + def collect(self, module=None, collected_facts=None): + facts_dict = {} + lsb_facts = {} + + if not module: + return facts_dict + + lsb_path = module.get_bin_path('lsb_release') + + # try the 'lsb_release' script first + if lsb_path: + lsb_facts = self._lsb_release_bin(lsb_path, + module=module) + + # no lsb_release, try looking in /etc/lsb-release + if not lsb_facts: + lsb_facts = self._lsb_release_file('/etc/lsb-release') + + if lsb_facts and 'release' in lsb_facts: + lsb_facts['major_release'] = lsb_facts['release'].split('.')[0] + + for k, v in lsb_facts.items(): + if v: + lsb_facts[k] = v.strip(LSBFactCollector.STRIP_QUOTES) + + facts_dict['lsb'] = lsb_facts + return facts_dict diff --git a/lib/ansible/module_utils/facts/system/pkg_mgr.py b/lib/ansible/module_utils/facts/system/pkg_mgr.py new file mode 100644 index 0000000..704ea20 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/pkg_mgr.py @@ -0,0 +1,165 @@ +# Collect facts related to the system package manager + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import subprocess + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + +# A list of dicts. If there is a platform with more than one +# package manager, put the preferred one last. If there is an +# ansible module, use that as the value for the 'name' key. +PKG_MGRS = [{'path': '/usr/bin/rpm-ostree', 'name': 'atomic_container'}, + {'path': '/usr/bin/yum', 'name': 'yum'}, + {'path': '/usr/bin/dnf', 'name': 'dnf'}, + {'path': '/usr/bin/apt-get', 'name': 'apt'}, + {'path': '/usr/bin/zypper', 'name': 'zypper'}, + {'path': '/usr/sbin/urpmi', 'name': 'urpmi'}, + {'path': '/usr/bin/pacman', 'name': 'pacman'}, + {'path': '/bin/opkg', 'name': 'opkg'}, + {'path': '/usr/pkg/bin/pkgin', 'name': 'pkgin'}, + {'path': '/opt/local/bin/pkgin', 'name': 'pkgin'}, + {'path': '/opt/tools/bin/pkgin', 'name': 'pkgin'}, + {'path': '/opt/local/bin/port', 'name': 'macports'}, + {'path': '/usr/local/bin/brew', 'name': 'homebrew'}, + {'path': '/opt/homebrew/bin/brew', 'name': 'homebrew'}, + {'path': '/sbin/apk', 'name': 'apk'}, + {'path': '/usr/sbin/pkg', 'name': 'pkgng'}, + {'path': '/usr/sbin/swlist', 'name': 'swdepot'}, + {'path': '/usr/bin/emerge', 'name': 'portage'}, + {'path': '/usr/sbin/pkgadd', 'name': 'svr4pkg'}, + {'path': '/usr/bin/pkg', 'name': 'pkg5'}, + {'path': '/usr/bin/xbps-install', 'name': 'xbps'}, + {'path': '/usr/local/sbin/pkg', 'name': 'pkgng'}, + {'path': '/usr/bin/swupd', 'name': 'swupd'}, + {'path': '/usr/sbin/sorcery', 'name': 'sorcery'}, + {'path': '/usr/bin/installp', 'name': 'installp'}, + {'path': '/QOpenSys/pkgs/bin/yum', 'name': 'yum'}, + ] + + +class OpenBSDPkgMgrFactCollector(BaseFactCollector): + name = 'pkg_mgr' + _fact_ids = set() # type: t.Set[str] + _platform = 'OpenBSD' + + def collect(self, module=None, collected_facts=None): + facts_dict = {} + + facts_dict['pkg_mgr'] = 'openbsd_pkg' + return facts_dict + + +# the fact ends up being 'pkg_mgr' so stick with that naming/spelling +class PkgMgrFactCollector(BaseFactCollector): + name = 'pkg_mgr' + _fact_ids = set() # type: t.Set[str] + _platform = 'Generic' + required_facts = set(['distribution']) + + def _pkg_mgr_exists(self, pkg_mgr_name): + for cur_pkg_mgr in [pkg_mgr for pkg_mgr in PKG_MGRS if pkg_mgr['name'] == pkg_mgr_name]: + if os.path.exists(cur_pkg_mgr['path']): + return pkg_mgr_name + + def _check_rh_versions(self, pkg_mgr_name, collected_facts): + if os.path.exists('/run/ostree-booted'): + return "atomic_container" + + if collected_facts['ansible_distribution'] == 'Fedora': + try: + if int(collected_facts['ansible_distribution_major_version']) < 23: + if self._pkg_mgr_exists('yum'): + pkg_mgr_name = 'yum' + + else: + if self._pkg_mgr_exists('dnf'): + pkg_mgr_name = 'dnf' + except ValueError: + # If there's some new magical Fedora version in the future, + # just default to dnf + pkg_mgr_name = 'dnf' + elif collected_facts['ansible_distribution'] == 'Amazon': + try: + if int(collected_facts['ansible_distribution_major_version']) < 2022: + if self._pkg_mgr_exists('yum'): + pkg_mgr_name = 'yum' + else: + if self._pkg_mgr_exists('dnf'): + pkg_mgr_name = 'dnf' + except ValueError: + pkg_mgr_name = 'dnf' + else: + # If it's not one of the above and it's Red Hat family of distros, assume + # RHEL or a clone. For versions of RHEL < 8 that Ansible supports, the + # vendor supported official package manager is 'yum' and in RHEL 8+ + # (as far as we know at the time of this writing) it is 'dnf'. + # If anyone wants to force a non-official package manager then they + # can define a provider to either the package or yum action plugins. + if int(collected_facts['ansible_distribution_major_version']) < 8: + pkg_mgr_name = 'yum' + else: + pkg_mgr_name = 'dnf' + return pkg_mgr_name + + def _check_apt_flavor(self, pkg_mgr_name): + # Check if '/usr/bin/apt' is APT-RPM or an ordinary (dpkg-based) APT. + # There's rpm package on Debian, so checking if /usr/bin/rpm exists + # is not enough. Instead ask RPM if /usr/bin/apt-get belongs to some + # RPM package. + rpm_query = '/usr/bin/rpm -q --whatprovides /usr/bin/apt-get'.split() + if os.path.exists('/usr/bin/rpm'): + with open(os.devnull, 'w') as null: + try: + subprocess.check_call(rpm_query, stdout=null, stderr=null) + pkg_mgr_name = 'apt_rpm' + except subprocess.CalledProcessError: + # No apt-get in RPM database. Looks like Debian/Ubuntu + # with rpm package installed + pkg_mgr_name = 'apt' + return pkg_mgr_name + + def pkg_mgrs(self, collected_facts): + # Filter out the /usr/bin/pkg because on Altlinux it is actually the + # perl-Package (not Solaris package manager). + # Since the pkg5 takes precedence over apt, this workaround + # is required to select the suitable package manager on Altlinux. + if collected_facts['ansible_os_family'] == 'Altlinux': + return filter(lambda pkg: pkg['path'] != '/usr/bin/pkg', PKG_MGRS) + else: + return PKG_MGRS + + def collect(self, module=None, collected_facts=None): + facts_dict = {} + collected_facts = collected_facts or {} + + pkg_mgr_name = 'unknown' + for pkg in self.pkg_mgrs(collected_facts): + if os.path.exists(pkg['path']): + pkg_mgr_name = pkg['name'] + + # Handle distro family defaults when more than one package manager is + # installed or available to the distro, the ansible_fact entry should be + # the default package manager officially supported by the distro. + if collected_facts['ansible_os_family'] == "RedHat": + pkg_mgr_name = self._check_rh_versions(pkg_mgr_name, collected_facts) + elif collected_facts['ansible_os_family'] == 'Debian' and pkg_mgr_name != 'apt': + # It's possible to install yum, dnf, zypper, rpm, etc inside of + # Debian. Doing so does not mean the system wants to use them. + pkg_mgr_name = 'apt' + elif collected_facts['ansible_os_family'] == 'Altlinux': + if pkg_mgr_name == 'apt': + pkg_mgr_name = 'apt_rpm' + + # Check if /usr/bin/apt-get is ordinary (dpkg-based) APT or APT-RPM + if pkg_mgr_name == 'apt': + pkg_mgr_name = self._check_apt_flavor(pkg_mgr_name) + + facts_dict['pkg_mgr'] = pkg_mgr_name + return facts_dict diff --git a/lib/ansible/module_utils/facts/system/platform.py b/lib/ansible/module_utils/facts/system/platform.py new file mode 100644 index 0000000..b947801 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/platform.py @@ -0,0 +1,99 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import socket +import platform + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_content + +from ansible.module_utils.facts.collector import BaseFactCollector + +# i86pc is a Solaris and derivatives-ism +SOLARIS_I86_RE_PATTERN = r'i([3456]86|86pc)' +solaris_i86_re = re.compile(SOLARIS_I86_RE_PATTERN) + + +class PlatformFactCollector(BaseFactCollector): + name = 'platform' + _fact_ids = set(['system', + 'kernel', + 'kernel_version', + 'machine', + 'python_version', + 'architecture', + 'machine_id']) # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + platform_facts = {} + # platform.system() can be Linux, Darwin, Java, or Windows + platform_facts['system'] = platform.system() + platform_facts['kernel'] = platform.release() + platform_facts['kernel_version'] = platform.version() + platform_facts['machine'] = platform.machine() + + platform_facts['python_version'] = platform.python_version() + + platform_facts['fqdn'] = socket.getfqdn() + platform_facts['hostname'] = platform.node().split('.')[0] + platform_facts['nodename'] = platform.node() + + platform_facts['domain'] = '.'.join(platform_facts['fqdn'].split('.')[1:]) + + arch_bits = platform.architecture()[0] + + platform_facts['userspace_bits'] = arch_bits.replace('bit', '') + if platform_facts['machine'] == 'x86_64': + platform_facts['architecture'] = platform_facts['machine'] + if platform_facts['userspace_bits'] == '64': + platform_facts['userspace_architecture'] = 'x86_64' + elif platform_facts['userspace_bits'] == '32': + platform_facts['userspace_architecture'] = 'i386' + elif solaris_i86_re.search(platform_facts['machine']): + platform_facts['architecture'] = 'i386' + if platform_facts['userspace_bits'] == '64': + platform_facts['userspace_architecture'] = 'x86_64' + elif platform_facts['userspace_bits'] == '32': + platform_facts['userspace_architecture'] = 'i386' + else: + platform_facts['architecture'] = platform_facts['machine'] + + if platform_facts['system'] == 'AIX': + # Attempt to use getconf to figure out architecture + # fall back to bootinfo if needed + getconf_bin = module.get_bin_path('getconf') + if getconf_bin: + rc, out, err = module.run_command([getconf_bin, 'MACHINE_ARCHITECTURE']) + data = out.splitlines() + platform_facts['architecture'] = data[0] + else: + bootinfo_bin = module.get_bin_path('bootinfo') + rc, out, err = module.run_command([bootinfo_bin, '-p']) + data = out.splitlines() + platform_facts['architecture'] = data[0] + elif platform_facts['system'] == 'OpenBSD': + platform_facts['architecture'] = platform.uname()[5] + + machine_id = get_file_content("/var/lib/dbus/machine-id") or get_file_content("/etc/machine-id") + if machine_id: + machine_id = machine_id.splitlines()[0] + platform_facts["machine_id"] = machine_id + + return platform_facts diff --git a/lib/ansible/module_utils/facts/system/python.py b/lib/ansible/module_utils/facts/system/python.py new file mode 100644 index 0000000..50b66dd --- /dev/null +++ b/lib/ansible/module_utils/facts/system/python.py @@ -0,0 +1,62 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import sys + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + +try: + # Check if we have SSLContext support + from ssl import create_default_context, SSLContext + del create_default_context + del SSLContext + HAS_SSLCONTEXT = True +except ImportError: + HAS_SSLCONTEXT = False + + +class PythonFactCollector(BaseFactCollector): + name = 'python' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + python_facts = {} + python_facts['python'] = { + 'version': { + 'major': sys.version_info[0], + 'minor': sys.version_info[1], + 'micro': sys.version_info[2], + 'releaselevel': sys.version_info[3], + 'serial': sys.version_info[4] + }, + 'version_info': list(sys.version_info), + 'executable': sys.executable, + 'has_sslcontext': HAS_SSLCONTEXT + } + + try: + python_facts['python']['type'] = sys.subversion[0] + except AttributeError: + try: + python_facts['python']['type'] = sys.implementation.name + except AttributeError: + python_facts['python']['type'] = None + + return python_facts diff --git a/lib/ansible/module_utils/facts/system/selinux.py b/lib/ansible/module_utils/facts/system/selinux.py new file mode 100644 index 0000000..5c6b012 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/selinux.py @@ -0,0 +1,93 @@ +# Collect facts related to selinux +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + +try: + from ansible.module_utils.compat import selinux + HAVE_SELINUX = True +except ImportError: + HAVE_SELINUX = False + +SELINUX_MODE_DICT = { + 1: 'enforcing', + 0: 'permissive', + -1: 'disabled' +} + + +class SelinuxFactCollector(BaseFactCollector): + name = 'selinux' + _fact_ids = set() # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + facts_dict = {} + selinux_facts = {} + + # If selinux library is missing, only set the status and selinux_python_present since + # there is no way to tell if SELinux is enabled or disabled on the system + # without the library. + if not HAVE_SELINUX: + selinux_facts['status'] = 'Missing selinux Python library' + facts_dict['selinux'] = selinux_facts + facts_dict['selinux_python_present'] = False + return facts_dict + + # Set a boolean for testing whether the Python library is present + facts_dict['selinux_python_present'] = True + + if not selinux.is_selinux_enabled(): + selinux_facts['status'] = 'disabled' + else: + selinux_facts['status'] = 'enabled' + + try: + selinux_facts['policyvers'] = selinux.security_policyvers() + except (AttributeError, OSError): + selinux_facts['policyvers'] = 'unknown' + + try: + (rc, configmode) = selinux.selinux_getenforcemode() + if rc == 0: + selinux_facts['config_mode'] = SELINUX_MODE_DICT.get(configmode, 'unknown') + else: + selinux_facts['config_mode'] = 'unknown' + except (AttributeError, OSError): + selinux_facts['config_mode'] = 'unknown' + + try: + mode = selinux.security_getenforce() + selinux_facts['mode'] = SELINUX_MODE_DICT.get(mode, 'unknown') + except (AttributeError, OSError): + selinux_facts['mode'] = 'unknown' + + try: + (rc, policytype) = selinux.selinux_getpolicytype() + if rc == 0: + selinux_facts['type'] = policytype + else: + selinux_facts['type'] = 'unknown' + except (AttributeError, OSError): + selinux_facts['type'] = 'unknown' + + facts_dict['selinux'] = selinux_facts + return facts_dict diff --git a/lib/ansible/module_utils/facts/system/service_mgr.py b/lib/ansible/module_utils/facts/system/service_mgr.py new file mode 100644 index 0000000..d862ac9 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/service_mgr.py @@ -0,0 +1,152 @@ +# Collect facts related to system service manager and init. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import platform +import re + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils._text import to_native + +from ansible.module_utils.facts.utils import get_file_content +from ansible.module_utils.facts.collector import BaseFactCollector + +# The distutils module is not shipped with SUNWPython on Solaris. +# It's in the SUNWPython-devel package which also contains development files +# that don't belong on production boxes. Since our Solaris code doesn't +# depend on LooseVersion, do not import it on Solaris. +if platform.system() != 'SunOS': + from ansible.module_utils.compat.version import LooseVersion + + +class ServiceMgrFactCollector(BaseFactCollector): + name = 'service_mgr' + _fact_ids = set() # type: t.Set[str] + required_facts = set(['platform', 'distribution']) + + @staticmethod + def is_systemd_managed(module): + # tools must be installed + if module.get_bin_path('systemctl'): + + # this should show if systemd is the boot init system, if checking init faild to mark as systemd + # these mirror systemd's own sd_boot test http://www.freedesktop.org/software/systemd/man/sd_booted.html + for canary in ["/run/systemd/system/", "/dev/.run/systemd/", "/dev/.systemd/"]: + if os.path.exists(canary): + return True + return False + + @staticmethod + def is_systemd_managed_offline(module): + # tools must be installed + if module.get_bin_path('systemctl'): + # check if /sbin/init is a symlink to systemd + # on SUSE, /sbin/init may be missing if systemd-sysvinit package is not installed. + if os.path.islink('/sbin/init') and os.path.basename(os.readlink('/sbin/init')) == 'systemd': + return True + return False + + def collect(self, module=None, collected_facts=None): + facts_dict = {} + + if not module: + return facts_dict + + collected_facts = collected_facts or {} + service_mgr_name = None + + # TODO: detect more custom init setups like bootscripts, dmd, s6, Epoch, etc + # also other OSs other than linux might need to check across several possible candidates + + # Mapping of proc_1 values to more useful names + proc_1_map = { + 'procd': 'openwrt_init', + 'runit-init': 'runit', + 'svscan': 'svc', + 'openrc-init': 'openrc', + } + + # try various forms of querying pid 1 + proc_1 = get_file_content('/proc/1/comm') + if proc_1 is None: + rc, proc_1, err = module.run_command("ps -p 1 -o comm|tail -n 1", use_unsafe_shell=True) + + # if command fails, or stdout is empty string or the output of the command starts with what looks like a PID, + # then the 'ps' command probably didn't work the way we wanted, probably because it's busybox + if rc != 0 or not proc_1.strip() or re.match(r' *[0-9]+ ', proc_1): + proc_1 = None + + # The ps command above may return "COMMAND" if the user cannot read /proc, e.g. with grsecurity + if proc_1 == "COMMAND\n": + proc_1 = None + + if proc_1 is None and os.path.islink('/sbin/init'): + proc_1 = os.readlink('/sbin/init') + + if proc_1 is not None: + proc_1 = os.path.basename(proc_1) + proc_1 = to_native(proc_1) + proc_1 = proc_1.strip() + + if proc_1 is not None and (proc_1 == 'init' or proc_1.endswith('sh')): + # many systems return init, so this cannot be trusted, if it ends in 'sh' it probalby is a shell in a container + proc_1 = None + + # if not init/None it should be an identifiable or custom init, so we are done! + if proc_1 is not None: + # Lookup proc_1 value in map and use proc_1 value itself if no match + service_mgr_name = proc_1_map.get(proc_1, proc_1) + + # start with the easy ones + elif collected_facts.get('ansible_distribution', None) == 'MacOSX': + # FIXME: find way to query executable, version matching is not ideal + if LooseVersion(platform.mac_ver()[0]) >= LooseVersion('10.4'): + service_mgr_name = 'launchd' + else: + service_mgr_name = 'systemstarter' + elif 'BSD' in collected_facts.get('ansible_system', '') or collected_facts.get('ansible_system') in ['Bitrig', 'DragonFly']: + # FIXME: we might want to break out to individual BSDs or 'rc' + service_mgr_name = 'bsdinit' + elif collected_facts.get('ansible_system') == 'AIX': + service_mgr_name = 'src' + elif collected_facts.get('ansible_system') == 'SunOS': + service_mgr_name = 'smf' + elif collected_facts.get('ansible_distribution') == 'OpenWrt': + service_mgr_name = 'openwrt_init' + elif collected_facts.get('ansible_system') == 'Linux': + # FIXME: mv is_systemd_managed + if self.is_systemd_managed(module=module): + service_mgr_name = 'systemd' + elif module.get_bin_path('initctl') and os.path.exists("/etc/init/"): + service_mgr_name = 'upstart' + elif os.path.exists('/sbin/openrc'): + service_mgr_name = 'openrc' + elif self.is_systemd_managed_offline(module=module): + service_mgr_name = 'systemd' + elif os.path.exists('/etc/init.d/'): + service_mgr_name = 'sysvinit' + + if not service_mgr_name: + # if we cannot detect, fallback to generic 'service' + service_mgr_name = 'service' + + facts_dict['service_mgr'] = service_mgr_name + return facts_dict diff --git a/lib/ansible/module_utils/facts/system/ssh_pub_keys.py b/lib/ansible/module_utils/facts/system/ssh_pub_keys.py new file mode 100644 index 0000000..85691c7 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/ssh_pub_keys.py @@ -0,0 +1,56 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.utils import get_file_content + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class SshPubKeyFactCollector(BaseFactCollector): + name = 'ssh_pub_keys' + _fact_ids = set(['ssh_host_pub_keys', + 'ssh_host_key_dsa_public', + 'ssh_host_key_rsa_public', + 'ssh_host_key_ecdsa_public', + 'ssh_host_key_ed25519_public']) # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + ssh_pub_key_facts = {} + algos = ('dsa', 'rsa', 'ecdsa', 'ed25519') + + # list of directories to check for ssh keys + # used in the order listed here, the first one with keys is used + keydirs = ['/etc/ssh', '/etc/openssh', '/etc'] + + for keydir in keydirs: + for algo in algos: + factname = 'ssh_host_key_%s_public' % algo + if factname in ssh_pub_key_facts: + # a previous keydir was already successful, stop looking + # for keys + return ssh_pub_key_facts + key_filename = '%s/ssh_host_%s_key.pub' % (keydir, algo) + keydata = get_file_content(key_filename) + if keydata is not None: + (keytype, key) = keydata.split()[0:2] + ssh_pub_key_facts[factname] = key + ssh_pub_key_facts[factname + '_keytype'] = keytype + + return ssh_pub_key_facts diff --git a/lib/ansible/module_utils/facts/system/user.py b/lib/ansible/module_utils/facts/system/user.py new file mode 100644 index 0000000..2efa993 --- /dev/null +++ b/lib/ansible/module_utils/facts/system/user.py @@ -0,0 +1,55 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import getpass +import os +import pwd + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class UserFactCollector(BaseFactCollector): + name = 'user' + _fact_ids = set(['user_id', 'user_uid', 'user_gid', + 'user_gecos', 'user_dir', 'user_shell', + 'real_user_id', 'effective_user_id', + 'effective_group_ids']) # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + user_facts = {} + + user_facts['user_id'] = getpass.getuser() + + try: + pwent = pwd.getpwnam(getpass.getuser()) + except KeyError: + pwent = pwd.getpwuid(os.getuid()) + + user_facts['user_uid'] = pwent.pw_uid + user_facts['user_gid'] = pwent.pw_gid + user_facts['user_gecos'] = pwent.pw_gecos + user_facts['user_dir'] = pwent.pw_dir + user_facts['user_shell'] = pwent.pw_shell + user_facts['real_user_id'] = os.getuid() + user_facts['effective_user_id'] = os.geteuid() + user_facts['real_group_id'] = os.getgid() + user_facts['effective_group_id'] = os.getgid() + + return user_facts diff --git a/lib/ansible/module_utils/facts/timeout.py b/lib/ansible/module_utils/facts/timeout.py new file mode 100644 index 0000000..ebb71cc --- /dev/null +++ b/lib/ansible/module_utils/facts/timeout.py @@ -0,0 +1,70 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import multiprocessing +import multiprocessing.pool as mp + +# timeout function to make sure some fact gathering +# steps do not exceed a time limit + +GATHER_TIMEOUT = None +DEFAULT_GATHER_TIMEOUT = 10 + + +class TimeoutError(Exception): + pass + + +def timeout(seconds=None, error_message="Timer expired"): + """ + Timeout decorator to expire after a set number of seconds. This raises an + ansible.module_utils.facts.TimeoutError if the timeout is hit before the + function completes. + """ + def decorator(func): + def wrapper(*args, **kwargs): + timeout_value = seconds + if timeout_value is None: + timeout_value = globals().get('GATHER_TIMEOUT') or DEFAULT_GATHER_TIMEOUT + + pool = mp.ThreadPool(processes=1) + res = pool.apply_async(func, args, kwargs) + pool.close() + try: + return res.get(timeout_value) + except multiprocessing.TimeoutError: + # This is an ansible.module_utils.common.facts.timeout.TimeoutError + raise TimeoutError('Timer expired after %s seconds' % timeout_value) + finally: + pool.terminate() + + return wrapper + + # If we were called as @timeout, then the first parameter will be the + # function we are to wrap instead of the number of seconds. Detect this + # and correct it by setting seconds to our default value and return the + # inner decorator function manually wrapped around the function + if callable(seconds): + func = seconds + seconds = None + return decorator(func) + + # If we were called as @timeout([...]) then python itself will take + # care of wrapping the inner decorator around the function + + return decorator diff --git a/lib/ansible/module_utils/facts/utils.py b/lib/ansible/module_utils/facts/utils.py new file mode 100644 index 0000000..a6027ab --- /dev/null +++ b/lib/ansible/module_utils/facts/utils.py @@ -0,0 +1,102 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import fcntl +import os + + +def get_file_content(path, default=None, strip=True): + ''' + Return the contents of a given file path + + :args path: path to file to return contents from + :args default: value to return if we could not read file + :args strip: controls if we strip whitespace from the result or not + + :returns: String with file contents (optionally stripped) or 'default' value + ''' + data = default + if os.path.exists(path) and os.access(path, os.R_OK): + datafile = None + try: + datafile = open(path) + try: + # try to not enter kernel 'block' mode, which prevents timeouts + fd = datafile.fileno() + flag = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flag | os.O_NONBLOCK) + except Exception: + pass # not required to operate, but would have been nice! + + # actually read the data + data = datafile.read() + + if strip: + data = data.strip() + + if len(data) == 0: + data = default + + except Exception: + # ignore errors as some jails/containers might have readable permissions but not allow reads + pass + finally: + if datafile is not None: + datafile.close() + + return data + + +def get_file_lines(path, strip=True, line_sep=None): + '''get list of lines from file''' + data = get_file_content(path, strip=strip) + if data: + if line_sep is None: + ret = data.splitlines() + else: + if len(line_sep) == 1: + ret = data.rstrip(line_sep).split(line_sep) + else: + ret = data.split(line_sep) + else: + ret = [] + return ret + + +def get_mount_size(mountpoint): + mount_size = {} + + try: + statvfs_result = os.statvfs(mountpoint) + mount_size['size_total'] = statvfs_result.f_frsize * statvfs_result.f_blocks + mount_size['size_available'] = statvfs_result.f_frsize * (statvfs_result.f_bavail) + + # Block total/available/used + mount_size['block_size'] = statvfs_result.f_bsize + mount_size['block_total'] = statvfs_result.f_blocks + mount_size['block_available'] = statvfs_result.f_bavail + mount_size['block_used'] = mount_size['block_total'] - mount_size['block_available'] + + # Inode total/available/used + mount_size['inode_total'] = statvfs_result.f_files + mount_size['inode_available'] = statvfs_result.f_favail + mount_size['inode_used'] = mount_size['inode_total'] - mount_size['inode_available'] + except OSError: + pass + + return mount_size diff --git a/lib/ansible/module_utils/facts/virtual/__init__.py b/lib/ansible/module_utils/facts/virtual/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/__init__.py diff --git a/lib/ansible/module_utils/facts/virtual/base.py b/lib/ansible/module_utils/facts/virtual/base.py new file mode 100644 index 0000000..67b59a5 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/base.py @@ -0,0 +1,80 @@ +# base classes for virtualization facts +# -*- coding: utf-8 -*- +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import ansible.module_utils.compat.typing as t + +from ansible.module_utils.facts.collector import BaseFactCollector + + +class Virtual: + """ + This is a generic Virtual subclass of Facts. This should be further + subclassed to implement per platform. If you subclass this, + you should define: + - virtualization_type + - virtualization_role + - container (e.g. solaris zones, freebsd jails, linux containers) + + All subclasses MUST define platform. + """ + platform = 'Generic' + + # FIXME: remove load_on_init if we can + def __init__(self, module, load_on_init=False): + self.module = module + + # FIXME: just here for existing tests cases till they are updated + def populate(self, collected_facts=None): + virtual_facts = self.get_virtual_facts() + + return virtual_facts + + def get_virtual_facts(self): + virtual_facts = { + 'virtualization_type': '', + 'virtualization_role': '', + 'virtualization_tech_guest': set(), + 'virtualization_tech_host': set(), + } + return virtual_facts + + +class VirtualCollector(BaseFactCollector): + name = 'virtual' + _fact_class = Virtual + _fact_ids = set([ + 'virtualization_type', + 'virtualization_role', + 'virtualization_tech_guest', + 'virtualization_tech_host', + ]) # type: t.Set[str] + + def collect(self, module=None, collected_facts=None): + collected_facts = collected_facts or {} + if not module: + return {} + + # Network munges cached_facts by side effect, so give it a copy + facts_obj = self._fact_class(module) + + facts_dict = facts_obj.populate(collected_facts=collected_facts) + + return facts_dict diff --git a/lib/ansible/module_utils/facts/virtual/dragonfly.py b/lib/ansible/module_utils/facts/virtual/dragonfly.py new file mode 100644 index 0000000..b176f8b --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/dragonfly.py @@ -0,0 +1,25 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.facts.virtual.freebsd import FreeBSDVirtual, VirtualCollector + + +class DragonFlyVirtualCollector(VirtualCollector): + # Note the _fact_class impl is actually the FreeBSDVirtual impl + _fact_class = FreeBSDVirtual + _platform = 'DragonFly' diff --git a/lib/ansible/module_utils/facts/virtual/freebsd.py b/lib/ansible/module_utils/facts/virtual/freebsd.py new file mode 100644 index 0000000..7062d01 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/freebsd.py @@ -0,0 +1,79 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from ansible.module_utils.facts.virtual.base import Virtual, VirtualCollector +from ansible.module_utils.facts.virtual.sysctl import VirtualSysctlDetectionMixin + + +class FreeBSDVirtual(Virtual, VirtualSysctlDetectionMixin): + """ + This is a FreeBSD-specific subclass of Virtual. It defines + - virtualization_type + - virtualization_role + """ + platform = 'FreeBSD' + + def get_virtual_facts(self): + virtual_facts = {} + host_tech = set() + guest_tech = set() + + # Set empty values as default + virtual_facts['virtualization_type'] = '' + virtual_facts['virtualization_role'] = '' + + if os.path.exists('/dev/xen/xenstore'): + guest_tech.add('xen') + virtual_facts['virtualization_type'] = 'xen' + virtual_facts['virtualization_role'] = 'guest' + + kern_vm_guest = self.detect_virt_product('kern.vm_guest') + guest_tech.update(kern_vm_guest['virtualization_tech_guest']) + host_tech.update(kern_vm_guest['virtualization_tech_host']) + + hw_hv_vendor = self.detect_virt_product('hw.hv_vendor') + guest_tech.update(hw_hv_vendor['virtualization_tech_guest']) + host_tech.update(hw_hv_vendor['virtualization_tech_host']) + + sec_jail_jailed = self.detect_virt_product('security.jail.jailed') + guest_tech.update(sec_jail_jailed['virtualization_tech_guest']) + host_tech.update(sec_jail_jailed['virtualization_tech_host']) + + if virtual_facts['virtualization_type'] == '': + sysctl = kern_vm_guest or hw_hv_vendor or sec_jail_jailed + # We call update here, then re-set virtualization_tech_host/guest + # later. + virtual_facts.update(sysctl) + + virtual_vendor_facts = self.detect_virt_vendor('hw.model') + guest_tech.update(virtual_vendor_facts['virtualization_tech_guest']) + host_tech.update(virtual_vendor_facts['virtualization_tech_host']) + + if virtual_facts['virtualization_type'] == '': + virtual_facts.update(virtual_vendor_facts) + + virtual_facts['virtualization_tech_guest'] = guest_tech + virtual_facts['virtualization_tech_host'] = host_tech + return virtual_facts + + +class FreeBSDVirtualCollector(VirtualCollector): + _fact_class = FreeBSDVirtual + _platform = 'FreeBSD' diff --git a/lib/ansible/module_utils/facts/virtual/hpux.py b/lib/ansible/module_utils/facts/virtual/hpux.py new file mode 100644 index 0000000..1057482 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/hpux.py @@ -0,0 +1,72 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re + +from ansible.module_utils.facts.virtual.base import Virtual, VirtualCollector + + +class HPUXVirtual(Virtual): + """ + This is a HP-UX specific subclass of Virtual. It defines + - virtualization_type + - virtualization_role + """ + platform = 'HP-UX' + + def get_virtual_facts(self): + virtual_facts = {} + host_tech = set() + guest_tech = set() + + if os.path.exists('/usr/sbin/vecheck'): + rc, out, err = self.module.run_command("/usr/sbin/vecheck") + if rc == 0: + guest_tech.add('HP vPar') + virtual_facts['virtualization_type'] = 'guest' + virtual_facts['virtualization_role'] = 'HP vPar' + if os.path.exists('/opt/hpvm/bin/hpvminfo'): + rc, out, err = self.module.run_command("/opt/hpvm/bin/hpvminfo") + if rc == 0 and re.match('.*Running.*HPVM vPar.*', out): + guest_tech.add('HPVM vPar') + virtual_facts['virtualization_type'] = 'guest' + virtual_facts['virtualization_role'] = 'HPVM vPar' + elif rc == 0 and re.match('.*Running.*HPVM guest.*', out): + guest_tech.add('HPVM IVM') + virtual_facts['virtualization_type'] = 'guest' + virtual_facts['virtualization_role'] = 'HPVM IVM' + elif rc == 0 and re.match('.*Running.*HPVM host.*', out): + guest_tech.add('HPVM') + virtual_facts['virtualization_type'] = 'host' + virtual_facts['virtualization_role'] = 'HPVM' + if os.path.exists('/usr/sbin/parstatus'): + rc, out, err = self.module.run_command("/usr/sbin/parstatus") + if rc == 0: + guest_tech.add('HP nPar') + virtual_facts['virtualization_type'] = 'guest' + virtual_facts['virtualization_role'] = 'HP nPar' + + virtual_facts['virtualization_tech_guest'] = guest_tech + virtual_facts['virtualization_tech_host'] = host_tech + return virtual_facts + + +class HPUXVirtualCollector(VirtualCollector): + _fact_class = HPUXVirtual + _platform = 'HP-UX' diff --git a/lib/ansible/module_utils/facts/virtual/linux.py b/lib/ansible/module_utils/facts/virtual/linux.py new file mode 100644 index 0000000..31fa061 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/linux.py @@ -0,0 +1,405 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import glob +import os +import re + +from ansible.module_utils.facts.virtual.base import Virtual, VirtualCollector +from ansible.module_utils.facts.utils import get_file_content, get_file_lines + + +class LinuxVirtual(Virtual): + """ + This is a Linux-specific subclass of Virtual. It defines + - virtualization_type + - virtualization_role + """ + platform = 'Linux' + + # For more information, check: http://people.redhat.com/~rjones/virt-what/ + def get_virtual_facts(self): + virtual_facts = {} + + # We want to maintain compatibility with the old "virtualization_type" + # and "virtualization_role" entries, so we need to track if we found + # them. We won't return them until the end, but if we found them early, + # we should avoid updating them again. + found_virt = False + + # But as we go along, we also want to track virt tech the new way. + host_tech = set() + guest_tech = set() + + # lxc/docker + if os.path.exists('/proc/1/cgroup'): + for line in get_file_lines('/proc/1/cgroup'): + if re.search(r'/docker(/|-[0-9a-f]+\.scope)', line): + guest_tech.add('docker') + if not found_virt: + virtual_facts['virtualization_type'] = 'docker' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + if re.search('/lxc/', line) or re.search('/machine.slice/machine-lxc', line): + guest_tech.add('lxc') + if not found_virt: + virtual_facts['virtualization_type'] = 'lxc' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + if re.search('/system.slice/containerd.service', line): + guest_tech.add('containerd') + if not found_virt: + virtual_facts['virtualization_type'] = 'containerd' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + # lxc does not always appear in cgroups anymore but sets 'container=lxc' environment var, requires root privs + if os.path.exists('/proc/1/environ'): + for line in get_file_lines('/proc/1/environ', line_sep='\x00'): + if re.search('container=lxc', line): + guest_tech.add('lxc') + if not found_virt: + virtual_facts['virtualization_type'] = 'lxc' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + if re.search('container=podman', line): + guest_tech.add('podman') + if not found_virt: + virtual_facts['virtualization_type'] = 'podman' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + if re.search('^container=.', line): + guest_tech.add('container') + if not found_virt: + virtual_facts['virtualization_type'] = 'container' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + if os.path.exists('/proc/vz') and not os.path.exists('/proc/lve'): + virtual_facts['virtualization_type'] = 'openvz' + if os.path.exists('/proc/bc'): + host_tech.add('openvz') + if not found_virt: + virtual_facts['virtualization_role'] = 'host' + else: + guest_tech.add('openvz') + if not found_virt: + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + systemd_container = get_file_content('/run/systemd/container') + if systemd_container: + guest_tech.add(systemd_container) + if not found_virt: + virtual_facts['virtualization_type'] = systemd_container + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + # If docker/containerd has a custom cgroup parent, checking /proc/1/cgroup (above) might fail. + # https://docs.docker.com/engine/reference/commandline/dockerd/#default-cgroup-parent + # Fallback to more rudimentary checks. + if os.path.exists('/.dockerenv') or os.path.exists('/.dockerinit'): + guest_tech.add('docker') + if not found_virt: + virtual_facts['virtualization_type'] = 'docker' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + # ensure 'container' guest_tech is appropriately set + if guest_tech.intersection(set(['docker', 'lxc', 'podman', 'openvz', 'containerd'])) or systemd_container: + guest_tech.add('container') + + if os.path.exists("/proc/xen"): + is_xen_host = False + try: + for line in get_file_lines('/proc/xen/capabilities'): + if "control_d" in line: + is_xen_host = True + except IOError: + pass + + if is_xen_host: + host_tech.add('xen') + if not found_virt: + virtual_facts['virtualization_type'] = 'xen' + virtual_facts['virtualization_role'] = 'host' + else: + if not found_virt: + virtual_facts['virtualization_type'] = 'xen' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + # assume guest for this block + if not found_virt: + virtual_facts['virtualization_role'] = 'guest' + + product_name = get_file_content('/sys/devices/virtual/dmi/id/product_name') + sys_vendor = get_file_content('/sys/devices/virtual/dmi/id/sys_vendor') + product_family = get_file_content('/sys/devices/virtual/dmi/id/product_family') + + if product_name in ('KVM', 'KVM Server', 'Bochs', 'AHV'): + guest_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + found_virt = True + + if sys_vendor == 'oVirt': + guest_tech.add('oVirt') + if not found_virt: + virtual_facts['virtualization_type'] = 'oVirt' + found_virt = True + + if sys_vendor == 'Red Hat': + if product_family == 'RHV': + guest_tech.add('RHV') + if not found_virt: + virtual_facts['virtualization_type'] = 'RHV' + found_virt = True + elif product_name == 'RHEV Hypervisor': + guest_tech.add('RHEV') + if not found_virt: + virtual_facts['virtualization_type'] = 'RHEV' + found_virt = True + + if product_name in ('VMware Virtual Platform', 'VMware7,1'): + guest_tech.add('VMware') + if not found_virt: + virtual_facts['virtualization_type'] = 'VMware' + found_virt = True + + if product_name in ('OpenStack Compute', 'OpenStack Nova'): + guest_tech.add('openstack') + if not found_virt: + virtual_facts['virtualization_type'] = 'openstack' + found_virt = True + + bios_vendor = get_file_content('/sys/devices/virtual/dmi/id/bios_vendor') + + if bios_vendor == 'Xen': + guest_tech.add('xen') + if not found_virt: + virtual_facts['virtualization_type'] = 'xen' + found_virt = True + + if bios_vendor == 'innotek GmbH': + guest_tech.add('virtualbox') + if not found_virt: + virtual_facts['virtualization_type'] = 'virtualbox' + found_virt = True + + if bios_vendor in ('Amazon EC2', 'DigitalOcean', 'Hetzner'): + guest_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + found_virt = True + + KVM_SYS_VENDORS = ('QEMU', 'Amazon EC2', 'DigitalOcean', 'Google', 'Scaleway', 'Nutanix') + if sys_vendor in KVM_SYS_VENDORS: + guest_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + found_virt = True + + if sys_vendor == 'KubeVirt': + guest_tech.add('KubeVirt') + if not found_virt: + virtual_facts['virtualization_type'] = 'KubeVirt' + found_virt = True + + # FIXME: This does also match hyperv + if sys_vendor == 'Microsoft Corporation': + guest_tech.add('VirtualPC') + if not found_virt: + virtual_facts['virtualization_type'] = 'VirtualPC' + found_virt = True + + if sys_vendor == 'Parallels Software International Inc.': + guest_tech.add('parallels') + if not found_virt: + virtual_facts['virtualization_type'] = 'parallels' + found_virt = True + + if sys_vendor == 'OpenStack Foundation': + guest_tech.add('openstack') + if not found_virt: + virtual_facts['virtualization_type'] = 'openstack' + found_virt = True + + # unassume guest + if not found_virt: + del virtual_facts['virtualization_role'] + + if os.path.exists('/proc/self/status'): + for line in get_file_lines('/proc/self/status'): + if re.match(r'^VxID:\s+\d+', line): + if not found_virt: + virtual_facts['virtualization_type'] = 'linux_vserver' + if re.match(r'^VxID:\s+0', line): + host_tech.add('linux_vserver') + if not found_virt: + virtual_facts['virtualization_role'] = 'host' + else: + guest_tech.add('linux_vserver') + if not found_virt: + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + if os.path.exists('/proc/cpuinfo'): + for line in get_file_lines('/proc/cpuinfo'): + if re.match('^model name.*QEMU Virtual CPU', line): + guest_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + elif re.match('^vendor_id.*User Mode Linux', line): + guest_tech.add('uml') + if not found_virt: + virtual_facts['virtualization_type'] = 'uml' + elif re.match('^model name.*UML', line): + guest_tech.add('uml') + if not found_virt: + virtual_facts['virtualization_type'] = 'uml' + elif re.match('^machine.*CHRP IBM pSeries .emulated by qemu.', line): + guest_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + elif re.match('^vendor_id.*PowerVM Lx86', line): + guest_tech.add('powervm_lx86') + if not found_virt: + virtual_facts['virtualization_type'] = 'powervm_lx86' + elif re.match('^vendor_id.*IBM/S390', line): + guest_tech.add('PR/SM') + if not found_virt: + virtual_facts['virtualization_type'] = 'PR/SM' + lscpu = self.module.get_bin_path('lscpu') + if lscpu: + rc, out, err = self.module.run_command(["lscpu"]) + if rc == 0: + for line in out.splitlines(): + data = line.split(":", 1) + key = data[0].strip() + if key == 'Hypervisor': + tech = data[1].strip() + guest_tech.add(tech) + if not found_virt: + virtual_facts['virtualization_type'] = tech + else: + guest_tech.add('ibm_systemz') + if not found_virt: + virtual_facts['virtualization_type'] = 'ibm_systemz' + else: + continue + if virtual_facts['virtualization_type'] == 'PR/SM': + if not found_virt: + virtual_facts['virtualization_role'] = 'LPAR' + else: + if not found_virt: + virtual_facts['virtualization_role'] = 'guest' + if not found_virt: + found_virt = True + + # Beware that we can have both kvm and virtualbox running on a single system + if os.path.exists("/proc/modules") and os.access('/proc/modules', os.R_OK): + modules = [] + for line in get_file_lines("/proc/modules"): + data = line.split(" ", 1) + modules.append(data[0]) + + if 'kvm' in modules: + host_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + virtual_facts['virtualization_role'] = 'host' + + if os.path.isdir('/rhev/'): + # Check whether this is a RHEV hypervisor (is vdsm running ?) + for f in glob.glob('/proc/[0-9]*/comm'): + try: + with open(f) as virt_fh: + comm_content = virt_fh.read().rstrip() + + if comm_content in ('vdsm', 'vdsmd'): + # We add both kvm and RHEV to host_tech in this case. + # It's accurate. RHEV uses KVM. + host_tech.add('RHEV') + if not found_virt: + virtual_facts['virtualization_type'] = 'RHEV' + break + except Exception: + pass + + found_virt = True + + if 'vboxdrv' in modules: + host_tech.add('virtualbox') + if not found_virt: + virtual_facts['virtualization_type'] = 'virtualbox' + virtual_facts['virtualization_role'] = 'host' + found_virt = True + + if 'virtio' in modules: + host_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + # In older Linux Kernel versions, /sys filesystem is not available + # dmidecode is the safest option to parse virtualization related values + dmi_bin = self.module.get_bin_path('dmidecode') + # We still want to continue even if dmidecode is not available + if dmi_bin is not None: + (rc, out, err) = self.module.run_command('%s -s system-product-name' % dmi_bin) + if rc == 0: + # Strip out commented lines (specific dmidecode output) + vendor_name = ''.join([line.strip() for line in out.splitlines() if not line.startswith('#')]) + if vendor_name.startswith('VMware'): + guest_tech.add('VMware') + if not found_virt: + virtual_facts['virtualization_type'] = 'VMware' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + if 'BHYVE' in out: + guest_tech.add('bhyve') + if not found_virt: + virtual_facts['virtualization_type'] = 'bhyve' + virtual_facts['virtualization_role'] = 'guest' + found_virt = True + + if os.path.exists('/dev/kvm'): + host_tech.add('kvm') + if not found_virt: + virtual_facts['virtualization_type'] = 'kvm' + virtual_facts['virtualization_role'] = 'host' + found_virt = True + + # If none of the above matches, return 'NA' for virtualization_type + # and virtualization_role. This allows for proper grouping. + if not found_virt: + virtual_facts['virtualization_type'] = 'NA' + virtual_facts['virtualization_role'] = 'NA' + found_virt = True + + virtual_facts['virtualization_tech_guest'] = guest_tech + virtual_facts['virtualization_tech_host'] = host_tech + return virtual_facts + + +class LinuxVirtualCollector(VirtualCollector): + _fact_class = LinuxVirtual + _platform = 'Linux' diff --git a/lib/ansible/module_utils/facts/virtual/netbsd.py b/lib/ansible/module_utils/facts/virtual/netbsd.py new file mode 100644 index 0000000..b4ef14e --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/netbsd.py @@ -0,0 +1,73 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from ansible.module_utils.facts.virtual.base import Virtual, VirtualCollector +from ansible.module_utils.facts.virtual.sysctl import VirtualSysctlDetectionMixin + + +class NetBSDVirtual(Virtual, VirtualSysctlDetectionMixin): + platform = 'NetBSD' + + def get_virtual_facts(self): + virtual_facts = {} + host_tech = set() + guest_tech = set() + + # Set empty values as default + virtual_facts['virtualization_type'] = '' + virtual_facts['virtualization_role'] = '' + + virtual_product_facts = self.detect_virt_product('machdep.dmi.system-product') + guest_tech.update(virtual_product_facts['virtualization_tech_guest']) + host_tech.update(virtual_product_facts['virtualization_tech_host']) + virtual_facts.update(virtual_product_facts) + + virtual_vendor_facts = self.detect_virt_vendor('machdep.dmi.system-vendor') + guest_tech.update(virtual_vendor_facts['virtualization_tech_guest']) + host_tech.update(virtual_vendor_facts['virtualization_tech_host']) + + if virtual_facts['virtualization_type'] == '': + virtual_facts.update(virtual_vendor_facts) + + # The above logic is tried first for backwards compatibility. If + # something above matches, use it. Otherwise if the result is still + # empty, try machdep.hypervisor. + virtual_vendor_facts = self.detect_virt_vendor('machdep.hypervisor') + guest_tech.update(virtual_vendor_facts['virtualization_tech_guest']) + host_tech.update(virtual_vendor_facts['virtualization_tech_host']) + + if virtual_facts['virtualization_type'] == '': + virtual_facts.update(virtual_vendor_facts) + + if os.path.exists('/dev/xencons'): + guest_tech.add('xen') + + if virtual_facts['virtualization_type'] == '': + virtual_facts['virtualization_type'] = 'xen' + virtual_facts['virtualization_role'] = 'guest' + + virtual_facts['virtualization_tech_guest'] = guest_tech + virtual_facts['virtualization_tech_host'] = host_tech + return virtual_facts + + +class NetBSDVirtualCollector(VirtualCollector): + _fact_class = NetBSDVirtual + _platform = 'NetBSD' diff --git a/lib/ansible/module_utils/facts/virtual/openbsd.py b/lib/ansible/module_utils/facts/virtual/openbsd.py new file mode 100644 index 0000000..c449028 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/openbsd.py @@ -0,0 +1,74 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + +from ansible.module_utils.facts.virtual.base import Virtual, VirtualCollector +from ansible.module_utils.facts.virtual.sysctl import VirtualSysctlDetectionMixin + +from ansible.module_utils.facts.utils import get_file_content + + +class OpenBSDVirtual(Virtual, VirtualSysctlDetectionMixin): + """ + This is a OpenBSD-specific subclass of Virtual. It defines + - virtualization_type + - virtualization_role + """ + platform = 'OpenBSD' + DMESG_BOOT = '/var/run/dmesg.boot' + + def get_virtual_facts(self): + virtual_facts = {} + host_tech = set() + guest_tech = set() + + # Set empty values as default + virtual_facts['virtualization_type'] = '' + virtual_facts['virtualization_role'] = '' + + virtual_product_facts = self.detect_virt_product('hw.product') + guest_tech.update(virtual_product_facts['virtualization_tech_guest']) + host_tech.update(virtual_product_facts['virtualization_tech_host']) + virtual_facts.update(virtual_product_facts) + + virtual_vendor_facts = self.detect_virt_vendor('hw.vendor') + guest_tech.update(virtual_vendor_facts['virtualization_tech_guest']) + host_tech.update(virtual_vendor_facts['virtualization_tech_host']) + + if virtual_facts['virtualization_type'] == '': + virtual_facts.update(virtual_vendor_facts) + + # Check the dmesg if vmm(4) attached, indicating the host is + # capable of virtualization. + dmesg_boot = get_file_content(OpenBSDVirtual.DMESG_BOOT) + for line in dmesg_boot.splitlines(): + match = re.match('^vmm0 at mainbus0: (SVM/RVI|VMX/EPT)$', line) + if match: + host_tech.add('vmm') + virtual_facts['virtualization_type'] = 'vmm' + virtual_facts['virtualization_role'] = 'host' + + virtual_facts['virtualization_tech_guest'] = guest_tech + virtual_facts['virtualization_tech_host'] = host_tech + return virtual_facts + + +class OpenBSDVirtualCollector(VirtualCollector): + _fact_class = OpenBSDVirtual + _platform = 'OpenBSD' diff --git a/lib/ansible/module_utils/facts/virtual/sunos.py b/lib/ansible/module_utils/facts/virtual/sunos.py new file mode 100644 index 0000000..1e92677 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/sunos.py @@ -0,0 +1,139 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os + +from ansible.module_utils.facts.virtual.base import Virtual, VirtualCollector + + +class SunOSVirtual(Virtual): + """ + This is a SunOS-specific subclass of Virtual. It defines + - virtualization_type + - virtualization_role + - container + """ + platform = 'SunOS' + + def get_virtual_facts(self): + virtual_facts = {} + host_tech = set() + guest_tech = set() + + # Check if it's a zone + zonename = self.module.get_bin_path('zonename') + if zonename: + rc, out, err = self.module.run_command(zonename) + if rc == 0: + if out.rstrip() == "global": + host_tech.add('zone') + else: + guest_tech.add('zone') + virtual_facts['container'] = 'zone' + + # Check if it's a branded zone (i.e. Solaris 8/9 zone) + if os.path.isdir('/.SUNWnative'): + guest_tech.add('zone') + virtual_facts['container'] = 'zone' + + # If it's a zone check if we can detect if our global zone is itself virtualized. + # Relies on the "guest tools" (e.g. vmware tools) to be installed + if 'container' in virtual_facts and virtual_facts['container'] == 'zone': + modinfo = self.module.get_bin_path('modinfo') + if modinfo: + rc, out, err = self.module.run_command(modinfo) + if rc == 0: + for line in out.splitlines(): + if 'VMware' in line: + guest_tech.add('vmware') + virtual_facts['virtualization_type'] = 'vmware' + virtual_facts['virtualization_role'] = 'guest' + if 'VirtualBox' in line: + guest_tech.add('virtualbox') + virtual_facts['virtualization_type'] = 'virtualbox' + virtual_facts['virtualization_role'] = 'guest' + + if os.path.exists('/proc/vz'): + guest_tech.add('virtuozzo') + virtual_facts['virtualization_type'] = 'virtuozzo' + virtual_facts['virtualization_role'] = 'guest' + + # Detect domaining on Sparc hardware + virtinfo = self.module.get_bin_path('virtinfo') + if virtinfo: + # The output of virtinfo is different whether we are on a machine with logical + # domains ('LDoms') on a T-series or domains ('Domains') on a M-series. Try LDoms first. + rc, out, err = self.module.run_command("/usr/sbin/virtinfo -p") + # The output contains multiple lines with different keys like this: + # DOMAINROLE|impl=LDoms|control=false|io=false|service=false|root=false + # The output may also be not formatted and the returncode is set to 0 regardless of the error condition: + # virtinfo can only be run from the global zone + if rc == 0: + try: + for line in out.splitlines(): + fields = line.split('|') + if fields[0] == 'DOMAINROLE' and fields[1] == 'impl=LDoms': + guest_tech.add('ldom') + virtual_facts['virtualization_type'] = 'ldom' + virtual_facts['virtualization_role'] = 'guest' + hostfeatures = [] + for field in fields[2:]: + arg = field.split('=') + if arg[1] == 'true': + hostfeatures.append(arg[0]) + if len(hostfeatures) > 0: + virtual_facts['virtualization_role'] = 'host (' + ','.join(hostfeatures) + ')' + except ValueError: + pass + + else: + smbios = self.module.get_bin_path('smbios') + if not smbios: + return + rc, out, err = self.module.run_command(smbios) + if rc == 0: + for line in out.splitlines(): + if 'VMware' in line: + guest_tech.add('vmware') + virtual_facts['virtualization_type'] = 'vmware' + virtual_facts['virtualization_role'] = 'guest' + elif 'Parallels' in line: + guest_tech.add('parallels') + virtual_facts['virtualization_type'] = 'parallels' + virtual_facts['virtualization_role'] = 'guest' + elif 'VirtualBox' in line: + guest_tech.add('virtualbox') + virtual_facts['virtualization_type'] = 'virtualbox' + virtual_facts['virtualization_role'] = 'guest' + elif 'HVM domU' in line: + guest_tech.add('xen') + virtual_facts['virtualization_type'] = 'xen' + virtual_facts['virtualization_role'] = 'guest' + elif 'KVM' in line: + guest_tech.add('kvm') + virtual_facts['virtualization_type'] = 'kvm' + virtual_facts['virtualization_role'] = 'guest' + + virtual_facts['virtualization_tech_guest'] = guest_tech + virtual_facts['virtualization_tech_host'] = host_tech + return virtual_facts + + +class SunOSVirtualCollector(VirtualCollector): + _fact_class = SunOSVirtual + _platform = 'SunOS' diff --git a/lib/ansible/module_utils/facts/virtual/sysctl.py b/lib/ansible/module_utils/facts/virtual/sysctl.py new file mode 100644 index 0000000..1c7b2b3 --- /dev/null +++ b/lib/ansible/module_utils/facts/virtual/sysctl.py @@ -0,0 +1,112 @@ +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re + + +class VirtualSysctlDetectionMixin(object): + def detect_sysctl(self): + self.sysctl_path = self.module.get_bin_path('sysctl') + + def detect_virt_product(self, key): + virtual_product_facts = {} + host_tech = set() + guest_tech = set() + + # We do similar to what we do in linux.py -- We want to allow multiple + # virt techs to show up, but maintain compatibility, so we have to track + # when we would have stopped, even though now we go through everything. + found_virt = False + + self.detect_sysctl() + if self.sysctl_path: + rc, out, err = self.module.run_command("%s -n %s" % (self.sysctl_path, key)) + if rc == 0: + if re.match('(KVM|kvm|Bochs|SmartDC).*', out): + guest_tech.add('kvm') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'kvm' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + if re.match('.*VMware.*', out): + guest_tech.add('VMware') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'VMware' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + if out.rstrip() == 'VirtualBox': + guest_tech.add('virtualbox') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'virtualbox' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + if re.match('(HVM domU|XenPVH|XenPV|XenPVHVM).*', out): + guest_tech.add('xen') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'xen' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + if out.rstrip() == 'Hyper-V': + guest_tech.add('Hyper-V') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'Hyper-V' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + if out.rstrip() == 'Parallels': + guest_tech.add('parallels') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'parallels' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + if out.rstrip() == 'RHEV Hypervisor': + guest_tech.add('RHEV') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'RHEV' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + if (key == 'security.jail.jailed') and (out.rstrip() == '1'): + guest_tech.add('jails') + if not found_virt: + virtual_product_facts['virtualization_type'] = 'jails' + virtual_product_facts['virtualization_role'] = 'guest' + found_virt = True + + virtual_product_facts['virtualization_tech_guest'] = guest_tech + virtual_product_facts['virtualization_tech_host'] = host_tech + return virtual_product_facts + + def detect_virt_vendor(self, key): + virtual_vendor_facts = {} + host_tech = set() + guest_tech = set() + self.detect_sysctl() + if self.sysctl_path: + rc, out, err = self.module.run_command("%s -n %s" % (self.sysctl_path, key)) + if rc == 0: + if out.rstrip() == 'QEMU': + guest_tech.add('kvm') + virtual_vendor_facts['virtualization_type'] = 'kvm' + virtual_vendor_facts['virtualization_role'] = 'guest' + if out.rstrip() == 'OpenBSD': + guest_tech.add('vmm') + virtual_vendor_facts['virtualization_type'] = 'vmm' + virtual_vendor_facts['virtualization_role'] = 'guest' + + virtual_vendor_facts['virtualization_tech_guest'] = guest_tech + virtual_vendor_facts['virtualization_tech_host'] = host_tech + return virtual_vendor_facts diff --git a/lib/ansible/module_utils/json_utils.py b/lib/ansible/module_utils/json_utils.py new file mode 100644 index 0000000..0e95aa6 --- /dev/null +++ b/lib/ansible/module_utils/json_utils.py @@ -0,0 +1,79 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import json + + +# NB: a copy of this function exists in ../../modules/core/async_wrapper.py. Ensure any +# changes are propagated there. +def _filter_non_json_lines(data, objects_only=False): + ''' + Used to filter unrelated output around module JSON output, like messages from + tcagetattr, or where dropbear spews MOTD on every single command (which is nuts). + + Filters leading lines before first line-starting occurrence of '{' or '[', and filter all + trailing lines after matching close character (working from the bottom of output). + ''' + warnings = [] + + # Filter initial junk + lines = data.splitlines() + + for start, line in enumerate(lines): + line = line.strip() + if line.startswith(u'{'): + endchar = u'}' + break + elif not objects_only and line.startswith(u'['): + endchar = u']' + break + else: + raise ValueError('No start of json char found') + + # Filter trailing junk + lines = lines[start:] + + for reverse_end_offset, line in enumerate(reversed(lines)): + if line.strip().endswith(endchar): + break + else: + raise ValueError('No end of json char found') + + if reverse_end_offset > 0: + # Trailing junk is uncommon and can point to things the user might + # want to change. So print a warning if we find any + trailing_junk = lines[len(lines) - reverse_end_offset:] + for line in trailing_junk: + if line.strip(): + warnings.append('Module invocation had junk after the JSON data: %s' % '\n'.join(trailing_junk)) + break + + lines = lines[:(len(lines) - reverse_end_offset)] + + return ('\n'.join(lines), warnings) diff --git a/lib/ansible/module_utils/parsing/__init__.py b/lib/ansible/module_utils/parsing/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/parsing/__init__.py diff --git a/lib/ansible/module_utils/parsing/convert_bool.py b/lib/ansible/module_utils/parsing/convert_bool.py new file mode 100644 index 0000000..7eea875 --- /dev/null +++ b/lib/ansible/module_utils/parsing/convert_bool.py @@ -0,0 +1,29 @@ +# Copyright: 2017, Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause ) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.module_utils.six import binary_type, text_type +from ansible.module_utils._text import to_text + + +BOOLEANS_TRUE = frozenset(('y', 'yes', 'on', '1', 'true', 't', 1, 1.0, True)) +BOOLEANS_FALSE = frozenset(('n', 'no', 'off', '0', 'false', 'f', 0, 0.0, False)) +BOOLEANS = BOOLEANS_TRUE.union(BOOLEANS_FALSE) + + +def boolean(value, strict=True): + if isinstance(value, bool): + return value + + normalized_value = value + if isinstance(value, (text_type, binary_type)): + normalized_value = to_text(value, errors='surrogate_or_strict').lower().strip() + + if normalized_value in BOOLEANS_TRUE: + return True + elif normalized_value in BOOLEANS_FALSE or not strict: + return False + + raise TypeError("The value '%s' is not a valid boolean. Valid booleans include: %s" % (to_text(value), ', '.join(repr(i) for i in BOOLEANS))) diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.AddType.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.AddType.psm1 new file mode 100644 index 0000000..6dc2917 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.AddType.psm1 @@ -0,0 +1,398 @@ +# Copyright (c) 2018 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +Function Add-CSharpType { + <# + .SYNOPSIS + Compiles one or more C# scripts similar to Add-Type. This exposes + more configuration options that are useable within Ansible and it + also allows multiple C# sources to be compiled together. + + .PARAMETER References + [String[]] A collection of C# scripts to compile together. + + .PARAMETER IgnoreWarnings + [Switch] Whether to compile code that contains compiler warnings, by + default warnings will cause a compiler error. + + .PARAMETER PassThru + [Switch] Whether to return the loaded Assembly + + .PARAMETER AnsibleModule + [Ansible.Basic.AnsibleModule] used to derive the TempPath and Debug values. + TempPath is set to the Tmpdir property of the class + IncludeDebugInfo is set when the Ansible verbosity is >= 3 + + .PARAMETER TempPath + [String] The temporary directory in which the dynamic assembly is + compiled to. This file is deleted once compilation is complete. + Cannot be used when AnsibleModule is set. This is a no-op when + running on PSCore. + + .PARAMETER IncludeDebugInfo + [Switch] Whether to include debug information in the compiled + assembly. Cannot be used when AnsibleModule is set. This is a no-op + when running on PSCore. + + .PARAMETER CompileSymbols + [String[]] A list of symbols to be defined during compile time. These are + added to the existing symbols, 'CORECLR', 'WINDOWS', 'UNIX' that are set + conditionalls in this cmdlet. + + .NOTES + The following features were added to control the compiling options from the + code itself. + + * Predefined compiler SYMBOLS + + * CORECLR - Added when running on PowerShell Core. + * WINDOWS - Added when running on Windows. + * UNIX - Added when running on non-Windows. + * X86 - Added when running on a 32-bit process (Ansible 2.10+) + * AMD64 - Added when running on a 64-bit process (Ansible 2.10+) + + * Ignore compiler warnings inline with the following comment inline + + //NoWarn -Name <rule code> [-CLR Core|Framework] + + * Specify custom assembly references inline + + //AssemblyReference -Name Dll.Location.dll [-CLR Core|Framework] + + # Added in Ansible 2.10 + //AssemblyReference -Type System.Type.Name [-CLR Core|Framework] + + * Create automatic type accelerators to simplify long namespace names (Ansible 2.9+) + + //TypeAccelerator -Name <AcceleratorName> -TypeName <Name of compiled type> + #> + param( + [Parameter(Mandatory = $true)][AllowEmptyCollection()][String[]]$References, + [Switch]$IgnoreWarnings, + [Switch]$PassThru, + [Parameter(Mandatory = $true, ParameterSetName = "Module")][Object]$AnsibleModule, + [Parameter(ParameterSetName = "Manual")][String]$TempPath = $env:TMP, + [Parameter(ParameterSetName = "Manual")][Switch]$IncludeDebugInfo, + [String[]]$CompileSymbols = @() + ) + if ($null -eq $References -or $References.Length -eq 0) { + return + } + + # define special symbols CORECLR, WINDOWS, UNIX if required + # the Is* variables are defined on PSCore, if absent we assume an + # older version of PowerShell under .NET Framework and Windows + $defined_symbols = [System.Collections.ArrayList]$CompileSymbols + + if ([System.IntPtr]::Size -eq 4) { + $defined_symbols.Add('X86') > $null + } + else { + $defined_symbols.Add('AMD64') > $null + } + + $is_coreclr = Get-Variable -Name IsCoreCLR -ErrorAction SilentlyContinue + if ($null -ne $is_coreclr) { + if ($is_coreclr.Value) { + $defined_symbols.Add("CORECLR") > $null + } + } + $is_windows = Get-Variable -Name IsWindows -ErrorAction SilentlyContinue + if ($null -ne $is_windows) { + if ($is_windows.Value) { + $defined_symbols.Add("WINDOWS") > $null + } + else { + $defined_symbols.Add("UNIX") > $null + } + } + else { + $defined_symbols.Add("WINDOWS") > $null + } + + # Store any TypeAccelerators shortcuts the util wants us to set + $type_accelerators = [System.Collections.Generic.List`1[Hashtable]]@() + + # pattern used to find referenced assemblies in the code + $assembly_pattern = [Regex]"//\s*AssemblyReference\s+-(?<Parameter>(Name)|(Type))\s+(?<Name>[\w.]*)(\s+-CLR\s+(?<CLR>Core|Framework))?" + $no_warn_pattern = [Regex]"//\s*NoWarn\s+-Name\s+(?<Name>[\w\d]*)(\s+-CLR\s+(?<CLR>Core|Framework))?" + $type_pattern = [Regex]"//\s*TypeAccelerator\s+-Name\s+(?<Name>[\w.]*)\s+-TypeName\s+(?<TypeName>[\w.]*)" + + # PSCore vs PSDesktop use different methods to compile the code, + # PSCore uses Roslyn and can compile the code purely in memory + # without touching the disk while PSDesktop uses CodeDom and csc.exe + # to compile the code. We branch out here and run each + # distribution's method to add our C# code. + if ($is_coreclr) { + # compile the code using Roslyn on PSCore + + # Include the default assemblies using the logic in Add-Type + # https://github.com/PowerShell/PowerShell/blob/master/src/Microsoft.PowerShell.Commands.Utility/commands/utility/AddType.cs + $assemblies = [System.Collections.Generic.HashSet`1[Microsoft.CodeAnalysis.MetadataReference]]@( + [Microsoft.CodeAnalysis.CompilationReference]::CreateFromFile(([System.Reflection.Assembly]::GetAssembly([PSObject])).Location) + ) + $netcore_app_ref_folder = [System.IO.Path]::Combine([System.IO.Path]::GetDirectoryName([PSObject].Assembly.Location), "ref") + $lib_assembly_location = [System.IO.Path]::GetDirectoryName([object].Assembly.Location) + foreach ($file in [System.IO.Directory]::EnumerateFiles($netcore_app_ref_folder, "*.dll", [System.IO.SearchOption]::TopDirectoryOnly)) { + $assemblies.Add([Microsoft.CodeAnalysis.MetadataReference]::CreateFromFile($file)) > $null + } + + # loop through the references, parse as a SyntaxTree and get + # referenced assemblies + $ignore_warnings = New-Object -TypeName 'System.Collections.Generic.Dictionary`2[[String], [Microsoft.CodeAnalysis.ReportDiagnostic]]' + $parse_options = ([Microsoft.CodeAnalysis.CSharp.CSharpParseOptions]::Default).WithPreprocessorSymbols($defined_symbols) + $syntax_trees = [System.Collections.Generic.List`1[Microsoft.CodeAnalysis.SyntaxTree]]@() + foreach ($reference in $References) { + # scan through code and add any assemblies that match + # //AssemblyReference -Name ... [-CLR Core] + # //NoWarn -Name ... [-CLR Core] + # //TypeAccelerator -Name ... -TypeName ... + $assembly_matches = $assembly_pattern.Matches($reference) + foreach ($match in $assembly_matches) { + $clr = $match.Groups["CLR"].Value + if ($clr -and $clr -ne "Core") { + continue + } + + $parameter_type = $match.Groups["Parameter"].Value + $assembly_path = $match.Groups["Name"].Value + if ($parameter_type -eq "Type") { + $assembly_path = ([Type]$assembly_path).Assembly.Location + } + else { + if (-not ([System.IO.Path]::IsPathRooted($assembly_path))) { + $assembly_path = Join-Path -Path $lib_assembly_location -ChildPath $assembly_path + } + } + $assemblies.Add([Microsoft.CodeAnalysis.MetadataReference]::CreateFromFile($assembly_path)) > $null + } + $warn_matches = $no_warn_pattern.Matches($reference) + foreach ($match in $warn_matches) { + $clr = $match.Groups["CLR"].Value + if ($clr -and $clr -ne "Core") { + continue + } + $ignore_warnings.Add($match.Groups["Name"], [Microsoft.CodeAnalysis.ReportDiagnostic]::Suppress) + } + $syntax_trees.Add([Microsoft.CodeAnalysis.CSharp.CSharpSyntaxTree]::ParseText($reference, $parse_options)) > $null + + $type_matches = $type_pattern.Matches($reference) + foreach ($match in $type_matches) { + $type_accelerators.Add(@{Name = $match.Groups["Name"].Value; TypeName = $match.Groups["TypeName"].Value }) + } + } + + # Release seems to contain the correct line numbers compared to + # debug,may need to keep a closer eye on this in the future + $compiler_options = (New-Object -TypeName Microsoft.CodeAnalysis.CSharp.CSharpCompilationOptions -ArgumentList @( + [Microsoft.CodeAnalysis.OutputKind]::DynamicallyLinkedLibrary + )).WithOptimizationLevel([Microsoft.CodeAnalysis.OptimizationLevel]::Release) + + # set warnings to error out if IgnoreWarnings is not set + if (-not $IgnoreWarnings.IsPresent) { + $compiler_options = $compiler_options.WithGeneralDiagnosticOption([Microsoft.CodeAnalysis.ReportDiagnostic]::Error) + $compiler_options = $compiler_options.WithSpecificDiagnosticOptions($ignore_warnings) + } + + # create compilation object + $compilation = [Microsoft.CodeAnalysis.CSharp.CSharpCompilation]::Create( + [System.Guid]::NewGuid().ToString(), + $syntax_trees, + $assemblies, + $compiler_options + ) + + # Load the compiled code and pdb info, we do this so we can + # include line number in a stracktrace + $code_ms = New-Object -TypeName System.IO.MemoryStream + $pdb_ms = New-Object -TypeName System.IO.MemoryStream + try { + $emit_result = $compilation.Emit($code_ms, $pdb_ms) + if (-not $emit_result.Success) { + $errors = [System.Collections.ArrayList]@() + + foreach ($e in $emit_result.Diagnostics) { + # builds the error msg, based on logic in Add-Type + # https://github.com/PowerShell/PowerShell/blob/master/src/Microsoft.PowerShell.Commands.Utility/commands/utility/AddType.cs#L1239 + if ($null -eq $e.Location.SourceTree) { + $errors.Add($e.ToString()) > $null + continue + } + + $cancel_token = New-Object -TypeName System.Threading.CancellationToken -ArgumentList $false + $text_lines = $e.Location.SourceTree.GetText($cancel_token).Lines + $line_span = $e.Location.GetLineSpan() + + $diagnostic_message = $e.ToString() + $error_line_string = $text_lines[$line_span.StartLinePosition.Line].ToString() + $error_position = $line_span.StartLinePosition.Character + + $sb = New-Object -TypeName System.Text.StringBuilder -ArgumentList ($diagnostic_message.Length + $error_line_string.Length * 2 + 4) + $sb.AppendLine($diagnostic_message) + $sb.AppendLine($error_line_string) + + for ($i = 0; $i -lt $error_line_string.Length; $i++) { + if ([System.Char]::IsWhiteSpace($error_line_string[$i])) { + continue + } + $sb.Append($error_line_string, 0, $i) + $sb.Append(' ', [Math]::Max(0, $error_position - $i)) + $sb.Append("^") + break + } + + $errors.Add($sb.ToString()) > $null + } + + throw [InvalidOperationException]"Failed to compile C# code:`r`n$($errors -join "`r`n")" + } + + $code_ms.Seek(0, [System.IO.SeekOrigin]::Begin) > $null + $pdb_ms.Seek(0, [System.IO.SeekOrigin]::Begin) > $null + $compiled_assembly = [System.Runtime.Loader.AssemblyLoadContext]::Default.LoadFromStream($code_ms, $pdb_ms) + } + finally { + $code_ms.Close() + $pdb_ms.Close() + } + } + else { + # compile the code using CodeDom on PSDesktop + + # configure compile options based on input + if ($PSCmdlet.ParameterSetName -eq "Module") { + $temp_path = $AnsibleModule.Tmpdir + $include_debug = $AnsibleModule.Verbosity -ge 3 + } + else { + $temp_path = $TempPath + $include_debug = $IncludeDebugInfo.IsPresent + } + $compiler_options = [System.Collections.ArrayList]@("/optimize") + if ($defined_symbols.Count -gt 0) { + $compiler_options.Add("/define:" + ([String]::Join(";", $defined_symbols.ToArray()))) > $null + } + + $compile_parameters = New-Object -TypeName System.CodeDom.Compiler.CompilerParameters + $compile_parameters.GenerateExecutable = $false + $compile_parameters.GenerateInMemory = $true + $compile_parameters.TreatWarningsAsErrors = (-not $IgnoreWarnings.IsPresent) + $compile_parameters.IncludeDebugInformation = $include_debug + $compile_parameters.TempFiles = (New-Object -TypeName System.CodeDom.Compiler.TempFileCollection -ArgumentList $temp_path, $false) + + # Add-Type automatically references System.dll, System.Core.dll, + # and System.Management.Automation.dll which we replicate here + $assemblies = [System.Collections.Generic.HashSet`1[String]]@( + "System.dll", + "System.Core.dll", + ([System.Reflection.Assembly]::GetAssembly([PSObject])).Location + ) + + # create a code snippet for each reference and check if we need + # to reference any extra assemblies + $ignore_warnings = [System.Collections.ArrayList]@() + $compile_units = [System.Collections.Generic.List`1[System.CodeDom.CodeSnippetCompileUnit]]@() + foreach ($reference in $References) { + # scan through code and add any assemblies that match + # //AssemblyReference -Name ... [-CLR Framework] + # //NoWarn -Name ... [-CLR Framework] + # //TypeAccelerator -Name ... -TypeName ... + $assembly_matches = $assembly_pattern.Matches($reference) + foreach ($match in $assembly_matches) { + $clr = $match.Groups["CLR"].Value + if ($clr -and $clr -ne "Framework") { + continue + } + + $parameter_type = $match.Groups["Parameter"].Value + $assembly_path = $match.Groups["Name"].Value + if ($parameter_type -eq "Type") { + $assembly_path = ([Type]$assembly_path).Assembly.Location + } + $assemblies.Add($assembly_path) > $null + } + $warn_matches = $no_warn_pattern.Matches($reference) + foreach ($match in $warn_matches) { + $clr = $match.Groups["CLR"].Value + if ($clr -and $clr -ne "Framework") { + continue + } + $warning_id = $match.Groups["Name"].Value + # /nowarn should only contain the numeric part + if ($warning_id.StartsWith("CS")) { + $warning_id = $warning_id.Substring(2) + } + $ignore_warnings.Add($warning_id) > $null + } + $compile_units.Add((New-Object -TypeName System.CodeDom.CodeSnippetCompileUnit -ArgumentList $reference)) > $null + + $type_matches = $type_pattern.Matches($reference) + foreach ($match in $type_matches) { + $type_accelerators.Add(@{Name = $match.Groups["Name"].Value; TypeName = $match.Groups["TypeName"].Value }) + } + } + if ($ignore_warnings.Count -gt 0) { + $compiler_options.Add("/nowarn:" + ([String]::Join(",", $ignore_warnings.ToArray()))) > $null + } + $compile_parameters.ReferencedAssemblies.AddRange($assemblies) + $compile_parameters.CompilerOptions = [String]::Join(" ", $compiler_options.ToArray()) + + # compile the code together and check for errors + $provider = New-Object -TypeName Microsoft.CSharp.CSharpCodeProvider + + # This calls csc.exe which can take compiler options from environment variables. Currently these env vars + # are known to have problems so they are unset: + # LIB - additional library paths will fail the compilation if they are invalid + $originalEnv = @{} + try { + 'LIB' | ForEach-Object -Process { + $value = Get-Item -LiteralPath "Env:\$_" -ErrorAction SilentlyContinue + if ($value) { + $originalEnv[$_] = $value + Remove-Item -LiteralPath "Env:\$_" + } + } + + $compile = $provider.CompileAssemblyFromDom($compile_parameters, $compile_units) + } + finally { + foreach ($kvp in $originalEnv.GetEnumerator()) { + [System.Environment]::SetEnvironmentVariable($kvp.Key, $kvp.Value, "Process") + } + } + + if ($compile.Errors.HasErrors) { + $msg = "Failed to compile C# code: " + foreach ($e in $compile.Errors) { + $msg += "`r`n" + $e.ToString() + } + throw [InvalidOperationException]$msg + } + $compiled_assembly = $compile.CompiledAssembly + } + + $type_accelerator = [PSObject].Assembly.GetType("System.Management.Automation.TypeAccelerators") + foreach ($accelerator in $type_accelerators) { + $type_name = $accelerator.TypeName + $found = $false + + foreach ($assembly_type in $compiled_assembly.GetTypes()) { + if ($assembly_type.Name -eq $type_name) { + $type_accelerator::Add($accelerator.Name, $assembly_type) + $found = $true + break + } + } + if (-not $found) { + throw "Failed to find compiled class '$type_name' for custom TypeAccelerator." + } + } + + # return the compiled assembly if PassThru is set. + if ($PassThru) { + return $compiled_assembly + } +} + +Export-ModuleMember -Function Add-CSharpType + diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.ArgvParser.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.ArgvParser.psm1 new file mode 100644 index 0000000..53d6870 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.ArgvParser.psm1 @@ -0,0 +1,78 @@ +# Copyright (c) 2017 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +# The rules used in these functions are derived from the below +# https://docs.microsoft.com/en-us/cpp/cpp/parsing-cpp-command-line-arguments +# https://blogs.msdn.microsoft.com/twistylittlepassagesallalike/2011/04/23/everyone-quotes-command-line-arguments-the-wrong-way/ + +Function Escape-Argument($argument, $force_quote = $false) { + # this converts a single argument to an escaped version, use Join-Arguments + # instead of this function as this only escapes a single string. + + # check if argument contains a space, \n, \t, \v or " + if ($force_quote -eq $false -and $argument.Length -gt 0 -and $argument -notmatch "[ \n\t\v`"]") { + # argument does not need escaping (and we don't want to force it), + # return as is + return $argument + } + else { + # we need to quote the arg so start with " + $new_argument = '"' + + for ($i = 0; $i -lt $argument.Length; $i++) { + $num_backslashes = 0 + + # get the number of \ from current char until end or not a \ + while ($i -ne ($argument.Length - 1) -and $argument[$i] -eq "\") { + $num_backslashes++ + $i++ + } + + $current_char = $argument[$i] + if ($i -eq ($argument.Length - 1) -and $current_char -eq "\") { + # We are at the end of the string so we need to add the same \ + # * 2 as the end char would be a " + $new_argument += ("\" * ($num_backslashes + 1) * 2) + } + elseif ($current_char -eq '"') { + # we have a inline ", we need to add the existing \ but * by 2 + # plus another 1 + $new_argument += ("\" * (($num_backslashes * 2) + 1)) + $new_argument += $current_char + } + else { + # normal character so no need to escape the \ we have counted + $new_argument += ("\" * $num_backslashes) + $new_argument += $current_char + } + } + + # we need to close the special arg with a " + $new_argument += '"' + return $new_argument + } +} + +Function Argv-ToString($arguments, $force_quote = $false) { + # Takes in a list of un escaped arguments and convert it to a single string + # that can be used when starting a new process. It will escape the + # characters as necessary in the list. + # While there is a CommandLineToArgvW function there is a no + # ArgvToCommandLineW that we can call to convert a list to an escaped + # string. + # You can also pass in force_quote so that each argument is quoted even + # when not necessary, by default only arguments with certain characters are + # quoted. + # TODO: add in another switch which will escape the args for cmd.exe + + $escaped_arguments = @() + foreach ($argument in $arguments) { + $escaped_argument = Escape-Argument -argument $argument -force_quote $force_quote + $escaped_arguments += $escaped_argument + } + + return ($escaped_arguments -join ' ') +} + +# this line must stay at the bottom to ensure all defined module parts are exported +Export-ModuleMember -Alias * -Function * -Cmdlet * diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.Backup.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.Backup.psm1 new file mode 100644 index 0000000..ca4f5ba --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.Backup.psm1 @@ -0,0 +1,34 @@ +# Copyright (c): 2018, Dag Wieers (@dagwieers) <dag@wieers.com> +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +Function Backup-File { + <# + .SYNOPSIS + Helper function to make a backup of a file. + .EXAMPLE + Backup-File -path $path -WhatIf:$check_mode +#> + [CmdletBinding(SupportsShouldProcess = $true)] + + Param ( + [Parameter(Mandatory = $true, ValueFromPipeline = $true)] + [string] $path + ) + + Process { + $backup_path = $null + if (Test-Path -LiteralPath $path -PathType Leaf) { + $backup_path = "$path.$pid." + [DateTime]::Now.ToString("yyyyMMdd-HHmmss") + ".bak"; + Try { + Copy-Item -LiteralPath $path -Destination $backup_path + } + Catch { + throw "Failed to create backup file '$backup_path' from '$path'. ($($_.Exception.Message))" + } + } + return $backup_path + } +} + +# This line must stay at the bottom to ensure all defined module parts are exported +Export-ModuleMember -Function Backup-File diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.CamelConversion.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.CamelConversion.psm1 new file mode 100644 index 0000000..9b86f84 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.CamelConversion.psm1 @@ -0,0 +1,69 @@ +# Copyright (c) 2017 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +# used by Convert-DictToSnakeCase to convert a string in camelCase +# format to snake_case +Function Convert-StringToSnakeCase($string) { + # cope with pluralized abbreaviations such as TargetGroupARNs + if ($string -cmatch "[A-Z]{3,}s") { + $replacement_string = $string -creplace $matches[0], "_$($matches[0].ToLower())" + + # handle when there was nothing before the plural pattern + if ($replacement_string.StartsWith("_") -and -not $string.StartsWith("_")) { + $replacement_string = $replacement_string.Substring(1) + } + $string = $replacement_string + } + $string = $string -creplace "(.)([A-Z][a-z]+)", '$1_$2' + $string = $string -creplace "([a-z0-9])([A-Z])", '$1_$2' + $string = $string.ToLower() + + return $string +} + +# used by Convert-DictToSnakeCase to covert list entries from camelCase +# to snake_case +Function Convert-ListToSnakeCase($list) { + $snake_list = [System.Collections.ArrayList]@() + foreach ($value in $list) { + if ($value -is [Hashtable]) { + $new_value = Convert-DictToSnakeCase -dict $value + } + elseif ($value -is [Array] -or $value -is [System.Collections.ArrayList]) { + $new_value = Convert-ListToSnakeCase -list $value + } + else { + $new_value = $value + } + [void]$snake_list.Add($new_value) + } + + return , $snake_list +} + +# converts a dict/hashtable keys from camelCase to snake_case +# this is to keep the return values consistent with the Ansible +# way of working. +Function Convert-DictToSnakeCase($dict) { + $snake_dict = @{} + foreach ($dict_entry in $dict.GetEnumerator()) { + $key = $dict_entry.Key + $snake_key = Convert-StringToSnakeCase -string $key + + $value = $dict_entry.Value + if ($value -is [Hashtable]) { + $snake_dict.$snake_key = Convert-DictToSnakeCase -dict $value + } + elseif ($value -is [Array] -or $value -is [System.Collections.ArrayList]) { + $snake_dict.$snake_key = Convert-ListToSnakeCase -list $value + } + else { + $snake_dict.$snake_key = $value + } + } + + return , $snake_dict +} + +# this line must stay at the bottom to ensure all defined module parts are exported +Export-ModuleMember -Alias * -Function * -Cmdlet * diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.CommandUtil.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.CommandUtil.psm1 new file mode 100644 index 0000000..56b5d39 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.CommandUtil.psm1 @@ -0,0 +1,107 @@ +# Copyright (c) 2017 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +#AnsibleRequires -CSharpUtil Ansible.Process + +Function Get-ExecutablePath { + <# + .SYNOPSIS + Get's the full path to an executable, will search the directory specified or ones in the PATH env var. + + .PARAMETER executable + [String]The executable to search for. + + .PARAMETER directory + [String] If set, the directory to search in. + + .OUTPUT + [String] The full path the executable specified. + #> + Param( + [String]$executable, + [String]$directory = $null + ) + + # we need to add .exe if it doesn't have an extension already + if (-not [System.IO.Path]::HasExtension($executable)) { + $executable = "$($executable).exe" + } + $full_path = [System.IO.Path]::GetFullPath($executable) + + if ($full_path -ne $executable -and $directory -ne $null) { + $file = Get-Item -LiteralPath "$directory\$executable" -Force -ErrorAction SilentlyContinue + } + else { + $file = Get-Item -LiteralPath $executable -Force -ErrorAction SilentlyContinue + } + + if ($null -ne $file) { + $executable_path = $file.FullName + } + else { + $executable_path = [Ansible.Process.ProcessUtil]::SearchPath($executable) + } + return $executable_path +} + +Function Run-Command { + <# + .SYNOPSIS + Run a command with the CreateProcess API and return the stdout/stderr and return code. + + .PARAMETER command + The full command, including the executable, to run. + + .PARAMETER working_directory + The working directory to set on the new process, will default to the current working dir. + + .PARAMETER stdin + A string to sent over the stdin pipe to the new process. + + .PARAMETER environment + A hashtable of key/value pairs to run with the command. If set, it will replace all other env vars. + + .PARAMETER output_encoding_override + The character encoding name for decoding stdout/stderr output of the process. + + .OUTPUT + [Hashtable] + [String]executable - The full path to the executable that was run + [String]stdout - The stdout stream of the process + [String]stderr - The stderr stream of the process + [Int32]rc - The return code of the process + #> + Param( + [string]$command, + [string]$working_directory = $null, + [string]$stdin = "", + [hashtable]$environment = @{}, + [string]$output_encoding_override = $null + ) + + # need to validate the working directory if it is set + if ($working_directory) { + # validate working directory is a valid path + if (-not (Test-Path -LiteralPath $working_directory)) { + throw "invalid working directory path '$working_directory'" + } + } + + # lpApplicationName needs to be the full path to an executable, we do this + # by getting the executable as the first arg and then getting the full path + $arguments = [Ansible.Process.ProcessUtil]::ParseCommandLine($command) + $executable = Get-ExecutablePath -executable $arguments[0] -directory $working_directory + + # run the command and get the results + $command_result = [Ansible.Process.ProcessUtil]::CreateProcess($executable, $command, $working_directory, $environment, $stdin, $output_encoding_override) + + return , @{ + executable = $executable + stdout = $command_result.StandardOut + stderr = $command_result.StandardError + rc = $command_result.ExitCode + } +} + +# this line must stay at the bottom to ensure all defined module parts are exported +Export-ModuleMember -Function Get-ExecutablePath, Run-Command diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.FileUtil.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.FileUtil.psm1 new file mode 100644 index 0000000..cd614d4 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.FileUtil.psm1 @@ -0,0 +1,66 @@ +# Copyright (c) 2017 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +<# +Test-Path/Get-Item cannot find/return info on files that are locked like +C:\pagefile.sys. These 2 functions are designed to work with these files and +provide similar functionality with the normal cmdlets with as minimal overhead +as possible. They work by using Get-ChildItem with a filter and return the +result from that. +#> + +Function Test-AnsiblePath { + [CmdletBinding()] + Param( + [Parameter(Mandatory = $true)][string]$Path + ) + # Replacement for Test-Path + try { + $file_attributes = [System.IO.File]::GetAttributes($Path) + } + catch [System.IO.FileNotFoundException], [System.IO.DirectoryNotFoundException] { + return $false + } + catch [NotSupportedException] { + # When testing a path like Cert:\LocalMachine\My, System.IO.File will + # not work, we just revert back to using Test-Path for this + return Test-Path -Path $Path + } + + if ([Int32]$file_attributes -eq -1) { + return $false + } + else { + return $true + } +} + +Function Get-AnsibleItem { + [CmdletBinding()] + Param( + [Parameter(Mandatory = $true)][string]$Path + ) + # Replacement for Get-Item + try { + $file_attributes = [System.IO.File]::GetAttributes($Path) + } + catch { + # if -ErrorAction SilentlyCotinue is set on the cmdlet and we failed to + # get the attributes, just return $null, otherwise throw the error + if ($ErrorActionPreference -ne "SilentlyContinue") { + throw $_ + } + return $null + } + if ([Int32]$file_attributes -eq -1) { + throw New-Object -TypeName System.Management.Automation.ItemNotFoundException -ArgumentList "Cannot find path '$Path' because it does not exist." + } + elseif ($file_attributes.HasFlag([System.IO.FileAttributes]::Directory)) { + return New-Object -TypeName System.IO.DirectoryInfo -ArgumentList $Path + } + else { + return New-Object -TypeName System.IO.FileInfo -ArgumentList $Path + } +} + +Export-ModuleMember -Function Test-AnsiblePath, Get-AnsibleItem diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.Legacy.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.Legacy.psm1 new file mode 100644 index 0000000..f0cb440 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.Legacy.psm1 @@ -0,0 +1,390 @@ +# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2014, and others +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +Set-StrictMode -Version 2.0 +$ErrorActionPreference = "Stop" + +Function Set-Attr($obj, $name, $value) { + <# + .SYNOPSIS + Helper function to set an "attribute" on a psobject instance in PowerShell. + This is a convenience to make adding Members to the object easier and + slightly more pythonic + .EXAMPLE + Set-Attr $result "changed" $true +#> + + # If the provided $obj is undefined, define one to be nice + If (-not $obj.GetType) { + $obj = @{ } + } + + Try { + $obj.$name = $value + } + Catch { + $obj | Add-Member -Force -MemberType NoteProperty -Name $name -Value $value + } +} + +Function Exit-Json($obj) { + <# + .SYNOPSIS + Helper function to convert a PowerShell object to JSON and output it, exiting + the script + .EXAMPLE + Exit-Json $result +#> + + # If the provided $obj is undefined, define one to be nice + If (-not $obj.GetType) { + $obj = @{ } + } + + if (-not $obj.ContainsKey('changed')) { + Set-Attr -obj $obj -name "changed" -value $false + } + + Write-Output $obj | ConvertTo-Json -Compress -Depth 99 + Exit +} + +Function Fail-Json($obj, $message = $null) { + <# + .SYNOPSIS + Helper function to add the "msg" property and "failed" property, convert the + PowerShell Hashtable to JSON and output it, exiting the script + .EXAMPLE + Fail-Json $result "This is the failure message" +#> + + if ($obj -is [hashtable] -or $obj -is [psobject]) { + # Nothing to do + } + elseif ($obj -is [string] -and $null -eq $message) { + # If we weren't given 2 args, and the only arg was a string, + # create a new Hashtable and use the arg as the failure message + $message = $obj + $obj = @{ } + } + else { + # If the first argument is undefined or a different type, + # make it a Hashtable + $obj = @{ } + } + + # Still using Set-Attr for PSObject compatibility + Set-Attr -obj $obj -name "msg" -value $message + Set-Attr -obj $obj -name "failed" -value $true + + if (-not $obj.ContainsKey('changed')) { + Set-Attr -obj $obj -name "changed" -value $false + } + + Write-Output $obj | ConvertTo-Json -Compress -Depth 99 + Exit 1 +} + +Function Add-Warning($obj, $message) { + <# + .SYNOPSIS + Helper function to add warnings, even if the warnings attribute was + not already set up. This is a convenience for the module developer + so they do not have to check for the attribute prior to adding. +#> + + if (-not $obj.ContainsKey("warnings")) { + $obj.warnings = @() + } + elseif ($obj.warnings -isnot [array]) { + throw "Add-Warning: warnings attribute is not an array" + } + + $obj.warnings += $message +} + +Function Add-DeprecationWarning($obj, $message, $version = $null) { + <# + .SYNOPSIS + Helper function to add deprecations, even if the deprecations attribute was + not already set up. This is a convenience for the module developer + so they do not have to check for the attribute prior to adding. +#> + if (-not $obj.ContainsKey("deprecations")) { + $obj.deprecations = @() + } + elseif ($obj.deprecations -isnot [array]) { + throw "Add-DeprecationWarning: deprecations attribute is not a list" + } + + $obj.deprecations += @{ + msg = $message + version = $version + } +} + +Function Expand-Environment($value) { + <# + .SYNOPSIS + Helper function to expand environment variables in values. By default + it turns any type to a string, but we ensure $null remains $null. +#> + if ($null -ne $value) { + [System.Environment]::ExpandEnvironmentVariables($value) + } + else { + $value + } +} + +Function Get-AnsibleParam { + <# + .SYNOPSIS + Helper function to get an "attribute" from a psobject instance in PowerShell. + This is a convenience to make getting Members from an object easier and + slightly more pythonic + .EXAMPLE + $attr = Get-AnsibleParam $response "code" -default "1" + .EXAMPLE + Get-AnsibleParam -obj $params -name "State" -default "Present" -ValidateSet "Present","Absent" -resultobj $resultobj -failifempty $true + Get-AnsibleParam also supports Parameter validation to save you from coding that manually + Note that if you use the failifempty option, you do need to specify resultobject as well. +#> + param ( + $obj, + $name, + $default = $null, + $resultobj = @{}, + $failifempty = $false, + $emptyattributefailmessage, + $ValidateSet, + $ValidateSetErrorMessage, + $type = $null, + $aliases = @() + ) + # Check if the provided Member $name or aliases exist in $obj and return it or the default. + try { + + $found = $null + # First try to find preferred parameter $name + $aliases = @($name) + $aliases + + # Iterate over aliases to find acceptable Member $name + foreach ($alias in $aliases) { + if ($obj.ContainsKey($alias)) { + $found = $alias + break + } + } + + if ($null -eq $found) { + throw + } + $name = $found + + if ($ValidateSet) { + + if ($ValidateSet -contains ($obj.$name)) { + $value = $obj.$name + } + else { + if ($null -eq $ValidateSetErrorMessage) { + #Auto-generated error should be sufficient in most use cases + $ValidateSetErrorMessage = "Get-AnsibleParam: Argument $name needs to be one of $($ValidateSet -join ",") but was $($obj.$name)." + } + Fail-Json -obj $resultobj -message $ValidateSetErrorMessage + } + } + else { + $value = $obj.$name + } + } + catch { + if ($failifempty -eq $false) { + $value = $default + } + else { + if (-not $emptyattributefailmessage) { + $emptyattributefailmessage = "Get-AnsibleParam: Missing required argument: $name" + } + Fail-Json -obj $resultobj -message $emptyattributefailmessage + } + } + + # If $null -eq $value, the parameter was unspecified by the user (deliberately or not) + # Please leave $null-values intact, modules need to know if a parameter was specified + if ($null -eq $value) { + return $null + } + + if ($type -eq "path") { + # Expand environment variables on path-type + $value = Expand-Environment($value) + # Test if a valid path is provided + if (-not (Test-Path -IsValid $value)) { + $path_invalid = $true + # could still be a valid-shaped path with a nonexistent drive letter + if ($value -match "^\w:") { + # rewrite path with a valid drive letter and recheck the shape- this might still fail, eg, a nonexistent non-filesystem PS path + if (Test-Path -IsValid $(@(Get-PSDrive -PSProvider Filesystem)[0].Name + $value.Substring(1))) { + $path_invalid = $false + } + } + if ($path_invalid) { + Fail-Json -obj $resultobj -message "Get-AnsibleParam: Parameter '$name' has an invalid path '$value' specified." + } + } + } + elseif ($type -eq "str") { + # Convert str types to real Powershell strings + $value = $value.ToString() + } + elseif ($type -eq "bool") { + # Convert boolean types to real Powershell booleans + $value = $value | ConvertTo-Bool + } + elseif ($type -eq "int") { + # Convert int types to real Powershell integers + $value = $value -as [int] + } + elseif ($type -eq "float") { + # Convert float types to real Powershell floats + $value = $value -as [float] + } + elseif ($type -eq "list") { + if ($value -is [array]) { + # Nothing to do + } + elseif ($value -is [string]) { + # Convert string type to real Powershell array + $value = $value.Split(",").Trim() + } + elseif ($value -is [int]) { + $value = @($value) + } + else { + Fail-Json -obj $resultobj -message "Get-AnsibleParam: Parameter '$name' is not a YAML list." + } + # , is not a typo, forces it to return as a list when it is empty or only has 1 entry + return , $value + } + + return $value +} + +#Alias Get-attr-->Get-AnsibleParam for backwards compat. Only add when needed to ease debugging of scripts +If (-not(Get-Alias -Name "Get-attr" -ErrorAction SilentlyContinue)) { + New-Alias -Name Get-attr -Value Get-AnsibleParam +} + +Function ConvertTo-Bool { + <# + .SYNOPSIS + Helper filter/pipeline function to convert a value to boolean following current + Ansible practices + .EXAMPLE + $is_true = "true" | ConvertTo-Bool +#> + param( + [parameter(valuefrompipeline = $true)] + $obj + ) + + process { + $boolean_strings = "yes", "on", "1", "true", 1 + $obj_string = [string]$obj + + if (($obj -is [boolean] -and $obj) -or $boolean_strings -contains $obj_string.ToLower()) { + return $true + } + else { + return $false + } + } +} + +Function Parse-Args { + <# + .SYNOPSIS + Helper function to parse Ansible JSON arguments from a "file" passed as + the single argument to the module. + .EXAMPLE + $params = Parse-Args $args +#> + [Diagnostics.CodeAnalysis.SuppressMessageAttribute("PSUseSingularNouns", "", Justification = "Cannot change the name now")] + param ($arguments, $supports_check_mode = $false) + + $params = New-Object psobject + If ($arguments.Length -gt 0) { + $params = Get-Content $arguments[0] | ConvertFrom-Json + } + Else { + $params = $complex_args + } + $check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false + If ($check_mode -and -not $supports_check_mode) { + Exit-Json @{ + skipped = $true + changed = $false + msg = "remote module does not support check mode" + } + } + return $params +} + + +Function Get-FileChecksum($path, $algorithm = 'sha1') { + <# + .SYNOPSIS + Helper function to calculate a hash of a file in a way which PowerShell 3 + and above can handle +#> + If (Test-Path -LiteralPath $path -PathType Leaf) { + switch ($algorithm) { + 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider } + 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider } + 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider } + 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider } + 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider } + default { Fail-Json @{} "Unsupported hash algorithm supplied '$algorithm'" } + } + + If ($PSVersionTable.PSVersion.Major -ge 4) { + $raw_hash = Get-FileHash -LiteralPath $path -Algorithm $algorithm + $hash = $raw_hash.Hash.ToLower() + } + Else { + $fp = [System.IO.File]::Open($path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, [System.IO.FileShare]::ReadWrite); + $hash = [System.BitConverter]::ToString($sp.ComputeHash($fp)).Replace("-", "").ToLower(); + $fp.Dispose(); + } + } + ElseIf (Test-Path -LiteralPath $path -PathType Container) { + $hash = "3"; + } + Else { + $hash = "1"; + } + return $hash +} + +Function Get-PendingRebootStatus { + <# + .SYNOPSIS + Check if reboot is required, if so notify CA. + Function returns true if computer has a pending reboot +#> + $featureData = Invoke-CimMethod -EA Ignore -Name GetServerFeature -Namespace root\microsoft\windows\servermanager -Class MSFT_ServerManagerTasks + $regData = Get-ItemProperty "HKLM:\SYSTEM\CurrentControlSet\Control\Session Manager" "PendingFileRenameOperations" -EA Ignore + $CBSRebootStatus = Get-ChildItem "HKLM:\\SOFTWARE\Microsoft\Windows\CurrentVersion\Component Based Servicing" -ErrorAction SilentlyContinue | + Where-Object { $_.PSChildName -eq "RebootPending" } + if (($featureData -and $featureData.RequiresReboot) -or $regData -or $CBSRebootStatus) { + return $True + } + else { + return $False + } +} + +# this line must stay at the bottom to ensure all defined module parts are exported +Export-ModuleMember -Alias * -Function * -Cmdlet * diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.LinkUtil.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.LinkUtil.psm1 new file mode 100644 index 0000000..1a251f6 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.LinkUtil.psm1 @@ -0,0 +1,464 @@ +# Copyright (c) 2017 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +#Requires -Module Ansible.ModuleUtils.PrivilegeUtil + +Function Load-LinkUtils { + [Diagnostics.CodeAnalysis.SuppressMessageAttribute("PSUseSingularNouns", "", Justification = "Cannot change the name now")] + param () + + $link_util = @' +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.IO; +using System.Runtime.InteropServices; +using System.Text; + +namespace Ansible +{ + public enum LinkType + { + SymbolicLink, + JunctionPoint, + HardLink + } + + public class LinkUtilWin32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + + public LinkUtilWin32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + + public LinkUtilWin32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + + public override string Message { get { return _msg; } } + public static explicit operator LinkUtilWin32Exception(string message) { return new LinkUtilWin32Exception(message); } + } + + public class LinkInfo + { + public LinkType Type { get; internal set; } + public string PrintName { get; internal set; } + public string SubstituteName { get; internal set; } + public string AbsolutePath { get; internal set; } + public string TargetPath { get; internal set; } + public string[] HardTargets { get; internal set; } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct REPARSE_DATA_BUFFER + { + public UInt32 ReparseTag; + public UInt16 ReparseDataLength; + public UInt16 Reserved; + public UInt16 SubstituteNameOffset; + public UInt16 SubstituteNameLength; + public UInt16 PrintNameOffset; + public UInt16 PrintNameLength; + + [MarshalAs(UnmanagedType.ByValArray, SizeConst = LinkUtil.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)] + public char[] PathBuffer; + } + + public class LinkUtil + { + public const int MAXIMUM_REPARSE_DATA_BUFFER_SIZE = 1024 * 16; + + private const UInt32 FILE_FLAG_BACKUP_SEMANTICS = 0x02000000; + private const UInt32 FILE_FLAG_OPEN_REPARSE_POINT = 0x00200000; + + private const UInt32 FSCTL_GET_REPARSE_POINT = 0x000900A8; + private const UInt32 FSCTL_SET_REPARSE_POINT = 0x000900A4; + private const UInt32 FILE_DEVICE_FILE_SYSTEM = 0x00090000; + + private const UInt32 IO_REPARSE_TAG_MOUNT_POINT = 0xA0000003; + private const UInt32 IO_REPARSE_TAG_SYMLINK = 0xA000000C; + + private const UInt32 SYMLINK_FLAG_RELATIVE = 0x00000001; + + private const Int64 INVALID_HANDLE_VALUE = -1; + + private const UInt32 SIZE_OF_WCHAR = 2; + + private const UInt32 SYMBOLIC_LINK_FLAG_FILE = 0x00000000; + private const UInt32 SYMBOLIC_LINK_FLAG_DIRECTORY = 0x00000001; + + [DllImport("kernel32.dll", CharSet = CharSet.Auto)] + private static extern SafeFileHandle CreateFile( + string lpFileName, + [MarshalAs(UnmanagedType.U4)] FileAccess dwDesiredAccess, + [MarshalAs(UnmanagedType.U4)] FileShare dwShareMode, + IntPtr lpSecurityAttributes, + [MarshalAs(UnmanagedType.U4)] FileMode dwCreationDisposition, + UInt32 dwFlagsAndAttributes, + IntPtr hTemplateFile); + + // Used by GetReparsePointInfo() + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool DeviceIoControl( + SafeFileHandle hDevice, + UInt32 dwIoControlCode, + IntPtr lpInBuffer, + UInt32 nInBufferSize, + out REPARSE_DATA_BUFFER lpOutBuffer, + UInt32 nOutBufferSize, + out UInt32 lpBytesReturned, + IntPtr lpOverlapped); + + // Used by CreateJunctionPoint() + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool DeviceIoControl( + SafeFileHandle hDevice, + UInt32 dwIoControlCode, + REPARSE_DATA_BUFFER lpInBuffer, + UInt32 nInBufferSize, + IntPtr lpOutBuffer, + UInt32 nOutBufferSize, + out UInt32 lpBytesReturned, + IntPtr lpOverlapped); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool GetVolumePathName( + string lpszFileName, + StringBuilder lpszVolumePathName, + ref UInt32 cchBufferLength); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern IntPtr FindFirstFileNameW( + string lpFileName, + UInt32 dwFlags, + ref UInt32 StringLength, + StringBuilder LinkName); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool FindNextFileNameW( + IntPtr hFindStream, + ref UInt32 StringLength, + StringBuilder LinkName); + + [DllImport("kernel32.dll", SetLastError = true)] + private static extern bool FindClose( + IntPtr hFindFile); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool RemoveDirectory( + string lpPathName); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool DeleteFile( + string lpFileName); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool CreateSymbolicLink( + string lpSymlinkFileName, + string lpTargetFileName, + UInt32 dwFlags); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Auto)] + private static extern bool CreateHardLink( + string lpFileName, + string lpExistingFileName, + IntPtr lpSecurityAttributes); + + public static LinkInfo GetLinkInfo(string linkPath) + { + FileAttributes attr = File.GetAttributes(linkPath); + if (attr.HasFlag(FileAttributes.ReparsePoint)) + return GetReparsePointInfo(linkPath); + + if (!attr.HasFlag(FileAttributes.Directory)) + return GetHardLinkInfo(linkPath); + + return null; + } + + public static void DeleteLink(string linkPath) + { + bool success; + FileAttributes attr = File.GetAttributes(linkPath); + if (attr.HasFlag(FileAttributes.Directory)) + { + success = RemoveDirectory(linkPath); + } + else + { + success = DeleteFile(linkPath); + } + + if (!success) + throw new LinkUtilWin32Exception(String.Format("Failed to delete link at {0}", linkPath)); + } + + public static void CreateLink(string linkPath, String linkTarget, LinkType linkType) + { + switch (linkType) + { + case LinkType.SymbolicLink: + UInt32 linkFlags; + FileAttributes attr = File.GetAttributes(linkTarget); + if (attr.HasFlag(FileAttributes.Directory)) + linkFlags = SYMBOLIC_LINK_FLAG_DIRECTORY; + else + linkFlags = SYMBOLIC_LINK_FLAG_FILE; + + if (!CreateSymbolicLink(linkPath, linkTarget, linkFlags)) + throw new LinkUtilWin32Exception(String.Format("CreateSymbolicLink({0}, {1}, {2}) failed", linkPath, linkTarget, linkFlags)); + break; + case LinkType.JunctionPoint: + CreateJunctionPoint(linkPath, linkTarget); + break; + case LinkType.HardLink: + if (!CreateHardLink(linkPath, linkTarget, IntPtr.Zero)) + throw new LinkUtilWin32Exception(String.Format("CreateHardLink({0}, {1}) failed", linkPath, linkTarget)); + break; + } + } + + private static LinkInfo GetHardLinkInfo(string linkPath) + { + UInt32 maxPath = 260; + List<string> result = new List<string>(); + + StringBuilder sb = new StringBuilder((int)maxPath); + UInt32 stringLength = maxPath; + if (!GetVolumePathName(linkPath, sb, ref stringLength)) + throw new LinkUtilWin32Exception("GetVolumePathName() failed"); + string volume = sb.ToString(); + + stringLength = maxPath; + IntPtr findHandle = FindFirstFileNameW(linkPath, 0, ref stringLength, sb); + if (findHandle.ToInt64() != INVALID_HANDLE_VALUE) + { + try + { + do + { + string hardLinkPath = sb.ToString(); + if (hardLinkPath.StartsWith("\\")) + hardLinkPath = hardLinkPath.Substring(1, hardLinkPath.Length - 1); + + result.Add(Path.Combine(volume, hardLinkPath)); + stringLength = maxPath; + + } while (FindNextFileNameW(findHandle, ref stringLength, sb)); + } + finally + { + FindClose(findHandle); + } + } + + if (result.Count > 1) + return new LinkInfo + { + Type = LinkType.HardLink, + HardTargets = result.ToArray() + }; + + return null; + } + + private static LinkInfo GetReparsePointInfo(string linkPath) + { + SafeFileHandle fileHandle = CreateFile( + linkPath, + FileAccess.Read, + FileShare.None, + IntPtr.Zero, + FileMode.Open, + FILE_FLAG_OPEN_REPARSE_POINT | FILE_FLAG_BACKUP_SEMANTICS, + IntPtr.Zero); + + if (fileHandle.IsInvalid) + throw new LinkUtilWin32Exception(String.Format("CreateFile({0}) failed", linkPath)); + + REPARSE_DATA_BUFFER buffer = new REPARSE_DATA_BUFFER(); + UInt32 bytesReturned; + try + { + if (!DeviceIoControl( + fileHandle, + FSCTL_GET_REPARSE_POINT, + IntPtr.Zero, + 0, + out buffer, + MAXIMUM_REPARSE_DATA_BUFFER_SIZE, + out bytesReturned, + IntPtr.Zero)) + throw new LinkUtilWin32Exception(String.Format("DeviceIoControl() failed for file at {0}", linkPath)); + } + finally + { + fileHandle.Dispose(); + } + + bool isRelative = false; + int pathOffset = 0; + LinkType linkType; + if (buffer.ReparseTag == IO_REPARSE_TAG_SYMLINK) + { + UInt32 bufferFlags = Convert.ToUInt32(buffer.PathBuffer[0]) + Convert.ToUInt32(buffer.PathBuffer[1]); + if (bufferFlags == SYMLINK_FLAG_RELATIVE) + isRelative = true; + pathOffset = 2; + linkType = LinkType.SymbolicLink; + } + else if (buffer.ReparseTag == IO_REPARSE_TAG_MOUNT_POINT) + { + linkType = LinkType.JunctionPoint; + } + else + { + string errorMessage = String.Format("Invalid Reparse Tag: {0}", buffer.ReparseTag.ToString()); + throw new Exception(errorMessage); + } + + string printName = new string(buffer.PathBuffer, + (int)(buffer.PrintNameOffset / SIZE_OF_WCHAR) + pathOffset, + (int)(buffer.PrintNameLength / SIZE_OF_WCHAR)); + string substituteName = new string(buffer.PathBuffer, + (int)(buffer.SubstituteNameOffset / SIZE_OF_WCHAR) + pathOffset, + (int)(buffer.SubstituteNameLength / SIZE_OF_WCHAR)); + + // TODO: should we check for \?\UNC\server for convert it to the NT style \\server path + // Remove the leading Windows object directory \?\ from the path if present + string targetPath = substituteName; + if (targetPath.StartsWith("\\??\\")) + targetPath = targetPath.Substring(4, targetPath.Length - 4); + + string absolutePath = targetPath; + if (isRelative) + absolutePath = Path.GetFullPath(Path.Combine(new FileInfo(linkPath).Directory.FullName, targetPath)); + + return new LinkInfo + { + Type = linkType, + PrintName = printName, + SubstituteName = substituteName, + AbsolutePath = absolutePath, + TargetPath = targetPath + }; + } + + private static void CreateJunctionPoint(string linkPath, string linkTarget) + { + // We need to create the link as a dir beforehand + Directory.CreateDirectory(linkPath); + SafeFileHandle fileHandle = CreateFile( + linkPath, + FileAccess.Write, + FileShare.Read | FileShare.Write | FileShare.None, + IntPtr.Zero, + FileMode.Open, + FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT, + IntPtr.Zero); + + if (fileHandle.IsInvalid) + throw new LinkUtilWin32Exception(String.Format("CreateFile({0}) failed", linkPath)); + + try + { + string substituteName = "\\??\\" + Path.GetFullPath(linkTarget); + string printName = linkTarget; + + REPARSE_DATA_BUFFER buffer = new REPARSE_DATA_BUFFER(); + buffer.SubstituteNameOffset = 0; + buffer.SubstituteNameLength = (UInt16)(substituteName.Length * SIZE_OF_WCHAR); + buffer.PrintNameOffset = (UInt16)(buffer.SubstituteNameLength + 2); + buffer.PrintNameLength = (UInt16)(printName.Length * SIZE_OF_WCHAR); + + buffer.ReparseTag = IO_REPARSE_TAG_MOUNT_POINT; + buffer.ReparseDataLength = (UInt16)(buffer.SubstituteNameLength + buffer.PrintNameLength + 12); + buffer.PathBuffer = new char[MAXIMUM_REPARSE_DATA_BUFFER_SIZE]; + + byte[] unicodeBytes = Encoding.Unicode.GetBytes(substituteName + "\0" + printName); + char[] pathBuffer = Encoding.Unicode.GetChars(unicodeBytes); + Array.Copy(pathBuffer, buffer.PathBuffer, pathBuffer.Length); + + UInt32 bytesReturned; + if (!DeviceIoControl( + fileHandle, + FSCTL_SET_REPARSE_POINT, + buffer, + (UInt32)(buffer.ReparseDataLength + 8), + IntPtr.Zero, 0, + out bytesReturned, + IntPtr.Zero)) + throw new LinkUtilWin32Exception(String.Format("DeviceIoControl() failed to create junction point at {0} to {1}", linkPath, linkTarget)); + } + finally + { + fileHandle.Dispose(); + } + } + } +} +'@ + + # FUTURE: find a better way to get the _ansible_remote_tmp variable + $original_tmp = $env:TMP + $original_lib = $env:LIB + + $remote_tmp = $original_tmp + $module_params = Get-Variable -Name complex_args -ErrorAction SilentlyContinue + if ($module_params) { + if ($module_params.Value.ContainsKey("_ansible_remote_tmp") ) { + $remote_tmp = $module_params.Value["_ansible_remote_tmp"] + $remote_tmp = [System.Environment]::ExpandEnvironmentVariables($remote_tmp) + } + } + + $env:TMP = $remote_tmp + $env:LIB = $null + Add-Type -TypeDefinition $link_util + $env:TMP = $original_tmp + $env:LIB = $original_lib + + # enable the SeBackupPrivilege if it is disabled + $state = Get-AnsiblePrivilege -Name SeBackupPrivilege + if ($state -eq $false) { + Set-AnsiblePrivilege -Name SeBackupPrivilege -Value $true + } +} + +Function Get-Link($link_path) { + $link_info = [Ansible.LinkUtil]::GetLinkInfo($link_path) + return $link_info +} + +Function Remove-Link($link_path) { + [Ansible.LinkUtil]::DeleteLink($link_path) +} + +Function New-Link($link_path, $link_target, $link_type) { + if (-not (Test-Path -LiteralPath $link_target)) { + throw "link_target '$link_target' does not exist, cannot create link" + } + + switch ($link_type) { + "link" { + $type = [Ansible.LinkType]::SymbolicLink + } + "junction" { + if (Test-Path -LiteralPath $link_target -PathType Leaf) { + throw "cannot set the target for a junction point to a file" + } + $type = [Ansible.LinkType]::JunctionPoint + } + "hard" { + if (Test-Path -LiteralPath $link_target -PathType Container) { + throw "cannot set the target for a hard link to a directory" + } + $type = [Ansible.LinkType]::HardLink + } + default { throw "invalid link_type option $($link_type): expecting link, junction, hard" } + } + [Ansible.LinkUtil]::CreateLink($link_path, $link_target, $type) +} + +# this line must stay at the bottom to ensure all defined module parts are exported +Export-ModuleMember -Alias * -Function * -Cmdlet * diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.PrivilegeUtil.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.PrivilegeUtil.psm1 new file mode 100644 index 0000000..78f0d64 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.PrivilegeUtil.psm1 @@ -0,0 +1,83 @@ +# Copyright (c) 2018 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +#AnsibleRequires -CSharpUtil Ansible.Privilege + +Function Get-AnsiblePrivilege { + <# + .SYNOPSIS + Get the status of a privilege for the current process. This returns + $true - the privilege is enabled + $false - the privilege is disabled + $null - the privilege is removed from the token + + If Name is not a valid privilege name, this will throw an + ArgumentException. + + .EXAMPLE + Get-AnsiblePrivilege -Name SeDebugPrivilege + #> + [CmdletBinding()] + param( + [Parameter(Mandatory = $true)][String]$Name + ) + + if (-not [Ansible.Privilege.PrivilegeUtil]::CheckPrivilegeName($Name)) { + throw [System.ArgumentException] "Invalid privilege name '$Name'" + } + + $process_token = [Ansible.Privilege.PrivilegeUtil]::GetCurrentProcess() + $privilege_info = [Ansible.Privilege.PrivilegeUtil]::GetAllPrivilegeInfo($process_token) + if ($privilege_info.ContainsKey($Name)) { + $status = $privilege_info.$Name + return $status.HasFlag([Ansible.Privilege.PrivilegeAttributes]::Enabled) + } + else { + return $null + } +} + +Function Set-AnsiblePrivilege { + <# + .SYNOPSIS + Enables/Disables a privilege on the current process' token. If a privilege + has been removed from the process token, this will throw an + InvalidOperationException. + + .EXAMPLE + # enable a privilege + Set-AnsiblePrivilege -Name SeCreateSymbolicLinkPrivilege -Value $true + + # disable a privilege + Set-AnsiblePrivilege -Name SeCreateSymbolicLinkPrivilege -Value $false + #> + [CmdletBinding(SupportsShouldProcess)] + param( + [Parameter(Mandatory = $true)][String]$Name, + [Parameter(Mandatory = $true)][bool]$Value + ) + + $action = switch ($Value) { + $true { "Enable" } + $false { "Disable" } + } + + $current_state = Get-AnsiblePrivilege -Name $Name + if ($current_state -eq $Value) { + return # no change needs to occur + } + elseif ($null -eq $current_state) { + # once a privilege is removed from a token we cannot do anything with it + throw [System.InvalidOperationException] "Cannot $($action.ToLower()) the privilege '$Name' as it has been removed from the token" + } + + $process_token = [Ansible.Privilege.PrivilegeUtil]::GetCurrentProcess() + if ($PSCmdlet.ShouldProcess($Name, "$action the privilege $Name")) { + $new_state = New-Object -TypeName 'System.Collections.Generic.Dictionary`2[[System.String], [System.Nullable`1[System.Boolean]]]' + $new_state.Add($Name, $Value) + [Ansible.Privilege.PrivilegeUtil]::SetTokenPrivileges($process_token, $new_state) > $null + } +} + +Export-ModuleMember -Function Get-AnsiblePrivilege, Set-AnsiblePrivilege + diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.SID.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.SID.psm1 new file mode 100644 index 0000000..d1f4b62 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.SID.psm1 @@ -0,0 +1,99 @@ +# Copyright (c) 2017 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +Function Convert-FromSID($sid) { + # Converts a SID to a Down-Level Logon name in the form of DOMAIN\UserName + # If the SID is for a local user or group then DOMAIN would be the server + # name. + + $account_object = New-Object System.Security.Principal.SecurityIdentifier($sid) + try { + $nt_account = $account_object.Translate([System.Security.Principal.NTAccount]) + } + catch { + Fail-Json -obj @{} -message "failed to convert sid '$sid' to a logon name: $($_.Exception.Message)" + } + + return $nt_account.Value +} + +Function Convert-ToSID { + [Diagnostics.CodeAnalysis.SuppressMessageAttribute("PSAvoidUsingEmptyCatchBlock", "", + Justification = "We don't care if converting to a SID fails, just that it failed or not")] + param($account_name) + # Converts an account name to a SID, it can take in the following forms + # SID: Will just return the SID value that was passed in + # UPN: + # principal@domain (Domain users only) + # Down-Level Login Name + # DOMAIN\principal (Domain) + # SERVERNAME\principal (Local) + # .\principal (Local) + # NT AUTHORITY\SYSTEM (Local Service Accounts) + # Login Name + # principal (Local/Local Service Accounts) + + try { + $sid = New-Object -TypeName System.Security.Principal.SecurityIdentifier -ArgumentList $account_name + return $sid.Value + } + catch {} + + if ($account_name -like "*\*") { + $account_name_split = $account_name -split "\\" + if ($account_name_split[0] -eq ".") { + $domain = $env:COMPUTERNAME + } + else { + $domain = $account_name_split[0] + } + $username = $account_name_split[1] + } + else { + $domain = $null + $username = $account_name + } + + if ($domain) { + # searching for a local group with the servername prefixed will fail, + # need to check for this situation and only use NTAccount(String) + if ($domain -eq $env:COMPUTERNAME) { + $adsi = [ADSI]("WinNT://$env:COMPUTERNAME,computer") + $group = $adsi.psbase.children | Where-Object { $_.schemaClassName -eq "group" -and $_.Name -eq $username } + } + else { + $group = $null + } + if ($group) { + $account = New-Object System.Security.Principal.NTAccount($username) + } + else { + $account = New-Object System.Security.Principal.NTAccount($domain, $username) + } + } + else { + # when in a domain NTAccount(String) will favour domain lookups check + # if username is a local user and explicitly search on the localhost for + # that account + $adsi = [ADSI]("WinNT://$env:COMPUTERNAME,computer") + $user = $adsi.psbase.children | Where-Object { $_.schemaClassName -eq "user" -and $_.Name -eq $username } + if ($user) { + $account = New-Object System.Security.Principal.NTAccount($env:COMPUTERNAME, $username) + } + else { + $account = New-Object System.Security.Principal.NTAccount($username) + } + } + + try { + $account_sid = $account.Translate([System.Security.Principal.SecurityIdentifier]) + } + catch { + Fail-Json @{} "account_name $account_name is not a valid account, cannot get SID: $($_.Exception.Message)" + } + + return $account_sid.Value +} + +# this line must stay at the bottom to ensure all defined module parts are exported +Export-ModuleMember -Alias * -Function * -Cmdlet * diff --git a/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.WebRequest.psm1 b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.WebRequest.psm1 new file mode 100644 index 0000000..b59ba72 --- /dev/null +++ b/lib/ansible/module_utils/powershell/Ansible.ModuleUtils.WebRequest.psm1 @@ -0,0 +1,530 @@ +# Copyright (c) 2019 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +Function Get-AnsibleWebRequest { + <# + .SYNOPSIS + Creates a System.Net.WebRequest object based on common URL module options in Ansible. + + .DESCRIPTION + Will create a WebRequest based on common input options within Ansible. This can be used manually or with + Invoke-WithWebRequest. + + .PARAMETER Uri + The URI to create the web request for. + + .PARAMETER Method + The protocol method to use, if omitted, will use the default value for the URI protocol specified. + + .PARAMETER FollowRedirects + Whether to follow redirect reponses. This is only valid when using a HTTP URI. + all - Will follow all redirects + none - Will follow no redirects + safe - Will only follow redirects when GET or HEAD is used as the Method + + .PARAMETER Headers + A hashtable or dictionary of header values to set on the request. This is only valid for a HTTP URI. + + .PARAMETER HttpAgent + A string to set for the 'User-Agent' header. This is only valid for a HTTP URI. + + .PARAMETER MaximumRedirection + The maximum number of redirections that will be followed. This is only valid for a HTTP URI. + + .PARAMETER Timeout + The timeout in seconds that defines how long to wait until the request times out. + + .PARAMETER ValidateCerts + Whether to validate SSL certificates, default to True. + + .PARAMETER ClientCert + The path to PFX file to use for X509 authentication. This is only valid for a HTTP URI. This path can either + be a filesystem path (C:\folder\cert.pfx) or a PSPath to a credential (Cert:\CurrentUser\My\<thumbprint>). + + .PARAMETER ClientCertPassword + The password for the PFX certificate if required. This is only valid for a HTTP URI. + + .PARAMETER ForceBasicAuth + Whether to set the Basic auth header on the first request instead of when required. This is only valid for a + HTTP URI. + + .PARAMETER UrlUsername + The username to use for authenticating with the target. + + .PARAMETER UrlPassword + The password to use for authenticating with the target. + + .PARAMETER UseDefaultCredential + Whether to use the current user's credentials if available. This will only work when using Become, using SSH with + password auth, or WinRM with CredSSP or Kerberos with credential delegation. + + .PARAMETER UseProxy + Whether to use the default proxy defined in IE (WinINet) for the user or set no proxy at all. This should not + be set to True when ProxyUrl is also defined. + + .PARAMETER ProxyUrl + An explicit proxy server to use for the request instead of relying on the default proxy in IE. This is only + valid for a HTTP URI. + + .PARAMETER ProxyUsername + An optional username to use for proxy authentication. + + .PARAMETER ProxyPassword + The password for ProxyUsername. + + .PARAMETER ProxyUseDefaultCredential + Whether to use the current user's credentials for proxy authentication if available. This will only work when + using Become, using SSH with password auth, or WinRM with CredSSP or Kerberos with credential delegation. + + .PARAMETER Module + The AnsibleBasic module that can be used as a backup parameter source or a way to return warnings back to the + Ansible controller. + + .EXAMPLE + $spec = @{ + options = @{} + } + $module = Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWebRequestSpec)) + + $web_request = Get-AnsibleWebRequest -Module $module + #> + [CmdletBinding()] + [OutputType([System.Net.WebRequest])] + Param ( + [Alias("url")] + [System.Uri] + $Uri, + + [System.String] + $Method, + + [Alias("follow_redirects")] + [ValidateSet("all", "none", "safe")] + [System.String] + $FollowRedirects = "safe", + + [System.Collections.IDictionary] + $Headers, + + [Alias("http_agent")] + [System.String] + $HttpAgent = "ansible-httpget", + + [Alias("maximum_redirection")] + [System.Int32] + $MaximumRedirection = 50, + + [System.Int32] + $Timeout = 30, + + [Alias("validate_certs")] + [System.Boolean] + $ValidateCerts = $true, + + # Credential params + [Alias("client_cert")] + [System.String] + $ClientCert, + + [Alias("client_cert_password")] + [System.String] + $ClientCertPassword, + + [Alias("force_basic_auth")] + [Switch] + $ForceBasicAuth, + + [Alias("url_username")] + [System.String] + $UrlUsername, + + [Alias("url_password")] + [System.String] + $UrlPassword, + + [Alias("use_default_credential")] + [Switch] + $UseDefaultCredential, + + # Proxy params + [Alias("use_proxy")] + [System.Boolean] + $UseProxy = $true, + + [Alias("proxy_url")] + [System.String] + $ProxyUrl, + + [Alias("proxy_username")] + [System.String] + $ProxyUsername, + + [Alias("proxy_password")] + [System.String] + $ProxyPassword, + + [Alias("proxy_use_default_credential")] + [Switch] + $ProxyUseDefaultCredential, + + [ValidateScript({ $_.GetType().FullName -eq 'Ansible.Basic.AnsibleModule' })] + [System.Object] + $Module + ) + + # Set module options for parameters unless they were explicitly passed in. + if ($Module) { + foreach ($param in $PSCmdlet.MyInvocation.MyCommand.Parameters.GetEnumerator()) { + if ($PSBoundParameters.ContainsKey($param.Key)) { + # Was set explicitly we want to use that value + continue + } + + foreach ($alias in @($Param.Key) + $param.Value.Aliases) { + if ($Module.Params.ContainsKey($alias)) { + $var_value = $Module.Params.$alias -as $param.Value.ParameterType + Set-Variable -Name $param.Key -Value $var_value + break + } + } + } + } + + # Disable certificate validation if requested + # FUTURE: set this on ServerCertificateValidationCallback of the HttpWebRequest once .NET 4.5 is the minimum + if (-not $ValidateCerts) { + [System.Net.ServicePointManager]::ServerCertificateValidationCallback = { $true } + } + + # Enable TLS1.1/TLS1.2 if they're available but disabled (eg. .NET 4.5) + $security_protocols = [System.Net.ServicePointManager]::SecurityProtocol -bor [System.Net.SecurityProtocolType]::SystemDefault + if ([System.Net.SecurityProtocolType].GetMember("Tls11").Count -gt 0) { + $security_protocols = $security_protocols -bor [System.Net.SecurityProtocolType]::Tls11 + } + if ([System.Net.SecurityProtocolType].GetMember("Tls12").Count -gt 0) { + $security_protocols = $security_protocols -bor [System.Net.SecurityProtocolType]::Tls12 + } + [System.Net.ServicePointManager]::SecurityProtocol = $security_protocols + + $web_request = [System.Net.WebRequest]::Create($Uri) + if ($Method) { + $web_request.Method = $Method + } + $web_request.Timeout = $Timeout * 1000 + + if ($UseDefaultCredential -and $web_request -is [System.Net.HttpWebRequest]) { + $web_request.UseDefaultCredentials = $true + } + elseif ($UrlUsername) { + if ($ForceBasicAuth) { + $auth_value = [System.Convert]::ToBase64String([System.Text.Encoding]::ASCII.GetBytes(("{0}:{1}" -f $UrlUsername, $UrlPassword))) + $web_request.Headers.Add("Authorization", "Basic $auth_value") + } + else { + $credential = New-Object -TypeName System.Net.NetworkCredential -ArgumentList $UrlUsername, $UrlPassword + $web_request.Credentials = $credential + } + } + + if ($ClientCert) { + # Expecting either a filepath or PSPath (Cert:\CurrentUser\My\<thumbprint>) + $cert = Get-Item -LiteralPath $ClientCert -ErrorAction SilentlyContinue + if ($null -eq $cert) { + Write-Error -Message "Client certificate '$ClientCert' does not exist" -Category ObjectNotFound + return + } + + $crypto_ns = 'System.Security.Cryptography.X509Certificates' + if ($cert.PSProvider.Name -ne 'Certificate') { + try { + $cert = New-Object -TypeName "$crypto_ns.X509Certificate2" -ArgumentList @( + $ClientCert, $ClientCertPassword + ) + } + catch [System.Security.Cryptography.CryptographicException] { + Write-Error -Message "Failed to read client certificate at '$ClientCert'" -Exception $_.Exception -Category SecurityError + return + } + } + $web_request.ClientCertificates = New-Object -TypeName "$crypto_ns.X509Certificate2Collection" -ArgumentList @( + $cert + ) + } + + if (-not $UseProxy) { + $proxy = $null + } + elseif ($ProxyUrl) { + $proxy = New-Object -TypeName System.Net.WebProxy -ArgumentList $ProxyUrl, $true + } + else { + $proxy = $web_request.Proxy + } + + # $web_request.Proxy may return $null for a FTP web request. We only set the credentials if we have an actual + # proxy to work with, otherwise just ignore the credentials property. + if ($null -ne $proxy) { + if ($ProxyUseDefaultCredential) { + # Weird hack, $web_request.Proxy returns an IWebProxy object which only guarantees the Credentials + # property. We cannot set UseDefaultCredentials so we just set the Credentials to the + # DefaultCredentials in the CredentialCache which does the same thing. + $proxy.Credentials = [System.Net.CredentialCache]::DefaultCredentials + } + elseif ($ProxyUsername) { + $proxy.Credentials = New-Object -TypeName System.Net.NetworkCredential -ArgumentList @( + $ProxyUsername, $ProxyPassword + ) + } + else { + $proxy.Credentials = $null + } + } + + $web_request.Proxy = $proxy + + # Some parameters only apply when dealing with a HttpWebRequest + if ($web_request -is [System.Net.HttpWebRequest]) { + if ($Headers) { + foreach ($header in $Headers.GetEnumerator()) { + switch ($header.Key) { + Accept { $web_request.Accept = $header.Value } + Connection { $web_request.Connection = $header.Value } + Content-Length { $web_request.ContentLength = $header.Value } + Content-Type { $web_request.ContentType = $header.Value } + Expect { $web_request.Expect = $header.Value } + Date { $web_request.Date = $header.Value } + Host { $web_request.Host = $header.Value } + If-Modified-Since { $web_request.IfModifiedSince = $header.Value } + Range { $web_request.AddRange($header.Value) } + Referer { $web_request.Referer = $header.Value } + Transfer-Encoding { + $web_request.SendChunked = $true + $web_request.TransferEncoding = $header.Value + } + User-Agent { continue } + default { $web_request.Headers.Add($header.Key, $header.Value) } + } + } + } + + # For backwards compatibility we need to support setting the User-Agent if the header was set in the task. + # We just need to make sure that if an explicit http_agent module was set then that takes priority. + if ($Headers -and $Headers.ContainsKey("User-Agent")) { + if ($HttpAgent -eq $ansible_web_request_options.http_agent.default) { + $HttpAgent = $Headers['User-Agent'] + } + elseif ($null -ne $Module) { + $Module.Warn("The 'User-Agent' header and the 'http_agent' was set, using the 'http_agent' for web request") + } + } + $web_request.UserAgent = $HttpAgent + + switch ($FollowRedirects) { + none { $web_request.AllowAutoRedirect = $false } + safe { + if ($web_request.Method -in @("GET", "HEAD")) { + $web_request.AllowAutoRedirect = $true + } + else { + $web_request.AllowAutoRedirect = $false + } + } + all { $web_request.AllowAutoRedirect = $true } + } + + if ($MaximumRedirection -eq 0) { + $web_request.AllowAutoRedirect = $false + } + else { + $web_request.MaximumAutomaticRedirections = $MaximumRedirection + } + } + + return $web_request +} + +Function Invoke-WithWebRequest { + <# + .SYNOPSIS + Invokes a ScriptBlock with the WebRequest. + + .DESCRIPTION + Invokes the ScriptBlock and handle extra information like accessing the response stream, closing those streams + safely as well as setting common module return values. + + .PARAMETER Module + The Ansible.Basic module to set the return values for. This will set the following return values; + elapsed - The total time, in seconds, that it took to send the web request and process the response + msg - The human readable description of the response status code + status_code - An int that is the response status code + + .PARAMETER Request + The System.Net.WebRequest to call. This can either be manually crafted or created with Get-AnsibleWebRequest. + + .PARAMETER Script + The ScriptBlock to invoke during the web request. This ScriptBlock should take in the params + Param ([System.Net.WebResponse]$Response, [System.IO.Stream]$Stream) + + This scriptblock should manage the response based on what it need to do. + + .PARAMETER Body + An optional Stream to send to the target during the request. + + .PARAMETER IgnoreBadResponse + By default a WebException will be raised for a non 2xx status code and the Script will not be invoked. This + parameter can be set to process all responses regardless of the status code. + + .EXAMPLE Basic module that downloads a file + $spec = @{ + options = @{ + path = @{ type = "path"; required = $true } + } + } + $module = Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWebRequestSpec)) + + $web_request = Get-AnsibleWebRequest -Module $module + + Invoke-WithWebRequest -Module $module -Request $web_request -Script { + Param ([System.Net.WebResponse]$Response, [System.IO.Stream]$Stream) + + $fs = [System.IO.File]::Create($module.Params.path) + try { + $Stream.CopyTo($fs) + $fs.Flush() + } finally { + $fs.Dispose() + } + } + #> + [CmdletBinding()] + param ( + [Parameter(Mandatory = $true)] + [System.Object] + [ValidateScript({ $_.GetType().FullName -eq 'Ansible.Basic.AnsibleModule' })] + $Module, + + [Parameter(Mandatory = $true)] + [System.Net.WebRequest] + $Request, + + [Parameter(Mandatory = $true)] + [ScriptBlock] + $Script, + + [AllowNull()] + [System.IO.Stream] + $Body, + + [Switch] + $IgnoreBadResponse + ) + + $start = Get-Date + if ($null -ne $Body) { + $request_st = $Request.GetRequestStream() + try { + $Body.CopyTo($request_st) + $request_st.Flush() + } + finally { + $request_st.Close() + } + } + + try { + try { + $web_response = $Request.GetResponse() + } + catch [System.Net.WebException] { + # A WebResponse with a status code not in the 200 range will raise a WebException. We check if the + # exception raised contains the actual response and continue on if IgnoreBadResponse is set. We also + # make sure we set the status_code return value on the Module object if possible + + if ($_.Exception.PSObject.Properties.Name -match "Response") { + $web_response = $_.Exception.Response + + if (-not $IgnoreBadResponse -or $null -eq $web_response) { + $Module.Result.msg = $_.Exception.StatusDescription + $Module.Result.status_code = $_.Exception.Response.StatusCode + throw $_ + } + } + else { + throw $_ + } + } + + if ($Request.RequestUri.IsFile) { + # A FileWebResponse won't have these properties set + $Module.Result.msg = "OK" + $Module.Result.status_code = 200 + } + else { + $Module.Result.msg = $web_response.StatusDescription + $Module.Result.status_code = $web_response.StatusCode + } + + $response_stream = $web_response.GetResponseStream() + try { + # Invoke the ScriptBlock and pass in WebResponse and ResponseStream + &$Script -Response $web_response -Stream $response_stream + } + finally { + $response_stream.Dispose() + } + } + finally { + if ($web_response) { + $web_response.Close() + } + $Module.Result.elapsed = ((Get-date) - $start).TotalSeconds + } +} + +Function Get-AnsibleWebRequestSpec { + <# + .SYNOPSIS + Used by modules to get the argument spec fragment for AnsibleModule. + + .EXAMPLES + $spec = @{ + options = @{} + } + $module = [Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWebRequestSpec)) + #> + @{ options = $ansible_web_request_options } +} + +# See lib/ansible/plugins/doc_fragments/url_windows.py +# Kept here for backwards compat as this variable was added in Ansible 2.9. Ultimately this util should be removed +# once the deprecation period has been added. +$ansible_web_request_options = @{ + method = @{ type = "str" } + follow_redirects = @{ type = "str"; choices = @("all", "none", "safe"); default = "safe" } + headers = @{ type = "dict" } + http_agent = @{ type = "str"; default = "ansible-httpget" } + maximum_redirection = @{ type = "int"; default = 50 } + timeout = @{ type = "int"; default = 30 } # Was defaulted to 10 in win_get_url but 30 in win_uri so we use 30 + validate_certs = @{ type = "bool"; default = $true } + + # Credential options + client_cert = @{ type = "str" } + client_cert_password = @{ type = "str"; no_log = $true } + force_basic_auth = @{ type = "bool"; default = $false } + url_username = @{ type = "str" } + url_password = @{ type = "str"; no_log = $true } + use_default_credential = @{ type = "bool"; default = $false } + + # Proxy options + use_proxy = @{ type = "bool"; default = $true } + proxy_url = @{ type = "str" } + proxy_username = @{ type = "str" } + proxy_password = @{ type = "str"; no_log = $true } + proxy_use_default_credential = @{ type = "bool"; default = $false } +} + +$export_members = @{ + Function = "Get-AnsibleWebRequest", "Get-AnsibleWebRequestSpec", "Invoke-WithWebRequest" + Variable = "ansible_web_request_options" +} +Export-ModuleMember @export_members diff --git a/lib/ansible/module_utils/powershell/__init__.py b/lib/ansible/module_utils/powershell/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/powershell/__init__.py diff --git a/lib/ansible/module_utils/pycompat24.py b/lib/ansible/module_utils/pycompat24.py new file mode 100644 index 0000000..c398427 --- /dev/null +++ b/lib/ansible/module_utils/pycompat24.py @@ -0,0 +1,91 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2016, Toshio Kuratomi <tkuratomi@ansible.com> +# Copyright (c) 2015, Marius Gedminas +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import sys + + +def get_exception(): + """Get the current exception. + + This code needs to work on Python 2.4 through 3.x, so we cannot use + "except Exception, e:" (SyntaxError on Python 3.x) nor + "except Exception as e:" (SyntaxError on Python 2.4-2.5). + Instead we must use :: + + except Exception: + e = get_exception() + + """ + return sys.exc_info()[1] + + +try: + # Python 2.6+ + from ast import literal_eval +except ImportError: + # a replacement for literal_eval that works with python 2.4. from: + # https://mail.python.org/pipermail/python-list/2009-September/551880.html + # which is essentially a cut/paste from an earlier (2.6) version of python's + # ast.py + from compiler import ast, parse + from ansible.module_utils.six import binary_type, integer_types, string_types, text_type + + def literal_eval(node_or_string): # type: ignore[misc] + """ + Safely evaluate an expression node or a string containing a Python + expression. The string or node provided may only consist of the following + Python literal structures: strings, numbers, tuples, lists, dicts, booleans, + and None. + """ + _safe_names = {'None': None, 'True': True, 'False': False} + if isinstance(node_or_string, string_types): + node_or_string = parse(node_or_string, mode='eval') + if isinstance(node_or_string, ast.Expression): + node_or_string = node_or_string.node + + def _convert(node): + if isinstance(node, ast.Const) and isinstance(node.value, (text_type, binary_type, float, complex) + integer_types): + return node.value + elif isinstance(node, ast.Tuple): + return tuple(map(_convert, node.nodes)) + elif isinstance(node, ast.List): + return list(map(_convert, node.nodes)) + elif isinstance(node, ast.Dict): + return dict((_convert(k), _convert(v)) for k, v in node.items()) + elif isinstance(node, ast.Name): + if node.name in _safe_names: + return _safe_names[node.name] + elif isinstance(node, ast.UnarySub): + return -_convert(node.expr) # pylint: disable=invalid-unary-operand-type + raise ValueError('malformed string') + return _convert(node_or_string) + +__all__ = ('get_exception', 'literal_eval') diff --git a/lib/ansible/module_utils/service.py b/lib/ansible/module_utils/service.py new file mode 100644 index 0000000..d2cecd4 --- /dev/null +++ b/lib/ansible/module_utils/service.py @@ -0,0 +1,274 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) Ansible Inc, 2016 +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import glob +import os +import pickle +import platform +import select +import shlex +import subprocess +import traceback + +from ansible.module_utils.six import PY2, b +from ansible.module_utils._text import to_bytes, to_text + + +def sysv_is_enabled(name, runlevel=None): + ''' + This function will check if the service name supplied + is enabled in any of the sysv runlevels + + :arg name: name of the service to test for + :kw runlevel: runlevel to check (default: None) + ''' + if runlevel: + if not os.path.isdir('/etc/rc0.d/'): + return bool(glob.glob('/etc/init.d/rc%s.d/S??%s' % (runlevel, name))) + return bool(glob.glob('/etc/rc%s.d/S??%s' % (runlevel, name))) + else: + if not os.path.isdir('/etc/rc0.d/'): + return bool(glob.glob('/etc/init.d/rc?.d/S??%s' % name)) + return bool(glob.glob('/etc/rc?.d/S??%s' % name)) + + +def get_sysv_script(name): + ''' + This function will return the expected path for an init script + corresponding to the service name supplied. + + :arg name: name or path of the service to test for + ''' + if name.startswith('/'): + result = name + else: + result = '/etc/init.d/%s' % name + + return result + + +def sysv_exists(name): + ''' + This function will return True or False depending on + the existence of an init script corresponding to the service name supplied. + + :arg name: name of the service to test for + ''' + return os.path.exists(get_sysv_script(name)) + + +def get_ps(module, pattern): + ''' + Last resort to find a service by trying to match pattern to programs in memory + ''' + found = False + if platform.system() == 'SunOS': + flags = '-ef' + else: + flags = 'auxww' + psbin = module.get_bin_path('ps', True) + + (rc, psout, pserr) = module.run_command([psbin, flags]) + if rc == 0: + for line in psout.splitlines(): + if pattern in line: + # FIXME: should add logic to prevent matching 'self', though that should be extremely rare + found = True + break + return found + + +def fail_if_missing(module, found, service, msg=''): + ''' + This function will return an error or exit gracefully depending on check mode status + and if the service is missing or not. + + :arg module: is an AnsibleModule object, used for it's utility methods + :arg found: boolean indicating if services was found or not + :arg service: name of service + :kw msg: extra info to append to error/success msg when missing + ''' + if not found: + module.fail_json(msg='Could not find the requested service %s: %s' % (service, msg)) + + +def fork_process(): + ''' + This function performs the double fork process to detach from the + parent process and execute. + ''' + pid = os.fork() + + if pid == 0: + # Set stdin/stdout/stderr to /dev/null + fd = os.open(os.devnull, os.O_RDWR) + + # clone stdin/out/err + for num in range(3): + if fd != num: + os.dup2(fd, num) + + # close otherwise + if fd not in range(3): + os.close(fd) + + # Make us a daemon + pid = os.fork() + + # end if not in child + if pid > 0: + os._exit(0) + + # get new process session and detach + sid = os.setsid() + if sid == -1: + raise Exception("Unable to detach session while daemonizing") + + # avoid possible problems with cwd being removed + os.chdir("/") + + pid = os.fork() + if pid > 0: + os._exit(0) + + return pid + + +def daemonize(module, cmd): + ''' + Execute a command while detaching as a daemon, returns rc, stdout, and stderr. + + :arg module: is an AnsibleModule object, used for it's utility methods + :arg cmd: is a list or string representing the command and options to run + + This is complex because daemonization is hard for people. + What we do is daemonize a part of this module, the daemon runs the command, + picks up the return code and output, and returns it to the main process. + ''' + + # init some vars + chunk = 4096 # FIXME: pass in as arg? + errors = 'surrogate_or_strict' + + # start it! + try: + pipe = os.pipe() + pid = fork_process() + except OSError: + module.fail_json(msg="Error while attempting to fork: %s", exception=traceback.format_exc()) + except Exception as exc: + module.fail_json(msg=to_text(exc), exception=traceback.format_exc()) + + # we don't do any locking as this should be a unique module/process + if pid == 0: + os.close(pipe[0]) + + # if command is string deal with py2 vs py3 conversions for shlex + if not isinstance(cmd, list): + if PY2: + cmd = shlex.split(to_bytes(cmd, errors=errors)) + else: + cmd = shlex.split(to_text(cmd, errors=errors)) + + # make sure we always use byte strings + run_cmd = [] + for c in cmd: + run_cmd.append(to_bytes(c, errors=errors)) + + # execute the command in forked process + p = subprocess.Popen(run_cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=lambda: os.close(pipe[1])) + fds = [p.stdout, p.stderr] + + # loop reading output till its done + output = {p.stdout: b(""), p.stderr: b("")} + while fds: + rfd, wfd, efd = select.select(fds, [], fds, 1) + if (rfd + wfd + efd) or p.poll(): + for out in list(fds): + if out in rfd: + data = os.read(out.fileno(), chunk) + if not data: + fds.remove(out) + output[out] += b(data) + + # even after fds close, we might want to wait for pid to die + p.wait() + + # Return a pickled data of parent + return_data = pickle.dumps([p.returncode, to_text(output[p.stdout]), to_text(output[p.stderr])], protocol=pickle.HIGHEST_PROTOCOL) + os.write(pipe[1], to_bytes(return_data, errors=errors)) + + # clean up + os.close(pipe[1]) + os._exit(0) + + elif pid == -1: + module.fail_json(msg="Unable to fork, no exception thrown, probably due to lack of resources, check logs.") + + else: + # in parent + os.close(pipe[1]) + os.waitpid(pid, 0) + + # Grab response data after child finishes + return_data = b("") + while True: + rfd, wfd, efd = select.select([pipe[0]], [], [pipe[0]]) + if pipe[0] in rfd: + data = os.read(pipe[0], chunk) + if not data: + break + return_data += b(data) + + # Note: no need to specify encoding on py3 as this module sends the + # pickle to itself (thus same python interpreter so we aren't mixing + # py2 and py3) + return pickle.loads(to_bytes(return_data, errors=errors)) + + +def check_ps(module, pattern): + + # Set ps flags + if platform.system() == 'SunOS': + psflags = '-ef' + else: + psflags = 'auxww' + + # Find ps binary + psbin = module.get_bin_path('ps', True) + + (rc, out, err) = module.run_command('%s %s' % (psbin, psflags)) + # If rc is 0, set running as appropriate + if rc == 0: + for line in out.split('\n'): + if pattern in line: + return True + return False diff --git a/lib/ansible/module_utils/six/__init__.py b/lib/ansible/module_utils/six/__init__.py new file mode 100644 index 0000000..f2d41c8 --- /dev/null +++ b/lib/ansible/module_utils/six/__init__.py @@ -0,0 +1,1009 @@ +# This code is strewn with things that are not defined on Python3 (unicode, +# long, etc) but they are all shielded by version checks. This is also an +# upstream vendored file that we're not going to modify on our own +# pylint: disable=undefined-variable +# +# Copyright (c) 2010-2020 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Utilities for writing code that runs on Python 2 and 3""" + +from __future__ import absolute_import + +import functools +import itertools +import operator +import sys +import types + +# The following makes it easier for us to script updates of the bundled code. It is not part of +# upstream six +_BUNDLED_METADATA = {"pypi_name": "six", "version": "1.16.0"} + +__author__ = "Benjamin Peterson <benjamin@python.org>" +__version__ = "1.16.0" + + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +PY34 = sys.version_info[0:2] >= (3, 4) + +if PY3: + string_types = str, + integer_types = int, + class_types = type, + text_type = str + binary_type = bytes + + MAXSIZE = sys.maxsize +else: + string_types = basestring, + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + + if sys.platform.startswith("java"): + # Jython always uses 32 bits. + MAXSIZE = int((1 << 31) - 1) + else: + # It's possible to have sizeof(long) != sizeof(Py_ssize_t). + class X(object): + + def __len__(self): + return 1 << 31 + try: + len(X()) + except OverflowError: + # 32-bit + MAXSIZE = int((1 << 31) - 1) + else: + # 64-bit + MAXSIZE = int((1 << 63) - 1) + del X + +if PY34: + from importlib.util import spec_from_loader +else: + spec_from_loader = None + + +def _add_doc(func, doc): + """Add documentation to a function.""" + func.__doc__ = doc + + +def _import_module(name): + """Import module, returning the module after the last dot.""" + __import__(name) + return sys.modules[name] + + +class _LazyDescr(object): + + def __init__(self, name): + self.name = name + + def __get__(self, obj, tp): + result = self._resolve() + setattr(obj, self.name, result) # Invokes __set__. + try: + # This is a bit ugly, but it avoids running this again by + # removing this descriptor. + delattr(obj.__class__, self.name) + except AttributeError: + pass + return result + + +class MovedModule(_LazyDescr): + + def __init__(self, name, old, new=None): + super(MovedModule, self).__init__(name) + if PY3: + if new is None: + new = name + self.mod = new + else: + self.mod = old + + def _resolve(self): + return _import_module(self.mod) + + def __getattr__(self, attr): + _module = self._resolve() + value = getattr(_module, attr) + setattr(self, attr, value) + return value + + +class _LazyModule(types.ModuleType): + + def __init__(self, name): + super(_LazyModule, self).__init__(name) + self.__doc__ = self.__class__.__doc__ + + def __dir__(self): + attrs = ["__doc__", "__name__"] + attrs += [attr.name for attr in self._moved_attributes] + return attrs + + # Subclasses should override this + _moved_attributes = [] + + +class MovedAttribute(_LazyDescr): + + def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): + super(MovedAttribute, self).__init__(name) + if PY3: + if new_mod is None: + new_mod = name + self.mod = new_mod + if new_attr is None: + if old_attr is None: + new_attr = name + else: + new_attr = old_attr + self.attr = new_attr + else: + self.mod = old_mod + if old_attr is None: + old_attr = name + self.attr = old_attr + + def _resolve(self): + module = _import_module(self.mod) + return getattr(module, self.attr) + + +class _SixMetaPathImporter(object): + + """ + A meta path importer to import six.moves and its submodules. + + This class implements a PEP302 finder and loader. It should be compatible + with Python 2.5 and all existing versions of Python3 + """ + + def __init__(self, six_module_name): + self.name = six_module_name + self.known_modules = {} + + def _add_module(self, mod, *fullnames): + for fullname in fullnames: + self.known_modules[self.name + "." + fullname] = mod + + def _get_module(self, fullname): + return self.known_modules[self.name + "." + fullname] + + def find_module(self, fullname, path=None): + if fullname in self.known_modules: + return self + return None + + def find_spec(self, fullname, path, target=None): + if fullname in self.known_modules: + return spec_from_loader(fullname, self) + return None + + def __get_module(self, fullname): + try: + return self.known_modules[fullname] + except KeyError: + raise ImportError("This loader does not know module " + fullname) + + def load_module(self, fullname): + try: + # in case of a reload + return sys.modules[fullname] + except KeyError: + pass + mod = self.__get_module(fullname) + if isinstance(mod, MovedModule): + mod = mod._resolve() + else: + mod.__loader__ = self + sys.modules[fullname] = mod + return mod + + def is_package(self, fullname): + """ + Return true, if the named module is a package. + + We need this method to get correct spec objects with + Python 3.4 (see PEP451) + """ + return hasattr(self.__get_module(fullname), "__path__") + + def get_code(self, fullname): + """Return None + + Required, if is_package is implemented""" + self.__get_module(fullname) # eventually raises ImportError + return None + get_source = get_code # same as get_code + + def create_module(self, spec): + return self.load_module(spec.name) + + def exec_module(self, module): + pass + + +_importer = _SixMetaPathImporter(__name__) + + +class _MovedItems(_LazyModule): + + """Lazy loading of moved objects""" + __path__ = [] # mark as package + + +_moved_attributes = [ + MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), + MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), + MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), + MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), + MovedAttribute("intern", "__builtin__", "sys"), + MovedAttribute("map", "itertools", "builtins", "imap", "map"), + MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), + MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), + MovedAttribute("getoutput", "commands", "subprocess"), + MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), + MovedAttribute("reduce", "__builtin__", "functools"), + MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), + MovedAttribute("StringIO", "StringIO", "io"), + MovedAttribute("UserDict", "UserDict", "collections"), + MovedAttribute("UserList", "UserList", "collections"), + MovedAttribute("UserString", "UserString", "collections"), + MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), + MovedModule("builtins", "__builtin__"), + MovedModule("configparser", "ConfigParser"), + MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"), + MovedModule("copyreg", "copy_reg"), + MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), + MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"), + MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"), + MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), + MovedModule("http_cookies", "Cookie", "http.cookies"), + MovedModule("html_entities", "htmlentitydefs", "html.entities"), + MovedModule("html_parser", "HTMLParser", "html.parser"), + MovedModule("http_client", "httplib", "http.client"), + MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), + MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), + MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), + MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), + MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), + MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), + MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), + MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), + MovedModule("cPickle", "cPickle", "pickle"), + MovedModule("queue", "Queue"), + MovedModule("reprlib", "repr"), + MovedModule("socketserver", "SocketServer"), + MovedModule("_thread", "thread", "_thread"), + MovedModule("tkinter", "Tkinter"), + MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), + MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), + MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), + MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), + MovedModule("tkinter_tix", "Tix", "tkinter.tix"), + MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), + MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), + MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), + MovedModule("tkinter_colorchooser", "tkColorChooser", + "tkinter.colorchooser"), + MovedModule("tkinter_commondialog", "tkCommonDialog", + "tkinter.commondialog"), + MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), + MovedModule("tkinter_font", "tkFont", "tkinter.font"), + MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), + MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", + "tkinter.simpledialog"), + MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), + MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), + MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), + MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), + MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), + MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), +] +# Add windows specific modules. +if sys.platform == "win32": + _moved_attributes += [ + MovedModule("winreg", "_winreg"), + ] + +for attr in _moved_attributes: + setattr(_MovedItems, attr.name, attr) + if isinstance(attr, MovedModule): + _importer._add_module(attr, "moves." + attr.name) +del attr + +_MovedItems._moved_attributes = _moved_attributes + +moves = _MovedItems(__name__ + ".moves") +_importer._add_module(moves, "moves") + + +class Module_six_moves_urllib_parse(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_parse""" + + +_urllib_parse_moved_attributes = [ + MovedAttribute("ParseResult", "urlparse", "urllib.parse"), + MovedAttribute("SplitResult", "urlparse", "urllib.parse"), + MovedAttribute("parse_qs", "urlparse", "urllib.parse"), + MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), + MovedAttribute("urldefrag", "urlparse", "urllib.parse"), + MovedAttribute("urljoin", "urlparse", "urllib.parse"), + MovedAttribute("urlparse", "urlparse", "urllib.parse"), + MovedAttribute("urlsplit", "urlparse", "urllib.parse"), + MovedAttribute("urlunparse", "urlparse", "urllib.parse"), + MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), + MovedAttribute("quote", "urllib", "urllib.parse"), + MovedAttribute("quote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote", "urllib", "urllib.parse"), + MovedAttribute("unquote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), + MovedAttribute("urlencode", "urllib", "urllib.parse"), + MovedAttribute("splitquery", "urllib", "urllib.parse"), + MovedAttribute("splittag", "urllib", "urllib.parse"), + MovedAttribute("splituser", "urllib", "urllib.parse"), + MovedAttribute("splitvalue", "urllib", "urllib.parse"), + MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), + MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), + MovedAttribute("uses_params", "urlparse", "urllib.parse"), + MovedAttribute("uses_query", "urlparse", "urllib.parse"), + MovedAttribute("uses_relative", "urlparse", "urllib.parse"), +] +for attr in _urllib_parse_moved_attributes: + setattr(Module_six_moves_urllib_parse, attr.name, attr) +del attr + +Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes + +_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), + "moves.urllib_parse", "moves.urllib.parse") + + +class Module_six_moves_urllib_error(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_error""" + + +_urllib_error_moved_attributes = [ + MovedAttribute("URLError", "urllib2", "urllib.error"), + MovedAttribute("HTTPError", "urllib2", "urllib.error"), + MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), +] +for attr in _urllib_error_moved_attributes: + setattr(Module_six_moves_urllib_error, attr.name, attr) +del attr + +Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes + +_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), + "moves.urllib_error", "moves.urllib.error") + + +class Module_six_moves_urllib_request(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_request""" + + +_urllib_request_moved_attributes = [ + MovedAttribute("urlopen", "urllib2", "urllib.request"), + MovedAttribute("install_opener", "urllib2", "urllib.request"), + MovedAttribute("build_opener", "urllib2", "urllib.request"), + MovedAttribute("pathname2url", "urllib", "urllib.request"), + MovedAttribute("url2pathname", "urllib", "urllib.request"), + MovedAttribute("getproxies", "urllib", "urllib.request"), + MovedAttribute("Request", "urllib2", "urllib.request"), + MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), + MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), + MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), + MovedAttribute("BaseHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), + MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), + MovedAttribute("FileHandler", "urllib2", "urllib.request"), + MovedAttribute("FTPHandler", "urllib2", "urllib.request"), + MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), + MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), + MovedAttribute("urlretrieve", "urllib", "urllib.request"), + MovedAttribute("urlcleanup", "urllib", "urllib.request"), + MovedAttribute("URLopener", "urllib", "urllib.request"), + MovedAttribute("FancyURLopener", "urllib", "urllib.request"), + MovedAttribute("proxy_bypass", "urllib", "urllib.request"), + MovedAttribute("parse_http_list", "urllib2", "urllib.request"), + MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), +] +for attr in _urllib_request_moved_attributes: + setattr(Module_six_moves_urllib_request, attr.name, attr) +del attr + +Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes + +_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), + "moves.urllib_request", "moves.urllib.request") + + +class Module_six_moves_urllib_response(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_response""" + + +_urllib_response_moved_attributes = [ + MovedAttribute("addbase", "urllib", "urllib.response"), + MovedAttribute("addclosehook", "urllib", "urllib.response"), + MovedAttribute("addinfo", "urllib", "urllib.response"), + MovedAttribute("addinfourl", "urllib", "urllib.response"), +] +for attr in _urllib_response_moved_attributes: + setattr(Module_six_moves_urllib_response, attr.name, attr) +del attr + +Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes + +_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), + "moves.urllib_response", "moves.urllib.response") + + +class Module_six_moves_urllib_robotparser(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_robotparser""" + + +_urllib_robotparser_moved_attributes = [ + MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), +] +for attr in _urllib_robotparser_moved_attributes: + setattr(Module_six_moves_urllib_robotparser, attr.name, attr) +del attr + +Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes + +_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), + "moves.urllib_robotparser", "moves.urllib.robotparser") + + +class Module_six_moves_urllib(types.ModuleType): + + """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" + __path__ = [] # mark as package + parse = _importer._get_module("moves.urllib_parse") + error = _importer._get_module("moves.urllib_error") + request = _importer._get_module("moves.urllib_request") + response = _importer._get_module("moves.urllib_response") + robotparser = _importer._get_module("moves.urllib_robotparser") + + def __dir__(self): + return ['parse', 'error', 'request', 'response', 'robotparser'] + + +_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), + "moves.urllib") + + +def add_move(move): + """Add an item to six.moves.""" + setattr(_MovedItems, move.name, move) + + +def remove_move(name): + """Remove item from six.moves.""" + try: + delattr(_MovedItems, name) + except AttributeError: + try: + del moves.__dict__[name] + except KeyError: + raise AttributeError("no such move, %r" % (name,)) + + +if PY3: + _meth_func = "__func__" + _meth_self = "__self__" + + _func_closure = "__closure__" + _func_code = "__code__" + _func_defaults = "__defaults__" + _func_globals = "__globals__" +else: + _meth_func = "im_func" + _meth_self = "im_self" + + _func_closure = "func_closure" + _func_code = "func_code" + _func_defaults = "func_defaults" + _func_globals = "func_globals" + + +try: + advance_iterator = next +except NameError: + def advance_iterator(it): + return it.next() +next = advance_iterator + + +try: + callable = callable +except NameError: + def callable(obj): + return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) + + +if PY3: + def get_unbound_function(unbound): + return unbound + + create_bound_method = types.MethodType + + def create_unbound_method(func, cls): + return func + + Iterator = object +else: + def get_unbound_function(unbound): + return unbound.im_func + + def create_bound_method(func, obj): + return types.MethodType(func, obj, obj.__class__) + + def create_unbound_method(func, cls): + return types.MethodType(func, None, cls) + + class Iterator(object): + + def next(self): + return type(self).__next__(self) + + callable = callable +_add_doc(get_unbound_function, + """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_closure = operator.attrgetter(_func_closure) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) +get_function_globals = operator.attrgetter(_func_globals) + + +if PY3: + def iterkeys(d, **kw): + return iter(d.keys(**kw)) + + def itervalues(d, **kw): + return iter(d.values(**kw)) + + def iteritems(d, **kw): + return iter(d.items(**kw)) + + def iterlists(d, **kw): + return iter(d.lists(**kw)) + + viewkeys = operator.methodcaller("keys") + + viewvalues = operator.methodcaller("values") + + viewitems = operator.methodcaller("items") +else: + def iterkeys(d, **kw): + return d.iterkeys(**kw) + + def itervalues(d, **kw): + return d.itervalues(**kw) + + def iteritems(d, **kw): + return d.iteritems(**kw) + + def iterlists(d, **kw): + return d.iterlists(**kw) + + viewkeys = operator.methodcaller("viewkeys") + + viewvalues = operator.methodcaller("viewvalues") + + viewitems = operator.methodcaller("viewitems") + +_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") +_add_doc(itervalues, "Return an iterator over the values of a dictionary.") +_add_doc(iteritems, + "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc(iterlists, + "Return an iterator over the (key, [values]) pairs of a dictionary.") + + +if PY3: + def b(s): + return s.encode("latin-1") + + def u(s): + return s + unichr = chr + import struct + int2byte = struct.Struct(">B").pack + del struct + byte2int = operator.itemgetter(0) + indexbytes = operator.getitem + iterbytes = iter + import io + StringIO = io.StringIO + BytesIO = io.BytesIO + del io + _assertCountEqual = "assertCountEqual" + if sys.version_info[1] <= 1: + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" + _assertNotRegex = "assertNotRegexpMatches" + else: + _assertRaisesRegex = "assertRaisesRegex" + _assertRegex = "assertRegex" + _assertNotRegex = "assertNotRegex" +else: + def b(s): + return s + # Workaround for standalone backslash + + def u(s): + return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") + unichr = unichr + int2byte = chr + + def byte2int(bs): + return ord(bs[0]) + + def indexbytes(buf, i): + return ord(buf[i]) + iterbytes = functools.partial(itertools.imap, ord) + import StringIO + StringIO = BytesIO = StringIO.StringIO + _assertCountEqual = "assertItemsEqual" + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" + _assertNotRegex = "assertNotRegexpMatches" +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +def assertCountEqual(self, *args, **kwargs): + return getattr(self, _assertCountEqual)(*args, **kwargs) + + +def assertRaisesRegex(self, *args, **kwargs): + return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +def assertRegex(self, *args, **kwargs): + return getattr(self, _assertRegex)(*args, **kwargs) + + +def assertNotRegex(self, *args, **kwargs): + return getattr(self, _assertNotRegex)(*args, **kwargs) + + +if PY3: + exec_ = getattr(moves.builtins, "exec") + + def reraise(tp, value, tb=None): + try: + if value is None: + value = tp() + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + finally: + value = None + tb = None + +else: + def exec_(_code_, _globs_=None, _locs_=None): + """Execute code in a namespace.""" + if _globs_ is None: + frame = sys._getframe(1) + _globs_ = frame.f_globals + if _locs_ is None: + _locs_ = frame.f_locals + del frame + elif _locs_ is None: + _locs_ = _globs_ + exec("""exec _code_ in _globs_, _locs_""") + + exec_("""def reraise(tp, value, tb=None): + try: + raise tp, value, tb + finally: + tb = None +""") + + +if sys.version_info[:2] > (3,): + exec_("""def raise_from(value, from_value): + try: + raise value from from_value + finally: + value = None +""") +else: + def raise_from(value, from_value): + raise value + + +print_ = getattr(moves.builtins, "print", None) +if print_ is None: + def print_(*args, **kwargs): + """The new-style print function for Python 2.4 and 2.5.""" + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + + def write(data): + if not isinstance(data, basestring): + data = str(data) + # If the file has an encoding, encode unicode with it. + if (isinstance(fp, file) and + isinstance(data, unicode) and + fp.encoding is not None): + errors = getattr(fp, "errors", None) + if errors is None: + errors = "strict" + data = data.encode(fp.encoding, errors) + fp.write(data) + want_unicode = False + sep = kwargs.pop("sep", None) + if sep is not None: + if isinstance(sep, unicode): + want_unicode = True + elif not isinstance(sep, str): + raise TypeError("sep must be None or a string") + end = kwargs.pop("end", None) + if end is not None: + if isinstance(end, unicode): + want_unicode = True + elif not isinstance(end, str): + raise TypeError("end must be None or a string") + if kwargs: + raise TypeError("invalid keyword arguments to print()") + if not want_unicode: + for arg in args: + if isinstance(arg, unicode): + want_unicode = True + break + if want_unicode: + newline = unicode("\n") + space = unicode(" ") + else: + newline = "\n" + space = " " + if sep is None: + sep = space + if end is None: + end = newline + for i, arg in enumerate(args): + if i: + write(sep) + write(arg) + write(end) +if sys.version_info[:2] < (3, 3): + _print = print_ + + def print_(*args, **kwargs): + fp = kwargs.get("file", sys.stdout) + flush = kwargs.pop("flush", False) + _print(*args, **kwargs) + if flush and fp is not None: + fp.flush() + +_add_doc(reraise, """Reraise an exception.""") + +if sys.version_info[0:2] < (3, 4): + # This does exactly the same what the :func:`py3:functools.update_wrapper` + # function does on Python versions after 3.2. It sets the ``__wrapped__`` + # attribute on ``wrapper`` object and it doesn't raise an error if any of + # the attributes mentioned in ``assigned`` and ``updated`` are missing on + # ``wrapped`` object. + def _update_wrapper(wrapper, wrapped, + assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES): + for attr in assigned: + try: + value = getattr(wrapped, attr) + except AttributeError: + continue + else: + setattr(wrapper, attr, value) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, {})) + wrapper.__wrapped__ = wrapped + return wrapper + _update_wrapper.__doc__ = functools.update_wrapper.__doc__ + + def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES): + return functools.partial(_update_wrapper, wrapped=wrapped, + assigned=assigned, updated=updated) + wraps.__doc__ = functools.wraps.__doc__ + +else: + wraps = functools.wraps + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + + def __new__(cls, name, this_bases, d): + if sys.version_info[:2] >= (3, 7): + # This version introduced PEP 560 that requires a bit + # of extra care (we mimic what is done by __build_class__). + resolved_bases = types.resolve_bases(bases) + if resolved_bases is not bases: + d['__orig_bases__'] = bases + else: + resolved_bases = bases + return meta(name, resolved_bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + return type.__new__(metaclass, 'temporary_class', (), {}) + + +def add_metaclass(metaclass): + """Class decorator for creating a class with a metaclass.""" + def wrapper(cls): + orig_vars = cls.__dict__.copy() + slots = orig_vars.get('__slots__') + if slots is not None: + if isinstance(slots, str): + slots = [slots] + for slots_var in slots: + orig_vars.pop(slots_var) + orig_vars.pop('__dict__', None) + orig_vars.pop('__weakref__', None) + if hasattr(cls, '__qualname__'): + orig_vars['__qualname__'] = cls.__qualname__ + return metaclass(cls.__name__, cls.__bases__, orig_vars) + return wrapper + + +def ensure_binary(s, encoding='utf-8', errors='strict'): + """Coerce **s** to six.binary_type. + + For Python 2: + - `unicode` -> encoded to `str` + - `str` -> `str` + + For Python 3: + - `str` -> encoded to `bytes` + - `bytes` -> `bytes` + """ + if isinstance(s, binary_type): + return s + if isinstance(s, text_type): + return s.encode(encoding, errors) + raise TypeError("not expecting type '%s'" % type(s)) + + +def ensure_str(s, encoding='utf-8', errors='strict'): + """Coerce *s* to `str`. + + For Python 2: + - `unicode` -> encoded to `str` + - `str` -> `str` + + For Python 3: + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + # Optimization: Fast return for the common case. + if type(s) is str: + return s + if PY2 and isinstance(s, text_type): + return s.encode(encoding, errors) + elif PY3 and isinstance(s, binary_type): + return s.decode(encoding, errors) + elif not isinstance(s, (text_type, binary_type)): + raise TypeError("not expecting type '%s'" % type(s)) + return s + + +def ensure_text(s, encoding='utf-8', errors='strict'): + """Coerce *s* to six.text_type. + + For Python 2: + - `unicode` -> `unicode` + - `str` -> `unicode` + + For Python 3: + - `str` -> `str` + - `bytes` -> decoded to `str` + """ + if isinstance(s, binary_type): + return s.decode(encoding, errors) + elif isinstance(s, text_type): + return s + else: + raise TypeError("not expecting type '%s'" % type(s)) + + +def python_2_unicode_compatible(klass): + """ + A class decorator that defines __unicode__ and __str__ methods under Python 2. + Under Python 3 it does nothing. + + To support Python 2 and 3 with a single code base, define a __str__ method + returning text and apply this decorator to the class. + """ + if PY2: + if '__str__' not in klass.__dict__: + raise ValueError("@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % + klass.__name__) + klass.__unicode__ = klass.__str__ + klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + return klass + + +# Complete the moves implementation. +# This code is at the end of this module to speed up module loading. +# Turn this module into a package. +__path__ = [] # required for PEP 302 and PEP 451 +__package__ = __name__ # see PEP 366 @ReservedAssignment +if globals().get("__spec__") is not None: + __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable +# Remove other six meta path importers, since they cause problems. This can +# happen if six is removed from sys.modules and then reloaded. (Setuptools does +# this for some reason.) +if sys.meta_path: + for i, importer in enumerate(sys.meta_path): + # Here's some real nastiness: Another "instance" of the six module might + # be floating around. Therefore, we can't use isinstance() to check for + # the six meta path importer, since the other six instance will have + # inserted an importer with different class. + if (type(importer).__name__ == "_SixMetaPathImporter" and + importer.name == __name__): + del sys.meta_path[i] + break + del i, importer +# Finally, add the importer to the meta path import hook. +sys.meta_path.append(_importer) diff --git a/lib/ansible/module_utils/splitter.py b/lib/ansible/module_utils/splitter.py new file mode 100644 index 0000000..c170b1c --- /dev/null +++ b/lib/ansible/module_utils/splitter.py @@ -0,0 +1,219 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013 +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +def _get_quote_state(token, quote_char): + ''' + the goal of this block is to determine if the quoted string + is unterminated in which case it needs to be put back together + ''' + # the char before the current one, used to see if + # the current character is escaped + prev_char = None + for idx, cur_char in enumerate(token): + if idx > 0: + prev_char = token[idx - 1] + if cur_char in '"\'' and prev_char != '\\': + if quote_char: + if cur_char == quote_char: + quote_char = None + else: + quote_char = cur_char + return quote_char + + +def _count_jinja2_blocks(token, cur_depth, open_token, close_token): + ''' + this function counts the number of opening/closing blocks for a + given opening/closing type and adjusts the current depth for that + block based on the difference + ''' + num_open = token.count(open_token) + num_close = token.count(close_token) + if num_open != num_close: + cur_depth += (num_open - num_close) + if cur_depth < 0: + cur_depth = 0 + return cur_depth + + +def split_args(args): + ''' + Splits args on whitespace, but intelligently reassembles + those that may have been split over a jinja2 block or quotes. + + When used in a remote module, we won't ever have to be concerned about + jinja2 blocks, however this function is/will be used in the + core portions as well before the args are templated. + + example input: a=b c="foo bar" + example output: ['a=b', 'c="foo bar"'] + + Basically this is a variation shlex that has some more intelligence for + how Ansible needs to use it. + ''' + + # the list of params parsed out of the arg string + # this is going to be the result value when we are donei + params = [] + + # here we encode the args, so we have a uniform charset to + # work with, and split on white space + args = args.strip() + try: + args = args.encode('utf-8') + do_decode = True + except UnicodeDecodeError: + do_decode = False + items = args.split('\n') + + # iterate over the tokens, and reassemble any that may have been + # split on a space inside a jinja2 block. + # ex if tokens are "{{", "foo", "}}" these go together + + # These variables are used + # to keep track of the state of the parsing, since blocks and quotes + # may be nested within each other. + + quote_char = None + inside_quotes = False + print_depth = 0 # used to count nested jinja2 {{ }} blocks + block_depth = 0 # used to count nested jinja2 {% %} blocks + comment_depth = 0 # used to count nested jinja2 {# #} blocks + + # now we loop over each split chunk, coalescing tokens if the white space + # split occurred within quotes or a jinja2 block of some kind + for itemidx, item in enumerate(items): + + # we split on spaces and newlines separately, so that we + # can tell which character we split on for reassembly + # inside quotation characters + tokens = item.strip().split(' ') + + line_continuation = False + for idx, token in enumerate(tokens): + + # if we hit a line continuation character, but + # we're not inside quotes, ignore it and continue + # on to the next token while setting a flag + if token == '\\' and not inside_quotes: + line_continuation = True + continue + + # store the previous quoting state for checking later + was_inside_quotes = inside_quotes + quote_char = _get_quote_state(token, quote_char) + inside_quotes = quote_char is not None + + # multiple conditions may append a token to the list of params, + # so we keep track with this flag to make sure it only happens once + # append means add to the end of the list, don't append means concatenate + # it to the end of the last token + appended = False + + # if we're inside quotes now, but weren't before, append the token + # to the end of the list, since we'll tack on more to it later + # otherwise, if we're inside any jinja2 block, inside quotes, or we were + # inside quotes (but aren't now) concat this token to the last param + if inside_quotes and not was_inside_quotes: + params.append(token) + appended = True + elif print_depth or block_depth or comment_depth or inside_quotes or was_inside_quotes: + if idx == 0 and not inside_quotes and was_inside_quotes: + params[-1] = "%s%s" % (params[-1], token) + elif len(tokens) > 1: + spacer = '' + if idx > 0: + spacer = ' ' + params[-1] = "%s%s%s" % (params[-1], spacer, token) + else: + spacer = '' + if not params[-1].endswith('\n') and idx == 0: + spacer = '\n' + params[-1] = "%s%s%s" % (params[-1], spacer, token) + appended = True + + # if the number of paired block tags is not the same, the depth has changed, so we calculate that here + # and may append the current token to the params (if we haven't previously done so) + prev_print_depth = print_depth + print_depth = _count_jinja2_blocks(token, print_depth, "{{", "}}") + if print_depth != prev_print_depth and not appended: + params.append(token) + appended = True + + prev_block_depth = block_depth + block_depth = _count_jinja2_blocks(token, block_depth, "{%", "%}") + if block_depth != prev_block_depth and not appended: + params.append(token) + appended = True + + prev_comment_depth = comment_depth + comment_depth = _count_jinja2_blocks(token, comment_depth, "{#", "#}") + if comment_depth != prev_comment_depth and not appended: + params.append(token) + appended = True + + # finally, if we're at zero depth for all blocks and not inside quotes, and have not + # yet appended anything to the list of params, we do so now + if not (print_depth or block_depth or comment_depth) and not inside_quotes and not appended and token != '': + params.append(token) + + # if this was the last token in the list, and we have more than + # one item (meaning we split on newlines), add a newline back here + # to preserve the original structure + if len(items) > 1 and itemidx != len(items) - 1 and not line_continuation: + if not params[-1].endswith('\n') or item == '': + params[-1] += '\n' + + # always clear the line continuation flag + line_continuation = False + + # If we're done and things are not at zero depth or we're still inside quotes, + # raise an error to indicate that the args were unbalanced + if print_depth or block_depth or comment_depth or inside_quotes: + raise Exception("error while splitting arguments, either an unbalanced jinja2 block or quotes") + + # finally, we decode each param back to the unicode it was in the arg string + if do_decode: + params = [x.decode('utf-8') for x in params] + + return params + + +def is_quoted(data): + return len(data) > 0 and (data[0] == '"' and data[-1] == '"' or data[0] == "'" and data[-1] == "'") + + +def unquote(data): + ''' removes first and last quotes from a string, if the string starts and ends with the same quotes ''' + if is_quoted(data): + return data[1:-1] + return data diff --git a/lib/ansible/module_utils/urls.py b/lib/ansible/module_utils/urls.py new file mode 100644 index 0000000..542f89b --- /dev/null +++ b/lib/ansible/module_utils/urls.py @@ -0,0 +1,2070 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013 +# Copyright (c), Toshio Kuratomi <tkuratomi@ansible.com>, 2015 +# +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) +# +# The match_hostname function and supporting code is under the terms and +# conditions of the Python Software Foundation License. They were taken from +# the Python3 standard library and adapted for use in Python2. See comments in the +# source for which code precisely is under this License. +# +# PSF License (see licenses/PSF-license.txt or https://opensource.org/licenses/Python-2.0) + + +''' +The **urls** utils module offers a replacement for the urllib2 python library. + +urllib2 is the python stdlib way to retrieve files from the Internet but it +lacks some security features (around verifying SSL certificates) that users +should care about in most situations. Using the functions in this module corrects +deficiencies in the urllib2 module wherever possible. + +There are also third-party libraries (for instance, requests) which can be used +to replace urllib2 with a more secure library. However, all third party libraries +require that the library be installed on the managed machine. That is an extra step +for users making use of a module. If possible, avoid third party libraries by using +this code instead. +''' + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import atexit +import base64 +import email.mime.multipart +import email.mime.nonmultipart +import email.mime.application +import email.parser +import email.utils +import functools +import io +import mimetypes +import netrc +import os +import platform +import re +import socket +import sys +import tempfile +import traceback +import types + +from contextlib import contextmanager + +try: + import gzip + HAS_GZIP = True + GZIP_IMP_ERR = None +except ImportError: + HAS_GZIP = False + GZIP_IMP_ERR = traceback.format_exc() + GzipFile = object +else: + GzipFile = gzip.GzipFile # type: ignore[assignment,misc] + +try: + import email.policy +except ImportError: + # Py2 + import email.generator + +try: + import httplib +except ImportError: + # Python 3 + import http.client as httplib # type: ignore[no-redef] + +import ansible.module_utils.compat.typing as t +import ansible.module_utils.six.moves.http_cookiejar as cookiejar +import ansible.module_utils.six.moves.urllib.error as urllib_error + +from ansible.module_utils.common.collections import Mapping, is_sequence +from ansible.module_utils.six import PY2, PY3, string_types +from ansible.module_utils.six.moves import cStringIO +from ansible.module_utils.basic import get_distribution, missing_required_lib +from ansible.module_utils._text import to_bytes, to_native, to_text + +try: + # python3 + import urllib.request as urllib_request + from urllib.request import AbstractHTTPHandler, BaseHandler +except ImportError: + # python2 + import urllib2 as urllib_request # type: ignore[no-redef] + from urllib2 import AbstractHTTPHandler, BaseHandler # type: ignore[no-redef] + +urllib_request.HTTPRedirectHandler.http_error_308 = urllib_request.HTTPRedirectHandler.http_error_307 # type: ignore[attr-defined] + +try: + from ansible.module_utils.six.moves.urllib.parse import urlparse, urlunparse, unquote + HAS_URLPARSE = True +except Exception: + HAS_URLPARSE = False + +try: + import ssl + HAS_SSL = True +except Exception: + HAS_SSL = False + +try: + # SNI Handling needs python2.7.9's SSLContext + from ssl import create_default_context, SSLContext + HAS_SSLCONTEXT = True +except ImportError: + HAS_SSLCONTEXT = False + +# SNI Handling for python < 2.7.9 with urllib3 support +HAS_URLLIB3_PYOPENSSLCONTEXT = False +HAS_URLLIB3_SSL_WRAP_SOCKET = False +if not HAS_SSLCONTEXT: + try: + # urllib3>=1.15 + try: + from urllib3.contrib.pyopenssl import PyOpenSSLContext + except Exception: + from requests.packages.urllib3.contrib.pyopenssl import PyOpenSSLContext + HAS_URLLIB3_PYOPENSSLCONTEXT = True + except Exception: + # urllib3<1.15,>=1.6 + try: + try: + from urllib3.contrib.pyopenssl import ssl_wrap_socket + except Exception: + from requests.packages.urllib3.contrib.pyopenssl import ssl_wrap_socket + HAS_URLLIB3_SSL_WRAP_SOCKET = True + except Exception: + pass + +# Select a protocol that includes all secure tls protocols +# Exclude insecure ssl protocols if possible + +if HAS_SSL: + # If we can't find extra tls methods, ssl.PROTOCOL_TLSv1 is sufficient + PROTOCOL = ssl.PROTOCOL_TLSv1 +if not HAS_SSLCONTEXT and HAS_SSL: + try: + import ctypes + import ctypes.util + except ImportError: + # python 2.4 (likely rhel5 which doesn't have tls1.1 support in its openssl) + pass + else: + libssl_name = ctypes.util.find_library('ssl') + libssl = ctypes.CDLL(libssl_name) + for method in ('TLSv1_1_method', 'TLSv1_2_method'): + try: + libssl[method] + # Found something - we'll let openssl autonegotiate and hope + # the server has disabled sslv2 and 3. best we can do. + PROTOCOL = ssl.PROTOCOL_SSLv23 + break + except AttributeError: + pass + del libssl + + +# The following makes it easier for us to script updates of the bundled backports.ssl_match_hostname +# The bundled backports.ssl_match_hostname should really be moved into its own file for processing +_BUNDLED_METADATA = {"pypi_name": "backports.ssl_match_hostname", "version": "3.7.0.1"} + +LOADED_VERIFY_LOCATIONS = set() # type: t.Set[str] + +HAS_MATCH_HOSTNAME = True +try: + from ssl import match_hostname, CertificateError +except ImportError: + try: + from backports.ssl_match_hostname import match_hostname, CertificateError # type: ignore[misc] + except ImportError: + HAS_MATCH_HOSTNAME = False + +HAS_CRYPTOGRAPHY = True +try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import hashes + from cryptography.exceptions import UnsupportedAlgorithm +except ImportError: + HAS_CRYPTOGRAPHY = False + +# Old import for GSSAPI authentication, this is not used in urls.py but kept for backwards compatibility. +try: + import urllib_gssapi + HAS_GSSAPI = True +except ImportError: + HAS_GSSAPI = False + +GSSAPI_IMP_ERR = None +try: + import gssapi + + class HTTPGSSAPIAuthHandler(BaseHandler): + """ Handles Negotiate/Kerberos support through the gssapi library. """ + + AUTH_HEADER_PATTERN = re.compile(r'(?:.*)\s*(Negotiate|Kerberos)\s*([^,]*),?', re.I) + handler_order = 480 # Handle before Digest authentication + + def __init__(self, username=None, password=None): + self.username = username + self.password = password + self._context = None + + def get_auth_value(self, headers): + auth_match = self.AUTH_HEADER_PATTERN.search(headers.get('www-authenticate', '')) + if auth_match: + return auth_match.group(1), base64.b64decode(auth_match.group(2)) + + def http_error_401(self, req, fp, code, msg, headers): + # If we've already attempted the auth and we've reached this again then there was a failure. + if self._context: + return + + parsed = generic_urlparse(urlparse(req.get_full_url())) + + auth_header = self.get_auth_value(headers) + if not auth_header: + return + auth_protocol, in_token = auth_header + + username = None + if self.username: + username = gssapi.Name(self.username, name_type=gssapi.NameType.user) + + if username and self.password: + if not hasattr(gssapi.raw, 'acquire_cred_with_password'): + raise NotImplementedError("Platform GSSAPI library does not support " + "gss_acquire_cred_with_password, cannot acquire GSSAPI credential with " + "explicit username and password.") + + b_password = to_bytes(self.password, errors='surrogate_or_strict') + cred = gssapi.raw.acquire_cred_with_password(username, b_password, usage='initiate').creds + + else: + cred = gssapi.Credentials(name=username, usage='initiate') + + # Get the peer certificate for the channel binding token if possible (HTTPS). A bug on macOS causes the + # authentication to fail when the CBT is present. Just skip that platform. + cbt = None + cert = getpeercert(fp, True) + if cert and platform.system() != 'Darwin': + cert_hash = get_channel_binding_cert_hash(cert) + if cert_hash: + cbt = gssapi.raw.ChannelBindings(application_data=b"tls-server-end-point:" + cert_hash) + + # TODO: We could add another option that is set to include the port in the SPN if desired in the future. + target = gssapi.Name("HTTP@%s" % parsed['hostname'], gssapi.NameType.hostbased_service) + self._context = gssapi.SecurityContext(usage="initiate", name=target, creds=cred, channel_bindings=cbt) + + resp = None + while not self._context.complete: + out_token = self._context.step(in_token) + if not out_token: + break + + auth_header = '%s %s' % (auth_protocol, to_native(base64.b64encode(out_token))) + req.add_unredirected_header('Authorization', auth_header) + resp = self.parent.open(req) + + # The response could contain a token that the client uses to validate the server + auth_header = self.get_auth_value(resp.headers) + if not auth_header: + break + in_token = auth_header[1] + + return resp + +except ImportError: + GSSAPI_IMP_ERR = traceback.format_exc() + HTTPGSSAPIAuthHandler = None # type: types.ModuleType | None # type: ignore[no-redef] + +if not HAS_MATCH_HOSTNAME: + # The following block of code is under the terms and conditions of the + # Python Software Foundation License + + """The match_hostname() function from Python 3.4, essential when using SSL.""" + + try: + # Divergence: Python-3.7+'s _ssl has this exception type but older Pythons do not + from _ssl import SSLCertVerificationError + CertificateError = SSLCertVerificationError # type: ignore[misc] + except ImportError: + class CertificateError(ValueError): # type: ignore[no-redef] + pass + + def _dnsname_match(dn, hostname): + """Matching according to RFC 6125, section 6.4.3 + + - Hostnames are compared lower case. + - For IDNA, both dn and hostname must be encoded as IDN A-label (ACE). + - Partial wildcards like 'www*.example.org', multiple wildcards, sole + wildcard or wildcards in labels other then the left-most label are not + supported and a CertificateError is raised. + - A wildcard must match at least one character. + """ + if not dn: + return False + + wildcards = dn.count('*') + # speed up common case w/o wildcards + if not wildcards: + return dn.lower() == hostname.lower() + + if wildcards > 1: + # Divergence .format() to percent formatting for Python < 2.6 + raise CertificateError( + "too many wildcards in certificate DNS name: %s" % repr(dn)) + + dn_leftmost, sep, dn_remainder = dn.partition('.') + + if '*' in dn_remainder: + # Only match wildcard in leftmost segment. + # Divergence .format() to percent formatting for Python < 2.6 + raise CertificateError( + "wildcard can only be present in the leftmost label: " + "%s." % repr(dn)) + + if not sep: + # no right side + # Divergence .format() to percent formatting for Python < 2.6 + raise CertificateError( + "sole wildcard without additional labels are not support: " + "%s." % repr(dn)) + + if dn_leftmost != '*': + # no partial wildcard matching + # Divergence .format() to percent formatting for Python < 2.6 + raise CertificateError( + "partial wildcards in leftmost label are not supported: " + "%s." % repr(dn)) + + hostname_leftmost, sep, hostname_remainder = hostname.partition('.') + if not hostname_leftmost or not sep: + # wildcard must match at least one char + return False + return dn_remainder.lower() == hostname_remainder.lower() + + def _inet_paton(ipname): + """Try to convert an IP address to packed binary form + + Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6 + support. + """ + # inet_aton() also accepts strings like '1' + # Divergence: We make sure we have native string type for all python versions + try: + b_ipname = to_bytes(ipname, errors='strict') + except UnicodeError: + raise ValueError("%s must be an all-ascii string." % repr(ipname)) + + # Set ipname in native string format + if sys.version_info < (3,): + n_ipname = b_ipname + else: + n_ipname = ipname + + if n_ipname.count('.') == 3: + try: + return socket.inet_aton(n_ipname) + # Divergence: OSError on late python3. socket.error earlier. + # Null bytes generate ValueError on python3(we want to raise + # ValueError anyway), TypeError # earlier + except (OSError, socket.error, TypeError): + pass + + try: + return socket.inet_pton(socket.AF_INET6, n_ipname) + # Divergence: OSError on late python3. socket.error earlier. + # Null bytes generate ValueError on python3(we want to raise + # ValueError anyway), TypeError # earlier + except (OSError, socket.error, TypeError): + # Divergence .format() to percent formatting for Python < 2.6 + raise ValueError("%s is neither an IPv4 nor an IP6 " + "address." % repr(ipname)) + except AttributeError: + # AF_INET6 not available + pass + + # Divergence .format() to percent formatting for Python < 2.6 + raise ValueError("%s is not an IPv4 address." % repr(ipname)) + + def _ipaddress_match(ipname, host_ip): + """Exact matching of IP addresses. + + RFC 6125 explicitly doesn't define an algorithm for this + (section 1.7.2 - "Out of Scope"). + """ + # OpenSSL may add a trailing newline to a subjectAltName's IP address + ip = _inet_paton(ipname.rstrip()) + return ip == host_ip + + def match_hostname(cert, hostname): # type: ignore[misc] + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 + rules are followed. + + The function matches IP addresses rather than dNSNames if hostname is a + valid ipaddress string. IPv4 addresses are supported on all platforms. + IPv6 addresses are supported on platforms with IPv6 support (AF_INET6 + and inet_pton). + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate, match_hostname needs a " + "SSL socket or SSL context with either " + "CERT_OPTIONAL or CERT_REQUIRED") + try: + # Divergence: Deal with hostname as bytes + host_ip = _inet_paton(to_text(hostname, errors='strict')) + except UnicodeError: + # Divergence: Deal with hostname as byte strings. + # IP addresses should be all ascii, so we consider it not + # an IP address if this fails + host_ip = None + except ValueError: + # Not an IP address (common case) + host_ip = None + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if host_ip is None and _dnsname_match(value, hostname): + return + dnsnames.append(value) + elif key == 'IP Address': + if host_ip is not None and _ipaddress_match(value, host_ip): + return + dnsnames.append(value) + if not dnsnames: + # The subject is only checked when there is no dNSName entry + # in subjectAltName + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_match(value, hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r doesn't match either of %s" % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or subjectAltName fields were found") + + # End of Python Software Foundation Licensed code + + HAS_MATCH_HOSTNAME = True + + +# This is a dummy cacert provided for macOS since you need at least 1 +# ca cert, regardless of validity, for Python on macOS to use the +# keychain functionality in OpenSSL for validating SSL certificates. +# See: http://mercurial.selenic.com/wiki/CACertificates#Mac_OS_X_10.6_and_higher +b_DUMMY_CA_CERT = b"""-----BEGIN CERTIFICATE----- +MIICvDCCAiWgAwIBAgIJAO8E12S7/qEpMA0GCSqGSIb3DQEBBQUAMEkxCzAJBgNV +BAYTAlVTMRcwFQYDVQQIEw5Ob3J0aCBDYXJvbGluYTEPMA0GA1UEBxMGRHVyaGFt +MRAwDgYDVQQKEwdBbnNpYmxlMB4XDTE0MDMxODIyMDAyMloXDTI0MDMxNTIyMDAy +MlowSTELMAkGA1UEBhMCVVMxFzAVBgNVBAgTDk5vcnRoIENhcm9saW5hMQ8wDQYD +VQQHEwZEdXJoYW0xEDAOBgNVBAoTB0Fuc2libGUwgZ8wDQYJKoZIhvcNAQEBBQAD +gY0AMIGJAoGBANtvpPq3IlNlRbCHhZAcP6WCzhc5RbsDqyh1zrkmLi0GwcQ3z/r9 +gaWfQBYhHpobK2Tiq11TfraHeNB3/VfNImjZcGpN8Fl3MWwu7LfVkJy3gNNnxkA1 +4Go0/LmIvRFHhbzgfuo9NFgjPmmab9eqXJceqZIlz2C8xA7EeG7ku0+vAgMBAAGj +gaswgagwHQYDVR0OBBYEFPnN1nPRqNDXGlCqCvdZchRNi/FaMHkGA1UdIwRyMHCA +FPnN1nPRqNDXGlCqCvdZchRNi/FaoU2kSzBJMQswCQYDVQQGEwJVUzEXMBUGA1UE +CBMOTm9ydGggQ2Fyb2xpbmExDzANBgNVBAcTBkR1cmhhbTEQMA4GA1UEChMHQW5z +aWJsZYIJAO8E12S7/qEpMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEA +MUB80IR6knq9K/tY+hvPsZer6eFMzO3JGkRFBh2kn6JdMDnhYGX7AXVHGflrwNQH +qFy+aenWXsC0ZvrikFxbQnX8GVtDADtVznxOi7XzFw7JOxdsVrpXgSN0eh0aMzvV +zKPZsZ2miVGclicJHzm5q080b1p/sZtuKIEZk6vZqEg= +-----END CERTIFICATE----- +""" + +b_PEM_CERT_RE = re.compile( + br'^-----BEGIN CERTIFICATE-----\n.+?-----END CERTIFICATE-----$', + flags=re.M | re.S +) + +# +# Exceptions +# + + +class ConnectionError(Exception): + """Failed to connect to the server""" + pass + + +class ProxyError(ConnectionError): + """Failure to connect because of a proxy""" + pass + + +class SSLValidationError(ConnectionError): + """Failure to connect due to SSL validation failing""" + pass + + +class NoSSLError(SSLValidationError): + """Needed to connect to an HTTPS url but no ssl library available to verify the certificate""" + pass + + +class MissingModuleError(Exception): + """Failed to import 3rd party module required by the caller""" + def __init__(self, message, import_traceback, module=None): + super(MissingModuleError, self).__init__(message) + self.import_traceback = import_traceback + self.module = module + + +# Some environments (Google Compute Engine's CoreOS deploys) do not compile +# against openssl and thus do not have any HTTPS support. +CustomHTTPSConnection = None +CustomHTTPSHandler = None +HTTPSClientAuthHandler = None +UnixHTTPSConnection = None +if hasattr(httplib, 'HTTPSConnection') and hasattr(urllib_request, 'HTTPSHandler'): + class CustomHTTPSConnection(httplib.HTTPSConnection): # type: ignore[no-redef] + def __init__(self, *args, **kwargs): + httplib.HTTPSConnection.__init__(self, *args, **kwargs) + self.context = None + if HAS_SSLCONTEXT: + self.context = self._context + elif HAS_URLLIB3_PYOPENSSLCONTEXT: + self.context = self._context = PyOpenSSLContext(PROTOCOL) + if self.context and self.cert_file: + self.context.load_cert_chain(self.cert_file, self.key_file) + + def connect(self): + "Connect to a host on a given (SSL) port." + + if hasattr(self, 'source_address'): + sock = socket.create_connection((self.host, self.port), self.timeout, self.source_address) + else: + sock = socket.create_connection((self.host, self.port), self.timeout) + + server_hostname = self.host + # Note: self._tunnel_host is not available on py < 2.6 but this code + # isn't used on py < 2.6 (lack of create_connection) + if self._tunnel_host: + self.sock = sock + self._tunnel() + server_hostname = self._tunnel_host + + if HAS_SSLCONTEXT or HAS_URLLIB3_PYOPENSSLCONTEXT: + self.sock = self.context.wrap_socket(sock, server_hostname=server_hostname) + elif HAS_URLLIB3_SSL_WRAP_SOCKET: + self.sock = ssl_wrap_socket(sock, keyfile=self.key_file, cert_reqs=ssl.CERT_NONE, # pylint: disable=used-before-assignment + certfile=self.cert_file, ssl_version=PROTOCOL, server_hostname=server_hostname) + else: + self.sock = ssl.wrap_socket(sock, keyfile=self.key_file, certfile=self.cert_file, ssl_version=PROTOCOL) + + class CustomHTTPSHandler(urllib_request.HTTPSHandler): # type: ignore[no-redef] + + def https_open(self, req): + kwargs = {} + if HAS_SSLCONTEXT: + kwargs['context'] = self._context + return self.do_open( + functools.partial( + CustomHTTPSConnection, + **kwargs + ), + req + ) + + https_request = AbstractHTTPHandler.do_request_ + + class HTTPSClientAuthHandler(urllib_request.HTTPSHandler): # type: ignore[no-redef] + '''Handles client authentication via cert/key + + This is a fairly lightweight extension on HTTPSHandler, and can be used + in place of HTTPSHandler + ''' + + def __init__(self, client_cert=None, client_key=None, unix_socket=None, **kwargs): + urllib_request.HTTPSHandler.__init__(self, **kwargs) + self.client_cert = client_cert + self.client_key = client_key + self._unix_socket = unix_socket + + def https_open(self, req): + return self.do_open(self._build_https_connection, req) + + def _build_https_connection(self, host, **kwargs): + kwargs.update({ + 'cert_file': self.client_cert, + 'key_file': self.client_key, + }) + try: + kwargs['context'] = self._context + except AttributeError: + pass + if self._unix_socket: + return UnixHTTPSConnection(self._unix_socket)(host, **kwargs) + if not HAS_SSLCONTEXT: + return CustomHTTPSConnection(host, **kwargs) + return httplib.HTTPSConnection(host, **kwargs) + + @contextmanager + def unix_socket_patch_httpconnection_connect(): + '''Monkey patch ``httplib.HTTPConnection.connect`` to be ``UnixHTTPConnection.connect`` + so that when calling ``super(UnixHTTPSConnection, self).connect()`` we get the + correct behavior of creating self.sock for the unix socket + ''' + _connect = httplib.HTTPConnection.connect + httplib.HTTPConnection.connect = UnixHTTPConnection.connect + yield + httplib.HTTPConnection.connect = _connect + + class UnixHTTPSConnection(httplib.HTTPSConnection): # type: ignore[no-redef] + def __init__(self, unix_socket): + self._unix_socket = unix_socket + + def connect(self): + # This method exists simply to ensure we monkeypatch + # httplib.HTTPConnection.connect to call UnixHTTPConnection.connect + with unix_socket_patch_httpconnection_connect(): + # Disable pylint check for the super() call. It complains about UnixHTTPSConnection + # being a NoneType because of the initial definition above, but it won't actually + # be a NoneType when this code runs + # pylint: disable=bad-super-call + super(UnixHTTPSConnection, self).connect() + + def __call__(self, *args, **kwargs): + httplib.HTTPSConnection.__init__(self, *args, **kwargs) + return self + + +class UnixHTTPConnection(httplib.HTTPConnection): + '''Handles http requests to a unix socket file''' + + def __init__(self, unix_socket): + self._unix_socket = unix_socket + + def connect(self): + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + self.sock.connect(self._unix_socket) + except OSError as e: + raise OSError('Invalid Socket File (%s): %s' % (self._unix_socket, e)) + if self.timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + self.sock.settimeout(self.timeout) + + def __call__(self, *args, **kwargs): + httplib.HTTPConnection.__init__(self, *args, **kwargs) + return self + + +class UnixHTTPHandler(urllib_request.HTTPHandler): + '''Handler for Unix urls''' + + def __init__(self, unix_socket, **kwargs): + urllib_request.HTTPHandler.__init__(self, **kwargs) + self._unix_socket = unix_socket + + def http_open(self, req): + return self.do_open(UnixHTTPConnection(self._unix_socket), req) + + +class ParseResultDottedDict(dict): + ''' + A dict that acts similarly to the ParseResult named tuple from urllib + ''' + def __init__(self, *args, **kwargs): + super(ParseResultDottedDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + def as_list(self): + ''' + Generate a list from this dict, that looks like the ParseResult named tuple + ''' + return [self.get(k, None) for k in ('scheme', 'netloc', 'path', 'params', 'query', 'fragment')] + + +def generic_urlparse(parts): + ''' + Returns a dictionary of url parts as parsed by urlparse, + but accounts for the fact that older versions of that + library do not support named attributes (ie. .netloc) + ''' + generic_parts = ParseResultDottedDict() + if hasattr(parts, 'netloc'): + # urlparse is newer, just read the fields straight + # from the parts object + generic_parts['scheme'] = parts.scheme + generic_parts['netloc'] = parts.netloc + generic_parts['path'] = parts.path + generic_parts['params'] = parts.params + generic_parts['query'] = parts.query + generic_parts['fragment'] = parts.fragment + generic_parts['username'] = parts.username + generic_parts['password'] = parts.password + hostname = parts.hostname + if hostname and hostname[0] == '[' and '[' in parts.netloc and ']' in parts.netloc: + # Py2.6 doesn't parse IPv6 addresses correctly + hostname = parts.netloc.split(']')[0][1:].lower() + generic_parts['hostname'] = hostname + + try: + port = parts.port + except ValueError: + # Py2.6 doesn't parse IPv6 addresses correctly + netloc = parts.netloc.split('@')[-1].split(']')[-1] + if ':' in netloc: + port = netloc.split(':')[1] + if port: + port = int(port) + else: + port = None + generic_parts['port'] = port + else: + # we have to use indexes, and then parse out + # the other parts not supported by indexing + generic_parts['scheme'] = parts[0] + generic_parts['netloc'] = parts[1] + generic_parts['path'] = parts[2] + generic_parts['params'] = parts[3] + generic_parts['query'] = parts[4] + generic_parts['fragment'] = parts[5] + # get the username, password, etc. + try: + netloc_re = re.compile(r'^((?:\w)+(?::(?:\w)+)?@)?([A-Za-z0-9.-]+)(:\d+)?$') + match = netloc_re.match(parts[1]) + auth = match.group(1) + hostname = match.group(2) + port = match.group(3) + if port: + # the capture group for the port will include the ':', + # so remove it and convert the port to an integer + port = int(port[1:]) + if auth: + # the capture group above includes the @, so remove it + # and then split it up based on the first ':' found + auth = auth[:-1] + username, password = auth.split(':', 1) + else: + username = password = None + generic_parts['username'] = username + generic_parts['password'] = password + generic_parts['hostname'] = hostname + generic_parts['port'] = port + except Exception: + generic_parts['username'] = None + generic_parts['password'] = None + generic_parts['hostname'] = parts[1] + generic_parts['port'] = None + return generic_parts + + +def extract_pem_certs(b_data): + for match in b_PEM_CERT_RE.finditer(b_data): + yield match.group(0) + + +def get_response_filename(response): + url = response.geturl() + path = urlparse(url)[2] + filename = os.path.basename(path.rstrip('/')) or None + if filename: + filename = unquote(filename) + + return response.headers.get_param('filename', header='content-disposition') or filename + + +def parse_content_type(response): + if PY2: + get_type = response.headers.gettype + get_param = response.headers.getparam + else: + get_type = response.headers.get_content_type + get_param = response.headers.get_param + + content_type = (get_type() or 'application/octet-stream').split(',')[0] + main_type, sub_type = content_type.split('/') + charset = (get_param('charset') or 'utf-8').split(',')[0] + return content_type, main_type, sub_type, charset + + +class GzipDecodedReader(GzipFile): + """A file-like object to decode a response encoded with the gzip + method, as described in RFC 1952. + + Largely copied from ``xmlrpclib``/``xmlrpc.client`` + """ + def __init__(self, fp): + if not HAS_GZIP: + raise MissingModuleError(self.missing_gzip_error(), import_traceback=GZIP_IMP_ERR) + + if PY3: + self._io = fp + else: + # Py2 ``HTTPResponse``/``addinfourl`` doesn't support all of the file object + # functionality GzipFile requires + self._io = io.BytesIO() + for block in iter(functools.partial(fp.read, 65536), b''): + self._io.write(block) + self._io.seek(0) + fp.close() + gzip.GzipFile.__init__(self, mode='rb', fileobj=self._io) # pylint: disable=non-parent-init-called + + def close(self): + try: + gzip.GzipFile.close(self) + finally: + self._io.close() + + @staticmethod + def missing_gzip_error(): + return missing_required_lib( + 'gzip', + reason='to decompress gzip encoded responses. ' + 'Set "decompress" to False, to prevent attempting auto decompression' + ) + + +class RequestWithMethod(urllib_request.Request): + ''' + Workaround for using DELETE/PUT/etc with urllib2 + Originally contained in library/net_infrastructure/dnsmadeeasy + ''' + + def __init__(self, url, method, data=None, headers=None, origin_req_host=None, unverifiable=True): + if headers is None: + headers = {} + self._method = method.upper() + urllib_request.Request.__init__(self, url, data, headers, origin_req_host, unverifiable) + + def get_method(self): + if self._method: + return self._method + else: + return urllib_request.Request.get_method(self) + + +def RedirectHandlerFactory(follow_redirects=None, validate_certs=True, ca_path=None, ciphers=None): + """This is a class factory that closes over the value of + ``follow_redirects`` so that the RedirectHandler class has access to + that value without having to use globals, and potentially cause problems + where ``open_url`` or ``fetch_url`` are used multiple times in a module. + """ + + class RedirectHandler(urllib_request.HTTPRedirectHandler): + """This is an implementation of a RedirectHandler to match the + functionality provided by httplib2. It will utilize the value of + ``follow_redirects`` that is passed into ``RedirectHandlerFactory`` + to determine how redirects should be handled in urllib2. + """ + + def redirect_request(self, req, fp, code, msg, hdrs, newurl): + if not any((HAS_SSLCONTEXT, HAS_URLLIB3_PYOPENSSLCONTEXT)): + handler = maybe_add_ssl_handler(newurl, validate_certs, ca_path=ca_path, ciphers=ciphers) + if handler: + urllib_request._opener.add_handler(handler) + + # Preserve urllib2 compatibility + if follow_redirects == 'urllib2': + return urllib_request.HTTPRedirectHandler.redirect_request(self, req, fp, code, msg, hdrs, newurl) + + # Handle disabled redirects + elif follow_redirects in ['no', 'none', False]: + raise urllib_error.HTTPError(newurl, code, msg, hdrs, fp) + + method = req.get_method() + + # Handle non-redirect HTTP status or invalid follow_redirects + if follow_redirects in ['all', 'yes', True]: + if code < 300 or code >= 400: + raise urllib_error.HTTPError(req.get_full_url(), code, msg, hdrs, fp) + elif follow_redirects == 'safe': + if code < 300 or code >= 400 or method not in ('GET', 'HEAD'): + raise urllib_error.HTTPError(req.get_full_url(), code, msg, hdrs, fp) + else: + raise urllib_error.HTTPError(req.get_full_url(), code, msg, hdrs, fp) + + try: + # Python 2-3.3 + data = req.get_data() + origin_req_host = req.get_origin_req_host() + except AttributeError: + # Python 3.4+ + data = req.data + origin_req_host = req.origin_req_host + + # Be conciliant with URIs containing a space + newurl = newurl.replace(' ', '%20') + + # Support redirect with payload and original headers + if code in (307, 308): + # Preserve payload and headers + headers = req.headers + else: + # Do not preserve payload and filter headers + data = None + headers = dict((k, v) for k, v in req.headers.items() + if k.lower() not in ("content-length", "content-type", "transfer-encoding")) + + # http://tools.ietf.org/html/rfc7231#section-6.4.4 + if code == 303 and method != 'HEAD': + method = 'GET' + + # Do what the browsers do, despite standards... + # First, turn 302s into GETs. + if code == 302 and method != 'HEAD': + method = 'GET' + + # Second, if a POST is responded to with a 301, turn it into a GET. + if code == 301 and method == 'POST': + method = 'GET' + + return RequestWithMethod(newurl, + method=method, + headers=headers, + data=data, + origin_req_host=origin_req_host, + unverifiable=True, + ) + + return RedirectHandler + + +def build_ssl_validation_error(hostname, port, paths, exc=None): + '''Inteligently build out the SSLValidationError based on what support + you have installed + ''' + + msg = [ + ('Failed to validate the SSL certificate for %s:%s.' + ' Make sure your managed systems have a valid CA' + ' certificate installed.') + ] + if not HAS_SSLCONTEXT: + msg.append('If the website serving the url uses SNI you need' + ' python >= 2.7.9 on your managed machine') + msg.append(' (the python executable used (%s) is version: %s)' % + (sys.executable, ''.join(sys.version.splitlines()))) + if not HAS_URLLIB3_PYOPENSSLCONTEXT and not HAS_URLLIB3_SSL_WRAP_SOCKET: + msg.append('or you can install the `urllib3`, `pyOpenSSL`,' + ' `ndg-httpsclient`, and `pyasn1` python modules') + + msg.append('to perform SNI verification in python >= 2.6.') + + msg.append('You can use validate_certs=False if you do' + ' not need to confirm the servers identity but this is' + ' unsafe and not recommended.' + ' Paths checked for this platform: %s.') + + if exc: + msg.append('The exception msg was: %s.' % to_native(exc)) + + raise SSLValidationError(' '.join(msg) % (hostname, port, ", ".join(paths))) + + +def atexit_remove_file(filename): + if os.path.exists(filename): + try: + os.unlink(filename) + except Exception: + # just ignore if we cannot delete, things should be ok + pass + + +def make_context(cafile=None, cadata=None, ciphers=None, validate_certs=True): + if ciphers is None: + ciphers = [] + + if not is_sequence(ciphers): + raise TypeError('Ciphers must be a list. Got %s.' % ciphers.__class__.__name__) + + if HAS_SSLCONTEXT: + context = create_default_context(cafile=cafile) + elif HAS_URLLIB3_PYOPENSSLCONTEXT: + context = PyOpenSSLContext(PROTOCOL) + else: + raise NotImplementedError('Host libraries are too old to support creating an sslcontext') + + if not validate_certs: + if ssl.OP_NO_SSLv2: + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + if validate_certs and any((cafile, cadata)): + context.load_verify_locations(cafile=cafile, cadata=cadata) + + if ciphers: + context.set_ciphers(':'.join(map(to_native, ciphers))) + + return context + + +def get_ca_certs(cafile=None): + # tries to find a valid CA cert in one of the + # standard locations for the current distribution + + cadata = bytearray() + paths_checked = [] + + if cafile: + paths_checked = [cafile] + with open(to_bytes(cafile, errors='surrogate_or_strict'), 'rb') as f: + if HAS_SSLCONTEXT: + for b_pem in extract_pem_certs(f.read()): + cadata.extend( + ssl.PEM_cert_to_DER_cert( + to_native(b_pem, errors='surrogate_or_strict') + ) + ) + return cafile, cadata, paths_checked + + if not HAS_SSLCONTEXT: + paths_checked.append('/etc/ssl/certs') + + system = to_text(platform.system(), errors='surrogate_or_strict') + # build a list of paths to check for .crt/.pem files + # based on the platform type + if system == u'Linux': + paths_checked.append('/etc/pki/ca-trust/extracted/pem') + paths_checked.append('/etc/pki/tls/certs') + paths_checked.append('/usr/share/ca-certificates/cacert.org') + elif system == u'FreeBSD': + paths_checked.append('/usr/local/share/certs') + elif system == u'OpenBSD': + paths_checked.append('/etc/ssl') + elif system == u'NetBSD': + paths_checked.append('/etc/openssl/certs') + elif system == u'SunOS': + paths_checked.append('/opt/local/etc/openssl/certs') + elif system == u'AIX': + paths_checked.append('/var/ssl/certs') + paths_checked.append('/opt/freeware/etc/ssl/certs') + + # fall back to a user-deployed cert in a standard + # location if the OS platform one is not available + paths_checked.append('/etc/ansible') + + tmp_path = None + if not HAS_SSLCONTEXT: + tmp_fd, tmp_path = tempfile.mkstemp() + atexit.register(atexit_remove_file, tmp_path) + + # Write the dummy ca cert if we are running on macOS + if system == u'Darwin': + if HAS_SSLCONTEXT: + cadata.extend( + ssl.PEM_cert_to_DER_cert( + to_native(b_DUMMY_CA_CERT, errors='surrogate_or_strict') + ) + ) + else: + os.write(tmp_fd, b_DUMMY_CA_CERT) + # Default Homebrew path for OpenSSL certs + paths_checked.append('/usr/local/etc/openssl') + + # for all of the paths, find any .crt or .pem files + # and compile them into single temp file for use + # in the ssl check to speed up the test + for path in paths_checked: + if not os.path.isdir(path): + continue + + dir_contents = os.listdir(path) + for f in dir_contents: + full_path = os.path.join(path, f) + if os.path.isfile(full_path) and os.path.splitext(f)[1] in ('.crt', '.pem'): + try: + if full_path not in LOADED_VERIFY_LOCATIONS: + with open(full_path, 'rb') as cert_file: + b_cert = cert_file.read() + if HAS_SSLCONTEXT: + try: + for b_pem in extract_pem_certs(b_cert): + cadata.extend( + ssl.PEM_cert_to_DER_cert( + to_native(b_pem, errors='surrogate_or_strict') + ) + ) + except Exception: + continue + else: + os.write(tmp_fd, b_cert) + os.write(tmp_fd, b'\n') + except (OSError, IOError): + pass + + if HAS_SSLCONTEXT: + default_verify_paths = ssl.get_default_verify_paths() + paths_checked[:0] = [default_verify_paths.capath] + else: + os.close(tmp_fd) + + return (tmp_path, cadata, paths_checked) + + +class SSLValidationHandler(urllib_request.BaseHandler): + ''' + A custom handler class for SSL validation. + + Based on: + http://stackoverflow.com/questions/1087227/validate-ssl-certificates-with-python + http://techknack.net/python-urllib2-handlers/ + ''' + CONNECT_COMMAND = "CONNECT %s:%s HTTP/1.0\r\n" + + def __init__(self, hostname, port, ca_path=None, ciphers=None, validate_certs=True): + self.hostname = hostname + self.port = port + self.ca_path = ca_path + self.ciphers = ciphers + self.validate_certs = validate_certs + + def get_ca_certs(self): + return get_ca_certs(self.ca_path) + + def validate_proxy_response(self, response, valid_codes=None): + ''' + make sure we get back a valid code from the proxy + ''' + valid_codes = [200] if valid_codes is None else valid_codes + + try: + (http_version, resp_code, msg) = re.match(br'(HTTP/\d\.\d) (\d\d\d) (.*)', response).groups() + if int(resp_code) not in valid_codes: + raise Exception + except Exception: + raise ProxyError('Connection to proxy failed') + + def detect_no_proxy(self, url): + ''' + Detect if the 'no_proxy' environment variable is set and honor those locations. + ''' + env_no_proxy = os.environ.get('no_proxy') + if env_no_proxy: + env_no_proxy = env_no_proxy.split(',') + netloc = urlparse(url).netloc + + for host in env_no_proxy: + if netloc.endswith(host) or netloc.split(':')[0].endswith(host): + # Our requested URL matches something in no_proxy, so don't + # use the proxy for this + return False + return True + + def make_context(self, cafile, cadata, ciphers=None, validate_certs=True): + cafile = self.ca_path or cafile + if self.ca_path: + cadata = None + else: + cadata = cadata or None + + return make_context(cafile=cafile, cadata=cadata, ciphers=ciphers, validate_certs=validate_certs) + + def http_request(self, req): + tmp_ca_cert_path, cadata, paths_checked = self.get_ca_certs() + + # Detect if 'no_proxy' environment variable is set and if our URL is included + use_proxy = self.detect_no_proxy(req.get_full_url()) + https_proxy = os.environ.get('https_proxy') + + context = None + try: + context = self.make_context(tmp_ca_cert_path, cadata, ciphers=self.ciphers, validate_certs=self.validate_certs) + except NotImplementedError: + # We'll make do with no context below + pass + + try: + if use_proxy and https_proxy: + proxy_parts = generic_urlparse(urlparse(https_proxy)) + port = proxy_parts.get('port') or 443 + proxy_hostname = proxy_parts.get('hostname', None) + if proxy_hostname is None or proxy_parts.get('scheme') == '': + raise ProxyError("Failed to parse https_proxy environment variable." + " Please make sure you export https proxy as 'https_proxy=<SCHEME>://<IP_ADDRESS>:<PORT>'") + + s = socket.create_connection((proxy_hostname, port)) + if proxy_parts.get('scheme') == 'http': + s.sendall(to_bytes(self.CONNECT_COMMAND % (self.hostname, self.port), errors='surrogate_or_strict')) + if proxy_parts.get('username'): + credentials = "%s:%s" % (proxy_parts.get('username', ''), proxy_parts.get('password', '')) + s.sendall(b'Proxy-Authorization: Basic %s\r\n' % base64.b64encode(to_bytes(credentials, errors='surrogate_or_strict')).strip()) + s.sendall(b'\r\n') + connect_result = b"" + while connect_result.find(b"\r\n\r\n") <= 0: + connect_result += s.recv(4096) + # 128 kilobytes of headers should be enough for everyone. + if len(connect_result) > 131072: + raise ProxyError('Proxy sent too verbose headers. Only 128KiB allowed.') + self.validate_proxy_response(connect_result) + if context: + ssl_s = context.wrap_socket(s, server_hostname=self.hostname) + elif HAS_URLLIB3_SSL_WRAP_SOCKET: + ssl_s = ssl_wrap_socket(s, ca_certs=tmp_ca_cert_path, cert_reqs=ssl.CERT_REQUIRED, ssl_version=PROTOCOL, server_hostname=self.hostname) + else: + ssl_s = ssl.wrap_socket(s, ca_certs=tmp_ca_cert_path, cert_reqs=ssl.CERT_REQUIRED, ssl_version=PROTOCOL) + match_hostname(ssl_s.getpeercert(), self.hostname) + else: + raise ProxyError('Unsupported proxy scheme: %s. Currently ansible only supports HTTP proxies.' % proxy_parts.get('scheme')) + else: + s = socket.create_connection((self.hostname, self.port)) + if context: + ssl_s = context.wrap_socket(s, server_hostname=self.hostname) + elif HAS_URLLIB3_SSL_WRAP_SOCKET: + ssl_s = ssl_wrap_socket(s, ca_certs=tmp_ca_cert_path, cert_reqs=ssl.CERT_REQUIRED, ssl_version=PROTOCOL, server_hostname=self.hostname) + else: + ssl_s = ssl.wrap_socket(s, ca_certs=tmp_ca_cert_path, cert_reqs=ssl.CERT_REQUIRED, ssl_version=PROTOCOL) + match_hostname(ssl_s.getpeercert(), self.hostname) + # close the ssl connection + # ssl_s.unwrap() + s.close() + except (ssl.SSLError, CertificateError) as e: + build_ssl_validation_error(self.hostname, self.port, paths_checked, e) + except socket.error as e: + raise ConnectionError('Failed to connect to %s at port %s: %s' % (self.hostname, self.port, to_native(e))) + + return req + + https_request = http_request + + +def maybe_add_ssl_handler(url, validate_certs, ca_path=None, ciphers=None): + parsed = generic_urlparse(urlparse(url)) + if parsed.scheme == 'https' and validate_certs: + if not HAS_SSL: + raise NoSSLError('SSL validation is not available in your version of python. You can use validate_certs=False,' + ' however this is unsafe and not recommended') + + # create the SSL validation handler + return SSLValidationHandler(parsed.hostname, parsed.port or 443, ca_path=ca_path, ciphers=ciphers, validate_certs=validate_certs) + + +def getpeercert(response, binary_form=False): + """ Attempt to get the peer certificate of the response from urlopen. """ + # The response from urllib2.open() is different across Python 2 and 3 + if PY3: + socket = response.fp.raw._sock + else: + socket = response.fp._sock.fp._sock + + try: + return socket.getpeercert(binary_form) + except AttributeError: + pass # Not HTTPS + + +def get_channel_binding_cert_hash(certificate_der): + """ Gets the channel binding app data for a TLS connection using the peer cert. """ + if not HAS_CRYPTOGRAPHY: + return + + # Logic documented in RFC 5929 section 4 https://tools.ietf.org/html/rfc5929#section-4 + cert = x509.load_der_x509_certificate(certificate_der, default_backend()) + + hash_algorithm = None + try: + hash_algorithm = cert.signature_hash_algorithm + except UnsupportedAlgorithm: + pass + + # If the signature hash algorithm is unknown/unsupported or md5/sha1 we must use SHA256. + if not hash_algorithm or hash_algorithm.name in ['md5', 'sha1']: + hash_algorithm = hashes.SHA256() + + digest = hashes.Hash(hash_algorithm, default_backend()) + digest.update(certificate_der) + return digest.finalize() + + +def rfc2822_date_string(timetuple, zone='-0000'): + """Accepts a timetuple and optional zone which defaults to ``-0000`` + and returns a date string as specified by RFC 2822, e.g.: + + Fri, 09 Nov 2001 01:08:47 -0000 + + Copied from email.utils.formatdate and modified for separate use + """ + return '%s, %02d %s %04d %02d:%02d:%02d %s' % ( + ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'][timetuple[6]], + timetuple[2], + ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'][timetuple[1] - 1], + timetuple[0], timetuple[3], timetuple[4], timetuple[5], + zone) + + +class Request: + def __init__(self, headers=None, use_proxy=True, force=False, timeout=10, validate_certs=True, + url_username=None, url_password=None, http_agent=None, force_basic_auth=False, + follow_redirects='urllib2', client_cert=None, client_key=None, cookies=None, unix_socket=None, + ca_path=None, unredirected_headers=None, decompress=True, ciphers=None, use_netrc=True): + """This class works somewhat similarly to the ``Session`` class of from requests + by defining a cookiejar that an be used across requests as well as cascaded defaults that + can apply to repeated requests + + For documentation of params, see ``Request.open`` + + >>> from ansible.module_utils.urls import Request + >>> r = Request() + >>> r.open('GET', 'http://httpbin.org/cookies/set?k1=v1').read() + '{\n "cookies": {\n "k1": "v1"\n }\n}\n' + >>> r = Request(url_username='user', url_password='passwd') + >>> r.open('GET', 'http://httpbin.org/basic-auth/user/passwd').read() + '{\n "authenticated": true, \n "user": "user"\n}\n' + >>> r = Request(headers=dict(foo='bar')) + >>> r.open('GET', 'http://httpbin.org/get', headers=dict(baz='qux')).read() + + """ + + self.headers = headers or {} + if not isinstance(self.headers, dict): + raise ValueError("headers must be a dict: %r" % self.headers) + self.use_proxy = use_proxy + self.force = force + self.timeout = timeout + self.validate_certs = validate_certs + self.url_username = url_username + self.url_password = url_password + self.http_agent = http_agent + self.force_basic_auth = force_basic_auth + self.follow_redirects = follow_redirects + self.client_cert = client_cert + self.client_key = client_key + self.unix_socket = unix_socket + self.ca_path = ca_path + self.unredirected_headers = unredirected_headers + self.decompress = decompress + self.ciphers = ciphers + self.use_netrc = use_netrc + if isinstance(cookies, cookiejar.CookieJar): + self.cookies = cookies + else: + self.cookies = cookiejar.CookieJar() + + def _fallback(self, value, fallback): + if value is None: + return fallback + return value + + def open(self, method, url, data=None, headers=None, use_proxy=None, + force=None, last_mod_time=None, timeout=None, validate_certs=None, + url_username=None, url_password=None, http_agent=None, + force_basic_auth=None, follow_redirects=None, + client_cert=None, client_key=None, cookies=None, use_gssapi=False, + unix_socket=None, ca_path=None, unredirected_headers=None, decompress=None, + ciphers=None, use_netrc=None): + """ + Sends a request via HTTP(S) or FTP using urllib2 (Python2) or urllib (Python3) + + Does not require the module environment + + Returns :class:`HTTPResponse` object. + + :arg method: method for the request + :arg url: URL to request + + :kwarg data: (optional) bytes, or file-like object to send + in the body of the request + :kwarg headers: (optional) Dictionary of HTTP Headers to send with the + request + :kwarg use_proxy: (optional) Boolean of whether or not to use proxy + :kwarg force: (optional) Boolean of whether or not to set `cache-control: no-cache` header + :kwarg last_mod_time: (optional) Datetime object to use when setting If-Modified-Since header + :kwarg timeout: (optional) How long to wait for the server to send + data before giving up, as a float + :kwarg validate_certs: (optional) Booleani that controls whether we verify + the server's TLS certificate + :kwarg url_username: (optional) String of the user to use when authenticating + :kwarg url_password: (optional) String of the password to use when authenticating + :kwarg http_agent: (optional) String of the User-Agent to use in the request + :kwarg force_basic_auth: (optional) Boolean determining if auth header should be sent in the initial request + :kwarg follow_redirects: (optional) String of urllib2, all/yes, safe, none to determine how redirects are + followed, see RedirectHandlerFactory for more information + :kwarg client_cert: (optional) PEM formatted certificate chain file to be used for SSL client authentication. + This file can also include the key as well, and if the key is included, client_key is not required + :kwarg client_key: (optional) PEM formatted file that contains your private key to be used for SSL client + authentication. If client_cert contains both the certificate and key, this option is not required + :kwarg cookies: (optional) CookieJar object to send with the + request + :kwarg use_gssapi: (optional) Use GSSAPI handler of requests. + :kwarg unix_socket: (optional) String of file system path to unix socket file to use when establishing + connection to the provided url + :kwarg ca_path: (optional) String of file system path to CA cert bundle to use + :kwarg unredirected_headers: (optional) A list of headers to not attach on a redirected request + :kwarg decompress: (optional) Whether to attempt to decompress gzip content-encoded responses + :kwarg ciphers: (optional) List of ciphers to use + :kwarg use_netrc: (optional) Boolean determining whether to use credentials from ~/.netrc file + :returns: HTTPResponse. Added in Ansible 2.9 + """ + + method = method.upper() + + if headers is None: + headers = {} + elif not isinstance(headers, dict): + raise ValueError("headers must be a dict") + headers = dict(self.headers, **headers) + + use_proxy = self._fallback(use_proxy, self.use_proxy) + force = self._fallback(force, self.force) + timeout = self._fallback(timeout, self.timeout) + validate_certs = self._fallback(validate_certs, self.validate_certs) + url_username = self._fallback(url_username, self.url_username) + url_password = self._fallback(url_password, self.url_password) + http_agent = self._fallback(http_agent, self.http_agent) + force_basic_auth = self._fallback(force_basic_auth, self.force_basic_auth) + follow_redirects = self._fallback(follow_redirects, self.follow_redirects) + client_cert = self._fallback(client_cert, self.client_cert) + client_key = self._fallback(client_key, self.client_key) + cookies = self._fallback(cookies, self.cookies) + unix_socket = self._fallback(unix_socket, self.unix_socket) + ca_path = self._fallback(ca_path, self.ca_path) + unredirected_headers = self._fallback(unredirected_headers, self.unredirected_headers) + decompress = self._fallback(decompress, self.decompress) + ciphers = self._fallback(ciphers, self.ciphers) + use_netrc = self._fallback(use_netrc, self.use_netrc) + + handlers = [] + + if unix_socket: + handlers.append(UnixHTTPHandler(unix_socket)) + + parsed = generic_urlparse(urlparse(url)) + if parsed.scheme != 'ftp': + username = url_username + password = url_password + + if username: + netloc = parsed.netloc + elif '@' in parsed.netloc: + credentials, netloc = parsed.netloc.split('@', 1) + if ':' in credentials: + username, password = credentials.split(':', 1) + else: + username = credentials + password = '' + + parsed_list = parsed.as_list() + parsed_list[1] = netloc + + # reconstruct url without credentials + url = urlunparse(parsed_list) + + if use_gssapi: + if HTTPGSSAPIAuthHandler: + handlers.append(HTTPGSSAPIAuthHandler(username, password)) + else: + imp_err_msg = missing_required_lib('gssapi', reason='for use_gssapi=True', + url='https://pypi.org/project/gssapi/') + raise MissingModuleError(imp_err_msg, import_traceback=GSSAPI_IMP_ERR) + + elif username and not force_basic_auth: + passman = urllib_request.HTTPPasswordMgrWithDefaultRealm() + + # this creates a password manager + passman.add_password(None, netloc, username, password) + + # because we have put None at the start it will always + # use this username/password combination for urls + # for which `theurl` is a super-url + authhandler = urllib_request.HTTPBasicAuthHandler(passman) + digest_authhandler = urllib_request.HTTPDigestAuthHandler(passman) + + # create the AuthHandler + handlers.append(authhandler) + handlers.append(digest_authhandler) + + elif username and force_basic_auth: + headers["Authorization"] = basic_auth_header(username, password) + + elif use_netrc: + try: + rc = netrc.netrc(os.environ.get('NETRC')) + login = rc.authenticators(parsed.hostname) + except IOError: + login = None + + if login: + username, _, password = login + if username and password: + headers["Authorization"] = basic_auth_header(username, password) + + if not use_proxy: + proxyhandler = urllib_request.ProxyHandler({}) + handlers.append(proxyhandler) + + if not any((HAS_SSLCONTEXT, HAS_URLLIB3_PYOPENSSLCONTEXT)): + ssl_handler = maybe_add_ssl_handler(url, validate_certs, ca_path=ca_path, ciphers=ciphers) + if ssl_handler: + handlers.append(ssl_handler) + else: + tmp_ca_path, cadata, paths_checked = get_ca_certs(ca_path) + context = make_context( + cafile=tmp_ca_path, + cadata=cadata, + ciphers=ciphers, + validate_certs=validate_certs, + ) + handlers.append(HTTPSClientAuthHandler(client_cert=client_cert, + client_key=client_key, + unix_socket=unix_socket, + context=context)) + + handlers.append(RedirectHandlerFactory(follow_redirects, validate_certs, ca_path=ca_path, ciphers=ciphers)) + + # add some nicer cookie handling + if cookies is not None: + handlers.append(urllib_request.HTTPCookieProcessor(cookies)) + + opener = urllib_request.build_opener(*handlers) + urllib_request.install_opener(opener) + + data = to_bytes(data, nonstring='passthru') + request = RequestWithMethod(url, method, data) + + # add the custom agent header, to help prevent issues + # with sites that block the default urllib agent string + if http_agent: + request.add_header('User-agent', http_agent) + + # Cache control + # Either we directly force a cache refresh + if force: + request.add_header('cache-control', 'no-cache') + # or we do it if the original is more recent than our copy + elif last_mod_time: + tstamp = rfc2822_date_string(last_mod_time.timetuple(), 'GMT') + request.add_header('If-Modified-Since', tstamp) + + # user defined headers now, which may override things we've set above + unredirected_headers = [h.lower() for h in (unredirected_headers or [])] + for header in headers: + if header.lower() in unredirected_headers: + request.add_unredirected_header(header, headers[header]) + else: + request.add_header(header, headers[header]) + + r = urllib_request.urlopen(request, None, timeout) + if decompress and r.headers.get('content-encoding', '').lower() == 'gzip': + fp = GzipDecodedReader(r.fp) + if PY3: + r.fp = fp + # Content-Length does not match gzip decoded length + # Prevent ``r.read`` from stopping at Content-Length + r.length = None + else: + # Py2 maps ``r.read`` to ``fp.read``, create new ``addinfourl`` + # object to compensate + msg = r.msg + r = urllib_request.addinfourl( + fp, + r.info(), + r.geturl(), + r.getcode() + ) + r.msg = msg + return r + + def get(self, url, **kwargs): + r"""Sends a GET request. Returns :class:`HTTPResponse` object. + + :arg url: URL to request + :kwarg \*\*kwargs: Optional arguments that ``open`` takes. + :returns: HTTPResponse + """ + + return self.open('GET', url, **kwargs) + + def options(self, url, **kwargs): + r"""Sends a OPTIONS request. Returns :class:`HTTPResponse` object. + + :arg url: URL to request + :kwarg \*\*kwargs: Optional arguments that ``open`` takes. + :returns: HTTPResponse + """ + + return self.open('OPTIONS', url, **kwargs) + + def head(self, url, **kwargs): + r"""Sends a HEAD request. Returns :class:`HTTPResponse` object. + + :arg url: URL to request + :kwarg \*\*kwargs: Optional arguments that ``open`` takes. + :returns: HTTPResponse + """ + + return self.open('HEAD', url, **kwargs) + + def post(self, url, data=None, **kwargs): + r"""Sends a POST request. Returns :class:`HTTPResponse` object. + + :arg url: URL to request. + :kwarg data: (optional) bytes, or file-like object to send in the body of the request. + :kwarg \*\*kwargs: Optional arguments that ``open`` takes. + :returns: HTTPResponse + """ + + return self.open('POST', url, data=data, **kwargs) + + def put(self, url, data=None, **kwargs): + r"""Sends a PUT request. Returns :class:`HTTPResponse` object. + + :arg url: URL to request. + :kwarg data: (optional) bytes, or file-like object to send in the body of the request. + :kwarg \*\*kwargs: Optional arguments that ``open`` takes. + :returns: HTTPResponse + """ + + return self.open('PUT', url, data=data, **kwargs) + + def patch(self, url, data=None, **kwargs): + r"""Sends a PATCH request. Returns :class:`HTTPResponse` object. + + :arg url: URL to request. + :kwarg data: (optional) bytes, or file-like object to send in the body of the request. + :kwarg \*\*kwargs: Optional arguments that ``open`` takes. + :returns: HTTPResponse + """ + + return self.open('PATCH', url, data=data, **kwargs) + + def delete(self, url, **kwargs): + r"""Sends a DELETE request. Returns :class:`HTTPResponse` object. + + :arg url: URL to request + :kwargs \*\*kwargs: Optional arguments that ``open`` takes. + :returns: HTTPResponse + """ + + return self.open('DELETE', url, **kwargs) + + +def open_url(url, data=None, headers=None, method=None, use_proxy=True, + force=False, last_mod_time=None, timeout=10, validate_certs=True, + url_username=None, url_password=None, http_agent=None, + force_basic_auth=False, follow_redirects='urllib2', + client_cert=None, client_key=None, cookies=None, + use_gssapi=False, unix_socket=None, ca_path=None, + unredirected_headers=None, decompress=True, ciphers=None, use_netrc=True): + ''' + Sends a request via HTTP(S) or FTP using urllib2 (Python2) or urllib (Python3) + + Does not require the module environment + ''' + method = method or ('POST' if data else 'GET') + return Request().open(method, url, data=data, headers=headers, use_proxy=use_proxy, + force=force, last_mod_time=last_mod_time, timeout=timeout, validate_certs=validate_certs, + url_username=url_username, url_password=url_password, http_agent=http_agent, + force_basic_auth=force_basic_auth, follow_redirects=follow_redirects, + client_cert=client_cert, client_key=client_key, cookies=cookies, + use_gssapi=use_gssapi, unix_socket=unix_socket, ca_path=ca_path, + unredirected_headers=unredirected_headers, decompress=decompress, ciphers=ciphers, use_netrc=use_netrc) + + +def prepare_multipart(fields): + """Takes a mapping, and prepares a multipart/form-data body + + :arg fields: Mapping + :returns: tuple of (content_type, body) where ``content_type`` is + the ``multipart/form-data`` ``Content-Type`` header including + ``boundary`` and ``body`` is the prepared bytestring body + + Payload content from a file will be base64 encoded and will include + the appropriate ``Content-Transfer-Encoding`` and ``Content-Type`` + headers. + + Example: + { + "file1": { + "filename": "/bin/true", + "mime_type": "application/octet-stream" + }, + "file2": { + "content": "text based file content", + "filename": "fake.txt", + "mime_type": "text/plain", + }, + "text_form_field": "value" + } + """ + + if not isinstance(fields, Mapping): + raise TypeError( + 'Mapping is required, cannot be type %s' % fields.__class__.__name__ + ) + + m = email.mime.multipart.MIMEMultipart('form-data') + for field, value in sorted(fields.items()): + if isinstance(value, string_types): + main_type = 'text' + sub_type = 'plain' + content = value + filename = None + elif isinstance(value, Mapping): + filename = value.get('filename') + content = value.get('content') + if not any((filename, content)): + raise ValueError('at least one of filename or content must be provided') + + mime = value.get('mime_type') + if not mime: + try: + mime = mimetypes.guess_type(filename or '', strict=False)[0] or 'application/octet-stream' + except Exception: + mime = 'application/octet-stream' + main_type, sep, sub_type = mime.partition('/') + else: + raise TypeError( + 'value must be a string, or mapping, cannot be type %s' % value.__class__.__name__ + ) + + if not content and filename: + with open(to_bytes(filename, errors='surrogate_or_strict'), 'rb') as f: + part = email.mime.application.MIMEApplication(f.read()) + del part['Content-Type'] + part.add_header('Content-Type', '%s/%s' % (main_type, sub_type)) + else: + part = email.mime.nonmultipart.MIMENonMultipart(main_type, sub_type) + part.set_payload(to_bytes(content)) + + part.add_header('Content-Disposition', 'form-data') + del part['MIME-Version'] + part.set_param( + 'name', + field, + header='Content-Disposition' + ) + if filename: + part.set_param( + 'filename', + to_native(os.path.basename(filename)), + header='Content-Disposition' + ) + + m.attach(part) + + if PY3: + # Ensure headers are not split over multiple lines + # The HTTP policy also uses CRLF by default + b_data = m.as_bytes(policy=email.policy.HTTP) + else: + # Py2 + # We cannot just call ``as_string`` since it provides no way + # to specify ``maxheaderlen`` + fp = cStringIO() # cStringIO seems to be required here + # Ensure headers are not split over multiple lines + g = email.generator.Generator(fp, maxheaderlen=0) + g.flatten(m) + # ``fix_eols`` switches from ``\n`` to ``\r\n`` + b_data = email.utils.fix_eols(fp.getvalue()) + del m + + headers, sep, b_content = b_data.partition(b'\r\n\r\n') + del b_data + + if PY3: + parser = email.parser.BytesHeaderParser().parsebytes + else: + # Py2 + parser = email.parser.HeaderParser().parsestr + + return ( + parser(headers)['content-type'], # Message converts to native strings + b_content + ) + + +# +# Module-related functions +# + + +def basic_auth_header(username, password): + """Takes a username and password and returns a byte string suitable for + using as value of an Authorization header to do basic auth. + """ + if password is None: + password = '' + return b"Basic %s" % base64.b64encode(to_bytes("%s:%s" % (username, password), errors='surrogate_or_strict')) + + +def url_argument_spec(): + ''' + Creates an argument spec that can be used with any module + that will be requesting content via urllib/urllib2 + ''' + return dict( + url=dict(type='str'), + force=dict(type='bool', default=False), + http_agent=dict(type='str', default='ansible-httpget'), + use_proxy=dict(type='bool', default=True), + validate_certs=dict(type='bool', default=True), + url_username=dict(type='str'), + url_password=dict(type='str', no_log=True), + force_basic_auth=dict(type='bool', default=False), + client_cert=dict(type='path'), + client_key=dict(type='path'), + use_gssapi=dict(type='bool', default=False), + ) + + +def fetch_url(module, url, data=None, headers=None, method=None, + use_proxy=None, force=False, last_mod_time=None, timeout=10, + use_gssapi=False, unix_socket=None, ca_path=None, cookies=None, unredirected_headers=None, + decompress=True, ciphers=None, use_netrc=True): + """Sends a request via HTTP(S) or FTP (needs the module as parameter) + + :arg module: The AnsibleModule (used to get username, password etc. (s.b.). + :arg url: The url to use. + + :kwarg data: The data to be sent (in case of POST/PUT). + :kwarg headers: A dict with the request headers. + :kwarg method: "POST", "PUT", etc. + :kwarg use_proxy: (optional) whether or not to use proxy (Default: True) + :kwarg boolean force: If True: Do not get a cached copy (Default: False) + :kwarg last_mod_time: Default: None + :kwarg int timeout: Default: 10 + :kwarg boolean use_gssapi: Default: False + :kwarg unix_socket: (optional) String of file system path to unix socket file to use when establishing + connection to the provided url + :kwarg ca_path: (optional) String of file system path to CA cert bundle to use + :kwarg cookies: (optional) CookieJar object to send with the request + :kwarg unredirected_headers: (optional) A list of headers to not attach on a redirected request + :kwarg decompress: (optional) Whether to attempt to decompress gzip content-encoded responses + :kwarg cipher: (optional) List of ciphers to use + :kwarg boolean use_netrc: (optional) If False: Ignores login and password in ~/.netrc file (Default: True) + + :returns: A tuple of (**response**, **info**). Use ``response.read()`` to read the data. + The **info** contains the 'status' and other meta data. When a HttpError (status >= 400) + occurred then ``info['body']`` contains the error response data:: + + Example:: + + data={...} + resp, info = fetch_url(module, + "http://example.com", + data=module.jsonify(data), + headers={'Content-type': 'application/json'}, + method="POST") + status_code = info["status"] + body = resp.read() + if status_code >= 400 : + body = info['body'] + """ + + if not HAS_URLPARSE: + module.fail_json(msg='urlparse is not installed') + + if not HAS_GZIP and decompress is True: + decompress = False + module.deprecate( + '%s. "decompress" has been automatically disabled to prevent a failure' % GzipDecodedReader.missing_gzip_error(), + version='2.16' + ) + + # ensure we use proper tempdir + old_tempdir = tempfile.tempdir + tempfile.tempdir = module.tmpdir + + # Get validate_certs from the module params + validate_certs = module.params.get('validate_certs', True) + + if use_proxy is None: + use_proxy = module.params.get('use_proxy', True) + + username = module.params.get('url_username', '') + password = module.params.get('url_password', '') + http_agent = module.params.get('http_agent', 'ansible-httpget') + force_basic_auth = module.params.get('force_basic_auth', '') + + follow_redirects = module.params.get('follow_redirects', 'urllib2') + + client_cert = module.params.get('client_cert') + client_key = module.params.get('client_key') + use_gssapi = module.params.get('use_gssapi', use_gssapi) + + if not isinstance(cookies, cookiejar.CookieJar): + cookies = cookiejar.LWPCookieJar() + + r = None + info = dict(url=url, status=-1) + try: + r = open_url(url, data=data, headers=headers, method=method, + use_proxy=use_proxy, force=force, last_mod_time=last_mod_time, timeout=timeout, + validate_certs=validate_certs, url_username=username, + url_password=password, http_agent=http_agent, force_basic_auth=force_basic_auth, + follow_redirects=follow_redirects, client_cert=client_cert, + client_key=client_key, cookies=cookies, use_gssapi=use_gssapi, + unix_socket=unix_socket, ca_path=ca_path, unredirected_headers=unredirected_headers, + decompress=decompress, ciphers=ciphers, use_netrc=use_netrc) + # Lowercase keys, to conform to py2 behavior, so that py3 and py2 are predictable + info.update(dict((k.lower(), v) for k, v in r.info().items())) + + # Don't be lossy, append header values for duplicate headers + # In Py2 there is nothing that needs done, py2 does this for us + if PY3: + temp_headers = {} + for name, value in r.headers.items(): + # The same as above, lower case keys to match py2 behavior, and create more consistent results + name = name.lower() + if name in temp_headers: + temp_headers[name] = ', '.join((temp_headers[name], value)) + else: + temp_headers[name] = value + info.update(temp_headers) + + # parse the cookies into a nice dictionary + cookie_list = [] + cookie_dict = dict() + # Python sorts cookies in order of most specific (ie. longest) path first. See ``CookieJar._cookie_attrs`` + # Cookies with the same path are reversed from response order. + # This code makes no assumptions about that, and accepts the order given by python + for cookie in cookies: + cookie_dict[cookie.name] = cookie.value + cookie_list.append((cookie.name, cookie.value)) + info['cookies_string'] = '; '.join('%s=%s' % c for c in cookie_list) + + info['cookies'] = cookie_dict + # finally update the result with a message about the fetch + info.update(dict(msg="OK (%s bytes)" % r.headers.get('Content-Length', 'unknown'), url=r.geturl(), status=r.code)) + except NoSSLError as e: + distribution = get_distribution() + if distribution is not None and distribution.lower() == 'redhat': + module.fail_json(msg='%s. You can also install python-ssl from EPEL' % to_native(e), **info) + else: + module.fail_json(msg='%s' % to_native(e), **info) + except (ConnectionError, ValueError) as e: + module.fail_json(msg=to_native(e), **info) + except MissingModuleError as e: + module.fail_json(msg=to_text(e), exception=e.import_traceback) + except urllib_error.HTTPError as e: + r = e + try: + if e.fp is None: + # Certain HTTPError objects may not have the ability to call ``.read()`` on Python 3 + # This is not handled gracefully in Python 3, and instead an exception is raised from + # tempfile, due to ``urllib.response.addinfourl`` not being initialized + raise AttributeError + body = e.read() + except AttributeError: + body = '' + else: + e.close() + + # Try to add exception info to the output but don't fail if we can't + try: + # Lowercase keys, to conform to py2 behavior, so that py3 and py2 are predictable + info.update(dict((k.lower(), v) for k, v in e.info().items())) + except Exception: + pass + + info.update({'msg': to_native(e), 'body': body, 'status': e.code}) + + except urllib_error.URLError as e: + code = int(getattr(e, 'code', -1)) + info.update(dict(msg="Request failed: %s" % to_native(e), status=code)) + except socket.error as e: + info.update(dict(msg="Connection failure: %s" % to_native(e), status=-1)) + except httplib.BadStatusLine as e: + info.update(dict(msg="Connection failure: connection was closed before a valid response was received: %s" % to_native(e.line), status=-1)) + except Exception as e: + info.update(dict(msg="An unknown error occurred: %s" % to_native(e), status=-1), + exception=traceback.format_exc()) + finally: + tempfile.tempdir = old_tempdir + + return r, info + + +def _suffixes(name): + """A list of the final component's suffixes, if any.""" + if name.endswith('.'): + return [] + name = name.lstrip('.') + return ['.' + s for s in name.split('.')[1:]] + + +def _split_multiext(name, min=3, max=4, count=2): + """Split a multi-part extension from a file name. + + Returns '([name minus extension], extension)'. + + Define the valid extension length (including the '.') with 'min' and 'max', + 'count' sets the number of extensions, counting from the end, to evaluate. + Evaluation stops on the first file extension that is outside the min and max range. + + If no valid extensions are found, the original ``name`` is returned + and ``extension`` is empty. + + :arg name: File name or path. + :kwarg min: Minimum length of a valid file extension. + :kwarg max: Maximum length of a valid file extension. + :kwarg count: Number of suffixes from the end to evaluate. + + """ + extension = '' + for i, sfx in enumerate(reversed(_suffixes(name))): + if i >= count: + break + + if min <= len(sfx) <= max: + extension = '%s%s' % (sfx, extension) + name = name.rstrip(sfx) + else: + # Stop on the first invalid extension + break + + return name, extension + + +def fetch_file(module, url, data=None, headers=None, method=None, + use_proxy=True, force=False, last_mod_time=None, timeout=10, + unredirected_headers=None, decompress=True, ciphers=None): + '''Download and save a file via HTTP(S) or FTP (needs the module as parameter). + This is basically a wrapper around fetch_url(). + + :arg module: The AnsibleModule (used to get username, password etc. (s.b.). + :arg url: The url to use. + + :kwarg data: The data to be sent (in case of POST/PUT). + :kwarg headers: A dict with the request headers. + :kwarg method: "POST", "PUT", etc. + :kwarg boolean use_proxy: Default: True + :kwarg boolean force: If True: Do not get a cached copy (Default: False) + :kwarg last_mod_time: Default: None + :kwarg int timeout: Default: 10 + :kwarg unredirected_headers: (optional) A list of headers to not attach on a redirected request + :kwarg decompress: (optional) Whether to attempt to decompress gzip content-encoded responses + :kwarg ciphers: (optional) List of ciphers to use + + :returns: A string, the path to the downloaded file. + ''' + # download file + bufsize = 65536 + parts = urlparse(url) + file_prefix, file_ext = _split_multiext(os.path.basename(parts.path), count=2) + fetch_temp_file = tempfile.NamedTemporaryFile(dir=module.tmpdir, prefix=file_prefix, suffix=file_ext, delete=False) + module.add_cleanup_file(fetch_temp_file.name) + try: + rsp, info = fetch_url(module, url, data, headers, method, use_proxy, force, last_mod_time, timeout, + unredirected_headers=unredirected_headers, decompress=decompress, ciphers=ciphers) + if not rsp: + module.fail_json(msg="Failure downloading %s, %s" % (url, info['msg'])) + data = rsp.read(bufsize) + while data: + fetch_temp_file.write(data) + data = rsp.read(bufsize) + fetch_temp_file.close() + except Exception as e: + module.fail_json(msg="Failure downloading %s, %s" % (url, to_native(e))) + return fetch_temp_file.name diff --git a/lib/ansible/module_utils/yumdnf.py b/lib/ansible/module_utils/yumdnf.py new file mode 100644 index 0000000..e265a2d --- /dev/null +++ b/lib/ansible/module_utils/yumdnf.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +# +# # Copyright: (c) 2012, Red Hat, Inc +# Written by Seth Vidal <skvidal at fedoraproject.org> +# Contributing Authors: +# - Ansible Core Team +# - Eduard Snesarev (@verm666) +# - Berend De Schouwer (@berenddeschouwer) +# - Abhijeet Kasurde (@Akasurde) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import time +import glob +import tempfile +from abc import ABCMeta, abstractmethod + +from ansible.module_utils._text import to_native +from ansible.module_utils.six import with_metaclass + +yumdnf_argument_spec = dict( + argument_spec=dict( + allow_downgrade=dict(type='bool', default=False), + autoremove=dict(type='bool', default=False), + bugfix=dict(required=False, type='bool', default=False), + cacheonly=dict(type='bool', default=False), + conf_file=dict(type='str'), + disable_excludes=dict(type='str', default=None), + disable_gpg_check=dict(type='bool', default=False), + disable_plugin=dict(type='list', elements='str', default=[]), + disablerepo=dict(type='list', elements='str', default=[]), + download_only=dict(type='bool', default=False), + download_dir=dict(type='str', default=None), + enable_plugin=dict(type='list', elements='str', default=[]), + enablerepo=dict(type='list', elements='str', default=[]), + exclude=dict(type='list', elements='str', default=[]), + installroot=dict(type='str', default="/"), + install_repoquery=dict(type='bool', default=True), + install_weak_deps=dict(type='bool', default=True), + list=dict(type='str'), + name=dict(type='list', elements='str', aliases=['pkg'], default=[]), + releasever=dict(default=None), + security=dict(type='bool', default=False), + skip_broken=dict(type='bool', default=False), + # removed==absent, installed==present, these are accepted as aliases + state=dict(type='str', default=None, choices=['absent', 'installed', 'latest', 'present', 'removed']), + update_cache=dict(type='bool', default=False, aliases=['expire-cache']), + update_only=dict(required=False, default="no", type='bool'), + validate_certs=dict(type='bool', default=True), + sslverify=dict(type='bool', default=True), + lock_timeout=dict(type='int', default=30), + ), + required_one_of=[['name', 'list', 'update_cache']], + mutually_exclusive=[['name', 'list']], + supports_check_mode=True, +) + + +class YumDnf(with_metaclass(ABCMeta, object)): # type: ignore[misc] + """ + Abstract class that handles the population of instance variables that should + be identical between both YUM and DNF modules because of the feature parity + and shared argument spec + """ + + def __init__(self, module): + + self.module = module + + self.allow_downgrade = self.module.params['allow_downgrade'] + self.autoremove = self.module.params['autoremove'] + self.bugfix = self.module.params['bugfix'] + self.cacheonly = self.module.params['cacheonly'] + self.conf_file = self.module.params['conf_file'] + self.disable_excludes = self.module.params['disable_excludes'] + self.disable_gpg_check = self.module.params['disable_gpg_check'] + self.disable_plugin = self.module.params['disable_plugin'] + self.disablerepo = self.module.params.get('disablerepo', []) + self.download_only = self.module.params['download_only'] + self.download_dir = self.module.params['download_dir'] + self.enable_plugin = self.module.params['enable_plugin'] + self.enablerepo = self.module.params.get('enablerepo', []) + self.exclude = self.module.params['exclude'] + self.installroot = self.module.params['installroot'] + self.install_repoquery = self.module.params['install_repoquery'] + self.install_weak_deps = self.module.params['install_weak_deps'] + self.list = self.module.params['list'] + self.names = [p.strip() for p in self.module.params['name']] + self.releasever = self.module.params['releasever'] + self.security = self.module.params['security'] + self.skip_broken = self.module.params['skip_broken'] + self.state = self.module.params['state'] + self.update_only = self.module.params['update_only'] + self.update_cache = self.module.params['update_cache'] + self.validate_certs = self.module.params['validate_certs'] + self.sslverify = self.module.params['sslverify'] + self.lock_timeout = self.module.params['lock_timeout'] + + # It's possible someone passed a comma separated string since it used + # to be a string type, so we should handle that + self.names = self.listify_comma_sep_strings_in_list(self.names) + self.disablerepo = self.listify_comma_sep_strings_in_list(self.disablerepo) + self.enablerepo = self.listify_comma_sep_strings_in_list(self.enablerepo) + self.exclude = self.listify_comma_sep_strings_in_list(self.exclude) + + # Fail if someone passed a space separated string + # https://github.com/ansible/ansible/issues/46301 + for name in self.names: + if ' ' in name and not any(spec in name for spec in ['@', '>', '<', '=']): + module.fail_json( + msg='It appears that a space separated string of packages was passed in ' + 'as an argument. To operate on several packages, pass a comma separated ' + 'string of packages or a list of packages.' + ) + + # Sanity checking for autoremove + if self.state is None: + if self.autoremove: + self.state = "absent" + else: + self.state = "present" + + if self.autoremove and (self.state != "absent"): + self.module.fail_json( + msg="Autoremove should be used alone or with state=absent", + results=[], + ) + + # This should really be redefined by both the yum and dnf module but a + # default isn't a bad idea + self.lockfile = '/var/run/yum.pid' + + @abstractmethod + def is_lockfile_pid_valid(self): + return + + def _is_lockfile_present(self): + return (os.path.isfile(self.lockfile) or glob.glob(self.lockfile)) and self.is_lockfile_pid_valid() + + def wait_for_lock(self): + '''Poll until the lock is removed if timeout is a positive number''' + + if not self._is_lockfile_present(): + return + + if self.lock_timeout > 0: + for iteration in range(0, self.lock_timeout): + time.sleep(1) + if not self._is_lockfile_present(): + return + + self.module.fail_json(msg='{0} lockfile is held by another process'.format(self.pkg_mgr_name)) + + def listify_comma_sep_strings_in_list(self, some_list): + """ + method to accept a list of strings as the parameter, find any strings + in that list that are comma separated, remove them from the list and add + their comma separated elements to the original list + """ + new_list = [] + remove_from_original_list = [] + for element in some_list: + if ',' in element: + remove_from_original_list.append(element) + new_list.extend([e.strip() for e in element.split(',')]) + + for element in remove_from_original_list: + some_list.remove(element) + + some_list.extend(new_list) + + if some_list == [""]: + return [] + + return some_list + + @abstractmethod + def run(self): + raise NotImplementedError |