diff options
Diffstat (limited to 'test/support')
135 files changed, 36289 insertions, 0 deletions
diff --git a/test/support/integration/plugins/filter/json_query.py b/test/support/integration/plugins/filter/json_query.py new file mode 100644 index 0000000..d1da71b --- /dev/null +++ b/test/support/integration/plugins/filter/json_query.py @@ -0,0 +1,53 @@ +# (c) 2015, Filipe Niero Felisbino <filipenf@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.errors import AnsibleError, AnsibleFilterError + +try: + import jmespath + HAS_LIB = True +except ImportError: + HAS_LIB = False + + +def json_query(data, expr): + '''Query data using jmespath query language ( http://jmespath.org ). Example: + - debug: msg="{{ instance | json_query(tagged_instances[*].block_device_mapping.*.volume_id') }}" + ''' + if not HAS_LIB: + raise AnsibleError('You need to install "jmespath" prior to running ' + 'json_query filter') + + try: + return jmespath.search(expr, data) + except jmespath.exceptions.JMESPathError as e: + raise AnsibleFilterError('JMESPathError in json_query filter plugin:\n%s' % e) + except Exception as e: + # For older jmespath, we can get ValueError and TypeError without much info. + raise AnsibleFilterError('Error in jmespath.search in json_query filter plugin:\n%s' % e) + + +class FilterModule(object): + ''' Query filter ''' + + def filters(self): + return { + 'json_query': json_query + } diff --git a/test/support/integration/plugins/module_utils/compat/__init__.py b/test/support/integration/plugins/module_utils/compat/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/support/integration/plugins/module_utils/compat/__init__.py diff --git a/test/support/integration/plugins/module_utils/compat/ipaddress.py b/test/support/integration/plugins/module_utils/compat/ipaddress.py new file mode 100644 index 0000000..c46ad72 --- /dev/null +++ b/test/support/integration/plugins/module_utils/compat/ipaddress.py @@ -0,0 +1,2476 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file, and this file only, is based on +# Lib/ipaddress.py of cpython +# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved" +# are retained in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. + +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +from __future__ import unicode_literals + + +import itertools +import struct + + +# The following makes it easier for us to script updates of the bundled code and is not part of +# upstream +_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"} + +__version__ = '1.0.22' + +# Compatibility functions +_compat_int_types = (int,) +try: + _compat_int_types = (int, long) +except NameError: + pass +try: + _compat_str = unicode +except NameError: + _compat_str = str + assert bytes != str +if b'\0'[0] == 0: # Python 3 semantics + def _compat_bytes_to_byte_vals(byt): + return byt +else: + def _compat_bytes_to_byte_vals(byt): + return [struct.unpack(b'!B', b)[0] for b in byt] +try: + _compat_int_from_byte_vals = int.from_bytes +except AttributeError: + def _compat_int_from_byte_vals(bytvals, endianess): + assert endianess == 'big' + res = 0 + for bv in bytvals: + assert isinstance(bv, _compat_int_types) + res = (res << 8) + bv + return res + + +def _compat_to_bytes(intval, length, endianess): + assert isinstance(intval, _compat_int_types) + assert endianess == 'big' + if length == 4: + if intval < 0 or intval >= 2 ** 32: + raise struct.error("integer out of range for 'I' format code") + return struct.pack(b'!I', intval) + elif length == 16: + if intval < 0 or intval >= 2 ** 128: + raise struct.error("integer out of range for 'QQ' format code") + return struct.pack(b'!QQ', intval >> 64, intval & 0xffffffffffffffff) + else: + raise NotImplementedError() + + +if hasattr(int, 'bit_length'): + # Not int.bit_length , since that won't work in 2.7 where long exists + def _compat_bit_length(i): + return i.bit_length() +else: + def _compat_bit_length(i): + for res in itertools.count(): + if i >> res == 0: + return res + + +def _compat_range(start, end, step=1): + assert step > 0 + i = start + while i < end: + yield i + i += step + + +class _TotalOrderingMixin(object): + __slots__ = () + + # Helper that derives the other comparison operations from + # __lt__ and __eq__ + # We avoid functools.total_ordering because it doesn't handle + # NotImplemented correctly yet (http://bugs.python.org/issue10042) + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + + def __lt__(self, other): + raise NotImplementedError + + def __le__(self, other): + less = self.__lt__(other) + if less is NotImplemented or not less: + return self.__eq__(other) + return less + + def __gt__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not (less or equal) + + def __ge__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + return not less + + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + '%r does not appear to be an IPv4 or IPv6 address. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?' % address) + + raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % + address) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + '%r does not appear to be an IPv4 or IPv6 network. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?' % address) + + raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % + address) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' % + address) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return _compat_to_bytes(address, 4, 'big') + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return _compat_to_bytes(address, 16, 'big') + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = _compat_str(address).split('/') + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) # pylint: disable=stop-iteration-return + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, _compat_bit_length(~number & (number - 1))) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if (not (isinstance(first, _BaseAddress) and + isinstance(last, _BaseAddress))): + raise TypeError('first and last must be IP addresses, not networks') + if first.version != last.version: + raise TypeError("%s and %s are not of the same version" % ( + first, last)) + if first > last: + raise ValueError('last IP address must be greater than first') + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError('unknown IP version') + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min(_count_righthand_zero_bits(first_int, ip_bits), + _compat_bit_length(last_int - first_int + 1) - 1) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, + # last.network_address <= net.network_address is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, nets[-1])) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase(_TotalOrderingMixin): + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return _compat_str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = '%200s has no version specified' % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._max_prefixlen, + self._version)) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = ( + '%r (len %d != %d) is not permitted as an IPv%d address. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?') + raise AddressValueError(msg % (address, address_len, + expected_len, self._version)) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits(ip_int, + cls._max_prefixlen) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = _compat_to_bytes(ip_int, byteslen, 'big') + msg = 'Netmask pattern %r mixes zeroes & ones' + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = '%r is not a valid netmask' % netmask_str + raise NetmaskValueError(msg) + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + def __reduce__(self): + return self.__class__, (_compat_str(self),) + + +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + try: + return (self._ip == other._ip and + self._version == other._version) + except AttributeError: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseAddress): + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self._ip != other._ip: + return self._ip < other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return _compat_str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + +class _BaseNetwork(_IPAddressBase): + + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + + """ + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return '%s/%d' % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError('address out of range') + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError('address out of range') + return self._address_class(broadcast + n) + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseNetwork): + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __eq__(self, other): + try: + return (self._version == other._version and + self.network_address == other.network_address and + int(self.netmask) == int(other.netmask)) + except AttributeError: + return NotImplemented + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return (int(self.network_address) <= int(other._ip) <= + int(self.broadcast_address)) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other or ( + other.network_address in self or ( + other.broadcast_address in self))) + + @property + def broadcast_address(self): + x = self._cache.get('broadcast_address') + if x is None: + x = self._address_class(int(self.network_address) | + int(self.hostmask)) + self._cache['broadcast_address'] = x + return x + + @property + def hostmask(self): + x = self._cache.get('hostmask') + if x is None: + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) + self._cache['hostmask'] = x + return x + + @property + def with_prefixlen(self): + return '%s/%d' % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = '%200s has no associated address class' % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError("%s and %s are not of the same version" % ( + self, other)) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not other.subnet_of(self): + raise ValueError('%s not contained in %s' % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__('%s/%s' % (other.network_address, + other.prefixlen)) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if other.subnet_of(s1): + yield s2 + s1, s2 = s1.subnets() + elif other.subnet_of(s2): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError('new prefix must be longer') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError('prefix length diff must be > 0') + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + 'prefix length diff %d is invalid for netblock %s' % ( + new_prefixlen, self)) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in _compat_range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError('new prefix must be shorter') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % + (self.prefixlen, prefixlen_diff)) + return self.__class__(( + int(self.network_address) & (int(self.netmask) << prefixlen_diff), + new_prefixlen)) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return (self.network_address.is_multicast and + self.broadcast_address.is_multicast) + + @staticmethod + def _is_subnet_of(a, b): + try: + # Always false if one is v4 and the other is v6. + if a._version != b._version: + raise TypeError("%s and %s are not of the same version" % (a, b)) + return (b.network_address <= a.network_address and + b.broadcast_address >= a.broadcast_address) + except AttributeError: + raise TypeError("Unable to test subnet containment " + "between %s and %s" % (a, b)) + + def subnet_of(self, other): + """Return True if this network is a subnet of other.""" + return self._is_subnet_of(self, other) + + def supernet_of(self, other): + """Return True if this network is a supernet of other.""" + return self._is_subnet_of(other, self) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return (self.network_address.is_reserved and + self.broadcast_address.is_reserved) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return (self.network_address.is_link_local and + self.broadcast_address.is_link_local) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return (self.network_address.is_private and + self.broadcast_address.is_private) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return (self.network_address.is_unspecified and + self.broadcast_address.is_unspecified) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return (self.network_address.is_loopback and + self.broadcast_address.is_loopback) + + +class _BaseV4(object): + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2 ** IPV4LENGTH) - 1 + _DECIMAL_DIGITS = frozenset('0123456789') + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0]) + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return _compat_str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + octets = ip_str.split('.') + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return _compat_int_from_byte_vals( + map(cls._parse_octet, octets), 'big') + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._DECIMAL_DIGITS.issuperset(octet_str): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + # Any octets that look like they *might* be written in octal, + # and which don't look exactly the same in both octal and + # decimal are rejected as ambiguous + if octet_int > 7 and octet_str[0] == '0': + msg = "Ambiguous (octal/decimal) value in %r not permitted" + raise ValueError(msg % octet_str) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return '.'.join(_compat_str(struct.unpack(b'!B', b)[0] + if isinstance(b, bytes) + else b) + for b in _compat_to_bytes(ip_int, 4, 'big')) + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split('.') + try: + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = _compat_str(self).split('.')[::-1] + return '.'.join(reverse_octets) + '.in-addr.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + return ( + self not in self._constants._public_network and + not self.is_private) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv4Address.__init__(self, address) + self.network = IPv4Network(self._ip) + self._prefixlen = self._max_prefixlen + return + + if isinstance(address, tuple): + IPv4Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + + self.network = IPv4Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv4Address.__init__(self, addr[0]) + + self.network = IPv4Network(address, strict=False) + self._prefixlen = self.network._prefixlen + + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.0.2/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (_compat_int_types, bytes)): + self.network_address = IPv4Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen) + # fixme: address/network test here. + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + # We weren't given an address[1] + arg = self._max_prefixlen + self.network_address = IPv4Address(address[0]) + self.netmask, self._prefixlen = self._make_netmask(arg) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv4Address(packed & + int(self.netmask)) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if (IPv4Address(int(self.network_address) & int(self.netmask)) != + self.network_address): + raise ValueError('%s has host bits set' % self) + self.network_address = IPv4Address(int(self.network_address) & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return (not (self.network_address in IPv4Network('100.64.0.0/10') and + self.broadcast_address in IPv4Network('100.64.0.0/10')) and + not self.is_private) + + +class _IPv4Constants(object): + + _linklocal_network = IPv4Network('169.254.0.0/16') + + _loopback_network = IPv4Network('127.0.0.0/8') + + _multicast_network = IPv4Network('224.0.0.0/4') + + _public_network = IPv4Network('100.64.0.0/10') + + _private_networks = [ + IPv4Network('0.0.0.0/8'), + IPv4Network('10.0.0.0/8'), + IPv4Network('127.0.0.0/8'), + IPv4Network('169.254.0.0/16'), + IPv4Network('172.16.0.0/12'), + IPv4Network('192.0.0.0/29'), + IPv4Network('192.0.0.170/31'), + IPv4Network('192.0.2.0/24'), + IPv4Network('192.168.0.0/16'), + IPv4Network('198.18.0.0/15'), + IPv4Network('198.51.100.0/24'), + IPv4Network('203.0.113.0/24'), + IPv4Network('240.0.0.0/4'), + IPv4Network('255.255.255.255/32'), + ] + + _reserved_network = IPv4Network('240.0.0.0/4') + + _unspecified_address = IPv4Address('0.0.0.0') + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6(object): + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2 ** IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef') + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + parts = ip_str.split(':') + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if '.' in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF)) + parts.append('%x' % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % ( + _max_parts - 1, ip_str) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in _compat_range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == '0': + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = (best_doublecolon_start + + best_doublecolon_len) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [''] + hextets[best_doublecolon_start:best_doublecolon_end] = [''] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [''] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError('IPv6 address is too large') + + hex_str = '%032x' % ip_int + hextets = ['%x' % int(hex_str[x:x + 4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ':'.join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = _compat_str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = _compat_str(self.ip) + else: + ip_str = _compat_str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = '%032x' % ip_int + parts = [hex_str[x:x + 4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return '%s/%d' % (':'.join(parts), self._prefixlen) + return ':'.join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(':', '') + return '.'.join(reverse_chars) + '.ip6.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF)) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv6Address.__init__(self, address) + self.network = IPv6Network(self._ip) + self._prefixlen = self._max_prefixlen + return + if isinstance(address, tuple): + IPv6Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv6Address.__init__(self, addr[0]) + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Efficient constructor from integer or packed address + if isinstance(address, (bytes, _compat_int_types)): + self.network_address = IPv6Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen) + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + self.network_address = IPv6Address(address[0]) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv6Address(packed & + int(self.netmask)) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + + self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if (IPv6Address(int(self.network_address) & int(self.netmask)) != + self.network_address): + raise ValueError('%s has host bits set' % self) + self.network_address = IPv6Address(int(self.network_address) & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return (self.network_address.is_site_local and + self.broadcast_address.is_site_local) + + +class _IPv6Constants(object): + + _linklocal_network = IPv6Network('fe80::/10') + + _multicast_network = IPv6Network('ff00::/8') + + _private_networks = [ + IPv6Network('::1/128'), + IPv6Network('::/128'), + IPv6Network('::ffff:0:0/96'), + IPv6Network('100::/64'), + IPv6Network('2001::/23'), + IPv6Network('2001:2::/48'), + IPv6Network('2001:db8::/32'), + IPv6Network('2001:10::/28'), + IPv6Network('fc00::/7'), + IPv6Network('fe80::/10'), + ] + + _reserved_networks = [ + IPv6Network('::/8'), IPv6Network('100::/8'), + IPv6Network('200::/7'), IPv6Network('400::/6'), + IPv6Network('800::/5'), IPv6Network('1000::/4'), + IPv6Network('4000::/3'), IPv6Network('6000::/3'), + IPv6Network('8000::/3'), IPv6Network('A000::/3'), + IPv6Network('C000::/3'), IPv6Network('E000::/4'), + IPv6Network('F000::/5'), IPv6Network('F800::/6'), + IPv6Network('FE00::/9'), + ] + + _sitelocal_network = IPv6Network('fec0::/10') + + +IPv6Address._constants = _IPv6Constants diff --git a/test/support/integration/plugins/module_utils/net_tools/__init__.py b/test/support/integration/plugins/module_utils/net_tools/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/support/integration/plugins/module_utils/net_tools/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/__init__.py b/test/support/integration/plugins/module_utils/network/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/common/__init__.py b/test/support/integration/plugins/module_utils/network/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/common/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/common/utils.py b/test/support/integration/plugins/module_utils/network/common/utils.py new file mode 100644 index 0000000..8031738 --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/common/utils.py @@ -0,0 +1,643 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Networking tools for network modules only + +import re +import ast +import operator +import socket +import json + +from itertools import chain + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils import basic +from ansible.module_utils.parsing.convert_bool import boolean + +# Backwards compatibility for 3rd party modules +# TODO(pabelanger): With move to ansible.netcommon, we should clean this code +# up and have modules import directly themself. +from ansible.module_utils.common.network import ( # noqa: F401 + to_bits, is_netmask, is_masklen, to_netmask, to_masklen, to_subnet, to_ipv6_network, VALID_MASKS +) + +try: + from jinja2 import Environment, StrictUndefined + from jinja2.exceptions import UndefinedError + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False + + +OPERATORS = frozenset(['ge', 'gt', 'eq', 'neq', 'lt', 'le']) +ALIASES = frozenset([('min', 'ge'), ('max', 'le'), ('exactly', 'eq'), ('neq', 'ne')]) + + +def to_list(val): + if isinstance(val, (list, tuple, set)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +def to_lines(stdout): + for item in stdout: + if isinstance(item, string_types): + item = to_text(item).split('\n') + yield item + + +def transform_commands(module): + transform = ComplexList(dict( + command=dict(key=True), + output=dict(), + prompt=dict(type='list'), + answer=dict(type='list'), + newline=dict(type='bool', default=True), + sendonly=dict(type='bool', default=False), + check_all=dict(type='bool', default=False), + ), module) + + return transform(module.params['commands']) + + +def sort_list(val): + if isinstance(val, list): + return sorted(val) + return val + + +class Entity(object): + """Transforms a dict to with an argument spec + + This class will take a dict and apply an Ansible argument spec to the + values. The resulting dict will contain all of the keys in the param + with appropriate values set. + + Example:: + + argument_spec = dict( + command=dict(key=True), + display=dict(default='text', choices=['text', 'json']), + validate=dict(type='bool') + ) + transform = Entity(module, argument_spec) + value = dict(command='foo') + result = transform(value) + print result + {'command': 'foo', 'display': 'text', 'validate': None} + + Supported argument spec: + * key - specifies how to map a single value to a dict + * read_from - read and apply the argument_spec from the module + * required - a value is required + * type - type of value (uses AnsibleModule type checker) + * fallback - implements fallback function + * choices - set of valid options + * default - default value + """ + + def __init__(self, module, attrs=None, args=None, keys=None, from_argspec=False): + args = [] if args is None else args + + self._attributes = attrs or {} + self._module = module + + for arg in args: + self._attributes[arg] = dict() + if from_argspec: + self._attributes[arg]['read_from'] = arg + if keys and arg in keys: + self._attributes[arg]['key'] = True + + self.attr_names = frozenset(self._attributes.keys()) + + _has_key = False + + for name, attr in iteritems(self._attributes): + if attr.get('read_from'): + if attr['read_from'] not in self._module.argument_spec: + module.fail_json(msg='argument %s does not exist' % attr['read_from']) + spec = self._module.argument_spec.get(attr['read_from']) + for key, value in iteritems(spec): + if key not in attr: + attr[key] = value + + if attr.get('key'): + if _has_key: + module.fail_json(msg='only one key value can be specified') + _has_key = True + attr['required'] = True + + def serialize(self): + return self._attributes + + def to_dict(self, value): + obj = {} + for name, attr in iteritems(self._attributes): + if attr.get('key'): + obj[name] = value + else: + obj[name] = attr.get('default') + return obj + + def __call__(self, value, strict=True): + if not isinstance(value, dict): + value = self.to_dict(value) + + if strict: + unknown = set(value).difference(self.attr_names) + if unknown: + self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown)) + + for name, attr in iteritems(self._attributes): + if value.get(name) is None: + value[name] = attr.get('default') + + if attr.get('fallback') and not value.get(name): + fallback = attr.get('fallback', (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + value[name] = fallback_strategy(*fallback_args, **fallback_kwargs) + except basic.AnsibleFallbackNotFound: + continue + + if attr.get('required') and value.get(name) is None: + self._module.fail_json(msg='missing required attribute %s' % name) + + if 'choices' in attr: + if value[name] not in attr['choices']: + self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name])) + + if value[name] is not None: + value_type = attr.get('type', 'str') + type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type] + type_checker(value[name]) + elif value.get(name): + value[name] = self._module.params[name] + + return value + + +class EntityCollection(Entity): + """Extends ```Entity``` to handle a list of dicts """ + + def __call__(self, iterable, strict=True): + if iterable is None: + iterable = [super(EntityCollection, self).__call__(self._module.params, strict)] + + if not isinstance(iterable, (list, tuple)): + self._module.fail_json(msg='value must be an iterable') + + return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable] + + +# these two are for backwards compatibility and can be removed once all of the +# modules that use them are updated +class ComplexDict(Entity): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexDict, self).__init__(module, attrs, *args, **kwargs) + + +class ComplexList(EntityCollection): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexList, self).__init__(module, attrs, *args, **kwargs) + + +def dict_diff(base, comparable): + """ Generate a dict object of differences + + This function will compare two dict objects and return the difference + between them as a dict object. For scalar values, the key will reflect + the updated value. If the key does not exist in `comparable`, then then no + key will be returned. For lists, the value in comparable will wholly replace + the value in base for the key. For dicts, the returned value will only + return keys that are different. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + + :returns: new dict object with differences + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(comparable, dict): + if comparable is None: + comparable = dict() + else: + raise AssertionError("`comparable` must be of type <dict>") + + updates = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + item = comparable.get(key) + if item is not None: + sub_diff = dict_diff(value, comparable[key]) + if sub_diff: + updates[key] = sub_diff + else: + comparable_value = comparable.get(key) + if comparable_value is not None: + if sort_list(base[key]) != sort_list(comparable_value): + updates[key] = comparable_value + + for key in set(comparable.keys()).difference(base.keys()): + updates[key] = comparable.get(key) + + return updates + + +def dict_merge(base, other): + """ Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. If the value is a list object, the two lists will + be combined and duplicate entries removed. + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(other, dict): + raise AssertionError("`other` must be of type <dict>") + + combined = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + if isinstance(other[key], Mapping): + combined[key] = dict_merge(value, other[key]) + else: + combined[key] = other[key] + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if item is not None: + try: + combined[key] = list(set(chain(value, item))) + except TypeError: + value.extend([i for i in item if i not in value]) + combined[key] = value + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) + + return combined + + +def param_list_to_dict(param_list, unique_key="name", remove_key=True): + """Rotates a list of dictionaries to be a dictionary of dictionaries. + + :param param_list: The aforementioned list of dictionaries + :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value + behind this key will be the key each dictionary can be found at in the new root dictionary + :param remove_key: If True, remove unique_key from the individual dictionaries before returning. + """ + param_dict = {} + for params in param_list: + params = params.copy() + if remove_key: + name = params.pop(unique_key) + else: + name = params.get(unique_key) + param_dict[name] = params + + return param_dict + + +def conditional(expr, val, cast=None): + match = re.match(r'^(.+)\((.+)\)$', str(expr), re.I) + if match: + op, arg = match.groups() + else: + op = 'eq' + if ' ' in str(expr): + raise AssertionError('invalid expression: cannot contain spaces') + arg = expr + + if cast is None and val is not None: + arg = type(val)(arg) + elif callable(cast): + arg = cast(arg) + val = cast(val) + + op = next((oper for alias, oper in ALIASES if op == alias), op) + + if not hasattr(operator, op) and op not in OPERATORS: + raise ValueError('unknown operator: %s' % op) + + func = getattr(operator, op) + return func(val, arg) + + +def ternary(value, true_val, false_val): + ''' value ? true_val : false_val ''' + if value: + return true_val + else: + return false_val + + +def remove_default_spec(spec): + for item in spec: + if 'default' in spec[item]: + del spec[item]['default'] + + +def validate_ip_address(address): + try: + socket.inet_aton(address) + except socket.error: + return False + return address.count('.') == 3 + + +def validate_ip_v6_address(address): + try: + socket.inet_pton(socket.AF_INET6, address) + except socket.error: + return False + return True + + +def validate_prefix(prefix): + if prefix and not 0 <= int(prefix) <= 32: + return False + return True + + +def load_provider(spec, args): + provider = args.get('provider') or {} + for key, value in iteritems(spec): + if key not in provider: + if 'fallback' in value: + provider[key] = _fallback(value['fallback']) + elif 'default' in value: + provider[key] = value['default'] + else: + provider[key] = None + if 'authorize' in provider: + # Coerce authorize to provider if a string has somehow snuck in. + provider['authorize'] = boolean(provider['authorize'] or False) + args['provider'] = provider + return provider + + +def _fallback(fallback): + strategy = fallback[0] + args = [] + kwargs = {} + + for item in fallback[1:]: + if isinstance(item, dict): + kwargs = item + else: + args = item + try: + return strategy(*args, **kwargs) + except basic.AnsibleFallbackNotFound: + pass + + +def generate_dict(spec): + """ + Generate dictionary which is in sync with argspec + + :param spec: A dictionary that is the argspec of the module + :rtype: A dictionary + :returns: A dictionary in sync with argspec with default value + """ + obj = {} + if not spec: + return obj + + for key, val in iteritems(spec): + if 'default' in val: + dct = {key: val['default']} + elif 'type' in val and val['type'] == 'dict': + dct = {key: generate_dict(val['options'])} + else: + dct = {key: None} + obj.update(dct) + return obj + + +def parse_conf_arg(cfg, arg): + """ + Parse config based on argument + + :param cfg: A text string which is a line of configuration. + :param arg: A text string which is to be matched. + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r'%s (.+)(\n|$)' % arg, cfg, re.M) + if match: + result = match.group(1).strip() + else: + result = None + return result + + +def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str='no'): + """ + Parse config based on command + + :param cfg: A text string which is a line of configuration. + :param cmd: A text string which is the command to be matched + :param res1: A text string to be returned if the command is present + :param res2: A text string to be returned if the negate command + is present + :param delete_str: A text string to identify the start of the + negate command + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r'\n\s+%s(\n|$)' % cmd, cfg) + if match: + return res1 + if res2 is not None: + match = re.search(r'\n\s+%s %s(\n|$)' % (delete_str, cmd), cfg) + if match: + return res2 + return None + + +def get_xml_conf_arg(cfg, path, data='text'): + """ + :param cfg: The top level configuration lxml Element tree object + :param path: The relative xpath w.r.t to top level element (cfg) + to be searched in the xml hierarchy + :param data: The type of data to be returned for the matched xml node. + Valid values are text, tag, attrib, with default as text. + :return: Returns the required type for the matched xml node or else None + """ + match = cfg.xpath(path) + if len(match): + if data == 'tag': + result = getattr(match[0], 'tag') + elif data == 'attrib': + result = getattr(match[0], 'attrib') + else: + result = getattr(match[0], 'text') + else: + result = None + return result + + +def remove_empties(cfg_dict): + """ + Generate final config dictionary + + :param cfg_dict: A dictionary parsed in the facts system + :rtype: A dictionary + :returns: A dictionary by eliminating keys that have null values + """ + final_cfg = {} + if not cfg_dict: + return final_cfg + + for key, val in iteritems(cfg_dict): + dct = None + if isinstance(val, dict): + child_val = remove_empties(val) + if child_val: + dct = {key: child_val} + elif (isinstance(val, list) and val + and all([isinstance(x, dict) for x in val])): + child_val = [remove_empties(x) for x in val] + if child_val: + dct = {key: child_val} + elif val not in [None, [], {}, (), '']: + dct = {key: val} + if dct: + final_cfg.update(dct) + return final_cfg + + +def validate_config(spec, data): + """ + Validate if the input data against the AnsibleModule spec format + :param spec: Ansible argument spec + :param data: Data to be validated + :return: + """ + params = basic._ANSIBLE_ARGS + basic._ANSIBLE_ARGS = to_bytes(json.dumps({'ANSIBLE_MODULE_ARGS': data})) + validated_data = basic.AnsibleModule(spec).params + basic._ANSIBLE_ARGS = params + return validated_data + + +def search_obj_in_list(name, lst, key='name'): + if not lst: + return None + else: + for item in lst: + if item.get(key) == name: + return item + + +class Template: + + def __init__(self): + if not HAS_JINJA2: + raise ImportError("jinja2 is required but does not appear to be installed. " + "It can be installed using `pip install jinja2`") + + self.env = Environment(undefined=StrictUndefined) + self.env.filters.update({'ternary': ternary}) + + def __call__(self, value, variables=None, fail_on_undefined=True): + variables = variables or {} + + if not self.contains_vars(value): + return value + + try: + value = self.env.from_string(value).render(variables) + except UndefinedError: + if not fail_on_undefined: + return None + raise + + if value: + try: + return ast.literal_eval(value) + except Exception: + return str(value) + else: + return None + + def contains_vars(self, data): + if isinstance(data, string_types): + for marker in (self.env.block_start_string, self.env.variable_start_string, self.env.comment_start_string): + if marker in data: + return True + return False diff --git a/test/support/integration/plugins/modules/htpasswd.py b/test/support/integration/plugins/modules/htpasswd.py new file mode 100644 index 0000000..2c55a6b --- /dev/null +++ b/test/support/integration/plugins/modules/htpasswd.py @@ -0,0 +1,275 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, Nimbis Services, Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = """ +module: htpasswd +version_added: "1.3" +short_description: manage user files for basic authentication +description: + - Add and remove username/password entries in a password file using htpasswd. + - This is used by web servers such as Apache and Nginx for basic authentication. +options: + path: + required: true + aliases: [ dest, destfile ] + description: + - Path to the file that contains the usernames and passwords + name: + required: true + aliases: [ username ] + description: + - User name to add or remove + password: + required: false + description: + - Password associated with user. + - Must be specified if user does not exist yet. + crypt_scheme: + required: false + choices: ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"] + default: "apr_md5_crypt" + description: + - Encryption scheme to be used. As well as the four choices listed + here, you can also use any other hash supported by passlib, such as + md5_crypt and sha256_crypt, which are linux passwd hashes. If you + do so the password file will not be compatible with Apache or Nginx + state: + required: false + choices: [ present, absent ] + default: "present" + description: + - Whether the user entry should be present or not + create: + required: false + type: bool + default: "yes" + description: + - Used with C(state=present). If specified, the file will be created + if it does not already exist. If set to "no", will fail if the + file does not exist +notes: + - "This module depends on the I(passlib) Python library, which needs to be installed on all target systems." + - "On Debian, Ubuntu, or Fedora: install I(python-passlib)." + - "On RHEL or CentOS: Enable EPEL, then install I(python-passlib)." +requirements: [ passlib>=1.6 ] +author: "Ansible Core Team" +extends_documentation_fragment: files +""" + +EXAMPLES = """ +# Add a user to a password file and ensure permissions are set +- htpasswd: + path: /etc/nginx/passwdfile + name: janedoe + password: '9s36?;fyNp' + owner: root + group: www-data + mode: 0640 + +# Remove a user from a password file +- htpasswd: + path: /etc/apache2/passwdfile + name: foobar + state: absent + +# Add a user to a password file suitable for use by libpam-pwdfile +- htpasswd: + path: /etc/mail/passwords + name: alex + password: oedu2eGh + crypt_scheme: md5_crypt +""" + + +import os +import tempfile +import traceback +from ansible.module_utils.compat.version import LooseVersion +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native + +PASSLIB_IMP_ERR = None +try: + from passlib.apache import HtpasswdFile, htpasswd_context + from passlib.context import CryptContext + import passlib +except ImportError: + PASSLIB_IMP_ERR = traceback.format_exc() + passlib_installed = False +else: + passlib_installed = True + +apache_hashes = ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"] + + +def create_missing_directories(dest): + destpath = os.path.dirname(dest) + if not os.path.exists(destpath): + os.makedirs(destpath) + + +def present(dest, username, password, crypt_scheme, create, check_mode): + """ Ensures user is present + + Returns (msg, changed) """ + if crypt_scheme in apache_hashes: + context = htpasswd_context + else: + context = CryptContext(schemes=[crypt_scheme] + apache_hashes) + if not os.path.exists(dest): + if not create: + raise ValueError('Destination %s does not exist' % dest) + if check_mode: + return ("Create %s" % dest, True) + create_missing_directories(dest) + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=True, default_scheme=crypt_scheme, context=context) + else: + ht = HtpasswdFile(dest, autoload=False, default=crypt_scheme, context=context) + if getattr(ht, 'set_password', None): + ht.set_password(username, password) + else: + ht.update(username, password) + ht.save() + return ("Created %s and added %s" % (dest, username), True) + else: + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=False, default_scheme=crypt_scheme, context=context) + else: + ht = HtpasswdFile(dest, default=crypt_scheme, context=context) + + found = None + if getattr(ht, 'check_password', None): + found = ht.check_password(username, password) + else: + found = ht.verify(username, password) + + if found: + return ("%s already present" % username, False) + else: + if not check_mode: + if getattr(ht, 'set_password', None): + ht.set_password(username, password) + else: + ht.update(username, password) + ht.save() + return ("Add/update %s" % username, True) + + +def absent(dest, username, check_mode): + """ Ensures user is absent + + Returns (msg, changed) """ + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=False) + else: + ht = HtpasswdFile(dest) + + if username not in ht.users(): + return ("%s not present" % username, False) + else: + if not check_mode: + ht.delete(username) + ht.save() + return ("Remove %s" % username, True) + + +def check_file_attrs(module, changed, message): + + file_args = module.load_file_common_arguments(module.params) + if module.set_fs_attributes_if_different(file_args, False): + + if changed: + message += " and " + changed = True + message += "ownership, perms or SE linux context changed" + + return message, changed + + +def main(): + arg_spec = dict( + path=dict(required=True, aliases=["dest", "destfile"]), + name=dict(required=True, aliases=["username"]), + password=dict(required=False, default=None, no_log=True), + crypt_scheme=dict(required=False, default="apr_md5_crypt"), + state=dict(required=False, default="present"), + create=dict(type='bool', default='yes'), + + ) + module = AnsibleModule(argument_spec=arg_spec, + add_file_common_args=True, + supports_check_mode=True) + + path = module.params['path'] + username = module.params['name'] + password = module.params['password'] + crypt_scheme = module.params['crypt_scheme'] + state = module.params['state'] + create = module.params['create'] + check_mode = module.check_mode + + if not passlib_installed: + module.fail_json(msg=missing_required_lib("passlib"), exception=PASSLIB_IMP_ERR) + + # Check file for blank lines in effort to avoid "need more than 1 value to unpack" error. + try: + f = open(path, "r") + except IOError: + # No preexisting file to remove blank lines from + f = None + else: + try: + lines = f.readlines() + finally: + f.close() + + # If the file gets edited, it returns true, so only edit the file if it has blank lines + strip = False + for line in lines: + if not line.strip(): + strip = True + break + + if strip: + # If check mode, create a temporary file + if check_mode: + temp = tempfile.NamedTemporaryFile() + path = temp.name + f = open(path, "w") + try: + [f.write(line) for line in lines if line.strip()] + finally: + f.close() + + try: + if state == 'present': + (msg, changed) = present(path, username, password, crypt_scheme, create, check_mode) + elif state == 'absent': + if not os.path.exists(path): + module.exit_json(msg="%s not present" % username, + warnings="%s does not exist" % path, changed=False) + (msg, changed) = absent(path, username, check_mode) + else: + module.fail_json(msg="Invalid state: %s" % state) + + check_file_attrs(module, changed, msg) + module.exit_json(msg=msg, changed=changed) + except Exception as e: + module.fail_json(msg=to_native(e)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/pkgng.py b/test/support/integration/plugins/modules/pkgng.py new file mode 100644 index 0000000..1136347 --- /dev/null +++ b/test/support/integration/plugins/modules/pkgng.py @@ -0,0 +1,406 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, bleader +# Written by bleader <bleader@ratonland.org> +# Based on pkgin module written by Shaun Zinck <shaun.zinck at gmail.com> +# that was based on pacman module written by Afterburn <https://github.com/afterburn> +# that was based on apt module written by Matthew Williams <matthew@flowroute.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: pkgng +short_description: Package manager for FreeBSD >= 9.0 +description: + - Manage binary packages for FreeBSD using 'pkgng' which is available in versions after 9.0. +version_added: "1.2" +options: + name: + description: + - Name or list of names of packages to install/remove. + required: true + state: + description: + - State of the package. + - 'Note: "latest" added in 2.7' + choices: [ 'present', 'latest', 'absent' ] + required: false + default: present + cached: + description: + - Use local package base instead of fetching an updated one. + type: bool + required: false + default: no + annotation: + description: + - A comma-separated list of keyvalue-pairs of the form + C(<+/-/:><key>[=<value>]). A C(+) denotes adding an annotation, a + C(-) denotes removing an annotation, and C(:) denotes modifying an + annotation. + If setting or modifying annotations, a value must be provided. + required: false + version_added: "1.6" + pkgsite: + description: + - For pkgng versions before 1.1.4, specify packagesite to use + for downloading packages. If not specified, use settings from + C(/usr/local/etc/pkg.conf). + - For newer pkgng versions, specify a the name of a repository + configured in C(/usr/local/etc/pkg/repos). + required: false + rootdir: + description: + - For pkgng versions 1.5 and later, pkg will install all packages + within the specified root directory. + - Can not be used together with I(chroot) or I(jail) options. + required: false + chroot: + version_added: "2.1" + description: + - Pkg will chroot in the specified environment. + - Can not be used together with I(rootdir) or I(jail) options. + required: false + jail: + version_added: "2.4" + description: + - Pkg will execute in the given jail name or id. + - Can not be used together with I(chroot) or I(rootdir) options. + autoremove: + version_added: "2.2" + description: + - Remove automatically installed packages which are no longer needed. + required: false + type: bool + default: no +author: "bleader (@bleader)" +notes: + - When using pkgsite, be careful that already in cache packages won't be downloaded again. + - When used with a `loop:` each package will be processed individually, + it is much more efficient to pass the list directly to the `name` option. +''' + +EXAMPLES = ''' +- name: Install package foo + pkgng: + name: foo + state: present + +- name: Annotate package foo and bar + pkgng: + name: foo,bar + annotation: '+test1=baz,-test2,:test3=foobar' + +- name: Remove packages foo and bar + pkgng: + name: foo,bar + state: absent + +# "latest" support added in 2.7 +- name: Upgrade package baz + pkgng: + name: baz + state: latest +''' + + +import re +from ansible.module_utils.basic import AnsibleModule + + +def query_package(module, pkgng_path, name, dir_arg): + + rc, out, err = module.run_command("%s %s info -g -e %s" % (pkgng_path, dir_arg, name)) + + if rc == 0: + return True + + return False + + +def query_update(module, pkgng_path, name, dir_arg, old_pkgng, pkgsite): + + # Check to see if a package upgrade is available. + # rc = 0, no updates available or package not installed + # rc = 1, updates available + if old_pkgng: + rc, out, err = module.run_command("%s %s upgrade -g -n %s" % (pkgsite, pkgng_path, name)) + else: + rc, out, err = module.run_command("%s %s upgrade %s -g -n %s" % (pkgng_path, dir_arg, pkgsite, name)) + + if rc == 1: + return True + + return False + + +def pkgng_older_than(module, pkgng_path, compare_version): + + rc, out, err = module.run_command("%s -v" % pkgng_path) + version = [int(x) for x in re.split(r'[\._]', out)] + + i = 0 + new_pkgng = True + while compare_version[i] == version[i]: + i += 1 + if i == min(len(compare_version), len(version)): + break + else: + if compare_version[i] > version[i]: + new_pkgng = False + return not new_pkgng + + +def remove_packages(module, pkgng_path, packages, dir_arg): + + remove_c = 0 + # Using a for loop in case of error, we can report the package that failed + for package in packages: + # Query the package first, to see if we even need to remove + if not query_package(module, pkgng_path, package, dir_arg): + continue + + if not module.check_mode: + rc, out, err = module.run_command("%s %s delete -y %s" % (pkgng_path, dir_arg, package)) + + if not module.check_mode and query_package(module, pkgng_path, package, dir_arg): + module.fail_json(msg="failed to remove %s: %s" % (package, out)) + + remove_c += 1 + + if remove_c > 0: + + return (True, "removed %s package(s)" % remove_c) + + return (False, "package(s) already absent") + + +def install_packages(module, pkgng_path, packages, cached, pkgsite, dir_arg, state): + + install_c = 0 + + # as of pkg-1.1.4, PACKAGESITE is deprecated in favor of repository definitions + # in /usr/local/etc/pkg/repos + old_pkgng = pkgng_older_than(module, pkgng_path, [1, 1, 4]) + if pkgsite != "": + if old_pkgng: + pkgsite = "PACKAGESITE=%s" % (pkgsite) + else: + pkgsite = "-r %s" % (pkgsite) + + # This environment variable skips mid-install prompts, + # setting them to their default values. + batch_var = 'env BATCH=yes' + + if not module.check_mode and not cached: + if old_pkgng: + rc, out, err = module.run_command("%s %s update" % (pkgsite, pkgng_path)) + else: + rc, out, err = module.run_command("%s %s update" % (pkgng_path, dir_arg)) + if rc != 0: + module.fail_json(msg="Could not update catalogue [%d]: %s %s" % (rc, out, err)) + + for package in packages: + already_installed = query_package(module, pkgng_path, package, dir_arg) + if already_installed and state == "present": + continue + + update_available = query_update(module, pkgng_path, package, dir_arg, old_pkgng, pkgsite) + if not update_available and already_installed and state == "latest": + continue + + if not module.check_mode: + if already_installed: + action = "upgrade" + else: + action = "install" + if old_pkgng: + rc, out, err = module.run_command("%s %s %s %s -g -U -y %s" % (batch_var, pkgsite, pkgng_path, action, package)) + else: + rc, out, err = module.run_command("%s %s %s %s %s -g -U -y %s" % (batch_var, pkgng_path, dir_arg, action, pkgsite, package)) + + if not module.check_mode and not query_package(module, pkgng_path, package, dir_arg): + module.fail_json(msg="failed to %s %s: %s" % (action, package, out), stderr=err) + + install_c += 1 + + if install_c > 0: + return (True, "added %s package(s)" % (install_c)) + + return (False, "package(s) already %s" % (state)) + + +def annotation_query(module, pkgng_path, package, tag, dir_arg): + rc, out, err = module.run_command("%s %s info -g -A %s" % (pkgng_path, dir_arg, package)) + match = re.search(r'^\s*(?P<tag>%s)\s*:\s*(?P<value>\w+)' % tag, out, flags=re.MULTILINE) + if match: + return match.group('value') + return False + + +def annotation_add(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if not _value: + # Annotation does not exist, add it. + rc, out, err = module.run_command('%s %s annotate -y -A %s %s "%s"' + % (pkgng_path, dir_arg, package, tag, value)) + if rc != 0: + module.fail_json(msg="could not annotate %s: %s" + % (package, out), stderr=err) + return True + elif _value != value: + # Annotation exists, but value differs + module.fail_json( + mgs="failed to annotate %s, because %s is already set to %s, but should be set to %s" + % (package, tag, _value, value)) + return False + else: + # Annotation exists, nothing to do + return False + + +def annotation_delete(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if _value: + rc, out, err = module.run_command('%s %s annotate -y -D %s %s' + % (pkgng_path, dir_arg, package, tag)) + if rc != 0: + module.fail_json(msg="could not delete annotation to %s: %s" + % (package, out), stderr=err) + return True + return False + + +def annotation_modify(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if not value: + # No such tag + module.fail_json(msg="could not change annotation to %s: tag %s does not exist" + % (package, tag)) + elif _value == value: + # No change in value + return False + else: + rc, out, err = module.run_command('%s %s annotate -y -M %s %s "%s"' + % (pkgng_path, dir_arg, package, tag, value)) + if rc != 0: + module.fail_json(msg="could not change annotation annotation to %s: %s" + % (package, out), stderr=err) + return True + + +def annotate_packages(module, pkgng_path, packages, annotation, dir_arg): + annotate_c = 0 + annotations = map(lambda _annotation: + re.match(r'(?P<operation>[\+-:])(?P<tag>\w+)(=(?P<value>\w+))?', + _annotation).groupdict(), + re.split(r',', annotation)) + + operation = { + '+': annotation_add, + '-': annotation_delete, + ':': annotation_modify + } + + for package in packages: + for _annotation in annotations: + if operation[_annotation['operation']](module, pkgng_path, package, _annotation['tag'], _annotation['value']): + annotate_c += 1 + + if annotate_c > 0: + return (True, "added %s annotations." % annotate_c) + return (False, "changed no annotations") + + +def autoremove_packages(module, pkgng_path, dir_arg): + rc, out, err = module.run_command("%s %s autoremove -n" % (pkgng_path, dir_arg)) + + autoremove_c = 0 + + match = re.search('^Deinstallation has been requested for the following ([0-9]+) packages', out, re.MULTILINE) + if match: + autoremove_c = int(match.group(1)) + + if autoremove_c == 0: + return False, "no package(s) to autoremove" + + if not module.check_mode: + rc, out, err = module.run_command("%s %s autoremove -y" % (pkgng_path, dir_arg)) + + return True, "autoremoved %d package(s)" % (autoremove_c) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + state=dict(default="present", choices=["present", "latest", "absent"], required=False), + name=dict(aliases=["pkg"], required=True, type='list'), + cached=dict(default=False, type='bool'), + annotation=dict(default="", required=False), + pkgsite=dict(default="", required=False), + rootdir=dict(default="", required=False, type='path'), + chroot=dict(default="", required=False, type='path'), + jail=dict(default="", required=False, type='str'), + autoremove=dict(default=False, type='bool')), + supports_check_mode=True, + mutually_exclusive=[["rootdir", "chroot", "jail"]]) + + pkgng_path = module.get_bin_path('pkg', True) + + p = module.params + + pkgs = p["name"] + + changed = False + msgs = [] + dir_arg = "" + + if p["rootdir"] != "": + old_pkgng = pkgng_older_than(module, pkgng_path, [1, 5, 0]) + if old_pkgng: + module.fail_json(msg="To use option 'rootdir' pkg version must be 1.5 or greater") + else: + dir_arg = "--rootdir %s" % (p["rootdir"]) + + if p["chroot"] != "": + dir_arg = '--chroot %s' % (p["chroot"]) + + if p["jail"] != "": + dir_arg = '--jail %s' % (p["jail"]) + + if p["state"] in ("present", "latest"): + _changed, _msg = install_packages(module, pkgng_path, pkgs, p["cached"], p["pkgsite"], dir_arg, p["state"]) + changed = changed or _changed + msgs.append(_msg) + + elif p["state"] == "absent": + _changed, _msg = remove_packages(module, pkgng_path, pkgs, dir_arg) + changed = changed or _changed + msgs.append(_msg) + + if p["autoremove"]: + _changed, _msg = autoremove_packages(module, pkgng_path, dir_arg) + changed = changed or _changed + msgs.append(_msg) + + if p["annotation"]: + _changed, _msg = annotate_packages(module, pkgng_path, pkgs, p["annotation"], dir_arg) + changed = changed or _changed + msgs.append(_msg) + + module.exit_json(changed=changed, msg=", ".join(msgs)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/sefcontext.py b/test/support/integration/plugins/modules/sefcontext.py new file mode 100644 index 0000000..5574abc --- /dev/null +++ b/test/support/integration/plugins/modules/sefcontext.py @@ -0,0 +1,310 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: sefcontext +short_description: Manages SELinux file context mapping definitions +description: +- Manages SELinux file context mapping definitions. +- Similar to the C(semanage fcontext) command. +version_added: '2.2' +options: + target: + description: + - Target path (expression). + type: str + required: yes + aliases: [ path ] + ftype: + description: + - The file type that should have SELinux contexts applied. + - "The following file type options are available:" + - C(a) for all files, + - C(b) for block devices, + - C(c) for character devices, + - C(d) for directories, + - C(f) for regular files, + - C(l) for symbolic links, + - C(p) for named pipes, + - C(s) for socket files. + type: str + choices: [ a, b, c, d, f, l, p, s ] + default: a + setype: + description: + - SELinux type for the specified target. + type: str + required: yes + seuser: + description: + - SELinux user for the specified target. + type: str + selevel: + description: + - SELinux range for the specified target. + type: str + aliases: [ serange ] + state: + description: + - Whether the SELinux file context must be C(absent) or C(present). + type: str + choices: [ absent, present ] + default: present + reload: + description: + - Reload SELinux policy after commit. + - Note that this does not apply SELinux file contexts to existing files. + type: bool + default: yes + ignore_selinux_state: + description: + - Useful for scenarios (chrooted environment) that you can't get the real SELinux state. + type: bool + default: no + version_added: '2.8' +notes: +- The changes are persistent across reboots. +- The M(sefcontext) module does not modify existing files to the new + SELinux context(s), so it is advisable to first create the SELinux + file contexts before creating files, or run C(restorecon) manually + for the existing files that require the new SELinux file contexts. +- Not applying SELinux fcontexts to existing files is a deliberate + decision as it would be unclear what reported changes would entail + to, and there's no guarantee that applying SELinux fcontext does + not pick up other unrelated prior changes. +requirements: +- libselinux-python +- policycoreutils-python +author: +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r''' +- name: Allow apache to modify files in /srv/git_repos + sefcontext: + target: '/srv/git_repos(/.*)?' + setype: httpd_git_rw_content_t + state: present + +- name: Apply new SELinux file context to filesystem + command: restorecon -irv /srv/git_repos +''' + +RETURN = r''' +# Default return values +''' + +import os +import subprocess +import traceback + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.common.respawn import has_respawned, probe_interpreters_for_module, respawn_module +from ansible.module_utils._text import to_native + +SELINUX_IMP_ERR = None +try: + import selinux + HAVE_SELINUX = True +except ImportError: + SELINUX_IMP_ERR = traceback.format_exc() + HAVE_SELINUX = False + +SEOBJECT_IMP_ERR = None +try: + import seobject + HAVE_SEOBJECT = True +except ImportError: + SEOBJECT_IMP_ERR = traceback.format_exc() + HAVE_SEOBJECT = False + +# Add missing entries (backward compatible) +if HAVE_SEOBJECT: + seobject.file_types.update( + a=seobject.SEMANAGE_FCONTEXT_ALL, + b=seobject.SEMANAGE_FCONTEXT_BLOCK, + c=seobject.SEMANAGE_FCONTEXT_CHAR, + d=seobject.SEMANAGE_FCONTEXT_DIR, + f=seobject.SEMANAGE_FCONTEXT_REG, + l=seobject.SEMANAGE_FCONTEXT_LINK, + p=seobject.SEMANAGE_FCONTEXT_PIPE, + s=seobject.SEMANAGE_FCONTEXT_SOCK, + ) + +# Make backward compatible +option_to_file_type_str = dict( + a='all files', + b='block device', + c='character device', + d='directory', + f='regular file', + l='symbolic link', + p='named pipe', + s='socket', +) + + +def get_runtime_status(ignore_selinux_state=False): + return True if ignore_selinux_state is True else selinux.is_selinux_enabled() + + +def semanage_fcontext_exists(sefcontext, target, ftype): + ''' Get the SELinux file context mapping definition from policy. Return None if it does not exist. ''' + + # Beware that records comprise of a string representation of the file_type + record = (target, option_to_file_type_str[ftype]) + records = sefcontext.get_all() + try: + return records[record] + except KeyError: + return None + + +def semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser, sestore=''): + ''' Add or modify SELinux file context mapping definition to the policy. ''' + + changed = False + prepared_diff = '' + + try: + sefcontext = seobject.fcontextRecords(sestore) + sefcontext.set_reload(do_reload) + exists = semanage_fcontext_exists(sefcontext, target, ftype) + if exists: + # Modify existing entry + orig_seuser, orig_serole, orig_setype, orig_serange = exists + + if seuser is None: + seuser = orig_seuser + if serange is None: + serange = orig_serange + + if setype != orig_setype or seuser != orig_seuser or serange != orig_serange: + if not module.check_mode: + sefcontext.modify(target, setype, ftype, serange, seuser) + changed = True + + if module._diff: + prepared_diff += '# Change to semanage file context mappings\n' + prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, orig_seuser, orig_serole, orig_setype, orig_serange) + prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, orig_serole, setype, serange) + else: + # Add missing entry + if seuser is None: + seuser = 'system_u' + if serange is None: + serange = 's0' + + if not module.check_mode: + sefcontext.add(target, setype, ftype, serange, seuser) + changed = True + + if module._diff: + prepared_diff += '# Addition to semanage file context mappings\n' + prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, 'object_r', setype, serange) + + except Exception as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e))) + + if module._diff and prepared_diff: + result['diff'] = dict(prepared=prepared_diff) + + module.exit_json(changed=changed, seuser=seuser, serange=serange, **result) + + +def semanage_fcontext_delete(module, result, target, ftype, do_reload, sestore=''): + ''' Delete SELinux file context mapping definition from the policy. ''' + + changed = False + prepared_diff = '' + + try: + sefcontext = seobject.fcontextRecords(sestore) + sefcontext.set_reload(do_reload) + exists = semanage_fcontext_exists(sefcontext, target, ftype) + if exists: + # Remove existing entry + orig_seuser, orig_serole, orig_setype, orig_serange = exists + + if not module.check_mode: + sefcontext.delete(target, ftype) + changed = True + + if module._diff: + prepared_diff += '# Deletion to semanage file context mappings\n' + prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, exists[0], exists[1], exists[2], exists[3]) + + except Exception as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e))) + + if module._diff and prepared_diff: + result['diff'] = dict(prepared=prepared_diff) + + module.exit_json(changed=changed, **result) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + ignore_selinux_state=dict(type='bool', default=False), + target=dict(type='str', required=True, aliases=['path']), + ftype=dict(type='str', default='a', choices=option_to_file_type_str.keys()), + setype=dict(type='str', required=True), + seuser=dict(type='str'), + selevel=dict(type='str', aliases=['serange']), + state=dict(type='str', default='present', choices=['absent', 'present']), + reload=dict(type='bool', default=True), + ), + supports_check_mode=True, + ) + + if not HAVE_SELINUX or not HAVE_SEOBJECT and not has_respawned(): + system_interpreters = [ + '/usr/libexec/platform-python', + '/usr/bin/python3', + '/usr/bin/python2', + ] + # policycoreutils-python depends on libselinux-python + interpreter = probe_interpreters_for_module(system_interpreters, 'seobject') + if interpreter: + respawn_module(interpreter) + + if not HAVE_SELINUX or not HAVE_SEOBJECT: + module.fail_json(msg=missing_required_lib("policycoreutils-python(3)"), exception=SELINUX_IMP_ERR) + + ignore_selinux_state = module.params['ignore_selinux_state'] + + if not get_runtime_status(ignore_selinux_state): + module.fail_json(msg="SELinux is disabled on this host.") + + target = module.params['target'] + ftype = module.params['ftype'] + setype = module.params['setype'] + seuser = module.params['seuser'] + serange = module.params['selevel'] + state = module.params['state'] + do_reload = module.params['reload'] + + result = dict(target=target, ftype=ftype, setype=setype, state=state) + + if state == 'present': + semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser) + elif state == 'absent': + semanage_fcontext_delete(module, result, target, ftype, do_reload) + else: + module.fail_json(msg='Invalid value of argument "state": {0}'.format(state)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/timezone.py b/test/support/integration/plugins/modules/timezone.py new file mode 100644 index 0000000..b7439a1 --- /dev/null +++ b/test/support/integration/plugins/modules/timezone.py @@ -0,0 +1,909 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Shinichi TAMURA (@tmshn) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: timezone +short_description: Configure timezone setting +description: + - This module configures the timezone setting, both of the system clock and of the hardware clock. If you want to set up the NTP, use M(service) module. + - It is recommended to restart C(crond) after changing the timezone, otherwise the jobs may run at the wrong time. + - Several different tools are used depending on the OS/Distribution involved. + For Linux it can use C(timedatectl) or edit C(/etc/sysconfig/clock) or C(/etc/timezone) and C(hwclock). + On SmartOS, C(sm-set-timezone), for macOS, C(systemsetup), for BSD, C(/etc/localtime) is modified. + On AIX, C(chtz) is used. + - As of Ansible 2.3 support was added for SmartOS and BSDs. + - As of Ansible 2.4 support was added for macOS. + - As of Ansible 2.9 support was added for AIX 6.1+ + - Windows and HPUX are not supported, please let us know if you find any other OS/distro in which this fails. +version_added: "2.2" +options: + name: + description: + - Name of the timezone for the system clock. + - Default is to keep current setting. + - B(At least one of name and hwclock are required.) + type: str + hwclock: + description: + - Whether the hardware clock is in UTC or in local timezone. + - Default is to keep current setting. + - Note that this option is recommended not to change and may fail + to configure, especially on virtual environments such as AWS. + - B(At least one of name and hwclock are required.) + - I(Only used on Linux.) + type: str + aliases: [ rtc ] + choices: [ local, UTC ] +notes: + - On SmartOS the C(sm-set-timezone) utility (part of the smtools package) is required to set the zone timezone + - On AIX only Olson/tz database timezones are useable (POSIX is not supported). + - An OS reboot is also required on AIX for the new timezone setting to take effect. +author: + - Shinichi TAMURA (@tmshn) + - Jasper Lievisse Adriaanse (@jasperla) + - Indrajit Raychaudhuri (@indrajitr) +''' + +RETURN = r''' +diff: + description: The differences about the given arguments. + returned: success + type: complex + contains: + before: + description: The values before change + type: dict + after: + description: The values after change + type: dict +''' + +EXAMPLES = r''' +- name: Set timezone to Asia/Tokyo + timezone: + name: Asia/Tokyo +''' + +import errno +import os +import platform +import random +import re +import string +import filecmp + +from ansible.module_utils.basic import AnsibleModule, get_distribution +from ansible.module_utils.six import iteritems + + +class Timezone(object): + """This is a generic Timezone manipulation class that is subclassed based on platform. + + A subclass may wish to override the following action methods: + - get(key, phase) ... get the value from the system at `phase` + - set(key, value) ... set the value to the current system + """ + + def __new__(cls, module): + """Return the platform-specific subclass. + + It does not use load_platform_subclass() because it needs to judge based + on whether the `timedatectl` command exists and is available. + + Args: + module: The AnsibleModule. + """ + if platform.system() == 'Linux': + timedatectl = module.get_bin_path('timedatectl') + if timedatectl is not None: + rc, stdout, stderr = module.run_command(timedatectl) + if rc == 0: + return super(Timezone, SystemdTimezone).__new__(SystemdTimezone) + else: + module.warn('timedatectl command was found but not usable: %s. using other method.' % stderr) + return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone) + else: + return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone) + elif re.match('^joyent_.*Z', platform.version()): + # platform.system() returns SunOS, which is too broad. So look at the + # platform version instead. However we have to ensure that we're not + # running in the global zone where changing the timezone has no effect. + zonename_cmd = module.get_bin_path('zonename') + if zonename_cmd is not None: + (rc, stdout, _) = module.run_command(zonename_cmd) + if rc == 0 and stdout.strip() == 'global': + module.fail_json(msg='Adjusting timezone is not supported in Global Zone') + + return super(Timezone, SmartOSTimezone).__new__(SmartOSTimezone) + elif platform.system() == 'Darwin': + return super(Timezone, DarwinTimezone).__new__(DarwinTimezone) + elif re.match('^(Free|Net|Open)BSD', platform.platform()): + return super(Timezone, BSDTimezone).__new__(BSDTimezone) + elif platform.system() == 'AIX': + AIXoslevel = int(platform.version() + platform.release()) + if AIXoslevel >= 61: + return super(Timezone, AIXTimezone).__new__(AIXTimezone) + else: + module.fail_json(msg='AIX os level must be >= 61 for timezone module (Target: %s).' % AIXoslevel) + else: + # Not supported yet + return super(Timezone, Timezone).__new__(Timezone) + + def __init__(self, module): + """Initialize of the class. + + Args: + module: The AnsibleModule. + """ + super(Timezone, self).__init__() + self.msg = [] + # `self.value` holds the values for each params on each phases. + # Initially there's only info of "planned" phase, but the + # `self.check()` function will fill out it. + self.value = dict() + for key in module.argument_spec: + value = module.params[key] + if value is not None: + self.value[key] = dict(planned=value) + self.module = module + + def abort(self, msg): + """Abort the process with error message. + + This is just the wrapper of module.fail_json(). + + Args: + msg: The error message. + """ + error_msg = ['Error message:', msg] + if len(self.msg) > 0: + error_msg.append('Other message(s):') + error_msg.extend(self.msg) + self.module.fail_json(msg='\n'.join(error_msg)) + + def execute(self, *commands, **kwargs): + """Execute the shell command. + + This is just the wrapper of module.run_command(). + + Args: + *commands: The command to execute. + It will be concatenated with single space. + **kwargs: Only 'log' key is checked. + If kwargs['log'] is true, record the command to self.msg. + + Returns: + stdout: Standard output of the command. + """ + command = ' '.join(commands) + (rc, stdout, stderr) = self.module.run_command(command, check_rc=True) + if kwargs.get('log', False): + self.msg.append('executed `%s`' % command) + return stdout + + def diff(self, phase1='before', phase2='after'): + """Calculate the difference between given 2 phases. + + Args: + phase1, phase2: The names of phase to compare. + + Returns: + diff: The difference of value between phase1 and phase2. + This is in the format which can be used with the + `--diff` option of ansible-playbook. + """ + diff = {phase1: {}, phase2: {}} + for key, value in iteritems(self.value): + diff[phase1][key] = value[phase1] + diff[phase2][key] = value[phase2] + return diff + + def check(self, phase): + """Check the state in given phase and set it to `self.value`. + + Args: + phase: The name of the phase to check. + + Returns: + NO RETURN VALUE + """ + if phase == 'planned': + return + for key, value in iteritems(self.value): + value[phase] = self.get(key, phase) + + def change(self): + """Make the changes effect based on `self.value`.""" + for key, value in iteritems(self.value): + if value['before'] != value['planned']: + self.set(key, value['planned']) + + # =========================================== + # Platform specific methods (must be replaced by subclass). + + def get(self, key, phase): + """Get the value for the key at the given phase. + + Called from self.check(). + + Args: + key: The key to get the value + phase: The phase to get the value + + Return: + value: The value for the key at the given phase. + """ + self.abort('get(key, phase) is not implemented on target platform') + + def set(self, key, value): + """Set the value for the key (of course, for the phase 'after'). + + Called from self.change(). + + Args: + key: Key to set the value + value: Value to set + """ + self.abort('set(key, value) is not implemented on target platform') + + def _verify_timezone(self): + tz = self.value['name']['planned'] + tzfile = '/usr/share/zoneinfo/%s' % tz + if not os.path.isfile(tzfile): + self.abort('given timezone "%s" is not available' % tz) + return tzfile + + +class SystemdTimezone(Timezone): + """This is a Timezone manipulation class for systemd-powered Linux. + + It uses the `timedatectl` command to check/set all arguments. + """ + + regexps = dict( + hwclock=re.compile(r'^\s*RTC in local TZ\s*:\s*([^\s]+)', re.MULTILINE), + name=re.compile(r'^\s*Time ?zone\s*:\s*([^\s]+)', re.MULTILINE) + ) + + subcmds = dict( + hwclock='set-local-rtc', + name='set-timezone' + ) + + def __init__(self, module): + super(SystemdTimezone, self).__init__(module) + self.timedatectl = module.get_bin_path('timedatectl', required=True) + self.status = dict() + # Validate given timezone + if 'name' in self.value: + self._verify_timezone() + + def _get_status(self, phase): + if phase not in self.status: + self.status[phase] = self.execute(self.timedatectl, 'status') + return self.status[phase] + + def get(self, key, phase): + status = self._get_status(phase) + value = self.regexps[key].search(status).group(1) + if key == 'hwclock': + # For key='hwclock'; convert yes/no -> local/UTC + if self.module.boolean(value): + value = 'local' + else: + value = 'UTC' + return value + + def set(self, key, value): + # For key='hwclock'; convert UTC/local -> yes/no + if key == 'hwclock': + if value == 'local': + value = 'yes' + else: + value = 'no' + self.execute(self.timedatectl, self.subcmds[key], value, log=True) + + +class NosystemdTimezone(Timezone): + """This is a Timezone manipulation class for non systemd-powered Linux. + + For timezone setting, it edits the following file and reflect changes: + - /etc/sysconfig/clock ... RHEL/CentOS + - /etc/timezone ... Debian/Ubuntu + For hwclock setting, it executes `hwclock --systohc` command with the + '--utc' or '--localtime' option. + """ + + conf_files = dict( + name=None, # To be set in __init__ + hwclock=None, # To be set in __init__ + adjtime='/etc/adjtime' + ) + + # It's fine if all tree config files don't exist + allow_no_file = dict( + name=True, + hwclock=True, + adjtime=True + ) + + regexps = dict( + name=None, # To be set in __init__ + hwclock=re.compile(r'^UTC\s*=\s*([^\s]+)', re.MULTILINE), + adjtime=re.compile(r'^(UTC|LOCAL)$', re.MULTILINE) + ) + + dist_regexps = dict( + SuSE=re.compile(r'^TIMEZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE), + redhat=re.compile(r'^ZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE) + ) + + dist_tzline_format = dict( + SuSE='TIMEZONE="%s"\n', + redhat='ZONE="%s"\n' + ) + + def __init__(self, module): + super(NosystemdTimezone, self).__init__(module) + # Validate given timezone + if 'name' in self.value: + tzfile = self._verify_timezone() + # `--remove-destination` is needed if /etc/localtime is a symlink so + # that it overwrites it instead of following it. + self.update_timezone = ['%s --remove-destination %s /etc/localtime' % (self.module.get_bin_path('cp', required=True), tzfile)] + self.update_hwclock = self.module.get_bin_path('hwclock', required=True) + # Distribution-specific configurations + if self.module.get_bin_path('dpkg-reconfigure') is not None: + # Debian/Ubuntu + if 'name' in self.value: + self.update_timezone = ['%s -sf %s /etc/localtime' % (self.module.get_bin_path('ln', required=True), tzfile), + '%s --frontend noninteractive tzdata' % self.module.get_bin_path('dpkg-reconfigure', required=True)] + self.conf_files['name'] = '/etc/timezone' + self.conf_files['hwclock'] = '/etc/default/rcS' + self.regexps['name'] = re.compile(r'^([^\s]+)', re.MULTILINE) + self.tzline_format = '%s\n' + else: + # RHEL/CentOS/SUSE + if self.module.get_bin_path('tzdata-update') is not None: + # tzdata-update cannot update the timezone if /etc/localtime is + # a symlink so we have to use cp to update the time zone which + # was set above. + if not os.path.islink('/etc/localtime'): + self.update_timezone = [self.module.get_bin_path('tzdata-update', required=True)] + # else: + # self.update_timezone = 'cp --remove-destination ...' <- configured above + self.conf_files['name'] = '/etc/sysconfig/clock' + self.conf_files['hwclock'] = '/etc/sysconfig/clock' + try: + f = open(self.conf_files['name'], 'r') + except IOError as err: + if self._allow_ioerror(err, 'name'): + # If the config file doesn't exist detect the distribution and set regexps. + distribution = get_distribution() + if distribution == 'SuSE': + # For SUSE + self.regexps['name'] = self.dist_regexps['SuSE'] + self.tzline_format = self.dist_tzline_format['SuSE'] + else: + # For RHEL/CentOS + self.regexps['name'] = self.dist_regexps['redhat'] + self.tzline_format = self.dist_tzline_format['redhat'] + else: + self.abort('could not read configuration file "%s"' % self.conf_files['name']) + else: + # The key for timezone might be `ZONE` or `TIMEZONE` + # (the former is used in RHEL/CentOS and the latter is used in SUSE linux). + # So check the content of /etc/sysconfig/clock and decide which key to use. + sysconfig_clock = f.read() + f.close() + if re.search(r'^TIMEZONE\s*=', sysconfig_clock, re.MULTILINE): + # For SUSE + self.regexps['name'] = self.dist_regexps['SuSE'] + self.tzline_format = self.dist_tzline_format['SuSE'] + else: + # For RHEL/CentOS + self.regexps['name'] = self.dist_regexps['redhat'] + self.tzline_format = self.dist_tzline_format['redhat'] + + def _allow_ioerror(self, err, key): + # In some cases, even if the target file does not exist, + # simply creating it may solve the problem. + # In such cases, we should continue the configuration rather than aborting. + if err.errno != errno.ENOENT: + # If the error is not ENOENT ("No such file or directory"), + # (e.g., permission error, etc), we should abort. + return False + return self.allow_no_file.get(key, False) + + def _edit_file(self, filename, regexp, value, key): + """Replace the first matched line with given `value`. + + If `regexp` matched more than once, other than the first line will be deleted. + + Args: + filename: The name of the file to edit. + regexp: The regular expression to search with. + value: The line which will be inserted. + key: For what key the file is being editted. + """ + # Read the file + try: + file = open(filename, 'r') + except IOError as err: + if self._allow_ioerror(err, key): + lines = [] + else: + self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename)) + else: + lines = file.readlines() + file.close() + # Find the all matched lines + matched_indices = [] + for i, line in enumerate(lines): + if regexp.search(line): + matched_indices.append(i) + if len(matched_indices) > 0: + insert_line = matched_indices[0] + else: + insert_line = 0 + # Remove all matched lines + for i in matched_indices[::-1]: + del lines[i] + # ...and insert the value + lines.insert(insert_line, value) + # Write the changes + try: + file = open(filename, 'w') + except IOError: + self.abort('tried to configure %s using a file "%s", but could not write to it' % (key, filename)) + else: + file.writelines(lines) + file.close() + self.msg.append('Added 1 line and deleted %s line(s) on %s' % (len(matched_indices), filename)) + + def _get_value_from_config(self, key, phase): + filename = self.conf_files[key] + try: + file = open(filename, mode='r') + except IOError as err: + if self._allow_ioerror(err, key): + if key == 'hwclock': + return 'n/a' + elif key == 'adjtime': + return 'UTC' + elif key == 'name': + return 'n/a' + else: + self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename)) + else: + status = file.read() + file.close() + try: + value = self.regexps[key].search(status).group(1) + except AttributeError: + if key == 'hwclock': + # If we cannot find UTC in the config that's fine. + return 'n/a' + elif key == 'adjtime': + # If we cannot find UTC/LOCAL in /etc/cannot that means UTC + # will be used by default. + return 'UTC' + elif key == 'name': + if phase == 'before': + # In 'before' phase UTC/LOCAL doesn't need to be set in + # the timezone config file, so we ignore this error. + return 'n/a' + else: + self.abort('tried to configure %s using a file "%s", but could not find a valid value in it' % (key, filename)) + else: + if key == 'hwclock': + # convert yes/no -> UTC/local + if self.module.boolean(value): + value = 'UTC' + else: + value = 'local' + elif key == 'adjtime': + # convert LOCAL -> local + if value != 'UTC': + value = value.lower() + return value + + def get(self, key, phase): + planned = self.value[key]['planned'] + if key == 'hwclock': + value = self._get_value_from_config(key, phase) + if value == planned: + # If the value in the config file is the same as the 'planned' + # value, we need to check /etc/adjtime. + value = self._get_value_from_config('adjtime', phase) + elif key == 'name': + value = self._get_value_from_config(key, phase) + if value == planned: + # If the planned values is the same as the one in the config file + # we need to check if /etc/localtime is also set to the 'planned' zone. + if os.path.islink('/etc/localtime'): + # If /etc/localtime is a symlink and is not set to the TZ we 'planned' + # to set, we need to return the TZ which the symlink points to. + if os.path.exists('/etc/localtime'): + # We use readlink() because on some distros zone files are symlinks + # to other zone files, so it's hard to get which TZ is actually set + # if we follow the symlink. + path = os.readlink('/etc/localtime') + linktz = re.search(r'/usr/share/zoneinfo/(.*)', path, re.MULTILINE) + if linktz: + valuelink = linktz.group(1) + if valuelink != planned: + value = valuelink + else: + # Set current TZ to 'n/a' if the symlink points to a path + # which isn't a zone file. + value = 'n/a' + else: + # Set current TZ to 'n/a' if the symlink to the zone file is broken. + value = 'n/a' + else: + # If /etc/localtime is not a symlink best we can do is compare it with + # the 'planned' zone info file and return 'n/a' if they are different. + try: + if not filecmp.cmp('/etc/localtime', '/usr/share/zoneinfo/' + planned): + return 'n/a' + except Exception: + return 'n/a' + else: + self.abort('unknown parameter "%s"' % key) + return value + + def set_timezone(self, value): + self._edit_file(filename=self.conf_files['name'], + regexp=self.regexps['name'], + value=self.tzline_format % value, + key='name') + for cmd in self.update_timezone: + self.execute(cmd) + + def set_hwclock(self, value): + if value == 'local': + option = '--localtime' + utc = 'no' + else: + option = '--utc' + utc = 'yes' + if self.conf_files['hwclock'] is not None: + self._edit_file(filename=self.conf_files['hwclock'], + regexp=self.regexps['hwclock'], + value='UTC=%s\n' % utc, + key='hwclock') + self.execute(self.update_hwclock, '--systohc', option, log=True) + + def set(self, key, value): + if key == 'name': + self.set_timezone(value) + elif key == 'hwclock': + self.set_hwclock(value) + else: + self.abort('unknown parameter "%s"' % key) + + +class SmartOSTimezone(Timezone): + """This is a Timezone manipulation class for SmartOS instances. + + It uses the C(sm-set-timezone) utility to set the timezone, and + inspects C(/etc/default/init) to determine the current timezone. + + NB: A zone needs to be rebooted in order for the change to be + activated. + """ + + def __init__(self, module): + super(SmartOSTimezone, self).__init__(module) + self.settimezone = self.module.get_bin_path('sm-set-timezone', required=False) + if not self.settimezone: + module.fail_json(msg='sm-set-timezone not found. Make sure the smtools package is installed.') + + def get(self, key, phase): + """Lookup the current timezone name in `/etc/default/init`. If anything else + is requested, or if the TZ field is not set we fail. + """ + if key == 'name': + try: + f = open('/etc/default/init', 'r') + for line in f: + m = re.match('^TZ=(.*)$', line.strip()) + if m: + return m.groups()[0] + except Exception: + self.module.fail_json(msg='Failed to read /etc/default/init') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + """Set the requested timezone through sm-set-timezone, an invalid timezone name + will be rejected and we have no further input validation to perform. + """ + if key == 'name': + cmd = 'sm-set-timezone %s' % value + + (rc, stdout, stderr) = self.module.run_command(cmd) + + if rc != 0: + self.module.fail_json(msg=stderr) + + # sm-set-timezone knows no state and will always set the timezone. + # XXX: https://github.com/joyent/smtools/pull/2 + m = re.match(r'^\* Changed (to)? timezone (to)? (%s).*' % value, stdout.splitlines()[1]) + if not (m and m.groups()[-1] == value): + self.module.fail_json(msg='Failed to set timezone') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class DarwinTimezone(Timezone): + """This is the timezone implementation for Darwin which, unlike other *BSD + implementations, uses the `systemsetup` command on Darwin to check/set + the timezone. + """ + + regexps = dict( + name=re.compile(r'^\s*Time ?Zone\s*:\s*([^\s]+)', re.MULTILINE) + ) + + def __init__(self, module): + super(DarwinTimezone, self).__init__(module) + self.systemsetup = module.get_bin_path('systemsetup', required=True) + self.status = dict() + # Validate given timezone + if 'name' in self.value: + self._verify_timezone() + + def _get_current_timezone(self, phase): + """Lookup the current timezone via `systemsetup -gettimezone`.""" + if phase not in self.status: + self.status[phase] = self.execute(self.systemsetup, '-gettimezone') + return self.status[phase] + + def _verify_timezone(self): + tz = self.value['name']['planned'] + # Lookup the list of supported timezones via `systemsetup -listtimezones`. + # Note: Skip the first line that contains the label 'Time Zones:' + out = self.execute(self.systemsetup, '-listtimezones').splitlines()[1:] + tz_list = list(map(lambda x: x.strip(), out)) + if tz not in tz_list: + self.abort('given timezone "%s" is not available' % tz) + return tz + + def get(self, key, phase): + if key == 'name': + status = self._get_current_timezone(phase) + value = self.regexps[key].search(status).group(1) + return value + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + if key == 'name': + self.execute(self.systemsetup, '-settimezone', value, log=True) + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class BSDTimezone(Timezone): + """This is the timezone implementation for *BSD which works simply through + updating the `/etc/localtime` symlink to point to a valid timezone name under + `/usr/share/zoneinfo`. + """ + + def __init__(self, module): + super(BSDTimezone, self).__init__(module) + + def __get_timezone(self): + zoneinfo_dir = '/usr/share/zoneinfo/' + localtime_file = '/etc/localtime' + + # Strategy 1: + # If /etc/localtime does not exist, assum the timezone is UTC. + if not os.path.exists(localtime_file): + self.module.warn('Could not read /etc/localtime. Assuming UTC.') + return 'UTC' + + # Strategy 2: + # Follow symlink of /etc/localtime + zoneinfo_file = localtime_file + while not zoneinfo_file.startswith(zoneinfo_dir): + try: + zoneinfo_file = os.readlink(localtime_file) + except OSError: + # OSError means "end of symlink chain" or broken link. + break + else: + return zoneinfo_file.replace(zoneinfo_dir, '') + + # Strategy 3: + # (If /etc/localtime is not symlinked) + # Check all files in /usr/share/zoneinfo and return first non-link match. + for dname, _, fnames in sorted(os.walk(zoneinfo_dir)): + for fname in sorted(fnames): + zoneinfo_file = os.path.join(dname, fname) + if not os.path.islink(zoneinfo_file) and filecmp.cmp(zoneinfo_file, localtime_file): + return zoneinfo_file.replace(zoneinfo_dir, '') + + # Strategy 4: + # As a fall-back, return 'UTC' as default assumption. + self.module.warn('Could not identify timezone name from /etc/localtime. Assuming UTC.') + return 'UTC' + + def get(self, key, phase): + """Lookup the current timezone by resolving `/etc/localtime`.""" + if key == 'name': + return self.__get_timezone() + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + if key == 'name': + # First determine if the requested timezone is valid by looking in + # the zoneinfo directory. + zonefile = '/usr/share/zoneinfo/' + value + try: + if not os.path.isfile(zonefile): + self.module.fail_json(msg='%s is not a recognized timezone' % value) + except Exception: + self.module.fail_json(msg='Failed to stat %s' % zonefile) + + # Now (somewhat) atomically update the symlink by creating a new + # symlink and move it into place. Otherwise we have to remove the + # original symlink and create the new symlink, however that would + # create a race condition in case another process tries to read + # /etc/localtime between removal and creation. + suffix = "".join([random.choice(string.ascii_letters + string.digits) for x in range(0, 10)]) + new_localtime = '/etc/localtime.' + suffix + + try: + os.symlink(zonefile, new_localtime) + os.rename(new_localtime, '/etc/localtime') + except Exception: + os.remove(new_localtime) + self.module.fail_json(msg='Could not update /etc/localtime') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class AIXTimezone(Timezone): + """This is a Timezone manipulation class for AIX instances. + + It uses the C(chtz) utility to set the timezone, and + inspects C(/etc/environment) to determine the current timezone. + + While AIX time zones can be set using two formats (POSIX and + Olson) the prefered method is Olson. + See the following article for more information: + https://developer.ibm.com/articles/au-aix-posix/ + + NB: AIX needs to be rebooted in order for the change to be + activated. + """ + + def __init__(self, module): + super(AIXTimezone, self).__init__(module) + self.settimezone = self.module.get_bin_path('chtz', required=True) + + def __get_timezone(self): + """ Return the current value of TZ= in /etc/environment """ + try: + f = open('/etc/environment', 'r') + etcenvironment = f.read() + f.close() + except Exception: + self.module.fail_json(msg='Issue reading contents of /etc/environment') + + match = re.search(r'^TZ=(.*)$', etcenvironment, re.MULTILINE) + if match: + return match.group(1) + else: + return None + + def get(self, key, phase): + """Lookup the current timezone name in `/etc/environment`. If anything else + is requested, or if the TZ field is not set we fail. + """ + if key == 'name': + return self.__get_timezone() + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + """Set the requested timezone through chtz, an invalid timezone name + will be rejected and we have no further input validation to perform. + """ + if key == 'name': + # chtz seems to always return 0 on AIX 7.2, even for invalid timezone values. + # It will only return non-zero if the chtz command itself fails, it does not check for + # valid timezones. We need to perform a basic check to confirm that the timezone + # definition exists in /usr/share/lib/zoneinfo + # This does mean that we can only support Olson for now. The below commented out regex + # detects Olson date formats, so in the future we could detect Posix or Olson and + # act accordingly. + + # regex_olson = re.compile('^([a-z0-9_\-\+]+\/?)+$', re.IGNORECASE) + # if not regex_olson.match(value): + # msg = 'Supplied timezone (%s) does not appear to a be valid Olson string' % value + # self.module.fail_json(msg=msg) + + # First determine if the requested timezone is valid by looking in the zoneinfo + # directory. + zonefile = '/usr/share/lib/zoneinfo/' + value + try: + if not os.path.isfile(zonefile): + self.module.fail_json(msg='%s is not a recognized timezone.' % value) + except Exception: + self.module.fail_json(msg='Failed to check %s.' % zonefile) + + # Now set the TZ using chtz + cmd = 'chtz %s' % value + (rc, stdout, stderr) = self.module.run_command(cmd) + + if rc != 0: + self.module.fail_json(msg=stderr) + + # The best condition check we can do is to check the value of TZ after making the + # change. + TZ = self.__get_timezone() + if TZ != value: + msg = 'TZ value does not match post-change (Actual: %s, Expected: %s).' % (TZ, value) + self.module.fail_json(msg=msg) + + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +def main(): + # Construct 'module' and 'tz' + module = AnsibleModule( + argument_spec=dict( + hwclock=dict(type='str', choices=['local', 'UTC'], aliases=['rtc']), + name=dict(type='str'), + ), + required_one_of=[ + ['hwclock', 'name'] + ], + supports_check_mode=True, + ) + tz = Timezone(module) + + # Check the current state + tz.check(phase='before') + if module.check_mode: + diff = tz.diff('before', 'planned') + # In check mode, 'planned' state is treated as 'after' state + diff['after'] = diff.pop('planned') + else: + # Make change + tz.change() + # Check the current state + tz.check(phase='after') + # Examine if the current state matches planned state + (after, planned) = tz.diff('after', 'planned').values() + if after != planned: + tz.abort('still not desired state, though changes have made - ' + 'planned: %s, after: %s' % (str(planned), str(after))) + diff = tz.diff('before', 'after') + + changed = (diff['before'] != diff['after']) + if len(tz.msg) > 0: + module.exit_json(changed=changed, diff=diff, msg='\n'.join(tz.msg)) + else: + module.exit_json(changed=changed, diff=diff) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/zypper.py b/test/support/integration/plugins/modules/zypper.py new file mode 100644 index 0000000..bfb3181 --- /dev/null +++ b/test/support/integration/plugins/modules/zypper.py @@ -0,0 +1,540 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, Patrick Callahan <pmc@patrickcallahan.com> +# based on +# openbsd_pkg +# (c) 2013 +# Patrik Lundin <patrik.lundin.swe@gmail.com> +# +# yum +# (c) 2012, Red Hat, Inc +# Written by Seth Vidal <skvidal at fedoraproject.org> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: zypper +author: + - "Patrick Callahan (@dirtyharrycallahan)" + - "Alexander Gubin (@alxgu)" + - "Thomas O'Donnell (@andytom)" + - "Robin Roth (@robinro)" + - "Andrii Radyk (@AnderEnder)" +version_added: "1.2" +short_description: Manage packages on SUSE and openSUSE +description: + - Manage packages on SUSE and openSUSE using the zypper and rpm tools. +options: + name: + description: + - Package name C(name) or package specifier or a list of either. + - Can include a version like C(name=1.0), C(name>3.4) or C(name<=2.7). If a version is given, C(oldpackage) is implied and zypper is allowed to + update the package within the version range given. + - You can also pass a url or a local path to a rpm file. + - When using state=latest, this can be '*', which updates all installed packages. + required: true + aliases: [ 'pkg' ] + state: + description: + - C(present) will make sure the package is installed. + C(latest) will make sure the latest version of the package is installed. + C(absent) will make sure the specified package is not installed. + C(dist-upgrade) will make sure the latest version of all installed packages from all enabled repositories is installed. + - When using C(dist-upgrade), I(name) should be C('*'). + required: false + choices: [ present, latest, absent, dist-upgrade ] + default: "present" + type: + description: + - The type of package to be operated on. + required: false + choices: [ package, patch, pattern, product, srcpackage, application ] + default: "package" + version_added: "2.0" + extra_args_precommand: + version_added: "2.6" + required: false + description: + - Add additional global target options to C(zypper). + - Options should be supplied in a single line as if given in the command line. + disable_gpg_check: + description: + - Whether to disable to GPG signature checking of the package + signature being installed. Has an effect only if state is + I(present) or I(latest). + required: false + default: "no" + type: bool + disable_recommends: + version_added: "1.8" + description: + - Corresponds to the C(--no-recommends) option for I(zypper). Default behavior (C(yes)) modifies zypper's default behavior; C(no) does + install recommended packages. + required: false + default: "yes" + type: bool + force: + version_added: "2.2" + description: + - Adds C(--force) option to I(zypper). Allows to downgrade packages and change vendor or architecture. + required: false + default: "no" + type: bool + force_resolution: + version_added: "2.10" + description: + - Adds C(--force-resolution) option to I(zypper). Allows to (un)install packages with conflicting requirements (resolver will choose a solution). + required: false + default: "no" + type: bool + update_cache: + version_added: "2.2" + description: + - Run the equivalent of C(zypper refresh) before the operation. Disabled in check mode. + required: false + default: "no" + type: bool + aliases: [ "refresh" ] + oldpackage: + version_added: "2.2" + description: + - Adds C(--oldpackage) option to I(zypper). Allows to downgrade packages with less side-effects than force. This is implied as soon as a + version is specified as part of the package name. + required: false + default: "no" + type: bool + extra_args: + version_added: "2.4" + required: false + description: + - Add additional options to C(zypper) command. + - Options should be supplied in a single line as if given in the command line. +notes: + - When used with a `loop:` each package will be processed individually, + it is much more efficient to pass the list directly to the `name` option. +# informational: requirements for nodes +requirements: + - "zypper >= 1.0 # included in openSUSE >= 11.1 or SUSE Linux Enterprise Server/Desktop >= 11.0" + - python-xml + - rpm +''' + +EXAMPLES = ''' +# Install "nmap" +- zypper: + name: nmap + state: present + +# Install apache2 with recommended packages +- zypper: + name: apache2 + state: present + disable_recommends: no + +# Apply a given patch +- zypper: + name: openSUSE-2016-128 + state: present + type: patch + +# Remove the "nmap" package +- zypper: + name: nmap + state: absent + +# Install the nginx rpm from a remote repo +- zypper: + name: 'http://nginx.org/packages/sles/12/x86_64/RPMS/nginx-1.8.0-1.sles12.ngx.x86_64.rpm' + state: present + +# Install local rpm file +- zypper: + name: /tmp/fancy-software.rpm + state: present + +# Update all packages +- zypper: + name: '*' + state: latest + +# Apply all available patches +- zypper: + name: '*' + state: latest + type: patch + +# Perform a dist-upgrade with additional arguments +- zypper: + name: '*' + state: dist-upgrade + extra_args: '--no-allow-vendor-change --allow-arch-change' + +# Refresh repositories and update package "openssl" +- zypper: + name: openssl + state: present + update_cache: yes + +# Install specific version (possible comparisons: <, >, <=, >=, =) +- zypper: + name: 'docker>=1.10' + state: present + +# Wait 20 seconds to acquire the lock before failing +- zypper: + name: mosh + state: present + environment: + ZYPP_LOCK_TIMEOUT: 20 +''' + +import xml +import re +from xml.dom.minidom import parseString as parseXML +from ansible.module_utils.six import iteritems +from ansible.module_utils._text import to_native + +# import module snippets +from ansible.module_utils.basic import AnsibleModule + + +class Package: + def __init__(self, name, prefix, version): + self.name = name + self.prefix = prefix + self.version = version + self.shouldinstall = (prefix == '+') + + def __str__(self): + return self.prefix + self.name + self.version + + +def split_name_version(name): + """splits of the package name and desired version + + example formats: + - docker>=1.10 + - apache=2.4 + + Allowed version specifiers: <, >, <=, >=, = + Allowed version format: [0-9.-]* + + Also allows a prefix indicating remove "-", "~" or install "+" + """ + + prefix = '' + if name[0] in ['-', '~', '+']: + prefix = name[0] + name = name[1:] + if prefix == '~': + prefix = '-' + + version_check = re.compile('^(.*?)((?:<|>|<=|>=|=)[0-9.-]*)?$') + try: + reres = version_check.match(name) + name, version = reres.groups() + if version is None: + version = '' + return prefix, name, version + except Exception: + return prefix, name, '' + + +def get_want_state(names, remove=False): + packages = [] + urls = [] + for name in names: + if '://' in name or name.endswith('.rpm'): + urls.append(name) + else: + prefix, pname, version = split_name_version(name) + if prefix not in ['-', '+']: + if remove: + prefix = '-' + else: + prefix = '+' + packages.append(Package(pname, prefix, version)) + return packages, urls + + +def get_installed_state(m, packages): + "get installed state of packages" + + cmd = get_cmd(m, 'search') + cmd.extend(['--match-exact', '--details', '--installed-only']) + cmd.extend([p.name for p in packages]) + return parse_zypper_xml(m, cmd, fail_not_found=False)[0] + + +def parse_zypper_xml(m, cmd, fail_not_found=True, packages=None): + rc, stdout, stderr = m.run_command(cmd, check_rc=False) + + try: + dom = parseXML(stdout) + except xml.parsers.expat.ExpatError as exc: + m.fail_json(msg="Failed to parse zypper xml output: %s" % to_native(exc), + rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + + if rc == 104: + # exit code 104 is ZYPPER_EXIT_INF_CAP_NOT_FOUND (no packages found) + if fail_not_found: + errmsg = dom.getElementsByTagName('message')[-1].childNodes[0].data + m.fail_json(msg=errmsg, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + else: + return {}, rc, stdout, stderr + elif rc in [0, 106, 103]: + # zypper exit codes + # 0: success + # 106: signature verification failed + # 103: zypper was upgraded, run same command again + if packages is None: + firstrun = True + packages = {} + solvable_list = dom.getElementsByTagName('solvable') + for solvable in solvable_list: + name = solvable.getAttribute('name') + packages[name] = {} + packages[name]['version'] = solvable.getAttribute('edition') + packages[name]['oldversion'] = solvable.getAttribute('edition-old') + status = solvable.getAttribute('status') + packages[name]['installed'] = status == "installed" + packages[name]['group'] = solvable.parentNode.nodeName + if rc == 103 and firstrun: + # if this was the first run and it failed with 103 + # run zypper again with the same command to complete update + return parse_zypper_xml(m, cmd, fail_not_found=fail_not_found, packages=packages) + + return packages, rc, stdout, stderr + m.fail_json(msg='Zypper run command failed with return code %s.' % rc, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + + +def get_cmd(m, subcommand): + "puts together the basic zypper command arguments with those passed to the module" + is_install = subcommand in ['install', 'update', 'patch', 'dist-upgrade'] + is_refresh = subcommand == 'refresh' + cmd = ['/usr/bin/zypper', '--quiet', '--non-interactive', '--xmlout'] + if m.params['extra_args_precommand']: + args_list = m.params['extra_args_precommand'].split() + cmd.extend(args_list) + # add global options before zypper command + if (is_install or is_refresh) and m.params['disable_gpg_check']: + cmd.append('--no-gpg-checks') + + if subcommand == 'search': + cmd.append('--disable-repositories') + + cmd.append(subcommand) + if subcommand not in ['patch', 'dist-upgrade'] and not is_refresh: + cmd.extend(['--type', m.params['type']]) + if m.check_mode and subcommand != 'search': + cmd.append('--dry-run') + if is_install: + cmd.append('--auto-agree-with-licenses') + if m.params['disable_recommends']: + cmd.append('--no-recommends') + if m.params['force']: + cmd.append('--force') + if m.params['force_resolution']: + cmd.append('--force-resolution') + if m.params['oldpackage']: + cmd.append('--oldpackage') + if m.params['extra_args']: + args_list = m.params['extra_args'].split(' ') + cmd.extend(args_list) + + return cmd + + +def set_diff(m, retvals, result): + # TODO: if there is only one package, set before/after to version numbers + packages = {'installed': [], 'removed': [], 'upgraded': []} + if result: + for p in result: + group = result[p]['group'] + if group == 'to-upgrade': + versions = ' (' + result[p]['oldversion'] + ' => ' + result[p]['version'] + ')' + packages['upgraded'].append(p + versions) + elif group == 'to-install': + packages['installed'].append(p) + elif group == 'to-remove': + packages['removed'].append(p) + + output = '' + for state in packages: + if packages[state]: + output += state + ': ' + ', '.join(packages[state]) + '\n' + if 'diff' not in retvals: + retvals['diff'] = {} + if 'prepared' not in retvals['diff']: + retvals['diff']['prepared'] = output + else: + retvals['diff']['prepared'] += '\n' + output + + +def package_present(m, name, want_latest): + "install and update (if want_latest) the packages in name_install, while removing the packages in name_remove" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + packages, urls = get_want_state(name) + + # add oldpackage flag when a version is given to allow downgrades + if any(p.version for p in packages): + m.params['oldpackage'] = True + + if not want_latest: + # for state=present: filter out already installed packages + # if a version is given leave the package in to let zypper handle the version + # resolution + packageswithoutversion = [p for p in packages if not p.version] + prerun_state = get_installed_state(m, packageswithoutversion) + # generate lists of packages to install or remove + packages = [p for p in packages if p.shouldinstall != (p.name in prerun_state)] + + if not packages and not urls: + # nothing to install/remove and nothing to update + return None, retvals + + # zypper install also updates packages + cmd = get_cmd(m, 'install') + cmd.append('--') + cmd.extend(urls) + # pass packages to zypper + # allow for + or - prefixes in install/remove lists + # also add version specifier if given + # do this in one zypper run to allow for dependency-resolution + # for example "-exim postfix" runs without removing packages depending on mailserver + cmd.extend([str(p) for p in packages]) + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + + return result, retvals + + +def package_update_all(m): + "run update or patch on all available packages" + + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + if m.params['type'] == 'patch': + cmdname = 'patch' + elif m.params['state'] == 'dist-upgrade': + cmdname = 'dist-upgrade' + else: + cmdname = 'update' + + cmd = get_cmd(m, cmdname) + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + return result, retvals + + +def package_absent(m, name): + "remove the packages in name" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + # Get package state + packages, urls = get_want_state(name, remove=True) + if any(p.prefix == '+' for p in packages): + m.fail_json(msg="Can not combine '+' prefix with state=remove/absent.") + if urls: + m.fail_json(msg="Can not remove via URL.") + if m.params['type'] == 'patch': + m.fail_json(msg="Can not remove patches.") + prerun_state = get_installed_state(m, packages) + packages = [p for p in packages if p.name in prerun_state] + + if not packages: + return None, retvals + + cmd = get_cmd(m, 'remove') + cmd.extend([p.name + p.version for p in packages]) + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + return result, retvals + + +def repo_refresh(m): + "update the repositories" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + + cmd = get_cmd(m, 'refresh') + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + + return retvals + +# =========================================== +# Main control flow + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(required=True, aliases=['pkg'], type='list'), + state=dict(required=False, default='present', choices=['absent', 'installed', 'latest', 'present', 'removed', 'dist-upgrade']), + type=dict(required=False, default='package', choices=['package', 'patch', 'pattern', 'product', 'srcpackage', 'application']), + extra_args_precommand=dict(required=False, default=None), + disable_gpg_check=dict(required=False, default='no', type='bool'), + disable_recommends=dict(required=False, default='yes', type='bool'), + force=dict(required=False, default='no', type='bool'), + force_resolution=dict(required=False, default='no', type='bool'), + update_cache=dict(required=False, aliases=['refresh'], default='no', type='bool'), + oldpackage=dict(required=False, default='no', type='bool'), + extra_args=dict(required=False, default=None), + ), + supports_check_mode=True + ) + + name = module.params['name'] + state = module.params['state'] + update_cache = module.params['update_cache'] + + # remove empty strings from package list + name = list(filter(None, name)) + + # Refresh repositories + if update_cache and not module.check_mode: + retvals = repo_refresh(module) + + if retvals['rc'] != 0: + module.fail_json(msg="Zypper refresh run failed.", **retvals) + + # Perform requested action + if name == ['*'] and state in ['latest', 'dist-upgrade']: + packages_changed, retvals = package_update_all(module) + elif name != ['*'] and state == 'dist-upgrade': + module.fail_json(msg="Can not dist-upgrade specific packages.") + else: + if state in ['absent', 'removed']: + packages_changed, retvals = package_absent(module, name) + elif state in ['installed', 'present', 'latest']: + packages_changed, retvals = package_present(module, name, state == 'latest') + + retvals['changed'] = retvals['rc'] == 0 and bool(packages_changed) + + if module._diff: + set_diff(module, retvals, packages_changed) + + if retvals['rc'] != 0: + module.fail_json(msg="Zypper run failed.", **retvals) + + if not retvals['changed']: + del retvals['stdout'] + del retvals['stderr'] + + module.exit_json(name=name, state=state, update_cache=update_cache, **retvals) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py new file mode 100644 index 0000000..089b339 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py @@ -0,0 +1,40 @@ +# +# Copyright 2018 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from ansible_collections.ansible.netcommon.plugins.action.network import ( + ActionModule as ActionNetworkModule, +) + + +class ActionModule(ActionNetworkModule): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + self._config_module = True + if self._play_context.connection.split(".")[-1] != "network_cli": + return { + "failed": True, + "msg": "Connection type %s is not valid for cli_config module" + % self._play_context.connection, + } + + return super(ActionModule, self).run(task_vars=task_vars) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py new file mode 100644 index 0000000..542dcfe --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py @@ -0,0 +1,90 @@ +# Copyright: (c) 2015, Ansible Inc, +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import copy + +from ansible.errors import AnsibleError +from ansible.plugins.action import ActionBase +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + result = {} + play_context = copy.deepcopy(self._play_context) + play_context.network_os = self._get_network_os(task_vars) + new_task = self._task.copy() + + module = self._get_implementation_module( + play_context.network_os, self._task.action + ) + if not module: + if self._task.args["fail_on_missing_module"]: + result["failed"] = True + else: + result["failed"] = False + + result["msg"] = ( + "Could not find implementation module %s for %s" + % (self._task.action, play_context.network_os) + ) + return result + + new_task.action = module + + action = self._shared_loader_obj.action_loader.get( + play_context.network_os, + task=new_task, + connection=self._connection, + play_context=play_context, + loader=self._loader, + templar=self._templar, + shared_loader_obj=self._shared_loader_obj, + ) + display.vvvv("Running implementation module %s" % module) + return action.run(task_vars=task_vars) + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host to use platform agnostic modules" + ) + + return network_os + + def _get_implementation_module(self, network_os, platform_agnostic_module): + module_name = ( + network_os.split(".")[-1] + + "_" + + platform_agnostic_module.partition("_")[2] + ) + if "." in network_os: + fqcn_module = ".".join(network_os.split(".")[0:-1]) + implementation_module = fqcn_module + "." + module_name + else: + implementation_module = module_name + + if implementation_module not in self._shared_loader_obj.module_loader: + implementation_module = None + + return implementation_module diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py new file mode 100644 index 0000000..40205a4 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py @@ -0,0 +1,199 @@ +# (c) 2018, Ansible Inc, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import re +import uuid +import hashlib + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible.plugins.action import ActionBase +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + socket_path = None + self._get_network_os(task_vars) + persistent_connection = self._play_context.connection.split(".")[-1] + + result = super(ActionModule, self).run(task_vars=task_vars) + + if persistent_connection != "network_cli": + # It is supported only with network_cli + result["failed"] = True + result["msg"] = ( + "connection type %s is not valid for net_get module," + " please use fully qualified name of network_cli connection type" + % self._play_context.connection + ) + return result + + try: + src = self._task.args["src"] + except KeyError as exc: + return { + "failed": True, + "msg": "missing required argument: %s" % exc, + } + + # Get destination file if specified + dest = self._task.args.get("dest") + + if dest is None: + dest = self._get_default_dest(src) + else: + dest = self._handle_dest_path(dest) + + # Get proto + proto = self._task.args.get("protocol") + if proto is None: + proto = "scp" + + if socket_path is None: + socket_path = self._connection.socket_path + + conn = Connection(socket_path) + sock_timeout = conn.get_option("persistent_command_timeout") + + try: + changed = self._handle_existing_file( + conn, src, dest, proto, sock_timeout + ) + if changed is False: + result["changed"] = changed + result["destination"] = dest + return result + except Exception as exc: + result["msg"] = ( + "Warning: %s idempotency check failed. Check dest" % exc + ) + + try: + conn.get_file( + source=src, destination=dest, proto=proto, timeout=sock_timeout + ) + except Exception as exc: + result["failed"] = True + result["msg"] = "Exception received: %s" % exc + + result["changed"] = changed + result["destination"] = dest + return result + + def _handle_dest_path(self, dest): + working_path = self._get_working_path() + + if os.path.isabs(dest) or urlsplit("dest").scheme: + dst = dest + else: + dst = self._loader.path_dwim_relative(working_path, "", dest) + + return dst + + def _get_src_filename_from_path(self, src_path): + filename_list = re.split("/|:", src_path) + return filename_list[-1] + + def _get_default_dest(self, src_path): + dest_path = self._get_working_path() + src_fname = self._get_src_filename_from_path(src_path) + filename = "%s/%s" % (dest_path, src_fname) + return filename + + def _handle_existing_file(self, conn, source, dest, proto, timeout): + """ + Determines whether the source and destination file match. + + :return: False if source and dest both exist and have matching sha1 sums, True otherwise. + """ + if not os.path.exists(dest): + return True + + cwd = self._loader.get_basedir() + filename = str(uuid.uuid4()) + tmp_dest_file = os.path.join(cwd, filename) + try: + conn.get_file( + source=source, + destination=tmp_dest_file, + proto=proto, + timeout=timeout, + ) + except ConnectionError as exc: + error = to_text(exc) + if error.endswith("No such file or directory"): + if os.path.exists(tmp_dest_file): + os.remove(tmp_dest_file) + return True + + try: + with open(tmp_dest_file, "r") as f: + new_content = f.read() + with open(dest, "r") as f: + old_content = f.read() + except (IOError, OSError): + os.remove(tmp_dest_file) + raise + + sha1 = hashlib.sha1() + old_content_b = to_bytes(old_content, errors="surrogate_or_strict") + sha1.update(old_content_b) + checksum_old = sha1.digest() + + sha1 = hashlib.sha1() + new_content_b = to_bytes(new_content, errors="surrogate_or_strict") + sha1.update(new_content_b) + checksum_new = sha1.digest() + os.remove(tmp_dest_file) + if checksum_old == checksum_new: + return False + return True + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py new file mode 100644 index 0000000..955329d --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py @@ -0,0 +1,235 @@ +# (c) 2018, Ansible Inc, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import uuid +import hashlib + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible.plugins.action import ActionBase +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + socket_path = None + network_os = self._get_network_os(task_vars).split(".")[-1] + persistent_connection = self._play_context.connection.split(".")[-1] + + result = super(ActionModule, self).run(task_vars=task_vars) + + if persistent_connection != "network_cli": + # It is supported only with network_cli + result["failed"] = True + result["msg"] = ( + "connection type %s is not valid for net_put module," + " please use fully qualified name of network_cli connection type" + % self._play_context.connection + ) + return result + + try: + src = self._task.args["src"] + except KeyError as exc: + return { + "failed": True, + "msg": "missing required argument: %s" % exc, + } + + src_file_path_name = src + + # Get destination file if specified + dest = self._task.args.get("dest") + + # Get proto + proto = self._task.args.get("protocol") + if proto is None: + proto = "scp" + + # Get mode if set + mode = self._task.args.get("mode") + if mode is None: + mode = "binary" + + if mode == "text": + try: + self._handle_template(convert_data=False) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + # Now src has resolved file write to disk in current diectory for scp + src = self._task.args.get("src") + filename = str(uuid.uuid4()) + cwd = self._loader.get_basedir() + output_file = os.path.join(cwd, filename) + try: + with open(output_file, "wb") as f: + f.write(to_bytes(src, encoding="utf-8")) + except Exception: + os.remove(output_file) + raise + else: + try: + output_file = self._get_binary_src_file(src) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + if socket_path is None: + socket_path = self._connection.socket_path + + conn = Connection(socket_path) + sock_timeout = conn.get_option("persistent_command_timeout") + + if dest is None: + dest = src_file_path_name + + try: + changed = self._handle_existing_file( + conn, output_file, dest, proto, sock_timeout + ) + if changed is False: + result["changed"] = changed + result["destination"] = dest + return result + except Exception as exc: + result["msg"] = ( + "Warning: %s idempotency check failed. Check dest" % exc + ) + + try: + conn.copy_file( + source=output_file, + destination=dest, + proto=proto, + timeout=sock_timeout, + ) + except Exception as exc: + if to_text(exc) == "No response from server": + if network_os == "iosxr": + # IOSXR sometimes closes socket prematurely after completion + # of file transfer + result[ + "msg" + ] = "Warning: iosxr scp server pre close issue. Please check dest" + else: + result["failed"] = True + result["msg"] = "Exception received: %s" % exc + + if mode == "text": + # Cleanup tmp file expanded wih ansible vars + os.remove(output_file) + + result["changed"] = changed + result["destination"] = dest + return result + + def _handle_existing_file(self, conn, source, dest, proto, timeout): + """ + Determines whether the source and destination file match. + + :return: False if source and dest both exist and have matching sha1 sums, True otherwise. + """ + cwd = self._loader.get_basedir() + filename = str(uuid.uuid4()) + tmp_source_file = os.path.join(cwd, filename) + try: + conn.get_file( + source=dest, + destination=tmp_source_file, + proto=proto, + timeout=timeout, + ) + except ConnectionError as exc: + error = to_text(exc) + if error.endswith("No such file or directory"): + if os.path.exists(tmp_source_file): + os.remove(tmp_source_file) + return True + + try: + with open(source, "r") as f: + new_content = f.read() + with open(tmp_source_file, "r") as f: + old_content = f.read() + except (IOError, OSError): + os.remove(tmp_source_file) + raise + + sha1 = hashlib.sha1() + old_content_b = to_bytes(old_content, errors="surrogate_or_strict") + sha1.update(old_content_b) + checksum_old = sha1.digest() + + sha1 = hashlib.sha1() + new_content_b = to_bytes(new_content, errors="surrogate_or_strict") + sha1.update(new_content_b) + checksum_new = sha1.digest() + os.remove(tmp_source_file) + if checksum_old == checksum_new: + return False + return True + + def _get_binary_src_file(self, src): + working_path = self._get_working_path() + + if os.path.isabs(src) or urlsplit("src").scheme: + source = src + else: + source = self._loader.path_dwim_relative( + working_path, "templates", src + ) + if not source: + source = self._loader.path_dwim_relative(working_path, src) + + if not os.path.exists(source): + raise ValueError("path specified in src not found") + + return source + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py new file mode 100644 index 0000000..5d05d33 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py @@ -0,0 +1,209 @@ +# +# (c) 2018 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import time +import re + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.plugins.action.normal import ActionModule as _ActionModule +from ansible.utils.display import Display + +display = Display() + +PRIVATE_KEYS_RE = re.compile("__.+__") + + +class ActionModule(_ActionModule): + def run(self, task_vars=None): + config_module = hasattr(self, "_config_module") and self._config_module + if config_module and self._task.args.get("src"): + try: + self._handle_src_option() + except AnsibleError as e: + return {"failed": True, "msg": e.message, "changed": False} + + result = super(ActionModule, self).run(task_vars=task_vars) + + if ( + config_module + and self._task.args.get("backup") + and not result.get("failed") + ): + self._handle_backup_option(result, task_vars) + + return result + + def _handle_backup_option(self, result, task_vars): + + filename = None + backup_path = None + try: + content = result["__backup__"] + except KeyError: + raise AnsibleError("Failed while reading configuration backup") + + backup_options = self._task.args.get("backup_options") + if backup_options: + filename = backup_options.get("filename") + backup_path = backup_options.get("dir_path") + + if not backup_path: + cwd = self._get_working_path() + backup_path = os.path.join(cwd, "backup") + if not filename: + tstamp = time.strftime( + "%Y-%m-%d@%H:%M:%S", time.localtime(time.time()) + ) + filename = "%s_config.%s" % ( + task_vars["inventory_hostname"], + tstamp, + ) + + dest = os.path.join(backup_path, filename) + backup_path = os.path.expanduser( + os.path.expandvars( + to_bytes(backup_path, errors="surrogate_or_strict") + ) + ) + + if not os.path.exists(backup_path): + os.makedirs(backup_path) + + new_task = self._task.copy() + for item in self._task.args: + if not item.startswith("_"): + new_task.args.pop(item, None) + + new_task.args.update(dict(content=content, dest=dest)) + copy_action = self._shared_loader_obj.action_loader.get( + "copy", + task=new_task, + connection=self._connection, + play_context=self._play_context, + loader=self._loader, + templar=self._templar, + shared_loader_obj=self._shared_loader_obj, + ) + copy_result = copy_action.run(task_vars=task_vars) + if copy_result.get("failed"): + result["failed"] = copy_result["failed"] + result["msg"] = copy_result.get("msg") + return + + result["backup_path"] = dest + if copy_result.get("changed", False): + result["changed"] = copy_result["changed"] + + if backup_options and backup_options.get("filename"): + result["date"] = time.strftime( + "%Y-%m-%d", + time.gmtime(os.stat(result["backup_path"]).st_ctime), + ) + result["time"] = time.strftime( + "%H:%M:%S", + time.gmtime(os.stat(result["backup_path"]).st_ctime), + ) + + else: + result["date"] = tstamp.split("@")[0] + result["time"] = tstamp.split("@")[1] + result["shortname"] = result["backup_path"][::-1].split(".", 1)[1][ + ::-1 + ] + result["filename"] = result["backup_path"].split("/")[-1] + + # strip out any keys that have two leading and two trailing + # underscore characters + for key in list(result.keys()): + if PRIVATE_KEYS_RE.match(key): + del result[key] + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _handle_src_option(self, convert_data=True): + src = self._task.args.get("src") + working_path = self._get_working_path() + + if os.path.isabs(src) or urlsplit("src").scheme: + source = src + else: + source = self._loader.path_dwim_relative( + working_path, "templates", src + ) + if not source: + source = self._loader.path_dwim_relative(working_path, src) + + if not os.path.exists(source): + raise AnsibleError("path specified in src not found") + + try: + with open(source, "r") as f: + template_data = to_text(f.read()) + except IOError as e: + raise AnsibleError( + "unable to load src file {0}, I/O error({1}): {2}".format( + source, e.errno, e.strerror + ) + ) + + # Create a template search path in the following order: + # [working_path, self_role_path, dependent_role_paths, dirname(source)] + searchpath = [working_path] + if self._task._role is not None: + searchpath.append(self._task._role._role_path) + if hasattr(self._task, "_block:"): + dep_chain = self._task._block.get_dep_chain() + if dep_chain is not None: + for role in dep_chain: + searchpath.append(role._role_path) + searchpath.append(os.path.dirname(source)) + with self._templar.set_temporary_context(searchpath=searchpath): + self._task.args["src"] = self._templar.template( + template_data, convert_data=convert_data + ) + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py new file mode 100644 index 0000000..33938fd --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """become: enable +short_description: Switch to elevated permissions on a network device +description: +- This become plugins allows elevated permissions on a remote network device. +author: ansible (@core) +options: + become_pass: + description: password + ini: + - section: enable_become_plugin + key: password + vars: + - name: ansible_become_password + - name: ansible_become_pass + - name: ansible_enable_pass + env: + - name: ANSIBLE_BECOME_PASS + - name: ANSIBLE_ENABLE_PASS +notes: +- enable is really implemented in the network connection handler and as such can only + be used with network connections. +- This plugin ignores the 'become_exe' and 'become_user' settings as it uses an API + and not an executable. +""" + +from ansible.plugins.become import BecomeBase + + +class BecomeModule(BecomeBase): + + name = "ansible.netcommon.enable" + + def build_become_command(self, cmd, shell): + # enable is implemented inside the network connection plugins + return cmd diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py new file mode 100644 index 0000000..b063ef0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py @@ -0,0 +1,324 @@ +# (c) 2018 Red Hat Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +connection: httpapi +short_description: Use httpapi to run command on network appliances +description: +- This connection plugin provides a connection to remote devices over a HTTP(S)-based + api. +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the HTTP(S) connection + to. + default: inventory_hostname + vars: + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the HTTP(S) connection. + - When unspecified, will pick 80 or 443 based on the value of use_ssl. + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_httpapi_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load the correct httpapi plugin to communicate with the remote device + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the API connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + needed for the device API. + vars: + - name: ansible_password + - name: ansible_httpapi_pass + - name: ansible_httpapi_password + use_ssl: + type: boolean + description: + - Whether to connect using SSL (HTTPS) or not (HTTP). + default: false + vars: + - name: ansible_httpapi_use_ssl + validate_certs: + type: boolean + description: + - Whether to validate SSL certificates + default: true + vars: + - name: ansible_httpapi_validate_certs + use_proxy: + type: boolean + description: + - Whether to use https_proxy for requests. + default: true + vars: + - name: ansible_httpapi_use_proxy + become: + type: boolean + description: + - The become option will instruct the CLI session to attempt privilege escalation + on platforms that support it. Normally this means transitioning from user mode + to C(enable) mode in the CLI session. If become is set to True and the remote + device does not support privilege escalation or the privilege has already been + elevated, then this option is silently ignored. + - Can be configured from the CLI via the C(--become) or C(-b) options. + default: false + ini: + - section: privilege_escalation + key: become + env: + - name: ANSIBLE_BECOME + vars: + - name: ansible_become + become_method: + description: + - This option allows the become method to be specified in for handling privilege + escalation. Typically the become_method value is set to C(enable) but could + be defined as other values. + default: sudo + ini: + - section: privilege_escalation + key: become_method + env: + - name: ANSIBLE_BECOME_METHOD + vars: + - name: ansible_become_method + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this option + as it could create a security vulnerability by logging sensitive information + in log file. + default: false + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" + +from io import BytesIO + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_bytes +from ansible.module_utils.six import PY3 +from ansible.module_utils.six.moves import cPickle +from ansible.module_utils.six.moves.urllib.error import HTTPError, URLError +from ansible.module_utils.urls import open_url +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import httpapi_loader +from ansible.plugins.connection import NetworkConnectionBase, ensure_connect + + +class Connection(NetworkConnectionBase): + """Network API connection""" + + transport = "ansible.netcommon.httpapi" + has_pipelining = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + + self._url = None + self._auth = None + + if self._network_os: + + self.httpapi = httpapi_loader.get(self._network_os, self) + if self.httpapi: + self._sub_plugin = { + "type": "httpapi", + "name": self.httpapi._load_name, + "obj": self.httpapi, + } + self.queue_message( + "vvvv", + "loaded API plugin %s from path %s for network_os %s" + % ( + self.httpapi._load_name, + self.httpapi._original_path, + self._network_os, + ), + ) + else: + raise AnsibleConnectionFailure( + "unable to load API plugin for network_os %s" + % self._network_os + ) + + else: + raise AnsibleConnectionFailure( + "Unable to automatically determine host network os. Please " + "manually configure ansible_network_os value for this host" + ) + self.queue_message("log", "network_os is set to %s" % self._network_os) + + def update_play_context(self, pc_data): + """Updates the play context information for the connection""" + pc_data = to_bytes(pc_data) + if PY3: + pc_data = cPickle.loads(pc_data, encoding="bytes") + else: + pc_data = cPickle.loads(pc_data) + play_context = PlayContext() + play_context.deserialize(pc_data) + + self.queue_message("vvvv", "updating play_context for connection") + if self._play_context.become ^ play_context.become: + self.set_become(play_context) + if play_context.become is True: + self.queue_message("vvvv", "authorizing connection") + else: + self.queue_message("vvvv", "deauthorizing connection") + + self._play_context = play_context + + def _connect(self): + if not self.connected: + protocol = "https" if self.get_option("use_ssl") else "http" + host = self.get_option("host") + port = self.get_option("port") or ( + 443 if protocol == "https" else 80 + ) + self._url = "%s://%s:%s" % (protocol, host, port) + + self.queue_message( + "vvv", + "ESTABLISH HTTP(S) CONNECTFOR USER: %s TO %s" + % (self._play_context.remote_user, self._url), + ) + self.httpapi.set_become(self._play_context) + self._connected = True + + self.httpapi.login( + self.get_option("remote_user"), self.get_option("password") + ) + + def close(self): + """ + Close the active session to the device + """ + # only close the connection if its connected. + if self._connected: + self.queue_message("vvvv", "closing http(s) connection to device") + self.logout() + + super(Connection, self).close() + + @ensure_connect + def send(self, path, data, **kwargs): + """ + Sends the command to the device over api + """ + url_kwargs = dict( + timeout=self.get_option("persistent_command_timeout"), + validate_certs=self.get_option("validate_certs"), + use_proxy=self.get_option("use_proxy"), + headers={}, + ) + url_kwargs.update(kwargs) + if self._auth: + # Avoid modifying passed-in headers + headers = dict(kwargs.get("headers", {})) + headers.update(self._auth) + url_kwargs["headers"] = headers + else: + url_kwargs["force_basic_auth"] = True + url_kwargs["url_username"] = self.get_option("remote_user") + url_kwargs["url_password"] = self.get_option("password") + + try: + url = self._url + path + self._log_messages( + "send url '%s' with data '%s' and kwargs '%s'" + % (url, data, url_kwargs) + ) + response = open_url(url, data=data, **url_kwargs) + except HTTPError as exc: + is_handled = self.handle_httperror(exc) + if is_handled is True: + return self.send(path, data, **kwargs) + elif is_handled is False: + raise + else: + response = is_handled + except URLError as exc: + raise AnsibleConnectionFailure( + "Could not connect to {0}: {1}".format( + self._url + path, exc.reason + ) + ) + + response_buffer = BytesIO() + resp_data = response.read() + self._log_messages("received response: '%s'" % resp_data) + response_buffer.write(resp_data) + + # Try to assign a new auth token if one is given + self._auth = self.update_auth(response, response_buffer) or self._auth + + response_buffer.seek(0) + + return response, response_buffer diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py new file mode 100644 index 0000000..1e2d3ca --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py @@ -0,0 +1,404 @@ +# (c) 2016 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +connection: netconf +short_description: Provides a persistent connection using the netconf protocol +description: +- This connection plugin provides a connection to remote devices over the SSH NETCONF + subsystem. This connection plugin is typically used by network devices for sending + and receiving RPC calls over NETCONF. +- Note this connection plugin requires ncclient to be installed on the local Ansible + controller. +requirements: +- ncclient +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the SSH connection + to. + default: inventory_hostname + vars: + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the SSH connection. + default: 830 + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load a device specific netconf plugin. If this option is not configured + (or set to C(auto)), then Ansible will attempt to guess the correct network_os + to use. If it can not guess a network_os correctly it will use C(default). + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the SSH connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + first establishing the SSH connection. + vars: + - name: ansible_password + - name: ansible_ssh_pass + - name: ansible_ssh_password + - name: ansible_netconf_password + private_key_file: + description: + - The private SSH key or certificate file used to authenticate to the remote device + when first establishing the SSH connection. + ini: + - section: defaults + key: private_key_file + env: + - name: ANSIBLE_PRIVATE_KEY_FILE + vars: + - name: ansible_private_key_file + look_for_keys: + default: true + description: + - Enables looking for ssh keys in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`). + env: + - name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS + ini: + - section: paramiko_connection + key: look_for_keys + type: boolean + host_key_checking: + description: Set this to "False" if you want to avoid host key checking by the + underlying tools Ansible uses to connect to the host + type: boolean + default: true + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + - name: ANSIBLE_NETCONF_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: paramiko_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + - name: ansible_netconf_host_key_checking + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + netconf_ssh_config: + description: + - This variable is used to enable bastion/jump host with netconf connection. If + set to True the bastion/jump host ssh settings should be present in ~/.ssh/config + file, alternatively it can be set to custom ssh configuration file path to read + the bastion/jump host settings. + ini: + - section: netconf_connection + key: ssh_config + version_added: '2.7' + env: + - name: ANSIBLE_NETCONF_SSH_CONFIG + vars: + - name: ansible_netconf_ssh_config + version_added: '2.7' + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this option + as it could create a security vulnerability by logging sensitive information + in log file. + default: false + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" + +import os +import logging +import json + +from ansible.errors import AnsibleConnectionFailure, AnsibleError +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.parsing.convert_bool import ( + BOOLEANS_TRUE, + BOOLEANS_FALSE, +) +from ansible.plugins.loader import netconf_loader +from ansible.plugins.connection import NetworkConnectionBase, ensure_connect + +try: + from ncclient import manager + from ncclient.operations import RPCError + from ncclient.transport.errors import SSHUnknownHostError + from ncclient.xml_ import to_ele, to_xml + + HAS_NCCLIENT = True + NCCLIENT_IMP_ERR = None +except ( + ImportError, + AttributeError, +) as err: # paramiko and gssapi are incompatible and raise AttributeError not ImportError + HAS_NCCLIENT = False + NCCLIENT_IMP_ERR = err + +logging.getLogger("ncclient").setLevel(logging.INFO) + + +class Connection(NetworkConnectionBase): + """NetConf connections""" + + transport = "ansible.netcommon.netconf" + has_pipelining = False + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + + # If network_os is not specified then set the network os to auto + # This will be used to trigger the use of guess_network_os when connecting. + self._network_os = self._network_os or "auto" + + self.netconf = netconf_loader.get(self._network_os, self) + if self.netconf: + self._sub_plugin = { + "type": "netconf", + "name": self.netconf._load_name, + "obj": self.netconf, + } + self.queue_message( + "vvvv", + "loaded netconf plugin %s from path %s for network_os %s" + % ( + self.netconf._load_name, + self.netconf._original_path, + self._network_os, + ), + ) + else: + self.netconf = netconf_loader.get("default", self) + self._sub_plugin = { + "type": "netconf", + "name": "default", + "obj": self.netconf, + } + self.queue_message( + "display", + "unable to load netconf plugin for network_os %s, falling back to default plugin" + % self._network_os, + ) + + self.queue_message("log", "network_os is set to %s" % self._network_os) + self._manager = None + self.key_filename = None + self._ssh_config = None + + def exec_command(self, cmd, in_data=None, sudoable=True): + """Sends the request to the node and returns the reply + The method accepts two forms of request. The first form is as a byte + string that represents xml string be send over netconf session. + The second form is a json-rpc (2.0) byte string. + """ + if self._manager: + # to_ele operates on native strings + request = to_ele(to_native(cmd, errors="surrogate_or_strict")) + + if request is None: + return "unable to parse request" + + try: + reply = self._manager.rpc(request) + except RPCError as exc: + error = self.internal_error( + data=to_text(to_xml(exc.xml), errors="surrogate_or_strict") + ) + return json.dumps(error) + + return reply.data_xml + else: + return super(Connection, self).exec_command(cmd, in_data, sudoable) + + @property + @ensure_connect + def manager(self): + return self._manager + + def _connect(self): + if not HAS_NCCLIENT: + raise AnsibleError( + "%s: %s" + % ( + missing_required_lib("ncclient"), + to_native(NCCLIENT_IMP_ERR), + ) + ) + + self.queue_message("log", "ssh connection done, starting ncclient") + + allow_agent = True + if self._play_context.password is not None: + allow_agent = False + setattr(self._play_context, "allow_agent", allow_agent) + + self.key_filename = ( + self._play_context.private_key_file + or self.get_option("private_key_file") + ) + if self.key_filename: + self.key_filename = str(os.path.expanduser(self.key_filename)) + + self._ssh_config = self.get_option("netconf_ssh_config") + if self._ssh_config in BOOLEANS_TRUE: + self._ssh_config = True + elif self._ssh_config in BOOLEANS_FALSE: + self._ssh_config = None + + # Try to guess the network_os if the network_os is set to auto + if self._network_os == "auto": + for cls in netconf_loader.all(class_only=True): + network_os = cls.guess_network_os(self) + if network_os: + self.queue_message( + "vvv", "discovered network_os %s" % network_os + ) + self._network_os = network_os + + # If we have tried to detect the network_os but were unable to i.e. network_os is still 'auto' + # then use default as the network_os + + if self._network_os == "auto": + # Network os not discovered. Set it to default + self.queue_message( + "vvv", + "Unable to discover network_os. Falling back to default.", + ) + self._network_os = "default" + try: + ncclient_device_handler = self.netconf.get_option( + "ncclient_device_handler" + ) + except KeyError: + ncclient_device_handler = "default" + self.queue_message( + "vvv", + "identified ncclient device handler: %s." + % ncclient_device_handler, + ) + device_params = {"name": ncclient_device_handler} + + try: + port = self._play_context.port or 830 + self.queue_message( + "vvv", + "ESTABLISH NETCONF SSH CONNECTION FOR USER: %s on PORT %s TO %s WITH SSH_CONFIG = %s" + % ( + self._play_context.remote_user, + port, + self._play_context.remote_addr, + self._ssh_config, + ), + ) + self._manager = manager.connect( + host=self._play_context.remote_addr, + port=port, + username=self._play_context.remote_user, + password=self._play_context.password, + key_filename=self.key_filename, + hostkey_verify=self.get_option("host_key_checking"), + look_for_keys=self.get_option("look_for_keys"), + device_params=device_params, + allow_agent=self._play_context.allow_agent, + timeout=self.get_option("persistent_connect_timeout"), + ssh_config=self._ssh_config, + ) + + self._manager._timeout = self.get_option( + "persistent_command_timeout" + ) + except SSHUnknownHostError as exc: + raise AnsibleConnectionFailure(to_native(exc)) + except ImportError: + raise AnsibleError( + "connection=netconf is not supported on {0}".format( + self._network_os + ) + ) + + if not self._manager.connected: + return 1, b"", b"not connected" + + self.queue_message( + "log", "ncclient manager object created successfully" + ) + + self._connected = True + + super(Connection, self)._connect() + + return ( + 0, + to_bytes(self._manager.session_id, errors="surrogate_or_strict"), + b"", + ) + + def close(self): + if self._manager: + self._manager.close_session() + super(Connection, self).close() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py new file mode 100644 index 0000000..fef4081 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py @@ -0,0 +1,1386 @@ +# (c) 2016 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """ +author: + - Ansible Networking Team (@ansible-network) +name: network_cli +short_description: Use network_cli to run command on network appliances +description: +- This connection plugin provides a connection to remote devices over the SSH and + implements a CLI shell. This connection plugin is typically used by network devices + for sending and receiving CLi commands to network devices. +version_added: 1.0.0 +requirements: +- ansible-pylibssh if using I(ssh_type=libssh) +extends_documentation_fragment: +- ansible.netcommon.connection_persistent +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the SSH connection + to. + default: inventory_hostname + vars: + - name: inventory_hostname + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the SSH connection. + default: 22 + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load the correct terminal and cliconf plugins to communicate with the remote + device. + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the SSH connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + first establishing the SSH connection. + vars: + - name: ansible_password + - name: ansible_ssh_pass + - name: ansible_ssh_password + private_key_file: + description: + - The private SSH key or certificate file used to authenticate to the remote device + when first establishing the SSH connection. + ini: + - section: defaults + key: private_key_file + env: + - name: ANSIBLE_PRIVATE_KEY_FILE + vars: + - name: ansible_private_key_file + become: + type: boolean + description: + - The become option will instruct the CLI session to attempt privilege escalation + on platforms that support it. Normally this means transitioning from user mode + to C(enable) mode in the CLI session. If become is set to True and the remote + device does not support privilege escalation or the privilege has already been + elevated, then this option is silently ignored. + - Can be configured from the CLI via the C(--become) or C(-b) options. + default: false + ini: + - section: privilege_escalation + key: become + env: + - name: ANSIBLE_BECOME + vars: + - name: ansible_become + become_errors: + type: str + description: + - This option determines how privilege escalation failures are handled when + I(become) is enabled. + - When set to C(ignore), the errors are silently ignored. + When set to C(warn), a warning message is displayed. + The default option C(fail), triggers a failure and halts execution. + vars: + - name: ansible_network_become_errors + default: fail + choices: ["ignore", "warn", "fail"] + terminal_errors: + type: str + description: + - This option determines how failures while setting terminal parameters + are handled. + - When set to C(ignore), the errors are silently ignored. + When set to C(warn), a warning message is displayed. + The default option C(fail), triggers a failure and halts execution. + vars: + - name: ansible_network_terminal_errors + default: fail + choices: ["ignore", "warn", "fail"] + version_added: 3.1.0 + become_method: + description: + - This option allows the become method to be specified in for handling privilege + escalation. Typically the become_method value is set to C(enable) but could + be defined as other values. + default: sudo + ini: + - section: privilege_escalation + key: become_method + env: + - name: ANSIBLE_BECOME_METHOD + vars: + - name: ansible_become_method + host_key_auto_add: + type: boolean + description: + - By default, Ansible will prompt the user before adding SSH keys to the known + hosts file. Since persistent connections such as network_cli run in background + processes, the user will never be prompted. By enabling this option, unknown + host keys will automatically be added to the known hosts file. + - Be sure to fully understand the security implications of enabling this option + on production systems as it could create a security vulnerability. + default: false + ini: + - section: paramiko_connection + key: host_key_auto_add + env: + - name: ANSIBLE_HOST_KEY_AUTO_ADD + persistent_buffer_read_timeout: + type: float + description: + - Configures, in seconds, the amount of time to wait for the data to be read from + Paramiko channel after the command prompt is matched. This timeout value ensures + that command prompt matched is correct and there is no more data left to be + received from remote host. + default: 0.1 + ini: + - section: persistent_connection + key: buffer_read_timeout + env: + - name: ANSIBLE_PERSISTENT_BUFFER_READ_TIMEOUT + vars: + - name: ansible_buffer_read_timeout + terminal_stdout_re: + type: list + elements: dict + description: + - A single regex pattern or a sequence of patterns along with optional flags to + match the command prompt from the received response chunk. This option accepts + C(pattern) and C(flags) keys. The value of C(pattern) is a python regex pattern + to match the response and the value of C(flags) is the value accepted by I(flags) + argument of I(re.compile) python method to control the way regex is matched + with the response, for example I('re.I'). + vars: + - name: ansible_terminal_stdout_re + terminal_stderr_re: + type: list + elements: dict + description: + - This option provides the regex pattern and optional flags to match the error + string from the received response chunk. This option accepts C(pattern) and + C(flags) keys. The value of C(pattern) is a python regex pattern to match the + response and the value of C(flags) is the value accepted by I(flags) argument + of I(re.compile) python method to control the way regex is matched with the + response, for example I('re.I'). + vars: + - name: ansible_terminal_stderr_re + terminal_initial_prompt: + type: list + elements: string + description: + - A single regex pattern or a sequence of patterns to evaluate the expected prompt + at the time of initial login to the remote host. + vars: + - name: ansible_terminal_initial_prompt + terminal_initial_answer: + type: list + elements: string + description: + - The answer to reply with if the C(terminal_initial_prompt) is matched. The value + can be a single answer or a list of answers for multiple terminal_initial_prompt. + In case the login menu has multiple prompts the sequence of the prompt and excepted + answer should be in same order and the value of I(terminal_prompt_checkall) + should be set to I(True) if all the values in C(terminal_initial_prompt) are + expected to be matched and set to I(False) if any one login prompt is to be + matched. + vars: + - name: ansible_terminal_initial_answer + terminal_initial_prompt_checkall: + type: boolean + description: + - By default the value is set to I(False) and any one of the prompts mentioned + in C(terminal_initial_prompt) option is matched it won't check for other prompts. + When set to I(True) it will check for all the prompts mentioned in C(terminal_initial_prompt) + option in the given order and all the prompts should be received from remote + host if not it will result in timeout. + default: false + vars: + - name: ansible_terminal_initial_prompt_checkall + terminal_inital_prompt_newline: + type: boolean + description: + - This boolean flag, that when set to I(True) will send newline in the response + if any of values in I(terminal_initial_prompt) is matched. + default: true + vars: + - name: ansible_terminal_initial_prompt_newline + network_cli_retries: + description: + - Number of attempts to connect to remote host. The delay time between the retires + increases after every attempt by power of 2 in seconds till either the maximum + attempts are exhausted or any of the C(persistent_command_timeout) or C(persistent_connect_timeout) + timers are triggered. + default: 3 + type: integer + env: + - name: ANSIBLE_NETWORK_CLI_RETRIES + ini: + - section: persistent_connection + key: network_cli_retries + vars: + - name: ansible_network_cli_retries + ssh_type: + description: + - The python package that will be used by the C(network_cli) connection plugin to create a SSH connection to remote host. + - I(libssh) will use the ansible-pylibssh package, which needs to be installed in order to work. + - I(paramiko) will instead use the paramiko package to manage the SSH connection. + - I(auto) will use ansible-pylibssh if that package is installed, otherwise will fallback to paramiko. + default: auto + choices: ["libssh", "paramiko", "auto"] + env: + - name: ANSIBLE_NETWORK_CLI_SSH_TYPE + ini: + - section: persistent_connection + key: ssh_type + vars: + - name: ansible_network_cli_ssh_type + host_key_checking: + description: 'Set this to "False" if you want to avoid host key checking by the underlying tools Ansible uses to connect to the host' + type: boolean + default: True + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: persistent_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + single_user_mode: + type: boolean + default: false + version_added: 2.0.0 + description: + - This option enables caching of data fetched from the target for re-use. + The cache is invalidated when the target device enters configuration mode. + - Applicable only for platforms where this has been implemented. + env: + - name: ANSIBLE_NETWORK_SINGLE_USER_MODE + vars: + - name: ansible_network_single_user_mode +""" + +import getpass +import json +import logging +import os +import re +import signal +import socket +import time +import traceback +from functools import wraps +from io import BytesIO + +from ansible.errors import AnsibleConnectionFailure, AnsibleError +from ansible.module_utils._text import to_bytes, to_text +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.six import PY3 +from ansible.module_utils.six.moves import cPickle +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import ( + cache_loader, + cliconf_loader, + connection_loader, + terminal_loader, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible_collections.ansible.netcommon.plugins.plugin_utils.connection_base import ( + NetworkConnectionBase, +) + +try: + from scp import SCPClient + + HAS_SCP = True +except ImportError: + HAS_SCP = False + +HAS_PYLIBSSH = False + + +def ensure_connect(func): + @wraps(func) + def wrapped(self, *args, **kwargs): + if not self._connected: + self._connect() + self.update_cli_prompt_context() + return func(self, *args, **kwargs) + + return wrapped + + +class AnsibleCmdRespRecv(Exception): + pass + + +class Connection(NetworkConnectionBase): + """CLI (shell) SSH connections on Paramiko""" + + transport = "ansible.netcommon.network_cli" + has_pipelining = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._ssh_shell = None + + self._matched_prompt = None + self._matched_cmd_prompt = None + self._matched_pattern = None + self._last_response = None + self._history = list() + self._command_response = None + self._last_recv_window = None + self._cache = None + + self._terminal = None + self.cliconf = None + + # Managing prompt context + self._check_prompt = False + + self._task_uuid = to_text(kwargs.get("task_uuid", "")) + self._ssh_type_conn = None + self._ssh_type = None + + self._single_user_mode = False + + if self._network_os: + self._terminal = terminal_loader.get(self._network_os, self) + if not self._terminal: + raise AnsibleConnectionFailure( + "network os %s is not supported" % self._network_os + ) + + self.cliconf = cliconf_loader.get(self._network_os, self) + if self.cliconf: + self._sub_plugin = { + "type": "cliconf", + "name": self.cliconf._load_name, + "obj": self.cliconf, + } + self.queue_message( + "vvvv", + "loaded cliconf plugin %s from path %s for network_os %s" + % ( + self.cliconf._load_name, + self.cliconf._original_path, + self._network_os, + ), + ) + else: + self.queue_message( + "vvvv", + "unable to load cliconf for network_os %s" + % self._network_os, + ) + else: + raise AnsibleConnectionFailure( + "Unable to automatically determine host network os. Please " + "manually configure ansible_network_os value for this host" + ) + self.queue_message("log", "network_os is set to %s" % self._network_os) + + @property + def ssh_type(self): + if self._ssh_type is None: + self._ssh_type = self.get_option("ssh_type") + self.queue_message( + "vvvv", "ssh type is set to %s" % self._ssh_type + ) + # Support autodetection of supported library + if self._ssh_type == "auto": + self.queue_message("vvvv", "autodetecting ssh_type") + if HAS_PYLIBSSH: + self._ssh_type = "libssh" + else: + self.queue_message( + "warning", + "ansible-pylibssh not installed, falling back to paramiko", + ) + self._ssh_type = "paramiko" + self.queue_message( + "vvvv", "ssh type is now set to %s" % self._ssh_type + ) + + if self._ssh_type not in ["paramiko", "libssh"]: + raise AnsibleConnectionFailure( + "Invalid value '%s' set for ssh_type option." + " Expected value is either 'libssh' or 'paramiko'" + % self._ssh_type + ) + + return self._ssh_type + + @property + def ssh_type_conn(self): + if self._ssh_type_conn is None: + if self.ssh_type == "libssh": + connection_plugin = "ansible.netcommon.libssh" + elif self.ssh_type == "paramiko": + # NOTE: This MUST be paramiko or things will break + connection_plugin = "paramiko" + else: + raise AnsibleConnectionFailure( + "Invalid value '%s' set for ssh_type option." + " Expected value is either 'libssh' or 'paramiko'" + % self._ssh_type + ) + + self._ssh_type_conn = connection_loader.get( + connection_plugin, self._play_context, "/dev/null" + ) + + return self._ssh_type_conn + + # To maintain backward compatibility + @property + def paramiko_conn(self): + return self.ssh_type_conn + + def _get_log_channel(self): + name = "p=%s u=%s | " % (os.getpid(), getpass.getuser()) + name += "%s [%s]" % (self.ssh_type, self._play_context.remote_addr) + return name + + @ensure_connect + def get_prompt(self): + """Returns the current prompt from the device""" + return self._matched_prompt + + def exec_command(self, cmd, in_data=None, sudoable=True): + # this try..except block is just to handle the transition to supporting + # network_cli as a toplevel connection. Once connection=local is gone, + # this block can be removed as well and all calls passed directly to + # the local connection + if self._ssh_shell: + try: + cmd = json.loads(to_text(cmd, errors="surrogate_or_strict")) + kwargs = { + "command": to_bytes( + cmd["command"], errors="surrogate_or_strict" + ) + } + for key in ( + "prompt", + "answer", + "sendonly", + "newline", + "prompt_retry_check", + ): + if cmd.get(key) is True or cmd.get(key) is False: + kwargs[key] = cmd[key] + elif cmd.get(key) is not None: + kwargs[key] = to_bytes( + cmd[key], errors="surrogate_or_strict" + ) + return self.send(**kwargs) + except ValueError: + cmd = to_bytes(cmd, errors="surrogate_or_strict") + return self.send(command=cmd) + + else: + return super(Connection, self).exec_command(cmd, in_data, sudoable) + + def get_options(self, hostvars=None): + options = super(Connection, self).get_options(hostvars=hostvars) + options.update(self.ssh_type_conn.get_options(hostvars=hostvars)) + return options + + def set_options(self, task_keys=None, var_options=None, direct=None): + super(Connection, self).set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + self.ssh_type_conn.set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + # Retain old look_for_keys behaviour, but only if not set + if not any( + [ + task_keys and ("look_for_keys" in task_keys), + var_options and ("look_for_keys" in var_options), + direct and ("look_for_keys" in direct), + ] + ): + look_for_keys = not bool( + self.get_option("password") + and not self.get_option("private_key_file") + ) + if not look_for_keys: + # This actually can't be overridden yet without changes in ansible-core + # TODO: Uncomment when appropriate + # self.queue_message( + # "warning", + # "Option look_for_keys has been implicitly set to {0} because " + # "it was not set explicitly. This is retained to maintain " + # "backwards compatibility with the old behavior. This behavior " + # "will be removed in some release after 2024-01-01".format( + # look_for_keys + # ), + # ) + self.ssh_type_conn.set_option("look_for_keys", look_for_keys) + + def update_play_context(self, pc_data): + """Updates the play context information for the connection""" + pc_data = to_bytes(pc_data) + if PY3: + pc_data = cPickle.loads(pc_data, encoding="bytes") + else: + pc_data = cPickle.loads(pc_data) + play_context = PlayContext() + play_context.deserialize(pc_data) + + self.queue_message("vvvv", "updating play_context for connection") + if self._play_context.become ^ play_context.become: + if play_context.become is True: + auth_pass = play_context.become_pass + self._on_become(become_pass=auth_pass) + self.queue_message("vvvv", "authorizing connection") + else: + self._terminal.on_unbecome() + self.queue_message("vvvv", "deauthorizing connection") + + self._play_context = play_context + if self._ssh_type_conn is not None: + # TODO: This works, but is not really ideal. We would rather use + # set_options, but then we need more custom handling in that + # method. + self._ssh_type_conn._play_context = play_context + + if hasattr(self, "reset_history"): + self.reset_history() + if hasattr(self, "disable_response_logging"): + self.disable_response_logging() + + self._single_user_mode = self.get_option("single_user_mode") + + def set_check_prompt(self, task_uuid): + self._check_prompt = task_uuid + + def update_cli_prompt_context(self): + # set cli prompt context at the start of new task run only + if self._check_prompt and self._task_uuid != self._check_prompt: + self._task_uuid, self._check_prompt = self._check_prompt, False + self.set_cli_prompt_context() + + def _connect(self): + """ + Connects to the remote device and starts the terminal + """ + if self._play_context.verbosity > 3: + logging.getLogger(self.ssh_type).setLevel(logging.DEBUG) + + self.queue_message( + "vvvv", "invoked shell using ssh_type: %s" % self.ssh_type + ) + + self._single_user_mode = self.get_option("single_user_mode") + + if not self.connected: + self.ssh_type_conn._set_log_channel(self._get_log_channel()) + self.ssh_type_conn.force_persistence = self.force_persistence + + command_timeout = self.get_option("persistent_command_timeout") + max_pause = min( + [ + self.get_option("persistent_connect_timeout"), + command_timeout, + ] + ) + retries = self.get_option("network_cli_retries") + total_pause = 0 + + for attempt in range(retries + 1): + try: + ssh = self.ssh_type_conn._connect() + break + except AnsibleError: + raise + except Exception as e: + pause = 2 ** (attempt + 1) + if attempt == retries or total_pause >= max_pause: + raise AnsibleConnectionFailure( + to_text(e, errors="surrogate_or_strict") + ) + else: + msg = ( + "network_cli_retry: attempt: %d, caught exception(%s), " + "pausing for %d seconds" + % ( + attempt + 1, + to_text(e, errors="surrogate_or_strict"), + pause, + ) + ) + + self.queue_message("vv", msg) + time.sleep(pause) + total_pause += pause + continue + + self.queue_message("vvvv", "ssh connection done, setting terminal") + self._connected = True + + self._ssh_shell = ssh.ssh.invoke_shell() + if self.ssh_type == "paramiko": + self._ssh_shell.settimeout(command_timeout) + + self.queue_message( + "vvvv", + "loaded terminal plugin for network_os %s" % self._network_os, + ) + + terminal_initial_prompt = ( + self.get_option("terminal_initial_prompt") + or self._terminal.terminal_initial_prompt + ) + terminal_initial_answer = ( + self.get_option("terminal_initial_answer") + or self._terminal.terminal_initial_answer + ) + newline = ( + self.get_option("terminal_inital_prompt_newline") + or self._terminal.terminal_inital_prompt_newline + ) + check_all = ( + self.get_option("terminal_initial_prompt_checkall") or False + ) + + self.receive( + prompts=terminal_initial_prompt, + answer=terminal_initial_answer, + newline=newline, + check_all=check_all, + ) + + if self._play_context.become: + self.queue_message("vvvv", "firing event: on_become") + auth_pass = self._play_context.become_pass + self._on_become(become_pass=auth_pass) + + self.queue_message("vvvv", "firing event: on_open_shell()") + self._on_open_shell() + + self.queue_message( + "vvvv", "ssh connection has completed successfully" + ) + + return self + + def _on_become(self, become_pass=None): + """ + Wraps terminal.on_become() to handle + privilege escalation failures based on user preference + """ + on_become_error = self.get_option("become_errors") + try: + self._terminal.on_become(passwd=become_pass) + except AnsibleConnectionFailure: + if on_become_error == "ignore": + pass + elif on_become_error == "warn": + self.queue_message( + "warning", "on_become: privilege escalation failed" + ) + else: + raise + + def _on_open_shell(self): + """ + Wraps terminal.on_open_shell() to handle + terminal setting failures based on user preference + """ + on_terminal_error = self.get_option("terminal_errors") + try: + self._terminal.on_open_shell() + except AnsibleConnectionFailure: + if on_terminal_error == "ignore": + pass + elif on_terminal_error == "warn": + self.queue_message( + "warning", + "on_open_shell: failed to set terminal parameters", + ) + else: + raise + + def close(self): + """ + Close the active connection to the device + """ + # only close the connection if its connected. + if self._connected: + self.queue_message("debug", "closing ssh connection to device") + if self._ssh_shell: + self.queue_message("debug", "firing event: on_close_shell()") + self._terminal.on_close_shell() + self._ssh_shell.close() + self._ssh_shell = None + self.queue_message("debug", "cli session is now closed") + + self.ssh_type_conn.close() + self._ssh_type_conn = None + self.queue_message( + "debug", "ssh connection has been closed successfully" + ) + super(Connection, self).close() + + def _read_post_command_prompt_match(self): + time.sleep(self.get_option("persistent_buffer_read_timeout")) + data = self._ssh_shell.read_bulk_response() + return data if data else None + + def receive_paramiko( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + + recv = BytesIO() + cache_socket_timeout = self.get_option("persistent_command_timeout") + self._ssh_shell.settimeout(cache_socket_timeout) + command_prompt_matched = False + handled = False + errored_response = None + + while True: + if command_prompt_matched: + try: + signal.signal( + signal.SIGALRM, self._handle_buffer_read_timeout + ) + signal.setitimer( + signal.ITIMER_REAL, self._buffer_read_timeout + ) + data = self._ssh_shell.recv(256) + signal.alarm(0) + self._log_messages( + "response-%s: %s" % (self._window_count + 1, data) + ) + # if data is still received on channel it indicates the prompt string + # is wrongly matched in between response chunks, continue to read + # remaining response. + command_prompt_matched = False + + # restart command_timeout timer + signal.signal(signal.SIGALRM, self._handle_command_timeout) + signal.alarm(self._command_timeout) + + except AnsibleCmdRespRecv: + # reset socket timeout to global timeout + return self._command_response + else: + data = self._ssh_shell.recv(256) + self._log_messages( + "response-%s: %s" % (self._window_count + 1, data) + ) + # when a channel stream is closed, received data will be empty + if not data: + break + + recv.write(data) + offset = recv.tell() - 256 if recv.tell() > 256 else 0 + recv.seek(offset) + + window = self._strip(recv.read()) + self._last_recv_window = window + self._window_count += 1 + + if prompts and not handled: + handled = self._handle_prompt( + window, prompts, answer, newline, False, check_all + ) + self._matched_prompt_window = self._window_count + elif ( + prompts + and handled + and prompt_retry_check + and self._matched_prompt_window + 1 == self._window_count + ): + # check again even when handled, if same prompt repeats in next window + # (like in the case of a wrong enable password, etc) indicates + # value of answer is wrong, report this as error. + if self._handle_prompt( + window, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + ): + raise AnsibleConnectionFailure( + "For matched prompt '%s', answer is not valid" + % self._matched_cmd_prompt + ) + + if self._find_error(window): + # We can't exit here, as we need to drain the buffer in case + # the error isn't fatal, and will be using the buffer again + errored_response = window + + if self._find_prompt(window): + if errored_response: + raise AnsibleConnectionFailure(errored_response) + self._last_response = recv.getvalue() + resp = self._strip(self._last_response) + self._command_response = self._sanitize( + resp, command, strip_prompt + ) + if self._buffer_read_timeout == 0.0: + # reset socket timeout to global timeout + return self._command_response + else: + command_prompt_matched = True + + def receive_libssh( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + self._command_response = resp = b"" + command_prompt_matched = False + handled = False + errored_response = None + + while True: + + if command_prompt_matched: + data = self._read_post_command_prompt_match() + if data: + command_prompt_matched = False + else: + return self._command_response + else: + try: + data = self._ssh_shell.read_bulk_response() + # TODO: Should be ConnectionError when pylibssh drops Python 2 support + except OSError: + # Socket has closed + break + + if not data: + continue + self._last_recv_window = self._strip(data) + resp += self._last_recv_window + self._window_count += 1 + + self._log_messages("response-%s: %s" % (self._window_count, data)) + + if prompts and not handled: + handled = self._handle_prompt( + resp, prompts, answer, newline, False, check_all + ) + self._matched_prompt_window = self._window_count + elif ( + prompts + and handled + and prompt_retry_check + and self._matched_prompt_window + 1 == self._window_count + ): + # check again even when handled, if same prompt repeats in next window + # (like in the case of a wrong enable password, etc) indicates + # value of answer is wrong, report this as error. + if self._handle_prompt( + resp, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + ): + raise AnsibleConnectionFailure( + "For matched prompt '%s', answer is not valid" + % self._matched_cmd_prompt + ) + + if self._find_error(resp): + # We can't exit here, as we need to drain the buffer in case + # the error isn't fatal, and will be using the buffer again + errored_response = resp + + if self._find_prompt(resp): + if errored_response: + raise AnsibleConnectionFailure(errored_response) + self._last_response = data + self._command_response += self._sanitize( + resp, command, strip_prompt + ) + command_prompt_matched = True + + def receive( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + """ + Handles receiving of output from command + """ + self._matched_prompt = None + self._matched_cmd_prompt = None + self._matched_prompt_window = 0 + self._window_count = 0 + + # set terminal regex values for command prompt and errors in response + self._terminal_stderr_re = self._get_terminal_std_re( + "terminal_stderr_re" + ) + self._terminal_stdout_re = self._get_terminal_std_re( + "terminal_stdout_re" + ) + + self._command_timeout = self.get_option("persistent_command_timeout") + self._validate_timeout_value( + self._command_timeout, "persistent_command_timeout" + ) + + self._buffer_read_timeout = self.get_option( + "persistent_buffer_read_timeout" + ) + self._validate_timeout_value( + self._buffer_read_timeout, "persistent_buffer_read_timeout" + ) + + self._log_messages("command: %s" % command) + if self.ssh_type == "libssh": + response = self.receive_libssh( + command, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + elif self.ssh_type == "paramiko": + response = self.receive_paramiko( + command, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + + return response + + @ensure_connect + def send( + self, + command, + prompt=None, + answer=None, + newline=True, + sendonly=False, + prompt_retry_check=False, + check_all=False, + strip_prompt=True, + ): + """ + Sends the command to the device in the opened shell + """ + # try cache first + if (not prompt) and (self._single_user_mode): + out = self.get_cache().lookup(command) + if out: + self.queue_message( + "vvvv", "cache hit for command: %s" % command + ) + return out + + if check_all: + prompt_len = len(to_list(prompt)) + answer_len = len(to_list(answer)) + if prompt_len != answer_len: + raise AnsibleConnectionFailure( + "Number of prompts (%s) is not same as that of answers (%s)" + % (prompt_len, answer_len) + ) + try: + cmd = b"%s\r" % command + self._history.append(cmd) + self._ssh_shell.sendall(cmd) + self._log_messages("send command: %s" % cmd) + if sendonly: + return + response = self.receive( + command, + prompt, + answer, + newline, + prompt_retry_check, + check_all, + strip_prompt, + ) + response = to_text(response, errors="surrogate_then_replace") + + if (not prompt) and (self._single_user_mode): + if self._needs_cache_invalidation(command): + # invalidate the existing cache + if self.get_cache().keys(): + self.queue_message( + "vvvv", "invalidating existing cache" + ) + self.get_cache().invalidate() + else: + # populate cache + self.queue_message( + "vvvv", "populating cache for command: %s" % command + ) + self.get_cache().populate(command, response) + + return response + except (socket.timeout, AttributeError): + self.queue_message("error", traceback.format_exc()) + raise AnsibleConnectionFailure( + "timeout value %s seconds reached while trying to send command: %s" + % (self._ssh_shell.gettimeout(), command.strip()) + ) + + def _handle_buffer_read_timeout(self, signum, frame): + self.queue_message( + "vvvv", + "Response received, triggered 'persistent_buffer_read_timeout' timer of %s seconds" + % self.get_option("persistent_buffer_read_timeout"), + ) + raise AnsibleCmdRespRecv() + + def _handle_command_timeout(self, signum, frame): + msg = ( + "command timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and Troubleshooting Guide." + % self.get_option("persistent_command_timeout") + ) + self.queue_message("log", msg) + raise AnsibleConnectionFailure(msg) + + def _strip(self, data): + """ + Removes ANSI codes from device response + """ + for regex in self._terminal.ansi_re: + data = regex.sub(b"", data) + return data + + def _handle_prompt( + self, + resp, + prompts, + answer, + newline, + prompt_retry_check=False, + check_all=False, + ): + """ + Matches the command prompt and responds + + :arg resp: Byte string containing the raw response from the remote + :arg prompts: Sequence of byte strings that we consider prompts for input + :arg answer: Sequence of Byte string to send back to the remote if we find a prompt. + A carriage return is automatically appended to this string. + :param prompt_retry_check: Bool value for trying to detect more prompts + :param check_all: Bool value to indicate if all the values in prompt sequence should be matched or any one of + given prompt. + :returns: True if a prompt was found in ``resp``. If check_all is True + will True only after all the prompt in the prompts list are matched. False otherwise. + """ + single_prompt = False + if not isinstance(prompts, list): + prompts = [prompts] + single_prompt = True + if not isinstance(answer, list): + answer = [answer] + try: + prompts_regex = [re.compile(to_bytes(r), re.I) for r in prompts] + except re.error as exc: + raise ConnectionError( + "Failed to compile one or more terminal prompt regexes: %s.\n" + "Prompts provided: %s" % (to_text(exc), prompts) + ) + for index, regex in enumerate(prompts_regex): + match = regex.search(resp) + if match: + self._matched_cmd_prompt = match.group() + self._log_messages( + "matched command prompt: %s" % self._matched_cmd_prompt + ) + + # if prompt_retry_check is enabled to check if same prompt is + # repeated don't send answer again. + if not prompt_retry_check: + prompt_answer = to_bytes( + answer[index] if len(answer) > index else answer[0] + ) + if newline: + prompt_answer += b"\r" + self._ssh_shell.sendall(prompt_answer) + self._log_messages( + "matched command prompt answer: %s" % prompt_answer + ) + if check_all and prompts and not single_prompt: + prompts.pop(0) + answer.pop(0) + return False + return True + return False + + def _sanitize(self, resp, command=None, strip_prompt=True): + """ + Removes elements from the response before returning to the caller + """ + cleaned = [] + for line in resp.splitlines(): + if command and line.strip() == command.strip(): + continue + + for prompt in self._matched_prompt.strip().splitlines(): + if prompt.strip() in line and strip_prompt: + break + else: + cleaned.append(line) + + return b"\n".join(cleaned).strip() + + def _find_error(self, response): + """Searches the buffered response for a matching error condition""" + for stderr_regex in self._terminal_stderr_re: + if stderr_regex.search(response): + self._log_messages( + "matched error regex (terminal_stderr_re) '%s' from response '%s'" + % (stderr_regex.pattern, response) + ) + + self._log_messages( + "matched stdout regex (terminal_stdout_re) '%s' from error response '%s'" + % (self._matched_pattern, response) + ) + return True + + return False + + def _find_prompt(self, response): + """Searches the buffered response for a matching command prompt""" + for stdout_regex in self._terminal_stdout_re: + match = stdout_regex.search(response) + if match: + self._matched_pattern = stdout_regex.pattern + self._matched_prompt = match.group() + self._log_messages( + "matched cli prompt '%s' with regex '%s' from response '%s'" + % (self._matched_prompt, self._matched_pattern, response) + ) + return True + + return False + + def _validate_timeout_value(self, timeout, timer_name): + if timeout < 0: + raise AnsibleConnectionFailure( + "'%s' timer value '%s' is invalid, value should be greater than or equal to zero." + % (timer_name, timeout) + ) + + def transport_test(self, connect_timeout): + """This method enables wait_for_connection to work. + + As it is used by wait_for_connection, it is called by that module's action plugin, + which is on the controller process, which means that nothing done on this instance + should impact the actual persistent connection... this check is for informational + purposes only and should be properly cleaned up. + """ + + # Force a fresh connect if for some reason we have connected before. + self.close() + self._connect() + self.close() + + def _get_terminal_std_re(self, option): + terminal_std_option = self.get_option(option) + terminal_std_re = [] + + if terminal_std_option: + for item in terminal_std_option: + if "pattern" not in item: + raise AnsibleConnectionFailure( + "'pattern' is a required key for option '%s'," + " received option value is %s" % (option, item) + ) + pattern = rb"%s" % to_bytes(item["pattern"]) + flag = item.get("flags", 0) + if flag: + flag = getattr(re, flag.split(".")[1]) + terminal_std_re.append(re.compile(pattern, flag)) + else: + # To maintain backward compatibility + terminal_std_re = getattr(self._terminal, option) + + return terminal_std_re + + def copy_file( + self, source=None, destination=None, proto="scp", timeout=30 + ): + """Copies file over scp/sftp to remote device + + :param source: Source file path + :param destination: Destination file path on remote device + :param proto: Protocol to be used for file transfer, + supported protocol: scp and sftp + :param timeout: Specifies the wait time to receive response from + remote host before triggering timeout exception + :return: None + """ + ssh = self.ssh_type_conn._connect_uncached() + if self.ssh_type == "libssh": + self.ssh_type_conn.put_file(source, destination, proto=proto) + elif self.ssh_type == "paramiko": + if proto == "scp": + if not HAS_SCP: + raise AnsibleError(missing_required_lib("scp")) + with SCPClient( + ssh.get_transport(), socket_timeout=timeout + ) as scp: + scp.put(source, destination) + elif proto == "sftp": + with ssh.open_sftp() as sftp: + sftp.put(source, destination) + else: + raise AnsibleError( + "Do not know how to do transfer file over protocol %s" + % proto + ) + else: + raise AnsibleError( + "Do not know how to do SCP with ssh_type %s" % self.ssh_type + ) + + def get_file(self, source=None, destination=None, proto="scp", timeout=30): + """Fetch file over scp/sftp from remote device + :param source: Source file path + :param destination: Destination file path + :param proto: Protocol to be used for file transfer, + supported protocol: scp and sftp + :param timeout: Specifies the wait time to receive response from + remote host before triggering timeout exception + :return: None + """ + """Fetch file over scp/sftp from remote device""" + ssh = self.ssh_type_conn._connect_uncached() + if self.ssh_type == "libssh": + self.ssh_type_conn.fetch_file(source, destination, proto=proto) + elif self.ssh_type == "paramiko": + if proto == "scp": + if not HAS_SCP: + raise AnsibleError(missing_required_lib("scp")) + try: + with SCPClient( + ssh.get_transport(), socket_timeout=timeout + ) as scp: + scp.get(source, destination) + except EOFError: + # This appears to be benign. + pass + elif proto == "sftp": + with ssh.open_sftp() as sftp: + sftp.get(source, destination) + else: + raise AnsibleError( + "Do not know how to do transfer file over protocol %s" + % proto + ) + else: + raise AnsibleError( + "Do not know how to do SCP with ssh_type %s" % self.ssh_type + ) + + def get_cache(self): + if not self._cache: + # TO-DO: support jsonfile or other modes of caching with + # a configurable option + self._cache = cache_loader.get("ansible.netcommon.memory") + return self._cache + + def _is_in_config_mode(self): + """ + Check if the target device is in config mode by comparing + the current prompt with the platform's `terminal_config_prompt`. + Returns False if `terminal_config_prompt` is not defined. + + :returns: A boolean indicating if the device is in config mode or not. + """ + cfg_mode = False + cur_prompt = to_text( + self.get_prompt(), errors="surrogate_then_replace" + ).strip() + cfg_prompt = getattr(self._terminal, "terminal_config_prompt", None) + if cfg_prompt and cfg_prompt.match(cur_prompt): + cfg_mode = True + return cfg_mode + + def _needs_cache_invalidation(self, command): + """ + This method determines if it is necessary to invalidate + the existing cache based on whether the device has entered + configuration mode or if the last command sent to the device + is potentially capable of making configuration changes. + + :param command: The last command sent to the target device. + :returns: A boolean indicating if cache invalidation is required or not. + """ + invalidate = False + cfg_cmds = [] + try: + # AnsiblePlugin base class in Ansible 2.9 does not have has_option() method. + # TO-DO: use has_option() when we drop 2.9 support. + cfg_cmds = self.cliconf.get_option("config_commands") + except AttributeError: + cfg_cmds = [] + if (self._is_in_config_mode()) or (to_text(command) in cfg_cmds): + invalidate = True + return invalidate diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py new file mode 100644 index 0000000..b29b487 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py @@ -0,0 +1,97 @@ +# 2017 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Core Team +connection: persistent +short_description: Use a persistent unix socket for connection +description: +- This is a helper plugin to allow making other connections persistent. +options: + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close + default: 10 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout +""" +from ansible.executor.task_executor import start_connection +from ansible.plugins.connection import ConnectionBase +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import Connection as SocketConnection +from ansible.utils.display import Display + +display = Display() + + +class Connection(ConnectionBase): + """ Local based connections """ + + transport = "ansible.netcommon.persistent" + has_pipelining = False + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._task_uuid = to_text(kwargs.get("task_uuid", "")) + + def _connect(self): + self._connected = True + return self + + def exec_command(self, cmd, in_data=None, sudoable=True): + display.vvvv( + "exec_command(), socket_path=%s" % self.socket_path, + host=self._play_context.remote_addr, + ) + connection = SocketConnection(self.socket_path) + out = connection.exec_command(cmd, in_data=in_data, sudoable=sudoable) + return 0, out, "" + + def put_file(self, in_path, out_path): + pass + + def fetch_file(self, in_path, out_path): + pass + + def close(self): + self._connected = False + + def run(self): + """Returns the path of the persistent connection socket. + + Attempts to ensure (within playcontext.timeout seconds) that the + socket path exists. If the path exists (or the timeout has expired), + returns the socket path. + """ + display.vvvv( + "starting connection from persistent connection plugin", + host=self._play_context.remote_addr, + ) + variables = { + "ansible_command_timeout": self.get_option( + "persistent_command_timeout" + ) + } + socket_path = start_connection( + self._play_context, variables, self._task_uuid + ) + display.vvvv( + "local domain socket path is %s" % socket_path, + host=self._play_context.remote_addr, + ) + setattr(self, "_socket_path", socket_path) + return socket_path diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py new file mode 100644 index 0000000..d572c30 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/connection_persistent.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r""" +options: + import_modules: + type: boolean + description: + - Reduce CPU usage and network module execution time + by enabling direct execution. Instead of the module being packaged + and executed by the shell, it will be directly executed by the Ansible + control node using the same python interpreter as the Ansible process. + Note- Incompatible with C(asynchronous mode). + Note- Python 3 and Ansible 2.9.16 or greater required. + Note- With Ansible 2.9.x fully qualified modules names are required in tasks. + default: true + ini: + - section: ansible_network + key: import_modules + env: + - name: ANSIBLE_NETWORK_IMPORT_MODULES + vars: + - name: ansible_network_import_modules + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to + return from the remote device. If this timer is exceeded before the + command returns, the connection plugin will raise an exception and + close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this + option as it could create a security vulnerability by logging sensitive information in log file. + default: False + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py new file mode 100644 index 0000000..8789075 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: + host: + description: + - Specifies the DNS host name or address for connecting to the remote device over + the specified transport. The value of host is used as the destination address + for the transport. + type: str + required: true + port: + description: + - Specifies the port to use when building the connection to the remote device. The + port value will default to port 830. + type: int + default: 830 + username: + description: + - Configures the username to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value is + not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME) + will be used instead. + type: str + password: + description: + - Specifies the password to use to authenticate the connection to the remote device. This + value is used to authenticate the SSH session. If the value is not specified + in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) will + be used instead. + type: str + timeout: + description: + - Specifies the timeout in seconds for communicating with the network device for + either connecting or sending commands. If the timeout is exceeded before the + operation is completed, the module will error. + type: int + default: 10 + ssh_keyfile: + description: + - Specifies the SSH key to use to authenticate the connection to the remote device. This + value is the path to the key used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_SSH_KEYFILE) + will be used instead. + type: path + hostkey_verify: + description: + - If set to C(yes), the ssh host key of the device must match a ssh key present + on the host if set to C(no), the ssh host key of the device is not checked. + type: bool + default: true + look_for_keys: + description: + - Enables looking in the usual locations for the ssh keys (e.g. :file:`~/.ssh/id_*`) + type: bool + default: true +notes: +- For information on using netconf see the :ref:`Platform Options guide using Netconf<netconf_enabled_platform_options>` +- For more information on using Ansible to manage network devices see the :ref:`Ansible + Network Guide <network_guide>` +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py new file mode 100644 index 0000000..ad65f6e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019 Ansible, Inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: {} +notes: +- This module is supported on C(ansible_network_os) network platforms. See the :ref:`Network + Platform Options <platform_options>` for details. +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py new file mode 100644 index 0000000..6ae47a7 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py @@ -0,0 +1,1186 @@ +# (c) 2014, Maciej Delmanowski <drybjed@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from functools import partial +import types + +try: + import netaddr +except ImportError: + # in this case, we'll make the filters return error messages (see bottom) + netaddr = None +else: + + class mac_linux(netaddr.mac_unix): + pass + + mac_linux.word_fmt = "%.2x" + +from ansible import errors + + +# ---- IP address and network query helpers ---- +def _empty_ipaddr_query(v, vtype): + # We don't have any query to process, so just check what type the user + # expects, and return the IP address in a correct format + if v: + if vtype == "address": + return str(v.ip) + elif vtype == "network": + return str(v) + + +def _first_last(v): + if v.size == 2: + first_usable = int(netaddr.IPAddress(v.first)) + last_usable = int(netaddr.IPAddress(v.last)) + return first_usable, last_usable + elif v.size > 1: + first_usable = int(netaddr.IPAddress(v.first + 1)) + last_usable = int(netaddr.IPAddress(v.last - 1)) + return first_usable, last_usable + + +def _6to4_query(v, vtype, value): + if v.version == 4: + + if v.size == 1: + ipconv = str(v.ip) + elif v.size > 1: + if v.ip != v.network: + ipconv = str(v.ip) + else: + ipconv = False + + if ipaddr(ipconv, "public"): + numbers = list(map(int, ipconv.split("."))) + + try: + return "2002:{:02x}{:02x}:{:02x}{:02x}::1/48".format(*numbers) + except Exception: + return False + + elif v.version == 6: + if vtype == "address": + if ipaddr(str(v), "2002::/16"): + return value + elif vtype == "network": + if v.ip != v.network: + if ipaddr(str(v.ip), "2002::/16"): + return value + else: + return False + + +def _ip_query(v): + if v.size == 1: + return str(v.ip) + if v.size > 1: + # /31 networks in netaddr have no broadcast address + if v.ip != v.network or not v.broadcast: + return str(v.ip) + + +def _gateway_query(v): + if v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _address_prefix_query(v): + if v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _bool_ipaddr_query(v): + if v: + return True + + +def _broadcast_query(v): + if v.size > 2: + return str(v.broadcast) + + +def _cidr_query(v): + return str(v) + + +def _cidr_lookup_query(v, iplist, value): + try: + if v in iplist: + return value + except Exception: + return False + + +def _first_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size == 2: + return str(netaddr.IPAddress(int(v.network))) + elif v.size > 1: + return str(netaddr.IPAddress(int(v.network) + 1)) + + +def _host_query(v): + if v.size == 1: + return str(v) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _hostmask_query(v): + return str(v.hostmask) + + +def _int_query(v, vtype): + if vtype == "address": + return int(v.ip) + elif vtype == "network": + return str(int(v.ip)) + "/" + str(int(v.prefixlen)) + + +def _ip_prefix_query(v): + if v.size == 2: + return str(v.ip) + "/" + str(v.prefixlen) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _ip_netmask_query(v): + if v.size == 2: + return str(v.ip) + " " + str(v.netmask) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + " " + str(v.netmask) + + +""" +def _ip_wildcard_query(v): + if v.size == 2: + return str(v.ip) + ' ' + str(v.hostmask) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + ' ' + str(v.hostmask) +""" + + +def _ipv4_query(v, value): + if v.version == 6: + try: + return str(v.ipv4()) + except Exception: + return False + else: + return value + + +def _ipv6_query(v, value): + if v.version == 4: + return str(v.ipv6()) + else: + return value + + +def _last_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + return str(netaddr.IPAddress(last_usable)) + + +def _link_local_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if v.version == 4: + if ipaddr(str(v_ip), "169.254.0.0/24"): + return value + + elif v.version == 6: + if ipaddr(str(v_ip), "fe80::/10"): + return value + + +def _loopback_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if v_ip.is_loopback(): + return value + + +def _multicast_query(v, value): + if v.is_multicast(): + return value + + +def _net_query(v): + if v.size > 1: + if v.ip == v.network: + return str(v.network) + "/" + str(v.prefixlen) + + +def _netmask_query(v): + return str(v.netmask) + + +def _network_query(v): + """Return the network of a given IP or subnet""" + return str(v.network) + + +def _network_id_query(v): + """Return the network of a given IP or subnet""" + return str(v.network) + + +def _network_netmask_query(v): + return str(v.network) + " " + str(v.netmask) + + +def _network_wildcard_query(v): + return str(v.network) + " " + str(v.hostmask) + + +def _next_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + next_ip = int(netaddr.IPAddress(int(v.ip) + 1)) + if next_ip >= first_usable and next_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) + 1)) + + +def _peer_query(v, vtype): + if vtype == "address": + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size == 2: + return str(netaddr.IPAddress(int(v.ip) ^ 1)) + if v.size == 4: + if int(v.ip) % 4 == 0: + raise errors.AnsibleFilterError( + "Network address of /30 has no peer" + ) + if int(v.ip) % 4 == 3: + raise errors.AnsibleFilterError( + "Broadcast address of /30 has no peer" + ) + return str(netaddr.IPAddress(int(v.ip) ^ 3)) + raise errors.AnsibleFilterError("Not a point-to-point network") + + +def _prefix_query(v): + return int(v.prefixlen) + + +def _previous_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + previous_ip = int(netaddr.IPAddress(int(v.ip) - 1)) + if previous_ip >= first_usable and previous_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) - 1)) + + +def _private_query(v, value): + if v.is_private(): + return value + + +def _public_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if ( + v_ip.is_unicast() + and not v_ip.is_private() + and not v_ip.is_loopback() + and not v_ip.is_netmask() + and not v_ip.is_hostmask() + ): + return value + + +def _range_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + first_usable = str(netaddr.IPAddress(first_usable)) + last_usable = str(netaddr.IPAddress(last_usable)) + return "{0}-{1}".format(first_usable, last_usable) + + +def _revdns_query(v): + v_ip = netaddr.IPAddress(str(v.ip)) + return v_ip.reverse_dns + + +def _size_query(v): + return v.size + + +def _size_usable_query(v): + if v.size == 1: + return 0 + elif v.size == 2: + return 2 + return v.size - 2 + + +def _subnet_query(v): + return str(v.cidr) + + +def _type_query(v): + if v.size == 1: + return "address" + if v.size > 1: + if v.ip != v.network: + return "address" + else: + return "network" + + +def _unicast_query(v, value): + if v.is_unicast(): + return value + + +def _version_query(v): + return v.version + + +def _wrap_query(v, vtype, value): + if v.version == 6: + if vtype == "address": + return "[" + str(v.ip) + "]" + elif vtype == "network": + return "[" + str(v.ip) + "]/" + str(v.prefixlen) + else: + return value + + +# ---- HWaddr query helpers ---- +def _bare_query(v): + v.dialect = netaddr.mac_bare + return str(v) + + +def _bool_hwaddr_query(v): + if v: + return True + + +def _int_hwaddr_query(v): + return int(v) + + +def _cisco_query(v): + v.dialect = netaddr.mac_cisco + return str(v) + + +def _empty_hwaddr_query(v, value): + if v: + return value + + +def _linux_query(v): + v.dialect = mac_linux + return str(v) + + +def _postgresql_query(v): + v.dialect = netaddr.mac_pgsql + return str(v) + + +def _unix_query(v): + v.dialect = netaddr.mac_unix + return str(v) + + +def _win_query(v): + v.dialect = netaddr.mac_eui48 + return str(v) + + +# ---- IP address and network filters ---- + +# Returns a minified list of subnets or a single subnet that spans all of +# the inputs. +def cidr_merge(value, action="merge"): + if not hasattr(value, "__iter__"): + raise errors.AnsibleFilterError( + "cidr_merge: expected iterable, got " + repr(value) + ) + + if action == "merge": + try: + return [str(ip) for ip in netaddr.cidr_merge(value)] + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + + elif action == "span": + # spanning_cidr needs at least two values + if len(value) == 0: + return None + elif len(value) == 1: + try: + return str(netaddr.IPNetwork(value[0])) + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + else: + try: + return str(netaddr.spanning_cidr(value)) + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + + else: + raise errors.AnsibleFilterError( + "cidr_merge: invalid action '%s'" % action + ) + + +def ipaddr(value, query="", version=False, alias="ipaddr"): + """ Check if string is an IP address or network and filter it """ + + query_func_extra_args = { + "": ("vtype",), + "6to4": ("vtype", "value"), + "cidr_lookup": ("iplist", "value"), + "first_usable": ("vtype",), + "int": ("vtype",), + "ipv4": ("value",), + "ipv6": ("value",), + "last_usable": ("vtype",), + "link-local": ("value",), + "loopback": ("value",), + "lo": ("value",), + "multicast": ("value",), + "next_usable": ("vtype",), + "peer": ("vtype",), + "previous_usable": ("vtype",), + "private": ("value",), + "public": ("value",), + "unicast": ("value",), + "range_usable": ("vtype",), + "wrap": ("vtype", "value"), + } + + query_func_map = { + "": _empty_ipaddr_query, + "6to4": _6to4_query, + "address": _ip_query, + "address/prefix": _address_prefix_query, # deprecate + "bool": _bool_ipaddr_query, + "broadcast": _broadcast_query, + "cidr": _cidr_query, + "cidr_lookup": _cidr_lookup_query, + "first_usable": _first_usable_query, + "gateway": _gateway_query, # deprecate + "gw": _gateway_query, # deprecate + "host": _host_query, + "host/prefix": _address_prefix_query, # deprecate + "hostmask": _hostmask_query, + "hostnet": _gateway_query, # deprecate + "int": _int_query, + "ip": _ip_query, + "ip/prefix": _ip_prefix_query, + "ip_netmask": _ip_netmask_query, + # 'ip_wildcard': _ip_wildcard_query, built then could not think of use case + "ipv4": _ipv4_query, + "ipv6": _ipv6_query, + "last_usable": _last_usable_query, + "link-local": _link_local_query, + "lo": _loopback_query, + "loopback": _loopback_query, + "multicast": _multicast_query, + "net": _net_query, + "next_usable": _next_usable_query, + "netmask": _netmask_query, + "network": _network_query, + "network_id": _network_id_query, + "network/prefix": _subnet_query, + "network_netmask": _network_netmask_query, + "network_wildcard": _network_wildcard_query, + "peer": _peer_query, + "prefix": _prefix_query, + "previous_usable": _previous_usable_query, + "private": _private_query, + "public": _public_query, + "range_usable": _range_usable_query, + "revdns": _revdns_query, + "router": _gateway_query, # deprecate + "size": _size_query, + "size_usable": _size_usable_query, + "subnet": _subnet_query, + "type": _type_query, + "unicast": _unicast_query, + "v4": _ipv4_query, + "v6": _ipv6_query, + "version": _version_query, + "wildcard": _hostmask_query, + "wrap": _wrap_query, + } + + vtype = None + + if not value: + return False + + elif value is True: + return False + + # Check if value is a list and parse each element + elif isinstance(value, (list, tuple, types.GeneratorType)): + + _ret = [] + for element in value: + if ipaddr(element, str(query), version): + _ret.append(ipaddr(element, str(query), version)) + + if _ret: + return _ret + else: + return list() + + # Check if value is a number and convert it to an IP address + elif str(value).isdigit(): + + # We don't know what IP version to assume, so let's check IPv4 first, + # then IPv6 + try: + if (not version) or (version and version == 4): + v = netaddr.IPNetwork("0.0.0.0/0") + v.value = int(value) + v.prefixlen = 32 + elif version and version == 6: + v = netaddr.IPNetwork("::/0") + v.value = int(value) + v.prefixlen = 128 + + # IPv4 didn't work the first time, so it definitely has to be IPv6 + except Exception: + try: + v = netaddr.IPNetwork("::/0") + v.value = int(value) + v.prefixlen = 128 + + # The value is too big for IPv6. Are you a nanobot? + except Exception: + return False + + # We got an IP address, let's mark it as such + value = str(v) + vtype = "address" + + # value has not been recognized, check if it's a valid IP string + else: + try: + v = netaddr.IPNetwork(value) + + # value is a valid IP string, check if user specified + # CIDR prefix or just an IP address, this will indicate default + # output format + try: + address, prefix = value.split("/") + vtype = "network" + except Exception: + vtype = "address" + + # value hasn't been recognized, maybe it's a numerical CIDR? + except Exception: + try: + address, prefix = value.split("/") + address.isdigit() + address = int(address) + prefix.isdigit() + prefix = int(prefix) + + # It's not numerical CIDR, give up + except Exception: + return False + + # It is something, so let's try and build a CIDR from the parts + try: + v = netaddr.IPNetwork("0.0.0.0/0") + v.value = address + v.prefixlen = prefix + + # It's not a valid IPv4 CIDR + except Exception: + try: + v = netaddr.IPNetwork("::/0") + v.value = address + v.prefixlen = prefix + + # It's not a valid IPv6 CIDR. Give up. + except Exception: + return False + + # We have a valid CIDR, so let's write it in correct format + value = str(v) + vtype = "network" + + # We have a query string but it's not in the known query types. Check if + # that string is a valid subnet, if so, we can check later if given IP + # address/network is inside that specific subnet + try: + # ?? 6to4 and link-local were True here before. Should they still? + if ( + query + and (query not in query_func_map or query == "cidr_lookup") + and not str(query).isdigit() + and ipaddr(query, "network") + ): + iplist = netaddr.IPSet([netaddr.IPNetwork(query)]) + query = "cidr_lookup" + except Exception: + pass + + # This code checks if value maches the IP version the user wants, ie. if + # it's any version ("ipaddr()"), IPv4 ("ipv4()") or IPv6 ("ipv6()") + # If version does not match, return False + if version and v.version != version: + return False + + extras = [] + for arg in query_func_extra_args.get(query, tuple()): + extras.append(locals()[arg]) + try: + return query_func_map[query](v, *extras) + except KeyError: + try: + float(query) + if v.size == 1: + if vtype == "address": + return str(v.ip) + elif vtype == "network": + return str(v) + + elif v.size > 1: + try: + return str(v[query]) + "/" + str(v.prefixlen) + except Exception: + return False + + else: + return value + + except Exception: + raise errors.AnsibleFilterError( + alias + ": unknown filter type: %s" % query + ) + + return False + + +def ipmath(value, amount): + try: + if "/" in value: + ip = netaddr.IPNetwork(value).ip + else: + ip = netaddr.IPAddress(value) + except (netaddr.AddrFormatError, ValueError): + msg = "You must pass a valid IP address; {0} is invalid".format(value) + raise errors.AnsibleFilterError(msg) + + if not isinstance(amount, int): + msg = ( + "You must pass an integer for arithmetic; " + "{0} is not a valid integer" + ).format(amount) + raise errors.AnsibleFilterError(msg) + + return str(ip + amount) + + +def ipwrap(value, query=""): + try: + if isinstance(value, (list, tuple, types.GeneratorType)): + _ret = [] + for element in value: + if ipaddr(element, query, version=False, alias="ipwrap"): + _ret.append(ipaddr(element, "wrap")) + else: + _ret.append(element) + + return _ret + else: + _ret = ipaddr(value, query, version=False, alias="ipwrap") + if _ret: + return ipaddr(_ret, "wrap") + else: + return value + + except Exception: + return value + + +def ipv4(value, query=""): + return ipaddr(value, query, version=4, alias="ipv4") + + +def ipv6(value, query=""): + return ipaddr(value, query, version=6, alias="ipv6") + + +# Split given subnet into smaller subnets or find out the biggest subnet of +# a given IP address with given CIDR prefix +# Usage: +# +# - address or address/prefix | ipsubnet +# returns CIDR subnet of a given input +# +# - address/prefix | ipsubnet(cidr) +# returns number of possible subnets for given CIDR prefix +# +# - address/prefix | ipsubnet(cidr, index) +# returns new subnet with given CIDR prefix +# +# - address | ipsubnet(cidr) +# returns biggest subnet with given CIDR prefix that address belongs to +# +# - address | ipsubnet(cidr, index) +# returns next indexed subnet which contains given address +# +# - address/prefix | ipsubnet(subnet/prefix) +# return the index of the subnet in the subnet +def ipsubnet(value, query="", index="x"): + """ Manipulate IPv4/IPv6 subnets """ + + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + value = netaddr.IPNetwork(v) + except Exception: + return False + query_string = str(query) + if not query: + return str(value) + + elif query_string.isdigit(): + vsize = ipaddr(v, "size") + query = int(query) + + try: + float(index) + index = int(index) + + if vsize > 1: + try: + return str(list(value.subnet(query))[index]) + except Exception: + return False + + elif vsize == 1: + try: + return str(value.supernet(query)[index]) + except Exception: + return False + + except Exception: + if vsize > 1: + try: + return str(len(list(value.subnet(query)))) + except Exception: + return False + + elif vsize == 1: + try: + return str(value.supernet(query)[0]) + except Exception: + return False + + elif query_string: + vtype = ipaddr(query, "type") + if vtype == "address": + v = ipaddr(query, "cidr") + elif vtype == "network": + v = ipaddr(query, "subnet") + else: + msg = "You must pass a valid subnet or IP address; {0} is invalid".format( + query_string + ) + raise errors.AnsibleFilterError(msg) + query = netaddr.IPNetwork(v) + for i, subnet in enumerate(query.subnet(value.prefixlen), 1): + if subnet == value: + return str(i) + msg = "{0} is not in the subnet {1}".format(value.cidr, query.cidr) + raise errors.AnsibleFilterError(msg) + return False + + +# Returns the nth host within a network described by value. +# Usage: +# +# - address or address/prefix | nthhost(nth) +# returns the nth host within the given network +def nthhost(value, query=""): + """ Get the nth host within a given network """ + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + value = netaddr.IPNetwork(v) + except Exception: + return False + + if not query: + return False + + try: + nth = int(query) + if value.size > nth: + return value[nth] + + except ValueError: + return False + + return False + + +# Returns the next nth usable ip within a network described by value. +def next_nth_usable(value, offset): + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + v = netaddr.IPNetwork(v) + except Exception: + return False + + if type(offset) != int: + raise errors.AnsibleFilterError("Must pass in an integer") + if v.size > 1: + first_usable, last_usable = _first_last(v) + nth_ip = int(netaddr.IPAddress(int(v.ip) + offset)) + if nth_ip >= first_usable and nth_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) + offset)) + + +# Returns the previous nth usable ip within a network described by value. +def previous_nth_usable(value, offset): + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + v = netaddr.IPNetwork(v) + except Exception: + return False + + if type(offset) != int: + raise errors.AnsibleFilterError("Must pass in an integer") + if v.size > 1: + first_usable, last_usable = _first_last(v) + nth_ip = int(netaddr.IPAddress(int(v.ip) - offset)) + if nth_ip >= first_usable and nth_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) - offset)) + + +def _range_checker(ip_check, first, last): + """ + Tests whether an ip address is within the bounds of the first and last address. + + :param ip_check: The ip to test if it is within first and last. + :param first: The first IP in the range to test against. + :param last: The last IP in the range to test against. + + :return: bool + """ + if ip_check >= first and ip_check <= last: + return True + else: + return False + + +def _address_normalizer(value): + """ + Used to validate an address or network type and return it in a consistent format. + This is being used for future use cases not currently available such as an address range. + + :param value: The string representation of an address or network. + + :return: The address or network in the normalized form. + """ + try: + vtype = ipaddr(value, "type") + if vtype == "address" or vtype == "network": + v = ipaddr(value, "subnet") + except Exception: + return False + + return v + + +def network_in_usable(value, test): + """ + Checks whether 'test' is a useable address or addresses in 'value' + + :param: value: The string representation of an address or network to test against. + :param test: The string representation of an address or network to validate if it is within the range of 'value'. + + :return: bool + """ + # normalize value and test variables into an ipaddr + v = _address_normalizer(value) + w = _address_normalizer(test) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + v_first = ipaddr(ipaddr(v, "first_usable") or ipaddr(v, "address"), "int") + v_last = ipaddr(ipaddr(v, "last_usable") or ipaddr(v, "address"), "int") + w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int") + w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int") + + if _range_checker(w_first, v_first, v_last) and _range_checker( + w_last, v_first, v_last + ): + return True + else: + return False + + +def network_in_network(value, test): + """ + Checks whether the 'test' address or addresses are in 'value', including broadcast and network + + :param: value: The network address or range to test against. + :param test: The address or network to validate if it is within the range of 'value'. + + :return: bool + """ + # normalize value and test variables into an ipaddr + v = _address_normalizer(value) + w = _address_normalizer(test) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + v_first = ipaddr(ipaddr(v, "network") or ipaddr(v, "address"), "int") + v_last = ipaddr(ipaddr(v, "broadcast") or ipaddr(v, "address"), "int") + w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int") + w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int") + + if _range_checker(w_first, v_first, v_last) and _range_checker( + w_last, v_first, v_last + ): + return True + else: + return False + + +def reduce_on_network(value, network): + """ + Reduces a list of addresses to only the addresses that match a given network. + + :param: value: The list of addresses to filter on. + :param: network: The network to validate against. + + :return: The reduced list of addresses. + """ + # normalize network variable into an ipaddr + n = _address_normalizer(network) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + n_first = ipaddr(ipaddr(n, "network") or ipaddr(n, "address"), "int") + n_last = ipaddr(ipaddr(n, "broadcast") or ipaddr(n, "address"), "int") + + # create an empty list to fill and return + r = [] + + for address in value: + # normalize address variables into an ipaddr + a = _address_normalizer(address) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + a_first = ipaddr(ipaddr(a, "network") or ipaddr(a, "address"), "int") + a_last = ipaddr(ipaddr(a, "broadcast") or ipaddr(a, "address"), "int") + + if _range_checker(a_first, n_first, n_last) and _range_checker( + a_last, n_first, n_last + ): + r.append(address) + + return r + + +# Returns the SLAAC address within a network for a given HW/MAC address. +# Usage: +# +# - prefix | slaac(mac) +def slaac(value, query=""): + """ Get the SLAAC address within given network """ + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + if ipaddr(value, "version") != 6: + return False + + value = netaddr.IPNetwork(v) + except Exception: + return False + + if not query: + return False + + try: + mac = hwaddr(query, alias="slaac") + + eui = netaddr.EUI(mac) + except Exception: + return False + + return eui.ipv6(value.network) + + +# ---- HWaddr / MAC address filters ---- +def hwaddr(value, query="", alias="hwaddr"): + """ Check if string is a HW/MAC address and filter it """ + + query_func_extra_args = {"": ("value",)} + + query_func_map = { + "": _empty_hwaddr_query, + "bare": _bare_query, + "bool": _bool_hwaddr_query, + "int": _int_hwaddr_query, + "cisco": _cisco_query, + "eui48": _win_query, + "linux": _linux_query, + "pgsql": _postgresql_query, + "postgresql": _postgresql_query, + "psql": _postgresql_query, + "unix": _unix_query, + "win": _win_query, + } + + try: + v = netaddr.EUI(value) + except Exception: + if query and query != "bool": + raise errors.AnsibleFilterError( + alias + ": not a hardware address: %s" % value + ) + + extras = [] + for arg in query_func_extra_args.get(query, tuple()): + extras.append(locals()[arg]) + try: + return query_func_map[query](v, *extras) + except KeyError: + raise errors.AnsibleFilterError( + alias + ": unknown filter type: %s" % query + ) + + return False + + +def macaddr(value, query=""): + return hwaddr(value, query, alias="macaddr") + + +def _need_netaddr(f_name, *args, **kwargs): + raise errors.AnsibleFilterError( + "The %s filter requires python's netaddr be " + "installed on the ansible controller" % f_name + ) + + +def ip4_hex(arg, delimiter=""): + """ Convert an IPv4 address to Hexadecimal notation """ + numbers = list(map(int, arg.split("."))) + return "{0:02x}{sep}{1:02x}{sep}{2:02x}{sep}{3:02x}".format( + *numbers, sep=delimiter + ) + + +# ---- Ansible filters ---- +class FilterModule(object): + """ IP address and network manipulation filters """ + + filter_map = { + # IP addresses and networks + "cidr_merge": cidr_merge, + "ipaddr": ipaddr, + "ipmath": ipmath, + "ipwrap": ipwrap, + "ip4_hex": ip4_hex, + "ipv4": ipv4, + "ipv6": ipv6, + "ipsubnet": ipsubnet, + "next_nth_usable": next_nth_usable, + "network_in_network": network_in_network, + "network_in_usable": network_in_usable, + "reduce_on_network": reduce_on_network, + "nthhost": nthhost, + "previous_nth_usable": previous_nth_usable, + "slaac": slaac, + # MAC / HW addresses + "hwaddr": hwaddr, + "macaddr": macaddr, + } + + def filters(self): + if netaddr: + return self.filter_map + else: + # Need to install python's netaddr for these filters to work + return dict( + (f, partial(_need_netaddr, f)) for f in self.filter_map + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py new file mode 100644 index 0000000..72d6c86 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py @@ -0,0 +1,531 @@ +# +# {c) 2017 Red Hat, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import re +import os +import traceback +import string + +from collections.abc import Mapping +from xml.etree.ElementTree import fromstring + +from ansible.module_utils._text import to_native, to_text +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + Template, +) +from ansible.module_utils.six import iteritems, string_types +from ansible.errors import AnsibleError, AnsibleFilterError +from ansible.utils.display import Display +from ansible.utils.encrypt import passlib_or_crypt, random_password + +try: + import yaml + + HAS_YAML = True +except ImportError: + HAS_YAML = False + +try: + import textfsm + + HAS_TEXTFSM = True +except ImportError: + HAS_TEXTFSM = False + +display = Display() + + +def re_matchall(regex, value): + objects = list() + for match in re.findall(regex.pattern, value, re.M): + obj = {} + if regex.groupindex: + for name, index in iteritems(regex.groupindex): + if len(regex.groupindex) == 1: + obj[name] = match + else: + obj[name] = match[index - 1] + objects.append(obj) + return objects + + +def re_search(regex, value): + obj = {} + match = regex.search(value, re.M) + if match: + items = list(match.groups()) + if regex.groupindex: + for name, index in iteritems(regex.groupindex): + obj[name] = items[index - 1] + return obj + + +def parse_cli(output, tmpl): + if not isinstance(output, string_types): + raise AnsibleError( + "parse_cli input should be a string, but was given a input of %s" + % (type(output)) + ) + + if not os.path.exists(tmpl): + raise AnsibleError("unable to locate parse_cli template: %s" % tmpl) + + try: + template = Template() + except ImportError as exc: + raise AnsibleError(to_native(exc)) + + with open(tmpl) as tmpl_fh: + tmpl_content = tmpl_fh.read() + + spec = yaml.safe_load(tmpl_content) + obj = {} + + for name, attrs in iteritems(spec["keys"]): + value = attrs["value"] + + try: + variables = spec.get("vars", {}) + value = template(value, variables) + except Exception: + pass + + if "start_block" in attrs and "end_block" in attrs: + start_block = re.compile(attrs["start_block"]) + end_block = re.compile(attrs["end_block"]) + + blocks = list() + lines = None + block_started = False + + for line in output.split("\n"): + match_start = start_block.match(line) + match_end = end_block.match(line) + + if match_start: + lines = list() + lines.append(line) + block_started = True + + elif match_end: + if lines: + lines.append(line) + blocks.append("\n".join(lines)) + block_started = False + + elif block_started: + if lines: + lines.append(line) + + regex_items = [re.compile(r) for r in attrs["items"]] + objects = list() + + for block in blocks: + if isinstance(value, Mapping) and "key" not in value: + items = list() + for regex in regex_items: + match = regex.search(block) + if match: + item_values = match.groupdict() + item_values["match"] = list(match.groups()) + items.append(item_values) + else: + items.append(None) + + obj = {} + for k, v in iteritems(value): + try: + obj[k] = template( + v, {"item": items}, fail_on_undefined=False + ) + except Exception: + obj[k] = None + objects.append(obj) + + elif isinstance(value, Mapping): + items = list() + for regex in regex_items: + match = regex.search(block) + if match: + item_values = match.groupdict() + item_values["match"] = list(match.groups()) + items.append(item_values) + else: + items.append(None) + + key = template(value["key"], {"item": items}) + values = dict( + [ + (k, template(v, {"item": items})) + for k, v in iteritems(value["values"]) + ] + ) + objects.append({key: values}) + + return objects + + elif "items" in attrs: + regexp = re.compile(attrs["items"]) + when = attrs.get("when") + conditional = ( + "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when + ) + + if isinstance(value, Mapping) and "key" not in value: + values = list() + + for item in re_matchall(regexp, output): + entry = {} + + for item_key, item_value in iteritems(value): + entry[item_key] = template(item_value, {"item": item}) + + if when: + if template(conditional, {"item": entry}): + values.append(entry) + else: + values.append(entry) + + obj[name] = values + + elif isinstance(value, Mapping): + values = dict() + + for item in re_matchall(regexp, output): + entry = {} + + for item_key, item_value in iteritems(value["values"]): + entry[item_key] = template(item_value, {"item": item}) + + key = template(value["key"], {"item": item}) + + if when: + if template( + conditional, {"item": {"key": key, "value": entry}} + ): + values[key] = entry + else: + values[key] = entry + + obj[name] = values + + else: + item = re_search(regexp, output) + obj[name] = template(value, {"item": item}) + + else: + obj[name] = value + + return obj + + +def parse_cli_textfsm(value, template): + if not HAS_TEXTFSM: + raise AnsibleError( + "parse_cli_textfsm filter requires TextFSM library to be installed" + ) + + if not isinstance(value, string_types): + raise AnsibleError( + "parse_cli_textfsm input should be a string, but was given a input of %s" + % (type(value)) + ) + + if not os.path.exists(template): + raise AnsibleError( + "unable to locate parse_cli_textfsm template: %s" % template + ) + + try: + template = open(template) + except IOError as exc: + raise AnsibleError(to_native(exc)) + + re_table = textfsm.TextFSM(template) + fsm_results = re_table.ParseText(value) + + results = list() + for item in fsm_results: + results.append(dict(zip(re_table.header, item))) + + return results + + +def _extract_param(template, root, attrs, value): + + key = None + when = attrs.get("when") + conditional = "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when + param_to_xpath_map = attrs["items"] + + if isinstance(value, Mapping): + key = value.get("key", None) + if key: + value = value["values"] + + entries = dict() if key else list() + + for element in root.findall(attrs["top"]): + entry = dict() + item_dict = dict() + for param, param_xpath in iteritems(param_to_xpath_map): + fields = None + try: + fields = element.findall(param_xpath) + except Exception: + display.warning( + "Failed to evaluate value of '%s' with XPath '%s'.\nUnexpected error: %s." + % (param, param_xpath, traceback.format_exc()) + ) + + tags = param_xpath.split("/") + + # check if xpath ends with attribute. + # If yes set attribute key/value dict to param value in case attribute matches + # else if it is a normal xpath assign matched element text value. + if len(tags) and tags[-1].endswith("]"): + if fields: + if len(fields) > 1: + item_dict[param] = [field.attrib for field in fields] + else: + item_dict[param] = fields[0].attrib + else: + item_dict[param] = {} + else: + if fields: + if len(fields) > 1: + item_dict[param] = [field.text for field in fields] + else: + item_dict[param] = fields[0].text + else: + item_dict[param] = None + + if isinstance(value, Mapping): + for item_key, item_value in iteritems(value): + entry[item_key] = template(item_value, {"item": item_dict}) + else: + entry = template(value, {"item": item_dict}) + + if key: + expanded_key = template(key, {"item": item_dict}) + if when: + if template( + conditional, + {"item": {"key": expanded_key, "value": entry}}, + ): + entries[expanded_key] = entry + else: + entries[expanded_key] = entry + else: + if when: + if template(conditional, {"item": entry}): + entries.append(entry) + else: + entries.append(entry) + + return entries + + +def parse_xml(output, tmpl): + if not os.path.exists(tmpl): + raise AnsibleError("unable to locate parse_xml template: %s" % tmpl) + + if not isinstance(output, string_types): + raise AnsibleError( + "parse_xml works on string input, but given input of : %s" + % type(output) + ) + + root = fromstring(output) + try: + template = Template() + except ImportError as exc: + raise AnsibleError(to_native(exc)) + + with open(tmpl) as tmpl_fh: + tmpl_content = tmpl_fh.read() + + spec = yaml.safe_load(tmpl_content) + obj = {} + + for name, attrs in iteritems(spec["keys"]): + value = attrs["value"] + + try: + variables = spec.get("vars", {}) + value = template(value, variables) + except Exception: + pass + + if "items" in attrs: + obj[name] = _extract_param(template, root, attrs, value) + else: + obj[name] = value + + return obj + + +def type5_pw(password, salt=None): + if not isinstance(password, string_types): + raise AnsibleFilterError( + "type5_pw password input should be a string, but was given a input of %s" + % (type(password).__name__) + ) + + salt_chars = u"".join( + (to_text(string.ascii_letters), to_text(string.digits), u"./") + ) + if salt is not None and not isinstance(salt, string_types): + raise AnsibleFilterError( + "type5_pw salt input should be a string, but was given a input of %s" + % (type(salt).__name__) + ) + elif not salt: + salt = random_password(length=4, chars=salt_chars) + elif not set(salt) <= set(salt_chars): + raise AnsibleFilterError( + "type5_pw salt used inproper characters, must be one of %s" + % (salt_chars) + ) + + encrypted_password = passlib_or_crypt(password, "md5_crypt", salt=salt) + + return encrypted_password + + +def hash_salt(password): + + split_password = password.split("$") + if len(split_password) != 4: + raise AnsibleFilterError( + "Could not parse salt out password correctly from {0}".format( + password + ) + ) + else: + return split_password[2] + + +def comp_type5( + unencrypted_password, encrypted_password, return_original=False +): + + salt = hash_salt(encrypted_password) + if type5_pw(unencrypted_password, salt) == encrypted_password: + if return_original is True: + return encrypted_password + else: + return True + return False + + +def vlan_parser(vlan_list, first_line_len=48, other_line_len=44): + + """ + Input: Unsorted list of vlan integers + Output: Sorted string list of integers according to IOS-like vlan list rules + + 1. Vlans are listed in ascending order + 2. Runs of 3 or more consecutive vlans are listed with a dash + 3. The first line of the list can be first_line_len characters long + 4. Subsequent list lines can be other_line_len characters + """ + + # Sort and remove duplicates + sorted_list = sorted(set(vlan_list)) + + if sorted_list[0] < 1 or sorted_list[-1] > 4094: + raise AnsibleFilterError("Valid VLAN range is 1-4094") + + parse_list = [] + idx = 0 + while idx < len(sorted_list): + start = idx + end = start + while end < len(sorted_list) - 1: + if sorted_list[end + 1] - sorted_list[end] == 1: + end += 1 + else: + break + + if start == end: + # Single VLAN + parse_list.append(str(sorted_list[idx])) + elif start + 1 == end: + # Run of 2 VLANs + parse_list.append(str(sorted_list[start])) + parse_list.append(str(sorted_list[end])) + else: + # Run of 3 or more VLANs + parse_list.append( + str(sorted_list[start]) + "-" + str(sorted_list[end]) + ) + idx = end + 1 + + line_count = 0 + result = [""] + for vlans in parse_list: + # First line (" switchport trunk allowed vlan ") + if line_count == 0: + if len(result[line_count] + vlans) > first_line_len: + result.append("") + line_count += 1 + result[line_count] += vlans + "," + else: + result[line_count] += vlans + "," + + # Subsequent lines (" switchport trunk allowed vlan add ") + else: + if len(result[line_count] + vlans) > other_line_len: + result.append("") + line_count += 1 + result[line_count] += vlans + "," + else: + result[line_count] += vlans + "," + + # Remove trailing orphan commas + for idx in range(0, len(result)): + result[idx] = result[idx].rstrip(",") + + # Sometimes text wraps to next line, but there are no remaining VLANs + if "" in result: + result.remove("") + + return result + + +class FilterModule(object): + """Filters for working with output from network devices""" + + filter_map = { + "parse_cli": parse_cli, + "parse_cli_textfsm": parse_cli_textfsm, + "parse_xml": parse_xml, + "type5_pw": type5_pw, + "hash_salt": hash_salt, + "comp_type5": comp_type5, + "vlan_parser": vlan_parser, + } + + def filters(self): + return self.filter_map diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py new file mode 100644 index 0000000..8afb3e5 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py @@ -0,0 +1,91 @@ +# Copyright (c) 2018 Cisco and/or its affiliates. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +httpapi: restconf +short_description: HttpApi Plugin for devices supporting Restconf API +description: +- This HttpApi plugin provides methods to connect to Restconf API endpoints. +options: + root_path: + type: str + description: + - Specifies the location of the Restconf root. + default: /restconf + vars: + - name: ansible_httpapi_restconf_root +""" + +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import ConnectionError +from ansible.module_utils.six.moves.urllib.error import HTTPError +from ansible.plugins.httpapi import HttpApiBase + + +CONTENT_TYPE = "application/yang-data+json" + + +class HttpApi(HttpApiBase): + def send_request(self, data, **message_kwargs): + if data: + data = json.dumps(data) + + path = "/".join( + [ + self.get_option("root_path").rstrip("/"), + message_kwargs.get("path", "").lstrip("/"), + ] + ) + + headers = { + "Content-Type": message_kwargs.get("content_type") or CONTENT_TYPE, + "Accept": message_kwargs.get("accept") or CONTENT_TYPE, + } + response, response_data = self.connection.send( + path, data, headers=headers, method=message_kwargs.get("method") + ) + + return handle_response(response, response_data) + + +def handle_response(response, response_data): + try: + response_data = json.loads(response_data.read()) + except ValueError: + response_data = response_data.read() + + if isinstance(response, HTTPError): + if response_data: + if "errors" in response_data: + errors = response_data["errors"]["error"] + error_text = "\n".join( + (error["error-message"] for error in errors) + ) + else: + error_text = response_data + + raise ConnectionError(error_text, code=response.code) + raise ConnectionError(to_text(response), code=response.code) + + return response_data diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py new file mode 100644 index 0000000..dc0a19f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py @@ -0,0 +1,2578 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file, and this file only, is based on +# Lib/ipaddress.py of cpython +# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved" +# are retained in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. + +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +from __future__ import unicode_literals + + +import itertools +import struct + + +# The following makes it easier for us to script updates of the bundled code and is not part of +# upstream +_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"} + +__version__ = "1.0.22" + +# Compatibility functions +_compat_int_types = (int,) +try: + _compat_int_types = (int, long) +except NameError: + pass +try: + _compat_str = unicode +except NameError: + _compat_str = str + assert bytes != str +if b"\0"[0] == 0: # Python 3 semantics + + def _compat_bytes_to_byte_vals(byt): + return byt + + +else: + + def _compat_bytes_to_byte_vals(byt): + return [struct.unpack(b"!B", b)[0] for b in byt] + + +try: + _compat_int_from_byte_vals = int.from_bytes +except AttributeError: + + def _compat_int_from_byte_vals(bytvals, endianess): + assert endianess == "big" + res = 0 + for bv in bytvals: + assert isinstance(bv, _compat_int_types) + res = (res << 8) + bv + return res + + +def _compat_to_bytes(intval, length, endianess): + assert isinstance(intval, _compat_int_types) + assert endianess == "big" + if length == 4: + if intval < 0 or intval >= 2 ** 32: + raise struct.error("integer out of range for 'I' format code") + return struct.pack(b"!I", intval) + elif length == 16: + if intval < 0 or intval >= 2 ** 128: + raise struct.error("integer out of range for 'QQ' format code") + return struct.pack(b"!QQ", intval >> 64, intval & 0xFFFFFFFFFFFFFFFF) + else: + raise NotImplementedError() + + +if hasattr(int, "bit_length"): + # Not int.bit_length , since that won't work in 2.7 where long exists + def _compat_bit_length(i): + return i.bit_length() + + +else: + + def _compat_bit_length(i): + for res in itertools.count(): + if i >> res == 0: + return res + + +def _compat_range(start, end, step=1): + assert step > 0 + i = start + while i < end: + yield i + i += step + + +class _TotalOrderingMixin(object): + __slots__ = () + + # Helper that derives the other comparison operations from + # __lt__ and __eq__ + # We avoid functools.total_ordering because it doesn't handle + # NotImplemented correctly yet (http://bugs.python.org/issue10042) + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + + def __lt__(self, other): + raise NotImplementedError + + def __le__(self, other): + less = self.__lt__(other) + if less is NotImplemented or not less: + return self.__eq__(other) + return less + + def __gt__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not (less or equal) + + def __ge__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + return not less + + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + "%r does not appear to be an IPv4 or IPv6 address. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" % address + ) + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 address" % address + ) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + "%r does not appear to be an IPv4 or IPv6 network. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" % address + ) + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 network" % address + ) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 interface" % address + ) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return _compat_to_bytes(address, 4, "big") + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return _compat_to_bytes(address, 16, "big") + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = _compat_str(address).split("/") + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) # pylint: disable=stop-iteration-return + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, _compat_bit_length(~number & (number - 1))) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if not ( + isinstance(first, _BaseAddress) and isinstance(last, _BaseAddress) + ): + raise TypeError("first and last must be IP addresses, not networks") + if first.version != last.version: + raise TypeError( + "%s and %s are not of the same version" % (first, last) + ) + if first > last: + raise ValueError("last IP address must be greater than first") + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError("unknown IP version") + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min( + _count_righthand_zero_bits(first_int, ip_bits), + _compat_bit_length(last_int - first_int + 1) - 1, + ) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, + # last.network_address <= net.network_address is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, ips[-1]) + ) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, ips[-1]) + ) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, nets[-1]) + ) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase(_TotalOrderingMixin): + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return _compat_str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = "%200s has no version specified" % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError( + msg % (address, self._max_prefixlen, self._version) + ) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = ( + "%r (len %d != %d) is not permitted as an IPv%d address. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" + ) + raise AddressValueError( + msg % (address, address_len, expected_len, self._version) + ) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits( + ip_int, cls._max_prefixlen + ) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = _compat_to_bytes(ip_int, byteslen, "big") + msg = "Netmask pattern %r mixes zeroes & ones" + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = "%r is not a valid netmask" % netmask_str + raise NetmaskValueError(msg) + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + def __reduce__(self): + return self.__class__, (_compat_str(self),) + + +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + try: + return self._ip == other._ip and self._version == other._version + except AttributeError: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseAddress): + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + if self._ip != other._ip: + return self._ip < other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return _compat_str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + +class _BaseNetwork(_IPAddressBase): + + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + + """ + + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return "%s/%d" % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError("address out of range") + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError("address out of range") + return self._address_class(broadcast + n) + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseNetwork): + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __eq__(self, other): + try: + return ( + self._version == other._version + and self.network_address == other.network_address + and int(self.netmask) == int(other.netmask) + ) + except AttributeError: + return NotImplemented + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return ( + int(self.network_address) + <= int(other._ip) + <= int(self.broadcast_address) + ) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other + or ( + other.network_address in self + or (other.broadcast_address in self) + ) + ) + + @property + def broadcast_address(self): + x = self._cache.get("broadcast_address") + if x is None: + x = self._address_class( + int(self.network_address) | int(self.hostmask) + ) + self._cache["broadcast_address"] = x + return x + + @property + def hostmask(self): + x = self._cache.get("hostmask") + if x is None: + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) + self._cache["hostmask"] = x + return x + + @property + def with_prefixlen(self): + return "%s/%d" % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = "%200s has no associated address class" % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not other.subnet_of(self): + raise ValueError("%s not contained in %s" % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__( + "%s/%s" % (other.network_address, other.prefixlen) + ) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if other.subnet_of(s1): + yield s2 + s1, s2 = s1.subnets() + elif other.subnet_of(s2): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError( + "Error performing exclusion: " + "s1: %s s2: %s other: %s" % (s1, s2, other) + ) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError( + "Error performing exclusion: " + "s1: %s s2: %s other: %s" % (s1, s2, other) + ) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError("new prefix must be longer") + if prefixlen_diff != 1: + raise ValueError("cannot set prefixlen_diff and new_prefix") + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError("prefix length diff must be > 0") + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + "prefix length diff %d is invalid for netblock %s" + % (new_prefixlen, self) + ) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in _compat_range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError("new prefix must be shorter") + if prefixlen_diff != 1: + raise ValueError("cannot set prefixlen_diff and new_prefix") + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + "current prefixlen is %d, cannot have a prefixlen_diff of %d" + % (self.prefixlen, prefixlen_diff) + ) + return self.__class__( + ( + int(self.network_address) + & (int(self.netmask) << prefixlen_diff), + new_prefixlen, + ) + ) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return ( + self.network_address.is_multicast + and self.broadcast_address.is_multicast + ) + + @staticmethod + def _is_subnet_of(a, b): + try: + # Always false if one is v4 and the other is v6. + if a._version != b._version: + raise TypeError( + "%s and %s are not of the same version" % (a, b) + ) + return ( + b.network_address <= a.network_address + and b.broadcast_address >= a.broadcast_address + ) + except AttributeError: + raise TypeError( + "Unable to test subnet containment " + "between %s and %s" % (a, b) + ) + + def subnet_of(self, other): + """Return True if this network is a subnet of other.""" + return self._is_subnet_of(self, other) + + def supernet_of(self, other): + """Return True if this network is a supernet of other.""" + return self._is_subnet_of(other, self) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return ( + self.network_address.is_reserved + and self.broadcast_address.is_reserved + ) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return ( + self.network_address.is_link_local + and self.broadcast_address.is_link_local + ) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return ( + self.network_address.is_private + and self.broadcast_address.is_private + ) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return ( + self.network_address.is_unspecified + and self.broadcast_address.is_unspecified + ) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return ( + self.network_address.is_loopback + and self.broadcast_address.is_loopback + ) + + +class _BaseV4(object): + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2 ** IPV4LENGTH) - 1 + _DECIMAL_DIGITS = frozenset("0123456789") + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0]) + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return _compat_str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError("Address cannot be empty") + + octets = ip_str.split(".") + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return _compat_int_from_byte_vals( + map(cls._parse_octet, octets), "big" + ) + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._DECIMAL_DIGITS.issuperset(octet_str): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + # Any octets that look like they *might* be written in octal, + # and which don't look exactly the same in both octal and + # decimal are rejected as ambiguous + if octet_int > 7 and octet_str[0] == "0": + msg = "Ambiguous (octal/decimal) value in %r not permitted" + raise ValueError(msg % octet_str) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return ".".join( + _compat_str( + struct.unpack(b"!B", b)[0] if isinstance(b, bytes) else b + ) + for b in _compat_to_bytes(ip_int, 4, "big") + ) + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split(".") + try: + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = _compat_str(self).split(".")[::-1] + return ".".join(reverse_octets) + ".in-addr.arpa" + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ("_ip", "__weakref__") + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, "big") + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if "/" in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + return ( + self not in self._constants._public_network and not self.is_private + ) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv4Address.__init__(self, address) + self.network = IPv4Network(self._ip) + self._prefixlen = self._max_prefixlen + return + + if isinstance(address, tuple): + IPv4Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + + self.network = IPv4Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv4Address.__init__(self, addr[0]) + + self.network = IPv4Network(address, strict=False) + self._prefixlen = self.network._prefixlen + + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + + def __str__(self): + return "%s/%d" % ( + self._string_from_ip_int(self._ip), + self.network.prefixlen, + ) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return ( + self.network < other.network + or self.network == other.network + and address_less + ) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.0.2/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (_compat_int_types, bytes)): + self.network_address = IPv4Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen + ) + # fixme: address/network test here. + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + # We weren't given an address[1] + arg = self._max_prefixlen + self.network_address = IPv4Address(address[0]) + self.netmask, self._prefixlen = self._make_netmask(arg) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError("%s has host bits set" % self) + else: + self.network_address = IPv4Address( + packed & int(self.netmask) + ) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if ( + IPv4Address(int(self.network_address) & int(self.netmask)) + != self.network_address + ): + raise ValueError("%s has host bits set" % self) + self.network_address = IPv4Address( + int(self.network_address) & int(self.netmask) + ) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return ( + not ( + self.network_address in IPv4Network("100.64.0.0/10") + and self.broadcast_address in IPv4Network("100.64.0.0/10") + ) + and not self.is_private + ) + + +class _IPv4Constants(object): + + _linklocal_network = IPv4Network("169.254.0.0/16") + + _loopback_network = IPv4Network("127.0.0.0/8") + + _multicast_network = IPv4Network("224.0.0.0/4") + + _public_network = IPv4Network("100.64.0.0/10") + + _private_networks = [ + IPv4Network("0.0.0.0/8"), + IPv4Network("10.0.0.0/8"), + IPv4Network("127.0.0.0/8"), + IPv4Network("169.254.0.0/16"), + IPv4Network("172.16.0.0/12"), + IPv4Network("192.0.0.0/29"), + IPv4Network("192.0.0.170/31"), + IPv4Network("192.0.2.0/24"), + IPv4Network("192.168.0.0/16"), + IPv4Network("198.18.0.0/15"), + IPv4Network("198.51.100.0/24"), + IPv4Network("203.0.113.0/24"), + IPv4Network("240.0.0.0/4"), + IPv4Network("255.255.255.255/32"), + ] + + _reserved_network = IPv4Network("240.0.0.0/4") + + _unspecified_address = IPv4Address("0.0.0.0") + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6(object): + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2 ** IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset("0123456789ABCDEFabcdef") + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError("Address cannot be empty") + + parts = ip_str.split(":") + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if "." in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + parts.append("%x" % ((ipv4_int >> 16) & 0xFFFF)) + parts.append("%x" % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % ( + _max_parts - 1, + ip_str, + ) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in _compat_range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == "0": + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = ( + best_doublecolon_start + best_doublecolon_len + ) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [""] + hextets[best_doublecolon_start:best_doublecolon_end] = [""] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [""] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError("IPv6 address is too large") + + hex_str = "%032x" % ip_int + hextets = ["%x" % int(hex_str[x : x + 4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ":".join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = _compat_str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = _compat_str(self.ip) + else: + ip_str = _compat_str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = "%032x" % ip_int + parts = [hex_str[x : x + 4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return "%s/%d" % (":".join(parts), self._prefixlen) + return ":".join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(":", "") + return ".".join(reverse_chars) + ".ip6.arpa" + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ("_ip", "__weakref__") + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, "big") + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if "/" in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return ( + IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF), + ) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv6Address.__init__(self, address) + self.network = IPv6Network(self._ip) + self._prefixlen = self._max_prefixlen + return + if isinstance(address, tuple): + IPv6Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv6Address.__init__(self, addr[0]) + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + self.hostmask = self.network.hostmask + + def __str__(self): + return "%s/%d" % ( + self._string_from_ip_int(self._ip), + self.network.prefixlen, + ) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return ( + self.network < other.network + or self.network == other.network + and address_less + ) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Efficient constructor from integer or packed address + if isinstance(address, (bytes, _compat_int_types)): + self.network_address = IPv6Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen + ) + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + self.network_address = IPv6Address(address[0]) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError("%s has host bits set" % self) + else: + self.network_address = IPv6Address( + packed & int(self.netmask) + ) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + + self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if ( + IPv6Address(int(self.network_address) & int(self.netmask)) + != self.network_address + ): + raise ValueError("%s has host bits set" % self) + self.network_address = IPv6Address( + int(self.network_address) & int(self.netmask) + ) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return ( + self.network_address.is_site_local + and self.broadcast_address.is_site_local + ) + + +class _IPv6Constants(object): + + _linklocal_network = IPv6Network("fe80::/10") + + _multicast_network = IPv6Network("ff00::/8") + + _private_networks = [ + IPv6Network("::1/128"), + IPv6Network("::/128"), + IPv6Network("::ffff:0:0/96"), + IPv6Network("100::/64"), + IPv6Network("2001::/23"), + IPv6Network("2001:2::/48"), + IPv6Network("2001:db8::/32"), + IPv6Network("2001:10::/28"), + IPv6Network("fc00::/7"), + IPv6Network("fe80::/10"), + ] + + _reserved_networks = [ + IPv6Network("::/8"), + IPv6Network("100::/8"), + IPv6Network("200::/7"), + IPv6Network("400::/6"), + IPv6Network("800::/5"), + IPv6Network("1000::/4"), + IPv6Network("4000::/3"), + IPv6Network("6000::/3"), + IPv6Network("8000::/3"), + IPv6Network("A000::/3"), + IPv6Network("C000::/3"), + IPv6Network("E000::/4"), + IPv6Network("F000::/5"), + IPv6Network("F800::/6"), + IPv6Network("FE00::/9"), + ] + + _sitelocal_network = IPv6Network("fec0::/10") + + +IPv6Address._constants = _IPv6Constants diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py new file mode 100644 index 0000000..68608d1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py @@ -0,0 +1,27 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The base class for all resource modules +""" + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import ( + get_resource_connection, +) + + +class ConfigBase(object): + """ The base class for all resource modules + """ + + ACTION_STATES = ["merged", "replaced", "overridden", "deleted"] + + def __init__(self, module): + self._module = module + self.state = module.params["state"] + self._connection = None + + if self.state not in ["rendered", "parsed"]: + self._connection = get_resource_connection(module) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py new file mode 100644 index 0000000..bc458eb --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py @@ -0,0 +1,473 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import re +import hashlib + +from ansible.module_utils.six.moves import zip +from ansible.module_utils._text import to_bytes, to_native + +DEFAULT_COMMENT_TOKENS = ["#", "!", "/*", "*/", "echo"] + +DEFAULT_IGNORE_LINES_RE = set( + [ + re.compile(r"Using \d+ out of \d+ bytes"), + re.compile(r"Building configuration"), + re.compile(r"Current configuration : \d+ bytes"), + ] +) + + +try: + Pattern = re._pattern_type +except AttributeError: + Pattern = re.Pattern + + +class ConfigLine(object): + def __init__(self, raw): + self.text = str(raw).strip() + self.raw = raw + self._children = list() + self._parents = list() + + def __str__(self): + return self.raw + + def __eq__(self, other): + return self.line == other.line + + def __ne__(self, other): + return not self.__eq__(other) + + def __getitem__(self, key): + for item in self._children: + if item.text == key: + return item + raise KeyError(key) + + @property + def line(self): + line = self.parents + line.append(self.text) + return " ".join(line) + + @property + def children(self): + return _obj_to_text(self._children) + + @property + def child_objs(self): + return self._children + + @property + def parents(self): + return _obj_to_text(self._parents) + + @property + def path(self): + config = _obj_to_raw(self._parents) + config.append(self.raw) + return "\n".join(config) + + @property + def has_children(self): + return len(self._children) > 0 + + @property + def has_parents(self): + return len(self._parents) > 0 + + def add_child(self, obj): + if not isinstance(obj, ConfigLine): + raise AssertionError("child must be of type `ConfigLine`") + self._children.append(obj) + + +def ignore_line(text, tokens=None): + for item in tokens or DEFAULT_COMMENT_TOKENS: + if text.startswith(item): + return True + for regex in DEFAULT_IGNORE_LINES_RE: + if regex.match(text): + return True + + +def _obj_to_text(x): + return [o.text for o in x] + + +def _obj_to_raw(x): + return [o.raw for o in x] + + +def _obj_to_block(objects, visited=None): + items = list() + for o in objects: + if o not in items: + items.append(o) + for child in o._children: + if child not in items: + items.append(child) + return _obj_to_raw(items) + + +def dumps(objects, output="block", comments=False): + if output == "block": + items = _obj_to_block(objects) + elif output == "commands": + items = _obj_to_text(objects) + elif output == "raw": + items = _obj_to_raw(objects) + else: + raise TypeError("unknown value supplied for keyword output") + + if output == "block": + if comments: + for index, item in enumerate(items): + nextitem = index + 1 + if ( + nextitem < len(items) + and not item.startswith(" ") + and items[nextitem].startswith(" ") + ): + item = "!\n%s" % item + items[index] = item + items.append("!") + items.append("end") + + return "\n".join(items) + + +class NetworkConfig(object): + def __init__(self, indent=1, contents=None, ignore_lines=None): + self._indent = indent + self._items = list() + self._config_text = None + + if ignore_lines: + for item in ignore_lines: + if not isinstance(item, Pattern): + item = re.compile(item) + DEFAULT_IGNORE_LINES_RE.add(item) + + if contents: + self.load(contents) + + @property + def items(self): + return self._items + + @property + def config_text(self): + return self._config_text + + @property + def sha1(self): + sha1 = hashlib.sha1() + sha1.update(to_bytes(str(self), errors="surrogate_or_strict")) + return sha1.digest() + + def __getitem__(self, key): + for line in self: + if line.text == key: + return line + raise KeyError(key) + + def __iter__(self): + return iter(self._items) + + def __str__(self): + return "\n".join([c.raw for c in self.items]) + + def __len__(self): + return len(self._items) + + def load(self, s): + self._config_text = s + self._items = self.parse(s) + + def loadfp(self, fp): + with open(fp) as f: + return self.load(f.read()) + + def parse(self, lines, comment_tokens=None): + toplevel = re.compile(r"\S") + childline = re.compile(r"^\s*(.+)$") + entry_reg = re.compile(r"([{};])") + + ancestors = list() + config = list() + + indents = [0] + + for linenum, line in enumerate( + to_native(lines, errors="surrogate_or_strict").split("\n") + ): + text = entry_reg.sub("", line).strip() + + cfg = ConfigLine(line) + + if not text or ignore_line(text, comment_tokens): + continue + + # handle top level commands + if toplevel.match(line): + ancestors = [cfg] + indents = [0] + + # handle sub level commands + else: + match = childline.match(line) + line_indent = match.start(1) + + if line_indent < indents[-1]: + while indents[-1] > line_indent: + indents.pop() + + if line_indent > indents[-1]: + indents.append(line_indent) + + curlevel = len(indents) - 1 + parent_level = curlevel - 1 + + cfg._parents = ancestors[:curlevel] + + if curlevel > len(ancestors): + config.append(cfg) + continue + + for i in range(curlevel, len(ancestors)): + ancestors.pop() + + ancestors.append(cfg) + ancestors[parent_level].add_child(cfg) + + config.append(cfg) + + return config + + def get_object(self, path): + for item in self.items: + if item.text == path[-1]: + if item.parents == path[:-1]: + return item + + def get_block(self, path): + if not isinstance(path, list): + raise AssertionError("path argument must be a list object") + obj = self.get_object(path) + if not obj: + raise ValueError("path does not exist in config") + return self._expand_block(obj) + + def get_block_config(self, path): + block = self.get_block(path) + return dumps(block, "block") + + def _expand_block(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj._children: + if child in S: + continue + self._expand_block(child, S) + return S + + def _diff_line(self, other): + updates = list() + for item in self.items: + if item not in other: + updates.append(item) + return updates + + def _diff_strict(self, other): + updates = list() + # block extracted from other does not have all parents + # but the last one. In case of multiple parents we need + # to add additional parents. + if other and isinstance(other, list) and len(other) > 0: + start_other = other[0] + if start_other.parents: + for parent in start_other.parents: + other.insert(0, ConfigLine(parent)) + for index, line in enumerate(self.items): + try: + if str(line).strip() != str(other[index]).strip(): + updates.append(line) + except (AttributeError, IndexError): + updates.append(line) + return updates + + def _diff_exact(self, other): + updates = list() + if len(other) != len(self.items): + updates.extend(self.items) + else: + for ours, theirs in zip(self.items, other): + if ours != theirs: + updates.extend(self.items) + break + return updates + + def difference(self, other, match="line", path=None, replace=None): + """Perform a config diff against the another network config + + :param other: instance of NetworkConfig to diff against + :param match: type of diff to perform. valid values are 'line', + 'strict', 'exact' + :param path: context in the network config to filter the diff + :param replace: the method used to generate the replacement lines. + valid values are 'block', 'line' + + :returns: a string of lines that are different + """ + if path and match != "line": + try: + other = other.get_block(path) + except ValueError: + other = list() + else: + other = other.items + + # generate a list of ConfigLines that aren't in other + meth = getattr(self, "_diff_%s" % match) + updates = meth(other) + + if replace == "block": + parents = list() + for item in updates: + if not item.has_parents: + parents.append(item) + else: + for p in item._parents: + if p not in parents: + parents.append(p) + + updates = list() + for item in parents: + updates.extend(self._expand_block(item)) + + visited = set() + expanded = list() + + for item in updates: + for p in item._parents: + if p.line not in visited: + visited.add(p.line) + expanded.append(p) + expanded.append(item) + visited.add(item.line) + + return expanded + + def add(self, lines, parents=None): + ancestors = list() + offset = 0 + obj = None + + # global config command + if not parents: + for line in lines: + # handle ignore lines + if ignore_line(line): + continue + + item = ConfigLine(line) + item.raw = line + if item not in self.items: + self.items.append(item) + + else: + for index, p in enumerate(parents): + try: + i = index + 1 + obj = self.get_block(parents[:i])[0] + ancestors.append(obj) + + except ValueError: + # add parent to config + offset = index * self._indent + obj = ConfigLine(p) + obj.raw = p.rjust(len(p) + offset) + if ancestors: + obj._parents = list(ancestors) + ancestors[-1]._children.append(obj) + self.items.append(obj) + ancestors.append(obj) + + # add child objects + for line in lines: + # handle ignore lines + if ignore_line(line): + continue + + # check if child already exists + for child in ancestors[-1]._children: + if child.text == line: + break + else: + offset = len(parents) * self._indent + item = ConfigLine(line) + item.raw = line.rjust(len(line) + offset) + item._parents = ancestors + ancestors[-1]._children.append(item) + self.items.append(item) + + +class CustomNetworkConfig(NetworkConfig): + def items_text(self): + return [item.text for item in self.items] + + def expand_section(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj.child_objs: + if child in S: + continue + self.expand_section(child, S) + return S + + def to_block(self, section): + return "\n".join([item.raw for item in section]) + + def get_section(self, path): + try: + section = self.get_section_objects(path) + return self.to_block(section) + except ValueError: + return list() + + def get_section_objects(self, path): + if not isinstance(path, list): + path = [path] + obj = self.get_object(path) + if not obj: + raise ValueError("path does not exist in config") + return self.expand_section(obj) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py new file mode 100644 index 0000000..477d318 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py @@ -0,0 +1,162 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The facts base class +this contains methods common to all facts subsets +""" +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import ( + get_resource_connection, +) +from ansible.module_utils.six import iteritems + + +class FactsBase(object): + """ + The facts base class + """ + + def __init__(self, module): + self._module = module + self._warnings = [] + self._gather_subset = module.params.get("gather_subset") + self._gather_network_resources = module.params.get( + "gather_network_resources" + ) + self._connection = None + if module.params.get("state") not in ["rendered", "parsed"]: + self._connection = get_resource_connection(module) + + self.ansible_facts = {"ansible_network_resources": {}} + self.ansible_facts["ansible_net_gather_network_resources"] = list() + self.ansible_facts["ansible_net_gather_subset"] = list() + + if not self._gather_subset: + self._gather_subset = ["!config"] + if not self._gather_network_resources: + self._gather_network_resources = ["!all"] + + def gen_runable(self, subsets, valid_subsets, resource_facts=False): + """ Generate the runable subset + + :param module: The module instance + :param subsets: The provided subsets + :param valid_subsets: The valid subsets + :param resource_facts: A boolean flag + :rtype: list + :returns: The runable subsets + """ + runable_subsets = set() + exclude_subsets = set() + minimal_gather_subset = set() + if not resource_facts: + minimal_gather_subset = frozenset(["default"]) + + for subset in subsets: + if subset == "all": + runable_subsets.update(valid_subsets) + continue + if subset == "min" and minimal_gather_subset: + runable_subsets.update(minimal_gather_subset) + continue + if subset.startswith("!"): + subset = subset[1:] + if subset == "min": + exclude_subsets.update(minimal_gather_subset) + continue + if subset == "all": + exclude_subsets.update( + valid_subsets - minimal_gather_subset + ) + continue + exclude = True + else: + exclude = False + + if subset not in valid_subsets: + self._module.fail_json( + msg="Subset must be one of [%s], got %s" + % ( + ", ".join(sorted([item for item in valid_subsets])), + subset, + ) + ) + + if exclude: + exclude_subsets.add(subset) + else: + runable_subsets.add(subset) + + if not runable_subsets: + runable_subsets.update(valid_subsets) + runable_subsets.difference_update(exclude_subsets) + return runable_subsets + + def get_network_resources_facts( + self, facts_resource_obj_map, resource_facts_type=None, data=None + ): + """ + :param fact_resource_subsets: + :param data: previously collected configuration + :return: + """ + if not resource_facts_type: + resource_facts_type = self._gather_network_resources + + restorun_subsets = self.gen_runable( + resource_facts_type, + frozenset(facts_resource_obj_map.keys()), + resource_facts=True, + ) + if restorun_subsets: + self.ansible_facts["ansible_net_gather_network_resources"] = list( + restorun_subsets + ) + instances = list() + for key in restorun_subsets: + fact_cls_obj = facts_resource_obj_map.get(key) + if fact_cls_obj: + instances.append(fact_cls_obj(self._module)) + else: + self._warnings.extend( + [ + "network resource fact gathering for '%s' is not supported" + % key + ] + ) + + for inst in instances: + inst.populate_facts(self._connection, self.ansible_facts, data) + + def get_network_legacy_facts( + self, fact_legacy_obj_map, legacy_facts_type=None + ): + if not legacy_facts_type: + legacy_facts_type = self._gather_subset + + runable_subsets = self.gen_runable( + legacy_facts_type, frozenset(fact_legacy_obj_map.keys()) + ) + if runable_subsets: + facts = dict() + # default subset should always returned be with legacy facts subsets + if "default" not in runable_subsets: + runable_subsets.add("default") + self.ansible_facts["ansible_net_gather_subset"] = list( + runable_subsets + ) + + instances = list() + for key in runable_subsets: + instances.append(fact_legacy_obj_map[key](self._module)) + + for inst in instances: + inst.populate() + facts.update(inst.facts) + self._warnings.extend(inst.warnings) + + for key, value in iteritems(facts): + key = "ansible_net_%s" % key + self.ansible_facts[key] = value diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py new file mode 100644 index 0000000..53a91e8 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py @@ -0,0 +1,179 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import sys + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError + +try: + from ncclient.xml_ import NCElement, new_ele, sub_ele + + HAS_NCCLIENT = True +except (ImportError, AttributeError): + HAS_NCCLIENT = False + +try: + from lxml.etree import Element, fromstring, XMLSyntaxError +except ImportError: + from xml.etree.ElementTree import Element, fromstring + + if sys.version_info < (2, 7): + from xml.parsers.expat import ExpatError as XMLSyntaxError + else: + from xml.etree.ElementTree import ParseError as XMLSyntaxError + +NS_MAP = {"nc": "urn:ietf:params:xml:ns:netconf:base:1.0"} + + +def exec_rpc(module, *args, **kwargs): + connection = NetconfConnection(module._socket_path) + return connection.execute_rpc(*args, **kwargs) + + +class NetconfConnection(Connection): + def __init__(self, socket_path): + super(NetconfConnection, self).__init__(socket_path) + + def __rpc__(self, name, *args, **kwargs): + """Executes the json-rpc and returns the output received + from remote device. + :name: rpc method to be executed over connection plugin that implements jsonrpc 2.0 + :args: Ordered list of params passed as arguments to rpc method + :kwargs: Dict of valid key, value pairs passed as arguments to rpc method + + For usage refer the respective connection plugin docs. + """ + self.check_rc = kwargs.pop("check_rc", True) + self.ignore_warning = kwargs.pop("ignore_warning", True) + + response = self._exec_jsonrpc(name, *args, **kwargs) + if "error" in response: + rpc_error = response["error"].get("data") + return self.parse_rpc_error( + to_bytes(rpc_error, errors="surrogate_then_replace") + ) + + return fromstring( + to_bytes(response["result"], errors="surrogate_then_replace") + ) + + def parse_rpc_error(self, rpc_error): + if self.check_rc: + try: + error_root = fromstring(rpc_error) + root = Element("root") + root.append(error_root) + + error_list = root.findall(".//nc:rpc-error", NS_MAP) + if not error_list: + raise ConnectionError( + to_text(rpc_error, errors="surrogate_then_replace") + ) + + warnings = [] + for error in error_list: + message_ele = error.find("./nc:error-message", NS_MAP) + + if message_ele is None: + message_ele = error.find("./nc:error-info", NS_MAP) + + message = ( + message_ele.text if message_ele is not None else None + ) + + severity = error.find("./nc:error-severity", NS_MAP).text + + if ( + severity == "warning" + and self.ignore_warning + and message is not None + ): + warnings.append(message) + else: + raise ConnectionError( + to_text(rpc_error, errors="surrogate_then_replace") + ) + return warnings + except XMLSyntaxError: + raise ConnectionError(rpc_error) + + +def transform_reply(): + return b"""<xsl:stylesheet version="1.0" xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> + <xsl:output method="xml" indent="no"/> + + <xsl:template match="/|comment()|processing-instruction()"> + <xsl:copy> + <xsl:apply-templates/> + </xsl:copy> + </xsl:template> + + <xsl:template match="*"> + <xsl:element name="{local-name()}"> + <xsl:apply-templates select="@*|node()"/> + </xsl:element> + </xsl:template> + + <xsl:template match="@*"> + <xsl:attribute name="{local-name()}"> + <xsl:value-of select="."/> + </xsl:attribute> + </xsl:template> + </xsl:stylesheet> + """ + + +# Note: Workaround for ncclient 0.5.3 +def remove_namespaces(data): + if not HAS_NCCLIENT: + raise ImportError( + "ncclient is required but does not appear to be installed. " + "It can be installed using `pip install ncclient`" + ) + return NCElement(data, transform_reply()).data_xml + + +def build_root_xml_node(tag): + return new_ele(tag) + + +def build_child_xml_node(parent, tag, text=None, attrib=None): + element = sub_ele(parent, tag) + if text: + element.text = to_text(text) + if attrib: + element.attrib.update(attrib) + return element + + +def build_subtree(parent, path): + element = parent + for field in path.split("/"): + sub_element = build_child_xml_node(element, field) + element = sub_element + return element diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py new file mode 100644 index 0000000..555fc71 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py @@ -0,0 +1,275 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com> +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import traceback +import json + +from ansible.module_utils._text import to_text, to_native +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.basic import env_fallback +from ansible.module_utils.connection import Connection, ConnectionError +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import ( + NetconfConnection, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import ( + Cli, +) +from ansible.module_utils.six import iteritems + + +NET_TRANSPORT_ARGS = dict( + host=dict(required=True), + port=dict(type="int"), + username=dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])), + password=dict( + no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]) + ), + ssh_keyfile=dict( + fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path" + ), + authorize=dict( + default=False, + fallback=(env_fallback, ["ANSIBLE_NET_AUTHORIZE"]), + type="bool", + ), + auth_pass=dict( + no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_AUTH_PASS"]) + ), + provider=dict(type="dict", no_log=True), + transport=dict(choices=list()), + timeout=dict(default=10, type="int"), +) + +NET_CONNECTION_ARGS = dict() + +NET_CONNECTIONS = dict() + + +def _transitional_argument_spec(): + argument_spec = {} + for key, value in iteritems(NET_TRANSPORT_ARGS): + value["required"] = False + argument_spec[key] = value + return argument_spec + + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +class ModuleStub(object): + def __init__(self, argument_spec, fail_json): + self.params = dict() + for key, value in argument_spec.items(): + self.params[key] = value.get("default") + self.fail_json = fail_json + + +class NetworkError(Exception): + def __init__(self, msg, **kwargs): + super(NetworkError, self).__init__(msg) + self.kwargs = kwargs + + +class Config(object): + def __init__(self, connection): + self.connection = connection + + def __call__(self, commands, **kwargs): + lines = to_list(commands) + return self.connection.configure(lines, **kwargs) + + def load_config(self, commands, **kwargs): + commands = to_list(commands) + return self.connection.load_config(commands, **kwargs) + + def get_config(self, **kwargs): + return self.connection.get_config(**kwargs) + + def save_config(self): + return self.connection.save_config() + + +class NetworkModule(AnsibleModule): + def __init__(self, *args, **kwargs): + connect_on_load = kwargs.pop("connect_on_load", True) + + argument_spec = NET_TRANSPORT_ARGS.copy() + argument_spec["transport"]["choices"] = NET_CONNECTIONS.keys() + argument_spec.update(NET_CONNECTION_ARGS.copy()) + + if kwargs.get("argument_spec"): + argument_spec.update(kwargs["argument_spec"]) + kwargs["argument_spec"] = argument_spec + + super(NetworkModule, self).__init__(*args, **kwargs) + + self.connection = None + self._cli = None + self._config = None + + try: + transport = self.params["transport"] or "__default__" + cls = NET_CONNECTIONS[transport] + self.connection = cls() + except KeyError: + self.fail_json( + msg="Unknown transport or no default transport specified" + ) + except (TypeError, NetworkError) as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + if connect_on_load: + self.connect() + + @property + def cli(self): + if not self.connected: + self.connect() + if self._cli: + return self._cli + self._cli = Cli(self.connection) + return self._cli + + @property + def config(self): + if not self.connected: + self.connect() + if self._config: + return self._config + self._config = Config(self.connection) + return self._config + + @property + def connected(self): + return self.connection._connected + + def _load_params(self): + super(NetworkModule, self)._load_params() + provider = self.params.get("provider") or dict() + for key, value in provider.items(): + for args in [NET_TRANSPORT_ARGS, NET_CONNECTION_ARGS]: + if key in args: + if self.params.get(key) is None and value is not None: + self.params[key] = value + + def connect(self): + try: + if not self.connected: + self.connection.connect(self.params) + if self.params["authorize"]: + self.connection.authorize(self.params) + self.log( + "connected to %s:%s using %s" + % ( + self.params["host"], + self.params["port"], + self.params["transport"], + ) + ) + except NetworkError as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + def disconnect(self): + try: + if self.connected: + self.connection.disconnect() + self.log("disconnected from %s" % self.params["host"]) + except NetworkError as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + +def register_transport(transport, default=False): + def register(cls): + NET_CONNECTIONS[transport] = cls + if default: + NET_CONNECTIONS["__default__"] = cls + return cls + + return register + + +def add_argument(key, value): + NET_CONNECTION_ARGS[key] = value + + +def get_resource_connection(module): + if hasattr(module, "_connection"): + return module._connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api in ("cliconf", "nxapi", "eapi", "exosapi"): + module._connection = Connection(module._socket_path) + elif network_api == "netconf": + module._connection = NetconfConnection(module._socket_path) + elif network_api == "local": + # This isn't supported, but we shouldn't fail here. + # Set the connection to a fake connection so it fails sensibly. + module._connection = LocalResourceConnection(module) + else: + module.fail_json( + msg="Invalid connection type {0!s}".format(network_api) + ) + + return module._connection + + +def get_capabilities(module): + if hasattr(module, "capabilities"): + return module._capabilities + try: + capabilities = Connection(module._socket_path).get_capabilities() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + except AssertionError: + # No socket_path, connection most likely local. + return dict(network_api="local") + module._capabilities = json.loads(capabilities) + + return module._capabilities + + +class LocalResourceConnection: + def __init__(self, module): + self.module = module + + def get(self, *args, **kwargs): + self.module.fail_json( + msg="Network resource modules not supported over local connection." + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py new file mode 100644 index 0000000..2dd1de9 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py @@ -0,0 +1,316 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com> +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +import shlex +import time + +from ansible.module_utils.parsing.convert_bool import ( + BOOLEANS_TRUE, + BOOLEANS_FALSE, +) +from ansible.module_utils.six import string_types, text_type +from ansible.module_utils.six.moves import zip + + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +class FailedConditionsError(Exception): + def __init__(self, msg, failed_conditions): + super(FailedConditionsError, self).__init__(msg) + self.failed_conditions = failed_conditions + + +class FailedConditionalError(Exception): + def __init__(self, msg, failed_conditional): + super(FailedConditionalError, self).__init__(msg) + self.failed_conditional = failed_conditional + + +class AddCommandError(Exception): + def __init__(self, msg, command): + super(AddCommandError, self).__init__(msg) + self.command = command + + +class AddConditionError(Exception): + def __init__(self, msg, condition): + super(AddConditionError, self).__init__(msg) + self.condition = condition + + +class Cli(object): + def __init__(self, connection): + self.connection = connection + self.default_output = connection.default_output or "text" + self._commands = list() + + @property + def commands(self): + return [str(c) for c in self._commands] + + def __call__(self, commands, output=None): + objects = list() + for cmd in to_list(commands): + objects.append(self.to_command(cmd, output)) + return self.connection.run_commands(objects) + + def to_command( + self, command, output=None, prompt=None, response=None, **kwargs + ): + output = output or self.default_output + if isinstance(command, Command): + return command + if isinstance(prompt, string_types): + prompt = re.compile(re.escape(prompt)) + return Command( + command, output, prompt=prompt, response=response, **kwargs + ) + + def add_commands(self, commands, output=None, **kwargs): + for cmd in commands: + self._commands.append(self.to_command(cmd, output, **kwargs)) + + def run_commands(self): + responses = self.connection.run_commands(self._commands) + for resp, cmd in zip(responses, self._commands): + cmd.response = resp + + # wipe out the commands list to avoid issues if additional + # commands are executed later + self._commands = list() + + return responses + + +class Command(object): + def __init__( + self, command, output=None, prompt=None, response=None, **kwargs + ): + + self.command = command + self.output = output + self.command_string = command + + self.prompt = prompt + self.response = response + + self.args = kwargs + + def __str__(self): + return self.command_string + + +class CommandRunner(object): + def __init__(self, module): + self.module = module + + self.items = list() + self.conditionals = set() + + self.commands = list() + + self.retries = 10 + self.interval = 1 + + self.match = "all" + + self._default_output = module.connection.default_output + + def add_command( + self, command, output=None, prompt=None, response=None, **kwargs + ): + if command in [str(c) for c in self.commands]: + raise AddCommandError( + "duplicated command detected", command=command + ) + cmd = self.module.cli.to_command( + command, output=output, prompt=prompt, response=response, **kwargs + ) + self.commands.append(cmd) + + def get_command(self, command, output=None): + for cmd in self.commands: + if cmd.command == command: + return cmd.response + raise ValueError("command '%s' not found" % command) + + def get_responses(self): + return [cmd.response for cmd in self.commands] + + def add_conditional(self, condition): + try: + self.conditionals.add(Conditional(condition)) + except AttributeError as exc: + raise AddConditionError(msg=str(exc), condition=condition) + + def run(self): + while self.retries > 0: + self.module.cli.add_commands(self.commands) + responses = self.module.cli.run_commands() + + for item in list(self.conditionals): + if item(responses): + if self.match == "any": + return item + self.conditionals.remove(item) + + if not self.conditionals: + break + + time.sleep(self.interval) + self.retries -= 1 + else: + failed_conditions = [item.raw for item in self.conditionals] + errmsg = ( + "One or more conditional statements have not been satisfied" + ) + raise FailedConditionsError(errmsg, failed_conditions) + + +class Conditional(object): + """Used in command modules to evaluate waitfor conditions + """ + + OPERATORS = { + "eq": ["eq", "=="], + "neq": ["neq", "ne", "!="], + "gt": ["gt", ">"], + "ge": ["ge", ">="], + "lt": ["lt", "<"], + "le": ["le", "<="], + "contains": ["contains"], + "matches": ["matches"], + } + + def __init__(self, conditional, encoding=None): + self.raw = conditional + self.negate = False + try: + components = shlex.split(conditional) + key, val = components[0], components[-1] + op_components = components[1:-1] + if "not" in op_components: + self.negate = True + op_components.pop(op_components.index("not")) + op = op_components[0] + + except ValueError: + raise ValueError("failed to parse conditional") + + self.key = key + self.func = self._func(op) + self.value = self._cast_value(val) + + def __call__(self, data): + value = self.get_value(dict(result=data)) + if not self.negate: + return self.func(value) + else: + return not self.func(value) + + def _cast_value(self, value): + if value in BOOLEANS_TRUE: + return True + elif value in BOOLEANS_FALSE: + return False + elif re.match(r"^\d+\.d+$", value): + return float(value) + elif re.match(r"^\d+$", value): + return int(value) + else: + return text_type(value) + + def _func(self, oper): + for func, operators in self.OPERATORS.items(): + if oper in operators: + return getattr(self, func) + raise AttributeError("unknown operator: %s" % oper) + + def get_value(self, result): + try: + return self.get_json(result) + except (IndexError, TypeError, AttributeError): + msg = "unable to apply conditional to result" + raise FailedConditionalError(msg, self.raw) + + def get_json(self, result): + string = re.sub(r"\[[\'|\"]", ".", self.key) + string = re.sub(r"[\'|\"]\]", ".", string) + parts = re.split(r"\.(?=[^\]]*(?:\[|$))", string) + for part in parts: + match = re.findall(r"\[(\S+?)\]", part) + if match: + key = part[: part.find("[")] + result = result[key] + for m in match: + try: + m = int(m) + except ValueError: + m = str(m) + result = result[m] + else: + result = result.get(part) + return result + + def number(self, value): + if "." in str(value): + return float(value) + else: + return int(value) + + def eq(self, value): + return value == self.value + + def neq(self, value): + return value != self.value + + def gt(self, value): + return self.number(value) > self.value + + def ge(self, value): + return self.number(value) >= self.value + + def lt(self, value): + return self.number(value) < self.value + + def le(self, value): + return self.number(value) <= self.value + + def contains(self, value): + return str(self.value) in value + + def matches(self, value): + match = re.search(self.value, value, re.M) + return match is not None diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py new file mode 100644 index 0000000..64eca15 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py @@ -0,0 +1,686 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Networking tools for network modules only + +import re +import ast +import operator +import socket +import json + +from itertools import chain + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils import basic +from ansible.module_utils.parsing.convert_bool import boolean + +# Backwards compatibility for 3rd party modules +# TODO(pabelanger): With move to ansible.netcommon, we should clean this code +# up and have modules import directly themself. +from ansible.module_utils.common.network import ( # noqa: F401 + to_bits, + is_netmask, + is_masklen, + to_netmask, + to_masklen, + to_subnet, + to_ipv6_network, + VALID_MASKS, +) + +try: + from jinja2 import Environment, StrictUndefined + from jinja2.exceptions import UndefinedError + + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False + + +OPERATORS = frozenset(["ge", "gt", "eq", "neq", "lt", "le"]) +ALIASES = frozenset( + [("min", "ge"), ("max", "le"), ("exactly", "eq"), ("neq", "ne")] +) + + +def to_list(val): + if isinstance(val, (list, tuple, set)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +def to_lines(stdout): + for item in stdout: + if isinstance(item, string_types): + item = to_text(item).split("\n") + yield item + + +def transform_commands(module): + transform = ComplexList( + dict( + command=dict(key=True), + output=dict(), + prompt=dict(type="list"), + answer=dict(type="list"), + newline=dict(type="bool", default=True), + sendonly=dict(type="bool", default=False), + check_all=dict(type="bool", default=False), + ), + module, + ) + + return transform(module.params["commands"]) + + +def sort_list(val): + if isinstance(val, list): + return sorted(val) + return val + + +class Entity(object): + """Transforms a dict to with an argument spec + + This class will take a dict and apply an Ansible argument spec to the + values. The resulting dict will contain all of the keys in the param + with appropriate values set. + + Example:: + + argument_spec = dict( + command=dict(key=True), + display=dict(default='text', choices=['text', 'json']), + validate=dict(type='bool') + ) + transform = Entity(module, argument_spec) + value = dict(command='foo') + result = transform(value) + print result + {'command': 'foo', 'display': 'text', 'validate': None} + + Supported argument spec: + * key - specifies how to map a single value to a dict + * read_from - read and apply the argument_spec from the module + * required - a value is required + * type - type of value (uses AnsibleModule type checker) + * fallback - implements fallback function + * choices - set of valid options + * default - default value + """ + + def __init__( + self, module, attrs=None, args=None, keys=None, from_argspec=False + ): + args = [] if args is None else args + + self._attributes = attrs or {} + self._module = module + + for arg in args: + self._attributes[arg] = dict() + if from_argspec: + self._attributes[arg]["read_from"] = arg + if keys and arg in keys: + self._attributes[arg]["key"] = True + + self.attr_names = frozenset(self._attributes.keys()) + + _has_key = False + + for name, attr in iteritems(self._attributes): + if attr.get("read_from"): + if attr["read_from"] not in self._module.argument_spec: + module.fail_json( + msg="argument %s does not exist" % attr["read_from"] + ) + spec = self._module.argument_spec.get(attr["read_from"]) + for key, value in iteritems(spec): + if key not in attr: + attr[key] = value + + if attr.get("key"): + if _has_key: + module.fail_json(msg="only one key value can be specified") + _has_key = True + attr["required"] = True + + def serialize(self): + return self._attributes + + def to_dict(self, value): + obj = {} + for name, attr in iteritems(self._attributes): + if attr.get("key"): + obj[name] = value + else: + obj[name] = attr.get("default") + return obj + + def __call__(self, value, strict=True): + if not isinstance(value, dict): + value = self.to_dict(value) + + if strict: + unknown = set(value).difference(self.attr_names) + if unknown: + self._module.fail_json( + msg="invalid keys: %s" % ",".join(unknown) + ) + + for name, attr in iteritems(self._attributes): + if value.get(name) is None: + value[name] = attr.get("default") + + if attr.get("fallback") and not value.get(name): + fallback = attr.get("fallback", (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + value[name] = fallback_strategy( + *fallback_args, **fallback_kwargs + ) + except basic.AnsibleFallbackNotFound: + continue + + if attr.get("required") and value.get(name) is None: + self._module.fail_json( + msg="missing required attribute %s" % name + ) + + if "choices" in attr: + if value[name] not in attr["choices"]: + self._module.fail_json( + msg="%s must be one of %s, got %s" + % (name, ", ".join(attr["choices"]), value[name]) + ) + + if value[name] is not None: + value_type = attr.get("type", "str") + type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[ + value_type + ] + type_checker(value[name]) + elif value.get(name): + value[name] = self._module.params[name] + + return value + + +class EntityCollection(Entity): + """Extends ```Entity``` to handle a list of dicts """ + + def __call__(self, iterable, strict=True): + if iterable is None: + iterable = [ + super(EntityCollection, self).__call__( + self._module.params, strict + ) + ] + + if not isinstance(iterable, (list, tuple)): + self._module.fail_json(msg="value must be an iterable") + + return [ + (super(EntityCollection, self).__call__(i, strict)) + for i in iterable + ] + + +# these two are for backwards compatibility and can be removed once all of the +# modules that use them are updated +class ComplexDict(Entity): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexDict, self).__init__(module, attrs, *args, **kwargs) + + +class ComplexList(EntityCollection): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexList, self).__init__(module, attrs, *args, **kwargs) + + +def dict_diff(base, comparable): + """ Generate a dict object of differences + + This function will compare two dict objects and return the difference + between them as a dict object. For scalar values, the key will reflect + the updated value. If the key does not exist in `comparable`, then then no + key will be returned. For lists, the value in comparable will wholly replace + the value in base for the key. For dicts, the returned value will only + return keys that are different. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + + :returns: new dict object with differences + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(comparable, dict): + if comparable is None: + comparable = dict() + else: + raise AssertionError("`comparable` must be of type <dict>") + + updates = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + item = comparable.get(key) + if item is not None: + sub_diff = dict_diff(value, comparable[key]) + if sub_diff: + updates[key] = sub_diff + else: + comparable_value = comparable.get(key) + if comparable_value is not None: + if sort_list(base[key]) != sort_list(comparable_value): + updates[key] = comparable_value + + for key in set(comparable.keys()).difference(base.keys()): + updates[key] = comparable.get(key) + + return updates + + +def dict_merge(base, other): + """ Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. If the value is a list object, the two lists will + be combined and duplicate entries removed. + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(other, dict): + raise AssertionError("`other` must be of type <dict>") + + combined = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + if isinstance(other[key], Mapping): + combined[key] = dict_merge(value, other[key]) + else: + combined[key] = other[key] + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if item is not None: + try: + combined[key] = list(set(chain(value, item))) + except TypeError: + value.extend([i for i in item if i not in value]) + combined[key] = value + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) + + return combined + + +def param_list_to_dict(param_list, unique_key="name", remove_key=True): + """Rotates a list of dictionaries to be a dictionary of dictionaries. + + :param param_list: The aforementioned list of dictionaries + :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value + behind this key will be the key each dictionary can be found at in the new root dictionary + :param remove_key: If True, remove unique_key from the individual dictionaries before returning. + """ + param_dict = {} + for params in param_list: + params = params.copy() + if remove_key: + name = params.pop(unique_key) + else: + name = params.get(unique_key) + param_dict[name] = params + + return param_dict + + +def conditional(expr, val, cast=None): + match = re.match(r"^(.+)\((.+)\)$", str(expr), re.I) + if match: + op, arg = match.groups() + else: + op = "eq" + if " " in str(expr): + raise AssertionError("invalid expression: cannot contain spaces") + arg = expr + + if cast is None and val is not None: + arg = type(val)(arg) + elif callable(cast): + arg = cast(arg) + val = cast(val) + + op = next((oper for alias, oper in ALIASES if op == alias), op) + + if not hasattr(operator, op) and op not in OPERATORS: + raise ValueError("unknown operator: %s" % op) + + func = getattr(operator, op) + return func(val, arg) + + +def ternary(value, true_val, false_val): + """ value ? true_val : false_val """ + if value: + return true_val + else: + return false_val + + +def remove_default_spec(spec): + for item in spec: + if "default" in spec[item]: + del spec[item]["default"] + + +def validate_ip_address(address): + try: + socket.inet_aton(address) + except socket.error: + return False + return address.count(".") == 3 + + +def validate_ip_v6_address(address): + try: + socket.inet_pton(socket.AF_INET6, address) + except socket.error: + return False + return True + + +def validate_prefix(prefix): + if prefix and not 0 <= int(prefix) <= 32: + return False + return True + + +def load_provider(spec, args): + provider = args.get("provider") or {} + for key, value in iteritems(spec): + if key not in provider: + if "fallback" in value: + provider[key] = _fallback(value["fallback"]) + elif "default" in value: + provider[key] = value["default"] + else: + provider[key] = None + if "authorize" in provider: + # Coerce authorize to provider if a string has somehow snuck in. + provider["authorize"] = boolean(provider["authorize"] or False) + args["provider"] = provider + return provider + + +def _fallback(fallback): + strategy = fallback[0] + args = [] + kwargs = {} + + for item in fallback[1:]: + if isinstance(item, dict): + kwargs = item + else: + args = item + try: + return strategy(*args, **kwargs) + except basic.AnsibleFallbackNotFound: + pass + + +def generate_dict(spec): + """ + Generate dictionary which is in sync with argspec + + :param spec: A dictionary that is the argspec of the module + :rtype: A dictionary + :returns: A dictionary in sync with argspec with default value + """ + obj = {} + if not spec: + return obj + + for key, val in iteritems(spec): + if "default" in val: + dct = {key: val["default"]} + elif "type" in val and val["type"] == "dict": + dct = {key: generate_dict(val["options"])} + else: + dct = {key: None} + obj.update(dct) + return obj + + +def parse_conf_arg(cfg, arg): + """ + Parse config based on argument + + :param cfg: A text string which is a line of configuration. + :param arg: A text string which is to be matched. + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r"%s (.+)(\n|$)" % arg, cfg, re.M) + if match: + result = match.group(1).strip() + else: + result = None + return result + + +def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str="no"): + """ + Parse config based on command + + :param cfg: A text string which is a line of configuration. + :param cmd: A text string which is the command to be matched + :param res1: A text string to be returned if the command is present + :param res2: A text string to be returned if the negate command + is present + :param delete_str: A text string to identify the start of the + negate command + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r"\n\s+%s(\n|$)" % cmd, cfg) + if match: + return res1 + if res2 is not None: + match = re.search(r"\n\s+%s %s(\n|$)" % (delete_str, cmd), cfg) + if match: + return res2 + return None + + +def get_xml_conf_arg(cfg, path, data="text"): + """ + :param cfg: The top level configuration lxml Element tree object + :param path: The relative xpath w.r.t to top level element (cfg) + to be searched in the xml hierarchy + :param data: The type of data to be returned for the matched xml node. + Valid values are text, tag, attrib, with default as text. + :return: Returns the required type for the matched xml node or else None + """ + match = cfg.xpath(path) + if len(match): + if data == "tag": + result = getattr(match[0], "tag") + elif data == "attrib": + result = getattr(match[0], "attrib") + else: + result = getattr(match[0], "text") + else: + result = None + return result + + +def remove_empties(cfg_dict): + """ + Generate final config dictionary + + :param cfg_dict: A dictionary parsed in the facts system + :rtype: A dictionary + :returns: A dictionary by eliminating keys that have null values + """ + final_cfg = {} + if not cfg_dict: + return final_cfg + + for key, val in iteritems(cfg_dict): + dct = None + if isinstance(val, dict): + child_val = remove_empties(val) + if child_val: + dct = {key: child_val} + elif ( + isinstance(val, list) + and val + and all([isinstance(x, dict) for x in val]) + ): + child_val = [remove_empties(x) for x in val] + if child_val: + dct = {key: child_val} + elif val not in [None, [], {}, (), ""]: + dct = {key: val} + if dct: + final_cfg.update(dct) + return final_cfg + + +def validate_config(spec, data): + """ + Validate if the input data against the AnsibleModule spec format + :param spec: Ansible argument spec + :param data: Data to be validated + :return: + """ + params = basic._ANSIBLE_ARGS + basic._ANSIBLE_ARGS = to_bytes(json.dumps({"ANSIBLE_MODULE_ARGS": data})) + validated_data = basic.AnsibleModule(spec).params + basic._ANSIBLE_ARGS = params + return validated_data + + +def search_obj_in_list(name, lst, key="name"): + if not lst: + return None + else: + for item in lst: + if item.get(key) == name: + return item + + +class Template: + def __init__(self): + if not HAS_JINJA2: + raise ImportError( + "jinja2 is required but does not appear to be installed. " + "It can be installed using `pip install jinja2`" + ) + + self.env = Environment(undefined=StrictUndefined) + self.env.filters.update({"ternary": ternary}) + + def __call__(self, value, variables=None, fail_on_undefined=True): + variables = variables or {} + + if not self.contains_vars(value): + return value + + try: + value = self.env.from_string(value).render(variables) + except UndefinedError: + if not fail_on_undefined: + return None + raise + + if value: + try: + return ast.literal_eval(value) + except Exception: + return str(value) + else: + return None + + def contains_vars(self, data): + if isinstance(data, string_types): + for marker in ( + self.env.block_start_string, + self.env.variable_start_string, + self.env.comment_start_string, + ): + if marker in data: + return True + return False diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py new file mode 100644 index 0000000..1f03299 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py @@ -0,0 +1,147 @@ +# +# (c) 2018 Red Hat, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +import json + +from copy import deepcopy +from contextlib import contextmanager + +try: + from lxml.etree import fromstring, tostring +except ImportError: + from xml.etree.ElementTree import fromstring, tostring + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import ( + NetconfConnection, +) + + +IGNORE_XML_ATTRIBUTE = () + + +def get_connection(module): + if hasattr(module, "_netconf_connection"): + return module._netconf_connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api == "netconf": + module._netconf_connection = NetconfConnection(module._socket_path) + else: + module.fail_json(msg="Invalid connection type %s" % network_api) + + return module._netconf_connection + + +def get_capabilities(module): + if hasattr(module, "_netconf_capabilities"): + return module._netconf_capabilities + + capabilities = Connection(module._socket_path).get_capabilities() + module._netconf_capabilities = json.loads(capabilities) + return module._netconf_capabilities + + +def lock_configuration(module, target=None): + conn = get_connection(module) + return conn.lock(target=target) + + +def unlock_configuration(module, target=None): + conn = get_connection(module) + return conn.unlock(target=target) + + +@contextmanager +def locked_config(module, target=None): + try: + lock_configuration(module, target=target) + yield + finally: + unlock_configuration(module, target=target) + + +def get_config(module, source, filter=None, lock=False): + conn = get_connection(module) + try: + locked = False + if lock: + conn.lock(target=source) + locked = True + response = conn.get_config(source=source, filter=filter) + + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + finally: + if locked: + conn.unlock(target=source) + + return response + + +def get(module, filter, lock=False): + conn = get_connection(module) + try: + locked = False + if lock: + conn.lock(target="running") + locked = True + + response = conn.get(filter=filter) + + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + finally: + if locked: + conn.unlock(target="running") + + return response + + +def dispatch(module, request): + conn = get_connection(module) + try: + response = conn.dispatch(request) + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + return response + + +def sanitize_xml(data): + tree = fromstring( + to_bytes(deepcopy(data), errors="surrogate_then_replace") + ) + for element in tree.getiterator(): + # remove attributes + attribute = element.attrib + if attribute: + for key in list(attribute): + if key not in IGNORE_XML_ATTRIBUTE: + attribute.pop(key) + return to_text(tostring(tree), errors="surrogate_then_replace").strip() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py new file mode 100644 index 0000000..fba46be --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py @@ -0,0 +1,61 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2018 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from ansible.module_utils.connection import Connection + + +def get(module, path=None, content=None, fields=None, output="json"): + if path is None: + raise ValueError("path value must be provided") + if content: + path += "?" + "content=%s" % content + if fields: + path += "?" + "field=%s" % fields + + accept = None + if output == "xml": + accept = "application/yang-data+xml" + + connection = Connection(module._socket_path) + return connection.send_request( + None, path=path, method="GET", accept=accept + ) + + +def edit_config(module, path=None, content=None, method="GET", format="json"): + if path is None: + raise ValueError("path value must be provided") + + content_type = None + if format == "xml": + content_type = "application/yang-data+xml" + + connection = Connection(module._socket_path) + return connection.send_request( + content, path=path, method=method, content_type=content_type + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py new file mode 100644 index 0000000..c1384c1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py @@ -0,0 +1,444 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: cli_config +author: Trishna Guha (@trishnaguha) +notes: +- The commands will be returned only for platforms that do not support onbox diff. + The C(--diff) option with the playbook will return the difference in configuration + for devices that has support for onbox diff +short_description: Push text based configuration to network devices over network_cli +description: +- This module provides platform agnostic way of pushing text based configuration to + network devices over network_cli connection plugin. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + config: + description: + - The config to be pushed to the network device. This argument is mutually exclusive + with C(rollback) and either one of the option should be given as input. The + config should have indentation that the device uses. + type: str + commit: + description: + - The C(commit) argument instructs the module to push the configuration to the + device. This is mapped to module check mode. + type: bool + replace: + description: + - If the C(replace) argument is set to C(yes), it will replace the entire running-config + of the device with the C(config) argument value. For devices that support replacing + running configuration from file on device like NXOS/JUNOS, the C(replace) argument + takes path to the file on the device that will be used for replacing the entire + running-config. The value of C(config) option should be I(None) for such devices. + Nexus 9K devices only support replace. Use I(net_put) or I(nxos_file_copy) in + case of NXOS module to copy the flat file to remote device and then use set + the fullpath to this argument. + type: str + backup: + description: + - This argument will cause the module to create a full backup of the current running + config from the remote device before any changes are made. If the C(backup_options) + value is not given, the backup file is written to the C(backup) folder in the + playbook root directory or role root directory, if playbook is part of an ansible + role. If the directory does not exist, it is created. + type: bool + default: 'no' + rollback: + description: + - The C(rollback) argument instructs the module to rollback the current configuration + to the identifier specified in the argument. If the specified rollback identifier + does not exist on the remote device, the module will fail. To rollback to the + most recent commit, set the C(rollback) argument to 0. This option is mutually + exclusive with C(config). + commit_comment: + description: + - The C(commit_comment) argument specifies a text string to be used when committing + the configuration. If the C(commit) argument is set to False, this argument + is silently ignored. This argument is only valid for the platforms that support + commit operation with comment. + type: str + defaults: + description: + - The I(defaults) argument will influence how the running-config is collected + from the device. When the value is set to true, the command used to collect + the running-config is append with the all keyword. When the value is set to + false, the command is issued without the all keyword. + default: 'no' + type: bool + multiline_delimiter: + description: + - This argument is used when pushing a multiline configuration element to the + device. It specifies the character to use as the delimiting character. This + only applies to the configuration action. + type: str + diff_replace: + description: + - Instructs the module on the way to perform the configuration on the device. + If the C(diff_replace) argument is set to I(line) then the modified lines are + pushed to the device in configuration mode. If the argument is set to I(block) + then the entire command block is pushed to the device in configuration mode + if any line is not correct. Note that this parameter will be ignored if the + platform has onbox diff support. + choices: + - line + - block + - config + diff_match: + description: + - Instructs the module on the way to perform the matching of the set of commands + against the current device config. If C(diff_match) is set to I(line), commands + are matched line by line. If C(diff_match) is set to I(strict), command lines + are matched with respect to position. If C(diff_match) is set to I(exact), command + lines must be an equal match. Finally, if C(diff_match) is set to I(none), the + module will not attempt to compare the source configuration with the running + configuration on the remote device. Note that this parameter will be ignored + if the platform has onbox diff support. + choices: + - line + - strict + - exact + - none + diff_ignore_lines: + description: + - Use this argument to specify one or more lines that should be ignored during + the diff. This is used for lines in the configuration that are automatically + updated by the system. This argument takes a list of regular expressions or + exact line matches. Note that this parameter will be ignored if the platform + has onbox diff support. + backup_options: + description: + - This is a dict object containing configurable options related to backup file + path. The value of this option is read only when C(backup) is set to I(yes), + if C(backup) is set to I(no) this option will be silently ignored. + suboptions: + filename: + description: + - The filename to be used to store the backup configuration. If the filename + is not given it will be generated based on the hostname, current time and + date in format defined by <hostname>_config.<current-date>@<current-time> + dir_path: + description: + - This option provides the path ending with directory name in which the backup + configuration file will be stored. If the directory does not exist it will + be first created and the filename is either the value of C(filename) or + default filename as described in C(filename) options description. If the + path value is not given in that case a I(backup) directory will be created + in the current working directory and backup configuration will be copied + in C(filename) within I(backup) directory. + type: path + type: dict +""" + +EXAMPLES = """ +- name: configure device with config + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + +- name: multiline config + cli_config: + config: | + hostname foo + feature nxapi + +- name: configure device with config with defaults enabled + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + defaults: yes + +- name: Use diff_match + cli_config: + config: "{{ lookup('file', 'interface_config') }}" + diff_match: none + +- name: nxos replace config + cli_config: + replace: 'bootflash:nxoscfg' + +- name: junos replace config + cli_config: + replace: '/var/home/ansible/junos01.cfg' + +- name: commit with comment + cli_config: + config: set system host-name foo + commit_comment: this is a test + +- name: configurable backup path + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + backup: yes + backup_options: + filename: backup.cfg + dir_path: /home/user +""" + +RETURN = """ +commands: + description: The set of commands that will be pushed to the remote device + returned: always + type: list + sample: ['interface Loopback999', 'no shutdown'] +backup_path: + description: The full path to the backup file + returned: when backup is yes + type: str + sample: /playbooks/ansible/backup/hostname_config.2016-07-16@22:28:34 +""" + +import json + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.connection import Connection +from ansible.module_utils._text import to_text + + +def validate_args(module, device_operations): + """validate param if it is supported on the platform + """ + feature_list = [ + "replace", + "rollback", + "commit_comment", + "defaults", + "multiline_delimiter", + "diff_replace", + "diff_match", + "diff_ignore_lines", + ] + + for feature in feature_list: + if module.params[feature]: + supports_feature = device_operations.get("supports_%s" % feature) + if supports_feature is None: + module.fail_json( + "This platform does not specify whether %s is supported or not. " + "Please report an issue against this platform's cliconf plugin." + % feature + ) + elif not supports_feature: + module.fail_json( + msg="Option %s is not supported on this platform" % feature + ) + + +def run( + module, device_operations, connection, candidate, running, rollback_id +): + result = {} + resp = {} + config_diff = [] + banner_diff = {} + + replace = module.params["replace"] + commit_comment = module.params["commit_comment"] + multiline_delimiter = module.params["multiline_delimiter"] + diff_replace = module.params["diff_replace"] + diff_match = module.params["diff_match"] + diff_ignore_lines = module.params["diff_ignore_lines"] + + commit = not module.check_mode + + if replace in ("yes", "true", "True"): + replace = True + elif replace in ("no", "false", "False"): + replace = False + + if ( + replace is not None + and replace not in [True, False] + and candidate is not None + ): + module.fail_json( + msg="Replace value '%s' is a configuration file path already" + " present on the device. Hence 'replace' and 'config' options" + " are mutually exclusive" % replace + ) + + if rollback_id is not None: + resp = connection.rollback(rollback_id, commit) + if "diff" in resp: + result["changed"] = True + + elif device_operations.get("supports_onbox_diff"): + if diff_replace: + module.warn( + "diff_replace is ignored as the device supports onbox diff" + ) + if diff_match: + module.warn( + "diff_mattch is ignored as the device supports onbox diff" + ) + if diff_ignore_lines: + module.warn( + "diff_ignore_lines is ignored as the device supports onbox diff" + ) + + if candidate and not isinstance(candidate, list): + candidate = candidate.strip("\n").splitlines() + + kwargs = { + "candidate": candidate, + "commit": commit, + "replace": replace, + "comment": commit_comment, + } + resp = connection.edit_config(**kwargs) + + if "diff" in resp: + result["changed"] = True + + elif device_operations.get("supports_generate_diff"): + kwargs = {"candidate": candidate, "running": running} + if diff_match: + kwargs.update({"diff_match": diff_match}) + if diff_replace: + kwargs.update({"diff_replace": diff_replace}) + if diff_ignore_lines: + kwargs.update({"diff_ignore_lines": diff_ignore_lines}) + + diff_response = connection.get_diff(**kwargs) + + config_diff = diff_response.get("config_diff") + banner_diff = diff_response.get("banner_diff") + + if config_diff: + if isinstance(config_diff, list): + candidate = config_diff + else: + candidate = config_diff.splitlines() + + kwargs = { + "candidate": candidate, + "commit": commit, + "replace": replace, + "comment": commit_comment, + } + if commit: + connection.edit_config(**kwargs) + result["changed"] = True + result["commands"] = config_diff.split("\n") + + if banner_diff: + candidate = json.dumps(banner_diff) + + kwargs = {"candidate": candidate, "commit": commit} + if multiline_delimiter: + kwargs.update({"multiline_delimiter": multiline_delimiter}) + if commit: + connection.edit_banner(**kwargs) + result["changed"] = True + + if module._diff: + if "diff" in resp: + result["diff"] = {"prepared": resp["diff"]} + else: + diff = "" + if config_diff: + if isinstance(config_diff, list): + diff += "\n".join(config_diff) + else: + diff += config_diff + if banner_diff: + diff += json.dumps(banner_diff) + result["diff"] = {"prepared": diff} + + return result + + +def main(): + """main entry point for execution + """ + backup_spec = dict(filename=dict(), dir_path=dict(type="path")) + argument_spec = dict( + backup=dict(default=False, type="bool"), + backup_options=dict(type="dict", options=backup_spec), + config=dict(type="str"), + commit=dict(type="bool"), + replace=dict(type="str"), + rollback=dict(type="int"), + commit_comment=dict(type="str"), + defaults=dict(default=False, type="bool"), + multiline_delimiter=dict(type="str"), + diff_replace=dict(choices=["line", "block", "config"]), + diff_match=dict(choices=["line", "strict", "exact", "none"]), + diff_ignore_lines=dict(type="list"), + ) + + mutually_exclusive = [("config", "rollback")] + required_one_of = [["backup", "config", "rollback"]] + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=mutually_exclusive, + required_one_of=required_one_of, + supports_check_mode=True, + ) + + result = {"changed": False} + + connection = Connection(module._socket_path) + capabilities = module.from_json(connection.get_capabilities()) + + if capabilities: + device_operations = capabilities.get("device_operations", dict()) + validate_args(module, device_operations) + else: + device_operations = dict() + + if module.params["defaults"]: + if "get_default_flag" in capabilities.get("rpc"): + flags = connection.get_default_flag() + else: + flags = "all" + else: + flags = [] + + candidate = module.params["config"] + candidate = ( + to_text(candidate, errors="surrogate_then_replace") + if candidate + else None + ) + running = connection.get_config(flags=flags) + rollback_id = module.params["rollback"] + + if module.params["backup"]: + result["__backup__"] = running + + if candidate or rollback_id or module.params["replace"]: + try: + result.update( + run( + module, + device_operations, + connection, + candidate, + running, + rollback_id, + ) + ) + except Exception as exc: + module.fail_json(msg=to_text(exc)) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py new file mode 100644 index 0000000..f0910f5 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py @@ -0,0 +1,71 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: net_get +author: Deepak Agrawal (@dagrawal) +short_description: Copy a file from a network device to Ansible Controller +description: +- This module provides functionality to copy file from network device to ansible controller. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + src: + description: + - Specifies the source file. The path to the source file can either be the full + path on the network device or a relative path as per path supported by destination + network device. + required: true + protocol: + description: + - Protocol used to transfer file. + default: scp + choices: + - scp + - sftp + dest: + description: + - Specifies the destination file. The path to the destination file can either + be the full path on the Ansible control host or a relative path from the playbook + or role root directory. + default: + - Same filename as specified in I(src). The path will be playbook root or role + root directory if playbook is part of a role. +requirements: +- scp +notes: +- Some devices need specific configurations to be enabled before scp can work These + configuration should be pre-configured before using this module e.g ios - C(ip scp + server enable). +- User privilege to do scp on network device should be pre-configured e.g. ios - need + user privilege 15 by default for allowing scp. +- Default destination of source file. +""" + +EXAMPLES = """ +- name: copy file from the network device to Ansible controller + net_get: + src: running_cfg_ios1.txt + +- name: copy file from ios to common location at /tmp + net_get: + src: running_cfg_sw1.txt + dest : /tmp/ios1.txt +""" + +RETURN = """ +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py new file mode 100644 index 0000000..2fc4a98 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py @@ -0,0 +1,82 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: net_put +author: Deepak Agrawal (@dagrawal) +short_description: Copy a file from Ansible Controller to a network device +description: +- This module provides functionality to copy file from Ansible controller to network + devices. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + src: + description: + - Specifies the source file. The path to the source file can either be the full + path on the Ansible control host or a relative path from the playbook or role + root directory. + required: true + protocol: + description: + - Protocol used to transfer file. + default: scp + choices: + - scp + - sftp + dest: + description: + - Specifies the destination file. The path to destination file can either be the + full path or relative path as supported by network_os. + default: + - Filename from src and at default directory of user shell on network_os. + required: false + mode: + description: + - Set the file transfer mode. If mode is set to I(text) then I(src) file will + go through Jinja2 template engine to replace any vars if present in the src + file. If mode is set to I(binary) then file will be copied as it is to destination + device. + default: binary + choices: + - binary + - text +requirements: +- scp +notes: +- Some devices need specific configurations to be enabled before scp can work These + configuration should be pre-configured before using this module e.g ios - C(ip scp + server enable). +- User privilege to do scp on network device should be pre-configured e.g. ios - need + user privilege 15 by default for allowing scp. +- Default destination of source file. +""" + +EXAMPLES = """ +- name: copy file from ansible controller to a network device + net_put: + src: running_cfg_ios1.txt + +- name: copy file at root dir of flash in slot 3 of sw1(ios) + net_put: + src: running_cfg_sw1.txt + protocol: sftp + dest : flash3:/running_cfg_sw1.txt +""" + +RETURN = """ +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py new file mode 100644 index 0000000..e9332f2 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py @@ -0,0 +1,70 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +netconf: default +short_description: Use default netconf plugin to run standard netconf commands as + per RFC +description: +- This default plugin provides low level abstraction apis for sending and receiving + netconf commands as per Netconf RFC specification. +options: + ncclient_device_handler: + type: str + default: default + description: + - Specifies the ncclient device handler name for network os that support default + netconf implementation as per Netconf RFC specification. To identify the ncclient + device handler name refer ncclient library documentation. +""" +import json + +from ansible.module_utils._text import to_text +from ansible.plugins.netconf import NetconfBase + + +class Netconf(NetconfBase): + def get_text(self, ele, tag): + try: + return to_text( + ele.find(tag).text, errors="surrogate_then_replace" + ).strip() + except AttributeError: + pass + + def get_device_info(self): + device_info = dict() + device_info["network_os"] = "default" + return device_info + + def get_capabilities(self): + result = dict() + result["rpc"] = self.get_base_rpc() + result["network_api"] = "netconf" + result["device_info"] = self.get_device_info() + result["server_capabilities"] = [c for c in self.m.server_capabilities] + result["client_capabilities"] = [c for c in self.m.client_capabilities] + result["session_id"] = self.m.session_id + result["device_operations"] = self.get_device_operations( + result["server_capabilities"] + ) + return json.dumps(result) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py new file mode 100644 index 0000000..a38a775 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/plugin_utils/connection_base.py @@ -0,0 +1,185 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# (c) 2015 Toshio Kuratomi <tkuratomi@ansible.com> +# (c) 2017, Peter Sprygada <psprygad@redhat.com> +# (c) 2017 Ansible Project +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os + +from ansible import constants as C +from ansible.plugins.connection import ConnectionBase +from ansible.plugins.loader import connection_loader +from ansible.utils.display import Display +from ansible.utils.path import unfrackpath + +display = Display() + + +__all__ = ["NetworkConnectionBase"] + +BUFSIZE = 65536 + + +class NetworkConnectionBase(ConnectionBase): + """ + A base class for network-style connections. + """ + + force_persistence = True + # Do not use _remote_is_local in other connections + _remote_is_local = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(NetworkConnectionBase, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._messages = [] + self._conn_closed = False + + self._network_os = self._play_context.network_os + + self._local = connection_loader.get("local", play_context, "/dev/null") + self._local.set_options() + + self._sub_plugin = {} + self._cached_variables = (None, None, None) + + # reconstruct the socket_path and set instance values accordingly + self._ansible_playbook_pid = kwargs.get("ansible_playbook_pid") + self._update_connection_state() + + def __getattr__(self, name): + try: + return self.__dict__[name] + except KeyError: + if not name.startswith("_"): + plugin = self._sub_plugin.get("obj") + if plugin: + method = getattr(plugin, name, None) + if method is not None: + return method + raise AttributeError( + "'%s' object has no attribute '%s'" + % (self.__class__.__name__, name) + ) + + def exec_command(self, cmd, in_data=None, sudoable=True): + return self._local.exec_command(cmd, in_data, sudoable) + + def queue_message(self, level, message): + """ + Adds a message to the queue of messages waiting to be pushed back to the controller process. + + :arg level: A string which can either be the name of a method in display, or 'log'. When + the messages are returned to task_executor, a value of log will correspond to + ``display.display(message, log_only=True)``, while another value will call ``display.[level](message)`` + """ + self._messages.append((level, message)) + + def pop_messages(self): + messages, self._messages = self._messages, [] + return messages + + def put_file(self, in_path, out_path): + """Transfer a file from local to remote""" + return self._local.put_file(in_path, out_path) + + def fetch_file(self, in_path, out_path): + """Fetch a file from remote to local""" + return self._local.fetch_file(in_path, out_path) + + def reset(self): + """ + Reset the connection + """ + if self._socket_path: + self.queue_message( + "vvvv", + "resetting persistent connection for socket_path %s" + % self._socket_path, + ) + self.close() + self.queue_message("vvvv", "reset call on connection instance") + + def close(self): + self._conn_closed = True + if self._connected: + self._connected = False + + def get_options(self, hostvars=None): + options = super(NetworkConnectionBase, self).get_options( + hostvars=hostvars + ) + + if ( + self._sub_plugin.get("obj") + and self._sub_plugin.get("type") != "external" + ): + try: + options.update( + self._sub_plugin["obj"].get_options(hostvars=hostvars) + ) + except AttributeError: + pass + + return options + + def set_options(self, task_keys=None, var_options=None, direct=None): + super(NetworkConnectionBase, self).set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + if self.get_option("persistent_log_messages"): + warning = ( + "Persistent connection logging is enabled for %s. This will log ALL interactions" + % self._play_context.remote_addr + ) + logpath = getattr(C, "DEFAULT_LOG_PATH") + if logpath is not None: + warning += " to %s" % logpath + self.queue_message( + "warning", + "%s and WILL NOT redact sensitive configuration like passwords. USE WITH CAUTION!" + % warning, + ) + + if ( + self._sub_plugin.get("obj") + and self._sub_plugin.get("type") != "external" + ): + try: + self._sub_plugin["obj"].set_options( + task_keys=task_keys, var_options=var_options, direct=direct + ) + except AttributeError: + pass + + def _update_connection_state(self): + """ + Reconstruct the connection socket_path and check if it exists + + If the socket path exists then the connection is active and set + both the _socket_path value to the path and the _connected value + to True. If the socket path doesn't exist, leave the socket path + value to None and the _connected value to False + """ + ssh = connection_loader.get("ssh", class_only=True) + control_path = ssh._create_control_path( + self._play_context.remote_addr, + self._play_context.port, + self._play_context.remote_user, + self._play_context.connection, + self._ansible_playbook_pid, + ) + + tmp_path = unfrackpath(C.PERSISTENT_CONTROL_PATH_DIR) + socket_path = unfrackpath(control_path % dict(directory=tmp_path)) + + if os.path.exists(socket_path): + self._connected = True + self._socket_path = socket_path + + def _log_messages(self, message): + if self.get_option("persistent_log_messages"): + self.queue_message("log", message) diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py new file mode 100644 index 0000000..e3605d0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py @@ -0,0 +1,133 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import sys +import copy + +from ansible_collections.ansible.netcommon.plugins.action.network import ( + ActionModule as ActionNetworkModule, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + load_provider, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + ios_provider_spec, +) +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionNetworkModule): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + module_name = self._task.action.split(".")[-1] + self._config_module = True if module_name == "ios_config" else False + persistent_connection = self._play_context.connection.split(".")[-1] + warnings = [] + + if persistent_connection == "network_cli": + provider = self._task.args.get("provider", {}) + if any(provider.values()): + display.warning( + "provider is unnecessary when using network_cli and will be ignored" + ) + del self._task.args["provider"] + elif self._play_context.connection == "local": + provider = load_provider(ios_provider_spec, self._task.args) + pc = copy.deepcopy(self._play_context) + pc.connection = "ansible.netcommon.network_cli" + pc.network_os = "cisco.ios.ios" + pc.remote_addr = provider["host"] or self._play_context.remote_addr + pc.port = int(provider["port"] or self._play_context.port or 22) + pc.remote_user = ( + provider["username"] or self._play_context.connection_user + ) + pc.password = provider["password"] or self._play_context.password + pc.private_key_file = ( + provider["ssh_keyfile"] or self._play_context.private_key_file + ) + pc.become = provider["authorize"] or False + if pc.become: + pc.become_method = "enable" + pc.become_pass = provider["auth_pass"] + + connection = self._shared_loader_obj.connection_loader.get( + "ansible.netcommon.persistent", + pc, + sys.stdin, + task_uuid=self._task._uuid, + ) + + # TODO: Remove below code after ansible minimal is cut out + if connection is None: + pc.connection = "network_cli" + pc.network_os = "ios" + connection = self._shared_loader_obj.connection_loader.get( + "persistent", pc, sys.stdin, task_uuid=self._task._uuid + ) + + display.vvv( + "using connection plugin %s (was local)" % pc.connection, + pc.remote_addr, + ) + + command_timeout = ( + int(provider["timeout"]) + if provider["timeout"] + else connection.get_option("persistent_command_timeout") + ) + connection.set_options( + direct={"persistent_command_timeout": command_timeout} + ) + + socket_path = connection.run() + display.vvvv("socket_path: %s" % socket_path, pc.remote_addr) + if not socket_path: + return { + "failed": True, + "msg": "unable to open shell. Please see: " + + "https://docs.ansible.com/ansible/latest/network/user_guide/network_debug_troubleshooting.html#category-unable-to-open-shell", + } + + task_vars["ansible_socket"] = socket_path + warnings.append( + [ + "connection local support for this module is deprecated and will be removed in version 2.14, use connection %s" + % pc.connection + ] + ) + else: + return { + "failed": True, + "msg": "Connection type %s is not valid for this module" + % self._play_context.connection, + } + + result = super(ActionModule, self).run(task_vars=task_vars) + if warnings: + if "warnings" in result: + result["warnings"].extend(warnings) + else: + result["warnings"] = warnings + return result diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py new file mode 100644 index 0000000..feba971 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py @@ -0,0 +1,466 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """ +--- +author: Ansible Networking Team +cliconf: ios +short_description: Use ios cliconf to run command on Cisco IOS platform +description: + - This ios plugin provides low level abstraction apis for + sending and receiving CLI commands from Cisco IOS network devices. +version_added: "2.4" +""" + +import re +import time +import json + +from collections.abc import Mapping + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_text +from ansible.module_utils.six import iteritems +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import ( + NetworkConfig, + dumps, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible.plugins.cliconf import CliconfBase, enable_mode + + +class Cliconf(CliconfBase): + @enable_mode + def get_config(self, source="running", flags=None, format=None): + if source not in ("running", "startup"): + raise ValueError( + "fetching configuration from %s is not supported" % source + ) + + if format: + raise ValueError( + "'format' value %s is not supported for get_config" % format + ) + + if not flags: + flags = [] + if source == "running": + cmd = "show running-config " + else: + cmd = "show startup-config " + + cmd += " ".join(to_list(flags)) + cmd = cmd.strip() + + return self.send_command(cmd) + + def get_diff( + self, + candidate=None, + running=None, + diff_match="line", + diff_ignore_lines=None, + path=None, + diff_replace="line", + ): + """ + Generate diff between candidate and running configuration. If the + remote host supports onbox diff capabilities ie. supports_onbox_diff in that case + candidate and running configurations are not required to be passed as argument. + In case if onbox diff capability is not supported candidate argument is mandatory + and running argument is optional. + :param candidate: The configuration which is expected to be present on remote host. + :param running: The base configuration which is used to generate diff. + :param diff_match: Instructs how to match the candidate configuration with current device configuration + Valid values are 'line', 'strict', 'exact', 'none'. + 'line' - commands are matched line by line + 'strict' - command lines are matched with respect to position + 'exact' - command lines must be an equal match + 'none' - will not compare the candidate configuration with the running configuration + :param diff_ignore_lines: Use this argument to specify one or more lines that should be + ignored during the diff. This is used for lines in the configuration + that are automatically updated by the system. This argument takes + a list of regular expressions or exact line matches. + :param path: The ordered set of parents that uniquely identify the section or hierarchy + the commands should be checked against. If the parents argument + is omitted, the commands are checked against the set of top + level or global commands. + :param diff_replace: Instructs on the way to perform the configuration on the device. + If the replace argument is set to I(line) then the modified lines are + pushed to the device in configuration mode. If the replace argument is + set to I(block) then the entire command block is pushed to the device in + configuration mode if any line is not correct. + :return: Configuration diff in json format. + { + 'config_diff': '', + 'banner_diff': {} + } + + """ + diff = {} + device_operations = self.get_device_operations() + option_values = self.get_option_values() + + if candidate is None and device_operations["supports_generate_diff"]: + raise ValueError( + "candidate configuration is required to generate diff" + ) + + if diff_match not in option_values["diff_match"]: + raise ValueError( + "'match' value %s in invalid, valid values are %s" + % (diff_match, ", ".join(option_values["diff_match"])) + ) + + if diff_replace not in option_values["diff_replace"]: + raise ValueError( + "'replace' value %s in invalid, valid values are %s" + % (diff_replace, ", ".join(option_values["diff_replace"])) + ) + + # prepare candidate configuration + candidate_obj = NetworkConfig(indent=1) + want_src, want_banners = self._extract_banners(candidate) + candidate_obj.load(want_src) + + if running and diff_match != "none": + # running configuration + have_src, have_banners = self._extract_banners(running) + running_obj = NetworkConfig( + indent=1, contents=have_src, ignore_lines=diff_ignore_lines + ) + configdiffobjs = candidate_obj.difference( + running_obj, path=path, match=diff_match, replace=diff_replace + ) + + else: + configdiffobjs = candidate_obj.items + have_banners = {} + + diff["config_diff"] = ( + dumps(configdiffobjs, "commands") if configdiffobjs else "" + ) + banners = self._diff_banners(want_banners, have_banners) + diff["banner_diff"] = banners if banners else {} + return diff + + @enable_mode + def edit_config( + self, candidate=None, commit=True, replace=None, comment=None + ): + resp = {} + operations = self.get_device_operations() + self.check_edit_config_capability( + operations, candidate, commit, replace, comment + ) + + results = [] + requests = [] + if commit: + self.send_command("configure terminal") + for line in to_list(candidate): + if not isinstance(line, Mapping): + line = {"command": line} + + cmd = line["command"] + if cmd != "end" and cmd[0] != "!": + results.append(self.send_command(**line)) + requests.append(cmd) + + self.send_command("end") + else: + raise ValueError("check mode is not supported") + + resp["request"] = requests + resp["response"] = results + return resp + + def edit_macro( + self, candidate=None, commit=True, replace=None, comment=None + ): + """ + ios_config: + lines: "{{ macro_lines }}" + parents: "macro name {{ macro_name }}" + after: '@' + match: line + replace: block + """ + resp = {} + operations = self.get_device_operations() + self.check_edit_config_capability( + operations, candidate, commit, replace, comment + ) + + results = [] + requests = [] + if commit: + commands = "" + self.send_command("config terminal") + time.sleep(0.1) + # first item: macro command + commands += candidate.pop(0) + "\n" + multiline_delimiter = candidate.pop(-1) + for line in candidate: + commands += " " + line + "\n" + commands += multiline_delimiter + "\n" + obj = {"command": commands, "sendonly": True} + results.append(self.send_command(**obj)) + requests.append(commands) + + time.sleep(0.1) + self.send_command("end", sendonly=True) + time.sleep(0.1) + results.append(self.send_command("\n")) + requests.append("\n") + + resp["request"] = requests + resp["response"] = results + return resp + + def get( + self, + command=None, + prompt=None, + answer=None, + sendonly=False, + output=None, + newline=True, + check_all=False, + ): + if not command: + raise ValueError("must provide value of command to execute") + if output: + raise ValueError( + "'output' value %s is not supported for get" % output + ) + + return self.send_command( + command=command, + prompt=prompt, + answer=answer, + sendonly=sendonly, + newline=newline, + check_all=check_all, + ) + + def get_device_info(self): + device_info = {} + + device_info["network_os"] = "ios" + reply = self.get(command="show version") + data = to_text(reply, errors="surrogate_or_strict").strip() + + match = re.search(r"Version (\S+)", data) + if match: + device_info["network_os_version"] = match.group(1).strip(",") + + model_search_strs = [ + r"^[Cc]isco (.+) \(revision", + r"^[Cc]isco (\S+).+bytes of .*memory", + ] + for item in model_search_strs: + match = re.search(item, data, re.M) + if match: + version = match.group(1).split(" ") + device_info["network_os_model"] = version[0] + break + + match = re.search(r"^(.+) uptime", data, re.M) + if match: + device_info["network_os_hostname"] = match.group(1) + + match = re.search(r'image file is "(.+)"', data) + if match: + device_info["network_os_image"] = match.group(1) + + return device_info + + def get_device_operations(self): + return { + "supports_diff_replace": True, + "supports_commit": False, + "supports_rollback": False, + "supports_defaults": True, + "supports_onbox_diff": False, + "supports_commit_comment": False, + "supports_multiline_delimiter": True, + "supports_diff_match": True, + "supports_diff_ignore_lines": True, + "supports_generate_diff": True, + "supports_replace": False, + } + + def get_option_values(self): + return { + "format": ["text"], + "diff_match": ["line", "strict", "exact", "none"], + "diff_replace": ["line", "block"], + "output": [], + } + + def get_capabilities(self): + result = super(Cliconf, self).get_capabilities() + result["rpc"] += [ + "edit_banner", + "get_diff", + "run_commands", + "get_defaults_flag", + ] + result["device_operations"] = self.get_device_operations() + result.update(self.get_option_values()) + return json.dumps(result) + + def edit_banner( + self, candidate=None, multiline_delimiter="@", commit=True + ): + """ + Edit banner on remote device + :param banners: Banners to be loaded in json format + :param multiline_delimiter: Line delimiter for banner + :param commit: Boolean value that indicates if the device candidate + configuration should be pushed in the running configuration or discarded. + :param diff: Boolean flag to indicate if configuration that is applied on remote host should + generated and returned in response or not + :return: Returns response of executing the configuration command received + from remote host + """ + resp = {} + banners_obj = json.loads(candidate) + results = [] + requests = [] + if commit: + for key, value in iteritems(banners_obj): + key += " %s" % multiline_delimiter + self.send_command("config terminal", sendonly=True) + for cmd in [key, value, multiline_delimiter]: + obj = {"command": cmd, "sendonly": True} + results.append(self.send_command(**obj)) + requests.append(cmd) + + self.send_command("end", sendonly=True) + time.sleep(0.1) + results.append(self.send_command("\n")) + requests.append("\n") + + resp["request"] = requests + resp["response"] = results + + return resp + + def run_commands(self, commands=None, check_rc=True): + if commands is None: + raise ValueError("'commands' value is required") + + responses = list() + for cmd in to_list(commands): + if not isinstance(cmd, Mapping): + cmd = {"command": cmd} + + output = cmd.pop("output", None) + if output: + raise ValueError( + "'output' value %s is not supported for run_commands" + % output + ) + + try: + out = self.send_command(**cmd) + except AnsibleConnectionFailure as e: + if check_rc: + raise + out = getattr(e, "err", to_text(e)) + + responses.append(out) + + return responses + + def get_defaults_flag(self): + """ + The method identifies the filter that should be used to fetch running-configuration + with defaults. + :return: valid default filter + """ + out = self.get("show running-config ?") + out = to_text(out, errors="surrogate_then_replace") + + commands = set() + for line in out.splitlines(): + if line.strip(): + commands.add(line.strip().split()[0]) + + if "all" in commands: + return "all" + else: + return "full" + + def set_cli_prompt_context(self): + """ + Make sure we are in the operational cli mode + :return: None + """ + if self._connection.connected: + out = self._connection.get_prompt() + + if out is None: + raise AnsibleConnectionFailure( + message=u"cli prompt is not identified from the last received" + u" response window: %s" + % self._connection._last_recv_window + ) + + if re.search( + r"config.*\)#", + to_text(out, errors="surrogate_then_replace").strip(), + ): + self._connection.queue_message( + "vvvv", "wrong context, sending end to device" + ) + self._connection.send_command("end") + + def _extract_banners(self, config): + banners = {} + banner_cmds = re.findall(r"^banner (\w+)", config, re.M) + for cmd in banner_cmds: + regex = r"banner %s \^C(.+?)(?=\^C)" % cmd + match = re.search(regex, config, re.S) + if match: + key = "banner %s" % cmd + banners[key] = match.group(1).strip() + + for cmd in banner_cmds: + regex = r"banner %s \^C(.+?)(?=\^C)" % cmd + match = re.search(regex, config, re.S) + if match: + config = config.replace(str(match.group(1)), "") + + config = re.sub(r"banner \w+ \^C\^C", "!! banner removed", config) + return config, banners + + def _diff_banners(self, want, have): + candidate = {} + for key, value in iteritems(want): + if value != have.get(key): + candidate[key] = value + return candidate diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py new file mode 100644 index 0000000..ff22d27 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Peter Sprygada <psprygada@ansible.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: + provider: + description: + - B(Deprecated) + - 'Starting with Ansible 2.5 we recommend using C(connection: network_cli).' + - For more information please see the L(IOS Platform Options guide, ../network/user_guide/platform_ios.html). + - HORIZONTALLINE + - A dict object containing connection details. + type: dict + suboptions: + host: + description: + - Specifies the DNS host name or address for connecting to the remote device + over the specified transport. The value of host is used as the destination + address for the transport. + type: str + required: true + port: + description: + - Specifies the port to use when building the connection to the remote device. + type: int + default: 22 + username: + description: + - Configures the username to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME) + will be used instead. + type: str + password: + description: + - Specifies the password to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) + will be used instead. + type: str + timeout: + description: + - Specifies the timeout in seconds for communicating with the network device + for either connecting or sending commands. If the timeout is exceeded before + the operation is completed, the module will error. + type: int + default: 10 + ssh_keyfile: + description: + - Specifies the SSH key to use to authenticate the connection to the remote + device. This value is the path to the key used to authenticate the SSH + session. If the value is not specified in the task, the value of environment + variable C(ANSIBLE_NET_SSH_KEYFILE) will be used instead. + type: path + authorize: + description: + - Instructs the module to enter privileged mode on the remote device before + sending any commands. If not specified, the device will attempt to execute + all commands in non-privileged mode. If the value is not specified in the + task, the value of environment variable C(ANSIBLE_NET_AUTHORIZE) will be + used instead. + type: bool + default: false + auth_pass: + description: + - Specifies the password to use if required to enter privileged mode on the + remote device. If I(authorize) is false, then this argument does nothing. + If the value is not specified in the task, the value of environment variable + C(ANSIBLE_NET_AUTH_PASS) will be used instead. + type: str +notes: +- For more information on using Ansible to manage network devices see the :ref:`Ansible + Network Guide <network_guide>` +- For more information on using Ansible to manage Cisco devices see the `Cisco integration + page <https://www.ansible.com/integrations/networks/cisco>`_. +""" diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py new file mode 100644 index 0000000..6818a0c --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py @@ -0,0 +1,197 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import env_fallback +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible.module_utils.connection import Connection, ConnectionError + +_DEVICE_CONFIGS = {} + +ios_provider_spec = { + "host": dict(), + "port": dict(type="int"), + "username": dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])), + "password": dict( + fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]), no_log=True + ), + "ssh_keyfile": dict( + fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path" + ), + "authorize": dict( + fallback=(env_fallback, ["ANSIBLE_NET_AUTHORIZE"]), type="bool" + ), + "auth_pass": dict( + fallback=(env_fallback, ["ANSIBLE_NET_AUTH_PASS"]), no_log=True + ), + "timeout": dict(type="int"), +} +ios_argument_spec = { + "provider": dict( + type="dict", options=ios_provider_spec, removed_in_version=2.14 + ) +} + + +def get_provider_argspec(): + return ios_provider_spec + + +def get_connection(module): + if hasattr(module, "_ios_connection"): + return module._ios_connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api == "cliconf": + module._ios_connection = Connection(module._socket_path) + else: + module.fail_json(msg="Invalid connection type %s" % network_api) + + return module._ios_connection + + +def get_capabilities(module): + if hasattr(module, "_ios_capabilities"): + return module._ios_capabilities + try: + capabilities = Connection(module._socket_path).get_capabilities() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + module._ios_capabilities = json.loads(capabilities) + return module._ios_capabilities + + +def get_defaults_flag(module): + connection = get_connection(module) + try: + out = connection.get_defaults_flag() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + return to_text(out, errors="surrogate_then_replace").strip() + + +def get_config(module, flags=None): + flags = to_list(flags) + + section_filter = False + if flags and "section" in flags[-1]: + section_filter = True + + flag_str = " ".join(flags) + + try: + return _DEVICE_CONFIGS[flag_str] + except KeyError: + connection = get_connection(module) + try: + out = connection.get_config(flags=flags) + except ConnectionError as exc: + if section_filter: + # Some ios devices don't understand `| section foo` + out = get_config(module, flags=flags[:-1]) + else: + module.fail_json( + msg=to_text(exc, errors="surrogate_then_replace") + ) + cfg = to_text(out, errors="surrogate_then_replace").strip() + _DEVICE_CONFIGS[flag_str] = cfg + return cfg + + +def run_commands(module, commands, check_rc=True): + connection = get_connection(module) + try: + return connection.run_commands(commands=commands, check_rc=check_rc) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc)) + + +def load_config(module, commands): + connection = get_connection(module) + + try: + resp = connection.edit_config(commands) + return resp.get("response") + except ConnectionError as exc: + module.fail_json(msg=to_text(exc)) + + +def normalize_interface(name): + """Return the normalized interface name + """ + if not name: + return + + def _get_number(name): + digits = "" + for char in name: + if char.isdigit() or char in "/.": + digits += char + return digits + + if name.lower().startswith("gi"): + if_type = "GigabitEthernet" + elif name.lower().startswith("te"): + if_type = "TenGigabitEthernet" + elif name.lower().startswith("fa"): + if_type = "FastEthernet" + elif name.lower().startswith("fo"): + if_type = "FortyGigabitEthernet" + elif name.lower().startswith("et"): + if_type = "Ethernet" + elif name.lower().startswith("vl"): + if_type = "Vlan" + elif name.lower().startswith("lo"): + if_type = "loopback" + elif name.lower().startswith("po"): + if_type = "port-channel" + elif name.lower().startswith("nv"): + if_type = "nve" + elif name.lower().startswith("twe"): + if_type = "TwentyFiveGigE" + elif name.lower().startswith("hu"): + if_type = "HundredGigE" + else: + if_type = None + + number_list = name.split(" ") + if len(number_list) == 2: + if_number = number_list[-1].strip() + else: + if_number = _get_number(name) + + if if_type: + proper_interface = if_type + if_number + else: + proper_interface = name + + return proper_interface diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py new file mode 100644 index 0000000..ef383fc --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py @@ -0,0 +1,229 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: ios_command +author: Peter Sprygada (@privateip) +short_description: Run commands on remote devices running Cisco IOS +description: +- Sends arbitrary commands to an ios node and returns the results read from the device. + This module includes an argument that will cause the module to wait for a specific + condition before returning or timing out if the condition is not met. +- This module does not support running commands in configuration mode. Please use + M(ios_config) to configure IOS devices. +extends_documentation_fragment: +- cisco.ios.ios +notes: +- Tested against IOS 15.6 +options: + commands: + description: + - List of commands to send to the remote ios device over the configured provider. + The resulting output from the command is returned. If the I(wait_for) argument + is provided, the module is not returned until the condition is satisfied or + the number of retries has expired. If a command sent to the device requires + answering a prompt, it is possible to pass a dict containing I(command), I(answer) + and I(prompt). Common answers are 'y' or "\r" (carriage return, must be double + quotes). See examples. + required: true + wait_for: + description: + - List of conditions to evaluate against the output of the command. The task will + wait for each condition to be true before moving forward. If the conditional + is not true within the configured number of retries, the task fails. See examples. + aliases: + - waitfor + match: + description: + - The I(match) argument is used in conjunction with the I(wait_for) argument to + specify the match policy. Valid values are C(all) or C(any). If the value + is set to C(all) then all conditionals in the wait_for must be satisfied. If + the value is set to C(any) then only one of the values must be satisfied. + default: all + choices: + - any + - all + retries: + description: + - Specifies the number of retries a command should by tried before it is considered + failed. The command is run on the target device every retry and evaluated against + the I(wait_for) conditions. + default: 10 + interval: + description: + - Configures the interval in seconds to wait between retries of the command. If + the command does not pass the specified conditions, the interval indicates how + long to wait before trying the command again. + default: 1 +""" + +EXAMPLES = r""" +tasks: + - name: run show version on remote devices + ios_command: + commands: show version + + - name: run show version and check to see if output contains IOS + ios_command: + commands: show version + wait_for: result[0] contains IOS + + - name: run multiple commands on remote nodes + ios_command: + commands: + - show version + - show interfaces + + - name: run multiple commands and evaluate the output + ios_command: + commands: + - show version + - show interfaces + wait_for: + - result[0] contains IOS + - result[1] contains Loopback0 + + - name: run commands that require answering a prompt + ios_command: + commands: + - command: 'clear counters GigabitEthernet0/1' + prompt: 'Clear "show interface" counters on this interface \[confirm\]' + answer: 'y' + - command: 'clear counters GigabitEthernet0/2' + prompt: '[confirm]' + answer: "\r" +""" + +RETURN = """ +stdout: + description: The set of responses from the commands + returned: always apart from low level errors (such as action plugin) + type: list + sample: ['...', '...'] +stdout_lines: + description: The value of stdout split into a list + returned: always apart from low level errors (such as action plugin) + type: list + sample: [['...', '...'], ['...'], ['...']] +failed_conditions: + description: The list of conditionals that have failed + returned: failed + type: list + sample: ['...', '...'] +""" +import time + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import ( + Conditional, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + transform_commands, + to_lines, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + run_commands, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + ios_argument_spec, +) + + +def parse_commands(module, warnings): + commands = transform_commands(module) + + if module.check_mode: + for item in list(commands): + if not item["command"].startswith("show"): + warnings.append( + "Only show commands are supported when using check mode, not " + "executing %s" % item["command"] + ) + commands.remove(item) + + return commands + + +def main(): + """main entry point for module execution + """ + argument_spec = dict( + commands=dict(type="list", required=True), + wait_for=dict(type="list", aliases=["waitfor"]), + match=dict(default="all", choices=["all", "any"]), + retries=dict(default=10, type="int"), + interval=dict(default=1, type="int"), + ) + + argument_spec.update(ios_argument_spec) + + module = AnsibleModule( + argument_spec=argument_spec, supports_check_mode=True + ) + + warnings = list() + result = {"changed": False, "warnings": warnings} + commands = parse_commands(module, warnings) + wait_for = module.params["wait_for"] or list() + + try: + conditionals = [Conditional(c) for c in wait_for] + except AttributeError as exc: + module.fail_json(msg=to_text(exc)) + + retries = module.params["retries"] + interval = module.params["interval"] + match = module.params["match"] + + while retries > 0: + responses = run_commands(module, commands) + + for item in list(conditionals): + if item(responses): + if match == "any": + conditionals = list() + break + conditionals.remove(item) + + if not conditionals: + break + + time.sleep(interval) + retries -= 1 + + if conditionals: + failed_conditions = [item.raw for item in conditionals] + msg = "One or more conditional statements have not been satisfied" + module.fail_json(msg=msg, failed_conditions=failed_conditions) + + result.update( + {"stdout": responses, "stdout_lines": list(to_lines(responses))} + ) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py new file mode 100644 index 0000000..beec5b8 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py @@ -0,0 +1,596 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: ios_config +author: Peter Sprygada (@privateip) +short_description: Manage Cisco IOS configuration sections +description: +- Cisco IOS configurations use a simple block indent file syntax for segmenting configuration + into sections. This module provides an implementation for working with IOS configuration + sections in a deterministic way. +extends_documentation_fragment: +- cisco.ios.ios +notes: +- Tested against IOS 15.6 +- Abbreviated commands are NOT idempotent, see L(Network FAQ,../network/user_guide/faq.html#why-do-the-config-modules-always-return-changed-true-with-abbreviated-commands). +options: + lines: + description: + - The ordered set of commands that should be configured in the section. The commands + must be the exact same commands as found in the device running-config. Be sure + to note the configuration command syntax as some commands are automatically + modified by the device config parser. + aliases: + - commands + parents: + description: + - The ordered set of parents that uniquely identify the section or hierarchy the + commands should be checked against. If the parents argument is omitted, the + commands are checked against the set of top level or global commands. + src: + description: + - Specifies the source path to the file that contains the configuration or configuration + template to load. The path to the source file can either be the full path on + the Ansible control host or a relative path from the playbook or role root directory. This + argument is mutually exclusive with I(lines), I(parents). + before: + description: + - The ordered set of commands to push on to the command stack if a change needs + to be made. This allows the playbook designer the opportunity to perform configuration + commands prior to pushing any changes without affecting how the set of commands + are matched against the system. + after: + description: + - The ordered set of commands to append to the end of the command stack if a change + needs to be made. Just like with I(before) this allows the playbook designer + to append a set of commands to be executed after the command set. + match: + description: + - Instructs the module on the way to perform the matching of the set of commands + against the current device config. If match is set to I(line), commands are + matched line by line. If match is set to I(strict), command lines are matched + with respect to position. If match is set to I(exact), command lines must be + an equal match. Finally, if match is set to I(none), the module will not attempt + to compare the source configuration with the running configuration on the remote + device. + choices: + - line + - strict + - exact + - none + default: line + replace: + description: + - Instructs the module on the way to perform the configuration on the device. + If the replace argument is set to I(line) then the modified lines are pushed + to the device in configuration mode. If the replace argument is set to I(block) + then the entire command block is pushed to the device in configuration mode + if any line is not correct. + default: line + choices: + - line + - block + multiline_delimiter: + description: + - This argument is used when pushing a multiline configuration element to the + IOS device. It specifies the character to use as the delimiting character. This + only applies to the configuration action. + default: '@' + backup: + description: + - This argument will cause the module to create a full backup of the current C(running-config) + from the remote device before any changes are made. If the C(backup_options) + value is not given, the backup file is written to the C(backup) folder in the + playbook root directory or role root directory, if playbook is part of an ansible + role. If the directory does not exist, it is created. + type: bool + default: 'no' + running_config: + description: + - The module, by default, will connect to the remote device and retrieve the current + running-config to use as a base for comparing against the contents of source. + There are times when it is not desirable to have the task get the current running-config + for every task in a playbook. The I(running_config) argument allows the implementer + to pass in the configuration to use as the base config for comparison. + aliases: + - config + defaults: + description: + - This argument specifies whether or not to collect all defaults when getting + the remote device running config. When enabled, the module will get the current + config by issuing the command C(show running-config all). + type: bool + default: 'no' + save_when: + description: + - When changes are made to the device running-configuration, the changes are not + copied to non-volatile storage by default. Using this argument will change + that before. If the argument is set to I(always), then the running-config will + always be copied to the startup-config and the I(modified) flag will always + be set to True. If the argument is set to I(modified), then the running-config + will only be copied to the startup-config if it has changed since the last save + to startup-config. If the argument is set to I(never), the running-config will + never be copied to the startup-config. If the argument is set to I(changed), + then the running-config will only be copied to the startup-config if the task + has made a change. I(changed) was added in Ansible 2.5. + default: never + choices: + - always + - never + - modified + - changed + diff_against: + description: + - When using the C(ansible-playbook --diff) command line argument the module can + generate diffs against different sources. + - When this option is configure as I(startup), the module will return the diff + of the running-config against the startup-config. + - When this option is configured as I(intended), the module will return the diff + of the running-config against the configuration provided in the C(intended_config) + argument. + - When this option is configured as I(running), the module will return the before + and after diff of the running-config with respect to any changes made to the + device configuration. + choices: + - running + - startup + - intended + diff_ignore_lines: + description: + - Use this argument to specify one or more lines that should be ignored during + the diff. This is used for lines in the configuration that are automatically + updated by the system. This argument takes a list of regular expressions or + exact line matches. + intended_config: + description: + - The C(intended_config) provides the master configuration that the node should + conform to and is used to check the final running-config against. This argument + will not modify any settings on the remote device and is strictly used to check + the compliance of the current device's configuration against. When specifying + this argument, the task should also modify the C(diff_against) value and set + it to I(intended). + backup_options: + description: + - This is a dict object containing configurable options related to backup file + path. The value of this option is read only when C(backup) is set to I(yes), + if C(backup) is set to I(no) this option will be silently ignored. + suboptions: + filename: + description: + - The filename to be used to store the backup configuration. If the filename + is not given it will be generated based on the hostname, current time and + date in format defined by <hostname>_config.<current-date>@<current-time> + dir_path: + description: + - This option provides the path ending with directory name in which the backup + configuration file will be stored. If the directory does not exist it will + be first created and the filename is either the value of C(filename) or + default filename as described in C(filename) options description. If the + path value is not given in that case a I(backup) directory will be created + in the current working directory and backup configuration will be copied + in C(filename) within I(backup) directory. + type: path + type: dict +""" + +EXAMPLES = """ +- name: configure top level configuration + ios_config: + lines: hostname {{ inventory_hostname }} + +- name: configure interface settings + ios_config: + lines: + - description test interface + - ip address 172.31.1.1 255.255.255.0 + parents: interface Ethernet1 + +- name: configure ip helpers on multiple interfaces + ios_config: + lines: + - ip helper-address 172.26.1.10 + - ip helper-address 172.26.3.8 + parents: "{{ item }}" + with_items: + - interface Ethernet1 + - interface Ethernet2 + - interface GigabitEthernet1 + +- name: configure policer in Scavenger class + ios_config: + lines: + - conform-action transmit + - exceed-action drop + parents: + - policy-map Foo + - class Scavenger + - police cir 64000 + +- name: load new acl into device + ios_config: + lines: + - 10 permit ip host 192.0.2.1 any log + - 20 permit ip host 192.0.2.2 any log + - 30 permit ip host 192.0.2.3 any log + - 40 permit ip host 192.0.2.4 any log + - 50 permit ip host 192.0.2.5 any log + parents: ip access-list extended test + before: no ip access-list extended test + match: exact + +- name: check the running-config against master config + ios_config: + diff_against: intended + intended_config: "{{ lookup('file', 'master.cfg') }}" + +- name: check the startup-config against the running-config + ios_config: + diff_against: startup + diff_ignore_lines: + - ntp clock .* + +- name: save running to startup when modified + ios_config: + save_when: modified + +- name: for idempotency, use full-form commands + ios_config: + lines: + # - shut + - shutdown + # parents: int gig1/0/11 + parents: interface GigabitEthernet1/0/11 + +# Set boot image based on comparison to a group_var (version) and the version +# that is returned from the `ios_facts` module +- name: SETTING BOOT IMAGE + ios_config: + lines: + - no boot system + - boot system flash bootflash:{{new_image}} + host: "{{ inventory_hostname }}" + when: ansible_net_version != version + +- name: render a Jinja2 template onto an IOS device + ios_config: + backup: yes + src: ios_template.j2 + +- name: configurable backup path + ios_config: + src: ios_template.j2 + backup: yes + backup_options: + filename: backup.cfg + dir_path: /home/user +""" + +RETURN = """ +updates: + description: The set of commands that will be pushed to the remote device + returned: always + type: list + sample: ['hostname foo', 'router ospf 1', 'router-id 192.0.2.1'] +commands: + description: The set of commands that will be pushed to the remote device + returned: always + type: list + sample: ['hostname foo', 'router ospf 1', 'router-id 192.0.2.1'] +backup_path: + description: The full path to the backup file + returned: when backup is yes + type: str + sample: /playbooks/ansible/backup/ios_config.2016-07-16@22:28:34 +filename: + description: The name of the backup file + returned: when backup is yes and filename is not specified in backup options + type: str + sample: ios_config.2016-07-16@22:28:34 +shortname: + description: The full path to the backup file excluding the timestamp + returned: when backup is yes and filename is not specified in backup options + type: str + sample: /playbooks/ansible/backup/ios_config +date: + description: The date extracted from the backup file name + returned: when backup is yes + type: str + sample: "2016-07-16" +time: + description: The time extracted from the backup file name + returned: when backup is yes + type: str + sample: "22:28:34" +""" +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import ConnectionError +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + run_commands, + get_config, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + get_defaults_flag, + get_connection, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + ios_argument_spec, +) +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import ( + NetworkConfig, + dumps, +) + + +def check_args(module, warnings): + if module.params["multiline_delimiter"]: + if len(module.params["multiline_delimiter"]) != 1: + module.fail_json( + msg="multiline_delimiter value can only be a " + "single character" + ) + + +def edit_config_or_macro(connection, commands): + # only catch the macro configuration command, + # not negated 'no' variation. + if commands[0].startswith("macro name"): + connection.edit_macro(candidate=commands) + else: + connection.edit_config(candidate=commands) + + +def get_candidate_config(module): + candidate = "" + if module.params["src"]: + candidate = module.params["src"] + + elif module.params["lines"]: + candidate_obj = NetworkConfig(indent=1) + parents = module.params["parents"] or list() + candidate_obj.add(module.params["lines"], parents=parents) + candidate = dumps(candidate_obj, "raw") + + return candidate + + +def get_running_config(module, current_config=None, flags=None): + running = module.params["running_config"] + if not running: + if not module.params["defaults"] and current_config: + running = current_config + else: + running = get_config(module, flags=flags) + + return running + + +def save_config(module, result): + result["changed"] = True + if not module.check_mode: + run_commands(module, "copy running-config startup-config\r") + else: + module.warn( + "Skipping command `copy running-config startup-config` " + "due to check_mode. Configuration not copied to " + "non-volatile storage" + ) + + +def main(): + """ main entry point for module execution + """ + backup_spec = dict(filename=dict(), dir_path=dict(type="path")) + argument_spec = dict( + src=dict(type="path"), + lines=dict(aliases=["commands"], type="list"), + parents=dict(type="list"), + before=dict(type="list"), + after=dict(type="list"), + match=dict( + default="line", choices=["line", "strict", "exact", "none"] + ), + replace=dict(default="line", choices=["line", "block"]), + multiline_delimiter=dict(default="@"), + running_config=dict(aliases=["config"]), + intended_config=dict(), + defaults=dict(type="bool", default=False), + backup=dict(type="bool", default=False), + backup_options=dict(type="dict", options=backup_spec), + save_when=dict( + choices=["always", "never", "modified", "changed"], default="never" + ), + diff_against=dict(choices=["startup", "intended", "running"]), + diff_ignore_lines=dict(type="list"), + ) + + argument_spec.update(ios_argument_spec) + + mutually_exclusive = [("lines", "src"), ("parents", "src")] + + required_if = [ + ("match", "strict", ["lines"]), + ("match", "exact", ["lines"]), + ("replace", "block", ["lines"]), + ("diff_against", "intended", ["intended_config"]), + ] + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=mutually_exclusive, + required_if=required_if, + supports_check_mode=True, + ) + + result = {"changed": False} + + warnings = list() + check_args(module, warnings) + result["warnings"] = warnings + + diff_ignore_lines = module.params["diff_ignore_lines"] + config = None + contents = None + flags = get_defaults_flag(module) if module.params["defaults"] else [] + connection = get_connection(module) + + if module.params["backup"] or ( + module._diff and module.params["diff_against"] == "running" + ): + contents = get_config(module, flags=flags) + config = NetworkConfig(indent=1, contents=contents) + if module.params["backup"]: + result["__backup__"] = contents + + if any((module.params["lines"], module.params["src"])): + match = module.params["match"] + replace = module.params["replace"] + path = module.params["parents"] + + candidate = get_candidate_config(module) + running = get_running_config(module, contents, flags=flags) + try: + response = connection.get_diff( + candidate=candidate, + running=running, + diff_match=match, + diff_ignore_lines=diff_ignore_lines, + path=path, + diff_replace=replace, + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + config_diff = response["config_diff"] + banner_diff = response["banner_diff"] + + if config_diff or banner_diff: + commands = config_diff.split("\n") + + if module.params["before"]: + commands[:0] = module.params["before"] + + if module.params["after"]: + commands.extend(module.params["after"]) + + result["commands"] = commands + result["updates"] = commands + result["banners"] = banner_diff + + # send the configuration commands to the device and merge + # them with the current running config + if not module.check_mode: + if commands: + edit_config_or_macro(connection, commands) + if banner_diff: + connection.edit_banner( + candidate=json.dumps(banner_diff), + multiline_delimiter=module.params[ + "multiline_delimiter" + ], + ) + + result["changed"] = True + + running_config = module.params["running_config"] + startup_config = None + + if module.params["save_when"] == "always": + save_config(module, result) + elif module.params["save_when"] == "modified": + output = run_commands( + module, ["show running-config", "show startup-config"] + ) + + running_config = NetworkConfig( + indent=1, contents=output[0], ignore_lines=diff_ignore_lines + ) + startup_config = NetworkConfig( + indent=1, contents=output[1], ignore_lines=diff_ignore_lines + ) + + if running_config.sha1 != startup_config.sha1: + save_config(module, result) + elif module.params["save_when"] == "changed" and result["changed"]: + save_config(module, result) + + if module._diff: + if not running_config: + output = run_commands(module, "show running-config") + contents = output[0] + else: + contents = running_config + + # recreate the object in order to process diff_ignore_lines + running_config = NetworkConfig( + indent=1, contents=contents, ignore_lines=diff_ignore_lines + ) + + if module.params["diff_against"] == "running": + if module.check_mode: + module.warn( + "unable to perform diff against running-config due to check mode" + ) + contents = None + else: + contents = config.config_text + + elif module.params["diff_against"] == "startup": + if not startup_config: + output = run_commands(module, "show startup-config") + contents = output[0] + else: + contents = startup_config.config_text + + elif module.params["diff_against"] == "intended": + contents = module.params["intended_config"] + + if contents is not None: + base_config = NetworkConfig( + indent=1, contents=contents, ignore_lines=diff_ignore_lines + ) + + if running_config.sha1 != base_config.sha1: + if module.params["diff_against"] == "intended": + before = running_config + after = base_config + elif module.params["diff_against"] in ("startup", "running"): + before = base_config + after = running_config + + result.update( + { + "changed": True, + "diff": {"before": str(before), "after": str(after)}, + } + ) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py new file mode 100644 index 0000000..29f31b0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py @@ -0,0 +1,115 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import json +import re + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_text, to_bytes +from ansible.plugins.terminal import TerminalBase +from ansible.utils.display import Display + +display = Display() + + +class TerminalModule(TerminalBase): + + terminal_stdout_re = [ + re.compile(br"[\r\n]?[\w\+\-\.:\/\[\]]+(?:\([^\)]+\)){0,3}(?:[>#]) ?$") + ] + + terminal_stderr_re = [ + re.compile(br"% ?Error"), + # re.compile(br"^% \w+", re.M), + re.compile(br"% ?Bad secret"), + re.compile(br"[\r\n%] Bad passwords"), + re.compile(br"invalid input", re.I), + re.compile(br"(?:incomplete|ambiguous) command", re.I), + re.compile(br"connection timed out", re.I), + re.compile(br"[^\r\n]+ not found"), + re.compile(br"'[^']' +returned error code: ?\d+"), + re.compile(br"Bad mask", re.I), + re.compile(br"% ?(\S+) ?overlaps with ?(\S+)", re.I), + re.compile(br"[%\S] ?Error: ?[\s]+", re.I), + re.compile(br"[%\S] ?Informational: ?[\s]+", re.I), + re.compile(br"Command authorization failed"), + ] + + def on_open_shell(self): + try: + self._exec_cli_command(b"terminal length 0") + except AnsibleConnectionFailure: + raise AnsibleConnectionFailure("unable to set terminal parameters") + + try: + self._exec_cli_command(b"terminal width 512") + try: + self._exec_cli_command(b"terminal width 0") + except AnsibleConnectionFailure: + pass + except AnsibleConnectionFailure: + display.display( + "WARNING: Unable to set terminal width, command responses may be truncated" + ) + + def on_become(self, passwd=None): + if self._get_prompt().endswith(b"#"): + return + + cmd = {u"command": u"enable"} + if passwd: + # Note: python-3.5 cannot combine u"" and r"" together. Thus make + # an r string and use to_text to ensure it's text on both py2 and py3. + cmd[u"prompt"] = to_text( + r"[\r\n]?(?:.*)?[Pp]assword: ?$", errors="surrogate_or_strict" + ) + cmd[u"answer"] = passwd + cmd[u"prompt_retry_check"] = True + try: + self._exec_cli_command( + to_bytes(json.dumps(cmd), errors="surrogate_or_strict") + ) + prompt = self._get_prompt() + if prompt is None or not prompt.endswith(b"#"): + raise AnsibleConnectionFailure( + "failed to elevate privilege to enable mode still at prompt [%s]" + % prompt + ) + except AnsibleConnectionFailure as e: + prompt = self._get_prompt() + raise AnsibleConnectionFailure( + "unable to elevate privilege to enable mode, at prompt [%s] with error: %s" + % (prompt, e.message) + ) + + def on_unbecome(self): + prompt = self._get_prompt() + if prompt is None: + # if prompt is None most likely the terminal is hung up at a prompt + return + + if b"(config" in prompt: + self._exec_cli_command(b"end") + self._exec_cli_command(b"disable") + + elif prompt.endswith(b"#"): + self._exec_cli_command(b"disable") diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py new file mode 100644 index 0000000..b86a0c4 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py @@ -0,0 +1,129 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import sys +import copy + +from ansible_collections.ansible.netcommon.plugins.action.network import ( + ActionModule as ActionNetworkModule, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + load_provider, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_provider_spec, +) +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionNetworkModule): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + module_name = self._task.action.split(".")[-1] + self._config_module = True if module_name == "vyos_config" else False + persistent_connection = self._play_context.connection.split(".")[-1] + warnings = [] + + if persistent_connection == "network_cli": + provider = self._task.args.get("provider", {}) + if any(provider.values()): + display.warning( + "provider is unnecessary when using network_cli and will be ignored" + ) + del self._task.args["provider"] + elif self._play_context.connection == "local": + provider = load_provider(vyos_provider_spec, self._task.args) + pc = copy.deepcopy(self._play_context) + pc.connection = "ansible.netcommon.network_cli" + pc.network_os = "vyos.vyos.vyos" + pc.remote_addr = provider["host"] or self._play_context.remote_addr + pc.port = int(provider["port"] or self._play_context.port or 22) + pc.remote_user = ( + provider["username"] or self._play_context.connection_user + ) + pc.password = provider["password"] or self._play_context.password + pc.private_key_file = ( + provider["ssh_keyfile"] or self._play_context.private_key_file + ) + + connection = self._shared_loader_obj.connection_loader.get( + "ansible.netcommon.persistent", + pc, + sys.stdin, + task_uuid=self._task._uuid, + ) + + # TODO: Remove below code after ansible minimal is cut out + if connection is None: + pc.connection = "network_cli" + pc.network_os = "vyos" + connection = self._shared_loader_obj.connection_loader.get( + "persistent", pc, sys.stdin, task_uuid=self._task._uuid + ) + + display.vvv( + "using connection plugin %s (was local)" % pc.connection, + pc.remote_addr, + ) + + command_timeout = ( + int(provider["timeout"]) + if provider["timeout"] + else connection.get_option("persistent_command_timeout") + ) + connection.set_options( + direct={"persistent_command_timeout": command_timeout} + ) + + socket_path = connection.run() + display.vvvv("socket_path: %s" % socket_path, pc.remote_addr) + if not socket_path: + return { + "failed": True, + "msg": "unable to open shell. Please see: " + + "https://docs.ansible.com/ansible/latest/network/user_guide/network_debug_troubleshooting.html#category-unable-to-open-shell", + } + + task_vars["ansible_socket"] = socket_path + warnings.append( + [ + "connection local support for this module is deprecated and will be removed in version 2.14, use connection %s" + % pc.connection + ] + ) + else: + return { + "failed": True, + "msg": "Connection type %s is not valid for this module" + % self._play_context.connection, + } + + result = super(ActionModule, self).run(task_vars=task_vars) + if warnings: + if "warnings" in result: + result["warnings"].extend(warnings) + else: + result["warnings"] = warnings + return result diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py new file mode 100644 index 0000000..3212615 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py @@ -0,0 +1,343 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """ +--- +author: Ansible Networking Team +cliconf: vyos +short_description: Use vyos cliconf to run command on VyOS platform +description: + - This vyos plugin provides low level abstraction apis for + sending and receiving CLI commands from VyOS network devices. +version_added: "2.4" +""" + +import re +import json + +from collections.abc import Mapping + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_text +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import ( + NetworkConfig, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible.plugins.cliconf import CliconfBase + + +class Cliconf(CliconfBase): + def get_device_info(self): + device_info = {} + + device_info["network_os"] = "vyos" + reply = self.get("show version") + data = to_text(reply, errors="surrogate_or_strict").strip() + + match = re.search(r"Version:\s*(.*)", data) + if match: + device_info["network_os_version"] = match.group(1) + + match = re.search(r"HW model:\s*(\S+)", data) + if match: + device_info["network_os_model"] = match.group(1) + + reply = self.get("show host name") + device_info["network_os_hostname"] = to_text( + reply, errors="surrogate_or_strict" + ).strip() + + return device_info + + def get_config(self, flags=None, format=None): + if format: + option_values = self.get_option_values() + if format not in option_values["format"]: + raise ValueError( + "'format' value %s is invalid. Valid values of format are %s" + % (format, ", ".join(option_values["format"])) + ) + + if not flags: + flags = [] + + if format == "text": + command = "show configuration" + else: + command = "show configuration commands" + + command += " ".join(to_list(flags)) + command = command.strip() + + out = self.send_command(command) + return out + + def edit_config( + self, candidate=None, commit=True, replace=None, comment=None + ): + resp = {} + operations = self.get_device_operations() + self.check_edit_config_capability( + operations, candidate, commit, replace, comment + ) + + results = [] + requests = [] + self.send_command("configure") + for cmd in to_list(candidate): + if not isinstance(cmd, Mapping): + cmd = {"command": cmd} + + results.append(self.send_command(**cmd)) + requests.append(cmd["command"]) + out = self.get("compare") + out = to_text(out, errors="surrogate_or_strict") + diff_config = out if not out.startswith("No changes") else None + + if diff_config: + if commit: + try: + self.commit(comment) + except AnsibleConnectionFailure as e: + msg = "commit failed: %s" % e.message + self.discard_changes() + raise AnsibleConnectionFailure(msg) + else: + self.send_command("exit") + else: + self.discard_changes() + else: + self.send_command("exit") + if ( + to_text( + self._connection.get_prompt(), errors="surrogate_or_strict" + ) + .strip() + .endswith("#") + ): + self.discard_changes() + + if diff_config: + resp["diff"] = diff_config + resp["response"] = results + resp["request"] = requests + return resp + + def get( + self, + command=None, + prompt=None, + answer=None, + sendonly=False, + output=None, + newline=True, + check_all=False, + ): + if not command: + raise ValueError("must provide value of command to execute") + if output: + raise ValueError( + "'output' value %s is not supported for get" % output + ) + + return self.send_command( + command=command, + prompt=prompt, + answer=answer, + sendonly=sendonly, + newline=newline, + check_all=check_all, + ) + + def commit(self, comment=None): + if comment: + command = 'commit comment "{0}"'.format(comment) + else: + command = "commit" + self.send_command(command) + + def discard_changes(self): + self.send_command("exit discard") + + def get_diff( + self, + candidate=None, + running=None, + diff_match="line", + diff_ignore_lines=None, + path=None, + diff_replace=None, + ): + diff = {} + device_operations = self.get_device_operations() + option_values = self.get_option_values() + + if candidate is None and device_operations["supports_generate_diff"]: + raise ValueError( + "candidate configuration is required to generate diff" + ) + + if diff_match not in option_values["diff_match"]: + raise ValueError( + "'match' value %s in invalid, valid values are %s" + % (diff_match, ", ".join(option_values["diff_match"])) + ) + + if diff_replace: + raise ValueError("'replace' in diff is not supported") + + if diff_ignore_lines: + raise ValueError("'diff_ignore_lines' in diff is not supported") + + if path: + raise ValueError("'path' in diff is not supported") + + set_format = candidate.startswith("set") or candidate.startswith( + "delete" + ) + candidate_obj = NetworkConfig(indent=4, contents=candidate) + if not set_format: + config = [c.line for c in candidate_obj.items] + commands = list() + # this filters out less specific lines + for item in config: + for index, entry in enumerate(commands): + if item.startswith(entry): + del commands[index] + break + commands.append(item) + + candidate_commands = [ + "set %s" % cmd.replace(" {", "") for cmd in commands + ] + + else: + candidate_commands = str(candidate).strip().split("\n") + + if diff_match == "none": + diff["config_diff"] = list(candidate_commands) + return diff + + running_commands = [ + str(c).replace("'", "") for c in running.splitlines() + ] + + updates = list() + visited = set() + + for line in candidate_commands: + item = str(line).replace("'", "") + + if not item.startswith("set") and not item.startswith("delete"): + raise ValueError( + "line must start with either `set` or `delete`" + ) + + elif item.startswith("set") and item not in running_commands: + updates.append(line) + + elif item.startswith("delete"): + if not running_commands: + updates.append(line) + else: + item = re.sub(r"delete", "set", item) + for entry in running_commands: + if entry.startswith(item) and line not in visited: + updates.append(line) + visited.add(line) + + diff["config_diff"] = list(updates) + return diff + + def run_commands(self, commands=None, check_rc=True): + if commands is None: + raise ValueError("'commands' value is required") + + responses = list() + for cmd in to_list(commands): + if not isinstance(cmd, Mapping): + cmd = {"command": cmd} + + output = cmd.pop("output", None) + if output: + raise ValueError( + "'output' value %s is not supported for run_commands" + % output + ) + + try: + out = self.send_command(**cmd) + except AnsibleConnectionFailure as e: + if check_rc: + raise + out = getattr(e, "err", e) + + responses.append(out) + + return responses + + def get_device_operations(self): + return { + "supports_diff_replace": False, + "supports_commit": True, + "supports_rollback": False, + "supports_defaults": False, + "supports_onbox_diff": True, + "supports_commit_comment": True, + "supports_multiline_delimiter": False, + "supports_diff_match": True, + "supports_diff_ignore_lines": False, + "supports_generate_diff": False, + "supports_replace": False, + } + + def get_option_values(self): + return { + "format": ["text", "set"], + "diff_match": ["line", "none"], + "diff_replace": [], + "output": [], + } + + def get_capabilities(self): + result = super(Cliconf, self).get_capabilities() + result["rpc"] += [ + "commit", + "discard_changes", + "get_diff", + "run_commands", + ] + result["device_operations"] = self.get_device_operations() + result.update(self.get_option_values()) + return json.dumps(result) + + def set_cli_prompt_context(self): + """ + Make sure we are in the operational cli mode + :return: None + """ + if self._connection.connected: + self._update_cli_prompt_context( + config_context="#", exit_command="exit discard" + ) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py new file mode 100644 index 0000000..094963f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Peter Sprygada <psprygada@ansible.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: + provider: + description: + - B(Deprecated) + - 'Starting with Ansible 2.5 we recommend using C(connection: network_cli).' + - For more information please see the L(Network Guide, ../network/getting_started/network_differences.html#multiple-communication-protocols). + - HORIZONTALLINE + - A dict object containing connection details. + type: dict + suboptions: + host: + description: + - Specifies the DNS host name or address for connecting to the remote device + over the specified transport. The value of host is used as the destination + address for the transport. + type: str + required: true + port: + description: + - Specifies the port to use when building the connection to the remote device. + type: int + default: 22 + username: + description: + - Configures the username to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME) + will be used instead. + type: str + password: + description: + - Specifies the password to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) + will be used instead. + type: str + timeout: + description: + - Specifies the timeout in seconds for communicating with the network device + for either connecting or sending commands. If the timeout is exceeded before + the operation is completed, the module will error. + type: int + default: 10 + ssh_keyfile: + description: + - Specifies the SSH key to use to authenticate the connection to the remote + device. This value is the path to the key used to authenticate the SSH + session. If the value is not specified in the task, the value of environment + variable C(ANSIBLE_NET_SSH_KEYFILE) will be used instead. + type: path +notes: +- For more information on using Ansible to manage network devices see the :ref:`Ansible + Network Guide <network_guide>` +""" diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py new file mode 100644 index 0000000..46fabaa --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py @@ -0,0 +1,22 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The arg spec for the vyos facts module. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class FactsArgs(object): # pylint: disable=R0903 + """ The arg spec for the vyos facts module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "gather_subset": dict(default=["!config"], type="list"), + "gather_network_resources": dict(type="list"), + } diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py new file mode 100644 index 0000000..a018cc0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py @@ -0,0 +1,263 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_firewall_rules module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Firewall_rulesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_firewall_rules module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "afi": { + "choices": ["ipv4", "ipv6"], + "required": True, + "type": "str", + }, + "rule_sets": { + "elements": "dict", + "options": { + "default_action": { + "choices": ["drop", "reject", "accept"], + "type": "str", + }, + "description": {"type": "str"}, + "enable_default_log": {"type": "bool"}, + "name": {"type": "str"}, + "rules": { + "elements": "dict", + "options": { + "action": { + "choices": [ + "drop", + "reject", + "accept", + "inspect", + ], + "type": "str", + }, + "description": {"type": "str"}, + "destination": { + "options": { + "address": {"type": "str"}, + "group": { + "options": { + "address_group": { + "type": "str" + }, + "network_group": { + "type": "str" + }, + "port_group": {"type": "str"}, + }, + "type": "dict", + }, + "port": {"type": "str"}, + }, + "type": "dict", + }, + "disabled": {"type": "bool"}, + "fragment": { + "choices": [ + "match-frag", + "match-non-frag", + ], + "type": "str", + }, + "icmp": { + "options": { + "code": {"type": "int"}, + "type": {"type": "int"}, + "type_name": { + "choices": [ + "any", + "echo-reply", + "destination-unreachable", + "network-unreachable", + "host-unreachable", + "protocol-unreachable", + "port-unreachable", + "fragmentation-needed", + "source-route-failed", + "network-unknown", + "host-unknown", + "network-prohibited", + "host-prohibited", + "TOS-network-unreachable", + "TOS-host-unreachable", + "communication-prohibited", + "host-precedence-violation", + "precedence-cutoff", + "source-quench", + "redirect", + "network-redirect", + "host-redirect", + "TOS-network-redirect", + "TOS-host-redirect", + "echo-request", + "router-advertisement", + "router-solicitation", + "time-exceeded", + "ttl-zero-during-transit", + "ttl-zero-during-reassembly", + "parameter-problem", + "ip-header-bad", + "required-option-missing", + "timestamp-request", + "timestamp-reply", + "address-mask-request", + "address-mask-reply", + "ping", + "pong", + "ttl-exceeded", + ], + "type": "str", + }, + }, + "type": "dict", + }, + "ipsec": { + "choices": ["match-ipsec", "match-none"], + "type": "str", + }, + "limit": { + "options": { + "burst": {"type": "int"}, + "rate": { + "options": { + "number": {"type": "int"}, + "unit": {"type": "str"}, + }, + "type": "dict", + }, + }, + "type": "dict", + }, + "number": {"required": True, "type": "int"}, + "p2p": { + "elements": "dict", + "options": { + "application": { + "choices": [ + "all", + "applejuice", + "bittorrent", + "directconnect", + "edonkey", + "gnutella", + "kazaa", + ], + "type": "str", + } + }, + "type": "list", + }, + "protocol": {"type": "str"}, + "recent": { + "options": { + "count": {"type": "int"}, + "time": {"type": "int"}, + }, + "type": "dict", + }, + "source": { + "options": { + "address": {"type": "str"}, + "group": { + "options": { + "address_group": { + "type": "str" + }, + "network_group": { + "type": "str" + }, + "port_group": {"type": "str"}, + }, + "type": "dict", + }, + "mac_address": {"type": "str"}, + "port": {"type": "str"}, + }, + "type": "dict", + }, + "state": { + "options": { + "established": {"type": "bool"}, + "invalid": {"type": "bool"}, + "new": {"type": "bool"}, + "related": {"type": "bool"}, + }, + "type": "dict", + }, + "tcp": { + "options": {"flags": {"type": "str"}}, + "type": "dict", + }, + "time": { + "options": { + "monthdays": {"type": "str"}, + "startdate": {"type": "str"}, + "starttime": {"type": "str"}, + "stopdate": {"type": "str"}, + "stoptime": {"type": "str"}, + "utc": {"type": "bool"}, + "weekdays": {"type": "str"}, + }, + "type": "dict", + }, + }, + "type": "list", + }, + }, + "type": "list", + }, + }, + "type": "list", + }, + "running_config": {"type": "str"}, + "state": { + "choices": [ + "merged", + "replaced", + "overridden", + "deleted", + "gathered", + "rendered", + "parsed", + ], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py new file mode 100644 index 0000000..3542cb1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py @@ -0,0 +1,69 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_interfaces module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class InterfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "description": {"type": "str"}, + "duplex": {"choices": ["full", "half", "auto"]}, + "enabled": {"default": True, "type": "bool"}, + "mtu": {"type": "int"}, + "name": {"required": True, "type": "str"}, + "speed": { + "choices": ["auto", "10", "100", "1000", "2500", "10000"], + "type": "str", + }, + "vifs": { + "elements": "dict", + "options": { + "vlan_id": {"type": "int"}, + "description": {"type": "str"}, + "enabled": {"default": True, "type": "bool"}, + "mtu": {"type": "int"}, + }, + "type": "list", + }, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py new file mode 100644 index 0000000..91434e4 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py @@ -0,0 +1,81 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_l3_interfaces module +""" + + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class L3_interfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_l3_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "ipv4": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "ipv6": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "name": {"required": True, "type": "str"}, + "vifs": { + "elements": "dict", + "options": { + "ipv4": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "ipv6": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "vlan_id": {"type": "int"}, + }, + "type": "list", + }, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py new file mode 100644 index 0000000..97c5d5a --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py @@ -0,0 +1,80 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# + +""" +The arg spec for the vyos_lag_interfaces module +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Lag_interfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_lag_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "arp_monitor": { + "options": { + "interval": {"type": "int"}, + "target": {"type": "list"}, + }, + "type": "dict", + }, + "hash_policy": { + "choices": ["layer2", "layer2+3", "layer3+4"], + "type": "str", + }, + "members": { + "elements": "dict", + "options": {"member": {"type": "str"}}, + "type": "list", + }, + "mode": { + "choices": [ + "802.3ad", + "active-backup", + "broadcast", + "round-robin", + "transmit-load-balance", + "adaptive-load-balance", + "xor-hash", + ], + "type": "str", + }, + "name": {"required": True, "type": "str"}, + "primary": {"type": "str"}, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py new file mode 100644 index 0000000..84bbc00 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py @@ -0,0 +1,56 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# + +""" +The arg spec for the vyos_lldp_global module +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Lldp_globalArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_lldp_global module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "options": { + "address": {"type": "str"}, + "enable": {"type": "bool"}, + "legacy_protocols": { + "choices": ["cdp", "edp", "fdp", "sonmp"], + "type": "list", + }, + "snmp": {"type": "str"}, + }, + "type": "dict", + }, + "state": { + "choices": ["merged", "replaced", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py new file mode 100644 index 0000000..2976fc0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py @@ -0,0 +1,89 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_lldp_interfaces module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Lldp_interfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_lldp_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "enable": {"default": True, "type": "bool"}, + "location": { + "options": { + "civic_based": { + "options": { + "ca_info": { + "elements": "dict", + "options": { + "ca_type": {"type": "int"}, + "ca_value": {"type": "str"}, + }, + "type": "list", + }, + "country_code": { + "required": True, + "type": "str", + }, + }, + "type": "dict", + }, + "coordinate_based": { + "options": { + "altitude": {"type": "int"}, + "datum": { + "choices": ["WGS84", "NAD83", "MLLW"], + "type": "str", + }, + "latitude": {"required": True, "type": "str"}, + "longitude": {"required": True, "type": "str"}, + }, + "type": "dict", + }, + "elin": {"type": "str"}, + }, + "type": "dict", + }, + "name": {"required": True, "type": "str"}, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py new file mode 100644 index 0000000..8ecd955 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py @@ -0,0 +1,99 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_static_routes module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Static_routesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_static_routes module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "address_families": { + "elements": "dict", + "options": { + "afi": { + "choices": ["ipv4", "ipv6"], + "required": True, + "type": "str", + }, + "routes": { + "elements": "dict", + "options": { + "blackhole_config": { + "options": { + "distance": {"type": "int"}, + "type": {"type": "str"}, + }, + "type": "dict", + }, + "dest": {"required": True, "type": "str"}, + "next_hops": { + "elements": "dict", + "options": { + "admin_distance": {"type": "int"}, + "enabled": {"type": "bool"}, + "forward_router_address": { + "required": True, + "type": "str", + }, + "interface": {"type": "str"}, + }, + "type": "list", + }, + }, + "type": "list", + }, + }, + "type": "list", + } + }, + "type": "list", + }, + "running_config": {"type": "str"}, + "state": { + "choices": [ + "merged", + "replaced", + "overridden", + "deleted", + "gathered", + "rendered", + "parsed", + ], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py new file mode 100644 index 0000000..377fec9 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py @@ -0,0 +1,438 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos_lldp_interfaces class +It is in this file where the current configuration (as dict) +is compared to the provided configuration (as dict) and the command set +necessary to bring the current configuration to it's desired end-state is +created +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.cfg.base import ( + ConfigBase, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.facts import ( + Facts, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, + dict_diff, +) +from ansible.module_utils.six import iteritems +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.utils.utils import ( + search_obj_in_list, + search_dict_tv_in_list, + key_value_in_dict, + is_dict_element_present, +) + + +class Lldp_interfaces(ConfigBase): + """ + The vyos_lldp_interfaces class + """ + + gather_subset = [ + "!all", + "!min", + ] + + gather_network_resources = [ + "lldp_interfaces", + ] + + params = ["enable", "location", "name"] + + def __init__(self, module): + super(Lldp_interfaces, self).__init__(module) + + def get_lldp_interfaces_facts(self): + """ Get the 'facts' (the current configuration) + + :rtype: A dictionary + :returns: The current configuration as a dictionary + """ + facts, _warnings = Facts(self._module).get_facts( + self.gather_subset, self.gather_network_resources + ) + lldp_interfaces_facts = facts["ansible_network_resources"].get( + "lldp_interfaces" + ) + if not lldp_interfaces_facts: + return [] + return lldp_interfaces_facts + + def execute_module(self): + """ Execute the module + + :rtype: A dictionary + :returns: The result from module execution + """ + result = {"changed": False} + commands = list() + warnings = list() + existing_lldp_interfaces_facts = self.get_lldp_interfaces_facts() + commands.extend(self.set_config(existing_lldp_interfaces_facts)) + if commands: + if self._module.check_mode: + resp = self._connection.edit_config(commands, commit=False) + else: + resp = self._connection.edit_config(commands) + result["changed"] = True + + result["commands"] = commands + + if self._module._diff: + result["diff"] = resp["diff"] if result["changed"] else None + + changed_lldp_interfaces_facts = self.get_lldp_interfaces_facts() + result["before"] = existing_lldp_interfaces_facts + if result["changed"]: + result["after"] = changed_lldp_interfaces_facts + + result["warnings"] = warnings + return result + + def set_config(self, existing_lldp_interfaces_facts): + """ Collect the configuration from the args passed to the module, + collect the current configuration (as a dict from facts) + + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + want = self._module.params["config"] + have = existing_lldp_interfaces_facts + resp = self.set_state(want, have) + return to_list(resp) + + def set_state(self, want, have): + """ Select the appropriate function based on the state provided + + :param want: the desired configuration as a dictionary + :param have: the current configuration as a dictionary + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + commands = [] + state = self._module.params["state"] + if state in ("merged", "replaced", "overridden") and not want: + self._module.fail_json( + msg="value of config parameter must not be empty for state {0}".format( + state + ) + ) + if state == "overridden": + commands.extend(self._state_overridden(want=want, have=have)) + elif state == "deleted": + if want: + for item in want: + name = item["name"] + have_item = search_obj_in_list(name, have) + commands.extend( + self._state_deleted(want=None, have=have_item) + ) + else: + for have_item in have: + commands.extend( + self._state_deleted(want=None, have=have_item) + ) + else: + for want_item in want: + name = want_item["name"] + have_item = search_obj_in_list(name, have) + if state == "merged": + commands.extend( + self._state_merged(want=want_item, have=have_item) + ) + else: + commands.extend( + self._state_replaced(want=want_item, have=have_item) + ) + return commands + + def _state_replaced(self, want, have): + """ The command generator when state is replaced + + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + commands = [] + if have: + commands.extend(self._state_deleted(want, have)) + commands.extend(self._state_merged(want, have)) + return commands + + def _state_overridden(self, want, have): + """ The command generator when state is overridden + + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + commands = [] + for have_item in have: + lldp_name = have_item["name"] + lldp_in_want = search_obj_in_list(lldp_name, want) + if not lldp_in_want: + commands.append( + self._compute_command(have_item["name"], remove=True) + ) + + for want_item in want: + name = want_item["name"] + lldp_in_have = search_obj_in_list(name, have) + commands.extend(self._state_replaced(want_item, lldp_in_have)) + return commands + + def _state_merged(self, want, have): + """ The command generator when state is merged + + :rtype: A list + :returns: the commands necessary to merge the provided into + the current configuration + """ + commands = [] + if have: + commands.extend(self._render_updates(want, have)) + else: + commands.extend(self._render_set_commands(want)) + return commands + + def _state_deleted(self, want, have): + """ The command generator when state is deleted + + :rtype: A list + :returns: the commands necessary to remove the current configuration + of the provided objects + """ + commands = [] + if want: + params = Lldp_interfaces.params + for attrib in params: + if attrib == "location": + commands.extend( + self._update_location(have["name"], want, have) + ) + + elif have: + commands.append(self._compute_command(have["name"], remove=True)) + return commands + + def _render_updates(self, want, have): + commands = [] + lldp_name = have["name"] + commands.extend(self._configure_status(lldp_name, want, have)) + commands.extend(self._add_location(lldp_name, want, have)) + + return commands + + def _render_set_commands(self, want): + commands = [] + have = {} + lldp_name = want["name"] + params = Lldp_interfaces.params + + commands.extend(self._add_location(lldp_name, want, have)) + for attrib in params: + value = want[attrib] + if value: + if attrib == "location": + commands.extend(self._add_location(lldp_name, want, have)) + elif attrib == "enable": + if not value: + commands.append( + self._compute_command(lldp_name, value="disable") + ) + else: + commands.append(self._compute_command(lldp_name)) + + return commands + + def _configure_status(self, name, want_item, have_item): + commands = [] + if is_dict_element_present(have_item, "enable"): + temp_have_item = False + else: + temp_have_item = True + if want_item["enable"] != temp_have_item: + if want_item["enable"]: + commands.append( + self._compute_command(name, value="disable", remove=True) + ) + else: + commands.append(self._compute_command(name, value="disable")) + return commands + + def _add_location(self, name, want_item, have_item): + commands = [] + have_dict = {} + have_ca = {} + set_cmd = name + " location " + want_location_type = want_item.get("location") or {} + have_location_type = have_item.get("location") or {} + + if want_location_type["coordinate_based"]: + want_dict = want_location_type.get("coordinate_based") or {} + if is_dict_element_present(have_location_type, "coordinate_based"): + have_dict = have_location_type.get("coordinate_based") or {} + location_type = "coordinate-based" + updates = dict_diff(have_dict, want_dict) + for key, value in iteritems(updates): + if value: + commands.append( + self._compute_command( + set_cmd + location_type, key, str(value) + ) + ) + + elif want_location_type["civic_based"]: + location_type = "civic-based" + want_dict = want_location_type.get("civic_based") or {} + want_ca = want_dict.get("ca_info") or [] + if is_dict_element_present(have_location_type, "civic_based"): + have_dict = have_location_type.get("civic_based") or {} + have_ca = have_dict.get("ca_info") or [] + if want_dict["country_code"] != have_dict["country_code"]: + commands.append( + self._compute_command( + set_cmd + location_type, + "country-code", + str(want_dict["country_code"]), + ) + ) + else: + commands.append( + self._compute_command( + set_cmd + location_type, + "country-code", + str(want_dict["country_code"]), + ) + ) + commands.extend(self._add_civic_address(name, want_ca, have_ca)) + + elif want_location_type["elin"]: + location_type = "elin" + if is_dict_element_present(have_location_type, "elin"): + if want_location_type.get("elin") != have_location_type.get( + "elin" + ): + commands.append( + self._compute_command( + set_cmd + location_type, + value=str(want_location_type["elin"]), + ) + ) + else: + commands.append( + self._compute_command( + set_cmd + location_type, + value=str(want_location_type["elin"]), + ) + ) + return commands + + def _update_location(self, name, want_item, have_item): + commands = [] + del_cmd = name + " location" + want_location_type = want_item.get("location") or {} + have_location_type = have_item.get("location") or {} + + if want_location_type["coordinate_based"]: + want_dict = want_location_type.get("coordinate_based") or {} + if is_dict_element_present(have_location_type, "coordinate_based"): + have_dict = have_location_type.get("coordinate_based") or {} + location_type = "coordinate-based" + for key, value in iteritems(have_dict): + only_in_have = key_value_in_dict(key, value, want_dict) + if not only_in_have: + commands.append( + self._compute_command( + del_cmd + location_type, key, str(value), True + ) + ) + else: + commands.append(self._compute_command(del_cmd, remove=True)) + + elif want_location_type["civic_based"]: + want_dict = want_location_type.get("civic_based") or {} + want_ca = want_dict.get("ca_info") or [] + if is_dict_element_present(have_location_type, "civic_based"): + have_dict = have_location_type.get("civic_based") or {} + have_ca = have_dict.get("ca_info") + commands.extend( + self._update_civic_address(name, want_ca, have_ca) + ) + else: + commands.append(self._compute_command(del_cmd, remove=True)) + + else: + if is_dict_element_present(have_location_type, "elin"): + if want_location_type.get("elin") != have_location_type.get( + "elin" + ): + commands.append( + self._compute_command(del_cmd, remove=True) + ) + else: + commands.append(self._compute_command(del_cmd, remove=True)) + return commands + + def _add_civic_address(self, name, want, have): + commands = [] + for item in want: + ca_type = item["ca_type"] + ca_value = item["ca_value"] + obj_in_have = search_dict_tv_in_list( + ca_type, ca_value, have, "ca_type", "ca_value" + ) + if not obj_in_have: + commands.append( + self._compute_command( + key=name + " location civic-based ca-type", + attrib=str(ca_type) + " ca-value", + value=ca_value, + ) + ) + return commands + + def _update_civic_address(self, name, want, have): + commands = [] + for item in have: + ca_type = item["ca_type"] + ca_value = item["ca_value"] + in_want = search_dict_tv_in_list( + ca_type, ca_value, want, "ca_type", "ca_value" + ) + if not in_want: + commands.append( + self._compute_command( + name, + "location civic-based ca-type", + str(ca_type), + remove=True, + ) + ) + return commands + + def _compute_command(self, key, attrib=None, value=None, remove=False): + if remove: + cmd = "delete service lldp interface " + else: + cmd = "set service lldp interface " + cmd += key + if attrib: + cmd += " " + attrib + if value: + cmd += " '" + value + "'" + return cmd diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py new file mode 100644 index 0000000..8f0a3bb --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py @@ -0,0 +1,83 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The facts class for vyos +this file validates each subset of facts and selectively +calls the appropriate facts gathering function +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.facts.facts import ( + FactsBase, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.interfaces.interfaces import ( + InterfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.l3_interfaces.l3_interfaces import ( + L3_interfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lag_interfaces.lag_interfaces import ( + Lag_interfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lldp_global.lldp_global import ( + Lldp_globalFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lldp_interfaces.lldp_interfaces import ( + Lldp_interfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.firewall_rules.firewall_rules import ( + Firewall_rulesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.static_routes.static_routes import ( + Static_routesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.legacy.base import ( + Default, + Neighbors, + Config, +) + + +FACT_LEGACY_SUBSETS = dict(default=Default, neighbors=Neighbors, config=Config) +FACT_RESOURCE_SUBSETS = dict( + interfaces=InterfacesFacts, + l3_interfaces=L3_interfacesFacts, + lag_interfaces=Lag_interfacesFacts, + lldp_global=Lldp_globalFacts, + lldp_interfaces=Lldp_interfacesFacts, + static_routes=Static_routesFacts, + firewall_rules=Firewall_rulesFacts, +) + + +class Facts(FactsBase): + """ The fact class for vyos + """ + + VALID_LEGACY_GATHER_SUBSETS = frozenset(FACT_LEGACY_SUBSETS.keys()) + VALID_RESOURCE_SUBSETS = frozenset(FACT_RESOURCE_SUBSETS.keys()) + + def __init__(self, module): + super(Facts, self).__init__(module) + + def get_facts( + self, legacy_facts_type=None, resource_facts_type=None, data=None + ): + """ Collect the facts for vyos + :param legacy_facts_type: List of legacy facts types + :param resource_facts_type: List of resource fact types + :param data: previously collected conf + :rtype: dict + :return: the facts gathered + """ + if self.VALID_RESOURCE_SUBSETS: + self.get_network_resources_facts( + FACT_RESOURCE_SUBSETS, resource_facts_type, data + ) + if self.VALID_LEGACY_GATHER_SUBSETS: + self.get_network_legacy_facts( + FACT_LEGACY_SUBSETS, legacy_facts_type + ) + return self.ansible_facts, self._warnings diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py new file mode 100644 index 0000000..971ea6f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py @@ -0,0 +1,380 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos firewall_rules fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from re import findall, search, M +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.firewall_rules.firewall_rules import ( + Firewall_rulesArgs, +) + + +class Firewall_rulesFacts(object): + """ The vyos firewall_rules fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Firewall_rulesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def get_device_data(self, connection): + return connection.get_config() + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for firewall_rules + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + # typically data is populated from the current device configuration + # data = connection.get('show running-config | section ^interface') + # using mock data instead + data = self.get_device_data(connection) + # split the config into instances of the resource + objs = [] + v6_rules = findall( + r"^set firewall ipv6-name (?:\'*)(\S+)(?:\'*)", data, M + ) + v4_rules = findall(r"^set firewall name (?:\'*)(\S+)(?:\'*)", data, M) + if v6_rules: + config = self.get_rules(data, v6_rules, type="ipv6") + if config: + config = utils.remove_empties(config) + objs.append(config) + if v4_rules: + config = self.get_rules(data, v4_rules, type="ipv4") + if config: + config = utils.remove_empties(config) + objs.append(config) + + ansible_facts["ansible_network_resources"].pop("firewall_rules", None) + facts = {} + if objs: + facts["firewall_rules"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["firewall_rules"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def get_rules(self, data, rules, type): + """ + This function performs following: + - Form regex to fetch 'rule-sets' specific config from data. + - Form the rule-set list based on ip address. + :param data: configuration. + :param rules: list of rule-sets. + :param type: ip address type. + :return: generated rule-sets configuration. + """ + r_v4 = [] + r_v6 = [] + for r in set(rules): + rule_regex = r" %s .+$" % r.strip("'") + cfg = findall(rule_regex, data, M) + fr = self.render_config(cfg, r.strip("'")) + fr["name"] = r.strip("'") + if type == "ipv6": + r_v6.append(fr) + else: + r_v4.append(fr) + if r_v4: + config = {"afi": "ipv4", "rule_sets": r_v4} + if r_v6: + config = {"afi": "ipv6", "rule_sets": r_v6} + return config + + def render_config(self, conf, match): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + conf = "\n".join(filter(lambda x: x, conf)) + a_lst = ["description", "default_action", "enable_default_log"] + config = self.parse_attr(conf, a_lst, match) + if not config: + config = {} + config["rules"] = self.parse_rules_lst(conf) + return config + + def parse_rules_lst(self, conf): + """ + This function forms the regex to fetch the 'rules' with in + 'rule-sets' + :param conf: configuration data. + :return: generated rule list configuration. + """ + r_lst = [] + rules = findall(r"rule (?:\'*)(\d+)(?:\'*)", conf, M) + if rules: + rules_lst = [] + for r in set(rules): + r_regex = r" %s .+$" % r + cfg = "\n".join(findall(r_regex, conf, M)) + obj = self.parse_rules(cfg) + obj["number"] = int(r) + if obj: + rules_lst.append(obj) + r_lst = sorted(rules_lst, key=lambda i: i["number"]) + return r_lst + + def parse_rules(self, conf): + """ + This function triggers the parsing of 'rule' attributes. + a_lst is a list having rule attributes which doesn't + have further sub attributes. + :param conf: configuration + :return: generated rule configuration dictionary. + """ + a_lst = [ + "ipsec", + "action", + "protocol", + "fragment", + "disabled", + "description", + ] + rule = self.parse_attr(conf, a_lst) + r_sub = { + "p2p": self.parse_p2p(conf), + "tcp": self.parse_tcp(conf, "tcp"), + "icmp": self.parse_icmp(conf, "icmp"), + "time": self.parse_time(conf, "time"), + "limit": self.parse_limit(conf, "limit"), + "state": self.parse_state(conf, "state"), + "recent": self.parse_recent(conf, "recent"), + "source": self.parse_src_or_dest(conf, "source"), + "destination": self.parse_src_or_dest(conf, "destination"), + } + rule.update(r_sub) + return rule + + def parse_p2p(self, conf): + """ + This function forms the regex to fetch the 'p2p' with in + 'rules' + :param conf: configuration data. + :return: generated rule list configuration. + """ + a_lst = [] + applications = findall(r"p2p (?:\'*)(\d+)(?:\'*)", conf, M) + if applications: + app_lst = [] + for r in set(applications): + obj = {"application": r.strip("'")} + app_lst.append(obj) + a_lst = sorted(app_lst, key=lambda i: i["application"]) + return a_lst + + def parse_src_or_dest(self, conf, attrib=None): + """ + This function triggers the parsing of 'source or + destination' attributes. + :param conf: configuration. + :param attrib:'source/destination'. + :return:generated source/destination configuration dictionary. + """ + a_lst = ["port", "address", "mac_address"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + cfg_dict["group"] = self.parse_group(conf, attrib + " group") + return cfg_dict + + def parse_recent(self, conf, attrib=None): + """ + This function triggers the parsing of 'recent' attributes + :param conf: configuration. + :param attrib: 'recent'. + :return: generated config dictionary. + """ + a_lst = ["time", "count"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_tcp(self, conf, attrib=None): + """ + This function triggers the parsing of 'tcp' attributes. + :param conf: configuration. + :param attrib: 'tcp'. + :return: generated config dictionary. + """ + cfg_dict = self.parse_attr(conf, ["flags"], match=attrib) + return cfg_dict + + def parse_time(self, conf, attrib=None): + """ + This function triggers the parsing of 'time' attributes. + :param conf: configuration. + :param attrib: 'time'. + :return: generated config dictionary. + """ + a_lst = [ + "stopdate", + "stoptime", + "weekdays", + "monthdays", + "startdate", + "starttime", + ] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_state(self, conf, attrib=None): + """ + This function triggers the parsing of 'state' attributes. + :param conf: configuration + :param attrib: 'state'. + :return: generated config dictionary. + """ + a_lst = ["new", "invalid", "related", "established"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_group(self, conf, attrib=None): + """ + This function triggers the parsing of 'group' attributes. + :param conf: configuration. + :param attrib: 'group'. + :return: generated config dictionary. + """ + a_lst = ["port_group", "address_group", "network_group"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_icmp(self, conf, attrib=None): + """ + This function triggers the parsing of 'icmp' attributes. + :param conf: configuration to be parsed. + :param attrib: 'icmp'. + :return: generated config dictionary. + """ + a_lst = ["code", "type", "type_name"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_limit(self, conf, attrib=None): + """ + This function triggers the parsing of 'limit' attributes. + :param conf: configuration to be parsed. + :param attrib: 'limit' + :return: generated config dictionary. + """ + cfg_dict = self.parse_attr(conf, ["burst"], match=attrib) + cfg_dict["rate"] = self.parse_rate(conf, "rate") + return cfg_dict + + def parse_rate(self, conf, attrib=None): + """ + This function triggers the parsing of 'rate' attributes. + :param conf: configuration. + :param attrib: 'rate' + :return: generated config dictionary. + """ + a_lst = ["unit", "number"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_attr(self, conf, attr_list, match=None): + """ + This function peforms the following: + - Form the regex to fetch the required attribute config. + - Type cast the output in desired format. + :param conf: configuration. + :param attr_list: list of attributes. + :param match: parent node/attribute name. + :return: generated config dictionary. + """ + config = {} + for attrib in attr_list: + regex = self.map_regex(attrib) + if match: + regex = match + " " + regex + if conf: + if self.is_bool(attrib): + out = conf.find(attrib.replace("_", "-")) + + dis = conf.find(attrib.replace("_", "-") + " 'disable'") + if out >= 1: + if dis >= 1: + config[attrib] = False + else: + config[attrib] = True + else: + out = search(r"^.*" + regex + " (.+)", conf, M) + if out: + val = out.group(1).strip("'") + if self.is_num(attrib): + val = int(val) + config[attrib] = val + return config + + def map_regex(self, attrib): + """ + - This function construct the regex string. + - replace the underscore with hyphen. + :param attrib: attribute + :return: regex string + """ + regex = attrib.replace("_", "-") + if attrib == "disabled": + regex = "disable" + return regex + + def is_bool(self, attrib): + """ + This function looks for the attribute in predefined bool type set. + :param attrib: attribute. + :return: True/False + """ + bool_set = ( + "new", + "invalid", + "related", + "disabled", + "established", + "enable_default_log", + ) + return True if attrib in bool_set else False + + def is_num(self, attrib): + """ + This function looks for the attribute in predefined integer type set. + :param attrib: attribute. + :return: True/false. + """ + num_set = ("time", "code", "type", "count", "burst", "number") + return True if attrib in num_set else False diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py new file mode 100644 index 0000000..4b24803 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py @@ -0,0 +1,134 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +from re import findall, M +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.interfaces.interfaces import ( + InterfacesArgs, +) + + +class InterfacesFacts(object): + """ The vyos interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = InterfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for interfaces + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config(flags=["| grep interfaces"]) + + objs = [] + interface_names = findall( + r"^set interfaces (?:ethernet|bonding|vti|loopback|vxlan) (?:\'*)(\S+)(?:\'*)", + data, + M, + ) + if interface_names: + for interface in set(interface_names): + intf_regex = r" %s .+$" % interface.strip("'") + cfg = findall(intf_regex, data, M) + obj = self.render_config(cfg) + obj["name"] = interface.strip("'") + if obj: + objs.append(obj) + facts = {} + if objs: + facts["interfaces"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["interfaces"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + vif_conf = "\n".join(filter(lambda x: ("vif" in x), conf)) + eth_conf = "\n".join(filter(lambda x: ("vif" not in x), conf)) + config = self.parse_attribs( + ["description", "speed", "mtu", "duplex"], eth_conf + ) + config["vifs"] = self.parse_vifs(vif_conf) + + return utils.remove_empties(config) + + def parse_vifs(self, conf): + vif_names = findall(r"vif (?:\'*)(\d+)(?:\'*)", conf, M) + vifs_list = None + + if vif_names: + vifs_list = [] + for vif in set(vif_names): + vif_regex = r" %s .+$" % vif + cfg = "\n".join(findall(vif_regex, conf, M)) + obj = self.parse_attribs(["description", "mtu"], cfg) + obj["vlan_id"] = int(vif) + if obj: + vifs_list.append(obj) + vifs_list = sorted(vifs_list, key=lambda i: i["vlan_id"]) + + return vifs_list + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value and item == "mtu": + config[item] = int(value.strip("'")) + elif value: + config[item] = value.strip("'") + else: + config[item] = None + if "disable" in conf: + config["enabled"] = False + else: + config["enabled"] = True + + return utils.remove_empties(config) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py new file mode 100644 index 0000000..d1d62c2 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py @@ -0,0 +1,143 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos l3_interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +import re +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible.module_utils.six import iteritems +from ansible_collections.ansible.netcommon.plugins.module_utils.compat import ( + ipaddress, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.l3_interfaces.l3_interfaces import ( + L3_interfacesArgs, +) + + +class L3_interfacesFacts(object): + """ The vyos l3_interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = L3_interfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for l3_interfaces + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + # operate on a collection of resource x + objs = [] + interface_names = re.findall( + r"set interfaces (?:ethernet|bonding|vti|vxlan) (?:\'*)(\S+)(?:\'*)", + data, + re.M, + ) + if interface_names: + for interface in set(interface_names): + intf_regex = r" %s .+$" % interface + cfg = re.findall(intf_regex, data, re.M) + obj = self.render_config(cfg) + obj["name"] = interface.strip("'") + if obj: + objs.append(obj) + + ansible_facts["ansible_network_resources"].pop("l3_interfaces", None) + facts = {} + if objs: + facts["l3_interfaces"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["l3_interfaces"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys from spec for null values + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + vif_conf = "\n".join(filter(lambda x: ("vif" in x), conf)) + eth_conf = "\n".join(filter(lambda x: ("vif" not in x), conf)) + config = self.parse_attribs(eth_conf) + config["vifs"] = self.parse_vifs(vif_conf) + + return utils.remove_empties(config) + + def parse_vifs(self, conf): + vif_names = re.findall(r"vif (\d+)", conf, re.M) + vifs_list = None + if vif_names: + vifs_list = [] + for vif in set(vif_names): + vif_regex = r" %s .+$" % vif + cfg = "\n".join(re.findall(vif_regex, conf, re.M)) + obj = self.parse_attribs(cfg) + obj["vlan_id"] = vif + if obj: + vifs_list.append(obj) + + return vifs_list + + def parse_attribs(self, conf): + config = {} + ipaddrs = re.findall(r"address (\S+)", conf, re.M) + config["ipv4"] = [] + config["ipv6"] = [] + + for item in ipaddrs: + item = item.strip("'") + if item == "dhcp": + config["ipv4"].append({"address": item}) + elif item == "dhcpv6": + config["ipv6"].append({"address": item}) + else: + ip_version = ipaddress.ip_address(item.split("/")[0]).version + if ip_version == 4: + config["ipv4"].append({"address": item}) + else: + config["ipv6"].append({"address": item}) + + for key, value in iteritems(config): + if value == []: + config[key] = None + + return utils.remove_empties(config) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py new file mode 100644 index 0000000..9201e5c --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py @@ -0,0 +1,152 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos lag_interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from re import findall, search, M +from copy import deepcopy + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lag_interfaces.lag_interfaces import ( + Lag_interfacesArgs, +) + + +class Lag_interfacesFacts(object): + """ The vyos lag_interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Lag_interfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for lag_interfaces + :param module: the module instance + :param connection: the device connection + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + objs = [] + lag_names = findall(r"^set interfaces bonding (\S+)", data, M) + if lag_names: + for lag in set(lag_names): + lag_regex = r" %s .+$" % lag + cfg = findall(lag_regex, data, M) + obj = self.render_config(cfg) + + output = connection.run_commands( + ["show interfaces bonding " + lag + " slaves"] + ) + lines = output[0].splitlines() + members = [] + member = {} + if len(lines) > 1: + for line in lines[2:]: + splitted_line = line.split() + + if len(splitted_line) > 1: + member["member"] = splitted_line[0] + members.append(member) + else: + members = [] + member = {} + obj["name"] = lag.strip("'") + if members: + obj["members"] = members + + if obj: + objs.append(obj) + + facts = {} + if objs: + facts["lag_interfaces"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["lag_interfaces"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + arp_monitor_conf = "\n".join( + filter(lambda x: ("arp-monitor" in x), conf) + ) + hash_policy_conf = "\n".join( + filter(lambda x: ("hash-policy" in x), conf) + ) + lag_conf = "\n".join(filter(lambda x: ("bond" in x), conf)) + config = self.parse_attribs(["mode", "primary"], lag_conf) + config["arp_monitor"] = self.parse_arp_monitor(arp_monitor_conf) + config["hash_policy"] = self.parse_hash_policy(hash_policy_conf) + + return utils.remove_empties(config) + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value: + config[item] = value.strip("'") + else: + config[item] = None + return utils.remove_empties(config) + + def parse_arp_monitor(self, conf): + arp_monitor = None + if conf: + arp_monitor = {} + target_list = [] + interval = search(r"^.*arp-monitor interval (.+)", conf, M) + targets = findall(r"^.*arp-monitor target '(.+)'", conf, M) + if targets: + for target in targets: + target_list.append(target) + arp_monitor["target"] = target_list + if interval: + value = interval.group(1).strip("'") + arp_monitor["interval"] = int(value) + return arp_monitor + + def parse_hash_policy(self, conf): + hash_policy = None + if conf: + hash_policy = search(r"^.*hash-policy (.+)", conf, M) + hash_policy = hash_policy.group(1).strip("'") + return hash_policy diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py new file mode 100644 index 0000000..f6b343e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The VyOS interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +import platform +import re +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + run_commands, + get_capabilities, +) + + +class LegacyFactsBase(object): + + COMMANDS = frozenset() + + def __init__(self, module): + self.module = module + self.facts = dict() + self.warnings = list() + self.responses = None + + def populate(self): + self.responses = run_commands(self.module, list(self.COMMANDS)) + + +class Default(LegacyFactsBase): + + COMMANDS = [ + "show version", + ] + + def populate(self): + super(Default, self).populate() + data = self.responses[0] + self.facts["serialnum"] = self.parse_serialnum(data) + self.facts.update(self.platform_facts()) + + def parse_serialnum(self, data): + match = re.search(r"HW S/N:\s+(\S+)", data) + if match: + return match.group(1) + + def platform_facts(self): + platform_facts = {} + + resp = get_capabilities(self.module) + device_info = resp["device_info"] + + platform_facts["system"] = device_info["network_os"] + + for item in ("model", "image", "version", "platform", "hostname"): + val = device_info.get("network_os_%s" % item) + if val: + platform_facts[item] = val + + platform_facts["api"] = resp["network_api"] + platform_facts["python_version"] = platform.python_version() + + return platform_facts + + +class Config(LegacyFactsBase): + + COMMANDS = [ + "show configuration commands", + "show system commit", + ] + + def populate(self): + super(Config, self).populate() + + self.facts["config"] = self.responses + + commits = self.responses[1] + entries = list() + entry = None + + for line in commits.split("\n"): + match = re.match(r"(\d+)\s+(.+)by(.+)via(.+)", line) + if match: + if entry: + entries.append(entry) + + entry = dict( + revision=match.group(1), + datetime=match.group(2), + by=str(match.group(3)).strip(), + via=str(match.group(4)).strip(), + comment=None, + ) + else: + entry["comment"] = line.strip() + + self.facts["commits"] = entries + + +class Neighbors(LegacyFactsBase): + + COMMANDS = [ + "show lldp neighbors", + "show lldp neighbors detail", + ] + + def populate(self): + super(Neighbors, self).populate() + + all_neighbors = self.responses[0] + if "LLDP not configured" not in all_neighbors: + neighbors = self.parse(self.responses[1]) + self.facts["neighbors"] = self.parse_neighbors(neighbors) + + def parse(self, data): + parsed = list() + values = None + for line in data.split("\n"): + if not line: + continue + elif line[0] == " ": + values += "\n%s" % line + elif line.startswith("Interface"): + if values: + parsed.append(values) + values = line + if values: + parsed.append(values) + return parsed + + def parse_neighbors(self, data): + facts = dict() + for item in data: + interface = self.parse_interface(item) + host = self.parse_host(item) + port = self.parse_port(item) + if interface not in facts: + facts[interface] = list() + facts[interface].append(dict(host=host, port=port)) + return facts + + def parse_interface(self, data): + match = re.search(r"^Interface:\s+(\S+),", data) + return match.group(1) + + def parse_host(self, data): + match = re.search(r"SysName:\s+(.+)$", data, re.M) + if match: + return match.group(1) + + def parse_port(self, data): + match = re.search(r"PortDescr:\s+(.+)$", data, re.M) + if match: + return match.group(1) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py new file mode 100644 index 0000000..3c7e2f9 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py @@ -0,0 +1,116 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos lldp_global fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from re import findall, M +from copy import deepcopy + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_global.lldp_global import ( + Lldp_globalArgs, +) + + +class Lldp_globalFacts(object): + """ The vyos lldp_global fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Lldp_globalArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for lldp_global + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + objs = {} + lldp_output = findall(r"^set service lldp (\S+)", data, M) + if lldp_output: + for item in set(lldp_output): + lldp_regex = r" %s .+$" % item + cfg = findall(lldp_regex, data, M) + obj = self.render_config(cfg) + if obj: + objs.update(obj) + lldp_service = findall(r"^set service (lldp)?('lldp')", data, M) + if lldp_service or lldp_output: + lldp_obj = {} + lldp_obj["enable"] = True + objs.update(lldp_obj) + + facts = {} + params = utils.validate_config(self.argument_spec, {"config": objs}) + facts["lldp_global"] = utils.remove_empties(params["config"]) + + ansible_facts["ansible_network_resources"].update(facts) + + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + protocol_conf = "\n".join( + filter(lambda x: ("legacy-protocols" in x), conf) + ) + att_conf = "\n".join( + filter(lambda x: ("legacy-protocols" not in x), conf) + ) + config = self.parse_attribs(["snmp", "address"], att_conf) + config["legacy_protocols"] = self.parse_protocols(protocol_conf) + return utils.remove_empties(config) + + def parse_protocols(self, conf): + protocol_support = None + if conf: + protocols = findall(r"^.*legacy-protocols (.+)", conf, M) + if protocols: + protocol_support = [] + for protocol in protocols: + protocol_support.append(protocol.strip("'")) + return protocol_support + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value: + config[item] = value.strip("'") + else: + config[item] = None + return utils.remove_empties(config) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py new file mode 100644 index 0000000..dcfbc6e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py @@ -0,0 +1,155 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos lldp_interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +from re import findall, search, M +from copy import deepcopy + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_interfaces.lldp_interfaces import ( + Lldp_interfacesArgs, +) + + +class Lldp_interfacesFacts(object): + """ The vyos lldp_interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Lldp_interfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for lldp_interfaces + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + objs = [] + lldp_names = findall(r"^set service lldp interface (\S+)", data, M) + if lldp_names: + for lldp in set(lldp_names): + lldp_regex = r" %s .+$" % lldp + cfg = findall(lldp_regex, data, M) + obj = self.render_config(cfg) + obj["name"] = lldp.strip("'") + if obj: + objs.append(obj) + facts = {} + if objs: + facts["lldp_interfaces"] = objs + ansible_facts["ansible_network_resources"].update(facts) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + config = {} + location = {} + + civic_conf = "\n".join(filter(lambda x: ("civic-based" in x), conf)) + elin_conf = "\n".join(filter(lambda x: ("elin" in x), conf)) + coordinate_conf = "\n".join( + filter(lambda x: ("coordinate-based" in x), conf) + ) + disable = "\n".join(filter(lambda x: ("disable" in x), conf)) + + coordinate_based_conf = self.parse_attribs( + ["altitude", "datum", "longitude", "latitude"], coordinate_conf + ) + elin_based_conf = self.parse_lldp_elin_based(elin_conf) + civic_based_conf = self.parse_lldp_civic_based(civic_conf) + if disable: + config["enable"] = False + if coordinate_conf: + location["coordinate_based"] = coordinate_based_conf + config["location"] = location + elif civic_based_conf: + location["civic_based"] = civic_based_conf + config["location"] = location + elif elin_conf: + location["elin"] = elin_based_conf + config["location"] = location + + return utils.remove_empties(config) + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value: + value = value.strip("'") + if item == "altitude": + value = int(value) + config[item] = value + else: + config[item] = None + return utils.remove_empties(config) + + def parse_lldp_civic_based(self, conf): + civic_based = None + if conf: + civic_info_list = [] + civic_add_list = findall(r"^.*civic-based ca-type (.+)", conf, M) + if civic_add_list: + for civic_add in civic_add_list: + ca = civic_add.split(" ") + c_add = {} + c_add["ca_type"] = int(ca[0].strip("'")) + c_add["ca_value"] = ca[2].strip("'") + civic_info_list.append(c_add) + + country_code = search( + r"^.*civic-based country-code (.+)", conf, M + ) + civic_based = {} + civic_based["ca_info"] = civic_info_list + civic_based["country_code"] = country_code.group(1).strip("'") + return civic_based + + def parse_lldp_elin_based(self, conf): + elin_based = None + if conf: + e_num = search(r"^.* elin (.+)", conf, M) + elin_based = e_num.group(1).strip("'") + + return elin_based diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py new file mode 100644 index 0000000..0004947 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py @@ -0,0 +1,181 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos static_routes fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from re import findall, search, M +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.static_routes.static_routes import ( + Static_routesArgs, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.utils.utils import ( + get_route_type, +) + + +class Static_routesFacts(object): + """ The vyos static_routes fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Static_routesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def get_device_data(self, connection): + return connection.get_config() + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for static_routes + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = self.get_device_data(connection) + # typically data is populated from the current device configuration + # data = connection.get('show running-config | section ^interface') + # using mock data instead + objs = [] + r_v4 = [] + r_v6 = [] + af = [] + static_routes = findall( + r"set protocols static route(6)? (\S+)", data, M + ) + if static_routes: + for route in set(static_routes): + route_regex = r" %s .+$" % route[1] + cfg = findall(route_regex, data, M) + sr = self.render_config(cfg) + sr["dest"] = route[1].strip("'") + afi = self.get_afi(sr["dest"]) + if afi == "ipv4": + r_v4.append(sr) + else: + r_v6.append(sr) + if r_v4: + afi_v4 = {"afi": "ipv4", "routes": r_v4} + af.append(afi_v4) + if r_v6: + afi_v6 = {"afi": "ipv6", "routes": r_v6} + af.append(afi_v6) + config = {"address_families": af} + if config: + objs.append(config) + + ansible_facts["ansible_network_resources"].pop("static_routes", None) + facts = {} + if objs: + facts["static_routes"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["static_routes"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + next_hops_conf = "\n".join(filter(lambda x: ("next-hop" in x), conf)) + blackhole_conf = "\n".join(filter(lambda x: ("blackhole" in x), conf)) + routes_dict = { + "blackhole_config": self.parse_blackhole(blackhole_conf), + "next_hops": self.parse_next_hop(next_hops_conf), + } + return routes_dict + + def parse_blackhole(self, conf): + blackhole = None + if conf: + distance = search(r"^.*blackhole distance (.\S+)", conf, M) + bh = conf.find("blackhole") + if distance is not None: + blackhole = {} + value = distance.group(1).strip("'") + blackhole["distance"] = int(value) + elif bh: + blackhole = {} + blackhole["type"] = "blackhole" + return blackhole + + def get_afi(self, address): + route_type = get_route_type(address) + if route_type == "route": + return "ipv4" + elif route_type == "route6": + return "ipv6" + + def parse_next_hop(self, conf): + nh_list = None + if conf: + nh_list = [] + hop_list = findall(r"^.*next-hop (.+)", conf, M) + if hop_list: + for hop in hop_list: + distance = search(r"^.*distance (.\S+)", hop, M) + interface = search(r"^.*interface (.\S+)", hop, M) + + dis = hop.find("disable") + hop_info = hop.split(" ") + nh_info = { + "forward_router_address": hop_info[0].strip("'") + } + if interface: + nh_info["interface"] = interface.group(1).strip("'") + if distance: + value = distance.group(1).strip("'") + nh_info["admin_distance"] = int(value) + elif dis >= 1: + nh_info["enabled"] = False + for element in nh_list: + if ( + element["forward_router_address"] + == nh_info["forward_router_address"] + ): + if "interface" in nh_info.keys(): + element["interface"] = nh_info["interface"] + if "admin_distance" in nh_info.keys(): + element["admin_distance"] = nh_info[ + "admin_distance" + ] + if "enabled" in nh_info.keys(): + element["enabled"] = nh_info["enabled"] + nh_info = None + if nh_info is not None: + nh_list.append(nh_info) + return nh_list diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py new file mode 100644 index 0000000..402adfc --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# utils +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from ansible.module_utils.six import iteritems +from ansible_collections.ansible.netcommon.plugins.module_utils.compat import ( + ipaddress, +) + + +def search_obj_in_list(name, lst, key="name"): + for item in lst: + if item[key] == name: + return item + return None + + +def get_interface_type(interface): + """Gets the type of interface + """ + if interface.startswith("eth"): + return "ethernet" + elif interface.startswith("bond"): + return "bonding" + elif interface.startswith("vti"): + return "vti" + elif interface.startswith("lo"): + return "loopback" + + +def dict_delete(base, comparable): + """ + This function generates a dict containing key, value pairs for keys + that are present in the `base` dict but not present in the `comparable` + dict. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + :returns: new dict object with key, value pairs that needs to be deleted. + + """ + to_delete = dict() + + for key in base: + if isinstance(base[key], dict): + sub_diff = dict_delete(base[key], comparable.get(key, {})) + if sub_diff: + to_delete[key] = sub_diff + else: + if key not in comparable: + to_delete[key] = base[key] + + return to_delete + + +def diff_list_of_dicts(want, have): + diff = [] + + set_w = set(tuple(d.items()) for d in want) + set_h = set(tuple(d.items()) for d in have) + difference = set_w.difference(set_h) + + for element in difference: + diff.append(dict((x, y) for x, y in element)) + + return diff + + +def get_lst_diff_for_dicts(want, have, lst): + """ + This function generates a list containing values + that are only in want and not in list in have dict + :param want: dict object to want + :param have: dict object to have + :param lst: list the diff on + :return: new list object with values which are only in want. + """ + if not have: + diff = want.get(lst) or [] + + else: + want_elements = want.get(lst) or {} + have_elements = have.get(lst) or {} + diff = list_diff_want_only(want_elements, have_elements) + return diff + + +def get_lst_same_for_dicts(want, have, lst): + """ + This function generates a list containing values + that are common for list in want and list in have dict + :param want: dict object to want + :param have: dict object to have + :param lst: list the comparison on + :return: new list object with values which are common in want and have. + """ + diff = None + if want and have: + want_list = want.get(lst) or {} + have_list = have.get(lst) or {} + diff = [ + i + for i in want_list and have_list + if i in have_list and i in want_list + ] + return diff + + +def list_diff_have_only(want_list, have_list): + """ + This function generated the list containing values + that are only in have list. + :param want_list: + :param have_list: + :return: new list with values which are only in have list + """ + if have_list and not want_list: + diff = have_list + elif not have_list: + diff = None + else: + diff = [ + i + for i in have_list + want_list + if i in have_list and i not in want_list + ] + return diff + + +def list_diff_want_only(want_list, have_list): + """ + This function generated the list containing values + that are only in want list. + :param want_list: + :param have_list: + :return: new list with values which are only in want list + """ + if have_list and not want_list: + diff = None + elif not have_list: + diff = want_list + else: + diff = [ + i + for i in have_list + want_list + if i in want_list and i not in have_list + ] + return diff + + +def search_dict_tv_in_list(d_val1, d_val2, lst, key1, key2): + """ + This function return the dict object if it exist in list. + :param d_val1: + :param d_val2: + :param lst: + :param key1: + :param key2: + :return: + """ + obj = next( + ( + item + for item in lst + if item[key1] == d_val1 and item[key2] == d_val2 + ), + None, + ) + if obj: + return obj + else: + return None + + +def key_value_in_dict(have_key, have_value, want_dict): + """ + This function checks whether the key and values exist in dict + :param have_key: + :param have_value: + :param want_dict: + :return: + """ + for key, value in iteritems(want_dict): + if key == have_key and value == have_value: + return True + return False + + +def is_dict_element_present(dict, key): + """ + This function checks whether the key is present in dict. + :param dict: + :param key: + :return: + """ + for item in dict: + if item == key: + return True + return False + + +def get_ip_address_version(address): + """ + This function returns the version of IP address + :param address: IP address + :return: + """ + try: + address = unicode(address) + except NameError: + address = str(address) + version = ipaddress.ip_address(address.split("/")[0]).version + return version + + +def get_route_type(address): + """ + This function returns the route type based on IP address + :param address: + :return: + """ + version = get_ip_address_version(address) + if version == 6: + return "route6" + elif version == 4: + return "route" diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py new file mode 100644 index 0000000..908395a --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py @@ -0,0 +1,124 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import env_fallback +from ansible.module_utils.connection import Connection, ConnectionError + +_DEVICE_CONFIGS = {} + +vyos_provider_spec = { + "host": dict(), + "port": dict(type="int"), + "username": dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])), + "password": dict( + fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]), no_log=True + ), + "ssh_keyfile": dict( + fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path" + ), + "timeout": dict(type="int"), +} +vyos_argument_spec = { + "provider": dict( + type="dict", options=vyos_provider_spec, removed_in_version=2.14 + ), +} + + +def get_provider_argspec(): + return vyos_provider_spec + + +def get_connection(module): + if hasattr(module, "_vyos_connection"): + return module._vyos_connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api == "cliconf": + module._vyos_connection = Connection(module._socket_path) + else: + module.fail_json(msg="Invalid connection type %s" % network_api) + + return module._vyos_connection + + +def get_capabilities(module): + if hasattr(module, "_vyos_capabilities"): + return module._vyos_capabilities + + try: + capabilities = Connection(module._socket_path).get_capabilities() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + module._vyos_capabilities = json.loads(capabilities) + return module._vyos_capabilities + + +def get_config(module, flags=None, format=None): + flags = [] if flags is None else flags + global _DEVICE_CONFIGS + + if _DEVICE_CONFIGS != {}: + return _DEVICE_CONFIGS + else: + connection = get_connection(module) + try: + out = connection.get_config(flags=flags, format=format) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + cfg = to_text(out, errors="surrogate_then_replace").strip() + _DEVICE_CONFIGS = cfg + return cfg + + +def run_commands(module, commands, check_rc=True): + connection = get_connection(module) + try: + response = connection.run_commands( + commands=commands, check_rc=check_rc + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + return response + + +def load_config(module, commands, commit=False, comment=None): + connection = get_connection(module) + + try: + response = connection.edit_config( + candidate=commands, commit=commit, comment=comment + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + return response.get("diff") diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py new file mode 100644 index 0000000..1853849 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py @@ -0,0 +1,223 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: vyos_command +author: Nathaniel Case (@Qalthos) +short_description: Run one or more commands on VyOS devices +description: +- The command module allows running one or more commands on remote devices running + VyOS. This module can also be introspected to validate key parameters before returning + successfully. If the conditional statements are not met in the wait period, the + task fails. +- Certain C(show) commands in VyOS produce many lines of output and use a custom pager + that can cause this module to hang. If the value of the environment variable C(ANSIBLE_VYOS_TERMINAL_LENGTH) + is not set, the default number of 10000 is used. +extends_documentation_fragment: +- vyos.vyos.vyos +options: + commands: + description: + - The ordered set of commands to execute on the remote device running VyOS. The + output from the command execution is returned to the playbook. If the I(wait_for) + argument is provided, the module is not returned until the condition is satisfied + or the number of retries has been exceeded. + required: true + wait_for: + description: + - Specifies what to evaluate from the output of the command and what conditionals + to apply. This argument will cause the task to wait for a particular conditional + to be true before moving forward. If the conditional is not true by the configured + I(retries), the task fails. See examples. + aliases: + - waitfor + match: + description: + - The I(match) argument is used in conjunction with the I(wait_for) argument to + specify the match policy. Valid values are C(all) or C(any). If the value is + set to C(all) then all conditionals in the wait_for must be satisfied. If the + value is set to C(any) then only one of the values must be satisfied. + default: all + choices: + - any + - all + retries: + description: + - Specifies the number of retries a command should be tried before it is considered + failed. The command is run on the target device every retry and evaluated against + the I(wait_for) conditionals. + default: 10 + interval: + description: + - Configures the interval in seconds to wait between I(retries) of the command. + If the command does not pass the specified conditions, the interval indicates + how long to wait before trying the command again. + default: 1 +notes: +- Tested against VyOS 1.1.8 (helium). +- Running C(show system boot-messages all) will cause the module to hang since VyOS + is using a custom pager setting to display the output of that command. +- If a command sent to the device requires answering a prompt, it is possible to pass + a dict containing I(command), I(answer) and I(prompt). See examples. +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +""" + +EXAMPLES = """ +tasks: + - name: show configuration on ethernet devices eth0 and eth1 + vyos_command: + commands: + - show interfaces ethernet {{ item }} + with_items: + - eth0 + - eth1 + + - name: run multiple commands and check if version output contains specific version string + vyos_command: + commands: + - show version + - show hardware cpu + wait_for: + - "result[0] contains 'VyOS 1.1.7'" + + - name: run command that requires answering a prompt + vyos_command: + commands: + - command: 'rollback 1' + prompt: 'Proceed with reboot? [confirm][y]' + answer: y +""" + +RETURN = """ +stdout: + description: The set of responses from the commands + returned: always apart from low level errors (such as action plugin) + type: list + sample: ['...', '...'] +stdout_lines: + description: The value of stdout split into a list + returned: always + type: list + sample: [['...', '...'], ['...'], ['...']] +failed_conditions: + description: The list of conditionals that have failed + returned: failed + type: list + sample: ['...', '...'] +warnings: + description: The list of warnings (if any) generated by module based on arguments + returned: always + type: list + sample: ['...', '...'] +""" +import time + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import ( + Conditional, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + transform_commands, + to_lines, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + run_commands, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_argument_spec, +) + + +def parse_commands(module, warnings): + commands = transform_commands(module) + + if module.check_mode: + for item in list(commands): + if not item["command"].startswith("show"): + warnings.append( + "Only show commands are supported when using check mode, not " + "executing %s" % item["command"] + ) + commands.remove(item) + + return commands + + +def main(): + spec = dict( + commands=dict(type="list", required=True), + wait_for=dict(type="list", aliases=["waitfor"]), + match=dict(default="all", choices=["all", "any"]), + retries=dict(default=10, type="int"), + interval=dict(default=1, type="int"), + ) + + spec.update(vyos_argument_spec) + + module = AnsibleModule(argument_spec=spec, supports_check_mode=True) + + warnings = list() + result = {"changed": False, "warnings": warnings} + commands = parse_commands(module, warnings) + wait_for = module.params["wait_for"] or list() + + try: + conditionals = [Conditional(c) for c in wait_for] + except AttributeError as exc: + module.fail_json(msg=to_text(exc)) + + retries = module.params["retries"] + interval = module.params["interval"] + match = module.params["match"] + + for _ in range(retries): + responses = run_commands(module, commands) + + for item in list(conditionals): + if item(responses): + if match == "any": + conditionals = list() + break + conditionals.remove(item) + + if not conditionals: + break + + time.sleep(interval) + + if conditionals: + failed_conditions = [item.raw for item in conditionals] + msg = "One or more conditional statements have not been satisfied" + module.fail_json(msg=msg, failed_conditions=failed_conditions) + + result.update( + {"stdout": responses, "stdout_lines": list(to_lines(responses)),} + ) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py new file mode 100644 index 0000000..b899045 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py @@ -0,0 +1,354 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: vyos_config +author: Nathaniel Case (@Qalthos) +short_description: Manage VyOS configuration on remote device +description: +- This module provides configuration file management of VyOS devices. It provides + arguments for managing both the configuration file and state of the active configuration. + All configuration statements are based on `set` and `delete` commands in the device + configuration. +extends_documentation_fragment: +- vyos.vyos.vyos +notes: +- Tested against VyOS 1.1.8 (helium). +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +options: + lines: + description: + - The ordered set of configuration lines to be managed and compared with the existing + configuration on the remote device. + src: + description: + - The C(src) argument specifies the path to the source config file to load. The + source config file can either be in bracket format or set format. The source + file can include Jinja2 template variables. + match: + description: + - The C(match) argument controls the method used to match against the current + active configuration. By default, the desired config is matched against the + active config and the deltas are loaded. If the C(match) argument is set to + C(none) the active configuration is ignored and the configuration is always + loaded. + default: line + choices: + - line + - none + backup: + description: + - The C(backup) argument will backup the current devices active configuration + to the Ansible control host prior to making any changes. If the C(backup_options) + value is not given, the backup file will be located in the backup folder in + the playbook root directory or role root directory, if playbook is part of an + ansible role. If the directory does not exist, it is created. + type: bool + default: 'no' + comment: + description: + - Allows a commit description to be specified to be included when the configuration + is committed. If the configuration is not changed or committed, this argument + is ignored. + default: configured by vyos_config + config: + description: + - The C(config) argument specifies the base configuration to use to compare against + the desired configuration. If this value is not specified, the module will + automatically retrieve the current active configuration from the remote device. + save: + description: + - The C(save) argument controls whether or not changes made to the active configuration + are saved to disk. This is independent of committing the config. When set + to True, the active configuration is saved. + type: bool + default: 'no' + backup_options: + description: + - This is a dict object containing configurable options related to backup file + path. The value of this option is read only when C(backup) is set to I(yes), + if C(backup) is set to I(no) this option will be silently ignored. + suboptions: + filename: + description: + - The filename to be used to store the backup configuration. If the filename + is not given it will be generated based on the hostname, current time and + date in format defined by <hostname>_config.<current-date>@<current-time> + dir_path: + description: + - This option provides the path ending with directory name in which the backup + configuration file will be stored. If the directory does not exist it will + be first created and the filename is either the value of C(filename) or + default filename as described in C(filename) options description. If the + path value is not given in that case a I(backup) directory will be created + in the current working directory and backup configuration will be copied + in C(filename) within I(backup) directory. + type: path + type: dict +""" + +EXAMPLES = """ +- name: configure the remote device + vyos_config: + lines: + - set system host-name {{ inventory_hostname }} + - set service lldp + - delete service dhcp-server + +- name: backup and load from file + vyos_config: + src: vyos.cfg + backup: yes + +- name: render a Jinja2 template onto the VyOS router + vyos_config: + src: vyos_template.j2 + +- name: for idempotency, use full-form commands + vyos_config: + lines: + # - set int eth eth2 description 'OUTSIDE' + - set interface ethernet eth2 description 'OUTSIDE' + +- name: configurable backup path + vyos_config: + backup: yes + backup_options: + filename: backup.cfg + dir_path: /home/user +""" + +RETURN = """ +commands: + description: The list of configuration commands sent to the device + returned: always + type: list + sample: ['...', '...'] +filtered: + description: The list of configuration commands removed to avoid a load failure + returned: always + type: list + sample: ['...', '...'] +backup_path: + description: The full path to the backup file + returned: when backup is yes + type: str + sample: /playbooks/ansible/backup/vyos_config.2016-07-16@22:28:34 +filename: + description: The name of the backup file + returned: when backup is yes and filename is not specified in backup options + type: str + sample: vyos_config.2016-07-16@22:28:34 +shortname: + description: The full path to the backup file excluding the timestamp + returned: when backup is yes and filename is not specified in backup options + type: str + sample: /playbooks/ansible/backup/vyos_config +date: + description: The date extracted from the backup file name + returned: when backup is yes + type: str + sample: "2016-07-16" +time: + description: The time extracted from the backup file name + returned: when backup is yes + type: str + sample: "22:28:34" +""" +import re + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.connection import ConnectionError +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + load_config, + get_config, + run_commands, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_argument_spec, + get_connection, +) + + +DEFAULT_COMMENT = "configured by vyos_config" + +CONFIG_FILTERS = [ + re.compile(r"set system login user \S+ authentication encrypted-password") +] + + +def get_candidate(module): + contents = module.params["src"] or module.params["lines"] + + if module.params["src"]: + contents = format_commands(contents.splitlines()) + + contents = "\n".join(contents) + return contents + + +def format_commands(commands): + """ + This function format the input commands and removes the prepend white spaces + for command lines having 'set' or 'delete' and it skips empty lines. + :param commands: + :return: list of commands + """ + return [ + line.strip() if line.split()[0] in ("set", "delete") else line + for line in commands + if len(line.strip()) > 0 + ] + + +def diff_config(commands, config): + config = [str(c).replace("'", "") for c in config.splitlines()] + + updates = list() + visited = set() + + for line in commands: + item = str(line).replace("'", "") + + if not item.startswith("set") and not item.startswith("delete"): + raise ValueError("line must start with either `set` or `delete`") + + elif item.startswith("set") and item not in config: + updates.append(line) + + elif item.startswith("delete"): + if not config: + updates.append(line) + else: + item = re.sub(r"delete", "set", item) + for entry in config: + if entry.startswith(item) and line not in visited: + updates.append(line) + visited.add(line) + + return list(updates) + + +def sanitize_config(config, result): + result["filtered"] = list() + index_to_filter = list() + for regex in CONFIG_FILTERS: + for index, line in enumerate(list(config)): + if regex.search(line): + result["filtered"].append(line) + index_to_filter.append(index) + # Delete all filtered configs + for filter_index in sorted(index_to_filter, reverse=True): + del config[filter_index] + + +def run(module, result): + # get the current active config from the node or passed in via + # the config param + config = module.params["config"] or get_config(module) + + # create the candidate config object from the arguments + candidate = get_candidate(module) + + # create loadable config that includes only the configuration updates + connection = get_connection(module) + try: + response = connection.get_diff( + candidate=candidate, + running=config, + diff_match=module.params["match"], + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + commands = response.get("config_diff") + sanitize_config(commands, result) + + result["commands"] = commands + + commit = not module.check_mode + comment = module.params["comment"] + + diff = None + if commands: + diff = load_config(module, commands, commit=commit, comment=comment) + + if result.get("filtered"): + result["warnings"].append( + "Some configuration commands were " + "removed, please see the filtered key" + ) + + result["changed"] = True + + if module._diff: + result["diff"] = {"prepared": diff} + + +def main(): + backup_spec = dict(filename=dict(), dir_path=dict(type="path")) + argument_spec = dict( + src=dict(type="path"), + lines=dict(type="list"), + match=dict(default="line", choices=["line", "none"]), + comment=dict(default=DEFAULT_COMMENT), + config=dict(), + backup=dict(type="bool", default=False), + backup_options=dict(type="dict", options=backup_spec), + save=dict(type="bool", default=False), + ) + + argument_spec.update(vyos_argument_spec) + + mutually_exclusive = [("lines", "src")] + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=mutually_exclusive, + supports_check_mode=True, + ) + + warnings = list() + + result = dict(changed=False, warnings=warnings) + + if module.params["backup"]: + result["__backup__"] = get_config(module=module) + + if any((module.params["src"], module.params["lines"])): + run(module, result) + + if module.params["save"]: + diff = run_commands(module, commands=["configure", "compare saved"])[1] + if diff != "[edit]": + run_commands(module, commands=["save"]) + result["changed"] = True + run_commands(module, commands=["exit"]) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py new file mode 100644 index 0000000..19fb727 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py @@ -0,0 +1,174 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The module file for vyos_facts +""" + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": [u"preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: vyos_facts +short_description: Get facts about vyos devices. +description: +- Collects facts from network devices running the vyos operating system. This module + places the facts gathered in the fact tree keyed by the respective resource name. The + facts module will always collect a base set of facts from the device and can enable + or disable collection of additional facts. +author: +- Nathaniel Case (@qalthos) +- Nilashish Chakraborty (@Nilashishc) +- Rohit Thakur (@rohitthakur2590) +extends_documentation_fragment: +- vyos.vyos.vyos +notes: +- Tested against VyOS 1.1.8 (helium). +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +options: + gather_subset: + description: + - When supplied, this argument will restrict the facts collected to a given subset. Possible + values for this argument include all, default, config, and neighbors. Can specify + a list of values to include a larger subset. Values can also be used with an + initial C(M(!)) to specify that a specific subset should not be collected. + required: false + default: '!config' + gather_network_resources: + description: + - When supplied, this argument will restrict the facts collected to a given subset. + Possible values for this argument include all and the resources like interfaces. + Can specify a list of values to include a larger subset. Values can also be + used with an initial C(M(!)) to specify that a specific subset should not be + collected. Valid subsets are 'all', 'interfaces', 'l3_interfaces', 'lag_interfaces', + 'lldp_global', 'lldp_interfaces', 'static_routes', 'firewall_rules'. + required: false +""" + +EXAMPLES = """ +# Gather all facts +- vyos_facts: + gather_subset: all + gather_network_resources: all + +# collect only the config and default facts +- vyos_facts: + gather_subset: config + +# collect everything exception the config +- vyos_facts: + gather_subset: "!config" + +# Collect only the interfaces facts +- vyos_facts: + gather_subset: + - '!all' + - '!min' + gather_network_resources: + - interfaces + +# Do not collect interfaces facts +- vyos_facts: + gather_network_resources: + - "!interfaces" + +# Collect interfaces and minimal default facts +- vyos_facts: + gather_subset: min + gather_network_resources: interfaces +""" + +RETURN = """ +ansible_net_config: + description: The running-config from the device + returned: when config is configured + type: str +ansible_net_commits: + description: The set of available configuration revisions + returned: when present + type: list +ansible_net_hostname: + description: The configured system hostname + returned: always + type: str +ansible_net_model: + description: The device model string + returned: always + type: str +ansible_net_serialnum: + description: The serial number of the device + returned: always + type: str +ansible_net_version: + description: The version of the software running + returned: always + type: str +ansible_net_neighbors: + description: The set of LLDP neighbors + returned: when interface is configured + type: list +ansible_net_gather_subset: + description: The list of subsets gathered by the module + returned: always + type: list +ansible_net_api: + description: The name of the transport + returned: always + type: str +ansible_net_python_version: + description: The Python version Ansible controller is using + returned: always + type: str +ansible_net_gather_network_resources: + description: The list of fact resource subsets collected from the device + returned: always + type: list +""" + +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.facts.facts import ( + FactsArgs, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.facts import ( + Facts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_argument_spec, +) + + +def main(): + """ + Main entry point for module execution + + :returns: ansible_facts + """ + argument_spec = FactsArgs.argument_spec + argument_spec.update(vyos_argument_spec) + + module = AnsibleModule( + argument_spec=argument_spec, supports_check_mode=True + ) + + warnings = [] + if module.params["gather_subset"] == "!config": + warnings.append( + "default value for `gather_subset` will be changed to `min` from `!config` v2.11 onwards" + ) + + result = Facts(module).get_facts() + + ansible_facts, additional_warnings = result + warnings.extend(additional_warnings) + + module.exit_json(ansible_facts=ansible_facts, warnings=warnings) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py new file mode 100644 index 0000000..8fe572b --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py @@ -0,0 +1,513 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# + +""" +The module file for vyos_lldp_interfaces +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + +DOCUMENTATION = """module: vyos_lldp_interfaces +short_description: Manages attributes of lldp interfaces on VyOS devices. +description: This module manages attributes of lldp interfaces on VyOS network devices. +notes: +- Tested against VyOS 1.1.8 (helium). +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +author: +- Rohit Thakur (@rohitthakur2590) +options: + config: + description: A list of lldp interfaces configurations. + type: list + suboptions: + name: + description: + - Name of the lldp interface. + type: str + required: true + enable: + description: + - to disable lldp on the interface. + type: bool + default: true + location: + description: + - LLDP-MED location data. + type: dict + suboptions: + civic_based: + description: + - Civic-based location data. + type: dict + suboptions: + ca_info: + description: LLDP-MED address info + type: list + suboptions: + ca_type: + description: LLDP-MED Civic Address type. + type: int + required: true + ca_value: + description: LLDP-MED Civic Address value. + type: str + required: true + country_code: + description: Country Code + type: str + required: true + coordinate_based: + description: + - Coordinate-based location. + type: dict + suboptions: + altitude: + description: Altitude in meters. + type: int + datum: + description: Coordinate datum type. + type: str + choices: + - WGS84 + - NAD83 + - MLLW + latitude: + description: Latitude. + type: str + required: true + longitude: + description: Longitude. + type: str + required: true + elin: + description: Emergency Call Service ELIN number (between 10-25 numbers). + type: str + state: + description: + - The state of the configuration after module completion. + type: str + choices: + - merged + - replaced + - overridden + - deleted + default: merged +""" +EXAMPLES = """ +# Using merged +# +# Before state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# +- name: Merge provided configuration with device configuration + vyos_lldp_interfaces: + config: + - name: 'eth1' + location: + civic_based: + country_code: 'US' + ca_info: + - ca_type: 0 + ca_value: 'ENGLISH' + + - name: 'eth2' + location: + coordinate_based: + altitude: 2200 + datum: 'WGS84' + longitude: '222.267255W' + latitude: '33.524449N' + state: merged +# +# +# ------------------------- +# Module Execution Result +# ------------------------- +# +# before": [] +# +# "commands": [ +# "set service lldp interface eth1 location civic-based country-code 'US'", +# "set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH'", +# "set service lldp interface eth1", +# "set service lldp interface eth2 location coordinate-based latitude '33.524449N'", +# "set service lldp interface eth2 location coordinate-based altitude '2200'", +# "set service lldp interface eth2 location coordinate-based datum 'WGS84'", +# "set service lldp interface eth2 location coordinate-based longitude '222.267255W'", +# "set service lldp interface eth2 location coordinate-based latitude '33.524449N'", +# "set service lldp interface eth2 location coordinate-based altitude '2200'", +# "set service lldp interface eth2 location coordinate-based datum 'WGS84'", +# "set service lldp interface eth2 location coordinate-based longitude '222.267255W'", +# "set service lldp interface eth2" +# +# "after": [ +# { +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth2" +# }, +# { +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth1" +# } +# ], +# +# After state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth1 location civic-based country-code 'US' +# set service lldp interface eth2 location coordinate-based altitude '2200' +# set service lldp interface eth2 location coordinate-based datum 'WGS84' +# set service lldp interface eth2 location coordinate-based latitude '33.524449N' +# set service lldp interface eth2 location coordinate-based longitude '222.267255W' + + +# Using replaced +# +# Before state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth1 location civic-based country-code 'US' +# set service lldp interface eth2 location coordinate-based altitude '2200' +# set service lldp interface eth2 location coordinate-based datum 'WGS84' +# set service lldp interface eth2 location coordinate-based latitude '33.524449N' +# set service lldp interface eth2 location coordinate-based longitude '222.267255W' +# +- name: Replace device configurations of listed LLDP interfaces with provided configurations + vyos_lldp_interfaces: + config: + - name: 'eth2' + location: + civic_based: + country_code: 'US' + ca_info: + - ca_type: 0 + ca_value: 'ENGLISH' + + - name: 'eth1' + location: + coordinate_based: + altitude: 2200 + datum: 'WGS84' + longitude: '222.267255W' + latitude: '33.524449N' + state: replaced +# +# +# ------------------------- +# Module Execution Result +# ------------------------- +# +# "before": [ +# { +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth2" +# }, +# { +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth1" +# } +# ] +# +# "commands": [ +# "delete service lldp interface eth2 location", +# "set service lldp interface eth2 'disable'", +# "set service lldp interface eth2 location civic-based country-code 'US'", +# "set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH'", +# "delete service lldp interface eth1 location", +# "set service lldp interface eth1 'disable'", +# "set service lldp interface eth1 location coordinate-based latitude '33.524449N'", +# "set service lldp interface eth1 location coordinate-based altitude '2200'", +# "set service lldp interface eth1 location coordinate-based datum 'WGS84'", +# "set service lldp interface eth1 location coordinate-based longitude '222.267255W'" +# ] +# +# "after": [ +# { +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth2" +# }, +# { +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth1" +# } +# ] +# +# After state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 'disable' +# set service lldp interface eth1 location coordinate-based altitude '2200' +# set service lldp interface eth1 location coordinate-based datum 'WGS84' +# set service lldp interface eth1 location coordinate-based latitude '33.524449N' +# set service lldp interface eth1 location coordinate-based longitude '222.267255W' +# set service lldp interface eth2 'disable' +# set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth2 location civic-based country-code 'US' + + +# Using overridden +# +# Before state +# -------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 'disable' +# set service lldp interface eth1 location coordinate-based altitude '2200' +# set service lldp interface eth1 location coordinate-based datum 'WGS84' +# set service lldp interface eth1 location coordinate-based latitude '33.524449N' +# set service lldp interface eth1 location coordinate-based longitude '222.267255W' +# set service lldp interface eth2 'disable' +# set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth2 location civic-based country-code 'US' +# +- name: Overrides all device configuration with provided configuration + vyos_lag_interfaces: + config: + - name: 'eth2' + location: + elin: 0000000911 + + state: overridden +# +# +# ------------------------- +# Module Execution Result +# ------------------------- +# +# "before": [ +# { +# "enable": false, +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth2" +# }, +# { +# "enable": false, +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth1" +# } +# ] +# +# "commands": [ +# "delete service lldp interface eth2 location", +# "delete service lldp interface eth2 disable", +# "set service lldp interface eth2 location elin 0000000911" +# +# +# "after": [ +# { +# "location": { +# "elin": 0000000911 +# }, +# "name": "eth2" +# } +# ] +# +# +# After state +# ------------ +# +# vyos@vyos# run show configuration commands | grep lldp +# set service lldp interface eth2 location elin '0000000911' + + +# Using deleted +# +# Before state +# ------------- +# +# vyos@vyos# run show configuration commands | grep lldp +# set service lldp interface eth2 location elin '0000000911' +# +- name: Delete lldp interface attributes of given interfaces. + vyos_lag_interfaces: + config: + - name: 'eth2' + state: deleted +# +# +# ------------------------ +# Module Execution Results +# ------------------------ +# + "before": [ + { + "location": { + "elin": 0000000911 + }, + "name": "eth2" + } + ] +# "commands": [ +# "commands": [ +# "delete service lldp interface eth2" +# ] +# +# "after": [] +# After state +# ------------ +# vyos@vyos# run show configuration commands | grep lldp +# set service 'lldp' + + +""" +RETURN = """ +before: + description: The configuration as structured data prior to module invocation. + returned: always + type: list + sample: > + The configuration returned will always be in the same format + of the parameters above. +after: + description: The configuration as structured data after module completion. + returned: when changed + type: list + sample: > + The configuration returned will always be in the same format + of the parameters above. +commands: + description: The set of commands pushed to the remote device. + returned: always + type: list + sample: + - "set service lldp interface eth2 'disable'" + - "delete service lldp interface eth1 location" +""" + + +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_interfaces.lldp_interfaces import ( + Lldp_interfacesArgs, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.config.lldp_interfaces.lldp_interfaces import ( + Lldp_interfaces, +) + + +def main(): + """ + Main entry point for module execution + + :returns: the result form module invocation + """ + required_if = [ + ("state", "merged", ("config",)), + ("state", "replaced", ("config",)), + ("state", "overridden", ("config",)), + ] + module = AnsibleModule( + argument_spec=Lldp_interfacesArgs.argument_spec, + required_if=required_if, + supports_check_mode=True, + ) + + result = Lldp_interfaces(module).execute_module() + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py new file mode 100644 index 0000000..fe7712f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py @@ -0,0 +1,53 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import re + +from ansible.plugins.terminal import TerminalBase +from ansible.errors import AnsibleConnectionFailure + + +class TerminalModule(TerminalBase): + + terminal_stdout_re = [ + re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"), + re.compile(br"\@[\w\-\.]+:\S+?[>#\$] ?$"), + ] + + terminal_stderr_re = [ + re.compile(br"\n\s*Invalid command:"), + re.compile(br"\nCommit failed"), + re.compile(br"\n\s+Set failed"), + ] + + terminal_length = os.getenv("ANSIBLE_VYOS_TERMINAL_LENGTH", 10000) + + def on_open_shell(self): + try: + for cmd in (b"set terminal length 0", b"set terminal width 512"): + self._exec_cli_command(cmd) + self._exec_cli_command( + b"set terminal length %d" % self.terminal_length + ) + except AnsibleConnectionFailure: + raise AnsibleConnectionFailure("unable to set terminal parameters") diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py new file mode 100644 index 0000000..adb918b --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py @@ -0,0 +1,522 @@ +# This file is part of Ansible + +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import base64 +import json +import os +import os.path +import shutil +import tempfile +import traceback +import zipfile + +from ansible import constants as C +from ansible.errors import AnsibleError, AnsibleFileNotFound +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.parsing.convert_bool import boolean +from ansible.plugins.action import ActionBase +from ansible.utils.hashing import checksum + + +def _walk_dirs(topdir, loader, decrypt=True, base_path=None, local_follow=False, trailing_slash_detector=None, checksum_check=False): + """ + Walk a filesystem tree returning enough information to copy the files. + This is similar to the _walk_dirs function in ``copy.py`` but returns + a dict instead of a tuple for each entry and includes the checksum of + a local file if wanted. + + :arg topdir: The directory that the filesystem tree is rooted at + :arg loader: The self._loader object from ActionBase + :kwarg decrypt: Whether to decrypt a file encrypted with ansible-vault + :kwarg base_path: The initial directory structure to strip off of the + files for the destination directory. If this is None (the default), + the base_path is set to ``top_dir``. + :kwarg local_follow: Whether to follow symlinks on the source. When set + to False, no symlinks are dereferenced. When set to True (the + default), the code will dereference most symlinks. However, symlinks + can still be present if needed to break a circular link. + :kwarg trailing_slash_detector: Function to determine if a path has + a trailing directory separator. Only needed when dealing with paths on + a remote machine (in which case, pass in a function that is aware of the + directory separator conventions on the remote machine). + :kawrg whether to get the checksum of the local file and add to the dict + :returns: dictionary of dictionaries. All of the path elements in the structure are text string. + This separates all the files, directories, and symlinks along with + import information about each:: + + { + 'files'; [{ + src: '/absolute/path/to/copy/from', + dest: 'relative/path/to/copy/to', + checksum: 'b54ba7f5621240d403f06815f7246006ef8c7d43' + }, ...], + 'directories'; [{ + src: '/absolute/path/to/copy/from', + dest: 'relative/path/to/copy/to' + }, ...], + 'symlinks'; [{ + src: '/symlink/target/path', + dest: 'relative/path/to/copy/to' + }, ...], + + } + + The ``symlinks`` field is only populated if ``local_follow`` is set to False + *or* a circular symlink cannot be dereferenced. The ``checksum`` entry is set + to None if checksum_check=False. + + """ + # Convert the path segments into byte strings + + r_files = {'files': [], 'directories': [], 'symlinks': []} + + def _recurse(topdir, rel_offset, parent_dirs, rel_base=u'', checksum_check=False): + """ + This is a closure (function utilizing variables from it's parent + function's scope) so that we only need one copy of all the containers. + Note that this function uses side effects (See the Variables used from + outer scope). + + :arg topdir: The directory we are walking for files + :arg rel_offset: Integer defining how many characters to strip off of + the beginning of a path + :arg parent_dirs: Directories that we're copying that this directory is in. + :kwarg rel_base: String to prepend to the path after ``rel_offset`` is + applied to form the relative path. + + Variables used from the outer scope + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + :r_files: Dictionary of files in the hierarchy. See the return value + for :func:`walk` for the structure of this dictionary. + :local_follow: Read-only inside of :func:`_recurse`. Whether to follow symlinks + """ + for base_path, sub_folders, files in os.walk(topdir): + for filename in files: + filepath = os.path.join(base_path, filename) + dest_filepath = os.path.join(rel_base, filepath[rel_offset:]) + + if os.path.islink(filepath): + # Dereference the symlnk + real_file = loader.get_real_file(os.path.realpath(filepath), decrypt=decrypt) + if local_follow and os.path.isfile(real_file): + # Add the file pointed to by the symlink + r_files['files'].append( + { + "src": real_file, + "dest": dest_filepath, + "checksum": _get_local_checksum(checksum_check, real_file) + } + ) + else: + # Mark this file as a symlink to copy + r_files['symlinks'].append({"src": os.readlink(filepath), "dest": dest_filepath}) + else: + # Just a normal file + real_file = loader.get_real_file(filepath, decrypt=decrypt) + r_files['files'].append( + { + "src": real_file, + "dest": dest_filepath, + "checksum": _get_local_checksum(checksum_check, real_file) + } + ) + + for dirname in sub_folders: + dirpath = os.path.join(base_path, dirname) + dest_dirpath = os.path.join(rel_base, dirpath[rel_offset:]) + real_dir = os.path.realpath(dirpath) + dir_stats = os.stat(real_dir) + + if os.path.islink(dirpath): + if local_follow: + if (dir_stats.st_dev, dir_stats.st_ino) in parent_dirs: + # Just insert the symlink if the target directory + # exists inside of the copy already + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Walk the dirpath to find all parent directories. + new_parents = set() + parent_dir_list = os.path.dirname(dirpath).split(os.path.sep) + for parent in range(len(parent_dir_list), 0, -1): + parent_stat = os.stat(u'/'.join(parent_dir_list[:parent])) + if (parent_stat.st_dev, parent_stat.st_ino) in parent_dirs: + # Reached the point at which the directory + # tree is already known. Don't add any + # more or we might go to an ancestor that + # isn't being copied. + break + new_parents.add((parent_stat.st_dev, parent_stat.st_ino)) + + if (dir_stats.st_dev, dir_stats.st_ino) in new_parents: + # This was a a circular symlink. So add it as + # a symlink + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Walk the directory pointed to by the symlink + r_files['directories'].append({"src": real_dir, "dest": dest_dirpath}) + offset = len(real_dir) + 1 + _recurse(real_dir, offset, parent_dirs.union(new_parents), + rel_base=dest_dirpath, + checksum_check=checksum_check) + else: + # Add the symlink to the destination + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Just a normal directory + r_files['directories'].append({"src": dirpath, "dest": dest_dirpath}) + + # Check if the source ends with a "/" so that we know which directory + # level to work at (similar to rsync) + source_trailing_slash = False + if trailing_slash_detector: + source_trailing_slash = trailing_slash_detector(topdir) + else: + source_trailing_slash = topdir.endswith(os.path.sep) + + # Calculate the offset needed to strip the base_path to make relative + # paths + if base_path is None: + base_path = topdir + if not source_trailing_slash: + base_path = os.path.dirname(base_path) + if topdir.startswith(base_path): + offset = len(base_path) + + # Make sure we're making the new paths relative + if trailing_slash_detector and not trailing_slash_detector(base_path): + offset += 1 + elif not base_path.endswith(os.path.sep): + offset += 1 + + if os.path.islink(topdir) and not local_follow: + r_files['symlinks'] = {"src": os.readlink(topdir), "dest": os.path.basename(topdir)} + return r_files + + dir_stats = os.stat(topdir) + parents = frozenset(((dir_stats.st_dev, dir_stats.st_ino),)) + # Actually walk the directory hierarchy + _recurse(topdir, offset, parents, checksum_check=checksum_check) + + return r_files + + +def _get_local_checksum(get_checksum, local_path): + if get_checksum: + return checksum(local_path) + else: + return None + + +class ActionModule(ActionBase): + + WIN_PATH_SEPARATOR = "\\" + + def _create_content_tempfile(self, content): + ''' Create a tempfile containing defined content ''' + fd, content_tempfile = tempfile.mkstemp(dir=C.DEFAULT_LOCAL_TMP) + f = os.fdopen(fd, 'wb') + content = to_bytes(content) + try: + f.write(content) + except Exception as err: + os.remove(content_tempfile) + raise Exception(err) + finally: + f.close() + return content_tempfile + + def _create_zip_tempfile(self, files, directories): + tmpdir = tempfile.mkdtemp(dir=C.DEFAULT_LOCAL_TMP) + zip_file_path = os.path.join(tmpdir, "win_copy.zip") + zip_file = zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_STORED, True) + + # encoding the file/dir name with base64 so Windows can unzip a unicode + # filename and get the right name, Windows doesn't handle unicode names + # very well + for directory in directories: + directory_path = to_bytes(directory['src'], errors='surrogate_or_strict') + archive_path = to_bytes(directory['dest'], errors='surrogate_or_strict') + + encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict') + zip_file.write(directory_path, encoded_path, zipfile.ZIP_DEFLATED) + + for file in files: + file_path = to_bytes(file['src'], errors='surrogate_or_strict') + archive_path = to_bytes(file['dest'], errors='surrogate_or_strict') + + encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict') + zip_file.write(file_path, encoded_path, zipfile.ZIP_DEFLATED) + + return zip_file_path + + def _remove_tempfile_if_content_defined(self, content, content_tempfile): + if content is not None: + os.remove(content_tempfile) + + def _copy_single_file(self, local_file, dest, source_rel, task_vars, tmp, backup): + if self._play_context.check_mode: + module_return = dict(changed=True) + return module_return + + # copy the file across to the server + tmp_src = self._connection._shell.join_path(tmp, 'source') + self._transfer_file(local_file, tmp_src) + + copy_args = self._task.args.copy() + copy_args.update( + dict( + dest=dest, + src=tmp_src, + _original_basename=source_rel, + _copy_mode="single", + backup=backup, + ) + ) + copy_args.pop('content', None) + + copy_result = self._execute_module(module_name="copy", + module_args=copy_args, + task_vars=task_vars) + + return copy_result + + def _copy_zip_file(self, dest, files, directories, task_vars, tmp, backup): + # create local zip file containing all the files and directories that + # need to be copied to the server + if self._play_context.check_mode: + module_return = dict(changed=True) + return module_return + + try: + zip_file = self._create_zip_tempfile(files, directories) + except Exception as e: + module_return = dict( + changed=False, + failed=True, + msg="failed to create tmp zip file: %s" % to_text(e), + exception=traceback.format_exc() + ) + return module_return + + zip_path = self._loader.get_real_file(zip_file) + + # send zip file to remote, file must end in .zip so + # Com Shell.Application works + tmp_src = self._connection._shell.join_path(tmp, 'source.zip') + self._transfer_file(zip_path, tmp_src) + + # run the explode operation of win_copy on remote + copy_args = self._task.args.copy() + copy_args.update( + dict( + src=tmp_src, + dest=dest, + _copy_mode="explode", + backup=backup, + ) + ) + copy_args.pop('content', None) + module_return = self._execute_module(module_name='copy', + module_args=copy_args, + task_vars=task_vars) + shutil.rmtree(os.path.dirname(zip_path)) + return module_return + + def run(self, tmp=None, task_vars=None): + ''' handler for file transfer operations ''' + if task_vars is None: + task_vars = dict() + + result = super(ActionModule, self).run(tmp, task_vars) + del tmp # tmp no longer has any effect + + source = self._task.args.get('src', None) + content = self._task.args.get('content', None) + dest = self._task.args.get('dest', None) + remote_src = boolean(self._task.args.get('remote_src', False), strict=False) + local_follow = boolean(self._task.args.get('local_follow', False), strict=False) + force = boolean(self._task.args.get('force', True), strict=False) + decrypt = boolean(self._task.args.get('decrypt', True), strict=False) + backup = boolean(self._task.args.get('backup', False), strict=False) + + result['src'] = source + result['dest'] = dest + + result['failed'] = True + if (source is None and content is None) or dest is None: + result['msg'] = "src (or content) and dest are required" + elif source is not None and content is not None: + result['msg'] = "src and content are mutually exclusive" + elif content is not None and dest is not None and ( + dest.endswith(os.path.sep) or dest.endswith(self.WIN_PATH_SEPARATOR)): + result['msg'] = "dest must be a file if content is defined" + else: + del result['failed'] + + if result.get('failed'): + return result + + # If content is defined make a temp file and write the content into it + content_tempfile = None + if content is not None: + try: + # if content comes to us as a dict it should be decoded json. + # We need to encode it back into a string and write it out + if isinstance(content, dict) or isinstance(content, list): + content_tempfile = self._create_content_tempfile(json.dumps(content)) + else: + content_tempfile = self._create_content_tempfile(content) + source = content_tempfile + except Exception as err: + result['failed'] = True + result['msg'] = "could not write content tmp file: %s" % to_native(err) + return result + # all actions should occur on the remote server, run win_copy module + elif remote_src: + new_module_args = self._task.args.copy() + new_module_args.update( + dict( + _copy_mode="remote", + dest=dest, + src=source, + force=force, + backup=backup, + ) + ) + new_module_args.pop('content', None) + result.update(self._execute_module(module_args=new_module_args, task_vars=task_vars)) + return result + # find_needle returns a path that may not have a trailing slash on a + # directory so we need to find that out first and append at the end + else: + trailing_slash = source.endswith(os.path.sep) + try: + # find in expected paths + source = self._find_needle('files', source) + except AnsibleError as e: + result['failed'] = True + result['msg'] = to_text(e) + result['exception'] = traceback.format_exc() + return result + + if trailing_slash != source.endswith(os.path.sep): + if source[-1] == os.path.sep: + source = source[:-1] + else: + source = source + os.path.sep + + # A list of source file tuples (full_path, relative_path) which will try to copy to the destination + source_files = {'files': [], 'directories': [], 'symlinks': []} + + # If source is a directory populate our list else source is a file and translate it to a tuple. + if os.path.isdir(to_bytes(source, errors='surrogate_or_strict')): + result['operation'] = 'folder_copy' + + # Get a list of the files we want to replicate on the remote side + source_files = _walk_dirs(source, self._loader, decrypt=decrypt, local_follow=local_follow, + trailing_slash_detector=self._connection._shell.path_has_trailing_slash, + checksum_check=force) + + # If it's recursive copy, destination is always a dir, + # explicitly mark it so (note - win_copy module relies on this). + if not self._connection._shell.path_has_trailing_slash(dest): + dest = "%s%s" % (dest, self.WIN_PATH_SEPARATOR) + + check_dest = dest + # Source is a file, add details to source_files dict + else: + result['operation'] = 'file_copy' + + # If the local file does not exist, get_real_file() raises AnsibleFileNotFound + try: + source_full = self._loader.get_real_file(source, decrypt=decrypt) + except AnsibleFileNotFound as e: + result['failed'] = True + result['msg'] = "could not find src=%s, %s" % (source_full, to_text(e)) + return result + + original_basename = os.path.basename(source) + result['original_basename'] = original_basename + + # check if dest ends with / or \ and append source filename to dest + if self._connection._shell.path_has_trailing_slash(dest): + check_dest = dest + filename = original_basename + result['dest'] = self._connection._shell.join_path(dest, filename) + else: + # replace \\ with / so we can use os.path to get the filename or dirname + unix_path = dest.replace(self.WIN_PATH_SEPARATOR, os.path.sep) + filename = os.path.basename(unix_path) + check_dest = os.path.dirname(unix_path) + + file_checksum = _get_local_checksum(force, source_full) + source_files['files'].append( + dict( + src=source_full, + dest=filename, + checksum=file_checksum + ) + ) + result['checksum'] = file_checksum + result['size'] = os.path.getsize(to_bytes(source_full, errors='surrogate_or_strict')) + + # find out the files/directories/symlinks that we need to copy to the server + query_args = self._task.args.copy() + query_args.update( + dict( + _copy_mode="query", + dest=check_dest, + force=force, + files=source_files['files'], + directories=source_files['directories'], + symlinks=source_files['symlinks'], + ) + ) + # src is not required for query, will fail path validation is src has unix allowed chars + query_args.pop('src', None) + + query_args.pop('content', None) + query_return = self._execute_module(module_args=query_args, + task_vars=task_vars) + + if query_return.get('failed') is True: + result.update(query_return) + return result + + if len(query_return['files']) > 0 or len(query_return['directories']) > 0 and self._connection._shell.tmpdir is None: + self._connection._shell.tmpdir = self._make_tmp_path() + + if len(query_return['files']) == 1 and len(query_return['directories']) == 0: + # we only need to copy 1 file, don't mess around with zips + file_src = query_return['files'][0]['src'] + file_dest = query_return['files'][0]['dest'] + result.update(self._copy_single_file(file_src, dest, file_dest, + task_vars, self._connection._shell.tmpdir, backup)) + if result.get('failed') is True: + result['msg'] = "failed to copy file %s: %s" % (file_src, result['msg']) + result['changed'] = True + + elif len(query_return['files']) > 0 or len(query_return['directories']) > 0: + # either multiple files or directories need to be copied, compress + # to a zip and 'explode' the zip on the server + # TODO: handle symlinks + result.update(self._copy_zip_file(dest, source_files['files'], + source_files['directories'], + task_vars, self._connection._shell.tmpdir, backup)) + result['changed'] = True + else: + # no operations need to occur + result['failed'] = False + result['changed'] = False + + # remove the content tmp file and remote tmp file if it was created + self._remove_tempfile_if_content_defined(content, content_tempfile) + self._remove_tmp_path(self._connection._shell.tmpdir) + return result diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/module_utils/WebRequest.psm1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/module_utils/WebRequest.psm1 new file mode 100644 index 0000000..8d077bd --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/module_utils/WebRequest.psm1 @@ -0,0 +1,518 @@ +# Copyright (c) 2020 Ansible Project +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +Function Get-AnsibleWindowsWebRequest { + <# + .SYNOPSIS + Creates a System.Net.WebRequest object based on common URL module options in Ansible. + + .DESCRIPTION + Will create a WebRequest based on common input options within Ansible. This can be used manually or with + Invoke-AnsibleWindowsWebRequest. + + .PARAMETER Uri + The URI to create the web request for. + + .PARAMETER UrlMethod + The protocol method to use, if omitted, will use the default value for the URI protocol specified. + + .PARAMETER FollowRedirects + Whether to follow redirect reponses. This is only valid when using a HTTP URI. + all - Will follow all redirects + none - Will follow no redirects + safe - Will only follow redirects when GET or HEAD is used as the UrlMethod + + .PARAMETER Headers + A hashtable or dictionary of header values to set on the request. This is only valid for a HTTP URI. + + .PARAMETER HttpAgent + A string to set for the 'User-Agent' header. This is only valid for a HTTP URI. + + .PARAMETER MaximumRedirection + The maximum number of redirections that will be followed. This is only valid for a HTTP URI. + + .PARAMETER UrlTimeout + The timeout in seconds that defines how long to wait until the request times out. + + .PARAMETER ValidateCerts + Whether to validate SSL certificates, default to True. + + .PARAMETER ClientCert + The path to PFX file to use for X509 authentication. This is only valid for a HTTP URI. This path can either + be a filesystem path (C:\folder\cert.pfx) or a PSPath to a credential (Cert:\CurrentUser\My\<thumbprint>). + + .PARAMETER ClientCertPassword + The password for the PFX certificate if required. This is only valid for a HTTP URI. + + .PARAMETER ForceBasicAuth + Whether to set the Basic auth header on the first request instead of when required. This is only valid for a + HTTP URI. + + .PARAMETER UrlUsername + The username to use for authenticating with the target. + + .PARAMETER UrlPassword + The password to use for authenticating with the target. + + .PARAMETER UseDefaultCredential + Whether to use the current user's credentials if available. This will only work when using Become, using SSH with + password auth, or WinRM with CredSSP or Kerberos with credential delegation. + + .PARAMETER UseProxy + Whether to use the default proxy defined in IE (WinINet) for the user or set no proxy at all. This should not + be set to True when ProxyUrl is also defined. + + .PARAMETER ProxyUrl + An explicit proxy server to use for the request instead of relying on the default proxy in IE. This is only + valid for a HTTP URI. + + .PARAMETER ProxyUsername + An optional username to use for proxy authentication. + + .PARAMETER ProxyPassword + The password for ProxyUsername. + + .PARAMETER ProxyUseDefaultCredential + Whether to use the current user's credentials for proxy authentication if available. This will only work when + using Become, using SSH with password auth, or WinRM with CredSSP or Kerberos with credential delegation. + + .PARAMETER Module + The AnsibleBasic module that can be used as a backup parameter source or a way to return warnings back to the + Ansible controller. + + .EXAMPLE + $spec = @{ + options = @{} + } + $module = [Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWindowsWebRequestSpec)) + + $web_request = Get-AnsibleWindowsWebRequest -Module $module + #> + [CmdletBinding()] + [OutputType([System.Net.WebRequest])] + Param ( + [Alias("url")] + [System.Uri] + $Uri, + + [Alias("url_method")] + [System.String] + $UrlMethod, + + [Alias("follow_redirects")] + [ValidateSet("all", "none", "safe")] + [System.String] + $FollowRedirects = "safe", + + [System.Collections.IDictionary] + $Headers, + + [Alias("http_agent")] + [System.String] + $HttpAgent = "ansible-httpget", + + [Alias("maximum_redirection")] + [System.Int32] + $MaximumRedirection = 50, + + [Alias("url_timeout")] + [System.Int32] + $UrlTimeout = 30, + + [Alias("validate_certs")] + [System.Boolean] + $ValidateCerts = $true, + + # Credential params + [Alias("client_cert")] + [System.String] + $ClientCert, + + [Alias("client_cert_password")] + [System.String] + $ClientCertPassword, + + [Alias("force_basic_auth")] + [Switch] + $ForceBasicAuth, + + [Alias("url_username")] + [System.String] + $UrlUsername, + + [Alias("url_password")] + [System.String] + $UrlPassword, + + [Alias("use_default_credential")] + [Switch] + $UseDefaultCredential, + + # Proxy params + [Alias("use_proxy")] + [System.Boolean] + $UseProxy = $true, + + [Alias("proxy_url")] + [System.String] + $ProxyUrl, + + [Alias("proxy_username")] + [System.String] + $ProxyUsername, + + [Alias("proxy_password")] + [System.String] + $ProxyPassword, + + [Alias("proxy_use_default_credential")] + [Switch] + $ProxyUseDefaultCredential, + + [ValidateScript({ $_.GetType().FullName -eq 'Ansible.Basic.AnsibleModule' })] + [System.Object] + $Module + ) + + # Set module options for parameters unless they were explicitly passed in. + if ($Module) { + foreach ($param in $PSCmdlet.MyInvocation.MyCommand.Parameters.GetEnumerator()) { + if ($PSBoundParameters.ContainsKey($param.Key)) { + # Was set explicitly we want to use that value + continue + } + + foreach ($alias in @($Param.Key) + $param.Value.Aliases) { + if ($Module.Params.ContainsKey($alias)) { + $var_value = $Module.Params.$alias -as $param.Value.ParameterType + Set-Variable -Name $param.Key -Value $var_value + break + } + } + } + } + + # Disable certificate validation if requested + # FUTURE: set this on ServerCertificateValidationCallback of the HttpWebRequest once .NET 4.5 is the minimum + if (-not $ValidateCerts) { + [System.Net.ServicePointManager]::ServerCertificateValidationCallback = { $true } + } + + # Enable TLS1.1/TLS1.2 if they're available but disabled (eg. .NET 4.5) + $security_protocols = [System.Net.ServicePointManager]::SecurityProtocol -bor [System.Net.SecurityProtocolType]::SystemDefault + if ([System.Net.SecurityProtocolType].GetMember("Tls11").Count -gt 0) { + $security_protocols = $security_protocols -bor [System.Net.SecurityProtocolType]::Tls11 + } + if ([System.Net.SecurityProtocolType].GetMember("Tls12").Count -gt 0) { + $security_protocols = $security_protocols -bor [System.Net.SecurityProtocolType]::Tls12 + } + [System.Net.ServicePointManager]::SecurityProtocol = $security_protocols + + $web_request = [System.Net.WebRequest]::Create($Uri) + if ($UrlMethod) { + $web_request.Method = $UrlMethod + } + $web_request.Timeout = $UrlTimeout * 1000 + + if ($UseDefaultCredential -and $web_request -is [System.Net.HttpWebRequest]) { + $web_request.UseDefaultCredentials = $true + } elseif ($UrlUsername) { + if ($ForceBasicAuth) { + $auth_value = [System.Convert]::ToBase64String([System.Text.Encoding]::ASCII.GetBytes(("{0}:{1}" -f $UrlUsername, $UrlPassword))) + $web_request.Headers.Add("Authorization", "Basic $auth_value") + } else { + $credential = New-Object -TypeName System.Net.NetworkCredential -ArgumentList $UrlUsername, $UrlPassword + $web_request.Credentials = $credential + } + } + + if ($ClientCert) { + # Expecting either a filepath or PSPath (Cert:\CurrentUser\My\<thumbprint>) + $cert = Get-Item -LiteralPath $ClientCert -ErrorAction SilentlyContinue + if ($null -eq $cert) { + Write-Error -Message "Client certificate '$ClientCert' does not exist" -Category ObjectNotFound + return + } + + $crypto_ns = 'System.Security.Cryptography.X509Certificates' + if ($cert.PSProvider.Name -ne 'Certificate') { + try { + $cert = New-Object -TypeName "$crypto_ns.X509Certificate2" -ArgumentList @( + $ClientCert, $ClientCertPassword + ) + } catch [System.Security.Cryptography.CryptographicException] { + Write-Error -Message "Failed to read client certificate at '$ClientCert'" -Exception $_.Exception -Category SecurityError + return + } + } + $web_request.ClientCertificates = New-Object -TypeName "$crypto_ns.X509Certificate2Collection" -ArgumentList @( + $cert + ) + } + + if (-not $UseProxy) { + $proxy = $null + } elseif ($ProxyUrl) { + $proxy = New-Object -TypeName System.Net.WebProxy -ArgumentList $ProxyUrl, $true + } else { + $proxy = $web_request.Proxy + } + + # $web_request.Proxy may return $null for a FTP web request. We only set the credentials if we have an actual + # proxy to work with, otherwise just ignore the credentials property. + if ($null -ne $proxy) { + if ($ProxyUseDefaultCredential) { + # Weird hack, $web_request.Proxy returns an IWebProxy object which only gurantees the Credentials + # property. We cannot set UseDefaultCredentials so we just set the Credentials to the + # DefaultCredentials in the CredentialCache which does the same thing. + $proxy.Credentials = [System.Net.CredentialCache]::DefaultCredentials + } elseif ($ProxyUsername) { + $proxy.Credentials = New-Object -TypeName System.Net.NetworkCredential -ArgumentList @( + $ProxyUsername, $ProxyPassword + ) + } else { + $proxy.Credentials = $null + } + } + + $web_request.Proxy = $proxy + + # Some parameters only apply when dealing with a HttpWebRequest + if ($web_request -is [System.Net.HttpWebRequest]) { + if ($Headers) { + foreach ($header in $Headers.GetEnumerator()) { + switch ($header.Key) { + Accept { $web_request.Accept = $header.Value } + Connection { $web_request.Connection = $header.Value } + Content-Length { $web_request.ContentLength = $header.Value } + Content-Type { $web_request.ContentType = $header.Value } + Expect { $web_request.Expect = $header.Value } + Date { $web_request.Date = $header.Value } + Host { $web_request.Host = $header.Value } + If-Modified-Since { $web_request.IfModifiedSince = $header.Value } + Range { $web_request.AddRange($header.Value) } + Referer { $web_request.Referer = $header.Value } + Transfer-Encoding { + $web_request.SendChunked = $true + $web_request.TransferEncoding = $header.Value + } + User-Agent { continue } + default { $web_request.Headers.Add($header.Key, $header.Value) } + } + } + } + + # For backwards compatibility we need to support setting the User-Agent if the header was set in the task. + # We just need to make sure that if an explicit http_agent module was set then that takes priority. + if ($Headers -and $Headers.ContainsKey("User-Agent")) { + $options = (Get-AnsibleWindowsWebRequestSpec).options + if ($HttpAgent -eq $options.http_agent.default) { + $HttpAgent = $Headers['User-Agent'] + } elseif ($null -ne $Module) { + $Module.Warn("The 'User-Agent' header and the 'http_agent' was set, using the 'http_agent' for web request") + } + } + $web_request.UserAgent = $HttpAgent + + switch ($FollowRedirects) { + none { $web_request.AllowAutoRedirect = $false } + safe { + if ($web_request.Method -in @("GET", "HEAD")) { + $web_request.AllowAutoRedirect = $true + } else { + $web_request.AllowAutoRedirect = $false + } + } + all { $web_request.AllowAutoRedirect = $true } + } + + if ($MaximumRedirection -eq 0) { + $web_request.AllowAutoRedirect = $false + } else { + $web_request.MaximumAutomaticRedirections = $MaximumRedirection + } + } + + return $web_request +} + +Function Invoke-AnsibleWindowsWebRequest { + <# + .SYNOPSIS + Invokes a ScriptBlock with the WebRequest. + + .DESCRIPTION + Invokes the ScriptBlock and handle extra information like accessing the response stream, closing those streams + safely as well as setting common module return values. + + .PARAMETER Module + The Ansible.Basic module to set the return values for. This will set the following return values; + elapsed - The total time, in seconds, that it took to send the web request and process the response + msg - The human readable description of the response status code + status_code - An int that is the response status code + + .PARAMETER Request + The System.Net.WebRequest to call. This can either be manually crafted or created with + Get-AnsibleWindowsWebRequest. + + .PARAMETER Script + The ScriptBlock to invoke during the web request. This ScriptBlock should take in the params + Param ([System.Net.WebResponse]$Response, [System.IO.Stream]$Stream) + + This scriptblock should manage the response based on what it need to do. + + .PARAMETER Body + An optional Stream to send to the target during the request. + + .PARAMETER IgnoreBadResponse + By default a WebException will be raised for a non 2xx status code and the Script will not be invoked. This + parameter can be set to process all responses regardless of the status code. + + .EXAMPLE Basic module that downloads a file + $spec = @{ + options = @{ + path = @{ type = "path"; required = $true } + } + } + $module = Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWindowsWebRequestSpec)) + + $web_request = Get-AnsibleWindowsWebRequest -Module $module + + Invoke-AnsibleWindowsWebRequest -Module $module -Request $web_request -Script { + Param ([System.Net.WebResponse]$Response, [System.IO.Stream]$Stream) + + $fs = [System.IO.File]::Create($module.Params.path) + try { + $Stream.CopyTo($fs) + $fs.Flush() + } finally { + $fs.Dispose() + } + } + #> + [CmdletBinding()] + param ( + [Parameter(Mandatory=$true)] + [System.Object] + [ValidateScript({ $_.GetType().FullName -eq 'Ansible.Basic.AnsibleModule' })] + $Module, + + [Parameter(Mandatory=$true)] + [System.Net.WebRequest] + $Request, + + [Parameter(Mandatory=$true)] + [ScriptBlock] + $Script, + + [AllowNull()] + [System.IO.Stream] + $Body, + + [Switch] + $IgnoreBadResponse + ) + + $start = Get-Date + if ($null -ne $Body) { + $request_st = $Request.GetRequestStream() + try { + $Body.CopyTo($request_st) + $request_st.Flush() + } finally { + $request_st.Close() + } + } + + try { + try { + $web_response = $Request.GetResponse() + } catch [System.Net.WebException] { + # A WebResponse with a status code not in the 200 range will raise a WebException. We check if the + # exception raised contains the actual response and continue on if IgnoreBadResponse is set. We also + # make sure we set the status_code return value on the Module object if possible + + if ($_.Exception.PSObject.Properties.Name -match "Response") { + $web_response = $_.Exception.Response + + if (-not $IgnoreBadResponse -or $null -eq $web_response) { + $Module.Result.msg = $_.Exception.StatusDescription + $Module.Result.status_code = $_.Exception.Response.StatusCode + throw $_ + } + } else { + throw $_ + } + } + + if ($Request.RequestUri.IsFile) { + # A FileWebResponse won't have these properties set + $Module.Result.msg = "OK" + $Module.Result.status_code = 200 + } else { + $Module.Result.msg = $web_response.StatusDescription + $Module.Result.status_code = $web_response.StatusCode + } + + $response_stream = $web_response.GetResponseStream() + try { + # Invoke the ScriptBlock and pass in WebResponse and ResponseStream + &$Script -Response $web_response -Stream $response_stream + } finally { + $response_stream.Dispose() + } + } finally { + if ($web_response) { + $web_response.Close() + } + $Module.Result.elapsed = ((Get-date) - $start).TotalSeconds + } +} + +Function Get-AnsibleWindowsWebRequestSpec { + <# + .SYNOPSIS + Used by modules to get the argument spec fragment for AnsibleModule. + + .EXAMPLES + $spec = @{ + options = @{} + } + $module = [Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWindowsWebRequestSpec)) + + .NOTES + The options here are reflected in the doc fragment 'ansible.windows.web_request' at + 'plugins/doc_fragments/web_request.py'. + #> + @{ + options = @{ + url_method = @{ type = 'str' } + follow_redirects = @{ type = 'str'; choices = @('all', 'none', 'safe'); default = 'safe' } + headers = @{ type = 'dict' } + http_agent = @{ type = 'str'; default = 'ansible-httpget' } + maximum_redirection = @{ type = 'int'; default = 50 } + url_timeout = @{ type = 'int'; default = 30 } + validate_certs = @{ type = 'bool'; default = $true } + + # Credential options + client_cert = @{ type = 'str' } + client_cert_password = @{ type = 'str'; no_log = $true } + force_basic_auth = @{ type = 'bool'; default = $false } + url_username = @{ type = 'str' } + url_password = @{ type = 'str'; no_log = $true } + use_default_credential = @{ type = 'bool'; default = $false } + + # Proxy options + use_proxy = @{ type = 'bool'; default = $true } + proxy_url = @{ type = 'str' } + proxy_username = @{ type = 'str' } + proxy_password = @{ type = 'str'; no_log = $true } + proxy_use_default_credential = @{ type = 'bool'; default = $false } + } + } +} + +$export_members = @{ + Function = "Get-AnsibleWindowsWebRequest", "Get-AnsibleWindowsWebRequestSpec", "Invoke-AnsibleWindowsWebRequest" +} +Export-ModuleMember @export_members diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1 new file mode 100644 index 0000000..1ce3ff4 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1 @@ -0,0 +1,58 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$results = @{changed=$false} + +$parsed_args = Parse-Args $args +$jid = Get-AnsibleParam $parsed_args "jid" -failifempty $true -resultobj $results +$mode = Get-AnsibleParam $parsed_args "mode" -Default "status" -ValidateSet "status","cleanup" + +# parsed in from the async_status action plugin +$async_dir = Get-AnsibleParam $parsed_args "_async_dir" -type "path" -failifempty $true + +$log_path = [System.IO.Path]::Combine($async_dir, $jid) + +If(-not $(Test-Path $log_path)) +{ + Fail-Json @{ansible_job_id=$jid; started=1; finished=1} "could not find job at '$async_dir'" +} + +If($mode -eq "cleanup") { + Remove-Item $log_path -Recurse + Exit-Json @{ansible_job_id=$jid; erased=$log_path} +} + +# NOT in cleanup mode, assume regular status mode +# no remote kill mode currently exists, but probably should +# consider log_path + ".pid" file and also unlink that above + +$data = $null +Try { + $data_raw = Get-Content $log_path + + # TODO: move this into module_utils/powershell.ps1? + $jss = New-Object System.Web.Script.Serialization.JavaScriptSerializer + $data = $jss.DeserializeObject($data_raw) +} +Catch { + If(-not $data_raw) { + # file not written yet? That means it is running + Exit-Json @{results_file=$log_path; ansible_job_id=$jid; started=1; finished=0} + } + Else { + Fail-Json @{ansible_job_id=$jid; results_file=$log_path; started=1; finished=1} "Could not parse job output: $data" + } +} + +If (-not $data.ContainsKey("started")) { + $data['finished'] = 1 + $data['ansible_job_id'] = $jid +} +ElseIf (-not $data.ContainsKey("finished")) { + $data['finished'] = 0 +} + +Exit-Json $data diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1 new file mode 100644 index 0000000..e3c3813 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1 @@ -0,0 +1,225 @@ +#!powershell + +# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com> +# Copyright: (c) 2015, Trond Hindenes +# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.PrivilegeUtil +#Requires -Module Ansible.ModuleUtils.SID + +$ErrorActionPreference = "Stop" + +# win_acl module (File/Resources Permission Additions/Removal) + +#Functions +function Get-UserSID { + param( + [String]$AccountName + ) + + $userSID = $null + $searchAppPools = $false + + if ($AccountName.Split("\").Count -gt 1) { + if ($AccountName.Split("\")[0] -eq "IIS APPPOOL") { + $searchAppPools = $true + $AccountName = $AccountName.Split("\")[1] + } + } + + if ($searchAppPools) { + Import-Module -Name WebAdministration + $testIISPath = Test-Path -LiteralPath "IIS:" + if ($testIISPath) { + $appPoolObj = Get-ItemProperty -LiteralPath "IIS:\AppPools\$AccountName" + $userSID = $appPoolObj.applicationPoolSid + } + } + else { + $userSID = Convert-ToSID -account_name $AccountName + } + + return $userSID +} + +$params = Parse-Args $args + +Function SetPrivilegeTokens() { + # Set privilege tokens only if admin. + # Admins would have these privs or be able to set these privs in the UI Anyway + + $adminRole=[System.Security.Principal.WindowsBuiltInRole]::Administrator + $myWindowsID=[System.Security.Principal.WindowsIdentity]::GetCurrent() + $myWindowsPrincipal=new-object System.Security.Principal.WindowsPrincipal($myWindowsID) + + + if ($myWindowsPrincipal.IsInRole($adminRole)) { + # Need to adjust token privs when executing Set-ACL in certain cases. + # e.g. d:\testdir is owned by group in which current user is not a member and no perms are inherited from d:\ + # This also sets us up for setting the owner as a feature. + # See the following for details of each privilege + # https://msdn.microsoft.com/en-us/library/windows/desktop/bb530716(v=vs.85).aspx + $privileges = @( + "SeRestorePrivilege", # Grants all write access control to any file, regardless of ACL. + "SeBackupPrivilege", # Grants all read access control to any file, regardless of ACL. + "SeTakeOwnershipPrivilege" # Grants ability to take owernship of an object w/out being granted discretionary access + ) + foreach ($privilege in $privileges) { + $state = Get-AnsiblePrivilege -Name $privilege + if ($state -eq $false) { + Set-AnsiblePrivilege -Name $privilege -Value $true + } + } + } +} + + +$result = @{ + changed = $false +} + +$path = Get-AnsibleParam -obj $params -name "path" -type "str" -failifempty $true +$user = Get-AnsibleParam -obj $params -name "user" -type "str" -failifempty $true +$rights = Get-AnsibleParam -obj $params -name "rights" -type "str" -failifempty $true + +$type = Get-AnsibleParam -obj $params -name "type" -type "str" -failifempty $true -validateset "allow","deny" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "absent","present" + +$inherit = Get-AnsibleParam -obj $params -name "inherit" -type "str" +$propagation = Get-AnsibleParam -obj $params -name "propagation" -type "str" -default "None" -validateset "InheritOnly","None","NoPropagateInherit" + +# We mount the HKCR, HKU, and HKCC registry hives so PS can access them. +# Network paths have no qualifiers so we use -EA SilentlyContinue to ignore that +$path_qualifier = Split-Path -Path $path -Qualifier -ErrorAction SilentlyContinue +if ($path_qualifier -eq "HKCR:" -and (-not (Test-Path -LiteralPath HKCR:\))) { + New-PSDrive -Name HKCR -PSProvider Registry -Root HKEY_CLASSES_ROOT > $null +} +if ($path_qualifier -eq "HKU:" -and (-not (Test-Path -LiteralPath HKU:\))) { + New-PSDrive -Name HKU -PSProvider Registry -Root HKEY_USERS > $null +} +if ($path_qualifier -eq "HKCC:" -and (-not (Test-Path -LiteralPath HKCC:\))) { + New-PSDrive -Name HKCC -PSProvider Registry -Root HKEY_CURRENT_CONFIG > $null +} + +If (-Not (Test-Path -LiteralPath $path)) { + Fail-Json -obj $result -message "$path file or directory does not exist on the host" +} + +# Test that the user/group is resolvable on the local machine +$sid = Get-UserSID -AccountName $user +if (!$sid) { + Fail-Json -obj $result -message "$user is not a valid user or group on the host machine or domain" +} + +If (Test-Path -LiteralPath $path -PathType Leaf) { + $inherit = "None" +} +ElseIf ($null -eq $inherit) { + $inherit = "ContainerInherit, ObjectInherit" +} + +# Bug in Set-Acl, Get-Acl where -LiteralPath only works for the Registry provider if the location is in that root +# qualifier. We also don't have a qualifier for a network path so only change if not null +if ($null -ne $path_qualifier) { + Push-Location -LiteralPath $path_qualifier +} + +Try { + SetPrivilegeTokens + $path_item = Get-Item -LiteralPath $path -Force + If ($path_item.PSProvider.Name -eq "Registry") { + $colRights = [System.Security.AccessControl.RegistryRights]$rights + } + Else { + $colRights = [System.Security.AccessControl.FileSystemRights]$rights + } + + $InheritanceFlag = [System.Security.AccessControl.InheritanceFlags]$inherit + $PropagationFlag = [System.Security.AccessControl.PropagationFlags]$propagation + + If ($type -eq "allow") { + $objType =[System.Security.AccessControl.AccessControlType]::Allow + } + Else { + $objType =[System.Security.AccessControl.AccessControlType]::Deny + } + + $objUser = New-Object System.Security.Principal.SecurityIdentifier($sid) + If ($path_item.PSProvider.Name -eq "Registry") { + $objACE = New-Object System.Security.AccessControl.RegistryAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType) + } + Else { + $objACE = New-Object System.Security.AccessControl.FileSystemAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType) + } + $objACL = Get-ACL -LiteralPath $path + + # Check if the ACE exists already in the objects ACL list + $match = $false + + ForEach($rule in $objACL.GetAccessRules($true, $true, [System.Security.Principal.SecurityIdentifier])){ + + If ($path_item.PSProvider.Name -eq "Registry") { + If (($rule.RegistryRights -eq $objACE.RegistryRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) { + $match = $true + Break + } + } else { + If (($rule.FileSystemRights -eq $objACE.FileSystemRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) { + $match = $true + Break + } + } + } + + If ($state -eq "present" -And $match -eq $false) { + Try { + $objACL.AddAccessRule($objACE) + If ($path_item.PSProvider.Name -eq "Registry") { + Set-ACL -LiteralPath $path -AclObject $objACL + } else { + (Get-Item -LiteralPath $path).SetAccessControl($objACL) + } + $result.changed = $true + } + Catch { + Fail-Json -obj $result -message "an exception occurred when adding the specified rule - $($_.Exception.Message)" + } + } + ElseIf ($state -eq "absent" -And $match -eq $true) { + Try { + $objACL.RemoveAccessRule($objACE) + If ($path_item.PSProvider.Name -eq "Registry") { + Set-ACL -LiteralPath $path -AclObject $objACL + } else { + (Get-Item -LiteralPath $path).SetAccessControl($objACL) + } + $result.changed = $true + } + Catch { + Fail-Json -obj $result -message "an exception occurred when removing the specified rule - $($_.Exception.Message)" + } + } + Else { + # A rule was attempting to be added but already exists + If ($match -eq $true) { + Exit-Json -obj $result -message "the specified rule already exists" + } + # A rule didn't exist that was trying to be removed + Else { + Exit-Json -obj $result -message "the specified rule does not exist" + } + } +} +Catch { + Fail-Json -obj $result -message "an error occurred when attempting to $state $rights permission(s) on $path for $user - $($_.Exception.Message)" +} +Finally { + # Make sure we revert the location stack to the original path just for cleanups sake + if ($null -ne $path_qualifier) { + Pop-Location + } +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py new file mode 100644 index 0000000..14fbd82 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py @@ -0,0 +1,132 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com> +# Copyright: (c) 2015, Trond Hindenes +# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_acl +version_added: "2.0" +short_description: Set file/directory/registry permissions for a system user or group +description: +- Add or remove rights/permissions for a given user or group for the specified + file, folder, registry key or AppPool identifies. +options: + path: + description: + - The path to the file or directory. + type: str + required: yes + user: + description: + - User or Group to add specified rights to act on src file/folder or + registry key. + type: str + required: yes + state: + description: + - Specify whether to add C(present) or remove C(absent) the specified access rule. + type: str + choices: [ absent, present ] + default: present + type: + description: + - Specify whether to allow or deny the rights specified. + type: str + required: yes + choices: [ allow, deny ] + rights: + description: + - The rights/permissions that are to be allowed/denied for the specified + user or group for the item at C(path). + - If C(path) is a file or directory, rights can be any right under MSDN + FileSystemRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.filesystemrights.aspx). + - If C(path) is a registry key, rights can be any right under MSDN + RegistryRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.registryrights.aspx). + type: str + required: yes + inherit: + description: + - Inherit flags on the ACL rules. + - Can be specified as a comma separated list, e.g. C(ContainerInherit), + C(ObjectInherit). + - For more information on the choices see MSDN InheritanceFlags enumeration + at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.inheritanceflags.aspx). + - Defaults to C(ContainerInherit, ObjectInherit) for Directories. + type: str + choices: [ ContainerInherit, ObjectInherit ] + propagation: + description: + - Propagation flag on the ACL rules. + - For more information on the choices see MSDN PropagationFlags enumeration + at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.propagationflags.aspx). + type: str + choices: [ InheritOnly, None, NoPropagateInherit ] + default: "None" +notes: +- If adding ACL's for AppPool identities (available since 2.3), the Windows + Feature "Web-Scripting-Tools" must be enabled. +seealso: +- module: win_acl_inheritance +- module: win_file +- module: win_owner +- module: win_stat +author: +- Phil Schwartz (@schwartzmx) +- Trond Hindenes (@trondhindenes) +- Hans-Joachim Kliemeck (@h0nIg) +''' + +EXAMPLES = r''' +- name: Restrict write and execute access to User Fed-Phil + win_acl: + user: Fed-Phil + path: C:\Important\Executable.exe + type: deny + rights: ExecuteFile,Write + +- name: Add IIS_IUSRS allow rights + win_acl: + path: C:\inetpub\wwwroot\MySite + user: IIS_IUSRS + rights: FullControl + type: allow + state: present + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Set registry key right + win_acl: + path: HKCU:\Bovine\Key + user: BUILTIN\Users + rights: EnumerateSubKeys + type: allow + state: present + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Remove FullControl AccessRule for IIS_IUSRS + win_acl: + path: C:\inetpub\wwwroot\MySite + user: IIS_IUSRS + rights: FullControl + type: allow + state: absent + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Deny Intern + win_acl: + path: C:\Administrator\Documents + user: Intern + rights: Read,Write,Modify,FullControl,Delete + type: deny + state: present +''' diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1 new file mode 100644 index 0000000..6a26ee7 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1 @@ -0,0 +1,403 @@ +#!powershell + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.Backup + +$ErrorActionPreference = 'Stop' + +$params = Parse-Args -arguments $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false + +# there are 4 modes to win_copy which are driven by the action plugins: +# explode: src is a zip file which needs to be extracted to dest, for use with multiple files +# query: win_copy action plugin wants to get the state of remote files to check whether it needs to send them +# remote: all copy action is happening remotely (remote_src=True) +# single: a single file has been copied, also used with template +$copy_mode = Get-AnsibleParam -obj $params -name "_copy_mode" -type "str" -default "single" -validateset "explode","query","remote","single" + +# used in explode, remote and single mode +$src = Get-AnsibleParam -obj $params -name "src" -type "path" -failifempty ($copy_mode -in @("explode","process","single")) +$dest = Get-AnsibleParam -obj $params -name "dest" -type "path" -failifempty $true +$backup = Get-AnsibleParam -obj $params -name "backup" -type "bool" -default $false + +# used in single mode +$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str" + +# used in query and remote mode +$force = Get-AnsibleParam -obj $params -name "force" -type "bool" -default $true + +# used in query mode, contains the local files/directories/symlinks that are to be copied +$files = Get-AnsibleParam -obj $params -name "files" -type "list" +$directories = Get-AnsibleParam -obj $params -name "directories" -type "list" + +$result = @{ + changed = $false +} + +if ($diff_mode) { + $result.diff = @{} +} + +Function Copy-File($source, $dest) { + $diff = "" + $copy_file = $false + $source_checksum = $null + if ($force) { + $source_checksum = Get-FileChecksum -path $source + } + + if (Test-Path -LiteralPath $dest -PathType Container) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': dest is already a folder" + } elseif (Test-Path -LiteralPath $dest -PathType Leaf) { + if ($force) { + $target_checksum = Get-FileChecksum -path $dest + if ($source_checksum -ne $target_checksum) { + $copy_file = $true + } + } + } else { + $copy_file = $true + } + + if ($copy_file) { + $file_dir = [System.IO.Path]::GetDirectoryName($dest) + # validate the parent dir is not a file and that it exists + if (Test-Path -LiteralPath $file_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder" + } elseif (-not (Test-Path -LiteralPath $file_dir)) { + # directory doesn't exist, need to create + New-Item -Path $file_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + $diff += "+$file_dir\`n" + } + + if ($backup) { + $result.backup_file = Backup-File -path $dest -WhatIf:$check_mode + } + + if (Test-Path -LiteralPath $dest -PathType Leaf) { + Remove-Item -LiteralPath $dest -Force -Recurse -WhatIf:$check_mode | Out-Null + $diff += "-$dest`n" + } + + if (-not $check_mode) { + # cannot run with -WhatIf:$check_mode as if the parent dir didn't + # exist and was created above would still not exist in check mode + Copy-Item -LiteralPath $source -Destination $dest -Force | Out-Null + } + $diff += "+$dest`n" + + $result.changed = $true + } + + # ugly but to save us from running the checksum twice, let's return it for + # the main code to add it to $result + return ,@{ diff = $diff; checksum = $source_checksum } +} + +Function Copy-Folder($source, $dest) { + $diff = "" + + if (-not (Test-Path -LiteralPath $dest -PathType Container)) { + $parent_dir = [System.IO.Path]::GetDirectoryName($dest) + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder" + } + if (Test-Path -LiteralPath $dest -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder from '$source' to '$dest': dest is already a file" + } + + New-Item -Path $dest -ItemType Container -WhatIf:$check_mode | Out-Null + $diff += "+$dest\`n" + $result.changed = $true + } + + $child_items = Get-ChildItem -LiteralPath $source -Force + foreach ($child_item in $child_items) { + $dest_child_path = Join-Path -Path $dest -ChildPath $child_item.Name + if ($child_item.PSIsContainer) { + $diff += (Copy-Folder -source $child_item.Fullname -dest $dest_child_path) + } else { + $diff += (Copy-File -source $child_item.Fullname -dest $dest_child_path).diff + } + } + + return $diff +} + +Function Get-FileSize($path) { + $file = Get-Item -LiteralPath $path -Force + if ($file.PSIsContainer) { + $size = (Get-ChildItem -Literalpath $file.FullName -Recurse -Force | ` + Where-Object { $_.PSObject.Properties.Name -contains 'Length' } | ` + Measure-Object -Property Length -Sum).Sum + if ($null -eq $size) { + $size = 0 + } + } else { + $size = $file.Length + } + + $size +} + +Function Extract-Zip($src, $dest) { + $archive = [System.IO.Compression.ZipFile]::Open($src, [System.IO.Compression.ZipArchiveMode]::Read, [System.Text.Encoding]::UTF8) + foreach ($entry in $archive.Entries) { + $archive_name = $entry.FullName + + # FullName may be appended with / or \, determine if it is padded and remove it + $padding_length = $archive_name.Length % 4 + if ($padding_length -eq 0) { + $is_dir = $false + $base64_name = $archive_name + } elseif ($padding_length -eq 1) { + $is_dir = $true + if ($archive_name.EndsWith("/") -or $archive_name.EndsWith("`\")) { + $base64_name = $archive_name.Substring(0, $archive_name.Length - 1) + } else { + throw "invalid base64 archive name '$archive_name'" + } + } else { + throw "invalid base64 length '$archive_name'" + } + + # to handle unicode character, win_copy action plugin has encoded the filename + $decoded_archive_name = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($base64_name)) + # re-add the / to the entry full name if it was a directory + if ($is_dir) { + $decoded_archive_name = "$decoded_archive_name/" + } + $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_name) + $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path) + + if (-not (Test-Path -LiteralPath $entry_dir)) { + New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + } + + if ($is_dir -eq $false) { + if (-not $check_mode) { + [System.IO.Compression.ZipFileExtensions]::ExtractToFile($entry, $entry_target_path, $true) + } + } + } + $archive.Dispose() # release the handle of the zip file +} + +Function Extract-ZipLegacy($src, $dest) { + if (-not (Test-Path -LiteralPath $dest)) { + New-Item -Path $dest -ItemType Directory -WhatIf:$check_mode | Out-Null + } + $shell = New-Object -ComObject Shell.Application + $zip = $shell.NameSpace($src) + $dest_path = $shell.NameSpace($dest) + + foreach ($entry in $zip.Items()) { + $is_dir = $entry.IsFolder + $encoded_archive_entry = $entry.Name + # to handle unicode character, win_copy action plugin has encoded the filename + $decoded_archive_entry = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($encoded_archive_entry)) + if ($is_dir) { + $decoded_archive_entry = "$decoded_archive_entry/" + } + + $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_entry) + $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path) + + if (-not (Test-Path -LiteralPath $entry_dir)) { + New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + } + + if ($is_dir -eq $false -and (-not $check_mode)) { + # https://msdn.microsoft.com/en-us/library/windows/desktop/bb787866.aspx + # From Folder.CopyHere documentation, 1044 means: + # - 1024: do not display a user interface if an error occurs + # - 16: respond with "yes to all" for any dialog box that is displayed + # - 4: do not display a progress dialog box + $dest_path.CopyHere($entry, 1044) + + # once file is extraced, we need to rename it with non base64 name + $combined_encoded_path = [System.IO.Path]::Combine($dest, $encoded_archive_entry) + Move-Item -LiteralPath $combined_encoded_path -Destination $entry_target_path -Force | Out-Null + } + } +} + +if ($copy_mode -eq "query") { + # we only return a list of files/directories that need to be copied over + # the source of the local file will be the key used + $changed_files = @() + $changed_directories = @() + $changed_symlinks = @() + + foreach ($file in $files) { + $filename = $file.dest + $local_checksum = $file.checksum + + $filepath = Join-Path -Path $dest -ChildPath $filename + if (Test-Path -LiteralPath $filepath -PathType Leaf) { + if ($force) { + $checksum = Get-FileChecksum -path $filepath + if ($checksum -ne $local_checksum) { + $changed_files += $file + } + } + } elseif (Test-Path -LiteralPath $filepath -PathType Container) { + Fail-Json -obj $result -message "cannot copy file to dest '$filepath': object at path is already a directory" + } else { + $changed_files += $file + } + } + + foreach ($directory in $directories) { + $dirname = $directory.dest + + $dirpath = Join-Path -Path $dest -ChildPath $dirname + $parent_dir = [System.IO.Path]::GetDirectoryName($dirpath) + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at parent directory path is already a file" + } + if (Test-Path -LiteralPath $dirpath -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at path is already a file" + } elseif (-not (Test-Path -LiteralPath $dirpath -PathType Container)) { + $changed_directories += $directory + } + } + + # TODO: Handle symlinks + + $result.files = $changed_files + $result.directories = $changed_directories + $result.symlinks = $changed_symlinks +} elseif ($copy_mode -eq "explode") { + # a single zip file containing the files and directories needs to be + # expanded this will always result in a change as the calculation is done + # on the win_copy action plugin and is only run if a change needs to occur + if (-not (Test-Path -LiteralPath $src -PathType Leaf)) { + Fail-Json -obj $result -message "Cannot expand src zip file: '$src' as it does not exist" + } + + # Detect if the PS zip assemblies are available or whether to use Shell + $use_legacy = $false + try { + Add-Type -AssemblyName System.IO.Compression.FileSystem | Out-Null + Add-Type -AssemblyName System.IO.Compression | Out-Null + } catch { + $use_legacy = $true + } + if ($use_legacy) { + Extract-ZipLegacy -src $src -dest $dest + } else { + Extract-Zip -src $src -dest $dest + } + + $result.changed = $true +} elseif ($copy_mode -eq "remote") { + # all copy actions are happening on the remote side (windows host), need + # too copy source and dest using PS code + $result.src = $src + $result.dest = $dest + + if (-not (Test-Path -LiteralPath $src)) { + Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist" + } + + if (Test-Path -LiteralPath $src -PathType Container) { + # we are copying a directory or the contents of a directory + $result.operation = 'folder_copy' + if ($src.EndsWith("/") -or $src.EndsWith("`\")) { + # copying the folder's contents to dest + $diff = "" + $child_files = Get-ChildItem -LiteralPath $src -Force + foreach ($child_file in $child_files) { + $dest_child_path = Join-Path -Path $dest -ChildPath $child_file.Name + if ($child_file.PSIsContainer) { + $diff += Copy-Folder -source $child_file.FullName -dest $dest_child_path + } else { + $diff += (Copy-File -source $child_file.FullName -dest $dest_child_path).diff + } + } + } else { + # copying the folder and it's contents to dest + $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name + $result.dest = $dest + $diff = Copy-Folder -source $src -dest $dest + } + } else { + # we are just copying a single file to dest + $result.operation = 'file_copy' + + $source_basename = (Get-Item -LiteralPath $src -Force).Name + $result.original_basename = $source_basename + + if ($dest.EndsWith("/") -or $dest.EndsWith("`\")) { + $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name + $result.dest = $dest + } else { + # check if the parent dir exists, this is only done if src is a + # file and dest if the path to a file (doesn't end with \ or /) + $parent_dir = Split-Path -LiteralPath $dest + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist" + } + } + $copy_result = Copy-File -source $src -dest $dest + $diff = $copy_result.diff + $result.checksum = $copy_result.checksum + } + + # the file might not exist if running in check mode + if (-not $check_mode -or (Test-Path -LiteralPath $dest -PathType Leaf)) { + $result.size = Get-FileSize -path $dest + } else { + $result.size = $null + } + if ($diff_mode) { + $result.diff.prepared = $diff + } +} elseif ($copy_mode -eq "single") { + # a single file is located in src and we need to copy to dest, this will + # always result in a change as the calculation is done on the Ansible side + # before this is run. This should also never run in check mode + if (-not (Test-Path -LiteralPath $src -PathType Leaf)) { + Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist" + } + + # the dest parameter is a directory, we need to append original_basename + if ($dest.EndsWith("/") -or $dest.EndsWith("`\") -or (Test-Path -LiteralPath $dest -PathType Container)) { + $remote_dest = Join-Path -Path $dest -ChildPath $original_basename + $parent_dir = Split-Path -LiteralPath $remote_dest + + # when dest ends with /, we need to create the destination directories + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + New-Item -Path $parent_dir -ItemType Directory | Out-Null + } + } else { + $remote_dest = $dest + $parent_dir = Split-Path -LiteralPath $remote_dest + + # check if the dest parent dirs exist, need to fail if they don't + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist" + } + } + + if ($backup) { + $result.backup_file = Backup-File -path $remote_dest -WhatIf:$check_mode + } + + Copy-Item -LiteralPath $src -Destination $remote_dest -Force | Out-Null + $result.changed = $true +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py new file mode 100644 index 0000000..a55f4c6 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py @@ -0,0 +1,207 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_copy +version_added: '1.9.2' +short_description: Copies files to remote locations on windows hosts +description: +- The C(win_copy) module copies a file on the local box to remote windows locations. +- For non-Windows targets, use the M(copy) module instead. +options: + content: + description: + - When used instead of C(src), sets the contents of a file directly to the + specified value. + - This is for simple values, for anything complex or with formatting please + switch to the M(template) module. + type: str + version_added: '2.3' + decrypt: + description: + - This option controls the autodecryption of source files using vault. + type: bool + default: yes + version_added: '2.5' + dest: + description: + - Remote absolute path where the file should be copied to. + - If C(src) is a directory, this must be a directory too. + - Use \ for path separators or \\ when in "double quotes". + - If C(dest) ends with \ then source or the contents of source will be + copied to the directory without renaming. + - If C(dest) is a nonexistent path, it will only be created if C(dest) ends + with "/" or "\", or C(src) is a directory. + - If C(src) and C(dest) are files and if the parent directory of C(dest) + doesn't exist, then the task will fail. + type: path + required: yes + backup: + description: + - Determine whether a backup should be created. + - When set to C(yes), create a backup file including the timestamp information + so you can get the original file back if you somehow clobbered it incorrectly. + - No backup is taken when C(remote_src=False) and multiple files are being + copied. + type: bool + default: no + version_added: '2.8' + force: + description: + - If set to C(yes), the file will only be transferred if the content + is different than destination. + - If set to C(no), the file will only be transferred if the + destination does not exist. + - If set to C(no), no checksuming of the content is performed which can + help improve performance on larger files. + type: bool + default: yes + version_added: '2.3' + local_follow: + description: + - This flag indicates that filesystem links in the source tree, if they + exist, should be followed. + type: bool + default: yes + version_added: '2.4' + remote_src: + description: + - If C(no), it will search for src at originating/master machine. + - If C(yes), it will go to the remote/target machine for the src. + type: bool + default: no + version_added: '2.3' + src: + description: + - Local path to a file to copy to the remote server; can be absolute or + relative. + - If path is a directory, it is copied (including the source folder name) + recursively to C(dest). + - If path is a directory and ends with "/", only the inside contents of + that directory are copied to the destination. Otherwise, if it does not + end with "/", the directory itself with all contents is copied. + - If path is a file and dest ends with "\", the file is copied to the + folder with the same filename. + - Required unless using C(content). + type: path +notes: +- Currently win_copy does not support copying symbolic links from both local to + remote and remote to remote. +- It is recommended that backslashes C(\) are used instead of C(/) when dealing + with remote paths. +- Because win_copy runs over WinRM, it is not a very efficient transfer + mechanism. If sending large files consider hosting them on a web service and + using M(win_get_url) instead. +seealso: +- module: assemble +- module: copy +- module: win_get_url +- module: win_robocopy +author: +- Jon Hawkesworth (@jhawkesworth) +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Copy a single file + win_copy: + src: /srv/myfiles/foo.conf + dest: C:\Temp\renamed-foo.conf + +- name: Copy a single file, but keep a backup + win_copy: + src: /srv/myfiles/foo.conf + dest: C:\Temp\renamed-foo.conf + backup: yes + +- name: Copy a single file keeping the filename + win_copy: + src: /src/myfiles/foo.conf + dest: C:\Temp\ + +- name: Copy folder to C:\Temp (results in C:\Temp\temp_files) + win_copy: + src: files/temp_files + dest: C:\Temp + +- name: Copy folder contents recursively + win_copy: + src: files/temp_files/ + dest: C:\Temp + +- name: Copy a single file where the source is on the remote host + win_copy: + src: C:\Temp\foo.txt + dest: C:\ansible\foo.txt + remote_src: yes + +- name: Copy a folder recursively where the source is on the remote host + win_copy: + src: C:\Temp + dest: C:\ansible + remote_src: yes + +- name: Set the contents of a file + win_copy: + content: abc123 + dest: C:\Temp\foo.txt + +- name: Copy a single file as another user + win_copy: + src: NuGet.config + dest: '%AppData%\NuGet\NuGet.config' + vars: + ansible_become_user: user + ansible_become_password: pass + # The tmp dir must be set when using win_copy as another user + # This ensures the become user will have permissions for the operation + # Make sure to specify a folder both the ansible_user and the become_user have access to (i.e not %TEMP% which is user specific and requires Admin) + ansible_remote_tmp: 'c:\tmp' +''' + +RETURN = r''' +backup_file: + description: Name of the backup file that was created. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +dest: + description: Destination file/path. + returned: changed + type: str + sample: C:\Temp\ +src: + description: Source file used for the copy on the target machine. + returned: changed + type: str + sample: /home/httpd/.ansible/tmp/ansible-tmp-1423796390.97-147729857856000/source +checksum: + description: SHA1 checksum of the file after running copy. + returned: success, src is a file + type: str + sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827 +size: + description: Size of the target, after execution. + returned: changed, src is a file + type: int + sample: 1220 +operation: + description: Whether a single file copy took place or a folder copy. + returned: success + type: str + sample: file_copy +original_basename: + description: Basename of the copied file. + returned: changed, src is a file + type: str + sample: foo.txt +''' diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1 new file mode 100644 index 0000000..5442754 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1 @@ -0,0 +1,152 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$ErrorActionPreference = "Stop" + +$params = Parse-Args $args -supports_check_mode $true + +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -default $false +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$path = Get-AnsibleParam -obj $params -name "path" -type "path" -failifempty $true -aliases "dest","name" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -validateset "absent","directory","file","touch" + +# used in template/copy when dest is the path to a dir and source is a file +$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str" +if ((Test-Path -LiteralPath $path -PathType Container) -and ($null -ne $original_basename)) { + $path = Join-Path -Path $path -ChildPath $original_basename +} + +$result = @{ + changed = $false +} + +# Used to delete symlinks as powershell cannot delete broken symlinks +$symlink_util = @" +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Ansible.Command { + public class SymLinkHelper { + [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)] + public static extern bool DeleteFileW(string lpFileName); + + [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)] + public static extern bool RemoveDirectoryW(string lpPathName); + + public static void DeleteDirectory(string path) { + if (!RemoveDirectoryW(path)) + throw new Exception(String.Format("RemoveDirectoryW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message)); + } + + public static void DeleteFile(string path) { + if (!DeleteFileW(path)) + throw new Exception(String.Format("DeleteFileW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message)); + } + } +} +"@ +$original_tmp = $env:TMP +$env:TMP = $_remote_tmp +Add-Type -TypeDefinition $symlink_util +$env:TMP = $original_tmp + +# Used to delete directories and files with logic on handling symbolic links +function Remove-File($file, $checkmode) { + try { + if ($file.Attributes -band [System.IO.FileAttributes]::ReparsePoint) { + # Bug with powershell, if you try and delete a symbolic link that is pointing + # to an invalid path it will fail, using Win32 API to do this instead + if ($file.PSIsContainer) { + if (-not $checkmode) { + [Ansible.Command.SymLinkHelper]::DeleteDirectory($file.FullName) + } + } else { + if (-not $checkmode) { + [Ansible.Command.SymlinkHelper]::DeleteFile($file.FullName) + } + } + } elseif ($file.PSIsContainer) { + Remove-Directory -directory $file -checkmode $checkmode + } else { + Remove-Item -LiteralPath $file.FullName -Force -WhatIf:$checkmode + } + } catch [Exception] { + Fail-Json $result "Failed to delete $($file.FullName): $($_.Exception.Message)" + } +} + +function Remove-Directory($directory, $checkmode) { + foreach ($file in Get-ChildItem -LiteralPath $directory.FullName) { + Remove-File -file $file -checkmode $checkmode + } + Remove-Item -LiteralPath $directory.FullName -Force -Recurse -WhatIf:$checkmode +} + + +if ($state -eq "touch") { + if (Test-Path -LiteralPath $path) { + if (-not $check_mode) { + (Get-ChildItem -LiteralPath $path).LastWriteTime = Get-Date + } + $result.changed = $true + } else { + Write-Output $null | Out-File -LiteralPath $path -Encoding ASCII -WhatIf:$check_mode + $result.changed = $true + } +} + +if (Test-Path -LiteralPath $path) { + $fileinfo = Get-Item -LiteralPath $path -Force + if ($state -eq "absent") { + Remove-File -file $fileinfo -checkmode $check_mode + $result.changed = $true + } else { + if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a directory" + } + + if ($state -eq "file" -and $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a file" + } + } + +} else { + + # If state is not supplied, test the $path to see if it looks like + # a file or a folder and set state to file or folder + if ($null -eq $state) { + $basename = Split-Path -Path $path -Leaf + if ($basename.length -gt 0) { + $state = "file" + } else { + $state = "directory" + } + } + + if ($state -eq "directory") { + try { + New-Item -Path $path -ItemType Directory -WhatIf:$check_mode | Out-Null + } catch { + if ($_.CategoryInfo.Category -eq "ResourceExists") { + $fileinfo = Get-Item -LiteralPath $_.CategoryInfo.TargetName + if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a directory" + } + } else { + Fail-Json $result $_.Exception.Message + } + } + $result.changed = $true + } elseif ($state -eq "file") { + Fail-Json $result "path $path will not be created" + } + +} + +Exit-Json $result diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py new file mode 100644 index 0000000..2814957 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py @@ -0,0 +1,70 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_file +version_added: "1.9.2" +short_description: Creates, touches or removes files or directories +description: + - Creates (empty) files, updates file modification stamps of existing files, + and can create or remove directories. + - Unlike M(file), does not modify ownership, permissions or manipulate links. + - For non-Windows targets, use the M(file) module instead. +options: + path: + description: + - Path to the file being managed. + required: yes + type: path + aliases: [ dest, name ] + state: + description: + - If C(directory), all immediate subdirectories will be created if they + do not exist. + - If C(file), the file will NOT be created if it does not exist, see the M(copy) + or M(template) module if you want that behavior. + - If C(absent), directories will be recursively deleted, and files will be removed. + - If C(touch), an empty file will be created if the C(path) does not + exist, while an existing file or directory will receive updated file access and + modification times (similar to the way C(touch) works from the command line). + type: str + choices: [ absent, directory, file, touch ] +seealso: +- module: file +- module: win_acl +- module: win_acl_inheritance +- module: win_owner +- module: win_stat +author: +- Jon Hawkesworth (@jhawkesworth) +''' + +EXAMPLES = r''' +- name: Touch a file (creates if not present, updates modification time if present) + win_file: + path: C:\Temp\foo.conf + state: touch + +- name: Remove a file, if present + win_file: + path: C:\Temp\foo.conf + state: absent + +- name: Create directory structure + win_file: + path: C:\Temp\folder\subfolder + state: directory + +- name: Remove directory structure + win_file: + path: C:\Temp + state: absent +''' diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1 new file mode 100644 index 0000000..c848b91 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1 @@ -0,0 +1,21 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic + +$spec = @{ + options = @{ + data = @{ type = "str"; default = "pong" } + } + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) +$data = $module.Params.data + +if ($data -eq "crash") { + throw "boom" +} + +$module.Result.ping = $data +$module.ExitJson() diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py new file mode 100644 index 0000000..6d35f37 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py @@ -0,0 +1,55 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>, and others +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_ping +version_added: "1.7" +short_description: A windows version of the classic ping module +description: + - Checks management connectivity of a windows host. + - This is NOT ICMP ping, this is just a trivial test module. + - For non-Windows targets, use the M(ping) module instead. + - For Network targets, use the M(net_ping) module instead. +options: + data: + description: + - Alternate data to return instead of 'pong'. + - If this parameter is set to C(crash), the module will cause an exception. + type: str + default: pong +seealso: +- module: ping +author: +- Chris Church (@cchurch) +''' + +EXAMPLES = r''' +# Test connectivity to a windows host +# ansible winserver -m win_ping + +- name: Example from an Ansible Playbook + win_ping: + +- name: Induce an exception to see what happens + win_ping: + data: crash +''' + +RETURN = r''' +ping: + description: Value provided with the data parameter. + returned: success + type: str + sample: pong +''' diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1 new file mode 100644 index 0000000..54aef8d --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1 @@ -0,0 +1,138 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.CommandUtil +#Requires -Module Ansible.ModuleUtils.FileUtil + +# TODO: add check mode support + +Set-StrictMode -Version 2 +$ErrorActionPreference = "Stop" + +# Cleanse CLIXML from stderr (sift out error stream data, discard others for now) +Function Cleanse-Stderr($raw_stderr) { + Try { + # NB: this regex isn't perfect, but is decent at finding CLIXML amongst other stderr noise + If($raw_stderr -match "(?s)(?<prenoise1>.*)#< CLIXML(?<prenoise2>.*)(?<clixml><Objs.+</Objs>)(?<postnoise>.*)") { + $clixml = [xml]$matches["clixml"] + + $merged_stderr = "{0}{1}{2}{3}" -f @( + $matches["prenoise1"], + $matches["prenoise2"], + # filter out just the Error-tagged strings for now, and zap embedded CRLF chars + ($clixml.Objs.ChildNodes | Where-Object { $_.Name -eq 'S' } | Where-Object { $_.S -eq 'Error' } | ForEach-Object { $_.'#text'.Replace('_x000D__x000A_','') } | Out-String), + $matches["postnoise"]) | Out-String + + return $merged_stderr.Trim() + + # FUTURE: parse/return other streams + } + Else { + $raw_stderr + } + } + Catch { + "***EXCEPTION PARSING CLIXML: $_***" + $raw_stderr + } +} + +$params = Parse-Args $args -supports_check_mode $false + +$raw_command_line = Get-AnsibleParam -obj $params -name "_raw_params" -type "str" -failifempty $true +$chdir = Get-AnsibleParam -obj $params -name "chdir" -type "path" +$executable = Get-AnsibleParam -obj $params -name "executable" -type "path" +$creates = Get-AnsibleParam -obj $params -name "creates" -type "path" +$removes = Get-AnsibleParam -obj $params -name "removes" -type "path" +$stdin = Get-AnsibleParam -obj $params -name "stdin" -type "str" +$no_profile = Get-AnsibleParam -obj $params -name "no_profile" -type "bool" -default $false +$output_encoding_override = Get-AnsibleParam -obj $params -name "output_encoding_override" -type "str" + +$raw_command_line = $raw_command_line.Trim() + +$result = @{ + changed = $true + cmd = $raw_command_line +} + +if ($creates -and $(Test-AnsiblePath -Path $creates)) { + Exit-Json @{msg="skipped, since $creates exists";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +if ($removes -and -not $(Test-AnsiblePath -Path $removes)) { + Exit-Json @{msg="skipped, since $removes does not exist";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +$exec_args = $null +If(-not $executable -or $executable -eq "powershell") { + $exec_application = "powershell.exe" + + # force input encoding to preamble-free UTF8 so PS sub-processes (eg, Start-Job) don't blow up + $raw_command_line = "[Console]::InputEncoding = New-Object Text.UTF8Encoding `$false; " + $raw_command_line + + # Base64 encode the command so we don't have to worry about the various levels of escaping + $encoded_command = [Convert]::ToBase64String([System.Text.Encoding]::Unicode.GetBytes($raw_command_line)) + + if ($stdin) { + $exec_args = "-encodedcommand $encoded_command" + } else { + $exec_args = "-noninteractive -encodedcommand $encoded_command" + } + + if ($no_profile) { + $exec_args = "-noprofile $exec_args" + } +} +Else { + # FUTURE: support arg translation from executable (or executable_args?) to process arguments for arbitrary interpreter? + $exec_application = $executable + if (-not ($exec_application.EndsWith(".exe"))) { + $exec_application = "$($exec_application).exe" + } + $exec_args = "/c $raw_command_line" +} + +$command = "`"$exec_application`" $exec_args" +$run_command_arg = @{ + command = $command +} +if ($chdir) { + $run_command_arg['working_directory'] = $chdir +} +if ($stdin) { + $run_command_arg['stdin'] = $stdin +} +if ($output_encoding_override) { + $run_command_arg['output_encoding_override'] = $output_encoding_override +} + +$start_datetime = [DateTime]::UtcNow +try { + $command_result = Run-Command @run_command_arg +} catch { + $result.changed = $false + try { + $result.rc = $_.Exception.NativeErrorCode + } catch { + $result.rc = 2 + } + Fail-Json -obj $result -message $_.Exception.Message +} + +# TODO: decode CLIXML stderr output (and other streams?) +$result.stdout = $command_result.stdout +$result.stderr = Cleanse-Stderr $command_result.stderr +$result.rc = $command_result.rc + +$end_datetime = [DateTime]::UtcNow +$result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff") + +If ($result.rc -ne 0) { + Fail-Json -obj $result -message "non-zero return code" +} + +Exit-Json $result diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py new file mode 100644 index 0000000..ee2cd76 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py @@ -0,0 +1,167 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Ansible, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_shell +short_description: Execute shell commands on target hosts +version_added: 2.2 +description: + - The C(win_shell) module takes the command name followed by a list of space-delimited arguments. + It is similar to the M(win_command) module, but runs + the command via a shell (defaults to PowerShell) on the target host. + - For non-Windows targets, use the M(shell) module instead. +options: + free_form: + description: + - The C(win_shell) module takes a free form command to run. + - There is no parameter actually named 'free form'. See the examples! + type: str + required: yes + creates: + description: + - A path or path filter pattern; when the referenced path exists on the target host, the task will be skipped. + type: path + removes: + description: + - A path or path filter pattern; when the referenced path B(does not) exist on the target host, the task will be skipped. + type: path + chdir: + description: + - Set the specified path as the current working directory before executing a command + type: path + executable: + description: + - Change the shell used to execute the command (eg, C(cmd)). + - The target shell must accept a C(/c) parameter followed by the raw command line to be executed. + type: path + stdin: + description: + - Set the stdin of the command directly to the specified value. + type: str + version_added: '2.5' + no_profile: + description: + - Do not load the user profile before running a command. This is only valid + when using PowerShell as the executable. + type: bool + default: no + version_added: '2.8' + output_encoding_override: + description: + - This option overrides the encoding of stdout/stderr output. + - You can use this option when you need to run a command which ignore the console's codepage. + - You should only need to use this option in very rare circumstances. + - This value can be any valid encoding C(Name) based on the output of C([System.Text.Encoding]::GetEncodings()). + See U(https://docs.microsoft.com/dotnet/api/system.text.encoding.getencodings). + type: str + version_added: '2.10' +notes: + - If you want to run an executable securely and predictably, it may be + better to use the M(win_command) module instead. Best practices when writing + playbooks will follow the trend of using M(win_command) unless C(win_shell) is + explicitly required. When running ad-hoc commands, use your best judgement. + - WinRM will not return from a command execution until all child processes created have exited. + Thus, it is not possible to use C(win_shell) to spawn long-running child or background processes. + Consider creating a Windows service for managing background processes. +seealso: +- module: psexec +- module: raw +- module: script +- module: shell +- module: win_command +- module: win_psexec +author: + - Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +# Execute a command in the remote shell; stdout goes to the specified +# file on the remote. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt + +# Change the working directory to somedir/ before executing the command. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt chdir=C:\somedir + +# You can also use the 'args' form to provide the options. This command +# will change the working directory to somedir/ and will only run when +# somedir/somelog.txt doesn't exist. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt + args: + chdir: C:\somedir + creates: C:\somelog.txt + +# Run a command under a non-Powershell interpreter (cmd in this case) +- win_shell: echo %HOMEDIR% + args: + executable: cmd + register: homedir_out + +- name: Run multi-lined shell commands + win_shell: | + $value = Test-Path -Path C:\temp + if ($value) { + Remove-Item -Path C:\temp -Force + } + New-Item -Path C:\temp -ItemType Directory + +- name: Retrieve the input based on stdin + win_shell: '$string = [Console]::In.ReadToEnd(); Write-Output $string.Trim()' + args: + stdin: Input message +''' + +RETURN = r''' +msg: + description: Changed. + returned: always + type: bool + sample: true +start: + description: The command execution start time. + returned: always + type: str + sample: '2016-02-25 09:18:26.429568' +end: + description: The command execution end time. + returned: always + type: str + sample: '2016-02-25 09:18:26.755339' +delta: + description: The command execution delta time. + returned: always + type: str + sample: '0:00:00.325771' +stdout: + description: The command standard output. + returned: always + type: str + sample: 'Clustering node rabbit@slave1 with rabbit@master ...' +stderr: + description: The command standard error. + returned: always + type: str + sample: 'ls: cannot access foo: No such file or directory' +cmd: + description: The command executed by the task. + returned: always + type: str + sample: 'rabbitmqctl join_cluster rabbit@master' +rc: + description: The command return code (0 means success). + returned: always + type: int + sample: 0 +stdout_lines: + description: The command standard output split in lines. + returned: always + type: list + sample: [u'Clustering node rabbit@slave1 with rabbit@master ...'] +''' diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1 new file mode 100644 index 0000000..071eb11 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1 @@ -0,0 +1,186 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.FileUtil +#Requires -Module Ansible.ModuleUtils.LinkUtil + +function ConvertTo-Timestamp($start_date, $end_date) { + if ($start_date -and $end_date) { + return (New-TimeSpan -Start $start_date -End $end_date).TotalSeconds + } +} + +function Get-FileChecksum($path, $algorithm) { + switch ($algorithm) { + 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider } + 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider } + 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider } + 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider } + 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider } + default { Fail-Json -obj $result -message "Unsupported hash algorithm supplied '$algorithm'" } + } + + $fp = [System.IO.File]::Open($path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, [System.IO.FileShare]::ReadWrite) + try { + $hash = [System.BitConverter]::ToString($sp.ComputeHash($fp)).Replace("-", "").ToLower() + } finally { + $fp.Dispose() + } + + return $hash +} + +function Get-FileInfo { + param([String]$Path, [Switch]$Follow) + + $info = Get-AnsibleItem -Path $Path -ErrorAction SilentlyContinue + $link_info = $null + if ($null -ne $info) { + try { + $link_info = Get-Link -link_path $info.FullName + } catch { + $module.Warn("Failed to check/get link info for file: $($_.Exception.Message)") + } + + # If follow=true we want to follow the link all the way back to root object + if ($Follow -and $null -ne $link_info -and $link_info.Type -in @("SymbolicLink", "JunctionPoint")) { + $info, $link_info = Get-FileInfo -Path $link_info.AbsolutePath -Follow + } + } + + return $info, $link_info +} + +$spec = @{ + options = @{ + path = @{ type='path'; required=$true; aliases=@( 'dest', 'name' ) } + get_checksum = @{ type='bool'; default=$true } + checksum_algorithm = @{ type='str'; default='sha1'; choices=@( 'md5', 'sha1', 'sha256', 'sha384', 'sha512' ) } + follow = @{ type='bool'; default=$false } + } + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$path = $module.Params.path +$get_checksum = $module.Params.get_checksum +$checksum_algorithm = $module.Params.checksum_algorithm +$follow = $module.Params.follow + +$module.Result.stat = @{ exists=$false } + +Load-LinkUtils +$info, $link_info = Get-FileInfo -Path $path -Follow:$follow +If ($null -ne $info) { + $epoch_date = Get-Date -Date "01/01/1970" + $attributes = @() + foreach ($attribute in ($info.Attributes -split ',')) { + $attributes += $attribute.Trim() + } + + # default values that are always set, specific values are set below this + # but are kept commented for easier readability + $stat = @{ + exists = $true + attributes = $info.Attributes.ToString() + isarchive = ($attributes -contains "Archive") + isdir = $false + ishidden = ($attributes -contains "Hidden") + isjunction = $false + islnk = $false + isreadonly = ($attributes -contains "ReadOnly") + isreg = $false + isshared = $false + nlink = 1 # Number of links to the file (hard links), overriden below if islnk + # lnk_target = islnk or isjunction Target of the symlink. Note that relative paths remain relative + # lnk_source = islnk os isjunction Target of the symlink normalized for the remote filesystem + hlnk_targets = @() + creationtime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.CreationTime) + lastaccesstime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastAccessTime) + lastwritetime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastWriteTime) + # size = a file and directory - calculated below + path = $info.FullName + filename = $info.Name + # extension = a file + # owner = set outsite this dict in case it fails + # sharename = a directory and isshared is True + # checksum = a file and get_checksum: True + } + try { + $stat.owner = $info.GetAccessControl().Owner + } catch { + # may not have rights, historical behaviour was to just set to $null + # due to ErrorActionPreference being set to "Continue" + $stat.owner = $null + } + + # values that are set according to the type of file + if ($info.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) { + $stat.isdir = $true + $share_info = Get-CimInstance -ClassName Win32_Share -Filter "Path='$($stat.path -replace '\\', '\\')'" + if ($null -ne $share_info) { + $stat.isshared = $true + $stat.sharename = $share_info.Name + } + + try { + $size = 0 + foreach ($file in $info.EnumerateFiles("*", [System.IO.SearchOption]::AllDirectories)) { + $size += $file.Length + } + $stat.size = $size + } catch { + $stat.size = 0 + } + } else { + $stat.extension = $info.Extension + $stat.isreg = $true + $stat.size = $info.Length + + if ($get_checksum) { + try { + $stat.checksum = Get-FileChecksum -path $path -algorithm $checksum_algorithm + } catch { + $module.FailJson("Failed to get hash of file, set get_checksum to False to ignore this error: $($_.Exception.Message)", $_) + } + } + } + + # Get symbolic link, junction point, hard link info + if ($null -ne $link_info) { + switch ($link_info.Type) { + "SymbolicLink" { + $stat.islnk = $true + $stat.isreg = $false + $stat.lnk_target = $link_info.TargetPath + $stat.lnk_source = $link_info.AbsolutePath + break + } + "JunctionPoint" { + $stat.isjunction = $true + $stat.isreg = $false + $stat.lnk_target = $link_info.TargetPath + $stat.lnk_source = $link_info.AbsolutePath + break + } + "HardLink" { + $stat.lnk_type = "hard" + $stat.nlink = $link_info.HardTargets.Count + + # remove current path from the targets + $hlnk_targets = $link_info.HardTargets | Where-Object { $_ -ne $stat.path } + $stat.hlnk_targets = @($hlnk_targets) + break + } + } + } + + $module.Result.stat = $stat +} + +$module.ExitJson() + diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py new file mode 100644 index 0000000..0676b5b --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py @@ -0,0 +1,236 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_stat +version_added: "1.7" +short_description: Get information about Windows files +description: + - Returns information about a Windows file. + - For non-Windows targets, use the M(stat) module instead. +options: + path: + description: + - The full path of the file/object to get the facts of; both forward and + back slashes are accepted. + type: path + required: yes + aliases: [ dest, name ] + get_checksum: + description: + - Whether to return a checksum of the file (default sha1) + type: bool + default: yes + version_added: "2.1" + checksum_algorithm: + description: + - Algorithm to determine checksum of file. + - Will throw an error if the host is unable to use specified algorithm. + type: str + default: sha1 + choices: [ md5, sha1, sha256, sha384, sha512 ] + version_added: "2.3" + follow: + description: + - Whether to follow symlinks or junction points. + - In the case of C(path) pointing to another link, then that will + be followed until no more links are found. + type: bool + default: no + version_added: "2.8" +seealso: +- module: stat +- module: win_acl +- module: win_file +- module: win_owner +author: +- Chris Church (@cchurch) +''' + +EXAMPLES = r''' +- name: Obtain information about a file + win_stat: + path: C:\foo.ini + register: file_info + +- name: Obtain information about a folder + win_stat: + path: C:\bar + register: folder_info + +- name: Get MD5 checksum of a file + win_stat: + path: C:\foo.ini + get_checksum: yes + checksum_algorithm: md5 + register: md5_checksum + +- debug: + var: md5_checksum.stat.checksum + +- name: Get SHA1 checksum of file + win_stat: + path: C:\foo.ini + get_checksum: yes + register: sha1_checksum + +- debug: + var: sha1_checksum.stat.checksum + +- name: Get SHA256 checksum of file + win_stat: + path: C:\foo.ini + get_checksum: yes + checksum_algorithm: sha256 + register: sha256_checksum + +- debug: + var: sha256_checksum.stat.checksum +''' + +RETURN = r''' +changed: + description: Whether anything was changed + returned: always + type: bool + sample: true +stat: + description: dictionary containing all the stat data + returned: success + type: complex + contains: + attributes: + description: Attributes of the file at path in raw form. + returned: success, path exists + type: str + sample: "Archive, Hidden" + checksum: + description: The checksum of a file based on checksum_algorithm specified. + returned: success, path exist, path is a file, get_checksum == True + checksum_algorithm specified is supported + type: str + sample: 09cb79e8fc7453c84a07f644e441fd81623b7f98 + creationtime: + description: The create time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + exists: + description: If the path exists or not. + returned: success + type: bool + sample: true + extension: + description: The extension of the file at path. + returned: success, path exists, path is a file + type: str + sample: ".ps1" + filename: + description: The name of the file (without path). + returned: success, path exists, path is a file + type: str + sample: foo.ini + hlnk_targets: + description: List of other files pointing to the same file (hard links), excludes the current file. + returned: success, path exists + type: list + sample: + - C:\temp\file.txt + - C:\Windows\update.log + isarchive: + description: If the path is ready for archiving or not. + returned: success, path exists + type: bool + sample: true + isdir: + description: If the path is a directory or not. + returned: success, path exists + type: bool + sample: true + ishidden: + description: If the path is hidden or not. + returned: success, path exists + type: bool + sample: true + isjunction: + description: If the path is a junction point or not. + returned: success, path exists + type: bool + sample: true + islnk: + description: If the path is a symbolic link or not. + returned: success, path exists + type: bool + sample: true + isreadonly: + description: If the path is read only or not. + returned: success, path exists + type: bool + sample: true + isreg: + description: If the path is a regular file. + returned: success, path exists + type: bool + sample: true + isshared: + description: If the path is shared or not. + returned: success, path exists + type: bool + sample: true + lastaccesstime: + description: The last access time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lastwritetime: + description: The last modification time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lnk_source: + description: Target of the symlink normalized for the remote filesystem. + returned: success, path exists and the path is a symbolic link or junction point + type: str + sample: C:\temp\link + lnk_target: + description: Target of the symlink. Note that relative paths remain relative. + returned: success, path exists and the path is a symbolic link or junction point + type: str + sample: ..\link + nlink: + description: Number of links to the file (hard links). + returned: success, path exists + type: int + sample: 1 + owner: + description: The owner of the file. + returned: success, path exists + type: str + sample: BUILTIN\Administrators + path: + description: The full absolute path to the file. + returned: success, path exists, file exists + type: str + sample: C:\foo.ini + sharename: + description: The name of share if folder is shared. + returned: success, path exists, file is a directory and isshared == True + type: str + sample: file-share + size: + description: The size in bytes of a file or folder. + returned: success, path exists, file is not a link + type: int + sample: 1024 +''' diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_uri.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_uri.ps1 new file mode 100644 index 0000000..9d7c68b --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_uri.ps1 @@ -0,0 +1,219 @@ +#!powershell + +# Copyright: (c) 2015, Corwin Brown <corwin@corwinbrown.com> +# Copyright: (c) 2017, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.CamelConversion +#Requires -Module Ansible.ModuleUtils.FileUtil +#Requires -Module Ansible.ModuleUtils.Legacy +#AnsibleRequires -PowerShell ..module_utils.WebRequest + +$spec = @{ + options = @{ + url = @{ type = "str"; required = $true } + content_type = @{ type = "str" } + body = @{ type = "raw" } + dest = @{ type = "path" } + creates = @{ type = "path" } + removes = @{ type = "path" } + return_content = @{ type = "bool"; default = $false } + status_code = @{ type = "list"; elements = "int"; default = @(200) } + + # Defined for ease of use and backwards compatibility + url_timeout = @{ + aliases = "timeout" + } + url_method = @{ + aliases = "method" + default = "GET" + } + + # Defined for the alias backwards compatibility, remove once aliases are removed + url_username = @{ + aliases = @("user", "username") + deprecated_aliases = @( + @{ name = "user"; date = [DateTime]::ParseExact("2022-07-01", "yyyy-MM-dd", $null); collection_name = 'ansible.windows' }, + @{ name = "username"; date = [DateTime]::ParseExact("2022-07-01", "yyyy-MM-dd", $null); collection_name = 'ansible.windows' } + ) + } + url_password = @{ + aliases = @("password") + deprecated_aliases = @( + @{ name = "password"; date = [DateTime]::ParseExact("2022-07-01", "yyyy-MM-dd", $null); collection_name = 'ansible.windows' } + ) + } + } + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWindowsWebRequestSpec)) + +$url = $module.Params.url +$method = $module.Params.url_method.ToUpper() +$content_type = $module.Params.content_type +$body = $module.Params.body +$dest = $module.Params.dest +$creates = $module.Params.creates +$removes = $module.Params.removes +$return_content = $module.Params.return_content +$status_code = $module.Params.status_code + +$JSON_CANDIDATES = @('text', 'json', 'javascript') + +$module.Result.elapsed = 0 +$module.Result.url = $url + +Function ConvertFrom-SafeJson { + <# + .SYNOPSIS + Safely convert a JSON string to an object, this is like ConvertFrom-Json except it respect -ErrorAction. + + .PAREMTER InputObject + The input object string to convert from. + #> + [CmdletBinding()] + param ( + [Parameter(Mandatory=$true)] + [AllowEmptyString()] + [AllowNull()] + [String] + $InputObject + ) + + if (-not $InputObject) { + return + } + + try { + # Make sure we output the actual object without unpacking with the unary comma + ,[Ansible.Basic.AnsibleModule]::FromJson($InputObject) + } catch [System.ArgumentException] { + Write-Error -Message "Invalid json string as input object: $($_.Exception.Message)" -Exception $_.Exception + } +} + +if (-not ($method -cmatch '^[A-Z]+$')) { + $module.FailJson("Parameter 'method' needs to be a single word in uppercase, like GET or POST.") +} + +if ($creates -and (Test-AnsiblePath -Path $creates)) { + $module.Result.skipped = $true + $module.Result.msg = "The 'creates' file or directory ($creates) already exists." + $module.ExitJson() +} + +if ($removes -and -not (Test-AnsiblePath -Path $removes)) { + $module.Result.skipped = $true + $module.Result.msg = "The 'removes' file or directory ($removes) does not exist." + $module.ExitJson() +} + +$client = Get-AnsibleWindowsWebRequest -Uri $url -Module $module + +if ($null -ne $content_type) { + $client.ContentType = $content_type +} + +$response_script = { + param($Response, $Stream) + + ForEach ($prop in $Response.PSObject.Properties) { + $result_key = Convert-StringToSnakeCase -string $prop.Name + $prop_value = $prop.Value + # convert and DateTime values to ISO 8601 standard + if ($prop_value -is [System.DateTime]) { + $prop_value = $prop_value.ToString("o", [System.Globalization.CultureInfo]::InvariantCulture) + } + $module.Result.$result_key = $prop_value + } + + # manually get the headers as not all of them are in the response properties + foreach ($header_key in $Response.Headers.GetEnumerator()) { + $header_value = $Response.Headers[$header_key] + $header_key = $header_key.Replace("-", "") # replace - with _ for snake conversion + $header_key = Convert-StringToSnakeCase -string $header_key + $module.Result.$header_key = $header_value + } + + # we only care about the return body if we need to return the content or create a file + if ($return_content -or $dest) { + # copy to a MemoryStream so we can read it multiple times + $memory_st = New-Object -TypeName System.IO.MemoryStream + try { + $Stream.CopyTo($memory_st) + + if ($return_content) { + $memory_st.Seek(0, [System.IO.SeekOrigin]::Begin) > $null + $content_bytes = $memory_st.ToArray() + $module.Result.content = [System.Text.Encoding]::UTF8.GetString($content_bytes) + if ($module.Result.ContainsKey("content_type") -and $module.Result.content_type -Match ($JSON_CANDIDATES -join '|')) { + $json = ConvertFrom-SafeJson -InputObject $module.Result.content -ErrorAction SilentlyContinue + if ($json) { + $module.Result.json = $json + } + } + } + + if ($dest) { + $memory_st.Seek(0, [System.IO.SeekOrigin]::Begin) > $null + $changed = $true + + if (Test-AnsiblePath -Path $dest) { + $actual_checksum = Get-FileChecksum -path $dest -algorithm "sha1" + + $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider + $content_checksum = [System.BitConverter]::ToString($sp.ComputeHash($memory_st)).Replace("-", "").ToLower() + + if ($actual_checksum -eq $content_checksum) { + $changed = $false + } + } + + $module.Result.changed = $changed + if ($changed -and (-not $module.CheckMode)) { + $memory_st.Seek(0, [System.IO.SeekOrigin]::Begin) > $null + $file_stream = [System.IO.File]::Create($dest) + try { + $memory_st.CopyTo($file_stream) + } finally { + $file_stream.Flush() + $file_stream.Close() + } + } + } + } finally { + $memory_st.Close() + } + } + + if ($status_code -notcontains $Response.StatusCode) { + $module.FailJson("Status code of request '$([int]$Response.StatusCode)' is not in list of valid status codes $status_code : $($Response.StatusCode)'.") + } +} + +$body_st = $null +if ($null -ne $body) { + if ($body -is [System.Collections.IDictionary] -or $body -is [System.Collections.IList]) { + $body_string = ConvertTo-Json -InputObject $body -Compress + } elseif ($body -isnot [String]) { + $body_string = $body.ToString() + } else { + $body_string = $body + } + $buffer = [System.Text.Encoding]::UTF8.GetBytes($body_string) + + $body_st = New-Object -TypeName System.IO.MemoryStream -ArgumentList @(,$buffer) +} + +try { + Invoke-AnsibleWindowsWebRequest -Module $module -Request $client -Script $response_script -Body $body_st -IgnoreBadResponse +} catch { + $module.FailJson("Unhandled exception occurred when sending web request. Exception: $($_.Exception.Message)", $_) +} finally { + if ($null -ne $body_st) { + $body_st.Dispose() + } +} + +$module.ExitJson() diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_uri.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_uri.py new file mode 100644 index 0000000..3b1094e --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_uri.py @@ -0,0 +1,155 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Corwin Brown <corwin@corwinbrown.com> +# Copyright: (c) 2017, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +DOCUMENTATION = r''' +--- +module: win_uri +short_description: Interacts with webservices +description: +- Interacts with FTP, HTTP and HTTPS web services. +- Supports Digest, Basic and WSSE HTTP authentication mechanisms. +- For non-Windows targets, use the M(ansible.builtin.uri) module instead. +options: + url: + description: + - Supports FTP, HTTP or HTTPS URLs in the form of (ftp|http|https)://host.domain:port/path. + type: str + required: yes + content_type: + description: + - Sets the "Content-Type" header. + type: str + body: + description: + - The body of the HTTP request/response to the web service. + type: raw + dest: + description: + - Output the response body to a file. + type: path + creates: + description: + - A filename, when it already exists, this step will be skipped. + type: path + removes: + description: + - A filename, when it does not exist, this step will be skipped. + type: path + return_content: + description: + - Whether or not to return the body of the response as a "content" key in + the dictionary result. If the reported Content-type is + "application/json", then the JSON is additionally loaded into a key + called C(json) in the dictionary results. + type: bool + default: no + status_code: + description: + - A valid, numeric, HTTP status code that signifies success of the request. + - Can also be comma separated list of status codes. + type: list + elements: int + default: [ 200 ] + + url_method: + default: GET + aliases: + - method + url_timeout: + aliases: + - timeout + + # Following defined in the web_request fragment but the module contains deprecated aliases for backwards compatibility. + url_username: + description: + - The username to use for authentication. + - The alias I(user) and I(username) is deprecated and will be removed on + the major release after C(2022-07-01). + aliases: + - user + - username + url_password: + description: + - The password for I(url_username). + - The alias I(password) is deprecated and will be removed on the major + release after C(2022-07-01). + aliases: + - password +extends_documentation_fragment: +- ansible.windows.web_request + +seealso: +- module: ansible.builtin.uri +- module: ansible.windows.win_get_url +author: +- Corwin Brown (@blakfeld) +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r''' +- name: Perform a GET and Store Output + ansible.windows.win_uri: + url: http://example.com/endpoint + register: http_output + +# Set a HOST header to hit an internal webserver: +- name: Hit a Specific Host on the Server + ansible.windows.win_uri: + url: http://example.com/ + method: GET + headers: + host: www.somesite.com + +- name: Perform a HEAD on an Endpoint + ansible.windows.win_uri: + url: http://www.example.com/ + method: HEAD + +- name: POST a Body to an Endpoint + ansible.windows.win_uri: + url: http://www.somesite.com/ + method: POST + body: "{ 'some': 'json' }" +''' + +RETURN = r''' +elapsed: + description: The number of seconds that elapsed while performing the download. + returned: always + type: float + sample: 23.2 +url: + description: The Target URL. + returned: always + type: str + sample: https://www.ansible.com +status_code: + description: The HTTP Status Code of the response. + returned: success + type: int + sample: 200 +status_description: + description: A summary of the status. + returned: success + type: str + sample: OK +content: + description: The raw content of the HTTP response. + returned: success and return_content is True + type: str + sample: '{"foo": "bar"}' +content_length: + description: The byte size of the response. + returned: success + type: int + sample: 54447 +json: + description: The json structure returned under content as a dictionary. + returned: success and Content-Type is "application/json" or "application/javascript" and return_content is True + type: dict + sample: {"this-is-dependent": "on the actual return content"} +''' diff --git a/test/support/windows-integration/plugins/action/win_copy.py b/test/support/windows-integration/plugins/action/win_copy.py new file mode 100644 index 0000000..adb918b --- /dev/null +++ b/test/support/windows-integration/plugins/action/win_copy.py @@ -0,0 +1,522 @@ +# This file is part of Ansible + +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import base64 +import json +import os +import os.path +import shutil +import tempfile +import traceback +import zipfile + +from ansible import constants as C +from ansible.errors import AnsibleError, AnsibleFileNotFound +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.parsing.convert_bool import boolean +from ansible.plugins.action import ActionBase +from ansible.utils.hashing import checksum + + +def _walk_dirs(topdir, loader, decrypt=True, base_path=None, local_follow=False, trailing_slash_detector=None, checksum_check=False): + """ + Walk a filesystem tree returning enough information to copy the files. + This is similar to the _walk_dirs function in ``copy.py`` but returns + a dict instead of a tuple for each entry and includes the checksum of + a local file if wanted. + + :arg topdir: The directory that the filesystem tree is rooted at + :arg loader: The self._loader object from ActionBase + :kwarg decrypt: Whether to decrypt a file encrypted with ansible-vault + :kwarg base_path: The initial directory structure to strip off of the + files for the destination directory. If this is None (the default), + the base_path is set to ``top_dir``. + :kwarg local_follow: Whether to follow symlinks on the source. When set + to False, no symlinks are dereferenced. When set to True (the + default), the code will dereference most symlinks. However, symlinks + can still be present if needed to break a circular link. + :kwarg trailing_slash_detector: Function to determine if a path has + a trailing directory separator. Only needed when dealing with paths on + a remote machine (in which case, pass in a function that is aware of the + directory separator conventions on the remote machine). + :kawrg whether to get the checksum of the local file and add to the dict + :returns: dictionary of dictionaries. All of the path elements in the structure are text string. + This separates all the files, directories, and symlinks along with + import information about each:: + + { + 'files'; [{ + src: '/absolute/path/to/copy/from', + dest: 'relative/path/to/copy/to', + checksum: 'b54ba7f5621240d403f06815f7246006ef8c7d43' + }, ...], + 'directories'; [{ + src: '/absolute/path/to/copy/from', + dest: 'relative/path/to/copy/to' + }, ...], + 'symlinks'; [{ + src: '/symlink/target/path', + dest: 'relative/path/to/copy/to' + }, ...], + + } + + The ``symlinks`` field is only populated if ``local_follow`` is set to False + *or* a circular symlink cannot be dereferenced. The ``checksum`` entry is set + to None if checksum_check=False. + + """ + # Convert the path segments into byte strings + + r_files = {'files': [], 'directories': [], 'symlinks': []} + + def _recurse(topdir, rel_offset, parent_dirs, rel_base=u'', checksum_check=False): + """ + This is a closure (function utilizing variables from it's parent + function's scope) so that we only need one copy of all the containers. + Note that this function uses side effects (See the Variables used from + outer scope). + + :arg topdir: The directory we are walking for files + :arg rel_offset: Integer defining how many characters to strip off of + the beginning of a path + :arg parent_dirs: Directories that we're copying that this directory is in. + :kwarg rel_base: String to prepend to the path after ``rel_offset`` is + applied to form the relative path. + + Variables used from the outer scope + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + :r_files: Dictionary of files in the hierarchy. See the return value + for :func:`walk` for the structure of this dictionary. + :local_follow: Read-only inside of :func:`_recurse`. Whether to follow symlinks + """ + for base_path, sub_folders, files in os.walk(topdir): + for filename in files: + filepath = os.path.join(base_path, filename) + dest_filepath = os.path.join(rel_base, filepath[rel_offset:]) + + if os.path.islink(filepath): + # Dereference the symlnk + real_file = loader.get_real_file(os.path.realpath(filepath), decrypt=decrypt) + if local_follow and os.path.isfile(real_file): + # Add the file pointed to by the symlink + r_files['files'].append( + { + "src": real_file, + "dest": dest_filepath, + "checksum": _get_local_checksum(checksum_check, real_file) + } + ) + else: + # Mark this file as a symlink to copy + r_files['symlinks'].append({"src": os.readlink(filepath), "dest": dest_filepath}) + else: + # Just a normal file + real_file = loader.get_real_file(filepath, decrypt=decrypt) + r_files['files'].append( + { + "src": real_file, + "dest": dest_filepath, + "checksum": _get_local_checksum(checksum_check, real_file) + } + ) + + for dirname in sub_folders: + dirpath = os.path.join(base_path, dirname) + dest_dirpath = os.path.join(rel_base, dirpath[rel_offset:]) + real_dir = os.path.realpath(dirpath) + dir_stats = os.stat(real_dir) + + if os.path.islink(dirpath): + if local_follow: + if (dir_stats.st_dev, dir_stats.st_ino) in parent_dirs: + # Just insert the symlink if the target directory + # exists inside of the copy already + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Walk the dirpath to find all parent directories. + new_parents = set() + parent_dir_list = os.path.dirname(dirpath).split(os.path.sep) + for parent in range(len(parent_dir_list), 0, -1): + parent_stat = os.stat(u'/'.join(parent_dir_list[:parent])) + if (parent_stat.st_dev, parent_stat.st_ino) in parent_dirs: + # Reached the point at which the directory + # tree is already known. Don't add any + # more or we might go to an ancestor that + # isn't being copied. + break + new_parents.add((parent_stat.st_dev, parent_stat.st_ino)) + + if (dir_stats.st_dev, dir_stats.st_ino) in new_parents: + # This was a a circular symlink. So add it as + # a symlink + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Walk the directory pointed to by the symlink + r_files['directories'].append({"src": real_dir, "dest": dest_dirpath}) + offset = len(real_dir) + 1 + _recurse(real_dir, offset, parent_dirs.union(new_parents), + rel_base=dest_dirpath, + checksum_check=checksum_check) + else: + # Add the symlink to the destination + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Just a normal directory + r_files['directories'].append({"src": dirpath, "dest": dest_dirpath}) + + # Check if the source ends with a "/" so that we know which directory + # level to work at (similar to rsync) + source_trailing_slash = False + if trailing_slash_detector: + source_trailing_slash = trailing_slash_detector(topdir) + else: + source_trailing_slash = topdir.endswith(os.path.sep) + + # Calculate the offset needed to strip the base_path to make relative + # paths + if base_path is None: + base_path = topdir + if not source_trailing_slash: + base_path = os.path.dirname(base_path) + if topdir.startswith(base_path): + offset = len(base_path) + + # Make sure we're making the new paths relative + if trailing_slash_detector and not trailing_slash_detector(base_path): + offset += 1 + elif not base_path.endswith(os.path.sep): + offset += 1 + + if os.path.islink(topdir) and not local_follow: + r_files['symlinks'] = {"src": os.readlink(topdir), "dest": os.path.basename(topdir)} + return r_files + + dir_stats = os.stat(topdir) + parents = frozenset(((dir_stats.st_dev, dir_stats.st_ino),)) + # Actually walk the directory hierarchy + _recurse(topdir, offset, parents, checksum_check=checksum_check) + + return r_files + + +def _get_local_checksum(get_checksum, local_path): + if get_checksum: + return checksum(local_path) + else: + return None + + +class ActionModule(ActionBase): + + WIN_PATH_SEPARATOR = "\\" + + def _create_content_tempfile(self, content): + ''' Create a tempfile containing defined content ''' + fd, content_tempfile = tempfile.mkstemp(dir=C.DEFAULT_LOCAL_TMP) + f = os.fdopen(fd, 'wb') + content = to_bytes(content) + try: + f.write(content) + except Exception as err: + os.remove(content_tempfile) + raise Exception(err) + finally: + f.close() + return content_tempfile + + def _create_zip_tempfile(self, files, directories): + tmpdir = tempfile.mkdtemp(dir=C.DEFAULT_LOCAL_TMP) + zip_file_path = os.path.join(tmpdir, "win_copy.zip") + zip_file = zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_STORED, True) + + # encoding the file/dir name with base64 so Windows can unzip a unicode + # filename and get the right name, Windows doesn't handle unicode names + # very well + for directory in directories: + directory_path = to_bytes(directory['src'], errors='surrogate_or_strict') + archive_path = to_bytes(directory['dest'], errors='surrogate_or_strict') + + encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict') + zip_file.write(directory_path, encoded_path, zipfile.ZIP_DEFLATED) + + for file in files: + file_path = to_bytes(file['src'], errors='surrogate_or_strict') + archive_path = to_bytes(file['dest'], errors='surrogate_or_strict') + + encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict') + zip_file.write(file_path, encoded_path, zipfile.ZIP_DEFLATED) + + return zip_file_path + + def _remove_tempfile_if_content_defined(self, content, content_tempfile): + if content is not None: + os.remove(content_tempfile) + + def _copy_single_file(self, local_file, dest, source_rel, task_vars, tmp, backup): + if self._play_context.check_mode: + module_return = dict(changed=True) + return module_return + + # copy the file across to the server + tmp_src = self._connection._shell.join_path(tmp, 'source') + self._transfer_file(local_file, tmp_src) + + copy_args = self._task.args.copy() + copy_args.update( + dict( + dest=dest, + src=tmp_src, + _original_basename=source_rel, + _copy_mode="single", + backup=backup, + ) + ) + copy_args.pop('content', None) + + copy_result = self._execute_module(module_name="copy", + module_args=copy_args, + task_vars=task_vars) + + return copy_result + + def _copy_zip_file(self, dest, files, directories, task_vars, tmp, backup): + # create local zip file containing all the files and directories that + # need to be copied to the server + if self._play_context.check_mode: + module_return = dict(changed=True) + return module_return + + try: + zip_file = self._create_zip_tempfile(files, directories) + except Exception as e: + module_return = dict( + changed=False, + failed=True, + msg="failed to create tmp zip file: %s" % to_text(e), + exception=traceback.format_exc() + ) + return module_return + + zip_path = self._loader.get_real_file(zip_file) + + # send zip file to remote, file must end in .zip so + # Com Shell.Application works + tmp_src = self._connection._shell.join_path(tmp, 'source.zip') + self._transfer_file(zip_path, tmp_src) + + # run the explode operation of win_copy on remote + copy_args = self._task.args.copy() + copy_args.update( + dict( + src=tmp_src, + dest=dest, + _copy_mode="explode", + backup=backup, + ) + ) + copy_args.pop('content', None) + module_return = self._execute_module(module_name='copy', + module_args=copy_args, + task_vars=task_vars) + shutil.rmtree(os.path.dirname(zip_path)) + return module_return + + def run(self, tmp=None, task_vars=None): + ''' handler for file transfer operations ''' + if task_vars is None: + task_vars = dict() + + result = super(ActionModule, self).run(tmp, task_vars) + del tmp # tmp no longer has any effect + + source = self._task.args.get('src', None) + content = self._task.args.get('content', None) + dest = self._task.args.get('dest', None) + remote_src = boolean(self._task.args.get('remote_src', False), strict=False) + local_follow = boolean(self._task.args.get('local_follow', False), strict=False) + force = boolean(self._task.args.get('force', True), strict=False) + decrypt = boolean(self._task.args.get('decrypt', True), strict=False) + backup = boolean(self._task.args.get('backup', False), strict=False) + + result['src'] = source + result['dest'] = dest + + result['failed'] = True + if (source is None and content is None) or dest is None: + result['msg'] = "src (or content) and dest are required" + elif source is not None and content is not None: + result['msg'] = "src and content are mutually exclusive" + elif content is not None and dest is not None and ( + dest.endswith(os.path.sep) or dest.endswith(self.WIN_PATH_SEPARATOR)): + result['msg'] = "dest must be a file if content is defined" + else: + del result['failed'] + + if result.get('failed'): + return result + + # If content is defined make a temp file and write the content into it + content_tempfile = None + if content is not None: + try: + # if content comes to us as a dict it should be decoded json. + # We need to encode it back into a string and write it out + if isinstance(content, dict) or isinstance(content, list): + content_tempfile = self._create_content_tempfile(json.dumps(content)) + else: + content_tempfile = self._create_content_tempfile(content) + source = content_tempfile + except Exception as err: + result['failed'] = True + result['msg'] = "could not write content tmp file: %s" % to_native(err) + return result + # all actions should occur on the remote server, run win_copy module + elif remote_src: + new_module_args = self._task.args.copy() + new_module_args.update( + dict( + _copy_mode="remote", + dest=dest, + src=source, + force=force, + backup=backup, + ) + ) + new_module_args.pop('content', None) + result.update(self._execute_module(module_args=new_module_args, task_vars=task_vars)) + return result + # find_needle returns a path that may not have a trailing slash on a + # directory so we need to find that out first and append at the end + else: + trailing_slash = source.endswith(os.path.sep) + try: + # find in expected paths + source = self._find_needle('files', source) + except AnsibleError as e: + result['failed'] = True + result['msg'] = to_text(e) + result['exception'] = traceback.format_exc() + return result + + if trailing_slash != source.endswith(os.path.sep): + if source[-1] == os.path.sep: + source = source[:-1] + else: + source = source + os.path.sep + + # A list of source file tuples (full_path, relative_path) which will try to copy to the destination + source_files = {'files': [], 'directories': [], 'symlinks': []} + + # If source is a directory populate our list else source is a file and translate it to a tuple. + if os.path.isdir(to_bytes(source, errors='surrogate_or_strict')): + result['operation'] = 'folder_copy' + + # Get a list of the files we want to replicate on the remote side + source_files = _walk_dirs(source, self._loader, decrypt=decrypt, local_follow=local_follow, + trailing_slash_detector=self._connection._shell.path_has_trailing_slash, + checksum_check=force) + + # If it's recursive copy, destination is always a dir, + # explicitly mark it so (note - win_copy module relies on this). + if not self._connection._shell.path_has_trailing_slash(dest): + dest = "%s%s" % (dest, self.WIN_PATH_SEPARATOR) + + check_dest = dest + # Source is a file, add details to source_files dict + else: + result['operation'] = 'file_copy' + + # If the local file does not exist, get_real_file() raises AnsibleFileNotFound + try: + source_full = self._loader.get_real_file(source, decrypt=decrypt) + except AnsibleFileNotFound as e: + result['failed'] = True + result['msg'] = "could not find src=%s, %s" % (source_full, to_text(e)) + return result + + original_basename = os.path.basename(source) + result['original_basename'] = original_basename + + # check if dest ends with / or \ and append source filename to dest + if self._connection._shell.path_has_trailing_slash(dest): + check_dest = dest + filename = original_basename + result['dest'] = self._connection._shell.join_path(dest, filename) + else: + # replace \\ with / so we can use os.path to get the filename or dirname + unix_path = dest.replace(self.WIN_PATH_SEPARATOR, os.path.sep) + filename = os.path.basename(unix_path) + check_dest = os.path.dirname(unix_path) + + file_checksum = _get_local_checksum(force, source_full) + source_files['files'].append( + dict( + src=source_full, + dest=filename, + checksum=file_checksum + ) + ) + result['checksum'] = file_checksum + result['size'] = os.path.getsize(to_bytes(source_full, errors='surrogate_or_strict')) + + # find out the files/directories/symlinks that we need to copy to the server + query_args = self._task.args.copy() + query_args.update( + dict( + _copy_mode="query", + dest=check_dest, + force=force, + files=source_files['files'], + directories=source_files['directories'], + symlinks=source_files['symlinks'], + ) + ) + # src is not required for query, will fail path validation is src has unix allowed chars + query_args.pop('src', None) + + query_args.pop('content', None) + query_return = self._execute_module(module_args=query_args, + task_vars=task_vars) + + if query_return.get('failed') is True: + result.update(query_return) + return result + + if len(query_return['files']) > 0 or len(query_return['directories']) > 0 and self._connection._shell.tmpdir is None: + self._connection._shell.tmpdir = self._make_tmp_path() + + if len(query_return['files']) == 1 and len(query_return['directories']) == 0: + # we only need to copy 1 file, don't mess around with zips + file_src = query_return['files'][0]['src'] + file_dest = query_return['files'][0]['dest'] + result.update(self._copy_single_file(file_src, dest, file_dest, + task_vars, self._connection._shell.tmpdir, backup)) + if result.get('failed') is True: + result['msg'] = "failed to copy file %s: %s" % (file_src, result['msg']) + result['changed'] = True + + elif len(query_return['files']) > 0 or len(query_return['directories']) > 0: + # either multiple files or directories need to be copied, compress + # to a zip and 'explode' the zip on the server + # TODO: handle symlinks + result.update(self._copy_zip_file(dest, source_files['files'], + source_files['directories'], + task_vars, self._connection._shell.tmpdir, backup)) + result['changed'] = True + else: + # no operations need to occur + result['failed'] = False + result['changed'] = False + + # remove the content tmp file and remote tmp file if it was created + self._remove_tempfile_if_content_defined(content, content_tempfile) + self._remove_tmp_path(self._connection._shell.tmpdir) + return result diff --git a/test/support/windows-integration/plugins/action/win_reboot.py b/test/support/windows-integration/plugins/action/win_reboot.py new file mode 100644 index 0000000..c408f4f --- /dev/null +++ b/test/support/windows-integration/plugins/action/win_reboot.py @@ -0,0 +1,96 @@ +# Copyright: (c) 2018, Matt Davis <mdavis@ansible.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from datetime import datetime + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_native +from ansible.plugins.action import ActionBase +from ansible.plugins.action.reboot import ActionModule as RebootActionModule +from ansible.utils.display import Display + +display = Display() + + +class TimedOutException(Exception): + pass + + +class ActionModule(RebootActionModule, ActionBase): + TRANSFERS_FILES = False + _VALID_ARGS = frozenset(( + 'connect_timeout', 'connect_timeout_sec', 'msg', 'post_reboot_delay', 'post_reboot_delay_sec', 'pre_reboot_delay', 'pre_reboot_delay_sec', + 'reboot_timeout', 'reboot_timeout_sec', 'shutdown_timeout', 'shutdown_timeout_sec', 'test_command', + )) + + DEFAULT_BOOT_TIME_COMMAND = "(Get-WmiObject -ClassName Win32_OperatingSystem).LastBootUpTime" + DEFAULT_CONNECT_TIMEOUT = 5 + DEFAULT_PRE_REBOOT_DELAY = 2 + DEFAULT_SUDOABLE = False + DEFAULT_SHUTDOWN_COMMAND_ARGS = '/r /t {delay_sec} /c "{message}"' + + DEPRECATED_ARGS = { + 'shutdown_timeout': '2.5', + 'shutdown_timeout_sec': '2.5', + } + + def __init__(self, *args, **kwargs): + super(ActionModule, self).__init__(*args, **kwargs) + + def get_distribution(self, task_vars): + return {'name': 'windows', 'version': '', 'family': ''} + + def get_shutdown_command(self, task_vars, distribution): + return self.DEFAULT_SHUTDOWN_COMMAND + + def run_test_command(self, distribution, **kwargs): + # Need to wrap the test_command in our PowerShell encoded wrapper. This is done to align the command input to a + # common shell and to allow the psrp connection plugin to report the correct exit code without manually setting + # $LASTEXITCODE for just that plugin. + test_command = self._task.args.get('test_command', self.DEFAULT_TEST_COMMAND) + kwargs['test_command'] = self._connection._shell._encode_script(test_command) + super(ActionModule, self).run_test_command(distribution, **kwargs) + + def perform_reboot(self, task_vars, distribution): + shutdown_command = self.get_shutdown_command(task_vars, distribution) + shutdown_command_args = self.get_shutdown_command_args(distribution) + reboot_command = self._connection._shell._encode_script('{0} {1}'.format(shutdown_command, shutdown_command_args)) + + display.vvv("{action}: rebooting server...".format(action=self._task.action)) + display.debug("{action}: distribution: {dist}".format(action=self._task.action, dist=distribution)) + display.debug("{action}: rebooting server with command '{command}'".format(action=self._task.action, command=reboot_command)) + + result = {} + reboot_result = self._low_level_execute_command(reboot_command, sudoable=self.DEFAULT_SUDOABLE) + result['start'] = datetime.utcnow() + + # Test for "A system shutdown has already been scheduled. (1190)" and handle it gracefully + stdout = reboot_result['stdout'] + stderr = reboot_result['stderr'] + if reboot_result['rc'] == 1190 or (reboot_result['rc'] != 0 and "(1190)" in reboot_result['stderr']): + display.warning('A scheduled reboot was pre-empted by Ansible.') + + # Try to abort (this may fail if it was already aborted) + result1 = self._low_level_execute_command(self._connection._shell._encode_script('shutdown /a'), + sudoable=self.DEFAULT_SUDOABLE) + + # Initiate reboot again + result2 = self._low_level_execute_command(reboot_command, sudoable=self.DEFAULT_SUDOABLE) + + reboot_result['rc'] = result2['rc'] + stdout += result1['stdout'] + result2['stdout'] + stderr += result1['stderr'] + result2['stderr'] + + if reboot_result['rc'] != 0: + result['failed'] = True + result['rebooted'] = False + result['msg'] = "Reboot command failed, error was: {stdout} {stderr}".format( + stdout=to_native(stdout.strip()), + stderr=to_native(stderr.strip())) + return result + + result['failed'] = False + return result diff --git a/test/support/windows-integration/plugins/action/win_template.py b/test/support/windows-integration/plugins/action/win_template.py new file mode 100644 index 0000000..20494b9 --- /dev/null +++ b/test/support/windows-integration/plugins/action/win_template.py @@ -0,0 +1,29 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.plugins.action import ActionBase +from ansible.plugins.action.template import ActionModule as TemplateActionModule + + +# Even though TemplateActionModule inherits from ActionBase, we still need to +# directly inherit from ActionBase to appease the plugin loader. +class ActionModule(TemplateActionModule, ActionBase): + DEFAULT_NEWLINE_SEQUENCE = '\r\n' diff --git a/test/support/windows-integration/plugins/become/runas.py b/test/support/windows-integration/plugins/become/runas.py new file mode 100644 index 0000000..c8ae881 --- /dev/null +++ b/test/support/windows-integration/plugins/become/runas.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = """ + become: runas + short_description: Run As user + description: + - This become plugins allows your remote/login user to execute commands as another user via the windows runas facility. + author: ansible (@core) + version_added: "2.8" + options: + become_user: + description: User you 'become' to execute the task + ini: + - section: privilege_escalation + key: become_user + - section: runas_become_plugin + key: user + vars: + - name: ansible_become_user + - name: ansible_runas_user + env: + - name: ANSIBLE_BECOME_USER + - name: ANSIBLE_RUNAS_USER + required: True + become_flags: + description: Options to pass to runas, a space delimited list of k=v pairs + default: '' + ini: + - section: privilege_escalation + key: become_flags + - section: runas_become_plugin + key: flags + vars: + - name: ansible_become_flags + - name: ansible_runas_flags + env: + - name: ANSIBLE_BECOME_FLAGS + - name: ANSIBLE_RUNAS_FLAGS + become_pass: + description: password + ini: + - section: runas_become_plugin + key: password + vars: + - name: ansible_become_password + - name: ansible_become_pass + - name: ansible_runas_pass + env: + - name: ANSIBLE_BECOME_PASS + - name: ANSIBLE_RUNAS_PASS + notes: + - runas is really implemented in the powershell module handler and as such can only be used with winrm connections. + - This plugin ignores the 'become_exe' setting as it uses an API and not an executable. + - The Secondary Logon service (seclogon) must be running to use runas +""" + +from ansible.plugins.become import BecomeBase + + +class BecomeModule(BecomeBase): + + name = 'runas' + + def build_become_command(self, cmd, shell): + # runas is implemented inside the winrm connection plugin + return cmd diff --git a/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs b/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs new file mode 100644 index 0000000..be0f3db --- /dev/null +++ b/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs @@ -0,0 +1,1341 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; +using Ansible.Privilege; + +namespace Ansible.Service +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct ENUM_SERVICE_STATUSW + { + public string lpServiceName; + public string lpDisplayName; + public SERVICE_STATUS ServiceStatus; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct QUERY_SERVICE_CONFIGW + { + public ServiceType dwServiceType; + public ServiceStartType dwStartType; + public ErrorControl dwErrorControl; + [MarshalAs(UnmanagedType.LPWStr)] public string lpBinaryPathName; + [MarshalAs(UnmanagedType.LPWStr)] public string lpLoadOrderGroup; + public Int32 dwTagId; + public IntPtr lpDependencies; // Can't rely on marshaling as dependencies are delimited by \0. + [MarshalAs(UnmanagedType.LPWStr)] public string lpServiceStartName; + [MarshalAs(UnmanagedType.LPWStr)] public string lpDisplayName; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SC_ACTION + { + public FailureAction Type; + public UInt32 Delay; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_DELAYED_AUTO_START_INFO + { + public bool fDelayedAutostart; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct SERVICE_DESCRIPTIONW + { + [MarshalAs(UnmanagedType.LPWStr)] public string lpDescription; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_FAILURE_ACTIONS_FLAG + { + public bool fFailureActionsOnNonCrashFailures; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct SERVICE_FAILURE_ACTIONSW + { + public UInt32 dwResetPeriod; + [MarshalAs(UnmanagedType.LPWStr)] public string lpRebootMsg; + [MarshalAs(UnmanagedType.LPWStr)] public string lpCommand; + public UInt32 cActions; + public IntPtr lpsaActions; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_LAUNCH_PROTECTED_INFO + { + public LaunchProtection dwLaunchProtected; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_PREFERRED_NODE_INFO + { + public UInt16 usPreferredNode; + public bool fDelete; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_PRESHUTDOWN_INFO + { + public UInt32 dwPreshutdownTimeout; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct SERVICE_REQUIRED_PRIVILEGES_INFOW + { + // Can't rely on marshaling as privileges are delimited by \0. + public IntPtr pmszRequiredPrivileges; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_SID_INFO + { + public ServiceSidInfo dwServiceSidType; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_STATUS + { + public ServiceType dwServiceType; + public ServiceStatus dwCurrentState; + public ControlsAccepted dwControlsAccepted; + public UInt32 dwWin32ExitCode; + public UInt32 dwServiceSpecificExitCode; + public UInt32 dwCheckPoint; + public UInt32 dwWaitHint; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_STATUS_PROCESS + { + public ServiceType dwServiceType; + public ServiceStatus dwCurrentState; + public ControlsAccepted dwControlsAccepted; + public UInt32 dwWin32ExitCode; + public UInt32 dwServiceSpecificExitCode; + public UInt32 dwCheckPoint; + public UInt32 dwWaitHint; + public UInt32 dwProcessId; + public ServiceFlags dwServiceFlags; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_TRIGGER + { + public TriggerType dwTriggerType; + public TriggerAction dwAction; + public IntPtr pTriggerSubtype; + public UInt32 cDataItems; + public IntPtr pDataItems; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_TRIGGER_SPECIFIC_DATA_ITEM + { + public TriggerDataType dwDataType; + public UInt32 cbData; + public IntPtr pData; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_TRIGGER_INFO + { + public UInt32 cTriggers; + public IntPtr pTriggers; + public IntPtr pReserved; + } + + public enum ConfigInfoLevel : uint + { + SERVICE_CONFIG_DESCRIPTION = 0x00000001, + SERVICE_CONFIG_FAILURE_ACTIONS = 0x00000002, + SERVICE_CONFIG_DELAYED_AUTO_START_INFO = 0x00000003, + SERVICE_CONFIG_FAILURE_ACTIONS_FLAG = 0x00000004, + SERVICE_CONFIG_SERVICE_SID_INFO = 0x00000005, + SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO = 0x00000006, + SERVICE_CONFIG_PRESHUTDOWN_INFO = 0x00000007, + SERVICE_CONFIG_TRIGGER_INFO = 0x00000008, + SERVICE_CONFIG_PREFERRED_NODE = 0x00000009, + SERVICE_CONFIG_LAUNCH_PROTECTED = 0x0000000c, + } + } + + internal class NativeMethods + { + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool ChangeServiceConfigW( + SafeHandle hService, + ServiceType dwServiceType, + ServiceStartType dwStartType, + ErrorControl dwErrorControl, + string lpBinaryPathName, + string lpLoadOrderGroup, + IntPtr lpdwTagId, + string lpDependencies, + string lpServiceStartName, + string lpPassword, + string lpDisplayName); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool ChangeServiceConfig2W( + SafeHandle hService, + NativeHelpers.ConfigInfoLevel dwInfoLevel, + IntPtr lpInfo); + + [DllImport("Advapi32.dll", SetLastError = true)] + public static extern bool CloseServiceHandle( + IntPtr hSCObject); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern SafeServiceHandle CreateServiceW( + SafeHandle hSCManager, + string lpServiceName, + string lpDisplayName, + ServiceRights dwDesiredAccess, + ServiceType dwServiceType, + ServiceStartType dwStartType, + ErrorControl dwErrorControl, + string lpBinaryPathName, + string lpLoadOrderGroup, + IntPtr lpdwTagId, + string lpDependencies, + string lpServiceStartName, + string lpPassword); + + [DllImport("Advapi32.dll", SetLastError = true)] + public static extern bool DeleteService( + SafeHandle hService); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool EnumDependentServicesW( + SafeHandle hService, + UInt32 dwServiceState, + SafeMemoryBuffer lpServices, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded, + out UInt32 lpServicesReturned); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern SafeServiceHandle OpenSCManagerW( + string lpMachineName, + string lpDatabaseNmae, + SCMRights dwDesiredAccess); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern SafeServiceHandle OpenServiceW( + SafeHandle hSCManager, + string lpServiceName, + ServiceRights dwDesiredAccess); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool QueryServiceConfigW( + SafeHandle hService, + IntPtr lpServiceConfig, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool QueryServiceConfig2W( + SafeHandle hservice, + NativeHelpers.ConfigInfoLevel dwInfoLevel, + IntPtr lpBuffer, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded); + + [DllImport("Advapi32.dll", SetLastError = true)] + public static extern bool QueryServiceStatusEx( + SafeHandle hService, + UInt32 InfoLevel, + IntPtr lpBuffer, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public UInt32 BufferLength { get; internal set; } + + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + BufferLength = (UInt32)cb; + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + internal class SafeServiceHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeServiceHandle() : base(true) { } + public SafeServiceHandle(IntPtr handle) : base(true) { this.handle = handle; } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseServiceHandle(handle); + } + } + + [Flags] + public enum ControlsAccepted : uint + { + None = 0x00000000, + Stop = 0x00000001, + PauseContinue = 0x00000002, + Shutdown = 0x00000004, + ParamChange = 0x00000008, + NetbindChange = 0x00000010, + HardwareProfileChange = 0x00000020, + PowerEvent = 0x00000040, + SessionChange = 0x00000080, + PreShutdown = 0x00000100, + } + + public enum ErrorControl : uint + { + Ignore = 0x00000000, + Normal = 0x00000001, + Severe = 0x00000002, + Critical = 0x00000003, + } + + public enum FailureAction : uint + { + None = 0x00000000, + Restart = 0x00000001, + Reboot = 0x00000002, + RunCommand = 0x00000003, + } + + public enum LaunchProtection : uint + { + None = 0, + Windows = 1, + WindowsLight = 2, + AntimalwareLight = 3, + } + + [Flags] + public enum SCMRights : uint + { + Connect = 0x00000001, + CreateService = 0x00000002, + EnumerateService = 0x00000004, + Lock = 0x00000008, + QueryLockStatus = 0x00000010, + ModifyBootConfig = 0x00000020, + AllAccess = 0x000F003F, + } + + [Flags] + public enum ServiceFlags : uint + { + None = 0x0000000, + RunsInSystemProcess = 0x00000001, + } + + [Flags] + public enum ServiceRights : uint + { + QueryConfig = 0x00000001, + ChangeConfig = 0x00000002, + QueryStatus = 0x00000004, + EnumerateDependents = 0x00000008, + Start = 0x00000010, + Stop = 0x00000020, + PauseContinue = 0x00000040, + Interrogate = 0x00000080, + UserDefinedControl = 0x00000100, + Delete = 0x00010000, + ReadControl = 0x00020000, + WriteDac = 0x00040000, + WriteOwner = 0x00080000, + AllAccess = 0x000F01FF, + AccessSystemSecurity = 0x01000000, + } + + public enum ServiceStartType : uint + { + BootStart = 0x00000000, + SystemStart = 0x00000001, + AutoStart = 0x00000002, + DemandStart = 0x00000003, + Disabled = 0x00000004, + + // Not part of ChangeServiceConfig enumeration but built by the Srvice class for the StartType property. + AutoStartDelayed = 0x1000000 + } + + [Flags] + public enum ServiceType : uint + { + KernelDriver = 0x00000001, + FileSystemDriver = 0x00000002, + Adapter = 0x00000004, + RecognizerDriver = 0x00000008, + Driver = KernelDriver | FileSystemDriver | RecognizerDriver, + Win32OwnProcess = 0x00000010, + Win32ShareProcess = 0x00000020, + Win32 = Win32OwnProcess | Win32ShareProcess, + UserProcess = 0x00000040, + UserOwnprocess = Win32OwnProcess | UserProcess, + UserShareProcess = Win32ShareProcess | UserProcess, + UserServiceInstance = 0x00000080, + InteractiveProcess = 0x00000100, + PkgService = 0x00000200, + } + + public enum ServiceSidInfo : uint + { + None, + Unrestricted, + Restricted = 3, + } + + public enum ServiceStatus : uint + { + Stopped = 0x00000001, + StartPending = 0x00000002, + StopPending = 0x00000003, + Running = 0x00000004, + ContinuePending = 0x00000005, + PausePending = 0x00000006, + Paused = 0x00000007, + } + + public enum TriggerAction : uint + { + ServiceStart = 0x00000001, + ServiceStop = 0x000000002, + } + + public enum TriggerDataType : uint + { + Binary = 00000001, + String = 0x00000002, + Level = 0x00000003, + KeywordAny = 0x00000004, + KeywordAll = 0x00000005, + } + + public enum TriggerType : uint + { + DeviceInterfaceArrival = 0x00000001, + IpAddressAvailability = 0x00000002, + DomainJoin = 0x00000003, + FirewallPortEvent = 0x00000004, + GroupPolicy = 0x00000005, + NetworkEndpoint = 0x00000006, + Custom = 0x00000014, + } + + public class ServiceManagerException : System.ComponentModel.Win32Exception + { + private string _msg; + + public ServiceManagerException(string message) : this(Marshal.GetLastWin32Error(), message) { } + public ServiceManagerException(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2} - 0x{2:X8})", message, base.Message, errorCode); + } + + public override string Message { get { return _msg; } } + public static explicit operator ServiceManagerException(string message) + { + return new ServiceManagerException(message); + } + } + + public class Action + { + public FailureAction Type; + public UInt32 Delay; + } + + public class FailureActions + { + public UInt32? ResetPeriod = null; // Get is always populated, can be null on set to preserve existing. + public string RebootMsg = null; + public string Command = null; + public List<Action> Actions = null; + + public FailureActions() { } + + internal FailureActions(NativeHelpers.SERVICE_FAILURE_ACTIONSW actions) + { + ResetPeriod = actions.dwResetPeriod; + RebootMsg = actions.lpRebootMsg; + Command = actions.lpCommand; + Actions = new List<Action>(); + + int actionLength = Marshal.SizeOf(typeof(NativeHelpers.SC_ACTION)); + for (int i = 0; i < actions.cActions; i++) + { + IntPtr actionPtr = IntPtr.Add(actions.lpsaActions, i * actionLength); + + NativeHelpers.SC_ACTION rawAction = (NativeHelpers.SC_ACTION)Marshal.PtrToStructure( + actionPtr, typeof(NativeHelpers.SC_ACTION)); + + Actions.Add(new Action() + { + Type = rawAction.Type, + Delay = rawAction.Delay, + }); + } + } + } + + public class TriggerItem + { + public TriggerDataType Type; + public object Data; // Can be string, List<string>, byte, byte[], or Int64 depending on Type. + + public TriggerItem() { } + + internal TriggerItem(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM dataItem) + { + Type = dataItem.dwDataType; + + byte[] itemBytes = new byte[dataItem.cbData]; + Marshal.Copy(dataItem.pData, itemBytes, 0, itemBytes.Length); + + switch (dataItem.dwDataType) + { + case TriggerDataType.String: + string value = Encoding.Unicode.GetString(itemBytes, 0, itemBytes.Length); + + if (value.EndsWith("\0\0")) + { + // Multistring with a delimiter of \0 and terminated with \0\0. + Data = new List<string>(value.Split(new char[1] { '\0' }, StringSplitOptions.RemoveEmptyEntries)); + } + else + // Just a single string with null character at the end, strip it off. + Data = value.Substring(0, value.Length - 1); + break; + case TriggerDataType.Level: + Data = itemBytes[0]; + break; + case TriggerDataType.KeywordAll: + case TriggerDataType.KeywordAny: + Data = BitConverter.ToUInt64(itemBytes, 0); + break; + default: + Data = itemBytes; + break; + } + } + } + + public class Trigger + { + // https://docs.microsoft.com/en-us/windows/win32/api/winsvc/ns-winsvc-service_trigger + public const string NAMED_PIPE_EVENT_GUID = "1f81d131-3fac-4537-9e0c-7e7b0c2f4b55"; + public const string RPC_INTERFACE_EVENT_GUID = "bc90d167-9470-4139-a9ba-be0bbbf5b74d"; + public const string DOMAIN_JOIN_GUID = "1ce20aba-9851-4421-9430-1ddeb766e809"; + public const string DOMAIN_LEAVE_GUID = "ddaf516e-58c2-4866-9574-c3b615d42ea1"; + public const string FIREWALL_PORT_OPEN_GUID = "b7569e07-8421-4ee0-ad10-86915afdad09"; + public const string FIREWALL_PORT_CLOSE_GUID = "a144ed38-8e12-4de4-9d96-e64740b1a524"; + public const string MACHINE_POLICY_PRESENT_GUID = "659fcae6-5bdb-4da9-b1ff-ca2a178d46e0"; + public const string NETWORK_MANAGER_FIRST_IP_ADDRESS_ARRIVAL_GUID = "4f27f2de-14e2-430b-a549-7cd48cbc8245"; + public const string NETWORK_MANAGER_LAST_IP_ADDRESS_REMOVAL_GUID = "cc4ba62a-162e-4648-847a-b6bdf993e335"; + public const string USER_POLICY_PRESENT_GUID = "54fb46c8-f089-464c-b1fd-59d1b62c3b50"; + + public TriggerType Type; + public TriggerAction Action; + public Guid SubType; + public List<TriggerItem> DataItems = new List<TriggerItem>(); + + public Trigger() { } + + internal Trigger(NativeHelpers.SERVICE_TRIGGER trigger) + { + Type = trigger.dwTriggerType; + Action = trigger.dwAction; + SubType = (Guid)Marshal.PtrToStructure(trigger.pTriggerSubtype, typeof(Guid)); + + int dataItemLength = Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)); + for (int i = 0; i < trigger.cDataItems; i++) + { + IntPtr dataPtr = IntPtr.Add(trigger.pDataItems, i * dataItemLength); + + var dataItem = (NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)Marshal.PtrToStructure( + dataPtr, typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)); + + DataItems.Add(new TriggerItem(dataItem)); + } + } + } + + public class Service : IDisposable + { + private const UInt32 SERVICE_NO_CHANGE = 0xFFFFFFFF; + + private SafeServiceHandle _scmHandle; + private SafeServiceHandle _serviceHandle; + private SafeMemoryBuffer _rawServiceConfig; + private NativeHelpers.SERVICE_STATUS_PROCESS _statusProcess; + + private NativeHelpers.QUERY_SERVICE_CONFIGW _ServiceConfig + { + get + { + return (NativeHelpers.QUERY_SERVICE_CONFIGW)Marshal.PtrToStructure( + _rawServiceConfig.DangerousGetHandle(), typeof(NativeHelpers.QUERY_SERVICE_CONFIGW)); + } + } + + // ServiceConfig + public string ServiceName { get; private set; } + + public ServiceType ServiceType + { + get { return _ServiceConfig.dwServiceType; } + set { ChangeServiceConfig(serviceType: value); } + } + + public ServiceStartType StartType + { + get + { + ServiceStartType startType = _ServiceConfig.dwStartType; + if (startType == ServiceStartType.AutoStart) + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_DELAYED_AUTO_START_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DELAYED_AUTO_START_INFO); + + if (value.fDelayedAutostart) + startType = ServiceStartType.AutoStartDelayed; + } + + return startType; + } + set + { + ServiceStartType newStartType = value; + bool delayedStart = false; + if (value == ServiceStartType.AutoStartDelayed) + { + newStartType = ServiceStartType.AutoStart; + delayedStart = true; + } + + ChangeServiceConfig(startType: newStartType); + + var info = new NativeHelpers.SERVICE_DELAYED_AUTO_START_INFO() + { + fDelayedAutostart = delayedStart, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DELAYED_AUTO_START_INFO, info); + } + } + + public ErrorControl ErrorControl + { + get { return _ServiceConfig.dwErrorControl; } + set { ChangeServiceConfig(errorControl: value); } + } + + public string Path + { + get { return _ServiceConfig.lpBinaryPathName; } + set { ChangeServiceConfig(binaryPath: value); } + } + + public string LoadOrderGroup + { + get { return _ServiceConfig.lpLoadOrderGroup; } + set { ChangeServiceConfig(loadOrderGroup: value); } + } + + public List<string> DependentOn + { + get + { + StringBuilder deps = new StringBuilder(); + IntPtr depPtr = _ServiceConfig.lpDependencies; + + bool wasNull = false; + while (true) + { + // Get the current char at the pointer and add it to the StringBuilder. + byte[] charBytes = new byte[sizeof(char)]; + Marshal.Copy(depPtr, charBytes, 0, charBytes.Length); + depPtr = IntPtr.Add(depPtr, charBytes.Length); + char currentChar = BitConverter.ToChar(charBytes, 0); + deps.Append(currentChar); + + // If the previous and current char is \0 exit the loop. + if (currentChar == '\0' && wasNull) + break; + wasNull = currentChar == '\0'; + } + + return new List<string>(deps.ToString().Split(new char[1] { '\0' }, + StringSplitOptions.RemoveEmptyEntries)); + } + set { ChangeServiceConfig(dependencies: value); } + } + + public IdentityReference Account + { + get + { + if (_ServiceConfig.lpServiceStartName == null) + // User services don't have the start name specified and will be null. + return null; + else if (_ServiceConfig.lpServiceStartName == "LocalSystem") + // Special string used for the SYSTEM account, this is the same even for different localisations. + return (NTAccount)new SecurityIdentifier("S-1-5-18").Translate(typeof(NTAccount)); + else + return new NTAccount(_ServiceConfig.lpServiceStartName); + } + set + { + string startName = null; + string pass = null; + + if (value != null) + { + // Create a SID and convert back from a SID to get the Netlogon form regardless of the input + // specified. + SecurityIdentifier accountSid = (SecurityIdentifier)value.Translate(typeof(SecurityIdentifier)); + NTAccount accountName = (NTAccount)accountSid.Translate(typeof(NTAccount)); + string[] accountSplit = accountName.Value.Split(new char[1] { '\\' }, 2); + + // SYSTEM, Local Service, Network Service + List<string> serviceAccounts = new List<string> { "S-1-5-18", "S-1-5-19", "S-1-5-20" }; + + // Well known service accounts and MSAs should have no password set. Explicitly blank out the + // existing password to ensure older passwords are no longer stored by Windows. + if (serviceAccounts.Contains(accountSid.Value) || accountSplit[1].EndsWith("$")) + pass = ""; + + // The SYSTEM account uses this special string to specify that account otherwise use the original + // NTAccount value in case it is in a custom format (not Netlogon) for a reason. + if (accountSid.Value == serviceAccounts[0]) + startName = "LocalSystem"; + else + startName = value.Translate(typeof(NTAccount)).Value; + } + + ChangeServiceConfig(startName: startName, password: pass); + } + } + + public string Password { set { ChangeServiceConfig(password: value); } } + + public string DisplayName + { + get { return _ServiceConfig.lpDisplayName; } + set { ChangeServiceConfig(displayName: value); } + } + + // ServiceConfig2 + + public string Description + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_DESCRIPTIONW>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DESCRIPTION); + + return value.lpDescription; + } + set + { + var info = new NativeHelpers.SERVICE_DESCRIPTIONW() + { + lpDescription = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DESCRIPTION, info); + } + } + + public FailureActions FailureActions + { + get + { + using (SafeMemoryBuffer b = QueryServiceConfig2( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS)) + { + NativeHelpers.SERVICE_FAILURE_ACTIONSW value = (NativeHelpers.SERVICE_FAILURE_ACTIONSW) + Marshal.PtrToStructure(b.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_FAILURE_ACTIONSW)); + + return new FailureActions(value); + } + } + set + { + // dwResetPeriod and lpsaActions must be set together, we need to read the existing config if someone + // wants to update 1 or the other but both aren't explicitly defined. + UInt32? resetPeriod = value.ResetPeriod; + List<Action> actions = value.Actions; + if ((resetPeriod != null && actions == null) || (resetPeriod == null && actions != null)) + { + FailureActions existingValue = this.FailureActions; + + if (resetPeriod != null && existingValue.Actions.Count == 0) + throw new ArgumentException( + "Cannot set FailureAction ResetPeriod without explicit Actions and no existing Actions"); + else if (resetPeriod == null) + resetPeriod = (UInt32)existingValue.ResetPeriod; + + if (actions == null) + actions = existingValue.Actions; + } + + var info = new NativeHelpers.SERVICE_FAILURE_ACTIONSW() + { + dwResetPeriod = resetPeriod == null ? 0 : (UInt32)resetPeriod, + lpRebootMsg = value.RebootMsg, + lpCommand = value.Command, + cActions = actions == null ? 0 : (UInt32)actions.Count, + lpsaActions = IntPtr.Zero, + }; + + // null means to keep the existing actions whereas an empty list deletes the actions. + if (actions == null) + { + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS, info); + return; + } + + int actionLength = Marshal.SizeOf(typeof(NativeHelpers.SC_ACTION)); + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(actionLength * actions.Count)) + { + info.lpsaActions = buffer.DangerousGetHandle(); + HashSet<string> privileges = new HashSet<string>(); + + for (int i = 0; i < actions.Count; i++) + { + IntPtr actionPtr = IntPtr.Add(info.lpsaActions, i * actionLength); + NativeHelpers.SC_ACTION action = new NativeHelpers.SC_ACTION() + { + Delay = actions[i].Delay, + Type = actions[i].Type, + }; + Marshal.StructureToPtr(action, actionPtr, false); + + // Need to make sure the SeShutdownPrivilege is enabled when adding a reboot failure action. + if (action.Type == FailureAction.Reboot) + privileges.Add("SeShutdownPrivilege"); + } + + using (new PrivilegeEnabler(true, privileges.ToList().ToArray())) + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS, info); + } + } + } + + public bool FailureActionsOnNonCrashFailures + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_FAILURE_ACTIONS_FLAG>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG); + + return value.fFailureActionsOnNonCrashFailures; + } + set + { + var info = new NativeHelpers.SERVICE_FAILURE_ACTIONS_FLAG() + { + fFailureActionsOnNonCrashFailures = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG, info); + } + } + + public ServiceSidInfo ServiceSidInfo + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_SID_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_SERVICE_SID_INFO); + + return value.dwServiceSidType; + } + set + { + var info = new NativeHelpers.SERVICE_SID_INFO() + { + dwServiceSidType = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_SERVICE_SID_INFO, info); + } + } + + public List<string> RequiredPrivileges + { + get + { + using (SafeMemoryBuffer buffer = QueryServiceConfig2( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO)) + { + var value = (NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW)Marshal.PtrToStructure( + buffer.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW)); + + int structLength = Marshal.SizeOf(value); + int stringLength = ((int)buffer.BufferLength - structLength) / sizeof(char); + + if (stringLength > 0) + { + string privilegesString = Marshal.PtrToStringUni(value.pmszRequiredPrivileges, stringLength); + return new List<string>(privilegesString.Split(new char[1] { '\0' }, + StringSplitOptions.RemoveEmptyEntries)); + } + else + return new List<string>(); + } + } + set + { + string privilegeString = String.Join("\0", value ?? new List<string>()) + "\0\0"; + + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(Marshal.StringToHGlobalUni(privilegeString))) + { + var info = new NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW() + { + pmszRequiredPrivileges = buffer.DangerousGetHandle(), + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO, info); + } + } + } + + public UInt32 PreShutdownTimeout + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_PRESHUTDOWN_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PRESHUTDOWN_INFO); + + return value.dwPreshutdownTimeout; + } + set + { + var info = new NativeHelpers.SERVICE_PRESHUTDOWN_INFO() + { + dwPreshutdownTimeout = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PRESHUTDOWN_INFO, info); + } + } + + public List<Trigger> Triggers + { + get + { + List<Trigger> triggers = new List<Trigger>(); + + using (SafeMemoryBuffer b = QueryServiceConfig2( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO)) + { + var value = (NativeHelpers.SERVICE_TRIGGER_INFO)Marshal.PtrToStructure( + b.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_TRIGGER_INFO)); + + int triggerLength = Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER)); + for (int i = 0; i < value.cTriggers; i++) + { + IntPtr triggerPtr = IntPtr.Add(value.pTriggers, i * triggerLength); + var trigger = (NativeHelpers.SERVICE_TRIGGER)Marshal.PtrToStructure(triggerPtr, + typeof(NativeHelpers.SERVICE_TRIGGER)); + + triggers.Add(new Trigger(trigger)); + } + } + + return triggers; + } + set + { + var info = new NativeHelpers.SERVICE_TRIGGER_INFO() + { + cTriggers = value == null ? 0 : (UInt32)value.Count, + pTriggers = IntPtr.Zero, + pReserved = IntPtr.Zero, + }; + + if (info.cTriggers == 0) + { + try + { + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO, info); + } + catch (ServiceManagerException e) + { + // Can fail with ERROR_INVALID_PARAMETER if no triggers were already set on the service, just + // continue as the service is what we want it to be. + if (e.NativeErrorCode != 87) + throw; + } + return; + } + + // Due to the dynamic nature of the trigger structure(s) we need to manually calculate the size of the + // data items on each trigger if present. This also serializes the raw data items to bytes here. + int structDataLength = 0; + int dataLength = 0; + Queue<byte[]> dataItems = new Queue<byte[]>(); + foreach (Trigger trigger in value) + { + if (trigger.DataItems == null || trigger.DataItems.Count == 0) + continue; + + foreach (TriggerItem dataItem in trigger.DataItems) + { + structDataLength += Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)); + + byte[] dataItemBytes; + Type dataItemType = dataItem.Data.GetType(); + if (dataItemType == typeof(byte)) + dataItemBytes = new byte[1] { (byte)dataItem.Data }; + else if (dataItemType == typeof(byte[])) + dataItemBytes = (byte[])dataItem.Data; + else if (dataItemType == typeof(UInt64)) + dataItemBytes = BitConverter.GetBytes((UInt64)dataItem.Data); + else if (dataItemType == typeof(string)) + dataItemBytes = Encoding.Unicode.GetBytes((string)dataItem.Data + "\0"); + else if (dataItemType == typeof(List<string>)) + dataItemBytes = Encoding.Unicode.GetBytes( + String.Join("\0", (List<string>)dataItem.Data) + "\0"); + else + throw new ArgumentException(String.Format("Trigger data type '{0}' not a value type", + dataItemType.Name)); + + dataLength += dataItemBytes.Length; + dataItems.Enqueue(dataItemBytes); + } + } + + using (SafeMemoryBuffer triggerBuffer = new SafeMemoryBuffer( + value.Count * Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER)))) + using (SafeMemoryBuffer triggerGuidBuffer = new SafeMemoryBuffer( + value.Count * Marshal.SizeOf(typeof(Guid)))) + using (SafeMemoryBuffer dataItemBuffer = new SafeMemoryBuffer(structDataLength)) + using (SafeMemoryBuffer dataBuffer = new SafeMemoryBuffer(dataLength)) + { + info.pTriggers = triggerBuffer.DangerousGetHandle(); + + IntPtr triggerPtr = triggerBuffer.DangerousGetHandle(); + IntPtr guidPtr = triggerGuidBuffer.DangerousGetHandle(); + IntPtr dataItemPtr = dataItemBuffer.DangerousGetHandle(); + IntPtr dataPtr = dataBuffer.DangerousGetHandle(); + + foreach (Trigger trigger in value) + { + int dataCount = trigger.DataItems == null ? 0 : trigger.DataItems.Count; + var rawTrigger = new NativeHelpers.SERVICE_TRIGGER() + { + dwTriggerType = trigger.Type, + dwAction = trigger.Action, + pTriggerSubtype = guidPtr, + cDataItems = (UInt32)dataCount, + pDataItems = dataCount == 0 ? IntPtr.Zero : dataItemPtr, + }; + guidPtr = StructureToPtr(trigger.SubType, guidPtr); + + for (int i = 0; i < rawTrigger.cDataItems; i++) + { + byte[] dataItemBytes = dataItems.Dequeue(); + var rawTriggerData = new NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM() + { + dwDataType = trigger.DataItems[i].Type, + cbData = (UInt32)dataItemBytes.Length, + pData = dataPtr, + }; + Marshal.Copy(dataItemBytes, 0, dataPtr, dataItemBytes.Length); + dataPtr = IntPtr.Add(dataPtr, dataItemBytes.Length); + + dataItemPtr = StructureToPtr(rawTriggerData, dataItemPtr); + } + + triggerPtr = StructureToPtr(rawTrigger, triggerPtr); + } + + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO, info); + } + } + } + + public UInt16? PreferredNode + { + get + { + try + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_PREFERRED_NODE_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PREFERRED_NODE); + + return value.usPreferredNode; + } + catch (ServiceManagerException e) + { + // If host has no NUMA support this will fail with ERROR_INVALID_PARAMETER + if (e.NativeErrorCode == 0x00000057) // ERROR_INVALID_PARAMETER + return null; + + throw; + } + } + set + { + var info = new NativeHelpers.SERVICE_PREFERRED_NODE_INFO(); + if (value == null) + info.fDelete = true; + else + info.usPreferredNode = (UInt16)value; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PREFERRED_NODE, info); + } + } + + public LaunchProtection LaunchProtection + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_LAUNCH_PROTECTED_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_LAUNCH_PROTECTED); + + return value.dwLaunchProtected; + } + set + { + var info = new NativeHelpers.SERVICE_LAUNCH_PROTECTED_INFO() + { + dwLaunchProtected = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_LAUNCH_PROTECTED, info); + } + } + + // ServiceStatus + public ServiceStatus State { get { return _statusProcess.dwCurrentState; } } + + public ControlsAccepted ControlsAccepted { get { return _statusProcess.dwControlsAccepted; } } + + public UInt32 Win32ExitCode { get { return _statusProcess.dwWin32ExitCode; } } + + public UInt32 ServiceExitCode { get { return _statusProcess.dwServiceSpecificExitCode; } } + + public UInt32 Checkpoint { get { return _statusProcess.dwCheckPoint; } } + + public UInt32 WaitHint { get { return _statusProcess.dwWaitHint; } } + + public UInt32 ProcessId { get { return _statusProcess.dwProcessId; } } + + public ServiceFlags ServiceFlags { get { return _statusProcess.dwServiceFlags; } } + + public Service(string name) : this(name, ServiceRights.AllAccess) { } + + public Service(string name, ServiceRights access) : this(name, access, SCMRights.Connect) { } + + public Service(string name, ServiceRights access, SCMRights scmAccess) + { + ServiceName = name; + _scmHandle = OpenSCManager(scmAccess); + _serviceHandle = NativeMethods.OpenServiceW(_scmHandle, name, access); + if (_serviceHandle.IsInvalid) + throw new ServiceManagerException(String.Format("Failed to open service '{0}'", name)); + + Refresh(); + } + + private Service(SafeServiceHandle scmHandle, SafeServiceHandle serviceHandle, string name) + { + ServiceName = name; + _scmHandle = scmHandle; + _serviceHandle = serviceHandle; + + Refresh(); + } + + // EnumDependentServices + public List<string> DependedBy + { + get + { + UInt32 bytesNeeded = 0; + UInt32 numServices = 0; + NativeMethods.EnumDependentServicesW(_serviceHandle, 3, new SafeMemoryBuffer(IntPtr.Zero), 0, + out bytesNeeded, out numServices); + + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded)) + { + if (!NativeMethods.EnumDependentServicesW(_serviceHandle, 3, buffer, bytesNeeded, out bytesNeeded, + out numServices)) + { + throw new ServiceManagerException("Failed to enumerated dependent services"); + } + + List<string> dependents = new List<string>(); + Type enumType = typeof(NativeHelpers.ENUM_SERVICE_STATUSW); + for (int i = 0; i < numServices; i++) + { + var service = (NativeHelpers.ENUM_SERVICE_STATUSW)Marshal.PtrToStructure( + IntPtr.Add(buffer.DangerousGetHandle(), i * Marshal.SizeOf(enumType)), enumType); + + dependents.Add(service.lpServiceName); + } + + return dependents; + } + } + } + + public static Service Create(string name, string binaryPath, string displayName = null, + ServiceType serviceType = ServiceType.Win32OwnProcess, + ServiceStartType startType = ServiceStartType.DemandStart, ErrorControl errorControl = ErrorControl.Normal, + string loadOrderGroup = null, List<string> dependencies = null, string startName = null, + string password = null) + { + SafeServiceHandle scmHandle = OpenSCManager(SCMRights.CreateService | SCMRights.Connect); + + if (displayName == null) + displayName = name; + + string depString = null; + if (dependencies != null && dependencies.Count > 0) + depString = String.Join("\0", dependencies) + "\0\0"; + + SafeServiceHandle serviceHandle = NativeMethods.CreateServiceW(scmHandle, name, displayName, + ServiceRights.AllAccess, serviceType, startType, errorControl, binaryPath, + loadOrderGroup, IntPtr.Zero, depString, startName, password); + + if (serviceHandle.IsInvalid) + throw new ServiceManagerException(String.Format("Failed to create new service '{0}'", name)); + + return new Service(scmHandle, serviceHandle, name); + } + + public void Delete() + { + if (!NativeMethods.DeleteService(_serviceHandle)) + throw new ServiceManagerException("Failed to delete service"); + Dispose(); + } + + public void Dispose() + { + if (_serviceHandle != null) + _serviceHandle.Dispose(); + + if (_scmHandle != null) + _scmHandle.Dispose(); + GC.SuppressFinalize(this); + } + + public void Refresh() + { + UInt32 bytesNeeded; + NativeMethods.QueryServiceConfigW(_serviceHandle, IntPtr.Zero, 0, out bytesNeeded); + + _rawServiceConfig = new SafeMemoryBuffer((int)bytesNeeded); + if (!NativeMethods.QueryServiceConfigW(_serviceHandle, _rawServiceConfig.DangerousGetHandle(), bytesNeeded, + out bytesNeeded)) + { + throw new ServiceManagerException("Failed to query service config"); + } + + NativeMethods.QueryServiceStatusEx(_serviceHandle, 0, IntPtr.Zero, 0, out bytesNeeded); + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded)) + { + if (!NativeMethods.QueryServiceStatusEx(_serviceHandle, 0, buffer.DangerousGetHandle(), bytesNeeded, + out bytesNeeded)) + { + throw new ServiceManagerException("Failed to query service status"); + } + + _statusProcess = (NativeHelpers.SERVICE_STATUS_PROCESS)Marshal.PtrToStructure( + buffer.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_STATUS_PROCESS)); + } + } + + private void ChangeServiceConfig(ServiceType serviceType = (ServiceType)SERVICE_NO_CHANGE, + ServiceStartType startType = (ServiceStartType)SERVICE_NO_CHANGE, + ErrorControl errorControl = (ErrorControl)SERVICE_NO_CHANGE, string binaryPath = null, + string loadOrderGroup = null, List<string> dependencies = null, string startName = null, + string password = null, string displayName = null) + { + string depString = null; + if (dependencies != null && dependencies.Count > 0) + depString = String.Join("\0", dependencies) + "\0\0"; + + if (!NativeMethods.ChangeServiceConfigW(_serviceHandle, serviceType, startType, errorControl, binaryPath, + loadOrderGroup, IntPtr.Zero, depString, startName, password, displayName)) + { + throw new ServiceManagerException("Failed to change service config"); + } + + Refresh(); + } + + private void ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel infoLevel, object info) + { + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(Marshal.SizeOf(info))) + { + Marshal.StructureToPtr(info, buffer.DangerousGetHandle(), false); + + if (!NativeMethods.ChangeServiceConfig2W(_serviceHandle, infoLevel, buffer.DangerousGetHandle())) + throw new ServiceManagerException("Failed to change service config"); + } + } + + private static SafeServiceHandle OpenSCManager(SCMRights desiredAccess) + { + SafeServiceHandle handle = NativeMethods.OpenSCManagerW(null, null, desiredAccess); + if (handle.IsInvalid) + throw new ServiceManagerException("Failed to open SCManager"); + + return handle; + } + + private T QueryServiceConfig2<T>(NativeHelpers.ConfigInfoLevel infoLevel) + { + using (SafeMemoryBuffer buffer = QueryServiceConfig2(infoLevel)) + return (T)Marshal.PtrToStructure(buffer.DangerousGetHandle(), typeof(T)); + } + + private SafeMemoryBuffer QueryServiceConfig2(NativeHelpers.ConfigInfoLevel infoLevel) + { + UInt32 bytesNeeded = 0; + NativeMethods.QueryServiceConfig2W(_serviceHandle, infoLevel, IntPtr.Zero, 0, out bytesNeeded); + + SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded); + if (!NativeMethods.QueryServiceConfig2W(_serviceHandle, infoLevel, buffer.DangerousGetHandle(), bytesNeeded, + out bytesNeeded)) + { + throw new ServiceManagerException(String.Format("QueryServiceConfig2W({0}) failed", + infoLevel.ToString())); + } + + return buffer; + } + + private static IntPtr StructureToPtr(object structure, IntPtr ptr) + { + Marshal.StructureToPtr(structure, ptr, false); + return IntPtr.Add(ptr, Marshal.SizeOf(structure)); + } + + ~Service() { Dispose(); } + } +} diff --git a/test/support/windows-integration/plugins/modules/async_status.ps1 b/test/support/windows-integration/plugins/modules/async_status.ps1 new file mode 100644 index 0000000..1ce3ff4 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/async_status.ps1 @@ -0,0 +1,58 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$results = @{changed=$false} + +$parsed_args = Parse-Args $args +$jid = Get-AnsibleParam $parsed_args "jid" -failifempty $true -resultobj $results +$mode = Get-AnsibleParam $parsed_args "mode" -Default "status" -ValidateSet "status","cleanup" + +# parsed in from the async_status action plugin +$async_dir = Get-AnsibleParam $parsed_args "_async_dir" -type "path" -failifempty $true + +$log_path = [System.IO.Path]::Combine($async_dir, $jid) + +If(-not $(Test-Path $log_path)) +{ + Fail-Json @{ansible_job_id=$jid; started=1; finished=1} "could not find job at '$async_dir'" +} + +If($mode -eq "cleanup") { + Remove-Item $log_path -Recurse + Exit-Json @{ansible_job_id=$jid; erased=$log_path} +} + +# NOT in cleanup mode, assume regular status mode +# no remote kill mode currently exists, but probably should +# consider log_path + ".pid" file and also unlink that above + +$data = $null +Try { + $data_raw = Get-Content $log_path + + # TODO: move this into module_utils/powershell.ps1? + $jss = New-Object System.Web.Script.Serialization.JavaScriptSerializer + $data = $jss.DeserializeObject($data_raw) +} +Catch { + If(-not $data_raw) { + # file not written yet? That means it is running + Exit-Json @{results_file=$log_path; ansible_job_id=$jid; started=1; finished=0} + } + Else { + Fail-Json @{ansible_job_id=$jid; results_file=$log_path; started=1; finished=1} "Could not parse job output: $data" + } +} + +If (-not $data.ContainsKey("started")) { + $data['finished'] = 1 + $data['ansible_job_id'] = $jid +} +ElseIf (-not $data.ContainsKey("finished")) { + $data['finished'] = 0 +} + +Exit-Json $data diff --git a/test/support/windows-integration/plugins/modules/setup.ps1 b/test/support/windows-integration/plugins/modules/setup.ps1 new file mode 100644 index 0000000..5064723 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/setup.ps1 @@ -0,0 +1,516 @@ +#!powershell + +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +Function Get-CustomFacts { + [cmdletBinding()] + param ( + [Parameter(mandatory=$false)] + $factpath = $null + ) + + if (Test-Path -Path $factpath) { + $FactsFiles = Get-ChildItem -Path $factpath | Where-Object -FilterScript {($PSItem.PSIsContainer -eq $false) -and ($PSItem.Extension -eq '.ps1')} + + foreach ($FactsFile in $FactsFiles) { + $out = & $($FactsFile.FullName) + $result.ansible_facts.Add("ansible_$(($FactsFile.Name).Split('.')[0])", $out) + } + } + else + { + Add-Warning $result "Non existing path was set for local facts - $factpath" + } +} + +Function Get-MachineSid { + # The Machine SID is stored in HKLM:\SECURITY\SAM\Domains\Account and is + # only accessible by the Local System account. This method get's the local + # admin account (ends with -500) and lops it off to get the machine sid. + + $machine_sid = $null + + try { + $admins_sid = "S-1-5-32-544" + $admin_group = ([Security.Principal.SecurityIdentifier]$admins_sid).Translate([Security.Principal.NTAccount]).Value + + Add-Type -AssemblyName System.DirectoryServices.AccountManagement + $principal_context = New-Object -TypeName System.DirectoryServices.AccountManagement.PrincipalContext([System.DirectoryServices.AccountManagement.ContextType]::Machine) + $group_principal = New-Object -TypeName System.DirectoryServices.AccountManagement.GroupPrincipal($principal_context, $admin_group) + $searcher = New-Object -TypeName System.DirectoryServices.AccountManagement.PrincipalSearcher($group_principal) + $groups = $searcher.FindOne() + + foreach ($user in $groups.Members) { + $user_sid = $user.Sid + if ($user_sid.Value.EndsWith("-500")) { + $machine_sid = $user_sid.AccountDomainSid.Value + break + } + } + } catch { + #can fail for any number of reasons, if it does just return the original null + Add-Warning -obj $result -message "Error during machine sid retrieval: $($_.Exception.Message)" + } + + return $machine_sid +} + +$cim_instances = @{} + +Function Get-LazyCimInstance([string]$instance_name, [string]$namespace="Root\CIMV2") { + if(-not $cim_instances.ContainsKey($instance_name)) { + $cim_instances[$instance_name] = $(Get-CimInstance -Namespace $namespace -ClassName $instance_name) + } + + return $cim_instances[$instance_name] +} + +$result = @{ + ansible_facts = @{ } + changed = $false +} + +$grouped_subsets = @{ + min=[System.Collections.Generic.List[string]]@('date_time','distribution','dns','env','local','platform','powershell_version','user') + network=[System.Collections.Generic.List[string]]@('all_ipv4_addresses','all_ipv6_addresses','interfaces','windows_domain', 'winrm') + hardware=[System.Collections.Generic.List[string]]@('bios','memory','processor','uptime','virtual') + external=[System.Collections.Generic.List[string]]@('facter') +} + +# build "all" set from everything mentioned in the group- this means every value must be in at least one subset to be considered legal +$all_set = [System.Collections.Generic.HashSet[string]]@() + +foreach($kv in $grouped_subsets.GetEnumerator()) { + [void] $all_set.UnionWith($kv.Value) +} + +# dynamically create an "all" subset now that we know what should be in it +$grouped_subsets['all'] = [System.Collections.Generic.List[string]]$all_set + +# start with all, build up gather and exclude subsets +$gather_subset = [System.Collections.Generic.HashSet[string]]$grouped_subsets.all +$explicit_subset = [System.Collections.Generic.HashSet[string]]@() +$exclude_subset = [System.Collections.Generic.HashSet[string]]@() + +$params = Parse-Args $args -supports_check_mode $true +$factpath = Get-AnsibleParam -obj $params -name "fact_path" -type "path" +$gather_subset_source = Get-AnsibleParam -obj $params -name "gather_subset" -type "list" -default "all" + +foreach($item in $gather_subset_source) { + if(([string]$item).StartsWith("!")) { + $item = ([string]$item).Substring(1) + if($item -eq "all") { + $all_minus_min = [System.Collections.Generic.HashSet[string]]@($all_set) + [void] $all_minus_min.ExceptWith($grouped_subsets.min) + [void] $exclude_subset.UnionWith($all_minus_min) + } + elseif($grouped_subsets.ContainsKey($item)) { + [void] $exclude_subset.UnionWith($grouped_subsets[$item]) + } + elseif($all_set.Contains($item)) { + [void] $exclude_subset.Add($item) + } + # NB: invalid exclude values are ignored, since that's what posix setup does + } + else { + if($grouped_subsets.ContainsKey($item)) { + [void] $explicit_subset.UnionWith($grouped_subsets[$item]) + } + elseif($all_set.Contains($item)) { + [void] $explicit_subset.Add($item) + } + else { + # NB: POSIX setup fails on invalid value; we warn, because we don't implement the same set as POSIX + # and we don't have platform-specific config for this... + Add-Warning $result "invalid value $item specified in gather_subset" + } + } +} + +[void] $gather_subset.ExceptWith($exclude_subset) +[void] $gather_subset.UnionWith($explicit_subset) + +$ansible_facts = @{ + gather_subset=@($gather_subset_source) + module_setup=$true +} + +$osversion = [Environment]::OSVersion + +if ($osversion.Version -lt [version]"6.2") { + # Server 2008, 2008 R2, and Windows 7 are not tested in CI and we want to let customers know about it before + # removing support altogether. + $version_string = "{0}.{1}" -f ($osversion.Version.Major, $osversion.Version.Minor) + $msg = "Windows version '$version_string' will no longer be supported or tested in the next Ansible release" + Add-DeprecationWarning -obj $result -message $msg -version "2.11" +} + +if($gather_subset.Contains('all_ipv4_addresses') -or $gather_subset.Contains('all_ipv6_addresses')) { + $netcfg = Get-LazyCimInstance Win32_NetworkAdapterConfiguration + + # TODO: split v4/v6 properly, return in separate keys + $ips = @() + Foreach ($ip in $netcfg.IPAddress) { + If ($ip) { + $ips += $ip + } + } + + $ansible_facts += @{ + ansible_ip_addresses = $ips + } +} + +if($gather_subset.Contains('bios')) { + $win32_bios = Get-LazyCimInstance Win32_Bios + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $ansible_facts += @{ + ansible_bios_date = $win32_bios.ReleaseDate.ToString("MM/dd/yyyy") + ansible_bios_version = $win32_bios.SMBIOSBIOSVersion + ansible_product_name = $win32_cs.Model.Trim() + ansible_product_serial = $win32_bios.SerialNumber + # ansible_product_version = ([string] $win32_cs.SystemFamily) + } +} + +if($gather_subset.Contains('date_time')) { + $datetime = (Get-Date) + $datetime_utc = $datetime.ToUniversalTime() + $date = @{ + date = $datetime.ToString("yyyy-MM-dd") + day = $datetime.ToString("dd") + epoch = (Get-Date -UFormat "%s") + hour = $datetime.ToString("HH") + iso8601 = $datetime_utc.ToString("yyyy-MM-ddTHH:mm:ssZ") + iso8601_basic = $datetime.ToString("yyyyMMddTHHmmssffffff") + iso8601_basic_short = $datetime.ToString("yyyyMMddTHHmmss") + iso8601_micro = $datetime_utc.ToString("yyyy-MM-ddTHH:mm:ss.ffffffZ") + minute = $datetime.ToString("mm") + month = $datetime.ToString("MM") + second = $datetime.ToString("ss") + time = $datetime.ToString("HH:mm:ss") + tz = ([System.TimeZoneInfo]::Local.Id) + tz_offset = $datetime.ToString("zzzz") + # Ensure that the weekday is in English + weekday = $datetime.ToString("dddd", [System.Globalization.CultureInfo]::InvariantCulture) + weekday_number = (Get-Date -UFormat "%w") + weeknumber = (Get-Date -UFormat "%W") + year = $datetime.ToString("yyyy") + } + + $ansible_facts += @{ + ansible_date_time = $date + } +} + +if($gather_subset.Contains('distribution')) { + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $product_type = switch($win32_os.ProductType) { + 1 { "workstation" } + 2 { "domain_controller" } + 3 { "server" } + default { "unknown" } + } + + $installation_type = $null + $current_version_path = "HKLM:\SOFTWARE\Microsoft\Windows NT\CurrentVersion" + if (Test-Path -LiteralPath $current_version_path) { + $install_type_prop = Get-ItemProperty -LiteralPath $current_version_path -ErrorAction SilentlyContinue + $installation_type = [String]$install_type_prop.InstallationType + } + + $ansible_facts += @{ + ansible_distribution = $win32_os.Caption + ansible_distribution_version = $osversion.Version.ToString() + ansible_distribution_major_version = $osversion.Version.Major.ToString() + ansible_os_family = "Windows" + ansible_os_name = ($win32_os.Name.Split('|')[0]).Trim() + ansible_os_product_type = $product_type + ansible_os_installation_type = $installation_type + } +} + +if($gather_subset.Contains('env')) { + $env_vars = @{ } + foreach ($item in Get-ChildItem Env:) { + $name = $item | Select-Object -ExpandProperty Name + # Powershell ConvertTo-Json fails if string ends with \ + $value = ($item | Select-Object -ExpandProperty Value).TrimEnd("\") + $env_vars.Add($name, $value) + } + + $ansible_facts += @{ + ansible_env = $env_vars + } +} + +if($gather_subset.Contains('facter')) { + # See if Facter is on the System Path + Try { + Get-Command facter -ErrorAction Stop > $null + $facter_installed = $true + } Catch { + $facter_installed = $false + } + + # Get JSON from Facter, and parse it out. + if ($facter_installed) { + &facter -j | Tee-Object -Variable facter_output > $null + $facts = "$facter_output" | ConvertFrom-Json + ForEach($fact in $facts.PSObject.Properties) { + $fact_name = $fact.Name + $ansible_facts.Add("facter_$fact_name", $fact.Value) + } + } +} + +if($gather_subset.Contains('interfaces')) { + $netcfg = Get-LazyCimInstance Win32_NetworkAdapterConfiguration + $ActiveNetcfg = @() + $ActiveNetcfg += $netcfg | Where-Object {$_.ipaddress -ne $null} + + $namespaces = Get-LazyCimInstance __Namespace -namespace root + if ($namespaces | Where-Object { $_.Name -eq "StandardCimv" }) { + $net_adapters = Get-LazyCimInstance MSFT_NetAdapter -namespace Root\StandardCimv2 + $guid_key = "InterfaceGUID" + $name_key = "Name" + } else { + $net_adapters = Get-LazyCimInstance Win32_NetworkAdapter + $guid_key = "GUID" + $name_key = "NetConnectionID" + } + + $formattednetcfg = @() + foreach ($adapter in $ActiveNetcfg) + { + $thisadapter = @{ + default_gateway = $null + connection_name = $null + dns_domain = $adapter.dnsdomain + interface_index = $adapter.InterfaceIndex + interface_name = $adapter.description + macaddress = $adapter.macaddress + } + + if ($adapter.defaultIPGateway) + { + $thisadapter.default_gateway = $adapter.DefaultIPGateway[0].ToString() + } + $net_adapter = $net_adapters | Where-Object { $_.$guid_key -eq $adapter.SettingID } + if ($net_adapter) { + $thisadapter.connection_name = $net_adapter.$name_key + } + + $formattednetcfg += $thisadapter + } + + $ansible_facts += @{ + ansible_interfaces = $formattednetcfg + } +} + +if ($gather_subset.Contains("local") -and $null -ne $factpath) { + # Get any custom facts; results are updated in the + Get-CustomFacts -factpath $factpath +} + +if($gather_subset.Contains('memory')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $ansible_facts += @{ + # Win32_PhysicalMemory is empty on some virtual platforms + ansible_memtotal_mb = ([math]::ceiling($win32_cs.TotalPhysicalMemory / 1024 / 1024)) + ansible_memfree_mb = ([math]::ceiling($win32_os.FreePhysicalMemory / 1024)) + ansible_swaptotal_mb = ([math]::round($win32_os.TotalSwapSpaceSize / 1024)) + ansible_pagefiletotal_mb = ([math]::round($win32_os.SizeStoredInPagingFiles / 1024)) + ansible_pagefilefree_mb = ([math]::round($win32_os.FreeSpaceInPagingFiles / 1024)) + } +} + + +if($gather_subset.Contains('platform')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $domain_suffix = $win32_cs.Domain.Substring($win32_cs.Workgroup.length) + $fqdn = $win32_cs.DNSHostname + + if( $domain_suffix -ne "") + { + $fqdn = $win32_cs.DNSHostname + "." + $domain_suffix + } + + try { + $ansible_reboot_pending = Get-PendingRebootStatus + } catch { + # fails for non-admin users, set to null in this case + $ansible_reboot_pending = $null + } + + $ansible_facts += @{ + ansible_architecture = $win32_os.OSArchitecture + ansible_domain = $domain_suffix + ansible_fqdn = $fqdn + ansible_hostname = $win32_cs.DNSHostname + ansible_netbios_name = $win32_cs.Name + ansible_kernel = $osversion.Version.ToString() + ansible_nodename = $fqdn + ansible_machine_id = Get-MachineSid + ansible_owner_contact = ([string] $win32_cs.PrimaryOwnerContact) + ansible_owner_name = ([string] $win32_cs.PrimaryOwnerName) + # FUTURE: should this live in its own subset? + ansible_reboot_pending = $ansible_reboot_pending + ansible_system = $osversion.Platform.ToString() + ansible_system_description = ([string] $win32_os.Description) + ansible_system_vendor = $win32_cs.Manufacturer + } +} + +if($gather_subset.Contains('powershell_version')) { + $ansible_facts += @{ + ansible_powershell_version = ($PSVersionTable.PSVersion.Major) + } +} + +if($gather_subset.Contains('processor')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $win32_cpu = Get-LazyCimInstance Win32_Processor + if ($win32_cpu -is [array]) { + # multi-socket, pick first + $win32_cpu = $win32_cpu[0] + } + + $cpu_list = @( ) + for ($i=1; $i -le $win32_cs.NumberOfLogicalProcessors; $i++) { + $cpu_list += $win32_cpu.Manufacturer + $cpu_list += $win32_cpu.Name + } + + $ansible_facts += @{ + ansible_processor = $cpu_list + ansible_processor_cores = $win32_cpu.NumberOfCores + ansible_processor_count = $win32_cs.NumberOfProcessors + ansible_processor_threads_per_core = ($win32_cpu.NumberOfLogicalProcessors / $win32_cpu.NumberofCores) + ansible_processor_vcpus = $win32_cs.NumberOfLogicalProcessors + } +} + +if($gather_subset.Contains('uptime')) { + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $ansible_facts += @{ + ansible_lastboot = $win32_os.lastbootuptime.ToString("u") + ansible_uptime_seconds = $([System.Convert]::ToInt64($(Get-Date).Subtract($win32_os.lastbootuptime).TotalSeconds)) + } +} + +if($gather_subset.Contains('user')) { + $user = [Security.Principal.WindowsIdentity]::GetCurrent() + $ansible_facts += @{ + ansible_user_dir = $env:userprofile + # Win32_UserAccount.FullName is probably the right thing here, but it can be expensive to get on large domains + ansible_user_gecos = "" + ansible_user_id = $env:username + ansible_user_sid = $user.User.Value + } +} + +if($gather_subset.Contains('windows_domain')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $domain_roles = @{ + 0 = "Stand-alone workstation" + 1 = "Member workstation" + 2 = "Stand-alone server" + 3 = "Member server" + 4 = "Backup domain controller" + 5 = "Primary domain controller" + } + + $domain_role = $domain_roles.Get_Item([Int32]$win32_cs.DomainRole) + + $ansible_facts += @{ + ansible_windows_domain = $win32_cs.Domain + ansible_windows_domain_member = $win32_cs.PartOfDomain + ansible_windows_domain_role = $domain_role + } +} + +if($gather_subset.Contains('winrm')) { + + $winrm_https_listener_parent_paths = Get-ChildItem -Path WSMan:\localhost\Listener -Recurse -ErrorAction SilentlyContinue | ` + Where-Object {$_.PSChildName -eq "Transport" -and $_.Value -eq "HTTPS"} | Select-Object PSParentPath + if ($winrm_https_listener_parent_paths -isnot [array]) { + $winrm_https_listener_parent_paths = @($winrm_https_listener_parent_paths) + } + + $winrm_https_listener_paths = @() + foreach ($winrm_https_listener_parent_path in $winrm_https_listener_parent_paths) { + $winrm_https_listener_paths += $winrm_https_listener_parent_path.PSParentPath.Substring($winrm_https_listener_parent_path.PSParentPath.LastIndexOf("\")) + } + + $https_listeners = @() + foreach ($winrm_https_listener_path in $winrm_https_listener_paths) { + $https_listeners += Get-ChildItem -Path "WSMan:\localhost\Listener$winrm_https_listener_path" + } + + $winrm_cert_thumbprints = @() + foreach ($https_listener in $https_listeners) { + $winrm_cert_thumbprints += $https_listener | Where-Object {$_.Name -EQ "CertificateThumbprint" } | Select-Object Value + } + + $winrm_cert_expiry = @() + foreach ($winrm_cert_thumbprint in $winrm_cert_thumbprints) { + Try { + $winrm_cert_expiry += Get-ChildItem -Path Cert:\LocalMachine\My | Where-Object Thumbprint -EQ $winrm_cert_thumbprint.Value.ToString().ToUpper() | Select-Object NotAfter + } Catch { + Add-Warning -obj $result -message "Error during certificate expiration retrieval: $($_.Exception.Message)" + } + } + + $winrm_cert_expirations = $winrm_cert_expiry | Sort-Object NotAfter + if ($winrm_cert_expirations) { + # this fact was renamed from ansible_winrm_certificate_expires due to collision with ansible_winrm_X connection var pattern + $ansible_facts.Add("ansible_win_rm_certificate_expires", $winrm_cert_expirations[0].NotAfter.ToString("yyyy-MM-dd HH:mm:ss")) + } +} + +if($gather_subset.Contains('virtual')) { + $machine_info = Get-LazyCimInstance Win32_ComputerSystem + + switch ($machine_info.model) { + "Virtual Machine" { + $machine_type="Hyper-V" + $machine_role="guest" + } + + "VMware Virtual Platform" { + $machine_type="VMware" + $machine_role="guest" + } + + "VirtualBox" { + $machine_type="VirtualBox" + $machine_role="guest" + } + + "HVM domU" { + $machine_type="Xen" + $machine_role="guest" + } + + default { + $machine_type="NA" + $machine_role="NA" + } + } + + $ansible_facts += @{ + ansible_virtualization_role = $machine_role + ansible_virtualization_type = $machine_type + } +} + +$result.ansible_facts += $ansible_facts + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/slurp.ps1 b/test/support/windows-integration/plugins/modules/slurp.ps1 new file mode 100644 index 0000000..eb506c7 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/slurp.ps1 @@ -0,0 +1,28 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$params = Parse-Args $args -supports_check_mode $true; +$src = Get-AnsibleParam -obj $params -name "src" -type "path" -aliases "path" -failifempty $true; + +$result = @{ + changed = $false; +} + +If (Test-Path -LiteralPath $src -PathType Leaf) +{ + $bytes = [System.IO.File]::ReadAllBytes($src); + $result.content = [System.Convert]::ToBase64String($bytes); + $result.encoding = "base64"; + Exit-Json $result; +} +ElseIf (Test-Path -LiteralPath $src -PathType Container) +{ + Fail-Json $result "Path $src is a directory"; +} +Else +{ + Fail-Json $result "Path $src is not found"; +} diff --git a/test/support/windows-integration/plugins/modules/win_acl.ps1 b/test/support/windows-integration/plugins/modules/win_acl.ps1 new file mode 100644 index 0000000..e3c3813 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_acl.ps1 @@ -0,0 +1,225 @@ +#!powershell + +# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com> +# Copyright: (c) 2015, Trond Hindenes +# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.PrivilegeUtil +#Requires -Module Ansible.ModuleUtils.SID + +$ErrorActionPreference = "Stop" + +# win_acl module (File/Resources Permission Additions/Removal) + +#Functions +function Get-UserSID { + param( + [String]$AccountName + ) + + $userSID = $null + $searchAppPools = $false + + if ($AccountName.Split("\").Count -gt 1) { + if ($AccountName.Split("\")[0] -eq "IIS APPPOOL") { + $searchAppPools = $true + $AccountName = $AccountName.Split("\")[1] + } + } + + if ($searchAppPools) { + Import-Module -Name WebAdministration + $testIISPath = Test-Path -LiteralPath "IIS:" + if ($testIISPath) { + $appPoolObj = Get-ItemProperty -LiteralPath "IIS:\AppPools\$AccountName" + $userSID = $appPoolObj.applicationPoolSid + } + } + else { + $userSID = Convert-ToSID -account_name $AccountName + } + + return $userSID +} + +$params = Parse-Args $args + +Function SetPrivilegeTokens() { + # Set privilege tokens only if admin. + # Admins would have these privs or be able to set these privs in the UI Anyway + + $adminRole=[System.Security.Principal.WindowsBuiltInRole]::Administrator + $myWindowsID=[System.Security.Principal.WindowsIdentity]::GetCurrent() + $myWindowsPrincipal=new-object System.Security.Principal.WindowsPrincipal($myWindowsID) + + + if ($myWindowsPrincipal.IsInRole($adminRole)) { + # Need to adjust token privs when executing Set-ACL in certain cases. + # e.g. d:\testdir is owned by group in which current user is not a member and no perms are inherited from d:\ + # This also sets us up for setting the owner as a feature. + # See the following for details of each privilege + # https://msdn.microsoft.com/en-us/library/windows/desktop/bb530716(v=vs.85).aspx + $privileges = @( + "SeRestorePrivilege", # Grants all write access control to any file, regardless of ACL. + "SeBackupPrivilege", # Grants all read access control to any file, regardless of ACL. + "SeTakeOwnershipPrivilege" # Grants ability to take owernship of an object w/out being granted discretionary access + ) + foreach ($privilege in $privileges) { + $state = Get-AnsiblePrivilege -Name $privilege + if ($state -eq $false) { + Set-AnsiblePrivilege -Name $privilege -Value $true + } + } + } +} + + +$result = @{ + changed = $false +} + +$path = Get-AnsibleParam -obj $params -name "path" -type "str" -failifempty $true +$user = Get-AnsibleParam -obj $params -name "user" -type "str" -failifempty $true +$rights = Get-AnsibleParam -obj $params -name "rights" -type "str" -failifempty $true + +$type = Get-AnsibleParam -obj $params -name "type" -type "str" -failifempty $true -validateset "allow","deny" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "absent","present" + +$inherit = Get-AnsibleParam -obj $params -name "inherit" -type "str" +$propagation = Get-AnsibleParam -obj $params -name "propagation" -type "str" -default "None" -validateset "InheritOnly","None","NoPropagateInherit" + +# We mount the HKCR, HKU, and HKCC registry hives so PS can access them. +# Network paths have no qualifiers so we use -EA SilentlyContinue to ignore that +$path_qualifier = Split-Path -Path $path -Qualifier -ErrorAction SilentlyContinue +if ($path_qualifier -eq "HKCR:" -and (-not (Test-Path -LiteralPath HKCR:\))) { + New-PSDrive -Name HKCR -PSProvider Registry -Root HKEY_CLASSES_ROOT > $null +} +if ($path_qualifier -eq "HKU:" -and (-not (Test-Path -LiteralPath HKU:\))) { + New-PSDrive -Name HKU -PSProvider Registry -Root HKEY_USERS > $null +} +if ($path_qualifier -eq "HKCC:" -and (-not (Test-Path -LiteralPath HKCC:\))) { + New-PSDrive -Name HKCC -PSProvider Registry -Root HKEY_CURRENT_CONFIG > $null +} + +If (-Not (Test-Path -LiteralPath $path)) { + Fail-Json -obj $result -message "$path file or directory does not exist on the host" +} + +# Test that the user/group is resolvable on the local machine +$sid = Get-UserSID -AccountName $user +if (!$sid) { + Fail-Json -obj $result -message "$user is not a valid user or group on the host machine or domain" +} + +If (Test-Path -LiteralPath $path -PathType Leaf) { + $inherit = "None" +} +ElseIf ($null -eq $inherit) { + $inherit = "ContainerInherit, ObjectInherit" +} + +# Bug in Set-Acl, Get-Acl where -LiteralPath only works for the Registry provider if the location is in that root +# qualifier. We also don't have a qualifier for a network path so only change if not null +if ($null -ne $path_qualifier) { + Push-Location -LiteralPath $path_qualifier +} + +Try { + SetPrivilegeTokens + $path_item = Get-Item -LiteralPath $path -Force + If ($path_item.PSProvider.Name -eq "Registry") { + $colRights = [System.Security.AccessControl.RegistryRights]$rights + } + Else { + $colRights = [System.Security.AccessControl.FileSystemRights]$rights + } + + $InheritanceFlag = [System.Security.AccessControl.InheritanceFlags]$inherit + $PropagationFlag = [System.Security.AccessControl.PropagationFlags]$propagation + + If ($type -eq "allow") { + $objType =[System.Security.AccessControl.AccessControlType]::Allow + } + Else { + $objType =[System.Security.AccessControl.AccessControlType]::Deny + } + + $objUser = New-Object System.Security.Principal.SecurityIdentifier($sid) + If ($path_item.PSProvider.Name -eq "Registry") { + $objACE = New-Object System.Security.AccessControl.RegistryAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType) + } + Else { + $objACE = New-Object System.Security.AccessControl.FileSystemAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType) + } + $objACL = Get-ACL -LiteralPath $path + + # Check if the ACE exists already in the objects ACL list + $match = $false + + ForEach($rule in $objACL.GetAccessRules($true, $true, [System.Security.Principal.SecurityIdentifier])){ + + If ($path_item.PSProvider.Name -eq "Registry") { + If (($rule.RegistryRights -eq $objACE.RegistryRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) { + $match = $true + Break + } + } else { + If (($rule.FileSystemRights -eq $objACE.FileSystemRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) { + $match = $true + Break + } + } + } + + If ($state -eq "present" -And $match -eq $false) { + Try { + $objACL.AddAccessRule($objACE) + If ($path_item.PSProvider.Name -eq "Registry") { + Set-ACL -LiteralPath $path -AclObject $objACL + } else { + (Get-Item -LiteralPath $path).SetAccessControl($objACL) + } + $result.changed = $true + } + Catch { + Fail-Json -obj $result -message "an exception occurred when adding the specified rule - $($_.Exception.Message)" + } + } + ElseIf ($state -eq "absent" -And $match -eq $true) { + Try { + $objACL.RemoveAccessRule($objACE) + If ($path_item.PSProvider.Name -eq "Registry") { + Set-ACL -LiteralPath $path -AclObject $objACL + } else { + (Get-Item -LiteralPath $path).SetAccessControl($objACL) + } + $result.changed = $true + } + Catch { + Fail-Json -obj $result -message "an exception occurred when removing the specified rule - $($_.Exception.Message)" + } + } + Else { + # A rule was attempting to be added but already exists + If ($match -eq $true) { + Exit-Json -obj $result -message "the specified rule already exists" + } + # A rule didn't exist that was trying to be removed + Else { + Exit-Json -obj $result -message "the specified rule does not exist" + } + } +} +Catch { + Fail-Json -obj $result -message "an error occurred when attempting to $state $rights permission(s) on $path for $user - $($_.Exception.Message)" +} +Finally { + # Make sure we revert the location stack to the original path just for cleanups sake + if ($null -ne $path_qualifier) { + Pop-Location + } +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/plugins/modules/win_acl.py b/test/support/windows-integration/plugins/modules/win_acl.py new file mode 100644 index 0000000..14fbd82 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_acl.py @@ -0,0 +1,132 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com> +# Copyright: (c) 2015, Trond Hindenes +# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_acl +version_added: "2.0" +short_description: Set file/directory/registry permissions for a system user or group +description: +- Add or remove rights/permissions for a given user or group for the specified + file, folder, registry key or AppPool identifies. +options: + path: + description: + - The path to the file or directory. + type: str + required: yes + user: + description: + - User or Group to add specified rights to act on src file/folder or + registry key. + type: str + required: yes + state: + description: + - Specify whether to add C(present) or remove C(absent) the specified access rule. + type: str + choices: [ absent, present ] + default: present + type: + description: + - Specify whether to allow or deny the rights specified. + type: str + required: yes + choices: [ allow, deny ] + rights: + description: + - The rights/permissions that are to be allowed/denied for the specified + user or group for the item at C(path). + - If C(path) is a file or directory, rights can be any right under MSDN + FileSystemRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.filesystemrights.aspx). + - If C(path) is a registry key, rights can be any right under MSDN + RegistryRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.registryrights.aspx). + type: str + required: yes + inherit: + description: + - Inherit flags on the ACL rules. + - Can be specified as a comma separated list, e.g. C(ContainerInherit), + C(ObjectInherit). + - For more information on the choices see MSDN InheritanceFlags enumeration + at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.inheritanceflags.aspx). + - Defaults to C(ContainerInherit, ObjectInherit) for Directories. + type: str + choices: [ ContainerInherit, ObjectInherit ] + propagation: + description: + - Propagation flag on the ACL rules. + - For more information on the choices see MSDN PropagationFlags enumeration + at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.propagationflags.aspx). + type: str + choices: [ InheritOnly, None, NoPropagateInherit ] + default: "None" +notes: +- If adding ACL's for AppPool identities (available since 2.3), the Windows + Feature "Web-Scripting-Tools" must be enabled. +seealso: +- module: win_acl_inheritance +- module: win_file +- module: win_owner +- module: win_stat +author: +- Phil Schwartz (@schwartzmx) +- Trond Hindenes (@trondhindenes) +- Hans-Joachim Kliemeck (@h0nIg) +''' + +EXAMPLES = r''' +- name: Restrict write and execute access to User Fed-Phil + win_acl: + user: Fed-Phil + path: C:\Important\Executable.exe + type: deny + rights: ExecuteFile,Write + +- name: Add IIS_IUSRS allow rights + win_acl: + path: C:\inetpub\wwwroot\MySite + user: IIS_IUSRS + rights: FullControl + type: allow + state: present + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Set registry key right + win_acl: + path: HKCU:\Bovine\Key + user: BUILTIN\Users + rights: EnumerateSubKeys + type: allow + state: present + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Remove FullControl AccessRule for IIS_IUSRS + win_acl: + path: C:\inetpub\wwwroot\MySite + user: IIS_IUSRS + rights: FullControl + type: allow + state: absent + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Deny Intern + win_acl: + path: C:\Administrator\Documents + user: Intern + rights: Read,Write,Modify,FullControl,Delete + type: deny + state: present +''' diff --git a/test/support/windows-integration/plugins/modules/win_certificate_store.ps1 b/test/support/windows-integration/plugins/modules/win_certificate_store.ps1 new file mode 100644 index 0000000..db98413 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_certificate_store.ps1 @@ -0,0 +1,260 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic + +$store_name_values = ([System.Security.Cryptography.X509Certificates.StoreName]).GetEnumValues() | ForEach-Object { $_.ToString() } +$store_location_values = ([System.Security.Cryptography.X509Certificates.StoreLocation]).GetEnumValues() | ForEach-Object { $_.ToString() } + +$spec = @{ + options = @{ + state = @{ type = "str"; default = "present"; choices = "absent", "exported", "present" } + path = @{ type = "path" } + thumbprint = @{ type = "str" } + store_name = @{ type = "str"; default = "My"; choices = $store_name_values } + store_location = @{ type = "str"; default = "LocalMachine"; choices = $store_location_values } + password = @{ type = "str"; no_log = $true } + key_exportable = @{ type = "bool"; default = $true } + key_storage = @{ type = "str"; default = "default"; choices = "default", "machine", "user" } + file_type = @{ type = "str"; default = "der"; choices = "der", "pem", "pkcs12" } + } + required_if = @( + @("state", "absent", @("path", "thumbprint"), $true), + @("state", "exported", @("path", "thumbprint")), + @("state", "present", @("path")) + ) + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +Function Get-CertFile($module, $path, $password, $key_exportable, $key_storage) { + # parses a certificate file and returns X509Certificate2Collection + if (-not (Test-Path -LiteralPath $path -PathType Leaf)) { + $module.FailJson("File at '$path' either does not exist or is not a file") + } + + # must set at least the PersistKeySet flag so that the PrivateKey + # is stored in a permanent container and not deleted once the handle + # is gone. + $store_flags = [System.Security.Cryptography.X509Certificates.X509KeyStorageFlags]::PersistKeySet + + $key_storage = $key_storage.substring(0,1).ToUpper() + $key_storage.substring(1).ToLower() + $store_flags = $store_flags -bor [Enum]::Parse([System.Security.Cryptography.X509Certificates.X509KeyStorageFlags], "$($key_storage)KeySet") + if ($key_exportable) { + $store_flags = $store_flags -bor [System.Security.Cryptography.X509Certificates.X509KeyStorageFlags]::Exportable + } + + # TODO: If I'm feeling adventurours, write code to parse PKCS#12 PEM encoded + # file as .NET does not have an easy way to import this + $certs = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Certificate2Collection + + try { + $certs.Import($path, $password, $store_flags) + } catch { + $module.FailJson("Failed to load cert from file: $($_.Exception.Message)", $_) + } + + return $certs +} + +Function New-CertFile($module, $cert, $path, $type, $password) { + $content_type = switch ($type) { + "pem" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Cert } + "der" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Cert } + "pkcs12" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Pkcs12 } + } + if ($type -eq "pkcs12") { + $missing_key = $false + if ($null -eq $cert.PrivateKey) { + $missing_key = $true + } elseif ($cert.PrivateKey.CspKeyContainerInfo.Exportable -eq $false) { + $missing_key = $true + } + if ($missing_key) { + $module.FailJson("Cannot export cert with key as PKCS12 when the key is not marked as exportable or not accessible by the current user") + } + } + + if (Test-Path -LiteralPath $path) { + Remove-Item -LiteralPath $path -Force + $module.Result.changed = $true + } + try { + $cert_bytes = $cert.Export($content_type, $password) + } catch { + $module.FailJson("Failed to export certificate as bytes: $($_.Exception.Message)", $_) + } + + # Need to manually handle a PEM file + if ($type -eq "pem") { + $cert_content = "-----BEGIN CERTIFICATE-----`r`n" + $base64_string = [System.Convert]::ToBase64String($cert_bytes, [System.Base64FormattingOptions]::InsertLineBreaks) + $cert_content += $base64_string + $cert_content += "`r`n-----END CERTIFICATE-----" + $file_encoding = [System.Text.Encoding]::ASCII + $cert_bytes = $file_encoding.GetBytes($cert_content) + } elseif ($type -eq "pkcs12") { + $module.Result.key_exported = $false + if ($null -ne $cert.PrivateKey) { + $module.Result.key_exportable = $cert.PrivateKey.CspKeyContainerInfo.Exportable + } + } + + if (-not $module.CheckMode) { + try { + [System.IO.File]::WriteAllBytes($path, $cert_bytes) + } catch [System.ArgumentNullException] { + $module.FailJson("Failed to write cert to file, cert was null: $($_.Exception.Message)", $_) + } catch [System.IO.IOException] { + $module.FailJson("Failed to write cert to file due to IO Exception: $($_.Exception.Message)", $_) + } catch [System.UnauthorizedAccessException] { + $module.FailJson("Failed to write cert to file due to permissions: $($_.Exception.Message)", $_) + } catch { + $module.FailJson("Failed to write cert to file: $($_.Exception.Message)", $_) + } + } + $module.Result.changed = $true +} + +Function Get-CertFileType($path, $password) { + $certs = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Certificate2Collection + try { + $certs.Import($path, $password, 0) + } catch [System.Security.Cryptography.CryptographicException] { + # the file is a pkcs12 we just had the wrong password + return "pkcs12" + } catch { + return "unknown" + } + + $file_contents = Get-Content -LiteralPath $path -Raw + if ($file_contents.StartsWith("-----BEGIN CERTIFICATE-----")) { + return "pem" + } elseif ($file_contents.StartsWith("-----BEGIN PKCS7-----")) { + return "pkcs7-ascii" + } elseif ($certs.Count -gt 1) { + # multiple certs must be pkcs7 + return "pkcs7-binary" + } elseif ($certs[0].HasPrivateKey) { + return "pkcs12" + } elseif ($path.EndsWith(".pfx") -or $path.EndsWith(".p12")) { + # no way to differenciate a pfx with a der file so we must rely on the + # extension + return "pkcs12" + } else { + return "der" + } +} + +$state = $module.Params.state +$path = $module.Params.path +$thumbprint = $module.Params.thumbprint +$store_name = [System.Security.Cryptography.X509Certificates.StoreName]"$($module.Params.store_name)" +$store_location = [System.Security.Cryptography.X509Certificates.Storelocation]"$($module.Params.store_location)" +$password = $module.Params.password +$key_exportable = $module.Params.key_exportable +$key_storage = $module.Params.key_storage +$file_type = $module.Params.file_type + +$module.Result.thumbprints = @() + +$store = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Store -ArgumentList $store_name, $store_location +try { + $store.Open([System.Security.Cryptography.X509Certificates.OpenFlags]::ReadWrite) +} catch [System.Security.Cryptography.CryptographicException] { + $module.FailJson("Unable to open the store as it is not readable: $($_.Exception.Message)", $_) +} catch [System.Security.SecurityException] { + $module.FailJson("Unable to open the store with the current permissions: $($_.Exception.Message)", $_) +} catch { + $module.FailJson("Unable to open the store: $($_.Exception.Message)", $_) +} +$store_certificates = $store.Certificates + +try { + if ($state -eq "absent") { + $cert_thumbprints = @() + + if ($null -ne $path) { + $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage + foreach ($cert in $certs) { + $cert_thumbprints += $cert.Thumbprint + } + } elseif ($null -ne $thumbprint) { + $cert_thumbprints += $thumbprint + } + + foreach ($cert_thumbprint in $cert_thumbprints) { + $module.Result.thumbprints += $cert_thumbprint + $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $cert_thumbprint, $false) + if ($found_certs.Count -gt 0) { + foreach ($found_cert in $found_certs) { + try { + if (-not $module.CheckMode) { + $store.Remove($found_cert) + } + } catch [System.Security.SecurityException] { + $module.FailJson("Unable to remove cert with thumbprint '$cert_thumbprint' with current permissions: $($_.Exception.Message)", $_) + } catch { + $module.FailJson("Unable to remove cert with thumbprint '$cert_thumbprint': $($_.Exception.Message)", $_) + } + $module.Result.changed = $true + } + } + } + } elseif ($state -eq "exported") { + # TODO: Add support for PKCS7 and exporting a cert chain + $module.Result.thumbprints += $thumbprint + $export = $true + if (Test-Path -LiteralPath $path -PathType Container) { + $module.FailJson("Cannot export cert to path '$path' as it is a directory") + } elseif (Test-Path -LiteralPath $path -PathType Leaf) { + $actual_cert_type = Get-CertFileType -path $path -password $password + if ($actual_cert_type -eq $file_type) { + try { + $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage + } catch { + # failed to load the file so we set the thumbprint to something + # that will fail validation + $certs = @{Thumbprint = $null} + } + + if ($certs.Thumbprint -eq $thumbprint) { + $export = $false + } + } + } + + if ($export) { + $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $thumbprint, $false) + if ($found_certs.Count -ne 1) { + $module.FailJson("Found $($found_certs.Count) certs when only expecting 1") + } + + New-CertFile -module $module -cert $found_certs -path $path -type $file_type -password $password + } + } else { + $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage + foreach ($cert in $certs) { + $module.Result.thumbprints += $cert.Thumbprint + $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $cert.Thumbprint, $false) + if ($found_certs.Count -eq 0) { + try { + if (-not $module.CheckMode) { + $store.Add($cert) + } + } catch [System.Security.Cryptography.CryptographicException] { + $module.FailJson("Unable to import certificate with thumbprint '$($cert.Thumbprint)' with the current permissions: $($_.Exception.Message)", $_) + } catch { + $module.FailJson("Unable to import certificate with thumbprint '$($cert.Thumbprint)': $($_.Exception.Message)", $_) + } + $module.Result.changed = $true + } + } + } +} finally { + $store.Close() +} + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_certificate_store.py b/test/support/windows-integration/plugins/modules/win_certificate_store.py new file mode 100644 index 0000000..dc617e3 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_certificate_store.py @@ -0,0 +1,208 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_certificate_store +version_added: '2.5' +short_description: Manages the certificate store +description: +- Used to import/export and remove certificates and keys from the local + certificate store. +- This module is not used to create certificates and will only manage existing + certs as a file or in the store. +- It can be used to import PEM, DER, P7B, PKCS12 (PFX) certificates and export + PEM, DER and PKCS12 certificates. +options: + state: + description: + - If C(present), will ensure that the certificate at I(path) is imported + into the certificate store specified. + - If C(absent), will ensure that the certificate specified by I(thumbprint) + or the thumbprint of the cert at I(path) is removed from the store + specified. + - If C(exported), will ensure the file at I(path) is a certificate + specified by I(thumbprint). + - When exporting a certificate, if I(path) is a directory then the module + will fail, otherwise the file will be replaced if needed. + type: str + choices: [ absent, exported, present ] + default: present + path: + description: + - The path to a certificate file. + - This is required when I(state) is C(present) or C(exported). + - When I(state) is C(absent) and I(thumbprint) is not specified, the + thumbprint is derived from the certificate at this path. + type: path + thumbprint: + description: + - The thumbprint as a hex string to either export or remove. + - See the examples for how to specify the thumbprint. + type: str + store_name: + description: + - The store name to use when importing a certificate or searching for a + certificate. + - "C(AddressBook): The X.509 certificate store for other users" + - "C(AuthRoot): The X.509 certificate store for third-party certificate authorities (CAs)" + - "C(CertificateAuthority): The X.509 certificate store for intermediate certificate authorities (CAs)" + - "C(Disallowed): The X.509 certificate store for revoked certificates" + - "C(My): The X.509 certificate store for personal certificates" + - "C(Root): The X.509 certificate store for trusted root certificate authorities (CAs)" + - "C(TrustedPeople): The X.509 certificate store for directly trusted people and resources" + - "C(TrustedPublisher): The X.509 certificate store for directly trusted publishers" + type: str + choices: + - AddressBook + - AuthRoot + - CertificateAuthority + - Disallowed + - My + - Root + - TrustedPeople + - TrustedPublisher + default: My + store_location: + description: + - The store location to use when importing a certificate or searching for a + certificate. + choices: [ CurrentUser, LocalMachine ] + default: LocalMachine + password: + description: + - The password of the pkcs12 certificate key. + - This is used when reading a pkcs12 certificate file or the password to + set when C(state=exported) and C(file_type=pkcs12). + - If the pkcs12 file has no password set or no password should be set on + the exported file, do not set this option. + type: str + key_exportable: + description: + - Whether to allow the private key to be exported. + - If C(no), then this module and other process will only be able to export + the certificate and the private key cannot be exported. + - Used when C(state=present) only. + type: bool + default: yes + key_storage: + description: + - Specifies where Windows will store the private key when it is imported. + - When set to C(default), the default option as set by Windows is used, typically C(user). + - When set to C(machine), the key is stored in a path accessible by various + users. + - When set to C(user), the key is stored in a path only accessible by the + current user. + - Used when C(state=present) only and cannot be changed once imported. + - See U(https://msdn.microsoft.com/en-us/library/system.security.cryptography.x509certificates.x509keystorageflags.aspx) + for more details. + type: str + choices: [ default, machine, user ] + default: default + file_type: + description: + - The file type to export the certificate as when C(state=exported). + - C(der) is a binary ASN.1 encoded file. + - C(pem) is a base64 encoded file of a der file in the OpenSSL form. + - C(pkcs12) (also known as pfx) is a binary container that contains both + the certificate and private key unlike the other options. + - When C(pkcs12) is set and the private key is not exportable or accessible + by the current user, it will throw an exception. + type: str + choices: [ der, pem, pkcs12 ] + default: der +notes: +- Some actions on PKCS12 certificates and keys may fail with the error + C(the specified network password is not correct), either use CredSSP or + Kerberos with credential delegation, or use C(become) to bypass these + restrictions. +- The certificates must be located on the Windows host to be set with I(path). +- When importing a certificate for usage in IIS, it is generally required + to use the C(machine) key_storage option, as both C(default) and C(user) + will make the private key unreadable to IIS APPPOOL identities and prevent + binding the certificate to the https endpoint. +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Import a certificate + win_certificate_store: + path: C:\Temp\cert.pem + state: present + +- name: Import pfx certificate that is password protected + win_certificate_store: + path: C:\Temp\cert.pfx + state: present + password: VeryStrongPasswordHere! + become: yes + become_method: runas + +- name: Import pfx certificate without password and set private key as un-exportable + win_certificate_store: + path: C:\Temp\cert.pfx + state: present + key_exportable: no + # usually you don't set this here but it is for illustrative purposes + vars: + ansible_winrm_transport: credssp + +- name: Remove a certificate based on file thumbprint + win_certificate_store: + path: C:\Temp\cert.pem + state: absent + +- name: Remove a certificate based on thumbprint + win_certificate_store: + thumbprint: BD7AF104CF1872BDB518D95C9534EA941665FD27 + state: absent + +- name: Remove certificate based on thumbprint is CurrentUser/TrustedPublishers store + win_certificate_store: + thumbprint: BD7AF104CF1872BDB518D95C9534EA941665FD27 + state: absent + store_location: CurrentUser + store_name: TrustedPublisher + +- name: Export certificate as der encoded file + win_certificate_store: + path: C:\Temp\cert.cer + state: exported + file_type: der + +- name: Export certificate and key as pfx encoded file + win_certificate_store: + path: C:\Temp\cert.pfx + state: exported + file_type: pkcs12 + password: AnotherStrongPass! + become: yes + become_method: runas + become_user: SYSTEM + +- name: Import certificate be used by IIS + win_certificate_store: + path: C:\Temp\cert.pfx + file_type: pkcs12 + password: StrongPassword! + store_location: LocalMachine + key_storage: machine + state: present +''' + +RETURN = r''' +thumbprints: + description: A list of certificate thumbprints that were touched by the + module. + returned: success + type: list + sample: ["BC05633694E675449136679A658281F17A191087"] +''' diff --git a/test/support/windows-integration/plugins/modules/win_command.ps1 b/test/support/windows-integration/plugins/modules/win_command.ps1 new file mode 100644 index 0000000..e2a3065 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_command.ps1 @@ -0,0 +1,78 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.CommandUtil +#Requires -Module Ansible.ModuleUtils.FileUtil + +# TODO: add check mode support + +Set-StrictMode -Version 2 +$ErrorActionPreference = 'Stop' + +$params = Parse-Args $args -supports_check_mode $false + +$raw_command_line = Get-AnsibleParam -obj $params -name "_raw_params" -type "str" -failifempty $true +$chdir = Get-AnsibleParam -obj $params -name "chdir" -type "path" +$creates = Get-AnsibleParam -obj $params -name "creates" -type "path" +$removes = Get-AnsibleParam -obj $params -name "removes" -type "path" +$stdin = Get-AnsibleParam -obj $params -name "stdin" -type "str" +$output_encoding_override = Get-AnsibleParam -obj $params -name "output_encoding_override" -type "str" + +$raw_command_line = $raw_command_line.Trim() + +$result = @{ + changed = $true + cmd = $raw_command_line +} + +if ($creates -and $(Test-AnsiblePath -Path $creates)) { + Exit-Json @{msg="skipped, since $creates exists";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +if ($removes -and -not $(Test-AnsiblePath -Path $removes)) { + Exit-Json @{msg="skipped, since $removes does not exist";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +$command_args = @{ + command = $raw_command_line +} +if ($chdir) { + $command_args['working_directory'] = $chdir +} +if ($stdin) { + $command_args['stdin'] = $stdin +} +if ($output_encoding_override) { + $command_args['output_encoding_override'] = $output_encoding_override +} + +$start_datetime = [DateTime]::UtcNow +try { + $command_result = Run-Command @command_args +} catch { + $result.changed = $false + try { + $result.rc = $_.Exception.NativeErrorCode + } catch { + $result.rc = 2 + } + Fail-Json -obj $result -message $_.Exception.Message +} + +$result.stdout = $command_result.stdout +$result.stderr = $command_result.stderr +$result.rc = $command_result.rc + +$end_datetime = [DateTime]::UtcNow +$result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff") + +If ($result.rc -ne 0) { + Fail-Json -obj $result -message "non-zero return code" +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_command.py b/test/support/windows-integration/plugins/modules/win_command.py new file mode 100644 index 0000000..508419b --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_command.py @@ -0,0 +1,136 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Ansible, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_command +short_description: Executes a command on a remote Windows node +version_added: 2.2 +description: + - The C(win_command) module takes the command name followed by a list of space-delimited arguments. + - The given command will be executed on all selected nodes. It will not be + processed through the shell, so variables like C($env:HOME) and operations + like C("<"), C(">"), C("|"), and C(";") will not work (use the M(win_shell) + module if you need these features). + - For non-Windows targets, use the M(command) module instead. +options: + free_form: + description: + - The C(win_command) module takes a free form command to run. + - There is no parameter actually named 'free form'. See the examples! + type: str + required: yes + creates: + description: + - A path or path filter pattern; when the referenced path exists on the target host, the task will be skipped. + type: path + removes: + description: + - A path or path filter pattern; when the referenced path B(does not) exist on the target host, the task will be skipped. + type: path + chdir: + description: + - Set the specified path as the current working directory before executing a command. + type: path + stdin: + description: + - Set the stdin of the command directly to the specified value. + type: str + version_added: '2.5' + output_encoding_override: + description: + - This option overrides the encoding of stdout/stderr output. + - You can use this option when you need to run a command which ignore the console's codepage. + - You should only need to use this option in very rare circumstances. + - This value can be any valid encoding C(Name) based on the output of C([System.Text.Encoding]::GetEncodings()). + See U(https://docs.microsoft.com/dotnet/api/system.text.encoding.getencodings). + type: str + version_added: '2.10' +notes: + - If you want to run a command through a shell (say you are using C(<), + C(>), C(|), etc), you actually want the M(win_shell) module instead. The + C(win_command) module is much more secure as it's not affected by the user's + environment. + - C(creates), C(removes), and C(chdir) can be specified after the command. For instance, if you only want to run a command if a certain file does not + exist, use this. +seealso: +- module: command +- module: psexec +- module: raw +- module: win_psexec +- module: win_shell +author: + - Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +- name: Save the result of 'whoami' in 'whoami_out' + win_command: whoami + register: whoami_out + +- name: Run command that only runs if folder exists and runs from a specific folder + win_command: wbadmin -backupTarget:C:\backup\ + args: + chdir: C:\somedir\ + creates: C:\backup\ + +- name: Run an executable and send data to the stdin for the executable + win_command: powershell.exe - + args: + stdin: Write-Host test +''' + +RETURN = r''' +msg: + description: changed + returned: always + type: bool + sample: true +start: + description: The command execution start time + returned: always + type: str + sample: '2016-02-25 09:18:26.429568' +end: + description: The command execution end time + returned: always + type: str + sample: '2016-02-25 09:18:26.755339' +delta: + description: The command execution delta time + returned: always + type: str + sample: '0:00:00.325771' +stdout: + description: The command standard output + returned: always + type: str + sample: 'Clustering node rabbit@slave1 with rabbit@master ...' +stderr: + description: The command standard error + returned: always + type: str + sample: 'ls: cannot access foo: No such file or directory' +cmd: + description: The command executed by the task + returned: always + type: str + sample: 'rabbitmqctl join_cluster rabbit@master' +rc: + description: The command return code (0 means success) + returned: always + type: int + sample: 0 +stdout_lines: + description: The command standard output split in lines + returned: always + type: list + sample: [u'Clustering node rabbit@slave1 with rabbit@master ...'] +''' diff --git a/test/support/windows-integration/plugins/modules/win_copy.ps1 b/test/support/windows-integration/plugins/modules/win_copy.ps1 new file mode 100644 index 0000000..6a26ee7 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_copy.ps1 @@ -0,0 +1,403 @@ +#!powershell + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.Backup + +$ErrorActionPreference = 'Stop' + +$params = Parse-Args -arguments $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false + +# there are 4 modes to win_copy which are driven by the action plugins: +# explode: src is a zip file which needs to be extracted to dest, for use with multiple files +# query: win_copy action plugin wants to get the state of remote files to check whether it needs to send them +# remote: all copy action is happening remotely (remote_src=True) +# single: a single file has been copied, also used with template +$copy_mode = Get-AnsibleParam -obj $params -name "_copy_mode" -type "str" -default "single" -validateset "explode","query","remote","single" + +# used in explode, remote and single mode +$src = Get-AnsibleParam -obj $params -name "src" -type "path" -failifempty ($copy_mode -in @("explode","process","single")) +$dest = Get-AnsibleParam -obj $params -name "dest" -type "path" -failifempty $true +$backup = Get-AnsibleParam -obj $params -name "backup" -type "bool" -default $false + +# used in single mode +$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str" + +# used in query and remote mode +$force = Get-AnsibleParam -obj $params -name "force" -type "bool" -default $true + +# used in query mode, contains the local files/directories/symlinks that are to be copied +$files = Get-AnsibleParam -obj $params -name "files" -type "list" +$directories = Get-AnsibleParam -obj $params -name "directories" -type "list" + +$result = @{ + changed = $false +} + +if ($diff_mode) { + $result.diff = @{} +} + +Function Copy-File($source, $dest) { + $diff = "" + $copy_file = $false + $source_checksum = $null + if ($force) { + $source_checksum = Get-FileChecksum -path $source + } + + if (Test-Path -LiteralPath $dest -PathType Container) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': dest is already a folder" + } elseif (Test-Path -LiteralPath $dest -PathType Leaf) { + if ($force) { + $target_checksum = Get-FileChecksum -path $dest + if ($source_checksum -ne $target_checksum) { + $copy_file = $true + } + } + } else { + $copy_file = $true + } + + if ($copy_file) { + $file_dir = [System.IO.Path]::GetDirectoryName($dest) + # validate the parent dir is not a file and that it exists + if (Test-Path -LiteralPath $file_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder" + } elseif (-not (Test-Path -LiteralPath $file_dir)) { + # directory doesn't exist, need to create + New-Item -Path $file_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + $diff += "+$file_dir\`n" + } + + if ($backup) { + $result.backup_file = Backup-File -path $dest -WhatIf:$check_mode + } + + if (Test-Path -LiteralPath $dest -PathType Leaf) { + Remove-Item -LiteralPath $dest -Force -Recurse -WhatIf:$check_mode | Out-Null + $diff += "-$dest`n" + } + + if (-not $check_mode) { + # cannot run with -WhatIf:$check_mode as if the parent dir didn't + # exist and was created above would still not exist in check mode + Copy-Item -LiteralPath $source -Destination $dest -Force | Out-Null + } + $diff += "+$dest`n" + + $result.changed = $true + } + + # ugly but to save us from running the checksum twice, let's return it for + # the main code to add it to $result + return ,@{ diff = $diff; checksum = $source_checksum } +} + +Function Copy-Folder($source, $dest) { + $diff = "" + + if (-not (Test-Path -LiteralPath $dest -PathType Container)) { + $parent_dir = [System.IO.Path]::GetDirectoryName($dest) + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder" + } + if (Test-Path -LiteralPath $dest -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder from '$source' to '$dest': dest is already a file" + } + + New-Item -Path $dest -ItemType Container -WhatIf:$check_mode | Out-Null + $diff += "+$dest\`n" + $result.changed = $true + } + + $child_items = Get-ChildItem -LiteralPath $source -Force + foreach ($child_item in $child_items) { + $dest_child_path = Join-Path -Path $dest -ChildPath $child_item.Name + if ($child_item.PSIsContainer) { + $diff += (Copy-Folder -source $child_item.Fullname -dest $dest_child_path) + } else { + $diff += (Copy-File -source $child_item.Fullname -dest $dest_child_path).diff + } + } + + return $diff +} + +Function Get-FileSize($path) { + $file = Get-Item -LiteralPath $path -Force + if ($file.PSIsContainer) { + $size = (Get-ChildItem -Literalpath $file.FullName -Recurse -Force | ` + Where-Object { $_.PSObject.Properties.Name -contains 'Length' } | ` + Measure-Object -Property Length -Sum).Sum + if ($null -eq $size) { + $size = 0 + } + } else { + $size = $file.Length + } + + $size +} + +Function Extract-Zip($src, $dest) { + $archive = [System.IO.Compression.ZipFile]::Open($src, [System.IO.Compression.ZipArchiveMode]::Read, [System.Text.Encoding]::UTF8) + foreach ($entry in $archive.Entries) { + $archive_name = $entry.FullName + + # FullName may be appended with / or \, determine if it is padded and remove it + $padding_length = $archive_name.Length % 4 + if ($padding_length -eq 0) { + $is_dir = $false + $base64_name = $archive_name + } elseif ($padding_length -eq 1) { + $is_dir = $true + if ($archive_name.EndsWith("/") -or $archive_name.EndsWith("`\")) { + $base64_name = $archive_name.Substring(0, $archive_name.Length - 1) + } else { + throw "invalid base64 archive name '$archive_name'" + } + } else { + throw "invalid base64 length '$archive_name'" + } + + # to handle unicode character, win_copy action plugin has encoded the filename + $decoded_archive_name = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($base64_name)) + # re-add the / to the entry full name if it was a directory + if ($is_dir) { + $decoded_archive_name = "$decoded_archive_name/" + } + $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_name) + $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path) + + if (-not (Test-Path -LiteralPath $entry_dir)) { + New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + } + + if ($is_dir -eq $false) { + if (-not $check_mode) { + [System.IO.Compression.ZipFileExtensions]::ExtractToFile($entry, $entry_target_path, $true) + } + } + } + $archive.Dispose() # release the handle of the zip file +} + +Function Extract-ZipLegacy($src, $dest) { + if (-not (Test-Path -LiteralPath $dest)) { + New-Item -Path $dest -ItemType Directory -WhatIf:$check_mode | Out-Null + } + $shell = New-Object -ComObject Shell.Application + $zip = $shell.NameSpace($src) + $dest_path = $shell.NameSpace($dest) + + foreach ($entry in $zip.Items()) { + $is_dir = $entry.IsFolder + $encoded_archive_entry = $entry.Name + # to handle unicode character, win_copy action plugin has encoded the filename + $decoded_archive_entry = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($encoded_archive_entry)) + if ($is_dir) { + $decoded_archive_entry = "$decoded_archive_entry/" + } + + $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_entry) + $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path) + + if (-not (Test-Path -LiteralPath $entry_dir)) { + New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + } + + if ($is_dir -eq $false -and (-not $check_mode)) { + # https://msdn.microsoft.com/en-us/library/windows/desktop/bb787866.aspx + # From Folder.CopyHere documentation, 1044 means: + # - 1024: do not display a user interface if an error occurs + # - 16: respond with "yes to all" for any dialog box that is displayed + # - 4: do not display a progress dialog box + $dest_path.CopyHere($entry, 1044) + + # once file is extraced, we need to rename it with non base64 name + $combined_encoded_path = [System.IO.Path]::Combine($dest, $encoded_archive_entry) + Move-Item -LiteralPath $combined_encoded_path -Destination $entry_target_path -Force | Out-Null + } + } +} + +if ($copy_mode -eq "query") { + # we only return a list of files/directories that need to be copied over + # the source of the local file will be the key used + $changed_files = @() + $changed_directories = @() + $changed_symlinks = @() + + foreach ($file in $files) { + $filename = $file.dest + $local_checksum = $file.checksum + + $filepath = Join-Path -Path $dest -ChildPath $filename + if (Test-Path -LiteralPath $filepath -PathType Leaf) { + if ($force) { + $checksum = Get-FileChecksum -path $filepath + if ($checksum -ne $local_checksum) { + $changed_files += $file + } + } + } elseif (Test-Path -LiteralPath $filepath -PathType Container) { + Fail-Json -obj $result -message "cannot copy file to dest '$filepath': object at path is already a directory" + } else { + $changed_files += $file + } + } + + foreach ($directory in $directories) { + $dirname = $directory.dest + + $dirpath = Join-Path -Path $dest -ChildPath $dirname + $parent_dir = [System.IO.Path]::GetDirectoryName($dirpath) + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at parent directory path is already a file" + } + if (Test-Path -LiteralPath $dirpath -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at path is already a file" + } elseif (-not (Test-Path -LiteralPath $dirpath -PathType Container)) { + $changed_directories += $directory + } + } + + # TODO: Handle symlinks + + $result.files = $changed_files + $result.directories = $changed_directories + $result.symlinks = $changed_symlinks +} elseif ($copy_mode -eq "explode") { + # a single zip file containing the files and directories needs to be + # expanded this will always result in a change as the calculation is done + # on the win_copy action plugin and is only run if a change needs to occur + if (-not (Test-Path -LiteralPath $src -PathType Leaf)) { + Fail-Json -obj $result -message "Cannot expand src zip file: '$src' as it does not exist" + } + + # Detect if the PS zip assemblies are available or whether to use Shell + $use_legacy = $false + try { + Add-Type -AssemblyName System.IO.Compression.FileSystem | Out-Null + Add-Type -AssemblyName System.IO.Compression | Out-Null + } catch { + $use_legacy = $true + } + if ($use_legacy) { + Extract-ZipLegacy -src $src -dest $dest + } else { + Extract-Zip -src $src -dest $dest + } + + $result.changed = $true +} elseif ($copy_mode -eq "remote") { + # all copy actions are happening on the remote side (windows host), need + # too copy source and dest using PS code + $result.src = $src + $result.dest = $dest + + if (-not (Test-Path -LiteralPath $src)) { + Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist" + } + + if (Test-Path -LiteralPath $src -PathType Container) { + # we are copying a directory or the contents of a directory + $result.operation = 'folder_copy' + if ($src.EndsWith("/") -or $src.EndsWith("`\")) { + # copying the folder's contents to dest + $diff = "" + $child_files = Get-ChildItem -LiteralPath $src -Force + foreach ($child_file in $child_files) { + $dest_child_path = Join-Path -Path $dest -ChildPath $child_file.Name + if ($child_file.PSIsContainer) { + $diff += Copy-Folder -source $child_file.FullName -dest $dest_child_path + } else { + $diff += (Copy-File -source $child_file.FullName -dest $dest_child_path).diff + } + } + } else { + # copying the folder and it's contents to dest + $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name + $result.dest = $dest + $diff = Copy-Folder -source $src -dest $dest + } + } else { + # we are just copying a single file to dest + $result.operation = 'file_copy' + + $source_basename = (Get-Item -LiteralPath $src -Force).Name + $result.original_basename = $source_basename + + if ($dest.EndsWith("/") -or $dest.EndsWith("`\")) { + $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name + $result.dest = $dest + } else { + # check if the parent dir exists, this is only done if src is a + # file and dest if the path to a file (doesn't end with \ or /) + $parent_dir = Split-Path -LiteralPath $dest + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist" + } + } + $copy_result = Copy-File -source $src -dest $dest + $diff = $copy_result.diff + $result.checksum = $copy_result.checksum + } + + # the file might not exist if running in check mode + if (-not $check_mode -or (Test-Path -LiteralPath $dest -PathType Leaf)) { + $result.size = Get-FileSize -path $dest + } else { + $result.size = $null + } + if ($diff_mode) { + $result.diff.prepared = $diff + } +} elseif ($copy_mode -eq "single") { + # a single file is located in src and we need to copy to dest, this will + # always result in a change as the calculation is done on the Ansible side + # before this is run. This should also never run in check mode + if (-not (Test-Path -LiteralPath $src -PathType Leaf)) { + Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist" + } + + # the dest parameter is a directory, we need to append original_basename + if ($dest.EndsWith("/") -or $dest.EndsWith("`\") -or (Test-Path -LiteralPath $dest -PathType Container)) { + $remote_dest = Join-Path -Path $dest -ChildPath $original_basename + $parent_dir = Split-Path -LiteralPath $remote_dest + + # when dest ends with /, we need to create the destination directories + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + New-Item -Path $parent_dir -ItemType Directory | Out-Null + } + } else { + $remote_dest = $dest + $parent_dir = Split-Path -LiteralPath $remote_dest + + # check if the dest parent dirs exist, need to fail if they don't + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist" + } + } + + if ($backup) { + $result.backup_file = Backup-File -path $remote_dest -WhatIf:$check_mode + } + + Copy-Item -LiteralPath $src -Destination $remote_dest -Force | Out-Null + $result.changed = $true +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/plugins/modules/win_copy.py b/test/support/windows-integration/plugins/modules/win_copy.py new file mode 100644 index 0000000..a55f4c6 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_copy.py @@ -0,0 +1,207 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_copy +version_added: '1.9.2' +short_description: Copies files to remote locations on windows hosts +description: +- The C(win_copy) module copies a file on the local box to remote windows locations. +- For non-Windows targets, use the M(copy) module instead. +options: + content: + description: + - When used instead of C(src), sets the contents of a file directly to the + specified value. + - This is for simple values, for anything complex or with formatting please + switch to the M(template) module. + type: str + version_added: '2.3' + decrypt: + description: + - This option controls the autodecryption of source files using vault. + type: bool + default: yes + version_added: '2.5' + dest: + description: + - Remote absolute path where the file should be copied to. + - If C(src) is a directory, this must be a directory too. + - Use \ for path separators or \\ when in "double quotes". + - If C(dest) ends with \ then source or the contents of source will be + copied to the directory without renaming. + - If C(dest) is a nonexistent path, it will only be created if C(dest) ends + with "/" or "\", or C(src) is a directory. + - If C(src) and C(dest) are files and if the parent directory of C(dest) + doesn't exist, then the task will fail. + type: path + required: yes + backup: + description: + - Determine whether a backup should be created. + - When set to C(yes), create a backup file including the timestamp information + so you can get the original file back if you somehow clobbered it incorrectly. + - No backup is taken when C(remote_src=False) and multiple files are being + copied. + type: bool + default: no + version_added: '2.8' + force: + description: + - If set to C(yes), the file will only be transferred if the content + is different than destination. + - If set to C(no), the file will only be transferred if the + destination does not exist. + - If set to C(no), no checksuming of the content is performed which can + help improve performance on larger files. + type: bool + default: yes + version_added: '2.3' + local_follow: + description: + - This flag indicates that filesystem links in the source tree, if they + exist, should be followed. + type: bool + default: yes + version_added: '2.4' + remote_src: + description: + - If C(no), it will search for src at originating/master machine. + - If C(yes), it will go to the remote/target machine for the src. + type: bool + default: no + version_added: '2.3' + src: + description: + - Local path to a file to copy to the remote server; can be absolute or + relative. + - If path is a directory, it is copied (including the source folder name) + recursively to C(dest). + - If path is a directory and ends with "/", only the inside contents of + that directory are copied to the destination. Otherwise, if it does not + end with "/", the directory itself with all contents is copied. + - If path is a file and dest ends with "\", the file is copied to the + folder with the same filename. + - Required unless using C(content). + type: path +notes: +- Currently win_copy does not support copying symbolic links from both local to + remote and remote to remote. +- It is recommended that backslashes C(\) are used instead of C(/) when dealing + with remote paths. +- Because win_copy runs over WinRM, it is not a very efficient transfer + mechanism. If sending large files consider hosting them on a web service and + using M(win_get_url) instead. +seealso: +- module: assemble +- module: copy +- module: win_get_url +- module: win_robocopy +author: +- Jon Hawkesworth (@jhawkesworth) +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Copy a single file + win_copy: + src: /srv/myfiles/foo.conf + dest: C:\Temp\renamed-foo.conf + +- name: Copy a single file, but keep a backup + win_copy: + src: /srv/myfiles/foo.conf + dest: C:\Temp\renamed-foo.conf + backup: yes + +- name: Copy a single file keeping the filename + win_copy: + src: /src/myfiles/foo.conf + dest: C:\Temp\ + +- name: Copy folder to C:\Temp (results in C:\Temp\temp_files) + win_copy: + src: files/temp_files + dest: C:\Temp + +- name: Copy folder contents recursively + win_copy: + src: files/temp_files/ + dest: C:\Temp + +- name: Copy a single file where the source is on the remote host + win_copy: + src: C:\Temp\foo.txt + dest: C:\ansible\foo.txt + remote_src: yes + +- name: Copy a folder recursively where the source is on the remote host + win_copy: + src: C:\Temp + dest: C:\ansible + remote_src: yes + +- name: Set the contents of a file + win_copy: + content: abc123 + dest: C:\Temp\foo.txt + +- name: Copy a single file as another user + win_copy: + src: NuGet.config + dest: '%AppData%\NuGet\NuGet.config' + vars: + ansible_become_user: user + ansible_become_password: pass + # The tmp dir must be set when using win_copy as another user + # This ensures the become user will have permissions for the operation + # Make sure to specify a folder both the ansible_user and the become_user have access to (i.e not %TEMP% which is user specific and requires Admin) + ansible_remote_tmp: 'c:\tmp' +''' + +RETURN = r''' +backup_file: + description: Name of the backup file that was created. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +dest: + description: Destination file/path. + returned: changed + type: str + sample: C:\Temp\ +src: + description: Source file used for the copy on the target machine. + returned: changed + type: str + sample: /home/httpd/.ansible/tmp/ansible-tmp-1423796390.97-147729857856000/source +checksum: + description: SHA1 checksum of the file after running copy. + returned: success, src is a file + type: str + sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827 +size: + description: Size of the target, after execution. + returned: changed, src is a file + type: int + sample: 1220 +operation: + description: Whether a single file copy took place or a folder copy. + returned: success + type: str + sample: file_copy +original_basename: + description: Basename of the copied file. + returned: changed, src is a file + type: str + sample: foo.txt +''' diff --git a/test/support/windows-integration/plugins/modules/win_file.ps1 b/test/support/windows-integration/plugins/modules/win_file.ps1 new file mode 100644 index 0000000..5442754 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_file.ps1 @@ -0,0 +1,152 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$ErrorActionPreference = "Stop" + +$params = Parse-Args $args -supports_check_mode $true + +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -default $false +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$path = Get-AnsibleParam -obj $params -name "path" -type "path" -failifempty $true -aliases "dest","name" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -validateset "absent","directory","file","touch" + +# used in template/copy when dest is the path to a dir and source is a file +$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str" +if ((Test-Path -LiteralPath $path -PathType Container) -and ($null -ne $original_basename)) { + $path = Join-Path -Path $path -ChildPath $original_basename +} + +$result = @{ + changed = $false +} + +# Used to delete symlinks as powershell cannot delete broken symlinks +$symlink_util = @" +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Ansible.Command { + public class SymLinkHelper { + [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)] + public static extern bool DeleteFileW(string lpFileName); + + [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)] + public static extern bool RemoveDirectoryW(string lpPathName); + + public static void DeleteDirectory(string path) { + if (!RemoveDirectoryW(path)) + throw new Exception(String.Format("RemoveDirectoryW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message)); + } + + public static void DeleteFile(string path) { + if (!DeleteFileW(path)) + throw new Exception(String.Format("DeleteFileW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message)); + } + } +} +"@ +$original_tmp = $env:TMP +$env:TMP = $_remote_tmp +Add-Type -TypeDefinition $symlink_util +$env:TMP = $original_tmp + +# Used to delete directories and files with logic on handling symbolic links +function Remove-File($file, $checkmode) { + try { + if ($file.Attributes -band [System.IO.FileAttributes]::ReparsePoint) { + # Bug with powershell, if you try and delete a symbolic link that is pointing + # to an invalid path it will fail, using Win32 API to do this instead + if ($file.PSIsContainer) { + if (-not $checkmode) { + [Ansible.Command.SymLinkHelper]::DeleteDirectory($file.FullName) + } + } else { + if (-not $checkmode) { + [Ansible.Command.SymlinkHelper]::DeleteFile($file.FullName) + } + } + } elseif ($file.PSIsContainer) { + Remove-Directory -directory $file -checkmode $checkmode + } else { + Remove-Item -LiteralPath $file.FullName -Force -WhatIf:$checkmode + } + } catch [Exception] { + Fail-Json $result "Failed to delete $($file.FullName): $($_.Exception.Message)" + } +} + +function Remove-Directory($directory, $checkmode) { + foreach ($file in Get-ChildItem -LiteralPath $directory.FullName) { + Remove-File -file $file -checkmode $checkmode + } + Remove-Item -LiteralPath $directory.FullName -Force -Recurse -WhatIf:$checkmode +} + + +if ($state -eq "touch") { + if (Test-Path -LiteralPath $path) { + if (-not $check_mode) { + (Get-ChildItem -LiteralPath $path).LastWriteTime = Get-Date + } + $result.changed = $true + } else { + Write-Output $null | Out-File -LiteralPath $path -Encoding ASCII -WhatIf:$check_mode + $result.changed = $true + } +} + +if (Test-Path -LiteralPath $path) { + $fileinfo = Get-Item -LiteralPath $path -Force + if ($state -eq "absent") { + Remove-File -file $fileinfo -checkmode $check_mode + $result.changed = $true + } else { + if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a directory" + } + + if ($state -eq "file" -and $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a file" + } + } + +} else { + + # If state is not supplied, test the $path to see if it looks like + # a file or a folder and set state to file or folder + if ($null -eq $state) { + $basename = Split-Path -Path $path -Leaf + if ($basename.length -gt 0) { + $state = "file" + } else { + $state = "directory" + } + } + + if ($state -eq "directory") { + try { + New-Item -Path $path -ItemType Directory -WhatIf:$check_mode | Out-Null + } catch { + if ($_.CategoryInfo.Category -eq "ResourceExists") { + $fileinfo = Get-Item -LiteralPath $_.CategoryInfo.TargetName + if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a directory" + } + } else { + Fail-Json $result $_.Exception.Message + } + } + $result.changed = $true + } elseif ($state -eq "file") { + Fail-Json $result "path $path will not be created" + } + +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_file.py b/test/support/windows-integration/plugins/modules/win_file.py new file mode 100644 index 0000000..2814957 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_file.py @@ -0,0 +1,70 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_file +version_added: "1.9.2" +short_description: Creates, touches or removes files or directories +description: + - Creates (empty) files, updates file modification stamps of existing files, + and can create or remove directories. + - Unlike M(file), does not modify ownership, permissions or manipulate links. + - For non-Windows targets, use the M(file) module instead. +options: + path: + description: + - Path to the file being managed. + required: yes + type: path + aliases: [ dest, name ] + state: + description: + - If C(directory), all immediate subdirectories will be created if they + do not exist. + - If C(file), the file will NOT be created if it does not exist, see the M(copy) + or M(template) module if you want that behavior. + - If C(absent), directories will be recursively deleted, and files will be removed. + - If C(touch), an empty file will be created if the C(path) does not + exist, while an existing file or directory will receive updated file access and + modification times (similar to the way C(touch) works from the command line). + type: str + choices: [ absent, directory, file, touch ] +seealso: +- module: file +- module: win_acl +- module: win_acl_inheritance +- module: win_owner +- module: win_stat +author: +- Jon Hawkesworth (@jhawkesworth) +''' + +EXAMPLES = r''' +- name: Touch a file (creates if not present, updates modification time if present) + win_file: + path: C:\Temp\foo.conf + state: touch + +- name: Remove a file, if present + win_file: + path: C:\Temp\foo.conf + state: absent + +- name: Create directory structure + win_file: + path: C:\Temp\folder\subfolder + state: directory + +- name: Remove directory structure + win_file: + path: C:\Temp + state: absent +''' diff --git a/test/support/windows-integration/plugins/modules/win_get_url.ps1 b/test/support/windows-integration/plugins/modules/win_get_url.ps1 new file mode 100644 index 0000000..1d8dd5a --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_get_url.ps1 @@ -0,0 +1,274 @@ +#!powershell + +# Copyright: (c) 2015, Paul Durivage <paul.durivage@rackspace.com> +# Copyright: (c) 2015, Tal Auslander <tal@cloudshare.com> +# Copyright: (c) 2017, Dag Wieers <dag@wieers.com> +# Copyright: (c) 2019, Viktor Utkin <viktor_utkin@epam.com> +# Copyright: (c) 2019, Uladzimir Klybik <uladzimir_klybik@epam.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.FileUtil +#Requires -Module Ansible.ModuleUtils.WebRequest + +$spec = @{ + options = @{ + url = @{ type="str"; required=$true } + dest = @{ type='path'; required=$true } + force = @{ type='bool'; default=$true } + checksum = @{ type='str' } + checksum_algorithm = @{ type='str'; default='sha1'; choices = @("md5", "sha1", "sha256", "sha384", "sha512") } + checksum_url = @{ type='str' } + + # Defined for the alias backwards compatibility, remove once aliases are removed + url_username = @{ + aliases = @("user", "username") + deprecated_aliases = @( + @{ name = "user"; version = "2.14" }, + @{ name = "username"; version = "2.14" } + ) + } + url_password = @{ + aliases = @("password") + deprecated_aliases = @( + @{ name = "password"; version = "2.14" } + ) + } + } + mutually_exclusive = @( + ,@('checksum', 'checksum_url') + ) + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWebRequestSpec)) + +$url = $module.Params.url +$dest = $module.Params.dest +$force = $module.Params.force +$checksum = $module.Params.checksum +$checksum_algorithm = $module.Params.checksum_algorithm +$checksum_url = $module.Params.checksum_url + +$module.Result.elapsed = 0 +$module.Result.url = $url + +Function Get-ChecksumFromUri { + param( + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Parameter(Mandatory=$true)][Uri]$Uri, + [Uri]$SourceUri + ) + + $script = { + param($Response, $Stream) + + $read_stream = New-Object -TypeName System.IO.StreamReader -ArgumentList $Stream + $web_checksum = $read_stream.ReadToEnd() + $basename = (Split-Path -Path $SourceUri.LocalPath -Leaf) + $basename = [regex]::Escape($basename) + $web_checksum_str = $web_checksum -split '\r?\n' | Select-String -Pattern $("\s+\.?\/?\\?" + $basename + "\s*$") + if (-not $web_checksum_str) { + $Module.FailJson("Checksum record not found for file name '$basename' in file from url: '$Uri'") + } + + $web_checksum_str_splitted = $web_checksum_str[0].ToString().split(" ", 2) + $hash_from_file = $web_checksum_str_splitted[0].Trim() + # Remove any non-alphanumeric characters + $hash_from_file = $hash_from_file -replace '\W+', '' + + Write-Output -InputObject $hash_from_file + } + $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module + + try { + Invoke-WithWebRequest -Module $Module -Request $web_request -Script $script + } catch { + $Module.FailJson("Error when getting the remote checksum from '$Uri'. $($_.Exception.Message)", $_) + } +} + +Function Compare-ModifiedFile { + <# + .SYNOPSIS + Compares the remote URI resource against the local Dest resource. Will + return true if the LastWriteTime/LastModificationDate of the remote is + newer than the local resource date. + #> + param( + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Parameter(Mandatory=$true)][Uri]$Uri, + [Parameter(Mandatory=$true)][String]$Dest + ) + + $dest_last_mod = (Get-AnsibleItem -Path $Dest).LastWriteTimeUtc + + # If the URI is a file we don't need to go through the whole WebRequest + if ($Uri.IsFile) { + $src_last_mod = (Get-AnsibleItem -Path $Uri.AbsolutePath).LastWriteTimeUtc + } else { + $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module + $web_request.Method = switch ($web_request.GetType().Name) { + FtpWebRequest { [System.Net.WebRequestMethods+Ftp]::GetDateTimestamp } + HttpWebRequest { [System.Net.WebRequestMethods+Http]::Head } + } + $script = { param($Response, $Stream); $Response.LastModified } + + try { + $src_last_mod = Invoke-WithWebRequest -Module $Module -Request $web_request -Script $script + } catch { + $Module.FailJson("Error when requesting 'Last-Modified' date from '$Uri'. $($_.Exception.Message)", $_) + } + } + + # Return $true if the Uri LastModification date is newer than the Dest LastModification date + ((Get-Date -Date $src_last_mod).ToUniversalTime() -gt $dest_last_mod) +} + +Function Get-Checksum { + param( + [Parameter(Mandatory=$true)][String]$Path, + [String]$Algorithm = "sha1" + ) + + switch ($Algorithm) { + 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider } + 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider } + 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider } + 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider } + 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider } + } + + $fs = [System.IO.File]::Open($Path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, + [System.IO.FileShare]::ReadWrite) + try { + $hash = [System.BitConverter]::ToString($sp.ComputeHash($fs)).Replace("-", "").ToLower() + } finally { + $fs.Dispose() + } + return $hash +} + +Function Invoke-DownloadFile { + param( + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Parameter(Mandatory=$true)][Uri]$Uri, + [Parameter(Mandatory=$true)][String]$Dest, + [String]$Checksum, + [String]$ChecksumAlgorithm + ) + + # Check $dest parent folder exists before attempting download, which avoids unhelpful generic error message. + $dest_parent = Split-Path -LiteralPath $Dest + if (-not (Test-Path -LiteralPath $dest_parent -PathType Container)) { + $module.FailJson("The path '$dest_parent' does not exist for destination '$Dest', or is not visible to the current user. Ensure download destination folder exists (perhaps using win_file state=directory) before win_get_url runs.") + } + + $download_script = { + param($Response, $Stream) + + # Download the file to a temporary directory so we can compare it + $tmp_dest = Join-Path -Path $Module.Tmpdir -ChildPath ([System.IO.Path]::GetRandomFileName()) + $fs = [System.IO.File]::Create($tmp_dest) + try { + $Stream.CopyTo($fs) + $fs.Flush() + } finally { + $fs.Dispose() + } + $tmp_checksum = Get-Checksum -Path $tmp_dest -Algorithm $ChecksumAlgorithm + $Module.Result.checksum_src = $tmp_checksum + + # If the checksum has been set, verify the checksum of the remote against the input checksum. + if ($Checksum -and $Checksum -ne $tmp_checksum) { + $Module.FailJson(("The checksum for {0} did not match '{1}', it was '{2}'" -f $Uri, $Checksum, $tmp_checksum)) + } + + $download = $true + if (Test-Path -LiteralPath $Dest) { + # Validate the remote checksum against the existing downloaded file + $dest_checksum = Get-Checksum -Path $Dest -Algorithm $ChecksumAlgorithm + + # If we don't need to download anything, save the dest checksum so we don't waste time calculating it + # again at the end of the script + if ($dest_checksum -eq $tmp_checksum) { + $download = $false + $Module.Result.checksum_dest = $dest_checksum + $Module.Result.size = (Get-AnsibleItem -Path $Dest).Length + } + } + + if ($download) { + Copy-Item -LiteralPath $tmp_dest -Destination $Dest -Force -WhatIf:$Module.CheckMode > $null + $Module.Result.changed = $true + } + } + $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module + + try { + Invoke-WithWebRequest -Module $Module -Request $web_request -Script $download_script + } catch { + $Module.FailJson("Error downloading '$Uri' to '$Dest': $($_.Exception.Message)", $_) + } +} + +# Use last part of url for dest file name if a directory is supplied for $dest +if (Test-Path -LiteralPath $dest -PathType Container) { + $uri = [System.Uri]$url + $basename = Split-Path -Path $uri.LocalPath -Leaf + if ($uri.LocalPath -and $uri.LocalPath -ne '/' -and $basename) { + $url_basename = Split-Path -Path $uri.LocalPath -Leaf + $dest = Join-Path -Path $dest -ChildPath $url_basename + } else { + $dest = Join-Path -Path $dest -ChildPath $uri.Host + } + + # Ensure we have a string instead of a PS object to avoid serialization issues + $dest = $dest.ToString() +} elseif (([System.IO.Path]::GetFileName($dest)) -eq '') { + # We have a trailing path separator + $module.FailJson("The destination path '$dest' does not exist, or is not visible to the current user. Ensure download destination folder exists (perhaps using win_file state=directory) before win_get_url runs.") +} + +$module.Result.dest = $dest + +if ($checksum) { + $checksum = $checksum.Trim().ToLower() +} +if ($checksum_algorithm) { + $checksum_algorithm = $checksum_algorithm.Trim().ToLower() +} +if ($checksum_url) { + $checksum_url = $checksum_url.Trim() +} + +# Check for case $checksum variable contain url. If yes, get file data from url and replace original value in $checksum +if ($checksum_url) { + $checksum_uri = [System.Uri]$checksum_url + if ($checksum_uri.Scheme -notin @("file", "ftp", "http", "https")) { + $module.FailJson("Unsupported 'checksum_url' value for '$dest': '$checksum_url'") + } + + $checksum = Get-ChecksumFromUri -Module $Module -Uri $checksum_uri -SourceUri $url +} + +if ($force -or -not (Test-Path -LiteralPath $dest)) { + # force=yes or dest does not exist, download the file + # Note: Invoke-DownloadFile will compare the checksums internally if dest exists + Invoke-DownloadFile -Module $module -Uri $url -Dest $dest -Checksum $checksum ` + -ChecksumAlgorithm $checksum_algorithm +} else { + # force=no, we want to check the last modified dates and only download if they don't match + $is_modified = Compare-ModifiedFile -Module $module -Uri $url -Dest $dest + if ($is_modified) { + Invoke-DownloadFile -Module $module -Uri $url -Dest $dest -Checksum $checksum ` + -ChecksumAlgorithm $checksum_algorithm + } +} + +if ((-not $module.Result.ContainsKey("checksum_dest")) -and (Test-Path -LiteralPath $dest)) { + # Calculate the dest file checksum if it hasn't already been done + $module.Result.checksum_dest = Get-Checksum -Path $dest -Algorithm $checksum_algorithm + $module.Result.size = (Get-AnsibleItem -Path $dest).Length +} + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_get_url.py b/test/support/windows-integration/plugins/modules/win_get_url.py new file mode 100644 index 0000000..ef5b5f9 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_get_url.py @@ -0,0 +1,215 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com>, and others +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# This is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_get_url +version_added: "1.7" +short_description: Downloads file from HTTP, HTTPS, or FTP to node +description: +- Downloads files from HTTP, HTTPS, or FTP to the remote server. +- The remote server I(must) have direct access to the remote resource. +- For non-Windows targets, use the M(get_url) module instead. +options: + url: + description: + - The full URL of a file to download. + type: str + required: yes + dest: + description: + - The location to save the file at the URL. + - Be sure to include a filename and extension as appropriate. + type: path + required: yes + force: + description: + - If C(yes), will download the file every time and replace the file if the contents change. If C(no), will only + download the file if it does not exist or the remote file has been + modified more recently than the local file. + - This works by sending an http HEAD request to retrieve last modified + time of the requested resource, so for this to work, the remote web + server must support HEAD requests. + type: bool + default: yes + version_added: "2.0" + checksum: + description: + - If a I(checksum) is passed to this parameter, the digest of the + destination file will be calculated after it is downloaded to ensure + its integrity and verify that the transfer completed successfully. + - This option cannot be set with I(checksum_url). + type: str + version_added: "2.8" + checksum_algorithm: + description: + - Specifies the hashing algorithm used when calculating the checksum of + the remote and destination file. + type: str + choices: + - md5 + - sha1 + - sha256 + - sha384 + - sha512 + default: sha1 + version_added: "2.8" + checksum_url: + description: + - Specifies a URL that contains the checksum values for the resource at + I(url). + - Like C(checksum), this is used to verify the integrity of the remote + transfer. + - This option cannot be set with I(checksum). + type: str + version_added: "2.8" + url_username: + description: + - The username to use for authentication. + - The aliases I(user) and I(username) are deprecated and will be removed in + Ansible 2.14. + aliases: + - user + - username + url_password: + description: + - The password for I(url_username). + - The alias I(password) is deprecated and will be removed in Ansible 2.14. + aliases: + - password + proxy_url: + version_added: "2.0" + proxy_username: + version_added: "2.0" + proxy_password: + version_added: "2.0" + headers: + version_added: "2.4" + use_proxy: + version_added: "2.4" + follow_redirects: + version_added: "2.9" + maximum_redirection: + version_added: "2.9" + client_cert: + version_added: "2.9" + client_cert_password: + version_added: "2.9" + method: + description: + - This option is not for use with C(win_get_url) and should be ignored. + version_added: "2.9" +notes: +- If your URL includes an escaped slash character (%2F) this module will convert it to a real slash. + This is a result of the behaviour of the System.Uri class as described in + L(the documentation,https://docs.microsoft.com/en-us/dotnet/framework/configure-apps/file-schema/network/schemesettings-element-uri-settings#remarks). +- Since Ansible 2.8, the module will skip reporting a change if the remote + checksum is the same as the local local even when C(force=yes). This is to + better align with M(get_url). +extends_documentation_fragment: +- url_windows +seealso: +- module: get_url +- module: uri +- module: win_uri +author: +- Paul Durivage (@angstwad) +- Takeshi Kuramochi (@tksarah) +''' + +EXAMPLES = r''' +- name: Download earthrise.jpg to specified path + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\Users\RandomUser\earthrise.jpg + +- name: Download earthrise.jpg to specified path only if modified + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\Users\RandomUser\earthrise.jpg + force: no + +- name: Download earthrise.jpg to specified path through a proxy server. + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\Users\RandomUser\earthrise.jpg + proxy_url: http://10.0.0.1:8080 + proxy_username: username + proxy_password: password + +- name: Download file from FTP with authentication + win_get_url: + url: ftp://server/file.txt + dest: '%TEMP%\ftp-file.txt' + url_username: ftp-user + url_password: ftp-password + +- name: Download src with sha256 checksum url + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\temp\earthrise.jpg + checksum_url: http://www.example.com/sha256sum.txt + checksum_algorithm: sha256 + force: True + +- name: Download src with sha256 checksum url + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\temp\earthrise.jpg + checksum: a97e6837f60cec6da4491bab387296bbcd72bdba + checksum_algorithm: sha1 + force: True +''' + +RETURN = r''' +dest: + description: destination file/path + returned: always + type: str + sample: C:\Users\RandomUser\earthrise.jpg +checksum_dest: + description: <algorithm> checksum of the file after the download + returned: success and dest has been downloaded + type: str + sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827 +checksum_src: + description: <algorithm> checksum of the remote resource + returned: force=yes or dest did not exist + type: str + sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827 +elapsed: + description: The elapsed seconds between the start of poll and the end of the module. + returned: always + type: float + sample: 2.1406487 +size: + description: size of the dest file + returned: success + type: int + sample: 1220 +url: + description: requested url + returned: always + type: str + sample: http://www.example.com/earthrise.jpg +msg: + description: Error message, or HTTP status message from web-server + returned: always + type: str + sample: OK +status_code: + description: HTTP status code + returned: always + type: int + sample: 200 +''' diff --git a/test/support/windows-integration/plugins/modules/win_lineinfile.ps1 b/test/support/windows-integration/plugins/modules/win_lineinfile.ps1 new file mode 100644 index 0000000..38dd8b8 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_lineinfile.ps1 @@ -0,0 +1,450 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.Backup + +function WriteLines($outlines, $path, $linesep, $encodingobj, $validate, $check_mode) { + Try { + $temppath = [System.IO.Path]::GetTempFileName(); + } + Catch { + Fail-Json @{} "Cannot create temporary file! ($($_.Exception.Message))"; + } + $joined = $outlines -join $linesep; + [System.IO.File]::WriteAllText($temppath, $joined, $encodingobj); + + If ($validate) { + + If (-not ($validate -like "*%s*")) { + Fail-Json @{} "validate must contain %s: $validate"; + } + + $validate = $validate.Replace("%s", $temppath); + + $parts = [System.Collections.ArrayList] $validate.Split(" "); + $cmdname = $parts[0]; + + $cmdargs = $validate.Substring($cmdname.Length + 1); + + $process = [Diagnostics.Process]::Start($cmdname, $cmdargs); + $process.WaitForExit(); + + If ($process.ExitCode -ne 0) { + [string] $output = $process.StandardOutput.ReadToEnd(); + [string] $error = $process.StandardError.ReadToEnd(); + Remove-Item $temppath -force; + Fail-Json @{} "failed to validate $cmdname $cmdargs with error: $output $error"; + } + + } + + # Commit changes to the path + $cleanpath = $path.Replace("/", "\"); + Try { + Copy-Item -Path $temppath -Destination $cleanpath -Force -WhatIf:$check_mode; + } + Catch { + Fail-Json @{} "Cannot write to: $cleanpath ($($_.Exception.Message))"; + } + + Try { + Remove-Item -Path $temppath -Force -WhatIf:$check_mode; + } + Catch { + Fail-Json @{} "Cannot remove temporary file: $temppath ($($_.Exception.Message))"; + } + + return $joined; + +} + + +# Implement the functionality for state == 'present' +function Present($path, $regex, $line, $insertafter, $insertbefore, $create, $backup, $backrefs, $validate, $encodingobj, $linesep, $check_mode, $diff_support) { + + # Note that we have to clean up the path because ansible wants to treat / and \ as + # interchangeable in windows pathnames, but .NET framework internals do not support that. + $cleanpath = $path.Replace("/", "\"); + + # Check if path exists. If it does not exist, either create it if create == "yes" + # was specified or fail with a reasonable error message. + If (-not (Test-Path -LiteralPath $path)) { + If (-not $create) { + Fail-Json @{} "Path $path does not exist !"; + } + # Create new empty file, using the specified encoding to write correct BOM + [System.IO.File]::WriteAllLines($cleanpath, "", $encodingobj); + } + + # Initialize result information + $result = @{ + backup = ""; + changed = $false; + msg = ""; + } + + # Read the dest file lines using the indicated encoding into a mutable ArrayList. + $before = [System.IO.File]::ReadAllLines($cleanpath, $encodingobj) + If ($null -eq $before) { + $lines = New-Object System.Collections.ArrayList; + } + Else { + $lines = [System.Collections.ArrayList] $before; + } + + if ($diff_support) { + $result.diff = @{ + before = $before -join $linesep; + } + } + + # Compile the regex specified, if provided + $mre = $null; + If ($regex) { + $mre = New-Object Regex $regex, 'Compiled'; + } + + # Compile the regex for insertafter or insertbefore, if provided + $insre = $null; + If ($insertafter -and $insertafter -ne "BOF" -and $insertafter -ne "EOF") { + $insre = New-Object Regex $insertafter, 'Compiled'; + } + ElseIf ($insertbefore -and $insertbefore -ne "BOF") { + $insre = New-Object Regex $insertbefore, 'Compiled'; + } + + # index[0] is the line num where regex has been found + # index[1] is the line num where insertafter/insertbefore has been found + $index = -1, -1; + $lineno = 0; + + # The latest match object and matched line + $matched_line = ""; + + # Iterate through the lines in the file looking for matches + Foreach ($cur_line in $lines) { + If ($regex) { + $m = $mre.Match($cur_line); + $match_found = $m.Success; + If ($match_found) { + $matched_line = $cur_line; + } + } + Else { + $match_found = $line -ceq $cur_line; + } + If ($match_found) { + $index[0] = $lineno; + } + ElseIf ($insre -and $insre.Match($cur_line).Success) { + If ($insertafter) { + $index[1] = $lineno + 1; + } + If ($insertbefore) { + $index[1] = $lineno; + } + } + $lineno = $lineno + 1; + } + + If ($index[0] -ne -1) { + If ($backrefs) { + $new_line = [regex]::Replace($matched_line, $regex, $line); + } + Else { + $new_line = $line; + } + If ($lines[$index[0]] -cne $new_line) { + $lines[$index[0]] = $new_line; + $result.changed = $true; + $result.msg = "line replaced"; + } + } + ElseIf ($backrefs) { + # No matches - no-op + } + ElseIf ($insertbefore -eq "BOF" -or $insertafter -eq "BOF") { + $lines.Insert(0, $line); + $result.changed = $true; + $result.msg = "line added"; + } + ElseIf ($insertafter -eq "EOF" -or $index[1] -eq -1) { + $lines.Add($line) > $null; + $result.changed = $true; + $result.msg = "line added"; + } + Else { + $lines.Insert($index[1], $line); + $result.changed = $true; + $result.msg = "line added"; + } + + # Write changes to the path if changes were made + If ($result.changed) { + + # Write backup file if backup == "yes" + If ($backup) { + $result.backup_file = Backup-File -path $path -WhatIf:$check_mode + # Ensure backward compatibility (deprecate in future) + $result.backup = $result.backup_file + } + + $writelines_params = @{ + outlines = $lines + path = $path + linesep = $linesep + encodingobj = $encodingobj + validate = $validate + check_mode = $check_mode + } + $after = WriteLines @writelines_params; + + if ($diff_support) { + $result.diff.after = $after; + } + } + + $result.encoding = $encodingobj.WebName; + + Exit-Json $result; +} + + +# Implement the functionality for state == 'absent' +function Absent($path, $regex, $line, $backup, $validate, $encodingobj, $linesep, $check_mode, $diff_support) { + + # Check if path exists. If it does not exist, fail with a reasonable error message. + If (-not (Test-Path -LiteralPath $path)) { + Fail-Json @{} "Path $path does not exist !"; + } + + # Initialize result information + $result = @{ + backup = ""; + changed = $false; + msg = ""; + } + + # Read the dest file lines using the indicated encoding into a mutable ArrayList. Note + # that we have to clean up the path because ansible wants to treat / and \ as + # interchangeable in windows pathnames, but .NET framework internals do not support that. + $cleanpath = $path.Replace("/", "\"); + $before = [System.IO.File]::ReadAllLines($cleanpath, $encodingobj); + If ($null -eq $before) { + $lines = New-Object System.Collections.ArrayList; + } + Else { + $lines = [System.Collections.ArrayList] $before; + } + + if ($diff_support) { + $result.diff = @{ + before = $before -join $linesep; + } + } + + # Compile the regex specified, if provided + $cre = $null; + If ($regex) { + $cre = New-Object Regex $regex, 'Compiled'; + } + + $found = New-Object System.Collections.ArrayList; + $left = New-Object System.Collections.ArrayList; + + Foreach ($cur_line in $lines) { + If ($regex) { + $m = $cre.Match($cur_line); + $match_found = $m.Success; + } + Else { + $match_found = $line -ceq $cur_line; + } + If ($match_found) { + $found.Add($cur_line) > $null; + $result.changed = $true; + } + Else { + $left.Add($cur_line) > $null; + } + } + + # Write changes to the path if changes were made + If ($result.changed) { + + # Write backup file if backup == "yes" + If ($backup) { + $result.backup_file = Backup-File -path $path -WhatIf:$check_mode + # Ensure backward compatibility (deprecate in future) + $result.backup = $result.backup_file + } + + $writelines_params = @{ + outlines = $left + path = $path + linesep = $linesep + encodingobj = $encodingobj + validate = $validate + check_mode = $check_mode + } + $after = WriteLines @writelines_params; + + if ($diff_support) { + $result.diff.after = $after; + } + } + + $result.encoding = $encodingobj.WebName; + $result.found = $found.Count; + $result.msg = "$($found.Count) line(s) removed"; + + Exit-Json $result; +} + + +# Parse the parameters file dropped by the Ansible machinery +$params = Parse-Args $args -supports_check_mode $true; +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false; +$diff_support = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false; + +# Initialize defaults for input parameters. +$path = Get-AnsibleParam -obj $params -name "path" -type "path" -failifempty $true -aliases "dest","destfile","name"; +$regex = Get-AnsibleParam -obj $params -name "regex" -type "str" -aliases "regexp"; +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent"; +$line = Get-AnsibleParam -obj $params -name "line" -type "str"; +$backrefs = Get-AnsibleParam -obj $params -name "backrefs" -type "bool" -default $false; +$insertafter = Get-AnsibleParam -obj $params -name "insertafter" -type "str"; +$insertbefore = Get-AnsibleParam -obj $params -name "insertbefore" -type "str"; +$create = Get-AnsibleParam -obj $params -name "create" -type "bool" -default $false; +$backup = Get-AnsibleParam -obj $params -name "backup" -type "bool" -default $false; +$validate = Get-AnsibleParam -obj $params -name "validate" -type "str"; +$encoding = Get-AnsibleParam -obj $params -name "encoding" -type "str" -default "auto"; +$newline = Get-AnsibleParam -obj $params -name "newline" -type "str" -default "windows" -validateset "unix","windows"; + +# Fail if the path is not a file +If (Test-Path -LiteralPath $path -PathType "container") { + Fail-Json @{} "Path $path is a directory"; +} + +# Default to windows line separator - probably most common +$linesep = "`r`n" +If ($newline -eq "unix") { + $linesep = "`n"; +} + +# Figure out the proper encoding to use for reading / writing the target file. + +# The default encoding is UTF-8 without BOM +$encodingobj = [System.Text.UTF8Encoding] $false; + +# If an explicit encoding is specified, use that instead +If ($encoding -ne "auto") { + $encodingobj = [System.Text.Encoding]::GetEncoding($encoding); +} + +# Otherwise see if we can determine the current encoding of the target file. +# If the file doesn't exist yet (create == 'yes') we use the default or +# explicitly specified encoding set above. +ElseIf (Test-Path -LiteralPath $path) { + + # Get a sorted list of encodings with preambles, longest first + $max_preamble_len = 0; + $sortedlist = New-Object System.Collections.SortedList; + Foreach ($encodinginfo in [System.Text.Encoding]::GetEncodings()) { + $encoding = $encodinginfo.GetEncoding(); + $plen = $encoding.GetPreamble().Length; + If ($plen -gt $max_preamble_len) { + $max_preamble_len = $plen; + } + If ($plen -gt 0) { + $sortedlist.Add(-($plen * 1000000 + $encoding.CodePage), $encoding) > $null; + } + } + + # Get the first N bytes from the file, where N is the max preamble length we saw + [Byte[]]$bom = Get-Content -Encoding Byte -ReadCount $max_preamble_len -TotalCount $max_preamble_len -LiteralPath $path; + + # Iterate through the sorted encodings, looking for a full match. + $found = $false; + Foreach ($encoding in $sortedlist.GetValueList()) { + $preamble = $encoding.GetPreamble(); + If ($preamble -and $bom) { + Foreach ($i in 0..($preamble.Length - 1)) { + If ($i -ge $bom.Length) { + break; + } + If ($preamble[$i] -ne $bom[$i]) { + break; + } + ElseIf ($i + 1 -eq $preamble.Length) { + $encodingobj = $encoding; + $found = $true; + } + } + If ($found) { + break; + } + } + } +} + + +# Main dispatch - based on the value of 'state', perform argument validation and +# call the appropriate handler function. +If ($state -eq "present") { + + If ($backrefs -and -not $regex) { + Fail-Json @{} "regexp= is required with backrefs=true"; + } + + If (-not $line) { + Fail-Json @{} "line= is required with state=present"; + } + + If ($insertbefore -and $insertafter) { + Add-Warning $result "Both insertbefore and insertafter parameters found, ignoring `"insertafter=$insertafter`"" + } + + If (-not $insertbefore -and -not $insertafter) { + $insertafter = "EOF"; + } + + $present_params = @{ + path = $path + regex = $regex + line = $line + insertafter = $insertafter + insertbefore = $insertbefore + create = $create + backup = $backup + backrefs = $backrefs + validate = $validate + encodingobj = $encodingobj + linesep = $linesep + check_mode = $check_mode + diff_support = $diff_support + } + Present @present_params; + +} +ElseIf ($state -eq "absent") { + + If (-not $regex -and -not $line) { + Fail-Json @{} "one of line= or regexp= is required with state=absent"; + } + + $absent_params = @{ + path = $path + regex = $regex + line = $line + backup = $backup + validate = $validate + encodingobj = $encodingobj + linesep = $linesep + check_mode = $check_mode + diff_support = $diff_support + } + Absent @absent_params; +} diff --git a/test/support/windows-integration/plugins/modules/win_lineinfile.py b/test/support/windows-integration/plugins/modules/win_lineinfile.py new file mode 100644 index 0000000..f4fb7f5 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_lineinfile.py @@ -0,0 +1,180 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_lineinfile +short_description: Ensure a particular line is in a file, or replace an existing line using a back-referenced regular expression +description: + - This module will search a file for a line, and ensure that it is present or absent. + - This is primarily useful when you want to change a single line in a file only. +version_added: "2.0" +options: + path: + description: + - The path of the file to modify. + - Note that the Windows path delimiter C(\) must be escaped as C(\\) when the line is double quoted. + - Before Ansible 2.3 this option was only usable as I(dest), I(destfile) and I(name). + type: path + required: yes + aliases: [ dest, destfile, name ] + backup: + description: + - Determine whether a backup should be created. + - When set to C(yes), create a backup file including the timestamp information + so you can get the original file back if you somehow clobbered it incorrectly. + type: bool + default: no + regex: + description: + - The regular expression to look for in every line of the file. For C(state=present), the pattern to replace if found; only the last line found + will be replaced. For C(state=absent), the pattern of the line to remove. Uses .NET compatible regular expressions; + see U(https://msdn.microsoft.com/en-us/library/hs600312%28v=vs.110%29.aspx). + aliases: [ "regexp" ] + state: + description: + - Whether the line should be there or not. + type: str + choices: [ absent, present ] + default: present + line: + description: + - Required for C(state=present). The line to insert/replace into the file. If C(backrefs) is set, may contain backreferences that will get + expanded with the C(regexp) capture groups if the regexp matches. + - Be aware that the line is processed first on the controller and thus is dependent on yaml quoting rules. Any double quoted line + will have control characters, such as '\r\n', expanded. To print such characters literally, use single or no quotes. + type: str + backrefs: + description: + - Used with C(state=present). If set, line can contain backreferences (both positional and named) that will get populated if the C(regexp) + matches. This flag changes the operation of the module slightly; C(insertbefore) and C(insertafter) will be ignored, and if the C(regexp) + doesn't match anywhere in the file, the file will be left unchanged. + - If the C(regexp) does match, the last matching line will be replaced by the expanded line parameter. + type: bool + default: no + insertafter: + description: + - Used with C(state=present). If specified, the line will be inserted after the last match of specified regular expression. A special value is + available; C(EOF) for inserting the line at the end of the file. + - If specified regular expression has no matches, EOF will be used instead. May not be used with C(backrefs). + type: str + choices: [ EOF, '*regex*' ] + default: EOF + insertbefore: + description: + - Used with C(state=present). If specified, the line will be inserted before the last match of specified regular expression. A value is available; + C(BOF) for inserting the line at the beginning of the file. + - If specified regular expression has no matches, the line will be inserted at the end of the file. May not be used with C(backrefs). + type: str + choices: [ BOF, '*regex*' ] + create: + description: + - Used with C(state=present). If specified, the file will be created if it does not already exist. By default it will fail if the file is missing. + type: bool + default: no + validate: + description: + - Validation to run before copying into place. Use %s in the command to indicate the current file to validate. + - The command is passed securely so shell features like expansion and pipes won't work. + type: str + encoding: + description: + - Specifies the encoding of the source text file to operate on (and thus what the output encoding will be). The default of C(auto) will cause + the module to auto-detect the encoding of the source file and ensure that the modified file is written with the same encoding. + - An explicit encoding can be passed as a string that is a valid value to pass to the .NET framework System.Text.Encoding.GetEncoding() method - + see U(https://msdn.microsoft.com/en-us/library/system.text.encoding%28v=vs.110%29.aspx). + - This is mostly useful with C(create=yes) if you want to create a new file with a specific encoding. If C(create=yes) is specified without a + specific encoding, the default encoding (UTF-8, no BOM) will be used. + type: str + default: auto + newline: + description: + - Specifies the line separator style to use for the modified file. This defaults to the windows line separator (C(\r\n)). Note that the indicated + line separator will be used for file output regardless of the original line separator that appears in the input file. + type: str + choices: [ unix, windows ] + default: windows +notes: + - As of Ansible 2.3, the I(dest) option has been changed to I(path) as default, but I(dest) still works as well. +seealso: +- module: assemble +- module: lineinfile +author: +- Brian Lloyd (@brianlloyd) +''' + +EXAMPLES = r''' +# Before Ansible 2.3, option 'dest', 'destfile' or 'name' was used instead of 'path' +- name: Insert path without converting \r\n + win_lineinfile: + path: c:\file.txt + line: c:\return\new + +- win_lineinfile: + path: C:\Temp\example.conf + regex: '^name=' + line: 'name=JohnDoe' + +- win_lineinfile: + path: C:\Temp\example.conf + regex: '^name=' + state: absent + +- win_lineinfile: + path: C:\Temp\example.conf + regex: '^127\.0\.0\.1' + line: '127.0.0.1 localhost' + +- win_lineinfile: + path: C:\Temp\httpd.conf + regex: '^Listen ' + insertafter: '^#Listen ' + line: Listen 8080 + +- win_lineinfile: + path: C:\Temp\services + regex: '^# port for http' + insertbefore: '^www.*80/tcp' + line: '# port for http by default' + +- name: Create file if it doesn't exist with a specific encoding + win_lineinfile: + path: C:\Temp\utf16.txt + create: yes + encoding: utf-16 + line: This is a utf-16 encoded file + +- name: Add a line to a file and ensure the resulting file uses unix line separators + win_lineinfile: + path: C:\Temp\testfile.txt + line: Line added to file + newline: unix + +- name: Update a line using backrefs + win_lineinfile: + path: C:\Temp\example.conf + backrefs: yes + regex: '(^name=)' + line: '$1JohnDoe' +''' + +RETURN = r''' +backup: + description: + - Name of the backup file that was created. + - This is now deprecated, use C(backup_file) instead. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +backup_file: + description: Name of the backup file that was created. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +''' diff --git a/test/support/windows-integration/plugins/modules/win_ping.ps1 b/test/support/windows-integration/plugins/modules/win_ping.ps1 new file mode 100644 index 0000000..c848b91 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_ping.ps1 @@ -0,0 +1,21 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic + +$spec = @{ + options = @{ + data = @{ type = "str"; default = "pong" } + } + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) +$data = $module.Params.data + +if ($data -eq "crash") { + throw "boom" +} + +$module.Result.ping = $data +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_ping.py b/test/support/windows-integration/plugins/modules/win_ping.py new file mode 100644 index 0000000..6d35f37 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_ping.py @@ -0,0 +1,55 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>, and others +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_ping +version_added: "1.7" +short_description: A windows version of the classic ping module +description: + - Checks management connectivity of a windows host. + - This is NOT ICMP ping, this is just a trivial test module. + - For non-Windows targets, use the M(ping) module instead. + - For Network targets, use the M(net_ping) module instead. +options: + data: + description: + - Alternate data to return instead of 'pong'. + - If this parameter is set to C(crash), the module will cause an exception. + type: str + default: pong +seealso: +- module: ping +author: +- Chris Church (@cchurch) +''' + +EXAMPLES = r''' +# Test connectivity to a windows host +# ansible winserver -m win_ping + +- name: Example from an Ansible Playbook + win_ping: + +- name: Induce an exception to see what happens + win_ping: + data: crash +''' + +RETURN = r''' +ping: + description: Value provided with the data parameter. + returned: success + type: str + sample: pong +''' diff --git a/test/support/windows-integration/plugins/modules/win_reboot.py b/test/support/windows-integration/plugins/modules/win_reboot.py new file mode 100644 index 0000000..1431804 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_reboot.py @@ -0,0 +1,131 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_reboot +short_description: Reboot a windows machine +description: +- Reboot a Windows machine, wait for it to go down, come back up, and respond to commands. +- For non-Windows targets, use the M(reboot) module instead. +version_added: '2.1' +options: + pre_reboot_delay: + description: + - Seconds to wait before reboot. Passed as a parameter to the reboot command. + type: int + default: 2 + aliases: [ pre_reboot_delay_sec ] + post_reboot_delay: + description: + - Seconds to wait after the reboot command was successful before attempting to validate the system rebooted successfully. + - This is useful if you want wait for something to settle despite your connection already working. + type: int + default: 0 + version_added: '2.4' + aliases: [ post_reboot_delay_sec ] + shutdown_timeout: + description: + - Maximum seconds to wait for shutdown to occur. + - Increase this timeout for very slow hardware, large update applications, etc. + - This option has been removed since Ansible 2.5 as the win_reboot behavior has changed. + type: int + default: 600 + aliases: [ shutdown_timeout_sec ] + reboot_timeout: + description: + - Maximum seconds to wait for machine to re-appear on the network and respond to a test command. + - This timeout is evaluated separately for both reboot verification and test command success so maximum clock time is actually twice this value. + type: int + default: 600 + aliases: [ reboot_timeout_sec ] + connect_timeout: + description: + - Maximum seconds to wait for a single successful TCP connection to the WinRM endpoint before trying again. + type: int + default: 5 + aliases: [ connect_timeout_sec ] + test_command: + description: + - Command to expect success for to determine the machine is ready for management. + type: str + default: whoami + msg: + description: + - Message to display to users. + type: str + default: Reboot initiated by Ansible + boot_time_command: + description: + - Command to run that returns a unique string indicating the last time the system was booted. + - Setting this to a command that has different output each time it is run will cause the task to fail. + type: str + default: '(Get-WmiObject -ClassName Win32_OperatingSystem).LastBootUpTime' + version_added: '2.10' +notes: +- If a shutdown was already scheduled on the system, C(win_reboot) will abort the scheduled shutdown and enforce its own shutdown. +- Beware that when C(win_reboot) returns, the Windows system may not have settled yet and some base services could be in limbo. + This can result in unexpected behavior. Check the examples for ways to mitigate this. +- The connection user must have the C(SeRemoteShutdownPrivilege) privilege enabled, see + U(https://docs.microsoft.com/en-us/windows/security/threat-protection/security-policy-settings/force-shutdown-from-a-remote-system) + for more information. +seealso: +- module: reboot +author: +- Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +- name: Reboot the machine with all defaults + win_reboot: + +- name: Reboot a slow machine that might have lots of updates to apply + win_reboot: + reboot_timeout: 3600 + +# Install a Windows feature and reboot if necessary +- name: Install IIS Web-Server + win_feature: + name: Web-Server + register: iis_install + +- name: Reboot when Web-Server feature requires it + win_reboot: + when: iis_install.reboot_required + +# One way to ensure the system is reliable, is to set WinRM to a delayed startup +- name: Ensure WinRM starts when the system has settled and is ready to work reliably + win_service: + name: WinRM + start_mode: delayed + + +# Additionally, you can add a delay before running the next task +- name: Reboot a machine that takes time to settle after being booted + win_reboot: + post_reboot_delay: 120 + +# Or you can make win_reboot validate exactly what you need to work before running the next task +- name: Validate that the netlogon service has started, before running the next task + win_reboot: + test_command: 'exit (Get-Service -Name Netlogon).Status -ne "Running"' +''' + +RETURN = r''' +rebooted: + description: True if the machine was rebooted. + returned: always + type: bool + sample: true +elapsed: + description: The number of seconds that elapsed waiting for the system to be rebooted. + returned: always + type: float + sample: 23.2 +''' diff --git a/test/support/windows-integration/plugins/modules/win_regedit.ps1 b/test/support/windows-integration/plugins/modules/win_regedit.ps1 new file mode 100644 index 0000000..c56b483 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_regedit.ps1 @@ -0,0 +1,495 @@ +#!powershell + +# Copyright: (c) 2015, Adam Keech <akeech@chathamfinancial.com> +# Copyright: (c) 2015, Josh Ludwig <jludwig@chathamfinancial.com> +# Copyright: (c) 2017, Jordan Borean <jborean93@gmail.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.PrivilegeUtil + +$params = Parse-Args -arguments $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$path = Get-AnsibleParam -obj $params -name "path" -type "str" -failifempty $true -aliases "key" +$name = Get-AnsibleParam -obj $params -name "name" -type "str" -aliases "entry","value" +$data = Get-AnsibleParam -obj $params -name "data" +$type = Get-AnsibleParam -obj $params -name "type" -type "str" -default "string" -validateset "none","binary","dword","expandstring","multistring","string","qword" -aliases "datatype" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent" +$delete_key = Get-AnsibleParam -obj $params -name "delete_key" -type "bool" -default $true +$hive = Get-AnsibleParam -obj $params -name "hive" -type "path" + +$result = @{ + changed = $false + data_changed = $false + data_type_changed = $false +} + +if ($diff_mode) { + $result.diff = @{ + before = "" + after = "" + } +} + +$registry_util = @' +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; + +namespace Ansible.WinRegedit +{ + internal class NativeMethods + { + [DllImport("advapi32.dll", CharSet = CharSet.Unicode)] + public static extern int RegLoadKeyW( + UInt32 hKey, + string lpSubKey, + string lpFile); + + [DllImport("advapi32.dll", CharSet = CharSet.Unicode)] + public static extern int RegUnLoadKeyW( + UInt32 hKey, + string lpSubKey); + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + public class Hive : IDisposable + { + private const UInt32 SCOPE = 0x80000002; // HKLM + private string hiveKey; + private bool loaded = false; + + public Hive(string hiveKey, string hivePath) + { + this.hiveKey = hiveKey; + int ret = NativeMethods.RegLoadKeyW(SCOPE, hiveKey, hivePath); + if (ret != 0) + throw new Win32Exception(ret, String.Format("Failed to load registry hive at {0}", hivePath)); + loaded = true; + } + + public static void UnloadHive(string hiveKey) + { + int ret = NativeMethods.RegUnLoadKeyW(SCOPE, hiveKey); + if (ret != 0) + throw new Win32Exception(ret, String.Format("Failed to unload registry hive at {0}", hiveKey)); + } + + public void Dispose() + { + if (loaded) + { + // Make sure the garbage collector disposes all unused handles and waits until it is complete + GC.Collect(); + GC.WaitForPendingFinalizers(); + + UnloadHive(hiveKey); + loaded = false; + } + GC.SuppressFinalize(this); + } + ~Hive() { this.Dispose(); } + } +} +'@ + +# fire a warning if the property name isn't specified, the (Default) key ($null) can only be a string +if ($null -eq $name -and $type -ne "string") { + Add-Warning -obj $result -message "the data type when name is not specified can only be 'string', the type has automatically been converted" + $type = "string" +} + +# Check that the registry path is in PSDrive format: HKCC, HKCR, HKCU, HKLM, HKU +if ($path -notmatch "^HK(CC|CR|CU|LM|U):\\") { + Fail-Json $result "path: $path is not a valid powershell path, see module documentation for examples." +} + +# Add a warning if the path does not contains a \ and is not the leaf path +$registry_path = (Split-Path -Path $path -NoQualifier).Substring(1) # removes the hive: and leading \ +$registry_leaf = Split-Path -Path $path -Leaf +if ($registry_path -ne $registry_leaf -and -not $registry_path.Contains('\')) { + $msg = "path is not using '\' as a separator, support for '/' as a separator will be removed in a future Ansible version" + Add-DeprecationWarning -obj $result -message $msg -version 2.12 + $registry_path = $registry_path.Replace('/', '\') +} + +# Simplified version of Convert-HexStringToByteArray from +# https://cyber-defense.sans.org/blog/2010/02/11/powershell-byte-array-hex-convert +# Expects a hex in the format you get when you run reg.exe export, +# and converts to a byte array so powershell can modify binary registry entries +# import format is like 'hex:be,ef,be,ef,be,ef,be,ef,be,ef' +Function Convert-RegExportHexStringToByteArray($string) { + # Remove 'hex:' from the front of the string if present + $string = $string.ToLower() -replace '^hex\:','' + + # Remove whitespace and any other non-hex crud. + $string = $string -replace '[^a-f0-9\\,x\-\:]','' + + # Turn commas into colons + $string = $string -replace ',',':' + + # Maybe there's nothing left over to convert... + if ($string.Length -eq 0) { + return ,@() + } + + # Split string with or without colon delimiters. + if ($string.Length -eq 1) { + return ,@([System.Convert]::ToByte($string,16)) + } elseif (($string.Length % 2 -eq 0) -and ($string.IndexOf(":") -eq -1)) { + return ,@($string -split '([a-f0-9]{2})' | foreach-object { if ($_) {[System.Convert]::ToByte($_,16)}}) + } elseif ($string.IndexOf(":") -ne -1) { + return ,@($string -split ':+' | foreach-object {[System.Convert]::ToByte($_,16)}) + } else { + return ,@() + } +} + +Function Compare-RegistryProperties($existing, $new) { + # Outputs $true if the property values don't match + if ($existing -is [Array]) { + (Compare-Object -ReferenceObject $existing -DifferenceObject $new -SyncWindow 0).Length -ne 0 + } else { + $existing -cne $new + } +} + +Function Get-DiffValue { + param( + [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryValueKind]$Type, + [Parameter(Mandatory=$true)][Object]$Value + ) + + $diff = @{ type = $Type.ToString(); value = $Value } + + $enum = [Microsoft.Win32.RegistryValueKind] + if ($Type -in @($enum::Binary, $enum::None)) { + $diff.value = [System.Collections.Generic.List`1[String]]@() + foreach ($dec_value in $Value) { + $diff.value.Add("0x{0:x2}" -f $dec_value) + } + } elseif ($Type -eq $enum::DWord) { + $diff.value = "0x{0:x8}" -f $Value + } elseif ($Type -eq $enum::QWord) { + $diff.value = "0x{0:x16}" -f $Value + } + + return $diff +} + +Function Set-StateAbsent { + param( + # Used for diffs and exception messages to match up against Ansible input + [Parameter(Mandatory=$true)][String]$PrintPath, + [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryKey]$Hive, + [Parameter(Mandatory=$true)][String]$Path, + [String]$Name, + [Switch]$DeleteKey + ) + + $key = $Hive.OpenSubKey($Path, $true) + if ($null -eq $key) { + # Key does not exist, no need to delete anything + return + } + + try { + if ($DeleteKey -and -not $Name) { + # delete_key=yes is set and name is null/empty, so delete the entire key + $key.Dispose() + $key = $null + if (-not $check_mode) { + try { + $Hive.DeleteSubKeyTree($Path, $false) + } catch { + Fail-Json -obj $result -message "failed to delete registry key at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.before = @{$PrintPath = @{}} + $result.diff.after = @{} + } + } else { + # delete_key=no or name is not null/empty, delete the property not the full key + $property = $key.GetValue($Name) + if ($null -eq $property) { + # property does not exist + return + } + $property_type = $key.GetValueKind($Name) # used for the diff + + if (-not $check_mode) { + try { + $key.DeleteValue($Name) + } catch { + Fail-Json -obj $result -message "failed to delete registry property '$Name' at $($PrintPath): $($_.Exception.Message)" + } + } + + $result.changed = $true + if ($diff_mode) { + $diff_value = Get-DiffValue -Type $property_type -Value $property + $result.diff.before = @{ $PrintPath = @{ $Name = $diff_value } } + $result.diff.after = @{ $PrintPath = @{} } + } + } + } finally { + if ($key) { + $key.Dispose() + } + } +} + +Function Set-StatePresent { + param( + [Parameter(Mandatory=$true)][String]$PrintPath, + [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryKey]$Hive, + [Parameter(Mandatory=$true)][String]$Path, + [String]$Name, + [Object]$Data, + [Microsoft.Win32.RegistryValueKind]$Type + ) + + $key = $Hive.OpenSubKey($Path, $true) + try { + if ($null -eq $key) { + # the key does not exist, create it so the next steps work + if (-not $check_mode) { + try { + $key = $Hive.CreateSubKey($Path) + } catch { + Fail-Json -obj $result -message "failed to create registry key at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.before = @{} + $result.diff.after = @{$PrintPath = @{}} + } + } elseif ($diff_mode) { + # Make sure the diff is in an expected state for the key + $result.diff.before = @{$PrintPath = @{}} + $result.diff.after = @{$PrintPath = @{}} + } + + if ($null -eq $key -or $null -eq $Data) { + # Check mode and key was created above, we cannot do any more work, or $Data is $null which happens when + # we create a new key but haven't explicitly set the data + return + } + + $property = $key.GetValue($Name, $null, [Microsoft.Win32.RegistryValueOptions]::DoNotExpandEnvironmentNames) + if ($null -ne $property) { + # property exists, need to compare the values and type + $existing_type = $key.GetValueKind($name) + $change_value = $false + + if ($Type -ne $existing_type) { + $change_value = $true + $result.data_type_changed = $true + $data_mismatch = Compare-RegistryProperties -existing $property -new $Data + if ($data_mismatch) { + $result.data_changed = $true + } + } else { + $data_mismatch = Compare-RegistryProperties -existing $property -new $Data + if ($data_mismatch) { + $change_value = $true + $result.data_changed = $true + } + } + + if ($change_value) { + if (-not $check_mode) { + try { + $key.SetValue($Name, $Data, $Type) + } catch { + Fail-Json -obj $result -message "failed to change registry property '$Name' at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.before.$PrintPath.$Name = Get-DiffValue -Type $existing_type -Value $property + $result.diff.after.$PrintPath.$Name = Get-DiffValue -Type $Type -Value $Data + } + } elseif ($diff_mode) { + $diff_value = Get-DiffValue -Type $existing_type -Value $property + $result.diff.before.$PrintPath.$Name = $diff_value + $result.diff.after.$PrintPath.$Name = $diff_value + } + } else { + # property doesn't exist just create a new one + if (-not $check_mode) { + try { + $key.SetValue($Name, $Data, $Type) + } catch { + Fail-Json -obj $result -message "failed to create registry property '$Name' at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.after.$PrintPath.$Name = Get-DiffValue -Type $Type -Value $Data + } + } + } finally { + if ($key) { + $key.Dispose() + } + } +} + +# convert property names "" to $null as "" refers to (Default) +if ($name -eq "") { + $name = $null +} + +# convert the data to the required format +if ($type -in @("binary", "none")) { + if ($null -eq $data) { + $data = "" + } + + # convert the data from string to byte array if in hex: format + if ($data -is [String]) { + $data = [byte[]](Convert-RegExportHexStringToByteArray -string $data) + } elseif ($data -is [Int]) { + if ($data -gt 255) { + Fail-Json $result "cannot convert binary data '$data' to byte array, please specify this value as a yaml byte array or a comma separated hex value string" + } + $data = [byte[]]@([byte]$data) + } elseif ($data -is [Array]) { + $data = [byte[]]$data + } +} elseif ($type -in @("dword", "qword")) { + # dword's and dword's don't allow null values, set to 0 + if ($null -eq $data) { + $data = 0 + } + + if ($data -is [String]) { + # if the data is a string we need to convert it to an unsigned int64 + # it needs to be unsigned as Ansible passes in an unsigned value while + # powershell uses a signed data type. The value will then be converted + # below + $data = [UInt64]$data + } + + if ($type -eq "dword") { + if ($data -gt [UInt32]::MaxValue) { + Fail-Json $result "data cannot be larger than 0xffffffff when type is dword" + } elseif ($data -gt [Int32]::MaxValue) { + # when dealing with larger int32 (> 2147483647 or 0x7FFFFFFF) powershell + # automatically converts it to a signed int64. We need to convert this to + # signed int32 by parsing the hex string value. + $data = "0x$("{0:x}" -f $data)" + } + $data = [Int32]$data + } else { + if ($data -gt [UInt64]::MaxValue) { + Fail-Json $result "data cannot be larger than 0xffffffffffffffff when type is qword" + } elseif ($data -gt [Int64]::MaxValue) { + $data = "0x$("{0:x}" -f $data)" + } + $data = [Int64]$data + } +} elseif ($type -in @("string", "expandstring") -and $name) { + # a null string or expandstring must be empty quotes + # Only do this if $name has been defined (not the default key) + if ($null -eq $data) { + $data = "" + } +} elseif ($type -eq "multistring") { + # convert the data for a multistring to a String[] array + if ($null -eq $data) { + $data = [String[]]@() + } elseif ($data -isnot [Array]) { + $new_data = New-Object -TypeName String[] -ArgumentList 1 + $new_data[0] = $data.ToString([CultureInfo]::InvariantCulture) + $data = $new_data + } else { + $new_data = New-Object -TypeName String[] -ArgumentList $data.Count + foreach ($entry in $data) { + $new_data[$data.IndexOf($entry)] = $entry.ToString([CultureInfo]::InvariantCulture) + } + $data = $new_data + } +} + +# convert the type string to the .NET class +$type = [System.Enum]::Parse([Microsoft.Win32.RegistryValueKind], $type, $true) + +$registry_hive = switch(Split-Path -Path $path -Qualifier) { + "HKCR:" { [Microsoft.Win32.Registry]::ClassesRoot } + "HKCC:" { [Microsoft.Win32.Registry]::CurrentConfig } + "HKCU:" { [Microsoft.Win32.Registry]::CurrentUser } + "HKLM:" { [Microsoft.Win32.Registry]::LocalMachine } + "HKU:" { [Microsoft.Win32.Registry]::Users } +} +$loaded_hive = $null +try { + if ($hive) { + if (-not (Test-Path -LiteralPath $hive)) { + Fail-Json -obj $result -message "hive at path '$hive' is not valid or accessible, cannot load hive" + } + + $original_tmp = $env:TMP + $env:TMP = $_remote_tmp + Add-Type -TypeDefinition $registry_util + $env:TMP = $original_tmp + + try { + Set-AnsiblePrivilege -Name SeBackupPrivilege -Value $true + Set-AnsiblePrivilege -Name SeRestorePrivilege -Value $true + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "failed to enable SeBackupPrivilege and SeRestorePrivilege for the current process: $($_.Exception.Message)" + } + + if (Test-Path -Path HKLM:\ANSIBLE) { + Add-Warning -obj $result -message "hive already loaded at HKLM:\ANSIBLE, had to unload hive for win_regedit to continue" + try { + [Ansible.WinRegedit.Hive]::UnloadHive("ANSIBLE") + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "failed to unload registry hive HKLM:\ANSIBLE from $($hive): $($_.Exception.Message)" + } + } + + try { + $loaded_hive = New-Object -TypeName Ansible.WinRegedit.Hive -ArgumentList "ANSIBLE", $hive + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "failed to load registry hive from '$hive' to HKLM:\ANSIBLE: $($_.Exception.Message)" + } + } + + if ($state -eq "present") { + Set-StatePresent -PrintPath $path -Hive $registry_hive -Path $registry_path -Name $name -Data $data -Type $type + } else { + Set-StateAbsent -PrintPath $path -Hive $registry_hive -Path $registry_path -Name $name -DeleteKey:$delete_key + } +} finally { + $registry_hive.Dispose() + if ($loaded_hive) { + $loaded_hive.Dispose() + } +} + +Exit-Json $result + diff --git a/test/support/windows-integration/plugins/modules/win_regedit.py b/test/support/windows-integration/plugins/modules/win_regedit.py new file mode 100644 index 0000000..2c0fff7 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_regedit.py @@ -0,0 +1,210 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Adam Keech <akeech@chathamfinancial.com> +# Copyright: (c) 2015, Josh Ludwig <jludwig@chathamfinancial.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + + +DOCUMENTATION = r''' +--- +module: win_regedit +version_added: '2.0' +short_description: Add, change, or remove registry keys and values +description: +- Add, modify or remove registry keys and values. +- More information about the windows registry from Wikipedia + U(https://en.wikipedia.org/wiki/Windows_Registry). +options: + path: + description: + - Name of the registry path. + - 'Should be in one of the following registry hives: HKCC, HKCR, HKCU, + HKLM, HKU.' + type: str + required: yes + aliases: [ key ] + name: + description: + - Name of the registry entry in the above C(path) parameters. + - If not provided, or empty then the '(Default)' property for the key will + be used. + type: str + aliases: [ entry, value ] + data: + description: + - Value of the registry entry C(name) in C(path). + - If not specified then the value for the property will be null for the + corresponding C(type). + - Binary and None data should be expressed in a yaml byte array or as comma + separated hex values. + - An easy way to generate this is to run C(regedit.exe) and use the + I(export) option to save the registry values to a file. + - In the exported file, binary value will look like C(hex:be,ef,be,ef), the + C(hex:) prefix is optional. + - DWORD and QWORD values should either be represented as a decimal number + or a hex value. + - Multistring values should be passed in as a list. + - See the examples for more details on how to format this data. + type: str + type: + description: + - The registry value data type. + type: str + choices: [ binary, dword, expandstring, multistring, string, qword ] + default: string + aliases: [ datatype ] + state: + description: + - The state of the registry entry. + type: str + choices: [ absent, present ] + default: present + delete_key: + description: + - When C(state) is 'absent' then this will delete the entire key. + - If C(no) then it will only clear out the '(Default)' property for + that key. + type: bool + default: yes + version_added: '2.4' + hive: + description: + - A path to a hive key like C:\Users\Default\NTUSER.DAT to load in the + registry. + - This hive is loaded under the HKLM:\ANSIBLE key which can then be used + in I(name) like any other path. + - This can be used to load the default user profile registry hive or any + other hive saved as a file. + - Using this function requires the user to have the C(SeRestorePrivilege) + and C(SeBackupPrivilege) privileges enabled. + type: path + version_added: '2.5' +notes: +- Check-mode C(-C/--check) and diff output C(-D/--diff) are supported, so that you can test every change against the active configuration before + applying changes. +- Beware that some registry hives (C(HKEY_USERS) in particular) do not allow to create new registry paths in the root folder. +- Since ansible 2.4, when checking if a string registry value has changed, a case-sensitive test is used. Previously the test was case-insensitive. +seealso: +- module: win_reg_stat +- module: win_regmerge +author: +- Adam Keech (@smadam813) +- Josh Ludwig (@joshludwig) +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Create registry path MyCompany + win_regedit: + path: HKCU:\Software\MyCompany + +- name: Add or update registry path MyCompany, with entry 'hello', and containing 'world' + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: world + +- name: Add or update registry path MyCompany, with dword entry 'hello', and containing 1337 as the decimal value + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: 1337 + type: dword + +- name: Add or update registry path MyCompany, with dword entry 'hello', and containing 0xff2500ae as the hex value + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: 0xff2500ae + type: dword + +- name: Add or update registry path MyCompany, with binary entry 'hello', and containing binary data in hex-string format + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: hex:be,ef,be,ef,be,ef,be,ef,be,ef + type: binary + +- name: Add or update registry path MyCompany, with binary entry 'hello', and containing binary data in yaml format + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: [0xbe,0xef,0xbe,0xef,0xbe,0xef,0xbe,0xef,0xbe,0xef] + type: binary + +- name: Add or update registry path MyCompany, with expand string entry 'hello' + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: '%appdata%\local' + type: expandstring + +- name: Add or update registry path MyCompany, with multi string entry 'hello' + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: ['hello', 'world'] + type: multistring + +- name: Disable keyboard layout hotkey for all users (changes existing) + win_regedit: + path: HKU:\.DEFAULT\Keyboard Layout\Toggle + name: Layout Hotkey + data: 3 + type: dword + +- name: Disable language hotkey for current users (adds new) + win_regedit: + path: HKCU:\Keyboard Layout\Toggle + name: Language Hotkey + data: 3 + type: dword + +- name: Remove registry path MyCompany (including all entries it contains) + win_regedit: + path: HKCU:\Software\MyCompany + state: absent + delete_key: yes + +- name: Clear the existing (Default) entry at path MyCompany + win_regedit: + path: HKCU:\Software\MyCompany + state: absent + delete_key: no + +- name: Remove entry 'hello' from registry path MyCompany + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + state: absent + +- name: Change default mouse trailing settings for new users + win_regedit: + path: HKLM:\ANSIBLE\Control Panel\Mouse + name: MouseTrails + data: 10 + type: str + state: present + hive: C:\Users\Default\NTUSER.dat +''' + +RETURN = r''' +data_changed: + description: Whether this invocation changed the data in the registry value. + returned: success + type: bool + sample: false +data_type_changed: + description: Whether this invocation changed the datatype of the registry value. + returned: success + type: bool + sample: true +''' diff --git a/test/support/windows-integration/plugins/modules/win_shell.ps1 b/test/support/windows-integration/plugins/modules/win_shell.ps1 new file mode 100644 index 0000000..54aef8d --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_shell.ps1 @@ -0,0 +1,138 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.CommandUtil +#Requires -Module Ansible.ModuleUtils.FileUtil + +# TODO: add check mode support + +Set-StrictMode -Version 2 +$ErrorActionPreference = "Stop" + +# Cleanse CLIXML from stderr (sift out error stream data, discard others for now) +Function Cleanse-Stderr($raw_stderr) { + Try { + # NB: this regex isn't perfect, but is decent at finding CLIXML amongst other stderr noise + If($raw_stderr -match "(?s)(?<prenoise1>.*)#< CLIXML(?<prenoise2>.*)(?<clixml><Objs.+</Objs>)(?<postnoise>.*)") { + $clixml = [xml]$matches["clixml"] + + $merged_stderr = "{0}{1}{2}{3}" -f @( + $matches["prenoise1"], + $matches["prenoise2"], + # filter out just the Error-tagged strings for now, and zap embedded CRLF chars + ($clixml.Objs.ChildNodes | Where-Object { $_.Name -eq 'S' } | Where-Object { $_.S -eq 'Error' } | ForEach-Object { $_.'#text'.Replace('_x000D__x000A_','') } | Out-String), + $matches["postnoise"]) | Out-String + + return $merged_stderr.Trim() + + # FUTURE: parse/return other streams + } + Else { + $raw_stderr + } + } + Catch { + "***EXCEPTION PARSING CLIXML: $_***" + $raw_stderr + } +} + +$params = Parse-Args $args -supports_check_mode $false + +$raw_command_line = Get-AnsibleParam -obj $params -name "_raw_params" -type "str" -failifempty $true +$chdir = Get-AnsibleParam -obj $params -name "chdir" -type "path" +$executable = Get-AnsibleParam -obj $params -name "executable" -type "path" +$creates = Get-AnsibleParam -obj $params -name "creates" -type "path" +$removes = Get-AnsibleParam -obj $params -name "removes" -type "path" +$stdin = Get-AnsibleParam -obj $params -name "stdin" -type "str" +$no_profile = Get-AnsibleParam -obj $params -name "no_profile" -type "bool" -default $false +$output_encoding_override = Get-AnsibleParam -obj $params -name "output_encoding_override" -type "str" + +$raw_command_line = $raw_command_line.Trim() + +$result = @{ + changed = $true + cmd = $raw_command_line +} + +if ($creates -and $(Test-AnsiblePath -Path $creates)) { + Exit-Json @{msg="skipped, since $creates exists";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +if ($removes -and -not $(Test-AnsiblePath -Path $removes)) { + Exit-Json @{msg="skipped, since $removes does not exist";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +$exec_args = $null +If(-not $executable -or $executable -eq "powershell") { + $exec_application = "powershell.exe" + + # force input encoding to preamble-free UTF8 so PS sub-processes (eg, Start-Job) don't blow up + $raw_command_line = "[Console]::InputEncoding = New-Object Text.UTF8Encoding `$false; " + $raw_command_line + + # Base64 encode the command so we don't have to worry about the various levels of escaping + $encoded_command = [Convert]::ToBase64String([System.Text.Encoding]::Unicode.GetBytes($raw_command_line)) + + if ($stdin) { + $exec_args = "-encodedcommand $encoded_command" + } else { + $exec_args = "-noninteractive -encodedcommand $encoded_command" + } + + if ($no_profile) { + $exec_args = "-noprofile $exec_args" + } +} +Else { + # FUTURE: support arg translation from executable (or executable_args?) to process arguments for arbitrary interpreter? + $exec_application = $executable + if (-not ($exec_application.EndsWith(".exe"))) { + $exec_application = "$($exec_application).exe" + } + $exec_args = "/c $raw_command_line" +} + +$command = "`"$exec_application`" $exec_args" +$run_command_arg = @{ + command = $command +} +if ($chdir) { + $run_command_arg['working_directory'] = $chdir +} +if ($stdin) { + $run_command_arg['stdin'] = $stdin +} +if ($output_encoding_override) { + $run_command_arg['output_encoding_override'] = $output_encoding_override +} + +$start_datetime = [DateTime]::UtcNow +try { + $command_result = Run-Command @run_command_arg +} catch { + $result.changed = $false + try { + $result.rc = $_.Exception.NativeErrorCode + } catch { + $result.rc = 2 + } + Fail-Json -obj $result -message $_.Exception.Message +} + +# TODO: decode CLIXML stderr output (and other streams?) +$result.stdout = $command_result.stdout +$result.stderr = Cleanse-Stderr $command_result.stderr +$result.rc = $command_result.rc + +$end_datetime = [DateTime]::UtcNow +$result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff") + +If ($result.rc -ne 0) { + Fail-Json -obj $result -message "non-zero return code" +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_shell.py b/test/support/windows-integration/plugins/modules/win_shell.py new file mode 100644 index 0000000..ee2cd76 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_shell.py @@ -0,0 +1,167 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Ansible, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_shell +short_description: Execute shell commands on target hosts +version_added: 2.2 +description: + - The C(win_shell) module takes the command name followed by a list of space-delimited arguments. + It is similar to the M(win_command) module, but runs + the command via a shell (defaults to PowerShell) on the target host. + - For non-Windows targets, use the M(shell) module instead. +options: + free_form: + description: + - The C(win_shell) module takes a free form command to run. + - There is no parameter actually named 'free form'. See the examples! + type: str + required: yes + creates: + description: + - A path or path filter pattern; when the referenced path exists on the target host, the task will be skipped. + type: path + removes: + description: + - A path or path filter pattern; when the referenced path B(does not) exist on the target host, the task will be skipped. + type: path + chdir: + description: + - Set the specified path as the current working directory before executing a command + type: path + executable: + description: + - Change the shell used to execute the command (eg, C(cmd)). + - The target shell must accept a C(/c) parameter followed by the raw command line to be executed. + type: path + stdin: + description: + - Set the stdin of the command directly to the specified value. + type: str + version_added: '2.5' + no_profile: + description: + - Do not load the user profile before running a command. This is only valid + when using PowerShell as the executable. + type: bool + default: no + version_added: '2.8' + output_encoding_override: + description: + - This option overrides the encoding of stdout/stderr output. + - You can use this option when you need to run a command which ignore the console's codepage. + - You should only need to use this option in very rare circumstances. + - This value can be any valid encoding C(Name) based on the output of C([System.Text.Encoding]::GetEncodings()). + See U(https://docs.microsoft.com/dotnet/api/system.text.encoding.getencodings). + type: str + version_added: '2.10' +notes: + - If you want to run an executable securely and predictably, it may be + better to use the M(win_command) module instead. Best practices when writing + playbooks will follow the trend of using M(win_command) unless C(win_shell) is + explicitly required. When running ad-hoc commands, use your best judgement. + - WinRM will not return from a command execution until all child processes created have exited. + Thus, it is not possible to use C(win_shell) to spawn long-running child or background processes. + Consider creating a Windows service for managing background processes. +seealso: +- module: psexec +- module: raw +- module: script +- module: shell +- module: win_command +- module: win_psexec +author: + - Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +# Execute a command in the remote shell; stdout goes to the specified +# file on the remote. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt + +# Change the working directory to somedir/ before executing the command. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt chdir=C:\somedir + +# You can also use the 'args' form to provide the options. This command +# will change the working directory to somedir/ and will only run when +# somedir/somelog.txt doesn't exist. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt + args: + chdir: C:\somedir + creates: C:\somelog.txt + +# Run a command under a non-Powershell interpreter (cmd in this case) +- win_shell: echo %HOMEDIR% + args: + executable: cmd + register: homedir_out + +- name: Run multi-lined shell commands + win_shell: | + $value = Test-Path -Path C:\temp + if ($value) { + Remove-Item -Path C:\temp -Force + } + New-Item -Path C:\temp -ItemType Directory + +- name: Retrieve the input based on stdin + win_shell: '$string = [Console]::In.ReadToEnd(); Write-Output $string.Trim()' + args: + stdin: Input message +''' + +RETURN = r''' +msg: + description: Changed. + returned: always + type: bool + sample: true +start: + description: The command execution start time. + returned: always + type: str + sample: '2016-02-25 09:18:26.429568' +end: + description: The command execution end time. + returned: always + type: str + sample: '2016-02-25 09:18:26.755339' +delta: + description: The command execution delta time. + returned: always + type: str + sample: '0:00:00.325771' +stdout: + description: The command standard output. + returned: always + type: str + sample: 'Clustering node rabbit@slave1 with rabbit@master ...' +stderr: + description: The command standard error. + returned: always + type: str + sample: 'ls: cannot access foo: No such file or directory' +cmd: + description: The command executed by the task. + returned: always + type: str + sample: 'rabbitmqctl join_cluster rabbit@master' +rc: + description: The command return code (0 means success). + returned: always + type: int + sample: 0 +stdout_lines: + description: The command standard output split in lines. + returned: always + type: list + sample: [u'Clustering node rabbit@slave1 with rabbit@master ...'] +''' diff --git a/test/support/windows-integration/plugins/modules/win_stat.ps1 b/test/support/windows-integration/plugins/modules/win_stat.ps1 new file mode 100644 index 0000000..071eb11 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_stat.ps1 @@ -0,0 +1,186 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.FileUtil +#Requires -Module Ansible.ModuleUtils.LinkUtil + +function ConvertTo-Timestamp($start_date, $end_date) { + if ($start_date -and $end_date) { + return (New-TimeSpan -Start $start_date -End $end_date).TotalSeconds + } +} + +function Get-FileChecksum($path, $algorithm) { + switch ($algorithm) { + 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider } + 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider } + 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider } + 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider } + 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider } + default { Fail-Json -obj $result -message "Unsupported hash algorithm supplied '$algorithm'" } + } + + $fp = [System.IO.File]::Open($path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, [System.IO.FileShare]::ReadWrite) + try { + $hash = [System.BitConverter]::ToString($sp.ComputeHash($fp)).Replace("-", "").ToLower() + } finally { + $fp.Dispose() + } + + return $hash +} + +function Get-FileInfo { + param([String]$Path, [Switch]$Follow) + + $info = Get-AnsibleItem -Path $Path -ErrorAction SilentlyContinue + $link_info = $null + if ($null -ne $info) { + try { + $link_info = Get-Link -link_path $info.FullName + } catch { + $module.Warn("Failed to check/get link info for file: $($_.Exception.Message)") + } + + # If follow=true we want to follow the link all the way back to root object + if ($Follow -and $null -ne $link_info -and $link_info.Type -in @("SymbolicLink", "JunctionPoint")) { + $info, $link_info = Get-FileInfo -Path $link_info.AbsolutePath -Follow + } + } + + return $info, $link_info +} + +$spec = @{ + options = @{ + path = @{ type='path'; required=$true; aliases=@( 'dest', 'name' ) } + get_checksum = @{ type='bool'; default=$true } + checksum_algorithm = @{ type='str'; default='sha1'; choices=@( 'md5', 'sha1', 'sha256', 'sha384', 'sha512' ) } + follow = @{ type='bool'; default=$false } + } + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$path = $module.Params.path +$get_checksum = $module.Params.get_checksum +$checksum_algorithm = $module.Params.checksum_algorithm +$follow = $module.Params.follow + +$module.Result.stat = @{ exists=$false } + +Load-LinkUtils +$info, $link_info = Get-FileInfo -Path $path -Follow:$follow +If ($null -ne $info) { + $epoch_date = Get-Date -Date "01/01/1970" + $attributes = @() + foreach ($attribute in ($info.Attributes -split ',')) { + $attributes += $attribute.Trim() + } + + # default values that are always set, specific values are set below this + # but are kept commented for easier readability + $stat = @{ + exists = $true + attributes = $info.Attributes.ToString() + isarchive = ($attributes -contains "Archive") + isdir = $false + ishidden = ($attributes -contains "Hidden") + isjunction = $false + islnk = $false + isreadonly = ($attributes -contains "ReadOnly") + isreg = $false + isshared = $false + nlink = 1 # Number of links to the file (hard links), overriden below if islnk + # lnk_target = islnk or isjunction Target of the symlink. Note that relative paths remain relative + # lnk_source = islnk os isjunction Target of the symlink normalized for the remote filesystem + hlnk_targets = @() + creationtime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.CreationTime) + lastaccesstime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastAccessTime) + lastwritetime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastWriteTime) + # size = a file and directory - calculated below + path = $info.FullName + filename = $info.Name + # extension = a file + # owner = set outsite this dict in case it fails + # sharename = a directory and isshared is True + # checksum = a file and get_checksum: True + } + try { + $stat.owner = $info.GetAccessControl().Owner + } catch { + # may not have rights, historical behaviour was to just set to $null + # due to ErrorActionPreference being set to "Continue" + $stat.owner = $null + } + + # values that are set according to the type of file + if ($info.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) { + $stat.isdir = $true + $share_info = Get-CimInstance -ClassName Win32_Share -Filter "Path='$($stat.path -replace '\\', '\\')'" + if ($null -ne $share_info) { + $stat.isshared = $true + $stat.sharename = $share_info.Name + } + + try { + $size = 0 + foreach ($file in $info.EnumerateFiles("*", [System.IO.SearchOption]::AllDirectories)) { + $size += $file.Length + } + $stat.size = $size + } catch { + $stat.size = 0 + } + } else { + $stat.extension = $info.Extension + $stat.isreg = $true + $stat.size = $info.Length + + if ($get_checksum) { + try { + $stat.checksum = Get-FileChecksum -path $path -algorithm $checksum_algorithm + } catch { + $module.FailJson("Failed to get hash of file, set get_checksum to False to ignore this error: $($_.Exception.Message)", $_) + } + } + } + + # Get symbolic link, junction point, hard link info + if ($null -ne $link_info) { + switch ($link_info.Type) { + "SymbolicLink" { + $stat.islnk = $true + $stat.isreg = $false + $stat.lnk_target = $link_info.TargetPath + $stat.lnk_source = $link_info.AbsolutePath + break + } + "JunctionPoint" { + $stat.isjunction = $true + $stat.isreg = $false + $stat.lnk_target = $link_info.TargetPath + $stat.lnk_source = $link_info.AbsolutePath + break + } + "HardLink" { + $stat.lnk_type = "hard" + $stat.nlink = $link_info.HardTargets.Count + + # remove current path from the targets + $hlnk_targets = $link_info.HardTargets | Where-Object { $_ -ne $stat.path } + $stat.hlnk_targets = @($hlnk_targets) + break + } + } + } + + $module.Result.stat = $stat +} + +$module.ExitJson() + diff --git a/test/support/windows-integration/plugins/modules/win_stat.py b/test/support/windows-integration/plugins/modules/win_stat.py new file mode 100644 index 0000000..0676b5b --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_stat.py @@ -0,0 +1,236 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_stat +version_added: "1.7" +short_description: Get information about Windows files +description: + - Returns information about a Windows file. + - For non-Windows targets, use the M(stat) module instead. +options: + path: + description: + - The full path of the file/object to get the facts of; both forward and + back slashes are accepted. + type: path + required: yes + aliases: [ dest, name ] + get_checksum: + description: + - Whether to return a checksum of the file (default sha1) + type: bool + default: yes + version_added: "2.1" + checksum_algorithm: + description: + - Algorithm to determine checksum of file. + - Will throw an error if the host is unable to use specified algorithm. + type: str + default: sha1 + choices: [ md5, sha1, sha256, sha384, sha512 ] + version_added: "2.3" + follow: + description: + - Whether to follow symlinks or junction points. + - In the case of C(path) pointing to another link, then that will + be followed until no more links are found. + type: bool + default: no + version_added: "2.8" +seealso: +- module: stat +- module: win_acl +- module: win_file +- module: win_owner +author: +- Chris Church (@cchurch) +''' + +EXAMPLES = r''' +- name: Obtain information about a file + win_stat: + path: C:\foo.ini + register: file_info + +- name: Obtain information about a folder + win_stat: + path: C:\bar + register: folder_info + +- name: Get MD5 checksum of a file + win_stat: + path: C:\foo.ini + get_checksum: yes + checksum_algorithm: md5 + register: md5_checksum + +- debug: + var: md5_checksum.stat.checksum + +- name: Get SHA1 checksum of file + win_stat: + path: C:\foo.ini + get_checksum: yes + register: sha1_checksum + +- debug: + var: sha1_checksum.stat.checksum + +- name: Get SHA256 checksum of file + win_stat: + path: C:\foo.ini + get_checksum: yes + checksum_algorithm: sha256 + register: sha256_checksum + +- debug: + var: sha256_checksum.stat.checksum +''' + +RETURN = r''' +changed: + description: Whether anything was changed + returned: always + type: bool + sample: true +stat: + description: dictionary containing all the stat data + returned: success + type: complex + contains: + attributes: + description: Attributes of the file at path in raw form. + returned: success, path exists + type: str + sample: "Archive, Hidden" + checksum: + description: The checksum of a file based on checksum_algorithm specified. + returned: success, path exist, path is a file, get_checksum == True + checksum_algorithm specified is supported + type: str + sample: 09cb79e8fc7453c84a07f644e441fd81623b7f98 + creationtime: + description: The create time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + exists: + description: If the path exists or not. + returned: success + type: bool + sample: true + extension: + description: The extension of the file at path. + returned: success, path exists, path is a file + type: str + sample: ".ps1" + filename: + description: The name of the file (without path). + returned: success, path exists, path is a file + type: str + sample: foo.ini + hlnk_targets: + description: List of other files pointing to the same file (hard links), excludes the current file. + returned: success, path exists + type: list + sample: + - C:\temp\file.txt + - C:\Windows\update.log + isarchive: + description: If the path is ready for archiving or not. + returned: success, path exists + type: bool + sample: true + isdir: + description: If the path is a directory or not. + returned: success, path exists + type: bool + sample: true + ishidden: + description: If the path is hidden or not. + returned: success, path exists + type: bool + sample: true + isjunction: + description: If the path is a junction point or not. + returned: success, path exists + type: bool + sample: true + islnk: + description: If the path is a symbolic link or not. + returned: success, path exists + type: bool + sample: true + isreadonly: + description: If the path is read only or not. + returned: success, path exists + type: bool + sample: true + isreg: + description: If the path is a regular file. + returned: success, path exists + type: bool + sample: true + isshared: + description: If the path is shared or not. + returned: success, path exists + type: bool + sample: true + lastaccesstime: + description: The last access time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lastwritetime: + description: The last modification time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lnk_source: + description: Target of the symlink normalized for the remote filesystem. + returned: success, path exists and the path is a symbolic link or junction point + type: str + sample: C:\temp\link + lnk_target: + description: Target of the symlink. Note that relative paths remain relative. + returned: success, path exists and the path is a symbolic link or junction point + type: str + sample: ..\link + nlink: + description: Number of links to the file (hard links). + returned: success, path exists + type: int + sample: 1 + owner: + description: The owner of the file. + returned: success, path exists + type: str + sample: BUILTIN\Administrators + path: + description: The full absolute path to the file. + returned: success, path exists, file exists + type: str + sample: C:\foo.ini + sharename: + description: The name of share if folder is shared. + returned: success, path exists, file is a directory and isshared == True + type: str + sample: file-share + size: + description: The size in bytes of a file or folder. + returned: success, path exists, file is not a link + type: int + sample: 1024 +''' diff --git a/test/support/windows-integration/plugins/modules/win_tempfile.ps1 b/test/support/windows-integration/plugins/modules/win_tempfile.ps1 new file mode 100644 index 0000000..9a1a717 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_tempfile.ps1 @@ -0,0 +1,72 @@ +#!powershell + +# Copyright: (c) 2017, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic + +Function New-TempFile { + Param ([string]$path, [string]$prefix, [string]$suffix, [string]$type, [bool]$checkmode) + $temppath = $null + $curerror = $null + $attempt = 0 + + # Since we don't know if the file already exists, we try 5 times with a random name + do { + $attempt += 1 + $randomname = [System.IO.Path]::GetRandomFileName() + $temppath = (Join-Path -Path $path -ChildPath "$prefix$randomname$suffix") + Try { + $file = New-Item -Path $temppath -ItemType $type -WhatIf:$checkmode + # Makes sure we get the full absolute path of the created temp file and not a relative or DOS 8.3 dir + if (-not $checkmode) { + $temppath = $file.FullName + } else { + # Just rely on GetFulLpath for check mode + $temppath = [System.IO.Path]::GetFullPath($temppath) + } + } Catch { + $temppath = $null + $curerror = $_ + } + } until (($null -ne $temppath) -or ($attempt -ge 5)) + + # If it fails 5 times, something is wrong and we have to report the details + if ($null -eq $temppath) { + $module.FailJson("No random temporary file worked in $attempt attempts. Error: $($curerror.Exception.Message)", $curerror) + } + + return $temppath.ToString() +} + +$spec = @{ + options = @{ + path = @{ type='path'; default='%TEMP%'; aliases=@( 'dest' ) } + state = @{ type='str'; default='file'; choices=@( 'directory', 'file') } + prefix = @{ type='str'; default='ansible.' } + suffix = @{ type='str' } + } + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$path = $module.Params.path +$state = $module.Params.state +$prefix = $module.Params.prefix +$suffix = $module.Params.suffix + +# Expand environment variables on non-path types +if ($null -ne $prefix) { + $prefix = [System.Environment]::ExpandEnvironmentVariables($prefix) +} +if ($null -ne $suffix) { + $suffix = [System.Environment]::ExpandEnvironmentVariables($suffix) +} + +$module.Result.changed = $true +$module.Result.state = $state + +$module.Result.path = New-TempFile -Path $path -Prefix $prefix -Suffix $suffix -Type $state -CheckMode $module.CheckMode + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_user.ps1 b/test/support/windows-integration/plugins/modules/win_user.ps1 new file mode 100644 index 0000000..54905cb --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user.ps1 @@ -0,0 +1,273 @@ +#!powershell + +# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.AccessToken +#Requires -Module Ansible.ModuleUtils.Legacy + +######## +$ADS_UF_PASSWD_CANT_CHANGE = 64 +$ADS_UF_DONT_EXPIRE_PASSWD = 65536 + +$adsi = [ADSI]"WinNT://$env:COMPUTERNAME" + +function Get-User($user) { + $adsi.Children | Where-Object {$_.SchemaClassName -eq 'user' -and $_.Name -eq $user } + return +} + +function Get-UserFlag($user, $flag) { + If ($user.UserFlags[0] -band $flag) { + $true + } + Else { + $false + } +} + +function Set-UserFlag($user, $flag) { + $user.UserFlags = ($user.UserFlags[0] -BOR $flag) +} + +function Clear-UserFlag($user, $flag) { + $user.UserFlags = ($user.UserFlags[0] -BXOR $flag) +} + +function Get-Group($grp) { + $adsi.Children | Where-Object { $_.SchemaClassName -eq 'Group' -and $_.Name -eq $grp } + return +} + +Function Test-LocalCredential { + param([String]$Username, [String]$Password) + + try { + $handle = [Ansible.AccessToken.TokenUtil]::LogonUser($Username, $null, $Password, "Network", "Default") + $handle.Dispose() + $valid_credentials = $true + } catch [Ansible.AccessToken.Win32Exception] { + # following errors indicate the creds are correct but the user was + # unable to log on for other reasons, which we don't care about + $success_codes = @( + 0x0000052F, # ERROR_ACCOUNT_RESTRICTION + 0x00000530, # ERROR_INVALID_LOGON_HOURS + 0x00000531, # ERROR_INVALID_WORKSTATION + 0x00000569 # ERROR_LOGON_TYPE_GRANTED + ) + + if ($_.Exception.NativeErrorCode -eq 0x0000052E) { + # ERROR_LOGON_FAILURE - the user or pass was incorrect + $valid_credentials = $false + } elseif ($_.Exception.NativeErrorCode -in $success_codes) { + $valid_credentials = $true + } else { + # an unknown failure, reraise exception + throw $_ + } + } + return $valid_credentials +} + +######## + +$params = Parse-Args $args; + +$result = @{ + changed = $false +}; + +$username = Get-AnsibleParam -obj $params -name "name" -type "str" -failifempty $true +$fullname = Get-AnsibleParam -obj $params -name "fullname" -type "str" +$description = Get-AnsibleParam -obj $params -name "description" -type "str" +$password = Get-AnsibleParam -obj $params -name "password" -type "str" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent","query" +$update_password = Get-AnsibleParam -obj $params -name "update_password" -type "str" -default "always" -validateset "always","on_create" +$password_expired = Get-AnsibleParam -obj $params -name "password_expired" -type "bool" +$password_never_expires = Get-AnsibleParam -obj $params -name "password_never_expires" -type "bool" +$user_cannot_change_password = Get-AnsibleParam -obj $params -name "user_cannot_change_password" -type "bool" +$account_disabled = Get-AnsibleParam -obj $params -name "account_disabled" -type "bool" +$account_locked = Get-AnsibleParam -obj $params -name "account_locked" -type "bool" +$groups = Get-AnsibleParam -obj $params -name "groups" +$groups_action = Get-AnsibleParam -obj $params -name "groups_action" -type "str" -default "replace" -validateset "add","remove","replace" + +If ($null -ne $account_locked -and $account_locked) { + Fail-Json $result "account_locked must be set to 'no' if provided" +} + +If ($null -ne $groups) { + If ($groups -is [System.String]) { + [string[]]$groups = $groups.Split(",") + } + ElseIf ($groups -isnot [System.Collections.IList]) { + Fail-Json $result "groups must be a string or array" + } + $groups = $groups | ForEach-Object { ([string]$_).Trim() } | Where-Object { $_ } + If ($null -eq $groups) { + $groups = @() + } +} + +$user_obj = Get-User $username + +If ($state -eq 'present') { + # Add or update user + try { + If (-not $user_obj) { + $user_obj = $adsi.Create("User", $username) + If ($null -ne $password) { + $user_obj.SetPassword($password) + } + $user_obj.SetInfo() + $result.changed = $true + } + ElseIf (($null -ne $password) -and ($update_password -eq 'always')) { + # ValidateCredentials will fail if either of these are true- just force update... + If($user_obj.AccountDisabled -or $user_obj.PasswordExpired) { + $password_match = $false + } + Else { + try { + $password_match = Test-LocalCredential -Username $username -Password $password + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "Failed to validate the user's credentials: $($_.Exception.Message)" + } + } + + If (-not $password_match) { + $user_obj.SetPassword($password) + $result.changed = $true + } + } + If (($null -ne $fullname) -and ($fullname -ne $user_obj.FullName[0])) { + $user_obj.FullName = $fullname + $result.changed = $true + } + If (($null -ne $description) -and ($description -ne $user_obj.Description[0])) { + $user_obj.Description = $description + $result.changed = $true + } + If (($null -ne $password_expired) -and ($password_expired -ne ($user_obj.PasswordExpired | ConvertTo-Bool))) { + $user_obj.PasswordExpired = If ($password_expired) { 1 } Else { 0 } + $result.changed = $true + } + If (($null -ne $password_never_expires) -and ($password_never_expires -ne (Get-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD))) { + If ($password_never_expires) { + Set-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD + } + Else { + Clear-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD + } + $result.changed = $true + } + If (($null -ne $user_cannot_change_password) -and ($user_cannot_change_password -ne (Get-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE))) { + If ($user_cannot_change_password) { + Set-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE + } + Else { + Clear-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE + } + $result.changed = $true + } + If (($null -ne $account_disabled) -and ($account_disabled -ne $user_obj.AccountDisabled)) { + $user_obj.AccountDisabled = $account_disabled + $result.changed = $true + } + If (($null -ne $account_locked) -and ($account_locked -ne $user_obj.IsAccountLocked)) { + $user_obj.IsAccountLocked = $account_locked + $result.changed = $true + } + If ($result.changed) { + $user_obj.SetInfo() + } + If ($null -ne $groups) { + [string[]]$current_groups = $user_obj.Groups() | ForEach-Object { $_.GetType().InvokeMember("Name", "GetProperty", $null, $_, $null) } + If (($groups_action -eq "remove") -or ($groups_action -eq "replace")) { + ForEach ($grp in $current_groups) { + If ((($groups_action -eq "remove") -and ($groups -contains $grp)) -or (($groups_action -eq "replace") -and ($groups -notcontains $grp))) { + $group_obj = Get-Group $grp + If ($group_obj) { + $group_obj.Remove($user_obj.Path) + $result.changed = $true + } + Else { + Fail-Json $result "group '$grp' not found" + } + } + } + } + If (($groups_action -eq "add") -or ($groups_action -eq "replace")) { + ForEach ($grp in $groups) { + If ($current_groups -notcontains $grp) { + $group_obj = Get-Group $grp + If ($group_obj) { + $group_obj.Add($user_obj.Path) + $result.changed = $true + } + Else { + Fail-Json $result "group '$grp' not found" + } + } + } + } + } + } + catch { + Fail-Json $result $_.Exception.Message + } +} +ElseIf ($state -eq 'absent') { + # Remove user + try { + If ($user_obj) { + $username = $user_obj.Name.Value + $adsi.delete("User", $user_obj.Name.Value) + $result.changed = $true + $result.msg = "User '$username' deleted successfully" + $user_obj = $null + } else { + $result.msg = "User '$username' was not found" + } + } + catch { + Fail-Json $result $_.Exception.Message + } +} + +try { + If ($user_obj -and $user_obj -is [System.DirectoryServices.DirectoryEntry]) { + $user_obj.RefreshCache() + $result.name = $user_obj.Name[0] + $result.fullname = $user_obj.FullName[0] + $result.path = $user_obj.Path + $result.description = $user_obj.Description[0] + $result.password_expired = ($user_obj.PasswordExpired | ConvertTo-Bool) + $result.password_never_expires = (Get-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD) + $result.user_cannot_change_password = (Get-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE) + $result.account_disabled = $user_obj.AccountDisabled + $result.account_locked = $user_obj.IsAccountLocked + $result.sid = (New-Object System.Security.Principal.SecurityIdentifier($user_obj.ObjectSid.Value, 0)).Value + $user_groups = @() + ForEach ($grp in $user_obj.Groups()) { + $group_result = @{ + name = $grp.GetType().InvokeMember("Name", "GetProperty", $null, $grp, $null) + path = $grp.GetType().InvokeMember("ADsPath", "GetProperty", $null, $grp, $null) + } + $user_groups += $group_result; + } + $result.groups = $user_groups + $result.state = "present" + } + Else { + $result.name = $username + if ($state -eq 'query') { + $result.msg = "User '$username' was not found" + } + $result.state = "absent" + } +} +catch { + Fail-Json $result $_.Exception.Message +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_user.py b/test/support/windows-integration/plugins/modules/win_user.py new file mode 100644 index 0000000..5fc0633 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user.py @@ -0,0 +1,194 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2014, Matt Martz <matt@sivel.net>, and others +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_user +version_added: "1.7" +short_description: Manages local Windows user accounts +description: + - Manages local Windows user accounts. + - For non-Windows targets, use the M(user) module instead. +options: + name: + description: + - Name of the user to create, remove or modify. + type: str + required: yes + fullname: + description: + - Full name of the user. + type: str + version_added: "1.9" + description: + description: + - Description of the user. + type: str + version_added: "1.9" + password: + description: + - Optionally set the user's password to this (plain text) value. + type: str + update_password: + description: + - C(always) will update passwords if they differ. C(on_create) will + only set the password for newly created users. + type: str + choices: [ always, on_create ] + default: always + version_added: "1.9" + password_expired: + description: + - C(yes) will require the user to change their password at next login. + - C(no) will clear the expired password flag. + type: bool + version_added: "1.9" + password_never_expires: + description: + - C(yes) will set the password to never expire. + - C(no) will allow the password to expire. + type: bool + version_added: "1.9" + user_cannot_change_password: + description: + - C(yes) will prevent the user from changing their password. + - C(no) will allow the user to change their password. + type: bool + version_added: "1.9" + account_disabled: + description: + - C(yes) will disable the user account. + - C(no) will clear the disabled flag. + type: bool + version_added: "1.9" + account_locked: + description: + - C(no) will unlock the user account if locked. + choices: [ 'no' ] + version_added: "1.9" + groups: + description: + - Adds or removes the user from this comma-separated list of groups, + depending on the value of I(groups_action). + - When I(groups_action) is C(replace) and I(groups) is set to the empty + string ('groups='), the user is removed from all groups. + version_added: "1.9" + groups_action: + description: + - If C(add), the user is added to each group in I(groups) where not + already a member. + - If C(replace), the user is added as a member of each group in + I(groups) and removed from any other groups. + - If C(remove), the user is removed from each group in I(groups). + type: str + choices: [ add, replace, remove ] + default: replace + version_added: "1.9" + state: + description: + - When C(absent), removes the user account if it exists. + - When C(present), creates or updates the user account. + - When C(query) (new in 1.9), retrieves the user account details + without making any changes. + type: str + choices: [ absent, present, query ] + default: present +seealso: +- module: user +- module: win_domain_membership +- module: win_domain_user +- module: win_group +- module: win_group_membership +- module: win_user_profile +author: + - Paul Durivage (@angstwad) + - Chris Church (@cchurch) +''' + +EXAMPLES = r''' +- name: Ensure user bob is present + win_user: + name: bob + password: B0bP4ssw0rd + state: present + groups: + - Users + +- name: Ensure user bob is absent + win_user: + name: bob + state: absent +''' + +RETURN = r''' +account_disabled: + description: Whether the user is disabled. + returned: user exists + type: bool + sample: false +account_locked: + description: Whether the user is locked. + returned: user exists + type: bool + sample: false +description: + description: The description set for the user. + returned: user exists + type: str + sample: Username for test +fullname: + description: The full name set for the user. + returned: user exists + type: str + sample: Test Username +groups: + description: A list of groups and their ADSI path the user is a member of. + returned: user exists + type: list + sample: [ + { + "name": "Administrators", + "path": "WinNT://WORKGROUP/USER-PC/Administrators" + } + ] +name: + description: The name of the user + returned: always + type: str + sample: username +password_expired: + description: Whether the password is expired. + returned: user exists + type: bool + sample: false +password_never_expires: + description: Whether the password is set to never expire. + returned: user exists + type: bool + sample: true +path: + description: The ADSI path for the user. + returned: user exists + type: str + sample: "WinNT://WORKGROUP/USER-PC/username" +sid: + description: The SID for the user. + returned: user exists + type: str + sample: S-1-5-21-3322259488-2828151810-3939402796-1001 +user_cannot_change_password: + description: Whether the user can change their own password. + returned: user exists + type: bool + sample: false +''' diff --git a/test/support/windows-integration/plugins/modules/win_user_right.ps1 b/test/support/windows-integration/plugins/modules/win_user_right.ps1 new file mode 100644 index 0000000..3fac52a --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user_right.ps1 @@ -0,0 +1,349 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.SID + +$ErrorActionPreference = 'Stop' + +$params = Parse-Args $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$name = Get-AnsibleParam -obj $params -name "name" -type "str" -failifempty $true +$users = Get-AnsibleParam -obj $params -name "users" -type "list" -failifempty $true +$action = Get-AnsibleParam -obj $params -name "action" -type "str" -default "set" -validateset "add","remove","set" + +$result = @{ + changed = $false + added = @() + removed = @() +} + +if ($diff_mode) { + $result.diff = @{} +} + +$sec_helper_util = @" +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Security.Principal; + +namespace Ansible +{ + public class LsaRightHelper : IDisposable + { + // Code modified from https://gallery.technet.microsoft.com/scriptcenter/Grant-Revoke-Query-user-26e259b0 + + enum Access : int + { + POLICY_READ = 0x20006, + POLICY_ALL_ACCESS = 0x00F0FFF, + POLICY_EXECUTE = 0X20801, + POLICY_WRITE = 0X207F8 + } + + IntPtr lsaHandle; + + const string LSA_DLL = "advapi32.dll"; + const CharSet DEFAULT_CHAR_SET = CharSet.Unicode; + + const uint STATUS_NO_MORE_ENTRIES = 0x8000001a; + const uint STATUS_NO_SUCH_PRIVILEGE = 0xc0000060; + + internal sealed class Sid : IDisposable + { + public IntPtr pSid = IntPtr.Zero; + public SecurityIdentifier sid = null; + + public Sid(string sidString) + { + try + { + sid = new SecurityIdentifier(sidString); + } catch + { + throw new ArgumentException(String.Format("SID string {0} could not be converted to SecurityIdentifier", sidString)); + } + + Byte[] buffer = new Byte[sid.BinaryLength]; + sid.GetBinaryForm(buffer, 0); + + pSid = Marshal.AllocHGlobal(sid.BinaryLength); + Marshal.Copy(buffer, 0, pSid, sid.BinaryLength); + } + + public void Dispose() + { + if (pSid != IntPtr.Zero) + { + Marshal.FreeHGlobal(pSid); + pSid = IntPtr.Zero; + } + GC.SuppressFinalize(this); + } + ~Sid() { Dispose(); } + } + + [StructLayout(LayoutKind.Sequential)] + private struct LSA_OBJECT_ATTRIBUTES + { + public int Length; + public IntPtr RootDirectory; + public IntPtr ObjectName; + public int Attributes; + public IntPtr SecurityDescriptor; + public IntPtr SecurityQualityOfService; + } + + [StructLayout(LayoutKind.Sequential, CharSet = DEFAULT_CHAR_SET)] + private struct LSA_UNICODE_STRING + { + public ushort Length; + public ushort MaximumLength; + [MarshalAs(UnmanagedType.LPWStr)] + public string Buffer; + } + + [StructLayout(LayoutKind.Sequential)] + private struct LSA_ENUMERATION_INFORMATION + { + public IntPtr Sid; + } + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaOpenPolicy( + LSA_UNICODE_STRING[] SystemName, + ref LSA_OBJECT_ATTRIBUTES ObjectAttributes, + int AccessMask, + out IntPtr PolicyHandle + ); + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaAddAccountRights( + IntPtr PolicyHandle, + IntPtr pSID, + LSA_UNICODE_STRING[] UserRights, + int CountOfRights + ); + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaRemoveAccountRights( + IntPtr PolicyHandle, + IntPtr pSID, + bool AllRights, + LSA_UNICODE_STRING[] UserRights, + int CountOfRights + ); + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaEnumerateAccountsWithUserRight( + IntPtr PolicyHandle, + LSA_UNICODE_STRING[] UserRights, + out IntPtr EnumerationBuffer, + out ulong CountReturned + ); + + [DllImport(LSA_DLL)] + private static extern int LsaNtStatusToWinError(int NTSTATUS); + + [DllImport(LSA_DLL)] + private static extern int LsaClose(IntPtr PolicyHandle); + + [DllImport(LSA_DLL)] + private static extern int LsaFreeMemory(IntPtr Buffer); + + public LsaRightHelper() + { + LSA_OBJECT_ATTRIBUTES lsaAttr; + lsaAttr.RootDirectory = IntPtr.Zero; + lsaAttr.ObjectName = IntPtr.Zero; + lsaAttr.Attributes = 0; + lsaAttr.SecurityDescriptor = IntPtr.Zero; + lsaAttr.SecurityQualityOfService = IntPtr.Zero; + lsaAttr.Length = Marshal.SizeOf(typeof(LSA_OBJECT_ATTRIBUTES)); + + lsaHandle = IntPtr.Zero; + + LSA_UNICODE_STRING[] system = new LSA_UNICODE_STRING[1]; + system[0] = InitLsaString(""); + + uint ret = LsaOpenPolicy(system, ref lsaAttr, (int)Access.POLICY_ALL_ACCESS, out lsaHandle); + if (ret != 0) + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + + public void AddPrivilege(string sidString, string privilege) + { + uint ret = 0; + using (Sid sid = new Sid(sidString)) + { + LSA_UNICODE_STRING[] privileges = new LSA_UNICODE_STRING[1]; + privileges[0] = InitLsaString(privilege); + ret = LsaAddAccountRights(lsaHandle, sid.pSid, privileges, 1); + } + if (ret != 0) + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + + public void RemovePrivilege(string sidString, string privilege) + { + uint ret = 0; + using (Sid sid = new Sid(sidString)) + { + LSA_UNICODE_STRING[] privileges = new LSA_UNICODE_STRING[1]; + privileges[0] = InitLsaString(privilege); + ret = LsaRemoveAccountRights(lsaHandle, sid.pSid, false, privileges, 1); + } + if (ret != 0) + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + + public string[] EnumerateAccountsWithUserRight(string privilege) + { + uint ret = 0; + ulong count = 0; + LSA_UNICODE_STRING[] rights = new LSA_UNICODE_STRING[1]; + rights[0] = InitLsaString(privilege); + IntPtr buffer = IntPtr.Zero; + + ret = LsaEnumerateAccountsWithUserRight(lsaHandle, rights, out buffer, out count); + switch (ret) + { + case 0: + string[] accounts = new string[count]; + for (int i = 0; i < (int)count; i++) + { + LSA_ENUMERATION_INFORMATION LsaInfo = (LSA_ENUMERATION_INFORMATION)Marshal.PtrToStructure( + IntPtr.Add(buffer, i * Marshal.SizeOf(typeof(LSA_ENUMERATION_INFORMATION))), + typeof(LSA_ENUMERATION_INFORMATION)); + + accounts[i] = new SecurityIdentifier(LsaInfo.Sid).ToString(); + } + LsaFreeMemory(buffer); + return accounts; + + case STATUS_NO_MORE_ENTRIES: + return new string[0]; + + case STATUS_NO_SUCH_PRIVILEGE: + throw new ArgumentException(String.Format("Invalid privilege {0} not found in LSA database", privilege)); + + default: + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + } + + static LSA_UNICODE_STRING InitLsaString(string s) + { + // Unicode strings max. 32KB + if (s.Length > 0x7ffe) + throw new ArgumentException("String too long"); + + LSA_UNICODE_STRING lus = new LSA_UNICODE_STRING(); + lus.Buffer = s; + lus.Length = (ushort)(s.Length * sizeof(char)); + lus.MaximumLength = (ushort)(lus.Length + sizeof(char)); + + return lus; + } + + public void Dispose() + { + if (lsaHandle != IntPtr.Zero) + { + LsaClose(lsaHandle); + lsaHandle = IntPtr.Zero; + } + GC.SuppressFinalize(this); + } + ~LsaRightHelper() { Dispose(); } + } +} +"@ + +$original_tmp = $env:TMP +$env:TMP = $_remote_tmp +Add-Type -TypeDefinition $sec_helper_util +$env:TMP = $original_tmp + +Function Compare-UserList($existing_users, $new_users) { + $added_users = [String[]]@() + $removed_users = [String[]]@() + if ($action -eq "add") { + $added_users = [Linq.Enumerable]::Except($new_users, $existing_users) + } elseif ($action -eq "remove") { + $removed_users = [Linq.Enumerable]::Intersect($new_users, $existing_users) + } else { + $added_users = [Linq.Enumerable]::Except($new_users, $existing_users) + $removed_users = [Linq.Enumerable]::Except($existing_users, $new_users) + } + + $change_result = @{ + added = $added_users + removed = $removed_users + } + + return $change_result +} + +# C# class we can use to enumerate/add/remove rights +$lsa_helper = New-Object -TypeName Ansible.LsaRightHelper + +$new_users = [System.Collections.ArrayList]@() +foreach ($user in $users) { + $user_sid = Convert-ToSID -account_name $user + $new_users.Add($user_sid) > $null +} +$new_users = [String[]]$new_users.ToArray() +try { + $existing_users = $lsa_helper.EnumerateAccountsWithUserRight($name) +} catch [ArgumentException] { + Fail-Json -obj $result -message "the specified right $name is not a valid right" +} catch { + Fail-Json -obj $result -message "failed to enumerate existing accounts with right: $($_.Exception.Message)" +} + +$change_result = Compare-UserList -existing_users $existing_users -new_user $new_users +if (($change_result.added.Length -gt 0) -or ($change_result.removed.Length -gt 0)) { + $result.changed = $true + $diff_text = "[$name]`n" + + # used in diff mode calculation + $new_user_list = [System.Collections.ArrayList]$existing_users + foreach ($user in $change_result.removed) { + if (-not $check_mode) { + $lsa_helper.RemovePrivilege($user, $name) + } + $user_name = Convert-FromSID -sid $user + $result.removed += $user_name + $diff_text += "-$user_name`n" + $new_user_list.Remove($user) > $null + } + foreach ($user in $change_result.added) { + if (-not $check_mode) { + $lsa_helper.AddPrivilege($user, $name) + } + $user_name = Convert-FromSID -sid $user + $result.added += $user_name + $diff_text += "+$user_name`n" + $new_user_list.Add($user) > $null + } + + if ($diff_mode) { + if ($new_user_list.Count -eq 0) { + $diff_text = "-$diff_text" + } else { + if ($existing_users.Count -eq 0) { + $diff_text = "+$diff_text" + } + } + $result.diff.prepared = $diff_text + } +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_user_right.py b/test/support/windows-integration/plugins/modules/win_user_right.py new file mode 100644 index 0000000..5588208 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user_right.py @@ -0,0 +1,108 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_user_right +version_added: '2.4' +short_description: Manage Windows User Rights +description: +- Add, remove or set User Rights for a group or users or groups. +- You can set user rights for both local and domain accounts. +options: + name: + description: + - The name of the User Right as shown by the C(Constant Name) value from + U(https://technet.microsoft.com/en-us/library/dd349804.aspx). + - The module will return an error if the right is invalid. + type: str + required: yes + users: + description: + - A list of users or groups to add/remove on the User Right. + - These can be in the form DOMAIN\user-group, user-group@DOMAIN.COM for + domain users/groups. + - For local users/groups it can be in the form user-group, .\user-group, + SERVERNAME\user-group where SERVERNAME is the name of the remote server. + - You can also add special local accounts like SYSTEM and others. + - Can be set to an empty list with I(action=set) to remove all accounts + from the right. + type: list + required: yes + action: + description: + - C(add) will add the users/groups to the existing right. + - C(remove) will remove the users/groups from the existing right. + - C(set) will replace the users/groups of the existing right. + type: str + default: set + choices: [ add, remove, set ] +notes: +- If the server is domain joined this module can change a right but if a GPO + governs this right then the changes won't last. +seealso: +- module: win_group +- module: win_group_membership +- module: win_user +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +--- +- name: Replace the entries of Deny log on locally + win_user_right: + name: SeDenyInteractiveLogonRight + users: + - Guest + - Users + action: set + +- name: Add account to Log on as a service + win_user_right: + name: SeServiceLogonRight + users: + - .\Administrator + - '{{ansible_hostname}}\local-user' + action: add + +- name: Remove accounts who can create Symbolic links + win_user_right: + name: SeCreateSymbolicLinkPrivilege + users: + - SYSTEM + - Administrators + - DOMAIN\User + - group@DOMAIN.COM + action: remove + +- name: Remove all accounts who cannot log on remote interactively + win_user_right: + name: SeDenyRemoteInteractiveLogonRight + users: [] +''' + +RETURN = r''' +added: + description: A list of accounts that were added to the right, this is empty + if no accounts were added. + returned: success + type: list + sample: ["NT AUTHORITY\\SYSTEM", "DOMAIN\\User"] +removed: + description: A list of accounts that were removed from the right, this is + empty if no accounts were removed. + returned: success + type: list + sample: ["SERVERNAME\\Administrator", "BUILTIN\\Administrators"] +''' diff --git a/test/support/windows-integration/plugins/modules/win_wait_for.ps1 b/test/support/windows-integration/plugins/modules/win_wait_for.ps1 new file mode 100644 index 0000000..e0a9a72 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_wait_for.ps1 @@ -0,0 +1,259 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.FileUtil + +$ErrorActionPreference = "Stop" + +$params = Parse-Args -arguments $args -supports_check_mode $true + +$connect_timeout = Get-AnsibleParam -obj $params -name "connect_timeout" -type "int" -default 5 +$delay = Get-AnsibleParam -obj $params -name "delay" -type "int" +$exclude_hosts = Get-AnsibleParam -obj $params -name "exclude_hosts" -type "list" +$hostname = Get-AnsibleParam -obj $params -name "host" -type "str" -default "127.0.0.1" +$path = Get-AnsibleParam -obj $params -name "path" -type "path" +$port = Get-AnsibleParam -obj $params -name "port" -type "int" +$regex = Get-AnsibleParam -obj $params -name "regex" -type "str" -aliases "search_regex","regexp" +$sleep = Get-AnsibleParam -obj $params -name "sleep" -type "int" -default 1 +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "started" -validateset "present","started","stopped","absent","drained" +$timeout = Get-AnsibleParam -obj $params -name "timeout" -type "int" -default 300 + +$result = @{ + changed = $false + elapsed = 0 +} + +# validate the input with the various options +if ($null -ne $port -and $null -ne $path) { + Fail-Json $result "port and path parameter can not both be passed to win_wait_for" +} +if ($null -ne $exclude_hosts -and $state -ne "drained") { + Fail-Json $result "exclude_hosts should only be with state=drained" +} +if ($null -ne $path) { + if ($state -in @("stopped","drained")) { + Fail-Json $result "state=$state should only be used for checking a port in the win_wait_for module" + } + + if ($null -ne $exclude_hosts) { + Fail-Json $result "exclude_hosts should only be used when checking a port and state=drained in the win_wait_for module" + } +} + +if ($null -ne $port) { + if ($null -ne $regex) { + Fail-Json $result "regex should by used when checking a string in a file in the win_wait_for module" + } + + if ($null -ne $exclude_hosts -and $state -ne "drained") { + Fail-Json $result "exclude_hosts should be used when state=drained in the win_wait_for module" + } +} + +Function Test-Port($hostname, $port) { + $timeout = $connect_timeout * 1000 + $socket = New-Object -TypeName System.Net.Sockets.TcpClient + $connect = $socket.BeginConnect($hostname, $port, $null, $null) + $wait = $connect.AsyncWaitHandle.WaitOne($timeout, $false) + + if ($wait) { + try { + $socket.EndConnect($connect) | Out-Null + $valid = $true + } catch { + $valid = $false + } + } else { + $valid = $false + } + + $socket.Close() + $socket.Dispose() + + $valid +} + +Function Get-PortConnections($hostname, $port) { + $connections = @() + + $conn_info = [Net.NetworkInformation.IPGlobalProperties]::GetIPGlobalProperties() + if ($hostname -eq "0.0.0.0") { + $active_connections = $conn_info.GetActiveTcpConnections() | Where-Object { $_.LocalEndPoint.Port -eq $port } + } else { + $active_connections = $conn_info.GetActiveTcpConnections() | Where-Object { $_.LocalEndPoint.Address -eq $hostname -and $_.LocalEndPoint.Port -eq $port } + } + + if ($null -ne $active_connections) { + foreach ($active_connection in $active_connections) { + $connections += $active_connection.RemoteEndPoint.Address + } + } + + $connections +} + +$module_start = Get-Date + +if ($null -ne $delay) { + Start-Sleep -Seconds $delay +} + +$attempts = 0 +if ($null -eq $path -and $null -eq $port -and $state -ne "drained") { + Start-Sleep -Seconds $timeout +} elseif ($null -ne $path) { + if ($state -in @("present", "started")) { + # check if the file exists or string exists in file + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + if (Test-AnsiblePath -Path $path) { + if ($null -eq $regex) { + $complete = $true + break + } else { + $file_contents = Get-Content -Path $path -Raw + if ($file_contents -match $regex) { + $complete = $true + break + } + } + } + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + if ($null -eq $regex) { + Fail-Json $result "timeout while waiting for file $path to be present" + } else { + Fail-Json $result "timeout while waiting for string regex $regex in file $path to match" + } + } + } elseif ($state -in @("absent")) { + # check if the file is deleted or string doesn't exist in file + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + if (Test-AnsiblePath -Path $path) { + if ($null -ne $regex) { + $file_contents = Get-Content -Path $path -Raw + if ($file_contents -notmatch $regex) { + $complete = $true + break + } + } + } else { + $complete = $true + break + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + if ($null -eq $regex) { + Fail-Json $result "timeout while waiting for file $path to be absent" + } else { + Fail-Json $result "timeout while waiting for string regex $regex in file $path to not match" + } + } + } +} elseif ($null -ne $port) { + if ($state -in @("started","present")) { + # check that the port is online and is listening + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + $port_result = Test-Port -hostname $hostname -port $port + if ($port_result -eq $true) { + $complete = $true + break + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + Fail-Json $result "timeout while waiting for $($hostname):$port to start listening" + } + } elseif ($state -in @("stopped","absent")) { + # check that the port is offline and is not listening + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + $port_result = Test-Port -hostname $hostname -port $port + if ($port_result -eq $false) { + $complete = $true + break + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + Fail-Json $result "timeout while waiting for $($hostname):$port to stop listening" + } + } elseif ($state -eq "drained") { + # check that the local port is online but has no active connections + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + $active_connections = Get-PortConnections -hostname $hostname -port $port + if ($null -eq $active_connections) { + $complete = $true + break + } elseif ($active_connections.Count -eq 0) { + # no connections on port + $complete = $true + break + } else { + # there are listeners, check if we should ignore any hosts + if ($null -ne $exclude_hosts) { + $connection_info = $active_connections + foreach ($exclude_host in $exclude_hosts) { + try { + $exclude_ips = [System.Net.Dns]::GetHostAddresses($exclude_host) | ForEach-Object { Write-Output $_.IPAddressToString } + $connection_info = $connection_info | Where-Object { $_ -notin $exclude_ips } + } catch { # ignore invalid hostnames + Add-Warning -obj $result -message "Invalid hostname specified $exclude_host" + } + } + + if ($connection_info.Count -eq 0) { + $complete = $true + break + } + } + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + Fail-Json $result "timeout while waiting for $($hostname):$port to drain" + } + } +} + +$result.elapsed = ((Get-Date) - $module_start).TotalSeconds +$result.wait_attempts = $attempts + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_wait_for.py b/test/support/windows-integration/plugins/modules/win_wait_for.py new file mode 100644 index 0000000..85721e7 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_wait_for.py @@ -0,0 +1,155 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub, actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_wait_for +version_added: '2.4' +short_description: Waits for a condition before continuing +description: +- You can wait for a set amount of time C(timeout), this is the default if + nothing is specified. +- Waiting for a port to become available is useful for when services are not + immediately available after their init scripts return which is true of + certain Java application servers. +- You can wait for a file to exist or not exist on the filesystem. +- This module can also be used to wait for a regex match string to be present + in a file. +- You can wait for active connections to be closed before continuing on a + local port. +options: + connect_timeout: + description: + - The maximum number of seconds to wait for a connection to happen before + closing and retrying. + type: int + default: 5 + delay: + description: + - The number of seconds to wait before starting to poll. + type: int + exclude_hosts: + description: + - The list of hosts or IPs to ignore when looking for active TCP + connections when C(state=drained). + type: list + host: + description: + - A resolvable hostname or IP address to wait for. + - If C(state=drained) then it will only check for connections on the IP + specified, you can use '0.0.0.0' to use all host IPs. + type: str + default: '127.0.0.1' + path: + description: + - The path to a file on the filesystem to check. + - If C(state) is present or started then it will wait until the file + exists. + - If C(state) is absent then it will wait until the file does not exist. + type: path + port: + description: + - The port number to poll on C(host). + type: int + regex: + description: + - Can be used to match a string in a file. + - If C(state) is present or started then it will wait until the regex + matches. + - If C(state) is absent then it will wait until the regex does not match. + - Defaults to a multiline regex. + type: str + aliases: [ "search_regex", "regexp" ] + sleep: + description: + - Number of seconds to sleep between checks. + type: int + default: 1 + state: + description: + - When checking a port, C(started) will ensure the port is open, C(stopped) + will check that is it closed and C(drained) will check for active + connections. + - When checking for a file or a search string C(present) or C(started) will + ensure that the file or string is present, C(absent) will check that the + file or search string is absent or removed. + type: str + choices: [ absent, drained, present, started, stopped ] + default: started + timeout: + description: + - The maximum number of seconds to wait for. + type: int + default: 300 +seealso: +- module: wait_for +- module: win_wait_for_process +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Wait 300 seconds for port 8000 to become open on the host, don't start checking for 10 seconds + win_wait_for: + port: 8000 + delay: 10 + +- name: Wait 150 seconds for port 8000 of any IP to close active connections + win_wait_for: + host: 0.0.0.0 + port: 8000 + state: drained + timeout: 150 + +- name: Wait for port 8000 of any IP to close active connection, ignoring certain hosts + win_wait_for: + host: 0.0.0.0 + port: 8000 + state: drained + exclude_hosts: ['10.2.1.2', '10.2.1.3'] + +- name: Wait for file C:\temp\log.txt to exist before continuing + win_wait_for: + path: C:\temp\log.txt + +- name: Wait until process complete is in the file before continuing + win_wait_for: + path: C:\temp\log.txt + regex: process complete + +- name: Wait until file is removed + win_wait_for: + path: C:\temp\log.txt + state: absent + +- name: Wait until port 1234 is offline but try every 10 seconds + win_wait_for: + port: 1234 + state: absent + sleep: 10 +''' + +RETURN = r''' +wait_attempts: + description: The number of attempts to poll the file or port before module + finishes. + returned: always + type: int + sample: 1 +elapsed: + description: The elapsed seconds between the start of poll and the end of the + module. This includes the delay if the option is set. + returned: always + type: float + sample: 2.1406487 +''' diff --git a/test/support/windows-integration/plugins/modules/win_whoami.ps1 b/test/support/windows-integration/plugins/modules/win_whoami.ps1 new file mode 100644 index 0000000..6c9965a --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_whoami.ps1 @@ -0,0 +1,837 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.CamelConversion + +$ErrorActionPreference = "Stop" + +$params = Parse-Args $args -supports_check_mode $true +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$session_util = @' +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; + +namespace Ansible +{ + public class SessionInfo + { + // SECURITY_LOGON_SESSION_DATA + public UInt64 LogonId { get; internal set; } + public Sid Account { get; internal set; } + public string LoginDomain { get; internal set; } + public string AuthenticationPackage { get; internal set; } + public SECURITY_LOGON_TYPE LogonType { get; internal set; } + public string LoginTime { get; internal set; } + public string LogonServer { get; internal set; } + public string DnsDomainName { get; internal set; } + public string Upn { get; internal set; } + public ArrayList UserFlags { get; internal set; } + + // TOKEN_STATISTICS + public SECURITY_IMPERSONATION_LEVEL ImpersonationLevel { get; internal set; } + public TOKEN_TYPE TokenType { get; internal set; } + + // TOKEN_GROUPS + public ArrayList Groups { get; internal set; } + public ArrayList Rights { get; internal set; } + + // TOKEN_MANDATORY_LABEL + public Sid Label { get; internal set; } + + // TOKEN_PRIVILEGES + public Hashtable Privileges { get; internal set; } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct LSA_UNICODE_STRING + { + public UInt16 Length; + public UInt16 MaximumLength; + public IntPtr buffer; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LUID + { + public UInt32 LowPart; + public Int32 HighPart; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SECURITY_LOGON_SESSION_DATA + { + public UInt32 Size; + public LUID LogonId; + public LSA_UNICODE_STRING Username; + public LSA_UNICODE_STRING LoginDomain; + public LSA_UNICODE_STRING AuthenticationPackage; + public SECURITY_LOGON_TYPE LogonType; + public UInt32 Session; + public IntPtr Sid; + public UInt64 LoginTime; + public LSA_UNICODE_STRING LogonServer; + public LSA_UNICODE_STRING DnsDomainName; + public LSA_UNICODE_STRING Upn; + public UInt32 UserFlags; + public LSA_LAST_INTER_LOGON_INFO LastLogonInfo; + public LSA_UNICODE_STRING LogonScript; + public LSA_UNICODE_STRING ProfilePath; + public LSA_UNICODE_STRING HomeDirectory; + public LSA_UNICODE_STRING HomeDirectoryDrive; + public UInt64 LogoffTime; + public UInt64 KickOffTime; + public UInt64 PasswordLastSet; + public UInt64 PasswordCanChange; + public UInt64 PasswordMustChange; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LSA_LAST_INTER_LOGON_INFO + { + public UInt64 LastSuccessfulLogon; + public UInt64 LastFailedLogon; + public UInt32 FailedAttemptCountSinceLastSuccessfulLogon; + } + + public enum TOKEN_TYPE + { + TokenPrimary = 1, + TokenImpersonation + } + + public enum SECURITY_IMPERSONATION_LEVEL + { + SecurityAnonymous, + SecurityIdentification, + SecurityImpersonation, + SecurityDelegation + } + + public enum SECURITY_LOGON_TYPE + { + System = 0, // Used only by the Sytem account + Interactive = 2, + Network, + Batch, + Service, + Proxy, + Unlock, + NetworkCleartext, + NewCredentials, + RemoteInteractive, + CachedInteractive, + CachedRemoteInteractive, + CachedUnlock + } + + [Flags] + public enum TokenGroupAttributes : uint + { + SE_GROUP_ENABLED = 0x00000004, + SE_GROUP_ENABLED_BY_DEFAULT = 0x00000002, + SE_GROUP_INTEGRITY = 0x00000020, + SE_GROUP_INTEGRITY_ENABLED = 0x00000040, + SE_GROUP_LOGON_ID = 0xC0000000, + SE_GROUP_MANDATORY = 0x00000001, + SE_GROUP_OWNER = 0x00000008, + SE_GROUP_RESOURCE = 0x20000000, + SE_GROUP_USE_FOR_DENY_ONLY = 0x00000010, + } + + [Flags] + public enum UserFlags : uint + { + LOGON_OPTIMIZED = 0x4000, + LOGON_WINLOGON = 0x8000, + LOGON_PKINIT = 0x10000, + LOGON_NOT_OPTMIZED = 0x20000, + } + + [StructLayout(LayoutKind.Sequential)] + public struct SID_AND_ATTRIBUTES + { + public IntPtr Sid; + public UInt32 Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LUID_AND_ATTRIBUTES + { + public LUID Luid; + public UInt32 Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_GROUPS + { + public UInt32 GroupCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public SID_AND_ATTRIBUTES[] Groups; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_MANDATORY_LABEL + { + public SID_AND_ATTRIBUTES Label; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_STATISTICS + { + public LUID TokenId; + public LUID AuthenticationId; + public UInt64 ExpirationTime; + public TOKEN_TYPE TokenType; + public SECURITY_IMPERSONATION_LEVEL ImpersonationLevel; + public UInt32 DynamicCharged; + public UInt32 DynamicAvailable; + public UInt32 GroupCount; + public UInt32 PrivilegeCount; + public LUID ModifiedId; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_PRIVILEGES + { + public UInt32 PrivilegeCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public LUID_AND_ATTRIBUTES[] Privileges; + } + + public class AccessToken : IDisposable + { + public enum TOKEN_INFORMATION_CLASS + { + TokenUser = 1, + TokenGroups, + TokenPrivileges, + TokenOwner, + TokenPrimaryGroup, + TokenDefaultDacl, + TokenSource, + TokenType, + TokenImpersonationLevel, + TokenStatistics, + TokenRestrictedSids, + TokenSessionId, + TokenGroupsAndPrivileges, + TokenSessionReference, + TokenSandBoxInert, + TokenAuditPolicy, + TokenOrigin, + TokenElevationType, + TokenLinkedToken, + TokenElevation, + TokenHasRestrictions, + TokenAccessInformation, + TokenVirtualizationAllowed, + TokenVirtualizationEnabled, + TokenIntegrityLevel, + TokenUIAccess, + TokenMandatoryPolicy, + TokenLogonSid, + TokenIsAppContainer, + TokenCapabilities, + TokenAppContainerSid, + TokenAppContainerNumber, + TokenUserClaimAttributes, + TokenDeviceClaimAttributes, + TokenRestrictedUserClaimAttributes, + TokenRestrictedDeviceClaimAttributes, + TokenDeviceGroups, + TokenRestrictedDeviceGroups, + TokenSecurityAttributes, + TokenIsRestricted, + MaxTokenInfoClass + } + + public IntPtr hToken = IntPtr.Zero; + + [DllImport("kernel32.dll")] + private static extern IntPtr GetCurrentProcess(); + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern bool OpenProcessToken( + IntPtr ProcessHandle, + TokenAccessLevels DesiredAccess, + out IntPtr TokenHandle); + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern bool GetTokenInformation( + IntPtr TokenHandle, + TOKEN_INFORMATION_CLASS TokenInformationClass, + IntPtr TokenInformation, + UInt32 TokenInformationLength, + out UInt32 ReturnLength); + + public AccessToken(TokenAccessLevels tokenAccessLevels) + { + IntPtr currentProcess = GetCurrentProcess(); + if (!OpenProcessToken(currentProcess, tokenAccessLevels, out hToken)) + throw new Win32Exception("OpenProcessToken() for current process failed"); + } + + public IntPtr GetTokenInformation<T>(out T tokenInformation, TOKEN_INFORMATION_CLASS tokenClass) + { + UInt32 tokenLength = 0; + GetTokenInformation(hToken, tokenClass, IntPtr.Zero, 0, out tokenLength); + + IntPtr infoPtr = Marshal.AllocHGlobal((int)tokenLength); + + if (!GetTokenInformation(hToken, tokenClass, infoPtr, tokenLength, out tokenLength)) + throw new Win32Exception(String.Format("GetTokenInformation() data for {0} failed", tokenClass.ToString())); + + tokenInformation = (T)Marshal.PtrToStructure(infoPtr, typeof(T)); + return infoPtr; + } + + public void Dispose() + { + GC.SuppressFinalize(this); + } + + ~AccessToken() { Dispose(); } + } + + public class LsaHandle : IDisposable + { + [Flags] + public enum DesiredAccess : uint + { + POLICY_VIEW_LOCAL_INFORMATION = 0x00000001, + POLICY_VIEW_AUDIT_INFORMATION = 0x00000002, + POLICY_GET_PRIVATE_INFORMATION = 0x00000004, + POLICY_TRUST_ADMIN = 0x00000008, + POLICY_CREATE_ACCOUNT = 0x00000010, + POLICY_CREATE_SECRET = 0x00000020, + POLICY_CREATE_PRIVILEGE = 0x00000040, + POLICY_SET_DEFAULT_QUOTA_LIMITS = 0x00000080, + POLICY_SET_AUDIT_REQUIREMENTS = 0x00000100, + POLICY_AUDIT_LOG_ADMIN = 0x00000200, + POLICY_SERVER_ADMIN = 0x00000400, + POLICY_LOOKUP_NAMES = 0x00000800, + POLICY_NOTIFICATION = 0x00001000 + } + + public IntPtr handle = IntPtr.Zero; + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern uint LsaOpenPolicy( + LSA_UNICODE_STRING[] SystemName, + ref LSA_OBJECT_ATTRIBUTES ObjectAttributes, + DesiredAccess AccessMask, + out IntPtr PolicyHandle); + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern uint LsaClose( + IntPtr ObjectHandle); + + [DllImport("advapi32.dll", SetLastError = false)] + private static extern int LsaNtStatusToWinError( + uint Status); + + [StructLayout(LayoutKind.Sequential)] + public struct LSA_OBJECT_ATTRIBUTES + { + public int Length; + public IntPtr RootDirectory; + public IntPtr ObjectName; + public int Attributes; + public IntPtr SecurityDescriptor; + public IntPtr SecurityQualityOfService; + } + + public LsaHandle(DesiredAccess desiredAccess) + { + LSA_OBJECT_ATTRIBUTES lsaAttr; + lsaAttr.RootDirectory = IntPtr.Zero; + lsaAttr.ObjectName = IntPtr.Zero; + lsaAttr.Attributes = 0; + lsaAttr.SecurityDescriptor = IntPtr.Zero; + lsaAttr.SecurityQualityOfService = IntPtr.Zero; + lsaAttr.Length = Marshal.SizeOf(typeof(LSA_OBJECT_ATTRIBUTES)); + LSA_UNICODE_STRING[] system = new LSA_UNICODE_STRING[1]; + system[0].buffer = IntPtr.Zero; + + uint res = LsaOpenPolicy(system, ref lsaAttr, desiredAccess, out handle); + if (res != 0) + throw new Win32Exception(LsaNtStatusToWinError(res), "LsaOpenPolicy() failed"); + } + + public void Dispose() + { + if (handle != IntPtr.Zero) + { + LsaClose(handle); + handle = IntPtr.Zero; + } + GC.SuppressFinalize(this); + } + + ~LsaHandle() { Dispose(); } + } + + public class Sid + { + public string SidString { get; internal set; } + public string DomainName { get; internal set; } + public string AccountName { get; internal set; } + public SID_NAME_USE SidType { get; internal set; } + + public enum SID_NAME_USE + { + SidTypeUser = 1, + SidTypeGroup, + SidTypeDomain, + SidTypeAlias, + SidTypeWellKnownGroup, + SidTypeDeletedAccount, + SidTypeInvalid, + SidTypeUnknown, + SidTypeComputer, + SidTypeLabel, + SidTypeLogon, + } + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + private static extern bool LookupAccountSid( + string lpSystemName, + [MarshalAs(UnmanagedType.LPArray)] + byte[] Sid, + StringBuilder lpName, + ref UInt32 cchName, + StringBuilder ReferencedDomainName, + ref UInt32 cchReferencedDomainName, + out SID_NAME_USE peUse); + + public Sid(IntPtr sidPtr) + { + SecurityIdentifier sid; + try + { + sid = new SecurityIdentifier(sidPtr); + } + catch (Exception e) + { + throw new ArgumentException(String.Format("Failed to cast IntPtr to SecurityIdentifier: {0}", e)); + } + + SetSidInfo(sid); + } + + public Sid(SecurityIdentifier sid) + { + SetSidInfo(sid); + } + + public override string ToString() + { + return SidString; + } + + private void SetSidInfo(SecurityIdentifier sid) + { + byte[] sidBytes = new byte[sid.BinaryLength]; + sid.GetBinaryForm(sidBytes, 0); + + StringBuilder lpName = new StringBuilder(); + UInt32 cchName = 0; + StringBuilder referencedDomainName = new StringBuilder(); + UInt32 cchReferencedDomainName = 0; + SID_NAME_USE peUse; + LookupAccountSid(null, sidBytes, lpName, ref cchName, referencedDomainName, ref cchReferencedDomainName, out peUse); + + lpName.EnsureCapacity((int)cchName); + referencedDomainName.EnsureCapacity((int)cchReferencedDomainName); + + SidString = sid.ToString(); + if (!LookupAccountSid(null, sidBytes, lpName, ref cchName, referencedDomainName, ref cchReferencedDomainName, out peUse)) + { + int lastError = Marshal.GetLastWin32Error(); + + if (lastError != 1332 && lastError != 1789) // Fails to lookup Logon Sid + { + throw new Win32Exception(lastError, String.Format("LookupAccountSid() failed for SID: {0} {1}", sid.ToString(), lastError)); + } + else if (SidString.StartsWith("S-1-5-5-")) + { + AccountName = String.Format("LogonSessionId_{0}", SidString.Substring(8)); + DomainName = "NT AUTHORITY"; + SidType = SID_NAME_USE.SidTypeLogon; + } + else + { + AccountName = null; + DomainName = null; + SidType = SID_NAME_USE.SidTypeUnknown; + } + } + else + { + AccountName = lpName.ToString(); + DomainName = referencedDomainName.ToString(); + SidType = peUse; + } + } + } + + public class SessionUtil + { + [DllImport("secur32.dll", SetLastError = false)] + private static extern uint LsaFreeReturnBuffer( + IntPtr Buffer); + + [DllImport("secur32.dll", SetLastError = false)] + private static extern uint LsaEnumerateLogonSessions( + out UInt64 LogonSessionCount, + out IntPtr LogonSessionList); + + [DllImport("secur32.dll", SetLastError = false)] + private static extern uint LsaGetLogonSessionData( + IntPtr LogonId, + out IntPtr ppLogonSessionData); + + [DllImport("advapi32.dll", SetLastError = false)] + private static extern int LsaNtStatusToWinError( + uint Status); + + [DllImport("advapi32", SetLastError = true)] + private static extern uint LsaEnumerateAccountRights( + IntPtr PolicyHandle, + IntPtr AccountSid, + out IntPtr UserRights, + out UInt64 CountOfRights); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + private static extern bool LookupPrivilegeName( + string lpSystemName, + ref LUID lpLuid, + StringBuilder lpName, + ref UInt32 cchName); + + private const UInt32 SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001; + private const UInt32 SE_PRIVILEGE_ENABLED = 0x00000002; + private const UInt32 STATUS_OBJECT_NAME_NOT_FOUND = 0xC0000034; + private const UInt32 STATUS_ACCESS_DENIED = 0xC0000022; + + public static SessionInfo GetSessionInfo() + { + AccessToken accessToken = new AccessToken(TokenAccessLevels.Query); + + // Get Privileges + Hashtable privilegeInfo = new Hashtable(); + TOKEN_PRIVILEGES privileges; + IntPtr privilegesPtr = accessToken.GetTokenInformation(out privileges, AccessToken.TOKEN_INFORMATION_CLASS.TokenPrivileges); + LUID_AND_ATTRIBUTES[] luidAndAttributes = new LUID_AND_ATTRIBUTES[privileges.PrivilegeCount]; + try + { + PtrToStructureArray(luidAndAttributes, privilegesPtr.ToInt64() + Marshal.SizeOf(privileges.PrivilegeCount)); + } + finally + { + Marshal.FreeHGlobal(privilegesPtr); + } + foreach (LUID_AND_ATTRIBUTES luidAndAttribute in luidAndAttributes) + { + LUID privLuid = luidAndAttribute.Luid; + UInt32 privNameLen = 0; + StringBuilder privName = new StringBuilder(); + LookupPrivilegeName(null, ref privLuid, null, ref privNameLen); + privName.EnsureCapacity((int)(privNameLen + 1)); + if (!LookupPrivilegeName(null, ref privLuid, privName, ref privNameLen)) + throw new Win32Exception("LookupPrivilegeName() failed"); + + string state = "disabled"; + if ((luidAndAttribute.Attributes & SE_PRIVILEGE_ENABLED) == SE_PRIVILEGE_ENABLED) + state = "enabled"; + if ((luidAndAttribute.Attributes & SE_PRIVILEGE_ENABLED_BY_DEFAULT) == SE_PRIVILEGE_ENABLED_BY_DEFAULT) + state = "enabled-by-default"; + privilegeInfo.Add(privName.ToString(), state); + } + + // Get Current Process LogonSID, User Rights and Groups + ArrayList userRights = new ArrayList(); + ArrayList userGroups = new ArrayList(); + TOKEN_GROUPS groups; + IntPtr groupsPtr = accessToken.GetTokenInformation(out groups, AccessToken.TOKEN_INFORMATION_CLASS.TokenGroups); + SID_AND_ATTRIBUTES[] sidAndAttributes = new SID_AND_ATTRIBUTES[groups.GroupCount]; + LsaHandle lsaHandle = null; + // We can only get rights if we are an admin + if (new WindowsPrincipal(WindowsIdentity.GetCurrent()).IsInRole(WindowsBuiltInRole.Administrator)) + lsaHandle = new LsaHandle(LsaHandle.DesiredAccess.POLICY_LOOKUP_NAMES); + try + { + PtrToStructureArray(sidAndAttributes, groupsPtr.ToInt64() + IntPtr.Size); + foreach (SID_AND_ATTRIBUTES sidAndAttribute in sidAndAttributes) + { + TokenGroupAttributes attributes = (TokenGroupAttributes)sidAndAttribute.Attributes; + if (attributes.HasFlag(TokenGroupAttributes.SE_GROUP_ENABLED) && lsaHandle != null) + { + ArrayList rights = GetAccountRights(lsaHandle.handle, sidAndAttribute.Sid); + foreach (string right in rights) + { + // Includes both Privileges and Account Rights, only add the ones with Logon in the name + // https://msdn.microsoft.com/en-us/library/windows/desktop/bb545671(v=vs.85).aspx + if (!userRights.Contains(right) && right.Contains("Logon")) + userRights.Add(right); + } + } + // Do not include the Logon SID in the groups category + if (!attributes.HasFlag(TokenGroupAttributes.SE_GROUP_LOGON_ID)) + { + Hashtable groupInfo = new Hashtable(); + Sid group = new Sid(sidAndAttribute.Sid); + ArrayList groupAttributes = new ArrayList(); + foreach (TokenGroupAttributes attribute in Enum.GetValues(typeof(TokenGroupAttributes))) + { + if (attributes.HasFlag(attribute)) + { + string attributeName = attribute.ToString().Substring(9); + attributeName = attributeName.Replace('_', ' '); + attributeName = attributeName.First().ToString().ToUpper() + attributeName.Substring(1).ToLower(); + groupAttributes.Add(attributeName); + } + } + // Using snake_case here as I can't generically convert all dict keys in PS (see Privileges) + groupInfo.Add("sid", group.SidString); + groupInfo.Add("domain_name", group.DomainName); + groupInfo.Add("account_name", group.AccountName); + groupInfo.Add("type", group.SidType); + groupInfo.Add("attributes", groupAttributes); + userGroups.Add(groupInfo); + } + } + } + finally + { + Marshal.FreeHGlobal(groupsPtr); + if (lsaHandle != null) + lsaHandle.Dispose(); + } + + // Get Integrity Level + Sid integritySid = null; + TOKEN_MANDATORY_LABEL mandatoryLabel; + IntPtr mandatoryLabelPtr = accessToken.GetTokenInformation(out mandatoryLabel, AccessToken.TOKEN_INFORMATION_CLASS.TokenIntegrityLevel); + Marshal.FreeHGlobal(mandatoryLabelPtr); + integritySid = new Sid(mandatoryLabel.Label.Sid); + + // Get Token Statistics + TOKEN_STATISTICS tokenStats; + IntPtr tokenStatsPtr = accessToken.GetTokenInformation(out tokenStats, AccessToken.TOKEN_INFORMATION_CLASS.TokenStatistics); + Marshal.FreeHGlobal(tokenStatsPtr); + + SessionInfo sessionInfo = GetSessionDataForLogonSession(tokenStats.AuthenticationId); + sessionInfo.Groups = userGroups; + sessionInfo.Label = integritySid; + sessionInfo.ImpersonationLevel = tokenStats.ImpersonationLevel; + sessionInfo.TokenType = tokenStats.TokenType; + sessionInfo.Privileges = privilegeInfo; + sessionInfo.Rights = userRights; + return sessionInfo; + } + + private static ArrayList GetAccountRights(IntPtr lsaHandle, IntPtr sid) + { + UInt32 res; + ArrayList rights = new ArrayList(); + IntPtr userRightsPointer = IntPtr.Zero; + UInt64 countOfRights = 0; + + res = LsaEnumerateAccountRights(lsaHandle, sid, out userRightsPointer, out countOfRights); + if (res != 0 && res != STATUS_OBJECT_NAME_NOT_FOUND) + throw new Win32Exception(LsaNtStatusToWinError(res), "LsaEnumerateAccountRights() failed"); + else if (res != STATUS_OBJECT_NAME_NOT_FOUND) + { + LSA_UNICODE_STRING[] userRights = new LSA_UNICODE_STRING[countOfRights]; + PtrToStructureArray(userRights, userRightsPointer.ToInt64()); + rights = new ArrayList(); + foreach (LSA_UNICODE_STRING right in userRights) + rights.Add(Marshal.PtrToStringUni(right.buffer)); + } + + return rights; + } + + private static SessionInfo GetSessionDataForLogonSession(LUID logonSession) + { + uint res; + UInt64 count = 0; + IntPtr luidPtr = IntPtr.Zero; + SessionInfo sessionInfo = null; + UInt64 processDataId = ConvertLuidToUint(logonSession); + + res = LsaEnumerateLogonSessions(out count, out luidPtr); + if (res != 0) + throw new Win32Exception(LsaNtStatusToWinError(res), "LsaEnumerateLogonSessions() failed"); + Int64 luidAddr = luidPtr.ToInt64(); + + try + { + for (UInt64 i = 0; i < count; i++) + { + IntPtr dataPointer = IntPtr.Zero; + res = LsaGetLogonSessionData(luidPtr, out dataPointer); + if (res == STATUS_ACCESS_DENIED) // Non admins won't be able to get info for session's that are not their own + { + luidPtr = new IntPtr(luidPtr.ToInt64() + Marshal.SizeOf(typeof(LUID))); + continue; + } + else if (res != 0) + throw new Win32Exception(LsaNtStatusToWinError(res), String.Format("LsaGetLogonSessionData() failed {0}", res)); + + SECURITY_LOGON_SESSION_DATA sessionData = (SECURITY_LOGON_SESSION_DATA)Marshal.PtrToStructure(dataPointer, typeof(SECURITY_LOGON_SESSION_DATA)); + UInt64 sessionDataid = ConvertLuidToUint(sessionData.LogonId); + + if (sessionDataid == processDataId) + { + ArrayList userFlags = new ArrayList(); + UserFlags flags = (UserFlags)sessionData.UserFlags; + foreach (UserFlags flag in Enum.GetValues(typeof(UserFlags))) + { + if (flags.HasFlag(flag)) + { + string flagName = flag.ToString().Substring(6); + flagName = flagName.Replace('_', ' '); + flagName = flagName.First().ToString().ToUpper() + flagName.Substring(1).ToLower(); + userFlags.Add(flagName); + } + } + + sessionInfo = new SessionInfo() + { + AuthenticationPackage = Marshal.PtrToStringUni(sessionData.AuthenticationPackage.buffer), + DnsDomainName = Marshal.PtrToStringUni(sessionData.DnsDomainName.buffer), + LoginDomain = Marshal.PtrToStringUni(sessionData.LoginDomain.buffer), + LoginTime = ConvertIntegerToDateString(sessionData.LoginTime), + LogonId = ConvertLuidToUint(sessionData.LogonId), + LogonServer = Marshal.PtrToStringUni(sessionData.LogonServer.buffer), + LogonType = sessionData.LogonType, + Upn = Marshal.PtrToStringUni(sessionData.Upn.buffer), + UserFlags = userFlags, + Account = new Sid(sessionData.Sid) + }; + break; + } + luidPtr = new IntPtr(luidPtr.ToInt64() + Marshal.SizeOf(typeof(LUID))); + } + } + finally + { + LsaFreeReturnBuffer(new IntPtr(luidAddr)); + } + + if (sessionInfo == null) + throw new Exception(String.Format("Could not find the data for logon session {0}", processDataId)); + return sessionInfo; + } + + private static string ConvertIntegerToDateString(UInt64 time) + { + if (time == 0) + return null; + if (time > (UInt64)DateTime.MaxValue.ToFileTime()) + return null; + + DateTime dateTime = DateTime.FromFileTime((long)time); + return dateTime.ToString("o"); + } + + private static UInt64 ConvertLuidToUint(LUID luid) + { + UInt32 low = luid.LowPart; + UInt64 high = (UInt64)luid.HighPart; + high = high << 32; + UInt64 uintValue = (high | (UInt64)low); + return uintValue; + } + + private static void PtrToStructureArray<T>(T[] array, Int64 pointerAddress) + { + Int64 pointerOffset = pointerAddress; + for (int i = 0; i < array.Length; i++, pointerOffset += Marshal.SizeOf(typeof(T))) + array[i] = (T)Marshal.PtrToStructure(new IntPtr(pointerOffset), typeof(T)); + } + + public static IEnumerable<T> GetValues<T>() + { + return Enum.GetValues(typeof(T)).Cast<T>(); + } + } +} +'@ + +$original_tmp = $env:TMP +$env:TMP = $_remote_tmp +Add-Type -TypeDefinition $session_util +$env:TMP = $original_tmp + +$session_info = [Ansible.SessionUtil]::GetSessionInfo() + +Function Convert-Value($value) { + $new_value = $value + if ($value -is [System.Collections.ArrayList]) { + $new_value = [System.Collections.ArrayList]@() + foreach ($list_value in $value) { + $new_list_value = Convert-Value -value $list_value + [void]$new_value.Add($new_list_value) + } + } elseif ($value -is [Hashtable]) { + $new_value = @{} + foreach ($entry in $value.GetEnumerator()) { + $entry_value = Convert-Value -value $entry.Value + # manually convert Sid type entry to remove the SidType prefix + if ($entry.Name -eq "type") { + $entry_value = $entry_value.Replace("SidType", "") + } + $new_value[$entry.Name] = $entry_value + } + } elseif ($value -is [Ansible.Sid]) { + $new_value = @{ + sid = $value.SidString + account_name = $value.AccountName + domain_name = $value.DomainName + type = $value.SidType.ToString().Replace("SidType", "") + } + } elseif ($value -is [Enum]) { + $new_value = $value.ToString() + } + + return ,$new_value +} + +$result = @{ + changed = $false +} + +$properties = [type][Ansible.SessionInfo] +foreach ($property in $properties.DeclaredProperties) { + $property_name = $property.Name + $property_value = $session_info.$property_name + $snake_name = Convert-StringToSnakeCase -string $property_name + + $result.$snake_name = Convert-Value -value $property_value +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/plugins/modules/win_whoami.py b/test/support/windows-integration/plugins/modules/win_whoami.py new file mode 100644 index 0000000..d647374 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_whoami.py @@ -0,0 +1,203 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_whoami +version_added: "2.5" +short_description: Get information about the current user and process +description: +- Designed to return the same information as the C(whoami /all) command. +- Also includes information missing from C(whoami) such as logon metadata like + logon rights, id, type. +notes: +- If running this module with a non admin user, the logon rights will be an + empty list as Administrator rights are required to query LSA for the + information. +seealso: +- module: win_credential +- module: win_group_membership +- module: win_user_right +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Get whoami information + win_whoami: +''' + +RETURN = r''' +authentication_package: + description: The name of the authentication package used to authenticate the + user in the session. + returned: success + type: str + sample: Negotiate +user_flags: + description: The user flags for the logon session, see UserFlags in + U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa380128). + returned: success + type: str + sample: Winlogon +upn: + description: The user principal name of the current user. + returned: success + type: str + sample: Administrator@DOMAIN.COM +logon_type: + description: The logon type that identifies the logon method, see + U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa380129.aspx). + returned: success + type: str + sample: Network +privileges: + description: A dictionary of privileges and their state on the logon token. + returned: success + type: dict + sample: { + "SeChangeNotifyPrivileges": "enabled-by-default", + "SeRemoteShutdownPrivilege": "disabled", + "SeDebugPrivilege": "enabled" + } +label: + description: The mandatory label set to the logon session. + returned: success + type: complex + contains: + domain_name: + description: The domain name of the label SID. + returned: success + type: str + sample: Mandatory Label + sid: + description: The SID in string form. + returned: success + type: str + sample: S-1-16-12288 + account_name: + description: The account name of the label SID. + returned: success + type: str + sample: High Mandatory Level + type: + description: The type of SID. + returned: success + type: str + sample: Label +impersonation_level: + description: The impersonation level of the token, only valid if + C(token_type) is C(TokenImpersonation), see + U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa379572.aspx). + returned: success + type: str + sample: SecurityAnonymous +login_time: + description: The logon time in ISO 8601 format + returned: success + type: str + sample: '2017-11-27T06:24:14.3321665+10:00' +groups: + description: A list of groups and attributes that the user is a member of. + returned: success + type: list + sample: [ + { + "account_name": "Domain Users", + "domain_name": "DOMAIN", + "attributes": [ + "Mandatory", + "Enabled by default", + "Enabled" + ], + "sid": "S-1-5-21-1654078763-769949647-2968445802-513", + "type": "Group" + }, + { + "account_name": "Administrators", + "domain_name": "BUILTIN", + "attributes": [ + "Mandatory", + "Enabled by default", + "Enabled", + "Owner" + ], + "sid": "S-1-5-32-544", + "type": "Alias" + } + ] +account: + description: The running account SID details. + returned: success + type: complex + contains: + domain_name: + description: The domain name of the account SID. + returned: success + type: str + sample: DOMAIN + sid: + description: The SID in string form. + returned: success + type: str + sample: S-1-5-21-1654078763-769949647-2968445802-500 + account_name: + description: The account name of the account SID. + returned: success + type: str + sample: Administrator + type: + description: The type of SID. + returned: success + type: str + sample: User +login_domain: + description: The name of the domain used to authenticate the owner of the + session. + returned: success + type: str + sample: DOMAIN +rights: + description: A list of logon rights assigned to the logon. + returned: success and running user is a member of the local Administrators group + type: list + sample: [ + "SeNetworkLogonRight", + "SeInteractiveLogonRight", + "SeBatchLogonRight", + "SeRemoteInteractiveLogonRight" + ] +logon_server: + description: The name of the server used to authenticate the owner of the + logon session. + returned: success + type: str + sample: DC01 +logon_id: + description: The unique identifier of the logon session. + returned: success + type: int + sample: 20470143 +dns_domain_name: + description: The DNS name of the logon session, this is an empty string if + this is not set. + returned: success + type: str + sample: DOMAIN.COM +token_type: + description: The token type to indicate whether it is a primary or + impersonation token. + returned: success + type: str + sample: TokenPrimary +''' |