diff options
Diffstat (limited to 'test/support')
227 files changed, 69953 insertions, 0 deletions
diff --git a/test/support/integration/plugins/cache/jsonfile.py b/test/support/integration/plugins/cache/jsonfile.py new file mode 100644 index 00000000..80b16f55 --- /dev/null +++ b/test/support/integration/plugins/cache/jsonfile.py @@ -0,0 +1,63 @@ +# (c) 2014, Brian Coca, Josh Drake, et al +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = ''' + cache: jsonfile + short_description: JSON formatted files. + description: + - This cache uses JSON formatted, per host, files saved to the filesystem. + version_added: "1.9" + author: Ansible Core (@ansible-core) + options: + _uri: + required: True + description: + - Path in which the cache plugin will save the JSON files + env: + - name: ANSIBLE_CACHE_PLUGIN_CONNECTION + ini: + - key: fact_caching_connection + section: defaults + _prefix: + description: User defined prefix to use when creating the JSON files + env: + - name: ANSIBLE_CACHE_PLUGIN_PREFIX + ini: + - key: fact_caching_prefix + section: defaults + _timeout: + default: 86400 + description: Expiration timeout in seconds for the cache plugin data. Set to 0 to never expire + env: + - name: ANSIBLE_CACHE_PLUGIN_TIMEOUT + ini: + - key: fact_caching_timeout + section: defaults + type: integer +''' + +import codecs +import json + +from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder +from ansible.plugins.cache import BaseFileCacheModule + + +class CacheModule(BaseFileCacheModule): + """ + A caching module backed by json files. + """ + + def _load(self, filepath): + # Valid JSON is always UTF-8 encoded. + with codecs.open(filepath, 'r', encoding='utf-8') as f: + return json.load(f, cls=AnsibleJSONDecoder) + + def _dump(self, value, filepath): + with codecs.open(filepath, 'w', encoding='utf-8') as f: + f.write(json.dumps(value, cls=AnsibleJSONEncoder, sort_keys=True, indent=4)) diff --git a/test/support/integration/plugins/filter/json_query.py b/test/support/integration/plugins/filter/json_query.py new file mode 100644 index 00000000..d1da71b4 --- /dev/null +++ b/test/support/integration/plugins/filter/json_query.py @@ -0,0 +1,53 @@ +# (c) 2015, Filipe Niero Felisbino <filipenf@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.errors import AnsibleError, AnsibleFilterError + +try: + import jmespath + HAS_LIB = True +except ImportError: + HAS_LIB = False + + +def json_query(data, expr): + '''Query data using jmespath query language ( http://jmespath.org ). Example: + - debug: msg="{{ instance | json_query(tagged_instances[*].block_device_mapping.*.volume_id') }}" + ''' + if not HAS_LIB: + raise AnsibleError('You need to install "jmespath" prior to running ' + 'json_query filter') + + try: + return jmespath.search(expr, data) + except jmespath.exceptions.JMESPathError as e: + raise AnsibleFilterError('JMESPathError in json_query filter plugin:\n%s' % e) + except Exception as e: + # For older jmespath, we can get ValueError and TypeError without much info. + raise AnsibleFilterError('Error in jmespath.search in json_query filter plugin:\n%s' % e) + + +class FilterModule(object): + ''' Query filter ''' + + def filters(self): + return { + 'json_query': json_query + } diff --git a/test/support/integration/plugins/inventory/aws_ec2.py b/test/support/integration/plugins/inventory/aws_ec2.py new file mode 100644 index 00000000..09c42cf9 --- /dev/null +++ b/test/support/integration/plugins/inventory/aws_ec2.py @@ -0,0 +1,760 @@ +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = ''' + name: aws_ec2 + plugin_type: inventory + short_description: EC2 inventory source + requirements: + - boto3 + - botocore + extends_documentation_fragment: + - inventory_cache + - constructed + description: + - Get inventory hosts from Amazon Web Services EC2. + - Uses a YAML configuration file that ends with C(aws_ec2.(yml|yaml)). + notes: + - If no credentials are provided and the control node has an associated IAM instance profile then the + role will be used for authentication. + author: + - Sloane Hertel (@s-hertel) + options: + aws_profile: + description: The AWS profile + type: str + aliases: [ boto_profile ] + env: + - name: AWS_DEFAULT_PROFILE + - name: AWS_PROFILE + aws_access_key: + description: The AWS access key to use. + type: str + aliases: [ aws_access_key_id ] + env: + - name: EC2_ACCESS_KEY + - name: AWS_ACCESS_KEY + - name: AWS_ACCESS_KEY_ID + aws_secret_key: + description: The AWS secret key that corresponds to the access key. + type: str + aliases: [ aws_secret_access_key ] + env: + - name: EC2_SECRET_KEY + - name: AWS_SECRET_KEY + - name: AWS_SECRET_ACCESS_KEY + aws_security_token: + description: The AWS security token if using temporary access and secret keys. + type: str + env: + - name: EC2_SECURITY_TOKEN + - name: AWS_SESSION_TOKEN + - name: AWS_SECURITY_TOKEN + plugin: + description: Token that ensures this is a source file for the plugin. + required: True + choices: ['aws_ec2'] + iam_role_arn: + description: The ARN of the IAM role to assume to perform the inventory lookup. You should still provide AWS + credentials with enough privilege to perform the AssumeRole action. + version_added: '2.9' + regions: + description: + - A list of regions in which to describe EC2 instances. + - If empty (the default) default this will include all regions, except possibly restricted ones like us-gov-west-1 and cn-north-1. + type: list + default: [] + hostnames: + description: + - A list in order of precedence for hostname variables. + - You can use the options specified in U(http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options). + - To use tags as hostnames use the syntax tag:Name=Value to use the hostname Name_Value, or tag:Name to use the value of the Name tag. + type: list + default: [] + filters: + description: + - A dictionary of filter value pairs. + - Available filters are listed here U(http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options). + type: dict + default: {} + include_extra_api_calls: + description: + - Add two additional API calls for every instance to include 'persistent' and 'events' host variables. + - Spot instances may be persistent and instances may have associated events. + type: bool + default: False + version_added: '2.8' + strict_permissions: + description: + - By default if a 403 (Forbidden) error code is encountered this plugin will fail. + - You can set this option to False in the inventory config file which will allow 403 errors to be gracefully skipped. + type: bool + default: True + use_contrib_script_compatible_sanitization: + description: + - By default this plugin is using a general group name sanitization to create safe and usable group names for use in Ansible. + This option allows you to override that, in efforts to allow migration from the old inventory script and + matches the sanitization of groups when the script's ``replace_dash_in_groups`` option is set to ``False``. + To replicate behavior of ``replace_dash_in_groups = True`` with constructed groups, + you will need to replace hyphens with underscores via the regex_replace filter for those entries. + - For this to work you should also turn off the TRANSFORM_INVALID_GROUP_CHARS setting, + otherwise the core engine will just use the standard sanitization on top. + - This is not the default as such names break certain functionality as not all characters are valid Python identifiers + which group names end up being used as. + type: bool + default: False + version_added: '2.8' +''' + +EXAMPLES = ''' +# Minimal example using environment vars or instance role credentials +# Fetch all hosts in us-east-1, the hostname is the public DNS if it exists, otherwise the private IP address +plugin: aws_ec2 +regions: + - us-east-1 + +# Example using filters, ignoring permission errors, and specifying the hostname precedence +plugin: aws_ec2 +boto_profile: aws_profile +# Populate inventory with instances in these regions +regions: + - us-east-1 + - us-east-2 +filters: + # All instances with their `Environment` tag set to `dev` + tag:Environment: dev + # All dev and QA hosts + tag:Environment: + - dev + - qa + instance.group-id: sg-xxxxxxxx +# Ignores 403 errors rather than failing +strict_permissions: False +# Note: I(hostnames) sets the inventory_hostname. To modify ansible_host without modifying +# inventory_hostname use compose (see example below). +hostnames: + - tag:Name=Tag1,Name=Tag2 # Return specific hosts only + - tag:CustomDNSName + - dns-name + - private-ip-address + +# Example using constructed features to create groups and set ansible_host +plugin: aws_ec2 +regions: + - us-east-1 + - us-west-1 +# keyed_groups may be used to create custom groups +strict: False +keyed_groups: + # Add e.g. x86_64 hosts to an arch_x86_64 group + - prefix: arch + key: 'architecture' + # Add hosts to tag_Name_Value groups for each Name/Value tag pair + - prefix: tag + key: tags + # Add hosts to e.g. instance_type_z3_tiny + - prefix: instance_type + key: instance_type + # Create security_groups_sg_abcd1234 group for each SG + - key: 'security_groups|json_query("[].group_id")' + prefix: 'security_groups' + # Create a group for each value of the Application tag + - key: tags.Application + separator: '' + # Create a group per region e.g. aws_region_us_east_2 + - key: placement.region + prefix: aws_region + # Create a group (or groups) based on the value of a custom tag "Role" and add them to a metagroup called "project" + - key: tags['Role'] + prefix: foo + parent_group: "project" +# Set individual variables with compose +compose: + # Use the private IP address to connect to the host + # (note: this does not modify inventory_hostname, which is set via I(hostnames)) + ansible_host: private_ip_address +''' + +import re + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_native, to_text +from ansible.module_utils.common.dict_transformations import camel_dict_to_snake_dict +from ansible.plugins.inventory import BaseInventoryPlugin, Constructable, Cacheable +from ansible.utils.display import Display +from ansible.module_utils.six import string_types + +try: + import boto3 + import botocore +except ImportError: + raise AnsibleError('The ec2 dynamic inventory plugin requires boto3 and botocore.') + +display = Display() + +# The mappings give an array of keys to get from the filter name to the value +# returned by boto3's EC2 describe_instances method. + +instance_meta_filter_to_boto_attr = { + 'group-id': ('Groups', 'GroupId'), + 'group-name': ('Groups', 'GroupName'), + 'network-interface.attachment.instance-owner-id': ('OwnerId',), + 'owner-id': ('OwnerId',), + 'requester-id': ('RequesterId',), + 'reservation-id': ('ReservationId',), +} + +instance_data_filter_to_boto_attr = { + 'affinity': ('Placement', 'Affinity'), + 'architecture': ('Architecture',), + 'availability-zone': ('Placement', 'AvailabilityZone'), + 'block-device-mapping.attach-time': ('BlockDeviceMappings', 'Ebs', 'AttachTime'), + 'block-device-mapping.delete-on-termination': ('BlockDeviceMappings', 'Ebs', 'DeleteOnTermination'), + 'block-device-mapping.device-name': ('BlockDeviceMappings', 'DeviceName'), + 'block-device-mapping.status': ('BlockDeviceMappings', 'Ebs', 'Status'), + 'block-device-mapping.volume-id': ('BlockDeviceMappings', 'Ebs', 'VolumeId'), + 'client-token': ('ClientToken',), + 'dns-name': ('PublicDnsName',), + 'host-id': ('Placement', 'HostId'), + 'hypervisor': ('Hypervisor',), + 'iam-instance-profile.arn': ('IamInstanceProfile', 'Arn'), + 'image-id': ('ImageId',), + 'instance-id': ('InstanceId',), + 'instance-lifecycle': ('InstanceLifecycle',), + 'instance-state-code': ('State', 'Code'), + 'instance-state-name': ('State', 'Name'), + 'instance-type': ('InstanceType',), + 'instance.group-id': ('SecurityGroups', 'GroupId'), + 'instance.group-name': ('SecurityGroups', 'GroupName'), + 'ip-address': ('PublicIpAddress',), + 'kernel-id': ('KernelId',), + 'key-name': ('KeyName',), + 'launch-index': ('AmiLaunchIndex',), + 'launch-time': ('LaunchTime',), + 'monitoring-state': ('Monitoring', 'State'), + 'network-interface.addresses.private-ip-address': ('NetworkInterfaces', 'PrivateIpAddress'), + 'network-interface.addresses.primary': ('NetworkInterfaces', 'PrivateIpAddresses', 'Primary'), + 'network-interface.addresses.association.public-ip': ('NetworkInterfaces', 'PrivateIpAddresses', 'Association', 'PublicIp'), + 'network-interface.addresses.association.ip-owner-id': ('NetworkInterfaces', 'PrivateIpAddresses', 'Association', 'IpOwnerId'), + 'network-interface.association.public-ip': ('NetworkInterfaces', 'Association', 'PublicIp'), + 'network-interface.association.ip-owner-id': ('NetworkInterfaces', 'Association', 'IpOwnerId'), + 'network-interface.association.allocation-id': ('ElasticGpuAssociations', 'ElasticGpuId'), + 'network-interface.association.association-id': ('ElasticGpuAssociations', 'ElasticGpuAssociationId'), + 'network-interface.attachment.attachment-id': ('NetworkInterfaces', 'Attachment', 'AttachmentId'), + 'network-interface.attachment.instance-id': ('InstanceId',), + 'network-interface.attachment.device-index': ('NetworkInterfaces', 'Attachment', 'DeviceIndex'), + 'network-interface.attachment.status': ('NetworkInterfaces', 'Attachment', 'Status'), + 'network-interface.attachment.attach-time': ('NetworkInterfaces', 'Attachment', 'AttachTime'), + 'network-interface.attachment.delete-on-termination': ('NetworkInterfaces', 'Attachment', 'DeleteOnTermination'), + 'network-interface.availability-zone': ('Placement', 'AvailabilityZone'), + 'network-interface.description': ('NetworkInterfaces', 'Description'), + 'network-interface.group-id': ('NetworkInterfaces', 'Groups', 'GroupId'), + 'network-interface.group-name': ('NetworkInterfaces', 'Groups', 'GroupName'), + 'network-interface.ipv6-addresses.ipv6-address': ('NetworkInterfaces', 'Ipv6Addresses', 'Ipv6Address'), + 'network-interface.mac-address': ('NetworkInterfaces', 'MacAddress'), + 'network-interface.network-interface-id': ('NetworkInterfaces', 'NetworkInterfaceId'), + 'network-interface.owner-id': ('NetworkInterfaces', 'OwnerId'), + 'network-interface.private-dns-name': ('NetworkInterfaces', 'PrivateDnsName'), + # 'network-interface.requester-id': (), + 'network-interface.requester-managed': ('NetworkInterfaces', 'Association', 'IpOwnerId'), + 'network-interface.status': ('NetworkInterfaces', 'Status'), + 'network-interface.source-dest-check': ('NetworkInterfaces', 'SourceDestCheck'), + 'network-interface.subnet-id': ('NetworkInterfaces', 'SubnetId'), + 'network-interface.vpc-id': ('NetworkInterfaces', 'VpcId'), + 'placement-group-name': ('Placement', 'GroupName'), + 'platform': ('Platform',), + 'private-dns-name': ('PrivateDnsName',), + 'private-ip-address': ('PrivateIpAddress',), + 'product-code': ('ProductCodes', 'ProductCodeId'), + 'product-code.type': ('ProductCodes', 'ProductCodeType'), + 'ramdisk-id': ('RamdiskId',), + 'reason': ('StateTransitionReason',), + 'root-device-name': ('RootDeviceName',), + 'root-device-type': ('RootDeviceType',), + 'source-dest-check': ('SourceDestCheck',), + 'spot-instance-request-id': ('SpotInstanceRequestId',), + 'state-reason-code': ('StateReason', 'Code'), + 'state-reason-message': ('StateReason', 'Message'), + 'subnet-id': ('SubnetId',), + 'tag': ('Tags',), + 'tag-key': ('Tags',), + 'tag-value': ('Tags',), + 'tenancy': ('Placement', 'Tenancy'), + 'virtualization-type': ('VirtualizationType',), + 'vpc-id': ('VpcId',), +} + + +class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): + + NAME = 'aws_ec2' + + def __init__(self): + super(InventoryModule, self).__init__() + + self.group_prefix = 'aws_ec2_' + + # credentials + self.boto_profile = None + self.aws_secret_access_key = None + self.aws_access_key_id = None + self.aws_security_token = None + self.iam_role_arn = None + + def _compile_values(self, obj, attr): + ''' + :param obj: A list or dict of instance attributes + :param attr: A key + :return The value(s) found via the attr + ''' + if obj is None: + return + + temp_obj = [] + + if isinstance(obj, list) or isinstance(obj, tuple): + for each in obj: + value = self._compile_values(each, attr) + if value: + temp_obj.append(value) + else: + temp_obj = obj.get(attr) + + has_indexes = any([isinstance(temp_obj, list), isinstance(temp_obj, tuple)]) + if has_indexes and len(temp_obj) == 1: + return temp_obj[0] + + return temp_obj + + def _get_boto_attr_chain(self, filter_name, instance): + ''' + :param filter_name: The filter + :param instance: instance dict returned by boto3 ec2 describe_instances() + ''' + allowed_filters = sorted(list(instance_data_filter_to_boto_attr.keys()) + list(instance_meta_filter_to_boto_attr.keys())) + if filter_name not in allowed_filters: + raise AnsibleError("Invalid filter '%s' provided; filter must be one of %s." % (filter_name, + allowed_filters)) + if filter_name in instance_data_filter_to_boto_attr: + boto_attr_list = instance_data_filter_to_boto_attr[filter_name] + else: + boto_attr_list = instance_meta_filter_to_boto_attr[filter_name] + + instance_value = instance + for attribute in boto_attr_list: + instance_value = self._compile_values(instance_value, attribute) + return instance_value + + def _get_credentials(self): + ''' + :return A dictionary of boto client credentials + ''' + boto_params = {} + for credential in (('aws_access_key_id', self.aws_access_key_id), + ('aws_secret_access_key', self.aws_secret_access_key), + ('aws_session_token', self.aws_security_token)): + if credential[1]: + boto_params[credential[0]] = credential[1] + + return boto_params + + def _get_connection(self, credentials, region='us-east-1'): + try: + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region, **credentials) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + if self.boto_profile: + try: + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + else: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + return connection + + def _boto3_assume_role(self, credentials, region): + """ + Assume an IAM role passed by iam_role_arn parameter + + :return: a dict containing the credentials of the assumed role + """ + + iam_role_arn = self.iam_role_arn + + try: + sts_connection = boto3.session.Session(profile_name=self.boto_profile).client('sts', region, **credentials) + sts_session = sts_connection.assume_role(RoleArn=iam_role_arn, RoleSessionName='ansible_aws_ec2_dynamic_inventory') + return dict( + aws_access_key_id=sts_session['Credentials']['AccessKeyId'], + aws_secret_access_key=sts_session['Credentials']['SecretAccessKey'], + aws_session_token=sts_session['Credentials']['SessionToken'] + ) + except botocore.exceptions.ClientError as e: + raise AnsibleError("Unable to assume IAM role: %s" % to_native(e)) + + def _boto3_conn(self, regions): + ''' + :param regions: A list of regions to create a boto3 client + + Generator that yields a boto3 client and the region + ''' + + credentials = self._get_credentials() + iam_role_arn = self.iam_role_arn + + if not regions: + try: + # as per https://boto3.amazonaws.com/v1/documentation/api/latest/guide/ec2-example-regions-avail-zones.html + client = self._get_connection(credentials) + resp = client.describe_regions() + regions = [x['RegionName'] for x in resp.get('Regions', [])] + except botocore.exceptions.NoRegionError: + # above seems to fail depending on boto3 version, ignore and lets try something else + pass + + # fallback to local list hardcoded in boto3 if still no regions + if not regions: + session = boto3.Session() + regions = session.get_available_regions('ec2') + + # I give up, now you MUST give me regions + if not regions: + raise AnsibleError('Unable to get regions list from available methods, you must specify the "regions" option to continue.') + + for region in regions: + connection = self._get_connection(credentials, region) + try: + if iam_role_arn is not None: + assumed_credentials = self._boto3_assume_role(credentials, region) + else: + assumed_credentials = credentials + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region, **assumed_credentials) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + if self.boto_profile: + try: + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + else: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + yield connection, region + + def _get_instances_by_region(self, regions, filters, strict_permissions): + ''' + :param regions: a list of regions in which to describe instances + :param filters: a list of boto3 filter dictionaries + :param strict_permissions: a boolean determining whether to fail or ignore 403 error codes + :return A list of instance dictionaries + ''' + all_instances = [] + + for connection, region in self._boto3_conn(regions): + try: + # By default find non-terminated/terminating instances + if not any([f['Name'] == 'instance-state-name' for f in filters]): + filters.append({'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopping', 'stopped']}) + paginator = connection.get_paginator('describe_instances') + reservations = paginator.paginate(Filters=filters).build_full_result().get('Reservations') + instances = [] + for r in reservations: + new_instances = r['Instances'] + for instance in new_instances: + instance.update(self._get_reservation_details(r)) + if self.get_option('include_extra_api_calls'): + instance.update(self._get_event_set_and_persistence(connection, instance['InstanceId'], instance.get('SpotInstanceRequestId'))) + instances.extend(new_instances) + except botocore.exceptions.ClientError as e: + if e.response['ResponseMetadata']['HTTPStatusCode'] == 403 and not strict_permissions: + instances = [] + else: + raise AnsibleError("Failed to describe instances: %s" % to_native(e)) + except botocore.exceptions.BotoCoreError as e: + raise AnsibleError("Failed to describe instances: %s" % to_native(e)) + + all_instances.extend(instances) + + return sorted(all_instances, key=lambda x: x['InstanceId']) + + def _get_reservation_details(self, reservation): + return { + 'OwnerId': reservation['OwnerId'], + 'RequesterId': reservation.get('RequesterId', ''), + 'ReservationId': reservation['ReservationId'] + } + + def _get_event_set_and_persistence(self, connection, instance_id, spot_instance): + host_vars = {'Events': '', 'Persistent': False} + try: + kwargs = {'InstanceIds': [instance_id]} + host_vars['Events'] = connection.describe_instance_status(**kwargs)['InstanceStatuses'][0].get('Events', '') + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + if not self.get_option('strict_permissions'): + pass + else: + raise AnsibleError("Failed to describe instance status: %s" % to_native(e)) + if spot_instance: + try: + kwargs = {'SpotInstanceRequestIds': [spot_instance]} + host_vars['Persistent'] = bool( + connection.describe_spot_instance_requests(**kwargs)['SpotInstanceRequests'][0].get('Type') == 'persistent' + ) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + if not self.get_option('strict_permissions'): + pass + else: + raise AnsibleError("Failed to describe spot instance requests: %s" % to_native(e)) + return host_vars + + def _get_tag_hostname(self, preference, instance): + tag_hostnames = preference.split('tag:', 1)[1] + if ',' in tag_hostnames: + tag_hostnames = tag_hostnames.split(',') + else: + tag_hostnames = [tag_hostnames] + tags = boto3_tag_list_to_ansible_dict(instance.get('Tags', [])) + for v in tag_hostnames: + if '=' in v: + tag_name, tag_value = v.split('=') + if tags.get(tag_name) == tag_value: + return to_text(tag_name) + "_" + to_text(tag_value) + else: + tag_value = tags.get(v) + if tag_value: + return to_text(tag_value) + return None + + def _get_hostname(self, instance, hostnames): + ''' + :param instance: an instance dict returned by boto3 ec2 describe_instances() + :param hostnames: a list of hostname destination variables in order of preference + :return the preferred identifer for the host + ''' + if not hostnames: + hostnames = ['dns-name', 'private-dns-name'] + + hostname = None + for preference in hostnames: + if 'tag' in preference: + if not preference.startswith('tag:'): + raise AnsibleError("To name a host by tags name_value, use 'tag:name=value'.") + hostname = self._get_tag_hostname(preference, instance) + else: + hostname = self._get_boto_attr_chain(preference, instance) + if hostname: + break + if hostname: + if ':' in to_text(hostname): + return self._sanitize_group_name((to_text(hostname))) + else: + return to_text(hostname) + + def _query(self, regions, filters, strict_permissions): + ''' + :param regions: a list of regions to query + :param filters: a list of boto3 filter dictionaries + :param hostnames: a list of hostname destination variables in order of preference + :param strict_permissions: a boolean determining whether to fail or ignore 403 error codes + ''' + return {'aws_ec2': self._get_instances_by_region(regions, filters, strict_permissions)} + + def _populate(self, groups, hostnames): + for group in groups: + group = self.inventory.add_group(group) + self._add_hosts(hosts=groups[group], group=group, hostnames=hostnames) + self.inventory.add_child('all', group) + + def _add_hosts(self, hosts, group, hostnames): + ''' + :param hosts: a list of hosts to be added to a group + :param group: the name of the group to which the hosts belong + :param hostnames: a list of hostname destination variables in order of preference + ''' + for host in hosts: + hostname = self._get_hostname(host, hostnames) + + host = camel_dict_to_snake_dict(host, ignore_list=['Tags']) + host['tags'] = boto3_tag_list_to_ansible_dict(host.get('tags', [])) + + # Allow easier grouping by region + host['placement']['region'] = host['placement']['availability_zone'][:-1] + + if not hostname: + continue + self.inventory.add_host(hostname, group=group) + for hostvar, hostval in host.items(): + self.inventory.set_variable(hostname, hostvar, hostval) + + # Use constructed if applicable + + strict = self.get_option('strict') + + # Composed variables + self._set_composite_vars(self.get_option('compose'), host, hostname, strict=strict) + + # Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group + self._add_host_to_composed_groups(self.get_option('groups'), host, hostname, strict=strict) + + # Create groups based on variable values and add the corresponding hosts to it + self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host, hostname, strict=strict) + + def _set_credentials(self): + ''' + :param config_data: contents of the inventory config file + ''' + + self.boto_profile = self.get_option('aws_profile') + self.aws_access_key_id = self.get_option('aws_access_key') + self.aws_secret_access_key = self.get_option('aws_secret_key') + self.aws_security_token = self.get_option('aws_security_token') + self.iam_role_arn = self.get_option('iam_role_arn') + + if not self.boto_profile and not (self.aws_access_key_id and self.aws_secret_access_key): + session = botocore.session.get_session() + try: + credentials = session.get_credentials().get_frozen_credentials() + except AttributeError: + pass + else: + self.aws_access_key_id = credentials.access_key + self.aws_secret_access_key = credentials.secret_key + self.aws_security_token = credentials.token + + if not self.boto_profile and not (self.aws_access_key_id and self.aws_secret_access_key): + raise AnsibleError("Insufficient boto credentials found. Please provide them in your " + "inventory configuration file or set them as environment variables.") + + def verify_file(self, path): + ''' + :param loader: an ansible.parsing.dataloader.DataLoader object + :param path: the path to the inventory config file + :return the contents of the config file + ''' + if super(InventoryModule, self).verify_file(path): + if path.endswith(('aws_ec2.yml', 'aws_ec2.yaml')): + return True + display.debug("aws_ec2 inventory filename must end with 'aws_ec2.yml' or 'aws_ec2.yaml'") + return False + + def parse(self, inventory, loader, path, cache=True): + + super(InventoryModule, self).parse(inventory, loader, path) + + self._read_config_data(path) + + if self.get_option('use_contrib_script_compatible_sanitization'): + self._sanitize_group_name = self._legacy_script_compatible_group_sanitization + + self._set_credentials() + + # get user specifications + regions = self.get_option('regions') + filters = ansible_dict_to_boto3_filter_list(self.get_option('filters')) + hostnames = self.get_option('hostnames') + strict_permissions = self.get_option('strict_permissions') + + cache_key = self.get_cache_key(path) + # false when refresh_cache or --flush-cache is used + if cache: + # get the user-specified directive + cache = self.get_option('cache') + + # Generate inventory + cache_needs_update = False + if cache: + try: + results = self._cache[cache_key] + except KeyError: + # if cache expires or cache file doesn't exist + cache_needs_update = True + + if not cache or cache_needs_update: + results = self._query(regions, filters, strict_permissions) + + self._populate(results, hostnames) + + # If the cache has expired/doesn't exist or if refresh_inventory/flush cache is used + # when the user is using caching, update the cached inventory + if cache_needs_update or (not cache and self.get_option('cache')): + self._cache[cache_key] = results + + @staticmethod + def _legacy_script_compatible_group_sanitization(name): + + # note that while this mirrors what the script used to do, it has many issues with unicode and usability in python + regex = re.compile(r"[^A-Za-z0-9\_\-]") + + return regex.sub('_', name) + + +def ansible_dict_to_boto3_filter_list(filters_dict): + + """ Convert an Ansible dict of filters to list of dicts that boto3 can use + Args: + filters_dict (dict): Dict of AWS filters. + Basic Usage: + >>> filters = {'some-aws-id': 'i-01234567'} + >>> ansible_dict_to_boto3_filter_list(filters) + { + 'some-aws-id': 'i-01234567' + } + Returns: + List: List of AWS filters and their values + [ + { + 'Name': 'some-aws-id', + 'Values': [ + 'i-01234567', + ] + } + ] + """ + + filters_list = [] + for k, v in filters_dict.items(): + filter_dict = {'Name': k} + if isinstance(v, string_types): + filter_dict['Values'] = [v] + else: + filter_dict['Values'] = v + + filters_list.append(filter_dict) + + return filters_list + + +def boto3_tag_list_to_ansible_dict(tags_list, tag_name_key_name=None, tag_value_key_name=None): + + """ Convert a boto3 list of resource tags to a flat dict of key:value pairs + Args: + tags_list (list): List of dicts representing AWS tags. + tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key") + tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value") + Basic Usage: + >>> tags_list = [{'Key': 'MyTagKey', 'Value': 'MyTagValue'}] + >>> boto3_tag_list_to_ansible_dict(tags_list) + [ + { + 'Key': 'MyTagKey', + 'Value': 'MyTagValue' + } + ] + Returns: + Dict: Dict of key:value pairs representing AWS tags + { + 'MyTagKey': 'MyTagValue', + } + """ + + if tag_name_key_name and tag_value_key_name: + tag_candidates = {tag_name_key_name: tag_value_key_name} + else: + tag_candidates = {'key': 'value', 'Key': 'Value'} + + if not tags_list: + return {} + for k, v in tag_candidates.items(): + if k in tags_list[0] and v in tags_list[0]: + return dict((tag[k], tag[v]) for tag in tags_list) + raise ValueError("Couldn't find tag key (candidates %s) in tag list %s" % (str(tag_candidates), str(tags_list))) diff --git a/test/support/integration/plugins/inventory/docker_swarm.py b/test/support/integration/plugins/inventory/docker_swarm.py new file mode 100644 index 00000000..d0a95ca0 --- /dev/null +++ b/test/support/integration/plugins/inventory/docker_swarm.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2018, Stefan Heitmueller <stefan.heitmueller@gmx.com> +# Copyright (c) 2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) + +__metaclass__ = type + +DOCUMENTATION = ''' + name: docker_swarm + plugin_type: inventory + version_added: '2.8' + author: + - Stefan Heitmüller (@morph027) <stefan.heitmueller@gmx.com> + short_description: Ansible dynamic inventory plugin for Docker swarm nodes. + requirements: + - python >= 2.7 + - L(Docker SDK for Python,https://docker-py.readthedocs.io/en/stable/) >= 1.10.0 + extends_documentation_fragment: + - constructed + description: + - Reads inventories from the Docker swarm API. + - Uses a YAML configuration file docker_swarm.[yml|yaml]. + - "The plugin returns following groups of swarm nodes: I(all) - all hosts; I(workers) - all worker nodes; + I(managers) - all manager nodes; I(leader) - the swarm leader node; + I(nonleaders) - all nodes except the swarm leader." + options: + plugin: + description: The name of this plugin, it should always be set to C(docker_swarm) for this plugin to + recognize it as it's own. + type: str + required: true + choices: docker_swarm + docker_host: + description: + - Socket of a Docker swarm manager node (C(tcp), C(unix)). + - "Use C(unix://var/run/docker.sock) to connect via local socket." + type: str + required: true + aliases: [ docker_url ] + verbose_output: + description: Toggle to (not) include all available nodes metadata (e.g. C(Platform), C(Architecture), C(OS), + C(EngineVersion)) + type: bool + default: yes + tls: + description: Connect using TLS without verifying the authenticity of the Docker host server. + type: bool + default: no + validate_certs: + description: Toggle if connecting using TLS with or without verifying the authenticity of the Docker + host server. + type: bool + default: no + aliases: [ tls_verify ] + client_key: + description: Path to the client's TLS key file. + type: path + aliases: [ tls_client_key, key_path ] + ca_cert: + description: Use a CA certificate when performing server verification by providing the path to a CA + certificate file. + type: path + aliases: [ tls_ca_cert, cacert_path ] + client_cert: + description: Path to the client's TLS certificate file. + type: path + aliases: [ tls_client_cert, cert_path ] + tls_hostname: + description: When verifying the authenticity of the Docker host server, provide the expected name of + the server. + type: str + ssl_version: + description: Provide a valid SSL version number. Default value determined by ssl.py module. + type: str + api_version: + description: + - The version of the Docker API running on the Docker Host. + - Defaults to the latest version of the API supported by docker-py. + type: str + aliases: [ docker_api_version ] + timeout: + description: + - The maximum amount of time in seconds to wait on a response from the API. + - If the value is not specified in the task, the value of environment variable C(DOCKER_TIMEOUT) + will be used instead. If the environment variable is not set, the default value will be used. + type: int + default: 60 + aliases: [ time_out ] + include_host_uri: + description: Toggle to return the additional attribute C(ansible_host_uri) which contains the URI of the + swarm leader in format of C(tcp://172.16.0.1:2376). This value may be used without additional + modification as value of option I(docker_host) in Docker Swarm modules when connecting via API. + The port always defaults to C(2376). + type: bool + default: no + include_host_uri_port: + description: Override the detected port number included in I(ansible_host_uri) + type: int +''' + +EXAMPLES = ''' +# Minimal example using local docker +plugin: docker_swarm +docker_host: unix://var/run/docker.sock + +# Minimal example using remote docker +plugin: docker_swarm +docker_host: tcp://my-docker-host:2375 + +# Example using remote docker with unverified TLS +plugin: docker_swarm +docker_host: tcp://my-docker-host:2376 +tls: yes + +# Example using remote docker with verified TLS and client certificate verification +plugin: docker_swarm +docker_host: tcp://my-docker-host:2376 +validate_certs: yes +ca_cert: /somewhere/ca.pem +client_key: /somewhere/key.pem +client_cert: /somewhere/cert.pem + +# Example using constructed features to create groups and set ansible_host +plugin: docker_swarm +docker_host: tcp://my-docker-host:2375 +strict: False +keyed_groups: + # add e.g. x86_64 hosts to an arch_x86_64 group + - prefix: arch + key: 'Description.Platform.Architecture' + # add e.g. linux hosts to an os_linux group + - prefix: os + key: 'Description.Platform.OS' + # create a group per node label + # e.g. a node labeled w/ "production" ends up in group "label_production" + # hint: labels containing special characters will be converted to safe names + - key: 'Spec.Labels' + prefix: label +''' + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_native +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ansible.plugins.inventory import BaseInventoryPlugin, Constructable +from ansible.parsing.utils.addresses import parse_address + +try: + import docker + from docker.errors import TLSParameterError + from docker.tls import TLSConfig + HAS_DOCKER = True +except ImportError: + HAS_DOCKER = False + + +def update_tls_hostname(result): + if result['tls_hostname'] is None: + # get default machine name from the url + parsed_url = urlparse(result['docker_host']) + if ':' in parsed_url.netloc: + result['tls_hostname'] = parsed_url.netloc[:parsed_url.netloc.rindex(':')] + else: + result['tls_hostname'] = parsed_url + + +def _get_tls_config(fail_function, **kwargs): + try: + tls_config = TLSConfig(**kwargs) + return tls_config + except TLSParameterError as exc: + fail_function("TLS config error: %s" % exc) + + +def get_connect_params(auth, fail_function): + if auth['tls'] or auth['tls_verify']: + auth['docker_host'] = auth['docker_host'].replace('tcp://', 'https://') + + if auth['tls_verify'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and host verification + if auth['cacert_path']: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + ca_cert=auth['cacert_path'], + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + else: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify'] and auth['cacert_path']: + # TLS with cacert only + tls_config = _get_tls_config(ca_cert=auth['cacert_path'], + assert_hostname=auth['tls_hostname'], + verify=True, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify']: + # TLS with verify and no certs + tls_config = _get_tls_config(verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and no host verification + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls']: + # TLS with no certs and not host verification + tls_config = _get_tls_config(verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + # No TLS + return dict(base_url=auth['docker_host'], + version=auth['api_version'], + timeout=auth['timeout']) + + +class InventoryModule(BaseInventoryPlugin, Constructable): + ''' Host inventory parser for ansible using Docker swarm as source. ''' + + NAME = 'docker_swarm' + + def _fail(self, msg): + raise AnsibleError(msg) + + def _populate(self): + raw_params = dict( + docker_host=self.get_option('docker_host'), + tls=self.get_option('tls'), + tls_verify=self.get_option('validate_certs'), + key_path=self.get_option('client_key'), + cacert_path=self.get_option('ca_cert'), + cert_path=self.get_option('client_cert'), + tls_hostname=self.get_option('tls_hostname'), + api_version=self.get_option('api_version'), + timeout=self.get_option('timeout'), + ssl_version=self.get_option('ssl_version'), + debug=None, + ) + update_tls_hostname(raw_params) + connect_params = get_connect_params(raw_params, fail_function=self._fail) + self.client = docker.DockerClient(**connect_params) + self.inventory.add_group('all') + self.inventory.add_group('manager') + self.inventory.add_group('worker') + self.inventory.add_group('leader') + self.inventory.add_group('nonleaders') + + if self.get_option('include_host_uri'): + if self.get_option('include_host_uri_port'): + host_uri_port = str(self.get_option('include_host_uri_port')) + elif self.get_option('tls') or self.get_option('validate_certs'): + host_uri_port = '2376' + else: + host_uri_port = '2375' + + try: + self.nodes = self.client.nodes.list() + for self.node in self.nodes: + self.node_attrs = self.client.nodes.get(self.node.id).attrs + self.inventory.add_host(self.node_attrs['ID']) + self.inventory.add_host(self.node_attrs['ID'], group=self.node_attrs['Spec']['Role']) + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host', + self.node_attrs['Status']['Addr']) + if self.get_option('include_host_uri'): + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host_uri', + 'tcp://' + self.node_attrs['Status']['Addr'] + ':' + host_uri_port) + if self.get_option('verbose_output'): + self.inventory.set_variable(self.node_attrs['ID'], 'docker_swarm_node_attributes', self.node_attrs) + if 'ManagerStatus' in self.node_attrs: + if self.node_attrs['ManagerStatus'].get('Leader'): + # This is workaround of bug in Docker when in some cases the Leader IP is 0.0.0.0 + # Check moby/moby#35437 for details + swarm_leader_ip = parse_address(self.node_attrs['ManagerStatus']['Addr'])[0] or \ + self.node_attrs['Status']['Addr'] + if self.get_option('include_host_uri'): + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host_uri', + 'tcp://' + swarm_leader_ip + ':' + host_uri_port) + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host', swarm_leader_ip) + self.inventory.add_host(self.node_attrs['ID'], group='leader') + else: + self.inventory.add_host(self.node_attrs['ID'], group='nonleaders') + else: + self.inventory.add_host(self.node_attrs['ID'], group='nonleaders') + # Use constructed if applicable + strict = self.get_option('strict') + # Composed variables + self._set_composite_vars(self.get_option('compose'), + self.node_attrs, + self.node_attrs['ID'], + strict=strict) + # Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group + self._add_host_to_composed_groups(self.get_option('groups'), + self.node_attrs, + self.node_attrs['ID'], + strict=strict) + # Create groups based on variable values and add the corresponding hosts to it + self._add_host_to_keyed_groups(self.get_option('keyed_groups'), + self.node_attrs, + self.node_attrs['ID'], + strict=strict) + except Exception as e: + raise AnsibleError('Unable to fetch hosts from Docker swarm API, this was the original exception: %s' % + to_native(e)) + + def verify_file(self, path): + """Return the possibly of a file being consumable by this plugin.""" + return ( + super(InventoryModule, self).verify_file(path) and + path.endswith((self.NAME + '.yaml', self.NAME + '.yml'))) + + def parse(self, inventory, loader, path, cache=True): + if not HAS_DOCKER: + raise AnsibleError('The Docker swarm dynamic inventory plugin requires the Docker SDK for Python: ' + 'https://github.com/docker/docker-py.') + super(InventoryModule, self).parse(inventory, loader, path, cache) + self._read_config_data(path) + self._populate() diff --git a/test/support/integration/plugins/inventory/foreman.py b/test/support/integration/plugins/inventory/foreman.py new file mode 100644 index 00000000..43073f81 --- /dev/null +++ b/test/support/integration/plugins/inventory/foreman.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2016 Guido Günther <agx@sigxcpu.org>, Daniel Lobato Garcia <dlobatog@redhat.com> +# Copyright (c) 2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = ''' + name: foreman + plugin_type: inventory + short_description: foreman inventory source + version_added: "2.6" + requirements: + - requests >= 1.1 + description: + - Get inventory hosts from the foreman service. + - "Uses a configuration file as an inventory source, it must end in ``.foreman.yml`` or ``.foreman.yaml`` and has a ``plugin: foreman`` entry." + extends_documentation_fragment: + - inventory_cache + - constructed + options: + plugin: + description: the name of this plugin, it should always be set to 'foreman' for this plugin to recognize it as it's own. + required: True + choices: ['foreman'] + url: + description: url to foreman + default: 'http://localhost:3000' + env: + - name: FOREMAN_SERVER + version_added: "2.8" + user: + description: foreman authentication user + required: True + env: + - name: FOREMAN_USER + version_added: "2.8" + password: + description: foreman authentication password + required: True + env: + - name: FOREMAN_PASSWORD + version_added: "2.8" + validate_certs: + description: verify SSL certificate if using https + type: boolean + default: False + group_prefix: + description: prefix to apply to foreman groups + default: foreman_ + vars_prefix: + description: prefix to apply to host variables, does not include facts nor params + default: foreman_ + want_facts: + description: Toggle, if True the plugin will retrieve host facts from the server + type: boolean + default: False + want_params: + description: Toggle, if true the inventory will retrieve 'all_parameters' information as host vars + type: boolean + default: False + want_hostcollections: + description: Toggle, if true the plugin will create Ansible groups for host collections + type: boolean + default: False + version_added: '2.10' + want_ansible_ssh_host: + description: Toggle, if true the plugin will populate the ansible_ssh_host variable to explicitly specify the connection target + type: boolean + default: False + version_added: '2.10' + +''' + +EXAMPLES = ''' +# my.foreman.yml +plugin: foreman +url: http://localhost:2222 +user: ansible-tester +password: secure +validate_certs: False +''' + +from distutils.version import LooseVersion + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.common._collections_compat import MutableMapping +from ansible.plugins.inventory import BaseInventoryPlugin, Cacheable, to_safe_group_name, Constructable + +# 3rd party imports +try: + import requests + if LooseVersion(requests.__version__) < LooseVersion('1.1.0'): + raise ImportError +except ImportError: + raise AnsibleError('This script requires python-requests 1.1 as a minimum version') + +from requests.auth import HTTPBasicAuth + + +class InventoryModule(BaseInventoryPlugin, Cacheable, Constructable): + ''' Host inventory parser for ansible using foreman as source. ''' + + NAME = 'foreman' + + def __init__(self): + + super(InventoryModule, self).__init__() + + # from config + self.foreman_url = None + + self.session = None + self.cache_key = None + self.use_cache = None + + def verify_file(self, path): + + valid = False + if super(InventoryModule, self).verify_file(path): + if path.endswith(('foreman.yaml', 'foreman.yml')): + valid = True + else: + self.display.vvv('Skipping due to inventory source not ending in "foreman.yaml" nor "foreman.yml"') + return valid + + def _get_session(self): + if not self.session: + self.session = requests.session() + self.session.auth = HTTPBasicAuth(self.get_option('user'), to_bytes(self.get_option('password'))) + self.session.verify = self.get_option('validate_certs') + return self.session + + def _get_json(self, url, ignore_errors=None): + + if not self.use_cache or url not in self._cache.get(self.cache_key, {}): + + if self.cache_key not in self._cache: + self._cache[self.cache_key] = {url: ''} + + results = [] + s = self._get_session() + params = {'page': 1, 'per_page': 250} + while True: + ret = s.get(url, params=params) + if ignore_errors and ret.status_code in ignore_errors: + break + ret.raise_for_status() + json = ret.json() + + # process results + # FIXME: This assumes 'return type' matches a specific query, + # it will break if we expand the queries and they dont have different types + if 'results' not in json: + # /hosts/:id dos not have a 'results' key + results = json + break + elif isinstance(json['results'], MutableMapping): + # /facts are returned as dict in 'results' + results = json['results'] + break + else: + # /hosts 's 'results' is a list of all hosts, returned is paginated + results = results + json['results'] + + # check for end of paging + if len(results) >= json['subtotal']: + break + if len(json['results']) == 0: + self.display.warning("Did not make any progress during loop. expected %d got %d" % (json['subtotal'], len(results))) + break + + # get next page + params['page'] += 1 + + self._cache[self.cache_key][url] = results + + return self._cache[self.cache_key][url] + + def _get_hosts(self): + return self._get_json("%s/api/v2/hosts" % self.foreman_url) + + def _get_all_params_by_id(self, hid): + url = "%s/api/v2/hosts/%s" % (self.foreman_url, hid) + ret = self._get_json(url, [404]) + if not ret or not isinstance(ret, MutableMapping) or not ret.get('all_parameters', False): + return {} + return ret.get('all_parameters') + + def _get_facts_by_id(self, hid): + url = "%s/api/v2/hosts/%s/facts" % (self.foreman_url, hid) + return self._get_json(url) + + def _get_host_data_by_id(self, hid): + url = "%s/api/v2/hosts/%s" % (self.foreman_url, hid) + return self._get_json(url) + + def _get_facts(self, host): + """Fetch all host facts of the host""" + + ret = self._get_facts_by_id(host['id']) + if len(ret.values()) == 0: + facts = {} + elif len(ret.values()) == 1: + facts = list(ret.values())[0] + else: + raise ValueError("More than one set of facts returned for '%s'" % host) + return facts + + def _populate(self): + + for host in self._get_hosts(): + + if host.get('name'): + host_name = self.inventory.add_host(host['name']) + + # create directly mapped groups + group_name = host.get('hostgroup_title', host.get('hostgroup_name')) + if group_name: + group_name = to_safe_group_name('%s%s' % (self.get_option('group_prefix'), group_name.lower().replace(" ", ""))) + group_name = self.inventory.add_group(group_name) + self.inventory.add_child(group_name, host_name) + + # set host vars from host info + try: + for k, v in host.items(): + if k not in ('name', 'hostgroup_title', 'hostgroup_name'): + try: + self.inventory.set_variable(host_name, self.get_option('vars_prefix') + k, v) + except ValueError as e: + self.display.warning("Could not set host info hostvar for %s, skipping %s: %s" % (host, k, to_text(e))) + except ValueError as e: + self.display.warning("Could not get host info for %s, skipping: %s" % (host_name, to_text(e))) + + # set host vars from params + if self.get_option('want_params'): + for p in self._get_all_params_by_id(host['id']): + try: + self.inventory.set_variable(host_name, p['name'], p['value']) + except ValueError as e: + self.display.warning("Could not set hostvar %s to '%s' for the '%s' host, skipping: %s" % + (p['name'], to_native(p['value']), host, to_native(e))) + + # set host vars from facts + if self.get_option('want_facts'): + self.inventory.set_variable(host_name, 'foreman_facts', self._get_facts(host)) + + # create group for host collections + if self.get_option('want_hostcollections'): + host_data = self._get_host_data_by_id(host['id']) + hostcollections = host_data.get('host_collections') + if hostcollections: + # Create Ansible groups for host collections + for hostcollection in hostcollections: + try: + hostcollection_group = to_safe_group_name('%shostcollection_%s' % (self.get_option('group_prefix'), + hostcollection['name'].lower().replace(" ", ""))) + hostcollection_group = self.inventory.add_group(hostcollection_group) + self.inventory.add_child(hostcollection_group, host_name) + except ValueError as e: + self.display.warning("Could not create groups for host collections for %s, skipping: %s" % (host_name, to_text(e))) + + # put ansible_ssh_host as hostvar + if self.get_option('want_ansible_ssh_host'): + for key in ('ip', 'ipv4', 'ipv6'): + if host.get(key): + try: + self.inventory.set_variable(host_name, 'ansible_ssh_host', host[key]) + break + except ValueError as e: + self.display.warning("Could not set hostvar ansible_ssh_host to '%s' for the '%s' host, skipping: %s" % + (host[key], host_name, to_text(e))) + + strict = self.get_option('strict') + + hostvars = self.inventory.get_host(host_name).get_vars() + self._set_composite_vars(self.get_option('compose'), hostvars, host_name, strict) + self._add_host_to_composed_groups(self.get_option('groups'), hostvars, host_name, strict) + self._add_host_to_keyed_groups(self.get_option('keyed_groups'), hostvars, host_name, strict) + + def parse(self, inventory, loader, path, cache=True): + + super(InventoryModule, self).parse(inventory, loader, path) + + # read config from file, this sets 'options' + self._read_config_data(path) + + # get connection host + self.foreman_url = self.get_option('url') + self.cache_key = self.get_cache_key(path) + self.use_cache = cache and self.get_option('cache') + + # actually populate inventory + self._populate() diff --git a/test/support/integration/plugins/lookup/rabbitmq.py b/test/support/integration/plugins/lookup/rabbitmq.py new file mode 100644 index 00000000..7c2745f4 --- /dev/null +++ b/test/support/integration/plugins/lookup/rabbitmq.py @@ -0,0 +1,190 @@ +# (c) 2018, John Imison <john+github@imison.net> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = """ + lookup: rabbitmq + author: John Imison <@Im0> + version_added: "2.8" + short_description: Retrieve messages from an AMQP/AMQPS RabbitMQ queue. + description: + - This lookup uses a basic get to retrieve all, or a limited number C(count), messages from a RabbitMQ queue. + options: + url: + description: + - An URI connection string to connect to the AMQP/AMQPS RabbitMQ server. + - For more information refer to the URI spec U(https://www.rabbitmq.com/uri-spec.html). + required: True + queue: + description: + - The queue to get messages from. + required: True + count: + description: + - How many messages to collect from the queue. + - If not set, defaults to retrieving all the messages from the queue. + requirements: + - The python pika package U(https://pypi.org/project/pika/). + notes: + - This lookup implements BlockingChannel.basic_get to get messages from a RabbitMQ server. + - After retrieving a message from the server, receipt of the message is acknowledged and the message on the server is deleted. + - Pika is a pure-Python implementation of the AMQP 0-9-1 protocol that tries to stay fairly independent of the underlying network support library. + - More information about pika can be found at U(https://pika.readthedocs.io/en/stable/). + - This plugin is tested against RabbitMQ. Other AMQP 0.9.1 protocol based servers may work but not tested/guaranteed. + - Assigning the return messages to a variable under C(vars) may result in unexpected results as the lookup is evaluated every time the + variable is referenced. + - Currently this plugin only handles text based messages from a queue. Unexpected results may occur when retrieving binary data. +""" + + +EXAMPLES = """ +- name: Get all messages off a queue + debug: + msg: "{{ lookup('rabbitmq', url='amqp://guest:guest@192.168.0.10:5672/%2F', queue='hello') }}" + + +# If you are intending on using the returned messages as a variable in more than +# one task (eg. debug, template), it is recommended to set_fact. + +- name: Get 2 messages off a queue and set a fact for re-use + set_fact: + messages: "{{ lookup('rabbitmq', url='amqp://guest:guest@192.168.0.10:5672/%2F', queue='hello', count=2) }}" + +- name: Dump out contents of the messages + debug: + var: messages + +""" + +RETURN = """ + _list: + description: + - A list of dictionaries with keys and value from the queue. + type: list + contains: + content_type: + description: The content_type on the message in the queue. + type: str + delivery_mode: + description: The delivery_mode on the message in the queue. + type: str + delivery_tag: + description: The delivery_tag on the message in the queue. + type: str + exchange: + description: The exchange the message came from. + type: str + message_count: + description: The message_count for the message on the queue. + type: str + msg: + description: The content of the message. + type: str + redelivered: + description: The redelivered flag. True if the message has been delivered before. + type: bool + routing_key: + description: The routing_key on the message in the queue. + type: str + headers: + description: The headers for the message returned from the queue. + type: dict + json: + description: If application/json is specified in content_type, json will be loaded into variables. + type: dict + +""" + +import json + +from ansible.errors import AnsibleError, AnsibleParserError +from ansible.plugins.lookup import LookupBase +from ansible.module_utils._text import to_native, to_text +from ansible.utils.display import Display + +try: + import pika + from pika import spec + HAS_PIKA = True +except ImportError: + HAS_PIKA = False + +display = Display() + + +class LookupModule(LookupBase): + + def run(self, terms, variables=None, url=None, queue=None, count=None): + if not HAS_PIKA: + raise AnsibleError('pika python package is required for rabbitmq lookup.') + if not url: + raise AnsibleError('URL is required for rabbitmq lookup.') + if not queue: + raise AnsibleError('Queue is required for rabbitmq lookup.') + + display.vvv(u"terms:%s : variables:%s url:%s queue:%s count:%s" % (terms, variables, url, queue, count)) + + try: + parameters = pika.URLParameters(url) + except Exception as e: + raise AnsibleError("URL malformed: %s" % to_native(e)) + + try: + connection = pika.BlockingConnection(parameters) + except Exception as e: + raise AnsibleError("Connection issue: %s" % to_native(e)) + + try: + conn_channel = connection.channel() + except pika.exceptions.AMQPChannelError as e: + try: + connection.close() + except pika.exceptions.AMQPConnectionError as ie: + raise AnsibleError("Channel and connection closing issues: %s / %s" % to_native(e), to_native(ie)) + raise AnsibleError("Channel issue: %s" % to_native(e)) + + ret = [] + idx = 0 + + while True: + method_frame, properties, body = conn_channel.basic_get(queue=queue) + if method_frame: + display.vvv(u"%s, %s, %s " % (method_frame, properties, to_text(body))) + + # TODO: In the future consider checking content_type and handle text/binary data differently. + msg_details = dict({ + 'msg': to_text(body), + 'message_count': method_frame.message_count, + 'routing_key': method_frame.routing_key, + 'delivery_tag': method_frame.delivery_tag, + 'redelivered': method_frame.redelivered, + 'exchange': method_frame.exchange, + 'delivery_mode': properties.delivery_mode, + 'content_type': properties.content_type, + 'headers': properties.headers + }) + if properties.content_type == 'application/json': + try: + msg_details['json'] = json.loads(msg_details['msg']) + except ValueError as e: + raise AnsibleError("Unable to decode JSON for message %s: %s" % (method_frame.delivery_tag, to_native(e))) + + ret.append(msg_details) + conn_channel.basic_ack(method_frame.delivery_tag) + idx += 1 + if method_frame.message_count == 0 or idx == count: + break + # If we didn't get a method_frame, exit. + else: + break + + if connection.is_closed: + return [ret] + else: + try: + connection.close() + except pika.exceptions.AMQPConnectionError: + pass + return [ret] diff --git a/test/support/integration/plugins/module_utils/aws/__init__.py b/test/support/integration/plugins/module_utils/aws/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/__init__.py diff --git a/test/support/integration/plugins/module_utils/aws/core.py b/test/support/integration/plugins/module_utils/aws/core.py new file mode 100644 index 00000000..c4527b6d --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/core.py @@ -0,0 +1,335 @@ +# +# Copyright 2017 Michael De La Rue | Ansible +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +"""This module adds shared support for generic Amazon AWS modules + +**This code is not yet ready for use in user modules. As of 2017** +**and through to 2018, the interface is likely to change** +**aggressively as the exact correct interface for ansible AWS modules** +**is identified. In particular, until this notice goes away or is** +**changed, methods may disappear from the interface. Please don't** +**publish modules using this except directly to the main Ansible** +**development repository.** + +In order to use this module, include it as part of a custom +module as shown below. + + from ansible.module_utils.aws import AnsibleAWSModule + module = AnsibleAWSModule(argument_spec=dictionary, supports_check_mode=boolean + mutually_exclusive=list1, required_together=list2) + +The 'AnsibleAWSModule' module provides similar, but more restricted, +interfaces to the normal Ansible module. It also includes the +additional methods for connecting to AWS using the standard module arguments + + m.resource('lambda') # - get an AWS connection as a boto3 resource. + +or + + m.client('sts') # - get an AWS connection as a boto3 client. + +To make use of AWSRetry easier, it can now be wrapped around any call from a +module-created client. To add retries to a client, create a client: + + m.client('ec2', retry_decorator=AWSRetry.jittered_backoff(retries=10)) + +Any calls from that client can be made to use the decorator passed at call-time +using the `aws_retry` argument. By default, no retries are used. + + ec2 = m.client('ec2', retry_decorator=AWSRetry.jittered_backoff(retries=10)) + ec2.describe_instances(InstanceIds=['i-123456789'], aws_retry=True) + +The call will be retried the specified number of times, so the calling functions +don't need to be wrapped in the backoff decorator. +""" + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import logging +import traceback +from functools import wraps +from distutils.version import LooseVersion + +try: + from cStringIO import StringIO +except ImportError: + # Python 3 + from io import StringIO + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native +from ansible.module_utils.ec2 import HAS_BOTO3, camel_dict_to_snake_dict, ec2_argument_spec, boto3_conn +from ansible.module_utils.ec2 import get_aws_connection_info, get_aws_region + +# We will also export HAS_BOTO3 so end user modules can use it. +__all__ = ('AnsibleAWSModule', 'HAS_BOTO3', 'is_boto3_error_code') + + +class AnsibleAWSModule(object): + """An ansible module class for AWS modules + + AnsibleAWSModule provides an a class for building modules which + connect to Amazon Web Services. The interface is currently more + restricted than the basic module class with the aim that later the + basic module class can be reduced. If you find that any key + feature is missing please contact the author/Ansible AWS team + (available on #ansible-aws on IRC) to request the additional + features needed. + """ + default_settings = { + "default_args": True, + "check_boto3": True, + "auto_retry": True, + "module_class": AnsibleModule + } + + def __init__(self, **kwargs): + local_settings = {} + for key in AnsibleAWSModule.default_settings: + try: + local_settings[key] = kwargs.pop(key) + except KeyError: + local_settings[key] = AnsibleAWSModule.default_settings[key] + self.settings = local_settings + + if local_settings["default_args"]: + # ec2_argument_spec contains the region so we use that; there's a patch coming which + # will add it to aws_argument_spec so if that's accepted then later we should change + # over + argument_spec_full = ec2_argument_spec() + try: + argument_spec_full.update(kwargs["argument_spec"]) + except (TypeError, NameError): + pass + kwargs["argument_spec"] = argument_spec_full + + self._module = AnsibleAWSModule.default_settings["module_class"](**kwargs) + + if local_settings["check_boto3"] and not HAS_BOTO3: + self._module.fail_json( + msg=missing_required_lib('botocore or boto3')) + + self.check_mode = self._module.check_mode + self._diff = self._module._diff + self._name = self._module._name + + self._botocore_endpoint_log_stream = StringIO() + self.logger = None + if self.params.get('debug_botocore_endpoint_logs'): + self.logger = logging.getLogger('botocore.endpoint') + self.logger.setLevel(logging.DEBUG) + self.logger.addHandler(logging.StreamHandler(self._botocore_endpoint_log_stream)) + + @property + def params(self): + return self._module.params + + def _get_resource_action_list(self): + actions = [] + for ln in self._botocore_endpoint_log_stream.getvalue().split('\n'): + ln = ln.strip() + if not ln: + continue + found_operational_request = re.search(r"OperationModel\(name=.*?\)", ln) + if found_operational_request: + operation_request = found_operational_request.group(0)[20:-1] + resource = re.search(r"https://.*?\.", ln).group(0)[8:-1] + actions.append("{0}:{1}".format(resource, operation_request)) + return list(set(actions)) + + def exit_json(self, *args, **kwargs): + if self.params.get('debug_botocore_endpoint_logs'): + kwargs['resource_actions'] = self._get_resource_action_list() + return self._module.exit_json(*args, **kwargs) + + def fail_json(self, *args, **kwargs): + if self.params.get('debug_botocore_endpoint_logs'): + kwargs['resource_actions'] = self._get_resource_action_list() + return self._module.fail_json(*args, **kwargs) + + def debug(self, *args, **kwargs): + return self._module.debug(*args, **kwargs) + + def warn(self, *args, **kwargs): + return self._module.warn(*args, **kwargs) + + def deprecate(self, *args, **kwargs): + return self._module.deprecate(*args, **kwargs) + + def boolean(self, *args, **kwargs): + return self._module.boolean(*args, **kwargs) + + def md5(self, *args, **kwargs): + return self._module.md5(*args, **kwargs) + + def client(self, service, retry_decorator=None): + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(self, boto3=True) + conn = boto3_conn(self, conn_type='client', resource=service, + region=region, endpoint=ec2_url, **aws_connect_kwargs) + return conn if retry_decorator is None else _RetryingBotoClientWrapper(conn, retry_decorator) + + def resource(self, service): + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(self, boto3=True) + return boto3_conn(self, conn_type='resource', resource=service, + region=region, endpoint=ec2_url, **aws_connect_kwargs) + + @property + def region(self, boto3=True): + return get_aws_region(self, boto3) + + def fail_json_aws(self, exception, msg=None): + """call fail_json with processed exception + + function for converting exceptions thrown by AWS SDK modules, + botocore, boto3 and boto, into nice error messages. + """ + last_traceback = traceback.format_exc() + + # to_native is trusted to handle exceptions that str() could + # convert to text. + try: + except_msg = to_native(exception.message) + except AttributeError: + except_msg = to_native(exception) + + if msg is not None: + message = '{0}: {1}'.format(msg, except_msg) + else: + message = except_msg + + try: + response = exception.response + except AttributeError: + response = None + + failure = dict( + msg=message, + exception=last_traceback, + **self._gather_versions() + ) + + if response is not None: + failure.update(**camel_dict_to_snake_dict(response)) + + self.fail_json(**failure) + + def _gather_versions(self): + """Gather AWS SDK (boto3 and botocore) dependency versions + + Returns {'boto3_version': str, 'botocore_version': str} + Returns {} if neither are installed + """ + if not HAS_BOTO3: + return {} + import boto3 + import botocore + return dict(boto3_version=boto3.__version__, + botocore_version=botocore.__version__) + + def boto3_at_least(self, desired): + """Check if the available boto3 version is greater than or equal to a desired version. + + Usage: + if module.params.get('assign_ipv6_address') and not module.boto3_at_least('1.4.4'): + # conditionally fail on old boto3 versions if a specific feature is not supported + module.fail_json(msg="Boto3 can't deal with EC2 IPv6 addresses before version 1.4.4.") + """ + existing = self._gather_versions() + return LooseVersion(existing['boto3_version']) >= LooseVersion(desired) + + def botocore_at_least(self, desired): + """Check if the available botocore version is greater than or equal to a desired version. + + Usage: + if not module.botocore_at_least('1.2.3'): + module.fail_json(msg='The Serverless Elastic Load Compute Service is not in botocore before v1.2.3') + if not module.botocore_at_least('1.5.3'): + module.warn('Botocore did not include waiters for Service X before 1.5.3. ' + 'To wait until Service X resources are fully available, update botocore.') + """ + existing = self._gather_versions() + return LooseVersion(existing['botocore_version']) >= LooseVersion(desired) + + +class _RetryingBotoClientWrapper(object): + __never_wait = ( + 'get_paginator', 'can_paginate', + 'get_waiter', 'generate_presigned_url', + ) + + def __init__(self, client, retry): + self.client = client + self.retry = retry + + def _create_optional_retry_wrapper_function(self, unwrapped): + retrying_wrapper = self.retry(unwrapped) + + @wraps(unwrapped) + def deciding_wrapper(aws_retry=False, *args, **kwargs): + if aws_retry: + return retrying_wrapper(*args, **kwargs) + else: + return unwrapped(*args, **kwargs) + return deciding_wrapper + + def __getattr__(self, name): + unwrapped = getattr(self.client, name) + if name in self.__never_wait: + return unwrapped + elif callable(unwrapped): + wrapped = self._create_optional_retry_wrapper_function(unwrapped) + setattr(self, name, wrapped) + return wrapped + else: + return unwrapped + + +def is_boto3_error_code(code, e=None): + """Check if the botocore exception is raised by a specific error code. + + Returns ClientError if the error code matches, a dummy exception if it does not have an error code or does not match + + Example: + try: + ec2.describe_instances(InstanceIds=['potato']) + except is_boto3_error_code('InvalidInstanceID.Malformed'): + # handle the error for that code case + except botocore.exceptions.ClientError as e: + # handle the generic error case for all other codes + """ + from botocore.exceptions import ClientError + if e is None: + import sys + dummy, e, dummy = sys.exc_info() + if isinstance(e, ClientError) and e.response['Error']['Code'] == code: + return ClientError + return type('NeverEverRaisedException', (Exception,), {}) + + +def get_boto3_client_method_parameters(client, method_name, required=False): + op = client.meta.method_to_api_mapping.get(method_name) + input_shape = client._service_model.operation_model(op).input_shape + if not input_shape: + parameters = [] + elif required: + parameters = list(input_shape.required_members) + else: + parameters = list(input_shape.members.keys()) + return parameters diff --git a/test/support/integration/plugins/module_utils/aws/iam.py b/test/support/integration/plugins/module_utils/aws/iam.py new file mode 100644 index 00000000..f05999aa --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/iam.py @@ -0,0 +1,49 @@ +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import traceback + +try: + from botocore.exceptions import ClientError, NoCredentialsError +except ImportError: + pass # caught by HAS_BOTO3 + +from ansible.module_utils._text import to_native + + +def get_aws_account_id(module): + """ Given AnsibleAWSModule instance, get the active AWS account ID + + get_account_id tries too find out the account that we are working + on. It's not guaranteed that this will be easy so we try in + several different ways. Giving either IAM or STS privilages to + the account should be enough to permit this. + """ + account_id = None + try: + sts_client = module.client('sts') + account_id = sts_client.get_caller_identity().get('Account') + # non-STS sessions may also get NoCredentialsError from this STS call, so + # we must catch that too and try the IAM version + except (ClientError, NoCredentialsError): + try: + iam_client = module.client('iam') + account_id = iam_client.get_user()['User']['Arn'].split(':')[4] + except ClientError as e: + if (e.response['Error']['Code'] == 'AccessDenied'): + except_msg = to_native(e) + # don't match on `arn:aws` because of China region `arn:aws-cn` and similar + account_id = except_msg.search(r"arn:\w+:iam::([0-9]{12,32}):\w+/").group(1) + if account_id is None: + module.fail_json_aws(e, msg="Could not get AWS account information") + except Exception as e: + module.fail_json( + msg="Failed to get AWS account information, Try allowing sts:GetCallerIdentity or iam:GetUser permissions.", + exception=traceback.format_exc() + ) + if not account_id: + module.fail_json(msg="Failed while determining AWS account ID. Try allowing sts:GetCallerIdentity or iam:GetUser permissions.") + return to_native(account_id) diff --git a/test/support/integration/plugins/module_utils/aws/s3.py b/test/support/integration/plugins/module_utils/aws/s3.py new file mode 100644 index 00000000..2185869d --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/s3.py @@ -0,0 +1,50 @@ +# Copyright (c) 2018 Red Hat, Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +try: + from botocore.exceptions import BotoCoreError, ClientError +except ImportError: + pass # Handled by the calling module + +HAS_MD5 = True +try: + from hashlib import md5 +except ImportError: + try: + from md5 import md5 + except ImportError: + HAS_MD5 = False + + +def calculate_etag(module, filename, etag, s3, bucket, obj, version=None): + if not HAS_MD5: + return None + + if '-' in etag: + # Multi-part ETag; a hash of the hashes of each part. + parts = int(etag[1:-1].split('-')[1]) + digests = [] + + s3_kwargs = dict( + Bucket=bucket, + Key=obj, + ) + if version: + s3_kwargs['VersionId'] = version + + with open(filename, 'rb') as f: + for part_num in range(1, parts + 1): + s3_kwargs['PartNumber'] = part_num + try: + head = s3.head_object(**s3_kwargs) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get head object") + digests.append(md5(f.read(int(head['ContentLength'])))) + + digest_squared = md5(b''.join(m.digest() for m in digests)) + return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests)) + else: # Compute the MD5 sum normally + return '"{0}"'.format(module.md5(filename)) diff --git a/test/support/integration/plugins/module_utils/aws/waiters.py b/test/support/integration/plugins/module_utils/aws/waiters.py new file mode 100644 index 00000000..25db598b --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/waiters.py @@ -0,0 +1,405 @@ +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +try: + import botocore.waiter as core_waiter +except ImportError: + pass # caught by HAS_BOTO3 + + +ec2_data = { + "version": 2, + "waiters": { + "InternetGatewayExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeInternetGateways", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(InternetGateways) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidInternetGatewayID.NotFound", + "state": "retry" + }, + ] + }, + "RouteTableExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeRouteTables", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(RouteTables[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidRouteTableID.NotFound", + "state": "retry" + }, + ] + }, + "SecurityGroupExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSecurityGroups", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(SecurityGroups[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidGroup.NotFound", + "state": "retry" + }, + ] + }, + "SubnetExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(Subnets[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidSubnetID.NotFound", + "state": "retry" + }, + ] + }, + "SubnetHasMapPublic": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": True, + "argument": "Subnets[].MapPublicIpOnLaunch", + "state": "success" + }, + ] + }, + "SubnetNoMapPublic": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": False, + "argument": "Subnets[].MapPublicIpOnLaunch", + "state": "success" + }, + ] + }, + "SubnetHasAssignIpv6": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": True, + "argument": "Subnets[].AssignIpv6AddressOnCreation", + "state": "success" + }, + ] + }, + "SubnetNoAssignIpv6": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": False, + "argument": "Subnets[].AssignIpv6AddressOnCreation", + "state": "success" + }, + ] + }, + "SubnetDeleted": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(Subnets[]) > `0`", + "state": "retry" + }, + { + "matcher": "error", + "expected": "InvalidSubnetID.NotFound", + "state": "success" + }, + ] + }, + "VpnGatewayExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeVpnGateways", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(VpnGateways[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidVpnGatewayID.NotFound", + "state": "retry" + }, + ] + }, + "VpnGatewayDetached": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeVpnGateways", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "VpnGateways[0].State == 'available'", + "state": "success" + }, + ] + }, + } +} + + +waf_data = { + "version": 2, + "waiters": { + "ChangeTokenInSync": { + "delay": 20, + "maxAttempts": 60, + "operation": "GetChangeTokenStatus", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "ChangeTokenStatus == 'INSYNC'", + "state": "success" + }, + { + "matcher": "error", + "expected": "WAFInternalErrorException", + "state": "retry" + } + ] + } + } +} + +eks_data = { + "version": 2, + "waiters": { + "ClusterActive": { + "delay": 20, + "maxAttempts": 60, + "operation": "DescribeCluster", + "acceptors": [ + { + "state": "success", + "matcher": "path", + "argument": "cluster.status", + "expected": "ACTIVE" + }, + { + "state": "retry", + "matcher": "error", + "expected": "ResourceNotFoundException" + } + ] + }, + "ClusterDeleted": { + "delay": 20, + "maxAttempts": 60, + "operation": "DescribeCluster", + "acceptors": [ + { + "state": "retry", + "matcher": "path", + "argument": "cluster.status != 'DELETED'", + "expected": True + }, + { + "state": "success", + "matcher": "error", + "expected": "ResourceNotFoundException" + } + ] + } + } +} + + +rds_data = { + "version": 2, + "waiters": { + "DBInstanceStopped": { + "delay": 20, + "maxAttempts": 60, + "operation": "DescribeDBInstances", + "acceptors": [ + { + "state": "success", + "matcher": "pathAll", + "argument": "DBInstances[].DBInstanceStatus", + "expected": "stopped" + }, + ] + } + } +} + + +def ec2_model(name): + ec2_models = core_waiter.WaiterModel(waiter_config=ec2_data) + return ec2_models.get_waiter(name) + + +def waf_model(name): + waf_models = core_waiter.WaiterModel(waiter_config=waf_data) + return waf_models.get_waiter(name) + + +def eks_model(name): + eks_models = core_waiter.WaiterModel(waiter_config=eks_data) + return eks_models.get_waiter(name) + + +def rds_model(name): + rds_models = core_waiter.WaiterModel(waiter_config=rds_data) + return rds_models.get_waiter(name) + + +waiters_by_name = { + ('EC2', 'internet_gateway_exists'): lambda ec2: core_waiter.Waiter( + 'internet_gateway_exists', + ec2_model('InternetGatewayExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_internet_gateways + )), + ('EC2', 'route_table_exists'): lambda ec2: core_waiter.Waiter( + 'route_table_exists', + ec2_model('RouteTableExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_route_tables + )), + ('EC2', 'security_group_exists'): lambda ec2: core_waiter.Waiter( + 'security_group_exists', + ec2_model('SecurityGroupExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_security_groups + )), + ('EC2', 'subnet_exists'): lambda ec2: core_waiter.Waiter( + 'subnet_exists', + ec2_model('SubnetExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_has_map_public'): lambda ec2: core_waiter.Waiter( + 'subnet_has_map_public', + ec2_model('SubnetHasMapPublic'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_no_map_public'): lambda ec2: core_waiter.Waiter( + 'subnet_no_map_public', + ec2_model('SubnetNoMapPublic'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_has_assign_ipv6'): lambda ec2: core_waiter.Waiter( + 'subnet_has_assign_ipv6', + ec2_model('SubnetHasAssignIpv6'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_no_assign_ipv6'): lambda ec2: core_waiter.Waiter( + 'subnet_no_assign_ipv6', + ec2_model('SubnetNoAssignIpv6'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_deleted'): lambda ec2: core_waiter.Waiter( + 'subnet_deleted', + ec2_model('SubnetDeleted'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'vpn_gateway_exists'): lambda ec2: core_waiter.Waiter( + 'vpn_gateway_exists', + ec2_model('VpnGatewayExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_vpn_gateways + )), + ('EC2', 'vpn_gateway_detached'): lambda ec2: core_waiter.Waiter( + 'vpn_gateway_detached', + ec2_model('VpnGatewayDetached'), + core_waiter.NormalizedOperationMethod( + ec2.describe_vpn_gateways + )), + ('WAF', 'change_token_in_sync'): lambda waf: core_waiter.Waiter( + 'change_token_in_sync', + waf_model('ChangeTokenInSync'), + core_waiter.NormalizedOperationMethod( + waf.get_change_token_status + )), + ('WAFRegional', 'change_token_in_sync'): lambda waf: core_waiter.Waiter( + 'change_token_in_sync', + waf_model('ChangeTokenInSync'), + core_waiter.NormalizedOperationMethod( + waf.get_change_token_status + )), + ('EKS', 'cluster_active'): lambda eks: core_waiter.Waiter( + 'cluster_active', + eks_model('ClusterActive'), + core_waiter.NormalizedOperationMethod( + eks.describe_cluster + )), + ('EKS', 'cluster_deleted'): lambda eks: core_waiter.Waiter( + 'cluster_deleted', + eks_model('ClusterDeleted'), + core_waiter.NormalizedOperationMethod( + eks.describe_cluster + )), + ('RDS', 'db_instance_stopped'): lambda rds: core_waiter.Waiter( + 'db_instance_stopped', + rds_model('DBInstanceStopped'), + core_waiter.NormalizedOperationMethod( + rds.describe_db_instances + )), +} + + +def get_waiter(client, waiter_name): + try: + return waiters_by_name[(client.__class__.__name__, waiter_name)](client) + except KeyError: + raise NotImplementedError("Waiter {0} could not be found for client {1}. Available waiters: {2}".format( + waiter_name, type(client), ', '.join(repr(k) for k in waiters_by_name.keys()))) diff --git a/test/support/integration/plugins/module_utils/azure_rm_common.py b/test/support/integration/plugins/module_utils/azure_rm_common.py new file mode 100644 index 00000000..a7b55e97 --- /dev/null +++ b/test/support/integration/plugins/module_utils/azure_rm_common.py @@ -0,0 +1,1473 @@ +# Copyright (c) 2016 Matt Davis, <mdavis@ansible.com> +# Chris Houseknecht, <house@redhat.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import os +import re +import types +import copy +import inspect +import traceback +import json + +from os.path import expanduser + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +try: + from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION +except Exception: + ANSIBLE_VERSION = 'unknown' +from ansible.module_utils.six.moves import configparser +import ansible.module_utils.six.moves.urllib.parse as urlparse + +AZURE_COMMON_ARGS = dict( + auth_source=dict( + type='str', + choices=['auto', 'cli', 'env', 'credential_file', 'msi'] + ), + profile=dict(type='str'), + subscription_id=dict(type='str'), + client_id=dict(type='str', no_log=True), + secret=dict(type='str', no_log=True), + tenant=dict(type='str', no_log=True), + ad_user=dict(type='str', no_log=True), + password=dict(type='str', no_log=True), + cloud_environment=dict(type='str', default='AzureCloud'), + cert_validation_mode=dict(type='str', choices=['validate', 'ignore']), + api_profile=dict(type='str', default='latest'), + adfs_authority_url=dict(type='str', default=None) +) + +AZURE_CREDENTIAL_ENV_MAPPING = dict( + profile='AZURE_PROFILE', + subscription_id='AZURE_SUBSCRIPTION_ID', + client_id='AZURE_CLIENT_ID', + secret='AZURE_SECRET', + tenant='AZURE_TENANT', + ad_user='AZURE_AD_USER', + password='AZURE_PASSWORD', + cloud_environment='AZURE_CLOUD_ENVIRONMENT', + cert_validation_mode='AZURE_CERT_VALIDATION_MODE', + adfs_authority_url='AZURE_ADFS_AUTHORITY_URL' +) + + +class SDKProfile(object): # pylint: disable=too-few-public-methods + + def __init__(self, default_api_version, profile=None): + """Constructor. + + :param str default_api_version: Default API version if not overridden by a profile. Nullable. + :param profile: A dict operation group name to API version. + :type profile: dict[str, str] + """ + self.profile = profile if profile is not None else {} + self.profile[None] = default_api_version + + @property + def default_api_version(self): + return self.profile[None] + + +# FUTURE: this should come from the SDK or an external location. +# For now, we have to copy from azure-cli +AZURE_API_PROFILES = { + 'latest': { + 'ContainerInstanceManagementClient': '2018-02-01-preview', + 'ComputeManagementClient': dict( + default_api_version='2018-10-01', + resource_skus='2018-10-01', + disks='2018-06-01', + snapshots='2018-10-01', + virtual_machine_run_commands='2018-10-01' + ), + 'NetworkManagementClient': '2018-08-01', + 'ResourceManagementClient': '2017-05-10', + 'StorageManagementClient': '2017-10-01', + 'WebSiteManagementClient': '2018-02-01', + 'PostgreSQLManagementClient': '2017-12-01', + 'MySQLManagementClient': '2017-12-01', + 'MariaDBManagementClient': '2019-03-01', + 'ManagementLockClient': '2016-09-01' + }, + '2019-03-01-hybrid': { + 'StorageManagementClient': '2017-10-01', + 'NetworkManagementClient': '2017-10-01', + 'ComputeManagementClient': SDKProfile('2017-12-01', { + 'resource_skus': '2017-09-01', + 'disks': '2017-03-30', + 'snapshots': '2017-03-30' + }), + 'ManagementLinkClient': '2016-09-01', + 'ManagementLockClient': '2016-09-01', + 'PolicyClient': '2016-12-01', + 'ResourceManagementClient': '2018-05-01', + 'SubscriptionClient': '2016-06-01', + 'DnsManagementClient': '2016-04-01', + 'KeyVaultManagementClient': '2016-10-01', + 'AuthorizationManagementClient': SDKProfile('2015-07-01', { + 'classic_administrators': '2015-06-01', + 'policy_assignments': '2016-12-01', + 'policy_definitions': '2016-12-01' + }), + 'KeyVaultClient': '2016-10-01', + 'azure.multiapi.storage': '2017-11-09', + 'azure.multiapi.cosmosdb': '2017-04-17' + }, + '2018-03-01-hybrid': { + 'StorageManagementClient': '2016-01-01', + 'NetworkManagementClient': '2017-10-01', + 'ComputeManagementClient': SDKProfile('2017-03-30'), + 'ManagementLinkClient': '2016-09-01', + 'ManagementLockClient': '2016-09-01', + 'PolicyClient': '2016-12-01', + 'ResourceManagementClient': '2018-02-01', + 'SubscriptionClient': '2016-06-01', + 'DnsManagementClient': '2016-04-01', + 'KeyVaultManagementClient': '2016-10-01', + 'AuthorizationManagementClient': SDKProfile('2015-07-01', { + 'classic_administrators': '2015-06-01' + }), + 'KeyVaultClient': '2016-10-01', + 'azure.multiapi.storage': '2017-04-17', + 'azure.multiapi.cosmosdb': '2017-04-17' + }, + '2017-03-09-profile': { + 'StorageManagementClient': '2016-01-01', + 'NetworkManagementClient': '2015-06-15', + 'ComputeManagementClient': SDKProfile('2016-03-30'), + 'ManagementLinkClient': '2016-09-01', + 'ManagementLockClient': '2015-01-01', + 'PolicyClient': '2015-10-01-preview', + 'ResourceManagementClient': '2016-02-01', + 'SubscriptionClient': '2016-06-01', + 'DnsManagementClient': '2016-04-01', + 'KeyVaultManagementClient': '2016-10-01', + 'AuthorizationManagementClient': SDKProfile('2015-07-01', { + 'classic_administrators': '2015-06-01' + }), + 'KeyVaultClient': '2016-10-01', + 'azure.multiapi.storage': '2015-04-05' + } +} + +AZURE_TAG_ARGS = dict( + tags=dict(type='dict'), + append_tags=dict(type='bool', default=True), +) + +AZURE_COMMON_REQUIRED_IF = [ + ('log_mode', 'file', ['log_path']) +] + +ANSIBLE_USER_AGENT = 'Ansible/{0}'.format(ANSIBLE_VERSION) +CLOUDSHELL_USER_AGENT_KEY = 'AZURE_HTTP_USER_AGENT' +VSCODEEXT_USER_AGENT_KEY = 'VSCODEEXT_USER_AGENT' + +CIDR_PATTERN = re.compile(r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1" + r"[0-9]{2}|2[0-4][0-9]|25[0-5])(/([0-9]|[1-2][0-9]|3[0-2]))") + +AZURE_SUCCESS_STATE = "Succeeded" +AZURE_FAILED_STATE = "Failed" + +HAS_AZURE = True +HAS_AZURE_EXC = None +HAS_AZURE_CLI_CORE = True +HAS_AZURE_CLI_CORE_EXC = None + +HAS_MSRESTAZURE = True +HAS_MSRESTAZURE_EXC = None + +try: + import importlib +except ImportError: + # This passes the sanity import test, but does not provide a user friendly error message. + # Doing so would require catching Exception for all imports of Azure dependencies in modules and module_utils. + importlib = None + +try: + from packaging.version import Version + HAS_PACKAGING_VERSION = True + HAS_PACKAGING_VERSION_EXC = None +except ImportError: + Version = None + HAS_PACKAGING_VERSION = False + HAS_PACKAGING_VERSION_EXC = traceback.format_exc() + +# NB: packaging issue sometimes cause msrestazure not to be installed, check it separately +try: + from msrest.serialization import Serializer +except ImportError: + HAS_MSRESTAZURE_EXC = traceback.format_exc() + HAS_MSRESTAZURE = False + +try: + from enum import Enum + from msrestazure.azure_active_directory import AADTokenCredentials + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_active_directory import MSIAuthentication + from msrestazure.tools import parse_resource_id, resource_id, is_valid_resource_id + from msrestazure import azure_cloud + from azure.common.credentials import ServicePrincipalCredentials, UserPassCredentials + from azure.mgmt.monitor.version import VERSION as monitor_client_version + from azure.mgmt.network.version import VERSION as network_client_version + from azure.mgmt.storage.version import VERSION as storage_client_version + from azure.mgmt.compute.version import VERSION as compute_client_version + from azure.mgmt.resource.version import VERSION as resource_client_version + from azure.mgmt.dns.version import VERSION as dns_client_version + from azure.mgmt.web.version import VERSION as web_client_version + from azure.mgmt.network import NetworkManagementClient + from azure.mgmt.resource.resources import ResourceManagementClient + from azure.mgmt.resource.subscriptions import SubscriptionClient + from azure.mgmt.storage import StorageManagementClient + from azure.mgmt.compute import ComputeManagementClient + from azure.mgmt.dns import DnsManagementClient + from azure.mgmt.monitor import MonitorManagementClient + from azure.mgmt.web import WebSiteManagementClient + from azure.mgmt.containerservice import ContainerServiceClient + from azure.mgmt.marketplaceordering import MarketplaceOrderingAgreements + from azure.mgmt.trafficmanager import TrafficManagerManagementClient + from azure.storage.cloudstorageaccount import CloudStorageAccount + from azure.storage.blob import PageBlobService, BlockBlobService + from adal.authentication_context import AuthenticationContext + from azure.mgmt.sql import SqlManagementClient + from azure.mgmt.servicebus import ServiceBusManagementClient + import azure.mgmt.servicebus.models as ServicebusModel + from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient + from azure.mgmt.rdbms.mysql import MySQLManagementClient + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from azure.mgmt.containerregistry import ContainerRegistryManagementClient + from azure.mgmt.containerinstance import ContainerInstanceManagementClient + from azure.mgmt.loganalytics import LogAnalyticsManagementClient + import azure.mgmt.loganalytics.models as LogAnalyticsModels + from azure.mgmt.automation import AutomationClient + import azure.mgmt.automation.models as AutomationModel + from azure.mgmt.iothub import IotHubClient + from azure.mgmt.iothub import models as IoTHubModels + from msrest.service_client import ServiceClient + from msrestazure import AzureConfiguration + from msrest.authentication import Authentication + from azure.mgmt.resource.locks import ManagementLockClient +except ImportError as exc: + Authentication = object + HAS_AZURE_EXC = traceback.format_exc() + HAS_AZURE = False + +from base64 import b64encode, b64decode +from hashlib import sha256 +from hmac import HMAC +from time import time + +try: + from urllib import (urlencode, quote_plus) +except ImportError: + from urllib.parse import (urlencode, quote_plus) + +try: + from azure.cli.core.util import CLIError + from azure.common.credentials import get_azure_cli_credentials, get_cli_profile + from azure.common.cloud import get_cli_active_cloud +except ImportError: + HAS_AZURE_CLI_CORE = False + HAS_AZURE_CLI_CORE_EXC = None + CLIError = Exception + + +def azure_id_to_dict(id): + pieces = re.sub(r'^\/', '', id).split('/') + result = {} + index = 0 + while index < len(pieces) - 1: + result[pieces[index]] = pieces[index + 1] + index += 1 + return result + + +def format_resource_id(val, subscription_id, namespace, types, resource_group): + return resource_id(name=val, + resource_group=resource_group, + namespace=namespace, + type=types, + subscription=subscription_id) if not is_valid_resource_id(val) else val + + +def normalize_location_name(name): + return name.replace(' ', '').lower() + + +# FUTURE: either get this from the requirements file (if we can be sure it's always available at runtime) +# or generate the requirements files from this so we only have one source of truth to maintain... +AZURE_PKG_VERSIONS = { + 'StorageManagementClient': { + 'package_name': 'storage', + 'expected_version': '3.1.0' + }, + 'ComputeManagementClient': { + 'package_name': 'compute', + 'expected_version': '4.4.0' + }, + 'ContainerInstanceManagementClient': { + 'package_name': 'containerinstance', + 'expected_version': '0.4.0' + }, + 'NetworkManagementClient': { + 'package_name': 'network', + 'expected_version': '2.3.0' + }, + 'ResourceManagementClient': { + 'package_name': 'resource', + 'expected_version': '2.1.0' + }, + 'DnsManagementClient': { + 'package_name': 'dns', + 'expected_version': '2.1.0' + }, + 'WebSiteManagementClient': { + 'package_name': 'web', + 'expected_version': '0.41.0' + }, + 'TrafficManagerManagementClient': { + 'package_name': 'trafficmanager', + 'expected_version': '0.50.0' + }, +} if HAS_AZURE else {} + + +AZURE_MIN_RELEASE = '2.0.0' + + +class AzureRMModuleBase(object): + def __init__(self, derived_arg_spec, bypass_checks=False, no_log=False, + mutually_exclusive=None, required_together=None, + required_one_of=None, add_file_common_args=False, supports_check_mode=False, + required_if=None, supports_tags=True, facts_module=False, skip_exec=False): + + merged_arg_spec = dict() + merged_arg_spec.update(AZURE_COMMON_ARGS) + if supports_tags: + merged_arg_spec.update(AZURE_TAG_ARGS) + + if derived_arg_spec: + merged_arg_spec.update(derived_arg_spec) + + merged_required_if = list(AZURE_COMMON_REQUIRED_IF) + if required_if: + merged_required_if += required_if + + self.module = AnsibleModule(argument_spec=merged_arg_spec, + bypass_checks=bypass_checks, + no_log=no_log, + mutually_exclusive=mutually_exclusive, + required_together=required_together, + required_one_of=required_one_of, + add_file_common_args=add_file_common_args, + supports_check_mode=supports_check_mode, + required_if=merged_required_if) + + if not HAS_PACKAGING_VERSION: + self.fail(msg=missing_required_lib('packaging'), + exception=HAS_PACKAGING_VERSION_EXC) + + if not HAS_MSRESTAZURE: + self.fail(msg=missing_required_lib('msrestazure'), + exception=HAS_MSRESTAZURE_EXC) + + if not HAS_AZURE: + self.fail(msg=missing_required_lib('ansible[azure] (azure >= {0})'.format(AZURE_MIN_RELEASE)), + exception=HAS_AZURE_EXC) + + self._network_client = None + self._storage_client = None + self._resource_client = None + self._compute_client = None + self._dns_client = None + self._web_client = None + self._marketplace_client = None + self._sql_client = None + self._mysql_client = None + self._mariadb_client = None + self._postgresql_client = None + self._containerregistry_client = None + self._containerinstance_client = None + self._containerservice_client = None + self._managedcluster_client = None + self._traffic_manager_management_client = None + self._monitor_client = None + self._resource = None + self._log_analytics_client = None + self._servicebus_client = None + self._automation_client = None + self._IoThub_client = None + self._lock_client = None + + self.check_mode = self.module.check_mode + self.api_profile = self.module.params.get('api_profile') + self.facts_module = facts_module + # self.debug = self.module.params.get('debug') + + # delegate auth to AzureRMAuth class (shared with all plugin types) + self.azure_auth = AzureRMAuth(fail_impl=self.fail, **self.module.params) + + # common parameter validation + if self.module.params.get('tags'): + self.validate_tags(self.module.params['tags']) + + if not skip_exec: + res = self.exec_module(**self.module.params) + self.module.exit_json(**res) + + def check_client_version(self, client_type): + # Ensure Azure modules are at least 2.0.0rc5. + package_version = AZURE_PKG_VERSIONS.get(client_type.__name__, None) + if package_version is not None: + client_name = package_version.get('package_name') + try: + client_module = importlib.import_module(client_type.__module__) + client_version = client_module.VERSION + except (RuntimeError, AttributeError): + # can't get at the module version for some reason, just fail silently... + return + expected_version = package_version.get('expected_version') + if Version(client_version) < Version(expected_version): + self.fail("Installed azure-mgmt-{0} client version is {1}. The minimum supported version is {2}. Try " + "`pip install ansible[azure]`".format(client_name, client_version, expected_version)) + if Version(client_version) != Version(expected_version): + self.module.warn("Installed azure-mgmt-{0} client version is {1}. The expected version is {2}. Try " + "`pip install ansible[azure]`".format(client_name, client_version, expected_version)) + + def exec_module(self, **kwargs): + self.fail("Error: {0} failed to implement exec_module method.".format(self.__class__.__name__)) + + def fail(self, msg, **kwargs): + ''' + Shortcut for calling module.fail() + + :param msg: Error message text. + :param kwargs: Any key=value pairs + :return: None + ''' + self.module.fail_json(msg=msg, **kwargs) + + def deprecate(self, msg, version=None, collection_name=None): + self.module.deprecate(msg, version, collection_name=collection_name) + + def log(self, msg, pretty_print=False): + if pretty_print: + self.module.debug(json.dumps(msg, indent=4, sort_keys=True)) + else: + self.module.debug(msg) + + def validate_tags(self, tags): + ''' + Check if tags dictionary contains string:string pairs. + + :param tags: dictionary of string:string pairs + :return: None + ''' + if not self.facts_module: + if not isinstance(tags, dict): + self.fail("Tags must be a dictionary of string:string values.") + for key, value in tags.items(): + if not isinstance(value, str): + self.fail("Tags values must be strings. Found {0}:{1}".format(str(key), str(value))) + + def update_tags(self, tags): + ''' + Call from the module to update metadata tags. Returns tuple + with bool indicating if there was a change and dict of new + tags to assign to the object. + + :param tags: metadata tags from the object + :return: bool, dict + ''' + tags = tags or dict() + new_tags = copy.copy(tags) if isinstance(tags, dict) else dict() + param_tags = self.module.params.get('tags') if isinstance(self.module.params.get('tags'), dict) else dict() + append_tags = self.module.params.get('append_tags') if self.module.params.get('append_tags') is not None else True + changed = False + # check add or update + for key, value in param_tags.items(): + if not new_tags.get(key) or new_tags[key] != value: + changed = True + new_tags[key] = value + # check remove + if not append_tags: + for key, value in tags.items(): + if not param_tags.get(key): + new_tags.pop(key) + changed = True + return changed, new_tags + + def has_tags(self, obj_tags, tag_list): + ''' + Used in fact modules to compare object tags to list of parameter tags. Return true if list of parameter tags + exists in object tags. + + :param obj_tags: dictionary of tags from an Azure object. + :param tag_list: list of tag keys or tag key:value pairs + :return: bool + ''' + + if not obj_tags and tag_list: + return False + + if not tag_list: + return True + + matches = 0 + result = False + for tag in tag_list: + tag_key = tag + tag_value = None + if ':' in tag: + tag_key, tag_value = tag.split(':') + if tag_value and obj_tags.get(tag_key) == tag_value: + matches += 1 + elif not tag_value and obj_tags.get(tag_key): + matches += 1 + if matches == len(tag_list): + result = True + return result + + def get_resource_group(self, resource_group): + ''' + Fetch a resource group. + + :param resource_group: name of a resource group + :return: resource group object + ''' + try: + return self.rm_client.resource_groups.get(resource_group) + except CloudError as cloud_error: + self.fail("Error retrieving resource group {0} - {1}".format(resource_group, cloud_error.message)) + except Exception as exc: + self.fail("Error retrieving resource group {0} - {1}".format(resource_group, str(exc))) + + def parse_resource_to_dict(self, resource): + ''' + Return a dict of the give resource, which contains name and resource group. + + :param resource: It can be a resource name, id or a dict contains name and resource group. + ''' + resource_dict = parse_resource_id(resource) if not isinstance(resource, dict) else resource + resource_dict['resource_group'] = resource_dict.get('resource_group', self.resource_group) + resource_dict['subscription_id'] = resource_dict.get('subscription_id', self.subscription_id) + return resource_dict + + def serialize_obj(self, obj, class_name, enum_modules=None): + ''' + Return a JSON representation of an Azure object. + + :param obj: Azure object + :param class_name: Name of the object's class + :param enum_modules: List of module names to build enum dependencies from. + :return: serialized result + ''' + enum_modules = [] if enum_modules is None else enum_modules + + dependencies = dict() + if enum_modules: + for module_name in enum_modules: + mod = importlib.import_module(module_name) + for mod_class_name, mod_class_obj in inspect.getmembers(mod, predicate=inspect.isclass): + dependencies[mod_class_name] = mod_class_obj + self.log("dependencies: ") + self.log(str(dependencies)) + serializer = Serializer(classes=dependencies) + return serializer.body(obj, class_name, keep_readonly=True) + + def get_poller_result(self, poller, wait=5): + ''' + Consistent method of waiting on and retrieving results from Azure's long poller + + :param poller Azure poller object + :return object resulting from the original request + ''' + try: + delay = wait + while not poller.done(): + self.log("Waiting for {0} sec".format(delay)) + poller.wait(timeout=delay) + return poller.result() + except Exception as exc: + self.log(str(exc)) + raise + + def check_provisioning_state(self, azure_object, requested_state='present'): + ''' + Check an Azure object's provisioning state. If something did not complete the provisioning + process, then we cannot operate on it. + + :param azure_object An object such as a subnet, storageaccount, etc. Must have provisioning_state + and name attributes. + :return None + ''' + + if hasattr(azure_object, 'properties') and hasattr(azure_object.properties, 'provisioning_state') and \ + hasattr(azure_object, 'name'): + # resource group object fits this model + if isinstance(azure_object.properties.provisioning_state, Enum): + if azure_object.properties.provisioning_state.value != AZURE_SUCCESS_STATE and \ + requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.properties.provisioning_state, AZURE_SUCCESS_STATE)) + return + if azure_object.properties.provisioning_state != AZURE_SUCCESS_STATE and \ + requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.properties.provisioning_state, AZURE_SUCCESS_STATE)) + return + + if hasattr(azure_object, 'provisioning_state') or not hasattr(azure_object, 'name'): + if isinstance(azure_object.provisioning_state, Enum): + if azure_object.provisioning_state.value != AZURE_SUCCESS_STATE and requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.provisioning_state, AZURE_SUCCESS_STATE)) + return + if azure_object.provisioning_state != AZURE_SUCCESS_STATE and requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.provisioning_state, AZURE_SUCCESS_STATE)) + + def get_blob_client(self, resource_group_name, storage_account_name, storage_blob_type='block'): + keys = dict() + try: + # Get keys from the storage account + self.log('Getting keys') + account_keys = self.storage_client.storage_accounts.list_keys(resource_group_name, storage_account_name) + except Exception as exc: + self.fail("Error getting keys for account {0} - {1}".format(storage_account_name, str(exc))) + + try: + self.log('Create blob service') + if storage_blob_type == 'page': + return PageBlobService(endpoint_suffix=self._cloud_environment.suffixes.storage_endpoint, + account_name=storage_account_name, + account_key=account_keys.keys[0].value) + elif storage_blob_type == 'block': + return BlockBlobService(endpoint_suffix=self._cloud_environment.suffixes.storage_endpoint, + account_name=storage_account_name, + account_key=account_keys.keys[0].value) + else: + raise Exception("Invalid storage blob type defined.") + except Exception as exc: + self.fail("Error creating blob service client for storage account {0} - {1}".format(storage_account_name, + str(exc))) + + def create_default_pip(self, resource_group, location, public_ip_name, allocation_method='Dynamic', sku=None): + ''' + Create a default public IP address <public_ip_name> to associate with a network interface. + If a PIP address matching <public_ip_name> exists, return it. Otherwise, create one. + + :param resource_group: name of an existing resource group + :param location: a valid azure location + :param public_ip_name: base name to assign the public IP address + :param allocation_method: one of 'Static' or 'Dynamic' + :param sku: sku + :return: PIP object + ''' + pip = None + + self.log("Starting create_default_pip {0}".format(public_ip_name)) + self.log("Check to see if public IP {0} exists".format(public_ip_name)) + try: + pip = self.network_client.public_ip_addresses.get(resource_group, public_ip_name) + except CloudError: + pass + + if pip: + self.log("Public ip {0} found.".format(public_ip_name)) + self.check_provisioning_state(pip) + return pip + + params = self.network_models.PublicIPAddress( + location=location, + public_ip_allocation_method=allocation_method, + sku=sku + ) + self.log('Creating default public IP {0}'.format(public_ip_name)) + try: + poller = self.network_client.public_ip_addresses.create_or_update(resource_group, public_ip_name, params) + except Exception as exc: + self.fail("Error creating {0} - {1}".format(public_ip_name, str(exc))) + + return self.get_poller_result(poller) + + def create_default_securitygroup(self, resource_group, location, security_group_name, os_type, open_ports): + ''' + Create a default security group <security_group_name> to associate with a network interface. If a security group matching + <security_group_name> exists, return it. Otherwise, create one. + + :param resource_group: Resource group name + :param location: azure location name + :param security_group_name: base name to use for the security group + :param os_type: one of 'Windows' or 'Linux'. Determins any default rules added to the security group. + :param ssh_port: for os_type 'Linux' port used in rule allowing SSH access. + :param rdp_port: for os_type 'Windows' port used in rule allowing RDP access. + :return: security_group object + ''' + group = None + + self.log("Create security group {0}".format(security_group_name)) + self.log("Check to see if security group {0} exists".format(security_group_name)) + try: + group = self.network_client.network_security_groups.get(resource_group, security_group_name) + except CloudError: + pass + + if group: + self.log("Security group {0} found.".format(security_group_name)) + self.check_provisioning_state(group) + return group + + parameters = self.network_models.NetworkSecurityGroup() + parameters.location = location + + if not open_ports: + # Open default ports based on OS type + if os_type == 'Linux': + # add an inbound SSH rule + parameters.security_rules = [ + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + description='Allow SSH Access', + source_port_range='*', + destination_port_range='22', + priority=100, + name='SSH') + ] + parameters.location = location + else: + # for windows add inbound RDP and WinRM rules + parameters.security_rules = [ + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + description='Allow RDP port 3389', + source_port_range='*', + destination_port_range='3389', + priority=100, + name='RDP01'), + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + description='Allow WinRM HTTPS port 5986', + source_port_range='*', + destination_port_range='5986', + priority=101, + name='WinRM01'), + ] + else: + # Open custom ports + parameters.security_rules = [] + priority = 100 + for port in open_ports: + priority += 1 + rule_name = "Rule_{0}".format(priority) + parameters.security_rules.append( + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + source_port_range='*', + destination_port_range=str(port), + priority=priority, + name=rule_name) + ) + + self.log('Creating default security group {0}'.format(security_group_name)) + try: + poller = self.network_client.network_security_groups.create_or_update(resource_group, + security_group_name, + parameters) + except Exception as exc: + self.fail("Error creating default security rule {0} - {1}".format(security_group_name, str(exc))) + + return self.get_poller_result(poller) + + @staticmethod + def _validation_ignore_callback(session, global_config, local_config, **kwargs): + session.verify = False + + def get_api_profile(self, client_type_name, api_profile_name): + profile_all_clients = AZURE_API_PROFILES.get(api_profile_name) + + if not profile_all_clients: + raise KeyError("unknown Azure API profile: {0}".format(api_profile_name)) + + profile_raw = profile_all_clients.get(client_type_name, None) + + if not profile_raw: + self.module.warn("Azure API profile {0} does not define an entry for {1}".format(api_profile_name, client_type_name)) + + if isinstance(profile_raw, dict): + if not profile_raw.get('default_api_version'): + raise KeyError("Azure API profile {0} does not define 'default_api_version'".format(api_profile_name)) + return profile_raw + + # wrap basic strings in a dict that just defines the default + return dict(default_api_version=profile_raw) + + def get_mgmt_svc_client(self, client_type, base_url=None, api_version=None): + self.log('Getting management service client {0}'.format(client_type.__name__)) + self.check_client_version(client_type) + + client_argspec = inspect.getargspec(client_type.__init__) + + if not base_url: + # most things are resource_manager, don't make everyone specify + base_url = self.azure_auth._cloud_environment.endpoints.resource_manager + + client_kwargs = dict(credentials=self.azure_auth.azure_credentials, subscription_id=self.azure_auth.subscription_id, base_url=base_url) + + api_profile_dict = {} + + if self.api_profile: + api_profile_dict = self.get_api_profile(client_type.__name__, self.api_profile) + + # unversioned clients won't accept profile; only send it if necessary + # clients without a version specified in the profile will use the default + if api_profile_dict and 'profile' in client_argspec.args: + client_kwargs['profile'] = api_profile_dict + + # If the client doesn't accept api_version, it's unversioned. + # If it does, favor explicitly-specified api_version, fall back to api_profile + if 'api_version' in client_argspec.args: + profile_default_version = api_profile_dict.get('default_api_version', None) + if api_version or profile_default_version: + client_kwargs['api_version'] = api_version or profile_default_version + if 'profile' in client_kwargs: + # remove profile; only pass API version if specified + client_kwargs.pop('profile') + + client = client_type(**client_kwargs) + + # FUTURE: remove this once everything exposes models directly (eg, containerinstance) + try: + getattr(client, "models") + except AttributeError: + def _ansible_get_models(self, *arg, **kwarg): + return self._ansible_models + + setattr(client, '_ansible_models', importlib.import_module(client_type.__module__).models) + client.models = types.MethodType(_ansible_get_models, client) + + client.config = self.add_user_agent(client.config) + + if self.azure_auth._cert_validation_mode == 'ignore': + client.config.session_configuration_callback = self._validation_ignore_callback + + return client + + def add_user_agent(self, config): + # Add user agent for Ansible + config.add_user_agent(ANSIBLE_USER_AGENT) + # Add user agent when running from Cloud Shell + if CLOUDSHELL_USER_AGENT_KEY in os.environ: + config.add_user_agent(os.environ[CLOUDSHELL_USER_AGENT_KEY]) + # Add user agent when running from VSCode extension + if VSCODEEXT_USER_AGENT_KEY in os.environ: + config.add_user_agent(os.environ[VSCODEEXT_USER_AGENT_KEY]) + return config + + def generate_sas_token(self, **kwags): + base_url = kwags.get('base_url', None) + expiry = kwags.get('expiry', time() + 3600) + key = kwags.get('key', None) + policy = kwags.get('policy', None) + url = quote_plus(base_url) + ttl = int(expiry) + sign_key = '{0}\n{1}'.format(url, ttl) + signature = b64encode(HMAC(b64decode(key), sign_key.encode('utf-8'), sha256).digest()) + result = { + 'sr': url, + 'sig': signature, + 'se': str(ttl), + } + if policy: + result['skn'] = policy + return 'SharedAccessSignature ' + urlencode(result) + + def get_data_svc_client(self, **kwags): + url = kwags.get('base_url', None) + config = AzureConfiguration(base_url='https://{0}'.format(url)) + config.credentials = AzureSASAuthentication(token=self.generate_sas_token(**kwags)) + config = self.add_user_agent(config) + return ServiceClient(creds=config.credentials, config=config) + + # passthru methods to AzureAuth instance for backcompat + @property + def credentials(self): + return self.azure_auth.credentials + + @property + def _cloud_environment(self): + return self.azure_auth._cloud_environment + + @property + def subscription_id(self): + return self.azure_auth.subscription_id + + @property + def storage_client(self): + self.log('Getting storage client...') + if not self._storage_client: + self._storage_client = self.get_mgmt_svc_client(StorageManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-07-01') + return self._storage_client + + @property + def storage_models(self): + return StorageManagementClient.models("2018-07-01") + + @property + def network_client(self): + self.log('Getting network client') + if not self._network_client: + self._network_client = self.get_mgmt_svc_client(NetworkManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2019-06-01') + return self._network_client + + @property + def network_models(self): + self.log("Getting network models...") + return NetworkManagementClient.models("2018-08-01") + + @property + def rm_client(self): + self.log('Getting resource manager client') + if not self._resource_client: + self._resource_client = self.get_mgmt_svc_client(ResourceManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2017-05-10') + return self._resource_client + + @property + def rm_models(self): + self.log("Getting resource manager models") + return ResourceManagementClient.models("2017-05-10") + + @property + def compute_client(self): + self.log('Getting compute client') + if not self._compute_client: + self._compute_client = self.get_mgmt_svc_client(ComputeManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2019-07-01') + return self._compute_client + + @property + def compute_models(self): + self.log("Getting compute models") + return ComputeManagementClient.models("2019-07-01") + + @property + def dns_client(self): + self.log('Getting dns client') + if not self._dns_client: + self._dns_client = self.get_mgmt_svc_client(DnsManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-05-01') + return self._dns_client + + @property + def dns_models(self): + self.log("Getting dns models...") + return DnsManagementClient.models('2018-05-01') + + @property + def web_client(self): + self.log('Getting web client') + if not self._web_client: + self._web_client = self.get_mgmt_svc_client(WebSiteManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-02-01') + return self._web_client + + @property + def containerservice_client(self): + self.log('Getting container service client') + if not self._containerservice_client: + self._containerservice_client = self.get_mgmt_svc_client(ContainerServiceClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2017-07-01') + return self._containerservice_client + + @property + def managedcluster_models(self): + self.log("Getting container service models") + return ContainerServiceClient.models('2018-03-31') + + @property + def managedcluster_client(self): + self.log('Getting container service client') + if not self._managedcluster_client: + self._managedcluster_client = self.get_mgmt_svc_client(ContainerServiceClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-03-31') + return self._managedcluster_client + + @property + def sql_client(self): + self.log('Getting SQL client') + if not self._sql_client: + self._sql_client = self.get_mgmt_svc_client(SqlManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._sql_client + + @property + def postgresql_client(self): + self.log('Getting PostgreSQL client') + if not self._postgresql_client: + self._postgresql_client = self.get_mgmt_svc_client(PostgreSQLManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._postgresql_client + + @property + def mysql_client(self): + self.log('Getting MySQL client') + if not self._mysql_client: + self._mysql_client = self.get_mgmt_svc_client(MySQLManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._mysql_client + + @property + def mariadb_client(self): + self.log('Getting MariaDB client') + if not self._mariadb_client: + self._mariadb_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._mariadb_client + + @property + def sql_client(self): + self.log('Getting SQL client') + if not self._sql_client: + self._sql_client = self.get_mgmt_svc_client(SqlManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._sql_client + + @property + def containerregistry_client(self): + self.log('Getting container registry mgmt client') + if not self._containerregistry_client: + self._containerregistry_client = self.get_mgmt_svc_client(ContainerRegistryManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2017-10-01') + + return self._containerregistry_client + + @property + def containerinstance_client(self): + self.log('Getting container instance mgmt client') + if not self._containerinstance_client: + self._containerinstance_client = self.get_mgmt_svc_client(ContainerInstanceManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-06-01') + + return self._containerinstance_client + + @property + def marketplace_client(self): + self.log('Getting marketplace agreement client') + if not self._marketplace_client: + self._marketplace_client = self.get_mgmt_svc_client(MarketplaceOrderingAgreements, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._marketplace_client + + @property + def traffic_manager_management_client(self): + self.log('Getting traffic manager client') + if not self._traffic_manager_management_client: + self._traffic_manager_management_client = self.get_mgmt_svc_client(TrafficManagerManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._traffic_manager_management_client + + @property + def monitor_client(self): + self.log('Getting monitor client') + if not self._monitor_client: + self._monitor_client = self.get_mgmt_svc_client(MonitorManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._monitor_client + + @property + def log_analytics_client(self): + self.log('Getting log analytics client') + if not self._log_analytics_client: + self._log_analytics_client = self.get_mgmt_svc_client(LogAnalyticsManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._log_analytics_client + + @property + def log_analytics_models(self): + self.log('Getting log analytics models') + return LogAnalyticsModels + + @property + def servicebus_client(self): + self.log('Getting servicebus client') + if not self._servicebus_client: + self._servicebus_client = self.get_mgmt_svc_client(ServiceBusManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._servicebus_client + + @property + def servicebus_models(self): + return ServicebusModel + + @property + def automation_client(self): + self.log('Getting automation client') + if not self._automation_client: + self._automation_client = self.get_mgmt_svc_client(AutomationClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._automation_client + + @property + def automation_models(self): + return AutomationModel + + @property + def IoThub_client(self): + self.log('Getting iothub client') + if not self._IoThub_client: + self._IoThub_client = self.get_mgmt_svc_client(IotHubClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._IoThub_client + + @property + def IoThub_models(self): + return IoTHubModels + + @property + def automation_client(self): + self.log('Getting automation client') + if not self._automation_client: + self._automation_client = self.get_mgmt_svc_client(AutomationClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._automation_client + + @property + def automation_models(self): + return AutomationModel + + @property + def lock_client(self): + self.log('Getting lock client') + if not self._lock_client: + self._lock_client = self.get_mgmt_svc_client(ManagementLockClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2016-09-01') + return self._lock_client + + @property + def lock_models(self): + self.log("Getting lock models") + return ManagementLockClient.models('2016-09-01') + + +class AzureSASAuthentication(Authentication): + """Simple SAS Authentication. + An implementation of Authentication in + https://github.com/Azure/msrest-for-python/blob/0732bc90bdb290e5f58c675ffdd7dbfa9acefc93/msrest/authentication.py + + :param str token: SAS token + """ + def __init__(self, token): + self.token = token + + def signed_session(self): + session = super(AzureSASAuthentication, self).signed_session() + session.headers['Authorization'] = self.token + return session + + +class AzureRMAuthException(Exception): + pass + + +class AzureRMAuth(object): + def __init__(self, auth_source='auto', profile=None, subscription_id=None, client_id=None, secret=None, + tenant=None, ad_user=None, password=None, cloud_environment='AzureCloud', cert_validation_mode='validate', + api_profile='latest', adfs_authority_url=None, fail_impl=None, **kwargs): + + if fail_impl: + self._fail_impl = fail_impl + else: + self._fail_impl = self._default_fail_impl + + self._cloud_environment = None + self._adfs_authority_url = None + + # authenticate + self.credentials = self._get_credentials( + dict(auth_source=auth_source, profile=profile, subscription_id=subscription_id, client_id=client_id, secret=secret, + tenant=tenant, ad_user=ad_user, password=password, cloud_environment=cloud_environment, + cert_validation_mode=cert_validation_mode, api_profile=api_profile, adfs_authority_url=adfs_authority_url)) + + if not self.credentials: + if HAS_AZURE_CLI_CORE: + self.fail("Failed to get credentials. Either pass as parameters, set environment variables, " + "define a profile in ~/.azure/credentials, or log in with Azure CLI (`az login`).") + else: + self.fail("Failed to get credentials. Either pass as parameters, set environment variables, " + "define a profile in ~/.azure/credentials, or install Azure CLI and log in (`az login`).") + + # cert validation mode precedence: module-arg, credential profile, env, "validate" + self._cert_validation_mode = cert_validation_mode or self.credentials.get('cert_validation_mode') or \ + os.environ.get('AZURE_CERT_VALIDATION_MODE') or 'validate' + + if self._cert_validation_mode not in ['validate', 'ignore']: + self.fail('invalid cert_validation_mode: {0}'.format(self._cert_validation_mode)) + + # if cloud_environment specified, look up/build Cloud object + raw_cloud_env = self.credentials.get('cloud_environment') + if self.credentials.get('credentials') is not None and raw_cloud_env is not None: + self._cloud_environment = raw_cloud_env + elif not raw_cloud_env: + self._cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD # SDK default + else: + # try to look up "well-known" values via the name attribute on azure_cloud members + all_clouds = [x[1] for x in inspect.getmembers(azure_cloud) if isinstance(x[1], azure_cloud.Cloud)] + matched_clouds = [x for x in all_clouds if x.name == raw_cloud_env] + if len(matched_clouds) == 1: + self._cloud_environment = matched_clouds[0] + elif len(matched_clouds) > 1: + self.fail("Azure SDK failure: more than one cloud matched for cloud_environment name '{0}'".format(raw_cloud_env)) + else: + if not urlparse.urlparse(raw_cloud_env).scheme: + self.fail("cloud_environment must be an endpoint discovery URL or one of {0}".format([x.name for x in all_clouds])) + try: + self._cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(raw_cloud_env) + except Exception as e: + self.fail("cloud_environment {0} could not be resolved: {1}".format(raw_cloud_env, e.message), exception=traceback.format_exc()) + + if self.credentials.get('subscription_id', None) is None and self.credentials.get('credentials') is None: + self.fail("Credentials did not include a subscription_id value.") + self.log("setting subscription_id") + self.subscription_id = self.credentials['subscription_id'] + + # get authentication authority + # for adfs, user could pass in authority or not. + # for others, use default authority from cloud environment + if self.credentials.get('adfs_authority_url') is None: + self._adfs_authority_url = self._cloud_environment.endpoints.active_directory + else: + self._adfs_authority_url = self.credentials.get('adfs_authority_url') + + # get resource from cloud environment + self._resource = self._cloud_environment.endpoints.active_directory_resource_id + + if self.credentials.get('credentials') is not None: + # AzureCLI credentials + self.azure_credentials = self.credentials['credentials'] + elif self.credentials.get('client_id') is not None and \ + self.credentials.get('secret') is not None and \ + self.credentials.get('tenant') is not None: + self.azure_credentials = ServicePrincipalCredentials(client_id=self.credentials['client_id'], + secret=self.credentials['secret'], + tenant=self.credentials['tenant'], + cloud_environment=self._cloud_environment, + verify=self._cert_validation_mode == 'validate') + + elif self.credentials.get('ad_user') is not None and \ + self.credentials.get('password') is not None and \ + self.credentials.get('client_id') is not None and \ + self.credentials.get('tenant') is not None: + + self.azure_credentials = self.acquire_token_with_username_password( + self._adfs_authority_url, + self._resource, + self.credentials['ad_user'], + self.credentials['password'], + self.credentials['client_id'], + self.credentials['tenant']) + + elif self.credentials.get('ad_user') is not None and self.credentials.get('password') is not None: + tenant = self.credentials.get('tenant') + if not tenant: + tenant = 'common' # SDK default + + self.azure_credentials = UserPassCredentials(self.credentials['ad_user'], + self.credentials['password'], + tenant=tenant, + cloud_environment=self._cloud_environment, + verify=self._cert_validation_mode == 'validate') + else: + self.fail("Failed to authenticate with provided credentials. Some attributes were missing. " + "Credentials must include client_id, secret and tenant or ad_user and password, or " + "ad_user, password, client_id, tenant and adfs_authority_url(optional) for ADFS authentication, or " + "be logged in using AzureCLI.") + + def fail(self, msg, exception=None, **kwargs): + self._fail_impl(msg) + + def _default_fail_impl(self, msg, exception=None, **kwargs): + raise AzureRMAuthException(msg) + + def _get_profile(self, profile="default"): + path = expanduser("~/.azure/credentials") + try: + config = configparser.ConfigParser() + config.read(path) + except Exception as exc: + self.fail("Failed to access {0}. Check that the file exists and you have read " + "access. {1}".format(path, str(exc))) + credentials = dict() + for key in AZURE_CREDENTIAL_ENV_MAPPING: + try: + credentials[key] = config.get(profile, key, raw=True) + except Exception: + pass + + if credentials.get('subscription_id'): + return credentials + + return None + + def _get_msi_credentials(self, subscription_id_param=None, **kwargs): + client_id = kwargs.get('client_id', None) + credentials = MSIAuthentication(client_id=client_id) + subscription_id = subscription_id_param or os.environ.get(AZURE_CREDENTIAL_ENV_MAPPING['subscription_id'], None) + if not subscription_id: + try: + # use the first subscription of the MSI + subscription_client = SubscriptionClient(credentials) + subscription = next(subscription_client.subscriptions.list()) + subscription_id = str(subscription.subscription_id) + except Exception as exc: + self.fail("Failed to get MSI token: {0}. " + "Please check whether your machine enabled MSI or grant access to any subscription.".format(str(exc))) + return { + 'credentials': credentials, + 'subscription_id': subscription_id + } + + def _get_azure_cli_credentials(self): + credentials, subscription_id = get_azure_cli_credentials() + cloud_environment = get_cli_active_cloud() + + cli_credentials = { + 'credentials': credentials, + 'subscription_id': subscription_id, + 'cloud_environment': cloud_environment + } + return cli_credentials + + def _get_env_credentials(self): + env_credentials = dict() + for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items(): + env_credentials[attribute] = os.environ.get(env_variable, None) + + if env_credentials['profile']: + credentials = self._get_profile(env_credentials['profile']) + return credentials + + if env_credentials.get('subscription_id') is not None: + return env_credentials + + return None + + # TODO: use explicit kwargs instead of intermediate dict + def _get_credentials(self, params): + # Get authentication credentials. + self.log('Getting credentials') + + arg_credentials = dict() + for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items(): + arg_credentials[attribute] = params.get(attribute, None) + + auth_source = params.get('auth_source', None) + if not auth_source: + auth_source = os.environ.get('ANSIBLE_AZURE_AUTH_SOURCE', 'auto') + + if auth_source == 'msi': + self.log('Retrieving credenitals from MSI') + return self._get_msi_credentials(arg_credentials['subscription_id'], client_id=params.get('client_id', None)) + + if auth_source == 'cli': + if not HAS_AZURE_CLI_CORE: + self.fail(msg=missing_required_lib('azure-cli', reason='for `cli` auth_source'), + exception=HAS_AZURE_CLI_CORE_EXC) + try: + self.log('Retrieving credentials from Azure CLI profile') + cli_credentials = self._get_azure_cli_credentials() + return cli_credentials + except CLIError as err: + self.fail("Azure CLI profile cannot be loaded - {0}".format(err)) + + if auth_source == 'env': + self.log('Retrieving credentials from environment') + env_credentials = self._get_env_credentials() + return env_credentials + + if auth_source == 'credential_file': + self.log("Retrieving credentials from credential file") + profile = params.get('profile') or 'default' + default_credentials = self._get_profile(profile) + return default_credentials + + # auto, precedence: module parameters -> environment variables -> default profile in ~/.azure/credentials + # try module params + if arg_credentials['profile'] is not None: + self.log('Retrieving credentials with profile parameter.') + credentials = self._get_profile(arg_credentials['profile']) + return credentials + + if arg_credentials['subscription_id']: + self.log('Received credentials from parameters.') + return arg_credentials + + # try environment + env_credentials = self._get_env_credentials() + if env_credentials: + self.log('Received credentials from env.') + return env_credentials + + # try default profile from ~./azure/credentials + default_credentials = self._get_profile() + if default_credentials: + self.log('Retrieved default profile credentials from ~/.azure/credentials.') + return default_credentials + + try: + if HAS_AZURE_CLI_CORE: + self.log('Retrieving credentials from AzureCLI profile') + cli_credentials = self._get_azure_cli_credentials() + return cli_credentials + except CLIError as ce: + self.log('Error getting AzureCLI profile credentials - {0}'.format(ce)) + + return None + + def acquire_token_with_username_password(self, authority, resource, username, password, client_id, tenant): + authority_uri = authority + + if tenant is not None: + authority_uri = authority + '/' + tenant + + context = AuthenticationContext(authority_uri) + token_response = context.acquire_token_with_username_password(resource, username, password, client_id) + + return AADTokenCredentials(token_response) + + def log(self, msg, pretty_print=False): + pass + # Use only during module development + # if self.debug: + # log_file = open('azure_rm.log', 'a') + # if pretty_print: + # log_file.write(json.dumps(msg, indent=4, sort_keys=True)) + # else: + # log_file.write(msg + u'\n') diff --git a/test/support/integration/plugins/module_utils/azure_rm_common_rest.py b/test/support/integration/plugins/module_utils/azure_rm_common_rest.py new file mode 100644 index 00000000..4fd7eaa3 --- /dev/null +++ b/test/support/integration/plugins/module_utils/azure_rm_common_rest.py @@ -0,0 +1,97 @@ +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION + +try: + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_configuration import AzureConfiguration + from msrest.service_client import ServiceClient + from msrest.pipeline import ClientRawResponse + from msrest.polling import LROPoller + from msrestazure.polling.arm_polling import ARMPolling + import uuid + import json +except ImportError: + # This is handled in azure_rm_common + AzureConfiguration = object + +ANSIBLE_USER_AGENT = 'Ansible/{0}'.format(ANSIBLE_VERSION) + + +class GenericRestClientConfiguration(AzureConfiguration): + + def __init__(self, credentials, subscription_id, base_url=None): + + if credentials is None: + raise ValueError("Parameter 'credentials' must not be None.") + if subscription_id is None: + raise ValueError("Parameter 'subscription_id' must not be None.") + if not base_url: + base_url = 'https://management.azure.com' + + super(GenericRestClientConfiguration, self).__init__(base_url) + + self.add_user_agent(ANSIBLE_USER_AGENT) + + self.credentials = credentials + self.subscription_id = subscription_id + + +class GenericRestClient(object): + + def __init__(self, credentials, subscription_id, base_url=None): + self.config = GenericRestClientConfiguration(credentials, subscription_id, base_url) + self._client = ServiceClient(self.config.credentials, self.config) + self.models = None + + def query(self, url, method, query_parameters, header_parameters, body, expected_status_codes, polling_timeout, polling_interval): + # Construct and send request + operation_config = {} + + request = None + + if header_parameters is None: + header_parameters = {} + + header_parameters['x-ms-client-request-id'] = str(uuid.uuid1()) + + if method == 'GET': + request = self._client.get(url, query_parameters) + elif method == 'PUT': + request = self._client.put(url, query_parameters) + elif method == 'POST': + request = self._client.post(url, query_parameters) + elif method == 'HEAD': + request = self._client.head(url, query_parameters) + elif method == 'PATCH': + request = self._client.patch(url, query_parameters) + elif method == 'DELETE': + request = self._client.delete(url, query_parameters) + elif method == 'MERGE': + request = self._client.merge(url, query_parameters) + + response = self._client.send(request, header_parameters, body, **operation_config) + + if response.status_code not in expected_status_codes: + exp = CloudError(response) + exp.request_id = response.headers.get('x-ms-request-id') + raise exp + elif response.status_code == 202 and polling_timeout > 0: + def get_long_running_output(response): + return response + poller = LROPoller(self._client, + ClientRawResponse(None, response), + get_long_running_output, + ARMPolling(polling_interval, **operation_config)) + response = self.get_poller_result(poller, polling_timeout) + + return response + + def get_poller_result(self, poller, timeout): + try: + poller.wait(timeout=timeout) + return poller.result() + except Exception as exc: + raise diff --git a/test/support/integration/plugins/module_utils/cloud.py b/test/support/integration/plugins/module_utils/cloud.py new file mode 100644 index 00000000..0d29071f --- /dev/null +++ b/test/support/integration/plugins/module_utils/cloud.py @@ -0,0 +1,217 @@ +# +# (c) 2016 Allen Sanabria, <asanabria@linuxdynasty.org> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +""" +This module adds shared support for generic cloud modules + +In order to use this module, include it as part of a custom +module as shown below. + +from ansible.module_utils.cloud import CloudRetry + +The 'cloud' module provides the following common classes: + + * CloudRetry + - The base class to be used by other cloud providers, in order to + provide a backoff/retry decorator based on status codes. + + - Example using the AWSRetry class which inherits from CloudRetry. + + @AWSRetry.exponential_backoff(retries=10, delay=3) + get_ec2_security_group_ids_from_names() + + @AWSRetry.jittered_backoff() + get_ec2_security_group_ids_from_names() + +""" +import random +from functools import wraps +import syslog +import time + + +def _exponential_backoff(retries=10, delay=2, backoff=2, max_delay=60): + """ Customizable exponential backoff strategy. + Args: + retries (int): Maximum number of times to retry a request. + delay (float): Initial (base) delay. + backoff (float): base of the exponent to use for exponential + backoff. + max_delay (int): Optional. If provided each delay generated is capped + at this amount. Defaults to 60 seconds. + Returns: + Callable that returns a generator. This generator yields durations in + seconds to be used as delays for an exponential backoff strategy. + Usage: + >>> backoff = _exponential_backoff() + >>> backoff + <function backoff_backoff at 0x7f0d939facf8> + >>> list(backoff()) + [2, 4, 8, 16, 32, 60, 60, 60, 60, 60] + """ + def backoff_gen(): + for retry in range(0, retries): + sleep = delay * backoff ** retry + yield sleep if max_delay is None else min(sleep, max_delay) + return backoff_gen + + +def _full_jitter_backoff(retries=10, delay=3, max_delay=60, _random=random): + """ Implements the "Full Jitter" backoff strategy described here + https://www.awsarchitectureblog.com/2015/03/backoff.html + Args: + retries (int): Maximum number of times to retry a request. + delay (float): Approximate number of seconds to sleep for the first + retry. + max_delay (int): The maximum number of seconds to sleep for any retry. + _random (random.Random or None): Makes this generator testable by + allowing developers to explicitly pass in the a seeded Random. + Returns: + Callable that returns a generator. This generator yields durations in + seconds to be used as delays for a full jitter backoff strategy. + Usage: + >>> backoff = _full_jitter_backoff(retries=5) + >>> backoff + <function backoff_backoff at 0x7f0d939facf8> + >>> list(backoff()) + [3, 6, 5, 23, 38] + >>> list(backoff()) + [2, 1, 6, 6, 31] + """ + def backoff_gen(): + for retry in range(0, retries): + yield _random.randint(0, min(max_delay, delay * 2 ** retry)) + return backoff_gen + + +class CloudRetry(object): + """ CloudRetry can be used by any cloud provider, in order to implement a + backoff algorithm/retry effect based on Status Code from Exceptions. + """ + # This is the base class of the exception. + # AWS Example botocore.exceptions.ClientError + base_class = None + + @staticmethod + def status_code_from_exception(error): + """ Return the status code from the exception object + Args: + error (object): The exception itself. + """ + pass + + @staticmethod + def found(response_code, catch_extra_error_codes=None): + """ Return True if the Response Code to retry on was found. + Args: + response_code (str): This is the Response Code that is being matched against. + """ + pass + + @classmethod + def _backoff(cls, backoff_strategy, catch_extra_error_codes=None): + """ Retry calling the Cloud decorated function using the provided + backoff strategy. + Args: + backoff_strategy (callable): Callable that returns a generator. The + generator should yield sleep times for each retry of the decorated + function. + """ + def deco(f): + @wraps(f) + def retry_func(*args, **kwargs): + for delay in backoff_strategy(): + try: + return f(*args, **kwargs) + except Exception as e: + if isinstance(e, cls.base_class): + response_code = cls.status_code_from_exception(e) + if cls.found(response_code, catch_extra_error_codes): + msg = "{0}: Retrying in {1} seconds...".format(str(e), delay) + syslog.syslog(syslog.LOG_INFO, msg) + time.sleep(delay) + else: + # Return original exception if exception is not a ClientError + raise e + else: + # Return original exception if exception is not a ClientError + raise e + return f(*args, **kwargs) + + return retry_func # true decorator + + return deco + + @classmethod + def exponential_backoff(cls, retries=10, delay=3, backoff=2, max_delay=60, catch_extra_error_codes=None): + """ + Retry calling the Cloud decorated function using an exponential backoff. + + Kwargs: + retries (int): Number of times to retry a failed request before giving up + default=10 + delay (int or float): Initial delay between retries in seconds + default=3 + backoff (int or float): backoff multiplier e.g. value of 2 will + double the delay each retry + default=1.1 + max_delay (int or None): maximum amount of time to wait between retries. + default=60 + """ + return cls._backoff(_exponential_backoff( + retries=retries, delay=delay, backoff=backoff, max_delay=max_delay), catch_extra_error_codes) + + @classmethod + def jittered_backoff(cls, retries=10, delay=3, max_delay=60, catch_extra_error_codes=None): + """ + Retry calling the Cloud decorated function using a jittered backoff + strategy. More on this strategy here: + + https://www.awsarchitectureblog.com/2015/03/backoff.html + + Kwargs: + retries (int): Number of times to retry a failed request before giving up + default=10 + delay (int): Initial delay between retries in seconds + default=3 + max_delay (int): maximum amount of time to wait between retries. + default=60 + """ + return cls._backoff(_full_jitter_backoff( + retries=retries, delay=delay, max_delay=max_delay), catch_extra_error_codes) + + @classmethod + def backoff(cls, tries=10, delay=3, backoff=1.1, catch_extra_error_codes=None): + """ + Retry calling the Cloud decorated function using an exponential backoff. + + Compatibility for the original implementation of CloudRetry.backoff that + did not provide configurable backoff strategies. Developers should use + CloudRetry.exponential_backoff instead. + + Kwargs: + tries (int): Number of times to try (not retry) before giving up + default=10 + delay (int or float): Initial delay between retries in seconds + default=3 + backoff (int or float): backoff multiplier e.g. value of 2 will + double the delay each retry + default=1.1 + """ + return cls.exponential_backoff( + retries=tries - 1, delay=delay, backoff=backoff, max_delay=None, catch_extra_error_codes=catch_extra_error_codes) diff --git a/test/support/integration/plugins/module_utils/compat/__init__.py b/test/support/integration/plugins/module_utils/compat/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/compat/__init__.py diff --git a/test/support/integration/plugins/module_utils/compat/ipaddress.py b/test/support/integration/plugins/module_utils/compat/ipaddress.py new file mode 100644 index 00000000..c46ad72a --- /dev/null +++ b/test/support/integration/plugins/module_utils/compat/ipaddress.py @@ -0,0 +1,2476 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file, and this file only, is based on +# Lib/ipaddress.py of cpython +# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved" +# are retained in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. + +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +from __future__ import unicode_literals + + +import itertools +import struct + + +# The following makes it easier for us to script updates of the bundled code and is not part of +# upstream +_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"} + +__version__ = '1.0.22' + +# Compatibility functions +_compat_int_types = (int,) +try: + _compat_int_types = (int, long) +except NameError: + pass +try: + _compat_str = unicode +except NameError: + _compat_str = str + assert bytes != str +if b'\0'[0] == 0: # Python 3 semantics + def _compat_bytes_to_byte_vals(byt): + return byt +else: + def _compat_bytes_to_byte_vals(byt): + return [struct.unpack(b'!B', b)[0] for b in byt] +try: + _compat_int_from_byte_vals = int.from_bytes +except AttributeError: + def _compat_int_from_byte_vals(bytvals, endianess): + assert endianess == 'big' + res = 0 + for bv in bytvals: + assert isinstance(bv, _compat_int_types) + res = (res << 8) + bv + return res + + +def _compat_to_bytes(intval, length, endianess): + assert isinstance(intval, _compat_int_types) + assert endianess == 'big' + if length == 4: + if intval < 0 or intval >= 2 ** 32: + raise struct.error("integer out of range for 'I' format code") + return struct.pack(b'!I', intval) + elif length == 16: + if intval < 0 or intval >= 2 ** 128: + raise struct.error("integer out of range for 'QQ' format code") + return struct.pack(b'!QQ', intval >> 64, intval & 0xffffffffffffffff) + else: + raise NotImplementedError() + + +if hasattr(int, 'bit_length'): + # Not int.bit_length , since that won't work in 2.7 where long exists + def _compat_bit_length(i): + return i.bit_length() +else: + def _compat_bit_length(i): + for res in itertools.count(): + if i >> res == 0: + return res + + +def _compat_range(start, end, step=1): + assert step > 0 + i = start + while i < end: + yield i + i += step + + +class _TotalOrderingMixin(object): + __slots__ = () + + # Helper that derives the other comparison operations from + # __lt__ and __eq__ + # We avoid functools.total_ordering because it doesn't handle + # NotImplemented correctly yet (http://bugs.python.org/issue10042) + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + + def __lt__(self, other): + raise NotImplementedError + + def __le__(self, other): + less = self.__lt__(other) + if less is NotImplemented or not less: + return self.__eq__(other) + return less + + def __gt__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not (less or equal) + + def __ge__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + return not less + + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + '%r does not appear to be an IPv4 or IPv6 address. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?' % address) + + raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % + address) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + '%r does not appear to be an IPv4 or IPv6 network. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?' % address) + + raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % + address) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' % + address) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return _compat_to_bytes(address, 4, 'big') + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return _compat_to_bytes(address, 16, 'big') + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = _compat_str(address).split('/') + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) # pylint: disable=stop-iteration-return + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, _compat_bit_length(~number & (number - 1))) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if (not (isinstance(first, _BaseAddress) and + isinstance(last, _BaseAddress))): + raise TypeError('first and last must be IP addresses, not networks') + if first.version != last.version: + raise TypeError("%s and %s are not of the same version" % ( + first, last)) + if first > last: + raise ValueError('last IP address must be greater than first') + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError('unknown IP version') + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min(_count_righthand_zero_bits(first_int, ip_bits), + _compat_bit_length(last_int - first_int + 1) - 1) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, + # last.network_address <= net.network_address is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, nets[-1])) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase(_TotalOrderingMixin): + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return _compat_str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = '%200s has no version specified' % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._max_prefixlen, + self._version)) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = ( + '%r (len %d != %d) is not permitted as an IPv%d address. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?') + raise AddressValueError(msg % (address, address_len, + expected_len, self._version)) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits(ip_int, + cls._max_prefixlen) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = _compat_to_bytes(ip_int, byteslen, 'big') + msg = 'Netmask pattern %r mixes zeroes & ones' + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = '%r is not a valid netmask' % netmask_str + raise NetmaskValueError(msg) + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + def __reduce__(self): + return self.__class__, (_compat_str(self),) + + +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + try: + return (self._ip == other._ip and + self._version == other._version) + except AttributeError: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseAddress): + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self._ip != other._ip: + return self._ip < other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return _compat_str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + +class _BaseNetwork(_IPAddressBase): + + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + + """ + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return '%s/%d' % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError('address out of range') + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError('address out of range') + return self._address_class(broadcast + n) + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseNetwork): + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __eq__(self, other): + try: + return (self._version == other._version and + self.network_address == other.network_address and + int(self.netmask) == int(other.netmask)) + except AttributeError: + return NotImplemented + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return (int(self.network_address) <= int(other._ip) <= + int(self.broadcast_address)) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other or ( + other.network_address in self or ( + other.broadcast_address in self))) + + @property + def broadcast_address(self): + x = self._cache.get('broadcast_address') + if x is None: + x = self._address_class(int(self.network_address) | + int(self.hostmask)) + self._cache['broadcast_address'] = x + return x + + @property + def hostmask(self): + x = self._cache.get('hostmask') + if x is None: + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) + self._cache['hostmask'] = x + return x + + @property + def with_prefixlen(self): + return '%s/%d' % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = '%200s has no associated address class' % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError("%s and %s are not of the same version" % ( + self, other)) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not other.subnet_of(self): + raise ValueError('%s not contained in %s' % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__('%s/%s' % (other.network_address, + other.prefixlen)) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if other.subnet_of(s1): + yield s2 + s1, s2 = s1.subnets() + elif other.subnet_of(s2): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError('new prefix must be longer') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError('prefix length diff must be > 0') + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + 'prefix length diff %d is invalid for netblock %s' % ( + new_prefixlen, self)) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in _compat_range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError('new prefix must be shorter') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % + (self.prefixlen, prefixlen_diff)) + return self.__class__(( + int(self.network_address) & (int(self.netmask) << prefixlen_diff), + new_prefixlen)) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return (self.network_address.is_multicast and + self.broadcast_address.is_multicast) + + @staticmethod + def _is_subnet_of(a, b): + try: + # Always false if one is v4 and the other is v6. + if a._version != b._version: + raise TypeError("%s and %s are not of the same version" % (a, b)) + return (b.network_address <= a.network_address and + b.broadcast_address >= a.broadcast_address) + except AttributeError: + raise TypeError("Unable to test subnet containment " + "between %s and %s" % (a, b)) + + def subnet_of(self, other): + """Return True if this network is a subnet of other.""" + return self._is_subnet_of(self, other) + + def supernet_of(self, other): + """Return True if this network is a supernet of other.""" + return self._is_subnet_of(other, self) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return (self.network_address.is_reserved and + self.broadcast_address.is_reserved) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return (self.network_address.is_link_local and + self.broadcast_address.is_link_local) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return (self.network_address.is_private and + self.broadcast_address.is_private) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return (self.network_address.is_unspecified and + self.broadcast_address.is_unspecified) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return (self.network_address.is_loopback and + self.broadcast_address.is_loopback) + + +class _BaseV4(object): + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2 ** IPV4LENGTH) - 1 + _DECIMAL_DIGITS = frozenset('0123456789') + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0]) + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return _compat_str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + octets = ip_str.split('.') + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return _compat_int_from_byte_vals( + map(cls._parse_octet, octets), 'big') + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._DECIMAL_DIGITS.issuperset(octet_str): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + # Any octets that look like they *might* be written in octal, + # and which don't look exactly the same in both octal and + # decimal are rejected as ambiguous + if octet_int > 7 and octet_str[0] == '0': + msg = "Ambiguous (octal/decimal) value in %r not permitted" + raise ValueError(msg % octet_str) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return '.'.join(_compat_str(struct.unpack(b'!B', b)[0] + if isinstance(b, bytes) + else b) + for b in _compat_to_bytes(ip_int, 4, 'big')) + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split('.') + try: + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = _compat_str(self).split('.')[::-1] + return '.'.join(reverse_octets) + '.in-addr.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + return ( + self not in self._constants._public_network and + not self.is_private) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv4Address.__init__(self, address) + self.network = IPv4Network(self._ip) + self._prefixlen = self._max_prefixlen + return + + if isinstance(address, tuple): + IPv4Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + + self.network = IPv4Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv4Address.__init__(self, addr[0]) + + self.network = IPv4Network(address, strict=False) + self._prefixlen = self.network._prefixlen + + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.0.2/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (_compat_int_types, bytes)): + self.network_address = IPv4Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen) + # fixme: address/network test here. + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + # We weren't given an address[1] + arg = self._max_prefixlen + self.network_address = IPv4Address(address[0]) + self.netmask, self._prefixlen = self._make_netmask(arg) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv4Address(packed & + int(self.netmask)) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if (IPv4Address(int(self.network_address) & int(self.netmask)) != + self.network_address): + raise ValueError('%s has host bits set' % self) + self.network_address = IPv4Address(int(self.network_address) & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return (not (self.network_address in IPv4Network('100.64.0.0/10') and + self.broadcast_address in IPv4Network('100.64.0.0/10')) and + not self.is_private) + + +class _IPv4Constants(object): + + _linklocal_network = IPv4Network('169.254.0.0/16') + + _loopback_network = IPv4Network('127.0.0.0/8') + + _multicast_network = IPv4Network('224.0.0.0/4') + + _public_network = IPv4Network('100.64.0.0/10') + + _private_networks = [ + IPv4Network('0.0.0.0/8'), + IPv4Network('10.0.0.0/8'), + IPv4Network('127.0.0.0/8'), + IPv4Network('169.254.0.0/16'), + IPv4Network('172.16.0.0/12'), + IPv4Network('192.0.0.0/29'), + IPv4Network('192.0.0.170/31'), + IPv4Network('192.0.2.0/24'), + IPv4Network('192.168.0.0/16'), + IPv4Network('198.18.0.0/15'), + IPv4Network('198.51.100.0/24'), + IPv4Network('203.0.113.0/24'), + IPv4Network('240.0.0.0/4'), + IPv4Network('255.255.255.255/32'), + ] + + _reserved_network = IPv4Network('240.0.0.0/4') + + _unspecified_address = IPv4Address('0.0.0.0') + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6(object): + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2 ** IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef') + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + parts = ip_str.split(':') + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if '.' in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF)) + parts.append('%x' % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % ( + _max_parts - 1, ip_str) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in _compat_range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == '0': + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = (best_doublecolon_start + + best_doublecolon_len) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [''] + hextets[best_doublecolon_start:best_doublecolon_end] = [''] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [''] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError('IPv6 address is too large') + + hex_str = '%032x' % ip_int + hextets = ['%x' % int(hex_str[x:x + 4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ':'.join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = _compat_str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = _compat_str(self.ip) + else: + ip_str = _compat_str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = '%032x' % ip_int + parts = [hex_str[x:x + 4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return '%s/%d' % (':'.join(parts), self._prefixlen) + return ':'.join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(':', '') + return '.'.join(reverse_chars) + '.ip6.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF)) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv6Address.__init__(self, address) + self.network = IPv6Network(self._ip) + self._prefixlen = self._max_prefixlen + return + if isinstance(address, tuple): + IPv6Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv6Address.__init__(self, addr[0]) + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Efficient constructor from integer or packed address + if isinstance(address, (bytes, _compat_int_types)): + self.network_address = IPv6Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen) + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + self.network_address = IPv6Address(address[0]) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv6Address(packed & + int(self.netmask)) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + + self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if (IPv6Address(int(self.network_address) & int(self.netmask)) != + self.network_address): + raise ValueError('%s has host bits set' % self) + self.network_address = IPv6Address(int(self.network_address) & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return (self.network_address.is_site_local and + self.broadcast_address.is_site_local) + + +class _IPv6Constants(object): + + _linklocal_network = IPv6Network('fe80::/10') + + _multicast_network = IPv6Network('ff00::/8') + + _private_networks = [ + IPv6Network('::1/128'), + IPv6Network('::/128'), + IPv6Network('::ffff:0:0/96'), + IPv6Network('100::/64'), + IPv6Network('2001::/23'), + IPv6Network('2001:2::/48'), + IPv6Network('2001:db8::/32'), + IPv6Network('2001:10::/28'), + IPv6Network('fc00::/7'), + IPv6Network('fe80::/10'), + ] + + _reserved_networks = [ + IPv6Network('::/8'), IPv6Network('100::/8'), + IPv6Network('200::/7'), IPv6Network('400::/6'), + IPv6Network('800::/5'), IPv6Network('1000::/4'), + IPv6Network('4000::/3'), IPv6Network('6000::/3'), + IPv6Network('8000::/3'), IPv6Network('A000::/3'), + IPv6Network('C000::/3'), IPv6Network('E000::/4'), + IPv6Network('F000::/5'), IPv6Network('F800::/6'), + IPv6Network('FE00::/9'), + ] + + _sitelocal_network = IPv6Network('fec0::/10') + + +IPv6Address._constants = _IPv6Constants diff --git a/test/support/integration/plugins/module_utils/crypto.py b/test/support/integration/plugins/module_utils/crypto.py new file mode 100644 index 00000000..e67eeff1 --- /dev/null +++ b/test/support/integration/plugins/module_utils/crypto.py @@ -0,0 +1,2125 @@ +# -*- coding: utf-8 -*- +# +# (c) 2016, Yanis Guenane <yanis+ansible@guenane.org> +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +# ---------------------------------------------------------------------- +# A clearly marked portion of this file is licensed under the BSD license +# Copyright (c) 2015, 2016 Paul Kehrer (@reaperhulk) +# Copyright (c) 2017 Fraser Tweedale (@frasertweedale) +# For more details, search for the function _obj2txt(). +# --------------------------------------------------------------------- +# A clearly marked portion of this file is extracted from a project that +# is licensed under the Apache License 2.0 +# Copyright (c) the OpenSSL contributors +# For more details, search for the function _OID_MAP. + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +import sys +from distutils.version import LooseVersion + +try: + import OpenSSL + from OpenSSL import crypto +except ImportError: + # An error will be raised in the calling class to let the end + # user know that OpenSSL couldn't be found. + pass + +try: + import cryptography + from cryptography import x509 + from cryptography.hazmat.backends import default_backend as cryptography_backend + from cryptography.hazmat.primitives.serialization import load_pem_private_key + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives import serialization + import ipaddress + + # Older versions of cryptography (< 2.1) do not have __hash__ functions for + # general name objects (DNSName, IPAddress, ...), while providing overloaded + # equality and string representation operations. This makes it impossible to + # use them in hash-based data structures such as set or dict. Since we are + # actually doing that in openssl_certificate, and potentially in other code, + # we need to monkey-patch __hash__ for these classes to make sure our code + # works fine. + if LooseVersion(cryptography.__version__) < LooseVersion('2.1'): + # A very simply hash function which relies on the representation + # of an object to be implemented. This is the case since at least + # cryptography 1.0, see + # https://github.com/pyca/cryptography/commit/7a9abce4bff36c05d26d8d2680303a6f64a0e84f + def simple_hash(self): + return hash(repr(self)) + + # The hash functions for the following types were added for cryptography 2.1: + # https://github.com/pyca/cryptography/commit/fbfc36da2a4769045f2373b004ddf0aff906cf38 + x509.DNSName.__hash__ = simple_hash + x509.DirectoryName.__hash__ = simple_hash + x509.GeneralName.__hash__ = simple_hash + x509.IPAddress.__hash__ = simple_hash + x509.OtherName.__hash__ = simple_hash + x509.RegisteredID.__hash__ = simple_hash + + if LooseVersion(cryptography.__version__) < LooseVersion('1.2'): + # The hash functions for the following types were added for cryptography 1.2: + # https://github.com/pyca/cryptography/commit/b642deed88a8696e5f01ce6855ccf89985fc35d0 + # https://github.com/pyca/cryptography/commit/d1b5681f6db2bde7a14625538bd7907b08dfb486 + x509.RFC822Name.__hash__ = simple_hash + x509.UniformResourceIdentifier.__hash__ = simple_hash + + # Test whether we have support for X25519, X448, Ed25519 and/or Ed448 + try: + import cryptography.hazmat.primitives.asymmetric.x25519 + CRYPTOGRAPHY_HAS_X25519 = True + try: + cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.private_bytes + CRYPTOGRAPHY_HAS_X25519_FULL = True + except AttributeError: + CRYPTOGRAPHY_HAS_X25519_FULL = False + except ImportError: + CRYPTOGRAPHY_HAS_X25519 = False + CRYPTOGRAPHY_HAS_X25519_FULL = False + try: + import cryptography.hazmat.primitives.asymmetric.x448 + CRYPTOGRAPHY_HAS_X448 = True + except ImportError: + CRYPTOGRAPHY_HAS_X448 = False + try: + import cryptography.hazmat.primitives.asymmetric.ed25519 + CRYPTOGRAPHY_HAS_ED25519 = True + except ImportError: + CRYPTOGRAPHY_HAS_ED25519 = False + try: + import cryptography.hazmat.primitives.asymmetric.ed448 + CRYPTOGRAPHY_HAS_ED448 = True + except ImportError: + CRYPTOGRAPHY_HAS_ED448 = False + + HAS_CRYPTOGRAPHY = True +except ImportError: + # Error handled in the calling module. + CRYPTOGRAPHY_HAS_X25519 = False + CRYPTOGRAPHY_HAS_X25519_FULL = False + CRYPTOGRAPHY_HAS_X448 = False + CRYPTOGRAPHY_HAS_ED25519 = False + CRYPTOGRAPHY_HAS_ED448 = False + HAS_CRYPTOGRAPHY = False + + +import abc +import base64 +import binascii +import datetime +import errno +import hashlib +import os +import re +import tempfile + +from ansible.module_utils import six +from ansible.module_utils._text import to_native, to_bytes, to_text + + +class OpenSSLObjectError(Exception): + pass + + +class OpenSSLBadPassphraseError(OpenSSLObjectError): + pass + + +def get_fingerprint_of_bytes(source): + """Generate the fingerprint of the given bytes.""" + + fingerprint = {} + + try: + algorithms = hashlib.algorithms + except AttributeError: + try: + algorithms = hashlib.algorithms_guaranteed + except AttributeError: + return None + + for algo in algorithms: + f = getattr(hashlib, algo) + try: + h = f(source) + except ValueError: + # This can happen for hash algorithms not supported in FIPS mode + # (https://github.com/ansible/ansible/issues/67213) + continue + try: + # Certain hash functions have a hexdigest() which expects a length parameter + pubkey_digest = h.hexdigest() + except TypeError: + pubkey_digest = h.hexdigest(32) + fingerprint[algo] = ':'.join(pubkey_digest[i:i + 2] for i in range(0, len(pubkey_digest), 2)) + + return fingerprint + + +def get_fingerprint(path, passphrase=None, content=None, backend='pyopenssl'): + """Generate the fingerprint of the public key. """ + + privatekey = load_privatekey(path, passphrase=passphrase, content=content, check_passphrase=False, backend=backend) + + if backend == 'pyopenssl': + try: + publickey = crypto.dump_publickey(crypto.FILETYPE_ASN1, privatekey) + except AttributeError: + # If PyOpenSSL < 16.0 crypto.dump_publickey() will fail. + try: + bio = crypto._new_mem_buf() + rc = crypto._lib.i2d_PUBKEY_bio(bio, privatekey._pkey) + if rc != 1: + crypto._raise_current_error() + publickey = crypto._bio_to_string(bio) + except AttributeError: + # By doing this we prevent the code from raising an error + # yet we return no value in the fingerprint hash. + return None + elif backend == 'cryptography': + publickey = privatekey.public_key().public_bytes( + serialization.Encoding.DER, + serialization.PublicFormat.SubjectPublicKeyInfo + ) + + return get_fingerprint_of_bytes(publickey) + + +def load_file_if_exists(path, module=None, ignore_errors=False): + try: + with open(path, 'rb') as f: + return f.read() + except EnvironmentError as exc: + if exc.errno == errno.ENOENT: + return None + if ignore_errors: + return None + if module is None: + raise + module.fail_json('Error while loading {0} - {1}'.format(path, str(exc))) + except Exception as exc: + if ignore_errors: + return None + if module is None: + raise + module.fail_json('Error while loading {0} - {1}'.format(path, str(exc))) + + +def load_privatekey(path, passphrase=None, check_passphrase=True, content=None, backend='pyopenssl'): + """Load the specified OpenSSL private key. + + The content can also be specified via content; in that case, + this function will not load the key from disk. + """ + + try: + if content is None: + with open(path, 'rb') as b_priv_key_fh: + priv_key_detail = b_priv_key_fh.read() + else: + priv_key_detail = content + + if backend == 'pyopenssl': + + # First try: try to load with real passphrase (resp. empty string) + # Will work if this is the correct passphrase, or the key is not + # password-protected. + try: + result = crypto.load_privatekey(crypto.FILETYPE_PEM, + priv_key_detail, + to_bytes(passphrase or '')) + except crypto.Error as e: + if len(e.args) > 0 and len(e.args[0]) > 0: + if e.args[0][0][2] in ('bad decrypt', 'bad password read'): + # This happens in case we have the wrong passphrase. + if passphrase is not None: + raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key!') + else: + raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!') + raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e)) + if check_passphrase: + # Next we want to make sure that the key is actually protected by + # a passphrase (in case we did try the empty string before, make + # sure that the key is not protected by the empty string) + try: + crypto.load_privatekey(crypto.FILETYPE_PEM, + priv_key_detail, + to_bytes('y' if passphrase == 'x' else 'x')) + if passphrase is not None: + # Since we can load the key without an exception, the + # key isn't password-protected + raise OpenSSLBadPassphraseError('Passphrase provided, but private key is not password-protected!') + except crypto.Error as e: + if passphrase is None and len(e.args) > 0 and len(e.args[0]) > 0: + if e.args[0][0][2] in ('bad decrypt', 'bad password read'): + # The key is obviously protected by the empty string. + # Don't do this at home (if it's possible at all)... + raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!') + elif backend == 'cryptography': + try: + result = load_pem_private_key(priv_key_detail, + None if passphrase is None else to_bytes(passphrase), + cryptography_backend()) + except TypeError as dummy: + raise OpenSSLBadPassphraseError('Wrong or empty passphrase provided for private key') + except ValueError as dummy: + raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key') + + return result + except (IOError, OSError) as exc: + raise OpenSSLObjectError(exc) + + +def load_certificate(path, content=None, backend='pyopenssl'): + """Load the specified certificate.""" + + try: + if content is None: + with open(path, 'rb') as cert_fh: + cert_content = cert_fh.read() + else: + cert_content = content + if backend == 'pyopenssl': + return crypto.load_certificate(crypto.FILETYPE_PEM, cert_content) + elif backend == 'cryptography': + return x509.load_pem_x509_certificate(cert_content, cryptography_backend()) + except (IOError, OSError) as exc: + raise OpenSSLObjectError(exc) + + +def load_certificate_request(path, content=None, backend='pyopenssl'): + """Load the specified certificate signing request.""" + try: + if content is None: + with open(path, 'rb') as csr_fh: + csr_content = csr_fh.read() + else: + csr_content = content + except (IOError, OSError) as exc: + raise OpenSSLObjectError(exc) + if backend == 'pyopenssl': + return crypto.load_certificate_request(crypto.FILETYPE_PEM, csr_content) + elif backend == 'cryptography': + return x509.load_pem_x509_csr(csr_content, cryptography_backend()) + + +def parse_name_field(input_dict): + """Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" + + result = [] + for key in input_dict: + if isinstance(input_dict[key], list): + for entry in input_dict[key]: + result.append((key, entry)) + else: + result.append((key, input_dict[key])) + return result + + +def convert_relative_to_datetime(relative_time_string): + """Get a datetime.datetime or None from a string in the time format described in sshd_config(5)""" + + parsed_result = re.match( + r"^(?P<prefix>[+-])((?P<weeks>\d+)[wW])?((?P<days>\d+)[dD])?((?P<hours>\d+)[hH])?((?P<minutes>\d+)[mM])?((?P<seconds>\d+)[sS]?)?$", + relative_time_string) + + if parsed_result is None or len(relative_time_string) == 1: + # not matched or only a single "+" or "-" + return None + + offset = datetime.timedelta(0) + if parsed_result.group("weeks") is not None: + offset += datetime.timedelta(weeks=int(parsed_result.group("weeks"))) + if parsed_result.group("days") is not None: + offset += datetime.timedelta(days=int(parsed_result.group("days"))) + if parsed_result.group("hours") is not None: + offset += datetime.timedelta(hours=int(parsed_result.group("hours"))) + if parsed_result.group("minutes") is not None: + offset += datetime.timedelta( + minutes=int(parsed_result.group("minutes"))) + if parsed_result.group("seconds") is not None: + offset += datetime.timedelta( + seconds=int(parsed_result.group("seconds"))) + + if parsed_result.group("prefix") == "+": + return datetime.datetime.utcnow() + offset + else: + return datetime.datetime.utcnow() - offset + + +def get_relative_time_option(input_string, input_name, backend='cryptography'): + """Return an absolute timespec if a relative timespec or an ASN1 formatted + string is provided. + + The return value will be a datetime object for the cryptography backend, + and a ASN1 formatted string for the pyopenssl backend.""" + result = to_native(input_string) + if result is None: + raise OpenSSLObjectError( + 'The timespec "%s" for %s is not valid' % + input_string, input_name) + # Relative time + if result.startswith("+") or result.startswith("-"): + result_datetime = convert_relative_to_datetime(result) + if backend == 'pyopenssl': + return result_datetime.strftime("%Y%m%d%H%M%SZ") + elif backend == 'cryptography': + return result_datetime + # Absolute time + if backend == 'pyopenssl': + return input_string + elif backend == 'cryptography': + for date_fmt in ['%Y%m%d%H%M%SZ', '%Y%m%d%H%MZ', '%Y%m%d%H%M%S%z', '%Y%m%d%H%M%z']: + try: + return datetime.datetime.strptime(result, date_fmt) + except ValueError: + pass + + raise OpenSSLObjectError( + 'The time spec "%s" for %s is invalid' % + (input_string, input_name) + ) + + +def select_message_digest(digest_string): + digest = None + if digest_string == 'sha256': + digest = hashes.SHA256() + elif digest_string == 'sha384': + digest = hashes.SHA384() + elif digest_string == 'sha512': + digest = hashes.SHA512() + elif digest_string == 'sha1': + digest = hashes.SHA1() + elif digest_string == 'md5': + digest = hashes.MD5() + return digest + + +def write_file(module, content, default_mode=None, path=None): + ''' + Writes content into destination file as securely as possible. + Uses file arguments from module. + ''' + # Find out parameters for file + file_args = module.load_file_common_arguments(module.params, path=path) + if file_args['mode'] is None: + file_args['mode'] = default_mode + # Create tempfile name + tmp_fd, tmp_name = tempfile.mkstemp(prefix=b'.ansible_tmp') + try: + os.close(tmp_fd) + except Exception as dummy: + pass + module.add_cleanup_file(tmp_name) # if we fail, let Ansible try to remove the file + try: + try: + # Create tempfile + file = os.open(tmp_name, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + os.write(file, content) + os.close(file) + except Exception as e: + try: + os.remove(tmp_name) + except Exception as dummy: + pass + module.fail_json(msg='Error while writing result into temporary file: {0}'.format(e)) + # Update destination to wanted permissions + if os.path.exists(file_args['path']): + module.set_fs_attributes_if_different(file_args, False) + # Move tempfile to final destination + module.atomic_move(tmp_name, file_args['path']) + # Try to update permissions again + module.set_fs_attributes_if_different(file_args, False) + except Exception as e: + try: + os.remove(tmp_name) + except Exception as dummy: + pass + module.fail_json(msg='Error while writing result: {0}'.format(e)) + + +@six.add_metaclass(abc.ABCMeta) +class OpenSSLObject(object): + + def __init__(self, path, state, force, check_mode): + self.path = path + self.state = state + self.force = force + self.name = os.path.basename(path) + self.changed = False + self.check_mode = check_mode + + def check(self, module, perms_required=True): + """Ensure the resource is in its desired state.""" + + def _check_state(): + return os.path.exists(self.path) + + def _check_perms(module): + file_args = module.load_file_common_arguments(module.params) + return not module.set_fs_attributes_if_different(file_args, False) + + if not perms_required: + return _check_state() + + return _check_state() and _check_perms(module) + + @abc.abstractmethod + def dump(self): + """Serialize the object into a dictionary.""" + + pass + + @abc.abstractmethod + def generate(self): + """Generate the resource.""" + + pass + + def remove(self, module): + """Remove the resource from the filesystem.""" + + try: + os.remove(self.path) + self.changed = True + except OSError as exc: + if exc.errno != errno.ENOENT: + raise OpenSSLObjectError(exc) + else: + pass + + +# ##################################################################################### +# ##################################################################################### +# This has been extracted from the OpenSSL project's objects.txt: +# https://github.com/openssl/openssl/blob/9537fe5757bb07761fa275d779bbd40bcf5530e4/crypto/objects/objects.txt +# Extracted with https://gist.github.com/felixfontein/376748017ad65ead093d56a45a5bf376 +# +# In case the following data structure has any copyrightable content, note that it is licensed as follows: +# Copyright (c) the OpenSSL contributors +# Licensed under the Apache License 2.0 +# https://github.com/openssl/openssl/blob/master/LICENSE +_OID_MAP = { + '0': ('itu-t', 'ITU-T', 'ccitt'), + '0.3.4401.5': ('ntt-ds', ), + '0.3.4401.5.3.1.9': ('camellia', ), + '0.3.4401.5.3.1.9.1': ('camellia-128-ecb', 'CAMELLIA-128-ECB'), + '0.3.4401.5.3.1.9.3': ('camellia-128-ofb', 'CAMELLIA-128-OFB'), + '0.3.4401.5.3.1.9.4': ('camellia-128-cfb', 'CAMELLIA-128-CFB'), + '0.3.4401.5.3.1.9.6': ('camellia-128-gcm', 'CAMELLIA-128-GCM'), + '0.3.4401.5.3.1.9.7': ('camellia-128-ccm', 'CAMELLIA-128-CCM'), + '0.3.4401.5.3.1.9.9': ('camellia-128-ctr', 'CAMELLIA-128-CTR'), + '0.3.4401.5.3.1.9.10': ('camellia-128-cmac', 'CAMELLIA-128-CMAC'), + '0.3.4401.5.3.1.9.21': ('camellia-192-ecb', 'CAMELLIA-192-ECB'), + '0.3.4401.5.3.1.9.23': ('camellia-192-ofb', 'CAMELLIA-192-OFB'), + '0.3.4401.5.3.1.9.24': ('camellia-192-cfb', 'CAMELLIA-192-CFB'), + '0.3.4401.5.3.1.9.26': ('camellia-192-gcm', 'CAMELLIA-192-GCM'), + '0.3.4401.5.3.1.9.27': ('camellia-192-ccm', 'CAMELLIA-192-CCM'), + '0.3.4401.5.3.1.9.29': ('camellia-192-ctr', 'CAMELLIA-192-CTR'), + '0.3.4401.5.3.1.9.30': ('camellia-192-cmac', 'CAMELLIA-192-CMAC'), + '0.3.4401.5.3.1.9.41': ('camellia-256-ecb', 'CAMELLIA-256-ECB'), + '0.3.4401.5.3.1.9.43': ('camellia-256-ofb', 'CAMELLIA-256-OFB'), + '0.3.4401.5.3.1.9.44': ('camellia-256-cfb', 'CAMELLIA-256-CFB'), + '0.3.4401.5.3.1.9.46': ('camellia-256-gcm', 'CAMELLIA-256-GCM'), + '0.3.4401.5.3.1.9.47': ('camellia-256-ccm', 'CAMELLIA-256-CCM'), + '0.3.4401.5.3.1.9.49': ('camellia-256-ctr', 'CAMELLIA-256-CTR'), + '0.3.4401.5.3.1.9.50': ('camellia-256-cmac', 'CAMELLIA-256-CMAC'), + '0.9': ('data', ), + '0.9.2342': ('pss', ), + '0.9.2342.19200300': ('ucl', ), + '0.9.2342.19200300.100': ('pilot', ), + '0.9.2342.19200300.100.1': ('pilotAttributeType', ), + '0.9.2342.19200300.100.1.1': ('userId', 'UID'), + '0.9.2342.19200300.100.1.2': ('textEncodedORAddress', ), + '0.9.2342.19200300.100.1.3': ('rfc822Mailbox', 'mail'), + '0.9.2342.19200300.100.1.4': ('info', ), + '0.9.2342.19200300.100.1.5': ('favouriteDrink', ), + '0.9.2342.19200300.100.1.6': ('roomNumber', ), + '0.9.2342.19200300.100.1.7': ('photo', ), + '0.9.2342.19200300.100.1.8': ('userClass', ), + '0.9.2342.19200300.100.1.9': ('host', ), + '0.9.2342.19200300.100.1.10': ('manager', ), + '0.9.2342.19200300.100.1.11': ('documentIdentifier', ), + '0.9.2342.19200300.100.1.12': ('documentTitle', ), + '0.9.2342.19200300.100.1.13': ('documentVersion', ), + '0.9.2342.19200300.100.1.14': ('documentAuthor', ), + '0.9.2342.19200300.100.1.15': ('documentLocation', ), + '0.9.2342.19200300.100.1.20': ('homeTelephoneNumber', ), + '0.9.2342.19200300.100.1.21': ('secretary', ), + '0.9.2342.19200300.100.1.22': ('otherMailbox', ), + '0.9.2342.19200300.100.1.23': ('lastModifiedTime', ), + '0.9.2342.19200300.100.1.24': ('lastModifiedBy', ), + '0.9.2342.19200300.100.1.25': ('domainComponent', 'DC'), + '0.9.2342.19200300.100.1.26': ('aRecord', ), + '0.9.2342.19200300.100.1.27': ('pilotAttributeType27', ), + '0.9.2342.19200300.100.1.28': ('mXRecord', ), + '0.9.2342.19200300.100.1.29': ('nSRecord', ), + '0.9.2342.19200300.100.1.30': ('sOARecord', ), + '0.9.2342.19200300.100.1.31': ('cNAMERecord', ), + '0.9.2342.19200300.100.1.37': ('associatedDomain', ), + '0.9.2342.19200300.100.1.38': ('associatedName', ), + '0.9.2342.19200300.100.1.39': ('homePostalAddress', ), + '0.9.2342.19200300.100.1.40': ('personalTitle', ), + '0.9.2342.19200300.100.1.41': ('mobileTelephoneNumber', ), + '0.9.2342.19200300.100.1.42': ('pagerTelephoneNumber', ), + '0.9.2342.19200300.100.1.43': ('friendlyCountryName', ), + '0.9.2342.19200300.100.1.44': ('uniqueIdentifier', 'uid'), + '0.9.2342.19200300.100.1.45': ('organizationalStatus', ), + '0.9.2342.19200300.100.1.46': ('janetMailbox', ), + '0.9.2342.19200300.100.1.47': ('mailPreferenceOption', ), + '0.9.2342.19200300.100.1.48': ('buildingName', ), + '0.9.2342.19200300.100.1.49': ('dSAQuality', ), + '0.9.2342.19200300.100.1.50': ('singleLevelQuality', ), + '0.9.2342.19200300.100.1.51': ('subtreeMinimumQuality', ), + '0.9.2342.19200300.100.1.52': ('subtreeMaximumQuality', ), + '0.9.2342.19200300.100.1.53': ('personalSignature', ), + '0.9.2342.19200300.100.1.54': ('dITRedirect', ), + '0.9.2342.19200300.100.1.55': ('audio', ), + '0.9.2342.19200300.100.1.56': ('documentPublisher', ), + '0.9.2342.19200300.100.3': ('pilotAttributeSyntax', ), + '0.9.2342.19200300.100.3.4': ('iA5StringSyntax', ), + '0.9.2342.19200300.100.3.5': ('caseIgnoreIA5StringSyntax', ), + '0.9.2342.19200300.100.4': ('pilotObjectClass', ), + '0.9.2342.19200300.100.4.3': ('pilotObject', ), + '0.9.2342.19200300.100.4.4': ('pilotPerson', ), + '0.9.2342.19200300.100.4.5': ('account', ), + '0.9.2342.19200300.100.4.6': ('document', ), + '0.9.2342.19200300.100.4.7': ('room', ), + '0.9.2342.19200300.100.4.9': ('documentSeries', ), + '0.9.2342.19200300.100.4.13': ('Domain', 'domain'), + '0.9.2342.19200300.100.4.14': ('rFC822localPart', ), + '0.9.2342.19200300.100.4.15': ('dNSDomain', ), + '0.9.2342.19200300.100.4.17': ('domainRelatedObject', ), + '0.9.2342.19200300.100.4.18': ('friendlyCountry', ), + '0.9.2342.19200300.100.4.19': ('simpleSecurityObject', ), + '0.9.2342.19200300.100.4.20': ('pilotOrganization', ), + '0.9.2342.19200300.100.4.21': ('pilotDSA', ), + '0.9.2342.19200300.100.4.22': ('qualityLabelledData', ), + '0.9.2342.19200300.100.10': ('pilotGroups', ), + '1': ('iso', 'ISO'), + '1.0.9797.3.4': ('gmac', 'GMAC'), + '1.0.10118.3.0.55': ('whirlpool', ), + '1.2': ('ISO Member Body', 'member-body'), + '1.2.156': ('ISO CN Member Body', 'ISO-CN'), + '1.2.156.10197': ('oscca', ), + '1.2.156.10197.1': ('sm-scheme', ), + '1.2.156.10197.1.104.1': ('sm4-ecb', 'SM4-ECB'), + '1.2.156.10197.1.104.2': ('sm4-cbc', 'SM4-CBC'), + '1.2.156.10197.1.104.3': ('sm4-ofb', 'SM4-OFB'), + '1.2.156.10197.1.104.4': ('sm4-cfb', 'SM4-CFB'), + '1.2.156.10197.1.104.5': ('sm4-cfb1', 'SM4-CFB1'), + '1.2.156.10197.1.104.6': ('sm4-cfb8', 'SM4-CFB8'), + '1.2.156.10197.1.104.7': ('sm4-ctr', 'SM4-CTR'), + '1.2.156.10197.1.301': ('sm2', 'SM2'), + '1.2.156.10197.1.401': ('sm3', 'SM3'), + '1.2.156.10197.1.501': ('SM2-with-SM3', 'SM2-SM3'), + '1.2.156.10197.1.504': ('sm3WithRSAEncryption', 'RSA-SM3'), + '1.2.392.200011.61.1.1.1.2': ('camellia-128-cbc', 'CAMELLIA-128-CBC'), + '1.2.392.200011.61.1.1.1.3': ('camellia-192-cbc', 'CAMELLIA-192-CBC'), + '1.2.392.200011.61.1.1.1.4': ('camellia-256-cbc', 'CAMELLIA-256-CBC'), + '1.2.392.200011.61.1.1.3.2': ('id-camellia128-wrap', ), + '1.2.392.200011.61.1.1.3.3': ('id-camellia192-wrap', ), + '1.2.392.200011.61.1.1.3.4': ('id-camellia256-wrap', ), + '1.2.410.200004': ('kisa', 'KISA'), + '1.2.410.200004.1.3': ('seed-ecb', 'SEED-ECB'), + '1.2.410.200004.1.4': ('seed-cbc', 'SEED-CBC'), + '1.2.410.200004.1.5': ('seed-cfb', 'SEED-CFB'), + '1.2.410.200004.1.6': ('seed-ofb', 'SEED-OFB'), + '1.2.410.200046.1.1': ('aria', ), + '1.2.410.200046.1.1.1': ('aria-128-ecb', 'ARIA-128-ECB'), + '1.2.410.200046.1.1.2': ('aria-128-cbc', 'ARIA-128-CBC'), + '1.2.410.200046.1.1.3': ('aria-128-cfb', 'ARIA-128-CFB'), + '1.2.410.200046.1.1.4': ('aria-128-ofb', 'ARIA-128-OFB'), + '1.2.410.200046.1.1.5': ('aria-128-ctr', 'ARIA-128-CTR'), + '1.2.410.200046.1.1.6': ('aria-192-ecb', 'ARIA-192-ECB'), + '1.2.410.200046.1.1.7': ('aria-192-cbc', 'ARIA-192-CBC'), + '1.2.410.200046.1.1.8': ('aria-192-cfb', 'ARIA-192-CFB'), + '1.2.410.200046.1.1.9': ('aria-192-ofb', 'ARIA-192-OFB'), + '1.2.410.200046.1.1.10': ('aria-192-ctr', 'ARIA-192-CTR'), + '1.2.410.200046.1.1.11': ('aria-256-ecb', 'ARIA-256-ECB'), + '1.2.410.200046.1.1.12': ('aria-256-cbc', 'ARIA-256-CBC'), + '1.2.410.200046.1.1.13': ('aria-256-cfb', 'ARIA-256-CFB'), + '1.2.410.200046.1.1.14': ('aria-256-ofb', 'ARIA-256-OFB'), + '1.2.410.200046.1.1.15': ('aria-256-ctr', 'ARIA-256-CTR'), + '1.2.410.200046.1.1.34': ('aria-128-gcm', 'ARIA-128-GCM'), + '1.2.410.200046.1.1.35': ('aria-192-gcm', 'ARIA-192-GCM'), + '1.2.410.200046.1.1.36': ('aria-256-gcm', 'ARIA-256-GCM'), + '1.2.410.200046.1.1.37': ('aria-128-ccm', 'ARIA-128-CCM'), + '1.2.410.200046.1.1.38': ('aria-192-ccm', 'ARIA-192-CCM'), + '1.2.410.200046.1.1.39': ('aria-256-ccm', 'ARIA-256-CCM'), + '1.2.643.2.2': ('cryptopro', ), + '1.2.643.2.2.3': ('GOST R 34.11-94 with GOST R 34.10-2001', 'id-GostR3411-94-with-GostR3410-2001'), + '1.2.643.2.2.4': ('GOST R 34.11-94 with GOST R 34.10-94', 'id-GostR3411-94-with-GostR3410-94'), + '1.2.643.2.2.9': ('GOST R 34.11-94', 'md_gost94'), + '1.2.643.2.2.10': ('HMAC GOST 34.11-94', 'id-HMACGostR3411-94'), + '1.2.643.2.2.14.0': ('id-Gost28147-89-None-KeyMeshing', ), + '1.2.643.2.2.14.1': ('id-Gost28147-89-CryptoPro-KeyMeshing', ), + '1.2.643.2.2.19': ('GOST R 34.10-2001', 'gost2001'), + '1.2.643.2.2.20': ('GOST R 34.10-94', 'gost94'), + '1.2.643.2.2.20.1': ('id-GostR3410-94-a', ), + '1.2.643.2.2.20.2': ('id-GostR3410-94-aBis', ), + '1.2.643.2.2.20.3': ('id-GostR3410-94-b', ), + '1.2.643.2.2.20.4': ('id-GostR3410-94-bBis', ), + '1.2.643.2.2.21': ('GOST 28147-89', 'gost89'), + '1.2.643.2.2.22': ('GOST 28147-89 MAC', 'gost-mac'), + '1.2.643.2.2.23': ('GOST R 34.11-94 PRF', 'prf-gostr3411-94'), + '1.2.643.2.2.30.0': ('id-GostR3411-94-TestParamSet', ), + '1.2.643.2.2.30.1': ('id-GostR3411-94-CryptoProParamSet', ), + '1.2.643.2.2.31.0': ('id-Gost28147-89-TestParamSet', ), + '1.2.643.2.2.31.1': ('id-Gost28147-89-CryptoPro-A-ParamSet', ), + '1.2.643.2.2.31.2': ('id-Gost28147-89-CryptoPro-B-ParamSet', ), + '1.2.643.2.2.31.3': ('id-Gost28147-89-CryptoPro-C-ParamSet', ), + '1.2.643.2.2.31.4': ('id-Gost28147-89-CryptoPro-D-ParamSet', ), + '1.2.643.2.2.31.5': ('id-Gost28147-89-CryptoPro-Oscar-1-1-ParamSet', ), + '1.2.643.2.2.31.6': ('id-Gost28147-89-CryptoPro-Oscar-1-0-ParamSet', ), + '1.2.643.2.2.31.7': ('id-Gost28147-89-CryptoPro-RIC-1-ParamSet', ), + '1.2.643.2.2.32.0': ('id-GostR3410-94-TestParamSet', ), + '1.2.643.2.2.32.2': ('id-GostR3410-94-CryptoPro-A-ParamSet', ), + '1.2.643.2.2.32.3': ('id-GostR3410-94-CryptoPro-B-ParamSet', ), + '1.2.643.2.2.32.4': ('id-GostR3410-94-CryptoPro-C-ParamSet', ), + '1.2.643.2.2.32.5': ('id-GostR3410-94-CryptoPro-D-ParamSet', ), + '1.2.643.2.2.33.1': ('id-GostR3410-94-CryptoPro-XchA-ParamSet', ), + '1.2.643.2.2.33.2': ('id-GostR3410-94-CryptoPro-XchB-ParamSet', ), + '1.2.643.2.2.33.3': ('id-GostR3410-94-CryptoPro-XchC-ParamSet', ), + '1.2.643.2.2.35.0': ('id-GostR3410-2001-TestParamSet', ), + '1.2.643.2.2.35.1': ('id-GostR3410-2001-CryptoPro-A-ParamSet', ), + '1.2.643.2.2.35.2': ('id-GostR3410-2001-CryptoPro-B-ParamSet', ), + '1.2.643.2.2.35.3': ('id-GostR3410-2001-CryptoPro-C-ParamSet', ), + '1.2.643.2.2.36.0': ('id-GostR3410-2001-CryptoPro-XchA-ParamSet', ), + '1.2.643.2.2.36.1': ('id-GostR3410-2001-CryptoPro-XchB-ParamSet', ), + '1.2.643.2.2.98': ('GOST R 34.10-2001 DH', 'id-GostR3410-2001DH'), + '1.2.643.2.2.99': ('GOST R 34.10-94 DH', 'id-GostR3410-94DH'), + '1.2.643.2.9': ('cryptocom', ), + '1.2.643.2.9.1.3.3': ('GOST R 34.11-94 with GOST R 34.10-94 Cryptocom', 'id-GostR3411-94-with-GostR3410-94-cc'), + '1.2.643.2.9.1.3.4': ('GOST R 34.11-94 with GOST R 34.10-2001 Cryptocom', 'id-GostR3411-94-with-GostR3410-2001-cc'), + '1.2.643.2.9.1.5.3': ('GOST 34.10-94 Cryptocom', 'gost94cc'), + '1.2.643.2.9.1.5.4': ('GOST 34.10-2001 Cryptocom', 'gost2001cc'), + '1.2.643.2.9.1.6.1': ('GOST 28147-89 Cryptocom ParamSet', 'id-Gost28147-89-cc'), + '1.2.643.2.9.1.8.1': ('GOST R 3410-2001 Parameter Set Cryptocom', 'id-GostR3410-2001-ParamSet-cc'), + '1.2.643.3.131.1.1': ('INN', 'INN'), + '1.2.643.7.1': ('id-tc26', ), + '1.2.643.7.1.1': ('id-tc26-algorithms', ), + '1.2.643.7.1.1.1': ('id-tc26-sign', ), + '1.2.643.7.1.1.1.1': ('GOST R 34.10-2012 with 256 bit modulus', 'gost2012_256'), + '1.2.643.7.1.1.1.2': ('GOST R 34.10-2012 with 512 bit modulus', 'gost2012_512'), + '1.2.643.7.1.1.2': ('id-tc26-digest', ), + '1.2.643.7.1.1.2.2': ('GOST R 34.11-2012 with 256 bit hash', 'md_gost12_256'), + '1.2.643.7.1.1.2.3': ('GOST R 34.11-2012 with 512 bit hash', 'md_gost12_512'), + '1.2.643.7.1.1.3': ('id-tc26-signwithdigest', ), + '1.2.643.7.1.1.3.2': ('GOST R 34.10-2012 with GOST R 34.11-2012 (256 bit)', 'id-tc26-signwithdigest-gost3410-2012-256'), + '1.2.643.7.1.1.3.3': ('GOST R 34.10-2012 with GOST R 34.11-2012 (512 bit)', 'id-tc26-signwithdigest-gost3410-2012-512'), + '1.2.643.7.1.1.4': ('id-tc26-mac', ), + '1.2.643.7.1.1.4.1': ('HMAC GOST 34.11-2012 256 bit', 'id-tc26-hmac-gost-3411-2012-256'), + '1.2.643.7.1.1.4.2': ('HMAC GOST 34.11-2012 512 bit', 'id-tc26-hmac-gost-3411-2012-512'), + '1.2.643.7.1.1.5': ('id-tc26-cipher', ), + '1.2.643.7.1.1.5.1': ('id-tc26-cipher-gostr3412-2015-magma', ), + '1.2.643.7.1.1.5.1.1': ('id-tc26-cipher-gostr3412-2015-magma-ctracpkm', ), + '1.2.643.7.1.1.5.1.2': ('id-tc26-cipher-gostr3412-2015-magma-ctracpkm-omac', ), + '1.2.643.7.1.1.5.2': ('id-tc26-cipher-gostr3412-2015-kuznyechik', ), + '1.2.643.7.1.1.5.2.1': ('id-tc26-cipher-gostr3412-2015-kuznyechik-ctracpkm', ), + '1.2.643.7.1.1.5.2.2': ('id-tc26-cipher-gostr3412-2015-kuznyechik-ctracpkm-omac', ), + '1.2.643.7.1.1.6': ('id-tc26-agreement', ), + '1.2.643.7.1.1.6.1': ('id-tc26-agreement-gost-3410-2012-256', ), + '1.2.643.7.1.1.6.2': ('id-tc26-agreement-gost-3410-2012-512', ), + '1.2.643.7.1.1.7': ('id-tc26-wrap', ), + '1.2.643.7.1.1.7.1': ('id-tc26-wrap-gostr3412-2015-magma', ), + '1.2.643.7.1.1.7.1.1': ('id-tc26-wrap-gostr3412-2015-magma-kexp15', 'id-tc26-wrap-gostr3412-2015-kuznyechik-kexp15'), + '1.2.643.7.1.1.7.2': ('id-tc26-wrap-gostr3412-2015-kuznyechik', ), + '1.2.643.7.1.2': ('id-tc26-constants', ), + '1.2.643.7.1.2.1': ('id-tc26-sign-constants', ), + '1.2.643.7.1.2.1.1': ('id-tc26-gost-3410-2012-256-constants', ), + '1.2.643.7.1.2.1.1.1': ('GOST R 34.10-2012 (256 bit) ParamSet A', 'id-tc26-gost-3410-2012-256-paramSetA'), + '1.2.643.7.1.2.1.1.2': ('GOST R 34.10-2012 (256 bit) ParamSet B', 'id-tc26-gost-3410-2012-256-paramSetB'), + '1.2.643.7.1.2.1.1.3': ('GOST R 34.10-2012 (256 bit) ParamSet C', 'id-tc26-gost-3410-2012-256-paramSetC'), + '1.2.643.7.1.2.1.1.4': ('GOST R 34.10-2012 (256 bit) ParamSet D', 'id-tc26-gost-3410-2012-256-paramSetD'), + '1.2.643.7.1.2.1.2': ('id-tc26-gost-3410-2012-512-constants', ), + '1.2.643.7.1.2.1.2.0': ('GOST R 34.10-2012 (512 bit) testing parameter set', 'id-tc26-gost-3410-2012-512-paramSetTest'), + '1.2.643.7.1.2.1.2.1': ('GOST R 34.10-2012 (512 bit) ParamSet A', 'id-tc26-gost-3410-2012-512-paramSetA'), + '1.2.643.7.1.2.1.2.2': ('GOST R 34.10-2012 (512 bit) ParamSet B', 'id-tc26-gost-3410-2012-512-paramSetB'), + '1.2.643.7.1.2.1.2.3': ('GOST R 34.10-2012 (512 bit) ParamSet C', 'id-tc26-gost-3410-2012-512-paramSetC'), + '1.2.643.7.1.2.2': ('id-tc26-digest-constants', ), + '1.2.643.7.1.2.5': ('id-tc26-cipher-constants', ), + '1.2.643.7.1.2.5.1': ('id-tc26-gost-28147-constants', ), + '1.2.643.7.1.2.5.1.1': ('GOST 28147-89 TC26 parameter set', 'id-tc26-gost-28147-param-Z'), + '1.2.643.100.1': ('OGRN', 'OGRN'), + '1.2.643.100.3': ('SNILS', 'SNILS'), + '1.2.643.100.111': ('Signing Tool of Subject', 'subjectSignTool'), + '1.2.643.100.112': ('Signing Tool of Issuer', 'issuerSignTool'), + '1.2.804': ('ISO-UA', ), + '1.2.804.2.1.1.1': ('ua-pki', ), + '1.2.804.2.1.1.1.1.1.1': ('DSTU Gost 28147-2009', 'dstu28147'), + '1.2.804.2.1.1.1.1.1.1.2': ('DSTU Gost 28147-2009 OFB mode', 'dstu28147-ofb'), + '1.2.804.2.1.1.1.1.1.1.3': ('DSTU Gost 28147-2009 CFB mode', 'dstu28147-cfb'), + '1.2.804.2.1.1.1.1.1.1.5': ('DSTU Gost 28147-2009 key wrap', 'dstu28147-wrap'), + '1.2.804.2.1.1.1.1.1.2': ('HMAC DSTU Gost 34311-95', 'hmacWithDstu34311'), + '1.2.804.2.1.1.1.1.2.1': ('DSTU Gost 34311-95', 'dstu34311'), + '1.2.804.2.1.1.1.1.3.1.1': ('DSTU 4145-2002 little endian', 'dstu4145le'), + '1.2.804.2.1.1.1.1.3.1.1.1.1': ('DSTU 4145-2002 big endian', 'dstu4145be'), + '1.2.804.2.1.1.1.1.3.1.1.2.0': ('DSTU curve 0', 'uacurve0'), + '1.2.804.2.1.1.1.1.3.1.1.2.1': ('DSTU curve 1', 'uacurve1'), + '1.2.804.2.1.1.1.1.3.1.1.2.2': ('DSTU curve 2', 'uacurve2'), + '1.2.804.2.1.1.1.1.3.1.1.2.3': ('DSTU curve 3', 'uacurve3'), + '1.2.804.2.1.1.1.1.3.1.1.2.4': ('DSTU curve 4', 'uacurve4'), + '1.2.804.2.1.1.1.1.3.1.1.2.5': ('DSTU curve 5', 'uacurve5'), + '1.2.804.2.1.1.1.1.3.1.1.2.6': ('DSTU curve 6', 'uacurve6'), + '1.2.804.2.1.1.1.1.3.1.1.2.7': ('DSTU curve 7', 'uacurve7'), + '1.2.804.2.1.1.1.1.3.1.1.2.8': ('DSTU curve 8', 'uacurve8'), + '1.2.804.2.1.1.1.1.3.1.1.2.9': ('DSTU curve 9', 'uacurve9'), + '1.2.840': ('ISO US Member Body', 'ISO-US'), + '1.2.840.10040': ('X9.57', 'X9-57'), + '1.2.840.10040.2': ('holdInstruction', ), + '1.2.840.10040.2.1': ('Hold Instruction None', 'holdInstructionNone'), + '1.2.840.10040.2.2': ('Hold Instruction Call Issuer', 'holdInstructionCallIssuer'), + '1.2.840.10040.2.3': ('Hold Instruction Reject', 'holdInstructionReject'), + '1.2.840.10040.4': ('X9.57 CM ?', 'X9cm'), + '1.2.840.10040.4.1': ('dsaEncryption', 'DSA'), + '1.2.840.10040.4.3': ('dsaWithSHA1', 'DSA-SHA1'), + '1.2.840.10045': ('ANSI X9.62', 'ansi-X9-62'), + '1.2.840.10045.1': ('id-fieldType', ), + '1.2.840.10045.1.1': ('prime-field', ), + '1.2.840.10045.1.2': ('characteristic-two-field', ), + '1.2.840.10045.1.2.3': ('id-characteristic-two-basis', ), + '1.2.840.10045.1.2.3.1': ('onBasis', ), + '1.2.840.10045.1.2.3.2': ('tpBasis', ), + '1.2.840.10045.1.2.3.3': ('ppBasis', ), + '1.2.840.10045.2': ('id-publicKeyType', ), + '1.2.840.10045.2.1': ('id-ecPublicKey', ), + '1.2.840.10045.3': ('ellipticCurve', ), + '1.2.840.10045.3.0': ('c-TwoCurve', ), + '1.2.840.10045.3.0.1': ('c2pnb163v1', ), + '1.2.840.10045.3.0.2': ('c2pnb163v2', ), + '1.2.840.10045.3.0.3': ('c2pnb163v3', ), + '1.2.840.10045.3.0.4': ('c2pnb176v1', ), + '1.2.840.10045.3.0.5': ('c2tnb191v1', ), + '1.2.840.10045.3.0.6': ('c2tnb191v2', ), + '1.2.840.10045.3.0.7': ('c2tnb191v3', ), + '1.2.840.10045.3.0.8': ('c2onb191v4', ), + '1.2.840.10045.3.0.9': ('c2onb191v5', ), + '1.2.840.10045.3.0.10': ('c2pnb208w1', ), + '1.2.840.10045.3.0.11': ('c2tnb239v1', ), + '1.2.840.10045.3.0.12': ('c2tnb239v2', ), + '1.2.840.10045.3.0.13': ('c2tnb239v3', ), + '1.2.840.10045.3.0.14': ('c2onb239v4', ), + '1.2.840.10045.3.0.15': ('c2onb239v5', ), + '1.2.840.10045.3.0.16': ('c2pnb272w1', ), + '1.2.840.10045.3.0.17': ('c2pnb304w1', ), + '1.2.840.10045.3.0.18': ('c2tnb359v1', ), + '1.2.840.10045.3.0.19': ('c2pnb368w1', ), + '1.2.840.10045.3.0.20': ('c2tnb431r1', ), + '1.2.840.10045.3.1': ('primeCurve', ), + '1.2.840.10045.3.1.1': ('prime192v1', ), + '1.2.840.10045.3.1.2': ('prime192v2', ), + '1.2.840.10045.3.1.3': ('prime192v3', ), + '1.2.840.10045.3.1.4': ('prime239v1', ), + '1.2.840.10045.3.1.5': ('prime239v2', ), + '1.2.840.10045.3.1.6': ('prime239v3', ), + '1.2.840.10045.3.1.7': ('prime256v1', ), + '1.2.840.10045.4': ('id-ecSigType', ), + '1.2.840.10045.4.1': ('ecdsa-with-SHA1', ), + '1.2.840.10045.4.2': ('ecdsa-with-Recommended', ), + '1.2.840.10045.4.3': ('ecdsa-with-Specified', ), + '1.2.840.10045.4.3.1': ('ecdsa-with-SHA224', ), + '1.2.840.10045.4.3.2': ('ecdsa-with-SHA256', ), + '1.2.840.10045.4.3.3': ('ecdsa-with-SHA384', ), + '1.2.840.10045.4.3.4': ('ecdsa-with-SHA512', ), + '1.2.840.10046.2.1': ('X9.42 DH', 'dhpublicnumber'), + '1.2.840.113533.7.66.10': ('cast5-cbc', 'CAST5-CBC'), + '1.2.840.113533.7.66.12': ('pbeWithMD5AndCast5CBC', ), + '1.2.840.113533.7.66.13': ('password based MAC', 'id-PasswordBasedMAC'), + '1.2.840.113533.7.66.30': ('Diffie-Hellman based MAC', 'id-DHBasedMac'), + '1.2.840.113549': ('RSA Data Security, Inc.', 'rsadsi'), + '1.2.840.113549.1': ('RSA Data Security, Inc. PKCS', 'pkcs'), + '1.2.840.113549.1.1': ('pkcs1', ), + '1.2.840.113549.1.1.1': ('rsaEncryption', ), + '1.2.840.113549.1.1.2': ('md2WithRSAEncryption', 'RSA-MD2'), + '1.2.840.113549.1.1.3': ('md4WithRSAEncryption', 'RSA-MD4'), + '1.2.840.113549.1.1.4': ('md5WithRSAEncryption', 'RSA-MD5'), + '1.2.840.113549.1.1.5': ('sha1WithRSAEncryption', 'RSA-SHA1'), + '1.2.840.113549.1.1.6': ('rsaOAEPEncryptionSET', ), + '1.2.840.113549.1.1.7': ('rsaesOaep', 'RSAES-OAEP'), + '1.2.840.113549.1.1.8': ('mgf1', 'MGF1'), + '1.2.840.113549.1.1.9': ('pSpecified', 'PSPECIFIED'), + '1.2.840.113549.1.1.10': ('rsassaPss', 'RSASSA-PSS'), + '1.2.840.113549.1.1.11': ('sha256WithRSAEncryption', 'RSA-SHA256'), + '1.2.840.113549.1.1.12': ('sha384WithRSAEncryption', 'RSA-SHA384'), + '1.2.840.113549.1.1.13': ('sha512WithRSAEncryption', 'RSA-SHA512'), + '1.2.840.113549.1.1.14': ('sha224WithRSAEncryption', 'RSA-SHA224'), + '1.2.840.113549.1.1.15': ('sha512-224WithRSAEncryption', 'RSA-SHA512/224'), + '1.2.840.113549.1.1.16': ('sha512-256WithRSAEncryption', 'RSA-SHA512/256'), + '1.2.840.113549.1.3': ('pkcs3', ), + '1.2.840.113549.1.3.1': ('dhKeyAgreement', ), + '1.2.840.113549.1.5': ('pkcs5', ), + '1.2.840.113549.1.5.1': ('pbeWithMD2AndDES-CBC', 'PBE-MD2-DES'), + '1.2.840.113549.1.5.3': ('pbeWithMD5AndDES-CBC', 'PBE-MD5-DES'), + '1.2.840.113549.1.5.4': ('pbeWithMD2AndRC2-CBC', 'PBE-MD2-RC2-64'), + '1.2.840.113549.1.5.6': ('pbeWithMD5AndRC2-CBC', 'PBE-MD5-RC2-64'), + '1.2.840.113549.1.5.10': ('pbeWithSHA1AndDES-CBC', 'PBE-SHA1-DES'), + '1.2.840.113549.1.5.11': ('pbeWithSHA1AndRC2-CBC', 'PBE-SHA1-RC2-64'), + '1.2.840.113549.1.5.12': ('PBKDF2', ), + '1.2.840.113549.1.5.13': ('PBES2', ), + '1.2.840.113549.1.5.14': ('PBMAC1', ), + '1.2.840.113549.1.7': ('pkcs7', ), + '1.2.840.113549.1.7.1': ('pkcs7-data', ), + '1.2.840.113549.1.7.2': ('pkcs7-signedData', ), + '1.2.840.113549.1.7.3': ('pkcs7-envelopedData', ), + '1.2.840.113549.1.7.4': ('pkcs7-signedAndEnvelopedData', ), + '1.2.840.113549.1.7.5': ('pkcs7-digestData', ), + '1.2.840.113549.1.7.6': ('pkcs7-encryptedData', ), + '1.2.840.113549.1.9': ('pkcs9', ), + '1.2.840.113549.1.9.1': ('emailAddress', ), + '1.2.840.113549.1.9.2': ('unstructuredName', ), + '1.2.840.113549.1.9.3': ('contentType', ), + '1.2.840.113549.1.9.4': ('messageDigest', ), + '1.2.840.113549.1.9.5': ('signingTime', ), + '1.2.840.113549.1.9.6': ('countersignature', ), + '1.2.840.113549.1.9.7': ('challengePassword', ), + '1.2.840.113549.1.9.8': ('unstructuredAddress', ), + '1.2.840.113549.1.9.9': ('extendedCertificateAttributes', ), + '1.2.840.113549.1.9.14': ('Extension Request', 'extReq'), + '1.2.840.113549.1.9.15': ('S/MIME Capabilities', 'SMIME-CAPS'), + '1.2.840.113549.1.9.16': ('S/MIME', 'SMIME'), + '1.2.840.113549.1.9.16.0': ('id-smime-mod', ), + '1.2.840.113549.1.9.16.0.1': ('id-smime-mod-cms', ), + '1.2.840.113549.1.9.16.0.2': ('id-smime-mod-ess', ), + '1.2.840.113549.1.9.16.0.3': ('id-smime-mod-oid', ), + '1.2.840.113549.1.9.16.0.4': ('id-smime-mod-msg-v3', ), + '1.2.840.113549.1.9.16.0.5': ('id-smime-mod-ets-eSignature-88', ), + '1.2.840.113549.1.9.16.0.6': ('id-smime-mod-ets-eSignature-97', ), + '1.2.840.113549.1.9.16.0.7': ('id-smime-mod-ets-eSigPolicy-88', ), + '1.2.840.113549.1.9.16.0.8': ('id-smime-mod-ets-eSigPolicy-97', ), + '1.2.840.113549.1.9.16.1': ('id-smime-ct', ), + '1.2.840.113549.1.9.16.1.1': ('id-smime-ct-receipt', ), + '1.2.840.113549.1.9.16.1.2': ('id-smime-ct-authData', ), + '1.2.840.113549.1.9.16.1.3': ('id-smime-ct-publishCert', ), + '1.2.840.113549.1.9.16.1.4': ('id-smime-ct-TSTInfo', ), + '1.2.840.113549.1.9.16.1.5': ('id-smime-ct-TDTInfo', ), + '1.2.840.113549.1.9.16.1.6': ('id-smime-ct-contentInfo', ), + '1.2.840.113549.1.9.16.1.7': ('id-smime-ct-DVCSRequestData', ), + '1.2.840.113549.1.9.16.1.8': ('id-smime-ct-DVCSResponseData', ), + '1.2.840.113549.1.9.16.1.9': ('id-smime-ct-compressedData', ), + '1.2.840.113549.1.9.16.1.19': ('id-smime-ct-contentCollection', ), + '1.2.840.113549.1.9.16.1.23': ('id-smime-ct-authEnvelopedData', ), + '1.2.840.113549.1.9.16.1.27': ('id-ct-asciiTextWithCRLF', ), + '1.2.840.113549.1.9.16.1.28': ('id-ct-xml', ), + '1.2.840.113549.1.9.16.2': ('id-smime-aa', ), + '1.2.840.113549.1.9.16.2.1': ('id-smime-aa-receiptRequest', ), + '1.2.840.113549.1.9.16.2.2': ('id-smime-aa-securityLabel', ), + '1.2.840.113549.1.9.16.2.3': ('id-smime-aa-mlExpandHistory', ), + '1.2.840.113549.1.9.16.2.4': ('id-smime-aa-contentHint', ), + '1.2.840.113549.1.9.16.2.5': ('id-smime-aa-msgSigDigest', ), + '1.2.840.113549.1.9.16.2.6': ('id-smime-aa-encapContentType', ), + '1.2.840.113549.1.9.16.2.7': ('id-smime-aa-contentIdentifier', ), + '1.2.840.113549.1.9.16.2.8': ('id-smime-aa-macValue', ), + '1.2.840.113549.1.9.16.2.9': ('id-smime-aa-equivalentLabels', ), + '1.2.840.113549.1.9.16.2.10': ('id-smime-aa-contentReference', ), + '1.2.840.113549.1.9.16.2.11': ('id-smime-aa-encrypKeyPref', ), + '1.2.840.113549.1.9.16.2.12': ('id-smime-aa-signingCertificate', ), + '1.2.840.113549.1.9.16.2.13': ('id-smime-aa-smimeEncryptCerts', ), + '1.2.840.113549.1.9.16.2.14': ('id-smime-aa-timeStampToken', ), + '1.2.840.113549.1.9.16.2.15': ('id-smime-aa-ets-sigPolicyId', ), + '1.2.840.113549.1.9.16.2.16': ('id-smime-aa-ets-commitmentType', ), + '1.2.840.113549.1.9.16.2.17': ('id-smime-aa-ets-signerLocation', ), + '1.2.840.113549.1.9.16.2.18': ('id-smime-aa-ets-signerAttr', ), + '1.2.840.113549.1.9.16.2.19': ('id-smime-aa-ets-otherSigCert', ), + '1.2.840.113549.1.9.16.2.20': ('id-smime-aa-ets-contentTimestamp', ), + '1.2.840.113549.1.9.16.2.21': ('id-smime-aa-ets-CertificateRefs', ), + '1.2.840.113549.1.9.16.2.22': ('id-smime-aa-ets-RevocationRefs', ), + '1.2.840.113549.1.9.16.2.23': ('id-smime-aa-ets-certValues', ), + '1.2.840.113549.1.9.16.2.24': ('id-smime-aa-ets-revocationValues', ), + '1.2.840.113549.1.9.16.2.25': ('id-smime-aa-ets-escTimeStamp', ), + '1.2.840.113549.1.9.16.2.26': ('id-smime-aa-ets-certCRLTimestamp', ), + '1.2.840.113549.1.9.16.2.27': ('id-smime-aa-ets-archiveTimeStamp', ), + '1.2.840.113549.1.9.16.2.28': ('id-smime-aa-signatureType', ), + '1.2.840.113549.1.9.16.2.29': ('id-smime-aa-dvcs-dvc', ), + '1.2.840.113549.1.9.16.2.47': ('id-smime-aa-signingCertificateV2', ), + '1.2.840.113549.1.9.16.3': ('id-smime-alg', ), + '1.2.840.113549.1.9.16.3.1': ('id-smime-alg-ESDHwith3DES', ), + '1.2.840.113549.1.9.16.3.2': ('id-smime-alg-ESDHwithRC2', ), + '1.2.840.113549.1.9.16.3.3': ('id-smime-alg-3DESwrap', ), + '1.2.840.113549.1.9.16.3.4': ('id-smime-alg-RC2wrap', ), + '1.2.840.113549.1.9.16.3.5': ('id-smime-alg-ESDH', ), + '1.2.840.113549.1.9.16.3.6': ('id-smime-alg-CMS3DESwrap', ), + '1.2.840.113549.1.9.16.3.7': ('id-smime-alg-CMSRC2wrap', ), + '1.2.840.113549.1.9.16.3.8': ('zlib compression', 'ZLIB'), + '1.2.840.113549.1.9.16.3.9': ('id-alg-PWRI-KEK', ), + '1.2.840.113549.1.9.16.4': ('id-smime-cd', ), + '1.2.840.113549.1.9.16.4.1': ('id-smime-cd-ldap', ), + '1.2.840.113549.1.9.16.5': ('id-smime-spq', ), + '1.2.840.113549.1.9.16.5.1': ('id-smime-spq-ets-sqt-uri', ), + '1.2.840.113549.1.9.16.5.2': ('id-smime-spq-ets-sqt-unotice', ), + '1.2.840.113549.1.9.16.6': ('id-smime-cti', ), + '1.2.840.113549.1.9.16.6.1': ('id-smime-cti-ets-proofOfOrigin', ), + '1.2.840.113549.1.9.16.6.2': ('id-smime-cti-ets-proofOfReceipt', ), + '1.2.840.113549.1.9.16.6.3': ('id-smime-cti-ets-proofOfDelivery', ), + '1.2.840.113549.1.9.16.6.4': ('id-smime-cti-ets-proofOfSender', ), + '1.2.840.113549.1.9.16.6.5': ('id-smime-cti-ets-proofOfApproval', ), + '1.2.840.113549.1.9.16.6.6': ('id-smime-cti-ets-proofOfCreation', ), + '1.2.840.113549.1.9.20': ('friendlyName', ), + '1.2.840.113549.1.9.21': ('localKeyID', ), + '1.2.840.113549.1.9.22': ('certTypes', ), + '1.2.840.113549.1.9.22.1': ('x509Certificate', ), + '1.2.840.113549.1.9.22.2': ('sdsiCertificate', ), + '1.2.840.113549.1.9.23': ('crlTypes', ), + '1.2.840.113549.1.9.23.1': ('x509Crl', ), + '1.2.840.113549.1.12': ('pkcs12', ), + '1.2.840.113549.1.12.1': ('pkcs12-pbeids', ), + '1.2.840.113549.1.12.1.1': ('pbeWithSHA1And128BitRC4', 'PBE-SHA1-RC4-128'), + '1.2.840.113549.1.12.1.2': ('pbeWithSHA1And40BitRC4', 'PBE-SHA1-RC4-40'), + '1.2.840.113549.1.12.1.3': ('pbeWithSHA1And3-KeyTripleDES-CBC', 'PBE-SHA1-3DES'), + '1.2.840.113549.1.12.1.4': ('pbeWithSHA1And2-KeyTripleDES-CBC', 'PBE-SHA1-2DES'), + '1.2.840.113549.1.12.1.5': ('pbeWithSHA1And128BitRC2-CBC', 'PBE-SHA1-RC2-128'), + '1.2.840.113549.1.12.1.6': ('pbeWithSHA1And40BitRC2-CBC', 'PBE-SHA1-RC2-40'), + '1.2.840.113549.1.12.10': ('pkcs12-Version1', ), + '1.2.840.113549.1.12.10.1': ('pkcs12-BagIds', ), + '1.2.840.113549.1.12.10.1.1': ('keyBag', ), + '1.2.840.113549.1.12.10.1.2': ('pkcs8ShroudedKeyBag', ), + '1.2.840.113549.1.12.10.1.3': ('certBag', ), + '1.2.840.113549.1.12.10.1.4': ('crlBag', ), + '1.2.840.113549.1.12.10.1.5': ('secretBag', ), + '1.2.840.113549.1.12.10.1.6': ('safeContentsBag', ), + '1.2.840.113549.2.2': ('md2', 'MD2'), + '1.2.840.113549.2.4': ('md4', 'MD4'), + '1.2.840.113549.2.5': ('md5', 'MD5'), + '1.2.840.113549.2.6': ('hmacWithMD5', ), + '1.2.840.113549.2.7': ('hmacWithSHA1', ), + '1.2.840.113549.2.8': ('hmacWithSHA224', ), + '1.2.840.113549.2.9': ('hmacWithSHA256', ), + '1.2.840.113549.2.10': ('hmacWithSHA384', ), + '1.2.840.113549.2.11': ('hmacWithSHA512', ), + '1.2.840.113549.2.12': ('hmacWithSHA512-224', ), + '1.2.840.113549.2.13': ('hmacWithSHA512-256', ), + '1.2.840.113549.3.2': ('rc2-cbc', 'RC2-CBC'), + '1.2.840.113549.3.4': ('rc4', 'RC4'), + '1.2.840.113549.3.7': ('des-ede3-cbc', 'DES-EDE3-CBC'), + '1.2.840.113549.3.8': ('rc5-cbc', 'RC5-CBC'), + '1.2.840.113549.3.10': ('des-cdmf', 'DES-CDMF'), + '1.3': ('identified-organization', 'org', 'ORG'), + '1.3.6': ('dod', 'DOD'), + '1.3.6.1': ('iana', 'IANA', 'internet'), + '1.3.6.1.1': ('Directory', 'directory'), + '1.3.6.1.2': ('Management', 'mgmt'), + '1.3.6.1.3': ('Experimental', 'experimental'), + '1.3.6.1.4': ('Private', 'private'), + '1.3.6.1.4.1': ('Enterprises', 'enterprises'), + '1.3.6.1.4.1.188.7.1.1.2': ('idea-cbc', 'IDEA-CBC'), + '1.3.6.1.4.1.311.2.1.14': ('Microsoft Extension Request', 'msExtReq'), + '1.3.6.1.4.1.311.2.1.21': ('Microsoft Individual Code Signing', 'msCodeInd'), + '1.3.6.1.4.1.311.2.1.22': ('Microsoft Commercial Code Signing', 'msCodeCom'), + '1.3.6.1.4.1.311.10.3.1': ('Microsoft Trust List Signing', 'msCTLSign'), + '1.3.6.1.4.1.311.10.3.3': ('Microsoft Server Gated Crypto', 'msSGC'), + '1.3.6.1.4.1.311.10.3.4': ('Microsoft Encrypted File System', 'msEFS'), + '1.3.6.1.4.1.311.17.1': ('Microsoft CSP Name', 'CSPName'), + '1.3.6.1.4.1.311.17.2': ('Microsoft Local Key set', 'LocalKeySet'), + '1.3.6.1.4.1.311.20.2.2': ('Microsoft Smartcardlogin', 'msSmartcardLogin'), + '1.3.6.1.4.1.311.20.2.3': ('Microsoft Universal Principal Name', 'msUPN'), + '1.3.6.1.4.1.311.60.2.1.1': ('jurisdictionLocalityName', 'jurisdictionL'), + '1.3.6.1.4.1.311.60.2.1.2': ('jurisdictionStateOrProvinceName', 'jurisdictionST'), + '1.3.6.1.4.1.311.60.2.1.3': ('jurisdictionCountryName', 'jurisdictionC'), + '1.3.6.1.4.1.1466.344': ('dcObject', 'dcobject'), + '1.3.6.1.4.1.1722.12.2.1.16': ('blake2b512', 'BLAKE2b512'), + '1.3.6.1.4.1.1722.12.2.2.8': ('blake2s256', 'BLAKE2s256'), + '1.3.6.1.4.1.3029.1.2': ('bf-cbc', 'BF-CBC'), + '1.3.6.1.4.1.11129.2.4.2': ('CT Precertificate SCTs', 'ct_precert_scts'), + '1.3.6.1.4.1.11129.2.4.3': ('CT Precertificate Poison', 'ct_precert_poison'), + '1.3.6.1.4.1.11129.2.4.4': ('CT Precertificate Signer', 'ct_precert_signer'), + '1.3.6.1.4.1.11129.2.4.5': ('CT Certificate SCTs', 'ct_cert_scts'), + '1.3.6.1.4.1.11591.4.11': ('scrypt', 'id-scrypt'), + '1.3.6.1.5': ('Security', 'security'), + '1.3.6.1.5.2.3': ('id-pkinit', ), + '1.3.6.1.5.2.3.4': ('PKINIT Client Auth', 'pkInitClientAuth'), + '1.3.6.1.5.2.3.5': ('Signing KDC Response', 'pkInitKDC'), + '1.3.6.1.5.5.7': ('PKIX', ), + '1.3.6.1.5.5.7.0': ('id-pkix-mod', ), + '1.3.6.1.5.5.7.0.1': ('id-pkix1-explicit-88', ), + '1.3.6.1.5.5.7.0.2': ('id-pkix1-implicit-88', ), + '1.3.6.1.5.5.7.0.3': ('id-pkix1-explicit-93', ), + '1.3.6.1.5.5.7.0.4': ('id-pkix1-implicit-93', ), + '1.3.6.1.5.5.7.0.5': ('id-mod-crmf', ), + '1.3.6.1.5.5.7.0.6': ('id-mod-cmc', ), + '1.3.6.1.5.5.7.0.7': ('id-mod-kea-profile-88', ), + '1.3.6.1.5.5.7.0.8': ('id-mod-kea-profile-93', ), + '1.3.6.1.5.5.7.0.9': ('id-mod-cmp', ), + '1.3.6.1.5.5.7.0.10': ('id-mod-qualified-cert-88', ), + '1.3.6.1.5.5.7.0.11': ('id-mod-qualified-cert-93', ), + '1.3.6.1.5.5.7.0.12': ('id-mod-attribute-cert', ), + '1.3.6.1.5.5.7.0.13': ('id-mod-timestamp-protocol', ), + '1.3.6.1.5.5.7.0.14': ('id-mod-ocsp', ), + '1.3.6.1.5.5.7.0.15': ('id-mod-dvcs', ), + '1.3.6.1.5.5.7.0.16': ('id-mod-cmp2000', ), + '1.3.6.1.5.5.7.1': ('id-pe', ), + '1.3.6.1.5.5.7.1.1': ('Authority Information Access', 'authorityInfoAccess'), + '1.3.6.1.5.5.7.1.2': ('Biometric Info', 'biometricInfo'), + '1.3.6.1.5.5.7.1.3': ('qcStatements', ), + '1.3.6.1.5.5.7.1.4': ('ac-auditEntity', ), + '1.3.6.1.5.5.7.1.5': ('ac-targeting', ), + '1.3.6.1.5.5.7.1.6': ('aaControls', ), + '1.3.6.1.5.5.7.1.7': ('sbgp-ipAddrBlock', ), + '1.3.6.1.5.5.7.1.8': ('sbgp-autonomousSysNum', ), + '1.3.6.1.5.5.7.1.9': ('sbgp-routerIdentifier', ), + '1.3.6.1.5.5.7.1.10': ('ac-proxying', ), + '1.3.6.1.5.5.7.1.11': ('Subject Information Access', 'subjectInfoAccess'), + '1.3.6.1.5.5.7.1.14': ('Proxy Certificate Information', 'proxyCertInfo'), + '1.3.6.1.5.5.7.1.24': ('TLS Feature', 'tlsfeature'), + '1.3.6.1.5.5.7.2': ('id-qt', ), + '1.3.6.1.5.5.7.2.1': ('Policy Qualifier CPS', 'id-qt-cps'), + '1.3.6.1.5.5.7.2.2': ('Policy Qualifier User Notice', 'id-qt-unotice'), + '1.3.6.1.5.5.7.2.3': ('textNotice', ), + '1.3.6.1.5.5.7.3': ('id-kp', ), + '1.3.6.1.5.5.7.3.1': ('TLS Web Server Authentication', 'serverAuth'), + '1.3.6.1.5.5.7.3.2': ('TLS Web Client Authentication', 'clientAuth'), + '1.3.6.1.5.5.7.3.3': ('Code Signing', 'codeSigning'), + '1.3.6.1.5.5.7.3.4': ('E-mail Protection', 'emailProtection'), + '1.3.6.1.5.5.7.3.5': ('IPSec End System', 'ipsecEndSystem'), + '1.3.6.1.5.5.7.3.6': ('IPSec Tunnel', 'ipsecTunnel'), + '1.3.6.1.5.5.7.3.7': ('IPSec User', 'ipsecUser'), + '1.3.6.1.5.5.7.3.8': ('Time Stamping', 'timeStamping'), + '1.3.6.1.5.5.7.3.9': ('OCSP Signing', 'OCSPSigning'), + '1.3.6.1.5.5.7.3.10': ('dvcs', 'DVCS'), + '1.3.6.1.5.5.7.3.17': ('ipsec Internet Key Exchange', 'ipsecIKE'), + '1.3.6.1.5.5.7.3.18': ('Ctrl/provision WAP Access', 'capwapAC'), + '1.3.6.1.5.5.7.3.19': ('Ctrl/Provision WAP Termination', 'capwapWTP'), + '1.3.6.1.5.5.7.3.21': ('SSH Client', 'secureShellClient'), + '1.3.6.1.5.5.7.3.22': ('SSH Server', 'secureShellServer'), + '1.3.6.1.5.5.7.3.23': ('Send Router', 'sendRouter'), + '1.3.6.1.5.5.7.3.24': ('Send Proxied Router', 'sendProxiedRouter'), + '1.3.6.1.5.5.7.3.25': ('Send Owner', 'sendOwner'), + '1.3.6.1.5.5.7.3.26': ('Send Proxied Owner', 'sendProxiedOwner'), + '1.3.6.1.5.5.7.3.27': ('CMC Certificate Authority', 'cmcCA'), + '1.3.6.1.5.5.7.3.28': ('CMC Registration Authority', 'cmcRA'), + '1.3.6.1.5.5.7.4': ('id-it', ), + '1.3.6.1.5.5.7.4.1': ('id-it-caProtEncCert', ), + '1.3.6.1.5.5.7.4.2': ('id-it-signKeyPairTypes', ), + '1.3.6.1.5.5.7.4.3': ('id-it-encKeyPairTypes', ), + '1.3.6.1.5.5.7.4.4': ('id-it-preferredSymmAlg', ), + '1.3.6.1.5.5.7.4.5': ('id-it-caKeyUpdateInfo', ), + '1.3.6.1.5.5.7.4.6': ('id-it-currentCRL', ), + '1.3.6.1.5.5.7.4.7': ('id-it-unsupportedOIDs', ), + '1.3.6.1.5.5.7.4.8': ('id-it-subscriptionRequest', ), + '1.3.6.1.5.5.7.4.9': ('id-it-subscriptionResponse', ), + '1.3.6.1.5.5.7.4.10': ('id-it-keyPairParamReq', ), + '1.3.6.1.5.5.7.4.11': ('id-it-keyPairParamRep', ), + '1.3.6.1.5.5.7.4.12': ('id-it-revPassphrase', ), + '1.3.6.1.5.5.7.4.13': ('id-it-implicitConfirm', ), + '1.3.6.1.5.5.7.4.14': ('id-it-confirmWaitTime', ), + '1.3.6.1.5.5.7.4.15': ('id-it-origPKIMessage', ), + '1.3.6.1.5.5.7.4.16': ('id-it-suppLangTags', ), + '1.3.6.1.5.5.7.5': ('id-pkip', ), + '1.3.6.1.5.5.7.5.1': ('id-regCtrl', ), + '1.3.6.1.5.5.7.5.1.1': ('id-regCtrl-regToken', ), + '1.3.6.1.5.5.7.5.1.2': ('id-regCtrl-authenticator', ), + '1.3.6.1.5.5.7.5.1.3': ('id-regCtrl-pkiPublicationInfo', ), + '1.3.6.1.5.5.7.5.1.4': ('id-regCtrl-pkiArchiveOptions', ), + '1.3.6.1.5.5.7.5.1.5': ('id-regCtrl-oldCertID', ), + '1.3.6.1.5.5.7.5.1.6': ('id-regCtrl-protocolEncrKey', ), + '1.3.6.1.5.5.7.5.2': ('id-regInfo', ), + '1.3.6.1.5.5.7.5.2.1': ('id-regInfo-utf8Pairs', ), + '1.3.6.1.5.5.7.5.2.2': ('id-regInfo-certReq', ), + '1.3.6.1.5.5.7.6': ('id-alg', ), + '1.3.6.1.5.5.7.6.1': ('id-alg-des40', ), + '1.3.6.1.5.5.7.6.2': ('id-alg-noSignature', ), + '1.3.6.1.5.5.7.6.3': ('id-alg-dh-sig-hmac-sha1', ), + '1.3.6.1.5.5.7.6.4': ('id-alg-dh-pop', ), + '1.3.6.1.5.5.7.7': ('id-cmc', ), + '1.3.6.1.5.5.7.7.1': ('id-cmc-statusInfo', ), + '1.3.6.1.5.5.7.7.2': ('id-cmc-identification', ), + '1.3.6.1.5.5.7.7.3': ('id-cmc-identityProof', ), + '1.3.6.1.5.5.7.7.4': ('id-cmc-dataReturn', ), + '1.3.6.1.5.5.7.7.5': ('id-cmc-transactionId', ), + '1.3.6.1.5.5.7.7.6': ('id-cmc-senderNonce', ), + '1.3.6.1.5.5.7.7.7': ('id-cmc-recipientNonce', ), + '1.3.6.1.5.5.7.7.8': ('id-cmc-addExtensions', ), + '1.3.6.1.5.5.7.7.9': ('id-cmc-encryptedPOP', ), + '1.3.6.1.5.5.7.7.10': ('id-cmc-decryptedPOP', ), + '1.3.6.1.5.5.7.7.11': ('id-cmc-lraPOPWitness', ), + '1.3.6.1.5.5.7.7.15': ('id-cmc-getCert', ), + '1.3.6.1.5.5.7.7.16': ('id-cmc-getCRL', ), + '1.3.6.1.5.5.7.7.17': ('id-cmc-revokeRequest', ), + '1.3.6.1.5.5.7.7.18': ('id-cmc-regInfo', ), + '1.3.6.1.5.5.7.7.19': ('id-cmc-responseInfo', ), + '1.3.6.1.5.5.7.7.21': ('id-cmc-queryPending', ), + '1.3.6.1.5.5.7.7.22': ('id-cmc-popLinkRandom', ), + '1.3.6.1.5.5.7.7.23': ('id-cmc-popLinkWitness', ), + '1.3.6.1.5.5.7.7.24': ('id-cmc-confirmCertAcceptance', ), + '1.3.6.1.5.5.7.8': ('id-on', ), + '1.3.6.1.5.5.7.8.1': ('id-on-personalData', ), + '1.3.6.1.5.5.7.8.3': ('Permanent Identifier', 'id-on-permanentIdentifier'), + '1.3.6.1.5.5.7.9': ('id-pda', ), + '1.3.6.1.5.5.7.9.1': ('id-pda-dateOfBirth', ), + '1.3.6.1.5.5.7.9.2': ('id-pda-placeOfBirth', ), + '1.3.6.1.5.5.7.9.3': ('id-pda-gender', ), + '1.3.6.1.5.5.7.9.4': ('id-pda-countryOfCitizenship', ), + '1.3.6.1.5.5.7.9.5': ('id-pda-countryOfResidence', ), + '1.3.6.1.5.5.7.10': ('id-aca', ), + '1.3.6.1.5.5.7.10.1': ('id-aca-authenticationInfo', ), + '1.3.6.1.5.5.7.10.2': ('id-aca-accessIdentity', ), + '1.3.6.1.5.5.7.10.3': ('id-aca-chargingIdentity', ), + '1.3.6.1.5.5.7.10.4': ('id-aca-group', ), + '1.3.6.1.5.5.7.10.5': ('id-aca-role', ), + '1.3.6.1.5.5.7.10.6': ('id-aca-encAttrs', ), + '1.3.6.1.5.5.7.11': ('id-qcs', ), + '1.3.6.1.5.5.7.11.1': ('id-qcs-pkixQCSyntax-v1', ), + '1.3.6.1.5.5.7.12': ('id-cct', ), + '1.3.6.1.5.5.7.12.1': ('id-cct-crs', ), + '1.3.6.1.5.5.7.12.2': ('id-cct-PKIData', ), + '1.3.6.1.5.5.7.12.3': ('id-cct-PKIResponse', ), + '1.3.6.1.5.5.7.21': ('id-ppl', ), + '1.3.6.1.5.5.7.21.0': ('Any language', 'id-ppl-anyLanguage'), + '1.3.6.1.5.5.7.21.1': ('Inherit all', 'id-ppl-inheritAll'), + '1.3.6.1.5.5.7.21.2': ('Independent', 'id-ppl-independent'), + '1.3.6.1.5.5.7.48': ('id-ad', ), + '1.3.6.1.5.5.7.48.1': ('OCSP', 'OCSP', 'id-pkix-OCSP'), + '1.3.6.1.5.5.7.48.1.1': ('Basic OCSP Response', 'basicOCSPResponse'), + '1.3.6.1.5.5.7.48.1.2': ('OCSP Nonce', 'Nonce'), + '1.3.6.1.5.5.7.48.1.3': ('OCSP CRL ID', 'CrlID'), + '1.3.6.1.5.5.7.48.1.4': ('Acceptable OCSP Responses', 'acceptableResponses'), + '1.3.6.1.5.5.7.48.1.5': ('OCSP No Check', 'noCheck'), + '1.3.6.1.5.5.7.48.1.6': ('OCSP Archive Cutoff', 'archiveCutoff'), + '1.3.6.1.5.5.7.48.1.7': ('OCSP Service Locator', 'serviceLocator'), + '1.3.6.1.5.5.7.48.1.8': ('Extended OCSP Status', 'extendedStatus'), + '1.3.6.1.5.5.7.48.1.9': ('valid', ), + '1.3.6.1.5.5.7.48.1.10': ('path', ), + '1.3.6.1.5.5.7.48.1.11': ('Trust Root', 'trustRoot'), + '1.3.6.1.5.5.7.48.2': ('CA Issuers', 'caIssuers'), + '1.3.6.1.5.5.7.48.3': ('AD Time Stamping', 'ad_timestamping'), + '1.3.6.1.5.5.7.48.4': ('ad dvcs', 'AD_DVCS'), + '1.3.6.1.5.5.7.48.5': ('CA Repository', 'caRepository'), + '1.3.6.1.5.5.8.1.1': ('hmac-md5', 'HMAC-MD5'), + '1.3.6.1.5.5.8.1.2': ('hmac-sha1', 'HMAC-SHA1'), + '1.3.6.1.6': ('SNMPv2', 'snmpv2'), + '1.3.6.1.7': ('Mail', ), + '1.3.6.1.7.1': ('MIME MHS', 'mime-mhs'), + '1.3.6.1.7.1.1': ('mime-mhs-headings', 'mime-mhs-headings'), + '1.3.6.1.7.1.1.1': ('id-hex-partial-message', 'id-hex-partial-message'), + '1.3.6.1.7.1.1.2': ('id-hex-multipart-message', 'id-hex-multipart-message'), + '1.3.6.1.7.1.2': ('mime-mhs-bodies', 'mime-mhs-bodies'), + '1.3.14.3.2': ('algorithm', 'algorithm'), + '1.3.14.3.2.3': ('md5WithRSA', 'RSA-NP-MD5'), + '1.3.14.3.2.6': ('des-ecb', 'DES-ECB'), + '1.3.14.3.2.7': ('des-cbc', 'DES-CBC'), + '1.3.14.3.2.8': ('des-ofb', 'DES-OFB'), + '1.3.14.3.2.9': ('des-cfb', 'DES-CFB'), + '1.3.14.3.2.11': ('rsaSignature', ), + '1.3.14.3.2.12': ('dsaEncryption-old', 'DSA-old'), + '1.3.14.3.2.13': ('dsaWithSHA', 'DSA-SHA'), + '1.3.14.3.2.15': ('shaWithRSAEncryption', 'RSA-SHA'), + '1.3.14.3.2.17': ('des-ede', 'DES-EDE'), + '1.3.14.3.2.18': ('sha', 'SHA'), + '1.3.14.3.2.26': ('sha1', 'SHA1'), + '1.3.14.3.2.27': ('dsaWithSHA1-old', 'DSA-SHA1-old'), + '1.3.14.3.2.29': ('sha1WithRSA', 'RSA-SHA1-2'), + '1.3.36.3.2.1': ('ripemd160', 'RIPEMD160'), + '1.3.36.3.3.1.2': ('ripemd160WithRSA', 'RSA-RIPEMD160'), + '1.3.36.3.3.2.8.1.1.1': ('brainpoolP160r1', ), + '1.3.36.3.3.2.8.1.1.2': ('brainpoolP160t1', ), + '1.3.36.3.3.2.8.1.1.3': ('brainpoolP192r1', ), + '1.3.36.3.3.2.8.1.1.4': ('brainpoolP192t1', ), + '1.3.36.3.3.2.8.1.1.5': ('brainpoolP224r1', ), + '1.3.36.3.3.2.8.1.1.6': ('brainpoolP224t1', ), + '1.3.36.3.3.2.8.1.1.7': ('brainpoolP256r1', ), + '1.3.36.3.3.2.8.1.1.8': ('brainpoolP256t1', ), + '1.3.36.3.3.2.8.1.1.9': ('brainpoolP320r1', ), + '1.3.36.3.3.2.8.1.1.10': ('brainpoolP320t1', ), + '1.3.36.3.3.2.8.1.1.11': ('brainpoolP384r1', ), + '1.3.36.3.3.2.8.1.1.12': ('brainpoolP384t1', ), + '1.3.36.3.3.2.8.1.1.13': ('brainpoolP512r1', ), + '1.3.36.3.3.2.8.1.1.14': ('brainpoolP512t1', ), + '1.3.36.8.3.3': ('Professional Information or basis for Admission', 'x509ExtAdmission'), + '1.3.101.1.4.1': ('Strong Extranet ID', 'SXNetID'), + '1.3.101.110': ('X25519', ), + '1.3.101.111': ('X448', ), + '1.3.101.112': ('ED25519', ), + '1.3.101.113': ('ED448', ), + '1.3.111': ('ieee', ), + '1.3.111.2.1619': ('IEEE Security in Storage Working Group', 'ieee-siswg'), + '1.3.111.2.1619.0.1.1': ('aes-128-xts', 'AES-128-XTS'), + '1.3.111.2.1619.0.1.2': ('aes-256-xts', 'AES-256-XTS'), + '1.3.132': ('certicom-arc', ), + '1.3.132.0': ('secg_ellipticCurve', ), + '1.3.132.0.1': ('sect163k1', ), + '1.3.132.0.2': ('sect163r1', ), + '1.3.132.0.3': ('sect239k1', ), + '1.3.132.0.4': ('sect113r1', ), + '1.3.132.0.5': ('sect113r2', ), + '1.3.132.0.6': ('secp112r1', ), + '1.3.132.0.7': ('secp112r2', ), + '1.3.132.0.8': ('secp160r1', ), + '1.3.132.0.9': ('secp160k1', ), + '1.3.132.0.10': ('secp256k1', ), + '1.3.132.0.15': ('sect163r2', ), + '1.3.132.0.16': ('sect283k1', ), + '1.3.132.0.17': ('sect283r1', ), + '1.3.132.0.22': ('sect131r1', ), + '1.3.132.0.23': ('sect131r2', ), + '1.3.132.0.24': ('sect193r1', ), + '1.3.132.0.25': ('sect193r2', ), + '1.3.132.0.26': ('sect233k1', ), + '1.3.132.0.27': ('sect233r1', ), + '1.3.132.0.28': ('secp128r1', ), + '1.3.132.0.29': ('secp128r2', ), + '1.3.132.0.30': ('secp160r2', ), + '1.3.132.0.31': ('secp192k1', ), + '1.3.132.0.32': ('secp224k1', ), + '1.3.132.0.33': ('secp224r1', ), + '1.3.132.0.34': ('secp384r1', ), + '1.3.132.0.35': ('secp521r1', ), + '1.3.132.0.36': ('sect409k1', ), + '1.3.132.0.37': ('sect409r1', ), + '1.3.132.0.38': ('sect571k1', ), + '1.3.132.0.39': ('sect571r1', ), + '1.3.132.1': ('secg-scheme', ), + '1.3.132.1.11.0': ('dhSinglePass-stdDH-sha224kdf-scheme', ), + '1.3.132.1.11.1': ('dhSinglePass-stdDH-sha256kdf-scheme', ), + '1.3.132.1.11.2': ('dhSinglePass-stdDH-sha384kdf-scheme', ), + '1.3.132.1.11.3': ('dhSinglePass-stdDH-sha512kdf-scheme', ), + '1.3.132.1.14.0': ('dhSinglePass-cofactorDH-sha224kdf-scheme', ), + '1.3.132.1.14.1': ('dhSinglePass-cofactorDH-sha256kdf-scheme', ), + '1.3.132.1.14.2': ('dhSinglePass-cofactorDH-sha384kdf-scheme', ), + '1.3.132.1.14.3': ('dhSinglePass-cofactorDH-sha512kdf-scheme', ), + '1.3.133.16.840.63.0': ('x9-63-scheme', ), + '1.3.133.16.840.63.0.2': ('dhSinglePass-stdDH-sha1kdf-scheme', ), + '1.3.133.16.840.63.0.3': ('dhSinglePass-cofactorDH-sha1kdf-scheme', ), + '2': ('joint-iso-itu-t', 'JOINT-ISO-ITU-T', 'joint-iso-ccitt'), + '2.5': ('directory services (X.500)', 'X500'), + '2.5.1.5': ('Selected Attribute Types', 'selected-attribute-types'), + '2.5.1.5.55': ('clearance', ), + '2.5.4': ('X509', ), + '2.5.4.3': ('commonName', 'CN'), + '2.5.4.4': ('surname', 'SN'), + '2.5.4.5': ('serialNumber', ), + '2.5.4.6': ('countryName', 'C'), + '2.5.4.7': ('localityName', 'L'), + '2.5.4.8': ('stateOrProvinceName', 'ST'), + '2.5.4.9': ('streetAddress', 'street'), + '2.5.4.10': ('organizationName', 'O'), + '2.5.4.11': ('organizationalUnitName', 'OU'), + '2.5.4.12': ('title', 'title'), + '2.5.4.13': ('description', ), + '2.5.4.14': ('searchGuide', ), + '2.5.4.15': ('businessCategory', ), + '2.5.4.16': ('postalAddress', ), + '2.5.4.17': ('postalCode', ), + '2.5.4.18': ('postOfficeBox', ), + '2.5.4.19': ('physicalDeliveryOfficeName', ), + '2.5.4.20': ('telephoneNumber', ), + '2.5.4.21': ('telexNumber', ), + '2.5.4.22': ('teletexTerminalIdentifier', ), + '2.5.4.23': ('facsimileTelephoneNumber', ), + '2.5.4.24': ('x121Address', ), + '2.5.4.25': ('internationaliSDNNumber', ), + '2.5.4.26': ('registeredAddress', ), + '2.5.4.27': ('destinationIndicator', ), + '2.5.4.28': ('preferredDeliveryMethod', ), + '2.5.4.29': ('presentationAddress', ), + '2.5.4.30': ('supportedApplicationContext', ), + '2.5.4.31': ('member', ), + '2.5.4.32': ('owner', ), + '2.5.4.33': ('roleOccupant', ), + '2.5.4.34': ('seeAlso', ), + '2.5.4.35': ('userPassword', ), + '2.5.4.36': ('userCertificate', ), + '2.5.4.37': ('cACertificate', ), + '2.5.4.38': ('authorityRevocationList', ), + '2.5.4.39': ('certificateRevocationList', ), + '2.5.4.40': ('crossCertificatePair', ), + '2.5.4.41': ('name', 'name'), + '2.5.4.42': ('givenName', 'GN'), + '2.5.4.43': ('initials', 'initials'), + '2.5.4.44': ('generationQualifier', ), + '2.5.4.45': ('x500UniqueIdentifier', ), + '2.5.4.46': ('dnQualifier', 'dnQualifier'), + '2.5.4.47': ('enhancedSearchGuide', ), + '2.5.4.48': ('protocolInformation', ), + '2.5.4.49': ('distinguishedName', ), + '2.5.4.50': ('uniqueMember', ), + '2.5.4.51': ('houseIdentifier', ), + '2.5.4.52': ('supportedAlgorithms', ), + '2.5.4.53': ('deltaRevocationList', ), + '2.5.4.54': ('dmdName', ), + '2.5.4.65': ('pseudonym', ), + '2.5.4.72': ('role', 'role'), + '2.5.4.97': ('organizationIdentifier', ), + '2.5.4.98': ('countryCode3c', 'c3'), + '2.5.4.99': ('countryCode3n', 'n3'), + '2.5.4.100': ('dnsName', ), + '2.5.8': ('directory services - algorithms', 'X500algorithms'), + '2.5.8.1.1': ('rsa', 'RSA'), + '2.5.8.3.100': ('mdc2WithRSA', 'RSA-MDC2'), + '2.5.8.3.101': ('mdc2', 'MDC2'), + '2.5.29': ('id-ce', ), + '2.5.29.9': ('X509v3 Subject Directory Attributes', 'subjectDirectoryAttributes'), + '2.5.29.14': ('X509v3 Subject Key Identifier', 'subjectKeyIdentifier'), + '2.5.29.15': ('X509v3 Key Usage', 'keyUsage'), + '2.5.29.16': ('X509v3 Private Key Usage Period', 'privateKeyUsagePeriod'), + '2.5.29.17': ('X509v3 Subject Alternative Name', 'subjectAltName'), + '2.5.29.18': ('X509v3 Issuer Alternative Name', 'issuerAltName'), + '2.5.29.19': ('X509v3 Basic Constraints', 'basicConstraints'), + '2.5.29.20': ('X509v3 CRL Number', 'crlNumber'), + '2.5.29.21': ('X509v3 CRL Reason Code', 'CRLReason'), + '2.5.29.23': ('Hold Instruction Code', 'holdInstructionCode'), + '2.5.29.24': ('Invalidity Date', 'invalidityDate'), + '2.5.29.27': ('X509v3 Delta CRL Indicator', 'deltaCRL'), + '2.5.29.28': ('X509v3 Issuing Distribution Point', 'issuingDistributionPoint'), + '2.5.29.29': ('X509v3 Certificate Issuer', 'certificateIssuer'), + '2.5.29.30': ('X509v3 Name Constraints', 'nameConstraints'), + '2.5.29.31': ('X509v3 CRL Distribution Points', 'crlDistributionPoints'), + '2.5.29.32': ('X509v3 Certificate Policies', 'certificatePolicies'), + '2.5.29.32.0': ('X509v3 Any Policy', 'anyPolicy'), + '2.5.29.33': ('X509v3 Policy Mappings', 'policyMappings'), + '2.5.29.35': ('X509v3 Authority Key Identifier', 'authorityKeyIdentifier'), + '2.5.29.36': ('X509v3 Policy Constraints', 'policyConstraints'), + '2.5.29.37': ('X509v3 Extended Key Usage', 'extendedKeyUsage'), + '2.5.29.37.0': ('Any Extended Key Usage', 'anyExtendedKeyUsage'), + '2.5.29.46': ('X509v3 Freshest CRL', 'freshestCRL'), + '2.5.29.54': ('X509v3 Inhibit Any Policy', 'inhibitAnyPolicy'), + '2.5.29.55': ('X509v3 AC Targeting', 'targetInformation'), + '2.5.29.56': ('X509v3 No Revocation Available', 'noRevAvail'), + '2.16.840.1.101.3': ('csor', ), + '2.16.840.1.101.3.4': ('nistAlgorithms', ), + '2.16.840.1.101.3.4.1': ('aes', ), + '2.16.840.1.101.3.4.1.1': ('aes-128-ecb', 'AES-128-ECB'), + '2.16.840.1.101.3.4.1.2': ('aes-128-cbc', 'AES-128-CBC'), + '2.16.840.1.101.3.4.1.3': ('aes-128-ofb', 'AES-128-OFB'), + '2.16.840.1.101.3.4.1.4': ('aes-128-cfb', 'AES-128-CFB'), + '2.16.840.1.101.3.4.1.5': ('id-aes128-wrap', ), + '2.16.840.1.101.3.4.1.6': ('aes-128-gcm', 'id-aes128-GCM'), + '2.16.840.1.101.3.4.1.7': ('aes-128-ccm', 'id-aes128-CCM'), + '2.16.840.1.101.3.4.1.8': ('id-aes128-wrap-pad', ), + '2.16.840.1.101.3.4.1.21': ('aes-192-ecb', 'AES-192-ECB'), + '2.16.840.1.101.3.4.1.22': ('aes-192-cbc', 'AES-192-CBC'), + '2.16.840.1.101.3.4.1.23': ('aes-192-ofb', 'AES-192-OFB'), + '2.16.840.1.101.3.4.1.24': ('aes-192-cfb', 'AES-192-CFB'), + '2.16.840.1.101.3.4.1.25': ('id-aes192-wrap', ), + '2.16.840.1.101.3.4.1.26': ('aes-192-gcm', 'id-aes192-GCM'), + '2.16.840.1.101.3.4.1.27': ('aes-192-ccm', 'id-aes192-CCM'), + '2.16.840.1.101.3.4.1.28': ('id-aes192-wrap-pad', ), + '2.16.840.1.101.3.4.1.41': ('aes-256-ecb', 'AES-256-ECB'), + '2.16.840.1.101.3.4.1.42': ('aes-256-cbc', 'AES-256-CBC'), + '2.16.840.1.101.3.4.1.43': ('aes-256-ofb', 'AES-256-OFB'), + '2.16.840.1.101.3.4.1.44': ('aes-256-cfb', 'AES-256-CFB'), + '2.16.840.1.101.3.4.1.45': ('id-aes256-wrap', ), + '2.16.840.1.101.3.4.1.46': ('aes-256-gcm', 'id-aes256-GCM'), + '2.16.840.1.101.3.4.1.47': ('aes-256-ccm', 'id-aes256-CCM'), + '2.16.840.1.101.3.4.1.48': ('id-aes256-wrap-pad', ), + '2.16.840.1.101.3.4.2': ('nist_hashalgs', ), + '2.16.840.1.101.3.4.2.1': ('sha256', 'SHA256'), + '2.16.840.1.101.3.4.2.2': ('sha384', 'SHA384'), + '2.16.840.1.101.3.4.2.3': ('sha512', 'SHA512'), + '2.16.840.1.101.3.4.2.4': ('sha224', 'SHA224'), + '2.16.840.1.101.3.4.2.5': ('sha512-224', 'SHA512-224'), + '2.16.840.1.101.3.4.2.6': ('sha512-256', 'SHA512-256'), + '2.16.840.1.101.3.4.2.7': ('sha3-224', 'SHA3-224'), + '2.16.840.1.101.3.4.2.8': ('sha3-256', 'SHA3-256'), + '2.16.840.1.101.3.4.2.9': ('sha3-384', 'SHA3-384'), + '2.16.840.1.101.3.4.2.10': ('sha3-512', 'SHA3-512'), + '2.16.840.1.101.3.4.2.11': ('shake128', 'SHAKE128'), + '2.16.840.1.101.3.4.2.12': ('shake256', 'SHAKE256'), + '2.16.840.1.101.3.4.2.13': ('hmac-sha3-224', 'id-hmacWithSHA3-224'), + '2.16.840.1.101.3.4.2.14': ('hmac-sha3-256', 'id-hmacWithSHA3-256'), + '2.16.840.1.101.3.4.2.15': ('hmac-sha3-384', 'id-hmacWithSHA3-384'), + '2.16.840.1.101.3.4.2.16': ('hmac-sha3-512', 'id-hmacWithSHA3-512'), + '2.16.840.1.101.3.4.3': ('dsa_with_sha2', 'sigAlgs'), + '2.16.840.1.101.3.4.3.1': ('dsa_with_SHA224', ), + '2.16.840.1.101.3.4.3.2': ('dsa_with_SHA256', ), + '2.16.840.1.101.3.4.3.3': ('dsa_with_SHA384', 'id-dsa-with-sha384'), + '2.16.840.1.101.3.4.3.4': ('dsa_with_SHA512', 'id-dsa-with-sha512'), + '2.16.840.1.101.3.4.3.5': ('dsa_with_SHA3-224', 'id-dsa-with-sha3-224'), + '2.16.840.1.101.3.4.3.6': ('dsa_with_SHA3-256', 'id-dsa-with-sha3-256'), + '2.16.840.1.101.3.4.3.7': ('dsa_with_SHA3-384', 'id-dsa-with-sha3-384'), + '2.16.840.1.101.3.4.3.8': ('dsa_with_SHA3-512', 'id-dsa-with-sha3-512'), + '2.16.840.1.101.3.4.3.9': ('ecdsa_with_SHA3-224', 'id-ecdsa-with-sha3-224'), + '2.16.840.1.101.3.4.3.10': ('ecdsa_with_SHA3-256', 'id-ecdsa-with-sha3-256'), + '2.16.840.1.101.3.4.3.11': ('ecdsa_with_SHA3-384', 'id-ecdsa-with-sha3-384'), + '2.16.840.1.101.3.4.3.12': ('ecdsa_with_SHA3-512', 'id-ecdsa-with-sha3-512'), + '2.16.840.1.101.3.4.3.13': ('RSA-SHA3-224', 'id-rsassa-pkcs1-v1_5-with-sha3-224'), + '2.16.840.1.101.3.4.3.14': ('RSA-SHA3-256', 'id-rsassa-pkcs1-v1_5-with-sha3-256'), + '2.16.840.1.101.3.4.3.15': ('RSA-SHA3-384', 'id-rsassa-pkcs1-v1_5-with-sha3-384'), + '2.16.840.1.101.3.4.3.16': ('RSA-SHA3-512', 'id-rsassa-pkcs1-v1_5-with-sha3-512'), + '2.16.840.1.113730': ('Netscape Communications Corp.', 'Netscape'), + '2.16.840.1.113730.1': ('Netscape Certificate Extension', 'nsCertExt'), + '2.16.840.1.113730.1.1': ('Netscape Cert Type', 'nsCertType'), + '2.16.840.1.113730.1.2': ('Netscape Base Url', 'nsBaseUrl'), + '2.16.840.1.113730.1.3': ('Netscape Revocation Url', 'nsRevocationUrl'), + '2.16.840.1.113730.1.4': ('Netscape CA Revocation Url', 'nsCaRevocationUrl'), + '2.16.840.1.113730.1.7': ('Netscape Renewal Url', 'nsRenewalUrl'), + '2.16.840.1.113730.1.8': ('Netscape CA Policy Url', 'nsCaPolicyUrl'), + '2.16.840.1.113730.1.12': ('Netscape SSL Server Name', 'nsSslServerName'), + '2.16.840.1.113730.1.13': ('Netscape Comment', 'nsComment'), + '2.16.840.1.113730.2': ('Netscape Data Type', 'nsDataType'), + '2.16.840.1.113730.2.5': ('Netscape Certificate Sequence', 'nsCertSequence'), + '2.16.840.1.113730.4.1': ('Netscape Server Gated Crypto', 'nsSGC'), + '2.23': ('International Organizations', 'international-organizations'), + '2.23.42': ('Secure Electronic Transactions', 'id-set'), + '2.23.42.0': ('content types', 'set-ctype'), + '2.23.42.0.0': ('setct-PANData', ), + '2.23.42.0.1': ('setct-PANToken', ), + '2.23.42.0.2': ('setct-PANOnly', ), + '2.23.42.0.3': ('setct-OIData', ), + '2.23.42.0.4': ('setct-PI', ), + '2.23.42.0.5': ('setct-PIData', ), + '2.23.42.0.6': ('setct-PIDataUnsigned', ), + '2.23.42.0.7': ('setct-HODInput', ), + '2.23.42.0.8': ('setct-AuthResBaggage', ), + '2.23.42.0.9': ('setct-AuthRevReqBaggage', ), + '2.23.42.0.10': ('setct-AuthRevResBaggage', ), + '2.23.42.0.11': ('setct-CapTokenSeq', ), + '2.23.42.0.12': ('setct-PInitResData', ), + '2.23.42.0.13': ('setct-PI-TBS', ), + '2.23.42.0.14': ('setct-PResData', ), + '2.23.42.0.16': ('setct-AuthReqTBS', ), + '2.23.42.0.17': ('setct-AuthResTBS', ), + '2.23.42.0.18': ('setct-AuthResTBSX', ), + '2.23.42.0.19': ('setct-AuthTokenTBS', ), + '2.23.42.0.20': ('setct-CapTokenData', ), + '2.23.42.0.21': ('setct-CapTokenTBS', ), + '2.23.42.0.22': ('setct-AcqCardCodeMsg', ), + '2.23.42.0.23': ('setct-AuthRevReqTBS', ), + '2.23.42.0.24': ('setct-AuthRevResData', ), + '2.23.42.0.25': ('setct-AuthRevResTBS', ), + '2.23.42.0.26': ('setct-CapReqTBS', ), + '2.23.42.0.27': ('setct-CapReqTBSX', ), + '2.23.42.0.28': ('setct-CapResData', ), + '2.23.42.0.29': ('setct-CapRevReqTBS', ), + '2.23.42.0.30': ('setct-CapRevReqTBSX', ), + '2.23.42.0.31': ('setct-CapRevResData', ), + '2.23.42.0.32': ('setct-CredReqTBS', ), + '2.23.42.0.33': ('setct-CredReqTBSX', ), + '2.23.42.0.34': ('setct-CredResData', ), + '2.23.42.0.35': ('setct-CredRevReqTBS', ), + '2.23.42.0.36': ('setct-CredRevReqTBSX', ), + '2.23.42.0.37': ('setct-CredRevResData', ), + '2.23.42.0.38': ('setct-PCertReqData', ), + '2.23.42.0.39': ('setct-PCertResTBS', ), + '2.23.42.0.40': ('setct-BatchAdminReqData', ), + '2.23.42.0.41': ('setct-BatchAdminResData', ), + '2.23.42.0.42': ('setct-CardCInitResTBS', ), + '2.23.42.0.43': ('setct-MeAqCInitResTBS', ), + '2.23.42.0.44': ('setct-RegFormResTBS', ), + '2.23.42.0.45': ('setct-CertReqData', ), + '2.23.42.0.46': ('setct-CertReqTBS', ), + '2.23.42.0.47': ('setct-CertResData', ), + '2.23.42.0.48': ('setct-CertInqReqTBS', ), + '2.23.42.0.49': ('setct-ErrorTBS', ), + '2.23.42.0.50': ('setct-PIDualSignedTBE', ), + '2.23.42.0.51': ('setct-PIUnsignedTBE', ), + '2.23.42.0.52': ('setct-AuthReqTBE', ), + '2.23.42.0.53': ('setct-AuthResTBE', ), + '2.23.42.0.54': ('setct-AuthResTBEX', ), + '2.23.42.0.55': ('setct-AuthTokenTBE', ), + '2.23.42.0.56': ('setct-CapTokenTBE', ), + '2.23.42.0.57': ('setct-CapTokenTBEX', ), + '2.23.42.0.58': ('setct-AcqCardCodeMsgTBE', ), + '2.23.42.0.59': ('setct-AuthRevReqTBE', ), + '2.23.42.0.60': ('setct-AuthRevResTBE', ), + '2.23.42.0.61': ('setct-AuthRevResTBEB', ), + '2.23.42.0.62': ('setct-CapReqTBE', ), + '2.23.42.0.63': ('setct-CapReqTBEX', ), + '2.23.42.0.64': ('setct-CapResTBE', ), + '2.23.42.0.65': ('setct-CapRevReqTBE', ), + '2.23.42.0.66': ('setct-CapRevReqTBEX', ), + '2.23.42.0.67': ('setct-CapRevResTBE', ), + '2.23.42.0.68': ('setct-CredReqTBE', ), + '2.23.42.0.69': ('setct-CredReqTBEX', ), + '2.23.42.0.70': ('setct-CredResTBE', ), + '2.23.42.0.71': ('setct-CredRevReqTBE', ), + '2.23.42.0.72': ('setct-CredRevReqTBEX', ), + '2.23.42.0.73': ('setct-CredRevResTBE', ), + '2.23.42.0.74': ('setct-BatchAdminReqTBE', ), + '2.23.42.0.75': ('setct-BatchAdminResTBE', ), + '2.23.42.0.76': ('setct-RegFormReqTBE', ), + '2.23.42.0.77': ('setct-CertReqTBE', ), + '2.23.42.0.78': ('setct-CertReqTBEX', ), + '2.23.42.0.79': ('setct-CertResTBE', ), + '2.23.42.0.80': ('setct-CRLNotificationTBS', ), + '2.23.42.0.81': ('setct-CRLNotificationResTBS', ), + '2.23.42.0.82': ('setct-BCIDistributionTBS', ), + '2.23.42.1': ('message extensions', 'set-msgExt'), + '2.23.42.1.1': ('generic cryptogram', 'setext-genCrypt'), + '2.23.42.1.3': ('merchant initiated auth', 'setext-miAuth'), + '2.23.42.1.4': ('setext-pinSecure', ), + '2.23.42.1.5': ('setext-pinAny', ), + '2.23.42.1.7': ('setext-track2', ), + '2.23.42.1.8': ('additional verification', 'setext-cv'), + '2.23.42.3': ('set-attr', ), + '2.23.42.3.0': ('setAttr-Cert', ), + '2.23.42.3.0.0': ('set-rootKeyThumb', ), + '2.23.42.3.0.1': ('set-addPolicy', ), + '2.23.42.3.1': ('payment gateway capabilities', 'setAttr-PGWYcap'), + '2.23.42.3.2': ('setAttr-TokenType', ), + '2.23.42.3.2.1': ('setAttr-Token-EMV', ), + '2.23.42.3.2.2': ('setAttr-Token-B0Prime', ), + '2.23.42.3.3': ('issuer capabilities', 'setAttr-IssCap'), + '2.23.42.3.3.3': ('setAttr-IssCap-CVM', ), + '2.23.42.3.3.3.1': ('generate cryptogram', 'setAttr-GenCryptgrm'), + '2.23.42.3.3.4': ('setAttr-IssCap-T2', ), + '2.23.42.3.3.4.1': ('encrypted track 2', 'setAttr-T2Enc'), + '2.23.42.3.3.4.2': ('cleartext track 2', 'setAttr-T2cleartxt'), + '2.23.42.3.3.5': ('setAttr-IssCap-Sig', ), + '2.23.42.3.3.5.1': ('ICC or token signature', 'setAttr-TokICCsig'), + '2.23.42.3.3.5.2': ('secure device signature', 'setAttr-SecDevSig'), + '2.23.42.5': ('set-policy', ), + '2.23.42.5.0': ('set-policy-root', ), + '2.23.42.7': ('certificate extensions', 'set-certExt'), + '2.23.42.7.0': ('setCext-hashedRoot', ), + '2.23.42.7.1': ('setCext-certType', ), + '2.23.42.7.2': ('setCext-merchData', ), + '2.23.42.7.3': ('setCext-cCertRequired', ), + '2.23.42.7.4': ('setCext-tunneling', ), + '2.23.42.7.5': ('setCext-setExt', ), + '2.23.42.7.6': ('setCext-setQualf', ), + '2.23.42.7.7': ('setCext-PGWYcapabilities', ), + '2.23.42.7.8': ('setCext-TokenIdentifier', ), + '2.23.42.7.9': ('setCext-Track2Data', ), + '2.23.42.7.10': ('setCext-TokenType', ), + '2.23.42.7.11': ('setCext-IssuerCapabilities', ), + '2.23.42.8': ('set-brand', ), + '2.23.42.8.1': ('set-brand-IATA-ATA', ), + '2.23.42.8.4': ('set-brand-Visa', ), + '2.23.42.8.5': ('set-brand-MasterCard', ), + '2.23.42.8.30': ('set-brand-Diners', ), + '2.23.42.8.34': ('set-brand-AmericanExpress', ), + '2.23.42.8.35': ('set-brand-JCB', ), + '2.23.42.8.6011': ('set-brand-Novus', ), + '2.23.43': ('wap', ), + '2.23.43.1': ('wap-wsg', ), + '2.23.43.1.4': ('wap-wsg-idm-ecid', ), + '2.23.43.1.4.1': ('wap-wsg-idm-ecid-wtls1', ), + '2.23.43.1.4.3': ('wap-wsg-idm-ecid-wtls3', ), + '2.23.43.1.4.4': ('wap-wsg-idm-ecid-wtls4', ), + '2.23.43.1.4.5': ('wap-wsg-idm-ecid-wtls5', ), + '2.23.43.1.4.6': ('wap-wsg-idm-ecid-wtls6', ), + '2.23.43.1.4.7': ('wap-wsg-idm-ecid-wtls7', ), + '2.23.43.1.4.8': ('wap-wsg-idm-ecid-wtls8', ), + '2.23.43.1.4.9': ('wap-wsg-idm-ecid-wtls9', ), + '2.23.43.1.4.10': ('wap-wsg-idm-ecid-wtls10', ), + '2.23.43.1.4.11': ('wap-wsg-idm-ecid-wtls11', ), + '2.23.43.1.4.12': ('wap-wsg-idm-ecid-wtls12', ), +} +# ##################################################################################### +# ##################################################################################### + +_OID_LOOKUP = dict() +_NORMALIZE_NAMES = dict() +_NORMALIZE_NAMES_SHORT = dict() + +for dotted, names in _OID_MAP.items(): + for name in names: + if name in _NORMALIZE_NAMES and _OID_LOOKUP[name] != dotted: + raise AssertionError( + 'Name collision during setup: "{0}" for OIDs {1} and {2}' + .format(name, dotted, _OID_LOOKUP[name]) + ) + _NORMALIZE_NAMES[name] = names[0] + _NORMALIZE_NAMES_SHORT[name] = names[-1] + _OID_LOOKUP[name] = dotted +for alias, original in [('userID', 'userId')]: + if alias in _NORMALIZE_NAMES: + raise AssertionError( + 'Name collision during adding aliases: "{0}" (alias for "{1}") is already mapped to OID {2}' + .format(alias, original, _OID_LOOKUP[alias]) + ) + _NORMALIZE_NAMES[alias] = original + _NORMALIZE_NAMES_SHORT[alias] = _NORMALIZE_NAMES_SHORT[original] + _OID_LOOKUP[alias] = _OID_LOOKUP[original] + + +def pyopenssl_normalize_name(name, short=False): + nid = OpenSSL._util.lib.OBJ_txt2nid(to_bytes(name)) + if nid != 0: + b_name = OpenSSL._util.lib.OBJ_nid2ln(nid) + name = to_text(OpenSSL._util.ffi.string(b_name)) + if short: + return _NORMALIZE_NAMES_SHORT.get(name, name) + else: + return _NORMALIZE_NAMES.get(name, name) + + +# ##################################################################################### +# ##################################################################################### +# # This excerpt is dual licensed under the terms of the Apache License, Version +# # 2.0, and the BSD License. See the LICENSE file at +# # https://github.com/pyca/cryptography/blob/master/LICENSE for complete details. +# # +# # Adapted from cryptography's hazmat/backends/openssl/decode_asn1.py +# # +# # Copyright (c) 2015, 2016 Paul Kehrer (@reaperhulk) +# # Copyright (c) 2017 Fraser Tweedale (@frasertweedale) +# # +# # Relevant commits from cryptography project (https://github.com/pyca/cryptography): +# # pyca/cryptography@719d536dd691e84e208534798f2eb4f82aaa2e07 +# # pyca/cryptography@5ab6d6a5c05572bd1c75f05baf264a2d0001894a +# # pyca/cryptography@2e776e20eb60378e0af9b7439000d0e80da7c7e3 +# # pyca/cryptography@fb309ed24647d1be9e319b61b1f2aa8ebb87b90b +# # pyca/cryptography@2917e460993c475c72d7146c50dc3bbc2414280d +# # pyca/cryptography@3057f91ea9a05fb593825006d87a391286a4d828 +# # pyca/cryptography@d607dd7e5bc5c08854ec0c9baff70ba4a35be36f +def _obj2txt(openssl_lib, openssl_ffi, obj): + # Set to 80 on the recommendation of + # https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values + # + # But OIDs longer than this occur in real life (e.g. Active + # Directory makes some very long OIDs). So we need to detect + # and properly handle the case where the default buffer is not + # big enough. + # + buf_len = 80 + buf = openssl_ffi.new("char[]", buf_len) + + # 'res' is the number of bytes that *would* be written if the + # buffer is large enough. If 'res' > buf_len - 1, we need to + # alloc a big-enough buffer and go again. + res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1) + if res > buf_len - 1: # account for terminating null byte + buf_len = res + 1 + buf = openssl_ffi.new("char[]", buf_len) + res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1) + return openssl_ffi.buffer(buf, res)[:].decode() +# ##################################################################################### +# ##################################################################################### + + +def cryptography_get_extensions_from_cert(cert): + # Since cryptography won't give us the DER value for an extension + # (that is only stored for unrecognized extensions), we have to re-do + # the extension parsing outselves. + result = dict() + backend = cert._backend + x509_obj = cert._x509 + + for i in range(backend._lib.X509_get_ext_count(x509_obj)): + ext = backend._lib.X509_get_ext(x509_obj, i) + if ext == backend._ffi.NULL: + continue + crit = backend._lib.X509_EXTENSION_get_critical(ext) + data = backend._lib.X509_EXTENSION_get_data(ext) + backend.openssl_assert(data != backend._ffi.NULL) + der = backend._ffi.buffer(data.data, data.length)[:] + entry = dict( + critical=(crit == 1), + value=base64.b64encode(der), + ) + oid = _obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext)) + result[oid] = entry + return result + + +def cryptography_get_extensions_from_csr(csr): + # Since cryptography won't give us the DER value for an extension + # (that is only stored for unrecognized extensions), we have to re-do + # the extension parsing outselves. + result = dict() + backend = csr._backend + + extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) + extensions = backend._ffi.gc( + extensions, + lambda ext: backend._lib.sk_X509_EXTENSION_pop_free( + ext, + backend._ffi.addressof(backend._lib._original_lib, "X509_EXTENSION_free") + ) + ) + + for i in range(backend._lib.sk_X509_EXTENSION_num(extensions)): + ext = backend._lib.sk_X509_EXTENSION_value(extensions, i) + if ext == backend._ffi.NULL: + continue + crit = backend._lib.X509_EXTENSION_get_critical(ext) + data = backend._lib.X509_EXTENSION_get_data(ext) + backend.openssl_assert(data != backend._ffi.NULL) + der = backend._ffi.buffer(data.data, data.length)[:] + entry = dict( + critical=(crit == 1), + value=base64.b64encode(der), + ) + oid = _obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext)) + result[oid] = entry + return result + + +def pyopenssl_get_extensions_from_cert(cert): + # While pyOpenSSL allows us to get an extension's DER value, it won't + # give us the dotted string for an OID. So we have to do some magic to + # get hold of it. + result = dict() + ext_count = cert.get_extension_count() + for i in range(0, ext_count): + ext = cert.get_extension(i) + entry = dict( + critical=bool(ext.get_critical()), + value=base64.b64encode(ext.get_data()), + ) + oid = _obj2txt( + OpenSSL._util.lib, + OpenSSL._util.ffi, + OpenSSL._util.lib.X509_EXTENSION_get_object(ext._extension) + ) + # This could also be done a bit simpler: + # + # oid = _obj2txt(OpenSSL._util.lib, OpenSSL._util.ffi, OpenSSL._util.lib.OBJ_nid2obj(ext._nid)) + # + # Unfortunately this gives the wrong result in case the linked OpenSSL + # doesn't know the OID. That's why we have to get the OID dotted string + # similarly to how cryptography does it. + result[oid] = entry + return result + + +def pyopenssl_get_extensions_from_csr(csr): + # While pyOpenSSL allows us to get an extension's DER value, it won't + # give us the dotted string for an OID. So we have to do some magic to + # get hold of it. + result = dict() + for ext in csr.get_extensions(): + entry = dict( + critical=bool(ext.get_critical()), + value=base64.b64encode(ext.get_data()), + ) + oid = _obj2txt( + OpenSSL._util.lib, + OpenSSL._util.ffi, + OpenSSL._util.lib.X509_EXTENSION_get_object(ext._extension) + ) + # This could also be done a bit simpler: + # + # oid = _obj2txt(OpenSSL._util.lib, OpenSSL._util.ffi, OpenSSL._util.lib.OBJ_nid2obj(ext._nid)) + # + # Unfortunately this gives the wrong result in case the linked OpenSSL + # doesn't know the OID. That's why we have to get the OID dotted string + # similarly to how cryptography does it. + result[oid] = entry + return result + + +def cryptography_name_to_oid(name): + dotted = _OID_LOOKUP.get(name) + if dotted is None: + raise OpenSSLObjectError('Cannot find OID for "{0}"'.format(name)) + return x509.oid.ObjectIdentifier(dotted) + + +def cryptography_oid_to_name(oid, short=False): + dotted_string = oid.dotted_string + names = _OID_MAP.get(dotted_string) + name = names[0] if names else oid._name + if short: + return _NORMALIZE_NAMES_SHORT.get(name, name) + else: + return _NORMALIZE_NAMES.get(name, name) + + +def cryptography_get_name(name): + ''' + Given a name string, returns a cryptography x509.Name object. + Raises an OpenSSLObjectError if the name is unknown or cannot be parsed. + ''' + try: + if name.startswith('DNS:'): + return x509.DNSName(to_text(name[4:])) + if name.startswith('IP:'): + return x509.IPAddress(ipaddress.ip_address(to_text(name[3:]))) + if name.startswith('email:'): + return x509.RFC822Name(to_text(name[6:])) + if name.startswith('URI:'): + return x509.UniformResourceIdentifier(to_text(name[4:])) + except Exception as e: + raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}": {1}'.format(name, e)) + if ':' not in name: + raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}" (forgot "DNS:" prefix?)'.format(name)) + raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}" (potentially unsupported by cryptography backend)'.format(name)) + + +def _get_hex(bytesstr): + if bytesstr is None: + return bytesstr + data = binascii.hexlify(bytesstr) + data = to_text(b':'.join(data[i:i + 2] for i in range(0, len(data), 2))) + return data + + +def cryptography_decode_name(name): + ''' + Given a cryptography x509.Name object, returns a string. + Raises an OpenSSLObjectError if the name is not supported. + ''' + if isinstance(name, x509.DNSName): + return 'DNS:{0}'.format(name.value) + if isinstance(name, x509.IPAddress): + return 'IP:{0}'.format(name.value.compressed) + if isinstance(name, x509.RFC822Name): + return 'email:{0}'.format(name.value) + if isinstance(name, x509.UniformResourceIdentifier): + return 'URI:{0}'.format(name.value) + if isinstance(name, x509.DirectoryName): + # FIXME: test + return 'DirName:' + ''.join(['/{0}:{1}'.format(attribute.oid._name, attribute.value) for attribute in name.value]) + if isinstance(name, x509.RegisteredID): + # FIXME: test + return 'RegisteredID:{0}'.format(name.value) + if isinstance(name, x509.OtherName): + # FIXME: test + return '{0}:{1}'.format(name.type_id.dotted_string, _get_hex(name.value)) + raise OpenSSLObjectError('Cannot decode name "{0}"'.format(name)) + + +def _cryptography_get_keyusage(usage): + ''' + Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage(). + Raises an OpenSSLObjectError if the identifier is unknown. + ''' + if usage in ('Digital Signature', 'digitalSignature'): + return 'digital_signature' + if usage in ('Non Repudiation', 'nonRepudiation'): + return 'content_commitment' + if usage in ('Key Encipherment', 'keyEncipherment'): + return 'key_encipherment' + if usage in ('Data Encipherment', 'dataEncipherment'): + return 'data_encipherment' + if usage in ('Key Agreement', 'keyAgreement'): + return 'key_agreement' + if usage in ('Certificate Sign', 'keyCertSign'): + return 'key_cert_sign' + if usage in ('CRL Sign', 'cRLSign'): + return 'crl_sign' + if usage in ('Encipher Only', 'encipherOnly'): + return 'encipher_only' + if usage in ('Decipher Only', 'decipherOnly'): + return 'decipher_only' + raise OpenSSLObjectError('Unknown key usage "{0}"'.format(usage)) + + +def cryptography_parse_key_usage_params(usages): + ''' + Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage(). + Raises an OpenSSLObjectError if an identifier is unknown. + ''' + params = dict( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) + for usage in usages: + params[_cryptography_get_keyusage(usage)] = True + return params + + +def cryptography_get_basic_constraints(constraints): + ''' + Given a list of constraints, returns a tuple (ca, path_length). + Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed. + ''' + ca = False + path_length = None + if constraints: + for constraint in constraints: + if constraint.startswith('CA:'): + if constraint == 'CA:TRUE': + ca = True + elif constraint == 'CA:FALSE': + ca = False + else: + raise OpenSSLObjectError('Unknown basic constraint value "{0}" for CA'.format(constraint[3:])) + elif constraint.startswith('pathlen:'): + v = constraint[len('pathlen:'):] + try: + path_length = int(v) + except Exception as e: + raise OpenSSLObjectError('Cannot parse path length constraint "{0}" ({1})'.format(v, e)) + else: + raise OpenSSLObjectError('Unknown basic constraint "{0}"'.format(constraint)) + return ca, path_length + + +def binary_exp_mod(f, e, m): + '''Computes f^e mod m in O(log e) multiplications modulo m.''' + # Compute len_e = floor(log_2(e)) + len_e = -1 + x = e + while x > 0: + x >>= 1 + len_e += 1 + # Compute f**e mod m + result = 1 + for k in range(len_e, -1, -1): + result = (result * result) % m + if ((e >> k) & 1) != 0: + result = (result * f) % m + return result + + +def simple_gcd(a, b): + '''Compute GCD of its two inputs.''' + while b != 0: + a, b = b, a % b + return a + + +def quick_is_not_prime(n): + '''Does some quick checks to see if we can poke a hole into the primality of n. + + A result of `False` does **not** mean that the number is prime; it just means + that we couldn't detect quickly whether it is not prime. + ''' + if n <= 2: + return True + # The constant in the next line is the product of all primes < 200 + if simple_gcd(n, 7799922041683461553249199106329813876687996789903550945093032474868511536164700810) > 1: + return True + # TODO: maybe do some iterations of Miller-Rabin to increase confidence + # (https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test) + return False + + +python_version = (sys.version_info[0], sys.version_info[1]) +if python_version >= (2, 7) or python_version >= (3, 1): + # Ansible still supports Python 2.6 on remote nodes + def count_bits(no): + no = abs(no) + if no == 0: + return 0 + return no.bit_length() +else: + # Slow, but works + def count_bits(no): + no = abs(no) + count = 0 + while no > 0: + no >>= 1 + count += 1 + return count + + +PEM_START = '-----BEGIN ' +PEM_END = '-----' +PKCS8_PRIVATEKEY_NAMES = ('PRIVATE KEY', 'ENCRYPTED PRIVATE KEY') +PKCS1_PRIVATEKEY_SUFFIX = ' PRIVATE KEY' + + +def identify_private_key_format(content): + '''Given the contents of a private key file, identifies its format.''' + # See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85 + # (PEM_read_bio_PrivateKey) + # and https://github.com/openssl/openssl/blob/master/include/openssl/pem.h#L46-L47 + # (PEM_STRING_PKCS8, PEM_STRING_PKCS8INF) + try: + lines = content.decode('utf-8').splitlines(False) + if lines[0].startswith(PEM_START) and lines[0].endswith(PEM_END) and len(lines[0]) > len(PEM_START) + len(PEM_END): + name = lines[0][len(PEM_START):-len(PEM_END)] + if name in PKCS8_PRIVATEKEY_NAMES: + return 'pkcs8' + if len(name) > len(PKCS1_PRIVATEKEY_SUFFIX) and name.endswith(PKCS1_PRIVATEKEY_SUFFIX): + return 'pkcs1' + return 'unknown-pem' + except UnicodeDecodeError: + pass + return 'raw' + + +def cryptography_key_needs_digest_for_signing(key): + '''Tests whether the given private key requires a digest algorithm for signing. + + Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm. + ''' + if CRYPTOGRAPHY_HAS_ED25519 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey): + return False + if CRYPTOGRAPHY_HAS_ED448 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey): + return False + return True + + +def cryptography_compare_public_keys(key1, key2): + '''Tests whether two public keys are the same. + + Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers(). + ''' + if CRYPTOGRAPHY_HAS_ED25519: + a = isinstance(key1, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey) + b = isinstance(key2, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey) + if a or b: + if not a or not b: + return False + a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + return a == b + if CRYPTOGRAPHY_HAS_ED448: + a = isinstance(key1, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey) + b = isinstance(key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey) + if a or b: + if not a or not b: + return False + a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + return a == b + return key1.public_numbers() == key2.public_numbers() + + +if HAS_CRYPTOGRAPHY: + REVOCATION_REASON_MAP = { + 'unspecified': x509.ReasonFlags.unspecified, + 'key_compromise': x509.ReasonFlags.key_compromise, + 'ca_compromise': x509.ReasonFlags.ca_compromise, + 'affiliation_changed': x509.ReasonFlags.affiliation_changed, + 'superseded': x509.ReasonFlags.superseded, + 'cessation_of_operation': x509.ReasonFlags.cessation_of_operation, + 'certificate_hold': x509.ReasonFlags.certificate_hold, + 'privilege_withdrawn': x509.ReasonFlags.privilege_withdrawn, + 'aa_compromise': x509.ReasonFlags.aa_compromise, + 'remove_from_crl': x509.ReasonFlags.remove_from_crl, + } + REVOCATION_REASON_MAP_INVERSE = dict() + for k, v in REVOCATION_REASON_MAP.items(): + REVOCATION_REASON_MAP_INVERSE[v] = k + + +def cryptography_decode_revoked_certificate(cert): + result = { + 'serial_number': cert.serial_number, + 'revocation_date': cert.revocation_date, + 'issuer': None, + 'issuer_critical': False, + 'reason': None, + 'reason_critical': False, + 'invalidity_date': None, + 'invalidity_date_critical': False, + } + try: + ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer) + result['issuer'] = list(ext.value) + result['issuer_critical'] = ext.critical + except x509.ExtensionNotFound: + pass + try: + ext = cert.extensions.get_extension_for_class(x509.CRLReason) + result['reason'] = ext.value.reason + result['reason_critical'] = ext.critical + except x509.ExtensionNotFound: + pass + try: + ext = cert.extensions.get_extension_for_class(x509.InvalidityDate) + result['invalidity_date'] = ext.value.invalidity_date + result['invalidity_date_critical'] = ext.critical + except x509.ExtensionNotFound: + pass + return result diff --git a/test/support/integration/plugins/module_utils/database.py b/test/support/integration/plugins/module_utils/database.py new file mode 100644 index 00000000..014939a2 --- /dev/null +++ b/test/support/integration/plugins/module_utils/database.py @@ -0,0 +1,142 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class SQLParseError(Exception): + pass + + +class UnclosedQuoteError(SQLParseError): + pass + + +# maps a type of identifier to the maximum number of dot levels that are +# allowed to specify that identifier. For example, a database column can be +# specified by up to 4 levels: database.schema.table.column +_PG_IDENTIFIER_TO_DOT_LEVEL = dict( + database=1, + schema=2, + table=3, + column=4, + role=1, + tablespace=1, + sequence=3, + publication=1, +) +_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1) + + +def _find_end_quote(identifier, quote_char): + accumulate = 0 + while True: + try: + quote = identifier.index(quote_char) + except ValueError: + raise UnclosedQuoteError + accumulate = accumulate + quote + try: + next_char = identifier[quote + 1] + except IndexError: + return accumulate + if next_char == quote_char: + try: + identifier = identifier[quote + 2:] + accumulate = accumulate + 2 + except IndexError: + raise UnclosedQuoteError + else: + return accumulate + + +def _identifier_parse(identifier, quote_char): + if not identifier: + raise SQLParseError('Identifier name unspecified or unquoted trailing dot') + + already_quoted = False + if identifier.startswith(quote_char): + already_quoted = True + try: + end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1 + except UnclosedQuoteError: + already_quoted = False + else: + if end_quote < len(identifier) - 1: + if identifier[end_quote + 1] == '.': + dot = end_quote + 1 + first_identifier = identifier[:dot] + next_identifier = identifier[dot + 1:] + further_identifiers = _identifier_parse(next_identifier, quote_char) + further_identifiers.insert(0, first_identifier) + else: + raise SQLParseError('User escaped identifiers must escape extra quotes') + else: + further_identifiers = [identifier] + + if not already_quoted: + try: + dot = identifier.index('.') + except ValueError: + identifier = identifier.replace(quote_char, quote_char * 2) + identifier = ''.join((quote_char, identifier, quote_char)) + further_identifiers = [identifier] + else: + if dot == 0 or dot >= len(identifier) - 1: + identifier = identifier.replace(quote_char, quote_char * 2) + identifier = ''.join((quote_char, identifier, quote_char)) + further_identifiers = [identifier] + else: + first_identifier = identifier[:dot] + next_identifier = identifier[dot + 1:] + further_identifiers = _identifier_parse(next_identifier, quote_char) + first_identifier = first_identifier.replace(quote_char, quote_char * 2) + first_identifier = ''.join((quote_char, first_identifier, quote_char)) + further_identifiers.insert(0, first_identifier) + + return further_identifiers + + +def pg_quote_identifier(identifier, id_type): + identifier_fragments = _identifier_parse(identifier, quote_char='"') + if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]: + raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type])) + return '.'.join(identifier_fragments) + + +def mysql_quote_identifier(identifier, id_type): + identifier_fragments = _identifier_parse(identifier, quote_char='`') + if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]: + raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type])) + + special_cased_fragments = [] + for fragment in identifier_fragments: + if fragment == '`*`': + special_cased_fragments.append('*') + else: + special_cased_fragments.append(fragment) + + return '.'.join(special_cased_fragments) diff --git a/test/support/integration/plugins/module_utils/docker/__init__.py b/test/support/integration/plugins/module_utils/docker/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/docker/__init__.py diff --git a/test/support/integration/plugins/module_utils/docker/common.py b/test/support/integration/plugins/module_utils/docker/common.py new file mode 100644 index 00000000..03307250 --- /dev/null +++ b/test/support/integration/plugins/module_utils/docker/common.py @@ -0,0 +1,1022 @@ +# +# Copyright 2016 Red Hat | Ansible +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +import os +import platform +import re +import sys +from datetime import timedelta +from distutils.version import LooseVersion + + +from ansible.module_utils.basic import AnsibleModule, env_fallback, missing_required_lib +from ansible.module_utils.common._collections_compat import Mapping, Sequence +from ansible.module_utils.six import string_types +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE, BOOLEANS_FALSE + +HAS_DOCKER_PY = True +HAS_DOCKER_PY_2 = False +HAS_DOCKER_PY_3 = False +HAS_DOCKER_ERROR = None + +try: + from requests.exceptions import SSLError + from docker import __version__ as docker_version + from docker.errors import APIError, NotFound, TLSParameterError + from docker.tls import TLSConfig + from docker import auth + + if LooseVersion(docker_version) >= LooseVersion('3.0.0'): + HAS_DOCKER_PY_3 = True + from docker import APIClient as Client + elif LooseVersion(docker_version) >= LooseVersion('2.0.0'): + HAS_DOCKER_PY_2 = True + from docker import APIClient as Client + else: + from docker import Client + +except ImportError as exc: + HAS_DOCKER_ERROR = str(exc) + HAS_DOCKER_PY = False + + +# The next 2 imports ``docker.models`` and ``docker.ssladapter`` are used +# to ensure the user does not have both ``docker`` and ``docker-py`` modules +# installed, as they utilize the same namespace are are incompatible +try: + # docker (Docker SDK for Python >= 2.0.0) + import docker.models # noqa: F401 + HAS_DOCKER_MODELS = True +except ImportError: + HAS_DOCKER_MODELS = False + +try: + # docker-py (Docker SDK for Python < 2.0.0) + import docker.ssladapter # noqa: F401 + HAS_DOCKER_SSLADAPTER = True +except ImportError: + HAS_DOCKER_SSLADAPTER = False + + +try: + from requests.exceptions import RequestException +except ImportError: + # Either docker-py is no longer using requests, or docker-py isn't around either, + # or docker-py's dependency requests is missing. In any case, define an exception + # class RequestException so that our code doesn't break. + class RequestException(Exception): + pass + + +DEFAULT_DOCKER_HOST = 'unix://var/run/docker.sock' +DEFAULT_TLS = False +DEFAULT_TLS_VERIFY = False +DEFAULT_TLS_HOSTNAME = 'localhost' +MIN_DOCKER_VERSION = "1.8.0" +DEFAULT_TIMEOUT_SECONDS = 60 + +DOCKER_COMMON_ARGS = dict( + docker_host=dict(type='str', default=DEFAULT_DOCKER_HOST, fallback=(env_fallback, ['DOCKER_HOST']), aliases=['docker_url']), + tls_hostname=dict(type='str', default=DEFAULT_TLS_HOSTNAME, fallback=(env_fallback, ['DOCKER_TLS_HOSTNAME'])), + api_version=dict(type='str', default='auto', fallback=(env_fallback, ['DOCKER_API_VERSION']), aliases=['docker_api_version']), + timeout=dict(type='int', default=DEFAULT_TIMEOUT_SECONDS, fallback=(env_fallback, ['DOCKER_TIMEOUT'])), + ca_cert=dict(type='path', aliases=['tls_ca_cert', 'cacert_path']), + client_cert=dict(type='path', aliases=['tls_client_cert', 'cert_path']), + client_key=dict(type='path', aliases=['tls_client_key', 'key_path']), + ssl_version=dict(type='str', fallback=(env_fallback, ['DOCKER_SSL_VERSION'])), + tls=dict(type='bool', default=DEFAULT_TLS, fallback=(env_fallback, ['DOCKER_TLS'])), + validate_certs=dict(type='bool', default=DEFAULT_TLS_VERIFY, fallback=(env_fallback, ['DOCKER_TLS_VERIFY']), aliases=['tls_verify']), + debug=dict(type='bool', default=False) +) + +DOCKER_MUTUALLY_EXCLUSIVE = [] + +DOCKER_REQUIRED_TOGETHER = [ + ['client_cert', 'client_key'] +] + +DEFAULT_DOCKER_REGISTRY = 'https://index.docker.io/v1/' +EMAIL_REGEX = r'[^@]+@[^@]+\.[^@]+' +BYTE_SUFFIXES = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + + +if not HAS_DOCKER_PY: + docker_version = None + + # No Docker SDK for Python. Create a place holder client to allow + # instantiation of AnsibleModule and proper error handing + class Client(object): # noqa: F811 + def __init__(self, **kwargs): + pass + + class APIError(Exception): # noqa: F811 + pass + + class NotFound(Exception): # noqa: F811 + pass + + +def is_image_name_id(name): + """Check whether the given image name is in fact an image ID (hash).""" + if re.match('^sha256:[0-9a-fA-F]{64}$', name): + return True + return False + + +def is_valid_tag(tag, allow_empty=False): + """Check whether the given string is a valid docker tag name.""" + if not tag: + return allow_empty + # See here ("Extended description") for a definition what tags can be: + # https://docs.docker.com/engine/reference/commandline/tag/ + return bool(re.match('^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$', tag)) + + +def sanitize_result(data): + """Sanitize data object for return to Ansible. + + When the data object contains types such as docker.types.containers.HostConfig, + Ansible will fail when these are returned via exit_json or fail_json. + HostConfig is derived from dict, but its constructor requires additional + arguments. This function sanitizes data structures by recursively converting + everything derived from dict to dict and everything derived from list (and tuple) + to a list. + """ + if isinstance(data, dict): + return dict((k, sanitize_result(v)) for k, v in data.items()) + elif isinstance(data, (list, tuple)): + return [sanitize_result(v) for v in data] + else: + return data + + +class DockerBaseClass(object): + + def __init__(self): + self.debug = False + + def log(self, msg, pretty_print=False): + pass + # if self.debug: + # log_file = open('docker.log', 'a') + # if pretty_print: + # log_file.write(json.dumps(msg, sort_keys=True, indent=4, separators=(',', ': '))) + # log_file.write(u'\n') + # else: + # log_file.write(msg + u'\n') + + +def update_tls_hostname(result): + if result['tls_hostname'] is None: + # get default machine name from the url + parsed_url = urlparse(result['docker_host']) + if ':' in parsed_url.netloc: + result['tls_hostname'] = parsed_url.netloc[:parsed_url.netloc.rindex(':')] + else: + result['tls_hostname'] = parsed_url + + +def _get_tls_config(fail_function, **kwargs): + try: + tls_config = TLSConfig(**kwargs) + return tls_config + except TLSParameterError as exc: + fail_function("TLS config error: %s" % exc) + + +def get_connect_params(auth, fail_function): + if auth['tls'] or auth['tls_verify']: + auth['docker_host'] = auth['docker_host'].replace('tcp://', 'https://') + + if auth['tls_verify'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and host verification + if auth['cacert_path']: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + ca_cert=auth['cacert_path'], + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + else: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify'] and auth['cacert_path']: + # TLS with cacert only + tls_config = _get_tls_config(ca_cert=auth['cacert_path'], + assert_hostname=auth['tls_hostname'], + verify=True, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify']: + # TLS with verify and no certs + tls_config = _get_tls_config(verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and no host verification + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls']: + # TLS with no certs and not host verification + tls_config = _get_tls_config(verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + # No TLS + return dict(base_url=auth['docker_host'], + version=auth['api_version'], + timeout=auth['timeout']) + + +DOCKERPYUPGRADE_SWITCH_TO_DOCKER = "Try `pip uninstall docker-py` followed by `pip install docker`." +DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade." +DOCKERPYUPGRADE_RECOMMEND_DOCKER = ("Use `pip install --upgrade docker-py` to upgrade. " + "Hint: if you do not need Python 2.6 support, try " + "`pip uninstall docker-py` instead, followed by `pip install docker`.") + + +class AnsibleDockerClient(Client): + + def __init__(self, argument_spec=None, supports_check_mode=False, mutually_exclusive=None, + required_together=None, required_if=None, min_docker_version=MIN_DOCKER_VERSION, + min_docker_api_version=None, option_minimal_versions=None, + option_minimal_versions_ignore_params=None, fail_results=None): + + # Modules can put information in here which will always be returned + # in case client.fail() is called. + self.fail_results = fail_results or {} + + merged_arg_spec = dict() + merged_arg_spec.update(DOCKER_COMMON_ARGS) + if argument_spec: + merged_arg_spec.update(argument_spec) + self.arg_spec = merged_arg_spec + + mutually_exclusive_params = [] + mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE + if mutually_exclusive: + mutually_exclusive_params += mutually_exclusive + + required_together_params = [] + required_together_params += DOCKER_REQUIRED_TOGETHER + if required_together: + required_together_params += required_together + + self.module = AnsibleModule( + argument_spec=merged_arg_spec, + supports_check_mode=supports_check_mode, + mutually_exclusive=mutually_exclusive_params, + required_together=required_together_params, + required_if=required_if) + + NEEDS_DOCKER_PY2 = (LooseVersion(min_docker_version) >= LooseVersion('2.0.0')) + + self.docker_py_version = LooseVersion(docker_version) + + if HAS_DOCKER_MODELS and HAS_DOCKER_SSLADAPTER: + self.fail("Cannot have both the docker-py and docker python modules (old and new version of Docker " + "SDK for Python) installed together as they use the same namespace and cause a corrupt " + "installation. Please uninstall both packages, and re-install only the docker-py or docker " + "python module (for %s's Python %s). It is recommended to install the docker module if no " + "support for Python 2.6 is required. Please note that simply uninstalling one of the modules " + "can leave the other module in a broken state." % (platform.node(), sys.executable)) + + if not HAS_DOCKER_PY: + if NEEDS_DOCKER_PY2: + msg = missing_required_lib("Docker SDK for Python: docker") + msg = msg + ", for example via `pip install docker`. The error was: %s" + else: + msg = missing_required_lib("Docker SDK for Python: docker (Python >= 2.7) or docker-py (Python 2.6)") + msg = msg + ", for example via `pip install docker` or `pip install docker-py` (Python 2.6). The error was: %s" + self.fail(msg % HAS_DOCKER_ERROR) + + if self.docker_py_version < LooseVersion(min_docker_version): + msg = "Error: Docker SDK for Python version is %s (%s's Python %s). Minimum version required is %s." + if not NEEDS_DOCKER_PY2: + # The minimal required version is < 2.0 (and the current version as well). + # Advertise docker (instead of docker-py) for non-Python-2.6 users. + msg += DOCKERPYUPGRADE_RECOMMEND_DOCKER + elif docker_version < LooseVersion('2.0'): + msg += DOCKERPYUPGRADE_SWITCH_TO_DOCKER + else: + msg += DOCKERPYUPGRADE_UPGRADE_DOCKER + self.fail(msg % (docker_version, platform.node(), sys.executable, min_docker_version)) + + self.debug = self.module.params.get('debug') + self.check_mode = self.module.check_mode + self._connect_params = get_connect_params(self.auth_params, fail_function=self.fail) + + try: + super(AnsibleDockerClient, self).__init__(**self._connect_params) + self.docker_api_version_str = self.version()['ApiVersion'] + except APIError as exc: + self.fail("Docker API error: %s" % exc) + except Exception as exc: + self.fail("Error connecting: %s" % exc) + + self.docker_api_version = LooseVersion(self.docker_api_version_str) + if min_docker_api_version is not None: + if self.docker_api_version < LooseVersion(min_docker_api_version): + self.fail('Docker API version is %s. Minimum version required is %s.' % (self.docker_api_version_str, min_docker_api_version)) + + if option_minimal_versions is not None: + self._get_minimal_versions(option_minimal_versions, option_minimal_versions_ignore_params) + + def log(self, msg, pretty_print=False): + pass + # if self.debug: + # log_file = open('docker.log', 'a') + # if pretty_print: + # log_file.write(json.dumps(msg, sort_keys=True, indent=4, separators=(',', ': '))) + # log_file.write(u'\n') + # else: + # log_file.write(msg + u'\n') + + def fail(self, msg, **kwargs): + self.fail_results.update(kwargs) + self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) + + @staticmethod + def _get_value(param_name, param_value, env_variable, default_value): + if param_value is not None: + # take module parameter value + if param_value in BOOLEANS_TRUE: + return True + if param_value in BOOLEANS_FALSE: + return False + return param_value + + if env_variable is not None: + env_value = os.environ.get(env_variable) + if env_value is not None: + # take the env variable value + if param_name == 'cert_path': + return os.path.join(env_value, 'cert.pem') + if param_name == 'cacert_path': + return os.path.join(env_value, 'ca.pem') + if param_name == 'key_path': + return os.path.join(env_value, 'key.pem') + if env_value in BOOLEANS_TRUE: + return True + if env_value in BOOLEANS_FALSE: + return False + return env_value + + # take the default + return default_value + + @property + def auth_params(self): + # Get authentication credentials. + # Precedence: module parameters-> environment variables-> defaults. + + self.log('Getting credentials') + + params = dict() + for key in DOCKER_COMMON_ARGS: + params[key] = self.module.params.get(key) + + if self.module.params.get('use_tls'): + # support use_tls option in docker_image.py. This will be deprecated. + use_tls = self.module.params.get('use_tls') + if use_tls == 'encrypt': + params['tls'] = True + if use_tls == 'verify': + params['validate_certs'] = True + + result = dict( + docker_host=self._get_value('docker_host', params['docker_host'], 'DOCKER_HOST', + DEFAULT_DOCKER_HOST), + tls_hostname=self._get_value('tls_hostname', params['tls_hostname'], + 'DOCKER_TLS_HOSTNAME', DEFAULT_TLS_HOSTNAME), + api_version=self._get_value('api_version', params['api_version'], 'DOCKER_API_VERSION', + 'auto'), + cacert_path=self._get_value('cacert_path', params['ca_cert'], 'DOCKER_CERT_PATH', None), + cert_path=self._get_value('cert_path', params['client_cert'], 'DOCKER_CERT_PATH', None), + key_path=self._get_value('key_path', params['client_key'], 'DOCKER_CERT_PATH', None), + ssl_version=self._get_value('ssl_version', params['ssl_version'], 'DOCKER_SSL_VERSION', None), + tls=self._get_value('tls', params['tls'], 'DOCKER_TLS', DEFAULT_TLS), + tls_verify=self._get_value('tls_verfy', params['validate_certs'], 'DOCKER_TLS_VERIFY', + DEFAULT_TLS_VERIFY), + timeout=self._get_value('timeout', params['timeout'], 'DOCKER_TIMEOUT', + DEFAULT_TIMEOUT_SECONDS), + ) + + update_tls_hostname(result) + + return result + + def _handle_ssl_error(self, error): + match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) + if match: + self.fail("You asked for verification that Docker daemons certificate's hostname matches %s. " + "The actual certificate's hostname is %s. Most likely you need to set DOCKER_TLS_HOSTNAME " + "or pass `tls_hostname` with a value of %s. You may also use TLS without verification by " + "setting the `tls` parameter to true." + % (self.auth_params['tls_hostname'], match.group(1), match.group(1))) + self.fail("SSL Exception: %s" % (error)) + + def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): + self.option_minimal_versions = dict() + for option in self.module.argument_spec: + if ignore_params is not None: + if option in ignore_params: + continue + self.option_minimal_versions[option] = dict() + self.option_minimal_versions.update(option_minimal_versions) + + for option, data in self.option_minimal_versions.items(): + # Test whether option is supported, and store result + support_docker_py = True + support_docker_api = True + if 'docker_py_version' in data: + support_docker_py = self.docker_py_version >= LooseVersion(data['docker_py_version']) + if 'docker_api_version' in data: + support_docker_api = self.docker_api_version >= LooseVersion(data['docker_api_version']) + data['supported'] = support_docker_py and support_docker_api + # Fail if option is not supported but used + if not data['supported']: + # Test whether option is specified + if 'detect_usage' in data: + used = data['detect_usage'](self) + else: + used = self.module.params.get(option) is not None + if used and 'default' in self.module.argument_spec[option]: + used = self.module.params[option] != self.module.argument_spec[option]['default'] + if used: + # If the option is used, compose error message. + if 'usage_msg' in data: + usg = data['usage_msg'] + else: + usg = 'set %s option' % (option, ) + if not support_docker_api: + msg = 'Docker API version is %s. Minimum version required is %s to %s.' + msg = msg % (self.docker_api_version_str, data['docker_api_version'], usg) + elif not support_docker_py: + msg = "Docker SDK for Python version is %s (%s's Python %s). Minimum version required is %s to %s. " + if LooseVersion(data['docker_py_version']) < LooseVersion('2.0.0'): + msg += DOCKERPYUPGRADE_RECOMMEND_DOCKER + elif self.docker_py_version < LooseVersion('2.0.0'): + msg += DOCKERPYUPGRADE_SWITCH_TO_DOCKER + else: + msg += DOCKERPYUPGRADE_UPGRADE_DOCKER + msg = msg % (docker_version, platform.node(), sys.executable, data['docker_py_version'], usg) + else: + # should not happen + msg = 'Cannot %s with your configuration.' % (usg, ) + self.fail(msg) + + def get_container_by_id(self, container_id): + try: + self.log("Inspecting container Id %s" % container_id) + result = self.inspect_container(container=container_id) + self.log("Completed container inspection") + return result + except NotFound as dummy: + return None + except Exception as exc: + self.fail("Error inspecting container: %s" % exc) + + def get_container(self, name=None): + ''' + Lookup a container and return the inspection results. + ''' + if name is None: + return None + + search_name = name + if not name.startswith('/'): + search_name = '/' + name + + result = None + try: + for container in self.containers(all=True): + self.log("testing container: %s" % (container['Names'])) + if isinstance(container['Names'], list) and search_name in container['Names']: + result = container + break + if container['Id'].startswith(name): + result = container + break + if container['Id'] == name: + result = container + break + except SSLError as exc: + self._handle_ssl_error(exc) + except Exception as exc: + self.fail("Error retrieving container list: %s" % exc) + + if result is None: + return None + + return self.get_container_by_id(result['Id']) + + def get_network(self, name=None, network_id=None): + ''' + Lookup a network and return the inspection results. + ''' + if name is None and network_id is None: + return None + + result = None + + if network_id is None: + try: + for network in self.networks(): + self.log("testing network: %s" % (network['Name'])) + if name == network['Name']: + result = network + break + if network['Id'].startswith(name): + result = network + break + except SSLError as exc: + self._handle_ssl_error(exc) + except Exception as exc: + self.fail("Error retrieving network list: %s" % exc) + + if result is not None: + network_id = result['Id'] + + if network_id is not None: + try: + self.log("Inspecting network Id %s" % network_id) + result = self.inspect_network(network_id) + self.log("Completed network inspection") + except NotFound as dummy: + return None + except Exception as exc: + self.fail("Error inspecting network: %s" % exc) + + return result + + def find_image(self, name, tag): + ''' + Lookup an image (by name and tag) and return the inspection results. + ''' + if not name: + return None + + self.log("Find image %s:%s" % (name, tag)) + images = self._image_lookup(name, tag) + if not images: + # In API <= 1.20 seeing 'docker.io/<name>' as the name of images pulled from docker hub + registry, repo_name = auth.resolve_repository_name(name) + if registry == 'docker.io': + # If docker.io is explicitly there in name, the image + # isn't found in some cases (#41509) + self.log("Check for docker.io image: %s" % repo_name) + images = self._image_lookup(repo_name, tag) + if not images and repo_name.startswith('library/'): + # Sometimes library/xxx images are not found + lookup = repo_name[len('library/'):] + self.log("Check for docker.io image: %s" % lookup) + images = self._image_lookup(lookup, tag) + if not images: + # Last case: if docker.io wasn't there, it can be that + # the image wasn't found either (#15586) + lookup = "%s/%s" % (registry, repo_name) + self.log("Check for docker.io image: %s" % lookup) + images = self._image_lookup(lookup, tag) + + if len(images) > 1: + self.fail("Registry returned more than one result for %s:%s" % (name, tag)) + + if len(images) == 1: + try: + inspection = self.inspect_image(images[0]['Id']) + except Exception as exc: + self.fail("Error inspecting image %s:%s - %s" % (name, tag, str(exc))) + return inspection + + self.log("Image %s:%s not found." % (name, tag)) + return None + + def find_image_by_id(self, image_id): + ''' + Lookup an image (by ID) and return the inspection results. + ''' + if not image_id: + return None + + self.log("Find image %s (by ID)" % image_id) + try: + inspection = self.inspect_image(image_id) + except Exception as exc: + self.fail("Error inspecting image ID %s - %s" % (image_id, str(exc))) + return inspection + + def _image_lookup(self, name, tag): + ''' + Including a tag in the name parameter sent to the Docker SDK for Python images method + does not work consistently. Instead, get the result set for name and manually check + if the tag exists. + ''' + try: + response = self.images(name=name) + except Exception as exc: + self.fail("Error searching for image %s - %s" % (name, str(exc))) + images = response + if tag: + lookup = "%s:%s" % (name, tag) + lookup_digest = "%s@%s" % (name, tag) + images = [] + for image in response: + tags = image.get('RepoTags') + digests = image.get('RepoDigests') + if (tags and lookup in tags) or (digests and lookup_digest in digests): + images = [image] + break + return images + + def pull_image(self, name, tag="latest"): + ''' + Pull an image + ''' + self.log("Pulling image %s:%s" % (name, tag)) + old_tag = self.find_image(name, tag) + try: + for line in self.pull(name, tag=tag, stream=True, decode=True): + self.log(line, pretty_print=True) + if line.get('error'): + if line.get('errorDetail'): + error_detail = line.get('errorDetail') + self.fail("Error pulling %s - code: %s message: %s" % (name, + error_detail.get('code'), + error_detail.get('message'))) + else: + self.fail("Error pulling %s - %s" % (name, line.get('error'))) + except Exception as exc: + self.fail("Error pulling image %s:%s - %s" % (name, tag, str(exc))) + + new_tag = self.find_image(name, tag) + + return new_tag, old_tag == new_tag + + def report_warnings(self, result, warnings_key=None): + ''' + Checks result of client operation for warnings, and if present, outputs them. + + warnings_key should be a list of keys used to crawl the result dictionary. + For example, if warnings_key == ['a', 'b'], the function will consider + result['a']['b'] if these keys exist. If the result is a non-empty string, it + will be reported as a warning. If the result is a list, every entry will be + reported as a warning. + + In most cases (if warnings are returned at all), warnings_key should be + ['Warnings'] or ['Warning']. The default value (if not specified) is ['Warnings']. + ''' + if warnings_key is None: + warnings_key = ['Warnings'] + for key in warnings_key: + if not isinstance(result, Mapping): + return + result = result.get(key) + if isinstance(result, Sequence): + for warning in result: + self.module.warn('Docker warning: {0}'.format(warning)) + elif isinstance(result, string_types) and result: + self.module.warn('Docker warning: {0}'.format(result)) + + def inspect_distribution(self, image, **kwargs): + ''' + Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 + since prior versions did not support accessing private repositories. + ''' + if self.docker_py_version < LooseVersion('4.0.0'): + registry = auth.resolve_repository_name(image)[0] + header = auth.get_config_header(self, registry) + if header: + return self._result(self._get( + self._url('/distribution/{0}/json', image), + headers={'X-Registry-Auth': header} + ), json=True) + return super(AnsibleDockerClient, self).inspect_distribution(image, **kwargs) + + +def compare_dict_allow_more_present(av, bv): + ''' + Compare two dictionaries for whether every entry of the first is in the second. + ''' + for key, value in av.items(): + if key not in bv: + return False + if bv[key] != value: + return False + return True + + +def compare_generic(a, b, method, datatype): + ''' + Compare values a and b as described by method and datatype. + + Returns ``True`` if the values compare equal, and ``False`` if not. + + ``a`` is usually the module's parameter, while ``b`` is a property + of the current object. ``a`` must not be ``None`` (except for + ``datatype == 'value'``). + + Valid values for ``method`` are: + - ``ignore`` (always compare as equal); + - ``strict`` (only compare if really equal) + - ``allow_more_present`` (allow b to have elements which a does not have). + + Valid values for ``datatype`` are: + - ``value``: for simple values (strings, numbers, ...); + - ``list``: for ``list``s or ``tuple``s where order matters; + - ``set``: for ``list``s, ``tuple``s or ``set``s where order does not + matter; + - ``set(dict)``: for ``list``s, ``tuple``s or ``sets`` where order does + not matter and which contain ``dict``s; ``allow_more_present`` is used + for the ``dict``s, and these are assumed to be dictionaries of values; + - ``dict``: for dictionaries of values. + ''' + if method == 'ignore': + return True + # If a or b is None: + if a is None or b is None: + # If both are None: equality + if a == b: + return True + # Otherwise, not equal for values, and equal + # if the other is empty for set/list/dict + if datatype == 'value': + return False + # For allow_more_present, allow a to be None + if method == 'allow_more_present' and a is None: + return True + # Otherwise, the iterable object which is not None must have length 0 + return len(b if a is None else a) == 0 + # Do proper comparison (both objects not None) + if datatype == 'value': + return a == b + elif datatype == 'list': + if method == 'strict': + return a == b + else: + i = 0 + for v in a: + while i < len(b) and b[i] != v: + i += 1 + if i == len(b): + return False + i += 1 + return True + elif datatype == 'dict': + if method == 'strict': + return a == b + else: + return compare_dict_allow_more_present(a, b) + elif datatype == 'set': + set_a = set(a) + set_b = set(b) + if method == 'strict': + return set_a == set_b + else: + return set_b >= set_a + elif datatype == 'set(dict)': + for av in a: + found = False + for bv in b: + if compare_dict_allow_more_present(av, bv): + found = True + break + if not found: + return False + if method == 'strict': + # If we would know that both a and b do not contain duplicates, + # we could simply compare len(a) to len(b) to finish this test. + # We can assume that b has no duplicates (as it is returned by + # docker), but we don't know for a. + for bv in b: + found = False + for av in a: + if compare_dict_allow_more_present(av, bv): + found = True + break + if not found: + return False + return True + + +class DifferenceTracker(object): + def __init__(self): + self._diff = [] + + def add(self, name, parameter=None, active=None): + self._diff.append(dict( + name=name, + parameter=parameter, + active=active, + )) + + def merge(self, other_tracker): + self._diff.extend(other_tracker._diff) + + @property + def empty(self): + return len(self._diff) == 0 + + def get_before_after(self): + ''' + Return texts ``before`` and ``after``. + ''' + before = dict() + after = dict() + for item in self._diff: + before[item['name']] = item['active'] + after[item['name']] = item['parameter'] + return before, after + + def has_difference_for(self, name): + ''' + Returns a boolean if a difference exists for name + ''' + return any(diff for diff in self._diff if diff['name'] == name) + + def get_legacy_docker_container_diffs(self): + ''' + Return differences in the docker_container legacy format. + ''' + result = [] + for entry in self._diff: + item = dict() + item[entry['name']] = dict( + parameter=entry['parameter'], + container=entry['active'], + ) + result.append(item) + return result + + def get_legacy_docker_diffs(self): + ''' + Return differences in the docker_container legacy format. + ''' + result = [entry['name'] for entry in self._diff] + return result + + +def clean_dict_booleans_for_docker_api(data): + ''' + Go doesn't like Python booleans 'True' or 'False', while Ansible is just + fine with them in YAML. As such, they need to be converted in cases where + we pass dictionaries to the Docker API (e.g. docker_network's + driver_options and docker_prune's filters). + ''' + result = dict() + if data is not None: + for k, v in data.items(): + if v is True: + v = 'true' + elif v is False: + v = 'false' + else: + v = str(v) + result[str(k)] = v + return result + + +def convert_duration_to_nanosecond(time_str): + """ + Return time duration in nanosecond. + """ + if not isinstance(time_str, str): + raise ValueError('Missing unit in duration - %s' % time_str) + + regex = re.compile( + r'^(((?P<hours>\d+)h)?' + r'((?P<minutes>\d+)m(?!s))?' + r'((?P<seconds>\d+)s)?' + r'((?P<milliseconds>\d+)ms)?' + r'((?P<microseconds>\d+)us)?)$' + ) + parts = regex.match(time_str) + + if not parts: + raise ValueError('Invalid time duration - %s' % time_str) + + parts = parts.groupdict() + time_params = {} + for (name, value) in parts.items(): + if value: + time_params[name] = int(value) + + delta = timedelta(**time_params) + time_in_nanoseconds = ( + delta.microseconds + (delta.seconds + delta.days * 24 * 3600) * 10 ** 6 + ) * 10 ** 3 + + return time_in_nanoseconds + + +def parse_healthcheck(healthcheck): + """ + Return dictionary of healthcheck parameters and boolean if + healthcheck defined in image was requested to be disabled. + """ + if (not healthcheck) or (not healthcheck.get('test')): + return None, None + + result = dict() + + # All supported healthcheck parameters + options = dict( + test='test', + interval='interval', + timeout='timeout', + start_period='start_period', + retries='retries' + ) + + duration_options = ['interval', 'timeout', 'start_period'] + + for (key, value) in options.items(): + if value in healthcheck: + if healthcheck.get(value) is None: + # due to recursive argument_spec, all keys are always present + # (but have default value None if not specified) + continue + if value in duration_options: + time = convert_duration_to_nanosecond(healthcheck.get(value)) + if time: + result[key] = time + elif healthcheck.get(value): + result[key] = healthcheck.get(value) + if key == 'test': + if isinstance(result[key], (tuple, list)): + result[key] = [str(e) for e in result[key]] + else: + result[key] = ['CMD-SHELL', str(result[key])] + elif key == 'retries': + try: + result[key] = int(result[key]) + except ValueError: + raise ValueError( + 'Cannot parse number of retries for healthcheck. ' + 'Expected an integer, got "{0}".'.format(result[key]) + ) + + if result['test'] == ['NONE']: + # If the user explicitly disables the healthcheck, return None + # as the healthcheck object, and set disable_healthcheck to True + return None, True + + return result, False + + +def omit_none_from_dict(d): + """ + Return a copy of the dictionary with all keys with value None omitted. + """ + return dict((k, v) for (k, v) in d.items() if v is not None) diff --git a/test/support/integration/plugins/module_utils/docker/swarm.py b/test/support/integration/plugins/module_utils/docker/swarm.py new file mode 100644 index 00000000..55d94db0 --- /dev/null +++ b/test/support/integration/plugins/module_utils/docker/swarm.py @@ -0,0 +1,280 @@ +# (c) 2019 Piotr Wojciechowski (@wojciechowskipiotr) <piotr@it-playground.pl> +# (c) Thierry Bouvet (@tbouvet) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +import json +from time import sleep + +try: + from docker.errors import APIError, NotFound +except ImportError: + # missing Docker SDK for Python handled in ansible.module_utils.docker.common + pass + +from ansible.module_utils._text import to_native +from ansible.module_utils.docker.common import ( + AnsibleDockerClient, + LooseVersion, +) + + +class AnsibleDockerSwarmClient(AnsibleDockerClient): + + def __init__(self, **kwargs): + super(AnsibleDockerSwarmClient, self).__init__(**kwargs) + + def get_swarm_node_id(self): + """ + Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID + of Docker host the module is executed on + :return: + NodeID of host or 'None' if not part of Swarm + """ + + try: + info = self.info() + except APIError as exc: + self.fail("Failed to get node information for %s" % to_native(exc)) + + if info: + json_str = json.dumps(info, ensure_ascii=False) + swarm_info = json.loads(json_str) + if swarm_info['Swarm']['NodeID']: + return swarm_info['Swarm']['NodeID'] + return None + + def check_if_swarm_node(self, node_id=None): + """ + Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host + system information looking if specific key in output exists. If 'node_id' is provided then it tries to + read node information assuming it is run on Swarm manager. The get_node_inspect() method handles exception if + it is not executed on Swarm manager + + :param node_id: Node identifier + :return: + bool: True if node is part of Swarm, False otherwise + """ + + if node_id is None: + try: + info = self.info() + except APIError: + self.fail("Failed to get host information.") + + if info: + json_str = json.dumps(info, ensure_ascii=False) + swarm_info = json.loads(json_str) + if swarm_info['Swarm']['NodeID']: + return True + if swarm_info['Swarm']['LocalNodeState'] in ('active', 'pending', 'locked'): + return True + return False + else: + try: + node_info = self.get_node_inspect(node_id=node_id) + except APIError: + return + + if node_info['ID'] is not None: + return True + return False + + def check_if_swarm_manager(self): + """ + Checks if node role is set as Manager in Swarm. The node is the docker host on which module action + is performed. The inspect_swarm() will fail if node is not a manager + + :return: True if node is Swarm Manager, False otherwise + """ + + try: + self.inspect_swarm() + return True + except APIError: + return False + + def fail_task_if_not_swarm_manager(self): + """ + If host is not a swarm manager then Ansible task on this host should end with 'failed' state + """ + if not self.check_if_swarm_manager(): + self.fail("Error running docker swarm module: must run on swarm manager node") + + def check_if_swarm_worker(self): + """ + Checks if node role is set as Worker in Swarm. The node is the docker host on which module action + is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node() + + :return: True if node is Swarm Worker, False otherwise + """ + + if self.check_if_swarm_node() and not self.check_if_swarm_manager(): + return True + return False + + def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1): + """ + Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about + node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or + host that is not part of Swarm it will fail the playbook + + :param repeat_check: number of check attempts with 5 seconds delay between them, by default check only once + :param node_id: node ID or name, if None then method will try to get node_id of host module run on + :return: + True if node is part of swarm but its state is down, False otherwise + """ + + if repeat_check < 1: + repeat_check = 1 + + if node_id is None: + node_id = self.get_swarm_node_id() + + for retry in range(0, repeat_check): + if retry > 0: + sleep(5) + node_info = self.get_node_inspect(node_id=node_id) + if node_info['Status']['State'] == 'down': + return True + return False + + def get_node_inspect(self, node_id=None, skip_missing=False): + """ + Returns Swarm node info as in 'docker node inspect' command about single node + + :param skip_missing: if True then function will return None instead of failing the task + :param node_id: node ID or name, if None then method will try to get node_id of host module run on + :return: + Single node information structure + """ + + if node_id is None: + node_id = self.get_swarm_node_id() + + if node_id is None: + self.fail("Failed to get node information.") + + try: + node_info = self.inspect_node(node_id=node_id) + except APIError as exc: + if exc.status_code == 503: + self.fail("Cannot inspect node: To inspect node execute module on Swarm Manager") + if exc.status_code == 404: + if skip_missing: + return None + self.fail("Error while reading from Swarm manager: %s" % to_native(exc)) + except Exception as exc: + self.fail("Error inspecting swarm node: %s" % exc) + + json_str = json.dumps(node_info, ensure_ascii=False) + node_info = json.loads(json_str) + + if 'ManagerStatus' in node_info: + if node_info['ManagerStatus'].get('Leader'): + # This is workaround of bug in Docker when in some cases the Leader IP is 0.0.0.0 + # Check moby/moby#35437 for details + count_colons = node_info['ManagerStatus']['Addr'].count(":") + if count_colons == 1: + swarm_leader_ip = node_info['ManagerStatus']['Addr'].split(":", 1)[0] or node_info['Status']['Addr'] + else: + swarm_leader_ip = node_info['Status']['Addr'] + node_info['Status']['Addr'] = swarm_leader_ip + return node_info + + def get_all_nodes_inspect(self): + """ + Returns Swarm node info as in 'docker node inspect' command about all registered nodes + + :return: + Structure with information about all nodes + """ + try: + node_info = self.nodes() + except APIError as exc: + if exc.status_code == 503: + self.fail("Cannot inspect node: To inspect node execute module on Swarm Manager") + self.fail("Error while reading from Swarm manager: %s" % to_native(exc)) + except Exception as exc: + self.fail("Error inspecting swarm node: %s" % exc) + + json_str = json.dumps(node_info, ensure_ascii=False) + node_info = json.loads(json_str) + return node_info + + def get_all_nodes_list(self, output='short'): + """ + Returns list of nodes registered in Swarm + + :param output: Defines format of returned data + :return: + If 'output' is 'short' then return data is list of nodes hostnames registered in Swarm, + if 'output' is 'long' then returns data is list of dict containing the attributes as in + output of command 'docker node ls' + """ + nodes_list = [] + + nodes_inspect = self.get_all_nodes_inspect() + if nodes_inspect is None: + return None + + if output == 'short': + for node in nodes_inspect: + nodes_list.append(node['Description']['Hostname']) + elif output == 'long': + for node in nodes_inspect: + node_property = {} + + node_property.update({'ID': node['ID']}) + node_property.update({'Hostname': node['Description']['Hostname']}) + node_property.update({'Status': node['Status']['State']}) + node_property.update({'Availability': node['Spec']['Availability']}) + if 'ManagerStatus' in node: + if node['ManagerStatus']['Leader'] is True: + node_property.update({'Leader': True}) + node_property.update({'ManagerStatus': node['ManagerStatus']['Reachability']}) + node_property.update({'EngineVersion': node['Description']['Engine']['EngineVersion']}) + + nodes_list.append(node_property) + else: + return None + + return nodes_list + + def get_node_name_by_id(self, nodeid): + return self.get_node_inspect(nodeid)['Description']['Hostname'] + + def get_unlock_key(self): + if self.docker_py_version < LooseVersion('2.7.0'): + return None + return super(AnsibleDockerSwarmClient, self).get_unlock_key() + + def get_service_inspect(self, service_id, skip_missing=False): + """ + Returns Swarm service info as in 'docker service inspect' command about single service + + :param service_id: service ID or name + :param skip_missing: if True then function will return None instead of failing the task + :return: + Single service information structure + """ + try: + service_info = self.inspect_service(service_id) + except NotFound as exc: + if skip_missing is False: + self.fail("Error while reading from Swarm manager: %s" % to_native(exc)) + else: + return None + except APIError as exc: + if exc.status_code == 503: + self.fail("Cannot inspect service: To inspect service execute module on Swarm Manager") + self.fail("Error inspecting swarm service: %s" % exc) + except Exception as exc: + self.fail("Error inspecting swarm service: %s" % exc) + + json_str = json.dumps(service_info, ensure_ascii=False) + service_info = json.loads(json_str) + return service_info diff --git a/test/support/integration/plugins/module_utils/ec2.py b/test/support/integration/plugins/module_utils/ec2.py new file mode 100644 index 00000000..0d28108d --- /dev/null +++ b/test/support/integration/plugins/module_utils/ec2.py @@ -0,0 +1,758 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013 +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re +import sys +import traceback + +from ansible.module_utils.ansible_release import __version__ +from ansible.module_utils.basic import missing_required_lib, env_fallback +from ansible.module_utils._text import to_native, to_text +from ansible.module_utils.cloud import CloudRetry +from ansible.module_utils.six import string_types, binary_type, text_type +from ansible.module_utils.common.dict_transformations import ( + camel_dict_to_snake_dict, snake_dict_to_camel_dict, + _camel_to_snake, _snake_to_camel, +) + +BOTO_IMP_ERR = None +try: + import boto + import boto.ec2 # boto does weird import stuff + HAS_BOTO = True +except ImportError: + BOTO_IMP_ERR = traceback.format_exc() + HAS_BOTO = False + +BOTO3_IMP_ERR = None +try: + import boto3 + import botocore + HAS_BOTO3 = True +except Exception: + BOTO3_IMP_ERR = traceback.format_exc() + HAS_BOTO3 = False + +try: + # Although this is to allow Python 3 the ability to use the custom comparison as a key, Python 2.7 also + # uses this (and it works as expected). Python 2.6 will trigger the ImportError. + from functools import cmp_to_key + PY3_COMPARISON = True +except ImportError: + PY3_COMPARISON = False + + +class AnsibleAWSError(Exception): + pass + + +def _botocore_exception_maybe(): + """ + Allow for boto3 not being installed when using these utils by wrapping + botocore.exceptions instead of assigning from it directly. + """ + if HAS_BOTO3: + return botocore.exceptions.ClientError + return type(None) + + +class AWSRetry(CloudRetry): + base_class = _botocore_exception_maybe() + + @staticmethod + def status_code_from_exception(error): + return error.response['Error']['Code'] + + @staticmethod + def found(response_code, catch_extra_error_codes=None): + # This list of failures is based on this API Reference + # http://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html + # + # TooManyRequestsException comes from inside botocore when it + # does retrys, unfortunately however it does not try long + # enough to allow some services such as API Gateway to + # complete configuration. At the moment of writing there is a + # botocore/boto3 bug open to fix this. + # + # https://github.com/boto/boto3/issues/876 (and linked PRs etc) + retry_on = [ + 'RequestLimitExceeded', 'Unavailable', 'ServiceUnavailable', + 'InternalFailure', 'InternalError', 'TooManyRequestsException', + 'Throttling' + ] + if catch_extra_error_codes: + retry_on.extend(catch_extra_error_codes) + + return response_code in retry_on + + +def boto3_conn(module, conn_type=None, resource=None, region=None, endpoint=None, **params): + try: + return _boto3_conn(conn_type=conn_type, resource=resource, region=region, endpoint=endpoint, **params) + except ValueError as e: + module.fail_json(msg="Couldn't connect to AWS: %s" % to_native(e)) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError, + botocore.exceptions.NoCredentialsError, botocore.exceptions.ConfigParseError) as e: + module.fail_json(msg=to_native(e)) + except botocore.exceptions.NoRegionError as e: + module.fail_json(msg="The %s module requires a region and none was found in configuration, " + "environment variables or module parameters" % module._name) + + +def _boto3_conn(conn_type=None, resource=None, region=None, endpoint=None, **params): + profile = params.pop('profile_name', None) + + if conn_type not in ['both', 'resource', 'client']: + raise ValueError('There is an issue in the calling code. You ' + 'must specify either both, resource, or client to ' + 'the conn_type parameter in the boto3_conn function ' + 'call') + + config = botocore.config.Config( + user_agent_extra='Ansible/{0}'.format(__version__), + ) + + if params.get('config') is not None: + config = config.merge(params.pop('config')) + if params.get('aws_config') is not None: + config = config.merge(params.pop('aws_config')) + + session = boto3.session.Session( + profile_name=profile, + ) + + if conn_type == 'resource': + return session.resource(resource, config=config, region_name=region, endpoint_url=endpoint, **params) + elif conn_type == 'client': + return session.client(resource, config=config, region_name=region, endpoint_url=endpoint, **params) + else: + client = session.client(resource, region_name=region, endpoint_url=endpoint, **params) + resource = session.resource(resource, region_name=region, endpoint_url=endpoint, **params) + return client, resource + + +boto3_inventory_conn = _boto3_conn + + +def boto_exception(err): + """ + Extracts the error message from a boto exception. + + :param err: Exception from boto + :return: Error message + """ + if hasattr(err, 'error_message'): + error = err.error_message + elif hasattr(err, 'message'): + error = str(err.message) + ' ' + str(err) + ' - ' + str(type(err)) + else: + error = '%s: %s' % (Exception, err) + + return error + + +def aws_common_argument_spec(): + return dict( + debug_botocore_endpoint_logs=dict(fallback=(env_fallback, ['ANSIBLE_DEBUG_BOTOCORE_LOGS']), default=False, type='bool'), + ec2_url=dict(), + aws_secret_key=dict(aliases=['ec2_secret_key', 'secret_key'], no_log=True), + aws_access_key=dict(aliases=['ec2_access_key', 'access_key']), + validate_certs=dict(default=True, type='bool'), + security_token=dict(aliases=['access_token'], no_log=True), + profile=dict(), + aws_config=dict(type='dict'), + ) + + +def ec2_argument_spec(): + spec = aws_common_argument_spec() + spec.update( + dict( + region=dict(aliases=['aws_region', 'ec2_region']), + ) + ) + return spec + + +def get_aws_region(module, boto3=False): + region = module.params.get('region') + + if region: + return region + + if 'AWS_REGION' in os.environ: + return os.environ['AWS_REGION'] + if 'AWS_DEFAULT_REGION' in os.environ: + return os.environ['AWS_DEFAULT_REGION'] + if 'EC2_REGION' in os.environ: + return os.environ['EC2_REGION'] + + if not boto3: + if not HAS_BOTO: + module.fail_json(msg=missing_required_lib('boto'), exception=BOTO_IMP_ERR) + # boto.config.get returns None if config not found + region = boto.config.get('Boto', 'aws_region') + if region: + return region + return boto.config.get('Boto', 'ec2_region') + + if not HAS_BOTO3: + module.fail_json(msg=missing_required_lib('boto3'), exception=BOTO3_IMP_ERR) + + # here we don't need to make an additional call, will default to 'us-east-1' if the below evaluates to None. + try: + profile_name = module.params.get('profile') + return botocore.session.Session(profile=profile_name).get_config_variable('region') + except botocore.exceptions.ProfileNotFound as e: + return None + + +def get_aws_connection_info(module, boto3=False): + + # Check module args for credentials, then check environment vars + # access_key + + ec2_url = module.params.get('ec2_url') + access_key = module.params.get('aws_access_key') + secret_key = module.params.get('aws_secret_key') + security_token = module.params.get('security_token') + region = get_aws_region(module, boto3) + profile_name = module.params.get('profile') + validate_certs = module.params.get('validate_certs') + config = module.params.get('aws_config') + + if not ec2_url: + if 'AWS_URL' in os.environ: + ec2_url = os.environ['AWS_URL'] + elif 'EC2_URL' in os.environ: + ec2_url = os.environ['EC2_URL'] + + if not access_key: + if os.environ.get('AWS_ACCESS_KEY_ID'): + access_key = os.environ['AWS_ACCESS_KEY_ID'] + elif os.environ.get('AWS_ACCESS_KEY'): + access_key = os.environ['AWS_ACCESS_KEY'] + elif os.environ.get('EC2_ACCESS_KEY'): + access_key = os.environ['EC2_ACCESS_KEY'] + elif HAS_BOTO and boto.config.get('Credentials', 'aws_access_key_id'): + access_key = boto.config.get('Credentials', 'aws_access_key_id') + elif HAS_BOTO and boto.config.get('default', 'aws_access_key_id'): + access_key = boto.config.get('default', 'aws_access_key_id') + else: + # in case access_key came in as empty string + access_key = None + + if not secret_key: + if os.environ.get('AWS_SECRET_ACCESS_KEY'): + secret_key = os.environ['AWS_SECRET_ACCESS_KEY'] + elif os.environ.get('AWS_SECRET_KEY'): + secret_key = os.environ['AWS_SECRET_KEY'] + elif os.environ.get('EC2_SECRET_KEY'): + secret_key = os.environ['EC2_SECRET_KEY'] + elif HAS_BOTO and boto.config.get('Credentials', 'aws_secret_access_key'): + secret_key = boto.config.get('Credentials', 'aws_secret_access_key') + elif HAS_BOTO and boto.config.get('default', 'aws_secret_access_key'): + secret_key = boto.config.get('default', 'aws_secret_access_key') + else: + # in case secret_key came in as empty string + secret_key = None + + if not security_token: + if os.environ.get('AWS_SECURITY_TOKEN'): + security_token = os.environ['AWS_SECURITY_TOKEN'] + elif os.environ.get('AWS_SESSION_TOKEN'): + security_token = os.environ['AWS_SESSION_TOKEN'] + elif os.environ.get('EC2_SECURITY_TOKEN'): + security_token = os.environ['EC2_SECURITY_TOKEN'] + elif HAS_BOTO and boto.config.get('Credentials', 'aws_security_token'): + security_token = boto.config.get('Credentials', 'aws_security_token') + elif HAS_BOTO and boto.config.get('default', 'aws_security_token'): + security_token = boto.config.get('default', 'aws_security_token') + else: + # in case secret_token came in as empty string + security_token = None + + if HAS_BOTO3 and boto3: + boto_params = dict(aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + aws_session_token=security_token) + boto_params['verify'] = validate_certs + + if profile_name: + boto_params = dict(aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) + boto_params['profile_name'] = profile_name + + else: + boto_params = dict(aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + security_token=security_token) + + # only set profile_name if passed as an argument + if profile_name: + boto_params['profile_name'] = profile_name + + boto_params['validate_certs'] = validate_certs + + if config is not None: + if HAS_BOTO3 and boto3: + boto_params['aws_config'] = botocore.config.Config(**config) + elif HAS_BOTO and not boto3: + if 'user_agent' in config: + sys.modules["boto.connection"].UserAgent = config['user_agent'] + + for param, value in boto_params.items(): + if isinstance(value, binary_type): + boto_params[param] = text_type(value, 'utf-8', 'strict') + + return region, ec2_url, boto_params + + +def get_ec2_creds(module): + ''' for compatibility mode with old modules that don't/can't yet + use ec2_connect method ''' + region, ec2_url, boto_params = get_aws_connection_info(module) + return ec2_url, boto_params['aws_access_key_id'], boto_params['aws_secret_access_key'], region + + +def boto_fix_security_token_in_profile(conn, profile_name): + ''' monkey patch for boto issue boto/boto#2100 ''' + profile = 'profile ' + profile_name + if boto.config.has_option(profile, 'aws_security_token'): + conn.provider.set_security_token(boto.config.get(profile, 'aws_security_token')) + return conn + + +def connect_to_aws(aws_module, region, **params): + try: + conn = aws_module.connect_to_region(region, **params) + except(boto.provider.ProfileNotFoundError): + raise AnsibleAWSError("Profile given for AWS was not found. Please fix and retry.") + if not conn: + if region not in [aws_module_region.name for aws_module_region in aws_module.regions()]: + raise AnsibleAWSError("Region %s does not seem to be available for aws module %s. If the region definitely exists, you may need to upgrade " + "boto or extend with endpoints_path" % (region, aws_module.__name__)) + else: + raise AnsibleAWSError("Unknown problem connecting to region %s for aws module %s." % (region, aws_module.__name__)) + if params.get('profile_name'): + conn = boto_fix_security_token_in_profile(conn, params['profile_name']) + return conn + + +def ec2_connect(module): + + """ Return an ec2 connection""" + + region, ec2_url, boto_params = get_aws_connection_info(module) + + # If we have a region specified, connect to its endpoint. + if region: + try: + ec2 = connect_to_aws(boto.ec2, region, **boto_params) + except (boto.exception.NoAuthHandlerFound, AnsibleAWSError, boto.provider.ProfileNotFoundError) as e: + module.fail_json(msg=str(e)) + # Otherwise, no region so we fallback to the old connection method + elif ec2_url: + try: + ec2 = boto.connect_ec2_endpoint(ec2_url, **boto_params) + except (boto.exception.NoAuthHandlerFound, AnsibleAWSError, boto.provider.ProfileNotFoundError) as e: + module.fail_json(msg=str(e)) + else: + module.fail_json(msg="Either region or ec2_url must be specified") + + return ec2 + + +def ansible_dict_to_boto3_filter_list(filters_dict): + + """ Convert an Ansible dict of filters to list of dicts that boto3 can use + Args: + filters_dict (dict): Dict of AWS filters. + Basic Usage: + >>> filters = {'some-aws-id': 'i-01234567'} + >>> ansible_dict_to_boto3_filter_list(filters) + { + 'some-aws-id': 'i-01234567' + } + Returns: + List: List of AWS filters and their values + [ + { + 'Name': 'some-aws-id', + 'Values': [ + 'i-01234567', + ] + } + ] + """ + + filters_list = [] + for k, v in filters_dict.items(): + filter_dict = {'Name': k} + if isinstance(v, string_types): + filter_dict['Values'] = [v] + else: + filter_dict['Values'] = v + + filters_list.append(filter_dict) + + return filters_list + + +def boto3_tag_list_to_ansible_dict(tags_list, tag_name_key_name=None, tag_value_key_name=None): + + """ Convert a boto3 list of resource tags to a flat dict of key:value pairs + Args: + tags_list (list): List of dicts representing AWS tags. + tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key") + tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value") + Basic Usage: + >>> tags_list = [{'Key': 'MyTagKey', 'Value': 'MyTagValue'}] + >>> boto3_tag_list_to_ansible_dict(tags_list) + [ + { + 'Key': 'MyTagKey', + 'Value': 'MyTagValue' + } + ] + Returns: + Dict: Dict of key:value pairs representing AWS tags + { + 'MyTagKey': 'MyTagValue', + } + """ + + if tag_name_key_name and tag_value_key_name: + tag_candidates = {tag_name_key_name: tag_value_key_name} + else: + tag_candidates = {'key': 'value', 'Key': 'Value'} + + if not tags_list: + return {} + for k, v in tag_candidates.items(): + if k in tags_list[0] and v in tags_list[0]: + return dict((tag[k], tag[v]) for tag in tags_list) + raise ValueError("Couldn't find tag key (candidates %s) in tag list %s" % (str(tag_candidates), str(tags_list))) + + +def ansible_dict_to_boto3_tag_list(tags_dict, tag_name_key_name='Key', tag_value_key_name='Value'): + + """ Convert a flat dict of key:value pairs representing AWS resource tags to a boto3 list of dicts + Args: + tags_dict (dict): Dict representing AWS resource tags. + tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key") + tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value") + Basic Usage: + >>> tags_dict = {'MyTagKey': 'MyTagValue'} + >>> ansible_dict_to_boto3_tag_list(tags_dict) + { + 'MyTagKey': 'MyTagValue' + } + Returns: + List: List of dicts containing tag keys and values + [ + { + 'Key': 'MyTagKey', + 'Value': 'MyTagValue' + } + ] + """ + + tags_list = [] + for k, v in tags_dict.items(): + tags_list.append({tag_name_key_name: k, tag_value_key_name: to_native(v)}) + + return tags_list + + +def get_ec2_security_group_ids_from_names(sec_group_list, ec2_connection, vpc_id=None, boto3=True): + + """ Return list of security group IDs from security group names. Note that security group names are not unique + across VPCs. If a name exists across multiple VPCs and no VPC ID is supplied, all matching IDs will be returned. This + will probably lead to a boto exception if you attempt to assign both IDs to a resource so ensure you wrap the call in + a try block + """ + + def get_sg_name(sg, boto3): + + if boto3: + return sg['GroupName'] + else: + return sg.name + + def get_sg_id(sg, boto3): + + if boto3: + return sg['GroupId'] + else: + return sg.id + + sec_group_id_list = [] + + if isinstance(sec_group_list, string_types): + sec_group_list = [sec_group_list] + + # Get all security groups + if boto3: + if vpc_id: + filters = [ + { + 'Name': 'vpc-id', + 'Values': [ + vpc_id, + ] + } + ] + all_sec_groups = ec2_connection.describe_security_groups(Filters=filters)['SecurityGroups'] + else: + all_sec_groups = ec2_connection.describe_security_groups()['SecurityGroups'] + else: + if vpc_id: + filters = {'vpc-id': vpc_id} + all_sec_groups = ec2_connection.get_all_security_groups(filters=filters) + else: + all_sec_groups = ec2_connection.get_all_security_groups() + + unmatched = set(sec_group_list).difference(str(get_sg_name(all_sg, boto3)) for all_sg in all_sec_groups) + sec_group_name_list = list(set(sec_group_list) - set(unmatched)) + + if len(unmatched) > 0: + # If we have unmatched names that look like an ID, assume they are + import re + sec_group_id_list = [sg for sg in unmatched if re.match('sg-[a-fA-F0-9]+$', sg)] + still_unmatched = [sg for sg in unmatched if not re.match('sg-[a-fA-F0-9]+$', sg)] + if len(still_unmatched) > 0: + raise ValueError("The following group names are not valid: %s" % ', '.join(still_unmatched)) + + sec_group_id_list += [str(get_sg_id(all_sg, boto3)) for all_sg in all_sec_groups if str(get_sg_name(all_sg, boto3)) in sec_group_name_list] + + return sec_group_id_list + + +def _hashable_policy(policy, policy_list): + """ + Takes a policy and returns a list, the contents of which are all hashable and sorted. + Example input policy: + {'Version': '2012-10-17', + 'Statement': [{'Action': 's3:PutObjectAcl', + 'Sid': 'AddCannedAcl2', + 'Resource': 'arn:aws:s3:::test_policy/*', + 'Effect': 'Allow', + 'Principal': {'AWS': ['arn:aws:iam::XXXXXXXXXXXX:user/username1', 'arn:aws:iam::XXXXXXXXXXXX:user/username2']} + }]} + Returned value: + [('Statement', ((('Action', (u's3:PutObjectAcl',)), + ('Effect', (u'Allow',)), + ('Principal', ('AWS', ((u'arn:aws:iam::XXXXXXXXXXXX:user/username1',), (u'arn:aws:iam::XXXXXXXXXXXX:user/username2',)))), + ('Resource', (u'arn:aws:s3:::test_policy/*',)), ('Sid', (u'AddCannedAcl2',)))), + ('Version', (u'2012-10-17',)))] + + """ + # Amazon will automatically convert bool and int to strings for us + if isinstance(policy, bool): + return tuple([str(policy).lower()]) + elif isinstance(policy, int): + return tuple([str(policy)]) + + if isinstance(policy, list): + for each in policy: + tupleified = _hashable_policy(each, []) + if isinstance(tupleified, list): + tupleified = tuple(tupleified) + policy_list.append(tupleified) + elif isinstance(policy, string_types) or isinstance(policy, binary_type): + policy = to_text(policy) + # convert root account ARNs to just account IDs + if policy.startswith('arn:aws:iam::') and policy.endswith(':root'): + policy = policy.split(':')[4] + return [policy] + elif isinstance(policy, dict): + sorted_keys = list(policy.keys()) + sorted_keys.sort() + for key in sorted_keys: + tupleified = _hashable_policy(policy[key], []) + if isinstance(tupleified, list): + tupleified = tuple(tupleified) + policy_list.append((key, tupleified)) + + # ensure we aren't returning deeply nested structures of length 1 + if len(policy_list) == 1 and isinstance(policy_list[0], tuple): + policy_list = policy_list[0] + if isinstance(policy_list, list): + if PY3_COMPARISON: + policy_list.sort(key=cmp_to_key(py3cmp)) + else: + policy_list.sort() + return policy_list + + +def py3cmp(a, b): + """ Python 2 can sort lists of mixed types. Strings < tuples. Without this function this fails on Python 3.""" + try: + if a > b: + return 1 + elif a < b: + return -1 + else: + return 0 + except TypeError as e: + # check to see if they're tuple-string + # always say strings are less than tuples (to maintain compatibility with python2) + str_ind = to_text(e).find('str') + tup_ind = to_text(e).find('tuple') + if -1 not in (str_ind, tup_ind): + if str_ind < tup_ind: + return -1 + elif tup_ind < str_ind: + return 1 + raise + + +def compare_policies(current_policy, new_policy): + """ Compares the existing policy and the updated policy + Returns True if there is a difference between policies. + """ + return set(_hashable_policy(new_policy, [])) != set(_hashable_policy(current_policy, [])) + + +def sort_json_policy_dict(policy_dict): + + """ Sort any lists in an IAM JSON policy so that comparison of two policies with identical values but + different orders will return true + Args: + policy_dict (dict): Dict representing IAM JSON policy. + Basic Usage: + >>> my_iam_policy = {'Principle': {'AWS':["31","7","14","101"]} + >>> sort_json_policy_dict(my_iam_policy) + Returns: + Dict: Will return a copy of the policy as a Dict but any List will be sorted + { + 'Principle': { + 'AWS': [ '7', '14', '31', '101' ] + } + } + """ + + def value_is_list(my_list): + + checked_list = [] + for item in my_list: + if isinstance(item, dict): + checked_list.append(sort_json_policy_dict(item)) + elif isinstance(item, list): + checked_list.append(value_is_list(item)) + else: + checked_list.append(item) + + # Sort list. If it's a list of dictionaries, sort by tuple of key-value + # pairs, since Python 3 doesn't allow comparisons such as `<` between dictionaries. + checked_list.sort(key=lambda x: sorted(x.items()) if isinstance(x, dict) else x) + return checked_list + + ordered_policy_dict = {} + for key, value in policy_dict.items(): + if isinstance(value, dict): + ordered_policy_dict[key] = sort_json_policy_dict(value) + elif isinstance(value, list): + ordered_policy_dict[key] = value_is_list(value) + else: + ordered_policy_dict[key] = value + + return ordered_policy_dict + + +def map_complex_type(complex_type, type_map): + """ + Allows to cast elements within a dictionary to a specific type + Example of usage: + + DEPLOYMENT_CONFIGURATION_TYPE_MAP = { + 'maximum_percent': 'int', + 'minimum_healthy_percent': 'int' + } + + deployment_configuration = map_complex_type(module.params['deployment_configuration'], + DEPLOYMENT_CONFIGURATION_TYPE_MAP) + + This ensures all keys within the root element are casted and valid integers + """ + + if complex_type is None: + return + new_type = type(complex_type)() + if isinstance(complex_type, dict): + for key in complex_type: + if key in type_map: + if isinstance(type_map[key], list): + new_type[key] = map_complex_type( + complex_type[key], + type_map[key][0]) + else: + new_type[key] = map_complex_type( + complex_type[key], + type_map[key]) + else: + return complex_type + elif isinstance(complex_type, list): + for i in range(len(complex_type)): + new_type.append(map_complex_type( + complex_type[i], + type_map)) + elif type_map: + return globals()['__builtins__'][type_map](complex_type) + return new_type + + +def compare_aws_tags(current_tags_dict, new_tags_dict, purge_tags=True): + """ + Compare two dicts of AWS tags. Dicts are expected to of been created using 'boto3_tag_list_to_ansible_dict' helper function. + Two dicts are returned - the first is tags to be set, the second is any tags to remove. Since the AWS APIs differ + these may not be able to be used out of the box. + + :param current_tags_dict: + :param new_tags_dict: + :param purge_tags: + :return: tag_key_value_pairs_to_set: a dict of key value pairs that need to be set in AWS. If all tags are identical this dict will be empty + :return: tag_keys_to_unset: a list of key names (type str) that need to be unset in AWS. If no tags need to be unset this list will be empty + """ + + tag_key_value_pairs_to_set = {} + tag_keys_to_unset = [] + + for key in current_tags_dict.keys(): + if key not in new_tags_dict and purge_tags: + tag_keys_to_unset.append(key) + + for key in set(new_tags_dict.keys()) - set(tag_keys_to_unset): + if to_text(new_tags_dict[key]) != current_tags_dict.get(key): + tag_key_value_pairs_to_set[key] = new_tags_dict[key] + + return tag_key_value_pairs_to_set, tag_keys_to_unset diff --git a/test/support/integration/plugins/module_utils/ecs/__init__.py b/test/support/integration/plugins/module_utils/ecs/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/ecs/__init__.py diff --git a/test/support/integration/plugins/module_utils/ecs/api.py b/test/support/integration/plugins/module_utils/ecs/api.py new file mode 100644 index 00000000..d89b0333 --- /dev/null +++ b/test/support/integration/plugins/module_utils/ecs/api.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is licensed under the +# Modified BSD License. Modules you write using this snippet, which is embedded +# dynamically by Ansible, still belong to the author of the module, and may assign +# their own license to the complete work. +# +# Copyright (c), Entrust Datacard Corporation, 2019 +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import json +import os +import re +import time +import traceback + +from ansible.module_utils._text import to_text, to_native +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.six.moves.urllib.parse import urlencode +from ansible.module_utils.six.moves.urllib.error import HTTPError +from ansible.module_utils.urls import Request + +YAML_IMP_ERR = None +try: + import yaml +except ImportError: + YAML_FOUND = False + YAML_IMP_ERR = traceback.format_exc() +else: + YAML_FOUND = True + +valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$") + + +def ecs_client_argument_spec(): + return dict( + entrust_api_user=dict(type='str', required=True), + entrust_api_key=dict(type='str', required=True, no_log=True), + entrust_api_client_cert_path=dict(type='path', required=True), + entrust_api_client_cert_key_path=dict(type='path', required=True, no_log=True), + entrust_api_specification_path=dict(type='path', default='https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml'), + ) + + +class SessionConfigurationException(Exception): + """ Raised if we cannot configure a session with the API """ + + pass + + +class RestOperationException(Exception): + """ Encapsulate a REST API error """ + + def __init__(self, error): + self.status = to_native(error.get("status", None)) + self.errors = [to_native(err.get("message")) for err in error.get("errors", {})] + self.message = to_native(" ".join(self.errors)) + + +def generate_docstring(operation_spec): + """Generate a docstring for an operation defined in operation_spec (swagger)""" + # Description of the operation + docs = operation_spec.get("description", "No Description") + docs += "\n\n" + + # Parameters of the operation + parameters = operation_spec.get("parameters", []) + if len(parameters) != 0: + docs += "\tArguments:\n\n" + for parameter in parameters: + docs += "{0} ({1}:{2}): {3}\n".format( + parameter.get("name"), + parameter.get("type", "No Type"), + "Required" if parameter.get("required", False) else "Not Required", + parameter.get("description"), + ) + + return docs + + +def bind(instance, method, operation_spec): + def binding_scope_fn(*args, **kwargs): + return method(instance, *args, **kwargs) + + # Make sure we don't confuse users; add the proper name and documentation to the function. + # Users can use !help(<function>) to get help on the function from interactive python or pdb + operation_name = operation_spec.get("operationId").split("Using")[0] + binding_scope_fn.__name__ = str(operation_name) + binding_scope_fn.__doc__ = generate_docstring(operation_spec) + + return binding_scope_fn + + +class RestOperation(object): + def __init__(self, session, uri, method, parameters=None): + self.session = session + self.method = method + if parameters is None: + self.parameters = {} + else: + self.parameters = parameters + self.url = "{scheme}://{host}{base_path}{uri}".format(scheme="https", host=session._spec.get("host"), base_path=session._spec.get("basePath"), uri=uri) + + def restmethod(self, *args, **kwargs): + """Do the hard work of making the request here""" + + # gather named path parameters and do substitution on the URL + if self.parameters: + path_parameters = {} + body_parameters = {} + query_parameters = {} + for x in self.parameters: + expected_location = x.get("in") + key_name = x.get("name", None) + key_value = kwargs.get(key_name, None) + if expected_location == "path" and key_name and key_value: + path_parameters.update({key_name: key_value}) + elif expected_location == "body" and key_name and key_value: + body_parameters.update({key_name: key_value}) + elif expected_location == "query" and key_name and key_value: + query_parameters.update({key_name: key_value}) + + if len(body_parameters.keys()) >= 1: + body_parameters = body_parameters.get(list(body_parameters.keys())[0]) + else: + body_parameters = None + else: + path_parameters = {} + query_parameters = {} + body_parameters = None + + # This will fail if we have not set path parameters with a KeyError + url = self.url.format(**path_parameters) + if query_parameters: + # modify the URL to add path parameters + url = url + "?" + urlencode(query_parameters) + + try: + if body_parameters: + body_parameters_json = json.dumps(body_parameters) + response = self.session.request.open(method=self.method, url=url, data=body_parameters_json) + else: + response = self.session.request.open(method=self.method, url=url) + request_error = False + except HTTPError as e: + # An HTTPError has the same methods available as a valid response from request.open + response = e + request_error = True + + # Return the result if JSON and success ({} for empty responses) + # Raise an exception if there was a failure. + try: + result_code = response.getcode() + result = json.loads(response.read()) + except ValueError: + result = {} + + if result or result == {}: + if result_code and result_code < 400: + return result + else: + raise RestOperationException(result) + + # Raise a generic RestOperationException if this fails + raise RestOperationException({"status": result_code, "errors": [{"message": "REST Operation Failed"}]}) + + +class Resource(object): + """ Implement basic CRUD operations against a path. """ + + def __init__(self, session): + self.session = session + self.parameters = {} + + for url in session._spec.get("paths").keys(): + methods = session._spec.get("paths").get(url) + for method in methods.keys(): + operation_spec = methods.get(method) + operation_name = operation_spec.get("operationId", None) + parameters = operation_spec.get("parameters") + + if not operation_name: + if method.lower() == "post": + operation_name = "Create" + elif method.lower() == "get": + operation_name = "Get" + elif method.lower() == "put": + operation_name = "Update" + elif method.lower() == "delete": + operation_name = "Delete" + elif method.lower() == "patch": + operation_name = "Patch" + else: + raise SessionConfigurationException(to_native("Invalid REST method type {0}".format(method))) + + # Get the non-parameter parts of the URL and append to the operation name + # e.g /application/version -> GetApplicationVersion + # e.g. /application/{id} -> GetApplication + # This may lead to duplicates, which we must prevent. + operation_name += re.sub(r"{(.*)}", "", url).replace("/", " ").title().replace(" ", "") + operation_spec["operationId"] = operation_name + + op = RestOperation(session, url, method, parameters) + setattr(self, operation_name, bind(self, op.restmethod, operation_spec)) + + +# Session to encapsulate the connection parameters of the module_utils Request object, the api spec, etc +class ECSSession(object): + def __init__(self, name, **kwargs): + """ + Initialize our session + """ + + self._set_config(name, **kwargs) + + def client(self): + resource = Resource(self) + return resource + + def _set_config(self, name, **kwargs): + headers = { + "Content-Type": "application/json", + "Connection": "keep-alive", + } + self.request = Request(headers=headers, timeout=60) + + configurators = [self._read_config_vars] + for configurator in configurators: + self._config = configurator(name, **kwargs) + if self._config: + break + if self._config is None: + raise SessionConfigurationException(to_native("No Configuration Found.")) + + # set up auth if passed + entrust_api_user = self.get_config("entrust_api_user") + entrust_api_key = self.get_config("entrust_api_key") + if entrust_api_user and entrust_api_key: + self.request.url_username = entrust_api_user + self.request.url_password = entrust_api_key + else: + raise SessionConfigurationException(to_native("User and key must be provided.")) + + # set up client certificate if passed (support all-in one or cert + key) + entrust_api_cert = self.get_config("entrust_api_cert") + entrust_api_cert_key = self.get_config("entrust_api_cert_key") + if entrust_api_cert: + self.request.client_cert = entrust_api_cert + if entrust_api_cert_key: + self.request.client_key = entrust_api_cert_key + else: + raise SessionConfigurationException(to_native("Client certificate for authentication to the API must be provided.")) + + # set up the spec + entrust_api_specification_path = self.get_config("entrust_api_specification_path") + + if not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path): + raise SessionConfigurationException(to_native("OpenAPI specification was not found at location {0}.".format(entrust_api_specification_path))) + if not valid_file_format.match(entrust_api_specification_path): + raise SessionConfigurationException(to_native("OpenAPI specification filename must end in .json, .yml or .yaml")) + + self.verify = True + + if entrust_api_specification_path.startswith("http"): + try: + http_response = Request().open(method="GET", url=entrust_api_specification_path) + http_response_contents = http_response.read() + if entrust_api_specification_path.endswith(".json"): + self._spec = json.load(http_response_contents) + elif entrust_api_specification_path.endswith(".yml") or entrust_api_specification_path.endswith(".yaml"): + self._spec = yaml.safe_load(http_response_contents) + except HTTPError as e: + raise SessionConfigurationException(to_native("Error downloading specification from address '{0}', received error code '{1}'".format( + entrust_api_specification_path, e.getcode()))) + else: + with open(entrust_api_specification_path) as f: + if ".json" in entrust_api_specification_path: + self._spec = json.load(f) + elif ".yml" in entrust_api_specification_path or ".yaml" in entrust_api_specification_path: + self._spec = yaml.safe_load(f) + + def get_config(self, item): + return self._config.get(item, None) + + def _read_config_vars(self, name, **kwargs): + """ Read configuration from variables passed to the module. """ + config = {} + + entrust_api_specification_path = kwargs.get("entrust_api_specification_path") + if not entrust_api_specification_path or (not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path)): + raise SessionConfigurationException( + to_native( + "Parameter provided for entrust_api_specification_path of value '{0}' was not a valid file path or HTTPS address.".format( + entrust_api_specification_path + ) + ) + ) + + for required_file in ["entrust_api_cert", "entrust_api_cert_key"]: + file_path = kwargs.get(required_file) + if not file_path or not os.path.isfile(file_path): + raise SessionConfigurationException( + to_native("Parameter provided for {0} of value '{1}' was not a valid file path.".format(required_file, file_path)) + ) + + for required_var in ["entrust_api_user", "entrust_api_key"]: + if not kwargs.get(required_var): + raise SessionConfigurationException(to_native("Parameter provided for {0} was missing.".format(required_var))) + + config["entrust_api_cert"] = kwargs.get("entrust_api_cert") + config["entrust_api_cert_key"] = kwargs.get("entrust_api_cert_key") + config["entrust_api_specification_path"] = kwargs.get("entrust_api_specification_path") + config["entrust_api_user"] = kwargs.get("entrust_api_user") + config["entrust_api_key"] = kwargs.get("entrust_api_key") + + return config + + +def ECSClient(entrust_api_user=None, entrust_api_key=None, entrust_api_cert=None, entrust_api_cert_key=None, entrust_api_specification_path=None): + """Create an ECS client""" + + if not YAML_FOUND: + raise SessionConfigurationException(missing_required_lib("PyYAML"), exception=YAML_IMP_ERR) + + if entrust_api_specification_path is None: + entrust_api_specification_path = "https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml" + + # Not functionally necessary with current uses of this module_util, but better to be explicit for future use cases + entrust_api_user = to_text(entrust_api_user) + entrust_api_key = to_text(entrust_api_key) + entrust_api_cert_key = to_text(entrust_api_cert_key) + entrust_api_specification_path = to_text(entrust_api_specification_path) + + return ECSSession( + "ecs", + entrust_api_user=entrust_api_user, + entrust_api_key=entrust_api_key, + entrust_api_cert=entrust_api_cert, + entrust_api_cert_key=entrust_api_cert_key, + entrust_api_specification_path=entrust_api_specification_path, + ).client() diff --git a/test/support/integration/plugins/module_utils/mysql.py b/test/support/integration/plugins/module_utils/mysql.py new file mode 100644 index 00000000..46198f36 --- /dev/null +++ b/test/support/integration/plugins/module_utils/mysql.py @@ -0,0 +1,106 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Jonathan Mainguy <jon@soh.re>, 2015 +# Most of this was originally added by Sven Schliesing @muffl0n in the mysql_user.py module +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os + +try: + import pymysql as mysql_driver + _mysql_cursor_param = 'cursor' +except ImportError: + try: + import MySQLdb as mysql_driver + import MySQLdb.cursors + _mysql_cursor_param = 'cursorclass' + except ImportError: + mysql_driver = None + +mysql_driver_fail_msg = 'The PyMySQL (Python 2.7 and Python 3.X) or MySQL-python (Python 2.X) module is required.' + + +def mysql_connect(module, login_user=None, login_password=None, config_file='', ssl_cert=None, ssl_key=None, ssl_ca=None, db=None, cursor_class=None, + connect_timeout=30, autocommit=False): + config = {} + + if ssl_ca is not None or ssl_key is not None or ssl_cert is not None: + config['ssl'] = {} + + if module.params['login_unix_socket']: + config['unix_socket'] = module.params['login_unix_socket'] + else: + config['host'] = module.params['login_host'] + config['port'] = module.params['login_port'] + + if os.path.exists(config_file): + config['read_default_file'] = config_file + + # If login_user or login_password are given, they should override the + # config file + if login_user is not None: + config['user'] = login_user + if login_password is not None: + config['passwd'] = login_password + if ssl_cert is not None: + config['ssl']['cert'] = ssl_cert + if ssl_key is not None: + config['ssl']['key'] = ssl_key + if ssl_ca is not None: + config['ssl']['ca'] = ssl_ca + if db is not None: + config['db'] = db + if connect_timeout is not None: + config['connect_timeout'] = connect_timeout + + if _mysql_cursor_param == 'cursor': + # In case of PyMySQL driver: + db_connection = mysql_driver.connect(autocommit=autocommit, **config) + else: + # In case of MySQLdb driver + db_connection = mysql_driver.connect(**config) + if autocommit: + db_connection.autocommit(True) + + if cursor_class == 'DictCursor': + return db_connection.cursor(**{_mysql_cursor_param: mysql_driver.cursors.DictCursor}), db_connection + else: + return db_connection.cursor(), db_connection + + +def mysql_common_argument_spec(): + return dict( + login_user=dict(type='str', default=None), + login_password=dict(type='str', no_log=True), + login_host=dict(type='str', default='localhost'), + login_port=dict(type='int', default=3306), + login_unix_socket=dict(type='str'), + config_file=dict(type='path', default='~/.my.cnf'), + connect_timeout=dict(type='int', default=30), + client_cert=dict(type='path', aliases=['ssl_cert']), + client_key=dict(type='path', aliases=['ssl_key']), + ca_cert=dict(type='path', aliases=['ssl_ca']), + ) diff --git a/test/support/integration/plugins/module_utils/net_tools/__init__.py b/test/support/integration/plugins/module_utils/net_tools/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/net_tools/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/__init__.py b/test/support/integration/plugins/module_utils/network/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/common/__init__.py b/test/support/integration/plugins/module_utils/network/common/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/common/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/common/utils.py b/test/support/integration/plugins/module_utils/network/common/utils.py new file mode 100644 index 00000000..80317387 --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/common/utils.py @@ -0,0 +1,643 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Networking tools for network modules only + +import re +import ast +import operator +import socket +import json + +from itertools import chain + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils import basic +from ansible.module_utils.parsing.convert_bool import boolean + +# Backwards compatibility for 3rd party modules +# TODO(pabelanger): With move to ansible.netcommon, we should clean this code +# up and have modules import directly themself. +from ansible.module_utils.common.network import ( # noqa: F401 + to_bits, is_netmask, is_masklen, to_netmask, to_masklen, to_subnet, to_ipv6_network, VALID_MASKS +) + +try: + from jinja2 import Environment, StrictUndefined + from jinja2.exceptions import UndefinedError + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False + + +OPERATORS = frozenset(['ge', 'gt', 'eq', 'neq', 'lt', 'le']) +ALIASES = frozenset([('min', 'ge'), ('max', 'le'), ('exactly', 'eq'), ('neq', 'ne')]) + + +def to_list(val): + if isinstance(val, (list, tuple, set)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +def to_lines(stdout): + for item in stdout: + if isinstance(item, string_types): + item = to_text(item).split('\n') + yield item + + +def transform_commands(module): + transform = ComplexList(dict( + command=dict(key=True), + output=dict(), + prompt=dict(type='list'), + answer=dict(type='list'), + newline=dict(type='bool', default=True), + sendonly=dict(type='bool', default=False), + check_all=dict(type='bool', default=False), + ), module) + + return transform(module.params['commands']) + + +def sort_list(val): + if isinstance(val, list): + return sorted(val) + return val + + +class Entity(object): + """Transforms a dict to with an argument spec + + This class will take a dict and apply an Ansible argument spec to the + values. The resulting dict will contain all of the keys in the param + with appropriate values set. + + Example:: + + argument_spec = dict( + command=dict(key=True), + display=dict(default='text', choices=['text', 'json']), + validate=dict(type='bool') + ) + transform = Entity(module, argument_spec) + value = dict(command='foo') + result = transform(value) + print result + {'command': 'foo', 'display': 'text', 'validate': None} + + Supported argument spec: + * key - specifies how to map a single value to a dict + * read_from - read and apply the argument_spec from the module + * required - a value is required + * type - type of value (uses AnsibleModule type checker) + * fallback - implements fallback function + * choices - set of valid options + * default - default value + """ + + def __init__(self, module, attrs=None, args=None, keys=None, from_argspec=False): + args = [] if args is None else args + + self._attributes = attrs or {} + self._module = module + + for arg in args: + self._attributes[arg] = dict() + if from_argspec: + self._attributes[arg]['read_from'] = arg + if keys and arg in keys: + self._attributes[arg]['key'] = True + + self.attr_names = frozenset(self._attributes.keys()) + + _has_key = False + + for name, attr in iteritems(self._attributes): + if attr.get('read_from'): + if attr['read_from'] not in self._module.argument_spec: + module.fail_json(msg='argument %s does not exist' % attr['read_from']) + spec = self._module.argument_spec.get(attr['read_from']) + for key, value in iteritems(spec): + if key not in attr: + attr[key] = value + + if attr.get('key'): + if _has_key: + module.fail_json(msg='only one key value can be specified') + _has_key = True + attr['required'] = True + + def serialize(self): + return self._attributes + + def to_dict(self, value): + obj = {} + for name, attr in iteritems(self._attributes): + if attr.get('key'): + obj[name] = value + else: + obj[name] = attr.get('default') + return obj + + def __call__(self, value, strict=True): + if not isinstance(value, dict): + value = self.to_dict(value) + + if strict: + unknown = set(value).difference(self.attr_names) + if unknown: + self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown)) + + for name, attr in iteritems(self._attributes): + if value.get(name) is None: + value[name] = attr.get('default') + + if attr.get('fallback') and not value.get(name): + fallback = attr.get('fallback', (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + value[name] = fallback_strategy(*fallback_args, **fallback_kwargs) + except basic.AnsibleFallbackNotFound: + continue + + if attr.get('required') and value.get(name) is None: + self._module.fail_json(msg='missing required attribute %s' % name) + + if 'choices' in attr: + if value[name] not in attr['choices']: + self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name])) + + if value[name] is not None: + value_type = attr.get('type', 'str') + type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type] + type_checker(value[name]) + elif value.get(name): + value[name] = self._module.params[name] + + return value + + +class EntityCollection(Entity): + """Extends ```Entity``` to handle a list of dicts """ + + def __call__(self, iterable, strict=True): + if iterable is None: + iterable = [super(EntityCollection, self).__call__(self._module.params, strict)] + + if not isinstance(iterable, (list, tuple)): + self._module.fail_json(msg='value must be an iterable') + + return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable] + + +# these two are for backwards compatibility and can be removed once all of the +# modules that use them are updated +class ComplexDict(Entity): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexDict, self).__init__(module, attrs, *args, **kwargs) + + +class ComplexList(EntityCollection): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexList, self).__init__(module, attrs, *args, **kwargs) + + +def dict_diff(base, comparable): + """ Generate a dict object of differences + + This function will compare two dict objects and return the difference + between them as a dict object. For scalar values, the key will reflect + the updated value. If the key does not exist in `comparable`, then then no + key will be returned. For lists, the value in comparable will wholly replace + the value in base for the key. For dicts, the returned value will only + return keys that are different. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + + :returns: new dict object with differences + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(comparable, dict): + if comparable is None: + comparable = dict() + else: + raise AssertionError("`comparable` must be of type <dict>") + + updates = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + item = comparable.get(key) + if item is not None: + sub_diff = dict_diff(value, comparable[key]) + if sub_diff: + updates[key] = sub_diff + else: + comparable_value = comparable.get(key) + if comparable_value is not None: + if sort_list(base[key]) != sort_list(comparable_value): + updates[key] = comparable_value + + for key in set(comparable.keys()).difference(base.keys()): + updates[key] = comparable.get(key) + + return updates + + +def dict_merge(base, other): + """ Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. If the value is a list object, the two lists will + be combined and duplicate entries removed. + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(other, dict): + raise AssertionError("`other` must be of type <dict>") + + combined = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + if isinstance(other[key], Mapping): + combined[key] = dict_merge(value, other[key]) + else: + combined[key] = other[key] + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if item is not None: + try: + combined[key] = list(set(chain(value, item))) + except TypeError: + value.extend([i for i in item if i not in value]) + combined[key] = value + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) + + return combined + + +def param_list_to_dict(param_list, unique_key="name", remove_key=True): + """Rotates a list of dictionaries to be a dictionary of dictionaries. + + :param param_list: The aforementioned list of dictionaries + :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value + behind this key will be the key each dictionary can be found at in the new root dictionary + :param remove_key: If True, remove unique_key from the individual dictionaries before returning. + """ + param_dict = {} + for params in param_list: + params = params.copy() + if remove_key: + name = params.pop(unique_key) + else: + name = params.get(unique_key) + param_dict[name] = params + + return param_dict + + +def conditional(expr, val, cast=None): + match = re.match(r'^(.+)\((.+)\)$', str(expr), re.I) + if match: + op, arg = match.groups() + else: + op = 'eq' + if ' ' in str(expr): + raise AssertionError('invalid expression: cannot contain spaces') + arg = expr + + if cast is None and val is not None: + arg = type(val)(arg) + elif callable(cast): + arg = cast(arg) + val = cast(val) + + op = next((oper for alias, oper in ALIASES if op == alias), op) + + if not hasattr(operator, op) and op not in OPERATORS: + raise ValueError('unknown operator: %s' % op) + + func = getattr(operator, op) + return func(val, arg) + + +def ternary(value, true_val, false_val): + ''' value ? true_val : false_val ''' + if value: + return true_val + else: + return false_val + + +def remove_default_spec(spec): + for item in spec: + if 'default' in spec[item]: + del spec[item]['default'] + + +def validate_ip_address(address): + try: + socket.inet_aton(address) + except socket.error: + return False + return address.count('.') == 3 + + +def validate_ip_v6_address(address): + try: + socket.inet_pton(socket.AF_INET6, address) + except socket.error: + return False + return True + + +def validate_prefix(prefix): + if prefix and not 0 <= int(prefix) <= 32: + return False + return True + + +def load_provider(spec, args): + provider = args.get('provider') or {} + for key, value in iteritems(spec): + if key not in provider: + if 'fallback' in value: + provider[key] = _fallback(value['fallback']) + elif 'default' in value: + provider[key] = value['default'] + else: + provider[key] = None + if 'authorize' in provider: + # Coerce authorize to provider if a string has somehow snuck in. + provider['authorize'] = boolean(provider['authorize'] or False) + args['provider'] = provider + return provider + + +def _fallback(fallback): + strategy = fallback[0] + args = [] + kwargs = {} + + for item in fallback[1:]: + if isinstance(item, dict): + kwargs = item + else: + args = item + try: + return strategy(*args, **kwargs) + except basic.AnsibleFallbackNotFound: + pass + + +def generate_dict(spec): + """ + Generate dictionary which is in sync with argspec + + :param spec: A dictionary that is the argspec of the module + :rtype: A dictionary + :returns: A dictionary in sync with argspec with default value + """ + obj = {} + if not spec: + return obj + + for key, val in iteritems(spec): + if 'default' in val: + dct = {key: val['default']} + elif 'type' in val and val['type'] == 'dict': + dct = {key: generate_dict(val['options'])} + else: + dct = {key: None} + obj.update(dct) + return obj + + +def parse_conf_arg(cfg, arg): + """ + Parse config based on argument + + :param cfg: A text string which is a line of configuration. + :param arg: A text string which is to be matched. + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r'%s (.+)(\n|$)' % arg, cfg, re.M) + if match: + result = match.group(1).strip() + else: + result = None + return result + + +def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str='no'): + """ + Parse config based on command + + :param cfg: A text string which is a line of configuration. + :param cmd: A text string which is the command to be matched + :param res1: A text string to be returned if the command is present + :param res2: A text string to be returned if the negate command + is present + :param delete_str: A text string to identify the start of the + negate command + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r'\n\s+%s(\n|$)' % cmd, cfg) + if match: + return res1 + if res2 is not None: + match = re.search(r'\n\s+%s %s(\n|$)' % (delete_str, cmd), cfg) + if match: + return res2 + return None + + +def get_xml_conf_arg(cfg, path, data='text'): + """ + :param cfg: The top level configuration lxml Element tree object + :param path: The relative xpath w.r.t to top level element (cfg) + to be searched in the xml hierarchy + :param data: The type of data to be returned for the matched xml node. + Valid values are text, tag, attrib, with default as text. + :return: Returns the required type for the matched xml node or else None + """ + match = cfg.xpath(path) + if len(match): + if data == 'tag': + result = getattr(match[0], 'tag') + elif data == 'attrib': + result = getattr(match[0], 'attrib') + else: + result = getattr(match[0], 'text') + else: + result = None + return result + + +def remove_empties(cfg_dict): + """ + Generate final config dictionary + + :param cfg_dict: A dictionary parsed in the facts system + :rtype: A dictionary + :returns: A dictionary by eliminating keys that have null values + """ + final_cfg = {} + if not cfg_dict: + return final_cfg + + for key, val in iteritems(cfg_dict): + dct = None + if isinstance(val, dict): + child_val = remove_empties(val) + if child_val: + dct = {key: child_val} + elif (isinstance(val, list) and val + and all([isinstance(x, dict) for x in val])): + child_val = [remove_empties(x) for x in val] + if child_val: + dct = {key: child_val} + elif val not in [None, [], {}, (), '']: + dct = {key: val} + if dct: + final_cfg.update(dct) + return final_cfg + + +def validate_config(spec, data): + """ + Validate if the input data against the AnsibleModule spec format + :param spec: Ansible argument spec + :param data: Data to be validated + :return: + """ + params = basic._ANSIBLE_ARGS + basic._ANSIBLE_ARGS = to_bytes(json.dumps({'ANSIBLE_MODULE_ARGS': data})) + validated_data = basic.AnsibleModule(spec).params + basic._ANSIBLE_ARGS = params + return validated_data + + +def search_obj_in_list(name, lst, key='name'): + if not lst: + return None + else: + for item in lst: + if item.get(key) == name: + return item + + +class Template: + + def __init__(self): + if not HAS_JINJA2: + raise ImportError("jinja2 is required but does not appear to be installed. " + "It can be installed using `pip install jinja2`") + + self.env = Environment(undefined=StrictUndefined) + self.env.filters.update({'ternary': ternary}) + + def __call__(self, value, variables=None, fail_on_undefined=True): + variables = variables or {} + + if not self.contains_vars(value): + return value + + try: + value = self.env.from_string(value).render(variables) + except UndefinedError: + if not fail_on_undefined: + return None + raise + + if value: + try: + return ast.literal_eval(value) + except Exception: + return str(value) + else: + return None + + def contains_vars(self, data): + if isinstance(data, string_types): + for marker in (self.env.block_start_string, self.env.variable_start_string, self.env.comment_start_string): + if marker in data: + return True + return False diff --git a/test/support/integration/plugins/module_utils/postgres.py b/test/support/integration/plugins/module_utils/postgres.py new file mode 100644 index 00000000..63811c30 --- /dev/null +++ b/test/support/integration/plugins/module_utils/postgres.py @@ -0,0 +1,330 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Ted Timmons <ted@timmons.me>, 2017. +# Most of this was originally added by other creators in the postgresql_user module. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +psycopg2 = None # This line needs for unit tests +try: + import psycopg2 + HAS_PSYCOPG2 = True +except ImportError: + HAS_PSYCOPG2 = False + +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils._text import to_native +from ansible.module_utils.six import iteritems +from distutils.version import LooseVersion + + +def postgres_common_argument_spec(): + """ + Return a dictionary with connection options. + + The options are commonly used by most of PostgreSQL modules. + """ + return dict( + login_user=dict(default='postgres'), + login_password=dict(default='', no_log=True), + login_host=dict(default=''), + login_unix_socket=dict(default=''), + port=dict(type='int', default=5432, aliases=['login_port']), + ssl_mode=dict(default='prefer', choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full']), + ca_cert=dict(aliases=['ssl_rootcert']), + ) + + +def ensure_required_libs(module): + """Check required libraries.""" + if not HAS_PSYCOPG2: + module.fail_json(msg=missing_required_lib('psycopg2')) + + if module.params.get('ca_cert') and LooseVersion(psycopg2.__version__) < LooseVersion('2.4.3'): + module.fail_json(msg='psycopg2 must be at least 2.4.3 in order to use the ca_cert parameter') + + +def connect_to_db(module, conn_params, autocommit=False, fail_on_conn=True): + """Connect to a PostgreSQL database. + + Return psycopg2 connection object. + + Args: + module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class + conn_params (dict) -- dictionary with connection parameters + + Kwargs: + autocommit (bool) -- commit automatically (default False) + fail_on_conn (bool) -- fail if connection failed or just warn and return None (default True) + """ + ensure_required_libs(module) + + db_connection = None + try: + db_connection = psycopg2.connect(**conn_params) + if autocommit: + if LooseVersion(psycopg2.__version__) >= LooseVersion('2.4.2'): + db_connection.set_session(autocommit=True) + else: + db_connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + + # Switch role, if specified: + if module.params.get('session_role'): + cursor = db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) + + try: + cursor.execute('SET ROLE "%s"' % module.params['session_role']) + except Exception as e: + module.fail_json(msg="Could not switch role: %s" % to_native(e)) + finally: + cursor.close() + + except TypeError as e: + if 'sslrootcert' in e.args[0]: + module.fail_json(msg='Postgresql server must be at least ' + 'version 8.4 to support sslrootcert') + + if fail_on_conn: + module.fail_json(msg="unable to connect to database: %s" % to_native(e)) + else: + module.warn("PostgreSQL server is unavailable: %s" % to_native(e)) + db_connection = None + + except Exception as e: + if fail_on_conn: + module.fail_json(msg="unable to connect to database: %s" % to_native(e)) + else: + module.warn("PostgreSQL server is unavailable: %s" % to_native(e)) + db_connection = None + + return db_connection + + +def exec_sql(obj, query, query_params=None, ddl=False, add_to_executed=True, dont_exec=False): + """Execute SQL. + + Auxiliary function for PostgreSQL user classes. + + Returns a query result if possible or True/False if ddl=True arg was passed. + It necessary for statements that don't return any result (like DDL queries). + + Args: + obj (obj) -- must be an object of a user class. + The object must have module (AnsibleModule class object) and + cursor (psycopg cursor object) attributes + query (str) -- SQL query to execute + + Kwargs: + query_params (dict or tuple) -- Query parameters to prevent SQL injections, + could be a dict or tuple + ddl (bool) -- must return True or False instead of rows (typical for DDL queries) + (default False) + add_to_executed (bool) -- append the query to obj.executed_queries attribute + dont_exec (bool) -- used with add_to_executed=True to generate a query, add it + to obj.executed_queries list and return True (default False) + """ + + if dont_exec: + # This is usually needed to return queries in check_mode + # without execution + query = obj.cursor.mogrify(query, query_params) + if add_to_executed: + obj.executed_queries.append(query) + + return True + + try: + if query_params is not None: + obj.cursor.execute(query, query_params) + else: + obj.cursor.execute(query) + + if add_to_executed: + if query_params is not None: + obj.executed_queries.append(obj.cursor.mogrify(query, query_params)) + else: + obj.executed_queries.append(query) + + if not ddl: + res = obj.cursor.fetchall() + return res + return True + except Exception as e: + obj.module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e))) + return False + + +def get_conn_params(module, params_dict, warn_db_default=True): + """Get connection parameters from the passed dictionary. + + Return a dictionary with parameters to connect to PostgreSQL server. + + Args: + module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class + params_dict (dict) -- dictionary with variables + + Kwargs: + warn_db_default (bool) -- warn that the default DB is used (default True) + """ + # To use defaults values, keyword arguments must be absent, so + # check which values are empty and don't include in the return dictionary + params_map = { + "login_host": "host", + "login_user": "user", + "login_password": "password", + "port": "port", + "ssl_mode": "sslmode", + "ca_cert": "sslrootcert" + } + + # Might be different in the modules: + if params_dict.get('db'): + params_map['db'] = 'database' + elif params_dict.get('database'): + params_map['database'] = 'database' + elif params_dict.get('login_db'): + params_map['login_db'] = 'database' + else: + if warn_db_default: + module.warn('Database name has not been passed, ' + 'used default database to connect to.') + + kw = dict((params_map[k], v) for (k, v) in iteritems(params_dict) + if k in params_map and v != '' and v is not None) + + # If a login_unix_socket is specified, incorporate it here. + is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost" + if is_localhost and params_dict["login_unix_socket"] != "": + kw["host"] = params_dict["login_unix_socket"] + + return kw + + +class PgMembership(object): + def __init__(self, module, cursor, groups, target_roles, fail_on_role=True): + self.module = module + self.cursor = cursor + self.target_roles = [r.strip() for r in target_roles] + self.groups = [r.strip() for r in groups] + self.executed_queries = [] + self.granted = {} + self.revoked = {} + self.fail_on_role = fail_on_role + self.non_existent_roles = [] + self.changed = False + self.__check_roles_exist() + + def grant(self): + for group in self.groups: + self.granted[group] = [] + + for role in self.target_roles: + # If role is in a group now, pass: + if self.__check_membership(group, role): + continue + + query = 'GRANT "%s" TO "%s"' % (group, role) + self.changed = exec_sql(self, query, ddl=True) + + if self.changed: + self.granted[group].append(role) + + return self.changed + + def revoke(self): + for group in self.groups: + self.revoked[group] = [] + + for role in self.target_roles: + # If role is not in a group now, pass: + if not self.__check_membership(group, role): + continue + + query = 'REVOKE "%s" FROM "%s"' % (group, role) + self.changed = exec_sql(self, query, ddl=True) + + if self.changed: + self.revoked[group].append(role) + + return self.changed + + def __check_membership(self, src_role, dst_role): + query = ("SELECT ARRAY(SELECT b.rolname FROM " + "pg_catalog.pg_auth_members m " + "JOIN pg_catalog.pg_roles b ON (m.roleid = b.oid) " + "WHERE m.member = r.oid) " + "FROM pg_catalog.pg_roles r " + "WHERE r.rolname = %(dst_role)s") + + res = exec_sql(self, query, query_params={'dst_role': dst_role}, add_to_executed=False) + membership = [] + if res: + membership = res[0][0] + + if not membership: + return False + + if src_role in membership: + return True + + return False + + def __check_roles_exist(self): + existent_groups = self.__roles_exist(self.groups) + existent_roles = self.__roles_exist(self.target_roles) + + for group in self.groups: + if group not in existent_groups: + if self.fail_on_role: + self.module.fail_json(msg="Role %s does not exist" % group) + else: + self.module.warn("Role %s does not exist, pass" % group) + self.non_existent_roles.append(group) + + for role in self.target_roles: + if role not in existent_roles: + if self.fail_on_role: + self.module.fail_json(msg="Role %s does not exist" % role) + else: + self.module.warn("Role %s does not exist, pass" % role) + + if role not in self.groups: + self.non_existent_roles.append(role) + + else: + if self.fail_on_role: + self.module.exit_json(msg="Role role '%s' is a member of role '%s'" % (role, role)) + else: + self.module.warn("Role role '%s' is a member of role '%s', pass" % (role, role)) + + # Update role lists, excluding non existent roles: + self.groups = [g for g in self.groups if g not in self.non_existent_roles] + + self.target_roles = [r for r in self.target_roles if r not in self.non_existent_roles] + + def __roles_exist(self, roles): + tmp = ["'" + x + "'" for x in roles] + query = "SELECT rolname FROM pg_roles WHERE rolname IN (%s)" % ','.join(tmp) + return [x[0] for x in exec_sql(self, query, add_to_executed=False)] diff --git a/test/support/integration/plugins/module_utils/rabbitmq.py b/test/support/integration/plugins/module_utils/rabbitmq.py new file mode 100644 index 00000000..cf764006 --- /dev/null +++ b/test/support/integration/plugins/module_utils/rabbitmq.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# +# Copyright: (c) 2016, Jorge Rodriguez <jorge.rodriguez@tiriel.eu> +# Copyright: (c) 2018, John Imison <john+github@imison.net> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils._text import to_native +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.six.moves.urllib import parse as urllib_parse +from mimetypes import MimeTypes + +import os +import json +import traceback + +PIKA_IMP_ERR = None +try: + import pika + import pika.exceptions + from pika import spec + HAS_PIKA = True +except ImportError: + PIKA_IMP_ERR = traceback.format_exc() + HAS_PIKA = False + + +def rabbitmq_argument_spec(): + return dict( + login_user=dict(type='str', default='guest'), + login_password=dict(type='str', default='guest', no_log=True), + login_host=dict(type='str', default='localhost'), + login_port=dict(type='str', default='15672'), + login_protocol=dict(type='str', default='http', choices=['http', 'https']), + ca_cert=dict(type='path', aliases=['cacert']), + client_cert=dict(type='path', aliases=['cert']), + client_key=dict(type='path', aliases=['key']), + vhost=dict(type='str', default='/'), + ) + + +# notification/rabbitmq_basic_publish.py +class RabbitClient(): + def __init__(self, module): + self.module = module + self.params = module.params + self.check_required_library() + self.check_host_params() + self.url = self.params['url'] + self.proto = self.params['proto'] + self.username = self.params['username'] + self.password = self.params['password'] + self.host = self.params['host'] + self.port = self.params['port'] + self.vhost = self.params['vhost'] + self.queue = self.params['queue'] + self.headers = self.params['headers'] + self.cafile = self.params['cafile'] + self.certfile = self.params['certfile'] + self.keyfile = self.params['keyfile'] + + if self.host is not None: + self.build_url() + + if self.cafile is not None: + self.append_ssl_certs() + + self.connect_to_rabbitmq() + + def check_required_library(self): + if not HAS_PIKA: + self.module.fail_json(msg=missing_required_lib("pika"), exception=PIKA_IMP_ERR) + + def check_host_params(self): + # Fail if url is specified and other conflicting parameters have been specified + if self.params['url'] is not None and any(self.params[k] is not None for k in ['proto', 'host', 'port', 'password', 'username', 'vhost']): + self.module.fail_json(msg="url and proto, host, port, vhost, username or password cannot be specified at the same time.") + + # Fail if url not specified and there is a missing parameter to build the url + if self.params['url'] is None and any(self.params[k] is None for k in ['proto', 'host', 'port', 'password', 'username', 'vhost']): + self.module.fail_json(msg="Connection parameters must be passed via url, or, proto, host, port, vhost, username or password.") + + def append_ssl_certs(self): + ssl_options = {} + if self.cafile: + ssl_options['cafile'] = self.cafile + if self.certfile: + ssl_options['certfile'] = self.certfile + if self.keyfile: + ssl_options['keyfile'] = self.keyfile + + self.url = self.url + '?ssl_options=' + urllib_parse.quote(json.dumps(ssl_options)) + + @staticmethod + def rabbitmq_argument_spec(): + return dict( + url=dict(type='str'), + proto=dict(type='str', choices=['amqp', 'amqps']), + host=dict(type='str'), + port=dict(type='int'), + username=dict(type='str'), + password=dict(type='str', no_log=True), + vhost=dict(type='str'), + queue=dict(type='str') + ) + + ''' Consider some file size limits here ''' + def _read_file(self, path): + try: + with open(path, "rb") as file_handle: + return file_handle.read() + except IOError as e: + self.module.fail_json(msg="Unable to open file %s: %s" % (path, to_native(e))) + + @staticmethod + def _check_file_mime_type(path): + mime = MimeTypes() + return mime.guess_type(path) + + def build_url(self): + self.url = '{0}://{1}:{2}@{3}:{4}/{5}'.format(self.proto, + self.username, + self.password, + self.host, + self.port, + self.vhost) + + def connect_to_rabbitmq(self): + """ + Function to connect to rabbitmq using username and password + """ + try: + parameters = pika.URLParameters(self.url) + except Exception as e: + self.module.fail_json(msg="URL malformed: %s" % to_native(e)) + + try: + self.connection = pika.BlockingConnection(parameters) + except Exception as e: + self.module.fail_json(msg="Connection issue: %s" % to_native(e)) + + try: + self.conn_channel = self.connection.channel() + except pika.exceptions.AMQPChannelError as e: + self.close_connection() + self.module.fail_json(msg="Channel issue: %s" % to_native(e)) + + def close_connection(self): + try: + self.connection.close() + except pika.exceptions.AMQPConnectionError: + pass + + def basic_publish(self): + self.content_type = self.params.get("content_type") + + if self.params.get("body") is not None: + args = dict( + body=self.params.get("body"), + exchange=self.params.get("exchange"), + routing_key=self.params.get("routing_key"), + properties=pika.BasicProperties(content_type=self.content_type, delivery_mode=1, headers=self.headers)) + + # If src (file) is defined and content_type is left as default, do a mime lookup on the file + if self.params.get("src") is not None and self.content_type == 'text/plain': + self.content_type = RabbitClient._check_file_mime_type(self.params.get("src"))[0] + self.headers.update( + filename=os.path.basename(self.params.get("src")) + ) + + args = dict( + body=self._read_file(self.params.get("src")), + exchange=self.params.get("exchange"), + routing_key=self.params.get("routing_key"), + properties=pika.BasicProperties(content_type=self.content_type, + delivery_mode=1, + headers=self.headers + )) + elif self.params.get("src") is not None: + args = dict( + body=self._read_file(self.params.get("src")), + exchange=self.params.get("exchange"), + routing_key=self.params.get("routing_key"), + properties=pika.BasicProperties(content_type=self.content_type, + delivery_mode=1, + headers=self.headers + )) + + try: + # If queue is not defined, RabbitMQ will return the queue name of the automatically generated queue. + if self.queue is None: + result = self.conn_channel.queue_declare(durable=self.params.get("durable"), + exclusive=self.params.get("exclusive"), + auto_delete=self.params.get("auto_delete")) + self.conn_channel.confirm_delivery() + self.queue = result.method.queue + else: + self.conn_channel.queue_declare(queue=self.queue, + durable=self.params.get("durable"), + exclusive=self.params.get("exclusive"), + auto_delete=self.params.get("auto_delete")) + self.conn_channel.confirm_delivery() + except Exception as e: + self.module.fail_json(msg="Queue declare issue: %s" % to_native(e)) + + # https://github.com/ansible/ansible/blob/devel/lib/ansible/module_utils/cloudstack.py#L150 + if args['routing_key'] is None: + args['routing_key'] = self.queue + + if args['exchange'] is None: + args['exchange'] = '' + + try: + self.conn_channel.basic_publish(**args) + return True + except pika.exceptions.UnroutableError: + return False diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py new file mode 120000 index 00000000..f9993bfb --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbconfiguration_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py new file mode 120000 index 00000000..b8293e64 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbdatabase_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py new file mode 120000 index 00000000..4311a0c1 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbfirewallrule_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py new file mode 120000 index 00000000..5f76e0e9 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbserver_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_resource_facts.py b/test/support/integration/plugins/modules/_azure_rm_resource_facts.py new file mode 120000 index 00000000..710fda10 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_resource_facts.py @@ -0,0 +1 @@ +azure_rm_resource_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py b/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py new file mode 120000 index 00000000..ead87c85 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py @@ -0,0 +1 @@ +azure_rm_webapp_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/aws_az_info.py b/test/support/integration/plugins/modules/aws_az_info.py new file mode 100644 index 00000000..c1efed6f --- /dev/null +++ b/test/support/integration/plugins/modules/aws_az_info.py @@ -0,0 +1,111 @@ +#!/usr/bin/python +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'supported_by': 'community', + 'status': ['preview'] +} + +DOCUMENTATION = ''' +module: aws_az_info +short_description: Gather information about availability zones in AWS. +description: + - Gather information about availability zones in AWS. + - This module was called C(aws_az_facts) before Ansible 2.9. The usage did not change. +version_added: '2.5' +author: 'Henrique Rodrigues (@Sodki)' +options: + filters: + description: + - A dict of filters to apply. Each dict item consists of a filter key and a filter value. See + U(https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeAvailabilityZones.html) for + possible filters. Filter names and values are case sensitive. You can also use underscores + instead of dashes (-) in the filter keys, which will take precedence in case of conflict. + required: false + default: {} + type: dict +extends_documentation_fragment: + - aws + - ec2 +requirements: [botocore, boto3] +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Gather information about all availability zones +- aws_az_info: + +# Gather information about a single availability zone +- aws_az_info: + filters: + zone-name: eu-west-1a +''' + +RETURN = ''' +availability_zones: + returned: on success + description: > + Availability zones that match the provided filters. Each element consists of a dict with all the information + related to that available zone. + type: list + sample: "[ + { + 'messages': [], + 'region_name': 'us-west-1', + 'state': 'available', + 'zone_name': 'us-west-1b' + }, + { + 'messages': [], + 'region_name': 'us-west-1', + 'state': 'available', + 'zone_name': 'us-west-1c' + } + ]" +''' + +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import AWSRetry, ansible_dict_to_boto3_filter_list, camel_dict_to_snake_dict + +try: + from botocore.exceptions import ClientError, BotoCoreError +except ImportError: + pass # Handled by AnsibleAWSModule + + +def main(): + argument_spec = dict( + filters=dict(default={}, type='dict') + ) + + module = AnsibleAWSModule(argument_spec=argument_spec) + if module._name == 'aws_az_facts': + module.deprecate("The 'aws_az_facts' module has been renamed to 'aws_az_info'", + version='2.14', collection_name='ansible.builtin') + + connection = module.client('ec2', retry_decorator=AWSRetry.jittered_backoff()) + + # Replace filter key underscores with dashes, for compatibility + sanitized_filters = dict((k.replace('_', '-'), v) for k, v in module.params.get('filters').items()) + + try: + availability_zones = connection.describe_availability_zones( + Filters=ansible_dict_to_boto3_filter_list(sanitized_filters) + ) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to describe availability zones.") + + # Turn the boto3 result into ansible_friendly_snaked_names + snaked_availability_zones = [camel_dict_to_snake_dict(az) for az in availability_zones['AvailabilityZones']] + + module.exit_json(availability_zones=snaked_availability_zones) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/aws_s3.py b/test/support/integration/plugins/modules/aws_s3.py new file mode 100644 index 00000000..54874f05 --- /dev/null +++ b/test/support/integration/plugins/modules/aws_s3.py @@ -0,0 +1,925 @@ +#!/usr/bin/python +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: aws_s3 +short_description: manage objects in S3. +description: + - This module allows the user to manage S3 buckets and the objects within them. Includes support for creating and + deleting both objects and buckets, retrieving objects as files or strings and generating download links. + This module has a dependency on boto3 and botocore. +notes: + - In 2.4, this module has been renamed from C(s3) into M(aws_s3). +version_added: "1.1" +options: + bucket: + description: + - Bucket name. + required: true + type: str + dest: + description: + - The destination file path when downloading an object/key with a GET operation. + version_added: "1.3" + type: path + encrypt: + description: + - When set for PUT mode, asks for server-side encryption. + default: true + version_added: "2.0" + type: bool + encryption_mode: + description: + - What encryption mode to use if I(encrypt=true). + default: AES256 + choices: + - AES256 + - aws:kms + version_added: "2.7" + type: str + expiry: + description: + - Time limit (in seconds) for the URL generated and returned by S3/Walrus when performing a I(mode=put) or I(mode=geturl) operation. + default: 600 + aliases: ['expiration'] + type: int + headers: + description: + - Custom headers for PUT operation, as a dictionary of 'key=value' and 'key=value,key=value'. + version_added: "2.0" + type: dict + marker: + description: + - Specifies the key to start with when using list mode. Object keys are returned in alphabetical order, starting with key after the marker in order. + version_added: "2.0" + type: str + max_keys: + description: + - Max number of results to return in list mode, set this if you want to retrieve fewer than the default 1000 keys. + default: 1000 + version_added: "2.0" + type: int + metadata: + description: + - Metadata for PUT operation, as a dictionary of 'key=value' and 'key=value,key=value'. + version_added: "1.6" + type: dict + mode: + description: + - Switches the module behaviour between put (upload), get (download), geturl (return download url, Ansible 1.3+), + getstr (download object as string (1.3+)), list (list keys, Ansible 2.0+), create (bucket), delete (bucket), + and delobj (delete object, Ansible 2.0+). + required: true + choices: ['get', 'put', 'delete', 'create', 'geturl', 'getstr', 'delobj', 'list'] + type: str + object: + description: + - Keyname of the object inside the bucket. Can be used to create "virtual directories", see examples. + type: str + permission: + description: + - This option lets the user set the canned permissions on the object/bucket that are created. + The permissions that can be set are C(private), C(public-read), C(public-read-write), C(authenticated-read) for a bucket or + C(private), C(public-read), C(public-read-write), C(aws-exec-read), C(authenticated-read), C(bucket-owner-read), + C(bucket-owner-full-control) for an object. Multiple permissions can be specified as a list. + default: ['private'] + version_added: "2.0" + type: list + elements: str + prefix: + description: + - Limits the response to keys that begin with the specified prefix for list mode. + default: "" + version_added: "2.0" + type: str + version: + description: + - Version ID of the object inside the bucket. Can be used to get a specific version of a file if versioning is enabled in the target bucket. + version_added: "2.0" + type: str + overwrite: + description: + - Force overwrite either locally on the filesystem or remotely with the object/key. Used with PUT and GET operations. + Boolean or one of [always, never, different], true is equal to 'always' and false is equal to 'never', new in 2.0. + When this is set to 'different', the md5 sum of the local file is compared with the 'ETag' of the object/key in S3. + The ETag may or may not be an MD5 digest of the object data. See the ETag response header here + U(https://docs.aws.amazon.com/AmazonS3/latest/API/RESTCommonResponseHeaders.html) + default: 'always' + aliases: ['force'] + version_added: "1.2" + type: str + retries: + description: + - On recoverable failure, how many times to retry before actually failing. + default: 0 + version_added: "2.0" + type: int + aliases: ['retry'] + s3_url: + description: + - S3 URL endpoint for usage with Ceph, Eucalyptus and fakes3 etc. Otherwise assumes AWS. + aliases: [ S3_URL ] + type: str + dualstack: + description: + - Enables Amazon S3 Dual-Stack Endpoints, allowing S3 communications using both IPv4 and IPv6. + - Requires at least botocore version 1.4.45. + type: bool + default: false + version_added: "2.7" + rgw: + description: + - Enable Ceph RGW S3 support. This option requires an explicit url via I(s3_url). + default: false + version_added: "2.2" + type: bool + src: + description: + - The source file path when performing a PUT operation. + version_added: "1.3" + type: str + ignore_nonexistent_bucket: + description: + - "Overrides initial bucket lookups in case bucket or iam policies are restrictive. Example: a user may have the + GetObject permission but no other permissions. In this case using the option mode: get will fail without specifying + I(ignore_nonexistent_bucket=true)." + version_added: "2.3" + type: bool + encryption_kms_key_id: + description: + - KMS key id to use when encrypting objects using I(encrypting=aws:kms). Ignored if I(encryption) is not C(aws:kms) + version_added: "2.7" + type: str +requirements: [ "boto3", "botocore" ] +author: + - "Lester Wade (@lwade)" + - "Sloane Hertel (@s-hertel)" +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +- name: Simple PUT operation + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + +- name: Simple PUT operation in Ceph RGW S3 + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + rgw: true + s3_url: "http://localhost:8000" + +- name: Simple GET operation + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + dest: /usr/local/myfile.txt + mode: get + +- name: Get a specific version of an object. + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + version: 48c9ee5131af7a716edc22df9772aa6f + dest: /usr/local/myfile.txt + mode: get + +- name: PUT/upload with metadata + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + metadata: 'Content-Encoding=gzip,Cache-Control=no-cache' + +- name: PUT/upload with custom headers + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + headers: 'x-amz-grant-full-control=emailAddress=owner@example.com' + +- name: List keys simple + aws_s3: + bucket: mybucket + mode: list + +- name: List keys all options + aws_s3: + bucket: mybucket + mode: list + prefix: /my/desired/ + marker: /my/desired/0023.txt + max_keys: 472 + +- name: Create an empty bucket + aws_s3: + bucket: mybucket + mode: create + permission: public-read + +- name: Create a bucket with key as directory, in the EU region + aws_s3: + bucket: mybucket + object: /my/directory/path + mode: create + region: eu-west-1 + +- name: Delete a bucket and all contents + aws_s3: + bucket: mybucket + mode: delete + +- name: GET an object but don't download if the file checksums match. New in 2.0 + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + dest: /usr/local/myfile.txt + mode: get + overwrite: different + +- name: Delete an object from a bucket + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + mode: delobj +''' + +RETURN = ''' +msg: + description: Message indicating the status of the operation. + returned: always + type: str + sample: PUT operation complete +url: + description: URL of the object. + returned: (for put and geturl operations) + type: str + sample: https://my-bucket.s3.amazonaws.com/my-key.txt?AWSAccessKeyId=<access-key>&Expires=1506888865&Signature=<signature> +expiry: + description: Number of seconds the presigned url is valid for. + returned: (for geturl operation) + type: int + sample: 600 +contents: + description: Contents of the object as string. + returned: (for getstr operation) + type: str + sample: "Hello, world!" +s3_keys: + description: List of object keys. + returned: (for list operation) + type: list + elements: str + sample: + - prefix1/ + - prefix1/key1 + - prefix1/key2 +''' + +import mimetypes +import os +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ssl import SSLError +from ansible.module_utils.basic import to_text, to_native +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.aws.s3 import calculate_etag, HAS_MD5 +from ansible.module_utils.ec2 import get_aws_connection_info, boto3_conn + +try: + import botocore +except ImportError: + pass # will be detected by imported AnsibleAWSModule + +IGNORE_S3_DROP_IN_EXCEPTIONS = ['XNotImplemented', 'NotImplemented'] + + +class Sigv4Required(Exception): + pass + + +def key_check(module, s3, bucket, obj, version=None, validate=True): + exists = True + try: + if version: + s3.head_object(Bucket=bucket, Key=obj, VersionId=version) + else: + s3.head_object(Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + # if a client error is thrown, check if it's a 404 error + # if it's a 404 error, then the object does not exist + error_code = int(e.response['Error']['Code']) + if error_code == 404: + exists = False + elif error_code == 403 and validate is False: + pass + else: + module.fail_json_aws(e, msg="Failed while looking up object (during key check) %s." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while looking up object (during key check) %s." % obj) + return exists + + +def etag_compare(module, local_file, s3, bucket, obj, version=None): + s3_etag = get_etag(s3, bucket, obj, version=version) + local_etag = calculate_etag(module, local_file, s3_etag, s3, bucket, obj, version) + + return s3_etag == local_etag + + +def get_etag(s3, bucket, obj, version=None): + if version: + key_check = s3.head_object(Bucket=bucket, Key=obj, VersionId=version) + else: + key_check = s3.head_object(Bucket=bucket, Key=obj) + if not key_check: + return None + return key_check['ETag'] + + +def bucket_check(module, s3, bucket, validate=True): + exists = True + try: + s3.head_bucket(Bucket=bucket) + except botocore.exceptions.ClientError as e: + # If a client error is thrown, then check that it was a 404 error. + # If it was a 404 error, then the bucket does not exist. + error_code = int(e.response['Error']['Code']) + if error_code == 404: + exists = False + elif error_code == 403 and validate is False: + pass + else: + module.fail_json_aws(e, msg="Failed while looking up bucket (during bucket_check) %s." % bucket) + except botocore.exceptions.EndpointConnectionError as e: + module.fail_json_aws(e, msg="Invalid endpoint provided") + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while looking up bucket (during bucket_check) %s." % bucket) + return exists + + +def create_bucket(module, s3, bucket, location=None): + if module.check_mode: + module.exit_json(msg="CREATE operation skipped - running in check mode", changed=True) + configuration = {} + if location not in ('us-east-1', None): + configuration['LocationConstraint'] = location + try: + if len(configuration) > 0: + s3.create_bucket(Bucket=bucket, CreateBucketConfiguration=configuration) + else: + s3.create_bucket(Bucket=bucket) + if module.params.get('permission'): + # Wait for the bucket to exist before setting ACLs + s3.get_waiter('bucket_exists').wait(Bucket=bucket) + for acl in module.params.get('permission'): + s3.put_bucket_acl(ACL=acl, Bucket=bucket) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS: + module.warn("PutBucketAcl is not implemented by your storage provider. Set the permission parameters to the empty list to avoid this warning") + else: + module.fail_json_aws(e, msg="Failed while creating bucket or setting acl (check that you have CreateBucket and PutBucketAcl permission).") + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while creating bucket or setting acl (check that you have CreateBucket and PutBucketAcl permission).") + + if bucket: + return True + + +def paginated_list(s3, **pagination_params): + pg = s3.get_paginator('list_objects_v2') + for page in pg.paginate(**pagination_params): + yield [data['Key'] for data in page.get('Contents', [])] + + +def paginated_versioned_list_with_fallback(s3, **pagination_params): + try: + versioned_pg = s3.get_paginator('list_object_versions') + for page in versioned_pg.paginate(**pagination_params): + delete_markers = [{'Key': data['Key'], 'VersionId': data['VersionId']} for data in page.get('DeleteMarkers', [])] + current_objects = [{'Key': data['Key'], 'VersionId': data['VersionId']} for data in page.get('Versions', [])] + yield delete_markers + current_objects + except botocore.exceptions.ClientError as e: + if to_text(e.response['Error']['Code']) in IGNORE_S3_DROP_IN_EXCEPTIONS + ['AccessDenied']: + for page in paginated_list(s3, **pagination_params): + yield [{'Key': data['Key']} for data in page] + + +def list_keys(module, s3, bucket, prefix, marker, max_keys): + pagination_params = {'Bucket': bucket} + for param_name, param_value in (('Prefix', prefix), ('StartAfter', marker), ('MaxKeys', max_keys)): + pagination_params[param_name] = param_value + try: + keys = sum(paginated_list(s3, **pagination_params), []) + module.exit_json(msg="LIST operation complete", s3_keys=keys) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while listing the keys in the bucket {0}".format(bucket)) + + +def delete_bucket(module, s3, bucket): + if module.check_mode: + module.exit_json(msg="DELETE operation skipped - running in check mode", changed=True) + try: + exists = bucket_check(module, s3, bucket) + if exists is False: + return False + # if there are contents then we need to delete them before we can delete the bucket + for keys in paginated_versioned_list_with_fallback(s3, Bucket=bucket): + if keys: + s3.delete_objects(Bucket=bucket, Delete={'Objects': keys}) + s3.delete_bucket(Bucket=bucket) + return True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while deleting bucket %s." % bucket) + + +def delete_key(module, s3, bucket, obj): + if module.check_mode: + module.exit_json(msg="DELETE operation skipped - running in check mode", changed=True) + try: + s3.delete_object(Bucket=bucket, Key=obj) + module.exit_json(msg="Object deleted from bucket %s." % (bucket), changed=True) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while trying to delete %s." % obj) + + +def create_dirkey(module, s3, bucket, obj, encrypt): + if module.check_mode: + module.exit_json(msg="PUT operation skipped - running in check mode", changed=True) + try: + params = {'Bucket': bucket, 'Key': obj, 'Body': b''} + if encrypt: + params['ServerSideEncryption'] = module.params['encryption_mode'] + if module.params['encryption_kms_key_id'] and module.params['encryption_mode'] == 'aws:kms': + params['SSEKMSKeyId'] = module.params['encryption_kms_key_id'] + + s3.put_object(**params) + for acl in module.params.get('permission'): + s3.put_object_acl(ACL=acl, Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS: + module.warn("PutObjectAcl is not implemented by your storage provider. Set the permissions parameters to the empty list to avoid this warning") + else: + module.fail_json_aws(e, msg="Failed while creating object %s." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while creating object %s." % obj) + module.exit_json(msg="Virtual directory %s created in bucket %s" % (obj, bucket), changed=True) + + +def path_check(path): + if os.path.exists(path): + return True + else: + return False + + +def option_in_extra_args(option): + temp_option = option.replace('-', '').lower() + + allowed_extra_args = {'acl': 'ACL', 'cachecontrol': 'CacheControl', 'contentdisposition': 'ContentDisposition', + 'contentencoding': 'ContentEncoding', 'contentlanguage': 'ContentLanguage', + 'contenttype': 'ContentType', 'expires': 'Expires', 'grantfullcontrol': 'GrantFullControl', + 'grantread': 'GrantRead', 'grantreadacp': 'GrantReadACP', 'grantwriteacp': 'GrantWriteACP', + 'metadata': 'Metadata', 'requestpayer': 'RequestPayer', 'serversideencryption': 'ServerSideEncryption', + 'storageclass': 'StorageClass', 'ssecustomeralgorithm': 'SSECustomerAlgorithm', 'ssecustomerkey': 'SSECustomerKey', + 'ssecustomerkeymd5': 'SSECustomerKeyMD5', 'ssekmskeyid': 'SSEKMSKeyId', 'websiteredirectlocation': 'WebsiteRedirectLocation'} + + if temp_option in allowed_extra_args: + return allowed_extra_args[temp_option] + + +def upload_s3file(module, s3, bucket, obj, src, expiry, metadata, encrypt, headers): + if module.check_mode: + module.exit_json(msg="PUT operation skipped - running in check mode", changed=True) + try: + extra = {} + if encrypt: + extra['ServerSideEncryption'] = module.params['encryption_mode'] + if module.params['encryption_kms_key_id'] and module.params['encryption_mode'] == 'aws:kms': + extra['SSEKMSKeyId'] = module.params['encryption_kms_key_id'] + if metadata: + extra['Metadata'] = {} + + # determine object metadata and extra arguments + for option in metadata: + extra_args_option = option_in_extra_args(option) + if extra_args_option is not None: + extra[extra_args_option] = metadata[option] + else: + extra['Metadata'][option] = metadata[option] + + if 'ContentType' not in extra: + content_type = mimetypes.guess_type(src)[0] + if content_type is None: + # s3 default content type + content_type = 'binary/octet-stream' + extra['ContentType'] = content_type + + s3.upload_file(Filename=src, Bucket=bucket, Key=obj, ExtraArgs=extra) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to complete PUT operation.") + try: + for acl in module.params.get('permission'): + s3.put_object_acl(ACL=acl, Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS: + module.warn("PutObjectAcl is not implemented by your storage provider. Set the permission parameters to the empty list to avoid this warning") + else: + module.fail_json_aws(e, msg="Unable to set object ACL") + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Unable to set object ACL") + try: + url = s3.generate_presigned_url(ClientMethod='put_object', + Params={'Bucket': bucket, 'Key': obj}, + ExpiresIn=expiry) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to generate presigned URL") + module.exit_json(msg="PUT operation complete", url=url, changed=True) + + +def download_s3file(module, s3, bucket, obj, dest, retries, version=None): + if module.check_mode: + module.exit_json(msg="GET operation skipped - running in check mode", changed=True) + # retries is the number of loops; range/xrange needs to be one + # more to get that count of loops. + try: + if version: + key = s3.get_object(Bucket=bucket, Key=obj, VersionId=version) + else: + key = s3.get_object(Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] == 'InvalidArgument' and 'require AWS Signature Version 4' in to_text(e): + raise Sigv4Required() + elif e.response['Error']['Code'] not in ("403", "404"): + # AccessDenied errors may be triggered if 1) file does not exist or 2) file exists but + # user does not have the s3:GetObject permission. 404 errors are handled by download_file(). + module.fail_json_aws(e, msg="Could not find the key %s." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Could not find the key %s." % obj) + + optional_kwargs = {'ExtraArgs': {'VersionId': version}} if version else {} + for x in range(0, retries + 1): + try: + s3.download_file(bucket, obj, dest, **optional_kwargs) + module.exit_json(msg="GET operation complete", changed=True) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + # actually fail on last pass through the loop. + if x >= retries: + module.fail_json_aws(e, msg="Failed while downloading %s." % obj) + # otherwise, try again, this may be a transient timeout. + except SSLError as e: # will ClientError catch SSLError? + # actually fail on last pass through the loop. + if x >= retries: + module.fail_json_aws(e, msg="s3 download failed") + # otherwise, try again, this may be a transient timeout. + + +def download_s3str(module, s3, bucket, obj, version=None, validate=True): + if module.check_mode: + module.exit_json(msg="GET operation skipped - running in check mode", changed=True) + try: + if version: + contents = to_native(s3.get_object(Bucket=bucket, Key=obj, VersionId=version)["Body"].read()) + else: + contents = to_native(s3.get_object(Bucket=bucket, Key=obj)["Body"].read()) + module.exit_json(msg="GET operation complete", contents=contents, changed=True) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] == 'InvalidArgument' and 'require AWS Signature Version 4' in to_text(e): + raise Sigv4Required() + else: + module.fail_json_aws(e, msg="Failed while getting contents of object %s as a string." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while getting contents of object %s as a string." % obj) + + +def get_download_url(module, s3, bucket, obj, expiry, changed=True): + try: + url = s3.generate_presigned_url(ClientMethod='get_object', + Params={'Bucket': bucket, 'Key': obj}, + ExpiresIn=expiry) + module.exit_json(msg="Download url:", url=url, expiry=expiry, changed=changed) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while getting download url.") + + +def is_fakes3(s3_url): + """ Return True if s3_url has scheme fakes3:// """ + if s3_url is not None: + return urlparse(s3_url).scheme in ('fakes3', 'fakes3s') + else: + return False + + +def get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=False): + if s3_url and rgw: # TODO - test this + rgw = urlparse(s3_url) + params = dict(module=module, conn_type='client', resource='s3', use_ssl=rgw.scheme == 'https', region=location, endpoint=s3_url, **aws_connect_kwargs) + elif is_fakes3(s3_url): + fakes3 = urlparse(s3_url) + port = fakes3.port + if fakes3.scheme == 'fakes3s': + protocol = "https" + if port is None: + port = 443 + else: + protocol = "http" + if port is None: + port = 80 + params = dict(module=module, conn_type='client', resource='s3', region=location, + endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)), + use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs) + else: + params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=s3_url, **aws_connect_kwargs) + if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms': + params['config'] = botocore.client.Config(signature_version='s3v4') + elif module.params['mode'] in ('get', 'getstr') and sig_4: + params['config'] = botocore.client.Config(signature_version='s3v4') + if module.params['dualstack']: + dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True}) + if 'config' in params: + params['config'] = params['config'].merge(dualconf) + else: + params['config'] = dualconf + return boto3_conn(**params) + + +def main(): + argument_spec = dict( + bucket=dict(required=True), + dest=dict(default=None, type='path'), + encrypt=dict(default=True, type='bool'), + encryption_mode=dict(choices=['AES256', 'aws:kms'], default='AES256'), + expiry=dict(default=600, type='int', aliases=['expiration']), + headers=dict(type='dict'), + marker=dict(default=""), + max_keys=dict(default=1000, type='int'), + metadata=dict(type='dict'), + mode=dict(choices=['get', 'put', 'delete', 'create', 'geturl', 'getstr', 'delobj', 'list'], required=True), + object=dict(), + permission=dict(type='list', default=['private']), + version=dict(default=None), + overwrite=dict(aliases=['force'], default='always'), + prefix=dict(default=""), + retries=dict(aliases=['retry'], type='int', default=0), + s3_url=dict(aliases=['S3_URL']), + dualstack=dict(default='no', type='bool'), + rgw=dict(default='no', type='bool'), + src=dict(), + ignore_nonexistent_bucket=dict(default=False, type='bool'), + encryption_kms_key_id=dict() + ) + module = AnsibleAWSModule( + argument_spec=argument_spec, + supports_check_mode=True, + required_if=[['mode', 'put', ['src', 'object']], + ['mode', 'get', ['dest', 'object']], + ['mode', 'getstr', ['object']], + ['mode', 'geturl', ['object']]], + ) + + bucket = module.params.get('bucket') + encrypt = module.params.get('encrypt') + expiry = module.params.get('expiry') + dest = module.params.get('dest', '') + headers = module.params.get('headers') + marker = module.params.get('marker') + max_keys = module.params.get('max_keys') + metadata = module.params.get('metadata') + mode = module.params.get('mode') + obj = module.params.get('object') + version = module.params.get('version') + overwrite = module.params.get('overwrite') + prefix = module.params.get('prefix') + retries = module.params.get('retries') + s3_url = module.params.get('s3_url') + dualstack = module.params.get('dualstack') + rgw = module.params.get('rgw') + src = module.params.get('src') + ignore_nonexistent_bucket = module.params.get('ignore_nonexistent_bucket') + + object_canned_acl = ["private", "public-read", "public-read-write", "aws-exec-read", "authenticated-read", "bucket-owner-read", "bucket-owner-full-control"] + bucket_canned_acl = ["private", "public-read", "public-read-write", "authenticated-read"] + + if overwrite not in ['always', 'never', 'different']: + if module.boolean(overwrite): + overwrite = 'always' + else: + overwrite = 'never' + + if overwrite == 'different' and not HAS_MD5: + module.fail_json(msg='overwrite=different is unavailable: ETag calculation requires MD5 support') + + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True) + + if region in ('us-east-1', '', None): + # default to US Standard region + location = 'us-east-1' + else: + # Boto uses symbolic names for locations but region strings will + # actually work fine for everything except us-east-1 (US Standard) + location = region + + if module.params.get('object'): + obj = module.params['object'] + # If there is a top level object, do nothing - if the object starts with / + # remove the leading character to maintain compatibility with Ansible versions < 2.4 + if obj.startswith('/'): + obj = obj[1:] + + # Bucket deletion does not require obj. Prevents ambiguity with delobj. + if obj and mode == "delete": + module.fail_json(msg='Parameter obj cannot be used with mode=delete') + + # allow eucarc environment variables to be used if ansible vars aren't set + if not s3_url and 'S3_URL' in os.environ: + s3_url = os.environ['S3_URL'] + + if dualstack and s3_url is not None and 'amazonaws.com' not in s3_url: + module.fail_json(msg='dualstack only applies to AWS S3') + + if dualstack and not module.botocore_at_least('1.4.45'): + module.fail_json(msg='dualstack requires botocore >= 1.4.45') + + # rgw requires an explicit url + if rgw and not s3_url: + module.fail_json(msg='rgw flavour requires s3_url') + + # Look at s3_url and tweak connection settings + # if connecting to RGW, Walrus or fakes3 + if s3_url: + for key in ['validate_certs', 'security_token', 'profile_name']: + aws_connect_kwargs.pop(key, None) + s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url) + + validate = not ignore_nonexistent_bucket + + # separate types of ACLs + bucket_acl = [acl for acl in module.params.get('permission') if acl in bucket_canned_acl] + object_acl = [acl for acl in module.params.get('permission') if acl in object_canned_acl] + error_acl = [acl for acl in module.params.get('permission') if acl not in bucket_canned_acl and acl not in object_canned_acl] + if error_acl: + module.fail_json(msg='Unknown permission specified: %s' % error_acl) + + # First, we check to see if the bucket exists, we get "bucket" returned. + bucketrtn = bucket_check(module, s3, bucket, validate=validate) + + if validate and mode not in ('create', 'put', 'delete') and not bucketrtn: + module.fail_json(msg="Source bucket cannot be found.") + + if mode == 'get': + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + if keyrtn is False: + if version: + module.fail_json(msg="Key %s with version id %s does not exist." % (obj, version)) + else: + module.fail_json(msg="Key %s does not exist." % obj) + + if path_check(dest) and overwrite != 'always': + if overwrite == 'never': + module.exit_json(msg="Local object already exists and overwrite is disabled.", changed=False) + if etag_compare(module, dest, s3, bucket, obj, version=version): + module.exit_json(msg="Local and remote object are identical, ignoring. Use overwrite=always parameter to force.", changed=False) + + try: + download_s3file(module, s3, bucket, obj, dest, retries, version=version) + except Sigv4Required: + s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=True) + download_s3file(module, s3, bucket, obj, dest, retries, version=version) + + if mode == 'put': + + # if putting an object in a bucket yet to be created, acls for the bucket and/or the object may be specified + # these were separated into the variables bucket_acl and object_acl above + + if not path_check(src): + module.fail_json(msg="Local object for PUT does not exist") + + if bucketrtn: + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + else: + # If the bucket doesn't exist we should create it. + # only use valid bucket acls for create_bucket function + module.params['permission'] = bucket_acl + create_bucket(module, s3, bucket, location) + + if keyrtn and overwrite != 'always': + if overwrite == 'never' or etag_compare(module, src, s3, bucket, obj): + # Return the download URL for the existing object + get_download_url(module, s3, bucket, obj, expiry, changed=False) + + # only use valid object acls for the upload_s3file function + module.params['permission'] = object_acl + upload_s3file(module, s3, bucket, obj, src, expiry, metadata, encrypt, headers) + + # Delete an object from a bucket, not the entire bucket + if mode == 'delobj': + if obj is None: + module.fail_json(msg="object parameter is required") + if bucket: + deletertn = delete_key(module, s3, bucket, obj) + if deletertn is True: + module.exit_json(msg="Object deleted from bucket %s." % bucket, changed=True) + else: + module.fail_json(msg="Bucket parameter is required.") + + # Delete an entire bucket, including all objects in the bucket + if mode == 'delete': + if bucket: + deletertn = delete_bucket(module, s3, bucket) + if deletertn is True: + module.exit_json(msg="Bucket %s and all keys have been deleted." % bucket, changed=True) + else: + module.fail_json(msg="Bucket parameter is required.") + + # Support for listing a set of keys + if mode == 'list': + exists = bucket_check(module, s3, bucket) + + # If the bucket does not exist then bail out + if not exists: + module.fail_json(msg="Target bucket (%s) cannot be found" % bucket) + + list_keys(module, s3, bucket, prefix, marker, max_keys) + + # Need to research how to create directories without "populating" a key, so this should just do bucket creation for now. + # WE SHOULD ENABLE SOME WAY OF CREATING AN EMPTY KEY TO CREATE "DIRECTORY" STRUCTURE, AWS CONSOLE DOES THIS. + if mode == 'create': + + # if both creating a bucket and putting an object in it, acls for the bucket and/or the object may be specified + # these were separated above into the variables bucket_acl and object_acl + + if bucket and not obj: + if bucketrtn: + module.exit_json(msg="Bucket already exists.", changed=False) + else: + # only use valid bucket acls when creating the bucket + module.params['permission'] = bucket_acl + module.exit_json(msg="Bucket created successfully", changed=create_bucket(module, s3, bucket, location)) + if bucket and obj: + if obj.endswith('/'): + dirobj = obj + else: + dirobj = obj + "/" + if bucketrtn: + if key_check(module, s3, bucket, dirobj): + module.exit_json(msg="Bucket %s and key %s already exists." % (bucket, obj), changed=False) + else: + # setting valid object acls for the create_dirkey function + module.params['permission'] = object_acl + create_dirkey(module, s3, bucket, dirobj, encrypt) + else: + # only use valid bucket acls for the create_bucket function + module.params['permission'] = bucket_acl + created = create_bucket(module, s3, bucket, location) + # only use valid object acls for the create_dirkey function + module.params['permission'] = object_acl + create_dirkey(module, s3, bucket, dirobj, encrypt) + + # Support for grabbing the time-expired URL for an object in S3/Walrus. + if mode == 'geturl': + if not bucket and not obj: + module.fail_json(msg="Bucket and Object parameters must be set") + + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + if keyrtn: + get_download_url(module, s3, bucket, obj, expiry) + else: + module.fail_json(msg="Key %s does not exist." % obj) + + if mode == 'getstr': + if bucket and obj: + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + if keyrtn: + try: + download_s3str(module, s3, bucket, obj, version=version) + except Sigv4Required: + s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=True) + download_s3str(module, s3, bucket, obj, version=version) + elif version is not None: + module.fail_json(msg="Key %s with version id %s does not exist." % (obj, version)) + else: + module.fail_json(msg="Key %s does not exist." % obj) + + module.exit_json(failed=False) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_appserviceplan.py b/test/support/integration/plugins/modules/azure_rm_appserviceplan.py new file mode 100644 index 00000000..ee871c35 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_appserviceplan.py @@ -0,0 +1,379 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_appserviceplan +version_added: "2.7" +short_description: Manage App Service Plan +description: + - Create, update and delete instance of App Service Plan. + +options: + resource_group: + description: + - Name of the resource group to which the resource belongs. + required: True + + name: + description: + - Unique name of the app service plan to create or update. + required: True + + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + + sku: + description: + - The pricing tiers, e.g., C(F1), C(D1), C(B1), C(B2), C(B3), C(S1), C(P1), C(P1V2) etc. + - Please see U(https://azure.microsoft.com/en-us/pricing/details/app-service/plans/) for more detail. + - For Linux app service plan, please see U(https://azure.microsoft.com/en-us/pricing/details/app-service/linux/) for more detail. + is_linux: + description: + - Describe whether to host webapp on Linux worker. + type: bool + default: false + + number_of_workers: + description: + - Describe number of workers to be allocated. + + state: + description: + - Assert the state of the app service plan. + - Use C(present) to create or update an app service plan and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Yunge Zhu (@yungezz) + +''' + +EXAMPLES = ''' + - name: Create a windows app service plan + azure_rm_appserviceplan: + resource_group: myResourceGroup + name: myAppPlan + location: eastus + sku: S1 + + - name: Create a linux app service plan + azure_rm_appserviceplan: + resource_group: myResourceGroup + name: myAppPlan + location: eastus + sku: S1 + is_linux: true + number_of_workers: 1 + + - name: update sku of existing windows app service plan + azure_rm_appserviceplan: + resource_group: myResourceGroup + name: myAppPlan + location: eastus + sku: S2 +''' + +RETURN = ''' +azure_appserviceplan: + description: Facts about the current state of the app service plan. + returned: always + type: dict + sample: { + "id": "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/myAppPlan" + } +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrestazure.azure_operation import AzureOperationPoller + from msrest.serialization import Model + from azure.mgmt.web.models import ( + app_service_plan, AppServicePlan, SkuDescription + ) +except ImportError: + # This is handled in azure_rm_common + pass + + +def _normalize_sku(sku): + if sku is None: + return sku + + sku = sku.upper() + if sku == 'FREE': + return 'F1' + elif sku == 'SHARED': + return 'D1' + return sku + + +def get_sku_name(tier): + tier = tier.upper() + if tier == 'F1' or tier == "FREE": + return 'FREE' + elif tier == 'D1' or tier == "SHARED": + return 'SHARED' + elif tier in ['B1', 'B2', 'B3', 'BASIC']: + return 'BASIC' + elif tier in ['S1', 'S2', 'S3']: + return 'STANDARD' + elif tier in ['P1', 'P2', 'P3']: + return 'PREMIUM' + elif tier in ['P1V2', 'P2V2', 'P3V2']: + return 'PREMIUMV2' + else: + return None + + +def appserviceplan_to_dict(plan): + return dict( + id=plan.id, + name=plan.name, + kind=plan.kind, + location=plan.location, + reserved=plan.reserved, + is_linux=plan.reserved, + provisioning_state=plan.provisioning_state, + status=plan.status, + target_worker_count=plan.target_worker_count, + sku=dict( + name=plan.sku.name, + size=plan.sku.size, + tier=plan.sku.tier, + family=plan.sku.family, + capacity=plan.sku.capacity + ), + resource_group=plan.resource_group, + number_of_sites=plan.number_of_sites, + tags=plan.tags if plan.tags else None + ) + + +class AzureRMAppServicePlans(AzureRMModuleBase): + """Configuration class for an Azure RM App Service Plan resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + location=dict( + type='str' + ), + sku=dict( + type='str' + ), + is_linux=dict( + type='bool', + default=False + ), + number_of_workers=dict( + type='str' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.name = None + self.location = None + + self.sku = None + self.is_linux = None + self.number_of_workers = 1 + + self.tags = None + + self.results = dict( + changed=False, + ansible_facts=dict(azure_appserviceplan=None) + ) + self.state = None + + super(AzureRMAppServicePlans, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if kwargs[key]: + setattr(self, key, kwargs[key]) + + old_response = None + response = None + to_be_updated = False + + # set location + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + self.location = resource_group.location + + # get app service plan + old_response = self.get_plan() + + # if not existing + if not old_response: + self.log("App Service plan doesn't exist") + + if self.state == "present": + to_be_updated = True + + if not self.sku: + self.fail('Please specify sku in plan when creation') + + else: + # existing app service plan, do update + self.log("App Service Plan already exists") + + if self.state == 'present': + self.log('Result: {0}'.format(old_response)) + + update_tags, newtags = self.update_tags(old_response.get('tags', dict())) + + if update_tags: + to_be_updated = True + self.tags = newtags + + # check if sku changed + if self.sku and _normalize_sku(self.sku) != old_response['sku']['size']: + to_be_updated = True + + # check if number_of_workers changed + if self.number_of_workers and int(self.number_of_workers) != old_response['sku']['capacity']: + to_be_updated = True + + if self.is_linux and self.is_linux != old_response['reserved']: + self.fail("Operation not allowed: cannot update reserved of app service plan.") + + if old_response: + self.results['id'] = old_response['id'] + + if to_be_updated: + self.log('Need to Create/Update app service plan') + self.results['changed'] = True + + if self.check_mode: + return self.results + + response = self.create_or_update_plan() + self.results['id'] = response['id'] + + if self.state == 'absent' and old_response: + self.log("Delete app service plan") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_plan() + + self.log('App service plan instance deleted') + + return self.results + + def get_plan(self): + ''' + Gets app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Get App Service Plan {0}".format(self.name)) + + try: + response = self.web_client.app_service_plans.get(self.resource_group, self.name) + if response: + self.log("Response : {0}".format(response)) + self.log("App Service Plan : {0} found".format(response.name)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + self.log("Didn't find app service plan {0} in resource group {1}".format(self.name, self.resource_group)) + + return False + + def create_or_update_plan(self): + ''' + Creates app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Create App Service Plan {0}".format(self.name)) + + try: + # normalize sku + sku = _normalize_sku(self.sku) + + sku_def = SkuDescription(tier=get_sku_name( + sku), name=sku, capacity=self.number_of_workers) + plan_def = AppServicePlan( + location=self.location, app_service_plan_name=self.name, sku=sku_def, reserved=self.is_linux, tags=self.tags if self.tags else None) + + response = self.web_client.app_service_plans.create_or_update(self.resource_group, self.name, plan_def) + + if isinstance(response, LROPoller) or isinstance(response, AzureOperationPoller): + response = self.get_poller_result(response) + + self.log("Response : {0}".format(response)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + self.fail("Failed to create app service plan {0} in resource group {1}: {2}".format(self.name, self.resource_group, str(ex))) + + def delete_plan(self): + ''' + Deletes specified App service plan in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the App service plan {0}".format(self.name)) + try: + response = self.web_client.app_service_plans.delete(resource_group_name=self.resource_group, + name=self.name) + except CloudError as e: + self.log('Error attempting to delete App service plan.') + self.fail( + "Error deleting the App service plan : {0}".format(str(e))) + + return True + + +def main(): + """Main execution""" + AzureRMAppServicePlans() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_functionapp.py b/test/support/integration/plugins/modules/azure_rm_functionapp.py new file mode 100644 index 00000000..0c372a88 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_functionapp.py @@ -0,0 +1,421 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Thomas Stringer <tomstr@microsoft.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_functionapp +version_added: "2.4" +short_description: Manage Azure Function Apps +description: + - Create, update or delete an Azure Function App. +options: + resource_group: + description: + - Name of resource group. + required: true + aliases: + - resource_group_name + name: + description: + - Name of the Azure Function App. + required: true + location: + description: + - Valid Azure location. Defaults to location of the resource group. + plan: + description: + - App service plan. + - It can be name of existing app service plan in same resource group as function app. + - It can be resource id of existing app service plan. + - Resource id. For example /subscriptions/<subs_id>/resourceGroups/<resource_group>/providers/Microsoft.Web/serverFarms/<plan_name>. + - It can be a dict which contains C(name), C(resource_group). + - C(name). Name of app service plan. + - C(resource_group). Resource group name of app service plan. + version_added: "2.8" + container_settings: + description: Web app container settings. + suboptions: + name: + description: + - Name of container. For example "imagename:tag". + registry_server_url: + description: + - Container registry server url. For example C(mydockerregistry.io). + registry_server_user: + description: + - The container registry server user name. + registry_server_password: + description: + - The container registry server password. + version_added: "2.8" + storage_account: + description: + - Name of the storage account to use. + required: true + aliases: + - storage + - storage_account_name + app_settings: + description: + - Dictionary containing application settings. + state: + description: + - Assert the state of the Function App. Use C(present) to create or update a Function App and C(absent) to delete. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Thomas Stringer (@trstringer) +''' + +EXAMPLES = ''' +- name: Create a function app + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + storage_account: myStorageAccount + +- name: Create a function app with app settings + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + storage_account: myStorageAccount + app_settings: + setting1: value1 + setting2: value2 + +- name: Create container based function app + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + storage_account: myStorageAccount + plan: + resource_group: myResourceGroup + name: myAppPlan + container_settings: + name: httpd + registry_server_url: index.docker.io + +- name: Delete a function app + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + state: absent +''' + +RETURN = ''' +state: + description: + - Current state of the Azure Function App. + returned: success + type: dict + example: + id: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myFunctionApp + name: myfunctionapp + kind: functionapp + location: East US + type: Microsoft.Web/sites + state: Running + host_names: + - myfunctionapp.azurewebsites.net + repository_site_name: myfunctionapp + usage_state: Normal + enabled: true + enabled_host_names: + - myfunctionapp.azurewebsites.net + - myfunctionapp.scm.azurewebsites.net + availability_state: Normal + host_name_ssl_states: + - name: myfunctionapp.azurewebsites.net + ssl_state: Disabled + host_type: Standard + - name: myfunctionapp.scm.azurewebsites.net + ssl_state: Disabled + host_type: Repository + server_farm_id: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/EastUSPlan + reserved: false + last_modified_time_utc: 2017-08-22T18:54:01.190Z + scm_site_also_stopped: false + client_affinity_enabled: true + client_cert_enabled: false + host_names_disabled: false + outbound_ip_addresses: ............ + container_size: 1536 + daily_memory_time_quota: 0 + resource_group: myResourceGroup + default_host_name: myfunctionapp.azurewebsites.net +''' # NOQA + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from azure.mgmt.web.models import ( + site_config, app_service_plan, Site, SiteConfig, NameValuePair, SiteSourceControl, + AppServicePlan, SkuDescription + ) + from azure.mgmt.resource.resources import ResourceManagementClient + from msrest.polling import LROPoller +except ImportError: + # This is handled in azure_rm_common + pass + +container_settings_spec = dict( + name=dict(type='str', required=True), + registry_server_url=dict(type='str'), + registry_server_user=dict(type='str'), + registry_server_password=dict(type='str', no_log=True) +) + + +class AzureRMFunctionApp(AzureRMModuleBase): + + def __init__(self): + + self.module_arg_spec = dict( + resource_group=dict(type='str', required=True, aliases=['resource_group_name']), + name=dict(type='str', required=True), + state=dict(type='str', default='present', choices=['present', 'absent']), + location=dict(type='str'), + storage_account=dict( + type='str', + aliases=['storage', 'storage_account_name'] + ), + app_settings=dict(type='dict'), + plan=dict( + type='raw' + ), + container_settings=dict( + type='dict', + options=container_settings_spec + ) + ) + + self.results = dict( + changed=False, + state=dict() + ) + + self.resource_group = None + self.name = None + self.state = None + self.location = None + self.storage_account = None + self.app_settings = None + self.plan = None + self.container_settings = None + + required_if = [('state', 'present', ['storage_account'])] + + super(AzureRMFunctionApp, self).__init__( + self.module_arg_spec, + supports_check_mode=True, + required_if=required_if + ) + + def exec_module(self, **kwargs): + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + if self.app_settings is None: + self.app_settings = dict() + + try: + resource_group = self.rm_client.resource_groups.get(self.resource_group) + except CloudError: + self.fail('Unable to retrieve resource group') + + self.location = self.location or resource_group.location + + try: + function_app = self.web_client.web_apps.get( + resource_group_name=self.resource_group, + name=self.name + ) + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + exists = function_app is not None + except CloudError as exc: + exists = False + + if self.state == 'absent': + if exists: + if self.check_mode: + self.results['changed'] = True + return self.results + try: + self.web_client.web_apps.delete( + resource_group_name=self.resource_group, + name=self.name + ) + self.results['changed'] = True + except CloudError as exc: + self.fail('Failure while deleting web app: {0}'.format(exc)) + else: + self.results['changed'] = False + else: + kind = 'functionapp' + linux_fx_version = None + if self.container_settings and self.container_settings.get('name'): + kind = 'functionapp,linux,container' + linux_fx_version = 'DOCKER|' + if self.container_settings.get('registry_server_url'): + self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url'] + linux_fx_version += self.container_settings['registry_server_url'] + '/' + linux_fx_version += self.container_settings['name'] + if self.container_settings.get('registry_server_user'): + self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings.get('registry_server_user') + + if self.container_settings.get('registry_server_password'): + self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings.get('registry_server_password') + + if not self.plan and function_app: + self.plan = function_app.server_farm_id + + if not exists: + function_app = Site( + location=self.location, + kind=kind, + site_config=SiteConfig( + app_settings=self.aggregated_app_settings(), + scm_type='LocalGit' + ) + ) + self.results['changed'] = True + else: + self.results['changed'], function_app = self.update(function_app) + + # get app service plan + if self.plan: + if isinstance(self.plan, dict): + self.plan = "/subscriptions/{0}/resourceGroups/{1}/providers/Microsoft.Web/serverfarms/{2}".format( + self.subscription_id, + self.plan.get('resource_group', self.resource_group), + self.plan.get('name') + ) + function_app.server_farm_id = self.plan + + # set linux fx version + if linux_fx_version: + function_app.site_config.linux_fx_version = linux_fx_version + + if self.check_mode: + self.results['state'] = function_app.as_dict() + elif self.results['changed']: + try: + new_function_app = self.web_client.web_apps.create_or_update( + resource_group_name=self.resource_group, + name=self.name, + site_envelope=function_app + ).result() + self.results['state'] = new_function_app.as_dict() + except CloudError as exc: + self.fail('Error creating or updating web app: {0}'.format(exc)) + + return self.results + + def update(self, source_function_app): + """Update the Site object if there are any changes""" + + source_app_settings = self.web_client.web_apps.list_application_settings( + resource_group_name=self.resource_group, + name=self.name + ) + + changed, target_app_settings = self.update_app_settings(source_app_settings.properties) + + source_function_app.site_config = SiteConfig( + app_settings=target_app_settings, + scm_type='LocalGit' + ) + + return changed, source_function_app + + def update_app_settings(self, source_app_settings): + """Update app settings""" + + target_app_settings = self.aggregated_app_settings() + target_app_settings_dict = dict([(i.name, i.value) for i in target_app_settings]) + return target_app_settings_dict != source_app_settings, target_app_settings + + def necessary_functionapp_settings(self): + """Construct the necessary app settings required for an Azure Function App""" + + function_app_settings = [] + + if self.container_settings is None: + for key in ['AzureWebJobsStorage', 'WEBSITE_CONTENTAZUREFILECONNECTIONSTRING', 'AzureWebJobsDashboard']: + function_app_settings.append(NameValuePair(name=key, value=self.storage_connection_string)) + function_app_settings.append(NameValuePair(name='FUNCTIONS_EXTENSION_VERSION', value='~1')) + function_app_settings.append(NameValuePair(name='WEBSITE_NODE_DEFAULT_VERSION', value='6.5.0')) + function_app_settings.append(NameValuePair(name='WEBSITE_CONTENTSHARE', value=self.name)) + else: + function_app_settings.append(NameValuePair(name='FUNCTIONS_EXTENSION_VERSION', value='~2')) + function_app_settings.append(NameValuePair(name='WEBSITES_ENABLE_APP_SERVICE_STORAGE', value=False)) + function_app_settings.append(NameValuePair(name='AzureWebJobsStorage', value=self.storage_connection_string)) + + return function_app_settings + + def aggregated_app_settings(self): + """Combine both system and user app settings""" + + function_app_settings = self.necessary_functionapp_settings() + for app_setting_key in self.app_settings: + found_setting = None + for s in function_app_settings: + if s.name == app_setting_key: + found_setting = s + break + if found_setting: + found_setting.value = self.app_settings[app_setting_key] + else: + function_app_settings.append(NameValuePair( + name=app_setting_key, + value=self.app_settings[app_setting_key] + )) + return function_app_settings + + @property + def storage_connection_string(self): + """Construct the storage account connection string""" + + return 'DefaultEndpointsProtocol=https;AccountName={0};AccountKey={1}'.format( + self.storage_account, + self.storage_key + ) + + @property + def storage_key(self): + """Retrieve the storage account key""" + + return self.storage_client.storage_accounts.list_keys( + resource_group_name=self.resource_group, + account_name=self.storage_account + ).keys[0].value + + +def main(): + """Main function execution""" + + AzureRMFunctionApp() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_functionapp_info.py b/test/support/integration/plugins/modules/azure_rm_functionapp_info.py new file mode 100644 index 00000000..40672f95 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_functionapp_info.py @@ -0,0 +1,207 @@ +#!/usr/bin/python +# +# Copyright (c) 2016 Thomas Stringer, <tomstr@microsoft.com> + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_functionapp_info +version_added: "2.9" +short_description: Get Azure Function App facts +description: + - Get facts for one Azure Function App or all Function Apps within a resource group. +options: + name: + description: + - Only show results for a specific Function App. + resource_group: + description: + - Limit results to a resource group. Required when filtering by name. + aliases: + - resource_group_name + tags: + description: + - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'. + +extends_documentation_fragment: + - azure + +author: + - Thomas Stringer (@trstringer) +''' + +EXAMPLES = ''' + - name: Get facts for one Function App + azure_rm_functionapp_info: + resource_group: myResourceGroup + name: myfunctionapp + + - name: Get facts for all Function Apps in a resource group + azure_rm_functionapp_info: + resource_group: myResourceGroup + + - name: Get facts for all Function Apps by tags + azure_rm_functionapp_info: + tags: + - testing +''' + +RETURN = ''' +azure_functionapps: + description: + - List of Azure Function Apps dicts. + returned: always + type: list + example: + id: /subscriptions/.../resourceGroups/ansible-rg/providers/Microsoft.Web/sites/myfunctionapp + name: myfunctionapp + kind: functionapp + location: East US + type: Microsoft.Web/sites + state: Running + host_names: + - myfunctionapp.azurewebsites.net + repository_site_name: myfunctionapp + usage_state: Normal + enabled: true + enabled_host_names: + - myfunctionapp.azurewebsites.net + - myfunctionapp.scm.azurewebsites.net + availability_state: Normal + host_name_ssl_states: + - name: myfunctionapp.azurewebsites.net + ssl_state: Disabled + host_type: Standard + - name: myfunctionapp.scm.azurewebsites.net + ssl_state: Disabled + host_type: Repository + server_farm_id: /subscriptions/.../resourceGroups/ansible-rg/providers/Microsoft.Web/serverfarms/EastUSPlan + reserved: false + last_modified_time_utc: 2017-08-22T18:54:01.190Z + scm_site_also_stopped: false + client_affinity_enabled: true + client_cert_enabled: false + host_names_disabled: false + outbound_ip_addresses: ............ + container_size: 1536 + daily_memory_time_quota: 0 + resource_group: myResourceGroup + default_host_name: myfunctionapp.azurewebsites.net +''' + +try: + from msrestazure.azure_exceptions import CloudError +except Exception: + # This is handled in azure_rm_common + pass + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + + +class AzureRMFunctionAppInfo(AzureRMModuleBase): + def __init__(self): + + self.module_arg_spec = dict( + name=dict(type='str'), + resource_group=dict(type='str', aliases=['resource_group_name']), + tags=dict(type='list'), + ) + + self.results = dict( + changed=False, + ansible_info=dict(azure_functionapps=[]) + ) + + self.name = None + self.resource_group = None + self.tags = None + + super(AzureRMFunctionAppInfo, self).__init__( + self.module_arg_spec, + supports_tags=False, + facts_module=True + ) + + def exec_module(self, **kwargs): + + is_old_facts = self.module._name == 'azure_rm_functionapp_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_functionapp_facts' module has been renamed to 'azure_rm_functionapp_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if self.name and not self.resource_group: + self.fail("Parameter error: resource group required when filtering by name.") + + if self.name: + self.results['ansible_info']['azure_functionapps'] = self.get_functionapp() + elif self.resource_group: + self.results['ansible_info']['azure_functionapps'] = self.list_resource_group() + else: + self.results['ansible_info']['azure_functionapps'] = self.list_all() + + return self.results + + def get_functionapp(self): + self.log('Get properties for Function App {0}'.format(self.name)) + function_app = None + result = [] + + try: + function_app = self.web_client.web_apps.get( + self.resource_group, + self.name + ) + except CloudError: + pass + + if function_app and self.has_tags(function_app.tags, self.tags): + result = function_app.as_dict() + + return [result] + + def list_resource_group(self): + self.log('List items') + try: + response = self.web_client.web_apps.list_by_resource_group(self.resource_group) + except Exception as exc: + self.fail("Error listing for resource group {0} - {1}".format(self.resource_group, str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + results.append(item.as_dict()) + return results + + def list_all(self): + self.log('List all items') + try: + response = self.web_client.web_apps.list_by_resource_group(self.resource_group) + except Exception as exc: + self.fail("Error listing all items - {0}".format(str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + results.append(item.as_dict()) + return results + + +def main(): + AzureRMFunctionAppInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py new file mode 100644 index 00000000..212cf795 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py @@ -0,0 +1,241 @@ +#!/usr/bin/python +# +# Copyright (c) 2019 Zim Kalinowski, (@zikalino) +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbconfiguration +version_added: "2.8" +short_description: Manage Configuration instance +description: + - Create, update and delete instance of Configuration. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. + required: True + server_name: + description: + - The name of the server. + required: True + name: + description: + - The name of the server configuration. + required: True + value: + description: + - Value of the configuration. + state: + description: + - Assert the state of the MariaDB configuration. Use C(present) to update setting, or C(absent) to reset to default value. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) +''' + +EXAMPLES = ''' + - name: Update SQL Server setting + azure_rm_mariadbconfiguration: + resource_group: myResourceGroup + server_name: myServer + name: event_scheduler + value: "ON" +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/myServer/confi + gurations/event_scheduler" +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from azure.mgmt.rdbms.mysql import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbConfiguration(AzureRMModuleBase): + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + value=dict( + type='str' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.server_name = None + self.name = None + self.value = None + + self.results = dict(changed=False) + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbConfiguration, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=False) + + def exec_module(self, **kwargs): + + for key in list(self.module_arg_spec.keys()): + if hasattr(self, key): + setattr(self, key, kwargs[key]) + + old_response = None + response = None + + old_response = self.get_configuration() + + if not old_response: + self.log("Configuration instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("Configuration instance already exists") + if self.state == 'absent' and old_response['source'] == 'user-override': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if Configuration instance has to be deleted or may be updated") + if self.value != old_response.get('value'): + self.to_do = Actions.Update + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the Configuration instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_configuration() + + self.results['changed'] = True + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("Configuration instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_configuration() + else: + self.log("Configuration instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + + return self.results + + def create_update_configuration(self): + self.log("Creating / Updating the Configuration instance {0}".format(self.name)) + + try: + response = self.mariadb_client.configurations.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name, + value=self.value, + source='user-override') + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the Configuration instance.') + self.fail("Error creating the Configuration instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_configuration(self): + self.log("Deleting the Configuration instance {0}".format(self.name)) + try: + response = self.mariadb_client.configurations.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name, + source='system-default') + except CloudError as e: + self.log('Error attempting to delete the Configuration instance.') + self.fail("Error deleting the Configuration instance: {0}".format(str(e))) + + return True + + def get_configuration(self): + self.log("Checking if the Configuration instance {0} is present".format(self.name)) + found = False + try: + response = self.mariadb_client.configurations.get(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("Configuration instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the Configuration instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbConfiguration() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py new file mode 100644 index 00000000..3faac5eb --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py @@ -0,0 +1,217 @@ +#!/usr/bin/python +# +# Copyright (c) 2019 Zim Kalinowski, (@zikalino) +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbconfiguration_info +version_added: "2.9" +short_description: Get Azure MariaDB Configuration facts +description: + - Get facts of Azure MariaDB Configuration. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + type: str + server_name: + description: + - The name of the server. + required: True + type: str + name: + description: + - Setting name. + type: str + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get specific setting of MariaDB Server + azure_rm_mariadbconfiguration_info: + resource_group: myResourceGroup + server_name: testserver + name: deadlock_timeout + + - name: Get all settings of MariaDB Server + azure_rm_mariadbconfiguration_info: + resource_group: myResourceGroup + server_name: server_name +''' + +RETURN = ''' +settings: + description: + - A list of dictionaries containing MariaDB Server settings. + returned: always + type: complex + contains: + id: + description: + - Setting resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver + /configurations/deadlock_timeout" + name: + description: + - Setting name. + returned: always + type: str + sample: deadlock_timeout + value: + description: + - Setting value. + returned: always + type: raw + sample: 1000 + description: + description: + - Description of the configuration. + returned: always + type: str + sample: Deadlock timeout. + source: + description: + - Source of the configuration. + returned: always + type: str + sample: system-default +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_operation import AzureOperationPoller + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbConfigurationInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict(changed=False) + self.mgmt_client = None + self.resource_group = None + self.server_name = None + self.name = None + super(AzureRMMariaDbConfigurationInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbconfiguration_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbconfiguration_facts' module has been renamed to 'azure_rm_mariadbconfiguration_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if self.name is not None: + self.results['settings'] = self.get() + else: + self.results['settings'] = self.list_by_server() + return self.results + + def get(self): + ''' + Gets facts of the specified MariaDB Configuration. + + :return: deserialized MariaDB Configurationinstance state dictionary + ''' + response = None + results = [] + try: + response = self.mgmt_client.configurations.get(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for Configurations.') + + if response is not None: + results.append(self.format_item(response)) + + return results + + def list_by_server(self): + ''' + Gets facts of the specified MariaDB Configuration. + + :return: deserialized MariaDB Configurationinstance state dictionary + ''' + response = None + results = [] + try: + response = self.mgmt_client.configurations.list_by_server(resource_group_name=self.resource_group, + server_name=self.server_name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for Configurations.') + + if response is not None: + for item in response: + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'resource_group': self.resource_group, + 'server_name': self.server_name, + 'id': d['id'], + 'name': d['name'], + 'value': d['value'], + 'description': d['description'], + 'source': d['source'] + } + return d + + +def main(): + AzureRMMariaDbConfigurationInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py new file mode 100644 index 00000000..8492b968 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py @@ -0,0 +1,304 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbdatabase +version_added: "2.8" +short_description: Manage MariaDB Database instance +description: + - Create, update and delete instance of MariaDB Database. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + server_name: + description: + - The name of the server. + required: True + name: + description: + - The name of the database. + required: True + charset: + description: + - The charset of the database. Check MariaDB documentation for possible values. + - This is only set on creation, use I(force_update) to recreate a database if the values don't match. + collation: + description: + - The collation of the database. Check MariaDB documentation for possible values. + - This is only set on creation, use I(force_update) to recreate a database if the values don't match. + force_update: + description: + - When set to C(true), will delete and recreate the existing MariaDB database if any of the properties don't match what is set. + - When set to C(false), no change will occur to the database even if any of the properties do not match. + type: bool + default: 'no' + state: + description: + - Assert the state of the MariaDB Database. Use C(present) to create or update a database and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Create (or update) MariaDB Database + azure_rm_mariadbdatabase: + resource_group: myResourceGroup + server_name: testserver + name: db1 +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver/databases/db1 +name: + description: + - Resource name. + returned: always + type: str + sample: db1 +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbDatabase(AzureRMModuleBase): + """Configuration class for an Azure RM MariaDB Database resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + charset=dict( + type='str' + ), + collation=dict( + type='str' + ), + force_update=dict( + type='bool', + default=False + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.server_name = None + self.name = None + self.force_update = None + self.parameters = dict() + + self.results = dict(changed=False) + self.mgmt_client = None + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbDatabase, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=False) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()): + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "charset": + self.parameters["charset"] = kwargs[key] + elif key == "collation": + self.parameters["collation"] = kwargs[key] + + old_response = None + response = None + + self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + resource_group = self.get_resource_group(self.resource_group) + + old_response = self.get_mariadbdatabase() + + if not old_response: + self.log("MariaDB Database instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("MariaDB Database instance already exists") + if self.state == 'absent': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if MariaDB Database instance has to be deleted or may be updated") + if ('collation' in self.parameters) and (self.parameters['collation'] != old_response['collation']): + self.to_do = Actions.Update + if ('charset' in self.parameters) and (self.parameters['charset'] != old_response['charset']): + self.to_do = Actions.Update + if self.to_do == Actions.Update: + if self.force_update: + if not self.check_mode: + self.delete_mariadbdatabase() + else: + self.fail("Database properties cannot be updated without setting 'force_update' option") + self.to_do = Actions.NoAction + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the MariaDB Database instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_mariadbdatabase() + self.results['changed'] = True + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("MariaDB Database instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_mariadbdatabase() + # make sure instance is actually deleted, for some Azure resources, instance is hanging around + # for some time after deletion -- this should be really fixed in Azure + while self.get_mariadbdatabase(): + time.sleep(20) + else: + self.log("MariaDB Database instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + self.results["name"] = response["name"] + + return self.results + + def create_update_mariadbdatabase(self): + ''' + Creates or updates MariaDB Database with the specified configuration. + + :return: deserialized MariaDB Database instance state dictionary + ''' + self.log("Creating / Updating the MariaDB Database instance {0}".format(self.name)) + + try: + response = self.mgmt_client.databases.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name, + parameters=self.parameters) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the MariaDB Database instance.') + self.fail("Error creating the MariaDB Database instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_mariadbdatabase(self): + ''' + Deletes specified MariaDB Database instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the MariaDB Database instance {0}".format(self.name)) + try: + response = self.mgmt_client.databases.delete(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name) + except CloudError as e: + self.log('Error attempting to delete the MariaDB Database instance.') + self.fail("Error deleting the MariaDB Database instance: {0}".format(str(e))) + + return True + + def get_mariadbdatabase(self): + ''' + Gets the properties of the specified MariaDB Database. + + :return: deserialized MariaDB Database instance state dictionary + ''' + self.log("Checking if the MariaDB Database instance {0} is present".format(self.name)) + found = False + try: + response = self.mgmt_client.databases.get(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("MariaDB Database instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the MariaDB Database instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbDatabase() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py new file mode 100644 index 00000000..e9c99c14 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py @@ -0,0 +1,212 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbdatabase_info +version_added: "2.9" +short_description: Get Azure MariaDB Database facts +description: + - Get facts of MariaDB Database. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + type: str + server_name: + description: + - The name of the server. + required: True + type: str + name: + description: + - The name of the database. + type: str + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get instance of MariaDB Database + azure_rm_mariadbdatabase_info: + resource_group: myResourceGroup + server_name: server_name + name: database_name + + - name: List instances of MariaDB Database + azure_rm_mariadbdatabase_info: + resource_group: myResourceGroup + server_name: server_name +''' + +RETURN = ''' +databases: + description: + - A list of dictionaries containing facts for MariaDB Databases. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testser + ver/databases/db1" + resource_group: + description: + - Resource group name. + returned: always + type: str + sample: testrg + server_name: + description: + - Server name. + returned: always + type: str + sample: testserver + name: + description: + - Resource name. + returned: always + type: str + sample: db1 + charset: + description: + - The charset of the database. + returned: always + type: str + sample: UTF8 + collation: + description: + - The collation of the database. + returned: always + type: str + sample: English_United States.1252 +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbDatabaseInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict( + changed=False + ) + self.resource_group = None + self.server_name = None + self.name = None + super(AzureRMMariaDbDatabaseInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbdatabase_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbdatabase_facts' module has been renamed to 'azure_rm_mariadbdatabase_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if (self.resource_group is not None and + self.server_name is not None and + self.name is not None): + self.results['databases'] = self.get() + elif (self.resource_group is not None and + self.server_name is not None): + self.results['databases'] = self.list_by_server() + return self.results + + def get(self): + response = None + results = [] + try: + response = self.mariadb_client.databases.get(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for Databases.') + + if response is not None: + results.append(self.format_item(response)) + + return results + + def list_by_server(self): + response = None + results = [] + try: + response = self.mariadb_client.databases.list_by_server(resource_group_name=self.resource_group, + server_name=self.server_name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.fail("Error listing for server {0} - {1}".format(self.server_name, str(e))) + + if response is not None: + for item in response: + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'resource_group': self.resource_group, + 'server_name': self.server_name, + 'name': d['name'], + 'charset': d['charset'], + 'collation': d['collation'] + } + return d + + +def main(): + AzureRMMariaDbDatabaseInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py new file mode 100644 index 00000000..1fc8c5e7 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py @@ -0,0 +1,277 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbfirewallrule +version_added: "2.8" +short_description: Manage MariaDB firewall rule instance +description: + - Create, update and delete instance of MariaDB firewall rule. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + server_name: + description: + - The name of the server. + required: True + name: + description: + - The name of the MariaDB firewall rule. + required: True + start_ip_address: + description: + - The start IP address of the MariaDB firewall rule. Must be IPv4 format. + end_ip_address: + description: + - The end IP address of the MariaDB firewall rule. Must be IPv4 format. + state: + description: + - Assert the state of the MariaDB firewall rule. Use C(present) to create or update a rule and C(absent) to ensure it is not present. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Create (or update) MariaDB firewall rule + azure_rm_mariadbfirewallrule: + resource_group: myResourceGroup + server_name: testserver + name: rule1 + start_ip_address: 10.0.0.17 + end_ip_address: 10.0.0.20 +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver/fire + wallRules/rule1" +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbFirewallRule(AzureRMModuleBase): + """Configuration class for an Azure RM MariaDB firewall rule resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + start_ip_address=dict( + type='str' + ), + end_ip_address=dict( + type='str' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.server_name = None + self.name = None + self.start_ip_address = None + self.end_ip_address = None + + self.results = dict(changed=False) + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbFirewallRule, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=False) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()): + if hasattr(self, key): + setattr(self, key, kwargs[key]) + + old_response = None + response = None + + resource_group = self.get_resource_group(self.resource_group) + + old_response = self.get_firewallrule() + + if not old_response: + self.log("MariaDB firewall rule instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("MariaDB firewall rule instance already exists") + if self.state == 'absent': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if MariaDB firewall rule instance has to be deleted or may be updated") + if (self.start_ip_address is not None) and (self.start_ip_address != old_response['start_ip_address']): + self.to_do = Actions.Update + if (self.end_ip_address is not None) and (self.end_ip_address != old_response['end_ip_address']): + self.to_do = Actions.Update + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the MariaDB firewall rule instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_firewallrule() + + if not old_response: + self.results['changed'] = True + else: + self.results['changed'] = old_response.__ne__(response) + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("MariaDB firewall rule instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_firewallrule() + # make sure instance is actually deleted, for some Azure resources, instance is hanging around + # for some time after deletion -- this should be really fixed in Azure + while self.get_firewallrule(): + time.sleep(20) + else: + self.log("MariaDB firewall rule instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + + return self.results + + def create_update_firewallrule(self): + ''' + Creates or updates MariaDB firewall rule with the specified configuration. + + :return: deserialized MariaDB firewall rule instance state dictionary + ''' + self.log("Creating / Updating the MariaDB firewall rule instance {0}".format(self.name)) + + try: + response = self.mariadb_client.firewall_rules.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name, + start_ip_address=self.start_ip_address, + end_ip_address=self.end_ip_address) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the MariaDB firewall rule instance.') + self.fail("Error creating the MariaDB firewall rule instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_firewallrule(self): + ''' + Deletes specified MariaDB firewall rule instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the MariaDB firewall rule instance {0}".format(self.name)) + try: + response = self.mariadb_client.firewall_rules.delete(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name) + except CloudError as e: + self.log('Error attempting to delete the MariaDB firewall rule instance.') + self.fail("Error deleting the MariaDB firewall rule instance: {0}".format(str(e))) + + return True + + def get_firewallrule(self): + ''' + Gets the properties of the specified MariaDB firewall rule. + + :return: deserialized MariaDB firewall rule instance state dictionary + ''' + self.log("Checking if the MariaDB firewall rule instance {0} is present".format(self.name)) + found = False + try: + response = self.mariadb_client.firewall_rules.get(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("MariaDB firewall rule instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the MariaDB firewall rule instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbFirewallRule() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py new file mode 100644 index 00000000..ef71be8d --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py @@ -0,0 +1,208 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbfirewallrule_info +version_added: "2.9" +short_description: Get Azure MariaDB Firewall Rule facts +description: + - Get facts of Azure MariaDB Firewall Rule. + +options: + resource_group: + description: + - The name of the resource group. + required: True + type: str + server_name: + description: + - The name of the server. + required: True + type: str + name: + description: + - The name of the server firewall rule. + type: str + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get instance of MariaDB Firewall Rule + azure_rm_mariadbfirewallrule_info: + resource_group: myResourceGroup + server_name: server_name + name: firewall_rule_name + + - name: List instances of MariaDB Firewall Rule + azure_rm_mariadbfirewallrule_info: + resource_group: myResourceGroup + server_name: server_name +''' + +RETURN = ''' +rules: + description: + - A list of dictionaries containing facts for MariaDB Firewall Rule. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/TestGroup/providers/Microsoft.DBforMariaDB/servers/testserver/fire + wallRules/rule1" + server_name: + description: + - The name of the server. + returned: always + type: str + sample: testserver + name: + description: + - Resource name. + returned: always + type: str + sample: rule1 + start_ip_address: + description: + - The start IP address of the MariaDB firewall rule. + returned: always + type: str + sample: 10.0.0.16 + end_ip_address: + description: + - The end IP address of the MariaDB firewall rule. + returned: always + type: str + sample: 10.0.0.18 +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_operation import AzureOperationPoller + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbFirewallRuleInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict( + changed=False + ) + self.mgmt_client = None + self.resource_group = None + self.server_name = None + self.name = None + super(AzureRMMariaDbFirewallRuleInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbfirewallrule_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbfirewallrule_facts' module has been renamed to 'azure_rm_mariadbfirewallrule_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if (self.name is not None): + self.results['rules'] = self.get() + else: + self.results['rules'] = self.list_by_server() + return self.results + + def get(self): + response = None + results = [] + try: + response = self.mgmt_client.firewall_rules.get(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for FirewallRules.') + + if response is not None: + results.append(self.format_item(response)) + + return results + + def list_by_server(self): + response = None + results = [] + try: + response = self.mgmt_client.firewall_rules.list_by_server(resource_group_name=self.resource_group, + server_name=self.server_name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for FirewallRules.') + + if response is not None: + for item in response: + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'resource_group': self.resource_group, + 'id': d['id'], + 'server_name': self.server_name, + 'name': d['name'], + 'start_ip_address': d['start_ip_address'], + 'end_ip_address': d['end_ip_address'] + } + return d + + +def main(): + AzureRMMariaDbFirewallRuleInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbserver.py b/test/support/integration/plugins/modules/azure_rm_mariadbserver.py new file mode 100644 index 00000000..30a29988 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbserver.py @@ -0,0 +1,388 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbserver +version_added: "2.8" +short_description: Manage MariaDB Server instance +description: + - Create, update and delete instance of MariaDB Server. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + name: + description: + - The name of the server. + required: True + sku: + description: + - The SKU (pricing tier) of the server. + suboptions: + name: + description: + - The name of the SKU, typically, tier + family + cores, for example C(B_Gen4_1), C(GP_Gen5_8). + tier: + description: + - The tier of the particular SKU, for example C(Basic). + choices: + - basic + - standard + capacity: + description: + - The scale up/out capacity, representing server's compute units. + type: int + size: + description: + - The size code, to be interpreted by resource as appropriate. + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + storage_mb: + description: + - The maximum storage allowed for a server. + type: int + version: + description: + - Server version. + choices: + - 10.2 + enforce_ssl: + description: + - Enable SSL enforcement. + type: bool + default: False + admin_username: + description: + - The administrator's login name of a server. Can only be specified when the server is being created (and is required for creation). + admin_password: + description: + - The password of the administrator login. + create_mode: + description: + - Create mode of SQL Server. + default: Default + state: + description: + - Assert the state of the MariaDB Server. Use C(present) to create or update a server and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Create (or update) MariaDB Server + azure_rm_mariadbserver: + resource_group: myResourceGroup + name: testserver + sku: + name: B_Gen5_1 + tier: Basic + location: eastus + storage_mb: 1024 + enforce_ssl: True + version: 10.2 + admin_username: cloudsa + admin_password: password +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/mariadbsrv1b6dd89593 +version: + description: + - Server version. Possible values include C(10.2). + returned: always + type: str + sample: 10.2 +state: + description: + - A state of a server that is visible to user. Possible values include C(Ready), C(Dropping), C(Disabled). + returned: always + type: str + sample: Ready +fully_qualified_domain_name: + description: + - The fully qualified domain name of a server. + returned: always + type: str + sample: mariadbsrv1b6dd89593.mariadb.database.azure.com +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbServers(AzureRMModuleBase): + """Configuration class for an Azure RM MariaDB Server resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + sku=dict( + type='dict' + ), + location=dict( + type='str' + ), + storage_mb=dict( + type='int' + ), + version=dict( + type='str', + choices=['10.2'] + ), + enforce_ssl=dict( + type='bool', + default=False + ), + create_mode=dict( + type='str', + default='Default' + ), + admin_username=dict( + type='str' + ), + admin_password=dict( + type='str', + no_log=True + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.name = None + self.parameters = dict() + self.tags = None + + self.results = dict(changed=False) + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbServers, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "sku": + ev = kwargs[key] + if 'tier' in ev: + if ev['tier'] == 'basic': + ev['tier'] = 'Basic' + elif ev['tier'] == 'standard': + ev['tier'] = 'Standard' + self.parameters["sku"] = ev + elif key == "location": + self.parameters["location"] = kwargs[key] + elif key == "storage_mb": + self.parameters.setdefault("properties", {}).setdefault("storage_profile", {})["storage_mb"] = kwargs[key] + elif key == "version": + self.parameters.setdefault("properties", {})["version"] = kwargs[key] + elif key == "enforce_ssl": + self.parameters.setdefault("properties", {})["ssl_enforcement"] = 'Enabled' if kwargs[key] else 'Disabled' + elif key == "create_mode": + self.parameters.setdefault("properties", {})["create_mode"] = kwargs[key] + elif key == "admin_username": + self.parameters.setdefault("properties", {})["administrator_login"] = kwargs[key] + elif key == "admin_password": + self.parameters.setdefault("properties", {})["administrator_login_password"] = kwargs[key] + + old_response = None + response = None + + resource_group = self.get_resource_group(self.resource_group) + + if "location" not in self.parameters: + self.parameters["location"] = resource_group.location + + old_response = self.get_mariadbserver() + + if not old_response: + self.log("MariaDB Server instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("MariaDB Server instance already exists") + if self.state == 'absent': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if MariaDB Server instance has to be deleted or may be updated") + update_tags, newtags = self.update_tags(old_response.get('tags', {})) + if update_tags: + self.tags = newtags + self.to_do = Actions.Update + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the MariaDB Server instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_mariadbserver() + + if not old_response: + self.results['changed'] = True + else: + self.results['changed'] = old_response.__ne__(response) + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("MariaDB Server instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_mariadbserver() + # make sure instance is actually deleted, for some Azure resources, instance is hanging around + # for some time after deletion -- this should be really fixed in Azure + while self.get_mariadbserver(): + time.sleep(20) + else: + self.log("MariaDB Server instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + self.results["version"] = response["version"] + self.results["state"] = response["user_visible_state"] + self.results["fully_qualified_domain_name"] = response["fully_qualified_domain_name"] + + return self.results + + def create_update_mariadbserver(self): + ''' + Creates or updates MariaDB Server with the specified configuration. + + :return: deserialized MariaDB Server instance state dictionary + ''' + self.log("Creating / Updating the MariaDB Server instance {0}".format(self.name)) + + try: + self.parameters['tags'] = self.tags + if self.to_do == Actions.Create: + response = self.mariadb_client.servers.create(resource_group_name=self.resource_group, + server_name=self.name, + parameters=self.parameters) + else: + # structure of parameters for update must be changed + self.parameters.update(self.parameters.pop("properties", {})) + response = self.mariadb_client.servers.update(resource_group_name=self.resource_group, + server_name=self.name, + parameters=self.parameters) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the MariaDB Server instance.') + self.fail("Error creating the MariaDB Server instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_mariadbserver(self): + ''' + Deletes specified MariaDB Server instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the MariaDB Server instance {0}".format(self.name)) + try: + response = self.mariadb_client.servers.delete(resource_group_name=self.resource_group, + server_name=self.name) + except CloudError as e: + self.log('Error attempting to delete the MariaDB Server instance.') + self.fail("Error deleting the MariaDB Server instance: {0}".format(str(e))) + + return True + + def get_mariadbserver(self): + ''' + Gets the properties of the specified MariaDB Server. + + :return: deserialized MariaDB Server instance state dictionary + ''' + self.log("Checking if the MariaDB Server instance {0} is present".format(self.name)) + found = False + try: + response = self.mariadb_client.servers.get(resource_group_name=self.resource_group, + server_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("MariaDB Server instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the MariaDB Server instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbServers() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py new file mode 100644 index 00000000..464aa4d8 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py @@ -0,0 +1,265 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbserver_info +version_added: "2.9" +short_description: Get Azure MariaDB Server facts +description: + - Get facts of MariaDB Server. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + type: str + name: + description: + - The name of the server. + type: str + tags: + description: + - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'. + type: list + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get instance of MariaDB Server + azure_rm_mariadbserver_info: + resource_group: myResourceGroup + name: server_name + + - name: List instances of MariaDB Server + azure_rm_mariadbserver_info: + resource_group: myResourceGroup +''' + +RETURN = ''' +servers: + description: + - A list of dictionaries containing facts for MariaDB servers. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/myabdud1223 + resource_group: + description: + - Resource group name. + returned: always + type: str + sample: myResourceGroup + name: + description: + - Resource name. + returned: always + type: str + sample: myabdud1223 + location: + description: + - The location the resource resides in. + returned: always + type: str + sample: eastus + sku: + description: + - The SKU of the server. + returned: always + type: complex + contains: + name: + description: + - The name of the SKU. + returned: always + type: str + sample: GP_Gen4_2 + tier: + description: + - The tier of the particular SKU. + returned: always + type: str + sample: GeneralPurpose + capacity: + description: + - The scale capacity. + returned: always + type: int + sample: 2 + storage_mb: + description: + - The maximum storage allowed for a server. + returned: always + type: int + sample: 128000 + enforce_ssl: + description: + - Enable SSL enforcement. + returned: always + type: bool + sample: False + admin_username: + description: + - The administrator's login name of a server. + returned: always + type: str + sample: serveradmin + version: + description: + - Server version. + returned: always + type: str + sample: "9.6" + user_visible_state: + description: + - A state of a server that is visible to user. + returned: always + type: str + sample: Ready + fully_qualified_domain_name: + description: + - The fully qualified domain name of a server. + returned: always + type: str + sample: myabdud1223.mys.database.azure.com + tags: + description: + - Tags assigned to the resource. Dictionary of string:string pairs. + type: dict + sample: { tag1: abc } +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbServerInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str' + ), + tags=dict( + type='list' + ) + ) + # store the results of the module operation + self.results = dict( + changed=False + ) + self.resource_group = None + self.name = None + self.tags = None + super(AzureRMMariaDbServerInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbserver_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbserver_facts' module has been renamed to 'azure_rm_mariadbserver_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if (self.resource_group is not None and + self.name is not None): + self.results['servers'] = self.get() + elif (self.resource_group is not None): + self.results['servers'] = self.list_by_resource_group() + return self.results + + def get(self): + response = None + results = [] + try: + response = self.mariadb_client.servers.get(resource_group_name=self.resource_group, + server_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for MariaDB Server.') + + if response and self.has_tags(response.tags, self.tags): + results.append(self.format_item(response)) + + return results + + def list_by_resource_group(self): + response = None + results = [] + try: + response = self.mariadb_client.servers.list_by_resource_group(resource_group_name=self.resource_group) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for MariaDB Servers.') + + if response is not None: + for item in response: + if self.has_tags(item.tags, self.tags): + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'id': d['id'], + 'resource_group': self.resource_group, + 'name': d['name'], + 'sku': d['sku'], + 'location': d['location'], + 'storage_mb': d['storage_profile']['storage_mb'], + 'version': d['version'], + 'enforce_ssl': (d['ssl_enforcement'] == 'Enabled'), + 'admin_username': d['administrator_login'], + 'user_visible_state': d['user_visible_state'], + 'fully_qualified_domain_name': d['fully_qualified_domain_name'], + 'tags': d.get('tags') + } + + return d + + +def main(): + AzureRMMariaDbServerInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_resource.py b/test/support/integration/plugins/modules/azure_rm_resource.py new file mode 100644 index 00000000..6ea3e3bb --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_resource.py @@ -0,0 +1,427 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_resource +version_added: "2.6" +short_description: Create any Azure resource +description: + - Create, update or delete any Azure resource using Azure REST API. + - This module gives access to resources that are not supported via Ansible modules. + - Refer to U(https://docs.microsoft.com/en-us/rest/api/) regarding details related to specific resource REST API. + +options: + url: + description: + - Azure RM Resource URL. + api_version: + description: + - Specific API version to be used. + provider: + description: + - Provider type. + - Required if URL is not specified. + resource_group: + description: + - Resource group to be used. + - Required if URL is not specified. + resource_type: + description: + - Resource type. + - Required if URL is not specified. + resource_name: + description: + - Resource name. + - Required if URL Is not specified. + subresource: + description: + - List of subresources. + suboptions: + namespace: + description: + - Subresource namespace. + type: + description: + - Subresource type. + name: + description: + - Subresource name. + body: + description: + - The body of the HTTP request/response to the web service. + method: + description: + - The HTTP method of the request or response. It must be uppercase. + choices: + - GET + - PUT + - POST + - HEAD + - PATCH + - DELETE + - MERGE + default: "PUT" + status_code: + description: + - A valid, numeric, HTTP status code that signifies success of the request. Can also be comma separated list of status codes. + type: list + default: [ 200, 201, 202 ] + idempotency: + description: + - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body). + default: no + type: bool + polling_timeout: + description: + - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body). + default: 0 + type: int + version_added: "2.8" + polling_interval: + description: + - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body). + default: 60 + type: int + version_added: "2.8" + state: + description: + - Assert the state of the resource. Use C(present) to create or update resource or C(absent) to delete resource. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + +''' + +EXAMPLES = ''' + - name: Update scaleset info using azure_rm_resource + azure_rm_resource: + resource_group: myResourceGroup + provider: compute + resource_type: virtualmachinescalesets + resource_name: myVmss + api_version: "2017-12-01" + body: { body } +''' + +RETURN = ''' +response: + description: + - Response specific to resource type. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + type: str + returned: always + sample: "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Storage/storageAccounts/staccb57dc95183" + kind: + description: + - The kind of storage. + type: str + returned: always + sample: Storage + location: + description: + - The resource location, defaults to location of the resource group. + type: str + returned: always + sample: eastus + name: + description: + The storage account name. + type: str + returned: always + sample: staccb57dc95183 + properties: + description: + - The storage account's related properties. + type: dict + returned: always + sample: { + "creationTime": "2019-06-13T06:34:33.0996676Z", + "encryption": { + "keySource": "Microsoft.Storage", + "services": { + "blob": { + "enabled": true, + "lastEnabledTime": "2019-06-13T06:34:33.1934074Z" + }, + "file": { + "enabled": true, + "lastEnabledTime": "2019-06-13T06:34:33.1934074Z" + } + } + }, + "networkAcls": { + "bypass": "AzureServices", + "defaultAction": "Allow", + "ipRules": [], + "virtualNetworkRules": [] + }, + "primaryEndpoints": { + "blob": "https://staccb57dc95183.blob.core.windows.net/", + "file": "https://staccb57dc95183.file.core.windows.net/", + "queue": "https://staccb57dc95183.queue.core.windows.net/", + "table": "https://staccb57dc95183.table.core.windows.net/" + }, + "primaryLocation": "eastus", + "provisioningState": "Succeeded", + "secondaryLocation": "westus", + "statusOfPrimary": "available", + "statusOfSecondary": "available", + "supportsHttpsTrafficOnly": false + } + sku: + description: + - The storage account SKU. + type: dict + returned: always + sample: { + "name": "Standard_GRS", + "tier": "Standard" + } + tags: + description: + - Resource tags. + type: dict + returned: always + sample: { 'key1': 'value1' } + type: + description: + - The resource type. + type: str + returned: always + sample: "Microsoft.Storage/storageAccounts" + +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase +from ansible.module_utils.azure_rm_common_rest import GenericRestClient +from ansible.module_utils.common.dict_transformations import dict_merge + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.service_client import ServiceClient + from msrestazure.tools import resource_id, is_valid_resource_id + import json + +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMResource(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + url=dict( + type='str' + ), + provider=dict( + type='str', + ), + resource_group=dict( + type='str', + ), + resource_type=dict( + type='str', + ), + resource_name=dict( + type='str', + ), + subresource=dict( + type='list', + default=[] + ), + api_version=dict( + type='str' + ), + method=dict( + type='str', + default='PUT', + choices=["GET", "PUT", "POST", "HEAD", "PATCH", "DELETE", "MERGE"] + ), + body=dict( + type='raw' + ), + status_code=dict( + type='list', + default=[200, 201, 202] + ), + idempotency=dict( + type='bool', + default=False + ), + polling_timeout=dict( + type='int', + default=0 + ), + polling_interval=dict( + type='int', + default=60 + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + # store the results of the module operation + self.results = dict( + changed=False, + response=None + ) + self.mgmt_client = None + self.url = None + self.api_version = None + self.provider = None + self.resource_group = None + self.resource_type = None + self.resource_name = None + self.subresource_type = None + self.subresource_name = None + self.subresource = [] + self.method = None + self.status_code = [] + self.idempotency = False + self.polling_timeout = None + self.polling_interval = None + self.state = None + self.body = None + super(AzureRMResource, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(GenericRestClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if self.state == 'absent': + self.method = 'DELETE' + self.status_code.append(204) + + if self.url is None: + orphan = None + rargs = dict() + rargs['subscription'] = self.subscription_id + rargs['resource_group'] = self.resource_group + if not (self.provider is None or self.provider.lower().startswith('.microsoft')): + rargs['namespace'] = "Microsoft." + self.provider + else: + rargs['namespace'] = self.provider + + if self.resource_type is not None and self.resource_name is not None: + rargs['type'] = self.resource_type + rargs['name'] = self.resource_name + for i in range(len(self.subresource)): + resource_ns = self.subresource[i].get('namespace', None) + resource_type = self.subresource[i].get('type', None) + resource_name = self.subresource[i].get('name', None) + if resource_type is not None and resource_name is not None: + rargs['child_namespace_' + str(i + 1)] = resource_ns + rargs['child_type_' + str(i + 1)] = resource_type + rargs['child_name_' + str(i + 1)] = resource_name + else: + orphan = resource_type + else: + orphan = self.resource_type + + self.url = resource_id(**rargs) + + if orphan is not None: + self.url += '/' + orphan + + # if api_version was not specified, get latest one + if not self.api_version: + try: + # extract provider and resource type + if "/providers/" in self.url: + provider = self.url.split("/providers/")[1].split("/")[0] + resourceType = self.url.split(provider + "/")[1].split("/")[0] + url = "/subscriptions/" + self.subscription_id + "/providers/" + provider + api_versions = json.loads(self.mgmt_client.query(url, "GET", {'api-version': '2015-01-01'}, None, None, [200], 0, 0).text) + for rt in api_versions['resourceTypes']: + if rt['resourceType'].lower() == resourceType.lower(): + self.api_version = rt['apiVersions'][0] + break + else: + # if there's no provider in API version, assume Microsoft.Resources + self.api_version = '2018-05-01' + if not self.api_version: + self.fail("Couldn't find api version for {0}/{1}".format(provider, resourceType)) + except Exception as exc: + self.fail("Failed to obtain API version: {0}".format(str(exc))) + + query_parameters = {} + query_parameters['api-version'] = self.api_version + + header_parameters = {} + header_parameters['Content-Type'] = 'application/json; charset=utf-8' + + needs_update = True + response = None + + if self.idempotency: + original = self.mgmt_client.query(self.url, "GET", query_parameters, None, None, [200, 404], 0, 0) + + if original.status_code == 404: + if self.state == 'absent': + needs_update = False + else: + try: + response = json.loads(original.text) + needs_update = (dict_merge(response, self.body) != response) + except Exception: + pass + + if needs_update: + response = self.mgmt_client.query(self.url, + self.method, + query_parameters, + header_parameters, + self.body, + self.status_code, + self.polling_timeout, + self.polling_interval) + if self.state == 'present': + try: + response = json.loads(response.text) + except Exception: + response = response.text + else: + response = None + + self.results['response'] = response + self.results['changed'] = needs_update + + return self.results + + +def main(): + AzureRMResource() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_resource_info.py b/test/support/integration/plugins/modules/azure_rm_resource_info.py new file mode 100644 index 00000000..f797f662 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_resource_info.py @@ -0,0 +1,432 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_resource_info +version_added: "2.9" +short_description: Generic facts of Azure resources +description: + - Obtain facts of any resource using Azure REST API. + - This module gives access to resources that are not supported via Ansible modules. + - Refer to U(https://docs.microsoft.com/en-us/rest/api/) regarding details related to specific resource REST API. + +options: + url: + description: + - Azure RM Resource URL. + api_version: + description: + - Specific API version to be used. + provider: + description: + - Provider type, should be specified in no URL is given. + resource_group: + description: + - Resource group to be used. + - Required if URL is not specified. + resource_type: + description: + - Resource type. + resource_name: + description: + - Resource name. + subresource: + description: + - List of subresources. + suboptions: + namespace: + description: + - Subresource namespace. + type: + description: + - Subresource type. + name: + description: + - Subresource name. + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + +''' + +EXAMPLES = ''' + - name: Get scaleset info + azure_rm_resource_info: + resource_group: myResourceGroup + provider: compute + resource_type: virtualmachinescalesets + resource_name: myVmss + api_version: "2017-12-01" + + - name: Query all the resources in the resource group + azure_rm_resource_info: + resource_group: "{{ resource_group }}" + resource_type: resources +''' + +RETURN = ''' +response: + description: + - Response specific to resource type. + returned: always + type: complex + contains: + id: + description: + - Id of the Azure resource. + type: str + returned: always + sample: "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Compute/virtualMachines/myVM" + location: + description: + - Resource location. + type: str + returned: always + sample: eastus + name: + description: + - Resource name. + type: str + returned: always + sample: myVM + properties: + description: + - Specifies the virtual machine's property. + type: complex + returned: always + contains: + diagnosticsProfile: + description: + - Specifies the boot diagnostic settings state. + type: complex + returned: always + contains: + bootDiagnostics: + description: + - A debugging feature, which to view Console Output and Screenshot to diagnose VM status. + type: dict + returned: always + sample: { + "enabled": true, + "storageUri": "https://vxisurgdiag.blob.core.windows.net/" + } + hardwareProfile: + description: + - Specifies the hardware settings for the virtual machine. + type: dict + returned: always + sample: { + "vmSize": "Standard_D2s_v3" + } + networkProfile: + description: + - Specifies the network interfaces of the virtual machine. + type: complex + returned: always + contains: + networkInterfaces: + description: + - Describes a network interface reference. + type: list + returned: always + sample: + - { + "id": "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Network/networkInterfaces/myvm441" + } + osProfile: + description: + - Specifies the operating system settings for the virtual machine. + type: complex + returned: always + contains: + adminUsername: + description: + - Specifies the name of the administrator account. + type: str + returned: always + sample: azureuser + allowExtensionOperations: + description: + - Specifies whether extension operations should be allowed on the virtual machine. + - This may only be set to False when no extensions are present on the virtual machine. + type: bool + returned: always + sample: true + computerName: + description: + - Specifies the host OS name of the virtual machine. + type: str + returned: always + sample: myVM + requireGuestProvisionSignale: + description: + - Specifies the host require guest provision signal or not. + type: bool + returned: always + sample: true + secrets: + description: + - Specifies set of certificates that should be installed onto the virtual machine. + type: list + returned: always + sample: [] + linuxConfiguration: + description: + - Specifies the Linux operating system settings on the virtual machine. + type: dict + returned: when OS type is Linux + sample: { + "disablePasswordAuthentication": false, + "provisionVMAgent": true + } + provisioningState: + description: + - The provisioning state. + type: str + returned: always + sample: Succeeded + vmID: + description: + - Specifies the VM unique ID which is a 128-bits identifier that is encoded and stored in all Azure laaS VMs SMBIOS. + - It can be read using platform BIOS commands. + type: str + returned: always + sample: "eb86d9bb-6725-4787-a487-2e497d5b340c" + storageProfile: + description: + - Specifies the storage account type for the managed disk. + type: complex + returned: always + contains: + dataDisks: + description: + - Specifies the parameters that are used to add a data disk to virtual machine. + type: list + returned: always + sample: + - { + "caching": "None", + "createOption": "Attach", + "diskSizeGB": 1023, + "lun": 2, + "managedDisk": { + "id": "/subscriptions/xxxx....xxxx/resourceGroups/V-XISURG/providers/Microsoft.Compute/disks/testdisk2", + "storageAccountType": "StandardSSD_LRS" + }, + "name": "testdisk2" + } + - { + "caching": "None", + "createOption": "Attach", + "diskSizeGB": 1023, + "lun": 1, + "managedDisk": { + "id": "/subscriptions/xxxx...xxxx/resourceGroups/V-XISURG/providers/Microsoft.Compute/disks/testdisk3", + "storageAccountType": "StandardSSD_LRS" + }, + "name": "testdisk3" + } + + imageReference: + description: + - Specifies information about the image to use. + type: dict + returned: always + sample: { + "offer": "UbuntuServer", + "publisher": "Canonical", + "sku": "18.04-LTS", + "version": "latest" + } + osDisk: + description: + - Specifies information about the operating system disk used by the virtual machine. + type: dict + returned: always + sample: { + "caching": "ReadWrite", + "createOption": "FromImage", + "diskSizeGB": 30, + "managedDisk": { + "id": "/subscriptions/xxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Compute/disks/myVM_disk1_xxx", + "storageAccountType": "Premium_LRS" + }, + "name": "myVM_disk1_xxx", + "osType": "Linux" + } + type: + description: + - The type of identity used for the virtual machine. + type: str + returned: always + sample: "Microsoft.Compute/virtualMachines" +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase +from ansible.module_utils.azure_rm_common_rest import GenericRestClient + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.service_client import ServiceClient + from msrestazure.tools import resource_id, is_valid_resource_id + import json + +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMResourceInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + url=dict( + type='str' + ), + provider=dict( + type='str' + ), + resource_group=dict( + type='str' + ), + resource_type=dict( + type='str' + ), + resource_name=dict( + type='str' + ), + subresource=dict( + type='list', + default=[] + ), + api_version=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict( + response=[] + ) + self.mgmt_client = None + self.url = None + self.api_version = None + self.provider = None + self.resource_group = None + self.resource_type = None + self.resource_name = None + self.subresource = [] + super(AzureRMResourceInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_resource_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_resource_facts' module has been renamed to 'azure_rm_resource_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(GenericRestClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if self.url is None: + orphan = None + rargs = dict() + rargs['subscription'] = self.subscription_id + rargs['resource_group'] = self.resource_group + if not (self.provider is None or self.provider.lower().startswith('.microsoft')): + rargs['namespace'] = "Microsoft." + self.provider + else: + rargs['namespace'] = self.provider + + if self.resource_type is not None and self.resource_name is not None: + rargs['type'] = self.resource_type + rargs['name'] = self.resource_name + for i in range(len(self.subresource)): + resource_ns = self.subresource[i].get('namespace', None) + resource_type = self.subresource[i].get('type', None) + resource_name = self.subresource[i].get('name', None) + if resource_type is not None and resource_name is not None: + rargs['child_namespace_' + str(i + 1)] = resource_ns + rargs['child_type_' + str(i + 1)] = resource_type + rargs['child_name_' + str(i + 1)] = resource_name + else: + orphan = resource_type + else: + orphan = self.resource_type + + self.url = resource_id(**rargs) + + if orphan is not None: + self.url += '/' + orphan + + # if api_version was not specified, get latest one + if not self.api_version: + try: + # extract provider and resource type + if "/providers/" in self.url: + provider = self.url.split("/providers/")[1].split("/")[0] + resourceType = self.url.split(provider + "/")[1].split("/")[0] + url = "/subscriptions/" + self.subscription_id + "/providers/" + provider + api_versions = json.loads(self.mgmt_client.query(url, "GET", {'api-version': '2015-01-01'}, None, None, [200], 0, 0).text) + for rt in api_versions['resourceTypes']: + if rt['resourceType'].lower() == resourceType.lower(): + self.api_version = rt['apiVersions'][0] + break + else: + # if there's no provider in API version, assume Microsoft.Resources + self.api_version = '2018-05-01' + if not self.api_version: + self.fail("Couldn't find api version for {0}/{1}".format(provider, resourceType)) + except Exception as exc: + self.fail("Failed to obtain API version: {0}".format(str(exc))) + + self.results['url'] = self.url + + query_parameters = {} + query_parameters['api-version'] = self.api_version + + header_parameters = {} + header_parameters['Content-Type'] = 'application/json; charset=utf-8' + skiptoken = None + + while True: + if skiptoken: + query_parameters['skiptoken'] = skiptoken + response = self.mgmt_client.query(self.url, "GET", query_parameters, header_parameters, None, [200, 404], 0, 0) + try: + response = json.loads(response.text) + if isinstance(response, dict): + if response.get('value'): + self.results['response'] = self.results['response'] + response['value'] + skiptoken = response.get('nextLink') + else: + self.results['response'] = self.results['response'] + [response] + except Exception as e: + self.fail('Failed to parse response: ' + str(e)) + if not skiptoken: + break + return self.results + + +def main(): + AzureRMResourceInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_storageaccount.py b/test/support/integration/plugins/modules/azure_rm_storageaccount.py new file mode 100644 index 00000000..d4158bbd --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_storageaccount.py @@ -0,0 +1,684 @@ +#!/usr/bin/python +# +# Copyright (c) 2016 Matt Davis, <mdavis@ansible.com> +# Chris Houseknecht, <house@redhat.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_storageaccount +version_added: "2.1" +short_description: Manage Azure storage accounts +description: + - Create, update or delete a storage account. +options: + resource_group: + description: + - Name of the resource group to use. + required: true + aliases: + - resource_group_name + name: + description: + - Name of the storage account to update or create. + state: + description: + - State of the storage account. Use C(present) to create or update a storage account and use C(absent) to delete an account. + default: present + choices: + - absent + - present + location: + description: + - Valid Azure location. Defaults to location of the resource group. + account_type: + description: + - Type of storage account. Required when creating a storage account. + - C(Standard_ZRS) and C(Premium_LRS) accounts cannot be changed to other account types. + - Other account types cannot be changed to C(Standard_ZRS) or C(Premium_LRS). + choices: + - Premium_LRS + - Standard_GRS + - Standard_LRS + - StandardSSD_LRS + - Standard_RAGRS + - Standard_ZRS + - Premium_ZRS + aliases: + - type + custom_domain: + description: + - User domain assigned to the storage account. + - Must be a dictionary with I(name) and I(use_sub_domain) keys where I(name) is the CNAME source. + - Only one custom domain is supported per storage account at this time. + - To clear the existing custom domain, use an empty string for the custom domain name property. + - Can be added to an existing storage account. Will be ignored during storage account creation. + aliases: + - custom_dns_domain_suffix + kind: + description: + - The kind of storage. + default: 'Storage' + choices: + - Storage + - StorageV2 + - BlobStorage + version_added: "2.2" + access_tier: + description: + - The access tier for this storage account. Required when I(kind=BlobStorage). + choices: + - Hot + - Cool + version_added: "2.4" + force_delete_nonempty: + description: + - Attempt deletion if resource already exists and cannot be updated. + type: bool + aliases: + - force + https_only: + description: + - Allows https traffic only to storage service when set to C(true). + type: bool + version_added: "2.8" + blob_cors: + description: + - Specifies CORS rules for the Blob service. + - You can include up to five CorsRule elements in the request. + - If no blob_cors elements are included in the argument list, nothing about CORS will be changed. + - If you want to delete all CORS rules and disable CORS for the Blob service, explicitly set I(blob_cors=[]). + type: list + version_added: "2.8" + suboptions: + allowed_origins: + description: + - A list of origin domains that will be allowed via CORS, or "*" to allow all domains. + type: list + required: true + allowed_methods: + description: + - A list of HTTP methods that are allowed to be executed by the origin. + type: list + required: true + max_age_in_seconds: + description: + - The number of seconds that the client/browser should cache a preflight response. + type: int + required: true + exposed_headers: + description: + - A list of response headers to expose to CORS clients. + type: list + required: true + allowed_headers: + description: + - A list of headers allowed to be part of the cross-origin request. + type: list + required: true + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Chris Houseknecht (@chouseknecht) + - Matt Davis (@nitzmahone) +''' + +EXAMPLES = ''' + - name: remove account, if it exists + azure_rm_storageaccount: + resource_group: myResourceGroup + name: clh0002 + state: absent + + - name: create an account + azure_rm_storageaccount: + resource_group: myResourceGroup + name: clh0002 + type: Standard_RAGRS + tags: + testing: testing + delete: on-exit + + - name: create an account with blob CORS + azure_rm_storageaccount: + resource_group: myResourceGroup + name: clh002 + type: Standard_RAGRS + blob_cors: + - allowed_origins: + - http://www.example.com/ + allowed_methods: + - GET + - POST + allowed_headers: + - x-ms-meta-data* + - x-ms-meta-target* + - x-ms-meta-abc + exposed_headers: + - x-ms-meta-* + max_age_in_seconds: 200 +''' + + +RETURN = ''' +state: + description: + - Current state of the storage account. + returned: always + type: complex + contains: + account_type: + description: + - Type of storage account. + returned: always + type: str + sample: Standard_RAGRS + custom_domain: + description: + - User domain assigned to the storage account. + returned: always + type: complex + contains: + name: + description: + - CNAME source. + returned: always + type: str + sample: testaccount + use_sub_domain: + description: + - Whether to use sub domain. + returned: always + type: bool + sample: true + id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Storage/storageAccounts/clh0003" + location: + description: + - Valid Azure location. Defaults to location of the resource group. + returned: always + type: str + sample: eastus2 + name: + description: + - Name of the storage account to update or create. + returned: always + type: str + sample: clh0003 + primary_endpoints: + description: + - The URLs to retrieve the public I(blob), I(queue), or I(table) object from the primary location. + returned: always + type: dict + sample: { + "blob": "https://clh0003.blob.core.windows.net/", + "queue": "https://clh0003.queue.core.windows.net/", + "table": "https://clh0003.table.core.windows.net/" + } + primary_location: + description: + - The location of the primary data center for the storage account. + returned: always + type: str + sample: eastus2 + provisioning_state: + description: + - The status of the storage account. + - Possible values include C(Creating), C(ResolvingDNS), C(Succeeded). + returned: always + type: str + sample: Succeeded + resource_group: + description: + - The resource group's name. + returned: always + type: str + sample: Testing + secondary_endpoints: + description: + - The URLs to retrieve the public I(blob), I(queue), or I(table) object from the secondary location. + returned: always + type: dict + sample: { + "blob": "https://clh0003-secondary.blob.core.windows.net/", + "queue": "https://clh0003-secondary.queue.core.windows.net/", + "table": "https://clh0003-secondary.table.core.windows.net/" + } + secondary_location: + description: + - The location of the geo-replicated secondary for the storage account. + returned: always + type: str + sample: centralus + status_of_primary: + description: + - The status of the primary location of the storage account; either C(available) or C(unavailable). + returned: always + type: str + sample: available + status_of_secondary: + description: + - The status of the secondary location of the storage account; either C(available) or C(unavailable). + returned: always + type: str + sample: available + tags: + description: + - Resource tags. + returned: always + type: dict + sample: { 'tags1': 'value1' } + type: + description: + - The storage account type. + returned: always + type: str + sample: "Microsoft.Storage/storageAccounts" +''' + +try: + from msrestazure.azure_exceptions import CloudError + from azure.storage.cloudstorageaccount import CloudStorageAccount + from azure.common import AzureMissingResourceHttpError +except ImportError: + # This is handled in azure_rm_common + pass + +import copy +from ansible.module_utils.azure_rm_common import AZURE_SUCCESS_STATE, AzureRMModuleBase +from ansible.module_utils._text import to_native + +cors_rule_spec = dict( + allowed_origins=dict(type='list', elements='str', required=True), + allowed_methods=dict(type='list', elements='str', required=True), + max_age_in_seconds=dict(type='int', required=True), + exposed_headers=dict(type='list', elements='str', required=True), + allowed_headers=dict(type='list', elements='str', required=True), +) + + +def compare_cors(cors1, cors2): + if len(cors1) != len(cors2): + return False + copy2 = copy.copy(cors2) + for rule1 in cors1: + matched = False + for rule2 in copy2: + if (rule1['max_age_in_seconds'] == rule2['max_age_in_seconds'] + and set(rule1['allowed_methods']) == set(rule2['allowed_methods']) + and set(rule1['allowed_origins']) == set(rule2['allowed_origins']) + and set(rule1['allowed_headers']) == set(rule2['allowed_headers']) + and set(rule1['exposed_headers']) == set(rule2['exposed_headers'])): + matched = True + copy2.remove(rule2) + if not matched: + return False + return True + + +class AzureRMStorageAccount(AzureRMModuleBase): + + def __init__(self): + + self.module_arg_spec = dict( + account_type=dict(type='str', + choices=['Premium_LRS', 'Standard_GRS', 'Standard_LRS', 'StandardSSD_LRS', 'Standard_RAGRS', 'Standard_ZRS', 'Premium_ZRS'], + aliases=['type']), + custom_domain=dict(type='dict', aliases=['custom_dns_domain_suffix']), + location=dict(type='str'), + name=dict(type='str', required=True), + resource_group=dict(required=True, type='str', aliases=['resource_group_name']), + state=dict(default='present', choices=['present', 'absent']), + force_delete_nonempty=dict(type='bool', default=False, aliases=['force']), + tags=dict(type='dict'), + kind=dict(type='str', default='Storage', choices=['Storage', 'StorageV2', 'BlobStorage']), + access_tier=dict(type='str', choices=['Hot', 'Cool']), + https_only=dict(type='bool', default=False), + blob_cors=dict(type='list', options=cors_rule_spec, elements='dict') + ) + + self.results = dict( + changed=False, + state=dict() + ) + + self.account_dict = None + self.resource_group = None + self.name = None + self.state = None + self.location = None + self.account_type = None + self.custom_domain = None + self.tags = None + self.force_delete_nonempty = None + self.kind = None + self.access_tier = None + self.https_only = None + self.blob_cors = None + + super(AzureRMStorageAccount, self).__init__(self.module_arg_spec, + supports_check_mode=True) + + def exec_module(self, **kwargs): + + for key in list(self.module_arg_spec.keys()) + ['tags']: + setattr(self, key, kwargs[key]) + + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + # Set default location + self.location = resource_group.location + + if len(self.name) < 3 or len(self.name) > 24: + self.fail("Parameter error: name length must be between 3 and 24 characters.") + + if self.custom_domain: + if self.custom_domain.get('name', None) is None: + self.fail("Parameter error: expecting custom_domain to have a name attribute of type string.") + if self.custom_domain.get('use_sub_domain', None) is None: + self.fail("Parameter error: expecting custom_domain to have a use_sub_domain " + "attribute of type boolean.") + + self.account_dict = self.get_account() + + if self.state == 'present' and self.account_dict and \ + self.account_dict['provisioning_state'] != AZURE_SUCCESS_STATE: + self.fail("Error: storage account {0} has not completed provisioning. State is {1}. Expecting state " + "to be {2}.".format(self.name, self.account_dict['provisioning_state'], AZURE_SUCCESS_STATE)) + + if self.account_dict is not None: + self.results['state'] = self.account_dict + else: + self.results['state'] = dict() + + if self.state == 'present': + if not self.account_dict: + self.results['state'] = self.create_account() + else: + self.update_account() + elif self.state == 'absent' and self.account_dict: + self.delete_account() + self.results['state'] = dict(Status='Deleted') + + return self.results + + def check_name_availability(self): + self.log('Checking name availability for {0}'.format(self.name)) + try: + response = self.storage_client.storage_accounts.check_name_availability(self.name) + except CloudError as e: + self.log('Error attempting to validate name.') + self.fail("Error checking name availability: {0}".format(str(e))) + if not response.name_available: + self.log('Error name not available.') + self.fail("{0} - {1}".format(response.message, response.reason)) + + def get_account(self): + self.log('Get properties for account {0}'.format(self.name)) + account_obj = None + blob_service_props = None + account_dict = None + + try: + account_obj = self.storage_client.storage_accounts.get_properties(self.resource_group, self.name) + blob_service_props = self.storage_client.blob_services.get_service_properties(self.resource_group, self.name) + except CloudError: + pass + + if account_obj: + account_dict = self.account_obj_to_dict(account_obj, blob_service_props) + + return account_dict + + def account_obj_to_dict(self, account_obj, blob_service_props=None): + account_dict = dict( + id=account_obj.id, + name=account_obj.name, + location=account_obj.location, + resource_group=self.resource_group, + type=account_obj.type, + access_tier=(account_obj.access_tier.value + if account_obj.access_tier is not None else None), + sku_tier=account_obj.sku.tier.value, + sku_name=account_obj.sku.name.value, + provisioning_state=account_obj.provisioning_state.value, + secondary_location=account_obj.secondary_location, + status_of_primary=(account_obj.status_of_primary.value + if account_obj.status_of_primary is not None else None), + status_of_secondary=(account_obj.status_of_secondary.value + if account_obj.status_of_secondary is not None else None), + primary_location=account_obj.primary_location, + https_only=account_obj.enable_https_traffic_only + ) + account_dict['custom_domain'] = None + if account_obj.custom_domain: + account_dict['custom_domain'] = dict( + name=account_obj.custom_domain.name, + use_sub_domain=account_obj.custom_domain.use_sub_domain + ) + + account_dict['primary_endpoints'] = None + if account_obj.primary_endpoints: + account_dict['primary_endpoints'] = dict( + blob=account_obj.primary_endpoints.blob, + queue=account_obj.primary_endpoints.queue, + table=account_obj.primary_endpoints.table + ) + account_dict['secondary_endpoints'] = None + if account_obj.secondary_endpoints: + account_dict['secondary_endpoints'] = dict( + blob=account_obj.secondary_endpoints.blob, + queue=account_obj.secondary_endpoints.queue, + table=account_obj.secondary_endpoints.table + ) + account_dict['tags'] = None + if account_obj.tags: + account_dict['tags'] = account_obj.tags + if blob_service_props and blob_service_props.cors and blob_service_props.cors.cors_rules: + account_dict['blob_cors'] = [dict( + allowed_origins=[to_native(y) for y in x.allowed_origins], + allowed_methods=[to_native(y) for y in x.allowed_methods], + max_age_in_seconds=x.max_age_in_seconds, + exposed_headers=[to_native(y) for y in x.exposed_headers], + allowed_headers=[to_native(y) for y in x.allowed_headers] + ) for x in blob_service_props.cors.cors_rules] + return account_dict + + def update_account(self): + self.log('Update storage account {0}'.format(self.name)) + if bool(self.https_only) != bool(self.account_dict.get('https_only')): + self.results['changed'] = True + self.account_dict['https_only'] = self.https_only + if not self.check_mode: + try: + parameters = self.storage_models.StorageAccountUpdateParameters(enable_https_traffic_only=self.https_only) + self.storage_client.storage_accounts.update(self.resource_group, + self.name, + parameters) + except Exception as exc: + self.fail("Failed to update account type: {0}".format(str(exc))) + + if self.account_type: + if self.account_type != self.account_dict['sku_name']: + # change the account type + SkuName = self.storage_models.SkuName + if self.account_dict['sku_name'] in [SkuName.premium_lrs, SkuName.standard_zrs]: + self.fail("Storage accounts of type {0} and {1} cannot be changed.".format( + SkuName.premium_lrs, SkuName.standard_zrs)) + if self.account_type in [SkuName.premium_lrs, SkuName.standard_zrs]: + self.fail("Storage account of type {0} cannot be changed to a type of {1} or {2}.".format( + self.account_dict['sku_name'], SkuName.premium_lrs, SkuName.standard_zrs)) + + self.results['changed'] = True + self.account_dict['sku_name'] = self.account_type + + if self.results['changed'] and not self.check_mode: + # Perform the update. The API only allows changing one attribute per call. + try: + self.log("sku_name: %s" % self.account_dict['sku_name']) + self.log("sku_tier: %s" % self.account_dict['sku_tier']) + sku = self.storage_models.Sku(name=SkuName(self.account_dict['sku_name'])) + sku.tier = self.storage_models.SkuTier(self.account_dict['sku_tier']) + parameters = self.storage_models.StorageAccountUpdateParameters(sku=sku) + self.storage_client.storage_accounts.update(self.resource_group, + self.name, + parameters) + except Exception as exc: + self.fail("Failed to update account type: {0}".format(str(exc))) + + if self.custom_domain: + if not self.account_dict['custom_domain'] or self.account_dict['custom_domain'] != self.custom_domain: + self.results['changed'] = True + self.account_dict['custom_domain'] = self.custom_domain + + if self.results['changed'] and not self.check_mode: + new_domain = self.storage_models.CustomDomain(name=self.custom_domain['name'], + use_sub_domain=self.custom_domain['use_sub_domain']) + parameters = self.storage_models.StorageAccountUpdateParameters(custom_domain=new_domain) + try: + self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters) + except Exception as exc: + self.fail("Failed to update custom domain: {0}".format(str(exc))) + + if self.access_tier: + if not self.account_dict['access_tier'] or self.account_dict['access_tier'] != self.access_tier: + self.results['changed'] = True + self.account_dict['access_tier'] = self.access_tier + + if self.results['changed'] and not self.check_mode: + parameters = self.storage_models.StorageAccountUpdateParameters(access_tier=self.access_tier) + try: + self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters) + except Exception as exc: + self.fail("Failed to update access tier: {0}".format(str(exc))) + + update_tags, self.account_dict['tags'] = self.update_tags(self.account_dict['tags']) + if update_tags: + self.results['changed'] = True + if not self.check_mode: + parameters = self.storage_models.StorageAccountUpdateParameters(tags=self.account_dict['tags']) + try: + self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters) + except Exception as exc: + self.fail("Failed to update tags: {0}".format(str(exc))) + + if self.blob_cors and not compare_cors(self.account_dict.get('blob_cors', []), self.blob_cors): + self.results['changed'] = True + if not self.check_mode: + self.set_blob_cors() + + def create_account(self): + self.log("Creating account {0}".format(self.name)) + + if not self.location: + self.fail('Parameter error: location required when creating a storage account.') + + if not self.account_type: + self.fail('Parameter error: account_type required when creating a storage account.') + + if not self.access_tier and self.kind == 'BlobStorage': + self.fail('Parameter error: access_tier required when creating a storage account of type BlobStorage.') + + self.check_name_availability() + self.results['changed'] = True + + if self.check_mode: + account_dict = dict( + location=self.location, + account_type=self.account_type, + name=self.name, + resource_group=self.resource_group, + enable_https_traffic_only=self.https_only, + tags=dict() + ) + if self.tags: + account_dict['tags'] = self.tags + if self.blob_cors: + account_dict['blob_cors'] = self.blob_cors + return account_dict + sku = self.storage_models.Sku(name=self.storage_models.SkuName(self.account_type)) + sku.tier = self.storage_models.SkuTier.standard if 'Standard' in self.account_type else \ + self.storage_models.SkuTier.premium + parameters = self.storage_models.StorageAccountCreateParameters(sku=sku, + kind=self.kind, + location=self.location, + tags=self.tags, + access_tier=self.access_tier) + self.log(str(parameters)) + try: + poller = self.storage_client.storage_accounts.create(self.resource_group, self.name, parameters) + self.get_poller_result(poller) + except CloudError as e: + self.log('Error creating storage account.') + self.fail("Failed to create account: {0}".format(str(e))) + if self.blob_cors: + self.set_blob_cors() + # the poller doesn't actually return anything + return self.get_account() + + def delete_account(self): + if self.account_dict['provisioning_state'] == self.storage_models.ProvisioningState.succeeded.value and \ + not self.force_delete_nonempty and self.account_has_blob_containers(): + self.fail("Account contains blob containers. Is it in use? Use the force_delete_nonempty option to attempt deletion.") + + self.log('Delete storage account {0}'.format(self.name)) + self.results['changed'] = True + if not self.check_mode: + try: + status = self.storage_client.storage_accounts.delete(self.resource_group, self.name) + self.log("delete status: ") + self.log(str(status)) + except CloudError as e: + self.fail("Failed to delete the account: {0}".format(str(e))) + return True + + def account_has_blob_containers(self): + ''' + If there are blob containers, then there are likely VMs depending on this account and it should + not be deleted. + ''' + self.log('Checking for existing blob containers') + blob_service = self.get_blob_client(self.resource_group, self.name) + try: + response = blob_service.list_containers() + except AzureMissingResourceHttpError: + # No blob storage available? + return False + + if len(response.items) > 0: + return True + return False + + def set_blob_cors(self): + try: + cors_rules = self.storage_models.CorsRules(cors_rules=[self.storage_models.CorsRule(**x) for x in self.blob_cors]) + self.storage_client.blob_services.set_service_properties(self.resource_group, + self.name, + self.storage_models.BlobServiceProperties(cors=cors_rules)) + except Exception as exc: + self.fail("Failed to set CORS rules: {0}".format(str(exc))) + + +def main(): + AzureRMStorageAccount() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_webapp.py b/test/support/integration/plugins/modules/azure_rm_webapp.py new file mode 100644 index 00000000..4f185f45 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_webapp.py @@ -0,0 +1,1070 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_webapp +version_added: "2.7" +short_description: Manage Web App instances +description: + - Create, update and delete instance of Web App. + +options: + resource_group: + description: + - Name of the resource group to which the resource belongs. + required: True + name: + description: + - Unique name of the app to create or update. To create or update a deployment slot, use the {slot} parameter. + required: True + + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + + plan: + description: + - App service plan. Required for creation. + - Can be name of existing app service plan in same resource group as web app. + - Can be the resource ID of an existing app service plan. For example + /subscriptions/<subs_id>/resourceGroups/<resource_group>/providers/Microsoft.Web/serverFarms/<plan_name>. + - Can be a dict containing five parameters, defined below. + - C(name), name of app service plan. + - C(resource_group), resource group of the app service plan. + - C(sku), SKU of app service plan, allowed values listed on U(https://azure.microsoft.com/en-us/pricing/details/app-service/linux/). + - C(is_linux), whether or not the app service plan is Linux. defaults to C(False). + - C(number_of_workers), number of workers for app service plan. + + frameworks: + description: + - Set of run time framework settings. Each setting is a dictionary. + - See U(https://docs.microsoft.com/en-us/azure/app-service/app-service-web-overview) for more info. + suboptions: + name: + description: + - Name of the framework. + - Supported framework list for Windows web app and Linux web app is different. + - Windows web apps support C(java), C(net_framework), C(php), C(python), and C(node) from June 2018. + - Windows web apps support multiple framework at the same time. + - Linux web apps support C(java), C(ruby), C(php), C(dotnetcore), and C(node) from June 2018. + - Linux web apps support only one framework. + - Java framework is mutually exclusive with others. + choices: + - java + - net_framework + - php + - python + - ruby + - dotnetcore + - node + version: + description: + - Version of the framework. For Linux web app supported value, see U(https://aka.ms/linux-stacks) for more info. + - C(net_framework) supported value sample, C(v4.0) for .NET 4.6 and C(v3.0) for .NET 3.5. + - C(php) supported value sample, C(5.5), C(5.6), C(7.0). + - C(python) supported value sample, C(5.5), C(5.6), C(7.0). + - C(node) supported value sample, C(6.6), C(6.9). + - C(dotnetcore) supported value sample, C(1.0), C(1.1), C(1.2). + - C(ruby) supported value sample, C(2.3). + - C(java) supported value sample, C(1.9) for Windows web app. C(1.8) for Linux web app. + settings: + description: + - List of settings of the framework. + suboptions: + java_container: + description: + - Name of Java container. + - Supported only when I(frameworks=java). Sample values C(Tomcat), C(Jetty). + java_container_version: + description: + - Version of Java container. + - Supported only when I(frameworks=java). + - Sample values for C(Tomcat), C(8.0), C(8.5), C(9.0). For C(Jetty,), C(9.1), C(9.3). + + container_settings: + description: + - Web app container settings. + suboptions: + name: + description: + - Name of container, for example C(imagename:tag). + registry_server_url: + description: + - Container registry server URL, for example C(mydockerregistry.io). + registry_server_user: + description: + - The container registry server user name. + registry_server_password: + description: + - The container registry server password. + + scm_type: + description: + - Repository type of deployment source, for example C(LocalGit), C(GitHub). + - List of supported values maintained at U(https://docs.microsoft.com/en-us/rest/api/appservice/webapps/createorupdate#scmtype). + + deployment_source: + description: + - Deployment source for git. + suboptions: + url: + description: + - Repository url of deployment source. + + branch: + description: + - The branch name of the repository. + startup_file: + description: + - The web's startup file. + - Used only for Linux web apps. + + client_affinity_enabled: + description: + - Whether or not to send session affinity cookies, which route client requests in the same session to the same instance. + type: bool + default: True + + https_only: + description: + - Configures web site to accept only https requests. + type: bool + + dns_registration: + description: + - Whether or not the web app hostname is registered with DNS on creation. Set to C(false) to register. + type: bool + + skip_custom_domain_verification: + description: + - Whether or not to skip verification of custom (non *.azurewebsites.net) domains associated with web app. Set to C(true) to skip. + type: bool + + ttl_in_seconds: + description: + - Time to live in seconds for web app default domain name. + + app_settings: + description: + - Configure web app application settings. Suboptions are in key value pair format. + + purge_app_settings: + description: + - Purge any existing application settings. Replace web app application settings with app_settings. + type: bool + + app_state: + description: + - Start/Stop/Restart the web app. + type: str + choices: + - started + - stopped + - restarted + default: started + + state: + description: + - State of the Web App. + - Use C(present) to create or update a Web App and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Yunge Zhu (@yungezz) + +''' + +EXAMPLES = ''' + - name: Create a windows web app with non-exist app service plan + azure_rm_webapp: + resource_group: myResourceGroup + name: myWinWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + is_linux: false + sku: S1 + + - name: Create a docker web app with some app settings, with docker image + azure_rm_webapp: + resource_group: myResourceGroup + name: myDockerWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + is_linux: true + sku: S1 + number_of_workers: 2 + app_settings: + testkey: testvalue + testkey2: testvalue2 + container_settings: + name: ansible/ansible:ubuntu1404 + + - name: Create a docker web app with private acr registry + azure_rm_webapp: + resource_group: myResourceGroup + name: myDockerWebapp + plan: myAppServicePlan + app_settings: + testkey: testvalue + container_settings: + name: ansible/ubuntu1404 + registry_server_url: myregistry.io + registry_server_user: user + registry_server_password: pass + + - name: Create a linux web app with Node 6.6 framework + azure_rm_webapp: + resource_group: myResourceGroup + name: myLinuxWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey: testvalue + frameworks: + - name: "node" + version: "6.6" + + - name: Create a windows web app with node, php + azure_rm_webapp: + resource_group: myResourceGroup + name: myWinWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey: testvalue + frameworks: + - name: "node" + version: 6.6 + - name: "php" + version: "7.0" + + - name: Create a stage deployment slot for an existing web app + azure_rm_webapp: + resource_group: myResourceGroup + name: myWebapp/slots/stage + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey:testvalue + + - name: Create a linux web app with java framework + azure_rm_webapp: + resource_group: myResourceGroup + name: myLinuxWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey: testvalue + frameworks: + - name: "java" + version: "8" + settings: + java_container: "Tomcat" + java_container_version: "8.5" +''' + +RETURN = ''' +azure_webapp: + description: + - ID of current web app. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myWebApp" +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model + from azure.mgmt.web.models import ( + site_config, app_service_plan, Site, + AppServicePlan, SkuDescription, NameValuePair + ) +except ImportError: + # This is handled in azure_rm_common + pass + +container_settings_spec = dict( + name=dict(type='str', required=True), + registry_server_url=dict(type='str'), + registry_server_user=dict(type='str'), + registry_server_password=dict(type='str', no_log=True) +) + +deployment_source_spec = dict( + url=dict(type='str'), + branch=dict(type='str') +) + + +framework_settings_spec = dict( + java_container=dict(type='str', required=True), + java_container_version=dict(type='str', required=True) +) + + +framework_spec = dict( + name=dict( + type='str', + required=True, + choices=['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby']), + version=dict(type='str', required=True), + settings=dict(type='dict', options=framework_settings_spec) +) + + +def _normalize_sku(sku): + if sku is None: + return sku + + sku = sku.upper() + if sku == 'FREE': + return 'F1' + elif sku == 'SHARED': + return 'D1' + return sku + + +def get_sku_name(tier): + tier = tier.upper() + if tier == 'F1' or tier == "FREE": + return 'FREE' + elif tier == 'D1' or tier == "SHARED": + return 'SHARED' + elif tier in ['B1', 'B2', 'B3', 'BASIC']: + return 'BASIC' + elif tier in ['S1', 'S2', 'S3']: + return 'STANDARD' + elif tier in ['P1', 'P2', 'P3']: + return 'PREMIUM' + elif tier in ['P1V2', 'P2V2', 'P3V2']: + return 'PREMIUMV2' + else: + return None + + +def appserviceplan_to_dict(plan): + return dict( + id=plan.id, + name=plan.name, + kind=plan.kind, + location=plan.location, + reserved=plan.reserved, + is_linux=plan.reserved, + provisioning_state=plan.provisioning_state, + tags=plan.tags if plan.tags else None + ) + + +def webapp_to_dict(webapp): + return dict( + id=webapp.id, + name=webapp.name, + location=webapp.location, + client_cert_enabled=webapp.client_cert_enabled, + enabled=webapp.enabled, + reserved=webapp.reserved, + client_affinity_enabled=webapp.client_affinity_enabled, + server_farm_id=webapp.server_farm_id, + host_names_disabled=webapp.host_names_disabled, + https_only=webapp.https_only if hasattr(webapp, 'https_only') else None, + skip_custom_domain_verification=webapp.skip_custom_domain_verification if hasattr(webapp, 'skip_custom_domain_verification') else None, + ttl_in_seconds=webapp.ttl_in_seconds if hasattr(webapp, 'ttl_in_seconds') else None, + state=webapp.state, + tags=webapp.tags if webapp.tags else None + ) + + +class Actions: + CreateOrUpdate, UpdateAppSettings, Delete = range(3) + + +class AzureRMWebApps(AzureRMModuleBase): + """Configuration class for an Azure RM Web App resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + location=dict( + type='str' + ), + plan=dict( + type='raw' + ), + frameworks=dict( + type='list', + elements='dict', + options=framework_spec + ), + container_settings=dict( + type='dict', + options=container_settings_spec + ), + scm_type=dict( + type='str', + ), + deployment_source=dict( + type='dict', + options=deployment_source_spec + ), + startup_file=dict( + type='str' + ), + client_affinity_enabled=dict( + type='bool', + default=True + ), + dns_registration=dict( + type='bool' + ), + https_only=dict( + type='bool' + ), + skip_custom_domain_verification=dict( + type='bool' + ), + ttl_in_seconds=dict( + type='int' + ), + app_settings=dict( + type='dict' + ), + purge_app_settings=dict( + type='bool', + default=False + ), + app_state=dict( + type='str', + choices=['started', 'stopped', 'restarted'], + default='started' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + mutually_exclusive = [['container_settings', 'frameworks']] + + self.resource_group = None + self.name = None + self.location = None + + # update in create_or_update as parameters + self.client_affinity_enabled = True + self.dns_registration = None + self.skip_custom_domain_verification = None + self.ttl_in_seconds = None + self.https_only = None + + self.tags = None + + # site config, e.g app settings, ssl + self.site_config = dict() + self.app_settings = dict() + self.app_settings_strDic = None + + # app service plan + self.plan = None + + # siteSourceControl + self.deployment_source = dict() + + # site, used at level creation, or update. e.g windows/linux, client_affinity etc first level args + self.site = None + + # property for internal usage, not used for sdk + self.container_settings = None + + self.purge_app_settings = False + self.app_state = 'started' + + self.results = dict( + changed=False, + id=None, + ) + self.state = None + self.to_do = [] + + self.frameworks = None + + # set site_config value from kwargs + self.site_config_updatable_properties = ["net_framework_version", + "java_version", + "php_version", + "python_version", + "scm_type"] + + # updatable_properties + self.updatable_properties = ["client_affinity_enabled", + "force_dns_registration", + "https_only", + "skip_custom_domain_verification", + "ttl_in_seconds"] + + self.supported_linux_frameworks = ['ruby', 'php', 'dotnetcore', 'node', 'java'] + self.supported_windows_frameworks = ['net_framework', 'php', 'python', 'node', 'java'] + + super(AzureRMWebApps, self).__init__(derived_arg_spec=self.module_arg_spec, + mutually_exclusive=mutually_exclusive, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "scm_type": + self.site_config[key] = kwargs[key] + + old_response = None + response = None + to_be_updated = False + + # set location + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + self.location = resource_group.location + + # get existing web app + old_response = self.get_webapp() + + if old_response: + self.results['id'] = old_response['id'] + + if self.state == 'present': + if not self.plan and not old_response: + self.fail("Please specify plan for newly created web app.") + + if not self.plan: + self.plan = old_response['server_farm_id'] + + self.plan = self.parse_resource_to_dict(self.plan) + + # get app service plan + is_linux = False + old_plan = self.get_app_service_plan() + if old_plan: + is_linux = old_plan['reserved'] + else: + is_linux = self.plan['is_linux'] if 'is_linux' in self.plan else False + + if self.frameworks: + # java is mutually exclusive with other frameworks + if len(self.frameworks) > 1 and any(f['name'] == 'java' for f in self.frameworks): + self.fail('Java is mutually exclusive with other frameworks.') + + if is_linux: + if len(self.frameworks) != 1: + self.fail('Can specify one framework only for Linux web app.') + + if self.frameworks[0]['name'] not in self.supported_linux_frameworks: + self.fail('Unsupported framework {0} for Linux web app.'.format(self.frameworks[0]['name'])) + + self.site_config['linux_fx_version'] = (self.frameworks[0]['name'] + '|' + self.frameworks[0]['version']).upper() + + if self.frameworks[0]['name'] == 'java': + if self.frameworks[0]['version'] != '8': + self.fail("Linux web app only supports java 8.") + if self.frameworks[0]['settings'] and self.frameworks[0]['settings']['java_container'].lower() != 'tomcat': + self.fail("Linux web app only supports tomcat container.") + + if self.frameworks[0]['settings'] and self.frameworks[0]['settings']['java_container'].lower() == 'tomcat': + self.site_config['linux_fx_version'] = 'TOMCAT|' + self.frameworks[0]['settings']['java_container_version'] + '-jre8' + else: + self.site_config['linux_fx_version'] = 'JAVA|8-jre8' + else: + for fx in self.frameworks: + if fx.get('name') not in self.supported_windows_frameworks: + self.fail('Unsupported framework {0} for Windows web app.'.format(fx.get('name'))) + else: + self.site_config[fx.get('name') + '_version'] = fx.get('version') + + if 'settings' in fx and fx['settings'] is not None: + for key, value in fx['settings'].items(): + self.site_config[key] = value + + if not self.app_settings: + self.app_settings = dict() + + if self.container_settings: + linux_fx_version = 'DOCKER|' + + if self.container_settings.get('registry_server_url'): + self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url'] + + linux_fx_version += self.container_settings['registry_server_url'] + '/' + + linux_fx_version += self.container_settings['name'] + + self.site_config['linux_fx_version'] = linux_fx_version + + if self.container_settings.get('registry_server_user'): + self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings['registry_server_user'] + + if self.container_settings.get('registry_server_password'): + self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings['registry_server_password'] + + # init site + self.site = Site(location=self.location, site_config=self.site_config) + + if self.https_only is not None: + self.site.https_only = self.https_only + + if self.client_affinity_enabled: + self.site.client_affinity_enabled = self.client_affinity_enabled + + # check if the web app already present in the resource group + if not old_response: + self.log("Web App instance doesn't exist") + + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + self.site.tags = self.tags + + # service plan is required for creation + if not self.plan: + self.fail("Please specify app service plan in plan parameter.") + + if not old_plan: + # no existing service plan, create one + if (not self.plan.get('name') or not self.plan.get('sku')): + self.fail('Please specify name, is_linux, sku in plan') + + if 'location' not in self.plan: + plan_resource_group = self.get_resource_group(self.plan['resource_group']) + self.plan['location'] = plan_resource_group.location + + old_plan = self.create_app_service_plan() + + self.site.server_farm_id = old_plan['id'] + + # if linux, setup startup_file + if old_plan['is_linux']: + if hasattr(self, 'startup_file'): + self.site_config['app_command_line'] = self.startup_file + + # set app setting + if self.app_settings: + app_settings = [] + for key in self.app_settings.keys(): + app_settings.append(NameValuePair(name=key, value=self.app_settings[key])) + + self.site_config['app_settings'] = app_settings + else: + # existing web app, do update + self.log("Web App instance already exists") + + self.log('Result: {0}'.format(old_response)) + + update_tags, self.site.tags = self.update_tags(old_response.get('tags', None)) + + if update_tags: + to_be_updated = True + + # check if root level property changed + if self.is_updatable_property_changed(old_response): + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + + # check if site_config changed + old_config = self.get_webapp_configuration() + + if self.is_site_config_changed(old_config): + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + + # check if linux_fx_version changed + if old_config.linux_fx_version != self.site_config.get('linux_fx_version', ''): + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + + self.app_settings_strDic = self.list_app_settings() + + # purge existing app_settings: + if self.purge_app_settings: + to_be_updated = True + self.app_settings_strDic = dict() + self.to_do.append(Actions.UpdateAppSettings) + + # check if app settings changed + if self.purge_app_settings or self.is_app_settings_changed(): + to_be_updated = True + self.to_do.append(Actions.UpdateAppSettings) + + if self.app_settings: + for key in self.app_settings.keys(): + self.app_settings_strDic[key] = self.app_settings[key] + + elif self.state == 'absent': + if old_response: + self.log("Delete Web App instance") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_webapp() + + self.log('Web App instance deleted') + + else: + self.fail("Web app {0} not exists.".format(self.name)) + + if to_be_updated: + self.log('Need to Create/Update web app') + self.results['changed'] = True + + if self.check_mode: + return self.results + + if Actions.CreateOrUpdate in self.to_do: + response = self.create_update_webapp() + + self.results['id'] = response['id'] + + if Actions.UpdateAppSettings in self.to_do: + update_response = self.update_app_settings() + self.results['id'] = update_response.id + + webapp = None + if old_response: + webapp = old_response + if response: + webapp = response + + if webapp: + if (webapp['state'] != 'Stopped' and self.app_state == 'stopped') or \ + (webapp['state'] != 'Running' and self.app_state == 'started') or \ + self.app_state == 'restarted': + + self.results['changed'] = True + if self.check_mode: + return self.results + + self.set_webapp_state(self.app_state) + + return self.results + + # compare existing web app with input, determine weather it's update operation + def is_updatable_property_changed(self, existing_webapp): + for property_name in self.updatable_properties: + if hasattr(self, property_name) and getattr(self, property_name) is not None and \ + getattr(self, property_name) != existing_webapp.get(property_name, None): + return True + + return False + + # compare xxx_version + def is_site_config_changed(self, existing_config): + for fx_version in self.site_config_updatable_properties: + if self.site_config.get(fx_version): + if not getattr(existing_config, fx_version) or \ + getattr(existing_config, fx_version).upper() != self.site_config.get(fx_version).upper(): + return True + + return False + + # comparing existing app setting with input, determine whether it's changed + def is_app_settings_changed(self): + if self.app_settings: + if self.app_settings_strDic: + for key in self.app_settings.keys(): + if self.app_settings[key] != self.app_settings_strDic.get(key, None): + return True + else: + return True + return False + + # comparing deployment source with input, determine wheather it's changed + def is_deployment_source_changed(self, existing_webapp): + if self.deployment_source: + if self.deployment_source.get('url') \ + and self.deployment_source['url'] != existing_webapp.get('site_source_control')['url']: + return True + + if self.deployment_source.get('branch') \ + and self.deployment_source['branch'] != existing_webapp.get('site_source_control')['branch']: + return True + + return False + + def create_update_webapp(self): + ''' + Creates or updates Web App with the specified configuration. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Creating / Updating the Web App instance {0}".format(self.name)) + + try: + skip_dns_registration = self.dns_registration + force_dns_registration = None if self.dns_registration is None else not self.dns_registration + + response = self.web_client.web_apps.create_or_update(resource_group_name=self.resource_group, + name=self.name, + site_envelope=self.site, + skip_dns_registration=skip_dns_registration, + skip_custom_domain_verification=self.skip_custom_domain_verification, + force_dns_registration=force_dns_registration, + ttl_in_seconds=self.ttl_in_seconds) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the Web App instance.') + self.fail( + "Error creating the Web App instance: {0}".format(str(exc))) + return webapp_to_dict(response) + + def delete_webapp(self): + ''' + Deletes specified Web App instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the Web App instance {0}".format(self.name)) + try: + response = self.web_client.web_apps.delete(resource_group_name=self.resource_group, + name=self.name) + except CloudError as e: + self.log('Error attempting to delete the Web App instance.') + self.fail( + "Error deleting the Web App instance: {0}".format(str(e))) + + return True + + def get_webapp(self): + ''' + Gets the properties of the specified Web App. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Checking if the Web App instance {0} is present".format(self.name)) + + response = None + + try: + response = self.web_client.web_apps.get(resource_group_name=self.resource_group, + name=self.name) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("Web App instance : {0} found".format(response.name)) + return webapp_to_dict(response) + + except CloudError as ex: + pass + + self.log("Didn't find web app {0} in resource group {1}".format( + self.name, self.resource_group)) + + return False + + def get_app_service_plan(self): + ''' + Gets app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Get App Service Plan {0}".format(self.plan['name'])) + + try: + response = self.web_client.app_service_plans.get( + resource_group_name=self.plan['resource_group'], + name=self.plan['name']) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("App Service Plan : {0} found".format(response.name)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + pass + + self.log("Didn't find app service plan {0} in resource group {1}".format( + self.plan['name'], self.plan['resource_group'])) + + return False + + def create_app_service_plan(self): + ''' + Creates app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Create App Service Plan {0}".format(self.plan['name'])) + + try: + # normalize sku + sku = _normalize_sku(self.plan['sku']) + + sku_def = SkuDescription(tier=get_sku_name( + sku), name=sku, capacity=(self.plan.get('number_of_workers', None))) + plan_def = AppServicePlan( + location=self.plan['location'], app_service_plan_name=self.plan['name'], sku=sku_def, reserved=(self.plan.get('is_linux', None))) + + poller = self.web_client.app_service_plans.create_or_update( + self.plan['resource_group'], self.plan['name'], plan_def) + + if isinstance(poller, LROPoller): + response = self.get_poller_result(poller) + + self.log("Response : {0}".format(response)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + self.fail("Failed to create app service plan {0} in resource group {1}: {2}".format( + self.plan['name'], self.plan['resource_group'], str(ex))) + + def list_app_settings(self): + ''' + List application settings + :return: deserialized list response + ''' + self.log("List application setting") + + try: + + response = self.web_client.web_apps.list_application_settings( + resource_group_name=self.resource_group, name=self.name) + self.log("Response : {0}".format(response)) + + return response.properties + except CloudError as ex: + self.fail("Failed to list application settings for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def update_app_settings(self): + ''' + Update application settings + :return: deserialized updating response + ''' + self.log("Update application setting") + + try: + response = self.web_client.web_apps.update_application_settings( + resource_group_name=self.resource_group, name=self.name, properties=self.app_settings_strDic) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to update application settings for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def create_or_update_source_control(self): + ''' + Update site source control + :return: deserialized updating response + ''' + self.log("Update site source control") + + if self.deployment_source is None: + return False + + self.deployment_source['is_manual_integration'] = False + self.deployment_source['is_mercurial'] = False + + try: + response = self.web_client.web_client.create_or_update_source_control( + self.resource_group, self.name, self.deployment_source) + self.log("Response : {0}".format(response)) + + return response.as_dict() + except CloudError as ex: + self.fail("Failed to update site source control for web app {0} in resource group {1}".format( + self.name, self.resource_group)) + + def get_webapp_configuration(self): + ''' + Get web app configuration + :return: deserialized web app configuration response + ''' + self.log("Get web app configuration") + + try: + + response = self.web_client.web_apps.get_configuration( + resource_group_name=self.resource_group, name=self.name) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.log("Failed to get configuration for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + return False + + def set_webapp_state(self, appstate): + ''' + Start/stop/restart web app + :return: deserialized updating response + ''' + try: + if appstate == 'started': + response = self.web_client.web_apps.start(resource_group_name=self.resource_group, name=self.name) + elif appstate == 'stopped': + response = self.web_client.web_apps.stop(resource_group_name=self.resource_group, name=self.name) + elif appstate == 'restarted': + response = self.web_client.web_apps.restart(resource_group_name=self.resource_group, name=self.name) + else: + self.fail("Invalid web app state {0}".format(appstate)) + + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.log("Failed to {0} web app {1} in resource group {2}, request_id {3} - {4}".format( + appstate, self.name, self.resource_group, request_id, str(ex))) + + +def main(): + """Main execution""" + AzureRMWebApps() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_webapp_info.py b/test/support/integration/plugins/modules/azure_rm_webapp_info.py new file mode 100644 index 00000000..22286803 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_webapp_info.py @@ -0,0 +1,489 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_webapp_info + +version_added: "2.9" + +short_description: Get Azure web app facts + +description: + - Get facts for a specific web app or all web app in a resource group, or all web app in current subscription. + +options: + name: + description: + - Only show results for a specific web app. + resource_group: + description: + - Limit results by resource group. + return_publish_profile: + description: + - Indicate whether to return publishing profile of the web app. + default: False + type: bool + tags: + description: + - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'. + +extends_documentation_fragment: + - azure + +author: + - Yunge Zhu (@yungezz) +''' + +EXAMPLES = ''' + - name: Get facts for web app by name + azure_rm_webapp_info: + resource_group: myResourceGroup + name: winwebapp1 + + - name: Get facts for web apps in resource group + azure_rm_webapp_info: + resource_group: myResourceGroup + + - name: Get facts for web apps with tags + azure_rm_webapp_info: + tags: + - testtag + - foo:bar +''' + +RETURN = ''' +webapps: + description: + - List of web apps. + returned: always + type: complex + contains: + id: + description: + - ID of the web app. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myWebApp + name: + description: + - Name of the web app. + returned: always + type: str + sample: winwebapp1 + resource_group: + description: + - Resource group of the web app. + returned: always + type: str + sample: myResourceGroup + location: + description: + - Location of the web app. + returned: always + type: str + sample: eastus + plan: + description: + - ID of app service plan used by the web app. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/myAppServicePlan + app_settings: + description: + - App settings of the application. Only returned when web app has app settings. + returned: always + type: dict + sample: { + "testkey": "testvalue", + "testkey2": "testvalue2" + } + frameworks: + description: + - Frameworks of the application. Only returned when web app has frameworks. + returned: always + type: list + sample: [ + { + "name": "net_framework", + "version": "v4.0" + }, + { + "name": "java", + "settings": { + "java_container": "tomcat", + "java_container_version": "8.5" + }, + "version": "1.7" + }, + { + "name": "php", + "version": "5.6" + } + ] + availability_state: + description: + - Availability of this web app. + returned: always + type: str + sample: Normal + default_host_name: + description: + - Host name of the web app. + returned: always + type: str + sample: vxxisurg397winapp4.azurewebsites.net + enabled: + description: + - Indicates the web app enabled or not. + returned: always + type: bool + sample: true + enabled_host_names: + description: + - Enabled host names of the web app. + returned: always + type: list + sample: [ + "vxxisurg397winapp4.azurewebsites.net", + "vxxisurg397winapp4.scm.azurewebsites.net" + ] + host_name_ssl_states: + description: + - SSL state per host names of the web app. + returned: always + type: list + sample: [ + { + "hostType": "Standard", + "name": "vxxisurg397winapp4.azurewebsites.net", + "sslState": "Disabled" + }, + { + "hostType": "Repository", + "name": "vxxisurg397winapp4.scm.azurewebsites.net", + "sslState": "Disabled" + } + ] + host_names: + description: + - Host names of the web app. + returned: always + type: list + sample: [ + "vxxisurg397winapp4.azurewebsites.net" + ] + outbound_ip_addresses: + description: + - Outbound IP address of the web app. + returned: always + type: str + sample: "40.71.11.131,40.85.166.200,168.62.166.67,137.135.126.248,137.135.121.45" + ftp_publish_url: + description: + - Publishing URL of the web app when deployment type is FTP. + returned: always + type: str + sample: ftp://xxxx.ftp.azurewebsites.windows.net + state: + description: + - State of the web app. + returned: always + type: str + sample: running + publishing_username: + description: + - Publishing profile user name. + returned: only when I(return_publish_profile=True). + type: str + sample: "$vxxisuRG397winapp4" + publishing_password: + description: + - Publishing profile password. + returned: only when I(return_publish_profile=True). + type: str + sample: "uvANsPQpGjWJmrFfm4Ssd5rpBSqGhjMk11pMSgW2vCsQtNx9tcgZ0xN26s9A" + tags: + description: + - Tags assigned to the resource. Dictionary of string:string pairs. + returned: always + type: dict + sample: { tag1: abc } +''' +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from azure.common import AzureMissingResourceHttpError, AzureHttpError +except Exception: + # This is handled in azure_rm_common + pass + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +AZURE_OBJECT_CLASS = 'WebApp' + + +class AzureRMWebAppInfo(AzureRMModuleBase): + + def __init__(self): + + self.module_arg_spec = dict( + name=dict(type='str'), + resource_group=dict(type='str'), + tags=dict(type='list'), + return_publish_profile=dict(type='bool', default=False), + ) + + self.results = dict( + changed=False, + webapps=[], + ) + + self.name = None + self.resource_group = None + self.tags = None + self.return_publish_profile = False + + self.framework_names = ['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby'] + + super(AzureRMWebAppInfo, self).__init__(self.module_arg_spec, + supports_tags=False, + facts_module=True) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_webapp_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_webapp_facts' module has been renamed to 'azure_rm_webapp_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if self.name: + self.results['webapps'] = self.list_by_name() + elif self.resource_group: + self.results['webapps'] = self.list_by_resource_group() + else: + self.results['webapps'] = self.list_all() + + return self.results + + def list_by_name(self): + self.log('Get web app {0}'.format(self.name)) + item = None + result = [] + + try: + item = self.web_client.web_apps.get(self.resource_group, self.name) + except CloudError: + pass + + if item and self.has_tags(item.tags, self.tags): + curated_result = self.get_curated_webapp(self.resource_group, self.name, item) + result = [curated_result] + + return result + + def list_by_resource_group(self): + self.log('List web apps in resource groups {0}'.format(self.resource_group)) + try: + response = list(self.web_client.web_apps.list_by_resource_group(self.resource_group)) + except CloudError as exc: + request_id = exc.request_id if exc.request_id else '' + self.fail("Error listing web apps in resource groups {0}, request id: {1} - {2}".format(self.resource_group, request_id, str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + curated_output = self.get_curated_webapp(self.resource_group, item.name, item) + results.append(curated_output) + return results + + def list_all(self): + self.log('List web apps in current subscription') + try: + response = list(self.web_client.web_apps.list()) + except CloudError as exc: + request_id = exc.request_id if exc.request_id else '' + self.fail("Error listing web apps, request id {0} - {1}".format(request_id, str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + curated_output = self.get_curated_webapp(item.resource_group, item.name, item) + results.append(curated_output) + return results + + def list_webapp_configuration(self, resource_group, name): + self.log('Get web app {0} configuration'.format(name)) + + response = [] + + try: + response = self.web_client.web_apps.get_configuration(resource_group_name=resource_group, name=name) + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail('Error getting web app {0} configuration, request id {1} - {2}'.format(name, request_id, str(ex))) + + return response.as_dict() + + def list_webapp_appsettings(self, resource_group, name): + self.log('Get web app {0} app settings'.format(name)) + + response = [] + + try: + response = self.web_client.web_apps.list_application_settings(resource_group_name=resource_group, name=name) + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail('Error getting web app {0} app settings, request id {1} - {2}'.format(name, request_id, str(ex))) + + return response.as_dict() + + def get_publish_credentials(self, resource_group, name): + self.log('Get web app {0} publish credentials'.format(name)) + try: + poller = self.web_client.web_apps.list_publishing_credentials(resource_group, name) + if isinstance(poller, LROPoller): + response = self.get_poller_result(poller) + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail('Error getting web app {0} publishing credentials - {1}'.format(request_id, str(ex))) + return response + + def get_webapp_ftp_publish_url(self, resource_group, name): + import xmltodict + + self.log('Get web app {0} app publish profile'.format(name)) + + url = None + try: + content = self.web_client.web_apps.list_publishing_profile_xml_with_secrets(resource_group_name=resource_group, name=name) + if not content: + return url + + full_xml = '' + for f in content: + full_xml += f.decode() + profiles = xmltodict.parse(full_xml, xml_attribs=True)['publishData']['publishProfile'] + + if not profiles: + return url + + for profile in profiles: + if profile['@publishMethod'] == 'FTP': + url = profile['@publishUrl'] + + except CloudError as ex: + self.fail('Error getting web app {0} app settings'.format(name)) + + return url + + def get_curated_webapp(self, resource_group, name, webapp): + pip = self.serialize_obj(webapp, AZURE_OBJECT_CLASS) + + try: + site_config = self.list_webapp_configuration(resource_group, name) + app_settings = self.list_webapp_appsettings(resource_group, name) + publish_cred = self.get_publish_credentials(resource_group, name) + ftp_publish_url = self.get_webapp_ftp_publish_url(resource_group, name) + except CloudError as ex: + pass + return self.construct_curated_webapp(webapp=pip, + configuration=site_config, + app_settings=app_settings, + deployment_slot=None, + ftp_publish_url=ftp_publish_url, + publish_credentials=publish_cred) + + def construct_curated_webapp(self, + webapp, + configuration=None, + app_settings=None, + deployment_slot=None, + ftp_publish_url=None, + publish_credentials=None): + curated_output = dict() + curated_output['id'] = webapp['id'] + curated_output['name'] = webapp['name'] + curated_output['resource_group'] = webapp['properties']['resourceGroup'] + curated_output['location'] = webapp['location'] + curated_output['plan'] = webapp['properties']['serverFarmId'] + curated_output['tags'] = webapp.get('tags', None) + + # important properties from output. not match input arguments. + curated_output['app_state'] = webapp['properties']['state'] + curated_output['availability_state'] = webapp['properties']['availabilityState'] + curated_output['default_host_name'] = webapp['properties']['defaultHostName'] + curated_output['host_names'] = webapp['properties']['hostNames'] + curated_output['enabled'] = webapp['properties']['enabled'] + curated_output['enabled_host_names'] = webapp['properties']['enabledHostNames'] + curated_output['host_name_ssl_states'] = webapp['properties']['hostNameSslStates'] + curated_output['outbound_ip_addresses'] = webapp['properties']['outboundIpAddresses'] + + # curated site_config + if configuration: + curated_output['frameworks'] = [] + for fx_name in self.framework_names: + fx_version = configuration.get(fx_name + '_version', None) + if fx_version: + fx = { + 'name': fx_name, + 'version': fx_version + } + # java container setting + if fx_name == 'java': + if configuration['java_container'] and configuration['java_container_version']: + settings = { + 'java_container': configuration['java_container'].lower(), + 'java_container_version': configuration['java_container_version'] + } + fx['settings'] = settings + + curated_output['frameworks'].append(fx) + + # linux_fx_version + if configuration.get('linux_fx_version', None): + tmp = configuration.get('linux_fx_version').split("|") + if len(tmp) == 2: + curated_output['frameworks'].append({'name': tmp[0].lower(), 'version': tmp[1]}) + + # curated app_settings + if app_settings and app_settings.get('properties', None): + curated_output['app_settings'] = dict() + for item in app_settings['properties']: + curated_output['app_settings'][item] = app_settings['properties'][item] + + # curated deploymenet_slot + if deployment_slot: + curated_output['deployment_slot'] = deployment_slot + + # ftp_publish_url + if ftp_publish_url: + curated_output['ftp_publish_url'] = ftp_publish_url + + # curated publish credentials + if publish_credentials and self.return_publish_profile: + curated_output['publishing_username'] = publish_credentials.publishing_user_name + curated_output['publishing_password'] = publish_credentials.publishing_password + return curated_output + + +def main(): + AzureRMWebAppInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_webappslot.py b/test/support/integration/plugins/modules/azure_rm_webappslot.py new file mode 100644 index 00000000..ddba710b --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_webappslot.py @@ -0,0 +1,1058 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_webappslot +version_added: "2.8" +short_description: Manage Azure Web App slot +description: + - Create, update and delete Azure Web App slot. + +options: + resource_group: + description: + - Name of the resource group to which the resource belongs. + required: True + name: + description: + - Unique name of the deployment slot to create or update. + required: True + webapp_name: + description: + - Web app name which this deployment slot belongs to. + required: True + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + configuration_source: + description: + - Source slot to clone configurations from when creating slot. Use webapp's name to refer to the production slot. + auto_swap_slot_name: + description: + - Used to configure target slot name to auto swap, or disable auto swap. + - Set it target slot name to auto swap. + - Set it to False to disable auto slot swap. + swap: + description: + - Swap deployment slots of a web app. + suboptions: + action: + description: + - Swap types. + - C(preview) is to apply target slot settings on source slot first. + - C(swap) is to complete swapping. + - C(reset) is to reset the swap. + choices: + - preview + - swap + - reset + default: preview + target_slot: + description: + - Name of target slot to swap. If set to None, then swap with production slot. + preserve_vnet: + description: + - C(True) to preserve virtual network to the slot during swap. Otherwise C(False). + type: bool + default: True + frameworks: + description: + - Set of run time framework settings. Each setting is a dictionary. + - See U(https://docs.microsoft.com/en-us/azure/app-service/app-service-web-overview) for more info. + suboptions: + name: + description: + - Name of the framework. + - Supported framework list for Windows web app and Linux web app is different. + - Windows web apps support C(java), C(net_framework), C(php), C(python), and C(node) from June 2018. + - Windows web apps support multiple framework at same time. + - Linux web apps support C(java), C(ruby), C(php), C(dotnetcore), and C(node) from June 2018. + - Linux web apps support only one framework. + - Java framework is mutually exclusive with others. + choices: + - java + - net_framework + - php + - python + - ruby + - dotnetcore + - node + version: + description: + - Version of the framework. For Linux web app supported value, see U(https://aka.ms/linux-stacks) for more info. + - C(net_framework) supported value sample, C(v4.0) for .NET 4.6 and C(v3.0) for .NET 3.5. + - C(php) supported value sample, C(5.5), C(5.6), C(7.0). + - C(python) supported value sample, C(5.5), C(5.6), C(7.0). + - C(node) supported value sample, C(6.6), C(6.9). + - C(dotnetcore) supported value sample, C(1.0), C(1.1), C(1.2). + - C(ruby) supported value sample, 2.3. + - C(java) supported value sample, C(1.9) for Windows web app. C(1.8) for Linux web app. + settings: + description: + - List of settings of the framework. + suboptions: + java_container: + description: + - Name of Java container. This is supported by specific framework C(java) onlys, for example C(Tomcat), C(Jetty). + java_container_version: + description: + - Version of Java container. This is supported by specific framework C(java) only. + - For C(Tomcat), for example C(8.0), C(8.5), C(9.0). For C(Jetty), for example C(9.1), C(9.3). + container_settings: + description: + - Web app slot container settings. + suboptions: + name: + description: + - Name of container, for example C(imagename:tag). + registry_server_url: + description: + - Container registry server URL, for example C(mydockerregistry.io). + registry_server_user: + description: + - The container registry server user name. + registry_server_password: + description: + - The container registry server password. + startup_file: + description: + - The slot startup file. + - This only applies for Linux web app slot. + app_settings: + description: + - Configure web app slot application settings. Suboptions are in key value pair format. + purge_app_settings: + description: + - Purge any existing application settings. Replace slot application settings with app_settings. + type: bool + deployment_source: + description: + - Deployment source for git. + suboptions: + url: + description: + - Repository URL of deployment source. + branch: + description: + - The branch name of the repository. + app_state: + description: + - Start/Stop/Restart the slot. + type: str + choices: + - started + - stopped + - restarted + default: started + state: + description: + - State of the Web App deployment slot. + - Use C(present) to create or update a slot and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Yunge Zhu(@yungezz) + +''' + +EXAMPLES = ''' + - name: Create a webapp slot + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + configuration_source: myJavaWebApp + app_settings: + testkey: testvalue + + - name: swap the slot with production slot + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + swap: + action: swap + + - name: stop the slot + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + app_state: stopped + + - name: udpate a webapp slot app settings + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + app_settings: + testkey: testvalue2 + + - name: udpate a webapp slot frameworks + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + frameworks: + - name: "node" + version: "10.1" +''' + +RETURN = ''' +id: + description: + - ID of current slot. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/testapp/slots/stage1 +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model + from azure.mgmt.web.models import ( + site_config, app_service_plan, Site, + AppServicePlan, SkuDescription, NameValuePair + ) +except ImportError: + # This is handled in azure_rm_common + pass + +swap_spec = dict( + action=dict( + type='str', + choices=[ + 'preview', + 'swap', + 'reset' + ], + default='preview' + ), + target_slot=dict( + type='str' + ), + preserve_vnet=dict( + type='bool', + default=True + ) +) + +container_settings_spec = dict( + name=dict(type='str', required=True), + registry_server_url=dict(type='str'), + registry_server_user=dict(type='str'), + registry_server_password=dict(type='str', no_log=True) +) + +deployment_source_spec = dict( + url=dict(type='str'), + branch=dict(type='str') +) + + +framework_settings_spec = dict( + java_container=dict(type='str', required=True), + java_container_version=dict(type='str', required=True) +) + + +framework_spec = dict( + name=dict( + type='str', + required=True, + choices=['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby']), + version=dict(type='str', required=True), + settings=dict(type='dict', options=framework_settings_spec) +) + + +def webapp_to_dict(webapp): + return dict( + id=webapp.id, + name=webapp.name, + location=webapp.location, + client_cert_enabled=webapp.client_cert_enabled, + enabled=webapp.enabled, + reserved=webapp.reserved, + client_affinity_enabled=webapp.client_affinity_enabled, + server_farm_id=webapp.server_farm_id, + host_names_disabled=webapp.host_names_disabled, + https_only=webapp.https_only if hasattr(webapp, 'https_only') else None, + skip_custom_domain_verification=webapp.skip_custom_domain_verification if hasattr(webapp, 'skip_custom_domain_verification') else None, + ttl_in_seconds=webapp.ttl_in_seconds if hasattr(webapp, 'ttl_in_seconds') else None, + state=webapp.state, + tags=webapp.tags if webapp.tags else None + ) + + +def slot_to_dict(slot): + return dict( + id=slot.id, + resource_group=slot.resource_group, + server_farm_id=slot.server_farm_id, + target_swap_slot=slot.target_swap_slot, + enabled_host_names=slot.enabled_host_names, + slot_swap_status=slot.slot_swap_status, + name=slot.name, + location=slot.location, + enabled=slot.enabled, + reserved=slot.reserved, + host_names_disabled=slot.host_names_disabled, + state=slot.state, + repository_site_name=slot.repository_site_name, + default_host_name=slot.default_host_name, + kind=slot.kind, + site_config=slot.site_config, + tags=slot.tags if slot.tags else None + ) + + +class Actions: + NoAction, CreateOrUpdate, UpdateAppSettings, Delete = range(4) + + +class AzureRMWebAppSlots(AzureRMModuleBase): + """Configuration class for an Azure RM Web App slot resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + webapp_name=dict( + type='str', + required=True + ), + location=dict( + type='str' + ), + configuration_source=dict( + type='str' + ), + auto_swap_slot_name=dict( + type='raw' + ), + swap=dict( + type='dict', + options=swap_spec + ), + frameworks=dict( + type='list', + elements='dict', + options=framework_spec + ), + container_settings=dict( + type='dict', + options=container_settings_spec + ), + deployment_source=dict( + type='dict', + options=deployment_source_spec + ), + startup_file=dict( + type='str' + ), + app_settings=dict( + type='dict' + ), + purge_app_settings=dict( + type='bool', + default=False + ), + app_state=dict( + type='str', + choices=['started', 'stopped', 'restarted'], + default='started' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + mutually_exclusive = [['container_settings', 'frameworks']] + + self.resource_group = None + self.name = None + self.webapp_name = None + self.location = None + + self.auto_swap_slot_name = None + self.swap = None + self.tags = None + self.startup_file = None + self.configuration_source = None + self.clone = False + + # site config, e.g app settings, ssl + self.site_config = dict() + self.app_settings = dict() + self.app_settings_strDic = None + + # siteSourceControl + self.deployment_source = dict() + + # site, used at level creation, or update. + self.site = None + + # property for internal usage, not used for sdk + self.container_settings = None + + self.purge_app_settings = False + self.app_state = 'started' + + self.results = dict( + changed=False, + id=None, + ) + self.state = None + self.to_do = Actions.NoAction + + self.frameworks = None + + # set site_config value from kwargs + self.site_config_updatable_frameworks = ["net_framework_version", + "java_version", + "php_version", + "python_version", + "linux_fx_version"] + + self.supported_linux_frameworks = ['ruby', 'php', 'dotnetcore', 'node', 'java'] + self.supported_windows_frameworks = ['net_framework', 'php', 'python', 'node', 'java'] + + super(AzureRMWebAppSlots, self).__init__(derived_arg_spec=self.module_arg_spec, + mutually_exclusive=mutually_exclusive, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "scm_type": + self.site_config[key] = kwargs[key] + + old_response = None + response = None + to_be_updated = False + + # set location + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + self.location = resource_group.location + + # get web app + webapp_response = self.get_webapp() + + if not webapp_response: + self.fail("Web app {0} does not exist in resource group {1}.".format(self.webapp_name, self.resource_group)) + + # get slot + old_response = self.get_slot() + + # set is_linux + is_linux = True if webapp_response['reserved'] else False + + if self.state == 'present': + if self.frameworks: + # java is mutually exclusive with other frameworks + if len(self.frameworks) > 1 and any(f['name'] == 'java' for f in self.frameworks): + self.fail('Java is mutually exclusive with other frameworks.') + + if is_linux: + if len(self.frameworks) != 1: + self.fail('Can specify one framework only for Linux web app.') + + if self.frameworks[0]['name'] not in self.supported_linux_frameworks: + self.fail('Unsupported framework {0} for Linux web app.'.format(self.frameworks[0]['name'])) + + self.site_config['linux_fx_version'] = (self.frameworks[0]['name'] + '|' + self.frameworks[0]['version']).upper() + + if self.frameworks[0]['name'] == 'java': + if self.frameworks[0]['version'] != '8': + self.fail("Linux web app only supports java 8.") + + if self.frameworks[0].get('settings', {}) and self.frameworks[0]['settings'].get('java_container', None) and \ + self.frameworks[0]['settings']['java_container'].lower() != 'tomcat': + self.fail("Linux web app only supports tomcat container.") + + if self.frameworks[0].get('settings', {}) and self.frameworks[0]['settings'].get('java_container', None) and \ + self.frameworks[0]['settings']['java_container'].lower() == 'tomcat': + self.site_config['linux_fx_version'] = 'TOMCAT|' + self.frameworks[0]['settings']['java_container_version'] + '-jre8' + else: + self.site_config['linux_fx_version'] = 'JAVA|8-jre8' + else: + for fx in self.frameworks: + if fx.get('name') not in self.supported_windows_frameworks: + self.fail('Unsupported framework {0} for Windows web app.'.format(fx.get('name'))) + else: + self.site_config[fx.get('name') + '_version'] = fx.get('version') + + if 'settings' in fx and fx['settings'] is not None: + for key, value in fx['settings'].items(): + self.site_config[key] = value + + if not self.app_settings: + self.app_settings = dict() + + if self.container_settings: + linux_fx_version = 'DOCKER|' + + if self.container_settings.get('registry_server_url'): + self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url'] + + linux_fx_version += self.container_settings['registry_server_url'] + '/' + + linux_fx_version += self.container_settings['name'] + + self.site_config['linux_fx_version'] = linux_fx_version + + if self.container_settings.get('registry_server_user'): + self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings['registry_server_user'] + + if self.container_settings.get('registry_server_password'): + self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings['registry_server_password'] + + # set auto_swap_slot_name + if self.auto_swap_slot_name and isinstance(self.auto_swap_slot_name, str): + self.site_config['auto_swap_slot_name'] = self.auto_swap_slot_name + if self.auto_swap_slot_name is False: + self.site_config['auto_swap_slot_name'] = None + + # init site + self.site = Site(location=self.location, site_config=self.site_config) + + # check if the slot already present in the webapp + if not old_response: + self.log("Web App slot doesn't exist") + + to_be_updated = True + self.to_do = Actions.CreateOrUpdate + self.site.tags = self.tags + + # if linux, setup startup_file + if self.startup_file: + self.site_config['app_command_line'] = self.startup_file + + # set app setting + if self.app_settings: + app_settings = [] + for key in self.app_settings.keys(): + app_settings.append(NameValuePair(name=key, value=self.app_settings[key])) + + self.site_config['app_settings'] = app_settings + + # clone slot + if self.configuration_source: + self.clone = True + + else: + # existing slot, do update + self.log("Web App slot already exists") + + self.log('Result: {0}'.format(old_response)) + + update_tags, self.site.tags = self.update_tags(old_response.get('tags', None)) + + if update_tags: + to_be_updated = True + + # check if site_config changed + old_config = self.get_configuration_slot(self.name) + + if self.is_site_config_changed(old_config): + to_be_updated = True + self.to_do = Actions.CreateOrUpdate + + self.app_settings_strDic = self.list_app_settings_slot(self.name) + + # purge existing app_settings: + if self.purge_app_settings: + to_be_updated = True + self.to_do = Actions.UpdateAppSettings + self.app_settings_strDic = dict() + + # check if app settings changed + if self.purge_app_settings or self.is_app_settings_changed(): + to_be_updated = True + self.to_do = Actions.UpdateAppSettings + + if self.app_settings: + for key in self.app_settings.keys(): + self.app_settings_strDic[key] = self.app_settings[key] + + elif self.state == 'absent': + if old_response: + self.log("Delete Web App slot") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_slot() + + self.log('Web App slot deleted') + + else: + self.log("Web app slot {0} not exists.".format(self.name)) + + if to_be_updated: + self.log('Need to Create/Update web app') + self.results['changed'] = True + + if self.check_mode: + return self.results + + if self.to_do == Actions.CreateOrUpdate: + response = self.create_update_slot() + + self.results['id'] = response['id'] + + if self.clone: + self.clone_slot() + + if self.to_do == Actions.UpdateAppSettings: + self.update_app_settings_slot() + + slot = None + if response: + slot = response + if old_response: + slot = old_response + + if slot: + if (slot['state'] != 'Stopped' and self.app_state == 'stopped') or \ + (slot['state'] != 'Running' and self.app_state == 'started') or \ + self.app_state == 'restarted': + + self.results['changed'] = True + if self.check_mode: + return self.results + + self.set_state_slot(self.app_state) + + if self.swap: + self.results['changed'] = True + if self.check_mode: + return self.results + + self.swap_slot() + + return self.results + + # compare site config + def is_site_config_changed(self, existing_config): + for fx_version in self.site_config_updatable_frameworks: + if self.site_config.get(fx_version): + if not getattr(existing_config, fx_version) or \ + getattr(existing_config, fx_version).upper() != self.site_config.get(fx_version).upper(): + return True + + if self.auto_swap_slot_name is False and existing_config.auto_swap_slot_name is not None: + return True + elif self.auto_swap_slot_name and self.auto_swap_slot_name != getattr(existing_config, 'auto_swap_slot_name', None): + return True + return False + + # comparing existing app setting with input, determine whether it's changed + def is_app_settings_changed(self): + if self.app_settings: + if len(self.app_settings_strDic) != len(self.app_settings): + return True + + if self.app_settings_strDic != self.app_settings: + return True + return False + + # comparing deployment source with input, determine whether it's changed + def is_deployment_source_changed(self, existing_webapp): + if self.deployment_source: + if self.deployment_source.get('url') \ + and self.deployment_source['url'] != existing_webapp.get('site_source_control')['url']: + return True + + if self.deployment_source.get('branch') \ + and self.deployment_source['branch'] != existing_webapp.get('site_source_control')['branch']: + return True + + return False + + def create_update_slot(self): + ''' + Creates or updates Web App slot with the specified configuration. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Creating / Updating the Web App slot {0}".format(self.name)) + + try: + response = self.web_client.web_apps.create_or_update_slot(resource_group_name=self.resource_group, + slot=self.name, + name=self.webapp_name, + site_envelope=self.site) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the Web App slot instance.') + self.fail("Error creating the Web App slot: {0}".format(str(exc))) + return slot_to_dict(response) + + def delete_slot(self): + ''' + Deletes specified Web App slot in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the Web App slot {0}".format(self.name)) + try: + response = self.web_client.web_apps.delete_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name) + except CloudError as e: + self.log('Error attempting to delete the Web App slot.') + self.fail( + "Error deleting the Web App slots: {0}".format(str(e))) + + return True + + def get_webapp(self): + ''' + Gets the properties of the specified Web App. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Checking if the Web App instance {0} is present".format(self.webapp_name)) + + response = None + + try: + response = self.web_client.web_apps.get(resource_group_name=self.resource_group, + name=self.webapp_name) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("Web App instance : {0} found".format(response.name)) + return webapp_to_dict(response) + + except CloudError as ex: + pass + + self.log("Didn't find web app {0} in resource group {1}".format( + self.webapp_name, self.resource_group)) + + return False + + def get_slot(self): + ''' + Gets the properties of the specified Web App slot. + + :return: deserialized Web App slot state dictionary + ''' + self.log( + "Checking if the Web App slot {0} is present".format(self.name)) + + response = None + + try: + response = self.web_client.web_apps.get_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("Web App slot: {0} found".format(response.name)) + return slot_to_dict(response) + + except CloudError as ex: + pass + + self.log("Does not find web app slot {0} in resource group {1}".format(self.name, self.resource_group)) + + return False + + def list_app_settings(self): + ''' + List webapp application settings + :return: deserialized list response + ''' + self.log("List webapp application setting") + + try: + + response = self.web_client.web_apps.list_application_settings( + resource_group_name=self.resource_group, name=self.webapp_name) + self.log("Response : {0}".format(response)) + + return response.properties + except CloudError as ex: + self.fail("Failed to list application settings for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def list_app_settings_slot(self, slot_name): + ''' + List application settings + :return: deserialized list response + ''' + self.log("List application setting") + + try: + + response = self.web_client.web_apps.list_application_settings_slot( + resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name) + self.log("Response : {0}".format(response)) + + return response.properties + except CloudError as ex: + self.fail("Failed to list application settings for web app slot {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def update_app_settings_slot(self, slot_name=None, app_settings=None): + ''' + Update application settings + :return: deserialized updating response + ''' + self.log("Update application setting") + + if slot_name is None: + slot_name = self.name + if app_settings is None: + app_settings = self.app_settings_strDic + try: + response = self.web_client.web_apps.update_application_settings_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=slot_name, + kind=None, + properties=app_settings) + self.log("Response : {0}".format(response)) + + return response.as_dict() + except CloudError as ex: + self.fail("Failed to update application settings for web app slot {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + return response + + def create_or_update_source_control_slot(self): + ''' + Update site source control + :return: deserialized updating response + ''' + self.log("Update site source control") + + if self.deployment_source is None: + return False + + self.deployment_source['is_manual_integration'] = False + self.deployment_source['is_mercurial'] = False + + try: + response = self.web_client.web_client.create_or_update_source_control_slot( + resource_group_name=self.resource_group, + name=self.webapp_name, + site_source_control=self.deployment_source, + slot=self.name) + self.log("Response : {0}".format(response)) + + return response.as_dict() + except CloudError as ex: + self.fail("Failed to update site source control for web app slot {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def get_configuration(self): + ''' + Get web app configuration + :return: deserialized web app configuration response + ''' + self.log("Get web app configuration") + + try: + + response = self.web_client.web_apps.get_configuration( + resource_group_name=self.resource_group, name=self.webapp_name) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to get configuration for web app {0} in resource group {1}: {2}".format( + self.webapp_name, self.resource_group, str(ex))) + + def get_configuration_slot(self, slot_name): + ''' + Get slot configuration + :return: deserialized slot configuration response + ''' + self.log("Get web app slot configuration") + + try: + + response = self.web_client.web_apps.get_configuration_slot( + resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to get configuration for web app slot {0} in resource group {1}: {2}".format( + slot_name, self.resource_group, str(ex))) + + def update_configuration_slot(self, slot_name=None, site_config=None): + ''' + Update slot configuration + :return: deserialized slot configuration response + ''' + self.log("Update web app slot configuration") + + if slot_name is None: + slot_name = self.name + if site_config is None: + site_config = self.site_config + try: + + response = self.web_client.web_apps.update_configuration_slot( + resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name, site_config=site_config) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to update configuration for web app slot {0} in resource group {1}: {2}".format( + slot_name, self.resource_group, str(ex))) + + def set_state_slot(self, appstate): + ''' + Start/stop/restart web app slot + :return: deserialized updating response + ''' + try: + if appstate == 'started': + response = self.web_client.web_apps.start_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name) + elif appstate == 'stopped': + response = self.web_client.web_apps.stop_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name) + elif appstate == 'restarted': + response = self.web_client.web_apps.restart_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name) + else: + self.fail("Invalid web app slot state {0}".format(appstate)) + + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail("Failed to {0} web app slot {1} in resource group {2}, request_id {3} - {4}".format( + appstate, self.name, self.resource_group, request_id, str(ex))) + + def swap_slot(self): + ''' + Swap slot + :return: deserialized response + ''' + self.log("Swap slot") + + try: + if self.swap['action'] == 'swap': + if self.swap['target_slot'] is None: + response = self.web_client.web_apps.swap_slot_with_production(resource_group_name=self.resource_group, + name=self.webapp_name, + target_slot=self.name, + preserve_vnet=self.swap['preserve_vnet']) + else: + response = self.web_client.web_apps.swap_slot_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name, + target_slot=self.swap['target_slot'], + preserve_vnet=self.swap['preserve_vnet']) + elif self.swap['action'] == 'preview': + if self.swap['target_slot'] is None: + response = self.web_client.web_apps.apply_slot_config_to_production(resource_group_name=self.resource_group, + name=self.webapp_name, + target_slot=self.name, + preserve_vnet=self.swap['preserve_vnet']) + else: + response = self.web_client.web_apps.apply_slot_configuration_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name, + target_slot=self.swap['target_slot'], + preserve_vnet=self.swap['preserve_vnet']) + elif self.swap['action'] == 'reset': + if self.swap['target_slot'] is None: + response = self.web_client.web_apps.reset_production_slot_config(resource_group_name=self.resource_group, + name=self.webapp_name) + else: + response = self.web_client.web_apps.reset_slot_configuration_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.swap['target_slot']) + response = self.web_client.web_apps.reset_slot_configuration_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name) + + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to swap web app slot {0} in resource group {1}: {2}".format(self.name, self.resource_group, str(ex))) + + def clone_slot(self): + if self.configuration_source: + src_slot = None if self.configuration_source.lower() == self.webapp_name.lower() else self.configuration_source + + if src_slot is None: + site_config_clone_from = self.get_configuration() + else: + site_config_clone_from = self.get_configuration_slot(slot_name=src_slot) + + self.update_configuration_slot(site_config=site_config_clone_from) + + if src_slot is None: + app_setting_clone_from = self.list_app_settings() + else: + app_setting_clone_from = self.list_app_settings_slot(src_slot) + + if self.app_settings: + app_setting_clone_from.update(self.app_settings) + + self.update_app_settings_slot(app_settings=app_setting_clone_from) + + +def main(): + """Main execution""" + AzureRMWebAppSlots() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/cloud_init_data_facts.py b/test/support/integration/plugins/modules/cloud_init_data_facts.py new file mode 100644 index 00000000..4f871b99 --- /dev/null +++ b/test/support/integration/plugins/modules/cloud_init_data_facts.py @@ -0,0 +1,134 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# (c) 2018, René Moser <mail@renemoser.net> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: cloud_init_data_facts +short_description: Retrieve facts of cloud-init. +description: + - Gathers facts by reading the status.json and result.json of cloud-init. +version_added: 2.6 +author: René Moser (@resmo) +options: + filter: + description: + - Filter facts + choices: [ status, result ] +notes: + - See http://cloudinit.readthedocs.io/ for more information about cloud-init. +''' + +EXAMPLES = ''' +- name: Gather all facts of cloud init + cloud_init_data_facts: + register: result + +- debug: + var: result + +- name: Wait for cloud init to finish + cloud_init_data_facts: + filter: status + register: res + until: "res.cloud_init_data_facts.status.v1.stage is defined and not res.cloud_init_data_facts.status.v1.stage" + retries: 50 + delay: 5 +''' + +RETURN = ''' +--- +cloud_init_data_facts: + description: Facts of result and status. + returned: success + type: dict + sample: '{ + "status": { + "v1": { + "datasource": "DataSourceCloudStack", + "errors": [] + }, + "result": { + "v1": { + "datasource": "DataSourceCloudStack", + "init": { + "errors": [], + "finished": 1522066377.0185432, + "start": 1522066375.2648022 + }, + "init-local": { + "errors": [], + "finished": 1522066373.70919, + "start": 1522066373.4726632 + }, + "modules-config": { + "errors": [], + "finished": 1522066380.9097016, + "start": 1522066379.0011985 + }, + "modules-final": { + "errors": [], + "finished": 1522066383.56594, + "start": 1522066382.3449218 + }, + "stage": null + } + }' +''' + +import os + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_text + + +CLOUD_INIT_PATH = "/var/lib/cloud/data/" + + +def gather_cloud_init_data_facts(module): + res = { + 'cloud_init_data_facts': dict() + } + + for i in ['result', 'status']: + filter = module.params.get('filter') + if filter is None or filter == i: + res['cloud_init_data_facts'][i] = dict() + json_file = CLOUD_INIT_PATH + i + '.json' + + if os.path.exists(json_file): + f = open(json_file, 'rb') + contents = to_text(f.read(), errors='surrogate_or_strict') + f.close() + + if contents: + res['cloud_init_data_facts'][i] = module.from_json(contents) + return res + + +def main(): + module = AnsibleModule( + argument_spec=dict( + filter=dict(choices=['result', 'status']), + ), + supports_check_mode=True, + ) + + facts = gather_cloud_init_data_facts(module) + result = dict(changed=False, ansible_facts=facts, **facts) + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/cloudformation.py b/test/support/integration/plugins/modules/cloudformation.py new file mode 100644 index 00000000..cd031465 --- /dev/null +++ b/test/support/integration/plugins/modules/cloudformation.py @@ -0,0 +1,837 @@ +#!/usr/bin/python + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: cloudformation +short_description: Create or delete an AWS CloudFormation stack +description: + - Launches or updates an AWS CloudFormation stack and waits for it complete. +notes: + - CloudFormation features change often, and this module tries to keep up. That means your botocore version should be fresh. + The version listed in the requirements is the oldest version that works with the module as a whole. + Some features may require recent versions, and we do not pinpoint a minimum version for each feature. + Instead of relying on the minimum version, keep botocore up to date. AWS is always releasing features and fixing bugs. +version_added: "1.1" +options: + stack_name: + description: + - Name of the CloudFormation stack. + required: true + type: str + disable_rollback: + description: + - If a stacks fails to form, rollback will remove the stack. + default: false + type: bool + on_create_failure: + description: + - Action to take upon failure of stack creation. Incompatible with the I(disable_rollback) option. + choices: + - DO_NOTHING + - ROLLBACK + - DELETE + version_added: "2.8" + type: str + create_timeout: + description: + - The amount of time (in minutes) that can pass before the stack status becomes CREATE_FAILED + version_added: "2.6" + type: int + template_parameters: + description: + - A list of hashes of all the template variables for the stack. The value can be a string or a dict. + - Dict can be used to set additional template parameter attributes like UsePreviousValue (see example). + default: {} + type: dict + state: + description: + - If I(state=present), stack will be created. + - If I(state=present) and if stack exists and template has changed, it will be updated. + - If I(state=absent), stack will be removed. + default: present + choices: [ present, absent ] + type: str + template: + description: + - The local path of the CloudFormation template. + - This must be the full path to the file, relative to the working directory. If using roles this may look + like C(roles/cloudformation/files/cloudformation-example.json). + - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url) + must be specified (but only one of them). + - If I(state=present), the stack does exist, and neither I(template), + I(template_body) nor I(template_url) are specified, the previous template will be reused. + type: path + notification_arns: + description: + - A comma separated list of Simple Notification Service (SNS) topic ARNs to publish stack related events. + version_added: "2.0" + type: str + stack_policy: + description: + - The path of the CloudFormation stack policy. A policy cannot be removed once placed, but it can be modified. + for instance, allow all updates U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/protect-stack-resources.html#d0e9051) + version_added: "1.9" + type: str + tags: + description: + - Dictionary of tags to associate with stack and its resources during stack creation. + - Can be updated later, updating tags removes previous entries. + version_added: "1.4" + type: dict + template_url: + description: + - Location of file containing the template body. The URL must point to a template (max size 307,200 bytes) located in an + S3 bucket in the same region as the stack. + - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url) + must be specified (but only one of them). + - If I(state=present), the stack does exist, and neither I(template), I(template_body) nor I(template_url) are specified, + the previous template will be reused. + version_added: "2.0" + type: str + create_changeset: + description: + - "If stack already exists create a changeset instead of directly applying changes. See the AWS Change Sets docs + U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-changesets.html)." + - "WARNING: if the stack does not exist, it will be created without changeset. If I(state=absent), the stack will be + deleted immediately with no changeset." + type: bool + default: false + version_added: "2.4" + changeset_name: + description: + - Name given to the changeset when creating a changeset. + - Only used when I(create_changeset=true). + - By default a name prefixed with Ansible-STACKNAME is generated based on input parameters. + See the AWS Change Sets docs for more information + U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-changesets.html) + version_added: "2.4" + type: str + template_format: + description: + - This parameter is ignored since Ansible 2.3 and will be removed in Ansible 2.14. + - Templates are now passed raw to CloudFormation regardless of format. + version_added: "2.0" + type: str + role_arn: + description: + - The role that AWS CloudFormation assumes to create the stack. See the AWS CloudFormation Service Role + docs U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-iam-servicerole.html) + version_added: "2.3" + type: str + termination_protection: + description: + - Enable or disable termination protection on the stack. Only works with botocore >= 1.7.18. + type: bool + version_added: "2.5" + template_body: + description: + - Template body. Use this to pass in the actual body of the CloudFormation template. + - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url) + must be specified (but only one of them). + - If I(state=present), the stack does exist, and neither I(template), I(template_body) nor I(template_url) + are specified, the previous template will be reused. + version_added: "2.5" + type: str + events_limit: + description: + - Maximum number of CloudFormation events to fetch from a stack when creating or updating it. + default: 200 + version_added: "2.7" + type: int + backoff_delay: + description: + - Number of seconds to wait for the next retry. + default: 3 + version_added: "2.8" + type: int + required: False + backoff_max_delay: + description: + - Maximum amount of time to wait between retries. + default: 30 + version_added: "2.8" + type: int + required: False + backoff_retries: + description: + - Number of times to retry operation. + - AWS API throttling mechanism fails CloudFormation module so we have to retry a couple of times. + default: 10 + version_added: "2.8" + type: int + required: False + capabilities: + description: + - Specify capabilities that stack template contains. + - Valid values are C(CAPABILITY_IAM), C(CAPABILITY_NAMED_IAM) and C(CAPABILITY_AUTO_EXPAND). + type: list + elements: str + version_added: "2.8" + default: [ CAPABILITY_IAM, CAPABILITY_NAMED_IAM ] + +author: "James S. Martin (@jsmartin)" +extends_documentation_fragment: +- aws +- ec2 +requirements: [ boto3, botocore>=1.5.45 ] +''' + +EXAMPLES = ''' +- name: create a cloudformation stack + cloudformation: + stack_name: "ansible-cloudformation" + state: "present" + region: "us-east-1" + disable_rollback: true + template: "files/cloudformation-example.json" + template_parameters: + KeyName: "jmartin" + DiskType: "ephemeral" + InstanceType: "m1.small" + ClusterSize: 3 + tags: + Stack: "ansible-cloudformation" + +# Basic role example +- name: create a stack, specify role that cloudformation assumes + cloudformation: + stack_name: "ansible-cloudformation" + state: "present" + region: "us-east-1" + disable_rollback: true + template: "roles/cloudformation/files/cloudformation-example.json" + role_arn: 'arn:aws:iam::123456789012:role/cloudformation-iam-role' + +- name: delete a stack + cloudformation: + stack_name: "ansible-cloudformation-old" + state: "absent" + +# Create a stack, pass in template from a URL, disable rollback if stack creation fails, +# pass in some parameters to the template, provide tags for resources created +- name: create a stack, pass in the template via an URL + cloudformation: + stack_name: "ansible-cloudformation" + state: present + region: us-east-1 + disable_rollback: true + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + template_parameters: + KeyName: jmartin + DiskType: ephemeral + InstanceType: m1.small + ClusterSize: 3 + tags: + Stack: ansible-cloudformation + +# Create a stack, passing in template body using lookup of Jinja2 template, disable rollback if stack creation fails, +# pass in some parameters to the template, provide tags for resources created +- name: create a stack, pass in the template body via lookup template + cloudformation: + stack_name: "ansible-cloudformation" + state: present + region: us-east-1 + disable_rollback: true + template_body: "{{ lookup('template', 'cloudformation.j2') }}" + template_parameters: + KeyName: jmartin + DiskType: ephemeral + InstanceType: m1.small + ClusterSize: 3 + tags: + Stack: ansible-cloudformation + +# Pass a template parameter which uses CloudFormation's UsePreviousValue attribute +# When use_previous_value is set to True, the given value will be ignored and +# CloudFormation will use the value from a previously submitted template. +# If use_previous_value is set to False (default) the given value is used. +- cloudformation: + stack_name: "ansible-cloudformation" + state: "present" + region: "us-east-1" + template: "files/cloudformation-example.json" + template_parameters: + DBSnapshotIdentifier: + use_previous_value: True + value: arn:aws:rds:es-east-1:000000000000:snapshot:rds:my-db-snapshot + DBName: + use_previous_value: True + tags: + Stack: "ansible-cloudformation" + +# Enable termination protection on a stack. +# If the stack already exists, this will update its termination protection +- name: enable termination protection during stack creation + cloudformation: + stack_name: my_stack + state: present + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + termination_protection: yes + +# Configure TimeoutInMinutes before the stack status becomes CREATE_FAILED +# In this case, if disable_rollback is not set or is set to false, the stack will be rolled back. +- name: enable termination protection during stack creation + cloudformation: + stack_name: my_stack + state: present + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + create_timeout: 5 + +# Configure rollback behaviour on the unsuccessful creation of a stack allowing +# CloudFormation to clean up, or do nothing in the event of an unsuccessful +# deployment +# In this case, if on_create_failure is set to "DELETE", it will clean up the stack if +# it fails to create +- name: create stack which will delete on creation failure + cloudformation: + stack_name: my_stack + state: present + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + on_create_failure: DELETE +''' + +RETURN = ''' +events: + type: list + description: Most recent events in CloudFormation's event log. This may be from a previous run in some cases. + returned: always + sample: ["StackEvent AWS::CloudFormation::Stack stackname UPDATE_COMPLETE", "StackEvent AWS::CloudFormation::Stack stackname UPDATE_COMPLETE_CLEANUP_IN_PROGRESS"] +log: + description: Debugging logs. Useful when modifying or finding an error. + returned: always + type: list + sample: ["updating stack"] +change_set_id: + description: The ID of the stack change set if one was created + returned: I(state=present) and I(create_changeset=true) + type: str + sample: "arn:aws:cloudformation:us-east-1:012345678901:changeSet/Ansible-StackName-f4496805bd1b2be824d1e315c6884247ede41eb0" +stack_resources: + description: AWS stack resources and their status. List of dictionaries, one dict per resource. + returned: state == present + type: list + sample: [ + { + "last_updated_time": "2016-10-11T19:40:14.979000+00:00", + "logical_resource_id": "CFTestSg", + "physical_resource_id": "cloudformation2-CFTestSg-16UQ4CYQ57O9F", + "resource_type": "AWS::EC2::SecurityGroup", + "status": "UPDATE_COMPLETE", + "status_reason": null + } + ] +stack_outputs: + type: dict + description: A key:value dictionary of all the stack outputs currently defined. If there are no stack outputs, it is an empty dictionary. + returned: state == present + sample: {"MySg": "AnsibleModuleTestYAML-CFTestSg-C8UVS567B6NS"} +''' # NOQA + +import json +import time +import uuid +import traceback +from hashlib import sha1 + +try: + import boto3 + import botocore + HAS_BOTO3 = True +except ImportError: + HAS_BOTO3 = False + +from ansible.module_utils.ec2 import ansible_dict_to_boto3_tag_list, AWSRetry, boto3_conn, boto_exception, ec2_argument_spec, get_aws_connection_info +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_bytes, to_native + + +def get_stack_events(cfn, stack_name, events_limit, token_filter=None): + '''This event data was never correct, it worked as a side effect. So the v2.3 format is different.''' + ret = {'events': [], 'log': []} + + try: + pg = cfn.get_paginator( + 'describe_stack_events' + ).paginate( + StackName=stack_name, + PaginationConfig={'MaxItems': events_limit} + ) + if token_filter is not None: + events = list(pg.search( + "StackEvents[?ClientRequestToken == '{0}']".format(token_filter) + )) + else: + events = list(pg.search("StackEvents[*]")) + except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err: + error_msg = boto_exception(err) + if 'does not exist' in error_msg: + # missing stack, don't bail. + ret['log'].append('Stack does not exist.') + return ret + ret['log'].append('Unknown error: ' + str(error_msg)) + return ret + + for e in events: + eventline = 'StackEvent {ResourceType} {LogicalResourceId} {ResourceStatus}'.format(**e) + ret['events'].append(eventline) + + if e['ResourceStatus'].endswith('FAILED'): + failline = '{ResourceType} {LogicalResourceId} {ResourceStatus}: {ResourceStatusReason}'.format(**e) + ret['log'].append(failline) + + return ret + + +def create_stack(module, stack_params, cfn, events_limit): + if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params: + module.fail_json(msg="Either 'template', 'template_body' or 'template_url' is required when the stack does not exist.") + + # 'DisableRollback', 'TimeoutInMinutes', 'EnableTerminationProtection' and + # 'OnFailure' only apply on creation, not update. + if module.params.get('on_create_failure') is not None: + stack_params['OnFailure'] = module.params['on_create_failure'] + else: + stack_params['DisableRollback'] = module.params['disable_rollback'] + + if module.params.get('create_timeout') is not None: + stack_params['TimeoutInMinutes'] = module.params['create_timeout'] + if module.params.get('termination_protection') is not None: + if boto_supports_termination_protection(cfn): + stack_params['EnableTerminationProtection'] = bool(module.params.get('termination_protection')) + else: + module.fail_json(msg="termination_protection parameter requires botocore >= 1.7.18") + + try: + response = cfn.create_stack(**stack_params) + # Use stack ID to follow stack state in case of on_create_failure = DELETE + result = stack_operation(cfn, response['StackId'], 'CREATE', events_limit, stack_params.get('ClientRequestToken', None)) + except Exception as err: + error_msg = boto_exception(err) + module.fail_json(msg="Failed to create stack {0}: {1}.".format(stack_params.get('StackName'), error_msg), exception=traceback.format_exc()) + if not result: + module.fail_json(msg="empty result") + return result + + +def list_changesets(cfn, stack_name): + res = cfn.list_change_sets(StackName=stack_name) + return [cs['ChangeSetName'] for cs in res['Summaries']] + + +def create_changeset(module, stack_params, cfn, events_limit): + if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params: + module.fail_json(msg="Either 'template' or 'template_url' is required.") + if module.params['changeset_name'] is not None: + stack_params['ChangeSetName'] = module.params['changeset_name'] + + # changesets don't accept ClientRequestToken parameters + stack_params.pop('ClientRequestToken', None) + + try: + changeset_name = build_changeset_name(stack_params) + stack_params['ChangeSetName'] = changeset_name + + # Determine if this changeset already exists + pending_changesets = list_changesets(cfn, stack_params['StackName']) + if changeset_name in pending_changesets: + warning = 'WARNING: %d pending changeset(s) exist(s) for this stack!' % len(pending_changesets) + result = dict(changed=False, output='ChangeSet %s already exists.' % changeset_name, warnings=[warning]) + else: + cs = cfn.create_change_set(**stack_params) + # Make sure we don't enter an infinite loop + time_end = time.time() + 600 + while time.time() < time_end: + try: + newcs = cfn.describe_change_set(ChangeSetName=cs['Id']) + except botocore.exceptions.BotoCoreError as err: + error_msg = boto_exception(err) + module.fail_json(msg=error_msg) + if newcs['Status'] == 'CREATE_PENDING' or newcs['Status'] == 'CREATE_IN_PROGRESS': + time.sleep(1) + elif newcs['Status'] == 'FAILED' and "The submitted information didn't contain changes" in newcs['StatusReason']: + cfn.delete_change_set(ChangeSetName=cs['Id']) + result = dict(changed=False, + output='The created Change Set did not contain any changes to this stack and was deleted.') + # a failed change set does not trigger any stack events so we just want to + # skip any further processing of result and just return it directly + return result + else: + break + # Lets not hog the cpu/spam the AWS API + time.sleep(1) + result = stack_operation(cfn, stack_params['StackName'], 'CREATE_CHANGESET', events_limit) + result['change_set_id'] = cs['Id'] + result['warnings'] = ['Created changeset named %s for stack %s' % (changeset_name, stack_params['StackName']), + 'You can execute it using: aws cloudformation execute-change-set --change-set-name %s' % cs['Id'], + 'NOTE that dependencies on this stack might fail due to pending changes!'] + except Exception as err: + error_msg = boto_exception(err) + if 'No updates are to be performed.' in error_msg: + result = dict(changed=False, output='Stack is already up-to-date.') + else: + module.fail_json(msg="Failed to create change set: {0}".format(error_msg), exception=traceback.format_exc()) + + if not result: + module.fail_json(msg="empty result") + return result + + +def update_stack(module, stack_params, cfn, events_limit): + if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params: + stack_params['UsePreviousTemplate'] = True + + # if the state is present and the stack already exists, we try to update it. + # AWS will tell us if the stack template and parameters are the same and + # don't need to be updated. + try: + cfn.update_stack(**stack_params) + result = stack_operation(cfn, stack_params['StackName'], 'UPDATE', events_limit, stack_params.get('ClientRequestToken', None)) + except Exception as err: + error_msg = boto_exception(err) + if 'No updates are to be performed.' in error_msg: + result = dict(changed=False, output='Stack is already up-to-date.') + else: + module.fail_json(msg="Failed to update stack {0}: {1}".format(stack_params.get('StackName'), error_msg), exception=traceback.format_exc()) + if not result: + module.fail_json(msg="empty result") + return result + + +def update_termination_protection(module, cfn, stack_name, desired_termination_protection_state): + '''updates termination protection of a stack''' + if not boto_supports_termination_protection(cfn): + module.fail_json(msg="termination_protection parameter requires botocore >= 1.7.18") + stack = get_stack_facts(cfn, stack_name) + if stack: + if stack['EnableTerminationProtection'] is not desired_termination_protection_state: + try: + cfn.update_termination_protection( + EnableTerminationProtection=desired_termination_protection_state, + StackName=stack_name) + except botocore.exceptions.ClientError as e: + module.fail_json(msg=boto_exception(e), exception=traceback.format_exc()) + + +def boto_supports_termination_protection(cfn): + '''termination protection was added in botocore 1.7.18''' + return hasattr(cfn, "update_termination_protection") + + +def stack_operation(cfn, stack_name, operation, events_limit, op_token=None): + '''gets the status of a stack while it is created/updated/deleted''' + existed = [] + while True: + try: + stack = get_stack_facts(cfn, stack_name) + existed.append('yes') + except Exception: + # If the stack previously existed, and now can't be found then it's + # been deleted successfully. + if 'yes' in existed or operation == 'DELETE': # stacks may delete fast, look in a few ways. + ret = get_stack_events(cfn, stack_name, events_limit, op_token) + ret.update({'changed': True, 'output': 'Stack Deleted'}) + return ret + else: + return {'changed': True, 'failed': True, 'output': 'Stack Not Found', 'exception': traceback.format_exc()} + ret = get_stack_events(cfn, stack_name, events_limit, op_token) + if not stack: + if 'yes' in existed or operation == 'DELETE': # stacks may delete fast, look in a few ways. + ret = get_stack_events(cfn, stack_name, events_limit, op_token) + ret.update({'changed': True, 'output': 'Stack Deleted'}) + return ret + else: + ret.update({'changed': False, 'failed': True, 'output': 'Stack not found.'}) + return ret + # it covers ROLLBACK_COMPLETE and UPDATE_ROLLBACK_COMPLETE + # Possible states: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-describing-stacks.html#w1ab2c15c17c21c13 + elif stack['StackStatus'].endswith('ROLLBACK_COMPLETE') and operation != 'CREATE_CHANGESET': + ret.update({'changed': True, 'failed': True, 'output': 'Problem with %s. Rollback complete' % operation}) + return ret + elif stack['StackStatus'] == 'DELETE_COMPLETE' and operation == 'CREATE': + ret.update({'changed': True, 'failed': True, 'output': 'Stack create failed. Delete complete.'}) + return ret + # note the ordering of ROLLBACK_COMPLETE, DELETE_COMPLETE, and COMPLETE, because otherwise COMPLETE will match all cases. + elif stack['StackStatus'].endswith('_COMPLETE'): + ret.update({'changed': True, 'output': 'Stack %s complete' % operation}) + return ret + elif stack['StackStatus'].endswith('_ROLLBACK_FAILED'): + ret.update({'changed': True, 'failed': True, 'output': 'Stack %s rollback failed' % operation}) + return ret + # note the ordering of ROLLBACK_FAILED and FAILED, because otherwise FAILED will match both cases. + elif stack['StackStatus'].endswith('_FAILED'): + ret.update({'changed': True, 'failed': True, 'output': 'Stack %s failed' % operation}) + return ret + else: + # this can loop forever :/ + time.sleep(5) + return {'failed': True, 'output': 'Failed for unknown reasons.'} + + +def build_changeset_name(stack_params): + if 'ChangeSetName' in stack_params: + return stack_params['ChangeSetName'] + + json_params = json.dumps(stack_params, sort_keys=True) + + return 'Ansible-{0}-{1}'.format( + stack_params['StackName'], + sha1(to_bytes(json_params, errors='surrogate_or_strict')).hexdigest() + ) + + +def check_mode_changeset(module, stack_params, cfn): + """Create a change set, describe it and delete it before returning check mode outputs.""" + stack_params['ChangeSetName'] = build_changeset_name(stack_params) + # changesets don't accept ClientRequestToken parameters + stack_params.pop('ClientRequestToken', None) + + try: + change_set = cfn.create_change_set(**stack_params) + for i in range(60): # total time 5 min + description = cfn.describe_change_set(ChangeSetName=change_set['Id']) + if description['Status'] in ('CREATE_COMPLETE', 'FAILED'): + break + time.sleep(5) + else: + # if the changeset doesn't finish in 5 mins, this `else` will trigger and fail + module.fail_json(msg="Failed to create change set %s" % stack_params['ChangeSetName']) + + cfn.delete_change_set(ChangeSetName=change_set['Id']) + + reason = description.get('StatusReason') + + if description['Status'] == 'FAILED' and "didn't contain changes" in description['StatusReason']: + return {'changed': False, 'msg': reason, 'meta': description['StatusReason']} + return {'changed': True, 'msg': reason, 'meta': description['Changes']} + + except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err: + error_msg = boto_exception(err) + module.fail_json(msg=error_msg, exception=traceback.format_exc()) + + +def get_stack_facts(cfn, stack_name): + try: + stack_response = cfn.describe_stacks(StackName=stack_name) + stack_info = stack_response['Stacks'][0] + except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err: + error_msg = boto_exception(err) + if 'does not exist' in error_msg: + # missing stack, don't bail. + return None + + # other error, bail. + raise err + + if stack_response and stack_response.get('Stacks', None): + stacks = stack_response['Stacks'] + if len(stacks): + stack_info = stacks[0] + + return stack_info + + +def main(): + argument_spec = ec2_argument_spec() + argument_spec.update(dict( + stack_name=dict(required=True), + template_parameters=dict(required=False, type='dict', default={}), + state=dict(default='present', choices=['present', 'absent']), + template=dict(default=None, required=False, type='path'), + notification_arns=dict(default=None, required=False), + stack_policy=dict(default=None, required=False), + disable_rollback=dict(default=False, type='bool'), + on_create_failure=dict(default=None, required=False, choices=['DO_NOTHING', 'ROLLBACK', 'DELETE']), + create_timeout=dict(default=None, type='int'), + template_url=dict(default=None, required=False), + template_body=dict(default=None, required=False), + template_format=dict(removed_in_version='2.14'), + create_changeset=dict(default=False, type='bool'), + changeset_name=dict(default=None, required=False), + role_arn=dict(default=None, required=False), + tags=dict(default=None, type='dict'), + termination_protection=dict(default=None, type='bool'), + events_limit=dict(default=200, type='int'), + backoff_retries=dict(type='int', default=10, required=False), + backoff_delay=dict(type='int', default=3, required=False), + backoff_max_delay=dict(type='int', default=30, required=False), + capabilities=dict(type='list', default=['CAPABILITY_IAM', 'CAPABILITY_NAMED_IAM']) + ) + ) + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=[['template_url', 'template', 'template_body'], + ['disable_rollback', 'on_create_failure']], + supports_check_mode=True + ) + if not HAS_BOTO3: + module.fail_json(msg='boto3 and botocore are required for this module') + + invalid_capabilities = [] + user_capabilities = module.params.get('capabilities') + for user_cap in user_capabilities: + if user_cap not in ['CAPABILITY_IAM', 'CAPABILITY_NAMED_IAM', 'CAPABILITY_AUTO_EXPAND']: + invalid_capabilities.append(user_cap) + + if invalid_capabilities: + module.fail_json(msg="Specified capabilities are invalid : %r," + " please check documentation for valid capabilities" % invalid_capabilities) + + # collect the parameters that are passed to boto3. Keeps us from having so many scalars floating around. + stack_params = { + 'Capabilities': user_capabilities, + 'ClientRequestToken': to_native(uuid.uuid4()), + } + state = module.params['state'] + stack_params['StackName'] = module.params['stack_name'] + + if module.params['template'] is not None: + with open(module.params['template'], 'r') as template_fh: + stack_params['TemplateBody'] = template_fh.read() + elif module.params['template_body'] is not None: + stack_params['TemplateBody'] = module.params['template_body'] + elif module.params['template_url'] is not None: + stack_params['TemplateURL'] = module.params['template_url'] + + if module.params.get('notification_arns'): + stack_params['NotificationARNs'] = module.params['notification_arns'].split(',') + else: + stack_params['NotificationARNs'] = [] + + # can't check the policy when verifying. + if module.params['stack_policy'] is not None and not module.check_mode and not module.params['create_changeset']: + with open(module.params['stack_policy'], 'r') as stack_policy_fh: + stack_params['StackPolicyBody'] = stack_policy_fh.read() + + template_parameters = module.params['template_parameters'] + + stack_params['Parameters'] = [] + for k, v in template_parameters.items(): + if isinstance(v, dict): + # set parameter based on a dict to allow additional CFN Parameter Attributes + param = dict(ParameterKey=k) + + if 'value' in v: + param['ParameterValue'] = str(v['value']) + + if 'use_previous_value' in v and bool(v['use_previous_value']): + param['UsePreviousValue'] = True + param.pop('ParameterValue', None) + + stack_params['Parameters'].append(param) + else: + # allow default k/v configuration to set a template parameter + stack_params['Parameters'].append({'ParameterKey': k, 'ParameterValue': str(v)}) + + if isinstance(module.params.get('tags'), dict): + stack_params['Tags'] = ansible_dict_to_boto3_tag_list(module.params['tags']) + + if module.params.get('role_arn'): + stack_params['RoleARN'] = module.params['role_arn'] + + result = {} + + try: + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True) + cfn = boto3_conn(module, conn_type='client', resource='cloudformation', region=region, endpoint=ec2_url, **aws_connect_kwargs) + except botocore.exceptions.NoCredentialsError as e: + module.fail_json(msg=boto_exception(e)) + + # Wrap the cloudformation client methods that this module uses with + # automatic backoff / retry for throttling error codes + backoff_wrapper = AWSRetry.jittered_backoff( + retries=module.params.get('backoff_retries'), + delay=module.params.get('backoff_delay'), + max_delay=module.params.get('backoff_max_delay') + ) + cfn.describe_stack_events = backoff_wrapper(cfn.describe_stack_events) + cfn.create_stack = backoff_wrapper(cfn.create_stack) + cfn.list_change_sets = backoff_wrapper(cfn.list_change_sets) + cfn.create_change_set = backoff_wrapper(cfn.create_change_set) + cfn.update_stack = backoff_wrapper(cfn.update_stack) + cfn.describe_stacks = backoff_wrapper(cfn.describe_stacks) + cfn.list_stack_resources = backoff_wrapper(cfn.list_stack_resources) + cfn.delete_stack = backoff_wrapper(cfn.delete_stack) + if boto_supports_termination_protection(cfn): + cfn.update_termination_protection = backoff_wrapper(cfn.update_termination_protection) + + stack_info = get_stack_facts(cfn, stack_params['StackName']) + + if module.check_mode: + if state == 'absent' and stack_info: + module.exit_json(changed=True, msg='Stack would be deleted', meta=[]) + elif state == 'absent' and not stack_info: + module.exit_json(changed=False, msg='Stack doesn\'t exist', meta=[]) + elif state == 'present' and not stack_info: + module.exit_json(changed=True, msg='New stack would be created', meta=[]) + else: + module.exit_json(**check_mode_changeset(module, stack_params, cfn)) + + if state == 'present': + if not stack_info: + result = create_stack(module, stack_params, cfn, module.params.get('events_limit')) + elif module.params.get('create_changeset'): + result = create_changeset(module, stack_params, cfn, module.params.get('events_limit')) + else: + if module.params.get('termination_protection') is not None: + update_termination_protection(module, cfn, stack_params['StackName'], + bool(module.params.get('termination_protection'))) + result = update_stack(module, stack_params, cfn, module.params.get('events_limit')) + + # format the stack output + + stack = get_stack_facts(cfn, stack_params['StackName']) + if stack is not None: + if result.get('stack_outputs') is None: + # always define stack_outputs, but it may be empty + result['stack_outputs'] = {} + for output in stack.get('Outputs', []): + result['stack_outputs'][output['OutputKey']] = output['OutputValue'] + stack_resources = [] + reslist = cfn.list_stack_resources(StackName=stack_params['StackName']) + for res in reslist.get('StackResourceSummaries', []): + stack_resources.append({ + "logical_resource_id": res['LogicalResourceId'], + "physical_resource_id": res.get('PhysicalResourceId', ''), + "resource_type": res['ResourceType'], + "last_updated_time": res['LastUpdatedTimestamp'], + "status": res['ResourceStatus'], + "status_reason": res.get('ResourceStatusReason') # can be blank, apparently + }) + result['stack_resources'] = stack_resources + + elif state == 'absent': + # absent state is different because of the way delete_stack works. + # problem is it it doesn't give an error if stack isn't found + # so must describe the stack first + + try: + stack = get_stack_facts(cfn, stack_params['StackName']) + if not stack: + result = {'changed': False, 'output': 'Stack not found.'} + else: + if stack_params.get('RoleARN') is None: + cfn.delete_stack(StackName=stack_params['StackName']) + else: + cfn.delete_stack(StackName=stack_params['StackName'], RoleARN=stack_params['RoleARN']) + result = stack_operation(cfn, stack_params['StackName'], 'DELETE', module.params.get('events_limit'), + stack_params.get('ClientRequestToken', None)) + except Exception as err: + module.fail_json(msg=boto_exception(err), exception=traceback.format_exc()) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/cloudformation_info.py b/test/support/integration/plugins/modules/cloudformation_info.py new file mode 100644 index 00000000..ee2e5c17 --- /dev/null +++ b/test/support/integration/plugins/modules/cloudformation_info.py @@ -0,0 +1,355 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: cloudformation_info +short_description: Obtain information about an AWS CloudFormation stack +description: + - Gets information about an AWS CloudFormation stack. + - This module was called C(cloudformation_facts) before Ansible 2.9, returning C(ansible_facts). + Note that the M(cloudformation_info) module no longer returns C(ansible_facts)! +requirements: + - boto3 >= 1.0.0 + - python >= 2.6 +version_added: "2.2" +author: + - Justin Menga (@jmenga) + - Kevin Coming (@waffie1) +options: + stack_name: + description: + - The name or id of the CloudFormation stack. Gathers information on all stacks by default. + type: str + all_facts: + description: + - Get all stack information for the stack. + type: bool + default: false + stack_events: + description: + - Get stack events for the stack. + type: bool + default: false + stack_template: + description: + - Get stack template body for the stack. + type: bool + default: false + stack_resources: + description: + - Get stack resources for the stack. + type: bool + default: false + stack_policy: + description: + - Get stack policy for the stack. + type: bool + default: false + stack_change_sets: + description: + - Get stack change sets for the stack + type: bool + default: false + version_added: '2.10' +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Get summary information about a stack +- cloudformation_info: + stack_name: my-cloudformation-stack + register: output + +- debug: + msg: "{{ output['cloudformation']['my-cloudformation-stack'] }}" + +# When the module is called as cloudformation_facts, return values are published +# in ansible_facts['cloudformation'][<stack_name>] and can be used as follows. +# Note that this is deprecated and will stop working in Ansible 2.13. + +- cloudformation_facts: + stack_name: my-cloudformation-stack + +- debug: + msg: "{{ ansible_facts['cloudformation']['my-cloudformation-stack'] }}" + +# Get stack outputs, when you have the stack name available as a fact +- set_fact: + stack_name: my-awesome-stack + +- cloudformation_info: + stack_name: "{{ stack_name }}" + register: my_stack + +- debug: + msg: "{{ my_stack.cloudformation[stack_name].stack_outputs }}" + +# Get all stack information about a stack +- cloudformation_info: + stack_name: my-cloudformation-stack + all_facts: true + +# Get stack resource and stack policy information about a stack +- cloudformation_info: + stack_name: my-cloudformation-stack + stack_resources: true + stack_policy: true + +# Fail if the stack doesn't exist +- name: try to get facts about a stack but fail if it doesn't exist + cloudformation_info: + stack_name: nonexistent-stack + all_facts: yes + failed_when: cloudformation['nonexistent-stack'] is undefined +''' + +RETURN = ''' +stack_description: + description: Summary facts about the stack + returned: if the stack exists + type: dict +stack_outputs: + description: Dictionary of stack outputs keyed by the value of each output 'OutputKey' parameter and corresponding value of each + output 'OutputValue' parameter + returned: if the stack exists + type: dict + sample: + ApplicationDatabaseName: dazvlpr01xj55a.ap-southeast-2.rds.amazonaws.com +stack_parameters: + description: Dictionary of stack parameters keyed by the value of each parameter 'ParameterKey' parameter and corresponding value of + each parameter 'ParameterValue' parameter + returned: if the stack exists + type: dict + sample: + DatabaseEngine: mysql + DatabasePassword: "***" +stack_events: + description: All stack events for the stack + returned: only if all_facts or stack_events is true and the stack exists + type: list +stack_policy: + description: Describes the stack policy for the stack + returned: only if all_facts or stack_policy is true and the stack exists + type: dict +stack_template: + description: Describes the stack template for the stack + returned: only if all_facts or stack_template is true and the stack exists + type: dict +stack_resource_list: + description: Describes stack resources for the stack + returned: only if all_facts or stack_resourses is true and the stack exists + type: list +stack_resources: + description: Dictionary of stack resources keyed by the value of each resource 'LogicalResourceId' parameter and corresponding value of each + resource 'PhysicalResourceId' parameter + returned: only if all_facts or stack_resourses is true and the stack exists + type: dict + sample: + AutoScalingGroup: "dev-someapp-AutoscalingGroup-1SKEXXBCAN0S7" + AutoScalingSecurityGroup: "sg-abcd1234" + ApplicationDatabase: "dazvlpr01xj55a" +stack_change_sets: + description: A list of stack change sets. Each item in the list represents the details of a specific changeset + + returned: only if all_facts or stack_change_sets is true and the stack exists + type: list +''' + +import json +import traceback + +from functools import partial +from ansible.module_utils._text import to_native +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import (camel_dict_to_snake_dict, AWSRetry, boto3_tag_list_to_ansible_dict) + +try: + import botocore +except ImportError: + pass # handled by AnsibleAWSModule + + +class CloudFormationServiceManager: + """Handles CloudFormation Services""" + + def __init__(self, module): + self.module = module + self.client = module.client('cloudformation') + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def describe_stacks_with_backoff(self, **kwargs): + paginator = self.client.get_paginator('describe_stacks') + return paginator.paginate(**kwargs).build_full_result()['Stacks'] + + def describe_stacks(self, stack_name=None): + try: + kwargs = {'StackName': stack_name} if stack_name else {} + response = self.describe_stacks_with_backoff(**kwargs) + if response is not None: + return response + self.module.fail_json(msg="Error describing stack(s) - an empty response was returned") + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + if 'does not exist' in e.response['Error']['Message']: + # missing stack, don't bail. + return {} + self.module.fail_json_aws(e, msg="Error describing stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def list_stack_resources_with_backoff(self, stack_name): + paginator = self.client.get_paginator('list_stack_resources') + return paginator.paginate(StackName=stack_name).build_full_result()['StackResourceSummaries'] + + def list_stack_resources(self, stack_name): + try: + return self.list_stack_resources_with_backoff(stack_name) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error listing stack resources for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def describe_stack_events_with_backoff(self, stack_name): + paginator = self.client.get_paginator('describe_stack_events') + return paginator.paginate(StackName=stack_name).build_full_result()['StackEvents'] + + def describe_stack_events(self, stack_name): + try: + return self.describe_stack_events_with_backoff(stack_name) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error listing stack events for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def list_stack_change_sets_with_backoff(self, stack_name): + paginator = self.client.get_paginator('list_change_sets') + return paginator.paginate(StackName=stack_name).build_full_result()['Summaries'] + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def describe_stack_change_set_with_backoff(self, **kwargs): + paginator = self.client.get_paginator('describe_change_set') + return paginator.paginate(**kwargs).build_full_result() + + def describe_stack_change_sets(self, stack_name): + changes = [] + try: + change_sets = self.list_stack_change_sets_with_backoff(stack_name) + for item in change_sets: + changes.append(self.describe_stack_change_set_with_backoff( + StackName=stack_name, + ChangeSetName=item['ChangeSetName'])) + return changes + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error describing stack change sets for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def get_stack_policy_with_backoff(self, stack_name): + return self.client.get_stack_policy(StackName=stack_name) + + def get_stack_policy(self, stack_name): + try: + response = self.get_stack_policy_with_backoff(stack_name) + stack_policy = response.get('StackPolicyBody') + if stack_policy: + return json.loads(stack_policy) + return dict() + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error getting stack policy for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def get_template_with_backoff(self, stack_name): + return self.client.get_template(StackName=stack_name) + + def get_template(self, stack_name): + try: + response = self.get_template_with_backoff(stack_name) + return response.get('TemplateBody') + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error getting stack template for stack " + stack_name) + + +def to_dict(items, key, value): + ''' Transforms a list of items to a Key/Value dictionary ''' + if items: + return dict(zip([i.get(key) for i in items], [i.get(value) for i in items])) + else: + return dict() + + +def main(): + argument_spec = dict( + stack_name=dict(), + all_facts=dict(required=False, default=False, type='bool'), + stack_policy=dict(required=False, default=False, type='bool'), + stack_events=dict(required=False, default=False, type='bool'), + stack_resources=dict(required=False, default=False, type='bool'), + stack_template=dict(required=False, default=False, type='bool'), + stack_change_sets=dict(required=False, default=False, type='bool'), + ) + module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True) + + is_old_facts = module._name == 'cloudformation_facts' + if is_old_facts: + module.deprecate("The 'cloudformation_facts' module has been renamed to 'cloudformation_info', " + "and the renamed one no longer returns ansible_facts", + version='2.13', collection_name='ansible.builtin') + + service_mgr = CloudFormationServiceManager(module) + + if is_old_facts: + result = {'ansible_facts': {'cloudformation': {}}} + else: + result = {'cloudformation': {}} + + for stack_description in service_mgr.describe_stacks(module.params.get('stack_name')): + facts = {'stack_description': stack_description} + stack_name = stack_description.get('StackName') + + # Create stack output and stack parameter dictionaries + if facts['stack_description']: + facts['stack_outputs'] = to_dict(facts['stack_description'].get('Outputs'), 'OutputKey', 'OutputValue') + facts['stack_parameters'] = to_dict(facts['stack_description'].get('Parameters'), + 'ParameterKey', 'ParameterValue') + facts['stack_tags'] = boto3_tag_list_to_ansible_dict(facts['stack_description'].get('Tags')) + + # Create optional stack outputs + all_facts = module.params.get('all_facts') + if all_facts or module.params.get('stack_resources'): + facts['stack_resource_list'] = service_mgr.list_stack_resources(stack_name) + facts['stack_resources'] = to_dict(facts.get('stack_resource_list'), + 'LogicalResourceId', 'PhysicalResourceId') + if all_facts or module.params.get('stack_template'): + facts['stack_template'] = service_mgr.get_template(stack_name) + if all_facts or module.params.get('stack_policy'): + facts['stack_policy'] = service_mgr.get_stack_policy(stack_name) + if all_facts or module.params.get('stack_events'): + facts['stack_events'] = service_mgr.describe_stack_events(stack_name) + if all_facts or module.params.get('stack_change_sets'): + facts['stack_change_sets'] = service_mgr.describe_stack_change_sets(stack_name) + + if is_old_facts: + result['ansible_facts']['cloudformation'][stack_name] = facts + else: + result['cloudformation'][stack_name] = camel_dict_to_snake_dict(facts, ignore_list=('stack_outputs', + 'stack_parameters', + 'stack_policy', + 'stack_resources', + 'stack_tags', + 'stack_template')) + + module.exit_json(changed=False, **result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/deploy_helper.py b/test/support/integration/plugins/modules/deploy_helper.py new file mode 100644 index 00000000..38594dde --- /dev/null +++ b/test/support/integration/plugins/modules/deploy_helper.py @@ -0,0 +1,521 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2014, Jasper N. Brouwer <jasper@nerdsweide.nl> +# (c) 2014, Ramon de la Fuente <ramon@delafuente.nl> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: deploy_helper +version_added: "2.0" +author: "Ramon de la Fuente (@ramondelafuente)" +short_description: Manages some of the steps common in deploying projects. +description: + - The Deploy Helper manages some of the steps common in deploying software. + It creates a folder structure, manages a symlink for the current release + and cleans up old releases. + - "Running it with the C(state=query) or C(state=present) will return the C(deploy_helper) fact. + C(project_path), whatever you set in the path parameter, + C(current_path), the path to the symlink that points to the active release, + C(releases_path), the path to the folder to keep releases in, + C(shared_path), the path to the folder to keep shared resources in, + C(unfinished_filename), the file to check for to recognize unfinished builds, + C(previous_release), the release the 'current' symlink is pointing to, + C(previous_release_path), the full path to the 'current' symlink target, + C(new_release), either the 'release' parameter or a generated timestamp, + C(new_release_path), the path to the new release folder (not created by the module)." + +options: + path: + required: True + aliases: ['dest'] + description: + - the root path of the project. Alias I(dest). + Returned in the C(deploy_helper.project_path) fact. + + state: + description: + - the state of the project. + C(query) will only gather facts, + C(present) will create the project I(root) folder, and in it the I(releases) and I(shared) folders, + C(finalize) will remove the unfinished_filename file, create a symlink to the newly + deployed release and optionally clean old releases, + C(clean) will remove failed & old releases, + C(absent) will remove the project folder (synonymous to the M(file) module with C(state=absent)) + choices: [ present, finalize, absent, clean, query ] + default: present + + release: + description: + - the release version that is being deployed. Defaults to a timestamp format %Y%m%d%H%M%S (i.e. '20141119223359'). + This parameter is optional during C(state=present), but needs to be set explicitly for C(state=finalize). + You can use the generated fact C(release={{ deploy_helper.new_release }}). + + releases_path: + description: + - the name of the folder that will hold the releases. This can be relative to C(path) or absolute. + Returned in the C(deploy_helper.releases_path) fact. + default: releases + + shared_path: + description: + - the name of the folder that will hold the shared resources. This can be relative to C(path) or absolute. + If this is set to an empty string, no shared folder will be created. + Returned in the C(deploy_helper.shared_path) fact. + default: shared + + current_path: + description: + - the name of the symlink that is created when the deploy is finalized. Used in C(finalize) and C(clean). + Returned in the C(deploy_helper.current_path) fact. + default: current + + unfinished_filename: + description: + - the name of the file that indicates a deploy has not finished. All folders in the releases_path that + contain this file will be deleted on C(state=finalize) with clean=True, or C(state=clean). This file is + automatically deleted from the I(new_release_path) during C(state=finalize). + default: DEPLOY_UNFINISHED + + clean: + description: + - Whether to run the clean procedure in case of C(state=finalize). + type: bool + default: 'yes' + + keep_releases: + description: + - the number of old releases to keep when cleaning. Used in C(finalize) and C(clean). Any unfinished builds + will be deleted first, so only correct releases will count. The current version will not count. + default: 5 + +notes: + - Facts are only returned for C(state=query) and C(state=present). If you use both, you should pass any overridden + parameters to both calls, otherwise the second call will overwrite the facts of the first one. + - When using C(state=clean), the releases are ordered by I(creation date). You should be able to switch to a + new naming strategy without problems. + - Because of the default behaviour of generating the I(new_release) fact, this module will not be idempotent + unless you pass your own release name with C(release). Due to the nature of deploying software, this should not + be much of a problem. +''' + +EXAMPLES = ''' + +# General explanation, starting with an example folder structure for a project: + +# root: +# releases: +# - 20140415234508 +# - 20140415235146 +# - 20140416082818 +# +# shared: +# - sessions +# - uploads +# +# current: releases/20140416082818 + + +# The 'releases' folder holds all the available releases. A release is a complete build of the application being +# deployed. This can be a clone of a repository for example, or a sync of a local folder on your filesystem. +# Having timestamped folders is one way of having distinct releases, but you could choose your own strategy like +# git tags or commit hashes. +# +# During a deploy, a new folder should be created in the releases folder and any build steps required should be +# performed. Once the new build is ready, the deploy procedure is 'finalized' by replacing the 'current' symlink +# with a link to this build. +# +# The 'shared' folder holds any resource that is shared between releases. Examples of this are web-server +# session files, or files uploaded by users of your application. It's quite common to have symlinks from a release +# folder pointing to a shared/subfolder, and creating these links would be automated as part of the build steps. +# +# The 'current' symlink points to one of the releases. Probably the latest one, unless a deploy is in progress. +# The web-server's root for the project will go through this symlink, so the 'downtime' when switching to a new +# release is reduced to the time it takes to switch the link. +# +# To distinguish between successful builds and unfinished ones, a file can be placed in the folder of the release +# that is currently in progress. The existence of this file will mark it as unfinished, and allow an automated +# procedure to remove it during cleanup. + + +# Typical usage +- name: Initialize the deploy root and gather facts + deploy_helper: + path: /path/to/root +- name: Clone the project to the new release folder + git: + repo: git://foosball.example.org/path/to/repo.git + dest: '{{ deploy_helper.new_release_path }}' + version: v1.1.1 +- name: Add an unfinished file, to allow cleanup on successful finalize + file: + path: '{{ deploy_helper.new_release_path }}/{{ deploy_helper.unfinished_filename }}' + state: touch +- name: Perform some build steps, like running your dependency manager for example + composer: + command: install + working_dir: '{{ deploy_helper.new_release_path }}' +- name: Create some folders in the shared folder + file: + path: '{{ deploy_helper.shared_path }}/{{ item }}' + state: directory + with_items: + - sessions + - uploads +- name: Add symlinks from the new release to the shared folder + file: + path: '{{ deploy_helper.new_release_path }}/{{ item.path }}' + src: '{{ deploy_helper.shared_path }}/{{ item.src }}' + state: link + with_items: + - path: app/sessions + src: sessions + - path: web/uploads + src: uploads +- name: Finalize the deploy, removing the unfinished file and switching the symlink + deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + +# Retrieving facts before running a deploy +- name: Run 'state=query' to gather facts without changing anything + deploy_helper: + path: /path/to/root + state: query +# Remember to set the 'release' parameter when you actually call 'state=present' later +- name: Initialize the deploy root + deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: present + +# all paths can be absolute or relative (to the 'path' parameter) +- deploy_helper: + path: /path/to/root + releases_path: /var/www/project/releases + shared_path: /var/www/shared + current_path: /var/www/active + +# Using your own naming strategy for releases (a version tag in this case): +- deploy_helper: + path: /path/to/root + release: v1.1.1 + state: present +- deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + +# Using a different unfinished_filename: +- deploy_helper: + path: /path/to/root + unfinished_filename: README.md + release: '{{ deploy_helper.new_release }}' + state: finalize + +# Postponing the cleanup of older builds: +- deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + clean: False +- deploy_helper: + path: /path/to/root + state: clean +# Or running the cleanup ahead of the new deploy +- deploy_helper: + path: /path/to/root + state: clean +- deploy_helper: + path: /path/to/root + state: present + +# Keeping more old releases: +- deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + keep_releases: 10 +# Or, if you use 'clean=false' on finalize: +- deploy_helper: + path: /path/to/root + state: clean + keep_releases: 10 + +# Removing the entire project root folder +- deploy_helper: + path: /path/to/root + state: absent + +# Debugging the facts returned by the module +- deploy_helper: + path: /path/to/root +- debug: + var: deploy_helper +''' +import os +import shutil +import time +import traceback + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_native + + +class DeployHelper(object): + + def __init__(self, module): + self.module = module + self.file_args = module.load_file_common_arguments(module.params) + + self.clean = module.params['clean'] + self.current_path = module.params['current_path'] + self.keep_releases = module.params['keep_releases'] + self.path = module.params['path'] + self.release = module.params['release'] + self.releases_path = module.params['releases_path'] + self.shared_path = module.params['shared_path'] + self.state = module.params['state'] + self.unfinished_filename = module.params['unfinished_filename'] + + def gather_facts(self): + current_path = os.path.join(self.path, self.current_path) + releases_path = os.path.join(self.path, self.releases_path) + if self.shared_path: + shared_path = os.path.join(self.path, self.shared_path) + else: + shared_path = None + + previous_release, previous_release_path = self._get_last_release(current_path) + + if not self.release and (self.state == 'query' or self.state == 'present'): + self.release = time.strftime("%Y%m%d%H%M%S") + + if self.release: + new_release_path = os.path.join(releases_path, self.release) + else: + new_release_path = None + + return { + 'project_path': self.path, + 'current_path': current_path, + 'releases_path': releases_path, + 'shared_path': shared_path, + 'previous_release': previous_release, + 'previous_release_path': previous_release_path, + 'new_release': self.release, + 'new_release_path': new_release_path, + 'unfinished_filename': self.unfinished_filename + } + + def delete_path(self, path): + if not os.path.lexists(path): + return False + + if not os.path.isdir(path): + self.module.fail_json(msg="%s exists but is not a directory" % path) + + if not self.module.check_mode: + try: + shutil.rmtree(path, ignore_errors=False) + except Exception as e: + self.module.fail_json(msg="rmtree failed: %s" % to_native(e), exception=traceback.format_exc()) + + return True + + def create_path(self, path): + changed = False + + if not os.path.lexists(path): + changed = True + if not self.module.check_mode: + os.makedirs(path) + + elif not os.path.isdir(path): + self.module.fail_json(msg="%s exists but is not a directory" % path) + + changed += self.module.set_directory_attributes_if_different(self._get_file_args(path), changed) + + return changed + + def check_link(self, path): + if os.path.lexists(path): + if not os.path.islink(path): + self.module.fail_json(msg="%s exists but is not a symbolic link" % path) + + def create_link(self, source, link_name): + changed = False + + if os.path.islink(link_name): + norm_link = os.path.normpath(os.path.realpath(link_name)) + norm_source = os.path.normpath(os.path.realpath(source)) + if norm_link == norm_source: + changed = False + else: + changed = True + if not self.module.check_mode: + if not os.path.lexists(source): + self.module.fail_json(msg="the symlink target %s doesn't exists" % source) + tmp_link_name = link_name + '.' + self.unfinished_filename + if os.path.islink(tmp_link_name): + os.unlink(tmp_link_name) + os.symlink(source, tmp_link_name) + os.rename(tmp_link_name, link_name) + else: + changed = True + if not self.module.check_mode: + os.symlink(source, link_name) + + return changed + + def remove_unfinished_file(self, new_release_path): + changed = False + unfinished_file_path = os.path.join(new_release_path, self.unfinished_filename) + if os.path.lexists(unfinished_file_path): + changed = True + if not self.module.check_mode: + os.remove(unfinished_file_path) + + return changed + + def remove_unfinished_builds(self, releases_path): + changes = 0 + + for release in os.listdir(releases_path): + if os.path.isfile(os.path.join(releases_path, release, self.unfinished_filename)): + if self.module.check_mode: + changes += 1 + else: + changes += self.delete_path(os.path.join(releases_path, release)) + + return changes + + def remove_unfinished_link(self, path): + changed = False + + tmp_link_name = os.path.join(path, self.release + '.' + self.unfinished_filename) + if not self.module.check_mode and os.path.exists(tmp_link_name): + changed = True + os.remove(tmp_link_name) + + return changed + + def cleanup(self, releases_path, reserve_version): + changes = 0 + + if os.path.lexists(releases_path): + releases = [f for f in os.listdir(releases_path) if os.path.isdir(os.path.join(releases_path, f))] + try: + releases.remove(reserve_version) + except ValueError: + pass + + if not self.module.check_mode: + releases.sort(key=lambda x: os.path.getctime(os.path.join(releases_path, x)), reverse=True) + for release in releases[self.keep_releases:]: + changes += self.delete_path(os.path.join(releases_path, release)) + elif len(releases) > self.keep_releases: + changes += (len(releases) - self.keep_releases) + + return changes + + def _get_file_args(self, path): + file_args = self.file_args.copy() + file_args['path'] = path + return file_args + + def _get_last_release(self, current_path): + previous_release = None + previous_release_path = None + + if os.path.lexists(current_path): + previous_release_path = os.path.realpath(current_path) + previous_release = os.path.basename(previous_release_path) + + return previous_release, previous_release_path + + +def main(): + + module = AnsibleModule( + argument_spec=dict( + path=dict(aliases=['dest'], required=True, type='path'), + release=dict(required=False, type='str', default=None), + releases_path=dict(required=False, type='str', default='releases'), + shared_path=dict(required=False, type='path', default='shared'), + current_path=dict(required=False, type='path', default='current'), + keep_releases=dict(required=False, type='int', default=5), + clean=dict(required=False, type='bool', default=True), + unfinished_filename=dict(required=False, type='str', default='DEPLOY_UNFINISHED'), + state=dict(required=False, choices=['present', 'absent', 'clean', 'finalize', 'query'], default='present') + ), + add_file_common_args=True, + supports_check_mode=True + ) + + deploy_helper = DeployHelper(module) + facts = deploy_helper.gather_facts() + + result = { + 'state': deploy_helper.state + } + + changes = 0 + + if deploy_helper.state == 'query': + result['ansible_facts'] = {'deploy_helper': facts} + + elif deploy_helper.state == 'present': + deploy_helper.check_link(facts['current_path']) + changes += deploy_helper.create_path(facts['project_path']) + changes += deploy_helper.create_path(facts['releases_path']) + if deploy_helper.shared_path: + changes += deploy_helper.create_path(facts['shared_path']) + + result['ansible_facts'] = {'deploy_helper': facts} + + elif deploy_helper.state == 'finalize': + if not deploy_helper.release: + module.fail_json(msg="'release' is a required parameter for state=finalize (try the 'deploy_helper.new_release' fact)") + if deploy_helper.keep_releases <= 0: + module.fail_json(msg="'keep_releases' should be at least 1") + + changes += deploy_helper.remove_unfinished_file(facts['new_release_path']) + changes += deploy_helper.create_link(facts['new_release_path'], facts['current_path']) + if deploy_helper.clean: + changes += deploy_helper.remove_unfinished_link(facts['project_path']) + changes += deploy_helper.remove_unfinished_builds(facts['releases_path']) + changes += deploy_helper.cleanup(facts['releases_path'], facts['new_release']) + + elif deploy_helper.state == 'clean': + changes += deploy_helper.remove_unfinished_link(facts['project_path']) + changes += deploy_helper.remove_unfinished_builds(facts['releases_path']) + changes += deploy_helper.cleanup(facts['releases_path'], facts['new_release']) + + elif deploy_helper.state == 'absent': + # destroy the facts + result['ansible_facts'] = {'deploy_helper': []} + changes += deploy_helper.delete_path(facts['project_path']) + + if changes > 0: + result['changed'] = True + else: + result['changed'] = False + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/docker_swarm.py b/test/support/integration/plugins/modules/docker_swarm.py new file mode 100644 index 00000000..a2c076c5 --- /dev/null +++ b/test/support/integration/plugins/modules/docker_swarm.py @@ -0,0 +1,681 @@ +#!/usr/bin/python + +# Copyright 2016 Red Hat | Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +--- +module: docker_swarm +short_description: Manage Swarm cluster +version_added: "2.7" +description: + - Create a new Swarm cluster. + - Add/Remove nodes or managers to an existing cluster. +options: + advertise_addr: + description: + - Externally reachable address advertised to other nodes. + - This can either be an address/port combination + in the form C(192.168.1.1:4567), or an interface followed by a + port number, like C(eth0:4567). + - If the port number is omitted, + the port number from the listen address is used. + - If I(advertise_addr) is not specified, it will be automatically + detected when possible. + - Only used when swarm is initialised or joined. Because of this it's not + considered for idempotency checking. + type: str + default_addr_pool: + description: + - Default address pool in CIDR format. + - Only used when swarm is initialised. Because of this it's not considered + for idempotency checking. + - Requires API version >= 1.39. + type: list + elements: str + version_added: "2.8" + subnet_size: + description: + - Default address pool subnet mask length. + - Only used when swarm is initialised. Because of this it's not considered + for idempotency checking. + - Requires API version >= 1.39. + type: int + version_added: "2.8" + listen_addr: + description: + - Listen address used for inter-manager communication. + - This can either be an address/port combination in the form + C(192.168.1.1:4567), or an interface followed by a port number, + like C(eth0:4567). + - If the port number is omitted, the default swarm listening port + is used. + - Only used when swarm is initialised or joined. Because of this it's not + considered for idempotency checking. + type: str + default: 0.0.0.0:2377 + force: + description: + - Use with state C(present) to force creating a new Swarm, even if already part of one. + - Use with state C(absent) to Leave the swarm even if this node is a manager. + type: bool + default: no + state: + description: + - Set to C(present), to create/update a new cluster. + - Set to C(join), to join an existing cluster. + - Set to C(absent), to leave an existing cluster. + - Set to C(remove), to remove an absent node from the cluster. + Note that removing requires Docker SDK for Python >= 2.4.0. + - Set to C(inspect) to display swarm informations. + type: str + default: present + choices: + - present + - join + - absent + - remove + - inspect + node_id: + description: + - Swarm id of the node to remove. + - Used with I(state=remove). + type: str + join_token: + description: + - Swarm token used to join a swarm cluster. + - Used with I(state=join). + type: str + remote_addrs: + description: + - Remote address of one or more manager nodes of an existing Swarm to connect to. + - Used with I(state=join). + type: list + elements: str + task_history_retention_limit: + description: + - Maximum number of tasks history stored. + - Docker default value is C(5). + type: int + snapshot_interval: + description: + - Number of logs entries between snapshot. + - Docker default value is C(10000). + type: int + keep_old_snapshots: + description: + - Number of snapshots to keep beyond the current snapshot. + - Docker default value is C(0). + type: int + log_entries_for_slow_followers: + description: + - Number of log entries to keep around to sync up slow followers after a snapshot is created. + type: int + heartbeat_tick: + description: + - Amount of ticks (in seconds) between each heartbeat. + - Docker default value is C(1s). + type: int + election_tick: + description: + - Amount of ticks (in seconds) needed without a leader to trigger a new election. + - Docker default value is C(10s). + type: int + dispatcher_heartbeat_period: + description: + - The delay for an agent to send a heartbeat to the dispatcher. + - Docker default value is C(5s). + type: int + node_cert_expiry: + description: + - Automatic expiry for nodes certificates. + - Docker default value is C(3months). + type: int + name: + description: + - The name of the swarm. + type: str + labels: + description: + - User-defined key/value metadata. + - Label operations in this module apply to the docker swarm cluster. + Use M(docker_node) module to add/modify/remove swarm node labels. + - Requires API version >= 1.32. + type: dict + signing_ca_cert: + description: + - The desired signing CA certificate for all swarm node TLS leaf certificates, in PEM format. + - This must not be a path to a certificate, but the contents of the certificate. + - Requires API version >= 1.30. + type: str + signing_ca_key: + description: + - The desired signing CA key for all swarm node TLS leaf certificates, in PEM format. + - This must not be a path to a key, but the contents of the key. + - Requires API version >= 1.30. + type: str + ca_force_rotate: + description: + - An integer whose purpose is to force swarm to generate a new signing CA certificate and key, + if none have been specified. + - Docker default value is C(0). + - Requires API version >= 1.30. + type: int + autolock_managers: + description: + - If set, generate a key and use it to lock data stored on the managers. + - Docker default value is C(no). + - M(docker_swarm_info) can be used to retrieve the unlock key. + type: bool + rotate_worker_token: + description: Rotate the worker join token. + type: bool + default: no + rotate_manager_token: + description: Rotate the manager join token. + type: bool + default: no +extends_documentation_fragment: + - docker + - docker.docker_py_1_documentation +requirements: + - "L(Docker SDK for Python,https://docker-py.readthedocs.io/en/stable/) >= 1.10.0 (use L(docker-py,https://pypi.org/project/docker-py/) for Python 2.6)" + - Docker API >= 1.25 +author: + - Thierry Bouvet (@tbouvet) + - Piotr Wojciechowski (@WojciechowskiPiotr) +''' + +EXAMPLES = ''' + +- name: Init a new swarm with default parameters + docker_swarm: + state: present + +- name: Update swarm configuration + docker_swarm: + state: present + election_tick: 5 + +- name: Add nodes + docker_swarm: + state: join + advertise_addr: 192.168.1.2 + join_token: SWMTKN-1--xxxxx + remote_addrs: [ '192.168.1.1:2377' ] + +- name: Leave swarm for a node + docker_swarm: + state: absent + +- name: Remove a swarm manager + docker_swarm: + state: absent + force: true + +- name: Remove node from swarm + docker_swarm: + state: remove + node_id: mynode + +- name: Inspect swarm + docker_swarm: + state: inspect + register: swarm_info +''' + +RETURN = ''' +swarm_facts: + description: Informations about swarm. + returned: success + type: dict + contains: + JoinTokens: + description: Tokens to connect to the Swarm. + returned: success + type: dict + contains: + Worker: + description: Token to create a new *worker* node + returned: success + type: str + example: SWMTKN-1--xxxxx + Manager: + description: Token to create a new *manager* node + returned: success + type: str + example: SWMTKN-1--xxxxx + UnlockKey: + description: The swarm unlock-key if I(autolock_managers) is C(true). + returned: on success if I(autolock_managers) is C(true) + and swarm is initialised, or if I(autolock_managers) has changed. + type: str + example: SWMKEY-1-xxx + +actions: + description: Provides the actions done on the swarm. + returned: when action failed. + type: list + elements: str + example: "['This cluster is already a swarm cluster']" + +''' + +import json +import traceback + +try: + from docker.errors import DockerException, APIError +except ImportError: + # missing Docker SDK for Python handled in ansible.module_utils.docker.common + pass + +from ansible.module_utils.docker.common import ( + DockerBaseClass, + DifferenceTracker, + RequestException, +) + +from ansible.module_utils.docker.swarm import AnsibleDockerSwarmClient + +from ansible.module_utils._text import to_native + + +class TaskParameters(DockerBaseClass): + def __init__(self): + super(TaskParameters, self).__init__() + + self.advertise_addr = None + self.listen_addr = None + self.remote_addrs = None + self.join_token = None + + # Spec + self.snapshot_interval = None + self.task_history_retention_limit = None + self.keep_old_snapshots = None + self.log_entries_for_slow_followers = None + self.heartbeat_tick = None + self.election_tick = None + self.dispatcher_heartbeat_period = None + self.node_cert_expiry = None + self.name = None + self.labels = None + self.log_driver = None + self.signing_ca_cert = None + self.signing_ca_key = None + self.ca_force_rotate = None + self.autolock_managers = None + self.rotate_worker_token = None + self.rotate_manager_token = None + self.default_addr_pool = None + self.subnet_size = None + + @staticmethod + def from_ansible_params(client): + result = TaskParameters() + for key, value in client.module.params.items(): + if key in result.__dict__: + setattr(result, key, value) + + result.update_parameters(client) + return result + + def update_from_swarm_info(self, swarm_info): + spec = swarm_info['Spec'] + + ca_config = spec.get('CAConfig') or dict() + if self.node_cert_expiry is None: + self.node_cert_expiry = ca_config.get('NodeCertExpiry') + if self.ca_force_rotate is None: + self.ca_force_rotate = ca_config.get('ForceRotate') + + dispatcher = spec.get('Dispatcher') or dict() + if self.dispatcher_heartbeat_period is None: + self.dispatcher_heartbeat_period = dispatcher.get('HeartbeatPeriod') + + raft = spec.get('Raft') or dict() + if self.snapshot_interval is None: + self.snapshot_interval = raft.get('SnapshotInterval') + if self.keep_old_snapshots is None: + self.keep_old_snapshots = raft.get('KeepOldSnapshots') + if self.heartbeat_tick is None: + self.heartbeat_tick = raft.get('HeartbeatTick') + if self.log_entries_for_slow_followers is None: + self.log_entries_for_slow_followers = raft.get('LogEntriesForSlowFollowers') + if self.election_tick is None: + self.election_tick = raft.get('ElectionTick') + + orchestration = spec.get('Orchestration') or dict() + if self.task_history_retention_limit is None: + self.task_history_retention_limit = orchestration.get('TaskHistoryRetentionLimit') + + encryption_config = spec.get('EncryptionConfig') or dict() + if self.autolock_managers is None: + self.autolock_managers = encryption_config.get('AutoLockManagers') + + if self.name is None: + self.name = spec['Name'] + + if self.labels is None: + self.labels = spec.get('Labels') or {} + + if 'LogDriver' in spec['TaskDefaults']: + self.log_driver = spec['TaskDefaults']['LogDriver'] + + def update_parameters(self, client): + assign = dict( + snapshot_interval='snapshot_interval', + task_history_retention_limit='task_history_retention_limit', + keep_old_snapshots='keep_old_snapshots', + log_entries_for_slow_followers='log_entries_for_slow_followers', + heartbeat_tick='heartbeat_tick', + election_tick='election_tick', + dispatcher_heartbeat_period='dispatcher_heartbeat_period', + node_cert_expiry='node_cert_expiry', + name='name', + labels='labels', + signing_ca_cert='signing_ca_cert', + signing_ca_key='signing_ca_key', + ca_force_rotate='ca_force_rotate', + autolock_managers='autolock_managers', + log_driver='log_driver', + ) + params = dict() + for dest, source in assign.items(): + if not client.option_minimal_versions[source]['supported']: + continue + value = getattr(self, source) + if value is not None: + params[dest] = value + self.spec = client.create_swarm_spec(**params) + + def compare_to_active(self, other, client, differences): + for k in self.__dict__: + if k in ('advertise_addr', 'listen_addr', 'remote_addrs', 'join_token', + 'rotate_worker_token', 'rotate_manager_token', 'spec', + 'default_addr_pool', 'subnet_size'): + continue + if not client.option_minimal_versions[k]['supported']: + continue + value = getattr(self, k) + if value is None: + continue + other_value = getattr(other, k) + if value != other_value: + differences.add(k, parameter=value, active=other_value) + if self.rotate_worker_token: + differences.add('rotate_worker_token', parameter=True, active=False) + if self.rotate_manager_token: + differences.add('rotate_manager_token', parameter=True, active=False) + return differences + + +class SwarmManager(DockerBaseClass): + + def __init__(self, client, results): + + super(SwarmManager, self).__init__() + + self.client = client + self.results = results + self.check_mode = self.client.check_mode + self.swarm_info = {} + + self.state = client.module.params['state'] + self.force = client.module.params['force'] + self.node_id = client.module.params['node_id'] + + self.differences = DifferenceTracker() + self.parameters = TaskParameters.from_ansible_params(client) + + self.created = False + + def __call__(self): + choice_map = { + "present": self.init_swarm, + "join": self.join, + "absent": self.leave, + "remove": self.remove, + "inspect": self.inspect_swarm + } + + if self.state == 'inspect': + self.client.module.deprecate( + "The 'inspect' state is deprecated, please use 'docker_swarm_info' to inspect swarm cluster", + version='2.12', collection_name='ansible.builtin') + + choice_map.get(self.state)() + + if self.client.module._diff or self.parameters.debug: + diff = dict() + diff['before'], diff['after'] = self.differences.get_before_after() + self.results['diff'] = diff + + def inspect_swarm(self): + try: + data = self.client.inspect_swarm() + json_str = json.dumps(data, ensure_ascii=False) + self.swarm_info = json.loads(json_str) + + self.results['changed'] = False + self.results['swarm_facts'] = self.swarm_info + + unlock_key = self.get_unlock_key() + self.swarm_info.update(unlock_key) + except APIError: + return + + def get_unlock_key(self): + default = {'UnlockKey': None} + if not self.has_swarm_lock_changed(): + return default + try: + return self.client.get_unlock_key() or default + except APIError: + return default + + def has_swarm_lock_changed(self): + return self.parameters.autolock_managers and ( + self.created or self.differences.has_difference_for('autolock_managers') + ) + + def init_swarm(self): + if not self.force and self.client.check_if_swarm_manager(): + self.__update_swarm() + return + + if not self.check_mode: + init_arguments = { + 'advertise_addr': self.parameters.advertise_addr, + 'listen_addr': self.parameters.listen_addr, + 'force_new_cluster': self.force, + 'swarm_spec': self.parameters.spec, + } + if self.parameters.default_addr_pool is not None: + init_arguments['default_addr_pool'] = self.parameters.default_addr_pool + if self.parameters.subnet_size is not None: + init_arguments['subnet_size'] = self.parameters.subnet_size + try: + self.client.init_swarm(**init_arguments) + except APIError as exc: + self.client.fail("Can not create a new Swarm Cluster: %s" % to_native(exc)) + + if not self.client.check_if_swarm_manager(): + if not self.check_mode: + self.client.fail("Swarm not created or other error!") + + self.created = True + self.inspect_swarm() + self.results['actions'].append("New Swarm cluster created: %s" % (self.swarm_info.get('ID'))) + self.differences.add('state', parameter='present', active='absent') + self.results['changed'] = True + self.results['swarm_facts'] = { + 'JoinTokens': self.swarm_info.get('JoinTokens'), + 'UnlockKey': self.swarm_info.get('UnlockKey') + } + + def __update_swarm(self): + try: + self.inspect_swarm() + version = self.swarm_info['Version']['Index'] + self.parameters.update_from_swarm_info(self.swarm_info) + old_parameters = TaskParameters() + old_parameters.update_from_swarm_info(self.swarm_info) + self.parameters.compare_to_active(old_parameters, self.client, self.differences) + if self.differences.empty: + self.results['actions'].append("No modification") + self.results['changed'] = False + return + update_parameters = TaskParameters.from_ansible_params(self.client) + update_parameters.update_parameters(self.client) + if not self.check_mode: + self.client.update_swarm( + version=version, swarm_spec=update_parameters.spec, + rotate_worker_token=self.parameters.rotate_worker_token, + rotate_manager_token=self.parameters.rotate_manager_token) + except APIError as exc: + self.client.fail("Can not update a Swarm Cluster: %s" % to_native(exc)) + return + + self.inspect_swarm() + self.results['actions'].append("Swarm cluster updated") + self.results['changed'] = True + + def join(self): + if self.client.check_if_swarm_node(): + self.results['actions'].append("This node is already part of a swarm.") + return + if not self.check_mode: + try: + self.client.join_swarm( + remote_addrs=self.parameters.remote_addrs, join_token=self.parameters.join_token, + listen_addr=self.parameters.listen_addr, advertise_addr=self.parameters.advertise_addr) + except APIError as exc: + self.client.fail("Can not join the Swarm Cluster: %s" % to_native(exc)) + self.results['actions'].append("New node is added to swarm cluster") + self.differences.add('joined', parameter=True, active=False) + self.results['changed'] = True + + def leave(self): + if not self.client.check_if_swarm_node(): + self.results['actions'].append("This node is not part of a swarm.") + return + if not self.check_mode: + try: + self.client.leave_swarm(force=self.force) + except APIError as exc: + self.client.fail("This node can not leave the Swarm Cluster: %s" % to_native(exc)) + self.results['actions'].append("Node has left the swarm cluster") + self.differences.add('joined', parameter='absent', active='present') + self.results['changed'] = True + + def remove(self): + if not self.client.check_if_swarm_manager(): + self.client.fail("This node is not a manager.") + + try: + status_down = self.client.check_if_swarm_node_is_down(node_id=self.node_id, repeat_check=5) + except APIError: + return + + if not status_down: + self.client.fail("Can not remove the node. The status node is ready and not down.") + + if not self.check_mode: + try: + self.client.remove_node(node_id=self.node_id, force=self.force) + except APIError as exc: + self.client.fail("Can not remove the node from the Swarm Cluster: %s" % to_native(exc)) + self.results['actions'].append("Node is removed from swarm cluster.") + self.differences.add('joined', parameter=False, active=True) + self.results['changed'] = True + + +def _detect_remove_operation(client): + return client.module.params['state'] == 'remove' + + +def main(): + argument_spec = dict( + advertise_addr=dict(type='str'), + state=dict(type='str', default='present', choices=['present', 'join', 'absent', 'remove', 'inspect']), + force=dict(type='bool', default=False), + listen_addr=dict(type='str', default='0.0.0.0:2377'), + remote_addrs=dict(type='list', elements='str'), + join_token=dict(type='str'), + snapshot_interval=dict(type='int'), + task_history_retention_limit=dict(type='int'), + keep_old_snapshots=dict(type='int'), + log_entries_for_slow_followers=dict(type='int'), + heartbeat_tick=dict(type='int'), + election_tick=dict(type='int'), + dispatcher_heartbeat_period=dict(type='int'), + node_cert_expiry=dict(type='int'), + name=dict(type='str'), + labels=dict(type='dict'), + signing_ca_cert=dict(type='str'), + signing_ca_key=dict(type='str'), + ca_force_rotate=dict(type='int'), + autolock_managers=dict(type='bool'), + node_id=dict(type='str'), + rotate_worker_token=dict(type='bool', default=False), + rotate_manager_token=dict(type='bool', default=False), + default_addr_pool=dict(type='list', elements='str'), + subnet_size=dict(type='int'), + ) + + required_if = [ + ('state', 'join', ['advertise_addr', 'remote_addrs', 'join_token']), + ('state', 'remove', ['node_id']) + ] + + option_minimal_versions = dict( + labels=dict(docker_py_version='2.6.0', docker_api_version='1.32'), + signing_ca_cert=dict(docker_py_version='2.6.0', docker_api_version='1.30'), + signing_ca_key=dict(docker_py_version='2.6.0', docker_api_version='1.30'), + ca_force_rotate=dict(docker_py_version='2.6.0', docker_api_version='1.30'), + autolock_managers=dict(docker_py_version='2.6.0'), + log_driver=dict(docker_py_version='2.6.0'), + remove_operation=dict( + docker_py_version='2.4.0', + detect_usage=_detect_remove_operation, + usage_msg='remove swarm nodes' + ), + default_addr_pool=dict(docker_py_version='4.0.0', docker_api_version='1.39'), + subnet_size=dict(docker_py_version='4.0.0', docker_api_version='1.39'), + ) + + client = AnsibleDockerSwarmClient( + argument_spec=argument_spec, + supports_check_mode=True, + required_if=required_if, + min_docker_version='1.10.0', + min_docker_api_version='1.25', + option_minimal_versions=option_minimal_versions, + ) + + try: + results = dict( + changed=False, + result='', + actions=[] + ) + + SwarmManager(client, results)() + client.module.exit_json(**results) + except DockerException as e: + client.fail('An unexpected docker error occurred: {0}'.format(e), exception=traceback.format_exc()) + except RequestException as e: + client.fail('An unexpected requests error occurred when docker-py tried to talk to the docker daemon: {0}'.format(e), exception=traceback.format_exc()) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2.py b/test/support/integration/plugins/modules/ec2.py new file mode 100644 index 00000000..952aa5a1 --- /dev/null +++ b/test/support/integration/plugins/modules/ec2.py @@ -0,0 +1,1766 @@ +#!/usr/bin/python +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: ec2 +short_description: create, terminate, start or stop an instance in ec2 +description: + - Creates or terminates ec2 instances. + - > + Note: This module uses the older boto Python module to interact with the EC2 API. + M(ec2) will still receive bug fixes, but no new features. + Consider using the M(ec2_instance) module instead. + If M(ec2_instance) does not support a feature you need that is available in M(ec2), please + file a feature request. +version_added: "0.9" +options: + key_name: + description: + - Key pair to use on the instance. + - The SSH key must already exist in AWS in order to use this argument. + - Keys can be created / deleted using the M(ec2_key) module. + aliases: ['keypair'] + type: str + id: + version_added: "1.1" + description: + - Identifier for this instance or set of instances, so that the module will be idempotent with respect to EC2 instances. + - This identifier is valid for at least 24 hours after the termination of the instance, and should not be reused for another call later on. + - For details, see the description of client token at U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Run_Instance_Idempotency.html). + type: str + group: + description: + - Security group (or list of groups) to use with the instance. + aliases: [ 'groups' ] + type: list + elements: str + group_id: + version_added: "1.1" + description: + - Security group id (or list of ids) to use with the instance. + type: list + elements: str + zone: + version_added: "1.2" + description: + - AWS availability zone in which to launch the instance. + aliases: [ 'aws_zone', 'ec2_zone' ] + type: str + instance_type: + description: + - Instance type to use for the instance, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-types.html). + - Required when creating a new instance. + type: str + aliases: ['type'] + tenancy: + version_added: "1.9" + description: + - An instance with a tenancy of C(dedicated) runs on single-tenant hardware and can only be launched into a VPC. + - Note that to use dedicated tenancy you MUST specify a I(vpc_subnet_id) as well. + - Dedicated tenancy is not available for EC2 "micro" instances. + default: default + choices: [ "default", "dedicated" ] + type: str + spot_price: + version_added: "1.5" + description: + - Maximum spot price to bid. If not set, a regular on-demand instance is requested. + - A spot request is made with this maximum bid. When it is filled, the instance is started. + type: str + spot_type: + version_added: "2.0" + description: + - The type of spot request. + - After being interrupted a C(persistent) spot instance will be started once there is capacity to fill the request again. + default: "one-time" + choices: [ "one-time", "persistent" ] + type: str + image: + description: + - I(ami) ID to use for the instance. + - Required when I(state=present). + type: str + kernel: + description: + - Kernel eki to use for the instance. + type: str + ramdisk: + description: + - Ramdisk eri to use for the instance. + type: str + wait: + description: + - Wait for the instance to reach its desired state before returning. + - Does not wait for SSH, see the 'wait_for_connection' example for details. + type: bool + default: false + wait_timeout: + description: + - How long before wait gives up, in seconds. + default: 300 + type: int + spot_wait_timeout: + version_added: "1.5" + description: + - How long to wait for the spot instance request to be fulfilled. Affects 'Request valid until' for setting spot request lifespan. + default: 600 + type: int + count: + description: + - Number of instances to launch. + default: 1 + type: int + monitoring: + version_added: "1.1" + description: + - Enable detailed monitoring (CloudWatch) for instance. + type: bool + default: false + user_data: + version_added: "0.9" + description: + - Opaque blob of data which is made available to the EC2 instance. + type: str + instance_tags: + version_added: "1.0" + description: + - A hash/dictionary of tags to add to the new instance or for starting/stopping instance by tag; '{"key":"value"}' and '{"key":"value","key":"value"}'. + type: dict + placement_group: + version_added: "1.3" + description: + - Placement group for the instance when using EC2 Clustered Compute. + type: str + vpc_subnet_id: + version_added: "1.1" + description: + - the subnet ID in which to launch the instance (VPC). + type: str + assign_public_ip: + version_added: "1.5" + description: + - When provisioning within vpc, assign a public IP address. Boto library must be 2.13.0+. + type: bool + private_ip: + version_added: "1.2" + description: + - The private ip address to assign the instance (from the vpc subnet). + type: str + instance_profile_name: + version_added: "1.3" + description: + - Name of the IAM instance profile (i.e. what the EC2 console refers to as an "IAM Role") to use. Boto library must be 2.5.0+. + type: str + instance_ids: + version_added: "1.3" + description: + - "list of instance ids, currently used for states: absent, running, stopped" + aliases: ['instance_id'] + type: list + elements: str + source_dest_check: + version_added: "1.6" + description: + - Enable or Disable the Source/Destination checks (for NAT instances and Virtual Routers). + When initially creating an instance the EC2 API defaults this to C(True). + type: bool + termination_protection: + version_added: "2.0" + description: + - Enable or Disable the Termination Protection. + type: bool + default: false + instance_initiated_shutdown_behavior: + version_added: "2.2" + description: + - Set whether AWS will Stop or Terminate an instance on shutdown. This parameter is ignored when using instance-store. + images (which require termination on shutdown). + default: 'stop' + choices: [ "stop", "terminate" ] + type: str + state: + version_added: "1.3" + description: + - Create, terminate, start, stop or restart instances. The state 'restarted' was added in Ansible 2.2. + - When I(state=absent), I(instance_ids) is required. + - When I(state=running), I(state=stopped) or I(state=restarted) then either I(instance_ids) or I(instance_tags) is required. + default: 'present' + choices: ['absent', 'present', 'restarted', 'running', 'stopped'] + type: str + volumes: + version_added: "1.5" + description: + - A list of hash/dictionaries of volumes to add to the new instance. + type: list + elements: dict + suboptions: + device_name: + type: str + required: true + description: + - A name for the device (For example C(/dev/sda)). + delete_on_termination: + type: bool + default: false + description: + - Whether the volume should be automatically deleted when the instance is terminated. + ephemeral: + type: str + description: + - Whether the volume should be ephemeral. + - Data on ephemeral volumes is lost when the instance is stopped. + - Mutually exclusive with the I(snapshot) parameter. + encrypted: + type: bool + default: false + description: + - Whether the volume should be encrypted using the 'aws/ebs' KMS CMK. + snapshot: + type: str + description: + - The ID of an EBS snapshot to copy when creating the volume. + - Mutually exclusive with the I(ephemeral) parameter. + volume_type: + type: str + description: + - The type of volume to create. + - See U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html) for more information on the available volume types. + volume_size: + type: int + description: + - The size of the volume (in GiB). + iops: + type: int + description: + - The number of IOPS per second to provision for the volume. + - Required when I(volume_type=io1). + ebs_optimized: + version_added: "1.6" + description: + - Whether instance is using optimized EBS volumes, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSOptimized.html). + default: false + type: bool + exact_count: + version_added: "1.5" + description: + - An integer value which indicates how many instances that match the 'count_tag' parameter should be running. + Instances are either created or terminated based on this value. + type: int + count_tag: + version_added: "1.5" + description: + - Used with I(exact_count) to determine how many nodes based on a specific tag criteria should be running. + This can be expressed in multiple ways and is shown in the EXAMPLES section. For instance, one can request 25 servers + that are tagged with "class=webserver". The specified tag must already exist or be passed in as the I(instance_tags) option. + type: raw + network_interfaces: + version_added: "2.0" + description: + - A list of existing network interfaces to attach to the instance at launch. When specifying existing network interfaces, + none of the I(assign_public_ip), I(private_ip), I(vpc_subnet_id), I(group), or I(group_id) parameters may be used. (Those parameters are + for creating a new network interface at launch.) + aliases: ['network_interface'] + type: list + elements: str + spot_launch_group: + version_added: "2.1" + description: + - Launch group for spot requests, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html#spot-launch-group). + type: str +author: + - "Tim Gerla (@tgerla)" + - "Lester Wade (@lwade)" + - "Seth Vidal (@skvidal)" +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Basic provisioning example +- ec2: + key_name: mykey + instance_type: t2.micro + image: ami-123456 + wait: yes + group: webserver + count: 3 + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Advanced example with tagging and CloudWatch +- ec2: + key_name: mykey + group: databases + instance_type: t2.micro + image: ami-123456 + wait: yes + wait_timeout: 500 + count: 5 + instance_tags: + db: postgres + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Single instance with additional IOPS volume from snapshot and volume delete on termination +- ec2: + key_name: mykey + group: webserver + instance_type: c3.medium + image: ami-123456 + wait: yes + wait_timeout: 500 + volumes: + - device_name: /dev/sdb + snapshot: snap-abcdef12 + volume_type: io1 + iops: 1000 + volume_size: 100 + delete_on_termination: true + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Single instance with ssd gp2 root volume +- ec2: + key_name: mykey + group: webserver + instance_type: c3.medium + image: ami-123456 + wait: yes + wait_timeout: 500 + volumes: + - device_name: /dev/xvda + volume_type: gp2 + volume_size: 8 + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + count_tag: + Name: dbserver + exact_count: 1 + +# Multiple groups example +- ec2: + key_name: mykey + group: ['databases', 'internal-services', 'sshable', 'and-so-forth'] + instance_type: m1.large + image: ami-6e649707 + wait: yes + wait_timeout: 500 + count: 5 + instance_tags: + db: postgres + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Multiple instances with additional volume from snapshot +- ec2: + key_name: mykey + group: webserver + instance_type: m1.large + image: ami-6e649707 + wait: yes + wait_timeout: 500 + count: 5 + volumes: + - device_name: /dev/sdb + snapshot: snap-abcdef12 + volume_size: 10 + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Dedicated tenancy example +- local_action: + module: ec2 + assign_public_ip: yes + group_id: sg-1dc53f72 + key_name: mykey + image: ami-6e649707 + instance_type: m1.small + tenancy: dedicated + vpc_subnet_id: subnet-29e63245 + wait: yes + +# Spot instance example +- ec2: + spot_price: 0.24 + spot_wait_timeout: 600 + keypair: mykey + group_id: sg-1dc53f72 + instance_type: m1.small + image: ami-6e649707 + wait: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + spot_launch_group: report_generators + instance_initiated_shutdown_behavior: terminate + +# Examples using pre-existing network interfaces +- ec2: + key_name: mykey + instance_type: t2.small + image: ami-f005ba11 + network_interface: eni-deadbeef + +- ec2: + key_name: mykey + instance_type: t2.small + image: ami-f005ba11 + network_interfaces: ['eni-deadbeef', 'eni-5ca1ab1e'] + +# Launch instances, runs some tasks +# and then terminate them + +- name: Create a sandbox instance + hosts: localhost + gather_facts: False + vars: + keypair: my_keypair + instance_type: m1.small + security_group: my_securitygroup + image: my_ami_id + region: us-east-1 + tasks: + - name: Launch instance + ec2: + key_name: "{{ keypair }}" + group: "{{ security_group }}" + instance_type: "{{ instance_type }}" + image: "{{ image }}" + wait: true + region: "{{ region }}" + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + register: ec2 + + - name: Add new instance to host group + add_host: + hostname: "{{ item.public_ip }}" + groupname: launched + loop: "{{ ec2.instances }}" + + - name: Wait for SSH to come up + delegate_to: "{{ item.public_dns_name }}" + wait_for_connection: + delay: 60 + timeout: 320 + loop: "{{ ec2.instances }}" + +- name: Configure instance(s) + hosts: launched + become: True + gather_facts: True + roles: + - my_awesome_role + - my_awesome_test + +- name: Terminate instances + hosts: localhost + tasks: + - name: Terminate instances that were previously launched + ec2: + state: 'absent' + instance_ids: '{{ ec2.instance_ids }}' + +# Start a few existing instances, run some tasks +# and stop the instances + +- name: Start sandbox instances + hosts: localhost + gather_facts: false + vars: + instance_ids: + - 'i-xxxxxx' + - 'i-xxxxxx' + - 'i-xxxxxx' + region: us-east-1 + tasks: + - name: Start the sandbox instances + ec2: + instance_ids: '{{ instance_ids }}' + region: '{{ region }}' + state: running + wait: True + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + roles: + - do_neat_stuff + - do_more_neat_stuff + +- name: Stop sandbox instances + hosts: localhost + gather_facts: false + vars: + instance_ids: + - 'i-xxxxxx' + - 'i-xxxxxx' + - 'i-xxxxxx' + region: us-east-1 + tasks: + - name: Stop the sandbox instances + ec2: + instance_ids: '{{ instance_ids }}' + region: '{{ region }}' + state: stopped + wait: True + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# +# Start stopped instances specified by tag +# +- local_action: + module: ec2 + instance_tags: + Name: ExtraPower + state: running + +# +# Restart instances specified by tag +# +- local_action: + module: ec2 + instance_tags: + Name: ExtraPower + state: restarted + +# +# Enforce that 5 instances with a tag "foo" are running +# (Highly recommended!) +# + +- ec2: + key_name: mykey + instance_type: c1.medium + image: ami-40603AD1 + wait: yes + group: webserver + instance_tags: + foo: bar + exact_count: 5 + count_tag: foo + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# +# Enforce that 5 running instances named "database" with a "dbtype" of "postgres" +# + +- ec2: + key_name: mykey + instance_type: c1.medium + image: ami-40603AD1 + wait: yes + group: webserver + instance_tags: + Name: database + dbtype: postgres + exact_count: 5 + count_tag: + Name: database + dbtype: postgres + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# +# count_tag complex argument examples +# + + # instances with tag foo +- ec2: + count_tag: + foo: + + # instances with tag foo=bar +- ec2: + count_tag: + foo: bar + + # instances with tags foo=bar & baz +- ec2: + count_tag: + foo: bar + baz: + + # instances with tags foo & bar & baz=bang +- ec2: + count_tag: + - foo + - bar + - baz: bang + +''' + +import time +import datetime +import traceback +from ast import literal_eval +from distutils.version import LooseVersion + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.ec2 import get_aws_connection_info, ec2_argument_spec, ec2_connect +from ansible.module_utils.six import get_function_code, string_types +from ansible.module_utils._text import to_bytes, to_text + +try: + import boto.ec2 + from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping + from boto.exception import EC2ResponseError + from boto import connect_ec2_endpoint + from boto import connect_vpc + HAS_BOTO = True +except ImportError: + HAS_BOTO = False + + +def find_running_instances_by_count_tag(module, ec2, vpc, count_tag, zone=None): + + # get reservations for instances that match tag(s) and are in the desired state + state = module.params.get('state') + if state not in ['running', 'stopped']: + state = None + reservations = get_reservations(module, ec2, vpc, tags=count_tag, state=state, zone=zone) + + instances = [] + for res in reservations: + if hasattr(res, 'instances'): + for inst in res.instances: + if inst.state == 'terminated' or inst.state == 'shutting-down': + continue + instances.append(inst) + + return reservations, instances + + +def _set_none_to_blank(dictionary): + result = dictionary + for k in result: + if isinstance(result[k], dict): + result[k] = _set_none_to_blank(result[k]) + elif not result[k]: + result[k] = "" + return result + + +def get_reservations(module, ec2, vpc, tags=None, state=None, zone=None): + # TODO: filters do not work with tags that have underscores + filters = dict() + + vpc_subnet_id = module.params.get('vpc_subnet_id') + vpc_id = None + if vpc_subnet_id: + filters.update({"subnet-id": vpc_subnet_id}) + if vpc: + vpc_id = vpc.get_all_subnets(subnet_ids=[vpc_subnet_id])[0].vpc_id + + if vpc_id: + filters.update({"vpc-id": vpc_id}) + + if tags is not None: + + if isinstance(tags, str): + try: + tags = literal_eval(tags) + except Exception: + pass + + # if not a string type, convert and make sure it's a text string + if isinstance(tags, int): + tags = to_text(tags) + + # if string, we only care that a tag of that name exists + if isinstance(tags, str): + filters.update({"tag-key": tags}) + + # if list, append each item to filters + if isinstance(tags, list): + for x in tags: + if isinstance(x, dict): + x = _set_none_to_blank(x) + filters.update(dict(("tag:" + tn, tv) for (tn, tv) in x.items())) + else: + filters.update({"tag-key": x}) + + # if dict, add the key and value to the filter + if isinstance(tags, dict): + tags = _set_none_to_blank(tags) + filters.update(dict(("tag:" + tn, tv) for (tn, tv) in tags.items())) + + # lets check to see if the filters dict is empty, if so then stop + if not filters: + module.fail_json(msg="Filters based on tag is empty => tags: %s" % (tags)) + + if state: + # http://stackoverflow.com/questions/437511/what-are-the-valid-instancestates-for-the-amazon-ec2-api + filters.update({'instance-state-name': state}) + + if zone: + filters.update({'availability-zone': zone}) + + if module.params.get('id'): + filters['client-token'] = module.params['id'] + + results = ec2.get_all_instances(filters=filters) + + return results + + +def get_instance_info(inst): + """ + Retrieves instance information from an instance + ID and returns it as a dictionary + """ + instance_info = {'id': inst.id, + 'ami_launch_index': inst.ami_launch_index, + 'private_ip': inst.private_ip_address, + 'private_dns_name': inst.private_dns_name, + 'public_ip': inst.ip_address, + 'dns_name': inst.dns_name, + 'public_dns_name': inst.public_dns_name, + 'state_code': inst.state_code, + 'architecture': inst.architecture, + 'image_id': inst.image_id, + 'key_name': inst.key_name, + 'placement': inst.placement, + 'region': inst.placement[:-1], + 'kernel': inst.kernel, + 'ramdisk': inst.ramdisk, + 'launch_time': inst.launch_time, + 'instance_type': inst.instance_type, + 'root_device_type': inst.root_device_type, + 'root_device_name': inst.root_device_name, + 'state': inst.state, + 'hypervisor': inst.hypervisor, + 'tags': inst.tags, + 'groups': dict((group.id, group.name) for group in inst.groups), + } + try: + instance_info['virtualization_type'] = getattr(inst, 'virtualization_type') + except AttributeError: + instance_info['virtualization_type'] = None + + try: + instance_info['ebs_optimized'] = getattr(inst, 'ebs_optimized') + except AttributeError: + instance_info['ebs_optimized'] = False + + try: + bdm_dict = {} + bdm = getattr(inst, 'block_device_mapping') + for device_name in bdm.keys(): + bdm_dict[device_name] = { + 'status': bdm[device_name].status, + 'volume_id': bdm[device_name].volume_id, + 'delete_on_termination': bdm[device_name].delete_on_termination + } + instance_info['block_device_mapping'] = bdm_dict + except AttributeError: + instance_info['block_device_mapping'] = False + + try: + instance_info['tenancy'] = getattr(inst, 'placement_tenancy') + except AttributeError: + instance_info['tenancy'] = 'default' + + return instance_info + + +def boto_supports_associate_public_ip_address(ec2): + """ + Check if Boto library has associate_public_ip_address in the NetworkInterfaceSpecification + class. Added in Boto 2.13.0 + + ec2: authenticated ec2 connection object + + Returns: + True if Boto library accepts associate_public_ip_address argument, else false + """ + + try: + network_interface = boto.ec2.networkinterface.NetworkInterfaceSpecification() + getattr(network_interface, "associate_public_ip_address") + return True + except AttributeError: + return False + + +def boto_supports_profile_name_arg(ec2): + """ + Check if Boto library has instance_profile_name argument. instance_profile_name has been added in Boto 2.5.0 + + ec2: authenticated ec2 connection object + + Returns: + True if Boto library accept instance_profile_name argument, else false + """ + run_instances_method = getattr(ec2, 'run_instances') + return 'instance_profile_name' in get_function_code(run_instances_method).co_varnames + + +def boto_supports_volume_encryption(): + """ + Check if Boto library supports encryption of EBS volumes (added in 2.29.0) + + Returns: + True if boto library has the named param as an argument on the request_spot_instances method, else False + """ + return hasattr(boto, 'Version') and LooseVersion(boto.Version) >= LooseVersion('2.29.0') + + +def create_block_device(module, ec2, volume): + # Not aware of a way to determine this programatically + # http://aws.amazon.com/about-aws/whats-new/2013/10/09/ebs-provisioned-iops-maximum-iops-gb-ratio-increased-to-30-1/ + MAX_IOPS_TO_SIZE_RATIO = 30 + + volume_type = volume.get('volume_type') + + if 'snapshot' not in volume and 'ephemeral' not in volume: + if 'volume_size' not in volume: + module.fail_json(msg='Size must be specified when creating a new volume or modifying the root volume') + if 'snapshot' in volume: + if volume_type == 'io1' and 'iops' not in volume: + module.fail_json(msg='io1 volumes must have an iops value set') + if 'iops' in volume: + snapshot = ec2.get_all_snapshots(snapshot_ids=[volume['snapshot']])[0] + size = volume.get('volume_size', snapshot.volume_size) + if int(volume['iops']) > MAX_IOPS_TO_SIZE_RATIO * size: + module.fail_json(msg='IOPS must be at most %d times greater than size' % MAX_IOPS_TO_SIZE_RATIO) + if 'ephemeral' in volume: + if 'snapshot' in volume: + module.fail_json(msg='Cannot set both ephemeral and snapshot') + if boto_supports_volume_encryption(): + return BlockDeviceType(snapshot_id=volume.get('snapshot'), + ephemeral_name=volume.get('ephemeral'), + size=volume.get('volume_size'), + volume_type=volume_type, + delete_on_termination=volume.get('delete_on_termination', False), + iops=volume.get('iops'), + encrypted=volume.get('encrypted', None)) + else: + return BlockDeviceType(snapshot_id=volume.get('snapshot'), + ephemeral_name=volume.get('ephemeral'), + size=volume.get('volume_size'), + volume_type=volume_type, + delete_on_termination=volume.get('delete_on_termination', False), + iops=volume.get('iops')) + + +def boto_supports_param_in_spot_request(ec2, param): + """ + Check if Boto library has a <param> in its request_spot_instances() method. For example, the placement_group parameter wasn't added until 2.3.0. + + ec2: authenticated ec2 connection object + + Returns: + True if boto library has the named param as an argument on the request_spot_instances method, else False + """ + method = getattr(ec2, 'request_spot_instances') + return param in get_function_code(method).co_varnames + + +def await_spot_requests(module, ec2, spot_requests, count): + """ + Wait for a group of spot requests to be fulfilled, or fail. + + module: Ansible module object + ec2: authenticated ec2 connection object + spot_requests: boto.ec2.spotinstancerequest.SpotInstanceRequest object returned by ec2.request_spot_instances + count: Total number of instances to be created by the spot requests + + Returns: + list of instance ID's created by the spot request(s) + """ + spot_wait_timeout = int(module.params.get('spot_wait_timeout')) + wait_complete = time.time() + spot_wait_timeout + + spot_req_inst_ids = dict() + while time.time() < wait_complete: + reqs = ec2.get_all_spot_instance_requests() + for sirb in spot_requests: + if sirb.id in spot_req_inst_ids: + continue + for sir in reqs: + if sir.id != sirb.id: + continue # this is not our spot instance + if sir.instance_id is not None: + spot_req_inst_ids[sirb.id] = sir.instance_id + elif sir.state == 'open': + continue # still waiting, nothing to do here + elif sir.state == 'active': + continue # Instance is created already, nothing to do here + elif sir.state == 'failed': + module.fail_json(msg="Spot instance request %s failed with status %s and fault %s:%s" % ( + sir.id, sir.status.code, sir.fault.code, sir.fault.message)) + elif sir.state == 'cancelled': + module.fail_json(msg="Spot instance request %s was cancelled before it could be fulfilled." % sir.id) + elif sir.state == 'closed': + # instance is terminating or marked for termination + # this may be intentional on the part of the operator, + # or it may have been terminated by AWS due to capacity, + # price, or group constraints in this case, we'll fail + # the module if the reason for the state is anything + # other than termination by user. Codes are documented at + # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-bid-status.html + if sir.status.code == 'instance-terminated-by-user': + # do nothing, since the user likely did this on purpose + pass + else: + spot_msg = "Spot instance request %s was closed by AWS with the status %s and fault %s:%s" + module.fail_json(msg=spot_msg % (sir.id, sir.status.code, sir.fault.code, sir.fault.message)) + + if len(spot_req_inst_ids) < count: + time.sleep(5) + else: + return list(spot_req_inst_ids.values()) + module.fail_json(msg="wait for spot requests timeout on %s" % time.asctime()) + + +def enforce_count(module, ec2, vpc): + + exact_count = module.params.get('exact_count') + count_tag = module.params.get('count_tag') + zone = module.params.get('zone') + + # fail here if the exact count was specified without filtering + # on a tag, as this may lead to a undesired removal of instances + if exact_count and count_tag is None: + module.fail_json(msg="you must use the 'count_tag' option with exact_count") + + reservations, instances = find_running_instances_by_count_tag(module, ec2, vpc, count_tag, zone) + + changed = None + checkmode = False + instance_dict_array = [] + changed_instance_ids = None + + if len(instances) == exact_count: + changed = False + elif len(instances) < exact_count: + changed = True + to_create = exact_count - len(instances) + if not checkmode: + (instance_dict_array, changed_instance_ids, changed) \ + = create_instances(module, ec2, vpc, override_count=to_create) + + for inst in instance_dict_array: + instances.append(inst) + elif len(instances) > exact_count: + changed = True + to_remove = len(instances) - exact_count + if not checkmode: + all_instance_ids = sorted([x.id for x in instances]) + remove_ids = all_instance_ids[0:to_remove] + + instances = [x for x in instances if x.id not in remove_ids] + + (changed, instance_dict_array, changed_instance_ids) \ + = terminate_instances(module, ec2, remove_ids) + terminated_list = [] + for inst in instance_dict_array: + inst['state'] = "terminated" + terminated_list.append(inst) + instance_dict_array = terminated_list + + # ensure all instances are dictionaries + all_instances = [] + for inst in instances: + + if not isinstance(inst, dict): + warn_if_public_ip_assignment_changed(module, inst) + inst = get_instance_info(inst) + all_instances.append(inst) + + return (all_instances, instance_dict_array, changed_instance_ids, changed) + + +def create_instances(module, ec2, vpc, override_count=None): + """ + Creates new instances + + module : AnsibleModule object + ec2: authenticated ec2 connection object + + Returns: + A list of dictionaries with instance information + about the instances that were launched + """ + + key_name = module.params.get('key_name') + id = module.params.get('id') + group_name = module.params.get('group') + group_id = module.params.get('group_id') + zone = module.params.get('zone') + instance_type = module.params.get('instance_type') + tenancy = module.params.get('tenancy') + spot_price = module.params.get('spot_price') + spot_type = module.params.get('spot_type') + image = module.params.get('image') + if override_count: + count = override_count + else: + count = module.params.get('count') + monitoring = module.params.get('monitoring') + kernel = module.params.get('kernel') + ramdisk = module.params.get('ramdisk') + wait = module.params.get('wait') + wait_timeout = int(module.params.get('wait_timeout')) + spot_wait_timeout = int(module.params.get('spot_wait_timeout')) + placement_group = module.params.get('placement_group') + user_data = module.params.get('user_data') + instance_tags = module.params.get('instance_tags') + vpc_subnet_id = module.params.get('vpc_subnet_id') + assign_public_ip = module.boolean(module.params.get('assign_public_ip')) + private_ip = module.params.get('private_ip') + instance_profile_name = module.params.get('instance_profile_name') + volumes = module.params.get('volumes') + ebs_optimized = module.params.get('ebs_optimized') + exact_count = module.params.get('exact_count') + count_tag = module.params.get('count_tag') + source_dest_check = module.boolean(module.params.get('source_dest_check')) + termination_protection = module.boolean(module.params.get('termination_protection')) + network_interfaces = module.params.get('network_interfaces') + spot_launch_group = module.params.get('spot_launch_group') + instance_initiated_shutdown_behavior = module.params.get('instance_initiated_shutdown_behavior') + + vpc_id = None + if vpc_subnet_id: + if not vpc: + module.fail_json(msg="region must be specified") + else: + vpc_id = vpc.get_all_subnets(subnet_ids=[vpc_subnet_id])[0].vpc_id + else: + vpc_id = None + + try: + # Here we try to lookup the group id from the security group name - if group is set. + if group_name: + if vpc_id: + grp_details = ec2.get_all_security_groups(filters={'vpc_id': vpc_id}) + else: + grp_details = ec2.get_all_security_groups() + if isinstance(group_name, string_types): + group_name = [group_name] + unmatched = set(group_name).difference(str(grp.name) for grp in grp_details) + if len(unmatched) > 0: + module.fail_json(msg="The following group names are not valid: %s" % ', '.join(unmatched)) + group_id = [str(grp.id) for grp in grp_details if str(grp.name) in group_name] + # Now we try to lookup the group id testing if group exists. + elif group_id: + # wrap the group_id in a list if it's not one already + if isinstance(group_id, string_types): + group_id = [group_id] + grp_details = ec2.get_all_security_groups(group_ids=group_id) + group_name = [grp_item.name for grp_item in grp_details] + except boto.exception.NoAuthHandlerFound as e: + module.fail_json(msg=str(e)) + + # Lookup any instances that much our run id. + + running_instances = [] + count_remaining = int(count) + + if id is not None: + filter_dict = {'client-token': id, 'instance-state-name': 'running'} + previous_reservations = ec2.get_all_instances(None, filter_dict) + for res in previous_reservations: + for prev_instance in res.instances: + running_instances.append(prev_instance) + count_remaining = count_remaining - len(running_instances) + + # Both min_count and max_count equal count parameter. This means the launch request is explicit (we want count, or fail) in how many instances we want. + + if count_remaining == 0: + changed = False + else: + changed = True + try: + params = {'image_id': image, + 'key_name': key_name, + 'monitoring_enabled': monitoring, + 'placement': zone, + 'instance_type': instance_type, + 'kernel_id': kernel, + 'ramdisk_id': ramdisk} + if user_data is not None: + params['user_data'] = to_bytes(user_data, errors='surrogate_or_strict') + + if ebs_optimized: + params['ebs_optimized'] = ebs_optimized + + # 'tenancy' always has a default value, but it is not a valid parameter for spot instance request + if not spot_price: + params['tenancy'] = tenancy + + if boto_supports_profile_name_arg(ec2): + params['instance_profile_name'] = instance_profile_name + else: + if instance_profile_name is not None: + module.fail_json( + msg="instance_profile_name parameter requires Boto version 2.5.0 or higher") + + if assign_public_ip is not None: + if not boto_supports_associate_public_ip_address(ec2): + module.fail_json( + msg="assign_public_ip parameter requires Boto version 2.13.0 or higher.") + elif not vpc_subnet_id: + module.fail_json( + msg="assign_public_ip only available with vpc_subnet_id") + + else: + if private_ip: + interface = boto.ec2.networkinterface.NetworkInterfaceSpecification( + subnet_id=vpc_subnet_id, + private_ip_address=private_ip, + groups=group_id, + associate_public_ip_address=assign_public_ip) + else: + interface = boto.ec2.networkinterface.NetworkInterfaceSpecification( + subnet_id=vpc_subnet_id, + groups=group_id, + associate_public_ip_address=assign_public_ip) + interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface) + params['network_interfaces'] = interfaces + else: + if network_interfaces: + if isinstance(network_interfaces, string_types): + network_interfaces = [network_interfaces] + interfaces = [] + for i, network_interface_id in enumerate(network_interfaces): + interface = boto.ec2.networkinterface.NetworkInterfaceSpecification( + network_interface_id=network_interface_id, + device_index=i) + interfaces.append(interface) + params['network_interfaces'] = \ + boto.ec2.networkinterface.NetworkInterfaceCollection(*interfaces) + else: + params['subnet_id'] = vpc_subnet_id + if vpc_subnet_id: + params['security_group_ids'] = group_id + else: + params['security_groups'] = group_name + + if volumes: + bdm = BlockDeviceMapping() + for volume in volumes: + if 'device_name' not in volume: + module.fail_json(msg='Device name must be set for volume') + # Minimum volume size is 1GiB. We'll use volume size explicitly set to 0 + # to be a signal not to create this volume + if 'volume_size' not in volume or int(volume['volume_size']) > 0: + bdm[volume['device_name']] = create_block_device(module, ec2, volume) + + params['block_device_map'] = bdm + + # check to see if we're using spot pricing first before starting instances + if not spot_price: + if assign_public_ip is not None and private_ip: + params.update( + dict( + min_count=count_remaining, + max_count=count_remaining, + client_token=id, + placement_group=placement_group, + ) + ) + else: + params.update( + dict( + min_count=count_remaining, + max_count=count_remaining, + client_token=id, + placement_group=placement_group, + private_ip_address=private_ip, + ) + ) + + # For ordinary (not spot) instances, we can select 'stop' + # (the default) or 'terminate' here. + params['instance_initiated_shutdown_behavior'] = instance_initiated_shutdown_behavior or 'stop' + + try: + res = ec2.run_instances(**params) + except boto.exception.EC2ResponseError as e: + if (params['instance_initiated_shutdown_behavior'] != 'terminate' and + "InvalidParameterCombination" == e.error_code): + params['instance_initiated_shutdown_behavior'] = 'terminate' + res = ec2.run_instances(**params) + else: + raise + + instids = [i.id for i in res.instances] + while True: + try: + ec2.get_all_instances(instids) + break + except boto.exception.EC2ResponseError as e: + if "<Code>InvalidInstanceID.NotFound</Code>" in str(e): + # there's a race between start and get an instance + continue + else: + module.fail_json(msg=str(e)) + + # The instances returned through ec2.run_instances above can be in + # terminated state due to idempotency. See commit 7f11c3d for a complete + # explanation. + terminated_instances = [ + str(instance.id) for instance in res.instances if instance.state == 'terminated' + ] + if terminated_instances: + module.fail_json(msg="Instances with id(s) %s " % terminated_instances + + "were created previously but have since been terminated - " + + "use a (possibly different) 'instanceid' parameter") + + else: + if private_ip: + module.fail_json( + msg='private_ip only available with on-demand (non-spot) instances') + if boto_supports_param_in_spot_request(ec2, 'placement_group'): + params['placement_group'] = placement_group + elif placement_group: + module.fail_json( + msg="placement_group parameter requires Boto version 2.3.0 or higher.") + + # You can't tell spot instances to 'stop'; they will always be + # 'terminate'd. For convenience, we'll ignore the latter value. + if instance_initiated_shutdown_behavior and instance_initiated_shutdown_behavior != 'terminate': + module.fail_json( + msg="instance_initiated_shutdown_behavior=stop is not supported for spot instances.") + + if spot_launch_group and isinstance(spot_launch_group, string_types): + params['launch_group'] = spot_launch_group + + params.update(dict( + count=count_remaining, + type=spot_type, + )) + + # Set spot ValidUntil + # ValidUntil -> (timestamp). The end date of the request, in + # UTC format (for example, YYYY -MM -DD T*HH* :MM :SS Z). + utc_valid_until = ( + datetime.datetime.utcnow() + + datetime.timedelta(seconds=spot_wait_timeout)) + params['valid_until'] = utc_valid_until.strftime('%Y-%m-%dT%H:%M:%S.000Z') + + res = ec2.request_spot_instances(spot_price, **params) + + # Now we have to do the intermediate waiting + if wait: + instids = await_spot_requests(module, ec2, res, count) + else: + instids = [] + except boto.exception.BotoServerError as e: + module.fail_json(msg="Instance creation failed => %s: %s" % (e.error_code, e.error_message)) + + # wait here until the instances are up + num_running = 0 + wait_timeout = time.time() + wait_timeout + res_list = () + while wait_timeout > time.time() and num_running < len(instids): + try: + res_list = ec2.get_all_instances(instids) + except boto.exception.BotoServerError as e: + if e.error_code == 'InvalidInstanceID.NotFound': + time.sleep(1) + continue + else: + raise + + num_running = 0 + for res in res_list: + num_running += len([i for i in res.instances if i.state == 'running']) + if len(res_list) <= 0: + # got a bad response of some sort, possibly due to + # stale/cached data. Wait a second and then try again + time.sleep(1) + continue + if wait and num_running < len(instids): + time.sleep(5) + else: + break + + if wait and wait_timeout <= time.time(): + # waiting took too long + module.fail_json(msg="wait for instances running timeout on %s" % time.asctime()) + + # We do this after the loop ends so that we end up with one list + for res in res_list: + running_instances.extend(res.instances) + + # Enabled by default by AWS + if source_dest_check is False: + for inst in res.instances: + inst.modify_attribute('sourceDestCheck', False) + + # Disabled by default by AWS + if termination_protection is True: + for inst in res.instances: + inst.modify_attribute('disableApiTermination', True) + + # Leave this as late as possible to try and avoid InvalidInstanceID.NotFound + if instance_tags and instids: + try: + ec2.create_tags(instids, instance_tags) + except boto.exception.EC2ResponseError as e: + module.fail_json(msg="Instance tagging failed => %s: %s" % (e.error_code, e.error_message)) + + instance_dict_array = [] + created_instance_ids = [] + for inst in running_instances: + inst.update() + d = get_instance_info(inst) + created_instance_ids.append(inst.id) + instance_dict_array.append(d) + + return (instance_dict_array, created_instance_ids, changed) + + +def terminate_instances(module, ec2, instance_ids): + """ + Terminates a list of instances + + module: Ansible module object + ec2: authenticated ec2 connection object + termination_list: a list of instances to terminate in the form of + [ {id: <inst-id>}, ..] + + Returns a dictionary of instance information + about the instances terminated. + + If the instance to be terminated is running + "changed" will be set to False. + + """ + + # Whether to wait for termination to complete before returning + wait = module.params.get('wait') + wait_timeout = int(module.params.get('wait_timeout')) + + changed = False + instance_dict_array = [] + + if not isinstance(instance_ids, list) or len(instance_ids) < 1: + module.fail_json(msg='instance_ids should be a list of instances, aborting') + + terminated_instance_ids = [] + for res in ec2.get_all_instances(instance_ids): + for inst in res.instances: + if inst.state == 'running' or inst.state == 'stopped': + terminated_instance_ids.append(inst.id) + instance_dict_array.append(get_instance_info(inst)) + try: + ec2.terminate_instances([inst.id]) + except EC2ResponseError as e: + module.fail_json(msg='Unable to terminate instance {0}, error: {1}'.format(inst.id, e)) + changed = True + + # wait here until the instances are 'terminated' + if wait: + num_terminated = 0 + wait_timeout = time.time() + wait_timeout + while wait_timeout > time.time() and num_terminated < len(terminated_instance_ids): + response = ec2.get_all_instances(instance_ids=terminated_instance_ids, + filters={'instance-state-name': 'terminated'}) + try: + num_terminated = sum([len(res.instances) for res in response]) + except Exception as e: + # got a bad response of some sort, possibly due to + # stale/cached data. Wait a second and then try again + time.sleep(1) + continue + + if num_terminated < len(terminated_instance_ids): + time.sleep(5) + + # waiting took too long + if wait_timeout < time.time() and num_terminated < len(terminated_instance_ids): + module.fail_json(msg="wait for instance termination timeout on %s" % time.asctime()) + # Lets get the current state of the instances after terminating - issue600 + instance_dict_array = [] + for res in ec2.get_all_instances(instance_ids=terminated_instance_ids, filters={'instance-state-name': 'terminated'}): + for inst in res.instances: + instance_dict_array.append(get_instance_info(inst)) + + return (changed, instance_dict_array, terminated_instance_ids) + + +def startstop_instances(module, ec2, instance_ids, state, instance_tags): + """ + Starts or stops a list of existing instances + + module: Ansible module object + ec2: authenticated ec2 connection object + instance_ids: The list of instances to start in the form of + [ {id: <inst-id>}, ..] + instance_tags: A dict of tag keys and values in the form of + {key: value, ... } + state: Intended state ("running" or "stopped") + + Returns a dictionary of instance information + about the instances started/stopped. + + If the instance was not able to change state, + "changed" will be set to False. + + Note that if instance_ids and instance_tags are both non-empty, + this method will process the intersection of the two + """ + + wait = module.params.get('wait') + wait_timeout = int(module.params.get('wait_timeout')) + group_id = module.params.get('group_id') + group_name = module.params.get('group') + changed = False + instance_dict_array = [] + + if not isinstance(instance_ids, list) or len(instance_ids) < 1: + # Fail unless the user defined instance tags + if not instance_tags: + module.fail_json(msg='instance_ids should be a list of instances, aborting') + + # To make an EC2 tag filter, we need to prepend 'tag:' to each key. + # An empty filter does no filtering, so it's safe to pass it to the + # get_all_instances method even if the user did not specify instance_tags + filters = {} + if instance_tags: + for key, value in instance_tags.items(): + filters["tag:" + key] = value + + if module.params.get('id'): + filters['client-token'] = module.params['id'] + # Check that our instances are not in the state we want to take + + # Check (and eventually change) instances attributes and instances state + existing_instances_array = [] + for res in ec2.get_all_instances(instance_ids, filters=filters): + for inst in res.instances: + + warn_if_public_ip_assignment_changed(module, inst) + + changed = (check_source_dest_attr(module, inst, ec2) or + check_termination_protection(module, inst) or changed) + + # Check security groups and if we're using ec2-vpc; ec2-classic security groups may not be modified + if inst.vpc_id and group_name: + grp_details = ec2.get_all_security_groups(filters={'vpc_id': inst.vpc_id}) + if isinstance(group_name, string_types): + group_name = [group_name] + unmatched = set(group_name) - set(to_text(grp.name) for grp in grp_details) + if unmatched: + module.fail_json(msg="The following group names are not valid: %s" % ', '.join(unmatched)) + group_ids = [to_text(grp.id) for grp in grp_details if to_text(grp.name) in group_name] + elif inst.vpc_id and group_id: + if isinstance(group_id, string_types): + group_id = [group_id] + grp_details = ec2.get_all_security_groups(group_ids=group_id) + group_ids = [grp_item.id for grp_item in grp_details] + if inst.vpc_id and (group_name or group_id): + if set(sg.id for sg in inst.groups) != set(group_ids): + changed = inst.modify_attribute('groupSet', group_ids) + + # Check instance state + if inst.state != state: + instance_dict_array.append(get_instance_info(inst)) + try: + if state == 'running': + inst.start() + else: + inst.stop() + except EC2ResponseError as e: + module.fail_json(msg='Unable to change state for instance {0}, error: {1}'.format(inst.id, e)) + changed = True + existing_instances_array.append(inst.id) + + instance_ids = list(set(existing_instances_array + (instance_ids or []))) + # Wait for all the instances to finish starting or stopping + wait_timeout = time.time() + wait_timeout + while wait and wait_timeout > time.time(): + instance_dict_array = [] + matched_instances = [] + for res in ec2.get_all_instances(instance_ids): + for i in res.instances: + if i.state == state: + instance_dict_array.append(get_instance_info(i)) + matched_instances.append(i) + if len(matched_instances) < len(instance_ids): + time.sleep(5) + else: + break + + if wait and wait_timeout <= time.time(): + # waiting took too long + module.fail_json(msg="wait for instances running timeout on %s" % time.asctime()) + + return (changed, instance_dict_array, instance_ids) + + +def restart_instances(module, ec2, instance_ids, state, instance_tags): + """ + Restarts a list of existing instances + + module: Ansible module object + ec2: authenticated ec2 connection object + instance_ids: The list of instances to start in the form of + [ {id: <inst-id>}, ..] + instance_tags: A dict of tag keys and values in the form of + {key: value, ... } + state: Intended state ("restarted") + + Returns a dictionary of instance information + about the instances. + + If the instance was not able to change state, + "changed" will be set to False. + + Wait will not apply here as this is a OS level operation. + + Note that if instance_ids and instance_tags are both non-empty, + this method will process the intersection of the two. + """ + + changed = False + instance_dict_array = [] + + if not isinstance(instance_ids, list) or len(instance_ids) < 1: + # Fail unless the user defined instance tags + if not instance_tags: + module.fail_json(msg='instance_ids should be a list of instances, aborting') + + # To make an EC2 tag filter, we need to prepend 'tag:' to each key. + # An empty filter does no filtering, so it's safe to pass it to the + # get_all_instances method even if the user did not specify instance_tags + filters = {} + if instance_tags: + for key, value in instance_tags.items(): + filters["tag:" + key] = value + if module.params.get('id'): + filters['client-token'] = module.params['id'] + + # Check that our instances are not in the state we want to take + + # Check (and eventually change) instances attributes and instances state + for res in ec2.get_all_instances(instance_ids, filters=filters): + for inst in res.instances: + + warn_if_public_ip_assignment_changed(module, inst) + + changed = (check_source_dest_attr(module, inst, ec2) or + check_termination_protection(module, inst) or changed) + + # Check instance state + if inst.state != state: + instance_dict_array.append(get_instance_info(inst)) + try: + inst.reboot() + except EC2ResponseError as e: + module.fail_json(msg='Unable to change state for instance {0}, error: {1}'.format(inst.id, e)) + changed = True + + return (changed, instance_dict_array, instance_ids) + + +def check_termination_protection(module, inst): + """ + Check the instance disableApiTermination attribute. + + module: Ansible module object + inst: EC2 instance object + + returns: True if state changed None otherwise + """ + + termination_protection = module.params.get('termination_protection') + + if (inst.get_attribute('disableApiTermination')['disableApiTermination'] != termination_protection and termination_protection is not None): + inst.modify_attribute('disableApiTermination', termination_protection) + return True + + +def check_source_dest_attr(module, inst, ec2): + """ + Check the instance sourceDestCheck attribute. + + module: Ansible module object + inst: EC2 instance object + + returns: True if state changed None otherwise + """ + + source_dest_check = module.params.get('source_dest_check') + + if source_dest_check is not None: + try: + if inst.vpc_id is not None and inst.get_attribute('sourceDestCheck')['sourceDestCheck'] != source_dest_check: + inst.modify_attribute('sourceDestCheck', source_dest_check) + return True + except boto.exception.EC2ResponseError as exc: + # instances with more than one Elastic Network Interface will + # fail, because they have the sourceDestCheck attribute defined + # per-interface + if exc.code == 'InvalidInstanceID': + for interface in inst.interfaces: + if interface.source_dest_check != source_dest_check: + ec2.modify_network_interface_attribute(interface.id, "sourceDestCheck", source_dest_check) + return True + else: + module.fail_json(msg='Failed to handle source_dest_check state for instance {0}, error: {1}'.format(inst.id, exc), + exception=traceback.format_exc()) + + +def warn_if_public_ip_assignment_changed(module, instance): + # This is a non-modifiable attribute. + assign_public_ip = module.params.get('assign_public_ip') + + # Check that public ip assignment is the same and warn if not + public_dns_name = getattr(instance, 'public_dns_name', None) + if (assign_public_ip or public_dns_name) and (not public_dns_name or assign_public_ip is False): + module.warn("Unable to modify public ip assignment to {0} for instance {1}. " + "Whether or not to assign a public IP is determined during instance creation.".format(assign_public_ip, instance.id)) + + +def main(): + argument_spec = ec2_argument_spec() + argument_spec.update( + dict( + key_name=dict(aliases=['keypair']), + id=dict(), + group=dict(type='list', aliases=['groups']), + group_id=dict(type='list'), + zone=dict(aliases=['aws_zone', 'ec2_zone']), + instance_type=dict(aliases=['type']), + spot_price=dict(), + spot_type=dict(default='one-time', choices=["one-time", "persistent"]), + spot_launch_group=dict(), + image=dict(), + kernel=dict(), + count=dict(type='int', default='1'), + monitoring=dict(type='bool', default=False), + ramdisk=dict(), + wait=dict(type='bool', default=False), + wait_timeout=dict(type='int', default=300), + spot_wait_timeout=dict(type='int', default=600), + placement_group=dict(), + user_data=dict(), + instance_tags=dict(type='dict'), + vpc_subnet_id=dict(), + assign_public_ip=dict(type='bool'), + private_ip=dict(), + instance_profile_name=dict(), + instance_ids=dict(type='list', aliases=['instance_id']), + source_dest_check=dict(type='bool', default=None), + termination_protection=dict(type='bool', default=None), + state=dict(default='present', choices=['present', 'absent', 'running', 'restarted', 'stopped']), + instance_initiated_shutdown_behavior=dict(default='stop', choices=['stop', 'terminate']), + exact_count=dict(type='int', default=None), + count_tag=dict(type='raw'), + volumes=dict(type='list'), + ebs_optimized=dict(type='bool', default=False), + tenancy=dict(default='default', choices=['default', 'dedicated']), + network_interfaces=dict(type='list', aliases=['network_interface']) + ) + ) + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=[ + # Can be uncommented when we finish the deprecation cycle. + # ['group', 'group_id'], + ['exact_count', 'count'], + ['exact_count', 'state'], + ['exact_count', 'instance_ids'], + ['network_interfaces', 'assign_public_ip'], + ['network_interfaces', 'group'], + ['network_interfaces', 'group_id'], + ['network_interfaces', 'private_ip'], + ['network_interfaces', 'vpc_subnet_id'], + ], + ) + + if module.params.get('group') and module.params.get('group_id'): + module.deprecate( + msg='Support for passing both group and group_id has been deprecated. ' + 'Currently group_id is ignored, in future passing both will result in an error', + version='2.14', collection_name='ansible.builtin') + + if not HAS_BOTO: + module.fail_json(msg='boto required for this module') + + try: + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module) + if module.params.get('region') or not module.params.get('ec2_url'): + ec2 = ec2_connect(module) + elif module.params.get('ec2_url'): + ec2 = connect_ec2_endpoint(ec2_url, **aws_connect_kwargs) + + if 'region' not in aws_connect_kwargs: + aws_connect_kwargs['region'] = ec2.region + + vpc = connect_vpc(**aws_connect_kwargs) + except boto.exception.NoAuthHandlerFound as e: + module.fail_json(msg="Failed to get connection: %s" % e.message, exception=traceback.format_exc()) + + tagged_instances = [] + + state = module.params['state'] + + if state == 'absent': + instance_ids = module.params['instance_ids'] + if not instance_ids: + module.fail_json(msg='instance_ids list is required for absent state') + + (changed, instance_dict_array, new_instance_ids) = terminate_instances(module, ec2, instance_ids) + + elif state in ('running', 'stopped'): + instance_ids = module.params.get('instance_ids') + instance_tags = module.params.get('instance_tags') + if not (isinstance(instance_ids, list) or isinstance(instance_tags, dict)): + module.fail_json(msg='running list needs to be a list of instances or set of tags to run: %s' % instance_ids) + + (changed, instance_dict_array, new_instance_ids) = startstop_instances(module, ec2, instance_ids, state, instance_tags) + + elif state in ('restarted'): + instance_ids = module.params.get('instance_ids') + instance_tags = module.params.get('instance_tags') + if not (isinstance(instance_ids, list) or isinstance(instance_tags, dict)): + module.fail_json(msg='running list needs to be a list of instances or set of tags to run: %s' % instance_ids) + + (changed, instance_dict_array, new_instance_ids) = restart_instances(module, ec2, instance_ids, state, instance_tags) + + elif state == 'present': + # Changed is always set to true when provisioning new instances + if not module.params.get('image'): + module.fail_json(msg='image parameter is required for new instance') + + if module.params.get('exact_count') is None: + (instance_dict_array, new_instance_ids, changed) = create_instances(module, ec2, vpc) + else: + (tagged_instances, instance_dict_array, new_instance_ids, changed) = enforce_count(module, ec2, vpc) + + # Always return instances in the same order + if new_instance_ids: + new_instance_ids.sort() + if instance_dict_array: + instance_dict_array.sort(key=lambda x: x['id']) + if tagged_instances: + tagged_instances.sort(key=lambda x: x['id']) + + module.exit_json(changed=changed, instance_ids=new_instance_ids, instances=instance_dict_array, tagged_instances=tagged_instances) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_ami_info.py b/test/support/integration/plugins/modules/ec2_ami_info.py new file mode 100644 index 00000000..53c2374d --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_ami_info.py @@ -0,0 +1,282 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: ec2_ami_info +version_added: '2.5' +short_description: Gather information about ec2 AMIs +description: + - Gather information about ec2 AMIs + - This module was called C(ec2_ami_facts) before Ansible 2.9. The usage did not change. +author: + - Prasad Katti (@prasadkatti) +requirements: [ boto3 ] +options: + image_ids: + description: One or more image IDs. + aliases: [image_id] + type: list + elements: str + filters: + description: + - A dict of filters to apply. Each dict item consists of a filter key and a filter value. + - See U(https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeImages.html) for possible filters. + - Filter names and values are case sensitive. + type: dict + owners: + description: + - Filter the images by the owner. Valid options are an AWS account ID, self, + or an AWS owner alias ( amazon | aws-marketplace | microsoft ). + aliases: [owner] + type: list + elements: str + executable_users: + description: + - Filter images by users with explicit launch permissions. Valid options are an AWS account ID, self, or all (public AMIs). + aliases: [executable_user] + type: list + elements: str + describe_image_attributes: + description: + - Describe attributes (like launchPermission) of the images found. + default: no + type: bool + +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +- name: gather information about an AMI using ami-id + ec2_ami_info: + image_ids: ami-5b488823 + +- name: gather information about all AMIs with tag key Name and value webapp + ec2_ami_info: + filters: + "tag:Name": webapp + +- name: gather information about an AMI with 'AMI Name' equal to foobar + ec2_ami_info: + filters: + name: foobar + +- name: gather information about Ubuntu 17.04 AMIs published by Canonical (099720109477) + ec2_ami_info: + owners: 099720109477 + filters: + name: "ubuntu/images/ubuntu-zesty-17.04-*" +''' + +RETURN = ''' +images: + description: A list of images. + returned: always + type: list + elements: dict + contains: + architecture: + description: The architecture of the image. + returned: always + type: str + sample: x86_64 + block_device_mappings: + description: Any block device mapping entries. + returned: always + type: list + elements: dict + contains: + device_name: + description: The device name exposed to the instance. + returned: always + type: str + sample: /dev/sda1 + ebs: + description: EBS volumes + returned: always + type: complex + creation_date: + description: The date and time the image was created. + returned: always + type: str + sample: '2017-10-16T19:22:13.000Z' + description: + description: The description of the AMI. + returned: always + type: str + sample: '' + ena_support: + description: Whether enhanced networking with ENA is enabled. + returned: always + type: bool + sample: true + hypervisor: + description: The hypervisor type of the image. + returned: always + type: str + sample: xen + image_id: + description: The ID of the AMI. + returned: always + type: str + sample: ami-5b466623 + image_location: + description: The location of the AMI. + returned: always + type: str + sample: 408466080000/Webapp + image_type: + description: The type of image. + returned: always + type: str + sample: machine + launch_permissions: + description: A List of AWS accounts may launch the AMI. + returned: When image is owned by calling account and I(describe_image_attributes) is yes. + type: list + elements: dict + contains: + group: + description: A value of 'all' means the AMI is public. + type: str + user_id: + description: An AWS account ID with permissions to launch the AMI. + type: str + sample: [{"group": "all"}, {"user_id": "408466080000"}] + name: + description: The name of the AMI that was provided during image creation. + returned: always + type: str + sample: Webapp + owner_id: + description: The AWS account ID of the image owner. + returned: always + type: str + sample: '408466080000' + public: + description: Whether the image has public launch permissions. + returned: always + type: bool + sample: true + root_device_name: + description: The device name of the root device. + returned: always + type: str + sample: /dev/sda1 + root_device_type: + description: The type of root device used by the AMI. + returned: always + type: str + sample: ebs + sriov_net_support: + description: Whether enhanced networking is enabled. + returned: always + type: str + sample: simple + state: + description: The current state of the AMI. + returned: always + type: str + sample: available + tags: + description: Any tags assigned to the image. + returned: always + type: dict + virtualization_type: + description: The type of virtualization of the AMI. + returned: always + type: str + sample: hvm +''' + +try: + from botocore.exceptions import ClientError, BotoCoreError +except ImportError: + pass # caught by AnsibleAWSModule + +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import ansible_dict_to_boto3_filter_list, camel_dict_to_snake_dict, boto3_tag_list_to_ansible_dict + + +def list_ec2_images(ec2_client, module): + + image_ids = module.params.get("image_ids") + owners = module.params.get("owners") + executable_users = module.params.get("executable_users") + filters = module.params.get("filters") + owner_param = [] + + # describe_images is *very* slow if you pass the `Owners` + # param (unless it's self), for some reason. + # Converting the owners to filters and removing from the + # owners param greatly speeds things up. + # Implementation based on aioue's suggestion in #24886 + for owner in owners: + if owner.isdigit(): + if 'owner-id' not in filters: + filters['owner-id'] = list() + filters['owner-id'].append(owner) + elif owner == 'self': + # self not a valid owner-alias filter (https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeImages.html) + owner_param.append(owner) + else: + if 'owner-alias' not in filters: + filters['owner-alias'] = list() + filters['owner-alias'].append(owner) + + filters = ansible_dict_to_boto3_filter_list(filters) + + try: + images = ec2_client.describe_images(ImageIds=image_ids, Filters=filters, Owners=owner_param, ExecutableUsers=executable_users) + images = [camel_dict_to_snake_dict(image) for image in images["Images"]] + except (ClientError, BotoCoreError) as err: + module.fail_json_aws(err, msg="error describing images") + for image in images: + try: + image['tags'] = boto3_tag_list_to_ansible_dict(image.get('tags', [])) + if module.params.get("describe_image_attributes"): + launch_permissions = ec2_client.describe_image_attribute(Attribute='launchPermission', ImageId=image['image_id'])['LaunchPermissions'] + image['launch_permissions'] = [camel_dict_to_snake_dict(perm) for perm in launch_permissions] + except (ClientError, BotoCoreError) as err: + # describing launch permissions of images owned by others is not permitted, but shouldn't cause failures + pass + + images.sort(key=lambda e: e.get('creation_date', '')) # it may be possible that creation_date does not always exist + module.exit_json(images=images) + + +def main(): + + argument_spec = dict( + image_ids=dict(default=[], type='list', aliases=['image_id']), + filters=dict(default={}, type='dict'), + owners=dict(default=[], type='list', aliases=['owner']), + executable_users=dict(default=[], type='list', aliases=['executable_user']), + describe_image_attributes=dict(default=False, type='bool') + ) + + module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True) + if module._module._name == 'ec2_ami_facts': + module._module.deprecate("The 'ec2_ami_facts' module has been renamed to 'ec2_ami_info'", + version='2.13', collection_name='ansible.builtin') + + ec2_client = module.client('ec2') + + list_ec2_images(ec2_client, module) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_group.py b/test/support/integration/plugins/modules/ec2_group.py new file mode 100644 index 00000000..bc416f66 --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_group.py @@ -0,0 +1,1345 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = ''' +--- +module: ec2_group +author: "Andrew de Quincey (@adq)" +version_added: "1.3" +requirements: [ boto3 ] +short_description: maintain an ec2 VPC security group. +description: + - Maintains ec2 security groups. This module has a dependency on python-boto >= 2.5. +options: + name: + description: + - Name of the security group. + - One of and only one of I(name) or I(group_id) is required. + - Required if I(state=present). + required: false + type: str + group_id: + description: + - Id of group to delete (works only with absent). + - One of and only one of I(name) or I(group_id) is required. + required: false + version_added: "2.4" + type: str + description: + description: + - Description of the security group. Required when C(state) is C(present). + required: false + type: str + vpc_id: + description: + - ID of the VPC to create the group in. + required: false + type: str + rules: + description: + - List of firewall inbound rules to enforce in this group (see example). If none are supplied, + no inbound rules will be enabled. Rules list may include its own name in `group_name`. + This allows idempotent loopback additions (e.g. allow group to access itself). + Rule sources list support was added in version 2.4. This allows to define multiple sources per + source type as well as multiple source types per rule. Prior to 2.4 an individual source is allowed. + In version 2.5 support for rule descriptions was added. + required: false + type: list + elements: dict + suboptions: + cidr_ip: + type: str + description: + - The IPv4 CIDR range traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + cidr_ipv6: + type: str + description: + - The IPv6 CIDR range traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + ip_prefix: + type: str + description: + - The IP Prefix U(https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-prefix-lists.html) + that traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_id: + type: str + description: + - The ID of the Security Group that traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_name: + type: str + description: + - Name of the Security Group that traffic is coming from. + - If the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_desc: + type: str + description: + - If the I(group_name) is set and the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + proto: + type: str + description: + - The IP protocol name (C(tcp), C(udp), C(icmp), C(icmpv6)) or number (U(https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers)) + from_port: + type: int + description: The start of the range of ports that traffic is coming from. A value of C(-1) indicates all ports. + to_port: + type: int + description: The end of the range of ports that traffic is coming from. A value of C(-1) indicates all ports. + rule_desc: + type: str + description: A description for the rule. + rules_egress: + description: + - List of firewall outbound rules to enforce in this group (see example). If none are supplied, + a default all-out rule is assumed. If an empty list is supplied, no outbound rules will be enabled. + Rule Egress sources list support was added in version 2.4. In version 2.5 support for rule descriptions + was added. + required: false + version_added: "1.6" + type: list + elements: dict + suboptions: + cidr_ip: + type: str + description: + - The IPv4 CIDR range traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + cidr_ipv6: + type: str + description: + - The IPv6 CIDR range traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + ip_prefix: + type: str + description: + - The IP Prefix U(https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-prefix-lists.html) + that traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_id: + type: str + description: + - The ID of the Security Group that traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_name: + type: str + description: + - Name of the Security Group that traffic is going to. + - If the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_desc: + type: str + description: + - If the I(group_name) is set and the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + proto: + type: str + description: + - The IP protocol name (C(tcp), C(udp), C(icmp), C(icmpv6)) or number (U(https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers)) + from_port: + type: int + description: The start of the range of ports that traffic is going to. A value of C(-1) indicates all ports. + to_port: + type: int + description: The end of the range of ports that traffic is going to. A value of C(-1) indicates all ports. + rule_desc: + type: str + description: A description for the rule. + state: + version_added: "1.4" + description: + - Create or delete a security group. + required: false + default: 'present' + choices: [ "present", "absent" ] + aliases: [] + type: str + purge_rules: + version_added: "1.8" + description: + - Purge existing rules on security group that are not found in rules. + required: false + default: 'true' + aliases: [] + type: bool + purge_rules_egress: + version_added: "1.8" + description: + - Purge existing rules_egress on security group that are not found in rules_egress. + required: false + default: 'true' + aliases: [] + type: bool + tags: + version_added: "2.4" + description: + - A dictionary of one or more tags to assign to the security group. + required: false + type: dict + aliases: ['resource_tags'] + purge_tags: + version_added: "2.4" + description: + - If yes, existing tags will be purged from the resource to match exactly what is defined by I(tags) parameter. If the I(tags) parameter is not set then + tags will not be modified. + required: false + default: yes + type: bool + +extends_documentation_fragment: + - aws + - ec2 + +notes: + - If a rule declares a group_name and that group doesn't exist, it will be + automatically created. In that case, group_desc should be provided as well. + The module will refuse to create a depended-on group without a description. + - Preview diff mode support is added in version 2.7. +''' + +EXAMPLES = ''' +- name: example using security group rule descriptions + ec2_group: + name: "{{ name }}" + description: sg with rule descriptions + vpc_id: vpc-xxxxxxxx + profile: "{{ aws_profile }}" + region: us-east-1 + rules: + - proto: tcp + ports: + - 80 + cidr_ip: 0.0.0.0/0 + rule_desc: allow all on port 80 + +- name: example ec2 group + ec2_group: + name: example + description: an example EC2 group + vpc_id: 12345 + region: eu-west-1 + aws_secret_key: SECRET + aws_access_key: ACCESS + rules: + - proto: tcp + from_port: 80 + to_port: 80 + cidr_ip: 0.0.0.0/0 + - proto: tcp + from_port: 22 + to_port: 22 + cidr_ip: 10.0.0.0/8 + - proto: tcp + from_port: 443 + to_port: 443 + # this should only be needed for EC2 Classic security group rules + # because in a VPC an ELB will use a user-account security group + group_id: amazon-elb/sg-87654321/amazon-elb-sg + - proto: tcp + from_port: 3306 + to_port: 3306 + group_id: 123412341234/sg-87654321/exact-name-of-sg + - proto: udp + from_port: 10050 + to_port: 10050 + cidr_ip: 10.0.0.0/8 + - proto: udp + from_port: 10051 + to_port: 10051 + group_id: sg-12345678 + - proto: icmp + from_port: 8 # icmp type, -1 = any type + to_port: -1 # icmp subtype, -1 = any subtype + cidr_ip: 10.0.0.0/8 + - proto: all + # the containing group name may be specified here + group_name: example + - proto: all + # in the 'proto' attribute, if you specify -1, all, or a protocol number other than tcp, udp, icmp, or 58 (ICMPv6), + # traffic on all ports is allowed, regardless of any ports you specify + from_port: 10050 # this value is ignored + to_port: 10050 # this value is ignored + cidr_ip: 10.0.0.0/8 + + rules_egress: + - proto: tcp + from_port: 80 + to_port: 80 + cidr_ip: 0.0.0.0/0 + cidr_ipv6: 64:ff9b::/96 + group_name: example-other + # description to use if example-other needs to be created + group_desc: other example EC2 group + +- name: example2 ec2 group + ec2_group: + name: example2 + description: an example2 EC2 group + vpc_id: 12345 + region: eu-west-1 + rules: + # 'ports' rule keyword was introduced in version 2.4. It accepts a single port value or a list of values including ranges (from_port-to_port). + - proto: tcp + ports: 22 + group_name: example-vpn + - proto: tcp + ports: + - 80 + - 443 + - 8080-8099 + cidr_ip: 0.0.0.0/0 + # Rule sources list support was added in version 2.4. This allows to define multiple sources per source type as well as multiple source types per rule. + - proto: tcp + ports: + - 6379 + - 26379 + group_name: + - example-vpn + - example-redis + - proto: tcp + ports: 5665 + group_name: example-vpn + cidr_ip: + - 172.16.1.0/24 + - 172.16.17.0/24 + cidr_ipv6: + - 2607:F8B0::/32 + - 64:ff9b::/96 + group_id: + - sg-edcd9784 + diff: True + +- name: "Delete group by its id" + ec2_group: + region: eu-west-1 + group_id: sg-33b4ee5b + state: absent +''' + +RETURN = ''' +group_name: + description: Security group name + sample: My Security Group + type: str + returned: on create/update +group_id: + description: Security group id + sample: sg-abcd1234 + type: str + returned: on create/update +description: + description: Description of security group + sample: My Security Group + type: str + returned: on create/update +tags: + description: Tags associated with the security group + sample: + Name: My Security Group + Purpose: protecting stuff + type: dict + returned: on create/update +vpc_id: + description: ID of VPC to which the security group belongs + sample: vpc-abcd1234 + type: str + returned: on create/update +ip_permissions: + description: Inbound rules associated with the security group. + sample: + - from_port: 8182 + ip_protocol: tcp + ip_ranges: + - cidr_ip: "1.1.1.1/32" + ipv6_ranges: [] + prefix_list_ids: [] + to_port: 8182 + user_id_group_pairs: [] + type: list + returned: on create/update +ip_permissions_egress: + description: Outbound rules associated with the security group. + sample: + - ip_protocol: -1 + ip_ranges: + - cidr_ip: "0.0.0.0/0" + ipv6_ranges: [] + prefix_list_ids: [] + user_id_group_pairs: [] + type: list + returned: on create/update +owner_id: + description: AWS Account ID of the security group + sample: 123456789012 + type: int + returned: on create/update +''' + +import json +import re +import itertools +from copy import deepcopy +from time import sleep +from collections import namedtuple +from ansible.module_utils.aws.core import AnsibleAWSModule, is_boto3_error_code +from ansible.module_utils.aws.iam import get_aws_account_id +from ansible.module_utils.aws.waiters import get_waiter +from ansible.module_utils.ec2 import AWSRetry, camel_dict_to_snake_dict, compare_aws_tags +from ansible.module_utils.ec2 import ansible_dict_to_boto3_filter_list, boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list +from ansible.module_utils.common.network import to_ipv6_subnet, to_subnet +from ansible.module_utils.compat.ipaddress import ip_network, IPv6Network +from ansible.module_utils._text import to_text +from ansible.module_utils.six import string_types + +try: + from botocore.exceptions import BotoCoreError, ClientError +except ImportError: + pass # caught by AnsibleAWSModule + + +Rule = namedtuple('Rule', ['port_range', 'protocol', 'target', 'target_type', 'description']) +valid_targets = set(['ipv4', 'ipv6', 'group', 'ip_prefix']) +current_account_id = None + + +def rule_cmp(a, b): + """Compare rules without descriptions""" + for prop in ['port_range', 'protocol', 'target', 'target_type']: + if prop == 'port_range' and to_text(a.protocol) == to_text(b.protocol): + # equal protocols can interchange `(-1, -1)` and `(None, None)` + if a.port_range in ((None, None), (-1, -1)) and b.port_range in ((None, None), (-1, -1)): + continue + elif getattr(a, prop) != getattr(b, prop): + return False + elif getattr(a, prop) != getattr(b, prop): + return False + return True + + +def rules_to_permissions(rules): + return [to_permission(rule) for rule in rules] + + +def to_permission(rule): + # take a Rule, output the serialized grant + perm = { + 'IpProtocol': rule.protocol, + } + perm['FromPort'], perm['ToPort'] = rule.port_range + if rule.target_type == 'ipv4': + perm['IpRanges'] = [{ + 'CidrIp': rule.target, + }] + if rule.description: + perm['IpRanges'][0]['Description'] = rule.description + elif rule.target_type == 'ipv6': + perm['Ipv6Ranges'] = [{ + 'CidrIpv6': rule.target, + }] + if rule.description: + perm['Ipv6Ranges'][0]['Description'] = rule.description + elif rule.target_type == 'group': + if isinstance(rule.target, tuple): + pair = {} + if rule.target[0]: + pair['UserId'] = rule.target[0] + # group_id/group_name are mutually exclusive - give group_id more precedence as it is more specific + if rule.target[1]: + pair['GroupId'] = rule.target[1] + elif rule.target[2]: + pair['GroupName'] = rule.target[2] + perm['UserIdGroupPairs'] = [pair] + else: + perm['UserIdGroupPairs'] = [{ + 'GroupId': rule.target + }] + if rule.description: + perm['UserIdGroupPairs'][0]['Description'] = rule.description + elif rule.target_type == 'ip_prefix': + perm['PrefixListIds'] = [{ + 'PrefixListId': rule.target, + }] + if rule.description: + perm['PrefixListIds'][0]['Description'] = rule.description + elif rule.target_type not in valid_targets: + raise ValueError('Invalid target type for rule {0}'.format(rule)) + return fix_port_and_protocol(perm) + + +def rule_from_group_permission(perm): + def ports_from_permission(p): + if 'FromPort' not in p and 'ToPort' not in p: + return (None, None) + return (int(perm['FromPort']), int(perm['ToPort'])) + + # outputs a rule tuple + for target_key, target_subkey, target_type in [ + ('IpRanges', 'CidrIp', 'ipv4'), + ('Ipv6Ranges', 'CidrIpv6', 'ipv6'), + ('PrefixListIds', 'PrefixListId', 'ip_prefix'), + ]: + if target_key not in perm: + continue + for r in perm[target_key]: + # there may be several IP ranges here, which is ok + yield Rule( + ports_from_permission(perm), + to_text(perm['IpProtocol']), + r[target_subkey], + target_type, + r.get('Description') + ) + if 'UserIdGroupPairs' in perm and perm['UserIdGroupPairs']: + for pair in perm['UserIdGroupPairs']: + target = ( + pair.get('UserId', None), + pair.get('GroupId', None), + pair.get('GroupName', None), + ) + if pair.get('UserId', '').startswith('amazon-'): + # amazon-elb and amazon-prefix rules don't need + # group-id specified, so remove it when querying + # from permission + target = ( + target[0], + None, + target[2], + ) + elif 'VpcPeeringConnectionId' in pair or pair['UserId'] != current_account_id: + target = ( + pair.get('UserId', None), + pair.get('GroupId', None), + pair.get('GroupName', None), + ) + + yield Rule( + ports_from_permission(perm), + to_text(perm['IpProtocol']), + target, + 'group', + pair.get('Description') + ) + + +@AWSRetry.backoff(tries=5, delay=5, backoff=2.0, catch_extra_error_codes=['InvalidGroup.NotFound']) +def get_security_groups_with_backoff(connection, **kwargs): + return connection.describe_security_groups(**kwargs) + + +@AWSRetry.backoff(tries=5, delay=5, backoff=2.0) +def sg_exists_with_backoff(connection, **kwargs): + try: + return connection.describe_security_groups(**kwargs) + except is_boto3_error_code('InvalidGroup.NotFound'): + return {'SecurityGroups': []} + + +def deduplicate_rules_args(rules): + """Returns unique rules""" + if rules is None: + return None + return list(dict(zip((json.dumps(r, sort_keys=True) for r in rules), rules)).values()) + + +def validate_rule(module, rule): + VALID_PARAMS = ('cidr_ip', 'cidr_ipv6', 'ip_prefix', + 'group_id', 'group_name', 'group_desc', + 'proto', 'from_port', 'to_port', 'rule_desc') + if not isinstance(rule, dict): + module.fail_json(msg='Invalid rule parameter type [%s].' % type(rule)) + for k in rule: + if k not in VALID_PARAMS: + module.fail_json(msg='Invalid rule parameter \'{0}\' for rule: {1}'.format(k, rule)) + + if 'group_id' in rule and 'cidr_ip' in rule: + module.fail_json(msg='Specify group_id OR cidr_ip, not both') + elif 'group_name' in rule and 'cidr_ip' in rule: + module.fail_json(msg='Specify group_name OR cidr_ip, not both') + elif 'group_id' in rule and 'cidr_ipv6' in rule: + module.fail_json(msg="Specify group_id OR cidr_ipv6, not both") + elif 'group_name' in rule and 'cidr_ipv6' in rule: + module.fail_json(msg="Specify group_name OR cidr_ipv6, not both") + elif 'cidr_ip' in rule and 'cidr_ipv6' in rule: + module.fail_json(msg="Specify cidr_ip OR cidr_ipv6, not both") + elif 'group_id' in rule and 'group_name' in rule: + module.fail_json(msg='Specify group_id OR group_name, not both') + + +def get_target_from_rule(module, client, rule, name, group, groups, vpc_id): + """ + Returns tuple of (target_type, target, group_created) after validating rule params. + + rule: Dict describing a rule. + name: Name of the security group being managed. + groups: Dict of all available security groups. + + AWS accepts an ip range or a security group as target of a rule. This + function validate the rule specification and return either a non-None + group_id or a non-None ip range. + """ + FOREIGN_SECURITY_GROUP_REGEX = r'^([^/]+)/?(sg-\S+)?/(\S+)' + group_id = None + group_name = None + target_group_created = False + + validate_rule(module, rule) + if rule.get('group_id') and re.match(FOREIGN_SECURITY_GROUP_REGEX, rule['group_id']): + # this is a foreign Security Group. Since you can't fetch it you must create an instance of it + owner_id, group_id, group_name = re.match(FOREIGN_SECURITY_GROUP_REGEX, rule['group_id']).groups() + group_instance = dict(UserId=owner_id, GroupId=group_id, GroupName=group_name) + groups[group_id] = group_instance + groups[group_name] = group_instance + # group_id/group_name are mutually exclusive - give group_id more precedence as it is more specific + if group_id and group_name: + group_name = None + return 'group', (owner_id, group_id, group_name), False + elif 'group_id' in rule: + return 'group', rule['group_id'], False + elif 'group_name' in rule: + group_name = rule['group_name'] + if group_name == name: + group_id = group['GroupId'] + groups[group_id] = group + groups[group_name] = group + elif group_name in groups and group.get('VpcId') and groups[group_name].get('VpcId'): + # both are VPC groups, this is ok + group_id = groups[group_name]['GroupId'] + elif group_name in groups and not (group.get('VpcId') or groups[group_name].get('VpcId')): + # both are EC2 classic, this is ok + group_id = groups[group_name]['GroupId'] + else: + auto_group = None + filters = {'group-name': group_name} + if vpc_id: + filters['vpc-id'] = vpc_id + # if we got here, either the target group does not exist, or there + # is a mix of EC2 classic + VPC groups. Mixing of EC2 classic + VPC + # is bad, so we have to create a new SG because no compatible group + # exists + if not rule.get('group_desc', '').strip(): + # retry describing the group once + try: + auto_group = get_security_groups_with_backoff(client, Filters=ansible_dict_to_boto3_filter_list(filters)).get('SecurityGroups', [])[0] + except (is_boto3_error_code('InvalidGroup.NotFound'), IndexError): + module.fail_json(msg="group %s will be automatically created by rule %s but " + "no description was provided" % (group_name, rule)) + except ClientError as e: # pylint: disable=duplicate-except + module.fail_json_aws(e) + elif not module.check_mode: + params = dict(GroupName=group_name, Description=rule['group_desc']) + if vpc_id: + params['VpcId'] = vpc_id + try: + auto_group = client.create_security_group(**params) + get_waiter( + client, 'security_group_exists', + ).wait( + GroupIds=[auto_group['GroupId']], + ) + except is_boto3_error_code('InvalidGroup.Duplicate'): + # The group exists, but didn't show up in any of our describe-security-groups calls + # Try searching on a filter for the name, and allow a retry window for AWS to update + # the model on their end. + try: + auto_group = get_security_groups_with_backoff(client, Filters=ansible_dict_to_boto3_filter_list(filters)).get('SecurityGroups', [])[0] + except IndexError as e: + module.fail_json(msg="Could not create or use existing group '{0}' in rule. Make sure the group exists".format(group_name)) + except ClientError as e: + module.fail_json_aws( + e, + msg="Could not create or use existing group '{0}' in rule. Make sure the group exists".format(group_name)) + if auto_group is not None: + group_id = auto_group['GroupId'] + groups[group_id] = auto_group + groups[group_name] = auto_group + target_group_created = True + return 'group', group_id, target_group_created + elif 'cidr_ip' in rule: + return 'ipv4', validate_ip(module, rule['cidr_ip']), False + elif 'cidr_ipv6' in rule: + return 'ipv6', validate_ip(module, rule['cidr_ipv6']), False + elif 'ip_prefix' in rule: + return 'ip_prefix', rule['ip_prefix'], False + + module.fail_json(msg="Could not match target for rule {0}".format(rule), failed_rule=rule) + + +def ports_expand(ports): + # takes a list of ports and returns a list of (port_from, port_to) + ports_expanded = [] + for port in ports: + if not isinstance(port, string_types): + ports_expanded.append((port,) * 2) + elif '-' in port: + ports_expanded.append(tuple(int(p.strip()) for p in port.split('-', 1))) + else: + ports_expanded.append((int(port.strip()),) * 2) + + return ports_expanded + + +def rule_expand_ports(rule): + # takes a rule dict and returns a list of expanded rule dicts + if 'ports' not in rule: + if isinstance(rule.get('from_port'), string_types): + rule['from_port'] = int(rule.get('from_port')) + if isinstance(rule.get('to_port'), string_types): + rule['to_port'] = int(rule.get('to_port')) + return [rule] + + ports = rule['ports'] if isinstance(rule['ports'], list) else [rule['ports']] + + rule_expanded = [] + for from_to in ports_expand(ports): + temp_rule = rule.copy() + del temp_rule['ports'] + temp_rule['from_port'], temp_rule['to_port'] = sorted(from_to) + rule_expanded.append(temp_rule) + + return rule_expanded + + +def rules_expand_ports(rules): + # takes a list of rules and expands it based on 'ports' + if not rules: + return rules + + return [rule for rule_complex in rules + for rule in rule_expand_ports(rule_complex)] + + +def rule_expand_source(rule, source_type): + # takes a rule dict and returns a list of expanded rule dicts for specified source_type + sources = rule[source_type] if isinstance(rule[source_type], list) else [rule[source_type]] + source_types_all = ('cidr_ip', 'cidr_ipv6', 'group_id', 'group_name', 'ip_prefix') + + rule_expanded = [] + for source in sources: + temp_rule = rule.copy() + for s in source_types_all: + temp_rule.pop(s, None) + temp_rule[source_type] = source + rule_expanded.append(temp_rule) + + return rule_expanded + + +def rule_expand_sources(rule): + # takes a rule dict and returns a list of expanded rule discts + source_types = (stype for stype in ('cidr_ip', 'cidr_ipv6', 'group_id', 'group_name', 'ip_prefix') if stype in rule) + + return [r for stype in source_types + for r in rule_expand_source(rule, stype)] + + +def rules_expand_sources(rules): + # takes a list of rules and expands it based on 'cidr_ip', 'group_id', 'group_name' + if not rules: + return rules + + return [rule for rule_complex in rules + for rule in rule_expand_sources(rule_complex)] + + +def update_rules_description(module, client, rule_type, group_id, ip_permissions): + if module.check_mode: + return + try: + if rule_type == "in": + client.update_security_group_rule_descriptions_ingress(GroupId=group_id, IpPermissions=ip_permissions) + if rule_type == "out": + client.update_security_group_rule_descriptions_egress(GroupId=group_id, IpPermissions=ip_permissions) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to update rule description for group %s" % group_id) + + +def fix_port_and_protocol(permission): + for key in ('FromPort', 'ToPort'): + if key in permission: + if permission[key] is None: + del permission[key] + else: + permission[key] = int(permission[key]) + + permission['IpProtocol'] = to_text(permission['IpProtocol']) + + return permission + + +def remove_old_permissions(client, module, revoke_ingress, revoke_egress, group_id): + if revoke_ingress: + revoke(client, module, revoke_ingress, group_id, 'in') + if revoke_egress: + revoke(client, module, revoke_egress, group_id, 'out') + return bool(revoke_ingress or revoke_egress) + + +def revoke(client, module, ip_permissions, group_id, rule_type): + if not module.check_mode: + try: + if rule_type == 'in': + client.revoke_security_group_ingress(GroupId=group_id, IpPermissions=ip_permissions) + elif rule_type == 'out': + client.revoke_security_group_egress(GroupId=group_id, IpPermissions=ip_permissions) + except (BotoCoreError, ClientError) as e: + rules = 'ingress rules' if rule_type == 'in' else 'egress rules' + module.fail_json_aws(e, "Unable to revoke {0}: {1}".format(rules, ip_permissions)) + + +def add_new_permissions(client, module, new_ingress, new_egress, group_id): + if new_ingress: + authorize(client, module, new_ingress, group_id, 'in') + if new_egress: + authorize(client, module, new_egress, group_id, 'out') + return bool(new_ingress or new_egress) + + +def authorize(client, module, ip_permissions, group_id, rule_type): + if not module.check_mode: + try: + if rule_type == 'in': + client.authorize_security_group_ingress(GroupId=group_id, IpPermissions=ip_permissions) + elif rule_type == 'out': + client.authorize_security_group_egress(GroupId=group_id, IpPermissions=ip_permissions) + except (BotoCoreError, ClientError) as e: + rules = 'ingress rules' if rule_type == 'in' else 'egress rules' + module.fail_json_aws(e, "Unable to authorize {0}: {1}".format(rules, ip_permissions)) + + +def validate_ip(module, cidr_ip): + split_addr = cidr_ip.split('/') + if len(split_addr) == 2: + # this_ip is a IPv4 or IPv6 CIDR that may or may not have host bits set + # Get the network bits if IPv4, and validate if IPv6. + try: + ip = to_subnet(split_addr[0], split_addr[1]) + if ip != cidr_ip: + module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, " + "check the network mask and make sure that only network bits are set: {1}.".format( + cidr_ip, ip)) + except ValueError: + # to_subnet throws a ValueError on IPv6 networks, so we should be working with v6 if we get here + try: + isinstance(ip_network(to_text(cidr_ip)), IPv6Network) + ip = cidr_ip + except ValueError: + # If a host bit is set on something other than a /128, IPv6Network will throw a ValueError + # The ipv6_cidr in this case probably looks like "2001:DB8:A0B:12F0::1/64" and we just want the network bits + ip6 = to_ipv6_subnet(split_addr[0]) + "/" + split_addr[1] + if ip6 != cidr_ip: + module.warn("One of your IPv6 CIDR addresses ({0}) has host bits set. To get rid of this warning, " + "check the network mask and make sure that only network bits are set: {1}.".format(cidr_ip, ip6)) + return ip6 + return ip + return cidr_ip + + +def update_tags(client, module, group_id, current_tags, tags, purge_tags): + tags_need_modify, tags_to_delete = compare_aws_tags(current_tags, tags, purge_tags) + + if not module.check_mode: + if tags_to_delete: + try: + client.delete_tags(Resources=[group_id], Tags=[{'Key': tag} for tag in tags_to_delete]) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to delete tags {0}".format(tags_to_delete)) + + # Add/update tags + if tags_need_modify: + try: + client.create_tags(Resources=[group_id], Tags=ansible_dict_to_boto3_tag_list(tags_need_modify)) + except (BotoCoreError, ClientError) as e: + module.fail_json(e, msg="Unable to add tags {0}".format(tags_need_modify)) + + return bool(tags_need_modify or tags_to_delete) + + +def update_rule_descriptions(module, group_id, present_ingress, named_tuple_ingress_list, present_egress, named_tuple_egress_list): + changed = False + client = module.client('ec2') + ingress_needs_desc_update = [] + egress_needs_desc_update = [] + + for present_rule in present_egress: + needs_update = [r for r in named_tuple_egress_list if rule_cmp(r, present_rule) and r.description != present_rule.description] + for r in needs_update: + named_tuple_egress_list.remove(r) + egress_needs_desc_update.extend(needs_update) + for present_rule in present_ingress: + needs_update = [r for r in named_tuple_ingress_list if rule_cmp(r, present_rule) and r.description != present_rule.description] + for r in needs_update: + named_tuple_ingress_list.remove(r) + ingress_needs_desc_update.extend(needs_update) + + if ingress_needs_desc_update: + update_rules_description(module, client, 'in', group_id, rules_to_permissions(ingress_needs_desc_update)) + changed |= True + if egress_needs_desc_update: + update_rules_description(module, client, 'out', group_id, rules_to_permissions(egress_needs_desc_update)) + changed |= True + return changed + + +def create_security_group(client, module, name, description, vpc_id): + if not module.check_mode: + params = dict(GroupName=name, Description=description) + if vpc_id: + params['VpcId'] = vpc_id + try: + group = client.create_security_group(**params) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to create security group") + # When a group is created, an egress_rule ALLOW ALL + # to 0.0.0.0/0 is added automatically but it's not + # reflected in the object returned by the AWS API + # call. We re-read the group for getting an updated object + # amazon sometimes takes a couple seconds to update the security group so wait till it exists + while True: + sleep(3) + group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0] + if group.get('VpcId') and not group.get('IpPermissionsEgress'): + pass + else: + break + return group + return None + + +def wait_for_rule_propagation(module, group, desired_ingress, desired_egress, purge_ingress, purge_egress): + group_id = group['GroupId'] + tries = 6 + + def await_rules(group, desired_rules, purge, rule_key): + for i in range(tries): + current_rules = set(sum([list(rule_from_group_permission(p)) for p in group[rule_key]], [])) + if purge and len(current_rules ^ set(desired_rules)) == 0: + return group + elif purge: + conflicts = current_rules ^ set(desired_rules) + # For cases where set comparison is equivalent, but invalid port/proto exist + for a, b in itertools.combinations(conflicts, 2): + if rule_cmp(a, b): + conflicts.discard(a) + conflicts.discard(b) + if not len(conflicts): + return group + elif current_rules.issuperset(desired_rules) and not purge: + return group + sleep(10) + group = get_security_groups_with_backoff(module.client('ec2'), GroupIds=[group_id])['SecurityGroups'][0] + module.warn("Ran out of time waiting for {0} {1}. Current: {2}, Desired: {3}".format(group_id, rule_key, current_rules, desired_rules)) + return group + + group = get_security_groups_with_backoff(module.client('ec2'), GroupIds=[group_id])['SecurityGroups'][0] + if 'VpcId' in group and module.params.get('rules_egress') is not None: + group = await_rules(group, desired_egress, purge_egress, 'IpPermissionsEgress') + return await_rules(group, desired_ingress, purge_ingress, 'IpPermissions') + + +def group_exists(client, module, vpc_id, group_id, name): + params = {'Filters': []} + if group_id: + params['GroupIds'] = [group_id] + if name: + # Add name to filters rather than params['GroupNames'] + # because params['GroupNames'] only checks the default vpc if no vpc is provided + params['Filters'].append({'Name': 'group-name', 'Values': [name]}) + if vpc_id: + params['Filters'].append({'Name': 'vpc-id', 'Values': [vpc_id]}) + # Don't filter by description to maintain backwards compatibility + + try: + security_groups = sg_exists_with_backoff(client, **params).get('SecurityGroups', []) + all_groups = get_security_groups_with_backoff(client).get('SecurityGroups', []) + except (BotoCoreError, ClientError) as e: # pylint: disable=duplicate-except + module.fail_json_aws(e, msg="Error in describe_security_groups") + + if security_groups: + groups = dict((group['GroupId'], group) for group in all_groups) + groups.update(dict((group['GroupName'], group) for group in all_groups)) + if vpc_id: + vpc_wins = dict((group['GroupName'], group) for group in all_groups if group.get('VpcId') and group['VpcId'] == vpc_id) + groups.update(vpc_wins) + # maintain backwards compatibility by using the last matching group + return security_groups[-1], groups + return None, {} + + +def verify_rules_with_descriptions_permitted(client, module, rules, rules_egress): + if not hasattr(client, "update_security_group_rule_descriptions_egress"): + all_rules = rules if rules else [] + rules_egress if rules_egress else [] + if any('rule_desc' in rule for rule in all_rules): + module.fail_json(msg="Using rule descriptions requires botocore version >= 1.7.2.") + + +def get_diff_final_resource(client, module, security_group): + def get_account_id(security_group, module): + try: + owner_id = security_group.get('owner_id', module.client('sts').get_caller_identity()['Account']) + except (BotoCoreError, ClientError) as e: + owner_id = "Unable to determine owner_id: {0}".format(to_text(e)) + return owner_id + + def get_final_tags(security_group_tags, specified_tags, purge_tags): + if specified_tags is None: + return security_group_tags + tags_need_modify, tags_to_delete = compare_aws_tags(security_group_tags, specified_tags, purge_tags) + end_result_tags = dict((k, v) for k, v in specified_tags.items() if k not in tags_to_delete) + end_result_tags.update(dict((k, v) for k, v in security_group_tags.items() if k not in tags_to_delete)) + end_result_tags.update(tags_need_modify) + return end_result_tags + + def get_final_rules(client, module, security_group_rules, specified_rules, purge_rules): + if specified_rules is None: + return security_group_rules + if purge_rules: + final_rules = [] + else: + final_rules = list(security_group_rules) + specified_rules = flatten_nested_targets(module, deepcopy(specified_rules)) + for rule in specified_rules: + format_rule = { + 'from_port': None, 'to_port': None, 'ip_protocol': rule.get('proto', 'tcp'), + 'ip_ranges': [], 'ipv6_ranges': [], 'prefix_list_ids': [], 'user_id_group_pairs': [] + } + if rule.get('proto', 'tcp') in ('all', '-1', -1): + format_rule['ip_protocol'] = '-1' + format_rule.pop('from_port') + format_rule.pop('to_port') + elif rule.get('ports'): + if rule.get('ports') and (isinstance(rule['ports'], string_types) or isinstance(rule['ports'], int)): + rule['ports'] = [rule['ports']] + for port in rule.get('ports'): + if isinstance(port, string_types) and '-' in port: + format_rule['from_port'], format_rule['to_port'] = port.split('-') + else: + format_rule['from_port'] = format_rule['to_port'] = port + elif rule.get('from_port') or rule.get('to_port'): + format_rule['from_port'] = rule.get('from_port', rule.get('to_port')) + format_rule['to_port'] = rule.get('to_port', rule.get('from_port')) + for source_type in ('cidr_ip', 'cidr_ipv6', 'prefix_list_id'): + if rule.get(source_type): + rule_key = {'cidr_ip': 'ip_ranges', 'cidr_ipv6': 'ipv6_ranges', 'prefix_list_id': 'prefix_list_ids'}.get(source_type) + if rule.get('rule_desc'): + format_rule[rule_key] = [{source_type: rule[source_type], 'description': rule['rule_desc']}] + else: + if not isinstance(rule[source_type], list): + rule[source_type] = [rule[source_type]] + format_rule[rule_key] = [{source_type: target} for target in rule[source_type]] + if rule.get('group_id') or rule.get('group_name'): + rule_sg = camel_dict_to_snake_dict(group_exists(client, module, module.params['vpc_id'], rule.get('group_id'), rule.get('group_name'))[0]) + format_rule['user_id_group_pairs'] = [{ + 'description': rule_sg.get('description', rule_sg.get('group_desc')), + 'group_id': rule_sg.get('group_id', rule.get('group_id')), + 'group_name': rule_sg.get('group_name', rule.get('group_name')), + 'peering_status': rule_sg.get('peering_status'), + 'user_id': rule_sg.get('user_id', get_account_id(security_group, module)), + 'vpc_id': rule_sg.get('vpc_id', module.params['vpc_id']), + 'vpc_peering_connection_id': rule_sg.get('vpc_peering_connection_id') + }] + for k, v in list(format_rule['user_id_group_pairs'][0].items()): + if v is None: + format_rule['user_id_group_pairs'][0].pop(k) + final_rules.append(format_rule) + # Order final rules consistently + final_rules.sort(key=get_ip_permissions_sort_key) + return final_rules + security_group_ingress = security_group.get('ip_permissions', []) + specified_ingress = module.params['rules'] + purge_ingress = module.params['purge_rules'] + security_group_egress = security_group.get('ip_permissions_egress', []) + specified_egress = module.params['rules_egress'] + purge_egress = module.params['purge_rules_egress'] + return { + 'description': module.params['description'], + 'group_id': security_group.get('group_id', 'sg-xxxxxxxx'), + 'group_name': security_group.get('group_name', module.params['name']), + 'ip_permissions': get_final_rules(client, module, security_group_ingress, specified_ingress, purge_ingress), + 'ip_permissions_egress': get_final_rules(client, module, security_group_egress, specified_egress, purge_egress), + 'owner_id': get_account_id(security_group, module), + 'tags': get_final_tags(security_group.get('tags', {}), module.params['tags'], module.params['purge_tags']), + 'vpc_id': security_group.get('vpc_id', module.params['vpc_id'])} + + +def flatten_nested_targets(module, rules): + def _flatten(targets): + for target in targets: + if isinstance(target, list): + for t in _flatten(target): + yield t + elif isinstance(target, string_types): + yield target + + if rules is not None: + for rule in rules: + target_list_type = None + if isinstance(rule.get('cidr_ip'), list): + target_list_type = 'cidr_ip' + elif isinstance(rule.get('cidr_ipv6'), list): + target_list_type = 'cidr_ipv6' + if target_list_type is not None: + rule[target_list_type] = list(_flatten(rule[target_list_type])) + return rules + + +def get_rule_sort_key(dicts): + if dicts.get('cidr_ip'): + return dicts.get('cidr_ip') + elif dicts.get('cidr_ipv6'): + return dicts.get('cidr_ipv6') + elif dicts.get('prefix_list_id'): + return dicts.get('prefix_list_id') + elif dicts.get('group_id'): + return dicts.get('group_id') + return None + + +def get_ip_permissions_sort_key(rule): + if rule.get('ip_ranges'): + rule.get('ip_ranges').sort(key=get_rule_sort_key) + return rule.get('ip_ranges')[0]['cidr_ip'] + elif rule.get('ipv6_ranges'): + rule.get('ipv6_ranges').sort(key=get_rule_sort_key) + return rule.get('ipv6_ranges')[0]['cidr_ipv6'] + elif rule.get('prefix_list_ids'): + rule.get('prefix_list_ids').sort(key=get_rule_sort_key) + return rule.get('prefix_list_ids')[0]['prefix_list_id'] + elif rule.get('user_id_group_pairs'): + rule.get('user_id_group_pairs').sort(key=get_rule_sort_key) + return rule.get('user_id_group_pairs')[0]['group_id'] + return None + + +def main(): + argument_spec = dict( + name=dict(), + group_id=dict(), + description=dict(), + vpc_id=dict(), + rules=dict(type='list'), + rules_egress=dict(type='list'), + state=dict(default='present', type='str', choices=['present', 'absent']), + purge_rules=dict(default=True, required=False, type='bool'), + purge_rules_egress=dict(default=True, required=False, type='bool'), + tags=dict(required=False, type='dict', aliases=['resource_tags']), + purge_tags=dict(default=True, required=False, type='bool') + ) + module = AnsibleAWSModule( + argument_spec=argument_spec, + supports_check_mode=True, + required_one_of=[['name', 'group_id']], + required_if=[['state', 'present', ['name']]], + ) + + name = module.params['name'] + group_id = module.params['group_id'] + description = module.params['description'] + vpc_id = module.params['vpc_id'] + rules = flatten_nested_targets(module, deepcopy(module.params['rules'])) + rules_egress = flatten_nested_targets(module, deepcopy(module.params['rules_egress'])) + rules = deduplicate_rules_args(rules_expand_sources(rules_expand_ports(rules))) + rules_egress = deduplicate_rules_args(rules_expand_sources(rules_expand_ports(rules_egress))) + state = module.params.get('state') + purge_rules = module.params['purge_rules'] + purge_rules_egress = module.params['purge_rules_egress'] + tags = module.params['tags'] + purge_tags = module.params['purge_tags'] + + if state == 'present' and not description: + module.fail_json(msg='Must provide description when state is present.') + + changed = False + client = module.client('ec2') + + verify_rules_with_descriptions_permitted(client, module, rules, rules_egress) + group, groups = group_exists(client, module, vpc_id, group_id, name) + group_created_new = not bool(group) + + global current_account_id + current_account_id = get_aws_account_id(module) + + before = {} + after = {} + + # Ensure requested group is absent + if state == 'absent': + if group: + # found a match, delete it + before = camel_dict_to_snake_dict(group, ignore_list=['Tags']) + before['tags'] = boto3_tag_list_to_ansible_dict(before.get('tags', [])) + try: + if not module.check_mode: + client.delete_security_group(GroupId=group['GroupId']) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to delete security group '%s'" % group) + else: + group = None + changed = True + else: + # no match found, no changes required + pass + + # Ensure requested group is present + elif state == 'present': + if group: + # existing group + before = camel_dict_to_snake_dict(group, ignore_list=['Tags']) + before['tags'] = boto3_tag_list_to_ansible_dict(before.get('tags', [])) + if group['Description'] != description: + module.warn("Group description does not match existing group. Descriptions cannot be changed without deleting " + "and re-creating the security group. Try using state=absent to delete, then rerunning this task.") + else: + # no match found, create it + group = create_security_group(client, module, name, description, vpc_id) + changed = True + + if tags is not None and group is not None: + current_tags = boto3_tag_list_to_ansible_dict(group.get('Tags', [])) + changed |= update_tags(client, module, group['GroupId'], current_tags, tags, purge_tags) + + if group: + named_tuple_ingress_list = [] + named_tuple_egress_list = [] + current_ingress = sum([list(rule_from_group_permission(p)) for p in group['IpPermissions']], []) + current_egress = sum([list(rule_from_group_permission(p)) for p in group['IpPermissionsEgress']], []) + + for new_rules, rule_type, named_tuple_rule_list in [(rules, 'in', named_tuple_ingress_list), + (rules_egress, 'out', named_tuple_egress_list)]: + if new_rules is None: + continue + for rule in new_rules: + target_type, target, target_group_created = get_target_from_rule( + module, client, rule, name, group, groups, vpc_id) + changed |= target_group_created + + if rule.get('proto', 'tcp') in ('all', '-1', -1): + rule['proto'] = '-1' + rule['from_port'] = None + rule['to_port'] = None + try: + int(rule.get('proto', 'tcp')) + rule['proto'] = to_text(rule.get('proto', 'tcp')) + rule['from_port'] = None + rule['to_port'] = None + except ValueError: + # rule does not use numeric protocol spec + pass + + named_tuple_rule_list.append( + Rule( + port_range=(rule['from_port'], rule['to_port']), + protocol=to_text(rule.get('proto', 'tcp')), + target=target, target_type=target_type, + description=rule.get('rule_desc'), + ) + ) + + # List comprehensions for rules to add, rules to modify, and rule ids to determine purging + new_ingress_permissions = [to_permission(r) for r in (set(named_tuple_ingress_list) - set(current_ingress))] + new_egress_permissions = [to_permission(r) for r in (set(named_tuple_egress_list) - set(current_egress))] + + if module.params.get('rules_egress') is None and 'VpcId' in group: + # when no egress rules are specified and we're in a VPC, + # we add in a default allow all out rule, which was the + # default behavior before egress rules were added + rule = Rule((None, None), '-1', '0.0.0.0/0', 'ipv4', None) + if rule in current_egress: + named_tuple_egress_list.append(rule) + if rule not in current_egress: + current_egress.append(rule) + + # List comprehensions for rules to add, rules to modify, and rule ids to determine purging + present_ingress = list(set(named_tuple_ingress_list).union(set(current_ingress))) + present_egress = list(set(named_tuple_egress_list).union(set(current_egress))) + + if purge_rules: + revoke_ingress = [] + for p in present_ingress: + if not any([rule_cmp(p, b) for b in named_tuple_ingress_list]): + revoke_ingress.append(to_permission(p)) + else: + revoke_ingress = [] + if purge_rules_egress and module.params.get('rules_egress') is not None: + if module.params.get('rules_egress') is []: + revoke_egress = [ + to_permission(r) for r in set(present_egress) - set(named_tuple_egress_list) + if r != Rule((None, None), '-1', '0.0.0.0/0', 'ipv4', None) + ] + else: + revoke_egress = [] + for p in present_egress: + if not any([rule_cmp(p, b) for b in named_tuple_egress_list]): + revoke_egress.append(to_permission(p)) + else: + revoke_egress = [] + + # named_tuple_ingress_list and named_tuple_egress_list got updated by + # method update_rule_descriptions, deep copy these two lists to new + # variables for the record of the 'desired' ingress and egress sg permissions + desired_ingress = deepcopy(named_tuple_ingress_list) + desired_egress = deepcopy(named_tuple_egress_list) + + changed |= update_rule_descriptions(module, group['GroupId'], present_ingress, named_tuple_ingress_list, present_egress, named_tuple_egress_list) + + # Revoke old rules + changed |= remove_old_permissions(client, module, revoke_ingress, revoke_egress, group['GroupId']) + rule_msg = 'Revoking {0}, and egress {1}'.format(revoke_ingress, revoke_egress) + + new_ingress_permissions = [to_permission(r) for r in (set(named_tuple_ingress_list) - set(current_ingress))] + new_ingress_permissions = rules_to_permissions(set(named_tuple_ingress_list) - set(current_ingress)) + new_egress_permissions = rules_to_permissions(set(named_tuple_egress_list) - set(current_egress)) + # Authorize new rules + changed |= add_new_permissions(client, module, new_ingress_permissions, new_egress_permissions, group['GroupId']) + + if group_created_new and module.params.get('rules') is None and module.params.get('rules_egress') is None: + # A new group with no rules provided is already being awaited. + # When it is created we wait for the default egress rule to be added by AWS + security_group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0] + elif changed and not module.check_mode: + # keep pulling until current security group rules match the desired ingress and egress rules + security_group = wait_for_rule_propagation(module, group, desired_ingress, desired_egress, purge_rules, purge_rules_egress) + else: + security_group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0] + security_group = camel_dict_to_snake_dict(security_group, ignore_list=['Tags']) + security_group['tags'] = boto3_tag_list_to_ansible_dict(security_group.get('tags', [])) + + else: + security_group = {'group_id': None} + + if module._diff: + if module.params['state'] == 'present': + after = get_diff_final_resource(client, module, security_group) + if before.get('ip_permissions'): + before['ip_permissions'].sort(key=get_ip_permissions_sort_key) + + security_group['diff'] = [{'before': before, 'after': after}] + + module.exit_json(changed=changed, **security_group) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_vpc_net.py b/test/support/integration/plugins/modules/ec2_vpc_net.py new file mode 100644 index 00000000..30e4b1e9 --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_vpc_net.py @@ -0,0 +1,524 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: ec2_vpc_net +short_description: Configure AWS virtual private clouds +description: + - Create, modify, and terminate AWS virtual private clouds. +version_added: "2.0" +author: + - Jonathan Davila (@defionscode) + - Sloane Hertel (@s-hertel) +options: + name: + description: + - The name to give your VPC. This is used in combination with C(cidr_block) to determine if a VPC already exists. + required: yes + type: str + cidr_block: + description: + - The primary CIDR of the VPC. After 2.5 a list of CIDRs can be provided. The first in the list will be used as the primary CIDR + and is used in conjunction with the C(name) to ensure idempotence. + required: yes + type: list + elements: str + ipv6_cidr: + description: + - Request an Amazon-provided IPv6 CIDR block with /56 prefix length. You cannot specify the range of IPv6 addresses, + or the size of the CIDR block. + default: False + type: bool + version_added: '2.10' + purge_cidrs: + description: + - Remove CIDRs that are associated with the VPC and are not specified in C(cidr_block). + default: no + type: bool + version_added: '2.5' + tenancy: + description: + - Whether to be default or dedicated tenancy. This cannot be changed after the VPC has been created. + default: default + choices: [ 'default', 'dedicated' ] + type: str + dns_support: + description: + - Whether to enable AWS DNS support. + default: yes + type: bool + dns_hostnames: + description: + - Whether to enable AWS hostname support. + default: yes + type: bool + dhcp_opts_id: + description: + - The id of the DHCP options to use for this VPC. + type: str + tags: + description: + - The tags you want attached to the VPC. This is independent of the name value, note if you pass a 'Name' key it would override the Name of + the VPC if it's different. + aliases: [ 'resource_tags' ] + type: dict + state: + description: + - The state of the VPC. Either absent or present. + default: present + choices: [ 'present', 'absent' ] + type: str + multi_ok: + description: + - By default the module will not create another VPC if there is another VPC with the same name and CIDR block. Specify this as true if you want + duplicate VPCs created. + type: bool + default: false +requirements: + - boto3 + - botocore +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +- name: create a VPC with dedicated tenancy and a couple of tags + ec2_vpc_net: + name: Module_dev2 + cidr_block: 10.10.0.0/16 + region: us-east-1 + tags: + module: ec2_vpc_net + this: works + tenancy: dedicated + +- name: create a VPC with dedicated tenancy and request an IPv6 CIDR + ec2_vpc_net: + name: Module_dev2 + cidr_block: 10.10.0.0/16 + ipv6_cidr: True + region: us-east-1 + tenancy: dedicated +''' + +RETURN = ''' +vpc: + description: info about the VPC that was created or deleted + returned: always + type: complex + contains: + cidr_block: + description: The CIDR of the VPC + returned: always + type: str + sample: 10.0.0.0/16 + cidr_block_association_set: + description: IPv4 CIDR blocks associated with the VPC + returned: success + type: list + sample: + "cidr_block_association_set": [ + { + "association_id": "vpc-cidr-assoc-97aeeefd", + "cidr_block": "20.0.0.0/24", + "cidr_block_state": { + "state": "associated" + } + } + ] + classic_link_enabled: + description: indicates whether ClassicLink is enabled + returned: always + type: bool + sample: false + dhcp_options_id: + description: the id of the DHCP options associated with this VPC + returned: always + type: str + sample: dopt-0fb8bd6b + id: + description: VPC resource id + returned: always + type: str + sample: vpc-c2e00da5 + instance_tenancy: + description: indicates whether VPC uses default or dedicated tenancy + returned: always + type: str + sample: default + ipv6_cidr_block_association_set: + description: IPv6 CIDR blocks associated with the VPC + returned: success + type: list + sample: + "ipv6_cidr_block_association_set": [ + { + "association_id": "vpc-cidr-assoc-97aeeefd", + "ipv6_cidr_block": "2001:db8::/56", + "ipv6_cidr_block_state": { + "state": "associated" + } + } + ] + is_default: + description: indicates whether this is the default VPC + returned: always + type: bool + sample: false + state: + description: state of the VPC + returned: always + type: str + sample: available + tags: + description: tags attached to the VPC, includes name + returned: always + type: complex + contains: + Name: + description: name tag for the VPC + returned: always + type: str + sample: pk_vpc4 +''' + +try: + import botocore +except ImportError: + pass # Handled by AnsibleAWSModule + +from time import sleep, time +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import (AWSRetry, camel_dict_to_snake_dict, compare_aws_tags, + ansible_dict_to_boto3_tag_list, boto3_tag_list_to_ansible_dict) +from ansible.module_utils.six import string_types +from ansible.module_utils._text import to_native +from ansible.module_utils.network.common.utils import to_subnet + + +def vpc_exists(module, vpc, name, cidr_block, multi): + """Returns None or a vpc object depending on the existence of a VPC. When supplied + with a CIDR, it will check for matching tags to determine if it is a match + otherwise it will assume the VPC does not exist and thus return None. + """ + try: + matching_vpcs = vpc.describe_vpcs(Filters=[{'Name': 'tag:Name', 'Values': [name]}, {'Name': 'cidr-block', 'Values': cidr_block}])['Vpcs'] + # If an exact matching using a list of CIDRs isn't found, check for a match with the first CIDR as is documented for C(cidr_block) + if not matching_vpcs: + matching_vpcs = vpc.describe_vpcs(Filters=[{'Name': 'tag:Name', 'Values': [name]}, {'Name': 'cidr-block', 'Values': [cidr_block[0]]}])['Vpcs'] + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to describe VPCs") + + if multi: + return None + elif len(matching_vpcs) == 1: + return matching_vpcs[0]['VpcId'] + elif len(matching_vpcs) > 1: + module.fail_json(msg='Currently there are %d VPCs that have the same name and ' + 'CIDR block you specified. If you would like to create ' + 'the VPC anyway please pass True to the multi_ok param.' % len(matching_vpcs)) + return None + + +@AWSRetry.backoff(delay=3, tries=8, catch_extra_error_codes=['InvalidVpcID.NotFound']) +def get_classic_link_with_backoff(connection, vpc_id): + try: + return connection.describe_vpc_classic_link(VpcIds=[vpc_id])['Vpcs'][0].get('ClassicLinkEnabled') + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Message"] == "The functionality you requested is not available in this region.": + return False + else: + raise + + +def get_vpc(module, connection, vpc_id): + # wait for vpc to be available + try: + connection.get_waiter('vpc_available').wait(VpcIds=[vpc_id]) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to wait for VPC {0} to be available.".format(vpc_id)) + + try: + vpc_obj = connection.describe_vpcs(VpcIds=[vpc_id], aws_retry=True)['Vpcs'][0] + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to describe VPCs") + try: + vpc_obj['ClassicLinkEnabled'] = get_classic_link_with_backoff(connection, vpc_id) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to describe VPCs") + + return vpc_obj + + +def update_vpc_tags(connection, module, vpc_id, tags, name): + if tags is None: + tags = dict() + + tags.update({'Name': name}) + tags = dict((k, to_native(v)) for k, v in tags.items()) + try: + current_tags = dict((t['Key'], t['Value']) for t in connection.describe_tags(Filters=[{'Name': 'resource-id', 'Values': [vpc_id]}])['Tags']) + tags_to_update, dummy = compare_aws_tags(current_tags, tags, False) + if tags_to_update: + if not module.check_mode: + tags = ansible_dict_to_boto3_tag_list(tags_to_update) + vpc_obj = connection.create_tags(Resources=[vpc_id], Tags=tags, aws_retry=True) + + # Wait for tags to be updated + expected_tags = boto3_tag_list_to_ansible_dict(tags) + filters = [{'Name': 'tag:{0}'.format(key), 'Values': [value]} for key, value in expected_tags.items()] + connection.get_waiter('vpc_available').wait(VpcIds=[vpc_id], Filters=filters) + + return True + else: + return False + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to update tags") + + +def update_dhcp_opts(connection, module, vpc_obj, dhcp_id): + if vpc_obj['DhcpOptionsId'] != dhcp_id: + if not module.check_mode: + try: + connection.associate_dhcp_options(DhcpOptionsId=dhcp_id, VpcId=vpc_obj['VpcId']) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to associate DhcpOptionsId {0}".format(dhcp_id)) + + try: + # Wait for DhcpOptionsId to be updated + filters = [{'Name': 'dhcp-options-id', 'Values': [dhcp_id]}] + connection.get_waiter('vpc_available').wait(VpcIds=[vpc_obj['VpcId']], Filters=filters) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json(msg="Failed to wait for DhcpOptionsId to be updated") + + return True + else: + return False + + +def create_vpc(connection, module, cidr_block, tenancy): + try: + if not module.check_mode: + vpc_obj = connection.create_vpc(CidrBlock=cidr_block, InstanceTenancy=tenancy) + else: + module.exit_json(changed=True) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to create the VPC") + + # wait for vpc to exist + try: + connection.get_waiter('vpc_exists').wait(VpcIds=[vpc_obj['Vpc']['VpcId']]) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to wait for VPC {0} to be created.".format(vpc_obj['Vpc']['VpcId'])) + + return vpc_obj['Vpc']['VpcId'] + + +def wait_for_vpc_attribute(connection, module, vpc_id, attribute, expected_value): + start_time = time() + updated = False + while time() < start_time + 300: + current_value = connection.describe_vpc_attribute( + Attribute=attribute, + VpcId=vpc_id + )['{0}{1}'.format(attribute[0].upper(), attribute[1:])]['Value'] + if current_value != expected_value: + sleep(3) + else: + updated = True + break + if not updated: + module.fail_json(msg="Failed to wait for {0} to be updated".format(attribute)) + + +def get_cidr_network_bits(module, cidr_block): + fixed_cidrs = [] + for cidr in cidr_block: + split_addr = cidr.split('/') + if len(split_addr) == 2: + # this_ip is a IPv4 CIDR that may or may not have host bits set + # Get the network bits. + valid_cidr = to_subnet(split_addr[0], split_addr[1]) + if cidr != valid_cidr: + module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, " + "check the network mask and make sure that only network bits are set: {1}.".format(cidr, valid_cidr)) + fixed_cidrs.append(valid_cidr) + else: + # let AWS handle invalid CIDRs + fixed_cidrs.append(cidr) + return fixed_cidrs + + +def main(): + argument_spec = dict( + name=dict(required=True), + cidr_block=dict(type='list', required=True), + ipv6_cidr=dict(type='bool', default=False), + tenancy=dict(choices=['default', 'dedicated'], default='default'), + dns_support=dict(type='bool', default=True), + dns_hostnames=dict(type='bool', default=True), + dhcp_opts_id=dict(), + tags=dict(type='dict', aliases=['resource_tags']), + state=dict(choices=['present', 'absent'], default='present'), + multi_ok=dict(type='bool', default=False), + purge_cidrs=dict(type='bool', default=False), + ) + + module = AnsibleAWSModule( + argument_spec=argument_spec, + supports_check_mode=True + ) + + name = module.params.get('name') + cidr_block = get_cidr_network_bits(module, module.params.get('cidr_block')) + ipv6_cidr = module.params.get('ipv6_cidr') + purge_cidrs = module.params.get('purge_cidrs') + tenancy = module.params.get('tenancy') + dns_support = module.params.get('dns_support') + dns_hostnames = module.params.get('dns_hostnames') + dhcp_id = module.params.get('dhcp_opts_id') + tags = module.params.get('tags') + state = module.params.get('state') + multi = module.params.get('multi_ok') + + changed = False + + connection = module.client( + 'ec2', + retry_decorator=AWSRetry.jittered_backoff( + retries=8, delay=3, catch_extra_error_codes=['InvalidVpcID.NotFound'] + ) + ) + + if dns_hostnames and not dns_support: + module.fail_json(msg='In order to enable DNS Hostnames you must also enable DNS support') + + if state == 'present': + + # Check if VPC exists + vpc_id = vpc_exists(module, connection, name, cidr_block, multi) + + if vpc_id is None: + vpc_id = create_vpc(connection, module, cidr_block[0], tenancy) + changed = True + + vpc_obj = get_vpc(module, connection, vpc_id) + + associated_cidrs = dict((cidr['CidrBlock'], cidr['AssociationId']) for cidr in vpc_obj.get('CidrBlockAssociationSet', []) + if cidr['CidrBlockState']['State'] != 'disassociated') + to_add = [cidr for cidr in cidr_block if cidr not in associated_cidrs] + to_remove = [associated_cidrs[cidr] for cidr in associated_cidrs if cidr not in cidr_block] + expected_cidrs = [cidr for cidr in associated_cidrs if associated_cidrs[cidr] not in to_remove] + to_add + + if len(cidr_block) > 1: + for cidr in to_add: + changed = True + try: + connection.associate_vpc_cidr_block(CidrBlock=cidr, VpcId=vpc_id) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Unable to associate CIDR {0}.".format(ipv6_cidr)) + if ipv6_cidr: + if 'Ipv6CidrBlockAssociationSet' in vpc_obj.keys(): + module.warn("Only one IPv6 CIDR is permitted per VPC, {0} already has CIDR {1}".format( + vpc_id, + vpc_obj['Ipv6CidrBlockAssociationSet'][0]['Ipv6CidrBlock'])) + else: + try: + connection.associate_vpc_cidr_block(AmazonProvidedIpv6CidrBlock=ipv6_cidr, VpcId=vpc_id) + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Unable to associate CIDR {0}.".format(ipv6_cidr)) + + if purge_cidrs: + for association_id in to_remove: + changed = True + try: + connection.disassociate_vpc_cidr_block(AssociationId=association_id) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Unable to disassociate {0}. You must detach or delete all gateways and resources that " + "are associated with the CIDR block before you can disassociate it.".format(association_id)) + + if dhcp_id is not None: + try: + if update_dhcp_opts(connection, module, vpc_obj, dhcp_id): + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to update DHCP options") + + if tags is not None or name is not None: + try: + if update_vpc_tags(connection, module, vpc_id, tags, name): + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to update tags") + + current_dns_enabled = connection.describe_vpc_attribute(Attribute='enableDnsSupport', VpcId=vpc_id, aws_retry=True)['EnableDnsSupport']['Value'] + current_dns_hostnames = connection.describe_vpc_attribute(Attribute='enableDnsHostnames', VpcId=vpc_id, aws_retry=True)['EnableDnsHostnames']['Value'] + if current_dns_enabled != dns_support: + changed = True + if not module.check_mode: + try: + connection.modify_vpc_attribute(VpcId=vpc_id, EnableDnsSupport={'Value': dns_support}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to update enabled dns support attribute") + if current_dns_hostnames != dns_hostnames: + changed = True + if not module.check_mode: + try: + connection.modify_vpc_attribute(VpcId=vpc_id, EnableDnsHostnames={'Value': dns_hostnames}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to update enabled dns hostnames attribute") + + # wait for associated cidrs to match + if to_add or to_remove: + try: + connection.get_waiter('vpc_available').wait( + VpcIds=[vpc_id], + Filters=[{'Name': 'cidr-block-association.cidr-block', 'Values': expected_cidrs}] + ) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to wait for CIDRs to update") + + # try to wait for enableDnsSupport and enableDnsHostnames to match + wait_for_vpc_attribute(connection, module, vpc_id, 'enableDnsSupport', dns_support) + wait_for_vpc_attribute(connection, module, vpc_id, 'enableDnsHostnames', dns_hostnames) + + final_state = camel_dict_to_snake_dict(get_vpc(module, connection, vpc_id)) + final_state['tags'] = boto3_tag_list_to_ansible_dict(final_state.get('tags', [])) + final_state['id'] = final_state.pop('vpc_id') + + module.exit_json(changed=changed, vpc=final_state) + + elif state == 'absent': + + # Check if VPC exists + vpc_id = vpc_exists(module, connection, name, cidr_block, multi) + + if vpc_id is not None: + try: + if not module.check_mode: + connection.delete_vpc(VpcId=vpc_id) + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to delete VPC {0} You may want to use the ec2_vpc_subnet, ec2_vpc_igw, " + "and/or ec2_vpc_route_table modules to ensure the other components are absent.".format(vpc_id)) + + module.exit_json(changed=changed, vpc={}) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_vpc_subnet.py b/test/support/integration/plugins/modules/ec2_vpc_subnet.py new file mode 100644 index 00000000..5085e99b --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_vpc_subnet.py @@ -0,0 +1,604 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: ec2_vpc_subnet +short_description: Manage subnets in AWS virtual private clouds +description: + - Manage subnets in AWS virtual private clouds. +version_added: "2.0" +author: +- Robert Estelle (@erydo) +- Brad Davidson (@brandond) +requirements: [ boto3 ] +options: + az: + description: + - "The availability zone for the subnet." + type: str + cidr: + description: + - "The CIDR block for the subnet. E.g. 192.0.2.0/24." + type: str + required: true + ipv6_cidr: + description: + - "The IPv6 CIDR block for the subnet. The VPC must have a /56 block assigned and this value must be a valid IPv6 /64 that falls in the VPC range." + - "Required if I(assign_instances_ipv6=true)" + version_added: "2.5" + type: str + tags: + description: + - "A dict of tags to apply to the subnet. Any tags currently applied to the subnet and not present here will be removed." + aliases: [ 'resource_tags' ] + type: dict + state: + description: + - "Create or remove the subnet." + default: present + choices: [ 'present', 'absent' ] + type: str + vpc_id: + description: + - "VPC ID of the VPC in which to create or delete the subnet." + required: true + type: str + map_public: + description: + - "Specify C(yes) to indicate that instances launched into the subnet should be assigned public IP address by default." + type: bool + default: 'no' + version_added: "2.4" + assign_instances_ipv6: + description: + - "Specify C(yes) to indicate that instances launched into the subnet should be automatically assigned an IPv6 address." + type: bool + default: false + version_added: "2.5" + wait: + description: + - "When I(wait=true) and I(state=present), module will wait for subnet to be in available state before continuing." + type: bool + default: true + version_added: "2.5" + wait_timeout: + description: + - "Number of seconds to wait for subnet to become available I(wait=True)." + default: 300 + version_added: "2.5" + type: int + purge_tags: + description: + - Whether or not to remove tags that do not appear in the I(tags) list. + type: bool + default: true + version_added: "2.5" +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +- name: Create subnet for database servers + ec2_vpc_subnet: + state: present + vpc_id: vpc-123456 + cidr: 10.0.1.16/28 + tags: + Name: Database Subnet + register: database_subnet + +- name: Remove subnet for database servers + ec2_vpc_subnet: + state: absent + vpc_id: vpc-123456 + cidr: 10.0.1.16/28 + +- name: Create subnet with IPv6 block assigned + ec2_vpc_subnet: + state: present + vpc_id: vpc-123456 + cidr: 10.1.100.0/24 + ipv6_cidr: 2001:db8:0:102::/64 + +- name: Remove IPv6 block assigned to subnet + ec2_vpc_subnet: + state: present + vpc_id: vpc-123456 + cidr: 10.1.100.0/24 + ipv6_cidr: '' +''' + +RETURN = ''' +subnet: + description: Dictionary of subnet values + returned: I(state=present) + type: complex + contains: + id: + description: Subnet resource id + returned: I(state=present) + type: str + sample: subnet-b883b2c4 + cidr_block: + description: The IPv4 CIDR of the Subnet + returned: I(state=present) + type: str + sample: "10.0.0.0/16" + ipv6_cidr_block: + description: The IPv6 CIDR block actively associated with the Subnet + returned: I(state=present) + type: str + sample: "2001:db8:0:102::/64" + availability_zone: + description: Availability zone of the Subnet + returned: I(state=present) + type: str + sample: us-east-1a + state: + description: state of the Subnet + returned: I(state=present) + type: str + sample: available + tags: + description: tags attached to the Subnet, includes name + returned: I(state=present) + type: dict + sample: {"Name": "My Subnet", "env": "staging"} + map_public_ip_on_launch: + description: whether public IP is auto-assigned to new instances + returned: I(state=present) + type: bool + sample: false + assign_ipv6_address_on_creation: + description: whether IPv6 address is auto-assigned to new instances + returned: I(state=present) + type: bool + sample: false + vpc_id: + description: the id of the VPC where this Subnet exists + returned: I(state=present) + type: str + sample: vpc-67236184 + available_ip_address_count: + description: number of available IPv4 addresses + returned: I(state=present) + type: str + sample: 251 + default_for_az: + description: indicates whether this is the default Subnet for this Availability Zone + returned: I(state=present) + type: bool + sample: false + ipv6_association_id: + description: The IPv6 association ID for the currently associated CIDR + returned: I(state=present) + type: str + sample: subnet-cidr-assoc-b85c74d2 + ipv6_cidr_block_association_set: + description: An array of IPv6 cidr block association set information. + returned: I(state=present) + type: complex + contains: + association_id: + description: The association ID + returned: always + type: str + ipv6_cidr_block: + description: The IPv6 CIDR block that is associated with the subnet. + returned: always + type: str + ipv6_cidr_block_state: + description: A hash/dict that contains a single item. The state of the cidr block association. + returned: always + type: dict + contains: + state: + description: The CIDR block association state. + returned: always + type: str +''' + + +import time + +try: + import botocore +except ImportError: + pass # caught by AnsibleAWSModule + +from ansible.module_utils._text import to_text +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.aws.waiters import get_waiter +from ansible.module_utils.ec2 import (ansible_dict_to_boto3_filter_list, ansible_dict_to_boto3_tag_list, + camel_dict_to_snake_dict, boto3_tag_list_to_ansible_dict, compare_aws_tags, AWSRetry) + + +def get_subnet_info(subnet): + if 'Subnets' in subnet: + return [get_subnet_info(s) for s in subnet['Subnets']] + elif 'Subnet' in subnet: + subnet = camel_dict_to_snake_dict(subnet['Subnet']) + else: + subnet = camel_dict_to_snake_dict(subnet) + + if 'tags' in subnet: + subnet['tags'] = boto3_tag_list_to_ansible_dict(subnet['tags']) + else: + subnet['tags'] = dict() + + if 'subnet_id' in subnet: + subnet['id'] = subnet['subnet_id'] + del subnet['subnet_id'] + + subnet['ipv6_cidr_block'] = '' + subnet['ipv6_association_id'] = '' + ipv6set = subnet.get('ipv6_cidr_block_association_set') + if ipv6set: + for item in ipv6set: + if item.get('ipv6_cidr_block_state', {}).get('state') in ('associated', 'associating'): + subnet['ipv6_cidr_block'] = item['ipv6_cidr_block'] + subnet['ipv6_association_id'] = item['association_id'] + + return subnet + + +@AWSRetry.exponential_backoff() +def describe_subnets_with_backoff(client, **params): + return client.describe_subnets(**params) + + +def waiter_params(module, params, start_time): + if not module.botocore_at_least("1.7.0"): + remaining_wait_timeout = int(module.params['wait_timeout'] + start_time - time.time()) + params['WaiterConfig'] = {'Delay': 5, 'MaxAttempts': remaining_wait_timeout // 5} + return params + + +def handle_waiter(conn, module, waiter_name, params, start_time): + try: + get_waiter(conn, waiter_name).wait( + **waiter_params(module, params, start_time) + ) + except botocore.exceptions.WaiterError as e: + module.fail_json_aws(e, "Failed to wait for updates to complete") + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "An exception happened while trying to wait for updates") + + +def create_subnet(conn, module, vpc_id, cidr, ipv6_cidr=None, az=None, start_time=None): + wait = module.params['wait'] + wait_timeout = module.params['wait_timeout'] + + params = dict(VpcId=vpc_id, + CidrBlock=cidr) + + if ipv6_cidr: + params['Ipv6CidrBlock'] = ipv6_cidr + + if az: + params['AvailabilityZone'] = az + + try: + subnet = get_subnet_info(conn.create_subnet(**params)) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't create subnet") + + # Sometimes AWS takes its time to create a subnet and so using + # new subnets's id to do things like create tags results in + # exception. + if wait and subnet.get('state') != 'available': + handle_waiter(conn, module, 'subnet_exists', {'SubnetIds': [subnet['id']]}, start_time) + try: + conn.get_waiter('subnet_available').wait( + **waiter_params(module, {'SubnetIds': [subnet['id']]}, start_time) + ) + subnet['state'] = 'available' + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Create subnet action timed out waiting for subnet to become available") + + return subnet + + +def ensure_tags(conn, module, subnet, tags, purge_tags, start_time): + changed = False + + filters = ansible_dict_to_boto3_filter_list({'resource-id': subnet['id'], 'resource-type': 'subnet'}) + try: + cur_tags = conn.describe_tags(Filters=filters) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't describe tags") + + to_update, to_delete = compare_aws_tags(boto3_tag_list_to_ansible_dict(cur_tags.get('Tags')), tags, purge_tags) + + if to_update: + try: + if not module.check_mode: + AWSRetry.exponential_backoff( + catch_extra_error_codes=['InvalidSubnetID.NotFound'] + )(conn.create_tags)( + Resources=[subnet['id']], + Tags=ansible_dict_to_boto3_tag_list(to_update) + ) + + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't create tags") + + if to_delete: + try: + if not module.check_mode: + tags_list = [] + for key in to_delete: + tags_list.append({'Key': key}) + + AWSRetry.exponential_backoff( + catch_extra_error_codes=['InvalidSubnetID.NotFound'] + )(conn.delete_tags)(Resources=[subnet['id']], Tags=tags_list) + + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't delete tags") + + if module.params['wait'] and not module.check_mode: + # Wait for tags to be updated + filters = [{'Name': 'tag:{0}'.format(k), 'Values': [v]} for k, v in tags.items()] + handle_waiter(conn, module, 'subnet_exists', + {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time) + + return changed + + +def ensure_map_public(conn, module, subnet, map_public, check_mode, start_time): + if check_mode: + return + try: + conn.modify_subnet_attribute(SubnetId=subnet['id'], MapPublicIpOnLaunch={'Value': map_public}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't modify subnet attribute") + + +def ensure_assign_ipv6_on_create(conn, module, subnet, assign_instances_ipv6, check_mode, start_time): + if check_mode: + return + try: + conn.modify_subnet_attribute(SubnetId=subnet['id'], AssignIpv6AddressOnCreation={'Value': assign_instances_ipv6}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't modify subnet attribute") + + +def disassociate_ipv6_cidr(conn, module, subnet, start_time): + if subnet.get('assign_ipv6_address_on_creation'): + ensure_assign_ipv6_on_create(conn, module, subnet, False, False, start_time) + + try: + conn.disassociate_subnet_cidr_block(AssociationId=subnet['ipv6_association_id']) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't disassociate ipv6 cidr block id {0} from subnet {1}" + .format(subnet['ipv6_association_id'], subnet['id'])) + + # Wait for cidr block to be disassociated + if module.params['wait']: + filters = ansible_dict_to_boto3_filter_list( + {'ipv6-cidr-block-association.state': ['disassociated'], + 'vpc-id': subnet['vpc_id']} + ) + handle_waiter(conn, module, 'subnet_exists', + {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time) + + +def ensure_ipv6_cidr_block(conn, module, subnet, ipv6_cidr, check_mode, start_time): + wait = module.params['wait'] + changed = False + + if subnet['ipv6_association_id'] and not ipv6_cidr: + if not check_mode: + disassociate_ipv6_cidr(conn, module, subnet, start_time) + changed = True + + if ipv6_cidr: + filters = ansible_dict_to_boto3_filter_list({'ipv6-cidr-block-association.ipv6-cidr-block': ipv6_cidr, + 'vpc-id': subnet['vpc_id']}) + + try: + check_subnets = get_subnet_info(describe_subnets_with_backoff(conn, Filters=filters)) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't get subnet info") + + if check_subnets and check_subnets[0]['ipv6_cidr_block']: + module.fail_json(msg="The IPv6 CIDR '{0}' conflicts with another subnet".format(ipv6_cidr)) + + if subnet['ipv6_association_id']: + if not check_mode: + disassociate_ipv6_cidr(conn, module, subnet, start_time) + changed = True + + try: + if not check_mode: + associate_resp = conn.associate_subnet_cidr_block(SubnetId=subnet['id'], Ipv6CidrBlock=ipv6_cidr) + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't associate ipv6 cidr {0} to {1}".format(ipv6_cidr, subnet['id'])) + else: + if not check_mode and wait: + filters = ansible_dict_to_boto3_filter_list( + {'ipv6-cidr-block-association.state': ['associated'], + 'vpc-id': subnet['vpc_id']} + ) + handle_waiter(conn, module, 'subnet_exists', + {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time) + + if associate_resp.get('Ipv6CidrBlockAssociation', {}).get('AssociationId'): + subnet['ipv6_association_id'] = associate_resp['Ipv6CidrBlockAssociation']['AssociationId'] + subnet['ipv6_cidr_block'] = associate_resp['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'] + if subnet['ipv6_cidr_block_association_set']: + subnet['ipv6_cidr_block_association_set'][0] = camel_dict_to_snake_dict(associate_resp['Ipv6CidrBlockAssociation']) + else: + subnet['ipv6_cidr_block_association_set'].append(camel_dict_to_snake_dict(associate_resp['Ipv6CidrBlockAssociation'])) + + return changed + + +def get_matching_subnet(conn, module, vpc_id, cidr): + filters = ansible_dict_to_boto3_filter_list({'vpc-id': vpc_id, 'cidr-block': cidr}) + try: + subnets = get_subnet_info(describe_subnets_with_backoff(conn, Filters=filters)) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't get matching subnet") + + if subnets: + return subnets[0] + + return None + + +def ensure_subnet_present(conn, module): + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + changed = False + + # Initialize start so max time does not exceed the specified wait_timeout for multiple operations + start_time = time.time() + + if subnet is None: + if not module.check_mode: + subnet = create_subnet(conn, module, module.params['vpc_id'], module.params['cidr'], + ipv6_cidr=module.params['ipv6_cidr'], az=module.params['az'], start_time=start_time) + changed = True + # Subnet will be None when check_mode is true + if subnet is None: + return { + 'changed': changed, + 'subnet': {} + } + if module.params['wait']: + handle_waiter(conn, module, 'subnet_exists', {'SubnetIds': [subnet['id']]}, start_time) + + if module.params['ipv6_cidr'] != subnet.get('ipv6_cidr_block'): + if ensure_ipv6_cidr_block(conn, module, subnet, module.params['ipv6_cidr'], module.check_mode, start_time): + changed = True + + if module.params['map_public'] != subnet['map_public_ip_on_launch']: + ensure_map_public(conn, module, subnet, module.params['map_public'], module.check_mode, start_time) + changed = True + + if module.params['assign_instances_ipv6'] != subnet.get('assign_ipv6_address_on_creation'): + ensure_assign_ipv6_on_create(conn, module, subnet, module.params['assign_instances_ipv6'], module.check_mode, start_time) + changed = True + + if module.params['tags'] != subnet['tags']: + stringified_tags_dict = dict((to_text(k), to_text(v)) for k, v in module.params['tags'].items()) + if ensure_tags(conn, module, subnet, stringified_tags_dict, module.params['purge_tags'], start_time): + changed = True + + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + if not module.check_mode and module.params['wait']: + # GET calls are not monotonic for map_public_ip_on_launch and assign_ipv6_address_on_creation + # so we only wait for those if necessary just before returning the subnet + subnet = ensure_final_subnet(conn, module, subnet, start_time) + + return { + 'changed': changed, + 'subnet': subnet + } + + +def ensure_final_subnet(conn, module, subnet, start_time): + for rewait in range(0, 30): + map_public_correct = False + assign_ipv6_correct = False + + if module.params['map_public'] == subnet['map_public_ip_on_launch']: + map_public_correct = True + else: + if module.params['map_public']: + handle_waiter(conn, module, 'subnet_has_map_public', {'SubnetIds': [subnet['id']]}, start_time) + else: + handle_waiter(conn, module, 'subnet_no_map_public', {'SubnetIds': [subnet['id']]}, start_time) + + if module.params['assign_instances_ipv6'] == subnet.get('assign_ipv6_address_on_creation'): + assign_ipv6_correct = True + else: + if module.params['assign_instances_ipv6']: + handle_waiter(conn, module, 'subnet_has_assign_ipv6', {'SubnetIds': [subnet['id']]}, start_time) + else: + handle_waiter(conn, module, 'subnet_no_assign_ipv6', {'SubnetIds': [subnet['id']]}, start_time) + + if map_public_correct and assign_ipv6_correct: + break + + time.sleep(5) + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + + return subnet + + +def ensure_subnet_absent(conn, module): + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + if subnet is None: + return {'changed': False} + + try: + if not module.check_mode: + conn.delete_subnet(SubnetId=subnet['id']) + if module.params['wait']: + handle_waiter(conn, module, 'subnet_deleted', {'SubnetIds': [subnet['id']]}, time.time()) + return {'changed': True} + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't delete subnet") + + +def main(): + argument_spec = dict( + az=dict(default=None, required=False), + cidr=dict(required=True), + ipv6_cidr=dict(default='', required=False), + state=dict(default='present', choices=['present', 'absent']), + tags=dict(default={}, required=False, type='dict', aliases=['resource_tags']), + vpc_id=dict(required=True), + map_public=dict(default=False, required=False, type='bool'), + assign_instances_ipv6=dict(default=False, required=False, type='bool'), + wait=dict(type='bool', default=True), + wait_timeout=dict(type='int', default=300, required=False), + purge_tags=dict(default=True, type='bool') + ) + + required_if = [('assign_instances_ipv6', True, ['ipv6_cidr'])] + + module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True, required_if=required_if) + + if module.params.get('assign_instances_ipv6') and not module.params.get('ipv6_cidr'): + module.fail_json(msg="assign_instances_ipv6 is True but ipv6_cidr is None or an empty string") + + if not module.botocore_at_least("1.7.0"): + module.warn("botocore >= 1.7.0 is required to use wait_timeout for custom wait times") + + connection = module.client('ec2') + + state = module.params.get('state') + + try: + if state == 'present': + result = ensure_subnet_present(connection, module) + elif state == 'absent': + result = ensure_subnet_absent(connection, module) + except botocore.exceptions.ClientError as e: + module.fail_json_aws(e) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/flatpak_remote.py b/test/support/integration/plugins/modules/flatpak_remote.py new file mode 100644 index 00000000..db208f1b --- /dev/null +++ b/test/support/integration/plugins/modules/flatpak_remote.py @@ -0,0 +1,243 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017 John Kwiatkoski (@JayKayy) <jkwiat40@gmail.com> +# Copyright: (c) 2018 Alexander Bethke (@oolongbrothers) <oolongbrothers@gmx.net> +# Copyright: (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +# ATTENTION CONTRIBUTORS! +# +# TL;DR: Run this module's integration tests manually before opening a pull request +# +# Long explanation: +# The integration tests for this module are currently NOT run on the Ansible project's continuous +# delivery pipeline. So please: When you make changes to this module, make sure that you run the +# included integration tests manually for both Python 2 and Python 3: +# +# Python 2: +# ansible-test integration -v --docker fedora28 --docker-privileged --allow-unsupported --python 2.7 flatpak_remote +# Python 3: +# ansible-test integration -v --docker fedora28 --docker-privileged --allow-unsupported --python 3.6 flatpak_remote +# +# Because of external dependencies, the current integration tests are somewhat too slow and brittle +# to be included right now. I have plans to rewrite the integration tests based on a local flatpak +# repository so that they can be included into the normal CI pipeline. +# //oolongbrothers + + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: flatpak_remote +version_added: '2.6' +short_description: Manage flatpak repository remotes +description: +- Allows users to add or remove flatpak remotes. +- The flatpak remotes concept is comparable to what is called repositories in other packaging + formats. +- Currently, remote addition is only supported via I(flatpakrepo) file URLs. +- Existing remotes will not be updated. +- See the M(flatpak) module for managing flatpaks. +author: +- John Kwiatkoski (@JayKayy) +- Alexander Bethke (@oolongbrothers) +requirements: +- flatpak +options: + executable: + description: + - The path to the C(flatpak) executable to use. + - By default, this module looks for the C(flatpak) executable on the path. + default: flatpak + flatpakrepo_url: + description: + - The URL to the I(flatpakrepo) file representing the repository remote to add. + - When used with I(state=present), the flatpak remote specified under the I(flatpakrepo_url) + is added using the specified installation C(method). + - When used with I(state=absent), this is not required. + - Required when I(state=present). + method: + description: + - The installation method to use. + - Defines if the I(flatpak) is supposed to be installed globally for the whole C(system) + or only for the current C(user). + choices: [ system, user ] + default: system + name: + description: + - The desired name for the flatpak remote to be registered under on the managed host. + - When used with I(state=present), the remote will be added to the managed host under + the specified I(name). + - When used with I(state=absent) the remote with that name will be removed. + required: true + state: + description: + - Indicates the desired package state. + choices: [ absent, present ] + default: present +''' + +EXAMPLES = r''' +- name: Add the Gnome flatpak remote to the system installation + flatpak_remote: + name: gnome + state: present + flatpakrepo_url: https://sdk.gnome.org/gnome-apps.flatpakrepo + +- name: Add the flathub flatpak repository remote to the user installation + flatpak_remote: + name: flathub + state: present + flatpakrepo_url: https://dl.flathub.org/repo/flathub.flatpakrepo + method: user + +- name: Remove the Gnome flatpak remote from the user installation + flatpak_remote: + name: gnome + state: absent + method: user + +- name: Remove the flathub remote from the system installation + flatpak_remote: + name: flathub + state: absent +''' + +RETURN = r''' +command: + description: The exact flatpak command that was executed + returned: When a flatpak command has been executed + type: str + sample: "/usr/bin/flatpak remote-add --system flatpak-test https://dl.flathub.org/repo/flathub.flatpakrepo" +msg: + description: Module error message + returned: failure + type: str + sample: "Executable '/usr/local/bin/flatpak' was not found on the system." +rc: + description: Return code from flatpak binary + returned: When a flatpak command has been executed + type: int + sample: 0 +stderr: + description: Error output from flatpak binary + returned: When a flatpak command has been executed + type: str + sample: "error: GPG verification enabled, but no summary found (check that the configured URL in remote config is correct)\n" +stdout: + description: Output from flatpak binary + returned: When a flatpak command has been executed + type: str + sample: "flathub\tFlathub\thttps://dl.flathub.org/repo/\t1\t\n" +''' + +import subprocess +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_bytes, to_native + + +def add_remote(module, binary, name, flatpakrepo_url, method): + """Add a new remote.""" + global result + command = "{0} remote-add --{1} {2} {3}".format( + binary, method, name, flatpakrepo_url) + _flatpak_command(module, module.check_mode, command) + result['changed'] = True + + +def remove_remote(module, binary, name, method): + """Remove an existing remote.""" + global result + command = "{0} remote-delete --{1} --force {2} ".format( + binary, method, name) + _flatpak_command(module, module.check_mode, command) + result['changed'] = True + + +def remote_exists(module, binary, name, method): + """Check if the remote exists.""" + command = "{0} remote-list -d --{1}".format(binary, method) + # The query operation for the remote needs to be run even in check mode + output = _flatpak_command(module, False, command) + for line in output.splitlines(): + listed_remote = line.split() + if len(listed_remote) == 0: + continue + if listed_remote[0] == to_native(name): + return True + return False + + +def _flatpak_command(module, noop, command): + global result + if noop: + result['rc'] = 0 + result['command'] = command + return "" + + process = subprocess.Popen( + command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout_data, stderr_data = process.communicate() + result['rc'] = process.returncode + result['command'] = command + result['stdout'] = stdout_data + result['stderr'] = stderr_data + if result['rc'] != 0: + module.fail_json(msg="Failed to execute flatpak command", **result) + return to_native(stdout_data) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(type='str', required=True), + flatpakrepo_url=dict(type='str'), + method=dict(type='str', default='system', + choices=['user', 'system']), + state=dict(type='str', default="present", + choices=['absent', 'present']), + executable=dict(type='str', default="flatpak") + ), + # This module supports check mode + supports_check_mode=True, + ) + + name = module.params['name'] + flatpakrepo_url = module.params['flatpakrepo_url'] + method = module.params['method'] + state = module.params['state'] + executable = module.params['executable'] + binary = module.get_bin_path(executable, None) + + if flatpakrepo_url is None: + flatpakrepo_url = '' + + global result + result = dict( + changed=False + ) + + # If the binary was not found, fail the operation + if not binary: + module.fail_json(msg="Executable '%s' was not found on the system." % executable, **result) + + remote_already_exists = remote_exists(module, binary, to_bytes(name), method) + + if state == 'present' and not remote_already_exists: + add_remote(module, binary, name, flatpakrepo_url, method) + elif state == 'absent' and remote_already_exists: + remove_remote(module, binary, name, method) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/htpasswd.py b/test/support/integration/plugins/modules/htpasswd.py new file mode 100644 index 00000000..ad12b0c0 --- /dev/null +++ b/test/support/integration/plugins/modules/htpasswd.py @@ -0,0 +1,275 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, Nimbis Services, Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = """ +module: htpasswd +version_added: "1.3" +short_description: manage user files for basic authentication +description: + - Add and remove username/password entries in a password file using htpasswd. + - This is used by web servers such as Apache and Nginx for basic authentication. +options: + path: + required: true + aliases: [ dest, destfile ] + description: + - Path to the file that contains the usernames and passwords + name: + required: true + aliases: [ username ] + description: + - User name to add or remove + password: + required: false + description: + - Password associated with user. + - Must be specified if user does not exist yet. + crypt_scheme: + required: false + choices: ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"] + default: "apr_md5_crypt" + description: + - Encryption scheme to be used. As well as the four choices listed + here, you can also use any other hash supported by passlib, such as + md5_crypt and sha256_crypt, which are linux passwd hashes. If you + do so the password file will not be compatible with Apache or Nginx + state: + required: false + choices: [ present, absent ] + default: "present" + description: + - Whether the user entry should be present or not + create: + required: false + type: bool + default: "yes" + description: + - Used with C(state=present). If specified, the file will be created + if it does not already exist. If set to "no", will fail if the + file does not exist +notes: + - "This module depends on the I(passlib) Python library, which needs to be installed on all target systems." + - "On Debian, Ubuntu, or Fedora: install I(python-passlib)." + - "On RHEL or CentOS: Enable EPEL, then install I(python-passlib)." +requirements: [ passlib>=1.6 ] +author: "Ansible Core Team" +extends_documentation_fragment: files +""" + +EXAMPLES = """ +# Add a user to a password file and ensure permissions are set +- htpasswd: + path: /etc/nginx/passwdfile + name: janedoe + password: '9s36?;fyNp' + owner: root + group: www-data + mode: 0640 + +# Remove a user from a password file +- htpasswd: + path: /etc/apache2/passwdfile + name: foobar + state: absent + +# Add a user to a password file suitable for use by libpam-pwdfile +- htpasswd: + path: /etc/mail/passwords + name: alex + password: oedu2eGh + crypt_scheme: md5_crypt +""" + + +import os +import tempfile +import traceback +from distutils.version import LooseVersion +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native + +PASSLIB_IMP_ERR = None +try: + from passlib.apache import HtpasswdFile, htpasswd_context + from passlib.context import CryptContext + import passlib +except ImportError: + PASSLIB_IMP_ERR = traceback.format_exc() + passlib_installed = False +else: + passlib_installed = True + +apache_hashes = ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"] + + +def create_missing_directories(dest): + destpath = os.path.dirname(dest) + if not os.path.exists(destpath): + os.makedirs(destpath) + + +def present(dest, username, password, crypt_scheme, create, check_mode): + """ Ensures user is present + + Returns (msg, changed) """ + if crypt_scheme in apache_hashes: + context = htpasswd_context + else: + context = CryptContext(schemes=[crypt_scheme] + apache_hashes) + if not os.path.exists(dest): + if not create: + raise ValueError('Destination %s does not exist' % dest) + if check_mode: + return ("Create %s" % dest, True) + create_missing_directories(dest) + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=True, default_scheme=crypt_scheme, context=context) + else: + ht = HtpasswdFile(dest, autoload=False, default=crypt_scheme, context=context) + if getattr(ht, 'set_password', None): + ht.set_password(username, password) + else: + ht.update(username, password) + ht.save() + return ("Created %s and added %s" % (dest, username), True) + else: + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=False, default_scheme=crypt_scheme, context=context) + else: + ht = HtpasswdFile(dest, default=crypt_scheme, context=context) + + found = None + if getattr(ht, 'check_password', None): + found = ht.check_password(username, password) + else: + found = ht.verify(username, password) + + if found: + return ("%s already present" % username, False) + else: + if not check_mode: + if getattr(ht, 'set_password', None): + ht.set_password(username, password) + else: + ht.update(username, password) + ht.save() + return ("Add/update %s" % username, True) + + +def absent(dest, username, check_mode): + """ Ensures user is absent + + Returns (msg, changed) """ + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=False) + else: + ht = HtpasswdFile(dest) + + if username not in ht.users(): + return ("%s not present" % username, False) + else: + if not check_mode: + ht.delete(username) + ht.save() + return ("Remove %s" % username, True) + + +def check_file_attrs(module, changed, message): + + file_args = module.load_file_common_arguments(module.params) + if module.set_fs_attributes_if_different(file_args, False): + + if changed: + message += " and " + changed = True + message += "ownership, perms or SE linux context changed" + + return message, changed + + +def main(): + arg_spec = dict( + path=dict(required=True, aliases=["dest", "destfile"]), + name=dict(required=True, aliases=["username"]), + password=dict(required=False, default=None, no_log=True), + crypt_scheme=dict(required=False, default="apr_md5_crypt"), + state=dict(required=False, default="present"), + create=dict(type='bool', default='yes'), + + ) + module = AnsibleModule(argument_spec=arg_spec, + add_file_common_args=True, + supports_check_mode=True) + + path = module.params['path'] + username = module.params['name'] + password = module.params['password'] + crypt_scheme = module.params['crypt_scheme'] + state = module.params['state'] + create = module.params['create'] + check_mode = module.check_mode + + if not passlib_installed: + module.fail_json(msg=missing_required_lib("passlib"), exception=PASSLIB_IMP_ERR) + + # Check file for blank lines in effort to avoid "need more than 1 value to unpack" error. + try: + f = open(path, "r") + except IOError: + # No preexisting file to remove blank lines from + f = None + else: + try: + lines = f.readlines() + finally: + f.close() + + # If the file gets edited, it returns true, so only edit the file if it has blank lines + strip = False + for line in lines: + if not line.strip(): + strip = True + break + + if strip: + # If check mode, create a temporary file + if check_mode: + temp = tempfile.NamedTemporaryFile() + path = temp.name + f = open(path, "w") + try: + [f.write(line) for line in lines if line.strip()] + finally: + f.close() + + try: + if state == 'present': + (msg, changed) = present(path, username, password, crypt_scheme, create, check_mode) + elif state == 'absent': + if not os.path.exists(path): + module.exit_json(msg="%s not present" % username, + warnings="%s does not exist" % path, changed=False) + (msg, changed) = absent(path, username, check_mode) + else: + module.fail_json(msg="Invalid state: %s" % state) + + check_file_attrs(module, changed, msg) + module.exit_json(msg=msg, changed=changed) + except Exception as e: + module.fail_json(msg=to_native(e)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/locale_gen.py b/test/support/integration/plugins/modules/locale_gen.py new file mode 100644 index 00000000..4968b834 --- /dev/null +++ b/test/support/integration/plugins/modules/locale_gen.py @@ -0,0 +1,237 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +--- +module: locale_gen +short_description: Creates or removes locales +description: + - Manages locales by editing /etc/locale.gen and invoking locale-gen. +version_added: "1.6" +author: +- Augustus Kling (@AugustusKling) +options: + name: + description: + - Name and encoding of the locale, such as "en_GB.UTF-8". + required: true + state: + description: + - Whether the locale shall be present. + choices: [ absent, present ] + default: present +''' + +EXAMPLES = ''' +- name: Ensure a locale exists + locale_gen: + name: de_CH.UTF-8 + state: present +''' + +import os +import re +from subprocess import Popen, PIPE, call + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_native + +LOCALE_NORMALIZATION = { + ".utf8": ".UTF-8", + ".eucjp": ".EUC-JP", + ".iso885915": ".ISO-8859-15", + ".cp1251": ".CP1251", + ".koi8r": ".KOI8-R", + ".armscii8": ".ARMSCII-8", + ".euckr": ".EUC-KR", + ".gbk": ".GBK", + ".gb18030": ".GB18030", + ".euctw": ".EUC-TW", +} + + +# =========================================== +# location module specific support methods. +# + +def is_available(name, ubuntuMode): + """Check if the given locale is available on the system. This is done by + checking either : + * if the locale is present in /etc/locales.gen + * or if the locale is present in /usr/share/i18n/SUPPORTED""" + if ubuntuMode: + __regexp = r'^(?P<locale>\S+_\S+) (?P<charset>\S+)\s*$' + __locales_available = '/usr/share/i18n/SUPPORTED' + else: + __regexp = r'^#{0,1}\s*(?P<locale>\S+_\S+) (?P<charset>\S+)\s*$' + __locales_available = '/etc/locale.gen' + + re_compiled = re.compile(__regexp) + fd = open(__locales_available, 'r') + for line in fd: + result = re_compiled.match(line) + if result and result.group('locale') == name: + return True + fd.close() + return False + + +def is_present(name): + """Checks if the given locale is currently installed.""" + output = Popen(["locale", "-a"], stdout=PIPE).communicate()[0] + output = to_native(output) + return any(fix_case(name) == fix_case(line) for line in output.splitlines()) + + +def fix_case(name): + """locale -a might return the encoding in either lower or upper case. + Passing through this function makes them uniform for comparisons.""" + for s, r in LOCALE_NORMALIZATION.items(): + name = name.replace(s, r) + return name + + +def replace_line(existing_line, new_line): + """Replaces lines in /etc/locale.gen""" + try: + f = open("/etc/locale.gen", "r") + lines = [line.replace(existing_line, new_line) for line in f] + finally: + f.close() + try: + f = open("/etc/locale.gen", "w") + f.write("".join(lines)) + finally: + f.close() + + +def set_locale(name, enabled=True): + """ Sets the state of the locale. Defaults to enabled. """ + search_string = r'#{0,1}\s*%s (?P<charset>.+)' % name + if enabled: + new_string = r'%s \g<charset>' % (name) + else: + new_string = r'# %s \g<charset>' % (name) + try: + f = open("/etc/locale.gen", "r") + lines = [re.sub(search_string, new_string, line) for line in f] + finally: + f.close() + try: + f = open("/etc/locale.gen", "w") + f.write("".join(lines)) + finally: + f.close() + + +def apply_change(targetState, name): + """Create or remove locale. + + Keyword arguments: + targetState -- Desired state, either present or absent. + name -- Name including encoding such as de_CH.UTF-8. + """ + if targetState == "present": + # Create locale. + set_locale(name, enabled=True) + else: + # Delete locale. + set_locale(name, enabled=False) + + localeGenExitValue = call("locale-gen") + if localeGenExitValue != 0: + raise EnvironmentError(localeGenExitValue, "locale.gen failed to execute, it returned " + str(localeGenExitValue)) + + +def apply_change_ubuntu(targetState, name): + """Create or remove locale. + + Keyword arguments: + targetState -- Desired state, either present or absent. + name -- Name including encoding such as de_CH.UTF-8. + """ + if targetState == "present": + # Create locale. + # Ubuntu's patched locale-gen automatically adds the new locale to /var/lib/locales/supported.d/local + localeGenExitValue = call(["locale-gen", name]) + else: + # Delete locale involves discarding the locale from /var/lib/locales/supported.d/local and regenerating all locales. + try: + f = open("/var/lib/locales/supported.d/local", "r") + content = f.readlines() + finally: + f.close() + try: + f = open("/var/lib/locales/supported.d/local", "w") + for line in content: + locale, charset = line.split(' ') + if locale != name: + f.write(line) + finally: + f.close() + # Purge locales and regenerate. + # Please provide a patch if you know how to avoid regenerating the locales to keep! + localeGenExitValue = call(["locale-gen", "--purge"]) + + if localeGenExitValue != 0: + raise EnvironmentError(localeGenExitValue, "locale.gen failed to execute, it returned " + str(localeGenExitValue)) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(type='str', required=True), + state=dict(type='str', default='present', choices=['absent', 'present']), + ), + supports_check_mode=True, + ) + + name = module.params['name'] + state = module.params['state'] + + if not os.path.exists("/etc/locale.gen"): + if os.path.exists("/var/lib/locales/supported.d/"): + # Ubuntu created its own system to manage locales. + ubuntuMode = True + else: + module.fail_json(msg="/etc/locale.gen and /var/lib/locales/supported.d/local are missing. Is the package \"locales\" installed?") + else: + # We found the common way to manage locales. + ubuntuMode = False + + if not is_available(name, ubuntuMode): + module.fail_json(msg="The locale you've entered is not available " + "on your system.") + + if is_present(name): + prev_state = "present" + else: + prev_state = "absent" + changed = (prev_state != state) + + if module.check_mode: + module.exit_json(changed=changed) + else: + if changed: + try: + if ubuntuMode is False: + apply_change(state, name) + else: + apply_change_ubuntu(state, name) + except EnvironmentError as e: + module.fail_json(msg=to_native(e), exitValue=e.errno) + + module.exit_json(name=name, changed=changed, msg="OK") + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/lvg.py b/test/support/integration/plugins/modules/lvg.py new file mode 100644 index 00000000..e2035f68 --- /dev/null +++ b/test/support/integration/plugins/modules/lvg.py @@ -0,0 +1,295 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2013, Alexander Bulimov <lazywolf0@gmail.com> +# Based on lvol module by Jeroen Hoekx <jeroen.hoekx@dsquare.be> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +author: +- Alexander Bulimov (@abulimov) +module: lvg +short_description: Configure LVM volume groups +description: + - This module creates, removes or resizes volume groups. +version_added: "1.1" +options: + vg: + description: + - The name of the volume group. + type: str + required: true + pvs: + description: + - List of comma-separated devices to use as physical devices in this volume group. + - Required when creating or resizing volume group. + - The module will take care of running pvcreate if needed. + type: list + pesize: + description: + - "The size of the physical extent. I(pesize) must be a power of 2 of at least 1 sector + (where the sector size is the largest sector size of the PVs currently used in the VG), + or at least 128KiB." + - Since Ansible 2.6, pesize can be optionally suffixed by a UNIT (k/K/m/M/g/G), default unit is megabyte. + type: str + default: "4" + pv_options: + description: + - Additional options to pass to C(pvcreate) when creating the volume group. + type: str + version_added: "2.4" + vg_options: + description: + - Additional options to pass to C(vgcreate) when creating the volume group. + type: str + version_added: "1.6" + state: + description: + - Control if the volume group exists. + type: str + choices: [ absent, present ] + default: present + force: + description: + - If C(yes), allows to remove volume group with logical volumes. + type: bool + default: no +seealso: +- module: filesystem +- module: lvol +- module: parted +notes: + - This module does not modify PE size for already present volume group. +''' + +EXAMPLES = r''' +- name: Create a volume group on top of /dev/sda1 with physical extent size = 32MB + lvg: + vg: vg.services + pvs: /dev/sda1 + pesize: 32 + +- name: Create a volume group on top of /dev/sdb with physical extent size = 128KiB + lvg: + vg: vg.services + pvs: /dev/sdb + pesize: 128K + +# If, for example, we already have VG vg.services on top of /dev/sdb1, +# this VG will be extended by /dev/sdc5. Or if vg.services was created on +# top of /dev/sda5, we first extend it with /dev/sdb1 and /dev/sdc5, +# and then reduce by /dev/sda5. +- name: Create or resize a volume group on top of /dev/sdb1 and /dev/sdc5. + lvg: + vg: vg.services + pvs: /dev/sdb1,/dev/sdc5 + +- name: Remove a volume group with name vg.services + lvg: + vg: vg.services + state: absent +''' + +import itertools +import os + +from ansible.module_utils.basic import AnsibleModule + + +def parse_vgs(data): + vgs = [] + for line in data.splitlines(): + parts = line.strip().split(';') + vgs.append({ + 'name': parts[0], + 'pv_count': int(parts[1]), + 'lv_count': int(parts[2]), + }) + return vgs + + +def find_mapper_device_name(module, dm_device): + dmsetup_cmd = module.get_bin_path('dmsetup', True) + mapper_prefix = '/dev/mapper/' + rc, dm_name, err = module.run_command("%s info -C --noheadings -o name %s" % (dmsetup_cmd, dm_device)) + if rc != 0: + module.fail_json(msg="Failed executing dmsetup command.", rc=rc, err=err) + mapper_device = mapper_prefix + dm_name.rstrip() + return mapper_device + + +def parse_pvs(module, data): + pvs = [] + dm_prefix = '/dev/dm-' + for line in data.splitlines(): + parts = line.strip().split(';') + if parts[0].startswith(dm_prefix): + parts[0] = find_mapper_device_name(module, parts[0]) + pvs.append({ + 'name': parts[0], + 'vg_name': parts[1], + }) + return pvs + + +def main(): + module = AnsibleModule( + argument_spec=dict( + vg=dict(type='str', required=True), + pvs=dict(type='list'), + pesize=dict(type='str', default='4'), + pv_options=dict(type='str', default=''), + vg_options=dict(type='str', default=''), + state=dict(type='str', default='present', choices=['absent', 'present']), + force=dict(type='bool', default=False), + ), + supports_check_mode=True, + ) + + vg = module.params['vg'] + state = module.params['state'] + force = module.boolean(module.params['force']) + pesize = module.params['pesize'] + pvoptions = module.params['pv_options'].split() + vgoptions = module.params['vg_options'].split() + + dev_list = [] + if module.params['pvs']: + dev_list = list(module.params['pvs']) + elif state == 'present': + module.fail_json(msg="No physical volumes given.") + + # LVM always uses real paths not symlinks so replace symlinks with actual path + for idx, dev in enumerate(dev_list): + dev_list[idx] = os.path.realpath(dev) + + if state == 'present': + # check given devices + for test_dev in dev_list: + if not os.path.exists(test_dev): + module.fail_json(msg="Device %s not found." % test_dev) + + # get pv list + pvs_cmd = module.get_bin_path('pvs', True) + if dev_list: + pvs_filter_pv_name = ' || '.join( + 'pv_name = {0}'.format(x) + for x in itertools.chain(dev_list, module.params['pvs']) + ) + pvs_filter_vg_name = 'vg_name = {0}'.format(vg) + pvs_filter = "--select '{0} || {1}' ".format(pvs_filter_pv_name, pvs_filter_vg_name) + else: + pvs_filter = '' + rc, current_pvs, err = module.run_command("%s --noheadings -o pv_name,vg_name --separator ';' %s" % (pvs_cmd, pvs_filter)) + if rc != 0: + module.fail_json(msg="Failed executing pvs command.", rc=rc, err=err) + + # check pv for devices + pvs = parse_pvs(module, current_pvs) + used_pvs = [pv for pv in pvs if pv['name'] in dev_list and pv['vg_name'] and pv['vg_name'] != vg] + if used_pvs: + module.fail_json(msg="Device %s is already in %s volume group." % (used_pvs[0]['name'], used_pvs[0]['vg_name'])) + + vgs_cmd = module.get_bin_path('vgs', True) + rc, current_vgs, err = module.run_command("%s --noheadings -o vg_name,pv_count,lv_count --separator ';'" % vgs_cmd) + + if rc != 0: + module.fail_json(msg="Failed executing vgs command.", rc=rc, err=err) + + changed = False + + vgs = parse_vgs(current_vgs) + + for test_vg in vgs: + if test_vg['name'] == vg: + this_vg = test_vg + break + else: + this_vg = None + + if this_vg is None: + if state == 'present': + # create VG + if module.check_mode: + changed = True + else: + # create PV + pvcreate_cmd = module.get_bin_path('pvcreate', True) + for current_dev in dev_list: + rc, _, err = module.run_command([pvcreate_cmd] + pvoptions + ['-f', str(current_dev)]) + if rc == 0: + changed = True + else: + module.fail_json(msg="Creating physical volume '%s' failed" % current_dev, rc=rc, err=err) + vgcreate_cmd = module.get_bin_path('vgcreate') + rc, _, err = module.run_command([vgcreate_cmd] + vgoptions + ['-s', pesize, vg] + dev_list) + if rc == 0: + changed = True + else: + module.fail_json(msg="Creating volume group '%s' failed" % vg, rc=rc, err=err) + else: + if state == 'absent': + if module.check_mode: + module.exit_json(changed=True) + else: + if this_vg['lv_count'] == 0 or force: + # remove VG + vgremove_cmd = module.get_bin_path('vgremove', True) + rc, _, err = module.run_command("%s --force %s" % (vgremove_cmd, vg)) + if rc == 0: + module.exit_json(changed=True) + else: + module.fail_json(msg="Failed to remove volume group %s" % (vg), rc=rc, err=err) + else: + module.fail_json(msg="Refuse to remove non-empty volume group %s without force=yes" % (vg)) + + # resize VG + current_devs = [os.path.realpath(pv['name']) for pv in pvs if pv['vg_name'] == vg] + devs_to_remove = list(set(current_devs) - set(dev_list)) + devs_to_add = list(set(dev_list) - set(current_devs)) + + if devs_to_add or devs_to_remove: + if module.check_mode: + changed = True + else: + if devs_to_add: + devs_to_add_string = ' '.join(devs_to_add) + # create PV + pvcreate_cmd = module.get_bin_path('pvcreate', True) + for current_dev in devs_to_add: + rc, _, err = module.run_command([pvcreate_cmd] + pvoptions + ['-f', str(current_dev)]) + if rc == 0: + changed = True + else: + module.fail_json(msg="Creating physical volume '%s' failed" % current_dev, rc=rc, err=err) + # add PV to our VG + vgextend_cmd = module.get_bin_path('vgextend', True) + rc, _, err = module.run_command("%s %s %s" % (vgextend_cmd, vg, devs_to_add_string)) + if rc == 0: + changed = True + else: + module.fail_json(msg="Unable to extend %s by %s." % (vg, devs_to_add_string), rc=rc, err=err) + + # remove some PV from our VG + if devs_to_remove: + devs_to_remove_string = ' '.join(devs_to_remove) + vgreduce_cmd = module.get_bin_path('vgreduce', True) + rc, _, err = module.run_command("%s --force %s %s" % (vgreduce_cmd, vg, devs_to_remove_string)) + if rc == 0: + changed = True + else: + module.fail_json(msg="Unable to reduce %s by %s." % (vg, devs_to_remove_string), rc=rc, err=err) + + module.exit_json(changed=changed) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/mongodb_parameter.py b/test/support/integration/plugins/modules/mongodb_parameter.py new file mode 100644 index 00000000..05de42b2 --- /dev/null +++ b/test/support/integration/plugins/modules/mongodb_parameter.py @@ -0,0 +1,223 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2016, Loic Blot <loic.blot@unix-experience.fr> +# Sponsored by Infopro Digital. http://www.infopro-digital.com/ +# Sponsored by E.T.A.I. http://www.etai.fr/ +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = r''' +--- +module: mongodb_parameter +short_description: Change an administrative parameter on a MongoDB server +description: + - Change an administrative parameter on a MongoDB server. +version_added: "2.1" +options: + login_user: + description: + - The MongoDB username used to authenticate with. + type: str + login_password: + description: + - The login user's password used to authenticate with. + type: str + login_host: + description: + - The host running the database. + type: str + default: localhost + login_port: + description: + - The MongoDB port to connect to. + default: 27017 + type: int + login_database: + description: + - The database where login credentials are stored. + type: str + replica_set: + description: + - Replica set to connect to (automatically connects to primary for writes). + type: str + ssl: + description: + - Whether to use an SSL connection when connecting to the database. + type: bool + default: no + param: + description: + - MongoDB administrative parameter to modify. + type: str + required: true + value: + description: + - MongoDB administrative parameter value to set. + type: str + required: true + param_type: + description: + - Define the type of parameter value. + default: str + type: str + choices: [int, str] + +notes: + - Requires the pymongo Python package on the remote host, version 2.4.2+. + - This can be installed using pip or the OS package manager. + - See also U(http://api.mongodb.org/python/current/installation.html) +requirements: [ "pymongo" ] +author: "Loic Blot (@nerzhul)" +''' + +EXAMPLES = r''' +- name: Set MongoDB syncdelay to 60 (this is an int) + mongodb_parameter: + param: syncdelay + value: 60 + param_type: int +''' + +RETURN = r''' +before: + description: value before modification + returned: success + type: str +after: + description: value after modification + returned: success + type: str +''' + +import os +import traceback + +try: + from pymongo.errors import ConnectionFailure + from pymongo.errors import OperationFailure + from pymongo import version as PyMongoVersion + from pymongo import MongoClient +except ImportError: + try: # for older PyMongo 2.2 + from pymongo import Connection as MongoClient + except ImportError: + pymongo_found = False + else: + pymongo_found = True +else: + pymongo_found = True + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.six.moves import configparser +from ansible.module_utils._text import to_native + + +# ========================================= +# MongoDB module specific support methods. +# + +def load_mongocnf(): + config = configparser.RawConfigParser() + mongocnf = os.path.expanduser('~/.mongodb.cnf') + + try: + config.readfp(open(mongocnf)) + creds = dict( + user=config.get('client', 'user'), + password=config.get('client', 'pass') + ) + except (configparser.NoOptionError, IOError): + return False + + return creds + + +# ========================================= +# Module execution. +# + +def main(): + module = AnsibleModule( + argument_spec=dict( + login_user=dict(default=None), + login_password=dict(default=None, no_log=True), + login_host=dict(default='localhost'), + login_port=dict(default=27017, type='int'), + login_database=dict(default=None), + replica_set=dict(default=None), + param=dict(required=True), + value=dict(required=True), + param_type=dict(default="str", choices=['str', 'int']), + ssl=dict(default=False, type='bool'), + ) + ) + + if not pymongo_found: + module.fail_json(msg=missing_required_lib('pymongo')) + + login_user = module.params['login_user'] + login_password = module.params['login_password'] + login_host = module.params['login_host'] + login_port = module.params['login_port'] + login_database = module.params['login_database'] + + replica_set = module.params['replica_set'] + ssl = module.params['ssl'] + + param = module.params['param'] + param_type = module.params['param_type'] + value = module.params['value'] + + # Verify parameter is coherent with specified type + try: + if param_type == 'int': + value = int(value) + except ValueError: + module.fail_json(msg="value '%s' is not %s" % (value, param_type)) + + try: + if replica_set: + client = MongoClient(login_host, int(login_port), replicaset=replica_set, ssl=ssl) + else: + client = MongoClient(login_host, int(login_port), ssl=ssl) + + if login_user is None and login_password is None: + mongocnf_creds = load_mongocnf() + if mongocnf_creds is not False: + login_user = mongocnf_creds['user'] + login_password = mongocnf_creds['password'] + elif login_password is None or login_user is None: + module.fail_json(msg='when supplying login arguments, both login_user and login_password must be provided') + + if login_user is not None and login_password is not None: + client.admin.authenticate(login_user, login_password, source=login_database) + + except ConnectionFailure as e: + module.fail_json(msg='unable to connect to database: %s' % to_native(e), exception=traceback.format_exc()) + + db = client.admin + + try: + after_value = db.command("setParameter", **{param: value}) + except OperationFailure as e: + module.fail_json(msg="unable to change parameter: %s" % to_native(e), exception=traceback.format_exc()) + + if "was" not in after_value: + module.exit_json(changed=True, msg="Unable to determine old value, assume it changed.") + else: + module.exit_json(changed=(value != after_value["was"]), before=after_value["was"], + after=value) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/mongodb_user.py b/test/support/integration/plugins/modules/mongodb_user.py new file mode 100644 index 00000000..362b3aa4 --- /dev/null +++ b/test/support/integration/plugins/modules/mongodb_user.py @@ -0,0 +1,474 @@ +#!/usr/bin/python + +# (c) 2012, Elliott Foster <elliott@fourkitchens.com> +# Sponsored by Four Kitchens http://fourkitchens.com. +# (c) 2014, Epic Games, Inc. +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: mongodb_user +short_description: Adds or removes a user from a MongoDB database +description: + - Adds or removes a user from a MongoDB database. +version_added: "1.1" +options: + login_user: + description: + - The MongoDB username used to authenticate with. + type: str + login_password: + description: + - The login user's password used to authenticate with. + type: str + login_host: + description: + - The host running the database. + default: localhost + type: str + login_port: + description: + - The MongoDB port to connect to. + default: '27017' + type: str + login_database: + version_added: "2.0" + description: + - The database where login credentials are stored. + type: str + replica_set: + version_added: "1.6" + description: + - Replica set to connect to (automatically connects to primary for writes). + type: str + database: + description: + - The name of the database to add/remove the user from. + required: true + type: str + aliases: [db] + name: + description: + - The name of the user to add or remove. + required: true + aliases: [user] + type: str + password: + description: + - The password to use for the user. + type: str + aliases: [pass] + ssl: + version_added: "1.8" + description: + - Whether to use an SSL connection when connecting to the database. + type: bool + ssl_cert_reqs: + version_added: "2.2" + description: + - Specifies whether a certificate is required from the other side of the connection, + and whether it will be validated if provided. + default: CERT_REQUIRED + choices: [CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED] + type: str + roles: + version_added: "1.3" + type: list + elements: raw + description: + - > + The database user roles valid values could either be one or more of the following strings: + 'read', 'readWrite', 'dbAdmin', 'userAdmin', 'clusterAdmin', 'readAnyDatabase', 'readWriteAnyDatabase', 'userAdminAnyDatabase', + 'dbAdminAnyDatabase' + - "Or the following dictionary '{ db: DATABASE_NAME, role: ROLE_NAME }'." + - "This param requires pymongo 2.5+. If it is a string, mongodb 2.4+ is also required. If it is a dictionary, mongo 2.6+ is required." + state: + description: + - The database user state. + default: present + choices: [absent, present] + type: str + update_password: + default: always + choices: [always, on_create] + version_added: "2.1" + description: + - C(always) will update passwords if they differ. + - C(on_create) will only set the password for newly created users. + type: str + +notes: + - Requires the pymongo Python package on the remote host, version 2.4.2+. This + can be installed using pip or the OS package manager. @see http://api.mongodb.org/python/current/installation.html +requirements: [ "pymongo" ] +author: + - "Elliott Foster (@elliotttf)" + - "Julien Thebault (@Lujeni)" +''' + +EXAMPLES = ''' +- name: Create 'burgers' database user with name 'bob' and password '12345'. + mongodb_user: + database: burgers + name: bob + password: 12345 + state: present + +- name: Create a database user via SSL (MongoDB must be compiled with the SSL option and configured properly) + mongodb_user: + database: burgers + name: bob + password: 12345 + state: present + ssl: True + +- name: Delete 'burgers' database user with name 'bob'. + mongodb_user: + database: burgers + name: bob + state: absent + +- name: Define more users with various specific roles (if not defined, no roles is assigned, and the user will be added via pre mongo 2.2 style) + mongodb_user: + database: burgers + name: ben + password: 12345 + roles: read + state: present + +- name: Define roles + mongodb_user: + database: burgers + name: jim + password: 12345 + roles: readWrite,dbAdmin,userAdmin + state: present + +- name: Define roles + mongodb_user: + database: burgers + name: joe + password: 12345 + roles: readWriteAnyDatabase + state: present + +- name: Add a user to database in a replica set, the primary server is automatically discovered and written to + mongodb_user: + database: burgers + name: bob + replica_set: belcher + password: 12345 + roles: readWriteAnyDatabase + state: present + +# add a user 'oplog_reader' with read only access to the 'local' database on the replica_set 'belcher'. This is useful for oplog access (MONGO_OPLOG_URL). +# please notice the credentials must be added to the 'admin' database because the 'local' database is not synchronized and can't receive user credentials +# To login with such user, the connection string should be MONGO_OPLOG_URL="mongodb://oplog_reader:oplog_reader_password@server1,server2/local?authSource=admin" +# This syntax requires mongodb 2.6+ and pymongo 2.5+ +- name: Roles as a dictionary + mongodb_user: + login_user: root + login_password: root_password + database: admin + user: oplog_reader + password: oplog_reader_password + state: present + replica_set: belcher + roles: + - db: local + role: read + +''' + +RETURN = ''' +user: + description: The name of the user to add or remove. + returned: success + type: str +''' + +import os +import ssl as ssl_lib +import traceback +from distutils.version import LooseVersion +from operator import itemgetter + +try: + from pymongo.errors import ConnectionFailure + from pymongo.errors import OperationFailure + from pymongo import version as PyMongoVersion + from pymongo import MongoClient +except ImportError: + try: # for older PyMongo 2.2 + from pymongo import Connection as MongoClient + except ImportError: + pymongo_found = False + else: + pymongo_found = True +else: + pymongo_found = True + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.six import binary_type, text_type +from ansible.module_utils.six.moves import configparser +from ansible.module_utils._text import to_native + + +# ========================================= +# MongoDB module specific support methods. +# + +def check_compatibility(module, client): + """Check the compatibility between the driver and the database. + + See: https://docs.mongodb.com/ecosystem/drivers/driver-compatibility-reference/#python-driver-compatibility + + Args: + module: Ansible module. + client (cursor): Mongodb cursor on admin database. + """ + loose_srv_version = LooseVersion(client.server_info()['version']) + loose_driver_version = LooseVersion(PyMongoVersion) + + if loose_srv_version >= LooseVersion('3.2') and loose_driver_version < LooseVersion('3.2'): + module.fail_json(msg=' (Note: you must use pymongo 3.2+ with MongoDB >= 3.2)') + + elif loose_srv_version >= LooseVersion('3.0') and loose_driver_version <= LooseVersion('2.8'): + module.fail_json(msg=' (Note: you must use pymongo 2.8+ with MongoDB 3.0)') + + elif loose_srv_version >= LooseVersion('2.6') and loose_driver_version <= LooseVersion('2.7'): + module.fail_json(msg=' (Note: you must use pymongo 2.7+ with MongoDB 2.6)') + + elif LooseVersion(PyMongoVersion) <= LooseVersion('2.5'): + module.fail_json(msg=' (Note: you must be on mongodb 2.4+ and pymongo 2.5+ to use the roles param)') + + +def user_find(client, user, db_name): + """Check if the user exists. + + Args: + client (cursor): Mongodb cursor on admin database. + user (str): User to check. + db_name (str): User's database. + + Returns: + dict: when user exists, False otherwise. + """ + for mongo_user in client["admin"].system.users.find(): + if mongo_user['user'] == user: + # NOTE: there is no 'db' field in mongo 2.4. + if 'db' not in mongo_user: + return mongo_user + + if mongo_user["db"] == db_name: + return mongo_user + return False + + +def user_add(module, client, db_name, user, password, roles): + # pymongo's user_add is a _create_or_update_user so we won't know if it was changed or updated + # without reproducing a lot of the logic in database.py of pymongo + db = client[db_name] + + if roles is None: + db.add_user(user, password, False) + else: + db.add_user(user, password, None, roles=roles) + + +def user_remove(module, client, db_name, user): + exists = user_find(client, user, db_name) + if exists: + if module.check_mode: + module.exit_json(changed=True, user=user) + db = client[db_name] + db.remove_user(user) + else: + module.exit_json(changed=False, user=user) + + +def load_mongocnf(): + config = configparser.RawConfigParser() + mongocnf = os.path.expanduser('~/.mongodb.cnf') + + try: + config.readfp(open(mongocnf)) + creds = dict( + user=config.get('client', 'user'), + password=config.get('client', 'pass') + ) + except (configparser.NoOptionError, IOError): + return False + + return creds + + +def check_if_roles_changed(uinfo, roles, db_name): + # We must be aware of users which can read the oplog on a replicaset + # Such users must have access to the local DB, but since this DB does not store users credentials + # and is not synchronized among replica sets, the user must be stored on the admin db + # Therefore their structure is the following : + # { + # "_id" : "admin.oplog_reader", + # "user" : "oplog_reader", + # "db" : "admin", # <-- admin DB + # "roles" : [ + # { + # "role" : "read", + # "db" : "local" # <-- local DB + # } + # ] + # } + + def make_sure_roles_are_a_list_of_dict(roles, db_name): + output = list() + for role in roles: + if isinstance(role, (binary_type, text_type)): + new_role = {"role": role, "db": db_name} + output.append(new_role) + else: + output.append(role) + return output + + roles_as_list_of_dict = make_sure_roles_are_a_list_of_dict(roles, db_name) + uinfo_roles = uinfo.get('roles', []) + + if sorted(roles_as_list_of_dict, key=itemgetter('db')) == sorted(uinfo_roles, key=itemgetter('db')): + return False + return True + + +# ========================================= +# Module execution. +# + +def main(): + module = AnsibleModule( + argument_spec=dict( + login_user=dict(default=None), + login_password=dict(default=None, no_log=True), + login_host=dict(default='localhost'), + login_port=dict(default='27017'), + login_database=dict(default=None), + replica_set=dict(default=None), + database=dict(required=True, aliases=['db']), + name=dict(required=True, aliases=['user']), + password=dict(aliases=['pass'], no_log=True), + ssl=dict(default=False, type='bool'), + roles=dict(default=None, type='list', elements='raw'), + state=dict(default='present', choices=['absent', 'present']), + update_password=dict(default="always", choices=["always", "on_create"]), + ssl_cert_reqs=dict(default='CERT_REQUIRED', choices=['CERT_NONE', 'CERT_OPTIONAL', 'CERT_REQUIRED']), + ), + supports_check_mode=True + ) + + if not pymongo_found: + module.fail_json(msg=missing_required_lib('pymongo')) + + login_user = module.params['login_user'] + login_password = module.params['login_password'] + login_host = module.params['login_host'] + login_port = module.params['login_port'] + login_database = module.params['login_database'] + + replica_set = module.params['replica_set'] + db_name = module.params['database'] + user = module.params['name'] + password = module.params['password'] + ssl = module.params['ssl'] + roles = module.params['roles'] or [] + state = module.params['state'] + update_password = module.params['update_password'] + + try: + connection_params = { + "host": login_host, + "port": int(login_port), + } + + if replica_set: + connection_params["replicaset"] = replica_set + + if ssl: + connection_params["ssl"] = ssl + connection_params["ssl_cert_reqs"] = getattr(ssl_lib, module.params['ssl_cert_reqs']) + + client = MongoClient(**connection_params) + + # NOTE: this check must be done ASAP. + # We doesn't need to be authenticated (this ability has lost in PyMongo 3.6) + if LooseVersion(PyMongoVersion) <= LooseVersion('3.5'): + check_compatibility(module, client) + + if login_user is None and login_password is None: + mongocnf_creds = load_mongocnf() + if mongocnf_creds is not False: + login_user = mongocnf_creds['user'] + login_password = mongocnf_creds['password'] + elif login_password is None or login_user is None: + module.fail_json(msg='when supplying login arguments, both login_user and login_password must be provided') + + if login_user is not None and login_password is not None: + client.admin.authenticate(login_user, login_password, source=login_database) + elif LooseVersion(PyMongoVersion) >= LooseVersion('3.0'): + if db_name != "admin": + module.fail_json(msg='The localhost login exception only allows the first admin account to be created') + # else: this has to be the first admin user added + + except Exception as e: + module.fail_json(msg='unable to connect to database: %s' % to_native(e), exception=traceback.format_exc()) + + if state == 'present': + if password is None and update_password == 'always': + module.fail_json(msg='password parameter required when adding a user unless update_password is set to on_create') + + try: + if update_password != 'always': + uinfo = user_find(client, user, db_name) + if uinfo: + password = None + if not check_if_roles_changed(uinfo, roles, db_name): + module.exit_json(changed=False, user=user) + + if module.check_mode: + module.exit_json(changed=True, user=user) + + user_add(module, client, db_name, user, password, roles) + except Exception as e: + module.fail_json(msg='Unable to add or update user: %s' % to_native(e), exception=traceback.format_exc()) + finally: + try: + client.close() + except Exception: + pass + # Here we can check password change if mongo provide a query for that : https://jira.mongodb.org/browse/SERVER-22848 + # newuinfo = user_find(client, user, db_name) + # if uinfo['role'] == newuinfo['role'] and CheckPasswordHere: + # module.exit_json(changed=False, user=user) + + elif state == 'absent': + try: + user_remove(module, client, db_name, user) + except Exception as e: + module.fail_json(msg='Unable to remove user: %s' % to_native(e), exception=traceback.format_exc()) + finally: + try: + client.close() + except Exception: + pass + module.exit_json(changed=True, user=user) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/pids.py b/test/support/integration/plugins/modules/pids.py new file mode 100644 index 00000000..4cbf45a9 --- /dev/null +++ b/test/support/integration/plugins/modules/pids.py @@ -0,0 +1,89 @@ +#!/usr/bin/python +# Copyright: (c) 2019, Saranya Sridharan +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +module: pids +version_added: 2.8 +description: "Retrieves a list of PIDs of given process name in Ansible controller/controlled machines.Returns an empty list if no process in that name exists." +short_description: "Retrieves process IDs list if the process is running otherwise return empty list" +author: + - Saranya Sridharan (@saranyasridharan) +requirements: + - psutil(python module) +options: + name: + description: the name of the process you want to get PID for. + required: true + type: str +''' + +EXAMPLES = ''' +# Pass the process name +- name: Getting process IDs of the process + pids: + name: python + register: pids_of_python + +- name: Printing the process IDs obtained + debug: + msg: "PIDS of python:{{pids_of_python.pids|join(',')}}" +''' + +RETURN = ''' +pids: + description: Process IDs of the given process + returned: list of none, one, or more process IDs + type: list + sample: [100,200] +''' + +from ansible.module_utils.basic import AnsibleModule +try: + import psutil + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + + +def compare_lower(a, b): + if a is None or b is None: + # this could just be "return False" but would lead to surprising behavior if both a and b are None + return a == b + + return a.lower() == b.lower() + + +def get_pid(name): + pids = [] + + for proc in psutil.process_iter(attrs=['name', 'cmdline']): + if compare_lower(proc.info['name'], name) or \ + proc.info['cmdline'] and compare_lower(proc.info['cmdline'][0], name): + pids.append(proc.pid) + + return pids + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(required=True, type="str"), + ), + supports_check_mode=True, + ) + if not HAS_PSUTIL: + module.fail_json(msg="Missing required 'psutil' python module. Try installing it with: pip install psutil") + name = module.params["name"] + response = dict(pids=get_pid(name)) + module.exit_json(**response) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/pkgng.py b/test/support/integration/plugins/modules/pkgng.py new file mode 100644 index 00000000..11363479 --- /dev/null +++ b/test/support/integration/plugins/modules/pkgng.py @@ -0,0 +1,406 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, bleader +# Written by bleader <bleader@ratonland.org> +# Based on pkgin module written by Shaun Zinck <shaun.zinck at gmail.com> +# that was based on pacman module written by Afterburn <https://github.com/afterburn> +# that was based on apt module written by Matthew Williams <matthew@flowroute.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: pkgng +short_description: Package manager for FreeBSD >= 9.0 +description: + - Manage binary packages for FreeBSD using 'pkgng' which is available in versions after 9.0. +version_added: "1.2" +options: + name: + description: + - Name or list of names of packages to install/remove. + required: true + state: + description: + - State of the package. + - 'Note: "latest" added in 2.7' + choices: [ 'present', 'latest', 'absent' ] + required: false + default: present + cached: + description: + - Use local package base instead of fetching an updated one. + type: bool + required: false + default: no + annotation: + description: + - A comma-separated list of keyvalue-pairs of the form + C(<+/-/:><key>[=<value>]). A C(+) denotes adding an annotation, a + C(-) denotes removing an annotation, and C(:) denotes modifying an + annotation. + If setting or modifying annotations, a value must be provided. + required: false + version_added: "1.6" + pkgsite: + description: + - For pkgng versions before 1.1.4, specify packagesite to use + for downloading packages. If not specified, use settings from + C(/usr/local/etc/pkg.conf). + - For newer pkgng versions, specify a the name of a repository + configured in C(/usr/local/etc/pkg/repos). + required: false + rootdir: + description: + - For pkgng versions 1.5 and later, pkg will install all packages + within the specified root directory. + - Can not be used together with I(chroot) or I(jail) options. + required: false + chroot: + version_added: "2.1" + description: + - Pkg will chroot in the specified environment. + - Can not be used together with I(rootdir) or I(jail) options. + required: false + jail: + version_added: "2.4" + description: + - Pkg will execute in the given jail name or id. + - Can not be used together with I(chroot) or I(rootdir) options. + autoremove: + version_added: "2.2" + description: + - Remove automatically installed packages which are no longer needed. + required: false + type: bool + default: no +author: "bleader (@bleader)" +notes: + - When using pkgsite, be careful that already in cache packages won't be downloaded again. + - When used with a `loop:` each package will be processed individually, + it is much more efficient to pass the list directly to the `name` option. +''' + +EXAMPLES = ''' +- name: Install package foo + pkgng: + name: foo + state: present + +- name: Annotate package foo and bar + pkgng: + name: foo,bar + annotation: '+test1=baz,-test2,:test3=foobar' + +- name: Remove packages foo and bar + pkgng: + name: foo,bar + state: absent + +# "latest" support added in 2.7 +- name: Upgrade package baz + pkgng: + name: baz + state: latest +''' + + +import re +from ansible.module_utils.basic import AnsibleModule + + +def query_package(module, pkgng_path, name, dir_arg): + + rc, out, err = module.run_command("%s %s info -g -e %s" % (pkgng_path, dir_arg, name)) + + if rc == 0: + return True + + return False + + +def query_update(module, pkgng_path, name, dir_arg, old_pkgng, pkgsite): + + # Check to see if a package upgrade is available. + # rc = 0, no updates available or package not installed + # rc = 1, updates available + if old_pkgng: + rc, out, err = module.run_command("%s %s upgrade -g -n %s" % (pkgsite, pkgng_path, name)) + else: + rc, out, err = module.run_command("%s %s upgrade %s -g -n %s" % (pkgng_path, dir_arg, pkgsite, name)) + + if rc == 1: + return True + + return False + + +def pkgng_older_than(module, pkgng_path, compare_version): + + rc, out, err = module.run_command("%s -v" % pkgng_path) + version = [int(x) for x in re.split(r'[\._]', out)] + + i = 0 + new_pkgng = True + while compare_version[i] == version[i]: + i += 1 + if i == min(len(compare_version), len(version)): + break + else: + if compare_version[i] > version[i]: + new_pkgng = False + return not new_pkgng + + +def remove_packages(module, pkgng_path, packages, dir_arg): + + remove_c = 0 + # Using a for loop in case of error, we can report the package that failed + for package in packages: + # Query the package first, to see if we even need to remove + if not query_package(module, pkgng_path, package, dir_arg): + continue + + if not module.check_mode: + rc, out, err = module.run_command("%s %s delete -y %s" % (pkgng_path, dir_arg, package)) + + if not module.check_mode and query_package(module, pkgng_path, package, dir_arg): + module.fail_json(msg="failed to remove %s: %s" % (package, out)) + + remove_c += 1 + + if remove_c > 0: + + return (True, "removed %s package(s)" % remove_c) + + return (False, "package(s) already absent") + + +def install_packages(module, pkgng_path, packages, cached, pkgsite, dir_arg, state): + + install_c = 0 + + # as of pkg-1.1.4, PACKAGESITE is deprecated in favor of repository definitions + # in /usr/local/etc/pkg/repos + old_pkgng = pkgng_older_than(module, pkgng_path, [1, 1, 4]) + if pkgsite != "": + if old_pkgng: + pkgsite = "PACKAGESITE=%s" % (pkgsite) + else: + pkgsite = "-r %s" % (pkgsite) + + # This environment variable skips mid-install prompts, + # setting them to their default values. + batch_var = 'env BATCH=yes' + + if not module.check_mode and not cached: + if old_pkgng: + rc, out, err = module.run_command("%s %s update" % (pkgsite, pkgng_path)) + else: + rc, out, err = module.run_command("%s %s update" % (pkgng_path, dir_arg)) + if rc != 0: + module.fail_json(msg="Could not update catalogue [%d]: %s %s" % (rc, out, err)) + + for package in packages: + already_installed = query_package(module, pkgng_path, package, dir_arg) + if already_installed and state == "present": + continue + + update_available = query_update(module, pkgng_path, package, dir_arg, old_pkgng, pkgsite) + if not update_available and already_installed and state == "latest": + continue + + if not module.check_mode: + if already_installed: + action = "upgrade" + else: + action = "install" + if old_pkgng: + rc, out, err = module.run_command("%s %s %s %s -g -U -y %s" % (batch_var, pkgsite, pkgng_path, action, package)) + else: + rc, out, err = module.run_command("%s %s %s %s %s -g -U -y %s" % (batch_var, pkgng_path, dir_arg, action, pkgsite, package)) + + if not module.check_mode and not query_package(module, pkgng_path, package, dir_arg): + module.fail_json(msg="failed to %s %s: %s" % (action, package, out), stderr=err) + + install_c += 1 + + if install_c > 0: + return (True, "added %s package(s)" % (install_c)) + + return (False, "package(s) already %s" % (state)) + + +def annotation_query(module, pkgng_path, package, tag, dir_arg): + rc, out, err = module.run_command("%s %s info -g -A %s" % (pkgng_path, dir_arg, package)) + match = re.search(r'^\s*(?P<tag>%s)\s*:\s*(?P<value>\w+)' % tag, out, flags=re.MULTILINE) + if match: + return match.group('value') + return False + + +def annotation_add(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if not _value: + # Annotation does not exist, add it. + rc, out, err = module.run_command('%s %s annotate -y -A %s %s "%s"' + % (pkgng_path, dir_arg, package, tag, value)) + if rc != 0: + module.fail_json(msg="could not annotate %s: %s" + % (package, out), stderr=err) + return True + elif _value != value: + # Annotation exists, but value differs + module.fail_json( + mgs="failed to annotate %s, because %s is already set to %s, but should be set to %s" + % (package, tag, _value, value)) + return False + else: + # Annotation exists, nothing to do + return False + + +def annotation_delete(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if _value: + rc, out, err = module.run_command('%s %s annotate -y -D %s %s' + % (pkgng_path, dir_arg, package, tag)) + if rc != 0: + module.fail_json(msg="could not delete annotation to %s: %s" + % (package, out), stderr=err) + return True + return False + + +def annotation_modify(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if not value: + # No such tag + module.fail_json(msg="could not change annotation to %s: tag %s does not exist" + % (package, tag)) + elif _value == value: + # No change in value + return False + else: + rc, out, err = module.run_command('%s %s annotate -y -M %s %s "%s"' + % (pkgng_path, dir_arg, package, tag, value)) + if rc != 0: + module.fail_json(msg="could not change annotation annotation to %s: %s" + % (package, out), stderr=err) + return True + + +def annotate_packages(module, pkgng_path, packages, annotation, dir_arg): + annotate_c = 0 + annotations = map(lambda _annotation: + re.match(r'(?P<operation>[\+-:])(?P<tag>\w+)(=(?P<value>\w+))?', + _annotation).groupdict(), + re.split(r',', annotation)) + + operation = { + '+': annotation_add, + '-': annotation_delete, + ':': annotation_modify + } + + for package in packages: + for _annotation in annotations: + if operation[_annotation['operation']](module, pkgng_path, package, _annotation['tag'], _annotation['value']): + annotate_c += 1 + + if annotate_c > 0: + return (True, "added %s annotations." % annotate_c) + return (False, "changed no annotations") + + +def autoremove_packages(module, pkgng_path, dir_arg): + rc, out, err = module.run_command("%s %s autoremove -n" % (pkgng_path, dir_arg)) + + autoremove_c = 0 + + match = re.search('^Deinstallation has been requested for the following ([0-9]+) packages', out, re.MULTILINE) + if match: + autoremove_c = int(match.group(1)) + + if autoremove_c == 0: + return False, "no package(s) to autoremove" + + if not module.check_mode: + rc, out, err = module.run_command("%s %s autoremove -y" % (pkgng_path, dir_arg)) + + return True, "autoremoved %d package(s)" % (autoremove_c) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + state=dict(default="present", choices=["present", "latest", "absent"], required=False), + name=dict(aliases=["pkg"], required=True, type='list'), + cached=dict(default=False, type='bool'), + annotation=dict(default="", required=False), + pkgsite=dict(default="", required=False), + rootdir=dict(default="", required=False, type='path'), + chroot=dict(default="", required=False, type='path'), + jail=dict(default="", required=False, type='str'), + autoremove=dict(default=False, type='bool')), + supports_check_mode=True, + mutually_exclusive=[["rootdir", "chroot", "jail"]]) + + pkgng_path = module.get_bin_path('pkg', True) + + p = module.params + + pkgs = p["name"] + + changed = False + msgs = [] + dir_arg = "" + + if p["rootdir"] != "": + old_pkgng = pkgng_older_than(module, pkgng_path, [1, 5, 0]) + if old_pkgng: + module.fail_json(msg="To use option 'rootdir' pkg version must be 1.5 or greater") + else: + dir_arg = "--rootdir %s" % (p["rootdir"]) + + if p["chroot"] != "": + dir_arg = '--chroot %s' % (p["chroot"]) + + if p["jail"] != "": + dir_arg = '--jail %s' % (p["jail"]) + + if p["state"] in ("present", "latest"): + _changed, _msg = install_packages(module, pkgng_path, pkgs, p["cached"], p["pkgsite"], dir_arg, p["state"]) + changed = changed or _changed + msgs.append(_msg) + + elif p["state"] == "absent": + _changed, _msg = remove_packages(module, pkgng_path, pkgs, dir_arg) + changed = changed or _changed + msgs.append(_msg) + + if p["autoremove"]: + _changed, _msg = autoremove_packages(module, pkgng_path, dir_arg) + changed = changed or _changed + msgs.append(_msg) + + if p["annotation"]: + _changed, _msg = annotate_packages(module, pkgng_path, pkgs, p["annotation"], dir_arg) + changed = changed or _changed + msgs.append(_msg) + + module.exit_json(changed=changed, msg=", ".join(msgs)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_db.py b/test/support/integration/plugins/modules/postgresql_db.py new file mode 100644 index 00000000..40858d99 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_db.py @@ -0,0 +1,657 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: postgresql_db +short_description: Add or remove PostgreSQL databases from a remote host. +description: + - Add or remove PostgreSQL databases from a remote host. +version_added: '0.6' +options: + name: + description: + - Name of the database to add or remove + type: str + required: true + aliases: [ db ] + port: + description: + - Database port to connect (if needed) + type: int + default: 5432 + aliases: + - login_port + owner: + description: + - Name of the role to set as owner of the database + type: str + template: + description: + - Template used to create the database + type: str + encoding: + description: + - Encoding of the database + type: str + lc_collate: + description: + - Collation order (LC_COLLATE) to use in the database. Must match collation order of template database unless C(template0) is used as template. + type: str + lc_ctype: + description: + - Character classification (LC_CTYPE) to use in the database (e.g. lower, upper, ...) Must match LC_CTYPE of template database unless C(template0) + is used as template. + type: str + session_role: + description: + - Switch to session_role after connecting. The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally. + type: str + version_added: '2.8' + state: + description: + - The database state. + - C(present) implies that the database should be created if necessary. + - C(absent) implies that the database should be removed if present. + - C(dump) requires a target definition to which the database will be backed up. (Added in Ansible 2.4) + Note that in some PostgreSQL versions of pg_dump, which is an embedded PostgreSQL utility and is used by the module, + returns rc 0 even when errors occurred (e.g. the connection is forbidden by pg_hba.conf, etc.), + so the module returns changed=True but the dump has not actually been done. Please, be sure that your version of + pg_dump returns rc 1 in this case. + - C(restore) also requires a target definition from which the database will be restored. (Added in Ansible 2.4) + - The format of the backup will be detected based on the target name. + - Supported compression formats for dump and restore include C(.pgc), C(.bz2), C(.gz) and C(.xz) + - Supported formats for dump and restore include C(.sql) and C(.tar) + type: str + choices: [ absent, dump, present, restore ] + default: present + target: + description: + - File to back up or restore from. + - Used when I(state) is C(dump) or C(restore). + type: path + version_added: '2.4' + target_opts: + description: + - Further arguments for pg_dump or pg_restore. + - Used when I(state) is C(dump) or C(restore). + type: str + version_added: '2.4' + maintenance_db: + description: + - The value specifies the initial database (which is also called as maintenance DB) that Ansible connects to. + type: str + default: postgres + version_added: '2.5' + conn_limit: + description: + - Specifies the database connection limit. + type: str + version_added: '2.8' + tablespace: + description: + - The tablespace to set for the database + U(https://www.postgresql.org/docs/current/sql-alterdatabase.html). + - If you want to move the database back to the default tablespace, + explicitly set this to pg_default. + type: path + version_added: '2.9' + dump_extra_args: + description: + - Provides additional arguments when I(state) is C(dump). + - Cannot be used with dump-file-format-related arguments like ``--format=d``. + type: str + version_added: '2.10' +seealso: +- name: CREATE DATABASE reference + description: Complete reference of the CREATE DATABASE command documentation. + link: https://www.postgresql.org/docs/current/sql-createdatabase.html +- name: DROP DATABASE reference + description: Complete reference of the DROP DATABASE command documentation. + link: https://www.postgresql.org/docs/current/sql-dropdatabase.html +- name: pg_dump reference + description: Complete reference of pg_dump documentation. + link: https://www.postgresql.org/docs/current/app-pgdump.html +- name: pg_restore reference + description: Complete reference of pg_restore documentation. + link: https://www.postgresql.org/docs/current/app-pgrestore.html +- module: postgresql_tablespace +- module: postgresql_info +- module: postgresql_ping +notes: +- State C(dump) and C(restore) don't require I(psycopg2) since version 2.8. +author: "Ansible Core Team" +extends_documentation_fragment: +- postgres +''' + +EXAMPLES = r''' +- name: Create a new database with name "acme" + postgresql_db: + name: acme + +# Note: If a template different from "template0" is specified, encoding and locale settings must match those of the template. +- name: Create a new database with name "acme" and specific encoding and locale # settings. + postgresql_db: + name: acme + encoding: UTF-8 + lc_collate: de_DE.UTF-8 + lc_ctype: de_DE.UTF-8 + template: template0 + +# Note: Default limit for the number of concurrent connections to a specific database is "-1", which means "unlimited" +- name: Create a new database with name "acme" which has a limit of 100 concurrent connections + postgresql_db: + name: acme + conn_limit: "100" + +- name: Dump an existing database to a file + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql + +- name: Dump an existing database to a file excluding the test table + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql + dump_extra_args: --exclude-table=test + +- name: Dump an existing database to a file (with compression) + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql.gz + +- name: Dump a single schema for an existing database + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql + target_opts: "-n public" + +# Note: In the example below, if database foo exists and has another tablespace +# the tablespace will be changed to foo. Access to the database will be locked +# until the copying of database files is finished. +- name: Create a new database called foo in tablespace bar + postgresql_db: + name: foo + tablespace: bar +''' + +RETURN = r''' +executed_commands: + description: List of commands which tried to run. + returned: always + type: list + sample: ["CREATE DATABASE acme"] + version_added: '2.10' +''' + + +import os +import subprocess +import traceback + +try: + import psycopg2 + import psycopg2.extras +except ImportError: + HAS_PSYCOPG2 = False +else: + HAS_PSYCOPG2 = True + +import ansible.module_utils.postgres as pgutils +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.database import SQLParseError, pg_quote_identifier +from ansible.module_utils.six import iteritems +from ansible.module_utils.six.moves import shlex_quote +from ansible.module_utils._text import to_native + +executed_commands = [] + + +class NotSupportedError(Exception): + pass + +# =========================================== +# PostgreSQL module specific support methods. +# + + +def set_owner(cursor, db, owner): + query = 'ALTER DATABASE %s OWNER TO "%s"' % ( + pg_quote_identifier(db, 'database'), + owner) + executed_commands.append(query) + cursor.execute(query) + return True + + +def set_conn_limit(cursor, db, conn_limit): + query = "ALTER DATABASE %s CONNECTION LIMIT %s" % ( + pg_quote_identifier(db, 'database'), + conn_limit) + executed_commands.append(query) + cursor.execute(query) + return True + + +def get_encoding_id(cursor, encoding): + query = "SELECT pg_char_to_encoding(%(encoding)s) AS encoding_id;" + cursor.execute(query, {'encoding': encoding}) + return cursor.fetchone()['encoding_id'] + + +def get_db_info(cursor, db): + query = """ + SELECT rolname AS owner, + pg_encoding_to_char(encoding) AS encoding, encoding AS encoding_id, + datcollate AS lc_collate, datctype AS lc_ctype, pg_database.datconnlimit AS conn_limit, + spcname AS tablespace + FROM pg_database + JOIN pg_roles ON pg_roles.oid = pg_database.datdba + JOIN pg_tablespace ON pg_tablespace.oid = pg_database.dattablespace + WHERE datname = %(db)s + """ + cursor.execute(query, {'db': db}) + return cursor.fetchone() + + +def db_exists(cursor, db): + query = "SELECT * FROM pg_database WHERE datname=%(db)s" + cursor.execute(query, {'db': db}) + return cursor.rowcount == 1 + + +def db_delete(cursor, db): + if db_exists(cursor, db): + query = "DROP DATABASE %s" % pg_quote_identifier(db, 'database') + executed_commands.append(query) + cursor.execute(query) + return True + else: + return False + + +def db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace): + params = dict(enc=encoding, collate=lc_collate, ctype=lc_ctype, conn_limit=conn_limit, tablespace=tablespace) + if not db_exists(cursor, db): + query_fragments = ['CREATE DATABASE %s' % pg_quote_identifier(db, 'database')] + if owner: + query_fragments.append('OWNER "%s"' % owner) + if template: + query_fragments.append('TEMPLATE %s' % pg_quote_identifier(template, 'database')) + if encoding: + query_fragments.append('ENCODING %(enc)s') + if lc_collate: + query_fragments.append('LC_COLLATE %(collate)s') + if lc_ctype: + query_fragments.append('LC_CTYPE %(ctype)s') + if tablespace: + query_fragments.append('TABLESPACE %s' % pg_quote_identifier(tablespace, 'tablespace')) + if conn_limit: + query_fragments.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit}) + query = ' '.join(query_fragments) + executed_commands.append(cursor.mogrify(query, params)) + cursor.execute(query, params) + return True + else: + db_info = get_db_info(cursor, db) + if (encoding and get_encoding_id(cursor, encoding) != db_info['encoding_id']): + raise NotSupportedError( + 'Changing database encoding is not supported. ' + 'Current encoding: %s' % db_info['encoding'] + ) + elif lc_collate and lc_collate != db_info['lc_collate']: + raise NotSupportedError( + 'Changing LC_COLLATE is not supported. ' + 'Current LC_COLLATE: %s' % db_info['lc_collate'] + ) + elif lc_ctype and lc_ctype != db_info['lc_ctype']: + raise NotSupportedError( + 'Changing LC_CTYPE is not supported.' + 'Current LC_CTYPE: %s' % db_info['lc_ctype'] + ) + else: + changed = False + + if owner and owner != db_info['owner']: + changed = set_owner(cursor, db, owner) + + if conn_limit and conn_limit != str(db_info['conn_limit']): + changed = set_conn_limit(cursor, db, conn_limit) + + if tablespace and tablespace != db_info['tablespace']: + changed = set_tablespace(cursor, db, tablespace) + + return changed + + +def db_matches(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace): + if not db_exists(cursor, db): + return False + else: + db_info = get_db_info(cursor, db) + if (encoding and get_encoding_id(cursor, encoding) != db_info['encoding_id']): + return False + elif lc_collate and lc_collate != db_info['lc_collate']: + return False + elif lc_ctype and lc_ctype != db_info['lc_ctype']: + return False + elif owner and owner != db_info['owner']: + return False + elif conn_limit and conn_limit != str(db_info['conn_limit']): + return False + elif tablespace and tablespace != db_info['tablespace']: + return False + else: + return True + + +def db_dump(module, target, target_opts="", + db=None, + dump_extra_args=None, + user=None, + password=None, + host=None, + port=None, + **kw): + + flags = login_flags(db, host, port, user, db_prefix=False) + cmd = module.get_bin_path('pg_dump', True) + comp_prog_path = None + + if os.path.splitext(target)[-1] == '.tar': + flags.append(' --format=t') + elif os.path.splitext(target)[-1] == '.pgc': + flags.append(' --format=c') + if os.path.splitext(target)[-1] == '.gz': + if module.get_bin_path('pigz'): + comp_prog_path = module.get_bin_path('pigz', True) + else: + comp_prog_path = module.get_bin_path('gzip', True) + elif os.path.splitext(target)[-1] == '.bz2': + comp_prog_path = module.get_bin_path('bzip2', True) + elif os.path.splitext(target)[-1] == '.xz': + comp_prog_path = module.get_bin_path('xz', True) + + cmd += "".join(flags) + + if dump_extra_args: + cmd += " {0} ".format(dump_extra_args) + + if target_opts: + cmd += " {0} ".format(target_opts) + + if comp_prog_path: + # Use a fifo to be notified of an error in pg_dump + # Using shell pipe has no way to return the code of the first command + # in a portable way. + fifo = os.path.join(module.tmpdir, 'pg_fifo') + os.mkfifo(fifo) + cmd = '{1} <{3} > {2} & {0} >{3}'.format(cmd, comp_prog_path, shlex_quote(target), fifo) + else: + cmd = '{0} > {1}'.format(cmd, shlex_quote(target)) + + return do_with_password(module, cmd, password) + + +def db_restore(module, target, target_opts="", + db=None, + user=None, + password=None, + host=None, + port=None, + **kw): + + flags = login_flags(db, host, port, user) + comp_prog_path = None + cmd = module.get_bin_path('psql', True) + + if os.path.splitext(target)[-1] == '.sql': + flags.append(' --file={0}'.format(target)) + + elif os.path.splitext(target)[-1] == '.tar': + flags.append(' --format=Tar') + cmd = module.get_bin_path('pg_restore', True) + + elif os.path.splitext(target)[-1] == '.pgc': + flags.append(' --format=Custom') + cmd = module.get_bin_path('pg_restore', True) + + elif os.path.splitext(target)[-1] == '.gz': + comp_prog_path = module.get_bin_path('zcat', True) + + elif os.path.splitext(target)[-1] == '.bz2': + comp_prog_path = module.get_bin_path('bzcat', True) + + elif os.path.splitext(target)[-1] == '.xz': + comp_prog_path = module.get_bin_path('xzcat', True) + + cmd += "".join(flags) + if target_opts: + cmd += " {0} ".format(target_opts) + + if comp_prog_path: + env = os.environ.copy() + if password: + env = {"PGPASSWORD": password} + p1 = subprocess.Popen([comp_prog_path, target], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p2 = subprocess.Popen(cmd, stdin=p1.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=env) + (stdout2, stderr2) = p2.communicate() + p1.stdout.close() + p1.wait() + if p1.returncode != 0: + stderr1 = p1.stderr.read() + return p1.returncode, '', stderr1, 'cmd: ****' + else: + return p2.returncode, '', stderr2, 'cmd: ****' + else: + cmd = '{0} < {1}'.format(cmd, shlex_quote(target)) + + return do_with_password(module, cmd, password) + + +def login_flags(db, host, port, user, db_prefix=True): + """ + returns a list of connection argument strings each prefixed + with a space and quoted where necessary to later be combined + in a single shell string with `"".join(rv)` + + db_prefix determines if "--dbname" is prefixed to the db argument, + since the argument was introduced in 9.3. + """ + flags = [] + if db: + if db_prefix: + flags.append(' --dbname={0}'.format(shlex_quote(db))) + else: + flags.append(' {0}'.format(shlex_quote(db))) + if host: + flags.append(' --host={0}'.format(host)) + if port: + flags.append(' --port={0}'.format(port)) + if user: + flags.append(' --username={0}'.format(user)) + return flags + + +def do_with_password(module, cmd, password): + env = {} + if password: + env = {"PGPASSWORD": password} + executed_commands.append(cmd) + rc, stderr, stdout = module.run_command(cmd, use_unsafe_shell=True, environ_update=env) + return rc, stderr, stdout, cmd + + +def set_tablespace(cursor, db, tablespace): + query = "ALTER DATABASE %s SET TABLESPACE %s" % ( + pg_quote_identifier(db, 'database'), + pg_quote_identifier(tablespace, 'tablespace')) + executed_commands.append(query) + cursor.execute(query) + return True + +# =========================================== +# Module execution. +# + + +def main(): + argument_spec = pgutils.postgres_common_argument_spec() + argument_spec.update( + db=dict(type='str', required=True, aliases=['name']), + owner=dict(type='str', default=''), + template=dict(type='str', default=''), + encoding=dict(type='str', default=''), + lc_collate=dict(type='str', default=''), + lc_ctype=dict(type='str', default=''), + state=dict(type='str', default='present', choices=['absent', 'dump', 'present', 'restore']), + target=dict(type='path', default=''), + target_opts=dict(type='str', default=''), + maintenance_db=dict(type='str', default="postgres"), + session_role=dict(type='str'), + conn_limit=dict(type='str', default=''), + tablespace=dict(type='path', default=''), + dump_extra_args=dict(type='str', default=None), + ) + + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True + ) + + db = module.params["db"] + owner = module.params["owner"] + template = module.params["template"] + encoding = module.params["encoding"] + lc_collate = module.params["lc_collate"] + lc_ctype = module.params["lc_ctype"] + target = module.params["target"] + target_opts = module.params["target_opts"] + state = module.params["state"] + changed = False + maintenance_db = module.params['maintenance_db'] + session_role = module.params["session_role"] + conn_limit = module.params['conn_limit'] + tablespace = module.params['tablespace'] + dump_extra_args = module.params['dump_extra_args'] + + raw_connection = state in ("dump", "restore") + + if not raw_connection: + pgutils.ensure_required_libs(module) + + # To use defaults values, keyword arguments must be absent, so + # check which values are empty and don't include in the **kw + # dictionary + params_map = { + "login_host": "host", + "login_user": "user", + "login_password": "password", + "port": "port", + "ssl_mode": "sslmode", + "ca_cert": "sslrootcert" + } + kw = dict((params_map[k], v) for (k, v) in iteritems(module.params) + if k in params_map and v != '' and v is not None) + + # If a login_unix_socket is specified, incorporate it here. + is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost" + + if is_localhost and module.params["login_unix_socket"] != "": + kw["host"] = module.params["login_unix_socket"] + + if target == "": + target = "{0}/{1}.sql".format(os.getcwd(), db) + target = os.path.expanduser(target) + + if not raw_connection: + try: + db_connection = psycopg2.connect(database=maintenance_db, **kw) + + # Enable autocommit so we can create databases + if psycopg2.__version__ >= '2.4.2': + db_connection.autocommit = True + else: + db_connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + cursor = db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) + + except TypeError as e: + if 'sslrootcert' in e.args[0]: + module.fail_json(msg='Postgresql server must be at least version 8.4 to support sslrootcert. Exception: {0}'.format(to_native(e)), + exception=traceback.format_exc()) + module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc()) + + except Exception as e: + module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc()) + + if session_role: + try: + cursor.execute('SET ROLE "%s"' % session_role) + except Exception as e: + module.fail_json(msg="Could not switch role: %s" % to_native(e), exception=traceback.format_exc()) + + try: + if module.check_mode: + if state == "absent": + changed = db_exists(cursor, db) + elif state == "present": + changed = not db_matches(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace) + module.exit_json(changed=changed, db=db, executed_commands=executed_commands) + + if state == "absent": + try: + changed = db_delete(cursor, db) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + elif state == "present": + try: + changed = db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + elif state in ("dump", "restore"): + method = state == "dump" and db_dump or db_restore + try: + if state == 'dump': + rc, stdout, stderr, cmd = method(module, target, target_opts, db, dump_extra_args, **kw) + else: + rc, stdout, stderr, cmd = method(module, target, target_opts, db, **kw) + + if rc != 0: + module.fail_json(msg=stderr, stdout=stdout, rc=rc, cmd=cmd) + else: + module.exit_json(changed=True, msg=stdout, stderr=stderr, rc=rc, cmd=cmd, + executed_commands=executed_commands) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + except NotSupportedError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + except SystemExit: + # Avoid catching this on Python 2.4 + raise + except Exception as e: + module.fail_json(msg="Database query failed: %s" % to_native(e), exception=traceback.format_exc()) + + module.exit_json(changed=changed, db=db, executed_commands=executed_commands) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_privs.py b/test/support/integration/plugins/modules/postgresql_privs.py new file mode 100644 index 00000000..ba8324dd --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_privs.py @@ -0,0 +1,1097 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# Copyright: (c) 2019, Tobias Birkefeld (@tcraxs) <t@craxs.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: postgresql_privs +version_added: '1.2' +short_description: Grant or revoke privileges on PostgreSQL database objects +description: +- Grant or revoke privileges on PostgreSQL database objects. +- This module is basically a wrapper around most of the functionality of + PostgreSQL's GRANT and REVOKE statements with detection of changes + (GRANT/REVOKE I(privs) ON I(type) I(objs) TO/FROM I(roles)). +options: + database: + description: + - Name of database to connect to. + required: yes + type: str + aliases: + - db + - login_db + state: + description: + - If C(present), the specified privileges are granted, if C(absent) they are revoked. + type: str + default: present + choices: [ absent, present ] + privs: + description: + - Comma separated list of privileges to grant/revoke. + type: str + aliases: + - priv + type: + description: + - Type of database object to set privileges on. + - The C(default_privs) choice is available starting at version 2.7. + - The C(foreign_data_wrapper) and C(foreign_server) object types are available from Ansible version '2.8'. + - The C(type) choice is available from Ansible version '2.10'. + type: str + default: table + choices: [ database, default_privs, foreign_data_wrapper, foreign_server, function, + group, language, table, tablespace, schema, sequence, type ] + objs: + description: + - Comma separated list of database objects to set privileges on. + - If I(type) is C(table), C(partition table), C(sequence) or C(function), + the special valueC(ALL_IN_SCHEMA) can be provided instead to specify all + database objects of type I(type) in the schema specified via I(schema). + (This also works with PostgreSQL < 9.0.) (C(ALL_IN_SCHEMA) is available + for C(function) and C(partition table) from version 2.8) + - If I(type) is C(database), this parameter can be omitted, in which case + privileges are set for the database specified via I(database). + - 'If I(type) is I(function), colons (":") in object names will be + replaced with commas (needed to specify function signatures, see examples)' + type: str + aliases: + - obj + schema: + description: + - Schema that contains the database objects specified via I(objs). + - May only be provided if I(type) is C(table), C(sequence), C(function), C(type), + or C(default_privs). Defaults to C(public) in these cases. + - Pay attention, for embedded types when I(type=type) + I(schema) can be C(pg_catalog) or C(information_schema) respectively. + type: str + roles: + description: + - Comma separated list of role (user/group) names to set permissions for. + - The special value C(PUBLIC) can be provided instead to set permissions + for the implicitly defined PUBLIC group. + type: str + required: yes + aliases: + - role + fail_on_role: + version_added: '2.8' + description: + - If C(yes), fail when target role (for whom privs need to be granted) does not exist. + Otherwise just warn and continue. + default: yes + type: bool + session_role: + version_added: '2.8' + description: + - Switch to session_role after connecting. + - The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally. + type: str + target_roles: + description: + - A list of existing role (user/group) names to set as the + default permissions for database objects subsequently created by them. + - Parameter I(target_roles) is only available with C(type=default_privs). + type: str + version_added: '2.8' + grant_option: + description: + - Whether C(role) may grant/revoke the specified privileges/group memberships to others. + - Set to C(no) to revoke GRANT OPTION, leave unspecified to make no changes. + - I(grant_option) only has an effect if I(state) is C(present). + type: bool + aliases: + - admin_option + host: + description: + - Database host address. If unspecified, connect via Unix socket. + type: str + aliases: + - login_host + port: + description: + - Database port to connect to. + type: int + default: 5432 + aliases: + - login_port + unix_socket: + description: + - Path to a Unix domain socket for local connections. + type: str + aliases: + - login_unix_socket + login: + description: + - The username to authenticate with. + type: str + default: postgres + aliases: + - login_user + password: + description: + - The password to authenticate with. + type: str + aliases: + - login_password + ssl_mode: + description: + - Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated with the server. + - See https://www.postgresql.org/docs/current/static/libpq-ssl.html for more information on the modes. + - Default of C(prefer) matches libpq default. + type: str + default: prefer + choices: [ allow, disable, prefer, require, verify-ca, verify-full ] + version_added: '2.3' + ca_cert: + description: + - Specifies the name of a file containing SSL certificate authority (CA) certificate(s). + - If the file exists, the server's certificate will be verified to be signed by one of these authorities. + version_added: '2.3' + type: str + aliases: + - ssl_rootcert + +notes: +- Parameters that accept comma separated lists (I(privs), I(objs), I(roles)) + have singular alias names (I(priv), I(obj), I(role)). +- To revoke only C(GRANT OPTION) for a specific object, set I(state) to + C(present) and I(grant_option) to C(no) (see examples). +- Note that when revoking privileges from a role R, this role may still have + access via privileges granted to any role R is a member of including C(PUBLIC). +- Note that when revoking privileges from a role R, you do so as the user + specified via I(login). If R has been granted the same privileges by + another user also, R can still access database objects via these privileges. +- When revoking privileges, C(RESTRICT) is assumed (see PostgreSQL docs). + +seealso: +- module: postgresql_user +- module: postgresql_owner +- module: postgresql_membership +- name: PostgreSQL privileges + description: General information about PostgreSQL privileges. + link: https://www.postgresql.org/docs/current/ddl-priv.html +- name: PostgreSQL GRANT command reference + description: Complete reference of the PostgreSQL GRANT command documentation. + link: https://www.postgresql.org/docs/current/sql-grant.html +- name: PostgreSQL REVOKE command reference + description: Complete reference of the PostgreSQL REVOKE command documentation. + link: https://www.postgresql.org/docs/current/sql-revoke.html + +extends_documentation_fragment: +- postgres + +author: +- Bernhard Weitzhofer (@b6d) +- Tobias Birkefeld (@tcraxs) +''' + +EXAMPLES = r''' +# On database "library": +# GRANT SELECT, INSERT, UPDATE ON TABLE public.books, public.authors +# TO librarian, reader WITH GRANT OPTION +- name: Grant privs to librarian and reader on database library + postgresql_privs: + database: library + state: present + privs: SELECT,INSERT,UPDATE + type: table + objs: books,authors + schema: public + roles: librarian,reader + grant_option: yes + +- name: Same as above leveraging default values + postgresql_privs: + db: library + privs: SELECT,INSERT,UPDATE + objs: books,authors + roles: librarian,reader + grant_option: yes + +# REVOKE GRANT OPTION FOR INSERT ON TABLE books FROM reader +# Note that role "reader" will be *granted* INSERT privilege itself if this +# isn't already the case (since state: present). +- name: Revoke privs from reader + postgresql_privs: + db: library + state: present + priv: INSERT + obj: books + role: reader + grant_option: no + +# "public" is the default schema. This also works for PostgreSQL 8.x. +- name: REVOKE INSERT, UPDATE ON ALL TABLES IN SCHEMA public FROM reader + postgresql_privs: + db: library + state: absent + privs: INSERT,UPDATE + objs: ALL_IN_SCHEMA + role: reader + +- name: GRANT ALL PRIVILEGES ON SCHEMA public, math TO librarian + postgresql_privs: + db: library + privs: ALL + type: schema + objs: public,math + role: librarian + +# Note the separation of arguments with colons. +- name: GRANT ALL PRIVILEGES ON FUNCTION math.add(int, int) TO librarian, reader + postgresql_privs: + db: library + privs: ALL + type: function + obj: add(int:int) + schema: math + roles: librarian,reader + +# Note that group role memberships apply cluster-wide and therefore are not +# restricted to database "library" here. +- name: GRANT librarian, reader TO alice, bob WITH ADMIN OPTION + postgresql_privs: + db: library + type: group + objs: librarian,reader + roles: alice,bob + admin_option: yes + +# Note that here "db: postgres" specifies the database to connect to, not the +# database to grant privileges on (which is specified via the "objs" param) +- name: GRANT ALL PRIVILEGES ON DATABASE library TO librarian + postgresql_privs: + db: postgres + privs: ALL + type: database + obj: library + role: librarian + +# If objs is omitted for type "database", it defaults to the database +# to which the connection is established +- name: GRANT ALL PRIVILEGES ON DATABASE library TO librarian + postgresql_privs: + db: library + privs: ALL + type: database + role: librarian + +# Available since version 2.7 +# Objs must be set, ALL_DEFAULT to TABLES/SEQUENCES/TYPES/FUNCTIONS +# ALL_DEFAULT works only with privs=ALL +# For specific +- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO librarian + postgresql_privs: + db: library + objs: ALL_DEFAULT + privs: ALL + type: default_privs + role: librarian + grant_option: yes + +# Available since version 2.7 +# Objs must be set, ALL_DEFAULT to TABLES/SEQUENCES/TYPES/FUNCTIONS +# ALL_DEFAULT works only with privs=ALL +# For specific +- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO reader, step 1 + postgresql_privs: + db: library + objs: TABLES,SEQUENCES + privs: SELECT + type: default_privs + role: reader + +- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO reader, step 2 + postgresql_privs: + db: library + objs: TYPES + privs: USAGE + type: default_privs + role: reader + +# Available since version 2.8 +- name: GRANT ALL PRIVILEGES ON FOREIGN DATA WRAPPER fdw TO reader + postgresql_privs: + db: test + objs: fdw + privs: ALL + type: foreign_data_wrapper + role: reader + +# Available since version 2.10 +- name: GRANT ALL PRIVILEGES ON TYPE customtype TO reader + postgresql_privs: + db: test + objs: customtype + privs: ALL + type: type + role: reader + +# Available since version 2.8 +- name: GRANT ALL PRIVILEGES ON FOREIGN SERVER fdw_server TO reader + postgresql_privs: + db: test + objs: fdw_server + privs: ALL + type: foreign_server + role: reader + +# Available since version 2.8 +# Grant 'execute' permissions on all functions in schema 'common' to role 'caller' +- name: GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA common TO caller + postgresql_privs: + type: function + state: present + privs: EXECUTE + roles: caller + objs: ALL_IN_SCHEMA + schema: common + +# Available since version 2.8 +# ALTER DEFAULT PRIVILEGES FOR ROLE librarian IN SCHEMA library GRANT SELECT ON TABLES TO reader +# GRANT SELECT privileges for new TABLES objects created by librarian as +# default to the role reader. +# For specific +- name: ALTER privs + postgresql_privs: + db: library + schema: library + objs: TABLES + privs: SELECT + type: default_privs + role: reader + target_roles: librarian + +# Available since version 2.8 +# ALTER DEFAULT PRIVILEGES FOR ROLE librarian IN SCHEMA library REVOKE SELECT ON TABLES FROM reader +# REVOKE SELECT privileges for new TABLES objects created by librarian as +# default from the role reader. +# For specific +- name: ALTER privs + postgresql_privs: + db: library + state: absent + schema: library + objs: TABLES + privs: SELECT + type: default_privs + role: reader + target_roles: librarian + +# Available since version 2.10 +- name: Grant type privileges for pg_catalog.numeric type to alice + postgresql_privs: + type: type + roles: alice + privs: ALL + objs: numeric + schema: pg_catalog + db: acme +''' + +RETURN = r''' +queries: + description: List of executed queries. + returned: always + type: list + sample: ['REVOKE GRANT OPTION FOR INSERT ON TABLE "books" FROM "reader";'] + version_added: '2.8' +''' + +import traceback + +PSYCOPG2_IMP_ERR = None +try: + import psycopg2 + import psycopg2.extensions +except ImportError: + PSYCOPG2_IMP_ERR = traceback.format_exc() + psycopg2 = None + +# import module snippets +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.database import pg_quote_identifier +from ansible.module_utils.postgres import postgres_common_argument_spec +from ansible.module_utils._text import to_native + +VALID_PRIVS = frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', + 'REFERENCES', 'TRIGGER', 'CREATE', 'CONNECT', + 'TEMPORARY', 'TEMP', 'EXECUTE', 'USAGE', 'ALL', 'USAGE')) +VALID_DEFAULT_OBJS = {'TABLES': ('ALL', 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER'), + 'SEQUENCES': ('ALL', 'SELECT', 'UPDATE', 'USAGE'), + 'FUNCTIONS': ('ALL', 'EXECUTE'), + 'TYPES': ('ALL', 'USAGE')} + +executed_queries = [] + + +class Error(Exception): + pass + + +def role_exists(module, cursor, rolname): + """Check user exists or not""" + query = "SELECT 1 FROM pg_roles WHERE rolname = '%s'" % rolname + try: + cursor.execute(query) + return cursor.rowcount > 0 + + except Exception as e: + module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e))) + + return False + + +# We don't have functools.partial in Python < 2.5 +def partial(f, *args, **kwargs): + """Partial function application""" + + def g(*g_args, **g_kwargs): + new_kwargs = kwargs.copy() + new_kwargs.update(g_kwargs) + return f(*(args + g_args), **g_kwargs) + + g.f = f + g.args = args + g.kwargs = kwargs + return g + + +class Connection(object): + """Wrapper around a psycopg2 connection with some convenience methods""" + + def __init__(self, params, module): + self.database = params.database + self.module = module + # To use defaults values, keyword arguments must be absent, so + # check which values are empty and don't include in the **kw + # dictionary + params_map = { + "host": "host", + "login": "user", + "password": "password", + "port": "port", + "database": "database", + "ssl_mode": "sslmode", + "ca_cert": "sslrootcert" + } + + kw = dict((params_map[k], getattr(params, k)) for k in params_map + if getattr(params, k) != '' and getattr(params, k) is not None) + + # If a unix_socket is specified, incorporate it here. + is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost" + if is_localhost and params.unix_socket != "": + kw["host"] = params.unix_socket + + sslrootcert = params.ca_cert + if psycopg2.__version__ < '2.4.3' and sslrootcert is not None: + raise ValueError('psycopg2 must be at least 2.4.3 in order to user the ca_cert parameter') + + self.connection = psycopg2.connect(**kw) + self.cursor = self.connection.cursor() + + def commit(self): + self.connection.commit() + + def rollback(self): + self.connection.rollback() + + @property + def encoding(self): + """Connection encoding in Python-compatible form""" + return psycopg2.extensions.encodings[self.connection.encoding] + + # Methods for querying database objects + + # PostgreSQL < 9.0 doesn't support "ALL TABLES IN SCHEMA schema"-like + # phrases in GRANT or REVOKE statements, therefore alternative methods are + # provided here. + + def schema_exists(self, schema): + query = """SELECT count(*) + FROM pg_catalog.pg_namespace WHERE nspname = %s""" + self.cursor.execute(query, (schema,)) + return self.cursor.fetchone()[0] > 0 + + def get_all_tables_in_schema(self, schema): + if not self.schema_exists(schema): + raise Error('Schema "%s" does not exist.' % schema) + query = """SELECT relname + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind in ('r', 'v', 'm', 'p')""" + self.cursor.execute(query, (schema,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_all_sequences_in_schema(self, schema): + if not self.schema_exists(schema): + raise Error('Schema "%s" does not exist.' % schema) + query = """SELECT relname + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind = 'S'""" + self.cursor.execute(query, (schema,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_all_functions_in_schema(self, schema): + if not self.schema_exists(schema): + raise Error('Schema "%s" does not exist.' % schema) + query = """SELECT p.proname, oidvectortypes(p.proargtypes) + FROM pg_catalog.pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE nspname = %s""" + self.cursor.execute(query, (schema,)) + return ["%s(%s)" % (t[0], t[1]) for t in self.cursor.fetchall()] + + # Methods for getting access control lists and group membership info + + # To determine whether anything has changed after granting/revoking + # privileges, we compare the access control lists of the specified database + # objects before and afterwards. Python's list/string comparison should + # suffice for change detection, we should not actually have to parse ACLs. + # The same should apply to group membership information. + + def get_table_acls(self, schema, tables): + query = """SELECT relacl + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind in ('r','p','v','m') AND relname = ANY (%s) + ORDER BY relname""" + self.cursor.execute(query, (schema, tables)) + return [t[0] for t in self.cursor.fetchall()] + + def get_sequence_acls(self, schema, sequences): + query = """SELECT relacl + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind = 'S' AND relname = ANY (%s) + ORDER BY relname""" + self.cursor.execute(query, (schema, sequences)) + return [t[0] for t in self.cursor.fetchall()] + + def get_function_acls(self, schema, function_signatures): + funcnames = [f.split('(', 1)[0] for f in function_signatures] + query = """SELECT proacl + FROM pg_catalog.pg_proc p + JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + WHERE nspname = %s AND proname = ANY (%s) + ORDER BY proname, proargtypes""" + self.cursor.execute(query, (schema, funcnames)) + return [t[0] for t in self.cursor.fetchall()] + + def get_schema_acls(self, schemas): + query = """SELECT nspacl FROM pg_catalog.pg_namespace + WHERE nspname = ANY (%s) ORDER BY nspname""" + self.cursor.execute(query, (schemas,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_language_acls(self, languages): + query = """SELECT lanacl FROM pg_catalog.pg_language + WHERE lanname = ANY (%s) ORDER BY lanname""" + self.cursor.execute(query, (languages,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_tablespace_acls(self, tablespaces): + query = """SELECT spcacl FROM pg_catalog.pg_tablespace + WHERE spcname = ANY (%s) ORDER BY spcname""" + self.cursor.execute(query, (tablespaces,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_database_acls(self, databases): + query = """SELECT datacl FROM pg_catalog.pg_database + WHERE datname = ANY (%s) ORDER BY datname""" + self.cursor.execute(query, (databases,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_group_memberships(self, groups): + query = """SELECT roleid, grantor, member, admin_option + FROM pg_catalog.pg_auth_members am + JOIN pg_catalog.pg_roles r ON r.oid = am.roleid + WHERE r.rolname = ANY(%s) + ORDER BY roleid, grantor, member""" + self.cursor.execute(query, (groups,)) + return self.cursor.fetchall() + + def get_default_privs(self, schema, *args): + query = """SELECT defaclacl + FROM pg_default_acl a + JOIN pg_namespace b ON a.defaclnamespace=b.oid + WHERE b.nspname = %s;""" + self.cursor.execute(query, (schema,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_foreign_data_wrapper_acls(self, fdws): + query = """SELECT fdwacl FROM pg_catalog.pg_foreign_data_wrapper + WHERE fdwname = ANY (%s) ORDER BY fdwname""" + self.cursor.execute(query, (fdws,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_foreign_server_acls(self, fs): + query = """SELECT srvacl FROM pg_catalog.pg_foreign_server + WHERE srvname = ANY (%s) ORDER BY srvname""" + self.cursor.execute(query, (fs,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_type_acls(self, schema, types): + query = """SELECT t.typacl FROM pg_catalog.pg_type t + JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE n.nspname = %s AND t.typname = ANY (%s) ORDER BY typname""" + self.cursor.execute(query, (schema, types)) + return [t[0] for t in self.cursor.fetchall()] + + # Manipulating privileges + + def manipulate_privs(self, obj_type, privs, objs, roles, target_roles, + state, grant_option, schema_qualifier=None, fail_on_role=True): + """Manipulate database object privileges. + + :param obj_type: Type of database object to grant/revoke + privileges for. + :param privs: Either a list of privileges to grant/revoke + or None if type is "group". + :param objs: List of database objects to grant/revoke + privileges for. + :param roles: Either a list of role names or "PUBLIC" + for the implicitly defined "PUBLIC" group + :param target_roles: List of role names to grant/revoke + default privileges as. + :param state: "present" to grant privileges, "absent" to revoke. + :param grant_option: Only for state "present": If True, set + grant/admin option. If False, revoke it. + If None, don't change grant option. + :param schema_qualifier: Some object types ("TABLE", "SEQUENCE", + "FUNCTION") must be qualified by schema. + Ignored for other Types. + """ + # get_status: function to get current status + if obj_type == 'table': + get_status = partial(self.get_table_acls, schema_qualifier) + elif obj_type == 'sequence': + get_status = partial(self.get_sequence_acls, schema_qualifier) + elif obj_type == 'function': + get_status = partial(self.get_function_acls, schema_qualifier) + elif obj_type == 'schema': + get_status = self.get_schema_acls + elif obj_type == 'language': + get_status = self.get_language_acls + elif obj_type == 'tablespace': + get_status = self.get_tablespace_acls + elif obj_type == 'database': + get_status = self.get_database_acls + elif obj_type == 'group': + get_status = self.get_group_memberships + elif obj_type == 'default_privs': + get_status = partial(self.get_default_privs, schema_qualifier) + elif obj_type == 'foreign_data_wrapper': + get_status = self.get_foreign_data_wrapper_acls + elif obj_type == 'foreign_server': + get_status = self.get_foreign_server_acls + elif obj_type == 'type': + get_status = partial(self.get_type_acls, schema_qualifier) + else: + raise Error('Unsupported database object type "%s".' % obj_type) + + # Return False (nothing has changed) if there are no objs to work on. + if not objs: + return False + + # obj_ids: quoted db object identifiers (sometimes schema-qualified) + if obj_type == 'function': + obj_ids = [] + for obj in objs: + try: + f, args = obj.split('(', 1) + except Exception: + raise Error('Illegal function signature: "%s".' % obj) + obj_ids.append('"%s"."%s"(%s' % (schema_qualifier, f, args)) + elif obj_type in ['table', 'sequence', 'type']: + obj_ids = ['"%s"."%s"' % (schema_qualifier, o) for o in objs] + else: + obj_ids = ['"%s"' % o for o in objs] + + # set_what: SQL-fragment specifying what to set for the target roles: + # Either group membership or privileges on objects of a certain type + if obj_type == 'group': + set_what = ','.join('"%s"' % i for i in obj_ids) + elif obj_type == 'default_privs': + # We don't want privs to be quoted here + set_what = ','.join(privs) + else: + # function types are already quoted above + if obj_type != 'function': + obj_ids = [pg_quote_identifier(i, 'table') for i in obj_ids] + # Note: obj_type has been checked against a set of string literals + # and privs was escaped when it was parsed + # Note: Underscores are replaced with spaces to support multi-word obj_type + set_what = '%s ON %s %s' % (','.join(privs), obj_type.replace('_', ' '), + ','.join(obj_ids)) + + # for_whom: SQL-fragment specifying for whom to set the above + if roles == 'PUBLIC': + for_whom = 'PUBLIC' + else: + for_whom = [] + for r in roles: + if not role_exists(self.module, self.cursor, r): + if fail_on_role: + self.module.fail_json(msg="Role '%s' does not exist" % r.strip()) + + else: + self.module.warn("Role '%s' does not exist, pass it" % r.strip()) + else: + for_whom.append('"%s"' % r) + + if not for_whom: + return False + + for_whom = ','.join(for_whom) + + # as_who: + as_who = None + if target_roles: + as_who = ','.join('"%s"' % r for r in target_roles) + + status_before = get_status(objs) + + query = QueryBuilder(state) \ + .for_objtype(obj_type) \ + .with_grant_option(grant_option) \ + .for_whom(for_whom) \ + .as_who(as_who) \ + .for_schema(schema_qualifier) \ + .set_what(set_what) \ + .for_objs(objs) \ + .build() + + executed_queries.append(query) + self.cursor.execute(query) + status_after = get_status(objs) + + def nonesorted(e): + # For python 3+ that can fail trying + # to compare NoneType elements by sort method. + if e is None: + return '' + return e + + status_before.sort(key=nonesorted) + status_after.sort(key=nonesorted) + return status_before != status_after + + +class QueryBuilder(object): + def __init__(self, state): + self._grant_option = None + self._for_whom = None + self._as_who = None + self._set_what = None + self._obj_type = None + self._state = state + self._schema = None + self._objs = None + self.query = [] + + def for_objs(self, objs): + self._objs = objs + return self + + def for_schema(self, schema): + self._schema = schema + return self + + def with_grant_option(self, option): + self._grant_option = option + return self + + def for_whom(self, who): + self._for_whom = who + return self + + def as_who(self, target_roles): + self._as_who = target_roles + return self + + def set_what(self, what): + self._set_what = what + return self + + def for_objtype(self, objtype): + self._obj_type = objtype + return self + + def build(self): + if self._state == 'present': + self.build_present() + elif self._state == 'absent': + self.build_absent() + else: + self.build_absent() + return '\n'.join(self.query) + + def add_default_revoke(self): + for obj in self._objs: + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} REVOKE ALL ON {2} FROM {3};'.format(self._as_who, + self._schema, obj, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} REVOKE ALL ON {1} FROM {2};'.format(self._schema, obj, + self._for_whom)) + + def add_grant_option(self): + if self._grant_option: + if self._obj_type == 'group': + self.query[-1] += ' WITH ADMIN OPTION;' + else: + self.query[-1] += ' WITH GRANT OPTION;' + else: + self.query[-1] += ';' + if self._obj_type == 'group': + self.query.append('REVOKE ADMIN OPTION FOR {0} FROM {1};'.format(self._set_what, self._for_whom)) + elif not self._obj_type == 'default_privs': + self.query.append('REVOKE GRANT OPTION FOR {0} FROM {1};'.format(self._set_what, self._for_whom)) + + def add_default_priv(self): + for obj in self._objs: + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} GRANT {2} ON {3} TO {4}'.format(self._as_who, + self._schema, + self._set_what, + obj, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} GRANT {1} ON {2} TO {3}'.format(self._schema, + self._set_what, + obj, + self._for_whom)) + self.add_grant_option() + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} GRANT USAGE ON TYPES TO {2}'.format(self._as_who, + self._schema, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} GRANT USAGE ON TYPES TO {1}'.format(self._schema, self._for_whom)) + self.add_grant_option() + + def build_present(self): + if self._obj_type == 'default_privs': + self.add_default_revoke() + self.add_default_priv() + else: + self.query.append('GRANT {0} TO {1}'.format(self._set_what, self._for_whom)) + self.add_grant_option() + + def build_absent(self): + if self._obj_type == 'default_privs': + self.query = [] + for obj in ['TABLES', 'SEQUENCES', 'TYPES']: + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} REVOKE ALL ON {2} FROM {3};'.format(self._as_who, + self._schema, obj, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} REVOKE ALL ON {1} FROM {2};'.format(self._schema, obj, + self._for_whom)) + else: + self.query.append('REVOKE {0} FROM {1};'.format(self._set_what, self._for_whom)) + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + database=dict(required=True, aliases=['db', 'login_db']), + state=dict(default='present', choices=['present', 'absent']), + privs=dict(required=False, aliases=['priv']), + type=dict(default='table', + choices=['table', + 'sequence', + 'function', + 'database', + 'schema', + 'language', + 'tablespace', + 'group', + 'default_privs', + 'foreign_data_wrapper', + 'foreign_server', + 'type', ]), + objs=dict(required=False, aliases=['obj']), + schema=dict(required=False), + roles=dict(required=True, aliases=['role']), + session_role=dict(required=False), + target_roles=dict(required=False), + grant_option=dict(required=False, type='bool', + aliases=['admin_option']), + host=dict(default='', aliases=['login_host']), + unix_socket=dict(default='', aliases=['login_unix_socket']), + login=dict(default='postgres', aliases=['login_user']), + password=dict(default='', aliases=['login_password'], no_log=True), + fail_on_role=dict(type='bool', default=True), + ) + + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True, + ) + + fail_on_role = module.params['fail_on_role'] + + # Create type object as namespace for module params + p = type('Params', (), module.params) + # param "schema": default, allowed depends on param "type" + if p.type in ['table', 'sequence', 'function', 'type', 'default_privs']: + p.schema = p.schema or 'public' + elif p.schema: + module.fail_json(msg='Argument "schema" is not allowed ' + 'for type "%s".' % p.type) + + # param "objs": default, required depends on param "type" + if p.type == 'database': + p.objs = p.objs or p.database + elif not p.objs: + module.fail_json(msg='Argument "objs" is required ' + 'for type "%s".' % p.type) + + # param "privs": allowed, required depends on param "type" + if p.type == 'group': + if p.privs: + module.fail_json(msg='Argument "privs" is not allowed ' + 'for type "group".') + elif not p.privs: + module.fail_json(msg='Argument "privs" is required ' + 'for type "%s".' % p.type) + + # Connect to Database + if not psycopg2: + module.fail_json(msg=missing_required_lib('psycopg2'), exception=PSYCOPG2_IMP_ERR) + try: + conn = Connection(p, module) + except psycopg2.Error as e: + module.fail_json(msg='Could not connect to database: %s' % to_native(e), exception=traceback.format_exc()) + except TypeError as e: + if 'sslrootcert' in e.args[0]: + module.fail_json(msg='Postgresql server must be at least version 8.4 to support sslrootcert') + module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc()) + except ValueError as e: + # We raise this when the psycopg library is too old + module.fail_json(msg=to_native(e)) + + if p.session_role: + try: + conn.cursor.execute('SET ROLE "%s"' % p.session_role) + except Exception as e: + module.fail_json(msg="Could not switch to role %s: %s" % (p.session_role, to_native(e)), exception=traceback.format_exc()) + + try: + # privs + if p.privs: + privs = frozenset(pr.upper() for pr in p.privs.split(',')) + if not privs.issubset(VALID_PRIVS): + module.fail_json(msg='Invalid privileges specified: %s' % privs.difference(VALID_PRIVS)) + else: + privs = None + # objs: + if p.type == 'table' and p.objs == 'ALL_IN_SCHEMA': + objs = conn.get_all_tables_in_schema(p.schema) + elif p.type == 'sequence' and p.objs == 'ALL_IN_SCHEMA': + objs = conn.get_all_sequences_in_schema(p.schema) + elif p.type == 'function' and p.objs == 'ALL_IN_SCHEMA': + objs = conn.get_all_functions_in_schema(p.schema) + elif p.type == 'default_privs': + if p.objs == 'ALL_DEFAULT': + objs = frozenset(VALID_DEFAULT_OBJS.keys()) + else: + objs = frozenset(obj.upper() for obj in p.objs.split(',')) + if not objs.issubset(VALID_DEFAULT_OBJS): + module.fail_json( + msg='Invalid Object set specified: %s' % objs.difference(VALID_DEFAULT_OBJS.keys())) + # Again, do we have valid privs specified for object type: + valid_objects_for_priv = frozenset(obj for obj in objs if privs.issubset(VALID_DEFAULT_OBJS[obj])) + if not valid_objects_for_priv == objs: + module.fail_json( + msg='Invalid priv specified. Valid object for priv: {0}. Objects: {1}'.format( + valid_objects_for_priv, objs)) + else: + objs = p.objs.split(',') + + # function signatures are encoded using ':' to separate args + if p.type == 'function': + objs = [obj.replace(':', ',') for obj in objs] + + # roles + if p.roles == 'PUBLIC': + roles = 'PUBLIC' + else: + roles = p.roles.split(',') + + if len(roles) == 1 and not role_exists(module, conn.cursor, roles[0]): + module.exit_json(changed=False) + + if fail_on_role: + module.fail_json(msg="Role '%s' does not exist" % roles[0].strip()) + + else: + module.warn("Role '%s' does not exist, nothing to do" % roles[0].strip()) + + # check if target_roles is set with type: default_privs + if p.target_roles and not p.type == 'default_privs': + module.warn('"target_roles" will be ignored ' + 'Argument "type: default_privs" is required for usage of "target_roles".') + + # target roles + if p.target_roles: + target_roles = p.target_roles.split(',') + else: + target_roles = None + + changed = conn.manipulate_privs( + obj_type=p.type, + privs=privs, + objs=objs, + roles=roles, + target_roles=target_roles, + state=p.state, + grant_option=p.grant_option, + schema_qualifier=p.schema, + fail_on_role=fail_on_role, + ) + + except Error as e: + conn.rollback() + module.fail_json(msg=e.message, exception=traceback.format_exc()) + + except psycopg2.Error as e: + conn.rollback() + module.fail_json(msg=to_native(e.message)) + + if module.check_mode: + conn.rollback() + else: + conn.commit() + module.exit_json(changed=changed, queries=executed_queries) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_query.py b/test/support/integration/plugins/modules/postgresql_query.py new file mode 100644 index 00000000..18d63e33 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_query.py @@ -0,0 +1,364 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Felix Archambault +# Copyright: (c) 2019, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'supported_by': 'community', + 'status': ['preview'] +} + +DOCUMENTATION = r''' +--- +module: postgresql_query +short_description: Run PostgreSQL queries +description: +- Runs arbitrary PostgreSQL queries. +- Can run queries from SQL script files. +- Does not run against backup files. Use M(postgresql_db) with I(state=restore) + to run queries on files made by pg_dump/pg_dumpall utilities. +version_added: '2.8' +options: + query: + description: + - SQL query to run. Variables can be escaped with psycopg2 syntax + U(http://initd.org/psycopg/docs/usage.html). + type: str + positional_args: + description: + - List of values to be passed as positional arguments to the query. + When the value is a list, it will be converted to PostgreSQL array. + - Mutually exclusive with I(named_args). + type: list + elements: raw + named_args: + description: + - Dictionary of key-value arguments to pass to the query. + When the value is a list, it will be converted to PostgreSQL array. + - Mutually exclusive with I(positional_args). + type: dict + path_to_script: + description: + - Path to SQL script on the remote host. + - Returns result of the last query in the script. + - Mutually exclusive with I(query). + type: path + session_role: + description: + - Switch to session_role after connecting. The specified session_role must + be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though + the session_role were the one that had logged in originally. + type: str + db: + description: + - Name of database to connect to and run queries against. + type: str + aliases: + - login_db + autocommit: + description: + - Execute in autocommit mode when the query can't be run inside a transaction block + (e.g., VACUUM). + - Mutually exclusive with I(check_mode). + type: bool + default: no + version_added: '2.9' + encoding: + description: + - Set the client encoding for the current session (e.g. C(UTF-8)). + - The default is the encoding defined by the database. + type: str + version_added: '2.10' +seealso: +- module: postgresql_db +author: +- Felix Archambault (@archf) +- Andrew Klychkov (@Andersson007) +- Will Rouesnel (@wrouesnel) +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Simple select query to acme db + postgresql_query: + db: acme + query: SELECT version() + +- name: Select query to db acme with positional arguments and non-default credentials + postgresql_query: + db: acme + login_user: django + login_password: mysecretpass + query: SELECT * FROM acme WHERE id = %s AND story = %s + positional_args: + - 1 + - test + +- name: Select query to test_db with named_args + postgresql_query: + db: test_db + query: SELECT * FROM test WHERE id = %(id_val)s AND story = %(story_val)s + named_args: + id_val: 1 + story_val: test + +- name: Insert query to test_table in db test_db + postgresql_query: + db: test_db + query: INSERT INTO test_table (id, story) VALUES (2, 'my_long_story') + +- name: Run queries from SQL script using UTF-8 client encoding for session + postgresql_query: + db: test_db + path_to_script: /var/lib/pgsql/test.sql + positional_args: + - 1 + encoding: UTF-8 + +- name: Example of using autocommit parameter + postgresql_query: + db: test_db + query: VACUUM + autocommit: yes + +- name: > + Insert data to the column of array type using positional_args. + Note that we use quotes here, the same as for passing JSON, etc. + postgresql_query: + query: INSERT INTO test_table (array_column) VALUES (%s) + positional_args: + - '{1,2,3}' + +# Pass list and string vars as positional_args +- name: Set vars + set_fact: + my_list: + - 1 + - 2 + - 3 + my_arr: '{1, 2, 3}' + +- name: Select from test table by passing positional_args as arrays + postgresql_query: + query: SELECT * FROM test_array_table WHERE arr_col1 = %s AND arr_col2 = %s + positional_args: + - '{{ my_list }}' + - '{{ my_arr|string }}' +''' + +RETURN = r''' +query: + description: Query that was tried to be executed. + returned: always + type: str + sample: 'SELECT * FROM bar' +statusmessage: + description: Attribute containing the message returned by the command. + returned: always + type: str + sample: 'INSERT 0 1' +query_result: + description: + - List of dictionaries in column:value form representing returned rows. + returned: changed + type: list + sample: [{"Column": "Value1"},{"Column": "Value2"}] +rowcount: + description: Number of affected rows. + returned: changed + type: int + sample: 5 +''' + +try: + from psycopg2 import ProgrammingError as Psycopg2ProgrammingError + from psycopg2.extras import DictCursor +except ImportError: + # it is needed for checking 'no result to fetch' in main(), + # psycopg2 availability will be checked by connect_to_db() into + # ansible.module_utils.postgres + pass + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) +from ansible.module_utils._text import to_native +from ansible.module_utils.six import iteritems + + +# =========================================== +# Module execution. +# + +def list_to_pg_array(elem): + """Convert the passed list to PostgreSQL array + represented as a string. + + Args: + elem (list): List that needs to be converted. + + Returns: + elem (str): String representation of PostgreSQL array. + """ + elem = str(elem).strip('[]') + elem = '{' + elem + '}' + return elem + + +def convert_elements_to_pg_arrays(obj): + """Convert list elements of the passed object + to PostgreSQL arrays represented as strings. + + Args: + obj (dict or list): Object whose elements need to be converted. + + Returns: + obj (dict or list): Object with converted elements. + """ + if isinstance(obj, dict): + for (key, elem) in iteritems(obj): + if isinstance(elem, list): + obj[key] = list_to_pg_array(elem) + + elif isinstance(obj, list): + for i, elem in enumerate(obj): + if isinstance(elem, list): + obj[i] = list_to_pg_array(elem) + + return obj + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + query=dict(type='str'), + db=dict(type='str', aliases=['login_db']), + positional_args=dict(type='list', elements='raw'), + named_args=dict(type='dict'), + session_role=dict(type='str'), + path_to_script=dict(type='path'), + autocommit=dict(type='bool', default=False), + encoding=dict(type='str'), + ) + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=(('positional_args', 'named_args'),), + supports_check_mode=True, + ) + + query = module.params["query"] + positional_args = module.params["positional_args"] + named_args = module.params["named_args"] + path_to_script = module.params["path_to_script"] + autocommit = module.params["autocommit"] + encoding = module.params["encoding"] + + if autocommit and module.check_mode: + module.fail_json(msg="Using autocommit is mutually exclusive with check_mode") + + if path_to_script and query: + module.fail_json(msg="path_to_script is mutually exclusive with query") + + if positional_args: + positional_args = convert_elements_to_pg_arrays(positional_args) + + elif named_args: + named_args = convert_elements_to_pg_arrays(named_args) + + if path_to_script: + try: + with open(path_to_script, 'rb') as f: + query = to_native(f.read()) + except Exception as e: + module.fail_json(msg="Cannot read file '%s' : %s" % (path_to_script, to_native(e))) + + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=autocommit) + if encoding is not None: + db_connection.set_client_encoding(encoding) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + # Prepare args: + if module.params.get("positional_args"): + arguments = module.params["positional_args"] + elif module.params.get("named_args"): + arguments = module.params["named_args"] + else: + arguments = None + + # Set defaults: + changed = False + + # Execute query: + try: + cursor.execute(query, arguments) + except Exception as e: + if not autocommit: + db_connection.rollback() + + cursor.close() + db_connection.close() + module.fail_json(msg="Cannot execute SQL '%s' %s: %s" % (query, arguments, to_native(e))) + + statusmessage = cursor.statusmessage + rowcount = cursor.rowcount + + try: + query_result = [dict(row) for row in cursor.fetchall()] + except Psycopg2ProgrammingError as e: + if to_native(e) == 'no results to fetch': + query_result = {} + + except Exception as e: + module.fail_json(msg="Cannot fetch rows from cursor: %s" % to_native(e)) + + if 'SELECT' not in statusmessage: + if 'UPDATE' in statusmessage or 'INSERT' in statusmessage or 'DELETE' in statusmessage: + s = statusmessage.split() + if len(s) == 3: + if statusmessage.split()[2] != '0': + changed = True + + elif len(s) == 2: + if statusmessage.split()[1] != '0': + changed = True + + else: + changed = True + + else: + changed = True + + if module.check_mode: + db_connection.rollback() + else: + if not autocommit: + db_connection.commit() + + kw = dict( + changed=changed, + query=cursor.query, + statusmessage=statusmessage, + query_result=query_result, + rowcount=rowcount if rowcount >= 0 else 0, + ) + + cursor.close() + db_connection.close() + + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_set.py b/test/support/integration/plugins/modules/postgresql_set.py new file mode 100644 index 00000000..cfbdae64 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_set.py @@ -0,0 +1,434 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2018, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community' +} + +DOCUMENTATION = r''' +--- +module: postgresql_set +short_description: Change a PostgreSQL server configuration parameter +description: + - Allows to change a PostgreSQL server configuration parameter. + - The module uses ALTER SYSTEM command and applies changes by reload server configuration. + - ALTER SYSTEM is used for changing server configuration parameters across the entire database cluster. + - It can be more convenient and safe than the traditional method of manually editing the postgresql.conf file. + - ALTER SYSTEM writes the given parameter setting to the $PGDATA/postgresql.auto.conf file, + which is read in addition to postgresql.conf. + - The module allows to reset parameter to boot_val (cluster initial value) by I(reset=yes) or remove parameter + string from postgresql.auto.conf and reload I(value=default) (for settings with postmaster context restart is required). + - After change you can see in the ansible output the previous and + the new parameter value and other information using returned values and M(debug) module. +version_added: '2.8' +options: + name: + description: + - Name of PostgreSQL server parameter. + type: str + required: true + value: + description: + - Parameter value to set. + - To remove parameter string from postgresql.auto.conf and + reload the server configuration you must pass I(value=default). + With I(value=default) the playbook always returns changed is true. + type: str + reset: + description: + - Restore parameter to initial state (boot_val). Mutually exclusive with I(value). + type: bool + default: false + session_role: + description: + - Switch to session_role after connecting. The specified session_role must + be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though + the session_role were the one that had logged in originally. + type: str + db: + description: + - Name of database to connect. + type: str + aliases: + - login_db +notes: +- Supported version of PostgreSQL is 9.4 and later. +- Pay attention, change setting with 'postmaster' context can return changed is true + when actually nothing changes because the same value may be presented in + several different form, for example, 1024MB, 1GB, etc. However in pg_settings + system view it can be defined like 131072 number of 8kB pages. + The final check of the parameter value cannot compare it because the server was + not restarted and the value in pg_settings is not updated yet. +- For some parameters restart of PostgreSQL server is required. + See official documentation U(https://www.postgresql.org/docs/current/view-pg-settings.html). +seealso: +- module: postgresql_info +- name: PostgreSQL server configuration + description: General information about PostgreSQL server configuration. + link: https://www.postgresql.org/docs/current/runtime-config.html +- name: PostgreSQL view pg_settings reference + description: Complete reference of the pg_settings view documentation. + link: https://www.postgresql.org/docs/current/view-pg-settings.html +- name: PostgreSQL ALTER SYSTEM command reference + description: Complete reference of the ALTER SYSTEM command documentation. + link: https://www.postgresql.org/docs/current/sql-altersystem.html +author: +- Andrew Klychkov (@Andersson007) +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Restore wal_keep_segments parameter to initial state + postgresql_set: + name: wal_keep_segments + reset: yes + +# Set work_mem parameter to 32MB and show what's been changed and restart is required or not +# (output example: "msg": "work_mem 4MB >> 64MB restart_req: False") +- name: Set work mem parameter + postgresql_set: + name: work_mem + value: 32mb + register: set + +- debug: + msg: "{{ set.name }} {{ set.prev_val_pretty }} >> {{ set.value_pretty }} restart_req: {{ set.restart_required }}" + when: set.changed +# Ensure that the restart of PostgreSQL server must be required for some parameters. +# In this situation you see the same parameter in prev_val and value_prettyue, but 'changed=True' +# (If you passed the value that was different from the current server setting). + +- name: Set log_min_duration_statement parameter to 1 second + postgresql_set: + name: log_min_duration_statement + value: 1s + +- name: Set wal_log_hints parameter to default value (remove parameter from postgresql.auto.conf) + postgresql_set: + name: wal_log_hints + value: default +''' + +RETURN = r''' +name: + description: Name of PostgreSQL server parameter. + returned: always + type: str + sample: 'shared_buffers' +restart_required: + description: Information about parameter current state. + returned: always + type: bool + sample: true +prev_val_pretty: + description: Information about previous state of the parameter. + returned: always + type: str + sample: '4MB' +value_pretty: + description: Information about current state of the parameter. + returned: always + type: str + sample: '64MB' +value: + description: + - Dictionary that contains the current parameter value (at the time of playbook finish). + - Pay attention that for real change some parameters restart of PostgreSQL server is required. + - Returns the current value in the check mode. + returned: always + type: dict + sample: { "value": 67108864, "unit": "b" } +context: + description: + - PostgreSQL setting context. + returned: always + type: str + sample: user +''' + +try: + from psycopg2.extras import DictCursor +except Exception: + # psycopg2 is checked by connect_to_db() + # from ansible.module_utils.postgres + pass + +from copy import deepcopy + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) +from ansible.module_utils._text import to_native + +PG_REQ_VER = 90400 + +# To allow to set value like 1mb instead of 1MB, etc: +POSSIBLE_SIZE_UNITS = ("mb", "gb", "tb") + +# =========================================== +# PostgreSQL module specific support methods. +# + + +def param_get(cursor, module, name): + query = ("SELECT name, setting, unit, context, boot_val " + "FROM pg_settings WHERE name = %(name)s") + try: + cursor.execute(query, {'name': name}) + info = cursor.fetchall() + cursor.execute("SHOW %s" % name) + val = cursor.fetchone() + + except Exception as e: + module.fail_json(msg="Unable to get %s value due to : %s" % (name, to_native(e))) + + raw_val = info[0][1] + unit = info[0][2] + context = info[0][3] + boot_val = info[0][4] + + if val[0] == 'True': + val[0] = 'on' + elif val[0] == 'False': + val[0] = 'off' + + if unit == 'kB': + if int(raw_val) > 0: + raw_val = int(raw_val) * 1024 + if int(boot_val) > 0: + boot_val = int(boot_val) * 1024 + + unit = 'b' + + elif unit == 'MB': + if int(raw_val) > 0: + raw_val = int(raw_val) * 1024 * 1024 + if int(boot_val) > 0: + boot_val = int(boot_val) * 1024 * 1024 + + unit = 'b' + + return (val[0], raw_val, unit, boot_val, context) + + +def pretty_to_bytes(pretty_val): + # The function returns a value in bytes + # if the value contains 'B', 'kB', 'MB', 'GB', 'TB'. + # Otherwise it returns the passed argument. + + val_in_bytes = None + + if 'kB' in pretty_val: + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 + + elif 'MB' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 * 1024 + + elif 'GB' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 * 1024 * 1024 + + elif 'TB' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 * 1024 * 1024 * 1024 + + elif 'B' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part + + else: + return pretty_val + + return val_in_bytes + + +def param_set(cursor, module, name, value, context): + try: + if str(value).lower() == 'default': + query = "ALTER SYSTEM SET %s = DEFAULT" % name + else: + query = "ALTER SYSTEM SET %s = '%s'" % (name, value) + cursor.execute(query) + + if context != 'postmaster': + cursor.execute("SELECT pg_reload_conf()") + + except Exception as e: + module.fail_json(msg="Unable to get %s value due to : %s" % (name, to_native(e))) + + return True + + +# =========================================== +# Module execution. +# + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + name=dict(type='str', required=True), + db=dict(type='str', aliases=['login_db']), + value=dict(type='str'), + reset=dict(type='bool'), + session_role=dict(type='str'), + ) + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True, + ) + + name = module.params["name"] + value = module.params["value"] + reset = module.params["reset"] + + # Allow to pass values like 1mb instead of 1MB, etc: + if value: + for unit in POSSIBLE_SIZE_UNITS: + if value[:-2].isdigit() and unit in value[-2:]: + value = value.upper() + + if value and reset: + module.fail_json(msg="%s: value and reset params are mutually exclusive" % name) + + if not value and not reset: + module.fail_json(msg="%s: at least one of value or reset param must be specified" % name) + + conn_params = get_conn_params(module, module.params, warn_db_default=False) + db_connection = connect_to_db(module, conn_params, autocommit=True) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + kw = {} + # Check server version (needs 9.4 or later): + ver = db_connection.server_version + if ver < PG_REQ_VER: + module.warn("PostgreSQL is %s version but %s or later is required" % (ver, PG_REQ_VER)) + kw = dict( + changed=False, + restart_required=False, + value_pretty="", + prev_val_pretty="", + value={"value": "", "unit": ""}, + ) + kw['name'] = name + db_connection.close() + module.exit_json(**kw) + + # Set default returned values: + restart_required = False + changed = False + kw['name'] = name + kw['restart_required'] = False + + # Get info about param state: + res = param_get(cursor, module, name) + current_value = res[0] + raw_val = res[1] + unit = res[2] + boot_val = res[3] + context = res[4] + + if value == 'True': + value = 'on' + elif value == 'False': + value = 'off' + + kw['prev_val_pretty'] = current_value + kw['value_pretty'] = deepcopy(kw['prev_val_pretty']) + kw['context'] = context + + # Do job + if context == "internal": + module.fail_json(msg="%s: cannot be changed (internal context). See " + "https://www.postgresql.org/docs/current/runtime-config-preset.html" % name) + + if context == "postmaster": + restart_required = True + + # If check_mode, just compare and exit: + if module.check_mode: + if pretty_to_bytes(value) == pretty_to_bytes(current_value): + kw['changed'] = False + + else: + kw['value_pretty'] = value + kw['changed'] = True + + # Anyway returns current raw value in the check_mode: + kw['value'] = dict( + value=raw_val, + unit=unit, + ) + kw['restart_required'] = restart_required + module.exit_json(**kw) + + # Set param: + if value and value != current_value: + changed = param_set(cursor, module, name, value, context) + + kw['value_pretty'] = value + + # Reset param: + elif reset: + if raw_val == boot_val: + # nothing to change, exit: + kw['value'] = dict( + value=raw_val, + unit=unit, + ) + module.exit_json(**kw) + + changed = param_set(cursor, module, name, boot_val, context) + + if restart_required: + module.warn("Restart of PostgreSQL is required for setting %s" % name) + + cursor.close() + db_connection.close() + + # Reconnect and recheck current value: + if context in ('sighup', 'superuser-backend', 'backend', 'superuser', 'user'): + db_connection = connect_to_db(module, conn_params, autocommit=True) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + res = param_get(cursor, module, name) + # f_ means 'final' + f_value = res[0] + f_raw_val = res[1] + + if raw_val == f_raw_val: + changed = False + + else: + changed = True + + kw['value_pretty'] = f_value + kw['value'] = dict( + value=f_raw_val, + unit=unit, + ) + + cursor.close() + db_connection.close() + + kw['changed'] = changed + kw['restart_required'] = restart_required + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_table.py b/test/support/integration/plugins/modules/postgresql_table.py new file mode 100644 index 00000000..3bef03b0 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_table.py @@ -0,0 +1,601 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community' +} + +DOCUMENTATION = r''' +--- +module: postgresql_table +short_description: Create, drop, or modify a PostgreSQL table +description: +- Allows to create, drop, rename, truncate a table, or change some table attributes. +version_added: '2.8' +options: + table: + description: + - Table name. + required: true + aliases: + - name + type: str + state: + description: + - The table state. I(state=absent) is mutually exclusive with I(tablespace), I(owner), I(unlogged), + I(like), I(including), I(columns), I(truncate), I(storage_params) and, I(rename). + type: str + default: present + choices: [ absent, present ] + tablespace: + description: + - Set a tablespace for the table. + required: false + type: str + owner: + description: + - Set a table owner. + type: str + unlogged: + description: + - Create an unlogged table. + type: bool + default: no + like: + description: + - Create a table like another table (with similar DDL). + Mutually exclusive with I(columns), I(rename), and I(truncate). + type: str + including: + description: + - Keywords that are used with like parameter, may be DEFAULTS, CONSTRAINTS, INDEXES, STORAGE, COMMENTS or ALL. + Needs I(like) specified. Mutually exclusive with I(columns), I(rename), and I(truncate). + type: str + columns: + description: + - Columns that are needed. + type: list + elements: str + rename: + description: + - New table name. Mutually exclusive with I(tablespace), I(owner), + I(unlogged), I(like), I(including), I(columns), I(truncate), and I(storage_params). + type: str + truncate: + description: + - Truncate a table. Mutually exclusive with I(tablespace), I(owner), I(unlogged), + I(like), I(including), I(columns), I(rename), and I(storage_params). + type: bool + default: no + storage_params: + description: + - Storage parameters like fillfactor, autovacuum_vacuum_treshold, etc. + Mutually exclusive with I(rename) and I(truncate). + type: list + elements: str + db: + description: + - Name of database to connect and where the table will be created. + type: str + aliases: + - login_db + session_role: + description: + - Switch to session_role after connecting. + The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though + the session_role were the one that had logged in originally. + type: str + cascade: + description: + - Automatically drop objects that depend on the table (such as views). + Used with I(state=absent) only. + type: bool + default: no + version_added: '2.9' +notes: +- If you do not pass db parameter, tables will be created in the database + named postgres. +- PostgreSQL allows to create columnless table, so columns param is optional. +- Unlogged tables are available from PostgreSQL server version 9.1. +seealso: +- module: postgresql_sequence +- module: postgresql_idx +- module: postgresql_info +- module: postgresql_tablespace +- module: postgresql_owner +- module: postgresql_privs +- module: postgresql_copy +- name: CREATE TABLE reference + description: Complete reference of the CREATE TABLE command documentation. + link: https://www.postgresql.org/docs/current/sql-createtable.html +- name: ALTER TABLE reference + description: Complete reference of the ALTER TABLE command documentation. + link: https://www.postgresql.org/docs/current/sql-altertable.html +- name: DROP TABLE reference + description: Complete reference of the DROP TABLE command documentation. + link: https://www.postgresql.org/docs/current/sql-droptable.html +- name: PostgreSQL data types + description: Complete reference of the PostgreSQL data types documentation. + link: https://www.postgresql.org/docs/current/datatype.html +author: +- Andrei Klychkov (@Andersson007) +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Create tbl2 in the acme database with the DDL like tbl1 with testuser as an owner + postgresql_table: + db: acme + name: tbl2 + like: tbl1 + owner: testuser + +- name: Create tbl2 in the acme database and tablespace ssd with the DDL like tbl1 including comments and indexes + postgresql_table: + db: acme + table: tbl2 + like: tbl1 + including: comments, indexes + tablespace: ssd + +- name: Create test_table with several columns in ssd tablespace with fillfactor=10 and autovacuum_analyze_threshold=1 + postgresql_table: + name: test_table + columns: + - id bigserial primary key + - num bigint + - stories text + tablespace: ssd + storage_params: + - fillfactor=10 + - autovacuum_analyze_threshold=1 + +- name: Create an unlogged table in schema acme + postgresql_table: + name: acme.useless_data + columns: waste_id int + unlogged: true + +- name: Rename table foo to bar + postgresql_table: + table: foo + rename: bar + +- name: Rename table foo from schema acme to bar + postgresql_table: + name: acme.foo + rename: bar + +- name: Set owner to someuser + postgresql_table: + name: foo + owner: someuser + +- name: Change tablespace of foo table to new_tablespace and set owner to new_user + postgresql_table: + name: foo + tablespace: new_tablespace + owner: new_user + +- name: Truncate table foo + postgresql_table: + name: foo + truncate: yes + +- name: Drop table foo from schema acme + postgresql_table: + name: acme.foo + state: absent + +- name: Drop table bar cascade + postgresql_table: + name: bar + state: absent + cascade: yes +''' + +RETURN = r''' +table: + description: Name of a table. + returned: always + type: str + sample: 'foo' +state: + description: Table state. + returned: always + type: str + sample: 'present' +owner: + description: Table owner. + returned: always + type: str + sample: 'postgres' +tablespace: + description: Tablespace. + returned: always + type: str + sample: 'ssd_tablespace' +queries: + description: List of executed queries. + returned: always + type: str + sample: [ 'CREATE TABLE "test_table" (id bigint)' ] +storage_params: + description: Storage parameters. + returned: always + type: list + sample: [ "fillfactor=100", "autovacuum_analyze_threshold=1" ] +''' + +try: + from psycopg2.extras import DictCursor +except ImportError: + # psycopg2 is checked by connect_to_db() + # from ansible.module_utils.postgres + pass + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.database import pg_quote_identifier +from ansible.module_utils.postgres import ( + connect_to_db, + exec_sql, + get_conn_params, + postgres_common_argument_spec, +) + + +# =========================================== +# PostgreSQL module specific support methods. +# + +class Table(object): + def __init__(self, name, module, cursor): + self.name = name + self.module = module + self.cursor = cursor + self.info = { + 'owner': '', + 'tblspace': '', + 'storage_params': [], + } + self.exists = False + self.__exists_in_db() + self.executed_queries = [] + + def get_info(self): + """Getter to refresh and get table info""" + self.__exists_in_db() + + def __exists_in_db(self): + """Check table exists and refresh info""" + if "." in self.name: + schema = self.name.split('.')[-2] + tblname = self.name.split('.')[-1] + else: + schema = 'public' + tblname = self.name + + query = ("SELECT t.tableowner, t.tablespace, c.reloptions " + "FROM pg_tables AS t " + "INNER JOIN pg_class AS c ON c.relname = t.tablename " + "INNER JOIN pg_namespace AS n ON c.relnamespace = n.oid " + "WHERE t.tablename = %(tblname)s " + "AND n.nspname = %(schema)s") + res = exec_sql(self, query, query_params={'tblname': tblname, 'schema': schema}, + add_to_executed=False) + if res: + self.exists = True + self.info = dict( + owner=res[0][0], + tblspace=res[0][1] if res[0][1] else '', + storage_params=res[0][2] if res[0][2] else [], + ) + + return True + else: + self.exists = False + return False + + def create(self, columns='', params='', tblspace='', + unlogged=False, owner=''): + """ + Create table. + If table exists, check passed args (params, tblspace, owner) and, + if they're different from current, change them. + Arguments: + params - storage params (passed by "WITH (...)" in SQL), + comma separated. + tblspace - tablespace. + owner - table owner. + unlogged - create unlogged table. + columns - column string (comma separated). + """ + name = pg_quote_identifier(self.name, 'table') + + changed = False + + if self.exists: + if tblspace == 'pg_default' and self.info['tblspace'] is None: + pass # Because they have the same meaning + elif tblspace and self.info['tblspace'] != tblspace: + self.set_tblspace(tblspace) + changed = True + + if owner and self.info['owner'] != owner: + self.set_owner(owner) + changed = True + + if params: + param_list = [p.strip(' ') for p in params.split(',')] + + new_param = False + for p in param_list: + if p not in self.info['storage_params']: + new_param = True + + if new_param: + self.set_stor_params(params) + changed = True + + if changed: + return True + return False + + query = "CREATE" + if unlogged: + query += " UNLOGGED TABLE %s" % name + else: + query += " TABLE %s" % name + + if columns: + query += " (%s)" % columns + else: + query += " ()" + + if params: + query += " WITH (%s)" % params + + if tblspace: + query += " TABLESPACE %s" % pg_quote_identifier(tblspace, 'database') + + if exec_sql(self, query, ddl=True): + changed = True + + if owner: + changed = self.set_owner(owner) + + return changed + + def create_like(self, src_table, including='', tblspace='', + unlogged=False, params='', owner=''): + """ + Create table like another table (with similar DDL). + Arguments: + src_table - source table. + including - corresponds to optional INCLUDING expression + in CREATE TABLE ... LIKE statement. + params - storage params (passed by "WITH (...)" in SQL), + comma separated. + tblspace - tablespace. + owner - table owner. + unlogged - create unlogged table. + """ + changed = False + + name = pg_quote_identifier(self.name, 'table') + + query = "CREATE" + if unlogged: + query += " UNLOGGED TABLE %s" % name + else: + query += " TABLE %s" % name + + query += " (LIKE %s" % pg_quote_identifier(src_table, 'table') + + if including: + including = including.split(',') + for i in including: + query += " INCLUDING %s" % i + + query += ')' + + if params: + query += " WITH (%s)" % params + + if tblspace: + query += " TABLESPACE %s" % pg_quote_identifier(tblspace, 'database') + + if exec_sql(self, query, ddl=True): + changed = True + + if owner: + changed = self.set_owner(owner) + + return changed + + def truncate(self): + query = "TRUNCATE TABLE %s" % pg_quote_identifier(self.name, 'table') + return exec_sql(self, query, ddl=True) + + def rename(self, newname): + query = "ALTER TABLE %s RENAME TO %s" % (pg_quote_identifier(self.name, 'table'), + pg_quote_identifier(newname, 'table')) + return exec_sql(self, query, ddl=True) + + def set_owner(self, username): + query = "ALTER TABLE %s OWNER TO %s" % (pg_quote_identifier(self.name, 'table'), + pg_quote_identifier(username, 'role')) + return exec_sql(self, query, ddl=True) + + def drop(self, cascade=False): + if not self.exists: + return False + + query = "DROP TABLE %s" % pg_quote_identifier(self.name, 'table') + if cascade: + query += " CASCADE" + return exec_sql(self, query, ddl=True) + + def set_tblspace(self, tblspace): + query = "ALTER TABLE %s SET TABLESPACE %s" % (pg_quote_identifier(self.name, 'table'), + pg_quote_identifier(tblspace, 'database')) + return exec_sql(self, query, ddl=True) + + def set_stor_params(self, params): + query = "ALTER TABLE %s SET (%s)" % (pg_quote_identifier(self.name, 'table'), params) + return exec_sql(self, query, ddl=True) + + +# =========================================== +# Module execution. +# + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + table=dict(type='str', required=True, aliases=['name']), + state=dict(type='str', default="present", choices=["absent", "present"]), + db=dict(type='str', default='', aliases=['login_db']), + tablespace=dict(type='str'), + owner=dict(type='str'), + unlogged=dict(type='bool', default=False), + like=dict(type='str'), + including=dict(type='str'), + rename=dict(type='str'), + truncate=dict(type='bool', default=False), + columns=dict(type='list', elements='str'), + storage_params=dict(type='list', elements='str'), + session_role=dict(type='str'), + cascade=dict(type='bool', default=False), + ) + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True, + ) + + table = module.params["table"] + state = module.params["state"] + tablespace = module.params["tablespace"] + owner = module.params["owner"] + unlogged = module.params["unlogged"] + like = module.params["like"] + including = module.params["including"] + newname = module.params["rename"] + storage_params = module.params["storage_params"] + truncate = module.params["truncate"] + columns = module.params["columns"] + cascade = module.params["cascade"] + + if state == 'present' and cascade: + module.warn("cascade=true is ignored when state=present") + + # Check mutual exclusive parameters: + if state == 'absent' and (truncate or newname or columns or tablespace or like or storage_params or unlogged or owner or including): + module.fail_json(msg="%s: state=absent is mutually exclusive with: " + "truncate, rename, columns, tablespace, " + "including, like, storage_params, unlogged, owner" % table) + + if truncate and (newname or columns or like or unlogged or storage_params or owner or tablespace or including): + module.fail_json(msg="%s: truncate is mutually exclusive with: " + "rename, columns, like, unlogged, including, " + "storage_params, owner, tablespace" % table) + + if newname and (columns or like or unlogged or storage_params or owner or tablespace or including): + module.fail_json(msg="%s: rename is mutually exclusive with: " + "columns, like, unlogged, including, " + "storage_params, owner, tablespace" % table) + + if like and columns: + module.fail_json(msg="%s: like and columns params are mutually exclusive" % table) + if including and not like: + module.fail_json(msg="%s: including param needs like param specified" % table) + + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + if storage_params: + storage_params = ','.join(storage_params) + + if columns: + columns = ','.join(columns) + + ############## + # Do main job: + table_obj = Table(table, module, cursor) + + # Set default returned values: + changed = False + kw = {} + kw['table'] = table + kw['state'] = '' + if table_obj.exists: + kw = dict( + table=table, + state='present', + owner=table_obj.info['owner'], + tablespace=table_obj.info['tblspace'], + storage_params=table_obj.info['storage_params'], + ) + + if state == 'absent': + changed = table_obj.drop(cascade=cascade) + + elif truncate: + changed = table_obj.truncate() + + elif newname: + changed = table_obj.rename(newname) + q = table_obj.executed_queries + table_obj = Table(newname, module, cursor) + table_obj.executed_queries = q + + elif state == 'present' and not like: + changed = table_obj.create(columns, storage_params, + tablespace, unlogged, owner) + + elif state == 'present' and like: + changed = table_obj.create_like(like, including, tablespace, + unlogged, storage_params) + + if changed: + if module.check_mode: + db_connection.rollback() + else: + db_connection.commit() + + # Refresh table info for RETURN. + # Note, if table has been renamed, it gets info by newname: + table_obj.get_info() + db_connection.commit() + if table_obj.exists: + kw = dict( + table=table, + state='present', + owner=table_obj.info['owner'], + tablespace=table_obj.info['tblspace'], + storage_params=table_obj.info['storage_params'], + ) + else: + # We just change the table state here + # to keep other information about the dropped table: + kw['state'] = 'absent' + + kw['queries'] = table_obj.executed_queries + kw['changed'] = changed + db_connection.close() + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_user.py b/test/support/integration/plugins/modules/postgresql_user.py new file mode 100644 index 00000000..10afd0a0 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_user.py @@ -0,0 +1,927 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'community' +} + +DOCUMENTATION = r''' +--- +module: postgresql_user +short_description: Add or remove a user (role) from a PostgreSQL server instance +description: +- Adds or removes a user (role) from a PostgreSQL server instance + ("cluster" in PostgreSQL terminology) and, optionally, + grants the user access to an existing database or tables. +- A user is a role with login privilege. +- The fundamental function of the module is to create, or delete, users from + a PostgreSQL instances. Privilege assignment, or removal, is an optional + step, which works on one database at a time. This allows for the module to + be called several times in the same module to modify the permissions on + different databases, or to grant permissions to already existing users. +- A user cannot be removed until all the privileges have been stripped from + the user. In such situation, if the module tries to remove the user it + will fail. To avoid this from happening the fail_on_user option signals + the module to try to remove the user, but if not possible keep going; the + module will report if changes happened and separately if the user was + removed or not. +version_added: '0.6' +options: + name: + description: + - Name of the user (role) to add or remove. + type: str + required: true + aliases: + - user + password: + description: + - Set the user's password, before 1.4 this was required. + - Password can be passed unhashed or hashed (MD5-hashed). + - Unhashed password will automatically be hashed when saved into the + database if C(encrypted) parameter is set, otherwise it will be save in + plain text format. + - When passing a hashed password it must be generated with the format + C('str["md5"] + md5[ password + username ]'), resulting in a total of + 35 characters. An easy way to do this is C(echo "md5$(echo -n + 'verysecretpasswordJOE' | md5sum | awk '{print $1}')"). + - Note that if the provided password string is already in MD5-hashed + format, then it is used as-is, regardless of C(encrypted) parameter. + type: str + db: + description: + - Name of database to connect to and where user's permissions will be granted. + type: str + aliases: + - login_db + fail_on_user: + description: + - If C(yes), fail when user (role) can't be removed. Otherwise just log and continue. + default: 'yes' + type: bool + aliases: + - fail_on_role + priv: + description: + - "Slash-separated PostgreSQL privileges string: C(priv1/priv2), where + privileges can be defined for database ( allowed options - 'CREATE', + 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL'. For example C(CONNECT) ) or + for table ( allowed options - 'SELECT', 'INSERT', 'UPDATE', 'DELETE', + 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL'. For example + C(table:SELECT) ). Mixed example of this string: + C(CONNECT/CREATE/table1:SELECT/table2:INSERT)." + type: str + role_attr_flags: + description: + - "PostgreSQL user attributes string in the format: CREATEDB,CREATEROLE,SUPERUSER." + - Note that '[NO]CREATEUSER' is deprecated. + - To create a simple role for using it like a group, use C(NOLOGIN) flag. + type: str + choices: [ '[NO]SUPERUSER', '[NO]CREATEROLE', '[NO]CREATEDB', + '[NO]INHERIT', '[NO]LOGIN', '[NO]REPLICATION', '[NO]BYPASSRLS' ] + session_role: + version_added: '2.8' + description: + - Switch to session_role after connecting. + - The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally. + type: str + state: + description: + - The user (role) state. + type: str + default: present + choices: [ absent, present ] + encrypted: + description: + - Whether the password is stored hashed in the database. + - Passwords can be passed already hashed or unhashed, and postgresql + ensures the stored password is hashed when C(encrypted) is set. + - "Note: Postgresql 10 and newer doesn't support unhashed passwords." + - Previous to Ansible 2.6, this was C(no) by default. + default: 'yes' + type: bool + version_added: '1.4' + expires: + description: + - The date at which the user's password is to expire. + - If set to C('infinity'), user's password never expire. + - Note that this value should be a valid SQL date and time type. + type: str + version_added: '1.4' + no_password_changes: + description: + - If C(yes), don't inspect database for password changes. Effective when + C(pg_authid) is not accessible (such as AWS RDS). Otherwise, make + password changes as necessary. + default: 'no' + type: bool + version_added: '2.0' + conn_limit: + description: + - Specifies the user (role) connection limit. + type: int + version_added: '2.4' + ssl_mode: + description: + - Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated with the server. + - See https://www.postgresql.org/docs/current/static/libpq-ssl.html for more information on the modes. + - Default of C(prefer) matches libpq default. + type: str + default: prefer + choices: [ allow, disable, prefer, require, verify-ca, verify-full ] + version_added: '2.3' + ca_cert: + description: + - Specifies the name of a file containing SSL certificate authority (CA) certificate(s). + - If the file exists, the server's certificate will be verified to be signed by one of these authorities. + type: str + aliases: [ ssl_rootcert ] + version_added: '2.3' + groups: + description: + - The list of groups (roles) that need to be granted to the user. + type: list + elements: str + version_added: '2.9' + comment: + description: + - Add a comment on the user (equal to the COMMENT ON ROLE statement result). + type: str + version_added: '2.10' +notes: +- The module creates a user (role) with login privilege by default. + Use NOLOGIN role_attr_flags to change this behaviour. +- If you specify PUBLIC as the user (role), then the privilege changes will apply to all users (roles). + You may not specify password or role_attr_flags when the PUBLIC user is specified. +seealso: +- module: postgresql_privs +- module: postgresql_membership +- module: postgresql_owner +- name: PostgreSQL database roles + description: Complete reference of the PostgreSQL database roles documentation. + link: https://www.postgresql.org/docs/current/user-manag.html +author: +- Ansible Core Team +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Connect to acme database, create django user, and grant access to database and products table + postgresql_user: + db: acme + name: django + password: ceec4eif7ya + priv: "CONNECT/products:ALL" + expires: "Jan 31 2020" + +- name: Add a comment on django user + postgresql_user: + db: acme + name: django + comment: This is a test user + +# Connect to default database, create rails user, set its password (MD5-hashed), +# and grant privilege to create other databases and demote rails from super user status if user exists +- name: Create rails user, set MD5-hashed password, grant privs + postgresql_user: + name: rails + password: md59543f1d82624df2b31672ec0f7050460 + role_attr_flags: CREATEDB,NOSUPERUSER + +- name: Connect to acme database and remove test user privileges from there + postgresql_user: + db: acme + name: test + priv: "ALL/products:ALL" + state: absent + fail_on_user: no + +- name: Connect to test database, remove test user from cluster + postgresql_user: + db: test + name: test + priv: ALL + state: absent + +- name: Connect to acme database and set user's password with no expire date + postgresql_user: + db: acme + name: django + password: mysupersecretword + priv: "CONNECT/products:ALL" + expires: infinity + +# Example privileges string format +# INSERT,UPDATE/table:SELECT/anothertable:ALL + +- name: Connect to test database and remove an existing user's password + postgresql_user: + db: test + user: test + password: "" + +- name: Create user test and grant group user_ro and user_rw to it + postgresql_user: + name: test + groups: + - user_ro + - user_rw +''' + +RETURN = r''' +queries: + description: List of executed queries. + returned: always + type: list + sample: ['CREATE USER "alice"', 'GRANT CONNECT ON DATABASE "acme" TO "alice"'] + version_added: '2.8' +''' + +import itertools +import re +import traceback +from hashlib import md5 + +try: + import psycopg2 + from psycopg2.extras import DictCursor +except ImportError: + # psycopg2 is checked by connect_to_db() + # from ansible.module_utils.postgres + pass + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.database import pg_quote_identifier, SQLParseError +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + PgMembership, + postgres_common_argument_spec, +) +from ansible.module_utils._text import to_bytes, to_native +from ansible.module_utils.six import iteritems + + +FLAGS = ('SUPERUSER', 'CREATEROLE', 'CREATEDB', 'INHERIT', 'LOGIN', 'REPLICATION') +FLAGS_BY_VERSION = {'BYPASSRLS': 90500} + +VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL')), + database=frozenset( + ('CREATE', 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL')), + ) + +# map to cope with idiosyncracies of SUPERUSER and LOGIN +PRIV_TO_AUTHID_COLUMN = dict(SUPERUSER='rolsuper', CREATEROLE='rolcreaterole', + CREATEDB='rolcreatedb', INHERIT='rolinherit', LOGIN='rolcanlogin', + REPLICATION='rolreplication', BYPASSRLS='rolbypassrls') + +executed_queries = [] + + +class InvalidFlagsError(Exception): + pass + + +class InvalidPrivsError(Exception): + pass + +# =========================================== +# PostgreSQL module specific support methods. +# + + +def user_exists(cursor, user): + # The PUBLIC user is a special case that is always there + if user == 'PUBLIC': + return True + query = "SELECT rolname FROM pg_roles WHERE rolname=%(user)s" + cursor.execute(query, {'user': user}) + return cursor.rowcount > 0 + + +def user_add(cursor, user, password, role_attr_flags, encrypted, expires, conn_limit): + """Create a new database user (role).""" + # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a + # literal + query_password_data = dict(password=password, expires=expires) + query = ['CREATE USER "%(user)s"' % + {"user": user}] + if password is not None and password != '': + query.append("WITH %(crypt)s" % {"crypt": encrypted}) + query.append("PASSWORD %(password)s") + if expires is not None: + query.append("VALID UNTIL %(expires)s") + if conn_limit is not None: + query.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit}) + query.append(role_attr_flags) + query = ' '.join(query) + executed_queries.append(query) + cursor.execute(query, query_password_data) + return True + + +def user_should_we_change_password(current_role_attrs, user, password, encrypted): + """Check if we should change the user's password. + + Compare the proposed password with the existing one, comparing + hashes if encrypted. If we can't access it assume yes. + """ + + if current_role_attrs is None: + # on some databases, E.g. AWS RDS instances, there is no access to + # the pg_authid relation to check the pre-existing password, so we + # just assume password is different + return True + + # Do we actually need to do anything? + pwchanging = False + if password is not None: + # Empty password means that the role shouldn't have a password, which + # means we need to check if the current password is None. + if password == '': + if current_role_attrs['rolpassword'] is not None: + pwchanging = True + # 32: MD5 hashes are represented as a sequence of 32 hexadecimal digits + # 3: The size of the 'md5' prefix + # When the provided password looks like a MD5-hash, value of + # 'encrypted' is ignored. + elif (password.startswith('md5') and len(password) == 32 + 3) or encrypted == 'UNENCRYPTED': + if password != current_role_attrs['rolpassword']: + pwchanging = True + elif encrypted == 'ENCRYPTED': + hashed_password = 'md5{0}'.format(md5(to_bytes(password) + to_bytes(user)).hexdigest()) + if hashed_password != current_role_attrs['rolpassword']: + pwchanging = True + + return pwchanging + + +def user_alter(db_connection, module, user, password, role_attr_flags, encrypted, expires, no_password_changes, conn_limit): + """Change user password and/or attributes. Return True if changed, False otherwise.""" + changed = False + + cursor = db_connection.cursor(cursor_factory=DictCursor) + # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a + # literal + if user == 'PUBLIC': + if password is not None: + module.fail_json(msg="cannot change the password for PUBLIC user") + elif role_attr_flags != '': + module.fail_json(msg="cannot change the role_attr_flags for PUBLIC user") + else: + return False + + # Handle passwords. + if not no_password_changes and (password is not None or role_attr_flags != '' or expires is not None or conn_limit is not None): + # Select password and all flag-like columns in order to verify changes. + try: + select = "SELECT * FROM pg_authid where rolname=%(user)s" + cursor.execute(select, {"user": user}) + # Grab current role attributes. + current_role_attrs = cursor.fetchone() + except psycopg2.ProgrammingError: + current_role_attrs = None + db_connection.rollback() + + pwchanging = user_should_we_change_password(current_role_attrs, user, password, encrypted) + + if current_role_attrs is None: + try: + # AWS RDS instances does not allow user to access pg_authid + # so try to get current_role_attrs from pg_roles tables + select = "SELECT * FROM pg_roles where rolname=%(user)s" + cursor.execute(select, {"user": user}) + # Grab current role attributes from pg_roles + current_role_attrs = cursor.fetchone() + except psycopg2.ProgrammingError as e: + db_connection.rollback() + module.fail_json(msg="Failed to get role details for current user %s: %s" % (user, e)) + + role_attr_flags_changing = False + if role_attr_flags: + role_attr_flags_dict = {} + for r in role_attr_flags.split(' '): + if r.startswith('NO'): + role_attr_flags_dict[r.replace('NO', '', 1)] = False + else: + role_attr_flags_dict[r] = True + + for role_attr_name, role_attr_value in role_attr_flags_dict.items(): + if current_role_attrs[PRIV_TO_AUTHID_COLUMN[role_attr_name]] != role_attr_value: + role_attr_flags_changing = True + + if expires is not None: + cursor.execute("SELECT %s::timestamptz;", (expires,)) + expires_with_tz = cursor.fetchone()[0] + expires_changing = expires_with_tz != current_role_attrs.get('rolvaliduntil') + else: + expires_changing = False + + conn_limit_changing = (conn_limit is not None and conn_limit != current_role_attrs['rolconnlimit']) + + if not pwchanging and not role_attr_flags_changing and not expires_changing and not conn_limit_changing: + return False + + alter = ['ALTER USER "%(user)s"' % {"user": user}] + if pwchanging: + if password != '': + alter.append("WITH %(crypt)s" % {"crypt": encrypted}) + alter.append("PASSWORD %(password)s") + else: + alter.append("WITH PASSWORD NULL") + alter.append(role_attr_flags) + elif role_attr_flags: + alter.append('WITH %s' % role_attr_flags) + if expires is not None: + alter.append("VALID UNTIL %(expires)s") + if conn_limit is not None: + alter.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit}) + + query_password_data = dict(password=password, expires=expires) + try: + cursor.execute(' '.join(alter), query_password_data) + changed = True + except psycopg2.InternalError as e: + if e.pgcode == '25006': + # Handle errors due to read-only transactions indicated by pgcode 25006 + # ERROR: cannot execute ALTER ROLE in a read-only transaction + changed = False + module.fail_json(msg=e.pgerror, exception=traceback.format_exc()) + return changed + else: + raise psycopg2.InternalError(e) + except psycopg2.NotSupportedError as e: + module.fail_json(msg=e.pgerror, exception=traceback.format_exc()) + + elif no_password_changes and role_attr_flags != '': + # Grab role information from pg_roles instead of pg_authid + select = "SELECT * FROM pg_roles where rolname=%(user)s" + cursor.execute(select, {"user": user}) + # Grab current role attributes. + current_role_attrs = cursor.fetchone() + + role_attr_flags_changing = False + + if role_attr_flags: + role_attr_flags_dict = {} + for r in role_attr_flags.split(' '): + if r.startswith('NO'): + role_attr_flags_dict[r.replace('NO', '', 1)] = False + else: + role_attr_flags_dict[r] = True + + for role_attr_name, role_attr_value in role_attr_flags_dict.items(): + if current_role_attrs[PRIV_TO_AUTHID_COLUMN[role_attr_name]] != role_attr_value: + role_attr_flags_changing = True + + if not role_attr_flags_changing: + return False + + alter = ['ALTER USER "%(user)s"' % + {"user": user}] + if role_attr_flags: + alter.append('WITH %s' % role_attr_flags) + + try: + cursor.execute(' '.join(alter)) + except psycopg2.InternalError as e: + if e.pgcode == '25006': + # Handle errors due to read-only transactions indicated by pgcode 25006 + # ERROR: cannot execute ALTER ROLE in a read-only transaction + changed = False + module.fail_json(msg=e.pgerror, exception=traceback.format_exc()) + return changed + else: + raise psycopg2.InternalError(e) + + # Grab new role attributes. + cursor.execute(select, {"user": user}) + new_role_attrs = cursor.fetchone() + + # Detect any differences between current_ and new_role_attrs. + changed = current_role_attrs != new_role_attrs + + return changed + + +def user_delete(cursor, user): + """Try to remove a user. Returns True if successful otherwise False""" + cursor.execute("SAVEPOINT ansible_pgsql_user_delete") + try: + query = 'DROP USER "%s"' % user + executed_queries.append(query) + cursor.execute(query) + except Exception: + cursor.execute("ROLLBACK TO SAVEPOINT ansible_pgsql_user_delete") + cursor.execute("RELEASE SAVEPOINT ansible_pgsql_user_delete") + return False + + cursor.execute("RELEASE SAVEPOINT ansible_pgsql_user_delete") + return True + + +def has_table_privileges(cursor, user, table, privs): + """ + Return the difference between the privileges that a user already has and + the privileges that they desire to have. + + :returns: tuple of: + * privileges that they have and were requested + * privileges they currently hold but were not requested + * privileges requested that they do not hold + """ + cur_privs = get_table_privileges(cursor, user, table) + have_currently = cur_privs.intersection(privs) + other_current = cur_privs.difference(privs) + desired = privs.difference(cur_privs) + return (have_currently, other_current, desired) + + +def get_table_privileges(cursor, user, table): + if '.' in table: + schema, table = table.split('.', 1) + else: + schema = 'public' + query = ("SELECT privilege_type FROM information_schema.role_table_grants " + "WHERE grantee=%(user)s AND table_name=%(table)s AND table_schema=%(schema)s") + cursor.execute(query, {'user': user, 'table': table, 'schema': schema}) + return frozenset([x[0] for x in cursor.fetchall()]) + + +def grant_table_privileges(cursor, user, table, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + query = 'GRANT %s ON TABLE %s TO "%s"' % ( + privs, pg_quote_identifier(table, 'table'), user) + executed_queries.append(query) + cursor.execute(query) + + +def revoke_table_privileges(cursor, user, table, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + query = 'REVOKE %s ON TABLE %s FROM "%s"' % ( + privs, pg_quote_identifier(table, 'table'), user) + executed_queries.append(query) + cursor.execute(query) + + +def get_database_privileges(cursor, user, db): + priv_map = { + 'C': 'CREATE', + 'T': 'TEMPORARY', + 'c': 'CONNECT', + } + query = 'SELECT datacl FROM pg_database WHERE datname = %s' + cursor.execute(query, (db,)) + datacl = cursor.fetchone()[0] + if datacl is None: + return set() + r = re.search(r'%s\\?"?=(C?T?c?)/[^,]+,?' % user, datacl) + if r is None: + return set() + o = set() + for v in r.group(1): + o.add(priv_map[v]) + return normalize_privileges(o, 'database') + + +def has_database_privileges(cursor, user, db, privs): + """ + Return the difference between the privileges that a user already has and + the privileges that they desire to have. + + :returns: tuple of: + * privileges that they have and were requested + * privileges they currently hold but were not requested + * privileges requested that they do not hold + """ + cur_privs = get_database_privileges(cursor, user, db) + have_currently = cur_privs.intersection(privs) + other_current = cur_privs.difference(privs) + desired = privs.difference(cur_privs) + return (have_currently, other_current, desired) + + +def grant_database_privileges(cursor, user, db, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + if user == "PUBLIC": + query = 'GRANT %s ON DATABASE %s TO PUBLIC' % ( + privs, pg_quote_identifier(db, 'database')) + else: + query = 'GRANT %s ON DATABASE %s TO "%s"' % ( + privs, pg_quote_identifier(db, 'database'), user) + + executed_queries.append(query) + cursor.execute(query) + + +def revoke_database_privileges(cursor, user, db, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + if user == "PUBLIC": + query = 'REVOKE %s ON DATABASE %s FROM PUBLIC' % ( + privs, pg_quote_identifier(db, 'database')) + else: + query = 'REVOKE %s ON DATABASE %s FROM "%s"' % ( + privs, pg_quote_identifier(db, 'database'), user) + + executed_queries.append(query) + cursor.execute(query) + + +def revoke_privileges(cursor, user, privs): + if privs is None: + return False + + revoke_funcs = dict(table=revoke_table_privileges, + database=revoke_database_privileges) + check_funcs = dict(table=has_table_privileges, + database=has_database_privileges) + + changed = False + for type_ in privs: + for name, privileges in iteritems(privs[type_]): + # Check that any of the privileges requested to be removed are + # currently granted to the user + differences = check_funcs[type_](cursor, user, name, privileges) + if differences[0]: + revoke_funcs[type_](cursor, user, name, privileges) + changed = True + return changed + + +def grant_privileges(cursor, user, privs): + if privs is None: + return False + + grant_funcs = dict(table=grant_table_privileges, + database=grant_database_privileges) + check_funcs = dict(table=has_table_privileges, + database=has_database_privileges) + + changed = False + for type_ in privs: + for name, privileges in iteritems(privs[type_]): + # Check that any of the privileges requested for the user are + # currently missing + differences = check_funcs[type_](cursor, user, name, privileges) + if differences[2]: + grant_funcs[type_](cursor, user, name, privileges) + changed = True + return changed + + +def parse_role_attrs(cursor, role_attr_flags): + """ + Parse role attributes string for user creation. + Format: + + attributes[,attributes,...] + + Where: + + attributes := CREATEDB,CREATEROLE,NOSUPERUSER,... + [ "[NO]SUPERUSER","[NO]CREATEROLE", "[NO]CREATEDB", + "[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION", + "[NO]BYPASSRLS" ] + + Note: "[NO]BYPASSRLS" role attribute introduced in 9.5 + Note: "[NO]CREATEUSER" role attribute is deprecated. + + """ + flags = frozenset(role.upper() for role in role_attr_flags.split(',') if role) + + valid_flags = frozenset(itertools.chain(FLAGS, get_valid_flags_by_version(cursor))) + valid_flags = frozenset(itertools.chain(valid_flags, ('NO%s' % flag for flag in valid_flags))) + + if not flags.issubset(valid_flags): + raise InvalidFlagsError('Invalid role_attr_flags specified: %s' % + ' '.join(flags.difference(valid_flags))) + + return ' '.join(flags) + + +def normalize_privileges(privs, type_): + new_privs = set(privs) + if 'ALL' in new_privs: + new_privs.update(VALID_PRIVS[type_]) + new_privs.remove('ALL') + if 'TEMP' in new_privs: + new_privs.add('TEMPORARY') + new_privs.remove('TEMP') + + return new_privs + + +def parse_privs(privs, db): + """ + Parse privilege string to determine permissions for database db. + Format: + + privileges[/privileges/...] + + Where: + + privileges := DATABASE_PRIVILEGES[,DATABASE_PRIVILEGES,...] | + TABLE_NAME:TABLE_PRIVILEGES[,TABLE_PRIVILEGES,...] + """ + if privs is None: + return privs + + o_privs = { + 'database': {}, + 'table': {} + } + for token in privs.split('/'): + if ':' not in token: + type_ = 'database' + name = db + priv_set = frozenset(x.strip().upper() + for x in token.split(',') if x.strip()) + else: + type_ = 'table' + name, privileges = token.split(':', 1) + priv_set = frozenset(x.strip().upper() + for x in privileges.split(',') if x.strip()) + + if not priv_set.issubset(VALID_PRIVS[type_]): + raise InvalidPrivsError('Invalid privs specified for %s: %s' % + (type_, ' '.join(priv_set.difference(VALID_PRIVS[type_])))) + + priv_set = normalize_privileges(priv_set, type_) + o_privs[type_][name] = priv_set + + return o_privs + + +def get_valid_flags_by_version(cursor): + """ + Some role attributes were introduced after certain versions. We want to + compile a list of valid flags against the current Postgres version. + """ + current_version = cursor.connection.server_version + + return [ + flag + for flag, version_introduced in FLAGS_BY_VERSION.items() + if current_version >= version_introduced + ] + + +def get_comment(cursor, user): + """Get user's comment.""" + query = ("SELECT pg_catalog.shobj_description(r.oid, 'pg_authid') " + "FROM pg_catalog.pg_roles r " + "WHERE r.rolname = %(user)s") + cursor.execute(query, {'user': user}) + return cursor.fetchone()[0] + + +def add_comment(cursor, user, comment): + """Add comment on user.""" + if comment != get_comment(cursor, user): + query = 'COMMENT ON ROLE "%s" IS ' % user + cursor.execute(query + '%(comment)s', {'comment': comment}) + executed_queries.append(cursor.mogrify(query + '%(comment)s', {'comment': comment})) + return True + else: + return False + + +# =========================================== +# Module execution. +# + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + user=dict(type='str', required=True, aliases=['name']), + password=dict(type='str', default=None, no_log=True), + state=dict(type='str', default='present', choices=['absent', 'present']), + priv=dict(type='str', default=None), + db=dict(type='str', default='', aliases=['login_db']), + fail_on_user=dict(type='bool', default='yes', aliases=['fail_on_role']), + role_attr_flags=dict(type='str', default=''), + encrypted=dict(type='bool', default='yes'), + no_password_changes=dict(type='bool', default='no'), + expires=dict(type='str', default=None), + conn_limit=dict(type='int', default=None), + session_role=dict(type='str'), + groups=dict(type='list', elements='str'), + comment=dict(type='str', default=None), + ) + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True + ) + + user = module.params["user"] + password = module.params["password"] + state = module.params["state"] + fail_on_user = module.params["fail_on_user"] + if module.params['db'] == '' and module.params["priv"] is not None: + module.fail_json(msg="privileges require a database to be specified") + privs = parse_privs(module.params["priv"], module.params["db"]) + no_password_changes = module.params["no_password_changes"] + if module.params["encrypted"]: + encrypted = "ENCRYPTED" + else: + encrypted = "UNENCRYPTED" + expires = module.params["expires"] + conn_limit = module.params["conn_limit"] + role_attr_flags = module.params["role_attr_flags"] + groups = module.params["groups"] + if groups: + groups = [e.strip() for e in groups] + comment = module.params["comment"] + + conn_params = get_conn_params(module, module.params, warn_db_default=False) + db_connection = connect_to_db(module, conn_params) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + try: + role_attr_flags = parse_role_attrs(cursor, role_attr_flags) + except InvalidFlagsError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + kw = dict(user=user) + changed = False + user_removed = False + + if state == "present": + if user_exists(cursor, user): + try: + changed = user_alter(db_connection, module, user, password, + role_attr_flags, encrypted, expires, no_password_changes, conn_limit) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + else: + try: + changed = user_add(cursor, user, password, + role_attr_flags, encrypted, expires, conn_limit) + except psycopg2.ProgrammingError as e: + module.fail_json(msg="Unable to add user with given requirement " + "due to : %s" % to_native(e), + exception=traceback.format_exc()) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + try: + changed = grant_privileges(cursor, user, privs) or changed + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + if groups: + target_roles = [] + target_roles.append(user) + pg_membership = PgMembership(module, cursor, groups, target_roles) + changed = pg_membership.grant() or changed + executed_queries.extend(pg_membership.executed_queries) + + if comment is not None: + try: + changed = add_comment(cursor, user, comment) or changed + except Exception as e: + module.fail_json(msg='Unable to add comment on role: %s' % to_native(e), + exception=traceback.format_exc()) + + else: + if user_exists(cursor, user): + if module.check_mode: + changed = True + kw['user_removed'] = True + else: + try: + changed = revoke_privileges(cursor, user, privs) + user_removed = user_delete(cursor, user) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + changed = changed or user_removed + if fail_on_user and not user_removed: + msg = "Unable to remove user" + module.fail_json(msg=msg) + kw['user_removed'] = user_removed + + if changed: + if module.check_mode: + db_connection.rollback() + else: + db_connection.commit() + + kw['changed'] = changed + kw['queries'] = executed_queries + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/rabbitmq_plugin.py b/test/support/integration/plugins/modules/rabbitmq_plugin.py new file mode 100644 index 00000000..301bbfe2 --- /dev/null +++ b/test/support/integration/plugins/modules/rabbitmq_plugin.py @@ -0,0 +1,180 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2013, Chatham Financial <oss@chathamfinancial.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community' +} + + +DOCUMENTATION = ''' +--- +module: rabbitmq_plugin +short_description: Manage RabbitMQ plugins +description: + - This module can be used to enable or disable RabbitMQ plugins. +version_added: "1.1" +author: + - Chris Hoffman (@chrishoffman) +options: + names: + description: + - Comma-separated list of plugin names. Also, accepts plugin name. + required: true + aliases: [name] + new_only: + description: + - Only enable missing plugins. + - Does not disable plugins that are not in the names list. + type: bool + default: "no" + state: + description: + - Specify if plugins are to be enabled or disabled. + default: enabled + choices: [enabled, disabled] + prefix: + description: + - Specify a custom install prefix to a Rabbit. + version_added: "1.3" +''' + +EXAMPLES = ''' +- name: Enables the rabbitmq_management plugin + rabbitmq_plugin: + names: rabbitmq_management + state: enabled + +- name: Enable multiple rabbitmq plugins + rabbitmq_plugin: + names: rabbitmq_management,rabbitmq_management_visualiser + state: enabled + +- name: Disable plugin + rabbitmq_plugin: + names: rabbitmq_management + state: disabled + +- name: Enable every plugin in list with existing plugins + rabbitmq_plugin: + names: rabbitmq_management,rabbitmq_management_visualiser,rabbitmq_shovel,rabbitmq_shovel_management + state: enabled + new_only: 'yes' +''' + +RETURN = ''' +enabled: + description: list of plugins enabled during task run + returned: always + type: list + sample: ["rabbitmq_management"] +disabled: + description: list of plugins disabled during task run + returned: always + type: list + sample: ["rabbitmq_management"] +''' + +import os +from ansible.module_utils.basic import AnsibleModule + + +class RabbitMqPlugins(object): + + def __init__(self, module): + self.module = module + bin_path = '' + if module.params['prefix']: + if os.path.isdir(os.path.join(module.params['prefix'], 'bin')): + bin_path = os.path.join(module.params['prefix'], 'bin') + elif os.path.isdir(os.path.join(module.params['prefix'], 'sbin')): + bin_path = os.path.join(module.params['prefix'], 'sbin') + else: + # No such path exists. + module.fail_json(msg="No binary folder in prefix %s" % module.params['prefix']) + + self._rabbitmq_plugins = os.path.join(bin_path, "rabbitmq-plugins") + else: + self._rabbitmq_plugins = module.get_bin_path('rabbitmq-plugins', True) + + def _exec(self, args, run_in_check_mode=False): + if not self.module.check_mode or (self.module.check_mode and run_in_check_mode): + cmd = [self._rabbitmq_plugins] + rc, out, err = self.module.run_command(cmd + args, check_rc=True) + return out.splitlines() + return list() + + def get_all(self): + list_output = self._exec(['list', '-E', '-m'], True) + plugins = [] + for plugin in list_output: + if not plugin: + break + plugins.append(plugin) + + return plugins + + def enable(self, name): + self._exec(['enable', name]) + + def disable(self, name): + self._exec(['disable', name]) + + +def main(): + arg_spec = dict( + names=dict(required=True, aliases=['name']), + new_only=dict(default='no', type='bool'), + state=dict(default='enabled', choices=['enabled', 'disabled']), + prefix=dict(required=False, default=None) + ) + module = AnsibleModule( + argument_spec=arg_spec, + supports_check_mode=True + ) + + result = dict() + names = module.params['names'].split(',') + new_only = module.params['new_only'] + state = module.params['state'] + + rabbitmq_plugins = RabbitMqPlugins(module) + enabled_plugins = rabbitmq_plugins.get_all() + + enabled = [] + disabled = [] + if state == 'enabled': + if not new_only: + for plugin in enabled_plugins: + if " " in plugin: + continue + if plugin not in names: + rabbitmq_plugins.disable(plugin) + disabled.append(plugin) + + for name in names: + if name not in enabled_plugins: + rabbitmq_plugins.enable(name) + enabled.append(name) + else: + for plugin in enabled_plugins: + if plugin in names: + rabbitmq_plugins.disable(plugin) + disabled.append(plugin) + + result['changed'] = len(enabled) > 0 or len(disabled) > 0 + result['enabled'] = enabled + result['disabled'] = disabled + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/rabbitmq_queue.py b/test/support/integration/plugins/modules/rabbitmq_queue.py new file mode 100644 index 00000000..567ec813 --- /dev/null +++ b/test/support/integration/plugins/modules/rabbitmq_queue.py @@ -0,0 +1,257 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Manuel Sousa <manuel.sousa@gmail.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: rabbitmq_queue +author: Manuel Sousa (@manuel-sousa) +version_added: "2.0" + +short_description: Manage rabbitMQ queues +description: + - This module uses rabbitMQ Rest API to create/delete queues +requirements: [ "requests >= 1.0.0" ] +options: + name: + description: + - Name of the queue + required: true + state: + description: + - Whether the queue should be present or absent + choices: [ "present", "absent" ] + default: present + durable: + description: + - whether queue is durable or not + type: bool + default: 'yes' + auto_delete: + description: + - if the queue should delete itself after all queues/queues unbound from it + type: bool + default: 'no' + message_ttl: + description: + - How long a message can live in queue before it is discarded (milliseconds) + default: forever + auto_expires: + description: + - How long a queue can be unused before it is automatically deleted (milliseconds) + default: forever + max_length: + description: + - How many messages can the queue contain before it starts rejecting + default: no limit + dead_letter_exchange: + description: + - Optional name of an exchange to which messages will be republished if they + - are rejected or expire + dead_letter_routing_key: + description: + - Optional replacement routing key to use when a message is dead-lettered. + - Original routing key will be used if unset + max_priority: + description: + - Maximum number of priority levels for the queue to support. + - If not set, the queue will not support message priorities. + - Larger numbers indicate higher priority. + version_added: "2.4" + arguments: + description: + - extra arguments for queue. If defined this argument is a key/value dictionary + default: {} +extends_documentation_fragment: + - rabbitmq +''' + +EXAMPLES = ''' +# Create a queue +- rabbitmq_queue: + name: myQueue + +# Create a queue on remote host +- rabbitmq_queue: + name: myRemoteQueue + login_user: user + login_password: secret + login_host: remote.example.org +''' + +import json +import traceback + +REQUESTS_IMP_ERR = None +try: + import requests + HAS_REQUESTS = True +except ImportError: + REQUESTS_IMP_ERR = traceback.format_exc() + HAS_REQUESTS = False + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.six.moves.urllib import parse as urllib_parse +from ansible.module_utils.rabbitmq import rabbitmq_argument_spec + + +def main(): + + argument_spec = rabbitmq_argument_spec() + argument_spec.update( + dict( + state=dict(default='present', choices=['present', 'absent'], type='str'), + name=dict(required=True, type='str'), + durable=dict(default=True, type='bool'), + auto_delete=dict(default=False, type='bool'), + message_ttl=dict(default=None, type='int'), + auto_expires=dict(default=None, type='int'), + max_length=dict(default=None, type='int'), + dead_letter_exchange=dict(default=None, type='str'), + dead_letter_routing_key=dict(default=None, type='str'), + arguments=dict(default=dict(), type='dict'), + max_priority=dict(default=None, type='int') + ) + ) + module = AnsibleModule(argument_spec=argument_spec, supports_check_mode=True) + + url = "%s://%s:%s/api/queues/%s/%s" % ( + module.params['login_protocol'], + module.params['login_host'], + module.params['login_port'], + urllib_parse.quote(module.params['vhost'], ''), + module.params['name'] + ) + + if not HAS_REQUESTS: + module.fail_json(msg=missing_required_lib("requests"), exception=REQUESTS_IMP_ERR) + + result = dict(changed=False, name=module.params['name']) + + # Check if queue already exists + r = requests.get(url, auth=(module.params['login_user'], module.params['login_password']), + verify=module.params['ca_cert'], cert=(module.params['client_cert'], module.params['client_key'])) + + if r.status_code == 200: + queue_exists = True + response = r.json() + elif r.status_code == 404: + queue_exists = False + response = r.text + else: + module.fail_json( + msg="Invalid response from RESTAPI when trying to check if queue exists", + details=r.text + ) + + if module.params['state'] == 'present': + change_required = not queue_exists + else: + change_required = queue_exists + + # Check if attributes change on existing queue + if not change_required and r.status_code == 200 and module.params['state'] == 'present': + if not ( + response['durable'] == module.params['durable'] and + response['auto_delete'] == module.params['auto_delete'] and + ( + ('x-message-ttl' in response['arguments'] and response['arguments']['x-message-ttl'] == module.params['message_ttl']) or + ('x-message-ttl' not in response['arguments'] and module.params['message_ttl'] is None) + ) and + ( + ('x-expires' in response['arguments'] and response['arguments']['x-expires'] == module.params['auto_expires']) or + ('x-expires' not in response['arguments'] and module.params['auto_expires'] is None) + ) and + ( + ('x-max-length' in response['arguments'] and response['arguments']['x-max-length'] == module.params['max_length']) or + ('x-max-length' not in response['arguments'] and module.params['max_length'] is None) + ) and + ( + ('x-dead-letter-exchange' in response['arguments'] and + response['arguments']['x-dead-letter-exchange'] == module.params['dead_letter_exchange']) or + ('x-dead-letter-exchange' not in response['arguments'] and module.params['dead_letter_exchange'] is None) + ) and + ( + ('x-dead-letter-routing-key' in response['arguments'] and + response['arguments']['x-dead-letter-routing-key'] == module.params['dead_letter_routing_key']) or + ('x-dead-letter-routing-key' not in response['arguments'] and module.params['dead_letter_routing_key'] is None) + ) and + ( + ('x-max-priority' in response['arguments'] and + response['arguments']['x-max-priority'] == module.params['max_priority']) or + ('x-max-priority' not in response['arguments'] and module.params['max_priority'] is None) + ) + ): + module.fail_json( + msg="RabbitMQ RESTAPI doesn't support attribute changes for existing queues", + ) + + # Copy parameters to arguments as used by RabbitMQ + for k, v in { + 'message_ttl': 'x-message-ttl', + 'auto_expires': 'x-expires', + 'max_length': 'x-max-length', + 'dead_letter_exchange': 'x-dead-letter-exchange', + 'dead_letter_routing_key': 'x-dead-letter-routing-key', + 'max_priority': 'x-max-priority' + }.items(): + if module.params[k] is not None: + module.params['arguments'][v] = module.params[k] + + # Exit if check_mode + if module.check_mode: + result['changed'] = change_required + result['details'] = response + result['arguments'] = module.params['arguments'] + module.exit_json(**result) + + # Do changes + if change_required: + if module.params['state'] == 'present': + r = requests.put( + url, + auth=(module.params['login_user'], module.params['login_password']), + headers={"content-type": "application/json"}, + data=json.dumps({ + "durable": module.params['durable'], + "auto_delete": module.params['auto_delete'], + "arguments": module.params['arguments'] + }), + verify=module.params['ca_cert'], + cert=(module.params['client_cert'], module.params['client_key']) + ) + elif module.params['state'] == 'absent': + r = requests.delete(url, auth=(module.params['login_user'], module.params['login_password']), + verify=module.params['ca_cert'], cert=(module.params['client_cert'], module.params['client_key'])) + + # RabbitMQ 3.6.7 changed this response code from 204 to 201 + if r.status_code == 204 or r.status_code == 201: + result['changed'] = True + module.exit_json(**result) + else: + module.fail_json( + msg="Error creating queue", + status=r.status_code, + details=r.text + ) + + else: + module.exit_json( + changed=False, + name=module.params['name'] + ) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/s3_bucket.py b/test/support/integration/plugins/modules/s3_bucket.py new file mode 100644 index 00000000..f35cf53b --- /dev/null +++ b/test/support/integration/plugins/modules/s3_bucket.py @@ -0,0 +1,740 @@ +#!/usr/bin/python +# +# This is a free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This Ansible library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: s3_bucket +short_description: Manage S3 buckets in AWS, DigitalOcean, Ceph, Walrus, FakeS3 and StorageGRID +description: + - Manage S3 buckets in AWS, DigitalOcean, Ceph, Walrus, FakeS3 and StorageGRID +version_added: "2.0" +requirements: [ boto3 ] +author: "Rob White (@wimnat)" +options: + force: + description: + - When trying to delete a bucket, delete all keys (including versions and delete markers) + in the bucket first (an s3 bucket must be empty for a successful deletion) + type: bool + default: 'no' + name: + description: + - Name of the s3 bucket + required: true + type: str + policy: + description: + - The JSON policy as a string. + type: json + s3_url: + description: + - S3 URL endpoint for usage with DigitalOcean, Ceph, Eucalyptus and fakes3 etc. + - Assumes AWS if not specified. + - For Walrus, use FQDN of the endpoint without scheme nor path. + aliases: [ S3_URL ] + type: str + ceph: + description: + - Enable API compatibility with Ceph. It takes into account the S3 API subset working + with Ceph in order to provide the same module behaviour where possible. + type: bool + version_added: "2.2" + requester_pays: + description: + - With Requester Pays buckets, the requester instead of the bucket owner pays the cost + of the request and the data download from the bucket. + type: bool + default: False + state: + description: + - Create or remove the s3 bucket + required: false + default: present + choices: [ 'present', 'absent' ] + type: str + tags: + description: + - tags dict to apply to bucket + type: dict + purge_tags: + description: + - whether to remove tags that aren't present in the C(tags) parameter + type: bool + default: True + version_added: "2.9" + versioning: + description: + - Whether versioning is enabled or disabled (note that once versioning is enabled, it can only be suspended) + type: bool + encryption: + description: + - Describes the default server-side encryption to apply to new objects in the bucket. + In order to remove the server-side encryption, the encryption needs to be set to 'none' explicitly. + choices: [ 'none', 'AES256', 'aws:kms' ] + version_added: "2.9" + type: str + encryption_key_id: + description: KMS master key ID to use for the default encryption. This parameter is allowed if encryption is aws:kms. If + not specified then it will default to the AWS provided KMS key. + version_added: "2.9" + type: str +extends_documentation_fragment: + - aws + - ec2 +notes: + - If C(requestPayment), C(policy), C(tagging) or C(versioning) + operations/API aren't implemented by the endpoint, module doesn't fail + if each parameter satisfies the following condition. + I(requester_pays) is C(False), I(policy), I(tags), and I(versioning) are C(None). +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Create a simple s3 bucket +- s3_bucket: + name: mys3bucket + state: present + +# Create a simple s3 bucket on Ceph Rados Gateway +- s3_bucket: + name: mys3bucket + s3_url: http://your-ceph-rados-gateway-server.xxx + ceph: true + +# Remove an s3 bucket and any keys it contains +- s3_bucket: + name: mys3bucket + state: absent + force: yes + +# Create a bucket, add a policy from a file, enable requester pays, enable versioning and tag +- s3_bucket: + name: mys3bucket + policy: "{{ lookup('file','policy.json') }}" + requester_pays: yes + versioning: yes + tags: + example: tag1 + another: tag2 + +# Create a simple DigitalOcean Spaces bucket using their provided regional endpoint +- s3_bucket: + name: mydobucket + s3_url: 'https://nyc3.digitaloceanspaces.com' + +# Create a bucket with AES256 encryption +- s3_bucket: + name: mys3bucket + state: present + encryption: "AES256" + +# Create a bucket with aws:kms encryption, KMS key +- s3_bucket: + name: mys3bucket + state: present + encryption: "aws:kms" + encryption_key_id: "arn:aws:kms:us-east-1:1234/5678example" + +# Create a bucket with aws:kms encryption, default key +- s3_bucket: + name: mys3bucket + state: present + encryption: "aws:kms" +''' + +import json +import os +import time + +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ansible.module_utils.six import string_types +from ansible.module_utils.basic import to_text +from ansible.module_utils.aws.core import AnsibleAWSModule, is_boto3_error_code +from ansible.module_utils.ec2 import compare_policies, ec2_argument_spec, boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list +from ansible.module_utils.ec2 import get_aws_connection_info, boto3_conn, AWSRetry + +try: + from botocore.exceptions import BotoCoreError, ClientError, EndpointConnectionError, WaiterError +except ImportError: + pass # handled by AnsibleAWSModule + + +def create_or_update_bucket(s3_client, module, location): + + policy = module.params.get("policy") + name = module.params.get("name") + requester_pays = module.params.get("requester_pays") + tags = module.params.get("tags") + purge_tags = module.params.get("purge_tags") + versioning = module.params.get("versioning") + encryption = module.params.get("encryption") + encryption_key_id = module.params.get("encryption_key_id") + changed = False + result = {} + + try: + bucket_is_present = bucket_exists(s3_client, name) + except EndpointConnectionError as e: + module.fail_json_aws(e, msg="Invalid endpoint provided: %s" % to_text(e)) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to check bucket presence") + + if not bucket_is_present: + try: + bucket_changed = create_bucket(s3_client, name, location) + s3_client.get_waiter('bucket_exists').wait(Bucket=name) + changed = changed or bucket_changed + except WaiterError as e: + module.fail_json_aws(e, msg='An error occurred waiting for the bucket to become available') + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed while creating bucket") + + # Versioning + try: + versioning_status = get_bucket_versioning(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket versioning") + except ClientError as exp: + if exp.response['Error']['Code'] != 'NotImplemented' or versioning is not None: + module.fail_json_aws(exp, msg="Failed to get bucket versioning") + else: + if versioning is not None: + required_versioning = None + if versioning and versioning_status.get('Status') != "Enabled": + required_versioning = 'Enabled' + elif not versioning and versioning_status.get('Status') == "Enabled": + required_versioning = 'Suspended' + + if required_versioning: + try: + put_bucket_versioning(s3_client, name, required_versioning) + changed = True + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to update bucket versioning") + + versioning_status = wait_versioning_is_applied(module, s3_client, name, required_versioning) + + # This output format is there to ensure compatibility with previous versions of the module + result['versioning'] = { + 'Versioning': versioning_status.get('Status', 'Disabled'), + 'MfaDelete': versioning_status.get('MFADelete', 'Disabled'), + } + + # Requester pays + try: + requester_pays_status = get_bucket_request_payment(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket request payment") + except ClientError as exp: + if exp.response['Error']['Code'] not in ('NotImplemented', 'XNotImplemented') or requester_pays: + module.fail_json_aws(exp, msg="Failed to get bucket request payment") + else: + if requester_pays: + payer = 'Requester' if requester_pays else 'BucketOwner' + if requester_pays_status != payer: + put_bucket_request_payment(s3_client, name, payer) + requester_pays_status = wait_payer_is_applied(module, s3_client, name, payer, should_fail=False) + if requester_pays_status is None: + # We have seen that it happens quite a lot of times that the put request was not taken into + # account, so we retry one more time + put_bucket_request_payment(s3_client, name, payer) + requester_pays_status = wait_payer_is_applied(module, s3_client, name, payer, should_fail=True) + changed = True + + result['requester_pays'] = requester_pays + + # Policy + try: + current_policy = get_bucket_policy(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket policy") + except ClientError as exp: + if exp.response['Error']['Code'] != 'NotImplemented' or policy is not None: + module.fail_json_aws(exp, msg="Failed to get bucket policy") + else: + if policy is not None: + if isinstance(policy, string_types): + policy = json.loads(policy) + + if not policy and current_policy: + try: + delete_bucket_policy(s3_client, name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket policy") + current_policy = wait_policy_is_applied(module, s3_client, name, policy) + changed = True + elif compare_policies(current_policy, policy): + try: + put_bucket_policy(s3_client, name, policy) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to update bucket policy") + current_policy = wait_policy_is_applied(module, s3_client, name, policy, should_fail=False) + if current_policy is None: + # As for request payement, it happens quite a lot of times that the put request was not taken into + # account, so we retry one more time + put_bucket_policy(s3_client, name, policy) + current_policy = wait_policy_is_applied(module, s3_client, name, policy, should_fail=True) + changed = True + + result['policy'] = current_policy + + # Tags + try: + current_tags_dict = get_current_bucket_tags_dict(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket tags") + except ClientError as exp: + if exp.response['Error']['Code'] not in ('NotImplemented', 'XNotImplemented') or tags is not None: + module.fail_json_aws(exp, msg="Failed to get bucket tags") + else: + if tags is not None: + # Tags are always returned as text + tags = dict((to_text(k), to_text(v)) for k, v in tags.items()) + if not purge_tags: + # Ensure existing tags that aren't updated by desired tags remain + current_copy = current_tags_dict.copy() + current_copy.update(tags) + tags = current_copy + if current_tags_dict != tags: + if tags: + try: + put_bucket_tagging(s3_client, name, tags) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to update bucket tags") + else: + if purge_tags: + try: + delete_bucket_tagging(s3_client, name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket tags") + current_tags_dict = wait_tags_are_applied(module, s3_client, name, tags) + changed = True + + result['tags'] = current_tags_dict + + # Encryption + if hasattr(s3_client, "get_bucket_encryption"): + try: + current_encryption = get_bucket_encryption(s3_client, name) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to get bucket encryption") + elif encryption is not None: + module.fail_json(msg="Using bucket encryption requires botocore version >= 1.7.41") + + if encryption is not None: + current_encryption_algorithm = current_encryption.get('SSEAlgorithm') if current_encryption else None + current_encryption_key = current_encryption.get('KMSMasterKeyID') if current_encryption else None + if encryption == 'none' and current_encryption_algorithm is not None: + try: + delete_bucket_encryption(s3_client, name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket encryption") + current_encryption = wait_encryption_is_applied(module, s3_client, name, None) + changed = True + elif encryption != 'none' and (encryption != current_encryption_algorithm) or (encryption == 'aws:kms' and current_encryption_key != encryption_key_id): + expected_encryption = {'SSEAlgorithm': encryption} + if encryption == 'aws:kms' and encryption_key_id is not None: + expected_encryption.update({'KMSMasterKeyID': encryption_key_id}) + try: + put_bucket_encryption(s3_client, name, expected_encryption) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to set bucket encryption") + current_encryption = wait_encryption_is_applied(module, s3_client, name, expected_encryption) + changed = True + + result['encryption'] = current_encryption + + module.exit_json(changed=changed, name=name, **result) + + +def bucket_exists(s3_client, bucket_name): + # head_bucket appeared to be really inconsistent, so we use list_buckets instead, + # and loop over all the buckets, even if we know it's less performant :( + all_buckets = s3_client.list_buckets(Bucket=bucket_name)['Buckets'] + return any(bucket['Name'] == bucket_name for bucket in all_buckets) + + +@AWSRetry.exponential_backoff(max_delay=120) +def create_bucket(s3_client, bucket_name, location): + try: + configuration = {} + if location not in ('us-east-1', None): + configuration['LocationConstraint'] = location + if len(configuration) > 0: + s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration=configuration) + else: + s3_client.create_bucket(Bucket=bucket_name) + return True + except ClientError as e: + error_code = e.response['Error']['Code'] + if error_code == 'BucketAlreadyOwnedByYou': + # We should never get there since we check the bucket presence before calling the create_or_update_bucket + # method. However, the AWS Api sometimes fails to report bucket presence, so we catch this exception + return False + else: + raise e + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_tagging(s3_client, bucket_name, tags): + s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={'TagSet': ansible_dict_to_boto3_tag_list(tags)}) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_policy(s3_client, bucket_name, policy): + s3_client.put_bucket_policy(Bucket=bucket_name, Policy=json.dumps(policy)) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def delete_bucket_policy(s3_client, bucket_name): + s3_client.delete_bucket_policy(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_policy(s3_client, bucket_name): + try: + current_policy = json.loads(s3_client.get_bucket_policy(Bucket=bucket_name).get('Policy')) + except ClientError as e: + if e.response['Error']['Code'] == 'NoSuchBucketPolicy': + current_policy = None + else: + raise e + return current_policy + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_request_payment(s3_client, bucket_name, payer): + s3_client.put_bucket_request_payment(Bucket=bucket_name, RequestPaymentConfiguration={'Payer': payer}) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_request_payment(s3_client, bucket_name): + return s3_client.get_bucket_request_payment(Bucket=bucket_name).get('Payer') + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_versioning(s3_client, bucket_name): + return s3_client.get_bucket_versioning(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_versioning(s3_client, bucket_name, required_versioning): + s3_client.put_bucket_versioning(Bucket=bucket_name, VersioningConfiguration={'Status': required_versioning}) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_encryption(s3_client, bucket_name): + try: + result = s3_client.get_bucket_encryption(Bucket=bucket_name) + return result.get('ServerSideEncryptionConfiguration', {}).get('Rules', [])[0].get('ApplyServerSideEncryptionByDefault') + except ClientError as e: + if e.response['Error']['Code'] == 'ServerSideEncryptionConfigurationNotFoundError': + return None + else: + raise e + except (IndexError, KeyError): + return None + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_encryption(s3_client, bucket_name, encryption): + server_side_encryption_configuration = {'Rules': [{'ApplyServerSideEncryptionByDefault': encryption}]} + s3_client.put_bucket_encryption(Bucket=bucket_name, ServerSideEncryptionConfiguration=server_side_encryption_configuration) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def delete_bucket_tagging(s3_client, bucket_name): + s3_client.delete_bucket_tagging(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def delete_bucket_encryption(s3_client, bucket_name): + s3_client.delete_bucket_encryption(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120) +def delete_bucket(s3_client, bucket_name): + try: + s3_client.delete_bucket(Bucket=bucket_name) + except ClientError as e: + if e.response['Error']['Code'] == 'NoSuchBucket': + # This means bucket should have been in a deleting state when we checked it existence + # We just ignore the error + pass + else: + raise e + + +def wait_policy_is_applied(module, s3_client, bucket_name, expected_policy, should_fail=True): + for dummy in range(0, 12): + try: + current_policy = get_bucket_policy(s3_client, bucket_name) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to get bucket policy") + + if compare_policies(current_policy, expected_policy): + time.sleep(5) + else: + return current_policy + if should_fail: + module.fail_json(msg="Bucket policy failed to apply in the expected time") + else: + return None + + +def wait_payer_is_applied(module, s3_client, bucket_name, expected_payer, should_fail=True): + for dummy in range(0, 12): + try: + requester_pays_status = get_bucket_request_payment(s3_client, bucket_name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get bucket request payment") + if requester_pays_status != expected_payer: + time.sleep(5) + else: + return requester_pays_status + if should_fail: + module.fail_json(msg="Bucket request payment failed to apply in the expected time") + else: + return None + + +def wait_encryption_is_applied(module, s3_client, bucket_name, expected_encryption): + for dummy in range(0, 12): + try: + encryption = get_bucket_encryption(s3_client, bucket_name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get updated encryption for bucket") + if encryption != expected_encryption: + time.sleep(5) + else: + return encryption + module.fail_json(msg="Bucket encryption failed to apply in the expected time") + + +def wait_versioning_is_applied(module, s3_client, bucket_name, required_versioning): + for dummy in range(0, 24): + try: + versioning_status = get_bucket_versioning(s3_client, bucket_name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get updated versioning for bucket") + if versioning_status.get('Status') != required_versioning: + time.sleep(8) + else: + return versioning_status + module.fail_json(msg="Bucket versioning failed to apply in the expected time") + + +def wait_tags_are_applied(module, s3_client, bucket_name, expected_tags_dict): + for dummy in range(0, 12): + try: + current_tags_dict = get_current_bucket_tags_dict(s3_client, bucket_name) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to get bucket policy") + if current_tags_dict != expected_tags_dict: + time.sleep(5) + else: + return current_tags_dict + module.fail_json(msg="Bucket tags failed to apply in the expected time") + + +def get_current_bucket_tags_dict(s3_client, bucket_name): + try: + current_tags = s3_client.get_bucket_tagging(Bucket=bucket_name).get('TagSet') + except ClientError as e: + if e.response['Error']['Code'] == 'NoSuchTagSet': + return {} + raise e + + return boto3_tag_list_to_ansible_dict(current_tags) + + +def paginated_list(s3_client, **pagination_params): + pg = s3_client.get_paginator('list_objects_v2') + for page in pg.paginate(**pagination_params): + yield [data['Key'] for data in page.get('Contents', [])] + + +def paginated_versions_list(s3_client, **pagination_params): + try: + pg = s3_client.get_paginator('list_object_versions') + for page in pg.paginate(**pagination_params): + # We have to merge the Versions and DeleteMarker lists here, as DeleteMarkers can still prevent a bucket deletion + yield [(data['Key'], data['VersionId']) for data in (page.get('Versions', []) + page.get('DeleteMarkers', []))] + except is_boto3_error_code('NoSuchBucket'): + yield [] + + +def destroy_bucket(s3_client, module): + + force = module.params.get("force") + name = module.params.get("name") + try: + bucket_is_present = bucket_exists(s3_client, name) + except EndpointConnectionError as e: + module.fail_json_aws(e, msg="Invalid endpoint provided: %s" % to_text(e)) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to check bucket presence") + + if not bucket_is_present: + module.exit_json(changed=False) + + if force: + # if there are contents then we need to delete them (including versions) before we can delete the bucket + try: + for key_version_pairs in paginated_versions_list(s3_client, Bucket=name): + formatted_keys = [{'Key': key, 'VersionId': version} for key, version in key_version_pairs] + for fk in formatted_keys: + # remove VersionId from cases where they are `None` so that + # unversioned objects are deleted using `DeleteObject` + # rather than `DeleteObjectVersion`, improving backwards + # compatibility with older IAM policies. + if not fk.get('VersionId'): + fk.pop('VersionId') + + if formatted_keys: + resp = s3_client.delete_objects(Bucket=name, Delete={'Objects': formatted_keys}) + if resp.get('Errors'): + module.fail_json( + msg='Could not empty bucket before deleting. Could not delete objects: {0}'.format( + ', '.join([k['Key'] for k in resp['Errors']]) + ), + errors=resp['Errors'], response=resp + ) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed while deleting bucket") + + try: + delete_bucket(s3_client, name) + s3_client.get_waiter('bucket_not_exists').wait(Bucket=name, WaiterConfig=dict(Delay=5, MaxAttempts=60)) + except WaiterError as e: + module.fail_json_aws(e, msg='An error occurred waiting for the bucket to be deleted.') + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket") + + module.exit_json(changed=True) + + +def is_fakes3(s3_url): + """ Return True if s3_url has scheme fakes3:// """ + if s3_url is not None: + return urlparse(s3_url).scheme in ('fakes3', 'fakes3s') + else: + return False + + +def get_s3_client(module, aws_connect_kwargs, location, ceph, s3_url): + if s3_url and ceph: # TODO - test this + ceph = urlparse(s3_url) + params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https', region=location, endpoint=s3_url, **aws_connect_kwargs) + elif is_fakes3(s3_url): + fakes3 = urlparse(s3_url) + port = fakes3.port + if fakes3.scheme == 'fakes3s': + protocol = "https" + if port is None: + port = 443 + else: + protocol = "http" + if port is None: + port = 80 + params = dict(module=module, conn_type='client', resource='s3', region=location, + endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)), + use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs) + else: + params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=s3_url, **aws_connect_kwargs) + return boto3_conn(**params) + + +def main(): + + argument_spec = ec2_argument_spec() + argument_spec.update( + dict( + force=dict(default=False, type='bool'), + policy=dict(type='json'), + name=dict(required=True), + requester_pays=dict(default=False, type='bool'), + s3_url=dict(aliases=['S3_URL']), + state=dict(default='present', choices=['present', 'absent']), + tags=dict(type='dict'), + purge_tags=dict(type='bool', default=True), + versioning=dict(type='bool'), + ceph=dict(default=False, type='bool'), + encryption=dict(choices=['none', 'AES256', 'aws:kms']), + encryption_key_id=dict() + ) + ) + + module = AnsibleAWSModule( + argument_spec=argument_spec, + ) + + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True) + + if region in ('us-east-1', '', None): + # default to US Standard region + location = 'us-east-1' + else: + # Boto uses symbolic names for locations but region strings will + # actually work fine for everything except us-east-1 (US Standard) + location = region + + s3_url = module.params.get('s3_url') + ceph = module.params.get('ceph') + + # allow eucarc environment variables to be used if ansible vars aren't set + if not s3_url and 'S3_URL' in os.environ: + s3_url = os.environ['S3_URL'] + + if ceph and not s3_url: + module.fail_json(msg='ceph flavour requires s3_url') + + # Look at s3_url and tweak connection settings + # if connecting to Ceph RGW, Walrus or fakes3 + if s3_url: + for key in ['validate_certs', 'security_token', 'profile_name']: + aws_connect_kwargs.pop(key, None) + s3_client = get_s3_client(module, aws_connect_kwargs, location, ceph, s3_url) + + if s3_client is None: # this should never happen + module.fail_json(msg='Unknown error, failed to create s3 connection, no information from boto.') + + state = module.params.get("state") + encryption = module.params.get("encryption") + encryption_key_id = module.params.get("encryption_key_id") + + # Parameter validation + if encryption_key_id is not None and encryption is None: + module.fail_json(msg="You must specify encryption parameter along with encryption_key_id.") + elif encryption_key_id is not None and encryption != 'aws:kms': + module.fail_json(msg="Only 'aws:kms' is a valid option for encryption parameter when you specify encryption_key_id.") + + if state == 'present': + create_or_update_bucket(s3_client, module, location) + elif state == 'absent': + destroy_bucket(s3_client, module) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/sefcontext.py b/test/support/integration/plugins/modules/sefcontext.py new file mode 100644 index 00000000..33e3fd2e --- /dev/null +++ b/test/support/integration/plugins/modules/sefcontext.py @@ -0,0 +1,298 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: sefcontext +short_description: Manages SELinux file context mapping definitions +description: +- Manages SELinux file context mapping definitions. +- Similar to the C(semanage fcontext) command. +version_added: '2.2' +options: + target: + description: + - Target path (expression). + type: str + required: yes + aliases: [ path ] + ftype: + description: + - The file type that should have SELinux contexts applied. + - "The following file type options are available:" + - C(a) for all files, + - C(b) for block devices, + - C(c) for character devices, + - C(d) for directories, + - C(f) for regular files, + - C(l) for symbolic links, + - C(p) for named pipes, + - C(s) for socket files. + type: str + choices: [ a, b, c, d, f, l, p, s ] + default: a + setype: + description: + - SELinux type for the specified target. + type: str + required: yes + seuser: + description: + - SELinux user for the specified target. + type: str + selevel: + description: + - SELinux range for the specified target. + type: str + aliases: [ serange ] + state: + description: + - Whether the SELinux file context must be C(absent) or C(present). + type: str + choices: [ absent, present ] + default: present + reload: + description: + - Reload SELinux policy after commit. + - Note that this does not apply SELinux file contexts to existing files. + type: bool + default: yes + ignore_selinux_state: + description: + - Useful for scenarios (chrooted environment) that you can't get the real SELinux state. + type: bool + default: no + version_added: '2.8' +notes: +- The changes are persistent across reboots. +- The M(sefcontext) module does not modify existing files to the new + SELinux context(s), so it is advisable to first create the SELinux + file contexts before creating files, or run C(restorecon) manually + for the existing files that require the new SELinux file contexts. +- Not applying SELinux fcontexts to existing files is a deliberate + decision as it would be unclear what reported changes would entail + to, and there's no guarantee that applying SELinux fcontext does + not pick up other unrelated prior changes. +requirements: +- libselinux-python +- policycoreutils-python +author: +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r''' +- name: Allow apache to modify files in /srv/git_repos + sefcontext: + target: '/srv/git_repos(/.*)?' + setype: httpd_git_rw_content_t + state: present + +- name: Apply new SELinux file context to filesystem + command: restorecon -irv /srv/git_repos +''' + +RETURN = r''' +# Default return values +''' + +import traceback + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native + +SELINUX_IMP_ERR = None +try: + import selinux + HAVE_SELINUX = True +except ImportError: + SELINUX_IMP_ERR = traceback.format_exc() + HAVE_SELINUX = False + +SEOBJECT_IMP_ERR = None +try: + import seobject + HAVE_SEOBJECT = True +except ImportError: + SEOBJECT_IMP_ERR = traceback.format_exc() + HAVE_SEOBJECT = False + +# Add missing entries (backward compatible) +if HAVE_SEOBJECT: + seobject.file_types.update( + a=seobject.SEMANAGE_FCONTEXT_ALL, + b=seobject.SEMANAGE_FCONTEXT_BLOCK, + c=seobject.SEMANAGE_FCONTEXT_CHAR, + d=seobject.SEMANAGE_FCONTEXT_DIR, + f=seobject.SEMANAGE_FCONTEXT_REG, + l=seobject.SEMANAGE_FCONTEXT_LINK, + p=seobject.SEMANAGE_FCONTEXT_PIPE, + s=seobject.SEMANAGE_FCONTEXT_SOCK, + ) + +# Make backward compatible +option_to_file_type_str = dict( + a='all files', + b='block device', + c='character device', + d='directory', + f='regular file', + l='symbolic link', + p='named pipe', + s='socket', +) + + +def get_runtime_status(ignore_selinux_state=False): + return True if ignore_selinux_state is True else selinux.is_selinux_enabled() + + +def semanage_fcontext_exists(sefcontext, target, ftype): + ''' Get the SELinux file context mapping definition from policy. Return None if it does not exist. ''' + + # Beware that records comprise of a string representation of the file_type + record = (target, option_to_file_type_str[ftype]) + records = sefcontext.get_all() + try: + return records[record] + except KeyError: + return None + + +def semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser, sestore=''): + ''' Add or modify SELinux file context mapping definition to the policy. ''' + + changed = False + prepared_diff = '' + + try: + sefcontext = seobject.fcontextRecords(sestore) + sefcontext.set_reload(do_reload) + exists = semanage_fcontext_exists(sefcontext, target, ftype) + if exists: + # Modify existing entry + orig_seuser, orig_serole, orig_setype, orig_serange = exists + + if seuser is None: + seuser = orig_seuser + if serange is None: + serange = orig_serange + + if setype != orig_setype or seuser != orig_seuser or serange != orig_serange: + if not module.check_mode: + sefcontext.modify(target, setype, ftype, serange, seuser) + changed = True + + if module._diff: + prepared_diff += '# Change to semanage file context mappings\n' + prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, orig_seuser, orig_serole, orig_setype, orig_serange) + prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, orig_serole, setype, serange) + else: + # Add missing entry + if seuser is None: + seuser = 'system_u' + if serange is None: + serange = 's0' + + if not module.check_mode: + sefcontext.add(target, setype, ftype, serange, seuser) + changed = True + + if module._diff: + prepared_diff += '# Addition to semanage file context mappings\n' + prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, 'object_r', setype, serange) + + except Exception as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e))) + + if module._diff and prepared_diff: + result['diff'] = dict(prepared=prepared_diff) + + module.exit_json(changed=changed, seuser=seuser, serange=serange, **result) + + +def semanage_fcontext_delete(module, result, target, ftype, do_reload, sestore=''): + ''' Delete SELinux file context mapping definition from the policy. ''' + + changed = False + prepared_diff = '' + + try: + sefcontext = seobject.fcontextRecords(sestore) + sefcontext.set_reload(do_reload) + exists = semanage_fcontext_exists(sefcontext, target, ftype) + if exists: + # Remove existing entry + orig_seuser, orig_serole, orig_setype, orig_serange = exists + + if not module.check_mode: + sefcontext.delete(target, ftype) + changed = True + + if module._diff: + prepared_diff += '# Deletion to semanage file context mappings\n' + prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, exists[0], exists[1], exists[2], exists[3]) + + except Exception as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e))) + + if module._diff and prepared_diff: + result['diff'] = dict(prepared=prepared_diff) + + module.exit_json(changed=changed, **result) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + ignore_selinux_state=dict(type='bool', default=False), + target=dict(type='str', required=True, aliases=['path']), + ftype=dict(type='str', default='a', choices=option_to_file_type_str.keys()), + setype=dict(type='str', required=True), + seuser=dict(type='str'), + selevel=dict(type='str', aliases=['serange']), + state=dict(type='str', default='present', choices=['absent', 'present']), + reload=dict(type='bool', default=True), + ), + supports_check_mode=True, + ) + if not HAVE_SELINUX: + module.fail_json(msg=missing_required_lib("libselinux-python"), exception=SELINUX_IMP_ERR) + + if not HAVE_SEOBJECT: + module.fail_json(msg=missing_required_lib("policycoreutils-python"), exception=SEOBJECT_IMP_ERR) + + ignore_selinux_state = module.params['ignore_selinux_state'] + + if not get_runtime_status(ignore_selinux_state): + module.fail_json(msg="SELinux is disabled on this host.") + + target = module.params['target'] + ftype = module.params['ftype'] + setype = module.params['setype'] + seuser = module.params['seuser'] + serange = module.params['selevel'] + state = module.params['state'] + do_reload = module.params['reload'] + + result = dict(target=target, ftype=ftype, setype=setype, state=state) + + if state == 'present': + semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser) + elif state == 'absent': + semanage_fcontext_delete(module, result, target, ftype, do_reload) + else: + module.fail_json(msg='Invalid value of argument "state": {0}'.format(state)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/selogin.py b/test/support/integration/plugins/modules/selogin.py new file mode 100644 index 00000000..6429ef36 --- /dev/null +++ b/test/support/integration/plugins/modules/selogin.py @@ -0,0 +1,260 @@ +#!/usr/bin/python + +# (c) 2017, Petr Lautrbach <plautrba@redhat.com> +# Based on seport.py module (c) 2014, Dan Keder <dan.keder@gmail.com> + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +--- +module: selogin +short_description: Manages linux user to SELinux user mapping +description: + - Manages linux user to SELinux user mapping +version_added: "2.8" +options: + login: + description: + - a Linux user + required: true + seuser: + description: + - SELinux user name + required: true + selevel: + aliases: [ serange ] + description: + - MLS/MCS Security Range (MLS/MCS Systems only) SELinux Range for SELinux login mapping defaults to the SELinux user record range. + default: s0 + state: + description: + - Desired mapping value. + required: true + default: present + choices: [ 'present', 'absent' ] + reload: + description: + - Reload SELinux policy after commit. + default: yes + ignore_selinux_state: + description: + - Run independent of selinux runtime state + type: bool + default: false +notes: + - The changes are persistent across reboots + - Not tested on any debian based system +requirements: [ 'libselinux', 'policycoreutils' ] +author: +- Dan Keder (@dankeder) +- Petr Lautrbach (@bachradsusi) +- James Cassell (@jamescassell) +''' + +EXAMPLES = ''' +# Modify the default user on the system to the guest_u user +- selogin: + login: __default__ + seuser: guest_u + state: present + +# Assign gijoe user on an MLS machine a range and to the staff_u user +- selogin: + login: gijoe + seuser: staff_u + serange: SystemLow-Secret + state: present + +# Assign all users in the engineering group to the staff_u user +- selogin: + login: '%engineering' + seuser: staff_u + state: present +''' + +RETURN = r''' +# Default return values +''' + + +import traceback + +SELINUX_IMP_ERR = None +try: + import selinux + HAVE_SELINUX = True +except ImportError: + SELINUX_IMP_ERR = traceback.format_exc() + HAVE_SELINUX = False + +SEOBJECT_IMP_ERR = None +try: + import seobject + HAVE_SEOBJECT = True +except ImportError: + SEOBJECT_IMP_ERR = traceback.format_exc() + HAVE_SEOBJECT = False + + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native + + +def semanage_login_add(module, login, seuser, do_reload, serange='s0', sestore=''): + """ Add linux user to SELinux user mapping + + :type module: AnsibleModule + :param module: Ansible module + + :type login: str + :param login: a Linux User or a Linux group if it begins with % + + :type seuser: str + :param proto: An SELinux user ('__default__', 'unconfined_u', 'staff_u', ...), see 'semanage login -l' + + :type serange: str + :param serange: SELinux MLS/MCS range (defaults to 's0') + + :type do_reload: bool + :param do_reload: Whether to reload SELinux policy after commit + + :type sestore: str + :param sestore: SELinux store + + :rtype: bool + :return: True if the policy was changed, otherwise False + """ + try: + selogin = seobject.loginRecords(sestore) + selogin.set_reload(do_reload) + change = False + all_logins = selogin.get_all() + # module.fail_json(msg="%s: %s %s" % (all_logins, login, sestore)) + # for local_login in all_logins: + if login not in all_logins.keys(): + change = True + if not module.check_mode: + selogin.add(login, seuser, serange) + else: + if all_logins[login][0] != seuser or all_logins[login][1] != serange: + change = True + if not module.check_mode: + selogin.modify(login, seuser, serange) + + except (ValueError, KeyError, OSError, RuntimeError) as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)), exception=traceback.format_exc()) + + return change + + +def semanage_login_del(module, login, seuser, do_reload, sestore=''): + """ Delete linux user to SELinux user mapping + + :type module: AnsibleModule + :param module: Ansible module + + :type login: str + :param login: a Linux User or a Linux group if it begins with % + + :type seuser: str + :param proto: An SELinux user ('__default__', 'unconfined_u', 'staff_u', ...), see 'semanage login -l' + + :type do_reload: bool + :param do_reload: Whether to reload SELinux policy after commit + + :type sestore: str + :param sestore: SELinux store + + :rtype: bool + :return: True if the policy was changed, otherwise False + """ + try: + selogin = seobject.loginRecords(sestore) + selogin.set_reload(do_reload) + change = False + all_logins = selogin.get_all() + # module.fail_json(msg="%s: %s %s" % (all_logins, login, sestore)) + if login in all_logins.keys(): + change = True + if not module.check_mode: + selogin.delete(login) + + except (ValueError, KeyError, OSError, RuntimeError) as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)), exception=traceback.format_exc()) + + return change + + +def get_runtime_status(ignore_selinux_state=False): + return True if ignore_selinux_state is True else selinux.is_selinux_enabled() + + +def main(): + module = AnsibleModule( + argument_spec=dict( + ignore_selinux_state=dict(type='bool', default=False), + login=dict(type='str', required=True), + seuser=dict(type='str'), + selevel=dict(type='str', aliases=['serange'], default='s0'), + state=dict(type='str', default='present', choices=['absent', 'present']), + reload=dict(type='bool', default=True), + ), + required_if=[ + ["state", "present", ["seuser"]] + ], + supports_check_mode=True + ) + if not HAVE_SELINUX: + module.fail_json(msg=missing_required_lib("libselinux"), exception=SELINUX_IMP_ERR) + + if not HAVE_SEOBJECT: + module.fail_json(msg=missing_required_lib("seobject from policycoreutils"), exception=SEOBJECT_IMP_ERR) + + ignore_selinux_state = module.params['ignore_selinux_state'] + + if not get_runtime_status(ignore_selinux_state): + module.fail_json(msg="SELinux is disabled on this host.") + + login = module.params['login'] + seuser = module.params['seuser'] + serange = module.params['selevel'] + state = module.params['state'] + do_reload = module.params['reload'] + + result = { + 'login': login, + 'seuser': seuser, + 'serange': serange, + 'state': state, + } + + if state == 'present': + result['changed'] = semanage_login_add(module, login, seuser, do_reload, serange) + elif state == 'absent': + result['changed'] = semanage_login_del(module, login, seuser, do_reload) + else: + module.fail_json(msg='Invalid value of argument "state": {0}'.format(state)) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/synchronize.py b/test/support/integration/plugins/modules/synchronize.py new file mode 100644 index 00000000..e4c520b7 --- /dev/null +++ b/test/support/integration/plugins/modules/synchronize.py @@ -0,0 +1,618 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2012-2013, Timothy Appnel <tim@appnel.com> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: synchronize +version_added: "1.4" +short_description: A wrapper around rsync to make common tasks in your playbooks quick and easy +description: + - C(synchronize) is a wrapper around rsync to make common tasks in your playbooks quick and easy. + - It is run and originates on the local host where Ansible is being run. + - Of course, you could just use the C(command) action to call rsync yourself, but you also have to add a fair number of + boilerplate options and host facts. + - This module is not intended to provide access to the full power of rsync, but does make the most common + invocations easier to implement. You `still` may need to call rsync directly via C(command) or C(shell) depending on your use case. +options: + src: + description: + - Path on the source host that will be synchronized to the destination. + - The path can be absolute or relative. + type: str + required: true + dest: + description: + - Path on the destination host that will be synchronized from the source. + - The path can be absolute or relative. + type: str + required: true + dest_port: + description: + - Port number for ssh on the destination host. + - Prior to Ansible 2.0, the ansible_ssh_port inventory var took precedence over this value. + - This parameter defaults to the value of C(ansible_ssh_port) or C(ansible_port), + the C(remote_port) config setting or the value from ssh client configuration + if none of the former have been set. + type: int + version_added: "1.5" + mode: + description: + - Specify the direction of the synchronization. + - In push mode the localhost or delegate is the source. + - In pull mode the remote host in context is the source. + type: str + choices: [ pull, push ] + default: push + archive: + description: + - Mirrors the rsync archive flag, enables recursive, links, perms, times, owner, group flags and -D. + type: bool + default: yes + checksum: + description: + - Skip based on checksum, rather than mod-time & size; Note that that "archive" option is still enabled by default - the "checksum" option will + not disable it. + type: bool + default: no + version_added: "1.6" + compress: + description: + - Compress file data during the transfer. + - In most cases, leave this enabled unless it causes problems. + type: bool + default: yes + version_added: "1.7" + existing_only: + description: + - Skip creating new files on receiver. + type: bool + default: no + version_added: "1.5" + delete: + description: + - Delete files in C(dest) that don't exist (after transfer, not before) in the C(src) path. + - This option requires C(recursive=yes). + - This option ignores excluded files and behaves like the rsync opt --delete-excluded. + type: bool + default: no + dirs: + description: + - Transfer directories without recursing. + type: bool + default: no + recursive: + description: + - Recurse into directories. + - This parameter defaults to the value of the archive option. + type: bool + links: + description: + - Copy symlinks as symlinks. + - This parameter defaults to the value of the archive option. + type: bool + copy_links: + description: + - Copy symlinks as the item that they point to (the referent) is copied, rather than the symlink. + type: bool + default: no + perms: + description: + - Preserve permissions. + - This parameter defaults to the value of the archive option. + type: bool + times: + description: + - Preserve modification times. + - This parameter defaults to the value of the archive option. + type: bool + owner: + description: + - Preserve owner (super user only). + - This parameter defaults to the value of the archive option. + type: bool + group: + description: + - Preserve group. + - This parameter defaults to the value of the archive option. + type: bool + rsync_path: + description: + - Specify the rsync command to run on the remote host. See C(--rsync-path) on the rsync man page. + - To specify the rsync command to run on the local host, you need to set this your task var C(ansible_rsync_path). + type: str + rsync_timeout: + description: + - Specify a C(--timeout) for the rsync command in seconds. + type: int + default: 0 + set_remote_user: + description: + - Put user@ for the remote paths. + - If you have a custom ssh config to define the remote user for a host + that does not match the inventory user, you should set this parameter to C(no). + type: bool + default: yes + use_ssh_args: + description: + - Use the ssh_args specified in ansible.cfg. + type: bool + default: no + version_added: "2.0" + rsync_opts: + description: + - Specify additional rsync options by passing in an array. + - Note that an empty string in C(rsync_opts) will end up transfer the current working directory. + type: list + default: + version_added: "1.6" + partial: + description: + - Tells rsync to keep the partial file which should make a subsequent transfer of the rest of the file much faster. + type: bool + default: no + version_added: "2.0" + verify_host: + description: + - Verify destination host key. + type: bool + default: no + version_added: "2.0" + private_key: + description: + - Specify the private key to use for SSH-based rsync connections (e.g. C(~/.ssh/id_rsa)). + type: path + version_added: "1.6" + link_dest: + description: + - Add a destination to hard link against during the rsync. + type: list + default: + version_added: "2.5" +notes: + - rsync must be installed on both the local and remote host. + - For the C(synchronize) module, the "local host" is the host `the synchronize task originates on`, and the "destination host" is the host + `synchronize is connecting to`. + - The "local host" can be changed to a different host by using `delegate_to`. This enables copying between two remote hosts or entirely on one + remote machine. + - > + The user and permissions for the synchronize `src` are those of the user running the Ansible task on the local host (or the remote_user for a + delegate_to host when delegate_to is used). + - The user and permissions for the synchronize `dest` are those of the `remote_user` on the destination host or the `become_user` if `become=yes` is active. + - In Ansible 2.0 a bug in the synchronize module made become occur on the "local host". This was fixed in Ansible 2.0.1. + - Currently, synchronize is limited to elevating permissions via passwordless sudo. This is because rsync itself is connecting to the remote machine + and rsync doesn't give us a way to pass sudo credentials in. + - Currently there are only a few connection types which support synchronize (ssh, paramiko, local, and docker) because a sync strategy has been + determined for those connection types. Note that the connection for these must not need a password as rsync itself is making the connection and + rsync does not provide us a way to pass a password to the connection. + - Expect that dest=~/x will be ~<remote_user>/x even if using sudo. + - Inspect the verbose output to validate the destination user/host/path are what was expected. + - To exclude files and directories from being synchronized, you may add C(.rsync-filter) files to the source directory. + - rsync daemon must be up and running with correct permission when using rsync protocol in source or destination path. + - The C(synchronize) module forces `--delay-updates` to avoid leaving a destination in a broken in-between state if the underlying rsync process + encounters an error. Those synchronizing large numbers of files that are willing to trade safety for performance should call rsync directly. + - link_destination is subject to the same limitations as the underlying rsync daemon. Hard links are only preserved if the relative subtrees + of the source and destination are the same. Attempts to hardlink into a directory that is a subdirectory of the source will be prevented. +seealso: +- module: copy +- module: win_robocopy +author: +- Timothy Appnel (@tima) +''' + +EXAMPLES = ''' +- name: Synchronization of src on the control machine to dest on the remote hosts + synchronize: + src: some/relative/path + dest: /some/absolute/path + +- name: Synchronization using rsync protocol (push) + synchronize: + src: some/relative/path/ + dest: rsync://somehost.com/path/ + +- name: Synchronization using rsync protocol (pull) + synchronize: + mode: pull + src: rsync://somehost.com/path/ + dest: /some/absolute/path/ + +- name: Synchronization using rsync protocol on delegate host (push) + synchronize: + src: /some/absolute/path/ + dest: rsync://somehost.com/path/ + delegate_to: delegate.host + +- name: Synchronization using rsync protocol on delegate host (pull) + synchronize: + mode: pull + src: rsync://somehost.com/path/ + dest: /some/absolute/path/ + delegate_to: delegate.host + +- name: Synchronization without any --archive options enabled + synchronize: + src: some/relative/path + dest: /some/absolute/path + archive: no + +- name: Synchronization with --archive options enabled except for --recursive + synchronize: + src: some/relative/path + dest: /some/absolute/path + recursive: no + +- name: Synchronization with --archive options enabled except for --times, with --checksum option enabled + synchronize: + src: some/relative/path + dest: /some/absolute/path + checksum: yes + times: no + +- name: Synchronization without --archive options enabled except use --links + synchronize: + src: some/relative/path + dest: /some/absolute/path + archive: no + links: yes + +- name: Synchronization of two paths both on the control machine + synchronize: + src: some/relative/path + dest: /some/absolute/path + delegate_to: localhost + +- name: Synchronization of src on the inventory host to the dest on the localhost in pull mode + synchronize: + mode: pull + src: some/relative/path + dest: /some/absolute/path + +- name: Synchronization of src on delegate host to dest on the current inventory host. + synchronize: + src: /first/absolute/path + dest: /second/absolute/path + delegate_to: delegate.host + +- name: Synchronize two directories on one remote host. + synchronize: + src: /first/absolute/path + dest: /second/absolute/path + delegate_to: "{{ inventory_hostname }}" + +- name: Synchronize and delete files in dest on the remote host that are not found in src of localhost. + synchronize: + src: some/relative/path + dest: /some/absolute/path + delete: yes + recursive: yes + +# This specific command is granted su privileges on the destination +- name: Synchronize using an alternate rsync command + synchronize: + src: some/relative/path + dest: /some/absolute/path + rsync_path: su -c rsync + +# Example .rsync-filter file in the source directory +# - var # exclude any path whose last part is 'var' +# - /var # exclude any path starting with 'var' starting at the source directory +# + /var/conf # include /var/conf even though it was previously excluded + +- name: Synchronize passing in extra rsync options + synchronize: + src: /tmp/helloworld + dest: /var/www/helloworld + rsync_opts: + - "--no-motd" + - "--exclude=.git" + +# Hardlink files if they didn't change +- name: Use hardlinks when synchronizing filesystems + synchronize: + src: /tmp/path_a/foo.txt + dest: /tmp/path_b/foo.txt + link_dest: /tmp/path_a/ + +# Specify the rsync binary to use on remote host and on local host +- hosts: groupofhosts + vars: + ansible_rsync_path: /usr/gnu/bin/rsync + + tasks: + - name: copy /tmp/localpath/ to remote location /tmp/remotepath + synchronize: + src: /tmp/localpath/ + dest: /tmp/remotepath + rsync_path: /usr/gnu/bin/rsync +''' + + +import os +import errno + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_bytes, to_native +from ansible.module_utils.six.moves import shlex_quote + + +client_addr = None + + +def substitute_controller(path): + global client_addr + if not client_addr: + ssh_env_string = os.environ.get('SSH_CLIENT', None) + try: + client_addr, _ = ssh_env_string.split(None, 1) + except AttributeError: + ssh_env_string = os.environ.get('SSH_CONNECTION', None) + try: + client_addr, _ = ssh_env_string.split(None, 1) + except AttributeError: + pass + if not client_addr: + raise ValueError + + if path.startswith('localhost:'): + path = path.replace('localhost', client_addr, 1) + return path + + +def is_rsh_needed(source, dest): + if source.startswith('rsync://') or dest.startswith('rsync://'): + return False + if ':' in source or ':' in dest: + return True + return False + + +def main(): + module = AnsibleModule( + argument_spec=dict( + src=dict(type='str', required=True), + dest=dict(type='str', required=True), + dest_port=dict(type='int'), + delete=dict(type='bool', default=False), + private_key=dict(type='path'), + rsync_path=dict(type='str'), + _local_rsync_path=dict(type='path', default='rsync'), + _local_rsync_password=dict(type='str', no_log=True), + _substitute_controller=dict(type='bool', default=False), + archive=dict(type='bool', default=True), + checksum=dict(type='bool', default=False), + compress=dict(type='bool', default=True), + existing_only=dict(type='bool', default=False), + dirs=dict(type='bool', default=False), + recursive=dict(type='bool'), + links=dict(type='bool'), + copy_links=dict(type='bool', default=False), + perms=dict(type='bool'), + times=dict(type='bool'), + owner=dict(type='bool'), + group=dict(type='bool'), + set_remote_user=dict(type='bool', default=True), + rsync_timeout=dict(type='int', default=0), + rsync_opts=dict(type='list', default=[]), + ssh_args=dict(type='str'), + partial=dict(type='bool', default=False), + verify_host=dict(type='bool', default=False), + mode=dict(type='str', default='push', choices=['pull', 'push']), + link_dest=dict(type='list') + ), + supports_check_mode=True, + ) + + if module.params['_substitute_controller']: + try: + source = substitute_controller(module.params['src']) + dest = substitute_controller(module.params['dest']) + except ValueError: + module.fail_json(msg='Could not determine controller hostname for rsync to send to') + else: + source = module.params['src'] + dest = module.params['dest'] + dest_port = module.params['dest_port'] + delete = module.params['delete'] + private_key = module.params['private_key'] + rsync_path = module.params['rsync_path'] + rsync = module.params.get('_local_rsync_path', 'rsync') + rsync_password = module.params.get('_local_rsync_password') + rsync_timeout = module.params.get('rsync_timeout', 'rsync_timeout') + archive = module.params['archive'] + checksum = module.params['checksum'] + compress = module.params['compress'] + existing_only = module.params['existing_only'] + dirs = module.params['dirs'] + partial = module.params['partial'] + # the default of these params depends on the value of archive + recursive = module.params['recursive'] + links = module.params['links'] + copy_links = module.params['copy_links'] + perms = module.params['perms'] + times = module.params['times'] + owner = module.params['owner'] + group = module.params['group'] + rsync_opts = module.params['rsync_opts'] + ssh_args = module.params['ssh_args'] + verify_host = module.params['verify_host'] + link_dest = module.params['link_dest'] + + if '/' not in rsync: + rsync = module.get_bin_path(rsync, required=True) + + cmd = [rsync, '--delay-updates', '-F'] + _sshpass_pipe = None + if rsync_password: + try: + module.run_command(["sshpass"]) + except OSError: + module.fail_json( + msg="to use rsync connection with passwords, you must install the sshpass program" + ) + _sshpass_pipe = os.pipe() + cmd = ['sshpass', '-d' + to_native(_sshpass_pipe[0], errors='surrogate_or_strict')] + cmd + if compress: + cmd.append('--compress') + if rsync_timeout: + cmd.append('--timeout=%s' % rsync_timeout) + if module.check_mode: + cmd.append('--dry-run') + if delete: + cmd.append('--delete-after') + if existing_only: + cmd.append('--existing') + if checksum: + cmd.append('--checksum') + if copy_links: + cmd.append('--copy-links') + if archive: + cmd.append('--archive') + if recursive is False: + cmd.append('--no-recursive') + if links is False: + cmd.append('--no-links') + if perms is False: + cmd.append('--no-perms') + if times is False: + cmd.append('--no-times') + if owner is False: + cmd.append('--no-owner') + if group is False: + cmd.append('--no-group') + else: + if recursive is True: + cmd.append('--recursive') + if links is True: + cmd.append('--links') + if perms is True: + cmd.append('--perms') + if times is True: + cmd.append('--times') + if owner is True: + cmd.append('--owner') + if group is True: + cmd.append('--group') + if dirs: + cmd.append('--dirs') + + if source.startswith('rsync://') and dest.startswith('rsync://'): + module.fail_json(msg='either src or dest must be a localhost', rc=1) + + if is_rsh_needed(source, dest): + + # https://github.com/ansible/ansible/issues/15907 + has_rsh = False + for rsync_opt in rsync_opts: + if '--rsh' in rsync_opt: + has_rsh = True + break + + # if the user has not supplied an --rsh option go ahead and add ours + if not has_rsh: + ssh_cmd = [module.get_bin_path('ssh', required=True), '-S', 'none'] + if private_key is not None: + ssh_cmd.extend(['-i', private_key]) + # If the user specified a port value + # Note: The action plugin takes care of setting this to a port from + # inventory if the user didn't specify an explicit dest_port + if dest_port is not None: + ssh_cmd.extend(['-o', 'Port=%s' % dest_port]) + if not verify_host: + ssh_cmd.extend(['-o', 'StrictHostKeyChecking=no', '-o', 'UserKnownHostsFile=/dev/null']) + ssh_cmd_str = ' '.join(shlex_quote(arg) for arg in ssh_cmd) + if ssh_args: + ssh_cmd_str += ' %s' % ssh_args + cmd.append('--rsh=%s' % ssh_cmd_str) + + if rsync_path: + cmd.append('--rsync-path=%s' % rsync_path) + + if rsync_opts: + if '' in rsync_opts: + module.warn('The empty string is present in rsync_opts which will cause rsync to' + ' transfer the current working directory. If this is intended, use "."' + ' instead to get rid of this warning. If this is unintended, check for' + ' problems in your playbook leading to empty string in rsync_opts.') + cmd.extend(rsync_opts) + + if partial: + cmd.append('--partial') + + if link_dest: + cmd.append('-H') + # verbose required because rsync does not believe that adding a + # hardlink is actually a change + cmd.append('-vv') + for x in link_dest: + link_path = os.path.abspath(os.path.expanduser(x)) + destination_path = os.path.abspath(os.path.dirname(dest)) + if destination_path.find(link_path) == 0: + module.fail_json(msg='Hardlinking into a subdirectory of the source would cause recursion. %s and %s' % (destination_path, dest)) + cmd.append('--link-dest=%s' % link_path) + + changed_marker = '<<CHANGED>>' + cmd.append('--out-format=' + changed_marker + '%i %n%L') + + # expand the paths + if '@' not in source: + source = os.path.expanduser(source) + if '@' not in dest: + dest = os.path.expanduser(dest) + + cmd.append(source) + cmd.append(dest) + cmdstr = ' '.join(cmd) + + # If we are using password authentication, write the password into the pipe + if rsync_password: + def _write_password_to_pipe(proc): + os.close(_sshpass_pipe[0]) + try: + os.write(_sshpass_pipe[1], to_bytes(rsync_password) + b'\n') + except OSError as exc: + # Ignore broken pipe errors if the sshpass process has exited. + if exc.errno != errno.EPIPE or proc.poll() is None: + raise + + (rc, out, err) = module.run_command( + cmd, pass_fds=_sshpass_pipe, + before_communicate_callback=_write_password_to_pipe) + else: + (rc, out, err) = module.run_command(cmd) + + if rc: + return module.fail_json(msg=err, rc=rc, cmd=cmdstr) + + if link_dest: + # a leading period indicates no change + changed = (changed_marker + '.') not in out + else: + changed = changed_marker in out + + out_clean = out.replace(changed_marker, '') + out_lines = out_clean.split('\n') + while '' in out_lines: + out_lines.remove('') + if module._diff: + diff = {'prepared': out_clean} + return module.exit_json(changed=changed, msg=out_clean, + rc=rc, cmd=cmdstr, stdout_lines=out_lines, + diff=diff) + + return module.exit_json(changed=changed, msg=out_clean, + rc=rc, cmd=cmdstr, stdout_lines=out_lines) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/timezone.py b/test/support/integration/plugins/modules/timezone.py new file mode 100644 index 00000000..b7439a12 --- /dev/null +++ b/test/support/integration/plugins/modules/timezone.py @@ -0,0 +1,909 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Shinichi TAMURA (@tmshn) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: timezone +short_description: Configure timezone setting +description: + - This module configures the timezone setting, both of the system clock and of the hardware clock. If you want to set up the NTP, use M(service) module. + - It is recommended to restart C(crond) after changing the timezone, otherwise the jobs may run at the wrong time. + - Several different tools are used depending on the OS/Distribution involved. + For Linux it can use C(timedatectl) or edit C(/etc/sysconfig/clock) or C(/etc/timezone) and C(hwclock). + On SmartOS, C(sm-set-timezone), for macOS, C(systemsetup), for BSD, C(/etc/localtime) is modified. + On AIX, C(chtz) is used. + - As of Ansible 2.3 support was added for SmartOS and BSDs. + - As of Ansible 2.4 support was added for macOS. + - As of Ansible 2.9 support was added for AIX 6.1+ + - Windows and HPUX are not supported, please let us know if you find any other OS/distro in which this fails. +version_added: "2.2" +options: + name: + description: + - Name of the timezone for the system clock. + - Default is to keep current setting. + - B(At least one of name and hwclock are required.) + type: str + hwclock: + description: + - Whether the hardware clock is in UTC or in local timezone. + - Default is to keep current setting. + - Note that this option is recommended not to change and may fail + to configure, especially on virtual environments such as AWS. + - B(At least one of name and hwclock are required.) + - I(Only used on Linux.) + type: str + aliases: [ rtc ] + choices: [ local, UTC ] +notes: + - On SmartOS the C(sm-set-timezone) utility (part of the smtools package) is required to set the zone timezone + - On AIX only Olson/tz database timezones are useable (POSIX is not supported). + - An OS reboot is also required on AIX for the new timezone setting to take effect. +author: + - Shinichi TAMURA (@tmshn) + - Jasper Lievisse Adriaanse (@jasperla) + - Indrajit Raychaudhuri (@indrajitr) +''' + +RETURN = r''' +diff: + description: The differences about the given arguments. + returned: success + type: complex + contains: + before: + description: The values before change + type: dict + after: + description: The values after change + type: dict +''' + +EXAMPLES = r''' +- name: Set timezone to Asia/Tokyo + timezone: + name: Asia/Tokyo +''' + +import errno +import os +import platform +import random +import re +import string +import filecmp + +from ansible.module_utils.basic import AnsibleModule, get_distribution +from ansible.module_utils.six import iteritems + + +class Timezone(object): + """This is a generic Timezone manipulation class that is subclassed based on platform. + + A subclass may wish to override the following action methods: + - get(key, phase) ... get the value from the system at `phase` + - set(key, value) ... set the value to the current system + """ + + def __new__(cls, module): + """Return the platform-specific subclass. + + It does not use load_platform_subclass() because it needs to judge based + on whether the `timedatectl` command exists and is available. + + Args: + module: The AnsibleModule. + """ + if platform.system() == 'Linux': + timedatectl = module.get_bin_path('timedatectl') + if timedatectl is not None: + rc, stdout, stderr = module.run_command(timedatectl) + if rc == 0: + return super(Timezone, SystemdTimezone).__new__(SystemdTimezone) + else: + module.warn('timedatectl command was found but not usable: %s. using other method.' % stderr) + return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone) + else: + return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone) + elif re.match('^joyent_.*Z', platform.version()): + # platform.system() returns SunOS, which is too broad. So look at the + # platform version instead. However we have to ensure that we're not + # running in the global zone where changing the timezone has no effect. + zonename_cmd = module.get_bin_path('zonename') + if zonename_cmd is not None: + (rc, stdout, _) = module.run_command(zonename_cmd) + if rc == 0 and stdout.strip() == 'global': + module.fail_json(msg='Adjusting timezone is not supported in Global Zone') + + return super(Timezone, SmartOSTimezone).__new__(SmartOSTimezone) + elif platform.system() == 'Darwin': + return super(Timezone, DarwinTimezone).__new__(DarwinTimezone) + elif re.match('^(Free|Net|Open)BSD', platform.platform()): + return super(Timezone, BSDTimezone).__new__(BSDTimezone) + elif platform.system() == 'AIX': + AIXoslevel = int(platform.version() + platform.release()) + if AIXoslevel >= 61: + return super(Timezone, AIXTimezone).__new__(AIXTimezone) + else: + module.fail_json(msg='AIX os level must be >= 61 for timezone module (Target: %s).' % AIXoslevel) + else: + # Not supported yet + return super(Timezone, Timezone).__new__(Timezone) + + def __init__(self, module): + """Initialize of the class. + + Args: + module: The AnsibleModule. + """ + super(Timezone, self).__init__() + self.msg = [] + # `self.value` holds the values for each params on each phases. + # Initially there's only info of "planned" phase, but the + # `self.check()` function will fill out it. + self.value = dict() + for key in module.argument_spec: + value = module.params[key] + if value is not None: + self.value[key] = dict(planned=value) + self.module = module + + def abort(self, msg): + """Abort the process with error message. + + This is just the wrapper of module.fail_json(). + + Args: + msg: The error message. + """ + error_msg = ['Error message:', msg] + if len(self.msg) > 0: + error_msg.append('Other message(s):') + error_msg.extend(self.msg) + self.module.fail_json(msg='\n'.join(error_msg)) + + def execute(self, *commands, **kwargs): + """Execute the shell command. + + This is just the wrapper of module.run_command(). + + Args: + *commands: The command to execute. + It will be concatenated with single space. + **kwargs: Only 'log' key is checked. + If kwargs['log'] is true, record the command to self.msg. + + Returns: + stdout: Standard output of the command. + """ + command = ' '.join(commands) + (rc, stdout, stderr) = self.module.run_command(command, check_rc=True) + if kwargs.get('log', False): + self.msg.append('executed `%s`' % command) + return stdout + + def diff(self, phase1='before', phase2='after'): + """Calculate the difference between given 2 phases. + + Args: + phase1, phase2: The names of phase to compare. + + Returns: + diff: The difference of value between phase1 and phase2. + This is in the format which can be used with the + `--diff` option of ansible-playbook. + """ + diff = {phase1: {}, phase2: {}} + for key, value in iteritems(self.value): + diff[phase1][key] = value[phase1] + diff[phase2][key] = value[phase2] + return diff + + def check(self, phase): + """Check the state in given phase and set it to `self.value`. + + Args: + phase: The name of the phase to check. + + Returns: + NO RETURN VALUE + """ + if phase == 'planned': + return + for key, value in iteritems(self.value): + value[phase] = self.get(key, phase) + + def change(self): + """Make the changes effect based on `self.value`.""" + for key, value in iteritems(self.value): + if value['before'] != value['planned']: + self.set(key, value['planned']) + + # =========================================== + # Platform specific methods (must be replaced by subclass). + + def get(self, key, phase): + """Get the value for the key at the given phase. + + Called from self.check(). + + Args: + key: The key to get the value + phase: The phase to get the value + + Return: + value: The value for the key at the given phase. + """ + self.abort('get(key, phase) is not implemented on target platform') + + def set(self, key, value): + """Set the value for the key (of course, for the phase 'after'). + + Called from self.change(). + + Args: + key: Key to set the value + value: Value to set + """ + self.abort('set(key, value) is not implemented on target platform') + + def _verify_timezone(self): + tz = self.value['name']['planned'] + tzfile = '/usr/share/zoneinfo/%s' % tz + if not os.path.isfile(tzfile): + self.abort('given timezone "%s" is not available' % tz) + return tzfile + + +class SystemdTimezone(Timezone): + """This is a Timezone manipulation class for systemd-powered Linux. + + It uses the `timedatectl` command to check/set all arguments. + """ + + regexps = dict( + hwclock=re.compile(r'^\s*RTC in local TZ\s*:\s*([^\s]+)', re.MULTILINE), + name=re.compile(r'^\s*Time ?zone\s*:\s*([^\s]+)', re.MULTILINE) + ) + + subcmds = dict( + hwclock='set-local-rtc', + name='set-timezone' + ) + + def __init__(self, module): + super(SystemdTimezone, self).__init__(module) + self.timedatectl = module.get_bin_path('timedatectl', required=True) + self.status = dict() + # Validate given timezone + if 'name' in self.value: + self._verify_timezone() + + def _get_status(self, phase): + if phase not in self.status: + self.status[phase] = self.execute(self.timedatectl, 'status') + return self.status[phase] + + def get(self, key, phase): + status = self._get_status(phase) + value = self.regexps[key].search(status).group(1) + if key == 'hwclock': + # For key='hwclock'; convert yes/no -> local/UTC + if self.module.boolean(value): + value = 'local' + else: + value = 'UTC' + return value + + def set(self, key, value): + # For key='hwclock'; convert UTC/local -> yes/no + if key == 'hwclock': + if value == 'local': + value = 'yes' + else: + value = 'no' + self.execute(self.timedatectl, self.subcmds[key], value, log=True) + + +class NosystemdTimezone(Timezone): + """This is a Timezone manipulation class for non systemd-powered Linux. + + For timezone setting, it edits the following file and reflect changes: + - /etc/sysconfig/clock ... RHEL/CentOS + - /etc/timezone ... Debian/Ubuntu + For hwclock setting, it executes `hwclock --systohc` command with the + '--utc' or '--localtime' option. + """ + + conf_files = dict( + name=None, # To be set in __init__ + hwclock=None, # To be set in __init__ + adjtime='/etc/adjtime' + ) + + # It's fine if all tree config files don't exist + allow_no_file = dict( + name=True, + hwclock=True, + adjtime=True + ) + + regexps = dict( + name=None, # To be set in __init__ + hwclock=re.compile(r'^UTC\s*=\s*([^\s]+)', re.MULTILINE), + adjtime=re.compile(r'^(UTC|LOCAL)$', re.MULTILINE) + ) + + dist_regexps = dict( + SuSE=re.compile(r'^TIMEZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE), + redhat=re.compile(r'^ZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE) + ) + + dist_tzline_format = dict( + SuSE='TIMEZONE="%s"\n', + redhat='ZONE="%s"\n' + ) + + def __init__(self, module): + super(NosystemdTimezone, self).__init__(module) + # Validate given timezone + if 'name' in self.value: + tzfile = self._verify_timezone() + # `--remove-destination` is needed if /etc/localtime is a symlink so + # that it overwrites it instead of following it. + self.update_timezone = ['%s --remove-destination %s /etc/localtime' % (self.module.get_bin_path('cp', required=True), tzfile)] + self.update_hwclock = self.module.get_bin_path('hwclock', required=True) + # Distribution-specific configurations + if self.module.get_bin_path('dpkg-reconfigure') is not None: + # Debian/Ubuntu + if 'name' in self.value: + self.update_timezone = ['%s -sf %s /etc/localtime' % (self.module.get_bin_path('ln', required=True), tzfile), + '%s --frontend noninteractive tzdata' % self.module.get_bin_path('dpkg-reconfigure', required=True)] + self.conf_files['name'] = '/etc/timezone' + self.conf_files['hwclock'] = '/etc/default/rcS' + self.regexps['name'] = re.compile(r'^([^\s]+)', re.MULTILINE) + self.tzline_format = '%s\n' + else: + # RHEL/CentOS/SUSE + if self.module.get_bin_path('tzdata-update') is not None: + # tzdata-update cannot update the timezone if /etc/localtime is + # a symlink so we have to use cp to update the time zone which + # was set above. + if not os.path.islink('/etc/localtime'): + self.update_timezone = [self.module.get_bin_path('tzdata-update', required=True)] + # else: + # self.update_timezone = 'cp --remove-destination ...' <- configured above + self.conf_files['name'] = '/etc/sysconfig/clock' + self.conf_files['hwclock'] = '/etc/sysconfig/clock' + try: + f = open(self.conf_files['name'], 'r') + except IOError as err: + if self._allow_ioerror(err, 'name'): + # If the config file doesn't exist detect the distribution and set regexps. + distribution = get_distribution() + if distribution == 'SuSE': + # For SUSE + self.regexps['name'] = self.dist_regexps['SuSE'] + self.tzline_format = self.dist_tzline_format['SuSE'] + else: + # For RHEL/CentOS + self.regexps['name'] = self.dist_regexps['redhat'] + self.tzline_format = self.dist_tzline_format['redhat'] + else: + self.abort('could not read configuration file "%s"' % self.conf_files['name']) + else: + # The key for timezone might be `ZONE` or `TIMEZONE` + # (the former is used in RHEL/CentOS and the latter is used in SUSE linux). + # So check the content of /etc/sysconfig/clock and decide which key to use. + sysconfig_clock = f.read() + f.close() + if re.search(r'^TIMEZONE\s*=', sysconfig_clock, re.MULTILINE): + # For SUSE + self.regexps['name'] = self.dist_regexps['SuSE'] + self.tzline_format = self.dist_tzline_format['SuSE'] + else: + # For RHEL/CentOS + self.regexps['name'] = self.dist_regexps['redhat'] + self.tzline_format = self.dist_tzline_format['redhat'] + + def _allow_ioerror(self, err, key): + # In some cases, even if the target file does not exist, + # simply creating it may solve the problem. + # In such cases, we should continue the configuration rather than aborting. + if err.errno != errno.ENOENT: + # If the error is not ENOENT ("No such file or directory"), + # (e.g., permission error, etc), we should abort. + return False + return self.allow_no_file.get(key, False) + + def _edit_file(self, filename, regexp, value, key): + """Replace the first matched line with given `value`. + + If `regexp` matched more than once, other than the first line will be deleted. + + Args: + filename: The name of the file to edit. + regexp: The regular expression to search with. + value: The line which will be inserted. + key: For what key the file is being editted. + """ + # Read the file + try: + file = open(filename, 'r') + except IOError as err: + if self._allow_ioerror(err, key): + lines = [] + else: + self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename)) + else: + lines = file.readlines() + file.close() + # Find the all matched lines + matched_indices = [] + for i, line in enumerate(lines): + if regexp.search(line): + matched_indices.append(i) + if len(matched_indices) > 0: + insert_line = matched_indices[0] + else: + insert_line = 0 + # Remove all matched lines + for i in matched_indices[::-1]: + del lines[i] + # ...and insert the value + lines.insert(insert_line, value) + # Write the changes + try: + file = open(filename, 'w') + except IOError: + self.abort('tried to configure %s using a file "%s", but could not write to it' % (key, filename)) + else: + file.writelines(lines) + file.close() + self.msg.append('Added 1 line and deleted %s line(s) on %s' % (len(matched_indices), filename)) + + def _get_value_from_config(self, key, phase): + filename = self.conf_files[key] + try: + file = open(filename, mode='r') + except IOError as err: + if self._allow_ioerror(err, key): + if key == 'hwclock': + return 'n/a' + elif key == 'adjtime': + return 'UTC' + elif key == 'name': + return 'n/a' + else: + self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename)) + else: + status = file.read() + file.close() + try: + value = self.regexps[key].search(status).group(1) + except AttributeError: + if key == 'hwclock': + # If we cannot find UTC in the config that's fine. + return 'n/a' + elif key == 'adjtime': + # If we cannot find UTC/LOCAL in /etc/cannot that means UTC + # will be used by default. + return 'UTC' + elif key == 'name': + if phase == 'before': + # In 'before' phase UTC/LOCAL doesn't need to be set in + # the timezone config file, so we ignore this error. + return 'n/a' + else: + self.abort('tried to configure %s using a file "%s", but could not find a valid value in it' % (key, filename)) + else: + if key == 'hwclock': + # convert yes/no -> UTC/local + if self.module.boolean(value): + value = 'UTC' + else: + value = 'local' + elif key == 'adjtime': + # convert LOCAL -> local + if value != 'UTC': + value = value.lower() + return value + + def get(self, key, phase): + planned = self.value[key]['planned'] + if key == 'hwclock': + value = self._get_value_from_config(key, phase) + if value == planned: + # If the value in the config file is the same as the 'planned' + # value, we need to check /etc/adjtime. + value = self._get_value_from_config('adjtime', phase) + elif key == 'name': + value = self._get_value_from_config(key, phase) + if value == planned: + # If the planned values is the same as the one in the config file + # we need to check if /etc/localtime is also set to the 'planned' zone. + if os.path.islink('/etc/localtime'): + # If /etc/localtime is a symlink and is not set to the TZ we 'planned' + # to set, we need to return the TZ which the symlink points to. + if os.path.exists('/etc/localtime'): + # We use readlink() because on some distros zone files are symlinks + # to other zone files, so it's hard to get which TZ is actually set + # if we follow the symlink. + path = os.readlink('/etc/localtime') + linktz = re.search(r'/usr/share/zoneinfo/(.*)', path, re.MULTILINE) + if linktz: + valuelink = linktz.group(1) + if valuelink != planned: + value = valuelink + else: + # Set current TZ to 'n/a' if the symlink points to a path + # which isn't a zone file. + value = 'n/a' + else: + # Set current TZ to 'n/a' if the symlink to the zone file is broken. + value = 'n/a' + else: + # If /etc/localtime is not a symlink best we can do is compare it with + # the 'planned' zone info file and return 'n/a' if they are different. + try: + if not filecmp.cmp('/etc/localtime', '/usr/share/zoneinfo/' + planned): + return 'n/a' + except Exception: + return 'n/a' + else: + self.abort('unknown parameter "%s"' % key) + return value + + def set_timezone(self, value): + self._edit_file(filename=self.conf_files['name'], + regexp=self.regexps['name'], + value=self.tzline_format % value, + key='name') + for cmd in self.update_timezone: + self.execute(cmd) + + def set_hwclock(self, value): + if value == 'local': + option = '--localtime' + utc = 'no' + else: + option = '--utc' + utc = 'yes' + if self.conf_files['hwclock'] is not None: + self._edit_file(filename=self.conf_files['hwclock'], + regexp=self.regexps['hwclock'], + value='UTC=%s\n' % utc, + key='hwclock') + self.execute(self.update_hwclock, '--systohc', option, log=True) + + def set(self, key, value): + if key == 'name': + self.set_timezone(value) + elif key == 'hwclock': + self.set_hwclock(value) + else: + self.abort('unknown parameter "%s"' % key) + + +class SmartOSTimezone(Timezone): + """This is a Timezone manipulation class for SmartOS instances. + + It uses the C(sm-set-timezone) utility to set the timezone, and + inspects C(/etc/default/init) to determine the current timezone. + + NB: A zone needs to be rebooted in order for the change to be + activated. + """ + + def __init__(self, module): + super(SmartOSTimezone, self).__init__(module) + self.settimezone = self.module.get_bin_path('sm-set-timezone', required=False) + if not self.settimezone: + module.fail_json(msg='sm-set-timezone not found. Make sure the smtools package is installed.') + + def get(self, key, phase): + """Lookup the current timezone name in `/etc/default/init`. If anything else + is requested, or if the TZ field is not set we fail. + """ + if key == 'name': + try: + f = open('/etc/default/init', 'r') + for line in f: + m = re.match('^TZ=(.*)$', line.strip()) + if m: + return m.groups()[0] + except Exception: + self.module.fail_json(msg='Failed to read /etc/default/init') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + """Set the requested timezone through sm-set-timezone, an invalid timezone name + will be rejected and we have no further input validation to perform. + """ + if key == 'name': + cmd = 'sm-set-timezone %s' % value + + (rc, stdout, stderr) = self.module.run_command(cmd) + + if rc != 0: + self.module.fail_json(msg=stderr) + + # sm-set-timezone knows no state and will always set the timezone. + # XXX: https://github.com/joyent/smtools/pull/2 + m = re.match(r'^\* Changed (to)? timezone (to)? (%s).*' % value, stdout.splitlines()[1]) + if not (m and m.groups()[-1] == value): + self.module.fail_json(msg='Failed to set timezone') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class DarwinTimezone(Timezone): + """This is the timezone implementation for Darwin which, unlike other *BSD + implementations, uses the `systemsetup` command on Darwin to check/set + the timezone. + """ + + regexps = dict( + name=re.compile(r'^\s*Time ?Zone\s*:\s*([^\s]+)', re.MULTILINE) + ) + + def __init__(self, module): + super(DarwinTimezone, self).__init__(module) + self.systemsetup = module.get_bin_path('systemsetup', required=True) + self.status = dict() + # Validate given timezone + if 'name' in self.value: + self._verify_timezone() + + def _get_current_timezone(self, phase): + """Lookup the current timezone via `systemsetup -gettimezone`.""" + if phase not in self.status: + self.status[phase] = self.execute(self.systemsetup, '-gettimezone') + return self.status[phase] + + def _verify_timezone(self): + tz = self.value['name']['planned'] + # Lookup the list of supported timezones via `systemsetup -listtimezones`. + # Note: Skip the first line that contains the label 'Time Zones:' + out = self.execute(self.systemsetup, '-listtimezones').splitlines()[1:] + tz_list = list(map(lambda x: x.strip(), out)) + if tz not in tz_list: + self.abort('given timezone "%s" is not available' % tz) + return tz + + def get(self, key, phase): + if key == 'name': + status = self._get_current_timezone(phase) + value = self.regexps[key].search(status).group(1) + return value + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + if key == 'name': + self.execute(self.systemsetup, '-settimezone', value, log=True) + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class BSDTimezone(Timezone): + """This is the timezone implementation for *BSD which works simply through + updating the `/etc/localtime` symlink to point to a valid timezone name under + `/usr/share/zoneinfo`. + """ + + def __init__(self, module): + super(BSDTimezone, self).__init__(module) + + def __get_timezone(self): + zoneinfo_dir = '/usr/share/zoneinfo/' + localtime_file = '/etc/localtime' + + # Strategy 1: + # If /etc/localtime does not exist, assum the timezone is UTC. + if not os.path.exists(localtime_file): + self.module.warn('Could not read /etc/localtime. Assuming UTC.') + return 'UTC' + + # Strategy 2: + # Follow symlink of /etc/localtime + zoneinfo_file = localtime_file + while not zoneinfo_file.startswith(zoneinfo_dir): + try: + zoneinfo_file = os.readlink(localtime_file) + except OSError: + # OSError means "end of symlink chain" or broken link. + break + else: + return zoneinfo_file.replace(zoneinfo_dir, '') + + # Strategy 3: + # (If /etc/localtime is not symlinked) + # Check all files in /usr/share/zoneinfo and return first non-link match. + for dname, _, fnames in sorted(os.walk(zoneinfo_dir)): + for fname in sorted(fnames): + zoneinfo_file = os.path.join(dname, fname) + if not os.path.islink(zoneinfo_file) and filecmp.cmp(zoneinfo_file, localtime_file): + return zoneinfo_file.replace(zoneinfo_dir, '') + + # Strategy 4: + # As a fall-back, return 'UTC' as default assumption. + self.module.warn('Could not identify timezone name from /etc/localtime. Assuming UTC.') + return 'UTC' + + def get(self, key, phase): + """Lookup the current timezone by resolving `/etc/localtime`.""" + if key == 'name': + return self.__get_timezone() + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + if key == 'name': + # First determine if the requested timezone is valid by looking in + # the zoneinfo directory. + zonefile = '/usr/share/zoneinfo/' + value + try: + if not os.path.isfile(zonefile): + self.module.fail_json(msg='%s is not a recognized timezone' % value) + except Exception: + self.module.fail_json(msg='Failed to stat %s' % zonefile) + + # Now (somewhat) atomically update the symlink by creating a new + # symlink and move it into place. Otherwise we have to remove the + # original symlink and create the new symlink, however that would + # create a race condition in case another process tries to read + # /etc/localtime between removal and creation. + suffix = "".join([random.choice(string.ascii_letters + string.digits) for x in range(0, 10)]) + new_localtime = '/etc/localtime.' + suffix + + try: + os.symlink(zonefile, new_localtime) + os.rename(new_localtime, '/etc/localtime') + except Exception: + os.remove(new_localtime) + self.module.fail_json(msg='Could not update /etc/localtime') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class AIXTimezone(Timezone): + """This is a Timezone manipulation class for AIX instances. + + It uses the C(chtz) utility to set the timezone, and + inspects C(/etc/environment) to determine the current timezone. + + While AIX time zones can be set using two formats (POSIX and + Olson) the prefered method is Olson. + See the following article for more information: + https://developer.ibm.com/articles/au-aix-posix/ + + NB: AIX needs to be rebooted in order for the change to be + activated. + """ + + def __init__(self, module): + super(AIXTimezone, self).__init__(module) + self.settimezone = self.module.get_bin_path('chtz', required=True) + + def __get_timezone(self): + """ Return the current value of TZ= in /etc/environment """ + try: + f = open('/etc/environment', 'r') + etcenvironment = f.read() + f.close() + except Exception: + self.module.fail_json(msg='Issue reading contents of /etc/environment') + + match = re.search(r'^TZ=(.*)$', etcenvironment, re.MULTILINE) + if match: + return match.group(1) + else: + return None + + def get(self, key, phase): + """Lookup the current timezone name in `/etc/environment`. If anything else + is requested, or if the TZ field is not set we fail. + """ + if key == 'name': + return self.__get_timezone() + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + """Set the requested timezone through chtz, an invalid timezone name + will be rejected and we have no further input validation to perform. + """ + if key == 'name': + # chtz seems to always return 0 on AIX 7.2, even for invalid timezone values. + # It will only return non-zero if the chtz command itself fails, it does not check for + # valid timezones. We need to perform a basic check to confirm that the timezone + # definition exists in /usr/share/lib/zoneinfo + # This does mean that we can only support Olson for now. The below commented out regex + # detects Olson date formats, so in the future we could detect Posix or Olson and + # act accordingly. + + # regex_olson = re.compile('^([a-z0-9_\-\+]+\/?)+$', re.IGNORECASE) + # if not regex_olson.match(value): + # msg = 'Supplied timezone (%s) does not appear to a be valid Olson string' % value + # self.module.fail_json(msg=msg) + + # First determine if the requested timezone is valid by looking in the zoneinfo + # directory. + zonefile = '/usr/share/lib/zoneinfo/' + value + try: + if not os.path.isfile(zonefile): + self.module.fail_json(msg='%s is not a recognized timezone.' % value) + except Exception: + self.module.fail_json(msg='Failed to check %s.' % zonefile) + + # Now set the TZ using chtz + cmd = 'chtz %s' % value + (rc, stdout, stderr) = self.module.run_command(cmd) + + if rc != 0: + self.module.fail_json(msg=stderr) + + # The best condition check we can do is to check the value of TZ after making the + # change. + TZ = self.__get_timezone() + if TZ != value: + msg = 'TZ value does not match post-change (Actual: %s, Expected: %s).' % (TZ, value) + self.module.fail_json(msg=msg) + + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +def main(): + # Construct 'module' and 'tz' + module = AnsibleModule( + argument_spec=dict( + hwclock=dict(type='str', choices=['local', 'UTC'], aliases=['rtc']), + name=dict(type='str'), + ), + required_one_of=[ + ['hwclock', 'name'] + ], + supports_check_mode=True, + ) + tz = Timezone(module) + + # Check the current state + tz.check(phase='before') + if module.check_mode: + diff = tz.diff('before', 'planned') + # In check mode, 'planned' state is treated as 'after' state + diff['after'] = diff.pop('planned') + else: + # Make change + tz.change() + # Check the current state + tz.check(phase='after') + # Examine if the current state matches planned state + (after, planned) = tz.diff('after', 'planned').values() + if after != planned: + tz.abort('still not desired state, though changes have made - ' + 'planned: %s, after: %s' % (str(planned), str(after))) + diff = tz.diff('before', 'after') + + changed = (diff['before'] != diff['after']) + if len(tz.msg) > 0: + module.exit_json(changed=changed, diff=diff, msg='\n'.join(tz.msg)) + else: + module.exit_json(changed=changed, diff=diff) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/x509_crl.py b/test/support/integration/plugins/modules/x509_crl.py new file mode 100644 index 00000000..ef601eda --- /dev/null +++ b/test/support/integration/plugins/modules/x509_crl.py @@ -0,0 +1,783 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019, Felix Fontein <felix@fontein.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: x509_crl +version_added: "2.10" +short_description: Generate Certificate Revocation Lists (CRLs) +description: + - This module allows one to (re)generate or update Certificate Revocation Lists (CRLs). + - Certificates on the revocation list can be either specified via serial number and (optionally) their issuer, + or as a path to a certificate file in PEM format. +requirements: + - cryptography >= 1.2 +author: + - Felix Fontein (@felixfontein) +options: + state: + description: + - Whether the CRL file should exist or not, taking action if the state is different from what is stated. + type: str + default: present + choices: [ absent, present ] + + mode: + description: + - Defines how to process entries of existing CRLs. + - If set to C(generate), makes sure that the CRL has the exact set of revoked certificates + as specified in I(revoked_certificates). + - If set to C(update), makes sure that the CRL contains the revoked certificates from + I(revoked_certificates), but can also contain other revoked certificates. If the CRL file + already exists, all entries from the existing CRL will also be included in the new CRL. + When using C(update), you might be interested in setting I(ignore_timestamps) to C(yes). + type: str + default: generate + choices: [ generate, update ] + + force: + description: + - Should the CRL be forced to be regenerated. + type: bool + default: no + + backup: + description: + - Create a backup file including a timestamp so you can get the original + CRL back if you overwrote it with a new one by accident. + type: bool + default: no + + path: + description: + - Remote absolute path where the generated CRL file should be created or is already located. + type: path + required: yes + + privatekey_path: + description: + - Path to the CA's private key to use when signing the CRL. + - Either I(privatekey_path) or I(privatekey_content) must be specified if I(state) is C(present), but not both. + type: path + + privatekey_content: + description: + - The content of the CA's private key to use when signing the CRL. + - Either I(privatekey_path) or I(privatekey_content) must be specified if I(state) is C(present), but not both. + type: str + + privatekey_passphrase: + description: + - The passphrase for the I(privatekey_path). + - This is required if the private key is password protected. + type: str + + issuer: + description: + - Key/value pairs that will be present in the issuer name field of the CRL. + - If you need to specify more than one value with the same key, use a list as value. + - Required if I(state) is C(present). + type: dict + + last_update: + description: + - The point in time from which this CRL can be trusted. + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent, except when + I(ignore_timestamps) is set to C(yes). + type: str + default: "+0s" + + next_update: + description: + - "The absolute latest point in time by which this I(issuer) is expected to have issued + another CRL. Many clients will treat a CRL as expired once I(next_update) occurs." + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent, except when + I(ignore_timestamps) is set to C(yes). + - Required if I(state) is C(present). + type: str + + digest: + description: + - Digest algorithm to be used when signing the CRL. + type: str + default: sha256 + + revoked_certificates: + description: + - List of certificates to be revoked. + - Required if I(state) is C(present). + type: list + elements: dict + suboptions: + path: + description: + - Path to a certificate in PEM format. + - The serial number and issuer will be extracted from the certificate. + - Mutually exclusive with I(content) and I(serial_number). One of these three options + must be specified. + type: path + content: + description: + - Content of a certificate in PEM format. + - The serial number and issuer will be extracted from the certificate. + - Mutually exclusive with I(path) and I(serial_number). One of these three options + must be specified. + type: str + serial_number: + description: + - Serial number of the certificate. + - Mutually exclusive with I(path) and I(content). One of these three options must + be specified. + type: int + revocation_date: + description: + - The point in time the certificate was revoked. + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent, except when + I(ignore_timestamps) is set to C(yes). + type: str + default: "+0s" + issuer: + description: + - The certificate's issuer. + - "Example: C(DNS:ca.example.org)" + type: list + elements: str + issuer_critical: + description: + - Whether the certificate issuer extension should be critical. + type: bool + default: no + reason: + description: + - The value for the revocation reason extension. + type: str + choices: + - unspecified + - key_compromise + - ca_compromise + - affiliation_changed + - superseded + - cessation_of_operation + - certificate_hold + - privilege_withdrawn + - aa_compromise + - remove_from_crl + reason_critical: + description: + - Whether the revocation reason extension should be critical. + type: bool + default: no + invalidity_date: + description: + - The point in time it was known/suspected that the private key was compromised + or that the certificate otherwise became invalid. + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent. This will NOT + change when I(ignore_timestamps) is set to C(yes). + type: str + invalidity_date_critical: + description: + - Whether the invalidity date extension should be critical. + type: bool + default: no + + ignore_timestamps: + description: + - Whether the timestamps I(last_update), I(next_update) and I(revocation_date) (in + I(revoked_certificates)) should be ignored for idempotency checks. The timestamp + I(invalidity_date) in I(revoked_certificates) will never be ignored. + - Use this in combination with relative timestamps for these values to get idempotency. + type: bool + default: no + + return_content: + description: + - If set to C(yes), will return the (current or generated) CRL's content as I(crl). + type: bool + default: no + +extends_documentation_fragment: + - files + +notes: + - All ASN.1 TIME values should be specified following the YYYYMMDDHHMMSSZ pattern. + - Date specified should be UTC. Minutes and seconds are mandatory. +''' + +EXAMPLES = r''' +- name: Generate a CRL + x509_crl: + path: /etc/ssl/my-ca.crl + privatekey_path: /etc/ssl/private/my-ca.pem + issuer: + CN: My CA + last_update: "+0s" + next_update: "+7d" + revoked_certificates: + - serial_number: 1234 + revocation_date: 20190331202428Z + issuer: + CN: My CA + - serial_number: 2345 + revocation_date: 20191013152910Z + reason: affiliation_changed + invalidity_date: 20191001000000Z + - path: /etc/ssl/crt/revoked-cert.pem + revocation_date: 20191010010203Z +''' + +RETURN = r''' +filename: + description: Path to the generated CRL + returned: changed or success + type: str + sample: /path/to/my-ca.crl +backup_file: + description: Name of backup file created. + returned: changed and if I(backup) is C(yes) + type: str + sample: /path/to/my-ca.crl.2019-03-09@11:22~ +privatekey: + description: Path to the private CA key + returned: changed or success + type: str + sample: /path/to/my-ca.pem +issuer: + description: + - The CRL's issuer. + - Note that for repeated values, only the last one will be returned. + returned: success + type: dict + sample: '{"organizationName": "Ansible", "commonName": "ca.example.com"}' +issuer_ordered: + description: The CRL's issuer as an ordered list of tuples. + returned: success + type: list + elements: list + sample: '[["organizationName", "Ansible"], ["commonName": "ca.example.com"]]' +last_update: + description: The point in time from which this CRL can be trusted as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +next_update: + description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +digest: + description: The signature algorithm used to sign the CRL. + returned: success + type: str + sample: sha256WithRSAEncryption +revoked_certificates: + description: List of certificates to be revoked. + returned: success + type: list + elements: dict + contains: + serial_number: + description: Serial number of the certificate. + type: int + sample: 1234 + revocation_date: + description: The point in time the certificate was revoked as ASN.1 TIME. + type: str + sample: 20190413202428Z + issuer: + description: The certificate's issuer. + type: list + elements: str + sample: '["DNS:ca.example.org"]' + issuer_critical: + description: Whether the certificate issuer extension is critical. + type: bool + sample: no + reason: + description: + - The value for the revocation reason extension. + - One of C(unspecified), C(key_compromise), C(ca_compromise), C(affiliation_changed), C(superseded), + C(cessation_of_operation), C(certificate_hold), C(privilege_withdrawn), C(aa_compromise), and + C(remove_from_crl). + type: str + sample: key_compromise + reason_critical: + description: Whether the revocation reason extension is critical. + type: bool + sample: no + invalidity_date: + description: | + The point in time it was known/suspected that the private key was compromised + or that the certificate otherwise became invalid as ASN.1 TIME. + type: str + sample: 20190413202428Z + invalidity_date_critical: + description: Whether the invalidity date extension is critical. + type: bool + sample: no +crl: + description: The (current or generated) CRL's content. + returned: if I(state) is C(present) and I(return_content) is C(yes) + type: str +''' + + +import os +import traceback +from distutils.version import LooseVersion + +from ansible.module_utils import crypto as crypto_utils +from ansible.module_utils._text import to_native, to_text +from ansible.module_utils.basic import AnsibleModule, missing_required_lib + +MINIMAL_CRYPTOGRAPHY_VERSION = '1.2' + +CRYPTOGRAPHY_IMP_ERR = None +try: + import cryptography + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.serialization import Encoding + from cryptography.x509 import ( + CertificateRevocationListBuilder, + RevokedCertificateBuilder, + NameAttribute, + Name, + ) + CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__) +except ImportError: + CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() + CRYPTOGRAPHY_FOUND = False +else: + CRYPTOGRAPHY_FOUND = True + + +TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ" + + +class CRLError(crypto_utils.OpenSSLObjectError): + pass + + +class CRL(crypto_utils.OpenSSLObject): + + def __init__(self, module): + super(CRL, self).__init__( + module.params['path'], + module.params['state'], + module.params['force'], + module.check_mode + ) + + self.update = module.params['mode'] == 'update' + self.ignore_timestamps = module.params['ignore_timestamps'] + self.return_content = module.params['return_content'] + self.crl_content = None + + self.privatekey_path = module.params['privatekey_path'] + self.privatekey_content = module.params['privatekey_content'] + if self.privatekey_content is not None: + self.privatekey_content = self.privatekey_content.encode('utf-8') + self.privatekey_passphrase = module.params['privatekey_passphrase'] + + self.issuer = crypto_utils.parse_name_field(module.params['issuer']) + self.issuer = [(entry[0], entry[1]) for entry in self.issuer if entry[1]] + + self.last_update = crypto_utils.get_relative_time_option(module.params['last_update'], 'last_update') + self.next_update = crypto_utils.get_relative_time_option(module.params['next_update'], 'next_update') + + self.digest = crypto_utils.select_message_digest(module.params['digest']) + if self.digest is None: + raise CRLError('The digest "{0}" is not supported'.format(module.params['digest'])) + + self.revoked_certificates = [] + for i, rc in enumerate(module.params['revoked_certificates']): + result = { + 'serial_number': None, + 'revocation_date': None, + 'issuer': None, + 'issuer_critical': False, + 'reason': None, + 'reason_critical': False, + 'invalidity_date': None, + 'invalidity_date_critical': False, + } + path_prefix = 'revoked_certificates[{0}].'.format(i) + if rc['path'] is not None or rc['content'] is not None: + # Load certificate from file or content + try: + if rc['content'] is not None: + rc['content'] = rc['content'].encode('utf-8') + cert = crypto_utils.load_certificate(rc['path'], content=rc['content'], backend='cryptography') + try: + result['serial_number'] = cert.serial_number + except AttributeError: + # The property was called "serial" before cryptography 1.4 + result['serial_number'] = cert.serial + except crypto_utils.OpenSSLObjectError as e: + if rc['content'] is not None: + module.fail_json( + msg='Cannot parse certificate from {0}content: {1}'.format(path_prefix, to_native(e)) + ) + else: + module.fail_json( + msg='Cannot read certificate "{1}" from {0}path: {2}'.format(path_prefix, rc['path'], to_native(e)) + ) + else: + # Specify serial_number (and potentially issuer) directly + result['serial_number'] = rc['serial_number'] + # All other options + if rc['issuer']: + result['issuer'] = [crypto_utils.cryptography_get_name(issuer) for issuer in rc['issuer']] + result['issuer_critical'] = rc['issuer_critical'] + result['revocation_date'] = crypto_utils.get_relative_time_option( + rc['revocation_date'], + path_prefix + 'revocation_date' + ) + if rc['reason']: + result['reason'] = crypto_utils.REVOCATION_REASON_MAP[rc['reason']] + result['reason_critical'] = rc['reason_critical'] + if rc['invalidity_date']: + result['invalidity_date'] = crypto_utils.get_relative_time_option( + rc['invalidity_date'], + path_prefix + 'invalidity_date' + ) + result['invalidity_date_critical'] = rc['invalidity_date_critical'] + self.revoked_certificates.append(result) + + self.module = module + + self.backup = module.params['backup'] + self.backup_file = None + + try: + self.privatekey = crypto_utils.load_privatekey( + path=self.privatekey_path, + content=self.privatekey_content, + passphrase=self.privatekey_passphrase, + backend='cryptography' + ) + except crypto_utils.OpenSSLBadPassphraseError as exc: + raise CRLError(exc) + + self.crl = None + try: + with open(self.path, 'rb') as f: + data = f.read() + self.crl = x509.load_pem_x509_crl(data, default_backend()) + if self.return_content: + self.crl_content = data + except Exception as dummy: + self.crl_content = None + + def remove(self): + if self.backup: + self.backup_file = self.module.backup_local(self.path) + super(CRL, self).remove(self.module) + + def _compress_entry(self, entry): + if self.ignore_timestamps: + # Throw out revocation_date + return ( + entry['serial_number'], + tuple(entry['issuer']) if entry['issuer'] is not None else None, + entry['issuer_critical'], + entry['reason'], + entry['reason_critical'], + entry['invalidity_date'], + entry['invalidity_date_critical'], + ) + else: + return ( + entry['serial_number'], + entry['revocation_date'], + tuple(entry['issuer']) if entry['issuer'] is not None else None, + entry['issuer_critical'], + entry['reason'], + entry['reason_critical'], + entry['invalidity_date'], + entry['invalidity_date_critical'], + ) + + def check(self, perms_required=True): + """Ensure the resource is in its desired state.""" + + state_and_perms = super(CRL, self).check(self.module, perms_required) + + if not state_and_perms: + return False + + if self.crl is None: + return False + + if self.last_update != self.crl.last_update and not self.ignore_timestamps: + return False + if self.next_update != self.crl.next_update and not self.ignore_timestamps: + return False + if self.digest.name != self.crl.signature_hash_algorithm.name: + return False + + want_issuer = [(crypto_utils.cryptography_name_to_oid(entry[0]), entry[1]) for entry in self.issuer] + if want_issuer != [(sub.oid, sub.value) for sub in self.crl.issuer]: + return False + + old_entries = [self._compress_entry(crypto_utils.cryptography_decode_revoked_certificate(cert)) for cert in self.crl] + new_entries = [self._compress_entry(cert) for cert in self.revoked_certificates] + if self.update: + # We don't simply use a set so that duplicate entries are treated correctly + for entry in new_entries: + try: + old_entries.remove(entry) + except ValueError: + return False + else: + if old_entries != new_entries: + return False + + return True + + def _generate_crl(self): + backend = default_backend() + crl = CertificateRevocationListBuilder() + + try: + crl = crl.issuer_name(Name([ + NameAttribute(crypto_utils.cryptography_name_to_oid(entry[0]), to_text(entry[1])) + for entry in self.issuer + ])) + except ValueError as e: + raise CRLError(e) + + crl = crl.last_update(self.last_update) + crl = crl.next_update(self.next_update) + + if self.update and self.crl: + new_entries = set([self._compress_entry(entry) for entry in self.revoked_certificates]) + for entry in self.crl: + decoded_entry = self._compress_entry(crypto_utils.cryptography_decode_revoked_certificate(entry)) + if decoded_entry not in new_entries: + crl = crl.add_revoked_certificate(entry) + for entry in self.revoked_certificates: + revoked_cert = RevokedCertificateBuilder() + revoked_cert = revoked_cert.serial_number(entry['serial_number']) + revoked_cert = revoked_cert.revocation_date(entry['revocation_date']) + if entry['issuer'] is not None: + revoked_cert = revoked_cert.add_extension( + x509.CertificateIssuer([ + crypto_utils.cryptography_get_name(name) for name in self.entry['issuer'] + ]), + entry['issuer_critical'] + ) + if entry['reason'] is not None: + revoked_cert = revoked_cert.add_extension( + x509.CRLReason(entry['reason']), + entry['reason_critical'] + ) + if entry['invalidity_date'] is not None: + revoked_cert = revoked_cert.add_extension( + x509.InvalidityDate(entry['invalidity_date']), + entry['invalidity_date_critical'] + ) + crl = crl.add_revoked_certificate(revoked_cert.build(backend)) + + self.crl = crl.sign(self.privatekey, self.digest, backend=backend) + return self.crl.public_bytes(Encoding.PEM) + + def generate(self): + if not self.check(perms_required=False) or self.force: + result = self._generate_crl() + if self.return_content: + self.crl_content = result + if self.backup: + self.backup_file = self.module.backup_local(self.path) + crypto_utils.write_file(self.module, result) + self.changed = True + + file_args = self.module.load_file_common_arguments(self.module.params) + if self.module.set_fs_attributes_if_different(file_args, False): + self.changed = True + + def _dump_revoked(self, entry): + return { + 'serial_number': entry['serial_number'], + 'revocation_date': entry['revocation_date'].strftime(TIMESTAMP_FORMAT), + 'issuer': + [crypto_utils.cryptography_decode_name(issuer) for issuer in entry['issuer']] + if entry['issuer'] is not None else None, + 'issuer_critical': entry['issuer_critical'], + 'reason': crypto_utils.REVOCATION_REASON_MAP_INVERSE.get(entry['reason']) if entry['reason'] is not None else None, + 'reason_critical': entry['reason_critical'], + 'invalidity_date': + entry['invalidity_date'].strftime(TIMESTAMP_FORMAT) + if entry['invalidity_date'] is not None else None, + 'invalidity_date_critical': entry['invalidity_date_critical'], + } + + def dump(self, check_mode=False): + result = { + 'changed': self.changed, + 'filename': self.path, + 'privatekey': self.privatekey_path, + 'last_update': None, + 'next_update': None, + 'digest': None, + 'issuer_ordered': None, + 'issuer': None, + 'revoked_certificates': [], + } + if self.backup_file: + result['backup_file'] = self.backup_file + + if check_mode: + result['last_update'] = self.last_update.strftime(TIMESTAMP_FORMAT) + result['next_update'] = self.next_update.strftime(TIMESTAMP_FORMAT) + # result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid) + result['digest'] = self.module.params['digest'] + result['issuer_ordered'] = self.issuer + result['issuer'] = {} + for k, v in self.issuer: + result['issuer'][k] = v + result['revoked_certificates'] = [] + for entry in self.revoked_certificates: + result['revoked_certificates'].append(self._dump_revoked(entry)) + elif self.crl: + result['last_update'] = self.crl.last_update.strftime(TIMESTAMP_FORMAT) + result['next_update'] = self.crl.next_update.strftime(TIMESTAMP_FORMAT) + try: + result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid) + except AttributeError: + # Older cryptography versions don't have signature_algorithm_oid yet + dotted = crypto_utils._obj2txt( + self.crl._backend._lib, + self.crl._backend._ffi, + self.crl._x509_crl.sig_alg.algorithm + ) + oid = x509.oid.ObjectIdentifier(dotted) + result['digest'] = crypto_utils.cryptography_oid_to_name(oid) + issuer = [] + for attribute in self.crl.issuer: + issuer.append([crypto_utils.cryptography_oid_to_name(attribute.oid), attribute.value]) + result['issuer_ordered'] = issuer + result['issuer'] = {} + for k, v in issuer: + result['issuer'][k] = v + result['revoked_certificates'] = [] + for cert in self.crl: + entry = crypto_utils.cryptography_decode_revoked_certificate(cert) + result['revoked_certificates'].append(self._dump_revoked(entry)) + + if self.return_content: + result['crl'] = self.crl_content + + return result + + +def main(): + module = AnsibleModule( + argument_spec=dict( + state=dict(type='str', default='present', choices=['present', 'absent']), + mode=dict(type='str', default='generate', choices=['generate', 'update']), + force=dict(type='bool', default=False), + backup=dict(type='bool', default=False), + path=dict(type='path', required=True), + privatekey_path=dict(type='path'), + privatekey_content=dict(type='str'), + privatekey_passphrase=dict(type='str', no_log=True), + issuer=dict(type='dict'), + last_update=dict(type='str', default='+0s'), + next_update=dict(type='str'), + digest=dict(type='str', default='sha256'), + ignore_timestamps=dict(type='bool', default=False), + return_content=dict(type='bool', default=False), + revoked_certificates=dict( + type='list', + elements='dict', + options=dict( + path=dict(type='path'), + content=dict(type='str'), + serial_number=dict(type='int'), + revocation_date=dict(type='str', default='+0s'), + issuer=dict(type='list', elements='str'), + issuer_critical=dict(type='bool', default=False), + reason=dict( + type='str', + choices=[ + 'unspecified', 'key_compromise', 'ca_compromise', 'affiliation_changed', + 'superseded', 'cessation_of_operation', 'certificate_hold', + 'privilege_withdrawn', 'aa_compromise', 'remove_from_crl' + ] + ), + reason_critical=dict(type='bool', default=False), + invalidity_date=dict(type='str'), + invalidity_date_critical=dict(type='bool', default=False), + ), + required_one_of=[['path', 'content', 'serial_number']], + mutually_exclusive=[['path', 'content', 'serial_number']], + ), + ), + required_if=[ + ('state', 'present', ['privatekey_path', 'privatekey_content'], True), + ('state', 'present', ['issuer', 'next_update', 'revoked_certificates'], False), + ], + mutually_exclusive=( + ['privatekey_path', 'privatekey_content'], + ), + supports_check_mode=True, + add_file_common_args=True, + ) + + if not CRYPTOGRAPHY_FOUND: + module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)), + exception=CRYPTOGRAPHY_IMP_ERR) + + try: + crl = CRL(module) + + if module.params['state'] == 'present': + if module.check_mode: + result = crl.dump(check_mode=True) + result['changed'] = module.params['force'] or not crl.check() + module.exit_json(**result) + + crl.generate() + else: + if module.check_mode: + result = crl.dump(check_mode=True) + result['changed'] = os.path.exists(module.params['path']) + module.exit_json(**result) + + crl.remove() + + result = crl.dump() + module.exit_json(**result) + except crypto_utils.OpenSSLObjectError as exc: + module.fail_json(msg=to_native(exc)) + + +if __name__ == "__main__": + main() diff --git a/test/support/integration/plugins/modules/x509_crl_info.py b/test/support/integration/plugins/modules/x509_crl_info.py new file mode 100644 index 00000000..b61db26f --- /dev/null +++ b/test/support/integration/plugins/modules/x509_crl_info.py @@ -0,0 +1,281 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2020, Felix Fontein <felix@fontein.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: x509_crl_info +version_added: "2.10" +short_description: Retrieve information on Certificate Revocation Lists (CRLs) +description: + - This module allows one to retrieve information on Certificate Revocation Lists (CRLs). +requirements: + - cryptography >= 1.2 +author: + - Felix Fontein (@felixfontein) +options: + path: + description: + - Remote absolute path where the generated CRL file should be created or is already located. + - Either I(path) or I(content) must be specified, but not both. + type: path + content: + description: + - Content of the X.509 certificate in PEM format. + - Either I(path) or I(content) must be specified, but not both. + type: str + +notes: + - All timestamp values are provided in ASN.1 TIME format, i.e. following the C(YYYYMMDDHHMMSSZ) pattern. + They are all in UTC. +seealso: + - module: x509_crl +''' + +EXAMPLES = r''' +- name: Get information on CRL + x509_crl_info: + path: /etc/ssl/my-ca.crl + register: result + +- debug: + msg: "{{ result }}" +''' + +RETURN = r''' +issuer: + description: + - The CRL's issuer. + - Note that for repeated values, only the last one will be returned. + returned: success + type: dict + sample: '{"organizationName": "Ansible", "commonName": "ca.example.com"}' +issuer_ordered: + description: The CRL's issuer as an ordered list of tuples. + returned: success + type: list + elements: list + sample: '[["organizationName", "Ansible"], ["commonName": "ca.example.com"]]' +last_update: + description: The point in time from which this CRL can be trusted as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +next_update: + description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +digest: + description: The signature algorithm used to sign the CRL. + returned: success + type: str + sample: sha256WithRSAEncryption +revoked_certificates: + description: List of certificates to be revoked. + returned: success + type: list + elements: dict + contains: + serial_number: + description: Serial number of the certificate. + type: int + sample: 1234 + revocation_date: + description: The point in time the certificate was revoked as ASN.1 TIME. + type: str + sample: 20190413202428Z + issuer: + description: The certificate's issuer. + type: list + elements: str + sample: '["DNS:ca.example.org"]' + issuer_critical: + description: Whether the certificate issuer extension is critical. + type: bool + sample: no + reason: + description: + - The value for the revocation reason extension. + - One of C(unspecified), C(key_compromise), C(ca_compromise), C(affiliation_changed), C(superseded), + C(cessation_of_operation), C(certificate_hold), C(privilege_withdrawn), C(aa_compromise), and + C(remove_from_crl). + type: str + sample: key_compromise + reason_critical: + description: Whether the revocation reason extension is critical. + type: bool + sample: no + invalidity_date: + description: | + The point in time it was known/suspected that the private key was compromised + or that the certificate otherwise became invalid as ASN.1 TIME. + type: str + sample: 20190413202428Z + invalidity_date_critical: + description: Whether the invalidity date extension is critical. + type: bool + sample: no +''' + + +import traceback +from distutils.version import LooseVersion + +from ansible.module_utils import crypto as crypto_utils +from ansible.module_utils._text import to_native +from ansible.module_utils.basic import AnsibleModule, missing_required_lib + +MINIMAL_CRYPTOGRAPHY_VERSION = '1.2' + +CRYPTOGRAPHY_IMP_ERR = None +try: + import cryptography + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__) +except ImportError: + CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() + CRYPTOGRAPHY_FOUND = False +else: + CRYPTOGRAPHY_FOUND = True + + +TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ" + + +class CRLError(crypto_utils.OpenSSLObjectError): + pass + + +class CRLInfo(crypto_utils.OpenSSLObject): + """The main module implementation.""" + + def __init__(self, module): + super(CRLInfo, self).__init__( + module.params['path'] or '', + 'present', + False, + module.check_mode + ) + + self.content = module.params['content'] + + self.module = module + + self.crl = None + if self.content is None: + try: + with open(self.path, 'rb') as f: + data = f.read() + except Exception as e: + self.module.fail_json(msg='Error while reading CRL file from disk: {0}'.format(e)) + else: + data = self.content.encode('utf-8') + + try: + self.crl = x509.load_pem_x509_crl(data, default_backend()) + except Exception as e: + self.module.fail_json(msg='Error while decoding CRL: {0}'.format(e)) + + def _dump_revoked(self, entry): + return { + 'serial_number': entry['serial_number'], + 'revocation_date': entry['revocation_date'].strftime(TIMESTAMP_FORMAT), + 'issuer': + [crypto_utils.cryptography_decode_name(issuer) for issuer in entry['issuer']] + if entry['issuer'] is not None else None, + 'issuer_critical': entry['issuer_critical'], + 'reason': crypto_utils.REVOCATION_REASON_MAP_INVERSE.get(entry['reason']) if entry['reason'] is not None else None, + 'reason_critical': entry['reason_critical'], + 'invalidity_date': + entry['invalidity_date'].strftime(TIMESTAMP_FORMAT) + if entry['invalidity_date'] is not None else None, + 'invalidity_date_critical': entry['invalidity_date_critical'], + } + + def get_info(self): + result = { + 'changed': False, + 'last_update': None, + 'next_update': None, + 'digest': None, + 'issuer_ordered': None, + 'issuer': None, + 'revoked_certificates': [], + } + + result['last_update'] = self.crl.last_update.strftime(TIMESTAMP_FORMAT) + result['next_update'] = self.crl.next_update.strftime(TIMESTAMP_FORMAT) + try: + result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid) + except AttributeError: + # Older cryptography versions don't have signature_algorithm_oid yet + dotted = crypto_utils._obj2txt( + self.crl._backend._lib, + self.crl._backend._ffi, + self.crl._x509_crl.sig_alg.algorithm + ) + oid = x509.oid.ObjectIdentifier(dotted) + result['digest'] = crypto_utils.cryptography_oid_to_name(oid) + issuer = [] + for attribute in self.crl.issuer: + issuer.append([crypto_utils.cryptography_oid_to_name(attribute.oid), attribute.value]) + result['issuer_ordered'] = issuer + result['issuer'] = {} + for k, v in issuer: + result['issuer'][k] = v + result['revoked_certificates'] = [] + for cert in self.crl: + entry = crypto_utils.cryptography_decode_revoked_certificate(cert) + result['revoked_certificates'].append(self._dump_revoked(entry)) + + return result + + def generate(self): + # Empty method because crypto_utils.OpenSSLObject wants this + pass + + def dump(self): + # Empty method because crypto_utils.OpenSSLObject wants this + pass + + +def main(): + module = AnsibleModule( + argument_spec=dict( + path=dict(type='path'), + content=dict(type='str'), + ), + required_one_of=( + ['path', 'content'], + ), + mutually_exclusive=( + ['path', 'content'], + ), + supports_check_mode=True, + ) + + if not CRYPTOGRAPHY_FOUND: + module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)), + exception=CRYPTOGRAPHY_IMP_ERR) + + try: + crl = CRLInfo(module) + result = crl.get_info() + module.exit_json(**result) + except crypto_utils.OpenSSLObjectError as e: + module.fail_json(msg=to_native(e)) + + +if __name__ == "__main__": + main() diff --git a/test/support/integration/plugins/modules/xml.py b/test/support/integration/plugins/modules/xml.py new file mode 100644 index 00000000..b5b35a38 --- /dev/null +++ b/test/support/integration/plugins/modules/xml.py @@ -0,0 +1,966 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2014, Red Hat, Inc. +# Copyright: (c) 2014, Tim Bielawa <tbielawa@redhat.com> +# Copyright: (c) 2014, Magnus Hedemark <mhedemar@redhat.com> +# Copyright: (c) 2017, Dag Wieers <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: xml +short_description: Manage bits and pieces of XML files or strings +description: +- A CRUD-like interface to managing bits of XML files. +version_added: '2.4' +options: + path: + description: + - Path to the file to operate on. + - This file must exist ahead of time. + - This parameter is required, unless C(xmlstring) is given. + type: path + required: yes + aliases: [ dest, file ] + xmlstring: + description: + - A string containing XML on which to operate. + - This parameter is required, unless C(path) is given. + type: str + required: yes + xpath: + description: + - A valid XPath expression describing the item(s) you want to manipulate. + - Operates on the document root, C(/), by default. + type: str + namespaces: + description: + - The namespace C(prefix:uri) mapping for the XPath expression. + - Needs to be a C(dict), not a C(list) of items. + type: dict + state: + description: + - Set or remove an xpath selection (node(s), attribute(s)). + type: str + choices: [ absent, present ] + default: present + aliases: [ ensure ] + attribute: + description: + - The attribute to select when using parameter C(value). + - This is a string, not prepended with C(@). + type: raw + value: + description: + - Desired state of the selected attribute. + - Either a string, or to unset a value, the Python C(None) keyword (YAML Equivalent, C(null)). + - Elements default to no value (but present). + - Attributes default to an empty string. + type: raw + add_children: + description: + - Add additional child-element(s) to a selected element for a given C(xpath). + - Child elements must be given in a list and each item may be either a string + (eg. C(children=ansible) to add an empty C(<ansible/>) child element), + or a hash where the key is an element name and the value is the element value. + - This parameter requires C(xpath) to be set. + type: list + set_children: + description: + - Set the child-element(s) of a selected element for a given C(xpath). + - Removes any existing children. + - Child elements must be specified as in C(add_children). + - This parameter requires C(xpath) to be set. + type: list + count: + description: + - Search for a given C(xpath) and provide the count of any matches. + - This parameter requires C(xpath) to be set. + type: bool + default: no + print_match: + description: + - Search for a given C(xpath) and print out any matches. + - This parameter requires C(xpath) to be set. + type: bool + default: no + pretty_print: + description: + - Pretty print XML output. + type: bool + default: no + content: + description: + - Search for a given C(xpath) and get content. + - This parameter requires C(xpath) to be set. + type: str + choices: [ attribute, text ] + input_type: + description: + - Type of input for C(add_children) and C(set_children). + type: str + choices: [ xml, yaml ] + default: yaml + backup: + description: + - Create a backup file including the timestamp information so you can get + the original file back if you somehow clobbered it incorrectly. + type: bool + default: no + strip_cdata_tags: + description: + - Remove CDATA tags surrounding text values. + - Note that this might break your XML file if text values contain characters that could be interpreted as XML. + type: bool + default: no + version_added: '2.7' + insertbefore: + description: + - Add additional child-element(s) before the first selected element for a given C(xpath). + - Child elements must be given in a list and each item may be either a string + (eg. C(children=ansible) to add an empty C(<ansible/>) child element), + or a hash where the key is an element name and the value is the element value. + - This parameter requires C(xpath) to be set. + type: bool + default: no + version_added: '2.8' + insertafter: + description: + - Add additional child-element(s) after the last selected element for a given C(xpath). + - Child elements must be given in a list and each item may be either a string + (eg. C(children=ansible) to add an empty C(<ansible/>) child element), + or a hash where the key is an element name and the value is the element value. + - This parameter requires C(xpath) to be set. + type: bool + default: no + version_added: '2.8' +requirements: +- lxml >= 2.3.0 +notes: +- Use the C(--check) and C(--diff) options when testing your expressions. +- The diff output is automatically pretty-printed, so may not reflect the actual file content, only the file structure. +- This module does not handle complicated xpath expressions, so limit xpath selectors to simple expressions. +- Beware that in case your XML elements are namespaced, you need to use the C(namespaces) parameter, see the examples. +- Namespaces prefix should be used for all children of an element where namespace is defined, unless another namespace is defined for them. +seealso: +- name: Xml module development community wiki + description: More information related to the development of this xml module. + link: https://github.com/ansible/community/wiki/Module:-xml +- name: Introduction to XPath + description: A brief tutorial on XPath (w3schools.com). + link: https://www.w3schools.com/xml/xpath_intro.asp +- name: XPath Reference document + description: The reference documentation on XSLT/XPath (developer.mozilla.org). + link: https://developer.mozilla.org/en-US/docs/Web/XPath +author: +- Tim Bielawa (@tbielawa) +- Magnus Hedemark (@magnus919) +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r''' +# Consider the following XML file: +# +# <business type="bar"> +# <name>Tasty Beverage Co.</name> +# <beers> +# <beer>Rochefort 10</beer> +# <beer>St. Bernardus Abbot 12</beer> +# <beer>Schlitz</beer> +# </beers> +# <rating subjective="true">10</rating> +# <website> +# <mobilefriendly/> +# <address>http://tastybeverageco.com</address> +# </website> +# </business> + +- name: Remove the 'subjective' attribute of the 'rating' element + xml: + path: /foo/bar.xml + xpath: /business/rating/@subjective + state: absent + +- name: Set the rating to '11' + xml: + path: /foo/bar.xml + xpath: /business/rating + value: 11 + +# Retrieve and display the number of nodes +- name: Get count of 'beers' nodes + xml: + path: /foo/bar.xml + xpath: /business/beers/beer + count: yes + register: hits + +- debug: + var: hits.count + +# Example where parent XML nodes are created automatically +- name: Add a 'phonenumber' element to the 'business' element + xml: + path: /foo/bar.xml + xpath: /business/phonenumber + value: 555-555-1234 + +- name: Add several more beers to the 'beers' element + xml: + path: /foo/bar.xml + xpath: /business/beers + add_children: + - beer: Old Rasputin + - beer: Old Motor Oil + - beer: Old Curmudgeon + +- name: Add several more beers to the 'beers' element and add them before the 'Rochefort 10' element + xml: + path: /foo/bar.xml + xpath: '/business/beers/beer[text()="Rochefort 10"]' + insertbefore: yes + add_children: + - beer: Old Rasputin + - beer: Old Motor Oil + - beer: Old Curmudgeon + +# NOTE: The 'state' defaults to 'present' and 'value' defaults to 'null' for elements +- name: Add a 'validxhtml' element to the 'website' element + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml + +- name: Add an empty 'validatedon' attribute to the 'validxhtml' element + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml/@validatedon + +- name: Add or modify an attribute, add element if needed + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml + attribute: validatedon + value: 1976-08-05 + +# How to read an attribute value and access it in Ansible +- name: Read an element's attribute values + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml + content: attribute + register: xmlresp + +- name: Show an attribute value + debug: + var: xmlresp.matches[0].validxhtml.validatedon + +- name: Remove all children from the 'website' element (option 1) + xml: + path: /foo/bar.xml + xpath: /business/website/* + state: absent + +- name: Remove all children from the 'website' element (option 2) + xml: + path: /foo/bar.xml + xpath: /business/website + children: [] + +# In case of namespaces, like in below XML, they have to be explicitly stated. +# +# <foo xmlns="http://x.test" xmlns:attr="http://z.test"> +# <bar> +# <baz xmlns="http://y.test" attr:my_namespaced_attribute="true" /> +# </bar> +# </foo> + +# NOTE: There is the prefix 'x' in front of the 'bar' element, too. +- name: Set namespaced '/x:foo/x:bar/y:baz/@z:my_namespaced_attribute' to 'false' + xml: + path: foo.xml + xpath: /x:foo/x:bar/y:baz + namespaces: + x: http://x.test + y: http://y.test + z: http://z.test + attribute: z:my_namespaced_attribute + value: 'false' +''' + +RETURN = r''' +actions: + description: A dictionary with the original xpath, namespaces and state. + type: dict + returned: success + sample: {xpath: xpath, namespaces: [namespace1, namespace2], state=present} +backup_file: + description: The name of the backup file that was created + type: str + returned: when backup=yes + sample: /path/to/file.xml.1942.2017-08-24@14:16:01~ +count: + description: The count of xpath matches. + type: int + returned: when parameter 'count' is set + sample: 2 +matches: + description: The xpath matches found. + type: list + returned: when parameter 'print_match' is set +msg: + description: A message related to the performed action(s). + type: str + returned: always +xmlstring: + description: An XML string of the resulting output. + type: str + returned: when parameter 'xmlstring' is set +''' + +import copy +import json +import os +import re +import traceback + +from distutils.version import LooseVersion +from io import BytesIO + +LXML_IMP_ERR = None +try: + from lxml import etree, objectify + HAS_LXML = True +except ImportError: + LXML_IMP_ERR = traceback.format_exc() + HAS_LXML = False + +from ansible.module_utils.basic import AnsibleModule, json_dict_bytes_to_unicode, missing_required_lib +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils._text import to_bytes, to_native +from ansible.module_utils.common._collections_compat import MutableMapping + +_IDENT = r"[a-zA-Z-][a-zA-Z0-9_\-\.]*" +_NSIDENT = _IDENT + "|" + _IDENT + ":" + _IDENT +# Note: we can't reasonably support the 'if you need to put both ' and " in a string, concatenate +# strings wrapped by the other delimiter' XPath trick, especially as simple XPath. +_XPSTR = "('(?:.*)'|\"(?:.*)\")" + +_RE_SPLITSIMPLELAST = re.compile("^(.*)/(" + _NSIDENT + ")$") +_RE_SPLITSIMPLELASTEQVALUE = re.compile("^(.*)/(" + _NSIDENT + ")/text\\(\\)=" + _XPSTR + "$") +_RE_SPLITSIMPLEATTRLAST = re.compile("^(.*)/(@(?:" + _NSIDENT + "))$") +_RE_SPLITSIMPLEATTRLASTEQVALUE = re.compile("^(.*)/(@(?:" + _NSIDENT + "))=" + _XPSTR + "$") +_RE_SPLITSUBLAST = re.compile("^(.*)/(" + _NSIDENT + ")\\[(.*)\\]$") +_RE_SPLITONLYEQVALUE = re.compile("^(.*)/text\\(\\)=" + _XPSTR + "$") + + +def has_changed(doc): + orig_obj = etree.tostring(objectify.fromstring(etree.tostring(orig_doc))) + obj = etree.tostring(objectify.fromstring(etree.tostring(doc))) + return (orig_obj != obj) + + +def do_print_match(module, tree, xpath, namespaces): + match = tree.xpath(xpath, namespaces=namespaces) + match_xpaths = [] + for m in match: + match_xpaths.append(tree.getpath(m)) + match_str = json.dumps(match_xpaths) + msg = "selector '%s' match: %s" % (xpath, match_str) + finish(module, tree, xpath, namespaces, changed=False, msg=msg) + + +def count_nodes(module, tree, xpath, namespaces): + """ Return the count of nodes matching the xpath """ + hits = tree.xpath("count(/%s)" % xpath, namespaces=namespaces) + msg = "found %d nodes" % hits + finish(module, tree, xpath, namespaces, changed=False, msg=msg, hitcount=int(hits)) + + +def is_node(tree, xpath, namespaces): + """ Test if a given xpath matches anything and if that match is a node. + + For now we just assume you're only searching for one specific thing.""" + if xpath_matches(tree, xpath, namespaces): + # OK, it found something + match = tree.xpath(xpath, namespaces=namespaces) + if isinstance(match[0], etree._Element): + return True + + return False + + +def is_attribute(tree, xpath, namespaces): + """ Test if a given xpath matches and that match is an attribute + + An xpath attribute search will only match one item""" + if xpath_matches(tree, xpath, namespaces): + match = tree.xpath(xpath, namespaces=namespaces) + if isinstance(match[0], etree._ElementStringResult): + return True + elif isinstance(match[0], etree._ElementUnicodeResult): + return True + return False + + +def xpath_matches(tree, xpath, namespaces): + """ Test if a node exists """ + if tree.xpath(xpath, namespaces=namespaces): + return True + return False + + +def delete_xpath_target(module, tree, xpath, namespaces): + """ Delete an attribute or element from a tree """ + try: + for result in tree.xpath(xpath, namespaces=namespaces): + # Get the xpath for this result + if is_attribute(tree, xpath, namespaces): + # Delete an attribute + parent = result.getparent() + # Pop this attribute match out of the parent + # node's 'attrib' dict by using this match's + # 'attrname' attribute for the key + parent.attrib.pop(result.attrname) + elif is_node(tree, xpath, namespaces): + # Delete an element + result.getparent().remove(result) + else: + raise Exception("Impossible error") + except Exception as e: + module.fail_json(msg="Couldn't delete xpath target: %s (%s)" % (xpath, e)) + else: + finish(module, tree, xpath, namespaces, changed=True) + + +def replace_children_of(children, match): + for element in list(match): + match.remove(element) + match.extend(children) + + +def set_target_children_inner(module, tree, xpath, namespaces, children, in_type): + matches = tree.xpath(xpath, namespaces=namespaces) + + # Create a list of our new children + children = children_to_nodes(module, children, in_type) + children_as_string = [etree.tostring(c) for c in children] + + changed = False + + # xpaths always return matches as a list, so.... + for match in matches: + # Check if elements differ + if len(list(match)) == len(children): + for idx, element in enumerate(list(match)): + if etree.tostring(element) != children_as_string[idx]: + replace_children_of(children, match) + changed = True + break + else: + replace_children_of(children, match) + changed = True + + return changed + + +def set_target_children(module, tree, xpath, namespaces, children, in_type): + changed = set_target_children_inner(module, tree, xpath, namespaces, children, in_type) + # Write it out + finish(module, tree, xpath, namespaces, changed=changed) + + +def add_target_children(module, tree, xpath, namespaces, children, in_type, insertbefore, insertafter): + if is_node(tree, xpath, namespaces): + new_kids = children_to_nodes(module, children, in_type) + if insertbefore or insertafter: + insert_target_children(tree, xpath, namespaces, new_kids, insertbefore, insertafter) + else: + for node in tree.xpath(xpath, namespaces=namespaces): + node.extend(new_kids) + finish(module, tree, xpath, namespaces, changed=True) + else: + finish(module, tree, xpath, namespaces) + + +def insert_target_children(tree, xpath, namespaces, children, insertbefore, insertafter): + """ + Insert the given children before or after the given xpath. If insertbefore is True, it is inserted before the + first xpath hit, with insertafter, it is inserted after the last xpath hit. + """ + insert_target = tree.xpath(xpath, namespaces=namespaces) + loc_index = 0 if insertbefore else -1 + index_in_parent = insert_target[loc_index].getparent().index(insert_target[loc_index]) + parent = insert_target[0].getparent() + if insertafter: + index_in_parent += 1 + for child in children: + parent.insert(index_in_parent, child) + index_in_parent += 1 + + +def _extract_xpstr(g): + return g[1:-1] + + +def split_xpath_last(xpath): + """split an XPath of the form /foo/bar/baz into /foo/bar and baz""" + xpath = xpath.strip() + m = _RE_SPLITSIMPLELAST.match(xpath) + if m: + # requesting an element to exist + return (m.group(1), [(m.group(2), None)]) + m = _RE_SPLITSIMPLELASTEQVALUE.match(xpath) + if m: + # requesting an element to exist with an inner text + return (m.group(1), [(m.group(2), _extract_xpstr(m.group(3)))]) + + m = _RE_SPLITSIMPLEATTRLAST.match(xpath) + if m: + # requesting an attribute to exist + return (m.group(1), [(m.group(2), None)]) + m = _RE_SPLITSIMPLEATTRLASTEQVALUE.match(xpath) + if m: + # requesting an attribute to exist with a value + return (m.group(1), [(m.group(2), _extract_xpstr(m.group(3)))]) + + m = _RE_SPLITSUBLAST.match(xpath) + if m: + content = [x.strip() for x in m.group(3).split(" and ")] + return (m.group(1), [('/' + m.group(2), content)]) + + m = _RE_SPLITONLYEQVALUE.match(xpath) + if m: + # requesting a change of inner text + return (m.group(1), [("", _extract_xpstr(m.group(2)))]) + return (xpath, []) + + +def nsnameToClark(name, namespaces): + if ":" in name: + (nsname, rawname) = name.split(":") + # return "{{%s}}%s" % (namespaces[nsname], rawname) + return "{{{0}}}{1}".format(namespaces[nsname], rawname) + + # no namespace name here + return name + + +def check_or_make_target(module, tree, xpath, namespaces): + (inner_xpath, changes) = split_xpath_last(xpath) + if (inner_xpath == xpath) or (changes is None): + module.fail_json(msg="Can't process Xpath %s in order to spawn nodes! tree is %s" % + (xpath, etree.tostring(tree, pretty_print=True))) + return False + + changed = False + + if not is_node(tree, inner_xpath, namespaces): + changed = check_or_make_target(module, tree, inner_xpath, namespaces) + + # we test again after calling check_or_make_target + if is_node(tree, inner_xpath, namespaces) and changes: + for (eoa, eoa_value) in changes: + if eoa and eoa[0] != '@' and eoa[0] != '/': + # implicitly creating an element + new_kids = children_to_nodes(module, [nsnameToClark(eoa, namespaces)], "yaml") + if eoa_value: + for nk in new_kids: + nk.text = eoa_value + + for node in tree.xpath(inner_xpath, namespaces=namespaces): + node.extend(new_kids) + changed = True + # module.fail_json(msg="now tree=%s" % etree.tostring(tree, pretty_print=True)) + elif eoa and eoa[0] == '/': + element = eoa[1:] + new_kids = children_to_nodes(module, [nsnameToClark(element, namespaces)], "yaml") + for node in tree.xpath(inner_xpath, namespaces=namespaces): + node.extend(new_kids) + for nk in new_kids: + for subexpr in eoa_value: + # module.fail_json(msg="element=%s subexpr=%s node=%s now tree=%s" % + # (element, subexpr, etree.tostring(node, pretty_print=True), etree.tostring(tree, pretty_print=True)) + check_or_make_target(module, nk, "./" + subexpr, namespaces) + changed = True + + # module.fail_json(msg="now tree=%s" % etree.tostring(tree, pretty_print=True)) + elif eoa == "": + for node in tree.xpath(inner_xpath, namespaces=namespaces): + if (node.text != eoa_value): + node.text = eoa_value + changed = True + + elif eoa and eoa[0] == '@': + attribute = nsnameToClark(eoa[1:], namespaces) + + for element in tree.xpath(inner_xpath, namespaces=namespaces): + changing = (attribute not in element.attrib or element.attrib[attribute] != eoa_value) + + if changing: + changed = changed or changing + if eoa_value is None: + value = "" + else: + value = eoa_value + element.attrib[attribute] = value + + # module.fail_json(msg="arf %s changing=%s as curval=%s changed tree=%s" % + # (xpath, changing, etree.tostring(tree, changing, element[attribute], pretty_print=True))) + + else: + module.fail_json(msg="unknown tree transformation=%s" % etree.tostring(tree, pretty_print=True)) + + return changed + + +def ensure_xpath_exists(module, tree, xpath, namespaces): + changed = False + + if not is_node(tree, xpath, namespaces): + changed = check_or_make_target(module, tree, xpath, namespaces) + + finish(module, tree, xpath, namespaces, changed) + + +def set_target_inner(module, tree, xpath, namespaces, attribute, value): + changed = False + + try: + if not is_node(tree, xpath, namespaces): + changed = check_or_make_target(module, tree, xpath, namespaces) + except Exception as e: + missing_namespace = "" + # NOTE: This checks only the namespaces defined in root element! + # TODO: Implement a more robust check to check for child namespaces' existence + if tree.getroot().nsmap and ":" not in xpath: + missing_namespace = "XML document has namespace(s) defined, but no namespace prefix(es) used in xpath!\n" + module.fail_json(msg="%sXpath %s causes a failure: %s\n -- tree is %s" % + (missing_namespace, xpath, e, etree.tostring(tree, pretty_print=True)), exception=traceback.format_exc()) + + if not is_node(tree, xpath, namespaces): + module.fail_json(msg="Xpath %s does not reference a node! tree is %s" % + (xpath, etree.tostring(tree, pretty_print=True))) + + for element in tree.xpath(xpath, namespaces=namespaces): + if not attribute: + changed = changed or (element.text != value) + if element.text != value: + element.text = value + else: + changed = changed or (element.get(attribute) != value) + if ":" in attribute: + attr_ns, attr_name = attribute.split(":") + # attribute = "{{%s}}%s" % (namespaces[attr_ns], attr_name) + attribute = "{{{0}}}{1}".format(namespaces[attr_ns], attr_name) + if element.get(attribute) != value: + element.set(attribute, value) + + return changed + + +def set_target(module, tree, xpath, namespaces, attribute, value): + changed = set_target_inner(module, tree, xpath, namespaces, attribute, value) + finish(module, tree, xpath, namespaces, changed) + + +def get_element_text(module, tree, xpath, namespaces): + if not is_node(tree, xpath, namespaces): + module.fail_json(msg="Xpath %s does not reference a node!" % xpath) + + elements = [] + for element in tree.xpath(xpath, namespaces=namespaces): + elements.append({element.tag: element.text}) + + finish(module, tree, xpath, namespaces, changed=False, msg=len(elements), hitcount=len(elements), matches=elements) + + +def get_element_attr(module, tree, xpath, namespaces): + if not is_node(tree, xpath, namespaces): + module.fail_json(msg="Xpath %s does not reference a node!" % xpath) + + elements = [] + for element in tree.xpath(xpath, namespaces=namespaces): + child = {} + for key in element.keys(): + value = element.get(key) + child.update({key: value}) + elements.append({element.tag: child}) + + finish(module, tree, xpath, namespaces, changed=False, msg=len(elements), hitcount=len(elements), matches=elements) + + +def child_to_element(module, child, in_type): + if in_type == 'xml': + infile = BytesIO(to_bytes(child, errors='surrogate_or_strict')) + + try: + parser = etree.XMLParser() + node = etree.parse(infile, parser) + return node.getroot() + except etree.XMLSyntaxError as e: + module.fail_json(msg="Error while parsing child element: %s" % e) + elif in_type == 'yaml': + if isinstance(child, string_types): + return etree.Element(child) + elif isinstance(child, MutableMapping): + if len(child) > 1: + module.fail_json(msg="Can only create children from hashes with one key") + + (key, value) = next(iteritems(child)) + if isinstance(value, MutableMapping): + children = value.pop('_', None) + + node = etree.Element(key, value) + + if children is not None: + if not isinstance(children, list): + module.fail_json(msg="Invalid children type: %s, must be list." % type(children)) + + subnodes = children_to_nodes(module, children) + node.extend(subnodes) + else: + node = etree.Element(key) + node.text = value + return node + else: + module.fail_json(msg="Invalid child type: %s. Children must be either strings or hashes." % type(child)) + else: + module.fail_json(msg="Invalid child input type: %s. Type must be either xml or yaml." % in_type) + + +def children_to_nodes(module=None, children=None, type='yaml'): + """turn a str/hash/list of str&hash into a list of elements""" + children = [] if children is None else children + + return [child_to_element(module, child, type) for child in children] + + +def make_pretty(module, tree): + xml_string = etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + result = dict( + changed=False, + ) + + if module.params['path']: + xml_file = module.params['path'] + with open(xml_file, 'rb') as xml_content: + if xml_string != xml_content.read(): + result['changed'] = True + if not module.check_mode: + if module.params['backup']: + result['backup_file'] = module.backup_local(module.params['path']) + tree.write(xml_file, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + elif module.params['xmlstring']: + result['xmlstring'] = xml_string + # NOTE: Modifying a string is not considered a change ! + if xml_string != module.params['xmlstring']: + result['changed'] = True + + module.exit_json(**result) + + +def finish(module, tree, xpath, namespaces, changed=False, msg='', hitcount=0, matches=tuple()): + + result = dict( + actions=dict( + xpath=xpath, + namespaces=namespaces, + state=module.params['state'] + ), + changed=has_changed(tree), + ) + + if module.params['count'] or hitcount: + result['count'] = hitcount + + if module.params['print_match'] or matches: + result['matches'] = matches + + if msg: + result['msg'] = msg + + if result['changed']: + if module._diff: + result['diff'] = dict( + before=etree.tostring(orig_doc, xml_declaration=True, encoding='UTF-8', pretty_print=True), + after=etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=True), + ) + + if module.params['path'] and not module.check_mode: + if module.params['backup']: + result['backup_file'] = module.backup_local(module.params['path']) + tree.write(module.params['path'], xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + if module.params['xmlstring']: + result['xmlstring'] = etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + module.exit_json(**result) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + path=dict(type='path', aliases=['dest', 'file']), + xmlstring=dict(type='str'), + xpath=dict(type='str'), + namespaces=dict(type='dict', default={}), + state=dict(type='str', default='present', choices=['absent', 'present'], aliases=['ensure']), + value=dict(type='raw'), + attribute=dict(type='raw'), + add_children=dict(type='list'), + set_children=dict(type='list'), + count=dict(type='bool', default=False), + print_match=dict(type='bool', default=False), + pretty_print=dict(type='bool', default=False), + content=dict(type='str', choices=['attribute', 'text']), + input_type=dict(type='str', default='yaml', choices=['xml', 'yaml']), + backup=dict(type='bool', default=False), + strip_cdata_tags=dict(type='bool', default=False), + insertbefore=dict(type='bool', default=False), + insertafter=dict(type='bool', default=False), + ), + supports_check_mode=True, + required_by=dict( + add_children=['xpath'], + # TODO: Reinstate this in Ansible v2.12 when we have deprecated the incorrect use below + # attribute=['value'], + content=['xpath'], + set_children=['xpath'], + value=['xpath'], + ), + required_if=[ + ['count', True, ['xpath']], + ['print_match', True, ['xpath']], + ['insertbefore', True, ['xpath']], + ['insertafter', True, ['xpath']], + ], + required_one_of=[ + ['path', 'xmlstring'], + ['add_children', 'content', 'count', 'pretty_print', 'print_match', 'set_children', 'value'], + ], + mutually_exclusive=[ + ['add_children', 'content', 'count', 'print_match', 'set_children', 'value'], + ['path', 'xmlstring'], + ['insertbefore', 'insertafter'], + ], + ) + + xml_file = module.params['path'] + xml_string = module.params['xmlstring'] + xpath = module.params['xpath'] + namespaces = module.params['namespaces'] + state = module.params['state'] + value = json_dict_bytes_to_unicode(module.params['value']) + attribute = module.params['attribute'] + set_children = json_dict_bytes_to_unicode(module.params['set_children']) + add_children = json_dict_bytes_to_unicode(module.params['add_children']) + pretty_print = module.params['pretty_print'] + content = module.params['content'] + input_type = module.params['input_type'] + print_match = module.params['print_match'] + count = module.params['count'] + backup = module.params['backup'] + strip_cdata_tags = module.params['strip_cdata_tags'] + insertbefore = module.params['insertbefore'] + insertafter = module.params['insertafter'] + + # Check if we have lxml 2.3.0 or newer installed + if not HAS_LXML: + module.fail_json(msg=missing_required_lib("lxml"), exception=LXML_IMP_ERR) + elif LooseVersion('.'.join(to_native(f) for f in etree.LXML_VERSION)) < LooseVersion('2.3.0'): + module.fail_json(msg='The xml ansible module requires lxml 2.3.0 or newer installed on the managed machine') + elif LooseVersion('.'.join(to_native(f) for f in etree.LXML_VERSION)) < LooseVersion('3.0.0'): + module.warn('Using lxml version lower than 3.0.0 does not guarantee predictable element attribute order.') + + # Report wrongly used attribute parameter when using content=attribute + # TODO: Remove this in Ansible v2.12 (and reinstate strict parameter test above) and remove the integration test example + if content == 'attribute' and attribute is not None: + module.deprecate("Parameter 'attribute=%s' is ignored when using 'content=attribute' only 'xpath' is used. Please remove entry." % attribute, + '2.12', collection_name='ansible.builtin') + + # Check if the file exists + if xml_string: + infile = BytesIO(to_bytes(xml_string, errors='surrogate_or_strict')) + elif os.path.isfile(xml_file): + infile = open(xml_file, 'rb') + else: + module.fail_json(msg="The target XML source '%s' does not exist." % xml_file) + + # Parse and evaluate xpath expression + if xpath is not None: + try: + etree.XPath(xpath) + except etree.XPathSyntaxError as e: + module.fail_json(msg="Syntax error in xpath expression: %s (%s)" % (xpath, e)) + except etree.XPathEvalError as e: + module.fail_json(msg="Evaluation error in xpath expression: %s (%s)" % (xpath, e)) + + # Try to parse in the target XML file + try: + parser = etree.XMLParser(remove_blank_text=pretty_print, strip_cdata=strip_cdata_tags) + doc = etree.parse(infile, parser) + except etree.XMLSyntaxError as e: + module.fail_json(msg="Error while parsing document: %s (%s)" % (xml_file or 'xml_string', e)) + + # Ensure we have the original copy to compare + global orig_doc + orig_doc = copy.deepcopy(doc) + + if print_match: + do_print_match(module, doc, xpath, namespaces) + + if count: + count_nodes(module, doc, xpath, namespaces) + + if content == 'attribute': + get_element_attr(module, doc, xpath, namespaces) + elif content == 'text': + get_element_text(module, doc, xpath, namespaces) + + # File exists: + if state == 'absent': + # - absent: delete xpath target + delete_xpath_target(module, doc, xpath, namespaces) + + # - present: carry on + + # children && value both set?: should have already aborted by now + # add_children && set_children both set?: should have already aborted by now + + # set_children set? + if set_children: + set_target_children(module, doc, xpath, namespaces, set_children, input_type) + + # add_children set? + if add_children: + add_target_children(module, doc, xpath, namespaces, add_children, input_type, insertbefore, insertafter) + + # No?: Carry on + + # Is the xpath target an attribute selector? + if value is not None: + set_target(module, doc, xpath, namespaces, attribute, value) + + # If an xpath was provided, we need to do something with the data + if xpath is not None: + ensure_xpath_exists(module, doc, xpath, namespaces) + + # Otherwise only reformat the xml data? + if pretty_print: + make_pretty(module, doc) + + module.fail_json(msg="Don't know what to do") + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/zypper.py b/test/support/integration/plugins/modules/zypper.py new file mode 100644 index 00000000..bfb31819 --- /dev/null +++ b/test/support/integration/plugins/modules/zypper.py @@ -0,0 +1,540 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, Patrick Callahan <pmc@patrickcallahan.com> +# based on +# openbsd_pkg +# (c) 2013 +# Patrik Lundin <patrik.lundin.swe@gmail.com> +# +# yum +# (c) 2012, Red Hat, Inc +# Written by Seth Vidal <skvidal at fedoraproject.org> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: zypper +author: + - "Patrick Callahan (@dirtyharrycallahan)" + - "Alexander Gubin (@alxgu)" + - "Thomas O'Donnell (@andytom)" + - "Robin Roth (@robinro)" + - "Andrii Radyk (@AnderEnder)" +version_added: "1.2" +short_description: Manage packages on SUSE and openSUSE +description: + - Manage packages on SUSE and openSUSE using the zypper and rpm tools. +options: + name: + description: + - Package name C(name) or package specifier or a list of either. + - Can include a version like C(name=1.0), C(name>3.4) or C(name<=2.7). If a version is given, C(oldpackage) is implied and zypper is allowed to + update the package within the version range given. + - You can also pass a url or a local path to a rpm file. + - When using state=latest, this can be '*', which updates all installed packages. + required: true + aliases: [ 'pkg' ] + state: + description: + - C(present) will make sure the package is installed. + C(latest) will make sure the latest version of the package is installed. + C(absent) will make sure the specified package is not installed. + C(dist-upgrade) will make sure the latest version of all installed packages from all enabled repositories is installed. + - When using C(dist-upgrade), I(name) should be C('*'). + required: false + choices: [ present, latest, absent, dist-upgrade ] + default: "present" + type: + description: + - The type of package to be operated on. + required: false + choices: [ package, patch, pattern, product, srcpackage, application ] + default: "package" + version_added: "2.0" + extra_args_precommand: + version_added: "2.6" + required: false + description: + - Add additional global target options to C(zypper). + - Options should be supplied in a single line as if given in the command line. + disable_gpg_check: + description: + - Whether to disable to GPG signature checking of the package + signature being installed. Has an effect only if state is + I(present) or I(latest). + required: false + default: "no" + type: bool + disable_recommends: + version_added: "1.8" + description: + - Corresponds to the C(--no-recommends) option for I(zypper). Default behavior (C(yes)) modifies zypper's default behavior; C(no) does + install recommended packages. + required: false + default: "yes" + type: bool + force: + version_added: "2.2" + description: + - Adds C(--force) option to I(zypper). Allows to downgrade packages and change vendor or architecture. + required: false + default: "no" + type: bool + force_resolution: + version_added: "2.10" + description: + - Adds C(--force-resolution) option to I(zypper). Allows to (un)install packages with conflicting requirements (resolver will choose a solution). + required: false + default: "no" + type: bool + update_cache: + version_added: "2.2" + description: + - Run the equivalent of C(zypper refresh) before the operation. Disabled in check mode. + required: false + default: "no" + type: bool + aliases: [ "refresh" ] + oldpackage: + version_added: "2.2" + description: + - Adds C(--oldpackage) option to I(zypper). Allows to downgrade packages with less side-effects than force. This is implied as soon as a + version is specified as part of the package name. + required: false + default: "no" + type: bool + extra_args: + version_added: "2.4" + required: false + description: + - Add additional options to C(zypper) command. + - Options should be supplied in a single line as if given in the command line. +notes: + - When used with a `loop:` each package will be processed individually, + it is much more efficient to pass the list directly to the `name` option. +# informational: requirements for nodes +requirements: + - "zypper >= 1.0 # included in openSUSE >= 11.1 or SUSE Linux Enterprise Server/Desktop >= 11.0" + - python-xml + - rpm +''' + +EXAMPLES = ''' +# Install "nmap" +- zypper: + name: nmap + state: present + +# Install apache2 with recommended packages +- zypper: + name: apache2 + state: present + disable_recommends: no + +# Apply a given patch +- zypper: + name: openSUSE-2016-128 + state: present + type: patch + +# Remove the "nmap" package +- zypper: + name: nmap + state: absent + +# Install the nginx rpm from a remote repo +- zypper: + name: 'http://nginx.org/packages/sles/12/x86_64/RPMS/nginx-1.8.0-1.sles12.ngx.x86_64.rpm' + state: present + +# Install local rpm file +- zypper: + name: /tmp/fancy-software.rpm + state: present + +# Update all packages +- zypper: + name: '*' + state: latest + +# Apply all available patches +- zypper: + name: '*' + state: latest + type: patch + +# Perform a dist-upgrade with additional arguments +- zypper: + name: '*' + state: dist-upgrade + extra_args: '--no-allow-vendor-change --allow-arch-change' + +# Refresh repositories and update package "openssl" +- zypper: + name: openssl + state: present + update_cache: yes + +# Install specific version (possible comparisons: <, >, <=, >=, =) +- zypper: + name: 'docker>=1.10' + state: present + +# Wait 20 seconds to acquire the lock before failing +- zypper: + name: mosh + state: present + environment: + ZYPP_LOCK_TIMEOUT: 20 +''' + +import xml +import re +from xml.dom.minidom import parseString as parseXML +from ansible.module_utils.six import iteritems +from ansible.module_utils._text import to_native + +# import module snippets +from ansible.module_utils.basic import AnsibleModule + + +class Package: + def __init__(self, name, prefix, version): + self.name = name + self.prefix = prefix + self.version = version + self.shouldinstall = (prefix == '+') + + def __str__(self): + return self.prefix + self.name + self.version + + +def split_name_version(name): + """splits of the package name and desired version + + example formats: + - docker>=1.10 + - apache=2.4 + + Allowed version specifiers: <, >, <=, >=, = + Allowed version format: [0-9.-]* + + Also allows a prefix indicating remove "-", "~" or install "+" + """ + + prefix = '' + if name[0] in ['-', '~', '+']: + prefix = name[0] + name = name[1:] + if prefix == '~': + prefix = '-' + + version_check = re.compile('^(.*?)((?:<|>|<=|>=|=)[0-9.-]*)?$') + try: + reres = version_check.match(name) + name, version = reres.groups() + if version is None: + version = '' + return prefix, name, version + except Exception: + return prefix, name, '' + + +def get_want_state(names, remove=False): + packages = [] + urls = [] + for name in names: + if '://' in name or name.endswith('.rpm'): + urls.append(name) + else: + prefix, pname, version = split_name_version(name) + if prefix not in ['-', '+']: + if remove: + prefix = '-' + else: + prefix = '+' + packages.append(Package(pname, prefix, version)) + return packages, urls + + +def get_installed_state(m, packages): + "get installed state of packages" + + cmd = get_cmd(m, 'search') + cmd.extend(['--match-exact', '--details', '--installed-only']) + cmd.extend([p.name for p in packages]) + return parse_zypper_xml(m, cmd, fail_not_found=False)[0] + + +def parse_zypper_xml(m, cmd, fail_not_found=True, packages=None): + rc, stdout, stderr = m.run_command(cmd, check_rc=False) + + try: + dom = parseXML(stdout) + except xml.parsers.expat.ExpatError as exc: + m.fail_json(msg="Failed to parse zypper xml output: %s" % to_native(exc), + rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + + if rc == 104: + # exit code 104 is ZYPPER_EXIT_INF_CAP_NOT_FOUND (no packages found) + if fail_not_found: + errmsg = dom.getElementsByTagName('message')[-1].childNodes[0].data + m.fail_json(msg=errmsg, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + else: + return {}, rc, stdout, stderr + elif rc in [0, 106, 103]: + # zypper exit codes + # 0: success + # 106: signature verification failed + # 103: zypper was upgraded, run same command again + if packages is None: + firstrun = True + packages = {} + solvable_list = dom.getElementsByTagName('solvable') + for solvable in solvable_list: + name = solvable.getAttribute('name') + packages[name] = {} + packages[name]['version'] = solvable.getAttribute('edition') + packages[name]['oldversion'] = solvable.getAttribute('edition-old') + status = solvable.getAttribute('status') + packages[name]['installed'] = status == "installed" + packages[name]['group'] = solvable.parentNode.nodeName + if rc == 103 and firstrun: + # if this was the first run and it failed with 103 + # run zypper again with the same command to complete update + return parse_zypper_xml(m, cmd, fail_not_found=fail_not_found, packages=packages) + + return packages, rc, stdout, stderr + m.fail_json(msg='Zypper run command failed with return code %s.' % rc, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + + +def get_cmd(m, subcommand): + "puts together the basic zypper command arguments with those passed to the module" + is_install = subcommand in ['install', 'update', 'patch', 'dist-upgrade'] + is_refresh = subcommand == 'refresh' + cmd = ['/usr/bin/zypper', '--quiet', '--non-interactive', '--xmlout'] + if m.params['extra_args_precommand']: + args_list = m.params['extra_args_precommand'].split() + cmd.extend(args_list) + # add global options before zypper command + if (is_install or is_refresh) and m.params['disable_gpg_check']: + cmd.append('--no-gpg-checks') + + if subcommand == 'search': + cmd.append('--disable-repositories') + + cmd.append(subcommand) + if subcommand not in ['patch', 'dist-upgrade'] and not is_refresh: + cmd.extend(['--type', m.params['type']]) + if m.check_mode and subcommand != 'search': + cmd.append('--dry-run') + if is_install: + cmd.append('--auto-agree-with-licenses') + if m.params['disable_recommends']: + cmd.append('--no-recommends') + if m.params['force']: + cmd.append('--force') + if m.params['force_resolution']: + cmd.append('--force-resolution') + if m.params['oldpackage']: + cmd.append('--oldpackage') + if m.params['extra_args']: + args_list = m.params['extra_args'].split(' ') + cmd.extend(args_list) + + return cmd + + +def set_diff(m, retvals, result): + # TODO: if there is only one package, set before/after to version numbers + packages = {'installed': [], 'removed': [], 'upgraded': []} + if result: + for p in result: + group = result[p]['group'] + if group == 'to-upgrade': + versions = ' (' + result[p]['oldversion'] + ' => ' + result[p]['version'] + ')' + packages['upgraded'].append(p + versions) + elif group == 'to-install': + packages['installed'].append(p) + elif group == 'to-remove': + packages['removed'].append(p) + + output = '' + for state in packages: + if packages[state]: + output += state + ': ' + ', '.join(packages[state]) + '\n' + if 'diff' not in retvals: + retvals['diff'] = {} + if 'prepared' not in retvals['diff']: + retvals['diff']['prepared'] = output + else: + retvals['diff']['prepared'] += '\n' + output + + +def package_present(m, name, want_latest): + "install and update (if want_latest) the packages in name_install, while removing the packages in name_remove" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + packages, urls = get_want_state(name) + + # add oldpackage flag when a version is given to allow downgrades + if any(p.version for p in packages): + m.params['oldpackage'] = True + + if not want_latest: + # for state=present: filter out already installed packages + # if a version is given leave the package in to let zypper handle the version + # resolution + packageswithoutversion = [p for p in packages if not p.version] + prerun_state = get_installed_state(m, packageswithoutversion) + # generate lists of packages to install or remove + packages = [p for p in packages if p.shouldinstall != (p.name in prerun_state)] + + if not packages and not urls: + # nothing to install/remove and nothing to update + return None, retvals + + # zypper install also updates packages + cmd = get_cmd(m, 'install') + cmd.append('--') + cmd.extend(urls) + # pass packages to zypper + # allow for + or - prefixes in install/remove lists + # also add version specifier if given + # do this in one zypper run to allow for dependency-resolution + # for example "-exim postfix" runs without removing packages depending on mailserver + cmd.extend([str(p) for p in packages]) + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + + return result, retvals + + +def package_update_all(m): + "run update or patch on all available packages" + + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + if m.params['type'] == 'patch': + cmdname = 'patch' + elif m.params['state'] == 'dist-upgrade': + cmdname = 'dist-upgrade' + else: + cmdname = 'update' + + cmd = get_cmd(m, cmdname) + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + return result, retvals + + +def package_absent(m, name): + "remove the packages in name" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + # Get package state + packages, urls = get_want_state(name, remove=True) + if any(p.prefix == '+' for p in packages): + m.fail_json(msg="Can not combine '+' prefix with state=remove/absent.") + if urls: + m.fail_json(msg="Can not remove via URL.") + if m.params['type'] == 'patch': + m.fail_json(msg="Can not remove patches.") + prerun_state = get_installed_state(m, packages) + packages = [p for p in packages if p.name in prerun_state] + + if not packages: + return None, retvals + + cmd = get_cmd(m, 'remove') + cmd.extend([p.name + p.version for p in packages]) + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + return result, retvals + + +def repo_refresh(m): + "update the repositories" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + + cmd = get_cmd(m, 'refresh') + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + + return retvals + +# =========================================== +# Main control flow + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(required=True, aliases=['pkg'], type='list'), + state=dict(required=False, default='present', choices=['absent', 'installed', 'latest', 'present', 'removed', 'dist-upgrade']), + type=dict(required=False, default='package', choices=['package', 'patch', 'pattern', 'product', 'srcpackage', 'application']), + extra_args_precommand=dict(required=False, default=None), + disable_gpg_check=dict(required=False, default='no', type='bool'), + disable_recommends=dict(required=False, default='yes', type='bool'), + force=dict(required=False, default='no', type='bool'), + force_resolution=dict(required=False, default='no', type='bool'), + update_cache=dict(required=False, aliases=['refresh'], default='no', type='bool'), + oldpackage=dict(required=False, default='no', type='bool'), + extra_args=dict(required=False, default=None), + ), + supports_check_mode=True + ) + + name = module.params['name'] + state = module.params['state'] + update_cache = module.params['update_cache'] + + # remove empty strings from package list + name = list(filter(None, name)) + + # Refresh repositories + if update_cache and not module.check_mode: + retvals = repo_refresh(module) + + if retvals['rc'] != 0: + module.fail_json(msg="Zypper refresh run failed.", **retvals) + + # Perform requested action + if name == ['*'] and state in ['latest', 'dist-upgrade']: + packages_changed, retvals = package_update_all(module) + elif name != ['*'] and state == 'dist-upgrade': + module.fail_json(msg="Can not dist-upgrade specific packages.") + else: + if state in ['absent', 'removed']: + packages_changed, retvals = package_absent(module, name) + elif state in ['installed', 'present', 'latest']: + packages_changed, retvals = package_present(module, name, state == 'latest') + + retvals['changed'] = retvals['rc'] == 0 and bool(packages_changed) + + if module._diff: + set_diff(module, retvals, packages_changed) + + if retvals['rc'] != 0: + module.fail_json(msg="Zypper run failed.", **retvals) + + if not retvals['changed']: + del retvals['stdout'] + del retvals['stderr'] + + module.exit_json(name=name, state=state, update_cache=update_cache, **retvals) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py new file mode 100644 index 00000000..089b339f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/cli_config.py @@ -0,0 +1,40 @@ +# +# Copyright 2018 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from ansible_collections.ansible.netcommon.plugins.action.network import ( + ActionModule as ActionNetworkModule, +) + + +class ActionModule(ActionNetworkModule): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + self._config_module = True + if self._play_context.connection.split(".")[-1] != "network_cli": + return { + "failed": True, + "msg": "Connection type %s is not valid for cli_config module" + % self._play_context.connection, + } + + return super(ActionModule, self).run(task_vars=task_vars) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py new file mode 100644 index 00000000..542dcfef --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_base.py @@ -0,0 +1,90 @@ +# Copyright: (c) 2015, Ansible Inc, +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import copy + +from ansible.errors import AnsibleError +from ansible.plugins.action import ActionBase +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + result = {} + play_context = copy.deepcopy(self._play_context) + play_context.network_os = self._get_network_os(task_vars) + new_task = self._task.copy() + + module = self._get_implementation_module( + play_context.network_os, self._task.action + ) + if not module: + if self._task.args["fail_on_missing_module"]: + result["failed"] = True + else: + result["failed"] = False + + result["msg"] = ( + "Could not find implementation module %s for %s" + % (self._task.action, play_context.network_os) + ) + return result + + new_task.action = module + + action = self._shared_loader_obj.action_loader.get( + play_context.network_os, + task=new_task, + connection=self._connection, + play_context=play_context, + loader=self._loader, + templar=self._templar, + shared_loader_obj=self._shared_loader_obj, + ) + display.vvvv("Running implementation module %s" % module) + return action.run(task_vars=task_vars) + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host to use platform agnostic modules" + ) + + return network_os + + def _get_implementation_module(self, network_os, platform_agnostic_module): + module_name = ( + network_os.split(".")[-1] + + "_" + + platform_agnostic_module.partition("_")[2] + ) + if "." in network_os: + fqcn_module = ".".join(network_os.split(".")[0:-1]) + implementation_module = fqcn_module + "." + module_name + else: + implementation_module = module_name + + if implementation_module not in self._shared_loader_obj.module_loader: + implementation_module = None + + return implementation_module diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py new file mode 100644 index 00000000..40205a46 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_get.py @@ -0,0 +1,199 @@ +# (c) 2018, Ansible Inc, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import re +import uuid +import hashlib + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible.plugins.action import ActionBase +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + socket_path = None + self._get_network_os(task_vars) + persistent_connection = self._play_context.connection.split(".")[-1] + + result = super(ActionModule, self).run(task_vars=task_vars) + + if persistent_connection != "network_cli": + # It is supported only with network_cli + result["failed"] = True + result["msg"] = ( + "connection type %s is not valid for net_get module," + " please use fully qualified name of network_cli connection type" + % self._play_context.connection + ) + return result + + try: + src = self._task.args["src"] + except KeyError as exc: + return { + "failed": True, + "msg": "missing required argument: %s" % exc, + } + + # Get destination file if specified + dest = self._task.args.get("dest") + + if dest is None: + dest = self._get_default_dest(src) + else: + dest = self._handle_dest_path(dest) + + # Get proto + proto = self._task.args.get("protocol") + if proto is None: + proto = "scp" + + if socket_path is None: + socket_path = self._connection.socket_path + + conn = Connection(socket_path) + sock_timeout = conn.get_option("persistent_command_timeout") + + try: + changed = self._handle_existing_file( + conn, src, dest, proto, sock_timeout + ) + if changed is False: + result["changed"] = changed + result["destination"] = dest + return result + except Exception as exc: + result["msg"] = ( + "Warning: %s idempotency check failed. Check dest" % exc + ) + + try: + conn.get_file( + source=src, destination=dest, proto=proto, timeout=sock_timeout + ) + except Exception as exc: + result["failed"] = True + result["msg"] = "Exception received: %s" % exc + + result["changed"] = changed + result["destination"] = dest + return result + + def _handle_dest_path(self, dest): + working_path = self._get_working_path() + + if os.path.isabs(dest) or urlsplit("dest").scheme: + dst = dest + else: + dst = self._loader.path_dwim_relative(working_path, "", dest) + + return dst + + def _get_src_filename_from_path(self, src_path): + filename_list = re.split("/|:", src_path) + return filename_list[-1] + + def _get_default_dest(self, src_path): + dest_path = self._get_working_path() + src_fname = self._get_src_filename_from_path(src_path) + filename = "%s/%s" % (dest_path, src_fname) + return filename + + def _handle_existing_file(self, conn, source, dest, proto, timeout): + """ + Determines whether the source and destination file match. + + :return: False if source and dest both exist and have matching sha1 sums, True otherwise. + """ + if not os.path.exists(dest): + return True + + cwd = self._loader.get_basedir() + filename = str(uuid.uuid4()) + tmp_dest_file = os.path.join(cwd, filename) + try: + conn.get_file( + source=source, + destination=tmp_dest_file, + proto=proto, + timeout=timeout, + ) + except ConnectionError as exc: + error = to_text(exc) + if error.endswith("No such file or directory"): + if os.path.exists(tmp_dest_file): + os.remove(tmp_dest_file) + return True + + try: + with open(tmp_dest_file, "r") as f: + new_content = f.read() + with open(dest, "r") as f: + old_content = f.read() + except (IOError, OSError): + os.remove(tmp_dest_file) + raise + + sha1 = hashlib.sha1() + old_content_b = to_bytes(old_content, errors="surrogate_or_strict") + sha1.update(old_content_b) + checksum_old = sha1.digest() + + sha1 = hashlib.sha1() + new_content_b = to_bytes(new_content, errors="surrogate_or_strict") + sha1.update(new_content_b) + checksum_new = sha1.digest() + os.remove(tmp_dest_file) + if checksum_old == checksum_new: + return False + return True + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py new file mode 100644 index 00000000..955329d4 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/net_put.py @@ -0,0 +1,235 @@ +# (c) 2018, Ansible Inc, +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import uuid +import hashlib + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible.plugins.action import ActionBase +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionBase): + def run(self, tmp=None, task_vars=None): + socket_path = None + network_os = self._get_network_os(task_vars).split(".")[-1] + persistent_connection = self._play_context.connection.split(".")[-1] + + result = super(ActionModule, self).run(task_vars=task_vars) + + if persistent_connection != "network_cli": + # It is supported only with network_cli + result["failed"] = True + result["msg"] = ( + "connection type %s is not valid for net_put module," + " please use fully qualified name of network_cli connection type" + % self._play_context.connection + ) + return result + + try: + src = self._task.args["src"] + except KeyError as exc: + return { + "failed": True, + "msg": "missing required argument: %s" % exc, + } + + src_file_path_name = src + + # Get destination file if specified + dest = self._task.args.get("dest") + + # Get proto + proto = self._task.args.get("protocol") + if proto is None: + proto = "scp" + + # Get mode if set + mode = self._task.args.get("mode") + if mode is None: + mode = "binary" + + if mode == "text": + try: + self._handle_template(convert_data=False) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + # Now src has resolved file write to disk in current diectory for scp + src = self._task.args.get("src") + filename = str(uuid.uuid4()) + cwd = self._loader.get_basedir() + output_file = os.path.join(cwd, filename) + try: + with open(output_file, "wb") as f: + f.write(to_bytes(src, encoding="utf-8")) + except Exception: + os.remove(output_file) + raise + else: + try: + output_file = self._get_binary_src_file(src) + except ValueError as exc: + return dict(failed=True, msg=to_text(exc)) + + if socket_path is None: + socket_path = self._connection.socket_path + + conn = Connection(socket_path) + sock_timeout = conn.get_option("persistent_command_timeout") + + if dest is None: + dest = src_file_path_name + + try: + changed = self._handle_existing_file( + conn, output_file, dest, proto, sock_timeout + ) + if changed is False: + result["changed"] = changed + result["destination"] = dest + return result + except Exception as exc: + result["msg"] = ( + "Warning: %s idempotency check failed. Check dest" % exc + ) + + try: + conn.copy_file( + source=output_file, + destination=dest, + proto=proto, + timeout=sock_timeout, + ) + except Exception as exc: + if to_text(exc) == "No response from server": + if network_os == "iosxr": + # IOSXR sometimes closes socket prematurely after completion + # of file transfer + result[ + "msg" + ] = "Warning: iosxr scp server pre close issue. Please check dest" + else: + result["failed"] = True + result["msg"] = "Exception received: %s" % exc + + if mode == "text": + # Cleanup tmp file expanded wih ansible vars + os.remove(output_file) + + result["changed"] = changed + result["destination"] = dest + return result + + def _handle_existing_file(self, conn, source, dest, proto, timeout): + """ + Determines whether the source and destination file match. + + :return: False if source and dest both exist and have matching sha1 sums, True otherwise. + """ + cwd = self._loader.get_basedir() + filename = str(uuid.uuid4()) + tmp_source_file = os.path.join(cwd, filename) + try: + conn.get_file( + source=dest, + destination=tmp_source_file, + proto=proto, + timeout=timeout, + ) + except ConnectionError as exc: + error = to_text(exc) + if error.endswith("No such file or directory"): + if os.path.exists(tmp_source_file): + os.remove(tmp_source_file) + return True + + try: + with open(source, "r") as f: + new_content = f.read() + with open(tmp_source_file, "r") as f: + old_content = f.read() + except (IOError, OSError): + os.remove(tmp_source_file) + raise + + sha1 = hashlib.sha1() + old_content_b = to_bytes(old_content, errors="surrogate_or_strict") + sha1.update(old_content_b) + checksum_old = sha1.digest() + + sha1 = hashlib.sha1() + new_content_b = to_bytes(new_content, errors="surrogate_or_strict") + sha1.update(new_content_b) + checksum_new = sha1.digest() + os.remove(tmp_source_file) + if checksum_old == checksum_new: + return False + return True + + def _get_binary_src_file(self, src): + working_path = self._get_working_path() + + if os.path.isabs(src) or urlsplit("src").scheme: + source = src + else: + source = self._loader.path_dwim_relative( + working_path, "templates", src + ) + if not source: + source = self._loader.path_dwim_relative(working_path, src) + + if not os.path.exists(source): + raise ValueError("path specified in src not found") + + return source + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py new file mode 100644 index 00000000..5d05d338 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/action/network.py @@ -0,0 +1,209 @@ +# +# (c) 2018 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import time +import re + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.six.moves.urllib.parse import urlsplit +from ansible.plugins.action.normal import ActionModule as _ActionModule +from ansible.utils.display import Display + +display = Display() + +PRIVATE_KEYS_RE = re.compile("__.+__") + + +class ActionModule(_ActionModule): + def run(self, task_vars=None): + config_module = hasattr(self, "_config_module") and self._config_module + if config_module and self._task.args.get("src"): + try: + self._handle_src_option() + except AnsibleError as e: + return {"failed": True, "msg": e.message, "changed": False} + + result = super(ActionModule, self).run(task_vars=task_vars) + + if ( + config_module + and self._task.args.get("backup") + and not result.get("failed") + ): + self._handle_backup_option(result, task_vars) + + return result + + def _handle_backup_option(self, result, task_vars): + + filename = None + backup_path = None + try: + content = result["__backup__"] + except KeyError: + raise AnsibleError("Failed while reading configuration backup") + + backup_options = self._task.args.get("backup_options") + if backup_options: + filename = backup_options.get("filename") + backup_path = backup_options.get("dir_path") + + if not backup_path: + cwd = self._get_working_path() + backup_path = os.path.join(cwd, "backup") + if not filename: + tstamp = time.strftime( + "%Y-%m-%d@%H:%M:%S", time.localtime(time.time()) + ) + filename = "%s_config.%s" % ( + task_vars["inventory_hostname"], + tstamp, + ) + + dest = os.path.join(backup_path, filename) + backup_path = os.path.expanduser( + os.path.expandvars( + to_bytes(backup_path, errors="surrogate_or_strict") + ) + ) + + if not os.path.exists(backup_path): + os.makedirs(backup_path) + + new_task = self._task.copy() + for item in self._task.args: + if not item.startswith("_"): + new_task.args.pop(item, None) + + new_task.args.update(dict(content=content, dest=dest)) + copy_action = self._shared_loader_obj.action_loader.get( + "copy", + task=new_task, + connection=self._connection, + play_context=self._play_context, + loader=self._loader, + templar=self._templar, + shared_loader_obj=self._shared_loader_obj, + ) + copy_result = copy_action.run(task_vars=task_vars) + if copy_result.get("failed"): + result["failed"] = copy_result["failed"] + result["msg"] = copy_result.get("msg") + return + + result["backup_path"] = dest + if copy_result.get("changed", False): + result["changed"] = copy_result["changed"] + + if backup_options and backup_options.get("filename"): + result["date"] = time.strftime( + "%Y-%m-%d", + time.gmtime(os.stat(result["backup_path"]).st_ctime), + ) + result["time"] = time.strftime( + "%H:%M:%S", + time.gmtime(os.stat(result["backup_path"]).st_ctime), + ) + + else: + result["date"] = tstamp.split("@")[0] + result["time"] = tstamp.split("@")[1] + result["shortname"] = result["backup_path"][::-1].split(".", 1)[1][ + ::-1 + ] + result["filename"] = result["backup_path"].split("/")[-1] + + # strip out any keys that have two leading and two trailing + # underscore characters + for key in list(result.keys()): + if PRIVATE_KEYS_RE.match(key): + del result[key] + + def _get_working_path(self): + cwd = self._loader.get_basedir() + if self._task._role is not None: + cwd = self._task._role._role_path + return cwd + + def _handle_src_option(self, convert_data=True): + src = self._task.args.get("src") + working_path = self._get_working_path() + + if os.path.isabs(src) or urlsplit("src").scheme: + source = src + else: + source = self._loader.path_dwim_relative( + working_path, "templates", src + ) + if not source: + source = self._loader.path_dwim_relative(working_path, src) + + if not os.path.exists(source): + raise AnsibleError("path specified in src not found") + + try: + with open(source, "r") as f: + template_data = to_text(f.read()) + except IOError as e: + raise AnsibleError( + "unable to load src file {0}, I/O error({1}): {2}".format( + source, e.errno, e.strerror + ) + ) + + # Create a template search path in the following order: + # [working_path, self_role_path, dependent_role_paths, dirname(source)] + searchpath = [working_path] + if self._task._role is not None: + searchpath.append(self._task._role._role_path) + if hasattr(self._task, "_block:"): + dep_chain = self._task._block.get_dep_chain() + if dep_chain is not None: + for role in dep_chain: + searchpath.append(role._role_path) + searchpath.append(os.path.dirname(source)) + with self._templar.set_temporary_context(searchpath=searchpath): + self._task.args["src"] = self._templar.template( + template_data, convert_data=convert_data + ) + + def _get_network_os(self, task_vars): + if "network_os" in self._task.args and self._task.args["network_os"]: + display.vvvv("Getting network OS from task argument") + network_os = self._task.args["network_os"] + elif self._play_context.network_os: + display.vvvv("Getting network OS from inventory") + network_os = self._play_context.network_os + elif ( + "network_os" in task_vars.get("ansible_facts", {}) + and task_vars["ansible_facts"]["network_os"] + ): + display.vvvv("Getting network OS from fact") + network_os = task_vars["ansible_facts"]["network_os"] + else: + raise AnsibleError( + "ansible_network_os must be specified on this host" + ) + + return network_os diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py new file mode 100644 index 00000000..33938fd1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/become/enable.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """become: enable +short_description: Switch to elevated permissions on a network device +description: +- This become plugins allows elevated permissions on a remote network device. +author: ansible (@core) +options: + become_pass: + description: password + ini: + - section: enable_become_plugin + key: password + vars: + - name: ansible_become_password + - name: ansible_become_pass + - name: ansible_enable_pass + env: + - name: ANSIBLE_BECOME_PASS + - name: ANSIBLE_ENABLE_PASS +notes: +- enable is really implemented in the network connection handler and as such can only + be used with network connections. +- This plugin ignores the 'become_exe' and 'become_user' settings as it uses an API + and not an executable. +""" + +from ansible.plugins.become import BecomeBase + + +class BecomeModule(BecomeBase): + + name = "ansible.netcommon.enable" + + def build_become_command(self, cmd, shell): + # enable is implemented inside the network connection plugins + return cmd diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py new file mode 100644 index 00000000..b063ef0d --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/httpapi.py @@ -0,0 +1,324 @@ +# (c) 2018 Red Hat Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +connection: httpapi +short_description: Use httpapi to run command on network appliances +description: +- This connection plugin provides a connection to remote devices over a HTTP(S)-based + api. +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the HTTP(S) connection + to. + default: inventory_hostname + vars: + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the HTTP(S) connection. + - When unspecified, will pick 80 or 443 based on the value of use_ssl. + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_httpapi_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load the correct httpapi plugin to communicate with the remote device + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the API connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + needed for the device API. + vars: + - name: ansible_password + - name: ansible_httpapi_pass + - name: ansible_httpapi_password + use_ssl: + type: boolean + description: + - Whether to connect using SSL (HTTPS) or not (HTTP). + default: false + vars: + - name: ansible_httpapi_use_ssl + validate_certs: + type: boolean + description: + - Whether to validate SSL certificates + default: true + vars: + - name: ansible_httpapi_validate_certs + use_proxy: + type: boolean + description: + - Whether to use https_proxy for requests. + default: true + vars: + - name: ansible_httpapi_use_proxy + become: + type: boolean + description: + - The become option will instruct the CLI session to attempt privilege escalation + on platforms that support it. Normally this means transitioning from user mode + to C(enable) mode in the CLI session. If become is set to True and the remote + device does not support privilege escalation or the privilege has already been + elevated, then this option is silently ignored. + - Can be configured from the CLI via the C(--become) or C(-b) options. + default: false + ini: + - section: privilege_escalation + key: become + env: + - name: ANSIBLE_BECOME + vars: + - name: ansible_become + become_method: + description: + - This option allows the become method to be specified in for handling privilege + escalation. Typically the become_method value is set to C(enable) but could + be defined as other values. + default: sudo + ini: + - section: privilege_escalation + key: become_method + env: + - name: ANSIBLE_BECOME_METHOD + vars: + - name: ansible_become_method + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this option + as it could create a security vulnerability by logging sensitive information + in log file. + default: false + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" + +from io import BytesIO + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_bytes +from ansible.module_utils.six import PY3 +from ansible.module_utils.six.moves import cPickle +from ansible.module_utils.six.moves.urllib.error import HTTPError, URLError +from ansible.module_utils.urls import open_url +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import httpapi_loader +from ansible.plugins.connection import NetworkConnectionBase, ensure_connect + + +class Connection(NetworkConnectionBase): + """Network API connection""" + + transport = "ansible.netcommon.httpapi" + has_pipelining = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + + self._url = None + self._auth = None + + if self._network_os: + + self.httpapi = httpapi_loader.get(self._network_os, self) + if self.httpapi: + self._sub_plugin = { + "type": "httpapi", + "name": self.httpapi._load_name, + "obj": self.httpapi, + } + self.queue_message( + "vvvv", + "loaded API plugin %s from path %s for network_os %s" + % ( + self.httpapi._load_name, + self.httpapi._original_path, + self._network_os, + ), + ) + else: + raise AnsibleConnectionFailure( + "unable to load API plugin for network_os %s" + % self._network_os + ) + + else: + raise AnsibleConnectionFailure( + "Unable to automatically determine host network os. Please " + "manually configure ansible_network_os value for this host" + ) + self.queue_message("log", "network_os is set to %s" % self._network_os) + + def update_play_context(self, pc_data): + """Updates the play context information for the connection""" + pc_data = to_bytes(pc_data) + if PY3: + pc_data = cPickle.loads(pc_data, encoding="bytes") + else: + pc_data = cPickle.loads(pc_data) + play_context = PlayContext() + play_context.deserialize(pc_data) + + self.queue_message("vvvv", "updating play_context for connection") + if self._play_context.become ^ play_context.become: + self.set_become(play_context) + if play_context.become is True: + self.queue_message("vvvv", "authorizing connection") + else: + self.queue_message("vvvv", "deauthorizing connection") + + self._play_context = play_context + + def _connect(self): + if not self.connected: + protocol = "https" if self.get_option("use_ssl") else "http" + host = self.get_option("host") + port = self.get_option("port") or ( + 443 if protocol == "https" else 80 + ) + self._url = "%s://%s:%s" % (protocol, host, port) + + self.queue_message( + "vvv", + "ESTABLISH HTTP(S) CONNECTFOR USER: %s TO %s" + % (self._play_context.remote_user, self._url), + ) + self.httpapi.set_become(self._play_context) + self._connected = True + + self.httpapi.login( + self.get_option("remote_user"), self.get_option("password") + ) + + def close(self): + """ + Close the active session to the device + """ + # only close the connection if its connected. + if self._connected: + self.queue_message("vvvv", "closing http(s) connection to device") + self.logout() + + super(Connection, self).close() + + @ensure_connect + def send(self, path, data, **kwargs): + """ + Sends the command to the device over api + """ + url_kwargs = dict( + timeout=self.get_option("persistent_command_timeout"), + validate_certs=self.get_option("validate_certs"), + use_proxy=self.get_option("use_proxy"), + headers={}, + ) + url_kwargs.update(kwargs) + if self._auth: + # Avoid modifying passed-in headers + headers = dict(kwargs.get("headers", {})) + headers.update(self._auth) + url_kwargs["headers"] = headers + else: + url_kwargs["force_basic_auth"] = True + url_kwargs["url_username"] = self.get_option("remote_user") + url_kwargs["url_password"] = self.get_option("password") + + try: + url = self._url + path + self._log_messages( + "send url '%s' with data '%s' and kwargs '%s'" + % (url, data, url_kwargs) + ) + response = open_url(url, data=data, **url_kwargs) + except HTTPError as exc: + is_handled = self.handle_httperror(exc) + if is_handled is True: + return self.send(path, data, **kwargs) + elif is_handled is False: + raise + else: + response = is_handled + except URLError as exc: + raise AnsibleConnectionFailure( + "Could not connect to {0}: {1}".format( + self._url + path, exc.reason + ) + ) + + response_buffer = BytesIO() + resp_data = response.read() + self._log_messages("received response: '%s'" % resp_data) + response_buffer.write(resp_data) + + # Try to assign a new auth token if one is given + self._auth = self.update_auth(response, response_buffer) or self._auth + + response_buffer.seek(0) + + return response, response_buffer diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py new file mode 100644 index 00000000..1e2d3caa --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/netconf.py @@ -0,0 +1,404 @@ +# (c) 2016 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +connection: netconf +short_description: Provides a persistent connection using the netconf protocol +description: +- This connection plugin provides a connection to remote devices over the SSH NETCONF + subsystem. This connection plugin is typically used by network devices for sending + and receiving RPC calls over NETCONF. +- Note this connection plugin requires ncclient to be installed on the local Ansible + controller. +requirements: +- ncclient +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the SSH connection + to. + default: inventory_hostname + vars: + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the SSH connection. + default: 830 + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load a device specific netconf plugin. If this option is not configured + (or set to C(auto)), then Ansible will attempt to guess the correct network_os + to use. If it can not guess a network_os correctly it will use C(default). + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the SSH connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + first establishing the SSH connection. + vars: + - name: ansible_password + - name: ansible_ssh_pass + - name: ansible_ssh_password + - name: ansible_netconf_password + private_key_file: + description: + - The private SSH key or certificate file used to authenticate to the remote device + when first establishing the SSH connection. + ini: + - section: defaults + key: private_key_file + env: + - name: ANSIBLE_PRIVATE_KEY_FILE + vars: + - name: ansible_private_key_file + look_for_keys: + default: true + description: + - Enables looking for ssh keys in the usual locations for ssh keys (e.g. :file:`~/.ssh/id_*`). + env: + - name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS + ini: + - section: paramiko_connection + key: look_for_keys + type: boolean + host_key_checking: + description: Set this to "False" if you want to avoid host key checking by the + underlying tools Ansible uses to connect to the host + type: boolean + default: true + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + - name: ANSIBLE_NETCONF_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: paramiko_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + - name: ansible_netconf_host_key_checking + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + netconf_ssh_config: + description: + - This variable is used to enable bastion/jump host with netconf connection. If + set to True the bastion/jump host ssh settings should be present in ~/.ssh/config + file, alternatively it can be set to custom ssh configuration file path to read + the bastion/jump host settings. + ini: + - section: netconf_connection + key: ssh_config + version_added: '2.7' + env: + - name: ANSIBLE_NETCONF_SSH_CONFIG + vars: + - name: ansible_netconf_ssh_config + version_added: '2.7' + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this option + as it could create a security vulnerability by logging sensitive information + in log file. + default: false + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages +""" + +import os +import logging +import json + +from ansible.errors import AnsibleConnectionFailure, AnsibleError +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.parsing.convert_bool import ( + BOOLEANS_TRUE, + BOOLEANS_FALSE, +) +from ansible.plugins.loader import netconf_loader +from ansible.plugins.connection import NetworkConnectionBase, ensure_connect + +try: + from ncclient import manager + from ncclient.operations import RPCError + from ncclient.transport.errors import SSHUnknownHostError + from ncclient.xml_ import to_ele, to_xml + + HAS_NCCLIENT = True + NCCLIENT_IMP_ERR = None +except ( + ImportError, + AttributeError, +) as err: # paramiko and gssapi are incompatible and raise AttributeError not ImportError + HAS_NCCLIENT = False + NCCLIENT_IMP_ERR = err + +logging.getLogger("ncclient").setLevel(logging.INFO) + + +class Connection(NetworkConnectionBase): + """NetConf connections""" + + transport = "ansible.netcommon.netconf" + has_pipelining = False + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + + # If network_os is not specified then set the network os to auto + # This will be used to trigger the use of guess_network_os when connecting. + self._network_os = self._network_os or "auto" + + self.netconf = netconf_loader.get(self._network_os, self) + if self.netconf: + self._sub_plugin = { + "type": "netconf", + "name": self.netconf._load_name, + "obj": self.netconf, + } + self.queue_message( + "vvvv", + "loaded netconf plugin %s from path %s for network_os %s" + % ( + self.netconf._load_name, + self.netconf._original_path, + self._network_os, + ), + ) + else: + self.netconf = netconf_loader.get("default", self) + self._sub_plugin = { + "type": "netconf", + "name": "default", + "obj": self.netconf, + } + self.queue_message( + "display", + "unable to load netconf plugin for network_os %s, falling back to default plugin" + % self._network_os, + ) + + self.queue_message("log", "network_os is set to %s" % self._network_os) + self._manager = None + self.key_filename = None + self._ssh_config = None + + def exec_command(self, cmd, in_data=None, sudoable=True): + """Sends the request to the node and returns the reply + The method accepts two forms of request. The first form is as a byte + string that represents xml string be send over netconf session. + The second form is a json-rpc (2.0) byte string. + """ + if self._manager: + # to_ele operates on native strings + request = to_ele(to_native(cmd, errors="surrogate_or_strict")) + + if request is None: + return "unable to parse request" + + try: + reply = self._manager.rpc(request) + except RPCError as exc: + error = self.internal_error( + data=to_text(to_xml(exc.xml), errors="surrogate_or_strict") + ) + return json.dumps(error) + + return reply.data_xml + else: + return super(Connection, self).exec_command(cmd, in_data, sudoable) + + @property + @ensure_connect + def manager(self): + return self._manager + + def _connect(self): + if not HAS_NCCLIENT: + raise AnsibleError( + "%s: %s" + % ( + missing_required_lib("ncclient"), + to_native(NCCLIENT_IMP_ERR), + ) + ) + + self.queue_message("log", "ssh connection done, starting ncclient") + + allow_agent = True + if self._play_context.password is not None: + allow_agent = False + setattr(self._play_context, "allow_agent", allow_agent) + + self.key_filename = ( + self._play_context.private_key_file + or self.get_option("private_key_file") + ) + if self.key_filename: + self.key_filename = str(os.path.expanduser(self.key_filename)) + + self._ssh_config = self.get_option("netconf_ssh_config") + if self._ssh_config in BOOLEANS_TRUE: + self._ssh_config = True + elif self._ssh_config in BOOLEANS_FALSE: + self._ssh_config = None + + # Try to guess the network_os if the network_os is set to auto + if self._network_os == "auto": + for cls in netconf_loader.all(class_only=True): + network_os = cls.guess_network_os(self) + if network_os: + self.queue_message( + "vvv", "discovered network_os %s" % network_os + ) + self._network_os = network_os + + # If we have tried to detect the network_os but were unable to i.e. network_os is still 'auto' + # then use default as the network_os + + if self._network_os == "auto": + # Network os not discovered. Set it to default + self.queue_message( + "vvv", + "Unable to discover network_os. Falling back to default.", + ) + self._network_os = "default" + try: + ncclient_device_handler = self.netconf.get_option( + "ncclient_device_handler" + ) + except KeyError: + ncclient_device_handler = "default" + self.queue_message( + "vvv", + "identified ncclient device handler: %s." + % ncclient_device_handler, + ) + device_params = {"name": ncclient_device_handler} + + try: + port = self._play_context.port or 830 + self.queue_message( + "vvv", + "ESTABLISH NETCONF SSH CONNECTION FOR USER: %s on PORT %s TO %s WITH SSH_CONFIG = %s" + % ( + self._play_context.remote_user, + port, + self._play_context.remote_addr, + self._ssh_config, + ), + ) + self._manager = manager.connect( + host=self._play_context.remote_addr, + port=port, + username=self._play_context.remote_user, + password=self._play_context.password, + key_filename=self.key_filename, + hostkey_verify=self.get_option("host_key_checking"), + look_for_keys=self.get_option("look_for_keys"), + device_params=device_params, + allow_agent=self._play_context.allow_agent, + timeout=self.get_option("persistent_connect_timeout"), + ssh_config=self._ssh_config, + ) + + self._manager._timeout = self.get_option( + "persistent_command_timeout" + ) + except SSHUnknownHostError as exc: + raise AnsibleConnectionFailure(to_native(exc)) + except ImportError: + raise AnsibleError( + "connection=netconf is not supported on {0}".format( + self._network_os + ) + ) + + if not self._manager.connected: + return 1, b"", b"not connected" + + self.queue_message( + "log", "ncclient manager object created successfully" + ) + + self._connected = True + + super(Connection, self)._connect() + + return ( + 0, + to_bytes(self._manager.session_id, errors="surrogate_or_strict"), + b"", + ) + + def close(self): + if self._manager: + self._manager.close_session() + super(Connection, self).close() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py new file mode 100644 index 00000000..8abcf8e8 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/network_cli.py @@ -0,0 +1,924 @@ +# (c) 2016 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +connection: network_cli +short_description: Use network_cli to run command on network appliances +description: +- This connection plugin provides a connection to remote devices over the SSH and + implements a CLI shell. This connection plugin is typically used by network devices + for sending and receiving CLi commands to network devices. +options: + host: + description: + - Specifies the remote device FQDN or IP address to establish the SSH connection + to. + default: inventory_hostname + vars: + - name: ansible_host + port: + type: int + description: + - Specifies the port on the remote device that listens for connections when establishing + the SSH connection. + default: 22 + ini: + - section: defaults + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + vars: + - name: ansible_port + network_os: + description: + - Configures the device platform network operating system. This value is used + to load the correct terminal and cliconf plugins to communicate with the remote + device. + vars: + - name: ansible_network_os + remote_user: + description: + - The username used to authenticate to the remote device when the SSH connection + is first established. If the remote_user is not specified, the connection will + use the username of the logged in user. + - Can be configured from the CLI via the C(--user) or C(-u) options. + ini: + - section: defaults + key: remote_user + env: + - name: ANSIBLE_REMOTE_USER + vars: + - name: ansible_user + password: + description: + - Configures the user password used to authenticate to the remote device when + first establishing the SSH connection. + vars: + - name: ansible_password + - name: ansible_ssh_pass + - name: ansible_ssh_password + private_key_file: + description: + - The private SSH key or certificate file used to authenticate to the remote device + when first establishing the SSH connection. + ini: + - section: defaults + key: private_key_file + env: + - name: ANSIBLE_PRIVATE_KEY_FILE + vars: + - name: ansible_private_key_file + become: + type: boolean + description: + - The become option will instruct the CLI session to attempt privilege escalation + on platforms that support it. Normally this means transitioning from user mode + to C(enable) mode in the CLI session. If become is set to True and the remote + device does not support privilege escalation or the privilege has already been + elevated, then this option is silently ignored. + - Can be configured from the CLI via the C(--become) or C(-b) options. + default: false + ini: + - section: privilege_escalation + key: become + env: + - name: ANSIBLE_BECOME + vars: + - name: ansible_become + become_method: + description: + - This option allows the become method to be specified in for handling privilege + escalation. Typically the become_method value is set to C(enable) but could + be defined as other values. + default: sudo + ini: + - section: privilege_escalation + key: become_method + env: + - name: ANSIBLE_BECOME_METHOD + vars: + - name: ansible_become_method + host_key_auto_add: + type: boolean + description: + - By default, Ansible will prompt the user before adding SSH keys to the known + hosts file. Since persistent connections such as network_cli run in background + processes, the user will never be prompted. By enabling this option, unknown + host keys will automatically be added to the known hosts file. + - Be sure to fully understand the security implications of enabling this option + on production systems as it could create a security vulnerability. + default: false + ini: + - section: paramiko_connection + key: host_key_auto_add + env: + - name: ANSIBLE_HOST_KEY_AUTO_ADD + persistent_connect_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait when trying to initially + establish a persistent connection. If this value expires before the connection + to the remote device is completed, the connection will fail. + default: 30 + ini: + - section: persistent_connection + key: connect_timeout + env: + - name: ANSIBLE_PERSISTENT_CONNECT_TIMEOUT + vars: + - name: ansible_connect_timeout + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close. + default: 30 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout + persistent_buffer_read_timeout: + type: float + description: + - Configures, in seconds, the amount of time to wait for the data to be read from + Paramiko channel after the command prompt is matched. This timeout value ensures + that command prompt matched is correct and there is no more data left to be + received from remote host. + default: 0.1 + ini: + - section: persistent_connection + key: buffer_read_timeout + env: + - name: ANSIBLE_PERSISTENT_BUFFER_READ_TIMEOUT + vars: + - name: ansible_buffer_read_timeout + persistent_log_messages: + type: boolean + description: + - This flag will enable logging the command executed and response received from + target device in the ansible log file. For this option to work 'log_path' ansible + configuration option is required to be set to a file path with write access. + - Be sure to fully understand the security implications of enabling this option + as it could create a security vulnerability by logging sensitive information + in log file. + default: false + ini: + - section: persistent_connection + key: log_messages + env: + - name: ANSIBLE_PERSISTENT_LOG_MESSAGES + vars: + - name: ansible_persistent_log_messages + terminal_stdout_re: + type: list + elements: dict + description: + - A single regex pattern or a sequence of patterns along with optional flags to + match the command prompt from the received response chunk. This option accepts + C(pattern) and C(flags) keys. The value of C(pattern) is a python regex pattern + to match the response and the value of C(flags) is the value accepted by I(flags) + argument of I(re.compile) python method to control the way regex is matched + with the response, for example I('re.I'). + vars: + - name: ansible_terminal_stdout_re + terminal_stderr_re: + type: list + elements: dict + description: + - This option provides the regex pattern and optional flags to match the error + string from the received response chunk. This option accepts C(pattern) and + C(flags) keys. The value of C(pattern) is a python regex pattern to match the + response and the value of C(flags) is the value accepted by I(flags) argument + of I(re.compile) python method to control the way regex is matched with the + response, for example I('re.I'). + vars: + - name: ansible_terminal_stderr_re + terminal_initial_prompt: + type: list + description: + - A single regex pattern or a sequence of patterns to evaluate the expected prompt + at the time of initial login to the remote host. + vars: + - name: ansible_terminal_initial_prompt + terminal_initial_answer: + type: list + description: + - The answer to reply with if the C(terminal_initial_prompt) is matched. The value + can be a single answer or a list of answers for multiple terminal_initial_prompt. + In case the login menu has multiple prompts the sequence of the prompt and excepted + answer should be in same order and the value of I(terminal_prompt_checkall) + should be set to I(True) if all the values in C(terminal_initial_prompt) are + expected to be matched and set to I(False) if any one login prompt is to be + matched. + vars: + - name: ansible_terminal_initial_answer + terminal_initial_prompt_checkall: + type: boolean + description: + - By default the value is set to I(False) and any one of the prompts mentioned + in C(terminal_initial_prompt) option is matched it won't check for other prompts. + When set to I(True) it will check for all the prompts mentioned in C(terminal_initial_prompt) + option in the given order and all the prompts should be received from remote + host if not it will result in timeout. + default: false + vars: + - name: ansible_terminal_initial_prompt_checkall + terminal_inital_prompt_newline: + type: boolean + description: + - This boolean flag, that when set to I(True) will send newline in the response + if any of values in I(terminal_initial_prompt) is matched. + default: true + vars: + - name: ansible_terminal_initial_prompt_newline + network_cli_retries: + description: + - Number of attempts to connect to remote host. The delay time between the retires + increases after every attempt by power of 2 in seconds till either the maximum + attempts are exhausted or any of the C(persistent_command_timeout) or C(persistent_connect_timeout) + timers are triggered. + default: 3 + type: integer + env: + - name: ANSIBLE_NETWORK_CLI_RETRIES + ini: + - section: persistent_connection + key: network_cli_retries + vars: + - name: ansible_network_cli_retries +""" + +from functools import wraps +import getpass +import json +import logging +import re +import os +import signal +import socket +import time +import traceback +from io import BytesIO + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils.six import PY3 +from ansible.module_utils.six.moves import cPickle +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible.module_utils._text import to_bytes, to_text +from ansible.playbook.play_context import PlayContext +from ansible.plugins.connection import NetworkConnectionBase +from ansible.plugins.loader import ( + cliconf_loader, + terminal_loader, + connection_loader, +) + + +def ensure_connect(func): + @wraps(func) + def wrapped(self, *args, **kwargs): + if not self._connected: + self._connect() + self.update_cli_prompt_context() + return func(self, *args, **kwargs) + + return wrapped + + +class AnsibleCmdRespRecv(Exception): + pass + + +class Connection(NetworkConnectionBase): + """ CLI (shell) SSH connections on Paramiko """ + + transport = "ansible.netcommon.network_cli" + has_pipelining = True + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._ssh_shell = None + + self._matched_prompt = None + self._matched_cmd_prompt = None + self._matched_pattern = None + self._last_response = None + self._history = list() + self._command_response = None + self._last_recv_window = None + + self._terminal = None + self.cliconf = None + self._paramiko_conn = None + + # Managing prompt context + self._check_prompt = False + self._task_uuid = to_text(kwargs.get("task_uuid", "")) + + if self._play_context.verbosity > 3: + logging.getLogger("paramiko").setLevel(logging.DEBUG) + + if self._network_os: + self._terminal = terminal_loader.get(self._network_os, self) + if not self._terminal: + raise AnsibleConnectionFailure( + "network os %s is not supported" % self._network_os + ) + + self.cliconf = cliconf_loader.get(self._network_os, self) + if self.cliconf: + self._sub_plugin = { + "type": "cliconf", + "name": self.cliconf._load_name, + "obj": self.cliconf, + } + self.queue_message( + "vvvv", + "loaded cliconf plugin %s from path %s for network_os %s" + % ( + self.cliconf._load_name, + self.cliconf._original_path, + self._network_os, + ), + ) + else: + self.queue_message( + "vvvv", + "unable to load cliconf for network_os %s" + % self._network_os, + ) + else: + raise AnsibleConnectionFailure( + "Unable to automatically determine host network os. Please " + "manually configure ansible_network_os value for this host" + ) + self.queue_message("log", "network_os is set to %s" % self._network_os) + + @property + def paramiko_conn(self): + if self._paramiko_conn is None: + self._paramiko_conn = connection_loader.get( + "paramiko", self._play_context, "/dev/null" + ) + self._paramiko_conn.set_options( + direct={ + "look_for_keys": not bool( + self._play_context.password + and not self._play_context.private_key_file + ) + } + ) + return self._paramiko_conn + + def _get_log_channel(self): + name = "p=%s u=%s | " % (os.getpid(), getpass.getuser()) + name += "paramiko [%s]" % self._play_context.remote_addr + return name + + @ensure_connect + def get_prompt(self): + """Returns the current prompt from the device""" + return self._matched_prompt + + def exec_command(self, cmd, in_data=None, sudoable=True): + # this try..except block is just to handle the transition to supporting + # network_cli as a toplevel connection. Once connection=local is gone, + # this block can be removed as well and all calls passed directly to + # the local connection + if self._ssh_shell: + try: + cmd = json.loads(to_text(cmd, errors="surrogate_or_strict")) + kwargs = { + "command": to_bytes( + cmd["command"], errors="surrogate_or_strict" + ) + } + for key in ( + "prompt", + "answer", + "sendonly", + "newline", + "prompt_retry_check", + ): + if cmd.get(key) is True or cmd.get(key) is False: + kwargs[key] = cmd[key] + elif cmd.get(key) is not None: + kwargs[key] = to_bytes( + cmd[key], errors="surrogate_or_strict" + ) + return self.send(**kwargs) + except ValueError: + cmd = to_bytes(cmd, errors="surrogate_or_strict") + return self.send(command=cmd) + + else: + return super(Connection, self).exec_command(cmd, in_data, sudoable) + + def update_play_context(self, pc_data): + """Updates the play context information for the connection""" + pc_data = to_bytes(pc_data) + if PY3: + pc_data = cPickle.loads(pc_data, encoding="bytes") + else: + pc_data = cPickle.loads(pc_data) + play_context = PlayContext() + play_context.deserialize(pc_data) + + self.queue_message("vvvv", "updating play_context for connection") + if self._play_context.become ^ play_context.become: + if play_context.become is True: + auth_pass = play_context.become_pass + self._terminal.on_become(passwd=auth_pass) + self.queue_message("vvvv", "authorizing connection") + else: + self._terminal.on_unbecome() + self.queue_message("vvvv", "deauthorizing connection") + + self._play_context = play_context + + if hasattr(self, "reset_history"): + self.reset_history() + if hasattr(self, "disable_response_logging"): + self.disable_response_logging() + + def set_check_prompt(self, task_uuid): + self._check_prompt = task_uuid + + def update_cli_prompt_context(self): + # set cli prompt context at the start of new task run only + if self._check_prompt and self._task_uuid != self._check_prompt: + self._task_uuid, self._check_prompt = self._check_prompt, False + self.set_cli_prompt_context() + + def _connect(self): + """ + Connects to the remote device and starts the terminal + """ + if not self.connected: + self.paramiko_conn._set_log_channel(self._get_log_channel()) + self.paramiko_conn.force_persistence = self.force_persistence + + command_timeout = self.get_option("persistent_command_timeout") + max_pause = min( + [ + self.get_option("persistent_connect_timeout"), + command_timeout, + ] + ) + retries = self.get_option("network_cli_retries") + total_pause = 0 + + for attempt in range(retries + 1): + try: + ssh = self.paramiko_conn._connect() + break + except Exception as e: + pause = 2 ** (attempt + 1) + if attempt == retries or total_pause >= max_pause: + raise AnsibleConnectionFailure( + to_text(e, errors="surrogate_or_strict") + ) + else: + msg = ( + u"network_cli_retry: attempt: %d, caught exception(%s), " + u"pausing for %d seconds" + % ( + attempt + 1, + to_text(e, errors="surrogate_or_strict"), + pause, + ) + ) + + self.queue_message("vv", msg) + time.sleep(pause) + total_pause += pause + continue + + self.queue_message("vvvv", "ssh connection done, setting terminal") + self._connected = True + + self._ssh_shell = ssh.ssh.invoke_shell() + self._ssh_shell.settimeout(command_timeout) + + self.queue_message( + "vvvv", + "loaded terminal plugin for network_os %s" % self._network_os, + ) + + terminal_initial_prompt = ( + self.get_option("terminal_initial_prompt") + or self._terminal.terminal_initial_prompt + ) + terminal_initial_answer = ( + self.get_option("terminal_initial_answer") + or self._terminal.terminal_initial_answer + ) + newline = ( + self.get_option("terminal_inital_prompt_newline") + or self._terminal.terminal_inital_prompt_newline + ) + check_all = ( + self.get_option("terminal_initial_prompt_checkall") or False + ) + + self.receive( + prompts=terminal_initial_prompt, + answer=terminal_initial_answer, + newline=newline, + check_all=check_all, + ) + + if self._play_context.become: + self.queue_message("vvvv", "firing event: on_become") + auth_pass = self._play_context.become_pass + self._terminal.on_become(passwd=auth_pass) + + self.queue_message("vvvv", "firing event: on_open_shell()") + self._terminal.on_open_shell() + + self.queue_message( + "vvvv", "ssh connection has completed successfully" + ) + + return self + + def close(self): + """ + Close the active connection to the device + """ + # only close the connection if its connected. + if self._connected: + self.queue_message("debug", "closing ssh connection to device") + if self._ssh_shell: + self.queue_message("debug", "firing event: on_close_shell()") + self._terminal.on_close_shell() + self._ssh_shell.close() + self._ssh_shell = None + self.queue_message("debug", "cli session is now closed") + + self.paramiko_conn.close() + self._paramiko_conn = None + self.queue_message( + "debug", "ssh connection has been closed successfully" + ) + super(Connection, self).close() + + def receive( + self, + command=None, + prompts=None, + answer=None, + newline=True, + prompt_retry_check=False, + check_all=False, + ): + """ + Handles receiving of output from command + """ + self._matched_prompt = None + self._matched_cmd_prompt = None + recv = BytesIO() + handled = False + command_prompt_matched = False + matched_prompt_window = window_count = 0 + + # set terminal regex values for command prompt and errors in response + self._terminal_stderr_re = self._get_terminal_std_re( + "terminal_stderr_re" + ) + self._terminal_stdout_re = self._get_terminal_std_re( + "terminal_stdout_re" + ) + + cache_socket_timeout = self._ssh_shell.gettimeout() + command_timeout = self.get_option("persistent_command_timeout") + self._validate_timeout_value( + command_timeout, "persistent_command_timeout" + ) + if cache_socket_timeout != command_timeout: + self._ssh_shell.settimeout(command_timeout) + + buffer_read_timeout = self.get_option("persistent_buffer_read_timeout") + self._validate_timeout_value( + buffer_read_timeout, "persistent_buffer_read_timeout" + ) + + self._log_messages("command: %s" % command) + while True: + if command_prompt_matched: + try: + signal.signal( + signal.SIGALRM, self._handle_buffer_read_timeout + ) + signal.setitimer(signal.ITIMER_REAL, buffer_read_timeout) + data = self._ssh_shell.recv(256) + signal.alarm(0) + self._log_messages( + "response-%s: %s" % (window_count + 1, data) + ) + # if data is still received on channel it indicates the prompt string + # is wrongly matched in between response chunks, continue to read + # remaining response. + command_prompt_matched = False + + # restart command_timeout timer + signal.signal(signal.SIGALRM, self._handle_command_timeout) + signal.alarm(command_timeout) + + except AnsibleCmdRespRecv: + # reset socket timeout to global timeout + self._ssh_shell.settimeout(cache_socket_timeout) + return self._command_response + else: + data = self._ssh_shell.recv(256) + self._log_messages( + "response-%s: %s" % (window_count + 1, data) + ) + # when a channel stream is closed, received data will be empty + if not data: + break + + recv.write(data) + offset = recv.tell() - 256 if recv.tell() > 256 else 0 + recv.seek(offset) + + window = self._strip(recv.read()) + self._last_recv_window = window + window_count += 1 + + if prompts and not handled: + handled = self._handle_prompt( + window, prompts, answer, newline, False, check_all + ) + matched_prompt_window = window_count + elif ( + prompts + and handled + and prompt_retry_check + and matched_prompt_window + 1 == window_count + ): + # check again even when handled, if same prompt repeats in next window + # (like in the case of a wrong enable password, etc) indicates + # value of answer is wrong, report this as error. + if self._handle_prompt( + window, + prompts, + answer, + newline, + prompt_retry_check, + check_all, + ): + raise AnsibleConnectionFailure( + "For matched prompt '%s', answer is not valid" + % self._matched_cmd_prompt + ) + + if self._find_prompt(window): + self._last_response = recv.getvalue() + resp = self._strip(self._last_response) + self._command_response = self._sanitize(resp, command) + if buffer_read_timeout == 0.0: + # reset socket timeout to global timeout + self._ssh_shell.settimeout(cache_socket_timeout) + return self._command_response + else: + command_prompt_matched = True + + @ensure_connect + def send( + self, + command, + prompt=None, + answer=None, + newline=True, + sendonly=False, + prompt_retry_check=False, + check_all=False, + ): + """ + Sends the command to the device in the opened shell + """ + if check_all: + prompt_len = len(to_list(prompt)) + answer_len = len(to_list(answer)) + if prompt_len != answer_len: + raise AnsibleConnectionFailure( + "Number of prompts (%s) is not same as that of answers (%s)" + % (prompt_len, answer_len) + ) + try: + cmd = b"%s\r" % command + self._history.append(cmd) + self._ssh_shell.sendall(cmd) + self._log_messages("send command: %s" % cmd) + if sendonly: + return + response = self.receive( + command, prompt, answer, newline, prompt_retry_check, check_all + ) + return to_text(response, errors="surrogate_then_replace") + except (socket.timeout, AttributeError): + self.queue_message("error", traceback.format_exc()) + raise AnsibleConnectionFailure( + "timeout value %s seconds reached while trying to send command: %s" + % (self._ssh_shell.gettimeout(), command.strip()) + ) + + def _handle_buffer_read_timeout(self, signum, frame): + self.queue_message( + "vvvv", + "Response received, triggered 'persistent_buffer_read_timeout' timer of %s seconds" + % self.get_option("persistent_buffer_read_timeout"), + ) + raise AnsibleCmdRespRecv() + + def _handle_command_timeout(self, signum, frame): + msg = ( + "command timeout triggered, timeout value is %s secs.\nSee the timeout setting options in the Network Debug and Troubleshooting Guide." + % self.get_option("persistent_command_timeout") + ) + self.queue_message("log", msg) + raise AnsibleConnectionFailure(msg) + + def _strip(self, data): + """ + Removes ANSI codes from device response + """ + for regex in self._terminal.ansi_re: + data = regex.sub(b"", data) + return data + + def _handle_prompt( + self, + resp, + prompts, + answer, + newline, + prompt_retry_check=False, + check_all=False, + ): + """ + Matches the command prompt and responds + + :arg resp: Byte string containing the raw response from the remote + :arg prompts: Sequence of byte strings that we consider prompts for input + :arg answer: Sequence of Byte string to send back to the remote if we find a prompt. + A carriage return is automatically appended to this string. + :param prompt_retry_check: Bool value for trying to detect more prompts + :param check_all: Bool value to indicate if all the values in prompt sequence should be matched or any one of + given prompt. + :returns: True if a prompt was found in ``resp``. If check_all is True + will True only after all the prompt in the prompts list are matched. False otherwise. + """ + single_prompt = False + if not isinstance(prompts, list): + prompts = [prompts] + single_prompt = True + if not isinstance(answer, list): + answer = [answer] + prompts_regex = [re.compile(to_bytes(r), re.I) for r in prompts] + for index, regex in enumerate(prompts_regex): + match = regex.search(resp) + if match: + self._matched_cmd_prompt = match.group() + self._log_messages( + "matched command prompt: %s" % self._matched_cmd_prompt + ) + + # if prompt_retry_check is enabled to check if same prompt is + # repeated don't send answer again. + if not prompt_retry_check: + prompt_answer = ( + answer[index] if len(answer) > index else answer[0] + ) + self._ssh_shell.sendall(b"%s" % prompt_answer) + if newline: + self._ssh_shell.sendall(b"\r") + prompt_answer += b"\r" + self._log_messages( + "matched command prompt answer: %s" % prompt_answer + ) + if check_all and prompts and not single_prompt: + prompts.pop(0) + answer.pop(0) + return False + return True + return False + + def _sanitize(self, resp, command=None): + """ + Removes elements from the response before returning to the caller + """ + cleaned = [] + for line in resp.splitlines(): + if command and line.strip() == command.strip(): + continue + + for prompt in self._matched_prompt.strip().splitlines(): + if prompt.strip() in line: + break + else: + cleaned.append(line) + return b"\n".join(cleaned).strip() + + def _find_prompt(self, response): + """Searches the buffered response for a matching command prompt + """ + errored_response = None + is_error_message = False + + for regex in self._terminal_stderr_re: + if regex.search(response): + is_error_message = True + + # Check if error response ends with command prompt if not + # receive it buffered prompt + for regex in self._terminal_stdout_re: + match = regex.search(response) + if match: + errored_response = response + self._matched_pattern = regex.pattern + self._matched_prompt = match.group() + self._log_messages( + "matched error regex '%s' from response '%s'" + % (self._matched_pattern, errored_response) + ) + break + + if not is_error_message: + for regex in self._terminal_stdout_re: + match = regex.search(response) + if match: + self._matched_pattern = regex.pattern + self._matched_prompt = match.group() + self._log_messages( + "matched cli prompt '%s' with regex '%s' from response '%s'" + % ( + self._matched_prompt, + self._matched_pattern, + response, + ) + ) + if not errored_response: + return True + + if errored_response: + raise AnsibleConnectionFailure(errored_response) + + return False + + def _validate_timeout_value(self, timeout, timer_name): + if timeout < 0: + raise AnsibleConnectionFailure( + "'%s' timer value '%s' is invalid, value should be greater than or equal to zero." + % (timer_name, timeout) + ) + + def transport_test(self, connect_timeout): + """This method enables wait_for_connection to work. + + As it is used by wait_for_connection, it is called by that module's action plugin, + which is on the controller process, which means that nothing done on this instance + should impact the actual persistent connection... this check is for informational + purposes only and should be properly cleaned up. + """ + + # Force a fresh connect if for some reason we have connected before. + self.close() + self._connect() + self.close() + + def _get_terminal_std_re(self, option): + terminal_std_option = self.get_option(option) + terminal_std_re = [] + + if terminal_std_option: + for item in terminal_std_option: + if "pattern" not in item: + raise AnsibleConnectionFailure( + "'pattern' is a required key for option '%s'," + " received option value is %s" % (option, item) + ) + pattern = br"%s" % to_bytes(item["pattern"]) + flag = item.get("flags", 0) + if flag: + flag = getattr(re, flag.split(".")[1]) + terminal_std_re.append(re.compile(pattern, flag)) + else: + # To maintain backward compatibility + terminal_std_re = getattr(self._terminal, option) + + return terminal_std_re diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py new file mode 100644 index 00000000..b29b4872 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/connection/persistent.py @@ -0,0 +1,97 @@ +# 2017 Red Hat Inc. +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Core Team +connection: persistent +short_description: Use a persistent unix socket for connection +description: +- This is a helper plugin to allow making other connections persistent. +options: + persistent_command_timeout: + type: int + description: + - Configures, in seconds, the amount of time to wait for a command to return from + the remote device. If this timer is exceeded before the command returns, the + connection plugin will raise an exception and close + default: 10 + ini: + - section: persistent_connection + key: command_timeout + env: + - name: ANSIBLE_PERSISTENT_COMMAND_TIMEOUT + vars: + - name: ansible_command_timeout +""" +from ansible.executor.task_executor import start_connection +from ansible.plugins.connection import ConnectionBase +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import Connection as SocketConnection +from ansible.utils.display import Display + +display = Display() + + +class Connection(ConnectionBase): + """ Local based connections """ + + transport = "ansible.netcommon.persistent" + has_pipelining = False + + def __init__(self, play_context, new_stdin, *args, **kwargs): + super(Connection, self).__init__( + play_context, new_stdin, *args, **kwargs + ) + self._task_uuid = to_text(kwargs.get("task_uuid", "")) + + def _connect(self): + self._connected = True + return self + + def exec_command(self, cmd, in_data=None, sudoable=True): + display.vvvv( + "exec_command(), socket_path=%s" % self.socket_path, + host=self._play_context.remote_addr, + ) + connection = SocketConnection(self.socket_path) + out = connection.exec_command(cmd, in_data=in_data, sudoable=sudoable) + return 0, out, "" + + def put_file(self, in_path, out_path): + pass + + def fetch_file(self, in_path, out_path): + pass + + def close(self): + self._connected = False + + def run(self): + """Returns the path of the persistent connection socket. + + Attempts to ensure (within playcontext.timeout seconds) that the + socket path exists. If the path exists (or the timeout has expired), + returns the socket path. + """ + display.vvvv( + "starting connection from persistent connection plugin", + host=self._play_context.remote_addr, + ) + variables = { + "ansible_command_timeout": self.get_option( + "persistent_command_timeout" + ) + } + socket_path = start_connection( + self._play_context, variables, self._task_uuid + ) + display.vvvv( + "local domain socket path is %s" % socket_path, + host=self._play_context.remote_addr, + ) + setattr(self, "_socket_path", socket_path) + return socket_path diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py new file mode 100644 index 00000000..8789075a --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/netconf.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: + host: + description: + - Specifies the DNS host name or address for connecting to the remote device over + the specified transport. The value of host is used as the destination address + for the transport. + type: str + required: true + port: + description: + - Specifies the port to use when building the connection to the remote device. The + port value will default to port 830. + type: int + default: 830 + username: + description: + - Configures the username to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value is + not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME) + will be used instead. + type: str + password: + description: + - Specifies the password to use to authenticate the connection to the remote device. This + value is used to authenticate the SSH session. If the value is not specified + in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) will + be used instead. + type: str + timeout: + description: + - Specifies the timeout in seconds for communicating with the network device for + either connecting or sending commands. If the timeout is exceeded before the + operation is completed, the module will error. + type: int + default: 10 + ssh_keyfile: + description: + - Specifies the SSH key to use to authenticate the connection to the remote device. This + value is the path to the key used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_SSH_KEYFILE) + will be used instead. + type: path + hostkey_verify: + description: + - If set to C(yes), the ssh host key of the device must match a ssh key present + on the host if set to C(no), the ssh host key of the device is not checked. + type: bool + default: true + look_for_keys: + description: + - Enables looking in the usual locations for the ssh keys (e.g. :file:`~/.ssh/id_*`) + type: bool + default: true +notes: +- For information on using netconf see the :ref:`Platform Options guide using Netconf<netconf_enabled_platform_options>` +- For more information on using Ansible to manage network devices see the :ref:`Ansible + Network Guide <network_guide>` +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py new file mode 100644 index 00000000..ad65f6ef --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/doc_fragments/network_agnostic.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019 Ansible, Inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: {} +notes: +- This module is supported on C(ansible_network_os) network platforms. See the :ref:`Network + Platform Options <platform_options>` for details. +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py new file mode 100644 index 00000000..6ae47a73 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/ipaddr.py @@ -0,0 +1,1186 @@ +# (c) 2014, Maciej Delmanowski <drybjed@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from functools import partial +import types + +try: + import netaddr +except ImportError: + # in this case, we'll make the filters return error messages (see bottom) + netaddr = None +else: + + class mac_linux(netaddr.mac_unix): + pass + + mac_linux.word_fmt = "%.2x" + +from ansible import errors + + +# ---- IP address and network query helpers ---- +def _empty_ipaddr_query(v, vtype): + # We don't have any query to process, so just check what type the user + # expects, and return the IP address in a correct format + if v: + if vtype == "address": + return str(v.ip) + elif vtype == "network": + return str(v) + + +def _first_last(v): + if v.size == 2: + first_usable = int(netaddr.IPAddress(v.first)) + last_usable = int(netaddr.IPAddress(v.last)) + return first_usable, last_usable + elif v.size > 1: + first_usable = int(netaddr.IPAddress(v.first + 1)) + last_usable = int(netaddr.IPAddress(v.last - 1)) + return first_usable, last_usable + + +def _6to4_query(v, vtype, value): + if v.version == 4: + + if v.size == 1: + ipconv = str(v.ip) + elif v.size > 1: + if v.ip != v.network: + ipconv = str(v.ip) + else: + ipconv = False + + if ipaddr(ipconv, "public"): + numbers = list(map(int, ipconv.split("."))) + + try: + return "2002:{:02x}{:02x}:{:02x}{:02x}::1/48".format(*numbers) + except Exception: + return False + + elif v.version == 6: + if vtype == "address": + if ipaddr(str(v), "2002::/16"): + return value + elif vtype == "network": + if v.ip != v.network: + if ipaddr(str(v.ip), "2002::/16"): + return value + else: + return False + + +def _ip_query(v): + if v.size == 1: + return str(v.ip) + if v.size > 1: + # /31 networks in netaddr have no broadcast address + if v.ip != v.network or not v.broadcast: + return str(v.ip) + + +def _gateway_query(v): + if v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _address_prefix_query(v): + if v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _bool_ipaddr_query(v): + if v: + return True + + +def _broadcast_query(v): + if v.size > 2: + return str(v.broadcast) + + +def _cidr_query(v): + return str(v) + + +def _cidr_lookup_query(v, iplist, value): + try: + if v in iplist: + return value + except Exception: + return False + + +def _first_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size == 2: + return str(netaddr.IPAddress(int(v.network))) + elif v.size > 1: + return str(netaddr.IPAddress(int(v.network) + 1)) + + +def _host_query(v): + if v.size == 1: + return str(v) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _hostmask_query(v): + return str(v.hostmask) + + +def _int_query(v, vtype): + if vtype == "address": + return int(v.ip) + elif vtype == "network": + return str(int(v.ip)) + "/" + str(int(v.prefixlen)) + + +def _ip_prefix_query(v): + if v.size == 2: + return str(v.ip) + "/" + str(v.prefixlen) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + "/" + str(v.prefixlen) + + +def _ip_netmask_query(v): + if v.size == 2: + return str(v.ip) + " " + str(v.netmask) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + " " + str(v.netmask) + + +""" +def _ip_wildcard_query(v): + if v.size == 2: + return str(v.ip) + ' ' + str(v.hostmask) + elif v.size > 1: + if v.ip != v.network: + return str(v.ip) + ' ' + str(v.hostmask) +""" + + +def _ipv4_query(v, value): + if v.version == 6: + try: + return str(v.ipv4()) + except Exception: + return False + else: + return value + + +def _ipv6_query(v, value): + if v.version == 4: + return str(v.ipv6()) + else: + return value + + +def _last_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + return str(netaddr.IPAddress(last_usable)) + + +def _link_local_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if v.version == 4: + if ipaddr(str(v_ip), "169.254.0.0/24"): + return value + + elif v.version == 6: + if ipaddr(str(v_ip), "fe80::/10"): + return value + + +def _loopback_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if v_ip.is_loopback(): + return value + + +def _multicast_query(v, value): + if v.is_multicast(): + return value + + +def _net_query(v): + if v.size > 1: + if v.ip == v.network: + return str(v.network) + "/" + str(v.prefixlen) + + +def _netmask_query(v): + return str(v.netmask) + + +def _network_query(v): + """Return the network of a given IP or subnet""" + return str(v.network) + + +def _network_id_query(v): + """Return the network of a given IP or subnet""" + return str(v.network) + + +def _network_netmask_query(v): + return str(v.network) + " " + str(v.netmask) + + +def _network_wildcard_query(v): + return str(v.network) + " " + str(v.hostmask) + + +def _next_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + next_ip = int(netaddr.IPAddress(int(v.ip) + 1)) + if next_ip >= first_usable and next_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) + 1)) + + +def _peer_query(v, vtype): + if vtype == "address": + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size == 2: + return str(netaddr.IPAddress(int(v.ip) ^ 1)) + if v.size == 4: + if int(v.ip) % 4 == 0: + raise errors.AnsibleFilterError( + "Network address of /30 has no peer" + ) + if int(v.ip) % 4 == 3: + raise errors.AnsibleFilterError( + "Broadcast address of /30 has no peer" + ) + return str(netaddr.IPAddress(int(v.ip) ^ 3)) + raise errors.AnsibleFilterError("Not a point-to-point network") + + +def _prefix_query(v): + return int(v.prefixlen) + + +def _previous_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + previous_ip = int(netaddr.IPAddress(int(v.ip) - 1)) + if previous_ip >= first_usable and previous_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) - 1)) + + +def _private_query(v, value): + if v.is_private(): + return value + + +def _public_query(v, value): + v_ip = netaddr.IPAddress(str(v.ip)) + if ( + v_ip.is_unicast() + and not v_ip.is_private() + and not v_ip.is_loopback() + and not v_ip.is_netmask() + and not v_ip.is_hostmask() + ): + return value + + +def _range_usable_query(v, vtype): + if vtype == "address": + "Does it make sense to raise an error" + raise errors.AnsibleFilterError("Not a network address") + elif vtype == "network": + if v.size > 1: + first_usable, last_usable = _first_last(v) + first_usable = str(netaddr.IPAddress(first_usable)) + last_usable = str(netaddr.IPAddress(last_usable)) + return "{0}-{1}".format(first_usable, last_usable) + + +def _revdns_query(v): + v_ip = netaddr.IPAddress(str(v.ip)) + return v_ip.reverse_dns + + +def _size_query(v): + return v.size + + +def _size_usable_query(v): + if v.size == 1: + return 0 + elif v.size == 2: + return 2 + return v.size - 2 + + +def _subnet_query(v): + return str(v.cidr) + + +def _type_query(v): + if v.size == 1: + return "address" + if v.size > 1: + if v.ip != v.network: + return "address" + else: + return "network" + + +def _unicast_query(v, value): + if v.is_unicast(): + return value + + +def _version_query(v): + return v.version + + +def _wrap_query(v, vtype, value): + if v.version == 6: + if vtype == "address": + return "[" + str(v.ip) + "]" + elif vtype == "network": + return "[" + str(v.ip) + "]/" + str(v.prefixlen) + else: + return value + + +# ---- HWaddr query helpers ---- +def _bare_query(v): + v.dialect = netaddr.mac_bare + return str(v) + + +def _bool_hwaddr_query(v): + if v: + return True + + +def _int_hwaddr_query(v): + return int(v) + + +def _cisco_query(v): + v.dialect = netaddr.mac_cisco + return str(v) + + +def _empty_hwaddr_query(v, value): + if v: + return value + + +def _linux_query(v): + v.dialect = mac_linux + return str(v) + + +def _postgresql_query(v): + v.dialect = netaddr.mac_pgsql + return str(v) + + +def _unix_query(v): + v.dialect = netaddr.mac_unix + return str(v) + + +def _win_query(v): + v.dialect = netaddr.mac_eui48 + return str(v) + + +# ---- IP address and network filters ---- + +# Returns a minified list of subnets or a single subnet that spans all of +# the inputs. +def cidr_merge(value, action="merge"): + if not hasattr(value, "__iter__"): + raise errors.AnsibleFilterError( + "cidr_merge: expected iterable, got " + repr(value) + ) + + if action == "merge": + try: + return [str(ip) for ip in netaddr.cidr_merge(value)] + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + + elif action == "span": + # spanning_cidr needs at least two values + if len(value) == 0: + return None + elif len(value) == 1: + try: + return str(netaddr.IPNetwork(value[0])) + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + else: + try: + return str(netaddr.spanning_cidr(value)) + except Exception as e: + raise errors.AnsibleFilterError( + "cidr_merge: error in netaddr:\n%s" % e + ) + + else: + raise errors.AnsibleFilterError( + "cidr_merge: invalid action '%s'" % action + ) + + +def ipaddr(value, query="", version=False, alias="ipaddr"): + """ Check if string is an IP address or network and filter it """ + + query_func_extra_args = { + "": ("vtype",), + "6to4": ("vtype", "value"), + "cidr_lookup": ("iplist", "value"), + "first_usable": ("vtype",), + "int": ("vtype",), + "ipv4": ("value",), + "ipv6": ("value",), + "last_usable": ("vtype",), + "link-local": ("value",), + "loopback": ("value",), + "lo": ("value",), + "multicast": ("value",), + "next_usable": ("vtype",), + "peer": ("vtype",), + "previous_usable": ("vtype",), + "private": ("value",), + "public": ("value",), + "unicast": ("value",), + "range_usable": ("vtype",), + "wrap": ("vtype", "value"), + } + + query_func_map = { + "": _empty_ipaddr_query, + "6to4": _6to4_query, + "address": _ip_query, + "address/prefix": _address_prefix_query, # deprecate + "bool": _bool_ipaddr_query, + "broadcast": _broadcast_query, + "cidr": _cidr_query, + "cidr_lookup": _cidr_lookup_query, + "first_usable": _first_usable_query, + "gateway": _gateway_query, # deprecate + "gw": _gateway_query, # deprecate + "host": _host_query, + "host/prefix": _address_prefix_query, # deprecate + "hostmask": _hostmask_query, + "hostnet": _gateway_query, # deprecate + "int": _int_query, + "ip": _ip_query, + "ip/prefix": _ip_prefix_query, + "ip_netmask": _ip_netmask_query, + # 'ip_wildcard': _ip_wildcard_query, built then could not think of use case + "ipv4": _ipv4_query, + "ipv6": _ipv6_query, + "last_usable": _last_usable_query, + "link-local": _link_local_query, + "lo": _loopback_query, + "loopback": _loopback_query, + "multicast": _multicast_query, + "net": _net_query, + "next_usable": _next_usable_query, + "netmask": _netmask_query, + "network": _network_query, + "network_id": _network_id_query, + "network/prefix": _subnet_query, + "network_netmask": _network_netmask_query, + "network_wildcard": _network_wildcard_query, + "peer": _peer_query, + "prefix": _prefix_query, + "previous_usable": _previous_usable_query, + "private": _private_query, + "public": _public_query, + "range_usable": _range_usable_query, + "revdns": _revdns_query, + "router": _gateway_query, # deprecate + "size": _size_query, + "size_usable": _size_usable_query, + "subnet": _subnet_query, + "type": _type_query, + "unicast": _unicast_query, + "v4": _ipv4_query, + "v6": _ipv6_query, + "version": _version_query, + "wildcard": _hostmask_query, + "wrap": _wrap_query, + } + + vtype = None + + if not value: + return False + + elif value is True: + return False + + # Check if value is a list and parse each element + elif isinstance(value, (list, tuple, types.GeneratorType)): + + _ret = [] + for element in value: + if ipaddr(element, str(query), version): + _ret.append(ipaddr(element, str(query), version)) + + if _ret: + return _ret + else: + return list() + + # Check if value is a number and convert it to an IP address + elif str(value).isdigit(): + + # We don't know what IP version to assume, so let's check IPv4 first, + # then IPv6 + try: + if (not version) or (version and version == 4): + v = netaddr.IPNetwork("0.0.0.0/0") + v.value = int(value) + v.prefixlen = 32 + elif version and version == 6: + v = netaddr.IPNetwork("::/0") + v.value = int(value) + v.prefixlen = 128 + + # IPv4 didn't work the first time, so it definitely has to be IPv6 + except Exception: + try: + v = netaddr.IPNetwork("::/0") + v.value = int(value) + v.prefixlen = 128 + + # The value is too big for IPv6. Are you a nanobot? + except Exception: + return False + + # We got an IP address, let's mark it as such + value = str(v) + vtype = "address" + + # value has not been recognized, check if it's a valid IP string + else: + try: + v = netaddr.IPNetwork(value) + + # value is a valid IP string, check if user specified + # CIDR prefix or just an IP address, this will indicate default + # output format + try: + address, prefix = value.split("/") + vtype = "network" + except Exception: + vtype = "address" + + # value hasn't been recognized, maybe it's a numerical CIDR? + except Exception: + try: + address, prefix = value.split("/") + address.isdigit() + address = int(address) + prefix.isdigit() + prefix = int(prefix) + + # It's not numerical CIDR, give up + except Exception: + return False + + # It is something, so let's try and build a CIDR from the parts + try: + v = netaddr.IPNetwork("0.0.0.0/0") + v.value = address + v.prefixlen = prefix + + # It's not a valid IPv4 CIDR + except Exception: + try: + v = netaddr.IPNetwork("::/0") + v.value = address + v.prefixlen = prefix + + # It's not a valid IPv6 CIDR. Give up. + except Exception: + return False + + # We have a valid CIDR, so let's write it in correct format + value = str(v) + vtype = "network" + + # We have a query string but it's not in the known query types. Check if + # that string is a valid subnet, if so, we can check later if given IP + # address/network is inside that specific subnet + try: + # ?? 6to4 and link-local were True here before. Should they still? + if ( + query + and (query not in query_func_map or query == "cidr_lookup") + and not str(query).isdigit() + and ipaddr(query, "network") + ): + iplist = netaddr.IPSet([netaddr.IPNetwork(query)]) + query = "cidr_lookup" + except Exception: + pass + + # This code checks if value maches the IP version the user wants, ie. if + # it's any version ("ipaddr()"), IPv4 ("ipv4()") or IPv6 ("ipv6()") + # If version does not match, return False + if version and v.version != version: + return False + + extras = [] + for arg in query_func_extra_args.get(query, tuple()): + extras.append(locals()[arg]) + try: + return query_func_map[query](v, *extras) + except KeyError: + try: + float(query) + if v.size == 1: + if vtype == "address": + return str(v.ip) + elif vtype == "network": + return str(v) + + elif v.size > 1: + try: + return str(v[query]) + "/" + str(v.prefixlen) + except Exception: + return False + + else: + return value + + except Exception: + raise errors.AnsibleFilterError( + alias + ": unknown filter type: %s" % query + ) + + return False + + +def ipmath(value, amount): + try: + if "/" in value: + ip = netaddr.IPNetwork(value).ip + else: + ip = netaddr.IPAddress(value) + except (netaddr.AddrFormatError, ValueError): + msg = "You must pass a valid IP address; {0} is invalid".format(value) + raise errors.AnsibleFilterError(msg) + + if not isinstance(amount, int): + msg = ( + "You must pass an integer for arithmetic; " + "{0} is not a valid integer" + ).format(amount) + raise errors.AnsibleFilterError(msg) + + return str(ip + amount) + + +def ipwrap(value, query=""): + try: + if isinstance(value, (list, tuple, types.GeneratorType)): + _ret = [] + for element in value: + if ipaddr(element, query, version=False, alias="ipwrap"): + _ret.append(ipaddr(element, "wrap")) + else: + _ret.append(element) + + return _ret + else: + _ret = ipaddr(value, query, version=False, alias="ipwrap") + if _ret: + return ipaddr(_ret, "wrap") + else: + return value + + except Exception: + return value + + +def ipv4(value, query=""): + return ipaddr(value, query, version=4, alias="ipv4") + + +def ipv6(value, query=""): + return ipaddr(value, query, version=6, alias="ipv6") + + +# Split given subnet into smaller subnets or find out the biggest subnet of +# a given IP address with given CIDR prefix +# Usage: +# +# - address or address/prefix | ipsubnet +# returns CIDR subnet of a given input +# +# - address/prefix | ipsubnet(cidr) +# returns number of possible subnets for given CIDR prefix +# +# - address/prefix | ipsubnet(cidr, index) +# returns new subnet with given CIDR prefix +# +# - address | ipsubnet(cidr) +# returns biggest subnet with given CIDR prefix that address belongs to +# +# - address | ipsubnet(cidr, index) +# returns next indexed subnet which contains given address +# +# - address/prefix | ipsubnet(subnet/prefix) +# return the index of the subnet in the subnet +def ipsubnet(value, query="", index="x"): + """ Manipulate IPv4/IPv6 subnets """ + + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + value = netaddr.IPNetwork(v) + except Exception: + return False + query_string = str(query) + if not query: + return str(value) + + elif query_string.isdigit(): + vsize = ipaddr(v, "size") + query = int(query) + + try: + float(index) + index = int(index) + + if vsize > 1: + try: + return str(list(value.subnet(query))[index]) + except Exception: + return False + + elif vsize == 1: + try: + return str(value.supernet(query)[index]) + except Exception: + return False + + except Exception: + if vsize > 1: + try: + return str(len(list(value.subnet(query)))) + except Exception: + return False + + elif vsize == 1: + try: + return str(value.supernet(query)[0]) + except Exception: + return False + + elif query_string: + vtype = ipaddr(query, "type") + if vtype == "address": + v = ipaddr(query, "cidr") + elif vtype == "network": + v = ipaddr(query, "subnet") + else: + msg = "You must pass a valid subnet or IP address; {0} is invalid".format( + query_string + ) + raise errors.AnsibleFilterError(msg) + query = netaddr.IPNetwork(v) + for i, subnet in enumerate(query.subnet(value.prefixlen), 1): + if subnet == value: + return str(i) + msg = "{0} is not in the subnet {1}".format(value.cidr, query.cidr) + raise errors.AnsibleFilterError(msg) + return False + + +# Returns the nth host within a network described by value. +# Usage: +# +# - address or address/prefix | nthhost(nth) +# returns the nth host within the given network +def nthhost(value, query=""): + """ Get the nth host within a given network """ + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + value = netaddr.IPNetwork(v) + except Exception: + return False + + if not query: + return False + + try: + nth = int(query) + if value.size > nth: + return value[nth] + + except ValueError: + return False + + return False + + +# Returns the next nth usable ip within a network described by value. +def next_nth_usable(value, offset): + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + v = netaddr.IPNetwork(v) + except Exception: + return False + + if type(offset) != int: + raise errors.AnsibleFilterError("Must pass in an integer") + if v.size > 1: + first_usable, last_usable = _first_last(v) + nth_ip = int(netaddr.IPAddress(int(v.ip) + offset)) + if nth_ip >= first_usable and nth_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) + offset)) + + +# Returns the previous nth usable ip within a network described by value. +def previous_nth_usable(value, offset): + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + v = netaddr.IPNetwork(v) + except Exception: + return False + + if type(offset) != int: + raise errors.AnsibleFilterError("Must pass in an integer") + if v.size > 1: + first_usable, last_usable = _first_last(v) + nth_ip = int(netaddr.IPAddress(int(v.ip) - offset)) + if nth_ip >= first_usable and nth_ip <= last_usable: + return str(netaddr.IPAddress(int(v.ip) - offset)) + + +def _range_checker(ip_check, first, last): + """ + Tests whether an ip address is within the bounds of the first and last address. + + :param ip_check: The ip to test if it is within first and last. + :param first: The first IP in the range to test against. + :param last: The last IP in the range to test against. + + :return: bool + """ + if ip_check >= first and ip_check <= last: + return True + else: + return False + + +def _address_normalizer(value): + """ + Used to validate an address or network type and return it in a consistent format. + This is being used for future use cases not currently available such as an address range. + + :param value: The string representation of an address or network. + + :return: The address or network in the normalized form. + """ + try: + vtype = ipaddr(value, "type") + if vtype == "address" or vtype == "network": + v = ipaddr(value, "subnet") + except Exception: + return False + + return v + + +def network_in_usable(value, test): + """ + Checks whether 'test' is a useable address or addresses in 'value' + + :param: value: The string representation of an address or network to test against. + :param test: The string representation of an address or network to validate if it is within the range of 'value'. + + :return: bool + """ + # normalize value and test variables into an ipaddr + v = _address_normalizer(value) + w = _address_normalizer(test) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + v_first = ipaddr(ipaddr(v, "first_usable") or ipaddr(v, "address"), "int") + v_last = ipaddr(ipaddr(v, "last_usable") or ipaddr(v, "address"), "int") + w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int") + w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int") + + if _range_checker(w_first, v_first, v_last) and _range_checker( + w_last, v_first, v_last + ): + return True + else: + return False + + +def network_in_network(value, test): + """ + Checks whether the 'test' address or addresses are in 'value', including broadcast and network + + :param: value: The network address or range to test against. + :param test: The address or network to validate if it is within the range of 'value'. + + :return: bool + """ + # normalize value and test variables into an ipaddr + v = _address_normalizer(value) + w = _address_normalizer(test) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + v_first = ipaddr(ipaddr(v, "network") or ipaddr(v, "address"), "int") + v_last = ipaddr(ipaddr(v, "broadcast") or ipaddr(v, "address"), "int") + w_first = ipaddr(ipaddr(w, "network") or ipaddr(w, "address"), "int") + w_last = ipaddr(ipaddr(w, "broadcast") or ipaddr(w, "address"), "int") + + if _range_checker(w_first, v_first, v_last) and _range_checker( + w_last, v_first, v_last + ): + return True + else: + return False + + +def reduce_on_network(value, network): + """ + Reduces a list of addresses to only the addresses that match a given network. + + :param: value: The list of addresses to filter on. + :param: network: The network to validate against. + + :return: The reduced list of addresses. + """ + # normalize network variable into an ipaddr + n = _address_normalizer(network) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + n_first = ipaddr(ipaddr(n, "network") or ipaddr(n, "address"), "int") + n_last = ipaddr(ipaddr(n, "broadcast") or ipaddr(n, "address"), "int") + + # create an empty list to fill and return + r = [] + + for address in value: + # normalize address variables into an ipaddr + a = _address_normalizer(address) + + # get first and last addresses as integers to compare value and test; or cathes value when case is /32 + a_first = ipaddr(ipaddr(a, "network") or ipaddr(a, "address"), "int") + a_last = ipaddr(ipaddr(a, "broadcast") or ipaddr(a, "address"), "int") + + if _range_checker(a_first, n_first, n_last) and _range_checker( + a_last, n_first, n_last + ): + r.append(address) + + return r + + +# Returns the SLAAC address within a network for a given HW/MAC address. +# Usage: +# +# - prefix | slaac(mac) +def slaac(value, query=""): + """ Get the SLAAC address within given network """ + try: + vtype = ipaddr(value, "type") + if vtype == "address": + v = ipaddr(value, "cidr") + elif vtype == "network": + v = ipaddr(value, "subnet") + + if ipaddr(value, "version") != 6: + return False + + value = netaddr.IPNetwork(v) + except Exception: + return False + + if not query: + return False + + try: + mac = hwaddr(query, alias="slaac") + + eui = netaddr.EUI(mac) + except Exception: + return False + + return eui.ipv6(value.network) + + +# ---- HWaddr / MAC address filters ---- +def hwaddr(value, query="", alias="hwaddr"): + """ Check if string is a HW/MAC address and filter it """ + + query_func_extra_args = {"": ("value",)} + + query_func_map = { + "": _empty_hwaddr_query, + "bare": _bare_query, + "bool": _bool_hwaddr_query, + "int": _int_hwaddr_query, + "cisco": _cisco_query, + "eui48": _win_query, + "linux": _linux_query, + "pgsql": _postgresql_query, + "postgresql": _postgresql_query, + "psql": _postgresql_query, + "unix": _unix_query, + "win": _win_query, + } + + try: + v = netaddr.EUI(value) + except Exception: + if query and query != "bool": + raise errors.AnsibleFilterError( + alias + ": not a hardware address: %s" % value + ) + + extras = [] + for arg in query_func_extra_args.get(query, tuple()): + extras.append(locals()[arg]) + try: + return query_func_map[query](v, *extras) + except KeyError: + raise errors.AnsibleFilterError( + alias + ": unknown filter type: %s" % query + ) + + return False + + +def macaddr(value, query=""): + return hwaddr(value, query, alias="macaddr") + + +def _need_netaddr(f_name, *args, **kwargs): + raise errors.AnsibleFilterError( + "The %s filter requires python's netaddr be " + "installed on the ansible controller" % f_name + ) + + +def ip4_hex(arg, delimiter=""): + """ Convert an IPv4 address to Hexadecimal notation """ + numbers = list(map(int, arg.split("."))) + return "{0:02x}{sep}{1:02x}{sep}{2:02x}{sep}{3:02x}".format( + *numbers, sep=delimiter + ) + + +# ---- Ansible filters ---- +class FilterModule(object): + """ IP address and network manipulation filters """ + + filter_map = { + # IP addresses and networks + "cidr_merge": cidr_merge, + "ipaddr": ipaddr, + "ipmath": ipmath, + "ipwrap": ipwrap, + "ip4_hex": ip4_hex, + "ipv4": ipv4, + "ipv6": ipv6, + "ipsubnet": ipsubnet, + "next_nth_usable": next_nth_usable, + "network_in_network": network_in_network, + "network_in_usable": network_in_usable, + "reduce_on_network": reduce_on_network, + "nthhost": nthhost, + "previous_nth_usable": previous_nth_usable, + "slaac": slaac, + # MAC / HW addresses + "hwaddr": hwaddr, + "macaddr": macaddr, + } + + def filters(self): + if netaddr: + return self.filter_map + else: + # Need to install python's netaddr for these filters to work + return dict( + (f, partial(_need_netaddr, f)) for f in self.filter_map + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py new file mode 100644 index 00000000..f99e6e76 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/filter/network.py @@ -0,0 +1,531 @@ +# +# {c) 2017 Red Hat, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import re +import os +import traceback +import string + +from xml.etree.ElementTree import fromstring + +from ansible.module_utils._text import to_native, to_text +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + Template, +) +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils.common._collections_compat import Mapping +from ansible.errors import AnsibleError, AnsibleFilterError +from ansible.utils.display import Display +from ansible.utils.encrypt import passlib_or_crypt, random_password + +try: + import yaml + + HAS_YAML = True +except ImportError: + HAS_YAML = False + +try: + import textfsm + + HAS_TEXTFSM = True +except ImportError: + HAS_TEXTFSM = False + +display = Display() + + +def re_matchall(regex, value): + objects = list() + for match in re.findall(regex.pattern, value, re.M): + obj = {} + if regex.groupindex: + for name, index in iteritems(regex.groupindex): + if len(regex.groupindex) == 1: + obj[name] = match + else: + obj[name] = match[index - 1] + objects.append(obj) + return objects + + +def re_search(regex, value): + obj = {} + match = regex.search(value, re.M) + if match: + items = list(match.groups()) + if regex.groupindex: + for name, index in iteritems(regex.groupindex): + obj[name] = items[index - 1] + return obj + + +def parse_cli(output, tmpl): + if not isinstance(output, string_types): + raise AnsibleError( + "parse_cli input should be a string, but was given a input of %s" + % (type(output)) + ) + + if not os.path.exists(tmpl): + raise AnsibleError("unable to locate parse_cli template: %s" % tmpl) + + try: + template = Template() + except ImportError as exc: + raise AnsibleError(to_native(exc)) + + with open(tmpl) as tmpl_fh: + tmpl_content = tmpl_fh.read() + + spec = yaml.safe_load(tmpl_content) + obj = {} + + for name, attrs in iteritems(spec["keys"]): + value = attrs["value"] + + try: + variables = spec.get("vars", {}) + value = template(value, variables) + except Exception: + pass + + if "start_block" in attrs and "end_block" in attrs: + start_block = re.compile(attrs["start_block"]) + end_block = re.compile(attrs["end_block"]) + + blocks = list() + lines = None + block_started = False + + for line in output.split("\n"): + match_start = start_block.match(line) + match_end = end_block.match(line) + + if match_start: + lines = list() + lines.append(line) + block_started = True + + elif match_end: + if lines: + lines.append(line) + blocks.append("\n".join(lines)) + block_started = False + + elif block_started: + if lines: + lines.append(line) + + regex_items = [re.compile(r) for r in attrs["items"]] + objects = list() + + for block in blocks: + if isinstance(value, Mapping) and "key" not in value: + items = list() + for regex in regex_items: + match = regex.search(block) + if match: + item_values = match.groupdict() + item_values["match"] = list(match.groups()) + items.append(item_values) + else: + items.append(None) + + obj = {} + for k, v in iteritems(value): + try: + obj[k] = template( + v, {"item": items}, fail_on_undefined=False + ) + except Exception: + obj[k] = None + objects.append(obj) + + elif isinstance(value, Mapping): + items = list() + for regex in regex_items: + match = regex.search(block) + if match: + item_values = match.groupdict() + item_values["match"] = list(match.groups()) + items.append(item_values) + else: + items.append(None) + + key = template(value["key"], {"item": items}) + values = dict( + [ + (k, template(v, {"item": items})) + for k, v in iteritems(value["values"]) + ] + ) + objects.append({key: values}) + + return objects + + elif "items" in attrs: + regexp = re.compile(attrs["items"]) + when = attrs.get("when") + conditional = ( + "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when + ) + + if isinstance(value, Mapping) and "key" not in value: + values = list() + + for item in re_matchall(regexp, output): + entry = {} + + for item_key, item_value in iteritems(value): + entry[item_key] = template(item_value, {"item": item}) + + if when: + if template(conditional, {"item": entry}): + values.append(entry) + else: + values.append(entry) + + obj[name] = values + + elif isinstance(value, Mapping): + values = dict() + + for item in re_matchall(regexp, output): + entry = {} + + for item_key, item_value in iteritems(value["values"]): + entry[item_key] = template(item_value, {"item": item}) + + key = template(value["key"], {"item": item}) + + if when: + if template( + conditional, {"item": {"key": key, "value": entry}} + ): + values[key] = entry + else: + values[key] = entry + + obj[name] = values + + else: + item = re_search(regexp, output) + obj[name] = template(value, {"item": item}) + + else: + obj[name] = value + + return obj + + +def parse_cli_textfsm(value, template): + if not HAS_TEXTFSM: + raise AnsibleError( + "parse_cli_textfsm filter requires TextFSM library to be installed" + ) + + if not isinstance(value, string_types): + raise AnsibleError( + "parse_cli_textfsm input should be a string, but was given a input of %s" + % (type(value)) + ) + + if not os.path.exists(template): + raise AnsibleError( + "unable to locate parse_cli_textfsm template: %s" % template + ) + + try: + template = open(template) + except IOError as exc: + raise AnsibleError(to_native(exc)) + + re_table = textfsm.TextFSM(template) + fsm_results = re_table.ParseText(value) + + results = list() + for item in fsm_results: + results.append(dict(zip(re_table.header, item))) + + return results + + +def _extract_param(template, root, attrs, value): + + key = None + when = attrs.get("when") + conditional = "{%% if %s %%}True{%% else %%}False{%% endif %%}" % when + param_to_xpath_map = attrs["items"] + + if isinstance(value, Mapping): + key = value.get("key", None) + if key: + value = value["values"] + + entries = dict() if key else list() + + for element in root.findall(attrs["top"]): + entry = dict() + item_dict = dict() + for param, param_xpath in iteritems(param_to_xpath_map): + fields = None + try: + fields = element.findall(param_xpath) + except Exception: + display.warning( + "Failed to evaluate value of '%s' with XPath '%s'.\nUnexpected error: %s." + % (param, param_xpath, traceback.format_exc()) + ) + + tags = param_xpath.split("/") + + # check if xpath ends with attribute. + # If yes set attribute key/value dict to param value in case attribute matches + # else if it is a normal xpath assign matched element text value. + if len(tags) and tags[-1].endswith("]"): + if fields: + if len(fields) > 1: + item_dict[param] = [field.attrib for field in fields] + else: + item_dict[param] = fields[0].attrib + else: + item_dict[param] = {} + else: + if fields: + if len(fields) > 1: + item_dict[param] = [field.text for field in fields] + else: + item_dict[param] = fields[0].text + else: + item_dict[param] = None + + if isinstance(value, Mapping): + for item_key, item_value in iteritems(value): + entry[item_key] = template(item_value, {"item": item_dict}) + else: + entry = template(value, {"item": item_dict}) + + if key: + expanded_key = template(key, {"item": item_dict}) + if when: + if template( + conditional, + {"item": {"key": expanded_key, "value": entry}}, + ): + entries[expanded_key] = entry + else: + entries[expanded_key] = entry + else: + if when: + if template(conditional, {"item": entry}): + entries.append(entry) + else: + entries.append(entry) + + return entries + + +def parse_xml(output, tmpl): + if not os.path.exists(tmpl): + raise AnsibleError("unable to locate parse_xml template: %s" % tmpl) + + if not isinstance(output, string_types): + raise AnsibleError( + "parse_xml works on string input, but given input of : %s" + % type(output) + ) + + root = fromstring(output) + try: + template = Template() + except ImportError as exc: + raise AnsibleError(to_native(exc)) + + with open(tmpl) as tmpl_fh: + tmpl_content = tmpl_fh.read() + + spec = yaml.safe_load(tmpl_content) + obj = {} + + for name, attrs in iteritems(spec["keys"]): + value = attrs["value"] + + try: + variables = spec.get("vars", {}) + value = template(value, variables) + except Exception: + pass + + if "items" in attrs: + obj[name] = _extract_param(template, root, attrs, value) + else: + obj[name] = value + + return obj + + +def type5_pw(password, salt=None): + if not isinstance(password, string_types): + raise AnsibleFilterError( + "type5_pw password input should be a string, but was given a input of %s" + % (type(password).__name__) + ) + + salt_chars = u"".join( + (to_text(string.ascii_letters), to_text(string.digits), u"./") + ) + if salt is not None and not isinstance(salt, string_types): + raise AnsibleFilterError( + "type5_pw salt input should be a string, but was given a input of %s" + % (type(salt).__name__) + ) + elif not salt: + salt = random_password(length=4, chars=salt_chars) + elif not set(salt) <= set(salt_chars): + raise AnsibleFilterError( + "type5_pw salt used inproper characters, must be one of %s" + % (salt_chars) + ) + + encrypted_password = passlib_or_crypt(password, "md5_crypt", salt=salt) + + return encrypted_password + + +def hash_salt(password): + + split_password = password.split("$") + if len(split_password) != 4: + raise AnsibleFilterError( + "Could not parse salt out password correctly from {0}".format( + password + ) + ) + else: + return split_password[2] + + +def comp_type5( + unencrypted_password, encrypted_password, return_original=False +): + + salt = hash_salt(encrypted_password) + if type5_pw(unencrypted_password, salt) == encrypted_password: + if return_original is True: + return encrypted_password + else: + return True + return False + + +def vlan_parser(vlan_list, first_line_len=48, other_line_len=44): + + """ + Input: Unsorted list of vlan integers + Output: Sorted string list of integers according to IOS-like vlan list rules + + 1. Vlans are listed in ascending order + 2. Runs of 3 or more consecutive vlans are listed with a dash + 3. The first line of the list can be first_line_len characters long + 4. Subsequent list lines can be other_line_len characters + """ + + # Sort and remove duplicates + sorted_list = sorted(set(vlan_list)) + + if sorted_list[0] < 1 or sorted_list[-1] > 4094: + raise AnsibleFilterError("Valid VLAN range is 1-4094") + + parse_list = [] + idx = 0 + while idx < len(sorted_list): + start = idx + end = start + while end < len(sorted_list) - 1: + if sorted_list[end + 1] - sorted_list[end] == 1: + end += 1 + else: + break + + if start == end: + # Single VLAN + parse_list.append(str(sorted_list[idx])) + elif start + 1 == end: + # Run of 2 VLANs + parse_list.append(str(sorted_list[start])) + parse_list.append(str(sorted_list[end])) + else: + # Run of 3 or more VLANs + parse_list.append( + str(sorted_list[start]) + "-" + str(sorted_list[end]) + ) + idx = end + 1 + + line_count = 0 + result = [""] + for vlans in parse_list: + # First line (" switchport trunk allowed vlan ") + if line_count == 0: + if len(result[line_count] + vlans) > first_line_len: + result.append("") + line_count += 1 + result[line_count] += vlans + "," + else: + result[line_count] += vlans + "," + + # Subsequent lines (" switchport trunk allowed vlan add ") + else: + if len(result[line_count] + vlans) > other_line_len: + result.append("") + line_count += 1 + result[line_count] += vlans + "," + else: + result[line_count] += vlans + "," + + # Remove trailing orphan commas + for idx in range(0, len(result)): + result[idx] = result[idx].rstrip(",") + + # Sometimes text wraps to next line, but there are no remaining VLANs + if "" in result: + result.remove("") + + return result + + +class FilterModule(object): + """Filters for working with output from network devices""" + + filter_map = { + "parse_cli": parse_cli, + "parse_cli_textfsm": parse_cli_textfsm, + "parse_xml": parse_xml, + "type5_pw": type5_pw, + "hash_salt": hash_salt, + "comp_type5": comp_type5, + "vlan_parser": vlan_parser, + } + + def filters(self): + return self.filter_map diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py new file mode 100644 index 00000000..8afb3e5e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/httpapi/restconf.py @@ -0,0 +1,91 @@ +# Copyright (c) 2018 Cisco and/or its affiliates. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +httpapi: restconf +short_description: HttpApi Plugin for devices supporting Restconf API +description: +- This HttpApi plugin provides methods to connect to Restconf API endpoints. +options: + root_path: + type: str + description: + - Specifies the location of the Restconf root. + default: /restconf + vars: + - name: ansible_httpapi_restconf_root +""" + +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import ConnectionError +from ansible.module_utils.six.moves.urllib.error import HTTPError +from ansible.plugins.httpapi import HttpApiBase + + +CONTENT_TYPE = "application/yang-data+json" + + +class HttpApi(HttpApiBase): + def send_request(self, data, **message_kwargs): + if data: + data = json.dumps(data) + + path = "/".join( + [ + self.get_option("root_path").rstrip("/"), + message_kwargs.get("path", "").lstrip("/"), + ] + ) + + headers = { + "Content-Type": message_kwargs.get("content_type") or CONTENT_TYPE, + "Accept": message_kwargs.get("accept") or CONTENT_TYPE, + } + response, response_data = self.connection.send( + path, data, headers=headers, method=message_kwargs.get("method") + ) + + return handle_response(response, response_data) + + +def handle_response(response, response_data): + try: + response_data = json.loads(response_data.read()) + except ValueError: + response_data = response_data.read() + + if isinstance(response, HTTPError): + if response_data: + if "errors" in response_data: + errors = response_data["errors"]["error"] + error_text = "\n".join( + (error["error-message"] for error in errors) + ) + else: + error_text = response_data + + raise ConnectionError(error_text, code=response.code) + raise ConnectionError(to_text(response), code=response.code) + + return response_data diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py new file mode 100644 index 00000000..dc0a19f7 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/compat/ipaddress.py @@ -0,0 +1,2578 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file, and this file only, is based on +# Lib/ipaddress.py of cpython +# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved" +# are retained in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. + +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +from __future__ import unicode_literals + + +import itertools +import struct + + +# The following makes it easier for us to script updates of the bundled code and is not part of +# upstream +_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"} + +__version__ = "1.0.22" + +# Compatibility functions +_compat_int_types = (int,) +try: + _compat_int_types = (int, long) +except NameError: + pass +try: + _compat_str = unicode +except NameError: + _compat_str = str + assert bytes != str +if b"\0"[0] == 0: # Python 3 semantics + + def _compat_bytes_to_byte_vals(byt): + return byt + + +else: + + def _compat_bytes_to_byte_vals(byt): + return [struct.unpack(b"!B", b)[0] for b in byt] + + +try: + _compat_int_from_byte_vals = int.from_bytes +except AttributeError: + + def _compat_int_from_byte_vals(bytvals, endianess): + assert endianess == "big" + res = 0 + for bv in bytvals: + assert isinstance(bv, _compat_int_types) + res = (res << 8) + bv + return res + + +def _compat_to_bytes(intval, length, endianess): + assert isinstance(intval, _compat_int_types) + assert endianess == "big" + if length == 4: + if intval < 0 or intval >= 2 ** 32: + raise struct.error("integer out of range for 'I' format code") + return struct.pack(b"!I", intval) + elif length == 16: + if intval < 0 or intval >= 2 ** 128: + raise struct.error("integer out of range for 'QQ' format code") + return struct.pack(b"!QQ", intval >> 64, intval & 0xFFFFFFFFFFFFFFFF) + else: + raise NotImplementedError() + + +if hasattr(int, "bit_length"): + # Not int.bit_length , since that won't work in 2.7 where long exists + def _compat_bit_length(i): + return i.bit_length() + + +else: + + def _compat_bit_length(i): + for res in itertools.count(): + if i >> res == 0: + return res + + +def _compat_range(start, end, step=1): + assert step > 0 + i = start + while i < end: + yield i + i += step + + +class _TotalOrderingMixin(object): + __slots__ = () + + # Helper that derives the other comparison operations from + # __lt__ and __eq__ + # We avoid functools.total_ordering because it doesn't handle + # NotImplemented correctly yet (http://bugs.python.org/issue10042) + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + + def __lt__(self, other): + raise NotImplementedError + + def __le__(self, other): + less = self.__lt__(other) + if less is NotImplemented or not less: + return self.__eq__(other) + return less + + def __gt__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not (less or equal) + + def __ge__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + return not less + + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + "%r does not appear to be an IPv4 or IPv6 address. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" % address + ) + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 address" % address + ) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + "%r does not appear to be an IPv4 or IPv6 network. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" % address + ) + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 network" % address + ) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError( + "%r does not appear to be an IPv4 or IPv6 interface" % address + ) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return _compat_to_bytes(address, 4, "big") + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return _compat_to_bytes(address, 16, "big") + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = _compat_str(address).split("/") + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) # pylint: disable=stop-iteration-return + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, _compat_bit_length(~number & (number - 1))) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if not ( + isinstance(first, _BaseAddress) and isinstance(last, _BaseAddress) + ): + raise TypeError("first and last must be IP addresses, not networks") + if first.version != last.version: + raise TypeError( + "%s and %s are not of the same version" % (first, last) + ) + if first > last: + raise ValueError("last IP address must be greater than first") + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError("unknown IP version") + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min( + _count_righthand_zero_bits(first_int, ip_bits), + _compat_bit_length(last_int - first_int + 1) - 1, + ) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, + # last.network_address <= net.network_address is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, ips[-1]) + ) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, ips[-1]) + ) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError( + "%s and %s are not of the same version" % (ip, nets[-1]) + ) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase(_TotalOrderingMixin): + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return _compat_str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = "%200s has no version specified" % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError( + msg % (address, self._max_prefixlen, self._version) + ) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = ( + "%r (len %d != %d) is not permitted as an IPv%d address. " + "Did you pass in a bytes (str in Python 2) instead of" + " a unicode object?" + ) + raise AddressValueError( + msg % (address, address_len, expected_len, self._version) + ) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits( + ip_int, cls._max_prefixlen + ) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = _compat_to_bytes(ip_int, byteslen, "big") + msg = "Netmask pattern %r mixes zeroes & ones" + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = "%r is not a valid netmask" % netmask_str + raise NetmaskValueError(msg) + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + def __reduce__(self): + return self.__class__, (_compat_str(self),) + + +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + try: + return self._ip == other._ip and self._version == other._version + except AttributeError: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseAddress): + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + if self._ip != other._ip: + return self._ip < other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return _compat_str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + +class _BaseNetwork(_IPAddressBase): + + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + + """ + + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return "%s/%d" % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError("address out of range") + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError("address out of range") + return self._address_class(broadcast + n) + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseNetwork): + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __eq__(self, other): + try: + return ( + self._version == other._version + and self.network_address == other.network_address + and int(self.netmask) == int(other.netmask) + ) + except AttributeError: + return NotImplemented + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return ( + int(self.network_address) + <= int(other._ip) + <= int(self.broadcast_address) + ) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other + or ( + other.network_address in self + or (other.broadcast_address in self) + ) + ) + + @property + def broadcast_address(self): + x = self._cache.get("broadcast_address") + if x is None: + x = self._address_class( + int(self.network_address) | int(self.hostmask) + ) + self._cache["broadcast_address"] = x + return x + + @property + def hostmask(self): + x = self._cache.get("hostmask") + if x is None: + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) + self._cache["hostmask"] = x + return x + + @property + def with_prefixlen(self): + return "%s/%d" % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = "%200s has no associated address class" % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError( + "%s and %s are not of the same version" % (self, other) + ) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not other.subnet_of(self): + raise ValueError("%s not contained in %s" % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__( + "%s/%s" % (other.network_address, other.prefixlen) + ) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if other.subnet_of(s1): + yield s2 + s1, s2 = s1.subnets() + elif other.subnet_of(s2): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError( + "Error performing exclusion: " + "s1: %s s2: %s other: %s" % (s1, s2, other) + ) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError( + "Error performing exclusion: " + "s1: %s s2: %s other: %s" % (s1, s2, other) + ) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError( + "%s and %s are not of the same type" % (self, other) + ) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError("new prefix must be longer") + if prefixlen_diff != 1: + raise ValueError("cannot set prefixlen_diff and new_prefix") + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError("prefix length diff must be > 0") + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + "prefix length diff %d is invalid for netblock %s" + % (new_prefixlen, self) + ) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in _compat_range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError("new prefix must be shorter") + if prefixlen_diff != 1: + raise ValueError("cannot set prefixlen_diff and new_prefix") + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + "current prefixlen is %d, cannot have a prefixlen_diff of %d" + % (self.prefixlen, prefixlen_diff) + ) + return self.__class__( + ( + int(self.network_address) + & (int(self.netmask) << prefixlen_diff), + new_prefixlen, + ) + ) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return ( + self.network_address.is_multicast + and self.broadcast_address.is_multicast + ) + + @staticmethod + def _is_subnet_of(a, b): + try: + # Always false if one is v4 and the other is v6. + if a._version != b._version: + raise TypeError( + "%s and %s are not of the same version" % (a, b) + ) + return ( + b.network_address <= a.network_address + and b.broadcast_address >= a.broadcast_address + ) + except AttributeError: + raise TypeError( + "Unable to test subnet containment " + "between %s and %s" % (a, b) + ) + + def subnet_of(self, other): + """Return True if this network is a subnet of other.""" + return self._is_subnet_of(self, other) + + def supernet_of(self, other): + """Return True if this network is a supernet of other.""" + return self._is_subnet_of(other, self) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return ( + self.network_address.is_reserved + and self.broadcast_address.is_reserved + ) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return ( + self.network_address.is_link_local + and self.broadcast_address.is_link_local + ) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return ( + self.network_address.is_private + and self.broadcast_address.is_private + ) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return ( + self.network_address.is_unspecified + and self.broadcast_address.is_unspecified + ) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return ( + self.network_address.is_loopback + and self.broadcast_address.is_loopback + ) + + +class _BaseV4(object): + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2 ** IPV4LENGTH) - 1 + _DECIMAL_DIGITS = frozenset("0123456789") + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0]) + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return _compat_str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError("Address cannot be empty") + + octets = ip_str.split(".") + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return _compat_int_from_byte_vals( + map(cls._parse_octet, octets), "big" + ) + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._DECIMAL_DIGITS.issuperset(octet_str): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + # Any octets that look like they *might* be written in octal, + # and which don't look exactly the same in both octal and + # decimal are rejected as ambiguous + if octet_int > 7 and octet_str[0] == "0": + msg = "Ambiguous (octal/decimal) value in %r not permitted" + raise ValueError(msg % octet_str) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return ".".join( + _compat_str( + struct.unpack(b"!B", b)[0] if isinstance(b, bytes) else b + ) + for b in _compat_to_bytes(ip_int, 4, "big") + ) + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split(".") + try: + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = _compat_str(self).split(".")[::-1] + return ".".join(reverse_octets) + ".in-addr.arpa" + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ("_ip", "__weakref__") + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, "big") + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if "/" in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + return ( + self not in self._constants._public_network and not self.is_private + ) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv4Address.__init__(self, address) + self.network = IPv4Network(self._ip) + self._prefixlen = self._max_prefixlen + return + + if isinstance(address, tuple): + IPv4Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + + self.network = IPv4Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv4Address.__init__(self, addr[0]) + + self.network = IPv4Network(address, strict=False) + self._prefixlen = self.network._prefixlen + + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + + def __str__(self): + return "%s/%d" % ( + self._string_from_ip_int(self._ip), + self.network.prefixlen, + ) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return ( + self.network < other.network + or self.network == other.network + and address_less + ) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.0.2/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (_compat_int_types, bytes)): + self.network_address = IPv4Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen + ) + # fixme: address/network test here. + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + # We weren't given an address[1] + arg = self._max_prefixlen + self.network_address = IPv4Address(address[0]) + self.netmask, self._prefixlen = self._make_netmask(arg) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError("%s has host bits set" % self) + else: + self.network_address = IPv4Address( + packed & int(self.netmask) + ) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if ( + IPv4Address(int(self.network_address) & int(self.netmask)) + != self.network_address + ): + raise ValueError("%s has host bits set" % self) + self.network_address = IPv4Address( + int(self.network_address) & int(self.netmask) + ) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return ( + not ( + self.network_address in IPv4Network("100.64.0.0/10") + and self.broadcast_address in IPv4Network("100.64.0.0/10") + ) + and not self.is_private + ) + + +class _IPv4Constants(object): + + _linklocal_network = IPv4Network("169.254.0.0/16") + + _loopback_network = IPv4Network("127.0.0.0/8") + + _multicast_network = IPv4Network("224.0.0.0/4") + + _public_network = IPv4Network("100.64.0.0/10") + + _private_networks = [ + IPv4Network("0.0.0.0/8"), + IPv4Network("10.0.0.0/8"), + IPv4Network("127.0.0.0/8"), + IPv4Network("169.254.0.0/16"), + IPv4Network("172.16.0.0/12"), + IPv4Network("192.0.0.0/29"), + IPv4Network("192.0.0.170/31"), + IPv4Network("192.0.2.0/24"), + IPv4Network("192.168.0.0/16"), + IPv4Network("198.18.0.0/15"), + IPv4Network("198.51.100.0/24"), + IPv4Network("203.0.113.0/24"), + IPv4Network("240.0.0.0/4"), + IPv4Network("255.255.255.255/32"), + ] + + _reserved_network = IPv4Network("240.0.0.0/4") + + _unspecified_address = IPv4Address("0.0.0.0") + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6(object): + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2 ** IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset("0123456789ABCDEFabcdef") + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError("Address cannot be empty") + + parts = ip_str.split(":") + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if "." in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + parts.append("%x" % ((ipv4_int >> 16) & 0xFFFF)) + parts.append("%x" % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % ( + _max_parts - 1, + ip_str, + ) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in _compat_range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == "0": + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = ( + best_doublecolon_start + best_doublecolon_len + ) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [""] + hextets[best_doublecolon_start:best_doublecolon_end] = [""] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [""] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError("IPv6 address is too large") + + hex_str = "%032x" % ip_int + hextets = ["%x" % int(hex_str[x : x + 4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ":".join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = _compat_str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = _compat_str(self.ip) + else: + ip_str = _compat_str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = "%032x" % ip_int + parts = [hex_str[x : x + 4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return "%s/%d" % (":".join(parts), self._prefixlen) + return ":".join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(":", "") + return ".".join(reverse_chars) + ".ip6.arpa" + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ("_ip", "__weakref__") + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, "big") + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if "/" in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return ( + IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF), + ) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv6Address.__init__(self, address) + self.network = IPv6Network(self._ip) + self._prefixlen = self._max_prefixlen + return + if isinstance(address, tuple): + IPv6Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv6Address.__init__(self, addr[0]) + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + self.hostmask = self.network.hostmask + + def __str__(self): + return "%s/%d" % ( + self._string_from_ip_int(self._ip), + self.network.prefixlen, + ) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return ( + self.network < other.network + or self.network == other.network + and address_less + ) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self._prefixlen) + + @property + def with_netmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.netmask) + + @property + def with_hostmask(self): + return "%s/%s" % (self._string_from_ip_int(self._ip), self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Efficient constructor from integer or packed address + if isinstance(address, (bytes, _compat_int_types)): + self.network_address = IPv6Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen + ) + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + self.network_address = IPv6Address(address[0]) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError("%s has host bits set" % self) + else: + self.network_address = IPv6Address( + packed & int(self.netmask) + ) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + + self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if ( + IPv6Address(int(self.network_address) & int(self.netmask)) + != self.network_address + ): + raise ValueError("%s has host bits set" % self) + self.network_address = IPv6Address( + int(self.network_address) & int(self.netmask) + ) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return ( + self.network_address.is_site_local + and self.broadcast_address.is_site_local + ) + + +class _IPv6Constants(object): + + _linklocal_network = IPv6Network("fe80::/10") + + _multicast_network = IPv6Network("ff00::/8") + + _private_networks = [ + IPv6Network("::1/128"), + IPv6Network("::/128"), + IPv6Network("::ffff:0:0/96"), + IPv6Network("100::/64"), + IPv6Network("2001::/23"), + IPv6Network("2001:2::/48"), + IPv6Network("2001:db8::/32"), + IPv6Network("2001:10::/28"), + IPv6Network("fc00::/7"), + IPv6Network("fe80::/10"), + ] + + _reserved_networks = [ + IPv6Network("::/8"), + IPv6Network("100::/8"), + IPv6Network("200::/7"), + IPv6Network("400::/6"), + IPv6Network("800::/5"), + IPv6Network("1000::/4"), + IPv6Network("4000::/3"), + IPv6Network("6000::/3"), + IPv6Network("8000::/3"), + IPv6Network("A000::/3"), + IPv6Network("C000::/3"), + IPv6Network("E000::/4"), + IPv6Network("F000::/5"), + IPv6Network("F800::/6"), + IPv6Network("FE00::/9"), + ] + + _sitelocal_network = IPv6Network("fec0::/10") + + +IPv6Address._constants = _IPv6Constants diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py new file mode 100644 index 00000000..68608d1b --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/cfg/base.py @@ -0,0 +1,27 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The base class for all resource modules +""" + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import ( + get_resource_connection, +) + + +class ConfigBase(object): + """ The base class for all resource modules + """ + + ACTION_STATES = ["merged", "replaced", "overridden", "deleted"] + + def __init__(self, module): + self._module = module + self.state = module.params["state"] + self._connection = None + + if self.state not in ["rendered", "parsed"]: + self._connection = get_resource_connection(module) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py new file mode 100644 index 00000000..bc458eb5 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/config.py @@ -0,0 +1,473 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import re +import hashlib + +from ansible.module_utils.six.moves import zip +from ansible.module_utils._text import to_bytes, to_native + +DEFAULT_COMMENT_TOKENS = ["#", "!", "/*", "*/", "echo"] + +DEFAULT_IGNORE_LINES_RE = set( + [ + re.compile(r"Using \d+ out of \d+ bytes"), + re.compile(r"Building configuration"), + re.compile(r"Current configuration : \d+ bytes"), + ] +) + + +try: + Pattern = re._pattern_type +except AttributeError: + Pattern = re.Pattern + + +class ConfigLine(object): + def __init__(self, raw): + self.text = str(raw).strip() + self.raw = raw + self._children = list() + self._parents = list() + + def __str__(self): + return self.raw + + def __eq__(self, other): + return self.line == other.line + + def __ne__(self, other): + return not self.__eq__(other) + + def __getitem__(self, key): + for item in self._children: + if item.text == key: + return item + raise KeyError(key) + + @property + def line(self): + line = self.parents + line.append(self.text) + return " ".join(line) + + @property + def children(self): + return _obj_to_text(self._children) + + @property + def child_objs(self): + return self._children + + @property + def parents(self): + return _obj_to_text(self._parents) + + @property + def path(self): + config = _obj_to_raw(self._parents) + config.append(self.raw) + return "\n".join(config) + + @property + def has_children(self): + return len(self._children) > 0 + + @property + def has_parents(self): + return len(self._parents) > 0 + + def add_child(self, obj): + if not isinstance(obj, ConfigLine): + raise AssertionError("child must be of type `ConfigLine`") + self._children.append(obj) + + +def ignore_line(text, tokens=None): + for item in tokens or DEFAULT_COMMENT_TOKENS: + if text.startswith(item): + return True + for regex in DEFAULT_IGNORE_LINES_RE: + if regex.match(text): + return True + + +def _obj_to_text(x): + return [o.text for o in x] + + +def _obj_to_raw(x): + return [o.raw for o in x] + + +def _obj_to_block(objects, visited=None): + items = list() + for o in objects: + if o not in items: + items.append(o) + for child in o._children: + if child not in items: + items.append(child) + return _obj_to_raw(items) + + +def dumps(objects, output="block", comments=False): + if output == "block": + items = _obj_to_block(objects) + elif output == "commands": + items = _obj_to_text(objects) + elif output == "raw": + items = _obj_to_raw(objects) + else: + raise TypeError("unknown value supplied for keyword output") + + if output == "block": + if comments: + for index, item in enumerate(items): + nextitem = index + 1 + if ( + nextitem < len(items) + and not item.startswith(" ") + and items[nextitem].startswith(" ") + ): + item = "!\n%s" % item + items[index] = item + items.append("!") + items.append("end") + + return "\n".join(items) + + +class NetworkConfig(object): + def __init__(self, indent=1, contents=None, ignore_lines=None): + self._indent = indent + self._items = list() + self._config_text = None + + if ignore_lines: + for item in ignore_lines: + if not isinstance(item, Pattern): + item = re.compile(item) + DEFAULT_IGNORE_LINES_RE.add(item) + + if contents: + self.load(contents) + + @property + def items(self): + return self._items + + @property + def config_text(self): + return self._config_text + + @property + def sha1(self): + sha1 = hashlib.sha1() + sha1.update(to_bytes(str(self), errors="surrogate_or_strict")) + return sha1.digest() + + def __getitem__(self, key): + for line in self: + if line.text == key: + return line + raise KeyError(key) + + def __iter__(self): + return iter(self._items) + + def __str__(self): + return "\n".join([c.raw for c in self.items]) + + def __len__(self): + return len(self._items) + + def load(self, s): + self._config_text = s + self._items = self.parse(s) + + def loadfp(self, fp): + with open(fp) as f: + return self.load(f.read()) + + def parse(self, lines, comment_tokens=None): + toplevel = re.compile(r"\S") + childline = re.compile(r"^\s*(.+)$") + entry_reg = re.compile(r"([{};])") + + ancestors = list() + config = list() + + indents = [0] + + for linenum, line in enumerate( + to_native(lines, errors="surrogate_or_strict").split("\n") + ): + text = entry_reg.sub("", line).strip() + + cfg = ConfigLine(line) + + if not text or ignore_line(text, comment_tokens): + continue + + # handle top level commands + if toplevel.match(line): + ancestors = [cfg] + indents = [0] + + # handle sub level commands + else: + match = childline.match(line) + line_indent = match.start(1) + + if line_indent < indents[-1]: + while indents[-1] > line_indent: + indents.pop() + + if line_indent > indents[-1]: + indents.append(line_indent) + + curlevel = len(indents) - 1 + parent_level = curlevel - 1 + + cfg._parents = ancestors[:curlevel] + + if curlevel > len(ancestors): + config.append(cfg) + continue + + for i in range(curlevel, len(ancestors)): + ancestors.pop() + + ancestors.append(cfg) + ancestors[parent_level].add_child(cfg) + + config.append(cfg) + + return config + + def get_object(self, path): + for item in self.items: + if item.text == path[-1]: + if item.parents == path[:-1]: + return item + + def get_block(self, path): + if not isinstance(path, list): + raise AssertionError("path argument must be a list object") + obj = self.get_object(path) + if not obj: + raise ValueError("path does not exist in config") + return self._expand_block(obj) + + def get_block_config(self, path): + block = self.get_block(path) + return dumps(block, "block") + + def _expand_block(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj._children: + if child in S: + continue + self._expand_block(child, S) + return S + + def _diff_line(self, other): + updates = list() + for item in self.items: + if item not in other: + updates.append(item) + return updates + + def _diff_strict(self, other): + updates = list() + # block extracted from other does not have all parents + # but the last one. In case of multiple parents we need + # to add additional parents. + if other and isinstance(other, list) and len(other) > 0: + start_other = other[0] + if start_other.parents: + for parent in start_other.parents: + other.insert(0, ConfigLine(parent)) + for index, line in enumerate(self.items): + try: + if str(line).strip() != str(other[index]).strip(): + updates.append(line) + except (AttributeError, IndexError): + updates.append(line) + return updates + + def _diff_exact(self, other): + updates = list() + if len(other) != len(self.items): + updates.extend(self.items) + else: + for ours, theirs in zip(self.items, other): + if ours != theirs: + updates.extend(self.items) + break + return updates + + def difference(self, other, match="line", path=None, replace=None): + """Perform a config diff against the another network config + + :param other: instance of NetworkConfig to diff against + :param match: type of diff to perform. valid values are 'line', + 'strict', 'exact' + :param path: context in the network config to filter the diff + :param replace: the method used to generate the replacement lines. + valid values are 'block', 'line' + + :returns: a string of lines that are different + """ + if path and match != "line": + try: + other = other.get_block(path) + except ValueError: + other = list() + else: + other = other.items + + # generate a list of ConfigLines that aren't in other + meth = getattr(self, "_diff_%s" % match) + updates = meth(other) + + if replace == "block": + parents = list() + for item in updates: + if not item.has_parents: + parents.append(item) + else: + for p in item._parents: + if p not in parents: + parents.append(p) + + updates = list() + for item in parents: + updates.extend(self._expand_block(item)) + + visited = set() + expanded = list() + + for item in updates: + for p in item._parents: + if p.line not in visited: + visited.add(p.line) + expanded.append(p) + expanded.append(item) + visited.add(item.line) + + return expanded + + def add(self, lines, parents=None): + ancestors = list() + offset = 0 + obj = None + + # global config command + if not parents: + for line in lines: + # handle ignore lines + if ignore_line(line): + continue + + item = ConfigLine(line) + item.raw = line + if item not in self.items: + self.items.append(item) + + else: + for index, p in enumerate(parents): + try: + i = index + 1 + obj = self.get_block(parents[:i])[0] + ancestors.append(obj) + + except ValueError: + # add parent to config + offset = index * self._indent + obj = ConfigLine(p) + obj.raw = p.rjust(len(p) + offset) + if ancestors: + obj._parents = list(ancestors) + ancestors[-1]._children.append(obj) + self.items.append(obj) + ancestors.append(obj) + + # add child objects + for line in lines: + # handle ignore lines + if ignore_line(line): + continue + + # check if child already exists + for child in ancestors[-1]._children: + if child.text == line: + break + else: + offset = len(parents) * self._indent + item = ConfigLine(line) + item.raw = line.rjust(len(line) + offset) + item._parents = ancestors + ancestors[-1]._children.append(item) + self.items.append(item) + + +class CustomNetworkConfig(NetworkConfig): + def items_text(self): + return [item.text for item in self.items] + + def expand_section(self, configobj, S=None): + if S is None: + S = list() + S.append(configobj) + for child in configobj.child_objs: + if child in S: + continue + self.expand_section(child, S) + return S + + def to_block(self, section): + return "\n".join([item.raw for item in section]) + + def get_section(self, path): + try: + section = self.get_section_objects(path) + return self.to_block(section) + except ValueError: + return list() + + def get_section_objects(self, path): + if not isinstance(path, list): + path = [path] + obj = self.get_object(path) + if not obj: + raise ValueError("path does not exist in config") + return self.expand_section(obj) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py new file mode 100644 index 00000000..477d3184 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/facts/facts.py @@ -0,0 +1,162 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The facts base class +this contains methods common to all facts subsets +""" +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.network import ( + get_resource_connection, +) +from ansible.module_utils.six import iteritems + + +class FactsBase(object): + """ + The facts base class + """ + + def __init__(self, module): + self._module = module + self._warnings = [] + self._gather_subset = module.params.get("gather_subset") + self._gather_network_resources = module.params.get( + "gather_network_resources" + ) + self._connection = None + if module.params.get("state") not in ["rendered", "parsed"]: + self._connection = get_resource_connection(module) + + self.ansible_facts = {"ansible_network_resources": {}} + self.ansible_facts["ansible_net_gather_network_resources"] = list() + self.ansible_facts["ansible_net_gather_subset"] = list() + + if not self._gather_subset: + self._gather_subset = ["!config"] + if not self._gather_network_resources: + self._gather_network_resources = ["!all"] + + def gen_runable(self, subsets, valid_subsets, resource_facts=False): + """ Generate the runable subset + + :param module: The module instance + :param subsets: The provided subsets + :param valid_subsets: The valid subsets + :param resource_facts: A boolean flag + :rtype: list + :returns: The runable subsets + """ + runable_subsets = set() + exclude_subsets = set() + minimal_gather_subset = set() + if not resource_facts: + minimal_gather_subset = frozenset(["default"]) + + for subset in subsets: + if subset == "all": + runable_subsets.update(valid_subsets) + continue + if subset == "min" and minimal_gather_subset: + runable_subsets.update(minimal_gather_subset) + continue + if subset.startswith("!"): + subset = subset[1:] + if subset == "min": + exclude_subsets.update(minimal_gather_subset) + continue + if subset == "all": + exclude_subsets.update( + valid_subsets - minimal_gather_subset + ) + continue + exclude = True + else: + exclude = False + + if subset not in valid_subsets: + self._module.fail_json( + msg="Subset must be one of [%s], got %s" + % ( + ", ".join(sorted([item for item in valid_subsets])), + subset, + ) + ) + + if exclude: + exclude_subsets.add(subset) + else: + runable_subsets.add(subset) + + if not runable_subsets: + runable_subsets.update(valid_subsets) + runable_subsets.difference_update(exclude_subsets) + return runable_subsets + + def get_network_resources_facts( + self, facts_resource_obj_map, resource_facts_type=None, data=None + ): + """ + :param fact_resource_subsets: + :param data: previously collected configuration + :return: + """ + if not resource_facts_type: + resource_facts_type = self._gather_network_resources + + restorun_subsets = self.gen_runable( + resource_facts_type, + frozenset(facts_resource_obj_map.keys()), + resource_facts=True, + ) + if restorun_subsets: + self.ansible_facts["ansible_net_gather_network_resources"] = list( + restorun_subsets + ) + instances = list() + for key in restorun_subsets: + fact_cls_obj = facts_resource_obj_map.get(key) + if fact_cls_obj: + instances.append(fact_cls_obj(self._module)) + else: + self._warnings.extend( + [ + "network resource fact gathering for '%s' is not supported" + % key + ] + ) + + for inst in instances: + inst.populate_facts(self._connection, self.ansible_facts, data) + + def get_network_legacy_facts( + self, fact_legacy_obj_map, legacy_facts_type=None + ): + if not legacy_facts_type: + legacy_facts_type = self._gather_subset + + runable_subsets = self.gen_runable( + legacy_facts_type, frozenset(fact_legacy_obj_map.keys()) + ) + if runable_subsets: + facts = dict() + # default subset should always returned be with legacy facts subsets + if "default" not in runable_subsets: + runable_subsets.add("default") + self.ansible_facts["ansible_net_gather_subset"] = list( + runable_subsets + ) + + instances = list() + for key in runable_subsets: + instances.append(fact_legacy_obj_map[key](self._module)) + + for inst in instances: + inst.populate() + facts.update(inst.facts) + self._warnings.extend(inst.warnings) + + for key, value in iteritems(facts): + key = "ansible_net_%s" % key + self.ansible_facts[key] = value diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py new file mode 100644 index 00000000..53a91e8c --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/netconf.py @@ -0,0 +1,179 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2017 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import sys + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError + +try: + from ncclient.xml_ import NCElement, new_ele, sub_ele + + HAS_NCCLIENT = True +except (ImportError, AttributeError): + HAS_NCCLIENT = False + +try: + from lxml.etree import Element, fromstring, XMLSyntaxError +except ImportError: + from xml.etree.ElementTree import Element, fromstring + + if sys.version_info < (2, 7): + from xml.parsers.expat import ExpatError as XMLSyntaxError + else: + from xml.etree.ElementTree import ParseError as XMLSyntaxError + +NS_MAP = {"nc": "urn:ietf:params:xml:ns:netconf:base:1.0"} + + +def exec_rpc(module, *args, **kwargs): + connection = NetconfConnection(module._socket_path) + return connection.execute_rpc(*args, **kwargs) + + +class NetconfConnection(Connection): + def __init__(self, socket_path): + super(NetconfConnection, self).__init__(socket_path) + + def __rpc__(self, name, *args, **kwargs): + """Executes the json-rpc and returns the output received + from remote device. + :name: rpc method to be executed over connection plugin that implements jsonrpc 2.0 + :args: Ordered list of params passed as arguments to rpc method + :kwargs: Dict of valid key, value pairs passed as arguments to rpc method + + For usage refer the respective connection plugin docs. + """ + self.check_rc = kwargs.pop("check_rc", True) + self.ignore_warning = kwargs.pop("ignore_warning", True) + + response = self._exec_jsonrpc(name, *args, **kwargs) + if "error" in response: + rpc_error = response["error"].get("data") + return self.parse_rpc_error( + to_bytes(rpc_error, errors="surrogate_then_replace") + ) + + return fromstring( + to_bytes(response["result"], errors="surrogate_then_replace") + ) + + def parse_rpc_error(self, rpc_error): + if self.check_rc: + try: + error_root = fromstring(rpc_error) + root = Element("root") + root.append(error_root) + + error_list = root.findall(".//nc:rpc-error", NS_MAP) + if not error_list: + raise ConnectionError( + to_text(rpc_error, errors="surrogate_then_replace") + ) + + warnings = [] + for error in error_list: + message_ele = error.find("./nc:error-message", NS_MAP) + + if message_ele is None: + message_ele = error.find("./nc:error-info", NS_MAP) + + message = ( + message_ele.text if message_ele is not None else None + ) + + severity = error.find("./nc:error-severity", NS_MAP).text + + if ( + severity == "warning" + and self.ignore_warning + and message is not None + ): + warnings.append(message) + else: + raise ConnectionError( + to_text(rpc_error, errors="surrogate_then_replace") + ) + return warnings + except XMLSyntaxError: + raise ConnectionError(rpc_error) + + +def transform_reply(): + return b"""<xsl:stylesheet version="1.0" xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> + <xsl:output method="xml" indent="no"/> + + <xsl:template match="/|comment()|processing-instruction()"> + <xsl:copy> + <xsl:apply-templates/> + </xsl:copy> + </xsl:template> + + <xsl:template match="*"> + <xsl:element name="{local-name()}"> + <xsl:apply-templates select="@*|node()"/> + </xsl:element> + </xsl:template> + + <xsl:template match="@*"> + <xsl:attribute name="{local-name()}"> + <xsl:value-of select="."/> + </xsl:attribute> + </xsl:template> + </xsl:stylesheet> + """ + + +# Note: Workaround for ncclient 0.5.3 +def remove_namespaces(data): + if not HAS_NCCLIENT: + raise ImportError( + "ncclient is required but does not appear to be installed. " + "It can be installed using `pip install ncclient`" + ) + return NCElement(data, transform_reply()).data_xml + + +def build_root_xml_node(tag): + return new_ele(tag) + + +def build_child_xml_node(parent, tag, text=None, attrib=None): + element = sub_ele(parent, tag) + if text: + element.text = to_text(text) + if attrib: + element.attrib.update(attrib) + return element + + +def build_subtree(parent, path): + element = parent + for field in path.split("/"): + sub_element = build_child_xml_node(element, field) + element = sub_element + return element diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py new file mode 100644 index 00000000..555fc713 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/network.py @@ -0,0 +1,275 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com> +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import traceback +import json + +from ansible.module_utils._text import to_text, to_native +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.basic import env_fallback +from ansible.module_utils.connection import Connection, ConnectionError +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import ( + NetconfConnection, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import ( + Cli, +) +from ansible.module_utils.six import iteritems + + +NET_TRANSPORT_ARGS = dict( + host=dict(required=True), + port=dict(type="int"), + username=dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])), + password=dict( + no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]) + ), + ssh_keyfile=dict( + fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path" + ), + authorize=dict( + default=False, + fallback=(env_fallback, ["ANSIBLE_NET_AUTHORIZE"]), + type="bool", + ), + auth_pass=dict( + no_log=True, fallback=(env_fallback, ["ANSIBLE_NET_AUTH_PASS"]) + ), + provider=dict(type="dict", no_log=True), + transport=dict(choices=list()), + timeout=dict(default=10, type="int"), +) + +NET_CONNECTION_ARGS = dict() + +NET_CONNECTIONS = dict() + + +def _transitional_argument_spec(): + argument_spec = {} + for key, value in iteritems(NET_TRANSPORT_ARGS): + value["required"] = False + argument_spec[key] = value + return argument_spec + + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +class ModuleStub(object): + def __init__(self, argument_spec, fail_json): + self.params = dict() + for key, value in argument_spec.items(): + self.params[key] = value.get("default") + self.fail_json = fail_json + + +class NetworkError(Exception): + def __init__(self, msg, **kwargs): + super(NetworkError, self).__init__(msg) + self.kwargs = kwargs + + +class Config(object): + def __init__(self, connection): + self.connection = connection + + def __call__(self, commands, **kwargs): + lines = to_list(commands) + return self.connection.configure(lines, **kwargs) + + def load_config(self, commands, **kwargs): + commands = to_list(commands) + return self.connection.load_config(commands, **kwargs) + + def get_config(self, **kwargs): + return self.connection.get_config(**kwargs) + + def save_config(self): + return self.connection.save_config() + + +class NetworkModule(AnsibleModule): + def __init__(self, *args, **kwargs): + connect_on_load = kwargs.pop("connect_on_load", True) + + argument_spec = NET_TRANSPORT_ARGS.copy() + argument_spec["transport"]["choices"] = NET_CONNECTIONS.keys() + argument_spec.update(NET_CONNECTION_ARGS.copy()) + + if kwargs.get("argument_spec"): + argument_spec.update(kwargs["argument_spec"]) + kwargs["argument_spec"] = argument_spec + + super(NetworkModule, self).__init__(*args, **kwargs) + + self.connection = None + self._cli = None + self._config = None + + try: + transport = self.params["transport"] or "__default__" + cls = NET_CONNECTIONS[transport] + self.connection = cls() + except KeyError: + self.fail_json( + msg="Unknown transport or no default transport specified" + ) + except (TypeError, NetworkError) as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + if connect_on_load: + self.connect() + + @property + def cli(self): + if not self.connected: + self.connect() + if self._cli: + return self._cli + self._cli = Cli(self.connection) + return self._cli + + @property + def config(self): + if not self.connected: + self.connect() + if self._config: + return self._config + self._config = Config(self.connection) + return self._config + + @property + def connected(self): + return self.connection._connected + + def _load_params(self): + super(NetworkModule, self)._load_params() + provider = self.params.get("provider") or dict() + for key, value in provider.items(): + for args in [NET_TRANSPORT_ARGS, NET_CONNECTION_ARGS]: + if key in args: + if self.params.get(key) is None and value is not None: + self.params[key] = value + + def connect(self): + try: + if not self.connected: + self.connection.connect(self.params) + if self.params["authorize"]: + self.connection.authorize(self.params) + self.log( + "connected to %s:%s using %s" + % ( + self.params["host"], + self.params["port"], + self.params["transport"], + ) + ) + except NetworkError as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + def disconnect(self): + try: + if self.connected: + self.connection.disconnect() + self.log("disconnected from %s" % self.params["host"]) + except NetworkError as exc: + self.fail_json( + msg=to_native(exc), exception=traceback.format_exc() + ) + + +def register_transport(transport, default=False): + def register(cls): + NET_CONNECTIONS[transport] = cls + if default: + NET_CONNECTIONS["__default__"] = cls + return cls + + return register + + +def add_argument(key, value): + NET_CONNECTION_ARGS[key] = value + + +def get_resource_connection(module): + if hasattr(module, "_connection"): + return module._connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api in ("cliconf", "nxapi", "eapi", "exosapi"): + module._connection = Connection(module._socket_path) + elif network_api == "netconf": + module._connection = NetconfConnection(module._socket_path) + elif network_api == "local": + # This isn't supported, but we shouldn't fail here. + # Set the connection to a fake connection so it fails sensibly. + module._connection = LocalResourceConnection(module) + else: + module.fail_json( + msg="Invalid connection type {0!s}".format(network_api) + ) + + return module._connection + + +def get_capabilities(module): + if hasattr(module, "capabilities"): + return module._capabilities + try: + capabilities = Connection(module._socket_path).get_capabilities() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + except AssertionError: + # No socket_path, connection most likely local. + return dict(network_api="local") + module._capabilities = json.loads(capabilities) + + return module._capabilities + + +class LocalResourceConnection: + def __init__(self, module): + self.module = module + + def get(self, *args, **kwargs): + self.module.fail_json( + msg="Network resource modules not supported over local connection." + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py new file mode 100644 index 00000000..2dd1de9e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/parsing.py @@ -0,0 +1,316 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2015 Peter Sprygada, <psprygada@ansible.com> +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import re +import shlex +import time + +from ansible.module_utils.parsing.convert_bool import ( + BOOLEANS_TRUE, + BOOLEANS_FALSE, +) +from ansible.module_utils.six import string_types, text_type +from ansible.module_utils.six.moves import zip + + +def to_list(val): + if isinstance(val, (list, tuple)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +class FailedConditionsError(Exception): + def __init__(self, msg, failed_conditions): + super(FailedConditionsError, self).__init__(msg) + self.failed_conditions = failed_conditions + + +class FailedConditionalError(Exception): + def __init__(self, msg, failed_conditional): + super(FailedConditionalError, self).__init__(msg) + self.failed_conditional = failed_conditional + + +class AddCommandError(Exception): + def __init__(self, msg, command): + super(AddCommandError, self).__init__(msg) + self.command = command + + +class AddConditionError(Exception): + def __init__(self, msg, condition): + super(AddConditionError, self).__init__(msg) + self.condition = condition + + +class Cli(object): + def __init__(self, connection): + self.connection = connection + self.default_output = connection.default_output or "text" + self._commands = list() + + @property + def commands(self): + return [str(c) for c in self._commands] + + def __call__(self, commands, output=None): + objects = list() + for cmd in to_list(commands): + objects.append(self.to_command(cmd, output)) + return self.connection.run_commands(objects) + + def to_command( + self, command, output=None, prompt=None, response=None, **kwargs + ): + output = output or self.default_output + if isinstance(command, Command): + return command + if isinstance(prompt, string_types): + prompt = re.compile(re.escape(prompt)) + return Command( + command, output, prompt=prompt, response=response, **kwargs + ) + + def add_commands(self, commands, output=None, **kwargs): + for cmd in commands: + self._commands.append(self.to_command(cmd, output, **kwargs)) + + def run_commands(self): + responses = self.connection.run_commands(self._commands) + for resp, cmd in zip(responses, self._commands): + cmd.response = resp + + # wipe out the commands list to avoid issues if additional + # commands are executed later + self._commands = list() + + return responses + + +class Command(object): + def __init__( + self, command, output=None, prompt=None, response=None, **kwargs + ): + + self.command = command + self.output = output + self.command_string = command + + self.prompt = prompt + self.response = response + + self.args = kwargs + + def __str__(self): + return self.command_string + + +class CommandRunner(object): + def __init__(self, module): + self.module = module + + self.items = list() + self.conditionals = set() + + self.commands = list() + + self.retries = 10 + self.interval = 1 + + self.match = "all" + + self._default_output = module.connection.default_output + + def add_command( + self, command, output=None, prompt=None, response=None, **kwargs + ): + if command in [str(c) for c in self.commands]: + raise AddCommandError( + "duplicated command detected", command=command + ) + cmd = self.module.cli.to_command( + command, output=output, prompt=prompt, response=response, **kwargs + ) + self.commands.append(cmd) + + def get_command(self, command, output=None): + for cmd in self.commands: + if cmd.command == command: + return cmd.response + raise ValueError("command '%s' not found" % command) + + def get_responses(self): + return [cmd.response for cmd in self.commands] + + def add_conditional(self, condition): + try: + self.conditionals.add(Conditional(condition)) + except AttributeError as exc: + raise AddConditionError(msg=str(exc), condition=condition) + + def run(self): + while self.retries > 0: + self.module.cli.add_commands(self.commands) + responses = self.module.cli.run_commands() + + for item in list(self.conditionals): + if item(responses): + if self.match == "any": + return item + self.conditionals.remove(item) + + if not self.conditionals: + break + + time.sleep(self.interval) + self.retries -= 1 + else: + failed_conditions = [item.raw for item in self.conditionals] + errmsg = ( + "One or more conditional statements have not been satisfied" + ) + raise FailedConditionsError(errmsg, failed_conditions) + + +class Conditional(object): + """Used in command modules to evaluate waitfor conditions + """ + + OPERATORS = { + "eq": ["eq", "=="], + "neq": ["neq", "ne", "!="], + "gt": ["gt", ">"], + "ge": ["ge", ">="], + "lt": ["lt", "<"], + "le": ["le", "<="], + "contains": ["contains"], + "matches": ["matches"], + } + + def __init__(self, conditional, encoding=None): + self.raw = conditional + self.negate = False + try: + components = shlex.split(conditional) + key, val = components[0], components[-1] + op_components = components[1:-1] + if "not" in op_components: + self.negate = True + op_components.pop(op_components.index("not")) + op = op_components[0] + + except ValueError: + raise ValueError("failed to parse conditional") + + self.key = key + self.func = self._func(op) + self.value = self._cast_value(val) + + def __call__(self, data): + value = self.get_value(dict(result=data)) + if not self.negate: + return self.func(value) + else: + return not self.func(value) + + def _cast_value(self, value): + if value in BOOLEANS_TRUE: + return True + elif value in BOOLEANS_FALSE: + return False + elif re.match(r"^\d+\.d+$", value): + return float(value) + elif re.match(r"^\d+$", value): + return int(value) + else: + return text_type(value) + + def _func(self, oper): + for func, operators in self.OPERATORS.items(): + if oper in operators: + return getattr(self, func) + raise AttributeError("unknown operator: %s" % oper) + + def get_value(self, result): + try: + return self.get_json(result) + except (IndexError, TypeError, AttributeError): + msg = "unable to apply conditional to result" + raise FailedConditionalError(msg, self.raw) + + def get_json(self, result): + string = re.sub(r"\[[\'|\"]", ".", self.key) + string = re.sub(r"[\'|\"]\]", ".", string) + parts = re.split(r"\.(?=[^\]]*(?:\[|$))", string) + for part in parts: + match = re.findall(r"\[(\S+?)\]", part) + if match: + key = part[: part.find("[")] + result = result[key] + for m in match: + try: + m = int(m) + except ValueError: + m = str(m) + result = result[m] + else: + result = result.get(part) + return result + + def number(self, value): + if "." in str(value): + return float(value) + else: + return int(value) + + def eq(self, value): + return value == self.value + + def neq(self, value): + return value != self.value + + def gt(self, value): + return self.number(value) > self.value + + def ge(self, value): + return self.number(value) >= self.value + + def lt(self, value): + return self.number(value) < self.value + + def le(self, value): + return self.number(value) <= self.value + + def contains(self, value): + return str(self.value) in value + + def matches(self, value): + match = re.search(self.value, value, re.M) + return match is not None diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py new file mode 100644 index 00000000..64eca157 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/common/utils.py @@ -0,0 +1,686 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Networking tools for network modules only + +import re +import ast +import operator +import socket +import json + +from itertools import chain + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils import basic +from ansible.module_utils.parsing.convert_bool import boolean + +# Backwards compatibility for 3rd party modules +# TODO(pabelanger): With move to ansible.netcommon, we should clean this code +# up and have modules import directly themself. +from ansible.module_utils.common.network import ( # noqa: F401 + to_bits, + is_netmask, + is_masklen, + to_netmask, + to_masklen, + to_subnet, + to_ipv6_network, + VALID_MASKS, +) + +try: + from jinja2 import Environment, StrictUndefined + from jinja2.exceptions import UndefinedError + + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False + + +OPERATORS = frozenset(["ge", "gt", "eq", "neq", "lt", "le"]) +ALIASES = frozenset( + [("min", "ge"), ("max", "le"), ("exactly", "eq"), ("neq", "ne")] +) + + +def to_list(val): + if isinstance(val, (list, tuple, set)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +def to_lines(stdout): + for item in stdout: + if isinstance(item, string_types): + item = to_text(item).split("\n") + yield item + + +def transform_commands(module): + transform = ComplexList( + dict( + command=dict(key=True), + output=dict(), + prompt=dict(type="list"), + answer=dict(type="list"), + newline=dict(type="bool", default=True), + sendonly=dict(type="bool", default=False), + check_all=dict(type="bool", default=False), + ), + module, + ) + + return transform(module.params["commands"]) + + +def sort_list(val): + if isinstance(val, list): + return sorted(val) + return val + + +class Entity(object): + """Transforms a dict to with an argument spec + + This class will take a dict and apply an Ansible argument spec to the + values. The resulting dict will contain all of the keys in the param + with appropriate values set. + + Example:: + + argument_spec = dict( + command=dict(key=True), + display=dict(default='text', choices=['text', 'json']), + validate=dict(type='bool') + ) + transform = Entity(module, argument_spec) + value = dict(command='foo') + result = transform(value) + print result + {'command': 'foo', 'display': 'text', 'validate': None} + + Supported argument spec: + * key - specifies how to map a single value to a dict + * read_from - read and apply the argument_spec from the module + * required - a value is required + * type - type of value (uses AnsibleModule type checker) + * fallback - implements fallback function + * choices - set of valid options + * default - default value + """ + + def __init__( + self, module, attrs=None, args=None, keys=None, from_argspec=False + ): + args = [] if args is None else args + + self._attributes = attrs or {} + self._module = module + + for arg in args: + self._attributes[arg] = dict() + if from_argspec: + self._attributes[arg]["read_from"] = arg + if keys and arg in keys: + self._attributes[arg]["key"] = True + + self.attr_names = frozenset(self._attributes.keys()) + + _has_key = False + + for name, attr in iteritems(self._attributes): + if attr.get("read_from"): + if attr["read_from"] not in self._module.argument_spec: + module.fail_json( + msg="argument %s does not exist" % attr["read_from"] + ) + spec = self._module.argument_spec.get(attr["read_from"]) + for key, value in iteritems(spec): + if key not in attr: + attr[key] = value + + if attr.get("key"): + if _has_key: + module.fail_json(msg="only one key value can be specified") + _has_key = True + attr["required"] = True + + def serialize(self): + return self._attributes + + def to_dict(self, value): + obj = {} + for name, attr in iteritems(self._attributes): + if attr.get("key"): + obj[name] = value + else: + obj[name] = attr.get("default") + return obj + + def __call__(self, value, strict=True): + if not isinstance(value, dict): + value = self.to_dict(value) + + if strict: + unknown = set(value).difference(self.attr_names) + if unknown: + self._module.fail_json( + msg="invalid keys: %s" % ",".join(unknown) + ) + + for name, attr in iteritems(self._attributes): + if value.get(name) is None: + value[name] = attr.get("default") + + if attr.get("fallback") and not value.get(name): + fallback = attr.get("fallback", (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + value[name] = fallback_strategy( + *fallback_args, **fallback_kwargs + ) + except basic.AnsibleFallbackNotFound: + continue + + if attr.get("required") and value.get(name) is None: + self._module.fail_json( + msg="missing required attribute %s" % name + ) + + if "choices" in attr: + if value[name] not in attr["choices"]: + self._module.fail_json( + msg="%s must be one of %s, got %s" + % (name, ", ".join(attr["choices"]), value[name]) + ) + + if value[name] is not None: + value_type = attr.get("type", "str") + type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[ + value_type + ] + type_checker(value[name]) + elif value.get(name): + value[name] = self._module.params[name] + + return value + + +class EntityCollection(Entity): + """Extends ```Entity``` to handle a list of dicts """ + + def __call__(self, iterable, strict=True): + if iterable is None: + iterable = [ + super(EntityCollection, self).__call__( + self._module.params, strict + ) + ] + + if not isinstance(iterable, (list, tuple)): + self._module.fail_json(msg="value must be an iterable") + + return [ + (super(EntityCollection, self).__call__(i, strict)) + for i in iterable + ] + + +# these two are for backwards compatibility and can be removed once all of the +# modules that use them are updated +class ComplexDict(Entity): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexDict, self).__init__(module, attrs, *args, **kwargs) + + +class ComplexList(EntityCollection): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexList, self).__init__(module, attrs, *args, **kwargs) + + +def dict_diff(base, comparable): + """ Generate a dict object of differences + + This function will compare two dict objects and return the difference + between them as a dict object. For scalar values, the key will reflect + the updated value. If the key does not exist in `comparable`, then then no + key will be returned. For lists, the value in comparable will wholly replace + the value in base for the key. For dicts, the returned value will only + return keys that are different. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + + :returns: new dict object with differences + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(comparable, dict): + if comparable is None: + comparable = dict() + else: + raise AssertionError("`comparable` must be of type <dict>") + + updates = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + item = comparable.get(key) + if item is not None: + sub_diff = dict_diff(value, comparable[key]) + if sub_diff: + updates[key] = sub_diff + else: + comparable_value = comparable.get(key) + if comparable_value is not None: + if sort_list(base[key]) != sort_list(comparable_value): + updates[key] = comparable_value + + for key in set(comparable.keys()).difference(base.keys()): + updates[key] = comparable.get(key) + + return updates + + +def dict_merge(base, other): + """ Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. If the value is a list object, the two lists will + be combined and duplicate entries removed. + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(other, dict): + raise AssertionError("`other` must be of type <dict>") + + combined = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + if isinstance(other[key], Mapping): + combined[key] = dict_merge(value, other[key]) + else: + combined[key] = other[key] + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if item is not None: + try: + combined[key] = list(set(chain(value, item))) + except TypeError: + value.extend([i for i in item if i not in value]) + combined[key] = value + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) + + return combined + + +def param_list_to_dict(param_list, unique_key="name", remove_key=True): + """Rotates a list of dictionaries to be a dictionary of dictionaries. + + :param param_list: The aforementioned list of dictionaries + :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value + behind this key will be the key each dictionary can be found at in the new root dictionary + :param remove_key: If True, remove unique_key from the individual dictionaries before returning. + """ + param_dict = {} + for params in param_list: + params = params.copy() + if remove_key: + name = params.pop(unique_key) + else: + name = params.get(unique_key) + param_dict[name] = params + + return param_dict + + +def conditional(expr, val, cast=None): + match = re.match(r"^(.+)\((.+)\)$", str(expr), re.I) + if match: + op, arg = match.groups() + else: + op = "eq" + if " " in str(expr): + raise AssertionError("invalid expression: cannot contain spaces") + arg = expr + + if cast is None and val is not None: + arg = type(val)(arg) + elif callable(cast): + arg = cast(arg) + val = cast(val) + + op = next((oper for alias, oper in ALIASES if op == alias), op) + + if not hasattr(operator, op) and op not in OPERATORS: + raise ValueError("unknown operator: %s" % op) + + func = getattr(operator, op) + return func(val, arg) + + +def ternary(value, true_val, false_val): + """ value ? true_val : false_val """ + if value: + return true_val + else: + return false_val + + +def remove_default_spec(spec): + for item in spec: + if "default" in spec[item]: + del spec[item]["default"] + + +def validate_ip_address(address): + try: + socket.inet_aton(address) + except socket.error: + return False + return address.count(".") == 3 + + +def validate_ip_v6_address(address): + try: + socket.inet_pton(socket.AF_INET6, address) + except socket.error: + return False + return True + + +def validate_prefix(prefix): + if prefix and not 0 <= int(prefix) <= 32: + return False + return True + + +def load_provider(spec, args): + provider = args.get("provider") or {} + for key, value in iteritems(spec): + if key not in provider: + if "fallback" in value: + provider[key] = _fallback(value["fallback"]) + elif "default" in value: + provider[key] = value["default"] + else: + provider[key] = None + if "authorize" in provider: + # Coerce authorize to provider if a string has somehow snuck in. + provider["authorize"] = boolean(provider["authorize"] or False) + args["provider"] = provider + return provider + + +def _fallback(fallback): + strategy = fallback[0] + args = [] + kwargs = {} + + for item in fallback[1:]: + if isinstance(item, dict): + kwargs = item + else: + args = item + try: + return strategy(*args, **kwargs) + except basic.AnsibleFallbackNotFound: + pass + + +def generate_dict(spec): + """ + Generate dictionary which is in sync with argspec + + :param spec: A dictionary that is the argspec of the module + :rtype: A dictionary + :returns: A dictionary in sync with argspec with default value + """ + obj = {} + if not spec: + return obj + + for key, val in iteritems(spec): + if "default" in val: + dct = {key: val["default"]} + elif "type" in val and val["type"] == "dict": + dct = {key: generate_dict(val["options"])} + else: + dct = {key: None} + obj.update(dct) + return obj + + +def parse_conf_arg(cfg, arg): + """ + Parse config based on argument + + :param cfg: A text string which is a line of configuration. + :param arg: A text string which is to be matched. + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r"%s (.+)(\n|$)" % arg, cfg, re.M) + if match: + result = match.group(1).strip() + else: + result = None + return result + + +def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str="no"): + """ + Parse config based on command + + :param cfg: A text string which is a line of configuration. + :param cmd: A text string which is the command to be matched + :param res1: A text string to be returned if the command is present + :param res2: A text string to be returned if the negate command + is present + :param delete_str: A text string to identify the start of the + negate command + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r"\n\s+%s(\n|$)" % cmd, cfg) + if match: + return res1 + if res2 is not None: + match = re.search(r"\n\s+%s %s(\n|$)" % (delete_str, cmd), cfg) + if match: + return res2 + return None + + +def get_xml_conf_arg(cfg, path, data="text"): + """ + :param cfg: The top level configuration lxml Element tree object + :param path: The relative xpath w.r.t to top level element (cfg) + to be searched in the xml hierarchy + :param data: The type of data to be returned for the matched xml node. + Valid values are text, tag, attrib, with default as text. + :return: Returns the required type for the matched xml node or else None + """ + match = cfg.xpath(path) + if len(match): + if data == "tag": + result = getattr(match[0], "tag") + elif data == "attrib": + result = getattr(match[0], "attrib") + else: + result = getattr(match[0], "text") + else: + result = None + return result + + +def remove_empties(cfg_dict): + """ + Generate final config dictionary + + :param cfg_dict: A dictionary parsed in the facts system + :rtype: A dictionary + :returns: A dictionary by eliminating keys that have null values + """ + final_cfg = {} + if not cfg_dict: + return final_cfg + + for key, val in iteritems(cfg_dict): + dct = None + if isinstance(val, dict): + child_val = remove_empties(val) + if child_val: + dct = {key: child_val} + elif ( + isinstance(val, list) + and val + and all([isinstance(x, dict) for x in val]) + ): + child_val = [remove_empties(x) for x in val] + if child_val: + dct = {key: child_val} + elif val not in [None, [], {}, (), ""]: + dct = {key: val} + if dct: + final_cfg.update(dct) + return final_cfg + + +def validate_config(spec, data): + """ + Validate if the input data against the AnsibleModule spec format + :param spec: Ansible argument spec + :param data: Data to be validated + :return: + """ + params = basic._ANSIBLE_ARGS + basic._ANSIBLE_ARGS = to_bytes(json.dumps({"ANSIBLE_MODULE_ARGS": data})) + validated_data = basic.AnsibleModule(spec).params + basic._ANSIBLE_ARGS = params + return validated_data + + +def search_obj_in_list(name, lst, key="name"): + if not lst: + return None + else: + for item in lst: + if item.get(key) == name: + return item + + +class Template: + def __init__(self): + if not HAS_JINJA2: + raise ImportError( + "jinja2 is required but does not appear to be installed. " + "It can be installed using `pip install jinja2`" + ) + + self.env = Environment(undefined=StrictUndefined) + self.env.filters.update({"ternary": ternary}) + + def __call__(self, value, variables=None, fail_on_undefined=True): + variables = variables or {} + + if not self.contains_vars(value): + return value + + try: + value = self.env.from_string(value).render(variables) + except UndefinedError: + if not fail_on_undefined: + return None + raise + + if value: + try: + return ast.literal_eval(value) + except Exception: + return str(value) + else: + return None + + def contains_vars(self, data): + if isinstance(data, string_types): + for marker in ( + self.env.block_start_string, + self.env.variable_start_string, + self.env.comment_start_string, + ): + if marker in data: + return True + return False diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py new file mode 100644 index 00000000..1f03299b --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/netconf/netconf.py @@ -0,0 +1,147 @@ +# +# (c) 2018 Red Hat, Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +import json + +from copy import deepcopy +from contextlib import contextmanager + +try: + from lxml.etree import fromstring, tostring +except ImportError: + from xml.etree.ElementTree import fromstring, tostring + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.connection import Connection, ConnectionError +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.netconf import ( + NetconfConnection, +) + + +IGNORE_XML_ATTRIBUTE = () + + +def get_connection(module): + if hasattr(module, "_netconf_connection"): + return module._netconf_connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api == "netconf": + module._netconf_connection = NetconfConnection(module._socket_path) + else: + module.fail_json(msg="Invalid connection type %s" % network_api) + + return module._netconf_connection + + +def get_capabilities(module): + if hasattr(module, "_netconf_capabilities"): + return module._netconf_capabilities + + capabilities = Connection(module._socket_path).get_capabilities() + module._netconf_capabilities = json.loads(capabilities) + return module._netconf_capabilities + + +def lock_configuration(module, target=None): + conn = get_connection(module) + return conn.lock(target=target) + + +def unlock_configuration(module, target=None): + conn = get_connection(module) + return conn.unlock(target=target) + + +@contextmanager +def locked_config(module, target=None): + try: + lock_configuration(module, target=target) + yield + finally: + unlock_configuration(module, target=target) + + +def get_config(module, source, filter=None, lock=False): + conn = get_connection(module) + try: + locked = False + if lock: + conn.lock(target=source) + locked = True + response = conn.get_config(source=source, filter=filter) + + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + finally: + if locked: + conn.unlock(target=source) + + return response + + +def get(module, filter, lock=False): + conn = get_connection(module) + try: + locked = False + if lock: + conn.lock(target="running") + locked = True + + response = conn.get(filter=filter) + + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + finally: + if locked: + conn.unlock(target="running") + + return response + + +def dispatch(module, request): + conn = get_connection(module) + try: + response = conn.dispatch(request) + except ConnectionError as e: + module.fail_json( + msg=to_text(e, errors="surrogate_then_replace").strip() + ) + + return response + + +def sanitize_xml(data): + tree = fromstring( + to_bytes(deepcopy(data), errors="surrogate_then_replace") + ) + for element in tree.getiterator(): + # remove attributes + attribute = element.attrib + if attribute: + for key in list(attribute): + if key not in IGNORE_XML_ATTRIBUTE: + attribute.pop(key) + return to_text(tostring(tree), errors="surrogate_then_replace").strip() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py new file mode 100644 index 00000000..fba46be0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/module_utils/network/restconf/restconf.py @@ -0,0 +1,61 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2018 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from ansible.module_utils.connection import Connection + + +def get(module, path=None, content=None, fields=None, output="json"): + if path is None: + raise ValueError("path value must be provided") + if content: + path += "?" + "content=%s" % content + if fields: + path += "?" + "field=%s" % fields + + accept = None + if output == "xml": + accept = "application/yang-data+xml" + + connection = Connection(module._socket_path) + return connection.send_request( + None, path=path, method="GET", accept=accept + ) + + +def edit_config(module, path=None, content=None, method="GET", format="json"): + if path is None: + raise ValueError("path value must be provided") + + content_type = None + if format == "xml": + content_type = "application/yang-data+xml" + + connection = Connection(module._socket_path) + return connection.send_request( + content, path=path, method=method, content_type=content_type + ) diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py new file mode 100644 index 00000000..c1384c1d --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/cli_config.py @@ -0,0 +1,444 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: cli_config +author: Trishna Guha (@trishnaguha) +notes: +- The commands will be returned only for platforms that do not support onbox diff. + The C(--diff) option with the playbook will return the difference in configuration + for devices that has support for onbox diff +short_description: Push text based configuration to network devices over network_cli +description: +- This module provides platform agnostic way of pushing text based configuration to + network devices over network_cli connection plugin. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + config: + description: + - The config to be pushed to the network device. This argument is mutually exclusive + with C(rollback) and either one of the option should be given as input. The + config should have indentation that the device uses. + type: str + commit: + description: + - The C(commit) argument instructs the module to push the configuration to the + device. This is mapped to module check mode. + type: bool + replace: + description: + - If the C(replace) argument is set to C(yes), it will replace the entire running-config + of the device with the C(config) argument value. For devices that support replacing + running configuration from file on device like NXOS/JUNOS, the C(replace) argument + takes path to the file on the device that will be used for replacing the entire + running-config. The value of C(config) option should be I(None) for such devices. + Nexus 9K devices only support replace. Use I(net_put) or I(nxos_file_copy) in + case of NXOS module to copy the flat file to remote device and then use set + the fullpath to this argument. + type: str + backup: + description: + - This argument will cause the module to create a full backup of the current running + config from the remote device before any changes are made. If the C(backup_options) + value is not given, the backup file is written to the C(backup) folder in the + playbook root directory or role root directory, if playbook is part of an ansible + role. If the directory does not exist, it is created. + type: bool + default: 'no' + rollback: + description: + - The C(rollback) argument instructs the module to rollback the current configuration + to the identifier specified in the argument. If the specified rollback identifier + does not exist on the remote device, the module will fail. To rollback to the + most recent commit, set the C(rollback) argument to 0. This option is mutually + exclusive with C(config). + commit_comment: + description: + - The C(commit_comment) argument specifies a text string to be used when committing + the configuration. If the C(commit) argument is set to False, this argument + is silently ignored. This argument is only valid for the platforms that support + commit operation with comment. + type: str + defaults: + description: + - The I(defaults) argument will influence how the running-config is collected + from the device. When the value is set to true, the command used to collect + the running-config is append with the all keyword. When the value is set to + false, the command is issued without the all keyword. + default: 'no' + type: bool + multiline_delimiter: + description: + - This argument is used when pushing a multiline configuration element to the + device. It specifies the character to use as the delimiting character. This + only applies to the configuration action. + type: str + diff_replace: + description: + - Instructs the module on the way to perform the configuration on the device. + If the C(diff_replace) argument is set to I(line) then the modified lines are + pushed to the device in configuration mode. If the argument is set to I(block) + then the entire command block is pushed to the device in configuration mode + if any line is not correct. Note that this parameter will be ignored if the + platform has onbox diff support. + choices: + - line + - block + - config + diff_match: + description: + - Instructs the module on the way to perform the matching of the set of commands + against the current device config. If C(diff_match) is set to I(line), commands + are matched line by line. If C(diff_match) is set to I(strict), command lines + are matched with respect to position. If C(diff_match) is set to I(exact), command + lines must be an equal match. Finally, if C(diff_match) is set to I(none), the + module will not attempt to compare the source configuration with the running + configuration on the remote device. Note that this parameter will be ignored + if the platform has onbox diff support. + choices: + - line + - strict + - exact + - none + diff_ignore_lines: + description: + - Use this argument to specify one or more lines that should be ignored during + the diff. This is used for lines in the configuration that are automatically + updated by the system. This argument takes a list of regular expressions or + exact line matches. Note that this parameter will be ignored if the platform + has onbox diff support. + backup_options: + description: + - This is a dict object containing configurable options related to backup file + path. The value of this option is read only when C(backup) is set to I(yes), + if C(backup) is set to I(no) this option will be silently ignored. + suboptions: + filename: + description: + - The filename to be used to store the backup configuration. If the filename + is not given it will be generated based on the hostname, current time and + date in format defined by <hostname>_config.<current-date>@<current-time> + dir_path: + description: + - This option provides the path ending with directory name in which the backup + configuration file will be stored. If the directory does not exist it will + be first created and the filename is either the value of C(filename) or + default filename as described in C(filename) options description. If the + path value is not given in that case a I(backup) directory will be created + in the current working directory and backup configuration will be copied + in C(filename) within I(backup) directory. + type: path + type: dict +""" + +EXAMPLES = """ +- name: configure device with config + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + +- name: multiline config + cli_config: + config: | + hostname foo + feature nxapi + +- name: configure device with config with defaults enabled + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + defaults: yes + +- name: Use diff_match + cli_config: + config: "{{ lookup('file', 'interface_config') }}" + diff_match: none + +- name: nxos replace config + cli_config: + replace: 'bootflash:nxoscfg' + +- name: junos replace config + cli_config: + replace: '/var/home/ansible/junos01.cfg' + +- name: commit with comment + cli_config: + config: set system host-name foo + commit_comment: this is a test + +- name: configurable backup path + cli_config: + config: "{{ lookup('template', 'basic/config.j2') }}" + backup: yes + backup_options: + filename: backup.cfg + dir_path: /home/user +""" + +RETURN = """ +commands: + description: The set of commands that will be pushed to the remote device + returned: always + type: list + sample: ['interface Loopback999', 'no shutdown'] +backup_path: + description: The full path to the backup file + returned: when backup is yes + type: str + sample: /playbooks/ansible/backup/hostname_config.2016-07-16@22:28:34 +""" + +import json + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.connection import Connection +from ansible.module_utils._text import to_text + + +def validate_args(module, device_operations): + """validate param if it is supported on the platform + """ + feature_list = [ + "replace", + "rollback", + "commit_comment", + "defaults", + "multiline_delimiter", + "diff_replace", + "diff_match", + "diff_ignore_lines", + ] + + for feature in feature_list: + if module.params[feature]: + supports_feature = device_operations.get("supports_%s" % feature) + if supports_feature is None: + module.fail_json( + "This platform does not specify whether %s is supported or not. " + "Please report an issue against this platform's cliconf plugin." + % feature + ) + elif not supports_feature: + module.fail_json( + msg="Option %s is not supported on this platform" % feature + ) + + +def run( + module, device_operations, connection, candidate, running, rollback_id +): + result = {} + resp = {} + config_diff = [] + banner_diff = {} + + replace = module.params["replace"] + commit_comment = module.params["commit_comment"] + multiline_delimiter = module.params["multiline_delimiter"] + diff_replace = module.params["diff_replace"] + diff_match = module.params["diff_match"] + diff_ignore_lines = module.params["diff_ignore_lines"] + + commit = not module.check_mode + + if replace in ("yes", "true", "True"): + replace = True + elif replace in ("no", "false", "False"): + replace = False + + if ( + replace is not None + and replace not in [True, False] + and candidate is not None + ): + module.fail_json( + msg="Replace value '%s' is a configuration file path already" + " present on the device. Hence 'replace' and 'config' options" + " are mutually exclusive" % replace + ) + + if rollback_id is not None: + resp = connection.rollback(rollback_id, commit) + if "diff" in resp: + result["changed"] = True + + elif device_operations.get("supports_onbox_diff"): + if diff_replace: + module.warn( + "diff_replace is ignored as the device supports onbox diff" + ) + if diff_match: + module.warn( + "diff_mattch is ignored as the device supports onbox diff" + ) + if diff_ignore_lines: + module.warn( + "diff_ignore_lines is ignored as the device supports onbox diff" + ) + + if candidate and not isinstance(candidate, list): + candidate = candidate.strip("\n").splitlines() + + kwargs = { + "candidate": candidate, + "commit": commit, + "replace": replace, + "comment": commit_comment, + } + resp = connection.edit_config(**kwargs) + + if "diff" in resp: + result["changed"] = True + + elif device_operations.get("supports_generate_diff"): + kwargs = {"candidate": candidate, "running": running} + if diff_match: + kwargs.update({"diff_match": diff_match}) + if diff_replace: + kwargs.update({"diff_replace": diff_replace}) + if diff_ignore_lines: + kwargs.update({"diff_ignore_lines": diff_ignore_lines}) + + diff_response = connection.get_diff(**kwargs) + + config_diff = diff_response.get("config_diff") + banner_diff = diff_response.get("banner_diff") + + if config_diff: + if isinstance(config_diff, list): + candidate = config_diff + else: + candidate = config_diff.splitlines() + + kwargs = { + "candidate": candidate, + "commit": commit, + "replace": replace, + "comment": commit_comment, + } + if commit: + connection.edit_config(**kwargs) + result["changed"] = True + result["commands"] = config_diff.split("\n") + + if banner_diff: + candidate = json.dumps(banner_diff) + + kwargs = {"candidate": candidate, "commit": commit} + if multiline_delimiter: + kwargs.update({"multiline_delimiter": multiline_delimiter}) + if commit: + connection.edit_banner(**kwargs) + result["changed"] = True + + if module._diff: + if "diff" in resp: + result["diff"] = {"prepared": resp["diff"]} + else: + diff = "" + if config_diff: + if isinstance(config_diff, list): + diff += "\n".join(config_diff) + else: + diff += config_diff + if banner_diff: + diff += json.dumps(banner_diff) + result["diff"] = {"prepared": diff} + + return result + + +def main(): + """main entry point for execution + """ + backup_spec = dict(filename=dict(), dir_path=dict(type="path")) + argument_spec = dict( + backup=dict(default=False, type="bool"), + backup_options=dict(type="dict", options=backup_spec), + config=dict(type="str"), + commit=dict(type="bool"), + replace=dict(type="str"), + rollback=dict(type="int"), + commit_comment=dict(type="str"), + defaults=dict(default=False, type="bool"), + multiline_delimiter=dict(type="str"), + diff_replace=dict(choices=["line", "block", "config"]), + diff_match=dict(choices=["line", "strict", "exact", "none"]), + diff_ignore_lines=dict(type="list"), + ) + + mutually_exclusive = [("config", "rollback")] + required_one_of = [["backup", "config", "rollback"]] + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=mutually_exclusive, + required_one_of=required_one_of, + supports_check_mode=True, + ) + + result = {"changed": False} + + connection = Connection(module._socket_path) + capabilities = module.from_json(connection.get_capabilities()) + + if capabilities: + device_operations = capabilities.get("device_operations", dict()) + validate_args(module, device_operations) + else: + device_operations = dict() + + if module.params["defaults"]: + if "get_default_flag" in capabilities.get("rpc"): + flags = connection.get_default_flag() + else: + flags = "all" + else: + flags = [] + + candidate = module.params["config"] + candidate = ( + to_text(candidate, errors="surrogate_then_replace") + if candidate + else None + ) + running = connection.get_config(flags=flags) + rollback_id = module.params["rollback"] + + if module.params["backup"]: + result["__backup__"] = running + + if candidate or rollback_id or module.params["replace"]: + try: + result.update( + run( + module, + device_operations, + connection, + candidate, + running, + rollback_id, + ) + ) + except Exception as exc: + module.fail_json(msg=to_text(exc)) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py new file mode 100644 index 00000000..f0910f52 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_get.py @@ -0,0 +1,71 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: net_get +author: Deepak Agrawal (@dagrawal) +short_description: Copy a file from a network device to Ansible Controller +description: +- This module provides functionality to copy file from network device to ansible controller. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + src: + description: + - Specifies the source file. The path to the source file can either be the full + path on the network device or a relative path as per path supported by destination + network device. + required: true + protocol: + description: + - Protocol used to transfer file. + default: scp + choices: + - scp + - sftp + dest: + description: + - Specifies the destination file. The path to the destination file can either + be the full path on the Ansible control host or a relative path from the playbook + or role root directory. + default: + - Same filename as specified in I(src). The path will be playbook root or role + root directory if playbook is part of a role. +requirements: +- scp +notes: +- Some devices need specific configurations to be enabled before scp can work These + configuration should be pre-configured before using this module e.g ios - C(ip scp + server enable). +- User privilege to do scp on network device should be pre-configured e.g. ios - need + user privilege 15 by default for allowing scp. +- Default destination of source file. +""" + +EXAMPLES = """ +- name: copy file from the network device to Ansible controller + net_get: + src: running_cfg_ios1.txt + +- name: copy file from ios to common location at /tmp + net_get: + src: running_cfg_sw1.txt + dest : /tmp/ios1.txt +""" + +RETURN = """ +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py new file mode 100644 index 00000000..2fc4a98c --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/modules/net_put.py @@ -0,0 +1,82 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2018, Ansible by Red Hat, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: net_put +author: Deepak Agrawal (@dagrawal) +short_description: Copy a file from Ansible Controller to a network device +description: +- This module provides functionality to copy file from Ansible controller to network + devices. +extends_documentation_fragment: +- ansible.netcommon.network_agnostic +options: + src: + description: + - Specifies the source file. The path to the source file can either be the full + path on the Ansible control host or a relative path from the playbook or role + root directory. + required: true + protocol: + description: + - Protocol used to transfer file. + default: scp + choices: + - scp + - sftp + dest: + description: + - Specifies the destination file. The path to destination file can either be the + full path or relative path as supported by network_os. + default: + - Filename from src and at default directory of user shell on network_os. + required: false + mode: + description: + - Set the file transfer mode. If mode is set to I(text) then I(src) file will + go through Jinja2 template engine to replace any vars if present in the src + file. If mode is set to I(binary) then file will be copied as it is to destination + device. + default: binary + choices: + - binary + - text +requirements: +- scp +notes: +- Some devices need specific configurations to be enabled before scp can work These + configuration should be pre-configured before using this module e.g ios - C(ip scp + server enable). +- User privilege to do scp on network device should be pre-configured e.g. ios - need + user privilege 15 by default for allowing scp. +- Default destination of source file. +""" + +EXAMPLES = """ +- name: copy file from ansible controller to a network device + net_put: + src: running_cfg_ios1.txt + +- name: copy file at root dir of flash in slot 3 of sw1(ios) + net_put: + src: running_cfg_sw1.txt + protocol: sftp + dest : flash3:/running_cfg_sw1.txt +""" + +RETURN = """ +""" diff --git a/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py new file mode 100644 index 00000000..e9332f26 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/ansible/netcommon/plugins/netconf/default.py @@ -0,0 +1,70 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """author: Ansible Networking Team +netconf: default +short_description: Use default netconf plugin to run standard netconf commands as + per RFC +description: +- This default plugin provides low level abstraction apis for sending and receiving + netconf commands as per Netconf RFC specification. +options: + ncclient_device_handler: + type: str + default: default + description: + - Specifies the ncclient device handler name for network os that support default + netconf implementation as per Netconf RFC specification. To identify the ncclient + device handler name refer ncclient library documentation. +""" +import json + +from ansible.module_utils._text import to_text +from ansible.plugins.netconf import NetconfBase + + +class Netconf(NetconfBase): + def get_text(self, ele, tag): + try: + return to_text( + ele.find(tag).text, errors="surrogate_then_replace" + ).strip() + except AttributeError: + pass + + def get_device_info(self): + device_info = dict() + device_info["network_os"] = "default" + return device_info + + def get_capabilities(self): + result = dict() + result["rpc"] = self.get_base_rpc() + result["network_api"] = "netconf" + result["device_info"] = self.get_device_info() + result["server_capabilities"] = [c for c in self.m.server_capabilities] + result["client_capabilities"] = [c for c in self.m.client_capabilities] + result["session_id"] = self.m.session_id + result["device_operations"] = self.get_device_operations( + result["server_capabilities"] + ) + return json.dumps(result) diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py new file mode 100644 index 00000000..e5ac2cd1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/action/ios.py @@ -0,0 +1,133 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import sys +import copy + +from ansible_collections.ansible.netcommon.plugins.action.network import ( + ActionModule as ActionNetworkModule, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + load_provider, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + ios_provider_spec, +) +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionNetworkModule): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + module_name = self._task.action.split(".")[-1] + self._config_module = True if module_name == "ios_config" else False + persistent_connection = self._play_context.connection.split(".")[-1] + warnings = [] + + if persistent_connection == "network_cli": + provider = self._task.args.get("provider", {}) + if any(provider.values()): + display.warning( + "provider is unnecessary when using network_cli and will be ignored" + ) + del self._task.args["provider"] + elif self._play_context.connection == "local": + provider = load_provider(ios_provider_spec, self._task.args) + pc = copy.deepcopy(self._play_context) + pc.connection = "ansible.netcommon.network_cli" + pc.network_os = "cisco.ios.ios" + pc.remote_addr = provider["host"] or self._play_context.remote_addr + pc.port = int(provider["port"] or self._play_context.port or 22) + pc.remote_user = ( + provider["username"] or self._play_context.connection_user + ) + pc.password = provider["password"] or self._play_context.password + pc.private_key_file = ( + provider["ssh_keyfile"] or self._play_context.private_key_file + ) + pc.become = provider["authorize"] or False + if pc.become: + pc.become_method = "enable" + pc.become_pass = provider["auth_pass"] + + connection = self._shared_loader_obj.connection_loader.get( + "ansible.netcommon.persistent", + pc, + sys.stdin, + task_uuid=self._task._uuid, + ) + + # TODO: Remove below code after ansible minimal is cut out + if connection is None: + pc.connection = "network_cli" + pc.network_os = "ios" + connection = self._shared_loader_obj.connection_loader.get( + "persistent", pc, sys.stdin, task_uuid=self._task._uuid + ) + + display.vvv( + "using connection plugin %s (was local)" % pc.connection, + pc.remote_addr, + ) + + command_timeout = ( + int(provider["timeout"]) + if provider["timeout"] + else connection.get_option("persistent_command_timeout") + ) + connection.set_options( + direct={"persistent_command_timeout": command_timeout} + ) + + socket_path = connection.run() + display.vvvv("socket_path: %s" % socket_path, pc.remote_addr) + if not socket_path: + return { + "failed": True, + "msg": "unable to open shell. Please see: " + + "https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell", + } + + task_vars["ansible_socket"] = socket_path + warnings.append( + [ + "connection local support for this module is deprecated and will be removed in version 2.14, use connection %s" + % pc.connection + ] + ) + else: + return { + "failed": True, + "msg": "Connection type %s is not valid for this module" + % self._play_context.connection, + } + + result = super(ActionModule, self).run(task_vars=task_vars) + if warnings: + if "warnings" in result: + result["warnings"].extend(warnings) + else: + result["warnings"] = warnings + return result diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py new file mode 100644 index 00000000..8a390034 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/cliconf/ios.py @@ -0,0 +1,465 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """ +--- +author: Ansible Networking Team +cliconf: ios +short_description: Use ios cliconf to run command on Cisco IOS platform +description: + - This ios plugin provides low level abstraction apis for + sending and receiving CLI commands from Cisco IOS network devices. +version_added: "2.4" +""" + +import re +import time +import json + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_text +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import ( + NetworkConfig, + dumps, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible.plugins.cliconf import CliconfBase, enable_mode + + +class Cliconf(CliconfBase): + @enable_mode + def get_config(self, source="running", flags=None, format=None): + if source not in ("running", "startup"): + raise ValueError( + "fetching configuration from %s is not supported" % source + ) + + if format: + raise ValueError( + "'format' value %s is not supported for get_config" % format + ) + + if not flags: + flags = [] + if source == "running": + cmd = "show running-config " + else: + cmd = "show startup-config " + + cmd += " ".join(to_list(flags)) + cmd = cmd.strip() + + return self.send_command(cmd) + + def get_diff( + self, + candidate=None, + running=None, + diff_match="line", + diff_ignore_lines=None, + path=None, + diff_replace="line", + ): + """ + Generate diff between candidate and running configuration. If the + remote host supports onbox diff capabilities ie. supports_onbox_diff in that case + candidate and running configurations are not required to be passed as argument. + In case if onbox diff capability is not supported candidate argument is mandatory + and running argument is optional. + :param candidate: The configuration which is expected to be present on remote host. + :param running: The base configuration which is used to generate diff. + :param diff_match: Instructs how to match the candidate configuration with current device configuration + Valid values are 'line', 'strict', 'exact', 'none'. + 'line' - commands are matched line by line + 'strict' - command lines are matched with respect to position + 'exact' - command lines must be an equal match + 'none' - will not compare the candidate configuration with the running configuration + :param diff_ignore_lines: Use this argument to specify one or more lines that should be + ignored during the diff. This is used for lines in the configuration + that are automatically updated by the system. This argument takes + a list of regular expressions or exact line matches. + :param path: The ordered set of parents that uniquely identify the section or hierarchy + the commands should be checked against. If the parents argument + is omitted, the commands are checked against the set of top + level or global commands. + :param diff_replace: Instructs on the way to perform the configuration on the device. + If the replace argument is set to I(line) then the modified lines are + pushed to the device in configuration mode. If the replace argument is + set to I(block) then the entire command block is pushed to the device in + configuration mode if any line is not correct. + :return: Configuration diff in json format. + { + 'config_diff': '', + 'banner_diff': {} + } + + """ + diff = {} + device_operations = self.get_device_operations() + option_values = self.get_option_values() + + if candidate is None and device_operations["supports_generate_diff"]: + raise ValueError( + "candidate configuration is required to generate diff" + ) + + if diff_match not in option_values["diff_match"]: + raise ValueError( + "'match' value %s in invalid, valid values are %s" + % (diff_match, ", ".join(option_values["diff_match"])) + ) + + if diff_replace not in option_values["diff_replace"]: + raise ValueError( + "'replace' value %s in invalid, valid values are %s" + % (diff_replace, ", ".join(option_values["diff_replace"])) + ) + + # prepare candidate configuration + candidate_obj = NetworkConfig(indent=1) + want_src, want_banners = self._extract_banners(candidate) + candidate_obj.load(want_src) + + if running and diff_match != "none": + # running configuration + have_src, have_banners = self._extract_banners(running) + running_obj = NetworkConfig( + indent=1, contents=have_src, ignore_lines=diff_ignore_lines + ) + configdiffobjs = candidate_obj.difference( + running_obj, path=path, match=diff_match, replace=diff_replace + ) + + else: + configdiffobjs = candidate_obj.items + have_banners = {} + + diff["config_diff"] = ( + dumps(configdiffobjs, "commands") if configdiffobjs else "" + ) + banners = self._diff_banners(want_banners, have_banners) + diff["banner_diff"] = banners if banners else {} + return diff + + @enable_mode + def edit_config( + self, candidate=None, commit=True, replace=None, comment=None + ): + resp = {} + operations = self.get_device_operations() + self.check_edit_config_capability( + operations, candidate, commit, replace, comment + ) + + results = [] + requests = [] + if commit: + self.send_command("configure terminal") + for line in to_list(candidate): + if not isinstance(line, Mapping): + line = {"command": line} + + cmd = line["command"] + if cmd != "end" and cmd[0] != "!": + results.append(self.send_command(**line)) + requests.append(cmd) + + self.send_command("end") + else: + raise ValueError("check mode is not supported") + + resp["request"] = requests + resp["response"] = results + return resp + + def edit_macro( + self, candidate=None, commit=True, replace=None, comment=None + ): + """ + ios_config: + lines: "{{ macro_lines }}" + parents: "macro name {{ macro_name }}" + after: '@' + match: line + replace: block + """ + resp = {} + operations = self.get_device_operations() + self.check_edit_config_capability( + operations, candidate, commit, replace, comment + ) + + results = [] + requests = [] + if commit: + commands = "" + self.send_command("config terminal") + time.sleep(0.1) + # first item: macro command + commands += candidate.pop(0) + "\n" + multiline_delimiter = candidate.pop(-1) + for line in candidate: + commands += " " + line + "\n" + commands += multiline_delimiter + "\n" + obj = {"command": commands, "sendonly": True} + results.append(self.send_command(**obj)) + requests.append(commands) + + time.sleep(0.1) + self.send_command("end", sendonly=True) + time.sleep(0.1) + results.append(self.send_command("\n")) + requests.append("\n") + + resp["request"] = requests + resp["response"] = results + return resp + + def get( + self, + command=None, + prompt=None, + answer=None, + sendonly=False, + output=None, + newline=True, + check_all=False, + ): + if not command: + raise ValueError("must provide value of command to execute") + if output: + raise ValueError( + "'output' value %s is not supported for get" % output + ) + + return self.send_command( + command=command, + prompt=prompt, + answer=answer, + sendonly=sendonly, + newline=newline, + check_all=check_all, + ) + + def get_device_info(self): + device_info = {} + + device_info["network_os"] = "ios" + reply = self.get(command="show version") + data = to_text(reply, errors="surrogate_or_strict").strip() + + match = re.search(r"Version (\S+)", data) + if match: + device_info["network_os_version"] = match.group(1).strip(",") + + model_search_strs = [ + r"^[Cc]isco (.+) \(revision", + r"^[Cc]isco (\S+).+bytes of .*memory", + ] + for item in model_search_strs: + match = re.search(item, data, re.M) + if match: + version = match.group(1).split(" ") + device_info["network_os_model"] = version[0] + break + + match = re.search(r"^(.+) uptime", data, re.M) + if match: + device_info["network_os_hostname"] = match.group(1) + + match = re.search(r'image file is "(.+)"', data) + if match: + device_info["network_os_image"] = match.group(1) + + return device_info + + def get_device_operations(self): + return { + "supports_diff_replace": True, + "supports_commit": False, + "supports_rollback": False, + "supports_defaults": True, + "supports_onbox_diff": False, + "supports_commit_comment": False, + "supports_multiline_delimiter": True, + "supports_diff_match": True, + "supports_diff_ignore_lines": True, + "supports_generate_diff": True, + "supports_replace": False, + } + + def get_option_values(self): + return { + "format": ["text"], + "diff_match": ["line", "strict", "exact", "none"], + "diff_replace": ["line", "block"], + "output": [], + } + + def get_capabilities(self): + result = super(Cliconf, self).get_capabilities() + result["rpc"] += [ + "edit_banner", + "get_diff", + "run_commands", + "get_defaults_flag", + ] + result["device_operations"] = self.get_device_operations() + result.update(self.get_option_values()) + return json.dumps(result) + + def edit_banner( + self, candidate=None, multiline_delimiter="@", commit=True + ): + """ + Edit banner on remote device + :param banners: Banners to be loaded in json format + :param multiline_delimiter: Line delimiter for banner + :param commit: Boolean value that indicates if the device candidate + configuration should be pushed in the running configuration or discarded. + :param diff: Boolean flag to indicate if configuration that is applied on remote host should + generated and returned in response or not + :return: Returns response of executing the configuration command received + from remote host + """ + resp = {} + banners_obj = json.loads(candidate) + results = [] + requests = [] + if commit: + for key, value in iteritems(banners_obj): + key += " %s" % multiline_delimiter + self.send_command("config terminal", sendonly=True) + for cmd in [key, value, multiline_delimiter]: + obj = {"command": cmd, "sendonly": True} + results.append(self.send_command(**obj)) + requests.append(cmd) + + self.send_command("end", sendonly=True) + time.sleep(0.1) + results.append(self.send_command("\n")) + requests.append("\n") + + resp["request"] = requests + resp["response"] = results + + return resp + + def run_commands(self, commands=None, check_rc=True): + if commands is None: + raise ValueError("'commands' value is required") + + responses = list() + for cmd in to_list(commands): + if not isinstance(cmd, Mapping): + cmd = {"command": cmd} + + output = cmd.pop("output", None) + if output: + raise ValueError( + "'output' value %s is not supported for run_commands" + % output + ) + + try: + out = self.send_command(**cmd) + except AnsibleConnectionFailure as e: + if check_rc: + raise + out = getattr(e, "err", to_text(e)) + + responses.append(out) + + return responses + + def get_defaults_flag(self): + """ + The method identifies the filter that should be used to fetch running-configuration + with defaults. + :return: valid default filter + """ + out = self.get("show running-config ?") + out = to_text(out, errors="surrogate_then_replace") + + commands = set() + for line in out.splitlines(): + if line.strip(): + commands.add(line.strip().split()[0]) + + if "all" in commands: + return "all" + else: + return "full" + + def set_cli_prompt_context(self): + """ + Make sure we are in the operational cli mode + :return: None + """ + if self._connection.connected: + out = self._connection.get_prompt() + + if out is None: + raise AnsibleConnectionFailure( + message=u"cli prompt is not identified from the last received" + u" response window: %s" + % self._connection._last_recv_window + ) + + if re.search( + r"config.*\)#", + to_text(out, errors="surrogate_then_replace").strip(), + ): + self._connection.queue_message( + "vvvv", "wrong context, sending end to device" + ) + self._connection.send_command("end") + + def _extract_banners(self, config): + banners = {} + banner_cmds = re.findall(r"^banner (\w+)", config, re.M) + for cmd in banner_cmds: + regex = r"banner %s \^C(.+?)(?=\^C)" % cmd + match = re.search(regex, config, re.S) + if match: + key = "banner %s" % cmd + banners[key] = match.group(1).strip() + + for cmd in banner_cmds: + regex = r"banner %s \^C(.+?)(?=\^C)" % cmd + match = re.search(regex, config, re.S) + if match: + config = config.replace(str(match.group(1)), "") + + config = re.sub(r"banner \w+ \^C\^C", "!! banner removed", config) + return config, banners + + def _diff_banners(self, want, have): + candidate = {} + for key, value in iteritems(want): + if value != have.get(key): + candidate[key] = value + return candidate diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py new file mode 100644 index 00000000..ff22d27c --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/doc_fragments/ios.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Peter Sprygada <psprygada@ansible.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: + provider: + description: + - B(Deprecated) + - 'Starting with Ansible 2.5 we recommend using C(connection: network_cli).' + - For more information please see the L(IOS Platform Options guide, ../network/user_guide/platform_ios.html). + - HORIZONTALLINE + - A dict object containing connection details. + type: dict + suboptions: + host: + description: + - Specifies the DNS host name or address for connecting to the remote device + over the specified transport. The value of host is used as the destination + address for the transport. + type: str + required: true + port: + description: + - Specifies the port to use when building the connection to the remote device. + type: int + default: 22 + username: + description: + - Configures the username to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME) + will be used instead. + type: str + password: + description: + - Specifies the password to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) + will be used instead. + type: str + timeout: + description: + - Specifies the timeout in seconds for communicating with the network device + for either connecting or sending commands. If the timeout is exceeded before + the operation is completed, the module will error. + type: int + default: 10 + ssh_keyfile: + description: + - Specifies the SSH key to use to authenticate the connection to the remote + device. This value is the path to the key used to authenticate the SSH + session. If the value is not specified in the task, the value of environment + variable C(ANSIBLE_NET_SSH_KEYFILE) will be used instead. + type: path + authorize: + description: + - Instructs the module to enter privileged mode on the remote device before + sending any commands. If not specified, the device will attempt to execute + all commands in non-privileged mode. If the value is not specified in the + task, the value of environment variable C(ANSIBLE_NET_AUTHORIZE) will be + used instead. + type: bool + default: false + auth_pass: + description: + - Specifies the password to use if required to enter privileged mode on the + remote device. If I(authorize) is false, then this argument does nothing. + If the value is not specified in the task, the value of environment variable + C(ANSIBLE_NET_AUTH_PASS) will be used instead. + type: str +notes: +- For more information on using Ansible to manage network devices see the :ref:`Ansible + Network Guide <network_guide>` +- For more information on using Ansible to manage Cisco devices see the `Cisco integration + page <https://www.ansible.com/integrations/networks/cisco>`_. +""" diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py new file mode 100644 index 00000000..6818a0ce --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/module_utils/network/ios/ios.py @@ -0,0 +1,197 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import env_fallback +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible.module_utils.connection import Connection, ConnectionError + +_DEVICE_CONFIGS = {} + +ios_provider_spec = { + "host": dict(), + "port": dict(type="int"), + "username": dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])), + "password": dict( + fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]), no_log=True + ), + "ssh_keyfile": dict( + fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path" + ), + "authorize": dict( + fallback=(env_fallback, ["ANSIBLE_NET_AUTHORIZE"]), type="bool" + ), + "auth_pass": dict( + fallback=(env_fallback, ["ANSIBLE_NET_AUTH_PASS"]), no_log=True + ), + "timeout": dict(type="int"), +} +ios_argument_spec = { + "provider": dict( + type="dict", options=ios_provider_spec, removed_in_version=2.14 + ) +} + + +def get_provider_argspec(): + return ios_provider_spec + + +def get_connection(module): + if hasattr(module, "_ios_connection"): + return module._ios_connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api == "cliconf": + module._ios_connection = Connection(module._socket_path) + else: + module.fail_json(msg="Invalid connection type %s" % network_api) + + return module._ios_connection + + +def get_capabilities(module): + if hasattr(module, "_ios_capabilities"): + return module._ios_capabilities + try: + capabilities = Connection(module._socket_path).get_capabilities() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + module._ios_capabilities = json.loads(capabilities) + return module._ios_capabilities + + +def get_defaults_flag(module): + connection = get_connection(module) + try: + out = connection.get_defaults_flag() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + return to_text(out, errors="surrogate_then_replace").strip() + + +def get_config(module, flags=None): + flags = to_list(flags) + + section_filter = False + if flags and "section" in flags[-1]: + section_filter = True + + flag_str = " ".join(flags) + + try: + return _DEVICE_CONFIGS[flag_str] + except KeyError: + connection = get_connection(module) + try: + out = connection.get_config(flags=flags) + except ConnectionError as exc: + if section_filter: + # Some ios devices don't understand `| section foo` + out = get_config(module, flags=flags[:-1]) + else: + module.fail_json( + msg=to_text(exc, errors="surrogate_then_replace") + ) + cfg = to_text(out, errors="surrogate_then_replace").strip() + _DEVICE_CONFIGS[flag_str] = cfg + return cfg + + +def run_commands(module, commands, check_rc=True): + connection = get_connection(module) + try: + return connection.run_commands(commands=commands, check_rc=check_rc) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc)) + + +def load_config(module, commands): + connection = get_connection(module) + + try: + resp = connection.edit_config(commands) + return resp.get("response") + except ConnectionError as exc: + module.fail_json(msg=to_text(exc)) + + +def normalize_interface(name): + """Return the normalized interface name + """ + if not name: + return + + def _get_number(name): + digits = "" + for char in name: + if char.isdigit() or char in "/.": + digits += char + return digits + + if name.lower().startswith("gi"): + if_type = "GigabitEthernet" + elif name.lower().startswith("te"): + if_type = "TenGigabitEthernet" + elif name.lower().startswith("fa"): + if_type = "FastEthernet" + elif name.lower().startswith("fo"): + if_type = "FortyGigabitEthernet" + elif name.lower().startswith("et"): + if_type = "Ethernet" + elif name.lower().startswith("vl"): + if_type = "Vlan" + elif name.lower().startswith("lo"): + if_type = "loopback" + elif name.lower().startswith("po"): + if_type = "port-channel" + elif name.lower().startswith("nv"): + if_type = "nve" + elif name.lower().startswith("twe"): + if_type = "TwentyFiveGigE" + elif name.lower().startswith("hu"): + if_type = "HundredGigE" + else: + if_type = None + + number_list = name.split(" ") + if len(number_list) == 2: + if_number = number_list[-1].strip() + else: + if_number = _get_number(name) + + if if_type: + proper_interface = if_type + if_number + else: + proper_interface = name + + return proper_interface diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py new file mode 100644 index 00000000..ef383fcc --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_command.py @@ -0,0 +1,229 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: ios_command +author: Peter Sprygada (@privateip) +short_description: Run commands on remote devices running Cisco IOS +description: +- Sends arbitrary commands to an ios node and returns the results read from the device. + This module includes an argument that will cause the module to wait for a specific + condition before returning or timing out if the condition is not met. +- This module does not support running commands in configuration mode. Please use + M(ios_config) to configure IOS devices. +extends_documentation_fragment: +- cisco.ios.ios +notes: +- Tested against IOS 15.6 +options: + commands: + description: + - List of commands to send to the remote ios device over the configured provider. + The resulting output from the command is returned. If the I(wait_for) argument + is provided, the module is not returned until the condition is satisfied or + the number of retries has expired. If a command sent to the device requires + answering a prompt, it is possible to pass a dict containing I(command), I(answer) + and I(prompt). Common answers are 'y' or "\r" (carriage return, must be double + quotes). See examples. + required: true + wait_for: + description: + - List of conditions to evaluate against the output of the command. The task will + wait for each condition to be true before moving forward. If the conditional + is not true within the configured number of retries, the task fails. See examples. + aliases: + - waitfor + match: + description: + - The I(match) argument is used in conjunction with the I(wait_for) argument to + specify the match policy. Valid values are C(all) or C(any). If the value + is set to C(all) then all conditionals in the wait_for must be satisfied. If + the value is set to C(any) then only one of the values must be satisfied. + default: all + choices: + - any + - all + retries: + description: + - Specifies the number of retries a command should by tried before it is considered + failed. The command is run on the target device every retry and evaluated against + the I(wait_for) conditions. + default: 10 + interval: + description: + - Configures the interval in seconds to wait between retries of the command. If + the command does not pass the specified conditions, the interval indicates how + long to wait before trying the command again. + default: 1 +""" + +EXAMPLES = r""" +tasks: + - name: run show version on remote devices + ios_command: + commands: show version + + - name: run show version and check to see if output contains IOS + ios_command: + commands: show version + wait_for: result[0] contains IOS + + - name: run multiple commands on remote nodes + ios_command: + commands: + - show version + - show interfaces + + - name: run multiple commands and evaluate the output + ios_command: + commands: + - show version + - show interfaces + wait_for: + - result[0] contains IOS + - result[1] contains Loopback0 + + - name: run commands that require answering a prompt + ios_command: + commands: + - command: 'clear counters GigabitEthernet0/1' + prompt: 'Clear "show interface" counters on this interface \[confirm\]' + answer: 'y' + - command: 'clear counters GigabitEthernet0/2' + prompt: '[confirm]' + answer: "\r" +""" + +RETURN = """ +stdout: + description: The set of responses from the commands + returned: always apart from low level errors (such as action plugin) + type: list + sample: ['...', '...'] +stdout_lines: + description: The value of stdout split into a list + returned: always apart from low level errors (such as action plugin) + type: list + sample: [['...', '...'], ['...'], ['...']] +failed_conditions: + description: The list of conditionals that have failed + returned: failed + type: list + sample: ['...', '...'] +""" +import time + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import ( + Conditional, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + transform_commands, + to_lines, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + run_commands, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + ios_argument_spec, +) + + +def parse_commands(module, warnings): + commands = transform_commands(module) + + if module.check_mode: + for item in list(commands): + if not item["command"].startswith("show"): + warnings.append( + "Only show commands are supported when using check mode, not " + "executing %s" % item["command"] + ) + commands.remove(item) + + return commands + + +def main(): + """main entry point for module execution + """ + argument_spec = dict( + commands=dict(type="list", required=True), + wait_for=dict(type="list", aliases=["waitfor"]), + match=dict(default="all", choices=["all", "any"]), + retries=dict(default=10, type="int"), + interval=dict(default=1, type="int"), + ) + + argument_spec.update(ios_argument_spec) + + module = AnsibleModule( + argument_spec=argument_spec, supports_check_mode=True + ) + + warnings = list() + result = {"changed": False, "warnings": warnings} + commands = parse_commands(module, warnings) + wait_for = module.params["wait_for"] or list() + + try: + conditionals = [Conditional(c) for c in wait_for] + except AttributeError as exc: + module.fail_json(msg=to_text(exc)) + + retries = module.params["retries"] + interval = module.params["interval"] + match = module.params["match"] + + while retries > 0: + responses = run_commands(module, commands) + + for item in list(conditionals): + if item(responses): + if match == "any": + conditionals = list() + break + conditionals.remove(item) + + if not conditionals: + break + + time.sleep(interval) + retries -= 1 + + if conditionals: + failed_conditions = [item.raw for item in conditionals] + msg = "One or more conditional statements have not been satisfied" + module.fail_json(msg=msg, failed_conditions=failed_conditions) + + result.update( + {"stdout": responses, "stdout_lines": list(to_lines(responses))} + ) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py new file mode 100644 index 00000000..beec5b8d --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/modules/ios_config.py @@ -0,0 +1,596 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: ios_config +author: Peter Sprygada (@privateip) +short_description: Manage Cisco IOS configuration sections +description: +- Cisco IOS configurations use a simple block indent file syntax for segmenting configuration + into sections. This module provides an implementation for working with IOS configuration + sections in a deterministic way. +extends_documentation_fragment: +- cisco.ios.ios +notes: +- Tested against IOS 15.6 +- Abbreviated commands are NOT idempotent, see L(Network FAQ,../network/user_guide/faq.html#why-do-the-config-modules-always-return-changed-true-with-abbreviated-commands). +options: + lines: + description: + - The ordered set of commands that should be configured in the section. The commands + must be the exact same commands as found in the device running-config. Be sure + to note the configuration command syntax as some commands are automatically + modified by the device config parser. + aliases: + - commands + parents: + description: + - The ordered set of parents that uniquely identify the section or hierarchy the + commands should be checked against. If the parents argument is omitted, the + commands are checked against the set of top level or global commands. + src: + description: + - Specifies the source path to the file that contains the configuration or configuration + template to load. The path to the source file can either be the full path on + the Ansible control host or a relative path from the playbook or role root directory. This + argument is mutually exclusive with I(lines), I(parents). + before: + description: + - The ordered set of commands to push on to the command stack if a change needs + to be made. This allows the playbook designer the opportunity to perform configuration + commands prior to pushing any changes without affecting how the set of commands + are matched against the system. + after: + description: + - The ordered set of commands to append to the end of the command stack if a change + needs to be made. Just like with I(before) this allows the playbook designer + to append a set of commands to be executed after the command set. + match: + description: + - Instructs the module on the way to perform the matching of the set of commands + against the current device config. If match is set to I(line), commands are + matched line by line. If match is set to I(strict), command lines are matched + with respect to position. If match is set to I(exact), command lines must be + an equal match. Finally, if match is set to I(none), the module will not attempt + to compare the source configuration with the running configuration on the remote + device. + choices: + - line + - strict + - exact + - none + default: line + replace: + description: + - Instructs the module on the way to perform the configuration on the device. + If the replace argument is set to I(line) then the modified lines are pushed + to the device in configuration mode. If the replace argument is set to I(block) + then the entire command block is pushed to the device in configuration mode + if any line is not correct. + default: line + choices: + - line + - block + multiline_delimiter: + description: + - This argument is used when pushing a multiline configuration element to the + IOS device. It specifies the character to use as the delimiting character. This + only applies to the configuration action. + default: '@' + backup: + description: + - This argument will cause the module to create a full backup of the current C(running-config) + from the remote device before any changes are made. If the C(backup_options) + value is not given, the backup file is written to the C(backup) folder in the + playbook root directory or role root directory, if playbook is part of an ansible + role. If the directory does not exist, it is created. + type: bool + default: 'no' + running_config: + description: + - The module, by default, will connect to the remote device and retrieve the current + running-config to use as a base for comparing against the contents of source. + There are times when it is not desirable to have the task get the current running-config + for every task in a playbook. The I(running_config) argument allows the implementer + to pass in the configuration to use as the base config for comparison. + aliases: + - config + defaults: + description: + - This argument specifies whether or not to collect all defaults when getting + the remote device running config. When enabled, the module will get the current + config by issuing the command C(show running-config all). + type: bool + default: 'no' + save_when: + description: + - When changes are made to the device running-configuration, the changes are not + copied to non-volatile storage by default. Using this argument will change + that before. If the argument is set to I(always), then the running-config will + always be copied to the startup-config and the I(modified) flag will always + be set to True. If the argument is set to I(modified), then the running-config + will only be copied to the startup-config if it has changed since the last save + to startup-config. If the argument is set to I(never), the running-config will + never be copied to the startup-config. If the argument is set to I(changed), + then the running-config will only be copied to the startup-config if the task + has made a change. I(changed) was added in Ansible 2.5. + default: never + choices: + - always + - never + - modified + - changed + diff_against: + description: + - When using the C(ansible-playbook --diff) command line argument the module can + generate diffs against different sources. + - When this option is configure as I(startup), the module will return the diff + of the running-config against the startup-config. + - When this option is configured as I(intended), the module will return the diff + of the running-config against the configuration provided in the C(intended_config) + argument. + - When this option is configured as I(running), the module will return the before + and after diff of the running-config with respect to any changes made to the + device configuration. + choices: + - running + - startup + - intended + diff_ignore_lines: + description: + - Use this argument to specify one or more lines that should be ignored during + the diff. This is used for lines in the configuration that are automatically + updated by the system. This argument takes a list of regular expressions or + exact line matches. + intended_config: + description: + - The C(intended_config) provides the master configuration that the node should + conform to and is used to check the final running-config against. This argument + will not modify any settings on the remote device and is strictly used to check + the compliance of the current device's configuration against. When specifying + this argument, the task should also modify the C(diff_against) value and set + it to I(intended). + backup_options: + description: + - This is a dict object containing configurable options related to backup file + path. The value of this option is read only when C(backup) is set to I(yes), + if C(backup) is set to I(no) this option will be silently ignored. + suboptions: + filename: + description: + - The filename to be used to store the backup configuration. If the filename + is not given it will be generated based on the hostname, current time and + date in format defined by <hostname>_config.<current-date>@<current-time> + dir_path: + description: + - This option provides the path ending with directory name in which the backup + configuration file will be stored. If the directory does not exist it will + be first created and the filename is either the value of C(filename) or + default filename as described in C(filename) options description. If the + path value is not given in that case a I(backup) directory will be created + in the current working directory and backup configuration will be copied + in C(filename) within I(backup) directory. + type: path + type: dict +""" + +EXAMPLES = """ +- name: configure top level configuration + ios_config: + lines: hostname {{ inventory_hostname }} + +- name: configure interface settings + ios_config: + lines: + - description test interface + - ip address 172.31.1.1 255.255.255.0 + parents: interface Ethernet1 + +- name: configure ip helpers on multiple interfaces + ios_config: + lines: + - ip helper-address 172.26.1.10 + - ip helper-address 172.26.3.8 + parents: "{{ item }}" + with_items: + - interface Ethernet1 + - interface Ethernet2 + - interface GigabitEthernet1 + +- name: configure policer in Scavenger class + ios_config: + lines: + - conform-action transmit + - exceed-action drop + parents: + - policy-map Foo + - class Scavenger + - police cir 64000 + +- name: load new acl into device + ios_config: + lines: + - 10 permit ip host 192.0.2.1 any log + - 20 permit ip host 192.0.2.2 any log + - 30 permit ip host 192.0.2.3 any log + - 40 permit ip host 192.0.2.4 any log + - 50 permit ip host 192.0.2.5 any log + parents: ip access-list extended test + before: no ip access-list extended test + match: exact + +- name: check the running-config against master config + ios_config: + diff_against: intended + intended_config: "{{ lookup('file', 'master.cfg') }}" + +- name: check the startup-config against the running-config + ios_config: + diff_against: startup + diff_ignore_lines: + - ntp clock .* + +- name: save running to startup when modified + ios_config: + save_when: modified + +- name: for idempotency, use full-form commands + ios_config: + lines: + # - shut + - shutdown + # parents: int gig1/0/11 + parents: interface GigabitEthernet1/0/11 + +# Set boot image based on comparison to a group_var (version) and the version +# that is returned from the `ios_facts` module +- name: SETTING BOOT IMAGE + ios_config: + lines: + - no boot system + - boot system flash bootflash:{{new_image}} + host: "{{ inventory_hostname }}" + when: ansible_net_version != version + +- name: render a Jinja2 template onto an IOS device + ios_config: + backup: yes + src: ios_template.j2 + +- name: configurable backup path + ios_config: + src: ios_template.j2 + backup: yes + backup_options: + filename: backup.cfg + dir_path: /home/user +""" + +RETURN = """ +updates: + description: The set of commands that will be pushed to the remote device + returned: always + type: list + sample: ['hostname foo', 'router ospf 1', 'router-id 192.0.2.1'] +commands: + description: The set of commands that will be pushed to the remote device + returned: always + type: list + sample: ['hostname foo', 'router ospf 1', 'router-id 192.0.2.1'] +backup_path: + description: The full path to the backup file + returned: when backup is yes + type: str + sample: /playbooks/ansible/backup/ios_config.2016-07-16@22:28:34 +filename: + description: The name of the backup file + returned: when backup is yes and filename is not specified in backup options + type: str + sample: ios_config.2016-07-16@22:28:34 +shortname: + description: The full path to the backup file excluding the timestamp + returned: when backup is yes and filename is not specified in backup options + type: str + sample: /playbooks/ansible/backup/ios_config +date: + description: The date extracted from the backup file name + returned: when backup is yes + type: str + sample: "2016-07-16" +time: + description: The time extracted from the backup file name + returned: when backup is yes + type: str + sample: "22:28:34" +""" +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.connection import ConnectionError +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + run_commands, + get_config, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + get_defaults_flag, + get_connection, +) +from ansible_collections.cisco.ios.plugins.module_utils.network.ios.ios import ( + ios_argument_spec, +) +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import ( + NetworkConfig, + dumps, +) + + +def check_args(module, warnings): + if module.params["multiline_delimiter"]: + if len(module.params["multiline_delimiter"]) != 1: + module.fail_json( + msg="multiline_delimiter value can only be a " + "single character" + ) + + +def edit_config_or_macro(connection, commands): + # only catch the macro configuration command, + # not negated 'no' variation. + if commands[0].startswith("macro name"): + connection.edit_macro(candidate=commands) + else: + connection.edit_config(candidate=commands) + + +def get_candidate_config(module): + candidate = "" + if module.params["src"]: + candidate = module.params["src"] + + elif module.params["lines"]: + candidate_obj = NetworkConfig(indent=1) + parents = module.params["parents"] or list() + candidate_obj.add(module.params["lines"], parents=parents) + candidate = dumps(candidate_obj, "raw") + + return candidate + + +def get_running_config(module, current_config=None, flags=None): + running = module.params["running_config"] + if not running: + if not module.params["defaults"] and current_config: + running = current_config + else: + running = get_config(module, flags=flags) + + return running + + +def save_config(module, result): + result["changed"] = True + if not module.check_mode: + run_commands(module, "copy running-config startup-config\r") + else: + module.warn( + "Skipping command `copy running-config startup-config` " + "due to check_mode. Configuration not copied to " + "non-volatile storage" + ) + + +def main(): + """ main entry point for module execution + """ + backup_spec = dict(filename=dict(), dir_path=dict(type="path")) + argument_spec = dict( + src=dict(type="path"), + lines=dict(aliases=["commands"], type="list"), + parents=dict(type="list"), + before=dict(type="list"), + after=dict(type="list"), + match=dict( + default="line", choices=["line", "strict", "exact", "none"] + ), + replace=dict(default="line", choices=["line", "block"]), + multiline_delimiter=dict(default="@"), + running_config=dict(aliases=["config"]), + intended_config=dict(), + defaults=dict(type="bool", default=False), + backup=dict(type="bool", default=False), + backup_options=dict(type="dict", options=backup_spec), + save_when=dict( + choices=["always", "never", "modified", "changed"], default="never" + ), + diff_against=dict(choices=["startup", "intended", "running"]), + diff_ignore_lines=dict(type="list"), + ) + + argument_spec.update(ios_argument_spec) + + mutually_exclusive = [("lines", "src"), ("parents", "src")] + + required_if = [ + ("match", "strict", ["lines"]), + ("match", "exact", ["lines"]), + ("replace", "block", ["lines"]), + ("diff_against", "intended", ["intended_config"]), + ] + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=mutually_exclusive, + required_if=required_if, + supports_check_mode=True, + ) + + result = {"changed": False} + + warnings = list() + check_args(module, warnings) + result["warnings"] = warnings + + diff_ignore_lines = module.params["diff_ignore_lines"] + config = None + contents = None + flags = get_defaults_flag(module) if module.params["defaults"] else [] + connection = get_connection(module) + + if module.params["backup"] or ( + module._diff and module.params["diff_against"] == "running" + ): + contents = get_config(module, flags=flags) + config = NetworkConfig(indent=1, contents=contents) + if module.params["backup"]: + result["__backup__"] = contents + + if any((module.params["lines"], module.params["src"])): + match = module.params["match"] + replace = module.params["replace"] + path = module.params["parents"] + + candidate = get_candidate_config(module) + running = get_running_config(module, contents, flags=flags) + try: + response = connection.get_diff( + candidate=candidate, + running=running, + diff_match=match, + diff_ignore_lines=diff_ignore_lines, + path=path, + diff_replace=replace, + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + config_diff = response["config_diff"] + banner_diff = response["banner_diff"] + + if config_diff or banner_diff: + commands = config_diff.split("\n") + + if module.params["before"]: + commands[:0] = module.params["before"] + + if module.params["after"]: + commands.extend(module.params["after"]) + + result["commands"] = commands + result["updates"] = commands + result["banners"] = banner_diff + + # send the configuration commands to the device and merge + # them with the current running config + if not module.check_mode: + if commands: + edit_config_or_macro(connection, commands) + if banner_diff: + connection.edit_banner( + candidate=json.dumps(banner_diff), + multiline_delimiter=module.params[ + "multiline_delimiter" + ], + ) + + result["changed"] = True + + running_config = module.params["running_config"] + startup_config = None + + if module.params["save_when"] == "always": + save_config(module, result) + elif module.params["save_when"] == "modified": + output = run_commands( + module, ["show running-config", "show startup-config"] + ) + + running_config = NetworkConfig( + indent=1, contents=output[0], ignore_lines=diff_ignore_lines + ) + startup_config = NetworkConfig( + indent=1, contents=output[1], ignore_lines=diff_ignore_lines + ) + + if running_config.sha1 != startup_config.sha1: + save_config(module, result) + elif module.params["save_when"] == "changed" and result["changed"]: + save_config(module, result) + + if module._diff: + if not running_config: + output = run_commands(module, "show running-config") + contents = output[0] + else: + contents = running_config + + # recreate the object in order to process diff_ignore_lines + running_config = NetworkConfig( + indent=1, contents=contents, ignore_lines=diff_ignore_lines + ) + + if module.params["diff_against"] == "running": + if module.check_mode: + module.warn( + "unable to perform diff against running-config due to check mode" + ) + contents = None + else: + contents = config.config_text + + elif module.params["diff_against"] == "startup": + if not startup_config: + output = run_commands(module, "show startup-config") + contents = output[0] + else: + contents = startup_config.config_text + + elif module.params["diff_against"] == "intended": + contents = module.params["intended_config"] + + if contents is not None: + base_config = NetworkConfig( + indent=1, contents=contents, ignore_lines=diff_ignore_lines + ) + + if running_config.sha1 != base_config.sha1: + if module.params["diff_against"] == "intended": + before = running_config + after = base_config + elif module.params["diff_against"] in ("startup", "running"): + before = base_config + after = running_config + + result.update( + { + "changed": True, + "diff": {"before": str(before), "after": str(after)}, + } + ) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py new file mode 100644 index 00000000..29f31b0e --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/cisco/ios/plugins/terminal/ios.py @@ -0,0 +1,115 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import json +import re + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_text, to_bytes +from ansible.plugins.terminal import TerminalBase +from ansible.utils.display import Display + +display = Display() + + +class TerminalModule(TerminalBase): + + terminal_stdout_re = [ + re.compile(br"[\r\n]?[\w\+\-\.:\/\[\]]+(?:\([^\)]+\)){0,3}(?:[>#]) ?$") + ] + + terminal_stderr_re = [ + re.compile(br"% ?Error"), + # re.compile(br"^% \w+", re.M), + re.compile(br"% ?Bad secret"), + re.compile(br"[\r\n%] Bad passwords"), + re.compile(br"invalid input", re.I), + re.compile(br"(?:incomplete|ambiguous) command", re.I), + re.compile(br"connection timed out", re.I), + re.compile(br"[^\r\n]+ not found"), + re.compile(br"'[^']' +returned error code: ?\d+"), + re.compile(br"Bad mask", re.I), + re.compile(br"% ?(\S+) ?overlaps with ?(\S+)", re.I), + re.compile(br"[%\S] ?Error: ?[\s]+", re.I), + re.compile(br"[%\S] ?Informational: ?[\s]+", re.I), + re.compile(br"Command authorization failed"), + ] + + def on_open_shell(self): + try: + self._exec_cli_command(b"terminal length 0") + except AnsibleConnectionFailure: + raise AnsibleConnectionFailure("unable to set terminal parameters") + + try: + self._exec_cli_command(b"terminal width 512") + try: + self._exec_cli_command(b"terminal width 0") + except AnsibleConnectionFailure: + pass + except AnsibleConnectionFailure: + display.display( + "WARNING: Unable to set terminal width, command responses may be truncated" + ) + + def on_become(self, passwd=None): + if self._get_prompt().endswith(b"#"): + return + + cmd = {u"command": u"enable"} + if passwd: + # Note: python-3.5 cannot combine u"" and r"" together. Thus make + # an r string and use to_text to ensure it's text on both py2 and py3. + cmd[u"prompt"] = to_text( + r"[\r\n]?(?:.*)?[Pp]assword: ?$", errors="surrogate_or_strict" + ) + cmd[u"answer"] = passwd + cmd[u"prompt_retry_check"] = True + try: + self._exec_cli_command( + to_bytes(json.dumps(cmd), errors="surrogate_or_strict") + ) + prompt = self._get_prompt() + if prompt is None or not prompt.endswith(b"#"): + raise AnsibleConnectionFailure( + "failed to elevate privilege to enable mode still at prompt [%s]" + % prompt + ) + except AnsibleConnectionFailure as e: + prompt = self._get_prompt() + raise AnsibleConnectionFailure( + "unable to elevate privilege to enable mode, at prompt [%s] with error: %s" + % (prompt, e.message) + ) + + def on_unbecome(self): + prompt = self._get_prompt() + if prompt is None: + # if prompt is None most likely the terminal is hung up at a prompt + return + + if b"(config" in prompt: + self._exec_cli_command(b"end") + self._exec_cli_command(b"disable") + + elif prompt.endswith(b"#"): + self._exec_cli_command(b"disable") diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py new file mode 100644 index 00000000..cab2f3fd --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/action/vyos.py @@ -0,0 +1,129 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import sys +import copy + +from ansible_collections.ansible.netcommon.plugins.action.network import ( + ActionModule as ActionNetworkModule, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + load_provider, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_provider_spec, +) +from ansible.utils.display import Display + +display = Display() + + +class ActionModule(ActionNetworkModule): + def run(self, tmp=None, task_vars=None): + del tmp # tmp no longer has any effect + + module_name = self._task.action.split(".")[-1] + self._config_module = True if module_name == "vyos_config" else False + persistent_connection = self._play_context.connection.split(".")[-1] + warnings = [] + + if persistent_connection == "network_cli": + provider = self._task.args.get("provider", {}) + if any(provider.values()): + display.warning( + "provider is unnecessary when using network_cli and will be ignored" + ) + del self._task.args["provider"] + elif self._play_context.connection == "local": + provider = load_provider(vyos_provider_spec, self._task.args) + pc = copy.deepcopy(self._play_context) + pc.connection = "ansible.netcommon.network_cli" + pc.network_os = "vyos.vyos.vyos" + pc.remote_addr = provider["host"] or self._play_context.remote_addr + pc.port = int(provider["port"] or self._play_context.port or 22) + pc.remote_user = ( + provider["username"] or self._play_context.connection_user + ) + pc.password = provider["password"] or self._play_context.password + pc.private_key_file = ( + provider["ssh_keyfile"] or self._play_context.private_key_file + ) + + connection = self._shared_loader_obj.connection_loader.get( + "ansible.netcommon.persistent", + pc, + sys.stdin, + task_uuid=self._task._uuid, + ) + + # TODO: Remove below code after ansible minimal is cut out + if connection is None: + pc.connection = "network_cli" + pc.network_os = "vyos" + connection = self._shared_loader_obj.connection_loader.get( + "persistent", pc, sys.stdin, task_uuid=self._task._uuid + ) + + display.vvv( + "using connection plugin %s (was local)" % pc.connection, + pc.remote_addr, + ) + + command_timeout = ( + int(provider["timeout"]) + if provider["timeout"] + else connection.get_option("persistent_command_timeout") + ) + connection.set_options( + direct={"persistent_command_timeout": command_timeout} + ) + + socket_path = connection.run() + display.vvvv("socket_path: %s" % socket_path, pc.remote_addr) + if not socket_path: + return { + "failed": True, + "msg": "unable to open shell. Please see: " + + "https://docs.ansible.com/ansible/network_debug_troubleshooting.html#unable-to-open-shell", + } + + task_vars["ansible_socket"] = socket_path + warnings.append( + [ + "connection local support for this module is deprecated and will be removed in version 2.14, use connection %s" + % pc.connection + ] + ) + else: + return { + "failed": True, + "msg": "Connection type %s is not valid for this module" + % self._play_context.connection, + } + + result = super(ActionModule, self).run(task_vars=task_vars) + if warnings: + if "warnings" in result: + result["warnings"].extend(warnings) + else: + result["warnings"] = warnings + return result diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py new file mode 100644 index 00000000..30336031 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/cliconf/vyos.py @@ -0,0 +1,342 @@ +# +# (c) 2017 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +DOCUMENTATION = """ +--- +author: Ansible Networking Team +cliconf: vyos +short_description: Use vyos cliconf to run command on VyOS platform +description: + - This vyos plugin provides low level abstraction apis for + sending and receiving CLI commands from VyOS network devices. +version_added: "2.4" +""" + +import re +import json + +from ansible.errors import AnsibleConnectionFailure +from ansible.module_utils._text import to_text +from ansible.module_utils.common._collections_compat import Mapping +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.config import ( + NetworkConfig, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, +) +from ansible.plugins.cliconf import CliconfBase + + +class Cliconf(CliconfBase): + def get_device_info(self): + device_info = {} + + device_info["network_os"] = "vyos" + reply = self.get("show version") + data = to_text(reply, errors="surrogate_or_strict").strip() + + match = re.search(r"Version:\s*(.*)", data) + if match: + device_info["network_os_version"] = match.group(1) + + match = re.search(r"HW model:\s*(\S+)", data) + if match: + device_info["network_os_model"] = match.group(1) + + reply = self.get("show host name") + device_info["network_os_hostname"] = to_text( + reply, errors="surrogate_or_strict" + ).strip() + + return device_info + + def get_config(self, flags=None, format=None): + if format: + option_values = self.get_option_values() + if format not in option_values["format"]: + raise ValueError( + "'format' value %s is invalid. Valid values of format are %s" + % (format, ", ".join(option_values["format"])) + ) + + if not flags: + flags = [] + + if format == "text": + command = "show configuration" + else: + command = "show configuration commands" + + command += " ".join(to_list(flags)) + command = command.strip() + + out = self.send_command(command) + return out + + def edit_config( + self, candidate=None, commit=True, replace=None, comment=None + ): + resp = {} + operations = self.get_device_operations() + self.check_edit_config_capability( + operations, candidate, commit, replace, comment + ) + + results = [] + requests = [] + self.send_command("configure") + for cmd in to_list(candidate): + if not isinstance(cmd, Mapping): + cmd = {"command": cmd} + + results.append(self.send_command(**cmd)) + requests.append(cmd["command"]) + out = self.get("compare") + out = to_text(out, errors="surrogate_or_strict") + diff_config = out if not out.startswith("No changes") else None + + if diff_config: + if commit: + try: + self.commit(comment) + except AnsibleConnectionFailure as e: + msg = "commit failed: %s" % e.message + self.discard_changes() + raise AnsibleConnectionFailure(msg) + else: + self.send_command("exit") + else: + self.discard_changes() + else: + self.send_command("exit") + if ( + to_text( + self._connection.get_prompt(), errors="surrogate_or_strict" + ) + .strip() + .endswith("#") + ): + self.discard_changes() + + if diff_config: + resp["diff"] = diff_config + resp["response"] = results + resp["request"] = requests + return resp + + def get( + self, + command=None, + prompt=None, + answer=None, + sendonly=False, + output=None, + newline=True, + check_all=False, + ): + if not command: + raise ValueError("must provide value of command to execute") + if output: + raise ValueError( + "'output' value %s is not supported for get" % output + ) + + return self.send_command( + command=command, + prompt=prompt, + answer=answer, + sendonly=sendonly, + newline=newline, + check_all=check_all, + ) + + def commit(self, comment=None): + if comment: + command = 'commit comment "{0}"'.format(comment) + else: + command = "commit" + self.send_command(command) + + def discard_changes(self): + self.send_command("exit discard") + + def get_diff( + self, + candidate=None, + running=None, + diff_match="line", + diff_ignore_lines=None, + path=None, + diff_replace=None, + ): + diff = {} + device_operations = self.get_device_operations() + option_values = self.get_option_values() + + if candidate is None and device_operations["supports_generate_diff"]: + raise ValueError( + "candidate configuration is required to generate diff" + ) + + if diff_match not in option_values["diff_match"]: + raise ValueError( + "'match' value %s in invalid, valid values are %s" + % (diff_match, ", ".join(option_values["diff_match"])) + ) + + if diff_replace: + raise ValueError("'replace' in diff is not supported") + + if diff_ignore_lines: + raise ValueError("'diff_ignore_lines' in diff is not supported") + + if path: + raise ValueError("'path' in diff is not supported") + + set_format = candidate.startswith("set") or candidate.startswith( + "delete" + ) + candidate_obj = NetworkConfig(indent=4, contents=candidate) + if not set_format: + config = [c.line for c in candidate_obj.items] + commands = list() + # this filters out less specific lines + for item in config: + for index, entry in enumerate(commands): + if item.startswith(entry): + del commands[index] + break + commands.append(item) + + candidate_commands = [ + "set %s" % cmd.replace(" {", "") for cmd in commands + ] + + else: + candidate_commands = str(candidate).strip().split("\n") + + if diff_match == "none": + diff["config_diff"] = list(candidate_commands) + return diff + + running_commands = [ + str(c).replace("'", "") for c in running.splitlines() + ] + + updates = list() + visited = set() + + for line in candidate_commands: + item = str(line).replace("'", "") + + if not item.startswith("set") and not item.startswith("delete"): + raise ValueError( + "line must start with either `set` or `delete`" + ) + + elif item.startswith("set") and item not in running_commands: + updates.append(line) + + elif item.startswith("delete"): + if not running_commands: + updates.append(line) + else: + item = re.sub(r"delete", "set", item) + for entry in running_commands: + if entry.startswith(item) and line not in visited: + updates.append(line) + visited.add(line) + + diff["config_diff"] = list(updates) + return diff + + def run_commands(self, commands=None, check_rc=True): + if commands is None: + raise ValueError("'commands' value is required") + + responses = list() + for cmd in to_list(commands): + if not isinstance(cmd, Mapping): + cmd = {"command": cmd} + + output = cmd.pop("output", None) + if output: + raise ValueError( + "'output' value %s is not supported for run_commands" + % output + ) + + try: + out = self.send_command(**cmd) + except AnsibleConnectionFailure as e: + if check_rc: + raise + out = getattr(e, "err", e) + + responses.append(out) + + return responses + + def get_device_operations(self): + return { + "supports_diff_replace": False, + "supports_commit": True, + "supports_rollback": False, + "supports_defaults": False, + "supports_onbox_diff": True, + "supports_commit_comment": True, + "supports_multiline_delimiter": False, + "supports_diff_match": True, + "supports_diff_ignore_lines": False, + "supports_generate_diff": False, + "supports_replace": False, + } + + def get_option_values(self): + return { + "format": ["text", "set"], + "diff_match": ["line", "none"], + "diff_replace": [], + "output": [], + } + + def get_capabilities(self): + result = super(Cliconf, self).get_capabilities() + result["rpc"] += [ + "commit", + "discard_changes", + "get_diff", + "run_commands", + ] + result["device_operations"] = self.get_device_operations() + result.update(self.get_option_values()) + return json.dumps(result) + + def set_cli_prompt_context(self): + """ + Make sure we are in the operational cli mode + :return: None + """ + if self._connection.connected: + self._update_cli_prompt_context( + config_context="#", exit_command="exit discard" + ) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py new file mode 100644 index 00000000..094963f1 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/doc_fragments/vyos.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Peter Sprygada <psprygada@ansible.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +class ModuleDocFragment(object): + + # Standard files documentation fragment + DOCUMENTATION = r"""options: + provider: + description: + - B(Deprecated) + - 'Starting with Ansible 2.5 we recommend using C(connection: network_cli).' + - For more information please see the L(Network Guide, ../network/getting_started/network_differences.html#multiple-communication-protocols). + - HORIZONTALLINE + - A dict object containing connection details. + type: dict + suboptions: + host: + description: + - Specifies the DNS host name or address for connecting to the remote device + over the specified transport. The value of host is used as the destination + address for the transport. + type: str + required: true + port: + description: + - Specifies the port to use when building the connection to the remote device. + type: int + default: 22 + username: + description: + - Configures the username to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_USERNAME) + will be used instead. + type: str + password: + description: + - Specifies the password to use to authenticate the connection to the remote + device. This value is used to authenticate the SSH session. If the value + is not specified in the task, the value of environment variable C(ANSIBLE_NET_PASSWORD) + will be used instead. + type: str + timeout: + description: + - Specifies the timeout in seconds for communicating with the network device + for either connecting or sending commands. If the timeout is exceeded before + the operation is completed, the module will error. + type: int + default: 10 + ssh_keyfile: + description: + - Specifies the SSH key to use to authenticate the connection to the remote + device. This value is the path to the key used to authenticate the SSH + session. If the value is not specified in the task, the value of environment + variable C(ANSIBLE_NET_SSH_KEYFILE) will be used instead. + type: path +notes: +- For more information on using Ansible to manage network devices see the :ref:`Ansible + Network Guide <network_guide>` +""" diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py new file mode 100644 index 00000000..46fabaa2 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/facts/facts.py @@ -0,0 +1,22 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The arg spec for the vyos facts module. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class FactsArgs(object): # pylint: disable=R0903 + """ The arg spec for the vyos facts module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "gather_subset": dict(default=["!config"], type="list"), + "gather_network_resources": dict(type="list"), + } diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py new file mode 100644 index 00000000..a018cc0b --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/firewall_rules/firewall_rules.py @@ -0,0 +1,263 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_firewall_rules module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Firewall_rulesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_firewall_rules module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "afi": { + "choices": ["ipv4", "ipv6"], + "required": True, + "type": "str", + }, + "rule_sets": { + "elements": "dict", + "options": { + "default_action": { + "choices": ["drop", "reject", "accept"], + "type": "str", + }, + "description": {"type": "str"}, + "enable_default_log": {"type": "bool"}, + "name": {"type": "str"}, + "rules": { + "elements": "dict", + "options": { + "action": { + "choices": [ + "drop", + "reject", + "accept", + "inspect", + ], + "type": "str", + }, + "description": {"type": "str"}, + "destination": { + "options": { + "address": {"type": "str"}, + "group": { + "options": { + "address_group": { + "type": "str" + }, + "network_group": { + "type": "str" + }, + "port_group": {"type": "str"}, + }, + "type": "dict", + }, + "port": {"type": "str"}, + }, + "type": "dict", + }, + "disabled": {"type": "bool"}, + "fragment": { + "choices": [ + "match-frag", + "match-non-frag", + ], + "type": "str", + }, + "icmp": { + "options": { + "code": {"type": "int"}, + "type": {"type": "int"}, + "type_name": { + "choices": [ + "any", + "echo-reply", + "destination-unreachable", + "network-unreachable", + "host-unreachable", + "protocol-unreachable", + "port-unreachable", + "fragmentation-needed", + "source-route-failed", + "network-unknown", + "host-unknown", + "network-prohibited", + "host-prohibited", + "TOS-network-unreachable", + "TOS-host-unreachable", + "communication-prohibited", + "host-precedence-violation", + "precedence-cutoff", + "source-quench", + "redirect", + "network-redirect", + "host-redirect", + "TOS-network-redirect", + "TOS-host-redirect", + "echo-request", + "router-advertisement", + "router-solicitation", + "time-exceeded", + "ttl-zero-during-transit", + "ttl-zero-during-reassembly", + "parameter-problem", + "ip-header-bad", + "required-option-missing", + "timestamp-request", + "timestamp-reply", + "address-mask-request", + "address-mask-reply", + "ping", + "pong", + "ttl-exceeded", + ], + "type": "str", + }, + }, + "type": "dict", + }, + "ipsec": { + "choices": ["match-ipsec", "match-none"], + "type": "str", + }, + "limit": { + "options": { + "burst": {"type": "int"}, + "rate": { + "options": { + "number": {"type": "int"}, + "unit": {"type": "str"}, + }, + "type": "dict", + }, + }, + "type": "dict", + }, + "number": {"required": True, "type": "int"}, + "p2p": { + "elements": "dict", + "options": { + "application": { + "choices": [ + "all", + "applejuice", + "bittorrent", + "directconnect", + "edonkey", + "gnutella", + "kazaa", + ], + "type": "str", + } + }, + "type": "list", + }, + "protocol": {"type": "str"}, + "recent": { + "options": { + "count": {"type": "int"}, + "time": {"type": "int"}, + }, + "type": "dict", + }, + "source": { + "options": { + "address": {"type": "str"}, + "group": { + "options": { + "address_group": { + "type": "str" + }, + "network_group": { + "type": "str" + }, + "port_group": {"type": "str"}, + }, + "type": "dict", + }, + "mac_address": {"type": "str"}, + "port": {"type": "str"}, + }, + "type": "dict", + }, + "state": { + "options": { + "established": {"type": "bool"}, + "invalid": {"type": "bool"}, + "new": {"type": "bool"}, + "related": {"type": "bool"}, + }, + "type": "dict", + }, + "tcp": { + "options": {"flags": {"type": "str"}}, + "type": "dict", + }, + "time": { + "options": { + "monthdays": {"type": "str"}, + "startdate": {"type": "str"}, + "starttime": {"type": "str"}, + "stopdate": {"type": "str"}, + "stoptime": {"type": "str"}, + "utc": {"type": "bool"}, + "weekdays": {"type": "str"}, + }, + "type": "dict", + }, + }, + "type": "list", + }, + }, + "type": "list", + }, + }, + "type": "list", + }, + "running_config": {"type": "str"}, + "state": { + "choices": [ + "merged", + "replaced", + "overridden", + "deleted", + "gathered", + "rendered", + "parsed", + ], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py new file mode 100644 index 00000000..3542cb19 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/interfaces/interfaces.py @@ -0,0 +1,69 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_interfaces module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class InterfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "description": {"type": "str"}, + "duplex": {"choices": ["full", "half", "auto"]}, + "enabled": {"default": True, "type": "bool"}, + "mtu": {"type": "int"}, + "name": {"required": True, "type": "str"}, + "speed": { + "choices": ["auto", "10", "100", "1000", "2500", "10000"], + "type": "str", + }, + "vifs": { + "elements": "dict", + "options": { + "vlan_id": {"type": "int"}, + "description": {"type": "str"}, + "enabled": {"default": True, "type": "bool"}, + "mtu": {"type": "int"}, + }, + "type": "list", + }, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py new file mode 100644 index 00000000..91434e4b --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/l3_interfaces/l3_interfaces.py @@ -0,0 +1,81 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_l3_interfaces module +""" + + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class L3_interfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_l3_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "ipv4": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "ipv6": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "name": {"required": True, "type": "str"}, + "vifs": { + "elements": "dict", + "options": { + "ipv4": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "ipv6": { + "elements": "dict", + "options": {"address": {"type": "str"}}, + "type": "list", + }, + "vlan_id": {"type": "int"}, + }, + "type": "list", + }, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py new file mode 100644 index 00000000..97c5d5a2 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lag_interfaces/lag_interfaces.py @@ -0,0 +1,80 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# + +""" +The arg spec for the vyos_lag_interfaces module +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Lag_interfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_lag_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "arp_monitor": { + "options": { + "interval": {"type": "int"}, + "target": {"type": "list"}, + }, + "type": "dict", + }, + "hash_policy": { + "choices": ["layer2", "layer2+3", "layer3+4"], + "type": "str", + }, + "members": { + "elements": "dict", + "options": {"member": {"type": "str"}}, + "type": "list", + }, + "mode": { + "choices": [ + "802.3ad", + "active-backup", + "broadcast", + "round-robin", + "transmit-load-balance", + "adaptive-load-balance", + "xor-hash", + ], + "type": "str", + }, + "name": {"required": True, "type": "str"}, + "primary": {"type": "str"}, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py new file mode 100644 index 00000000..84bbc00c --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_global/lldp_global.py @@ -0,0 +1,56 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# + +""" +The arg spec for the vyos_lldp_global module +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Lldp_globalArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_lldp_global module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "options": { + "address": {"type": "str"}, + "enable": {"type": "bool"}, + "legacy_protocols": { + "choices": ["cdp", "edp", "fdp", "sonmp"], + "type": "list", + }, + "snmp": {"type": "str"}, + }, + "type": "dict", + }, + "state": { + "choices": ["merged", "replaced", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py new file mode 100644 index 00000000..2976fc09 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/lldp_interfaces/lldp_interfaces.py @@ -0,0 +1,89 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_lldp_interfaces module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Lldp_interfacesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_lldp_interfaces module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "enable": {"default": True, "type": "bool"}, + "location": { + "options": { + "civic_based": { + "options": { + "ca_info": { + "elements": "dict", + "options": { + "ca_type": {"type": "int"}, + "ca_value": {"type": "str"}, + }, + "type": "list", + }, + "country_code": { + "required": True, + "type": "str", + }, + }, + "type": "dict", + }, + "coordinate_based": { + "options": { + "altitude": {"type": "int"}, + "datum": { + "choices": ["WGS84", "NAD83", "MLLW"], + "type": "str", + }, + "latitude": {"required": True, "type": "str"}, + "longitude": {"required": True, "type": "str"}, + }, + "type": "dict", + }, + "elin": {"type": "str"}, + }, + "type": "dict", + }, + "name": {"required": True, "type": "str"}, + }, + "type": "list", + }, + "state": { + "choices": ["merged", "replaced", "overridden", "deleted"], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py new file mode 100644 index 00000000..8ecd955a --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/argspec/static_routes/static_routes.py @@ -0,0 +1,99 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# +""" +The arg spec for the vyos_static_routes module +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +class Static_routesArgs(object): # pylint: disable=R0903 + """The arg spec for the vyos_static_routes module + """ + + def __init__(self, **kwargs): + pass + + argument_spec = { + "config": { + "elements": "dict", + "options": { + "address_families": { + "elements": "dict", + "options": { + "afi": { + "choices": ["ipv4", "ipv6"], + "required": True, + "type": "str", + }, + "routes": { + "elements": "dict", + "options": { + "blackhole_config": { + "options": { + "distance": {"type": "int"}, + "type": {"type": "str"}, + }, + "type": "dict", + }, + "dest": {"required": True, "type": "str"}, + "next_hops": { + "elements": "dict", + "options": { + "admin_distance": {"type": "int"}, + "enabled": {"type": "bool"}, + "forward_router_address": { + "required": True, + "type": "str", + }, + "interface": {"type": "str"}, + }, + "type": "list", + }, + }, + "type": "list", + }, + }, + "type": "list", + } + }, + "type": "list", + }, + "running_config": {"type": "str"}, + "state": { + "choices": [ + "merged", + "replaced", + "overridden", + "deleted", + "gathered", + "rendered", + "parsed", + ], + "default": "merged", + "type": "str", + }, + } # pylint: disable=C0301 diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py new file mode 100644 index 00000000..377fec9a --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/config/lldp_interfaces/lldp_interfaces.py @@ -0,0 +1,438 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos_lldp_interfaces class +It is in this file where the current configuration (as dict) +is compared to the provided configuration (as dict) and the command set +necessary to bring the current configuration to it's desired end-state is +created +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.cfg.base import ( + ConfigBase, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.facts import ( + Facts, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + to_list, + dict_diff, +) +from ansible.module_utils.six import iteritems +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.utils.utils import ( + search_obj_in_list, + search_dict_tv_in_list, + key_value_in_dict, + is_dict_element_present, +) + + +class Lldp_interfaces(ConfigBase): + """ + The vyos_lldp_interfaces class + """ + + gather_subset = [ + "!all", + "!min", + ] + + gather_network_resources = [ + "lldp_interfaces", + ] + + params = ["enable", "location", "name"] + + def __init__(self, module): + super(Lldp_interfaces, self).__init__(module) + + def get_lldp_interfaces_facts(self): + """ Get the 'facts' (the current configuration) + + :rtype: A dictionary + :returns: The current configuration as a dictionary + """ + facts, _warnings = Facts(self._module).get_facts( + self.gather_subset, self.gather_network_resources + ) + lldp_interfaces_facts = facts["ansible_network_resources"].get( + "lldp_interfaces" + ) + if not lldp_interfaces_facts: + return [] + return lldp_interfaces_facts + + def execute_module(self): + """ Execute the module + + :rtype: A dictionary + :returns: The result from module execution + """ + result = {"changed": False} + commands = list() + warnings = list() + existing_lldp_interfaces_facts = self.get_lldp_interfaces_facts() + commands.extend(self.set_config(existing_lldp_interfaces_facts)) + if commands: + if self._module.check_mode: + resp = self._connection.edit_config(commands, commit=False) + else: + resp = self._connection.edit_config(commands) + result["changed"] = True + + result["commands"] = commands + + if self._module._diff: + result["diff"] = resp["diff"] if result["changed"] else None + + changed_lldp_interfaces_facts = self.get_lldp_interfaces_facts() + result["before"] = existing_lldp_interfaces_facts + if result["changed"]: + result["after"] = changed_lldp_interfaces_facts + + result["warnings"] = warnings + return result + + def set_config(self, existing_lldp_interfaces_facts): + """ Collect the configuration from the args passed to the module, + collect the current configuration (as a dict from facts) + + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + want = self._module.params["config"] + have = existing_lldp_interfaces_facts + resp = self.set_state(want, have) + return to_list(resp) + + def set_state(self, want, have): + """ Select the appropriate function based on the state provided + + :param want: the desired configuration as a dictionary + :param have: the current configuration as a dictionary + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + commands = [] + state = self._module.params["state"] + if state in ("merged", "replaced", "overridden") and not want: + self._module.fail_json( + msg="value of config parameter must not be empty for state {0}".format( + state + ) + ) + if state == "overridden": + commands.extend(self._state_overridden(want=want, have=have)) + elif state == "deleted": + if want: + for item in want: + name = item["name"] + have_item = search_obj_in_list(name, have) + commands.extend( + self._state_deleted(want=None, have=have_item) + ) + else: + for have_item in have: + commands.extend( + self._state_deleted(want=None, have=have_item) + ) + else: + for want_item in want: + name = want_item["name"] + have_item = search_obj_in_list(name, have) + if state == "merged": + commands.extend( + self._state_merged(want=want_item, have=have_item) + ) + else: + commands.extend( + self._state_replaced(want=want_item, have=have_item) + ) + return commands + + def _state_replaced(self, want, have): + """ The command generator when state is replaced + + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + commands = [] + if have: + commands.extend(self._state_deleted(want, have)) + commands.extend(self._state_merged(want, have)) + return commands + + def _state_overridden(self, want, have): + """ The command generator when state is overridden + + :rtype: A list + :returns: the commands necessary to migrate the current configuration + to the desired configuration + """ + commands = [] + for have_item in have: + lldp_name = have_item["name"] + lldp_in_want = search_obj_in_list(lldp_name, want) + if not lldp_in_want: + commands.append( + self._compute_command(have_item["name"], remove=True) + ) + + for want_item in want: + name = want_item["name"] + lldp_in_have = search_obj_in_list(name, have) + commands.extend(self._state_replaced(want_item, lldp_in_have)) + return commands + + def _state_merged(self, want, have): + """ The command generator when state is merged + + :rtype: A list + :returns: the commands necessary to merge the provided into + the current configuration + """ + commands = [] + if have: + commands.extend(self._render_updates(want, have)) + else: + commands.extend(self._render_set_commands(want)) + return commands + + def _state_deleted(self, want, have): + """ The command generator when state is deleted + + :rtype: A list + :returns: the commands necessary to remove the current configuration + of the provided objects + """ + commands = [] + if want: + params = Lldp_interfaces.params + for attrib in params: + if attrib == "location": + commands.extend( + self._update_location(have["name"], want, have) + ) + + elif have: + commands.append(self._compute_command(have["name"], remove=True)) + return commands + + def _render_updates(self, want, have): + commands = [] + lldp_name = have["name"] + commands.extend(self._configure_status(lldp_name, want, have)) + commands.extend(self._add_location(lldp_name, want, have)) + + return commands + + def _render_set_commands(self, want): + commands = [] + have = {} + lldp_name = want["name"] + params = Lldp_interfaces.params + + commands.extend(self._add_location(lldp_name, want, have)) + for attrib in params: + value = want[attrib] + if value: + if attrib == "location": + commands.extend(self._add_location(lldp_name, want, have)) + elif attrib == "enable": + if not value: + commands.append( + self._compute_command(lldp_name, value="disable") + ) + else: + commands.append(self._compute_command(lldp_name)) + + return commands + + def _configure_status(self, name, want_item, have_item): + commands = [] + if is_dict_element_present(have_item, "enable"): + temp_have_item = False + else: + temp_have_item = True + if want_item["enable"] != temp_have_item: + if want_item["enable"]: + commands.append( + self._compute_command(name, value="disable", remove=True) + ) + else: + commands.append(self._compute_command(name, value="disable")) + return commands + + def _add_location(self, name, want_item, have_item): + commands = [] + have_dict = {} + have_ca = {} + set_cmd = name + " location " + want_location_type = want_item.get("location") or {} + have_location_type = have_item.get("location") or {} + + if want_location_type["coordinate_based"]: + want_dict = want_location_type.get("coordinate_based") or {} + if is_dict_element_present(have_location_type, "coordinate_based"): + have_dict = have_location_type.get("coordinate_based") or {} + location_type = "coordinate-based" + updates = dict_diff(have_dict, want_dict) + for key, value in iteritems(updates): + if value: + commands.append( + self._compute_command( + set_cmd + location_type, key, str(value) + ) + ) + + elif want_location_type["civic_based"]: + location_type = "civic-based" + want_dict = want_location_type.get("civic_based") or {} + want_ca = want_dict.get("ca_info") or [] + if is_dict_element_present(have_location_type, "civic_based"): + have_dict = have_location_type.get("civic_based") or {} + have_ca = have_dict.get("ca_info") or [] + if want_dict["country_code"] != have_dict["country_code"]: + commands.append( + self._compute_command( + set_cmd + location_type, + "country-code", + str(want_dict["country_code"]), + ) + ) + else: + commands.append( + self._compute_command( + set_cmd + location_type, + "country-code", + str(want_dict["country_code"]), + ) + ) + commands.extend(self._add_civic_address(name, want_ca, have_ca)) + + elif want_location_type["elin"]: + location_type = "elin" + if is_dict_element_present(have_location_type, "elin"): + if want_location_type.get("elin") != have_location_type.get( + "elin" + ): + commands.append( + self._compute_command( + set_cmd + location_type, + value=str(want_location_type["elin"]), + ) + ) + else: + commands.append( + self._compute_command( + set_cmd + location_type, + value=str(want_location_type["elin"]), + ) + ) + return commands + + def _update_location(self, name, want_item, have_item): + commands = [] + del_cmd = name + " location" + want_location_type = want_item.get("location") or {} + have_location_type = have_item.get("location") or {} + + if want_location_type["coordinate_based"]: + want_dict = want_location_type.get("coordinate_based") or {} + if is_dict_element_present(have_location_type, "coordinate_based"): + have_dict = have_location_type.get("coordinate_based") or {} + location_type = "coordinate-based" + for key, value in iteritems(have_dict): + only_in_have = key_value_in_dict(key, value, want_dict) + if not only_in_have: + commands.append( + self._compute_command( + del_cmd + location_type, key, str(value), True + ) + ) + else: + commands.append(self._compute_command(del_cmd, remove=True)) + + elif want_location_type["civic_based"]: + want_dict = want_location_type.get("civic_based") or {} + want_ca = want_dict.get("ca_info") or [] + if is_dict_element_present(have_location_type, "civic_based"): + have_dict = have_location_type.get("civic_based") or {} + have_ca = have_dict.get("ca_info") + commands.extend( + self._update_civic_address(name, want_ca, have_ca) + ) + else: + commands.append(self._compute_command(del_cmd, remove=True)) + + else: + if is_dict_element_present(have_location_type, "elin"): + if want_location_type.get("elin") != have_location_type.get( + "elin" + ): + commands.append( + self._compute_command(del_cmd, remove=True) + ) + else: + commands.append(self._compute_command(del_cmd, remove=True)) + return commands + + def _add_civic_address(self, name, want, have): + commands = [] + for item in want: + ca_type = item["ca_type"] + ca_value = item["ca_value"] + obj_in_have = search_dict_tv_in_list( + ca_type, ca_value, have, "ca_type", "ca_value" + ) + if not obj_in_have: + commands.append( + self._compute_command( + key=name + " location civic-based ca-type", + attrib=str(ca_type) + " ca-value", + value=ca_value, + ) + ) + return commands + + def _update_civic_address(self, name, want, have): + commands = [] + for item in have: + ca_type = item["ca_type"] + ca_value = item["ca_value"] + in_want = search_dict_tv_in_list( + ca_type, ca_value, want, "ca_type", "ca_value" + ) + if not in_want: + commands.append( + self._compute_command( + name, + "location civic-based ca-type", + str(ca_type), + remove=True, + ) + ) + return commands + + def _compute_command(self, key, attrib=None, value=None, remove=False): + if remove: + cmd = "delete service lldp interface " + else: + cmd = "set service lldp interface " + cmd += key + if attrib: + cmd += " " + attrib + if value: + cmd += " '" + value + "'" + return cmd diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py new file mode 100644 index 00000000..8f0a3bb6 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/facts.py @@ -0,0 +1,83 @@ +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The facts class for vyos +this file validates each subset of facts and selectively +calls the appropriate facts gathering function +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.facts.facts import ( + FactsBase, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.interfaces.interfaces import ( + InterfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.l3_interfaces.l3_interfaces import ( + L3_interfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lag_interfaces.lag_interfaces import ( + Lag_interfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lldp_global.lldp_global import ( + Lldp_globalFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.lldp_interfaces.lldp_interfaces import ( + Lldp_interfacesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.firewall_rules.firewall_rules import ( + Firewall_rulesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.static_routes.static_routes import ( + Static_routesFacts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.legacy.base import ( + Default, + Neighbors, + Config, +) + + +FACT_LEGACY_SUBSETS = dict(default=Default, neighbors=Neighbors, config=Config) +FACT_RESOURCE_SUBSETS = dict( + interfaces=InterfacesFacts, + l3_interfaces=L3_interfacesFacts, + lag_interfaces=Lag_interfacesFacts, + lldp_global=Lldp_globalFacts, + lldp_interfaces=Lldp_interfacesFacts, + static_routes=Static_routesFacts, + firewall_rules=Firewall_rulesFacts, +) + + +class Facts(FactsBase): + """ The fact class for vyos + """ + + VALID_LEGACY_GATHER_SUBSETS = frozenset(FACT_LEGACY_SUBSETS.keys()) + VALID_RESOURCE_SUBSETS = frozenset(FACT_RESOURCE_SUBSETS.keys()) + + def __init__(self, module): + super(Facts, self).__init__(module) + + def get_facts( + self, legacy_facts_type=None, resource_facts_type=None, data=None + ): + """ Collect the facts for vyos + :param legacy_facts_type: List of legacy facts types + :param resource_facts_type: List of resource fact types + :param data: previously collected conf + :rtype: dict + :return: the facts gathered + """ + if self.VALID_RESOURCE_SUBSETS: + self.get_network_resources_facts( + FACT_RESOURCE_SUBSETS, resource_facts_type, data + ) + if self.VALID_LEGACY_GATHER_SUBSETS: + self.get_network_legacy_facts( + FACT_LEGACY_SUBSETS, legacy_facts_type + ) + return self.ansible_facts, self._warnings diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py new file mode 100644 index 00000000..971ea6fe --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/firewall_rules/firewall_rules.py @@ -0,0 +1,380 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos firewall_rules fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from re import findall, search, M +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.firewall_rules.firewall_rules import ( + Firewall_rulesArgs, +) + + +class Firewall_rulesFacts(object): + """ The vyos firewall_rules fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Firewall_rulesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def get_device_data(self, connection): + return connection.get_config() + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for firewall_rules + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + # typically data is populated from the current device configuration + # data = connection.get('show running-config | section ^interface') + # using mock data instead + data = self.get_device_data(connection) + # split the config into instances of the resource + objs = [] + v6_rules = findall( + r"^set firewall ipv6-name (?:\'*)(\S+)(?:\'*)", data, M + ) + v4_rules = findall(r"^set firewall name (?:\'*)(\S+)(?:\'*)", data, M) + if v6_rules: + config = self.get_rules(data, v6_rules, type="ipv6") + if config: + config = utils.remove_empties(config) + objs.append(config) + if v4_rules: + config = self.get_rules(data, v4_rules, type="ipv4") + if config: + config = utils.remove_empties(config) + objs.append(config) + + ansible_facts["ansible_network_resources"].pop("firewall_rules", None) + facts = {} + if objs: + facts["firewall_rules"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["firewall_rules"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def get_rules(self, data, rules, type): + """ + This function performs following: + - Form regex to fetch 'rule-sets' specific config from data. + - Form the rule-set list based on ip address. + :param data: configuration. + :param rules: list of rule-sets. + :param type: ip address type. + :return: generated rule-sets configuration. + """ + r_v4 = [] + r_v6 = [] + for r in set(rules): + rule_regex = r" %s .+$" % r.strip("'") + cfg = findall(rule_regex, data, M) + fr = self.render_config(cfg, r.strip("'")) + fr["name"] = r.strip("'") + if type == "ipv6": + r_v6.append(fr) + else: + r_v4.append(fr) + if r_v4: + config = {"afi": "ipv4", "rule_sets": r_v4} + if r_v6: + config = {"afi": "ipv6", "rule_sets": r_v6} + return config + + def render_config(self, conf, match): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + conf = "\n".join(filter(lambda x: x, conf)) + a_lst = ["description", "default_action", "enable_default_log"] + config = self.parse_attr(conf, a_lst, match) + if not config: + config = {} + config["rules"] = self.parse_rules_lst(conf) + return config + + def parse_rules_lst(self, conf): + """ + This function forms the regex to fetch the 'rules' with in + 'rule-sets' + :param conf: configuration data. + :return: generated rule list configuration. + """ + r_lst = [] + rules = findall(r"rule (?:\'*)(\d+)(?:\'*)", conf, M) + if rules: + rules_lst = [] + for r in set(rules): + r_regex = r" %s .+$" % r + cfg = "\n".join(findall(r_regex, conf, M)) + obj = self.parse_rules(cfg) + obj["number"] = int(r) + if obj: + rules_lst.append(obj) + r_lst = sorted(rules_lst, key=lambda i: i["number"]) + return r_lst + + def parse_rules(self, conf): + """ + This function triggers the parsing of 'rule' attributes. + a_lst is a list having rule attributes which doesn't + have further sub attributes. + :param conf: configuration + :return: generated rule configuration dictionary. + """ + a_lst = [ + "ipsec", + "action", + "protocol", + "fragment", + "disabled", + "description", + ] + rule = self.parse_attr(conf, a_lst) + r_sub = { + "p2p": self.parse_p2p(conf), + "tcp": self.parse_tcp(conf, "tcp"), + "icmp": self.parse_icmp(conf, "icmp"), + "time": self.parse_time(conf, "time"), + "limit": self.parse_limit(conf, "limit"), + "state": self.parse_state(conf, "state"), + "recent": self.parse_recent(conf, "recent"), + "source": self.parse_src_or_dest(conf, "source"), + "destination": self.parse_src_or_dest(conf, "destination"), + } + rule.update(r_sub) + return rule + + def parse_p2p(self, conf): + """ + This function forms the regex to fetch the 'p2p' with in + 'rules' + :param conf: configuration data. + :return: generated rule list configuration. + """ + a_lst = [] + applications = findall(r"p2p (?:\'*)(\d+)(?:\'*)", conf, M) + if applications: + app_lst = [] + for r in set(applications): + obj = {"application": r.strip("'")} + app_lst.append(obj) + a_lst = sorted(app_lst, key=lambda i: i["application"]) + return a_lst + + def parse_src_or_dest(self, conf, attrib=None): + """ + This function triggers the parsing of 'source or + destination' attributes. + :param conf: configuration. + :param attrib:'source/destination'. + :return:generated source/destination configuration dictionary. + """ + a_lst = ["port", "address", "mac_address"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + cfg_dict["group"] = self.parse_group(conf, attrib + " group") + return cfg_dict + + def parse_recent(self, conf, attrib=None): + """ + This function triggers the parsing of 'recent' attributes + :param conf: configuration. + :param attrib: 'recent'. + :return: generated config dictionary. + """ + a_lst = ["time", "count"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_tcp(self, conf, attrib=None): + """ + This function triggers the parsing of 'tcp' attributes. + :param conf: configuration. + :param attrib: 'tcp'. + :return: generated config dictionary. + """ + cfg_dict = self.parse_attr(conf, ["flags"], match=attrib) + return cfg_dict + + def parse_time(self, conf, attrib=None): + """ + This function triggers the parsing of 'time' attributes. + :param conf: configuration. + :param attrib: 'time'. + :return: generated config dictionary. + """ + a_lst = [ + "stopdate", + "stoptime", + "weekdays", + "monthdays", + "startdate", + "starttime", + ] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_state(self, conf, attrib=None): + """ + This function triggers the parsing of 'state' attributes. + :param conf: configuration + :param attrib: 'state'. + :return: generated config dictionary. + """ + a_lst = ["new", "invalid", "related", "established"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_group(self, conf, attrib=None): + """ + This function triggers the parsing of 'group' attributes. + :param conf: configuration. + :param attrib: 'group'. + :return: generated config dictionary. + """ + a_lst = ["port_group", "address_group", "network_group"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_icmp(self, conf, attrib=None): + """ + This function triggers the parsing of 'icmp' attributes. + :param conf: configuration to be parsed. + :param attrib: 'icmp'. + :return: generated config dictionary. + """ + a_lst = ["code", "type", "type_name"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_limit(self, conf, attrib=None): + """ + This function triggers the parsing of 'limit' attributes. + :param conf: configuration to be parsed. + :param attrib: 'limit' + :return: generated config dictionary. + """ + cfg_dict = self.parse_attr(conf, ["burst"], match=attrib) + cfg_dict["rate"] = self.parse_rate(conf, "rate") + return cfg_dict + + def parse_rate(self, conf, attrib=None): + """ + This function triggers the parsing of 'rate' attributes. + :param conf: configuration. + :param attrib: 'rate' + :return: generated config dictionary. + """ + a_lst = ["unit", "number"] + cfg_dict = self.parse_attr(conf, a_lst, match=attrib) + return cfg_dict + + def parse_attr(self, conf, attr_list, match=None): + """ + This function peforms the following: + - Form the regex to fetch the required attribute config. + - Type cast the output in desired format. + :param conf: configuration. + :param attr_list: list of attributes. + :param match: parent node/attribute name. + :return: generated config dictionary. + """ + config = {} + for attrib in attr_list: + regex = self.map_regex(attrib) + if match: + regex = match + " " + regex + if conf: + if self.is_bool(attrib): + out = conf.find(attrib.replace("_", "-")) + + dis = conf.find(attrib.replace("_", "-") + " 'disable'") + if out >= 1: + if dis >= 1: + config[attrib] = False + else: + config[attrib] = True + else: + out = search(r"^.*" + regex + " (.+)", conf, M) + if out: + val = out.group(1).strip("'") + if self.is_num(attrib): + val = int(val) + config[attrib] = val + return config + + def map_regex(self, attrib): + """ + - This function construct the regex string. + - replace the underscore with hyphen. + :param attrib: attribute + :return: regex string + """ + regex = attrib.replace("_", "-") + if attrib == "disabled": + regex = "disable" + return regex + + def is_bool(self, attrib): + """ + This function looks for the attribute in predefined bool type set. + :param attrib: attribute. + :return: True/False + """ + bool_set = ( + "new", + "invalid", + "related", + "disabled", + "established", + "enable_default_log", + ) + return True if attrib in bool_set else False + + def is_num(self, attrib): + """ + This function looks for the attribute in predefined integer type set. + :param attrib: attribute. + :return: True/false. + """ + num_set = ("time", "code", "type", "count", "burst", "number") + return True if attrib in num_set else False diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py new file mode 100644 index 00000000..4b24803b --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/interfaces/interfaces.py @@ -0,0 +1,134 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +from re import findall, M +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.interfaces.interfaces import ( + InterfacesArgs, +) + + +class InterfacesFacts(object): + """ The vyos interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = InterfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for interfaces + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config(flags=["| grep interfaces"]) + + objs = [] + interface_names = findall( + r"^set interfaces (?:ethernet|bonding|vti|loopback|vxlan) (?:\'*)(\S+)(?:\'*)", + data, + M, + ) + if interface_names: + for interface in set(interface_names): + intf_regex = r" %s .+$" % interface.strip("'") + cfg = findall(intf_regex, data, M) + obj = self.render_config(cfg) + obj["name"] = interface.strip("'") + if obj: + objs.append(obj) + facts = {} + if objs: + facts["interfaces"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["interfaces"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + vif_conf = "\n".join(filter(lambda x: ("vif" in x), conf)) + eth_conf = "\n".join(filter(lambda x: ("vif" not in x), conf)) + config = self.parse_attribs( + ["description", "speed", "mtu", "duplex"], eth_conf + ) + config["vifs"] = self.parse_vifs(vif_conf) + + return utils.remove_empties(config) + + def parse_vifs(self, conf): + vif_names = findall(r"vif (?:\'*)(\d+)(?:\'*)", conf, M) + vifs_list = None + + if vif_names: + vifs_list = [] + for vif in set(vif_names): + vif_regex = r" %s .+$" % vif + cfg = "\n".join(findall(vif_regex, conf, M)) + obj = self.parse_attribs(["description", "mtu"], cfg) + obj["vlan_id"] = int(vif) + if obj: + vifs_list.append(obj) + vifs_list = sorted(vifs_list, key=lambda i: i["vlan_id"]) + + return vifs_list + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value and item == "mtu": + config[item] = int(value.strip("'")) + elif value: + config[item] = value.strip("'") + else: + config[item] = None + if "disable" in conf: + config["enabled"] = False + else: + config["enabled"] = True + + return utils.remove_empties(config) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py new file mode 100644 index 00000000..d1d62c23 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/l3_interfaces/l3_interfaces.py @@ -0,0 +1,143 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos l3_interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +import re +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible.module_utils.six import iteritems +from ansible_collections.ansible.netcommon.plugins.module_utils.compat import ( + ipaddress, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.l3_interfaces.l3_interfaces import ( + L3_interfacesArgs, +) + + +class L3_interfacesFacts(object): + """ The vyos l3_interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = L3_interfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for l3_interfaces + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + # operate on a collection of resource x + objs = [] + interface_names = re.findall( + r"set interfaces (?:ethernet|bonding|vti|vxlan) (?:\'*)(\S+)(?:\'*)", + data, + re.M, + ) + if interface_names: + for interface in set(interface_names): + intf_regex = r" %s .+$" % interface + cfg = re.findall(intf_regex, data, re.M) + obj = self.render_config(cfg) + obj["name"] = interface.strip("'") + if obj: + objs.append(obj) + + ansible_facts["ansible_network_resources"].pop("l3_interfaces", None) + facts = {} + if objs: + facts["l3_interfaces"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["l3_interfaces"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys from spec for null values + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + vif_conf = "\n".join(filter(lambda x: ("vif" in x), conf)) + eth_conf = "\n".join(filter(lambda x: ("vif" not in x), conf)) + config = self.parse_attribs(eth_conf) + config["vifs"] = self.parse_vifs(vif_conf) + + return utils.remove_empties(config) + + def parse_vifs(self, conf): + vif_names = re.findall(r"vif (\d+)", conf, re.M) + vifs_list = None + if vif_names: + vifs_list = [] + for vif in set(vif_names): + vif_regex = r" %s .+$" % vif + cfg = "\n".join(re.findall(vif_regex, conf, re.M)) + obj = self.parse_attribs(cfg) + obj["vlan_id"] = vif + if obj: + vifs_list.append(obj) + + return vifs_list + + def parse_attribs(self, conf): + config = {} + ipaddrs = re.findall(r"address (\S+)", conf, re.M) + config["ipv4"] = [] + config["ipv6"] = [] + + for item in ipaddrs: + item = item.strip("'") + if item == "dhcp": + config["ipv4"].append({"address": item}) + elif item == "dhcpv6": + config["ipv6"].append({"address": item}) + else: + ip_version = ipaddress.ip_address(item.split("/")[0]).version + if ip_version == 4: + config["ipv4"].append({"address": item}) + else: + config["ipv6"].append({"address": item}) + + for key, value in iteritems(config): + if value == []: + config[key] = None + + return utils.remove_empties(config) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py new file mode 100644 index 00000000..9201e5c6 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lag_interfaces/lag_interfaces.py @@ -0,0 +1,152 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos lag_interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from re import findall, search, M +from copy import deepcopy + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lag_interfaces.lag_interfaces import ( + Lag_interfacesArgs, +) + + +class Lag_interfacesFacts(object): + """ The vyos lag_interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Lag_interfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for lag_interfaces + :param module: the module instance + :param connection: the device connection + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + objs = [] + lag_names = findall(r"^set interfaces bonding (\S+)", data, M) + if lag_names: + for lag in set(lag_names): + lag_regex = r" %s .+$" % lag + cfg = findall(lag_regex, data, M) + obj = self.render_config(cfg) + + output = connection.run_commands( + ["show interfaces bonding " + lag + " slaves"] + ) + lines = output[0].splitlines() + members = [] + member = {} + if len(lines) > 1: + for line in lines[2:]: + splitted_line = line.split() + + if len(splitted_line) > 1: + member["member"] = splitted_line[0] + members.append(member) + else: + members = [] + member = {} + obj["name"] = lag.strip("'") + if members: + obj["members"] = members + + if obj: + objs.append(obj) + + facts = {} + if objs: + facts["lag_interfaces"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["lag_interfaces"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + arp_monitor_conf = "\n".join( + filter(lambda x: ("arp-monitor" in x), conf) + ) + hash_policy_conf = "\n".join( + filter(lambda x: ("hash-policy" in x), conf) + ) + lag_conf = "\n".join(filter(lambda x: ("bond" in x), conf)) + config = self.parse_attribs(["mode", "primary"], lag_conf) + config["arp_monitor"] = self.parse_arp_monitor(arp_monitor_conf) + config["hash_policy"] = self.parse_hash_policy(hash_policy_conf) + + return utils.remove_empties(config) + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value: + config[item] = value.strip("'") + else: + config[item] = None + return utils.remove_empties(config) + + def parse_arp_monitor(self, conf): + arp_monitor = None + if conf: + arp_monitor = {} + target_list = [] + interval = search(r"^.*arp-monitor interval (.+)", conf, M) + targets = findall(r"^.*arp-monitor target '(.+)'", conf, M) + if targets: + for target in targets: + target_list.append(target) + arp_monitor["target"] = target_list + if interval: + value = interval.group(1).strip("'") + arp_monitor["interval"] = int(value) + return arp_monitor + + def parse_hash_policy(self, conf): + hash_policy = None + if conf: + hash_policy = search(r"^.*hash-policy (.+)", conf, M) + hash_policy = hash_policy.group(1).strip("'") + return hash_policy diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py new file mode 100644 index 00000000..f6b343e0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/legacy/base.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The VyOS interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +import platform +import re +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + run_commands, + get_capabilities, +) + + +class LegacyFactsBase(object): + + COMMANDS = frozenset() + + def __init__(self, module): + self.module = module + self.facts = dict() + self.warnings = list() + self.responses = None + + def populate(self): + self.responses = run_commands(self.module, list(self.COMMANDS)) + + +class Default(LegacyFactsBase): + + COMMANDS = [ + "show version", + ] + + def populate(self): + super(Default, self).populate() + data = self.responses[0] + self.facts["serialnum"] = self.parse_serialnum(data) + self.facts.update(self.platform_facts()) + + def parse_serialnum(self, data): + match = re.search(r"HW S/N:\s+(\S+)", data) + if match: + return match.group(1) + + def platform_facts(self): + platform_facts = {} + + resp = get_capabilities(self.module) + device_info = resp["device_info"] + + platform_facts["system"] = device_info["network_os"] + + for item in ("model", "image", "version", "platform", "hostname"): + val = device_info.get("network_os_%s" % item) + if val: + platform_facts[item] = val + + platform_facts["api"] = resp["network_api"] + platform_facts["python_version"] = platform.python_version() + + return platform_facts + + +class Config(LegacyFactsBase): + + COMMANDS = [ + "show configuration commands", + "show system commit", + ] + + def populate(self): + super(Config, self).populate() + + self.facts["config"] = self.responses + + commits = self.responses[1] + entries = list() + entry = None + + for line in commits.split("\n"): + match = re.match(r"(\d+)\s+(.+)by(.+)via(.+)", line) + if match: + if entry: + entries.append(entry) + + entry = dict( + revision=match.group(1), + datetime=match.group(2), + by=str(match.group(3)).strip(), + via=str(match.group(4)).strip(), + comment=None, + ) + else: + entry["comment"] = line.strip() + + self.facts["commits"] = entries + + +class Neighbors(LegacyFactsBase): + + COMMANDS = [ + "show lldp neighbors", + "show lldp neighbors detail", + ] + + def populate(self): + super(Neighbors, self).populate() + + all_neighbors = self.responses[0] + if "LLDP not configured" not in all_neighbors: + neighbors = self.parse(self.responses[1]) + self.facts["neighbors"] = self.parse_neighbors(neighbors) + + def parse(self, data): + parsed = list() + values = None + for line in data.split("\n"): + if not line: + continue + elif line[0] == " ": + values += "\n%s" % line + elif line.startswith("Interface"): + if values: + parsed.append(values) + values = line + if values: + parsed.append(values) + return parsed + + def parse_neighbors(self, data): + facts = dict() + for item in data: + interface = self.parse_interface(item) + host = self.parse_host(item) + port = self.parse_port(item) + if interface not in facts: + facts[interface] = list() + facts[interface].append(dict(host=host, port=port)) + return facts + + def parse_interface(self, data): + match = re.search(r"^Interface:\s+(\S+),", data) + return match.group(1) + + def parse_host(self, data): + match = re.search(r"SysName:\s+(.+)$", data, re.M) + if match: + return match.group(1) + + def parse_port(self, data): + match = re.search(r"PortDescr:\s+(.+)$", data, re.M) + if match: + return match.group(1) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py new file mode 100644 index 00000000..3c7e2f93 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_global/lldp_global.py @@ -0,0 +1,116 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos lldp_global fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from re import findall, M +from copy import deepcopy + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_global.lldp_global import ( + Lldp_globalArgs, +) + + +class Lldp_globalFacts(object): + """ The vyos lldp_global fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Lldp_globalArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for lldp_global + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + objs = {} + lldp_output = findall(r"^set service lldp (\S+)", data, M) + if lldp_output: + for item in set(lldp_output): + lldp_regex = r" %s .+$" % item + cfg = findall(lldp_regex, data, M) + obj = self.render_config(cfg) + if obj: + objs.update(obj) + lldp_service = findall(r"^set service (lldp)?('lldp')", data, M) + if lldp_service or lldp_output: + lldp_obj = {} + lldp_obj["enable"] = True + objs.update(lldp_obj) + + facts = {} + params = utils.validate_config(self.argument_spec, {"config": objs}) + facts["lldp_global"] = utils.remove_empties(params["config"]) + + ansible_facts["ansible_network_resources"].update(facts) + + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + protocol_conf = "\n".join( + filter(lambda x: ("legacy-protocols" in x), conf) + ) + att_conf = "\n".join( + filter(lambda x: ("legacy-protocols" not in x), conf) + ) + config = self.parse_attribs(["snmp", "address"], att_conf) + config["legacy_protocols"] = self.parse_protocols(protocol_conf) + return utils.remove_empties(config) + + def parse_protocols(self, conf): + protocol_support = None + if conf: + protocols = findall(r"^.*legacy-protocols (.+)", conf, M) + if protocols: + protocol_support = [] + for protocol in protocols: + protocol_support.append(protocol.strip("'")) + return protocol_support + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value: + config[item] = value.strip("'") + else: + config[item] = None + return utils.remove_empties(config) diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py new file mode 100644 index 00000000..dcfbc6ee --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/lldp_interfaces/lldp_interfaces.py @@ -0,0 +1,155 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos lldp_interfaces fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + + +from re import findall, search, M +from copy import deepcopy + +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_interfaces.lldp_interfaces import ( + Lldp_interfacesArgs, +) + + +class Lldp_interfacesFacts(object): + """ The vyos lldp_interfaces fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Lldp_interfacesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for lldp_interfaces + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = connection.get_config() + + objs = [] + lldp_names = findall(r"^set service lldp interface (\S+)", data, M) + if lldp_names: + for lldp in set(lldp_names): + lldp_regex = r" %s .+$" % lldp + cfg = findall(lldp_regex, data, M) + obj = self.render_config(cfg) + obj["name"] = lldp.strip("'") + if obj: + objs.append(obj) + facts = {} + if objs: + facts["lldp_interfaces"] = objs + ansible_facts["ansible_network_resources"].update(facts) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + config = {} + location = {} + + civic_conf = "\n".join(filter(lambda x: ("civic-based" in x), conf)) + elin_conf = "\n".join(filter(lambda x: ("elin" in x), conf)) + coordinate_conf = "\n".join( + filter(lambda x: ("coordinate-based" in x), conf) + ) + disable = "\n".join(filter(lambda x: ("disable" in x), conf)) + + coordinate_based_conf = self.parse_attribs( + ["altitude", "datum", "longitude", "latitude"], coordinate_conf + ) + elin_based_conf = self.parse_lldp_elin_based(elin_conf) + civic_based_conf = self.parse_lldp_civic_based(civic_conf) + if disable: + config["enable"] = False + if coordinate_conf: + location["coordinate_based"] = coordinate_based_conf + config["location"] = location + elif civic_based_conf: + location["civic_based"] = civic_based_conf + config["location"] = location + elif elin_conf: + location["elin"] = elin_based_conf + config["location"] = location + + return utils.remove_empties(config) + + def parse_attribs(self, attribs, conf): + config = {} + for item in attribs: + value = utils.parse_conf_arg(conf, item) + if value: + value = value.strip("'") + if item == "altitude": + value = int(value) + config[item] = value + else: + config[item] = None + return utils.remove_empties(config) + + def parse_lldp_civic_based(self, conf): + civic_based = None + if conf: + civic_info_list = [] + civic_add_list = findall(r"^.*civic-based ca-type (.+)", conf, M) + if civic_add_list: + for civic_add in civic_add_list: + ca = civic_add.split(" ") + c_add = {} + c_add["ca_type"] = int(ca[0].strip("'")) + c_add["ca_value"] = ca[2].strip("'") + civic_info_list.append(c_add) + + country_code = search( + r"^.*civic-based country-code (.+)", conf, M + ) + civic_based = {} + civic_based["ca_info"] = civic_info_list + civic_based["country_code"] = country_code.group(1).strip("'") + return civic_based + + def parse_lldp_elin_based(self, conf): + elin_based = None + if conf: + e_num = search(r"^.* elin (.+)", conf, M) + elin_based = e_num.group(1).strip("'") + + return elin_based diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py new file mode 100644 index 00000000..00049475 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/facts/static_routes/static_routes.py @@ -0,0 +1,181 @@ +# +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The vyos static_routes fact class +It is in this file the configuration is collected from the device +for a given resource, parsed, and the facts tree is populated +based on the configuration. +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from re import findall, search, M +from copy import deepcopy +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common import ( + utils, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.static_routes.static_routes import ( + Static_routesArgs, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.utils.utils import ( + get_route_type, +) + + +class Static_routesFacts(object): + """ The vyos static_routes fact class + """ + + def __init__(self, module, subspec="config", options="options"): + self._module = module + self.argument_spec = Static_routesArgs.argument_spec + spec = deepcopy(self.argument_spec) + if subspec: + if options: + facts_argument_spec = spec[subspec][options] + else: + facts_argument_spec = spec[subspec] + else: + facts_argument_spec = spec + + self.generated_spec = utils.generate_dict(facts_argument_spec) + + def get_device_data(self, connection): + return connection.get_config() + + def populate_facts(self, connection, ansible_facts, data=None): + """ Populate the facts for static_routes + :param connection: the device connection + :param ansible_facts: Facts dictionary + :param data: previously collected conf + :rtype: dictionary + :returns: facts + """ + if not data: + data = self.get_device_data(connection) + # typically data is populated from the current device configuration + # data = connection.get('show running-config | section ^interface') + # using mock data instead + objs = [] + r_v4 = [] + r_v6 = [] + af = [] + static_routes = findall( + r"set protocols static route(6)? (\S+)", data, M + ) + if static_routes: + for route in set(static_routes): + route_regex = r" %s .+$" % route[1] + cfg = findall(route_regex, data, M) + sr = self.render_config(cfg) + sr["dest"] = route[1].strip("'") + afi = self.get_afi(sr["dest"]) + if afi == "ipv4": + r_v4.append(sr) + else: + r_v6.append(sr) + if r_v4: + afi_v4 = {"afi": "ipv4", "routes": r_v4} + af.append(afi_v4) + if r_v6: + afi_v6 = {"afi": "ipv6", "routes": r_v6} + af.append(afi_v6) + config = {"address_families": af} + if config: + objs.append(config) + + ansible_facts["ansible_network_resources"].pop("static_routes", None) + facts = {} + if objs: + facts["static_routes"] = [] + params = utils.validate_config( + self.argument_spec, {"config": objs} + ) + for cfg in params["config"]: + facts["static_routes"].append(utils.remove_empties(cfg)) + + ansible_facts["ansible_network_resources"].update(facts) + return ansible_facts + + def render_config(self, conf): + """ + Render config as dictionary structure and delete keys + from spec for null values + + :param spec: The facts tree, generated from the argspec + :param conf: The configuration + :rtype: dictionary + :returns: The generated config + """ + next_hops_conf = "\n".join(filter(lambda x: ("next-hop" in x), conf)) + blackhole_conf = "\n".join(filter(lambda x: ("blackhole" in x), conf)) + routes_dict = { + "blackhole_config": self.parse_blackhole(blackhole_conf), + "next_hops": self.parse_next_hop(next_hops_conf), + } + return routes_dict + + def parse_blackhole(self, conf): + blackhole = None + if conf: + distance = search(r"^.*blackhole distance (.\S+)", conf, M) + bh = conf.find("blackhole") + if distance is not None: + blackhole = {} + value = distance.group(1).strip("'") + blackhole["distance"] = int(value) + elif bh: + blackhole = {} + blackhole["type"] = "blackhole" + return blackhole + + def get_afi(self, address): + route_type = get_route_type(address) + if route_type == "route": + return "ipv4" + elif route_type == "route6": + return "ipv6" + + def parse_next_hop(self, conf): + nh_list = None + if conf: + nh_list = [] + hop_list = findall(r"^.*next-hop (.+)", conf, M) + if hop_list: + for hop in hop_list: + distance = search(r"^.*distance (.\S+)", hop, M) + interface = search(r"^.*interface (.\S+)", hop, M) + + dis = hop.find("disable") + hop_info = hop.split(" ") + nh_info = { + "forward_router_address": hop_info[0].strip("'") + } + if interface: + nh_info["interface"] = interface.group(1).strip("'") + if distance: + value = distance.group(1).strip("'") + nh_info["admin_distance"] = int(value) + elif dis >= 1: + nh_info["enabled"] = False + for element in nh_list: + if ( + element["forward_router_address"] + == nh_info["forward_router_address"] + ): + if "interface" in nh_info.keys(): + element["interface"] = nh_info["interface"] + if "admin_distance" in nh_info.keys(): + element["admin_distance"] = nh_info[ + "admin_distance" + ] + if "enabled" in nh_info.keys(): + element["enabled"] = nh_info["enabled"] + nh_info = None + if nh_info is not None: + nh_list.append(nh_info) + return nh_list diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py new file mode 100644 index 00000000..402adfc9 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/utils/utils.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# utils +from __future__ import absolute_import, division, print_function + +__metaclass__ = type +from ansible.module_utils.six import iteritems +from ansible_collections.ansible.netcommon.plugins.module_utils.compat import ( + ipaddress, +) + + +def search_obj_in_list(name, lst, key="name"): + for item in lst: + if item[key] == name: + return item + return None + + +def get_interface_type(interface): + """Gets the type of interface + """ + if interface.startswith("eth"): + return "ethernet" + elif interface.startswith("bond"): + return "bonding" + elif interface.startswith("vti"): + return "vti" + elif interface.startswith("lo"): + return "loopback" + + +def dict_delete(base, comparable): + """ + This function generates a dict containing key, value pairs for keys + that are present in the `base` dict but not present in the `comparable` + dict. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + :returns: new dict object with key, value pairs that needs to be deleted. + + """ + to_delete = dict() + + for key in base: + if isinstance(base[key], dict): + sub_diff = dict_delete(base[key], comparable.get(key, {})) + if sub_diff: + to_delete[key] = sub_diff + else: + if key not in comparable: + to_delete[key] = base[key] + + return to_delete + + +def diff_list_of_dicts(want, have): + diff = [] + + set_w = set(tuple(d.items()) for d in want) + set_h = set(tuple(d.items()) for d in have) + difference = set_w.difference(set_h) + + for element in difference: + diff.append(dict((x, y) for x, y in element)) + + return diff + + +def get_lst_diff_for_dicts(want, have, lst): + """ + This function generates a list containing values + that are only in want and not in list in have dict + :param want: dict object to want + :param have: dict object to have + :param lst: list the diff on + :return: new list object with values which are only in want. + """ + if not have: + diff = want.get(lst) or [] + + else: + want_elements = want.get(lst) or {} + have_elements = have.get(lst) or {} + diff = list_diff_want_only(want_elements, have_elements) + return diff + + +def get_lst_same_for_dicts(want, have, lst): + """ + This function generates a list containing values + that are common for list in want and list in have dict + :param want: dict object to want + :param have: dict object to have + :param lst: list the comparison on + :return: new list object with values which are common in want and have. + """ + diff = None + if want and have: + want_list = want.get(lst) or {} + have_list = have.get(lst) or {} + diff = [ + i + for i in want_list and have_list + if i in have_list and i in want_list + ] + return diff + + +def list_diff_have_only(want_list, have_list): + """ + This function generated the list containing values + that are only in have list. + :param want_list: + :param have_list: + :return: new list with values which are only in have list + """ + if have_list and not want_list: + diff = have_list + elif not have_list: + diff = None + else: + diff = [ + i + for i in have_list + want_list + if i in have_list and i not in want_list + ] + return diff + + +def list_diff_want_only(want_list, have_list): + """ + This function generated the list containing values + that are only in want list. + :param want_list: + :param have_list: + :return: new list with values which are only in want list + """ + if have_list and not want_list: + diff = None + elif not have_list: + diff = want_list + else: + diff = [ + i + for i in have_list + want_list + if i in want_list and i not in have_list + ] + return diff + + +def search_dict_tv_in_list(d_val1, d_val2, lst, key1, key2): + """ + This function return the dict object if it exist in list. + :param d_val1: + :param d_val2: + :param lst: + :param key1: + :param key2: + :return: + """ + obj = next( + ( + item + for item in lst + if item[key1] == d_val1 and item[key2] == d_val2 + ), + None, + ) + if obj: + return obj + else: + return None + + +def key_value_in_dict(have_key, have_value, want_dict): + """ + This function checks whether the key and values exist in dict + :param have_key: + :param have_value: + :param want_dict: + :return: + """ + for key, value in iteritems(want_dict): + if key == have_key and value == have_value: + return True + return False + + +def is_dict_element_present(dict, key): + """ + This function checks whether the key is present in dict. + :param dict: + :param key: + :return: + """ + for item in dict: + if item == key: + return True + return False + + +def get_ip_address_version(address): + """ + This function returns the version of IP address + :param address: IP address + :return: + """ + try: + address = unicode(address) + except NameError: + address = str(address) + version = ipaddress.ip_address(address.split("/")[0]).version + return version + + +def get_route_type(address): + """ + This function returns the route type based on IP address + :param address: + :return: + """ + version = get_ip_address_version(address) + if version == 6: + return "route6" + elif version == 4: + return "route" diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py new file mode 100644 index 00000000..908395a6 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/module_utils/network/vyos/vyos.py @@ -0,0 +1,124 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +import json + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import env_fallback +from ansible.module_utils.connection import Connection, ConnectionError + +_DEVICE_CONFIGS = {} + +vyos_provider_spec = { + "host": dict(), + "port": dict(type="int"), + "username": dict(fallback=(env_fallback, ["ANSIBLE_NET_USERNAME"])), + "password": dict( + fallback=(env_fallback, ["ANSIBLE_NET_PASSWORD"]), no_log=True + ), + "ssh_keyfile": dict( + fallback=(env_fallback, ["ANSIBLE_NET_SSH_KEYFILE"]), type="path" + ), + "timeout": dict(type="int"), +} +vyos_argument_spec = { + "provider": dict( + type="dict", options=vyos_provider_spec, removed_in_version=2.14 + ), +} + + +def get_provider_argspec(): + return vyos_provider_spec + + +def get_connection(module): + if hasattr(module, "_vyos_connection"): + return module._vyos_connection + + capabilities = get_capabilities(module) + network_api = capabilities.get("network_api") + if network_api == "cliconf": + module._vyos_connection = Connection(module._socket_path) + else: + module.fail_json(msg="Invalid connection type %s" % network_api) + + return module._vyos_connection + + +def get_capabilities(module): + if hasattr(module, "_vyos_capabilities"): + return module._vyos_capabilities + + try: + capabilities = Connection(module._socket_path).get_capabilities() + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + module._vyos_capabilities = json.loads(capabilities) + return module._vyos_capabilities + + +def get_config(module, flags=None, format=None): + flags = [] if flags is None else flags + global _DEVICE_CONFIGS + + if _DEVICE_CONFIGS != {}: + return _DEVICE_CONFIGS + else: + connection = get_connection(module) + try: + out = connection.get_config(flags=flags, format=format) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + cfg = to_text(out, errors="surrogate_then_replace").strip() + _DEVICE_CONFIGS = cfg + return cfg + + +def run_commands(module, commands, check_rc=True): + connection = get_connection(module) + try: + response = connection.run_commands( + commands=commands, check_rc=check_rc + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + return response + + +def load_config(module, commands, commit=False, comment=None): + connection = get_connection(module) + + try: + response = connection.edit_config( + candidate=commands, commit=commit, comment=comment + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + return response.get("diff") diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py new file mode 100644 index 00000000..18538491 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_command.py @@ -0,0 +1,223 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: vyos_command +author: Nathaniel Case (@Qalthos) +short_description: Run one or more commands on VyOS devices +description: +- The command module allows running one or more commands on remote devices running + VyOS. This module can also be introspected to validate key parameters before returning + successfully. If the conditional statements are not met in the wait period, the + task fails. +- Certain C(show) commands in VyOS produce many lines of output and use a custom pager + that can cause this module to hang. If the value of the environment variable C(ANSIBLE_VYOS_TERMINAL_LENGTH) + is not set, the default number of 10000 is used. +extends_documentation_fragment: +- vyos.vyos.vyos +options: + commands: + description: + - The ordered set of commands to execute on the remote device running VyOS. The + output from the command execution is returned to the playbook. If the I(wait_for) + argument is provided, the module is not returned until the condition is satisfied + or the number of retries has been exceeded. + required: true + wait_for: + description: + - Specifies what to evaluate from the output of the command and what conditionals + to apply. This argument will cause the task to wait for a particular conditional + to be true before moving forward. If the conditional is not true by the configured + I(retries), the task fails. See examples. + aliases: + - waitfor + match: + description: + - The I(match) argument is used in conjunction with the I(wait_for) argument to + specify the match policy. Valid values are C(all) or C(any). If the value is + set to C(all) then all conditionals in the wait_for must be satisfied. If the + value is set to C(any) then only one of the values must be satisfied. + default: all + choices: + - any + - all + retries: + description: + - Specifies the number of retries a command should be tried before it is considered + failed. The command is run on the target device every retry and evaluated against + the I(wait_for) conditionals. + default: 10 + interval: + description: + - Configures the interval in seconds to wait between I(retries) of the command. + If the command does not pass the specified conditions, the interval indicates + how long to wait before trying the command again. + default: 1 +notes: +- Tested against VyOS 1.1.8 (helium). +- Running C(show system boot-messages all) will cause the module to hang since VyOS + is using a custom pager setting to display the output of that command. +- If a command sent to the device requires answering a prompt, it is possible to pass + a dict containing I(command), I(answer) and I(prompt). See examples. +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +""" + +EXAMPLES = """ +tasks: + - name: show configuration on ethernet devices eth0 and eth1 + vyos_command: + commands: + - show interfaces ethernet {{ item }} + with_items: + - eth0 + - eth1 + + - name: run multiple commands and check if version output contains specific version string + vyos_command: + commands: + - show version + - show hardware cpu + wait_for: + - "result[0] contains 'VyOS 1.1.7'" + + - name: run command that requires answering a prompt + vyos_command: + commands: + - command: 'rollback 1' + prompt: 'Proceed with reboot? [confirm][y]' + answer: y +""" + +RETURN = """ +stdout: + description: The set of responses from the commands + returned: always apart from low level errors (such as action plugin) + type: list + sample: ['...', '...'] +stdout_lines: + description: The value of stdout split into a list + returned: always + type: list + sample: [['...', '...'], ['...'], ['...']] +failed_conditions: + description: The list of conditionals that have failed + returned: failed + type: list + sample: ['...', '...'] +warnings: + description: The list of warnings (if any) generated by module based on arguments + returned: always + type: list + sample: ['...', '...'] +""" +import time + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.parsing import ( + Conditional, +) +from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import ( + transform_commands, + to_lines, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + run_commands, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_argument_spec, +) + + +def parse_commands(module, warnings): + commands = transform_commands(module) + + if module.check_mode: + for item in list(commands): + if not item["command"].startswith("show"): + warnings.append( + "Only show commands are supported when using check mode, not " + "executing %s" % item["command"] + ) + commands.remove(item) + + return commands + + +def main(): + spec = dict( + commands=dict(type="list", required=True), + wait_for=dict(type="list", aliases=["waitfor"]), + match=dict(default="all", choices=["all", "any"]), + retries=dict(default=10, type="int"), + interval=dict(default=1, type="int"), + ) + + spec.update(vyos_argument_spec) + + module = AnsibleModule(argument_spec=spec, supports_check_mode=True) + + warnings = list() + result = {"changed": False, "warnings": warnings} + commands = parse_commands(module, warnings) + wait_for = module.params["wait_for"] or list() + + try: + conditionals = [Conditional(c) for c in wait_for] + except AttributeError as exc: + module.fail_json(msg=to_text(exc)) + + retries = module.params["retries"] + interval = module.params["interval"] + match = module.params["match"] + + for _ in range(retries): + responses = run_commands(module, commands) + + for item in list(conditionals): + if item(responses): + if match == "any": + conditionals = list() + break + conditionals.remove(item) + + if not conditionals: + break + + time.sleep(interval) + + if conditionals: + failed_conditions = [item.raw for item in conditionals] + msg = "One or more conditional statements have not been satisfied" + module.fail_json(msg=msg, failed_conditions=failed_conditions) + + result.update( + {"stdout": responses, "stdout_lines": list(to_lines(responses)),} + ) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py new file mode 100644 index 00000000..b899045a --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_config.py @@ -0,0 +1,354 @@ +#!/usr/bin/python +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: vyos_config +author: Nathaniel Case (@Qalthos) +short_description: Manage VyOS configuration on remote device +description: +- This module provides configuration file management of VyOS devices. It provides + arguments for managing both the configuration file and state of the active configuration. + All configuration statements are based on `set` and `delete` commands in the device + configuration. +extends_documentation_fragment: +- vyos.vyos.vyos +notes: +- Tested against VyOS 1.1.8 (helium). +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +options: + lines: + description: + - The ordered set of configuration lines to be managed and compared with the existing + configuration on the remote device. + src: + description: + - The C(src) argument specifies the path to the source config file to load. The + source config file can either be in bracket format or set format. The source + file can include Jinja2 template variables. + match: + description: + - The C(match) argument controls the method used to match against the current + active configuration. By default, the desired config is matched against the + active config and the deltas are loaded. If the C(match) argument is set to + C(none) the active configuration is ignored and the configuration is always + loaded. + default: line + choices: + - line + - none + backup: + description: + - The C(backup) argument will backup the current devices active configuration + to the Ansible control host prior to making any changes. If the C(backup_options) + value is not given, the backup file will be located in the backup folder in + the playbook root directory or role root directory, if playbook is part of an + ansible role. If the directory does not exist, it is created. + type: bool + default: 'no' + comment: + description: + - Allows a commit description to be specified to be included when the configuration + is committed. If the configuration is not changed or committed, this argument + is ignored. + default: configured by vyos_config + config: + description: + - The C(config) argument specifies the base configuration to use to compare against + the desired configuration. If this value is not specified, the module will + automatically retrieve the current active configuration from the remote device. + save: + description: + - The C(save) argument controls whether or not changes made to the active configuration + are saved to disk. This is independent of committing the config. When set + to True, the active configuration is saved. + type: bool + default: 'no' + backup_options: + description: + - This is a dict object containing configurable options related to backup file + path. The value of this option is read only when C(backup) is set to I(yes), + if C(backup) is set to I(no) this option will be silently ignored. + suboptions: + filename: + description: + - The filename to be used to store the backup configuration. If the filename + is not given it will be generated based on the hostname, current time and + date in format defined by <hostname>_config.<current-date>@<current-time> + dir_path: + description: + - This option provides the path ending with directory name in which the backup + configuration file will be stored. If the directory does not exist it will + be first created and the filename is either the value of C(filename) or + default filename as described in C(filename) options description. If the + path value is not given in that case a I(backup) directory will be created + in the current working directory and backup configuration will be copied + in C(filename) within I(backup) directory. + type: path + type: dict +""" + +EXAMPLES = """ +- name: configure the remote device + vyos_config: + lines: + - set system host-name {{ inventory_hostname }} + - set service lldp + - delete service dhcp-server + +- name: backup and load from file + vyos_config: + src: vyos.cfg + backup: yes + +- name: render a Jinja2 template onto the VyOS router + vyos_config: + src: vyos_template.j2 + +- name: for idempotency, use full-form commands + vyos_config: + lines: + # - set int eth eth2 description 'OUTSIDE' + - set interface ethernet eth2 description 'OUTSIDE' + +- name: configurable backup path + vyos_config: + backup: yes + backup_options: + filename: backup.cfg + dir_path: /home/user +""" + +RETURN = """ +commands: + description: The list of configuration commands sent to the device + returned: always + type: list + sample: ['...', '...'] +filtered: + description: The list of configuration commands removed to avoid a load failure + returned: always + type: list + sample: ['...', '...'] +backup_path: + description: The full path to the backup file + returned: when backup is yes + type: str + sample: /playbooks/ansible/backup/vyos_config.2016-07-16@22:28:34 +filename: + description: The name of the backup file + returned: when backup is yes and filename is not specified in backup options + type: str + sample: vyos_config.2016-07-16@22:28:34 +shortname: + description: The full path to the backup file excluding the timestamp + returned: when backup is yes and filename is not specified in backup options + type: str + sample: /playbooks/ansible/backup/vyos_config +date: + description: The date extracted from the backup file name + returned: when backup is yes + type: str + sample: "2016-07-16" +time: + description: The time extracted from the backup file name + returned: when backup is yes + type: str + sample: "22:28:34" +""" +import re + +from ansible.module_utils._text import to_text +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.connection import ConnectionError +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + load_config, + get_config, + run_commands, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_argument_spec, + get_connection, +) + + +DEFAULT_COMMENT = "configured by vyos_config" + +CONFIG_FILTERS = [ + re.compile(r"set system login user \S+ authentication encrypted-password") +] + + +def get_candidate(module): + contents = module.params["src"] or module.params["lines"] + + if module.params["src"]: + contents = format_commands(contents.splitlines()) + + contents = "\n".join(contents) + return contents + + +def format_commands(commands): + """ + This function format the input commands and removes the prepend white spaces + for command lines having 'set' or 'delete' and it skips empty lines. + :param commands: + :return: list of commands + """ + return [ + line.strip() if line.split()[0] in ("set", "delete") else line + for line in commands + if len(line.strip()) > 0 + ] + + +def diff_config(commands, config): + config = [str(c).replace("'", "") for c in config.splitlines()] + + updates = list() + visited = set() + + for line in commands: + item = str(line).replace("'", "") + + if not item.startswith("set") and not item.startswith("delete"): + raise ValueError("line must start with either `set` or `delete`") + + elif item.startswith("set") and item not in config: + updates.append(line) + + elif item.startswith("delete"): + if not config: + updates.append(line) + else: + item = re.sub(r"delete", "set", item) + for entry in config: + if entry.startswith(item) and line not in visited: + updates.append(line) + visited.add(line) + + return list(updates) + + +def sanitize_config(config, result): + result["filtered"] = list() + index_to_filter = list() + for regex in CONFIG_FILTERS: + for index, line in enumerate(list(config)): + if regex.search(line): + result["filtered"].append(line) + index_to_filter.append(index) + # Delete all filtered configs + for filter_index in sorted(index_to_filter, reverse=True): + del config[filter_index] + + +def run(module, result): + # get the current active config from the node or passed in via + # the config param + config = module.params["config"] or get_config(module) + + # create the candidate config object from the arguments + candidate = get_candidate(module) + + # create loadable config that includes only the configuration updates + connection = get_connection(module) + try: + response = connection.get_diff( + candidate=candidate, + running=config, + diff_match=module.params["match"], + ) + except ConnectionError as exc: + module.fail_json(msg=to_text(exc, errors="surrogate_then_replace")) + + commands = response.get("config_diff") + sanitize_config(commands, result) + + result["commands"] = commands + + commit = not module.check_mode + comment = module.params["comment"] + + diff = None + if commands: + diff = load_config(module, commands, commit=commit, comment=comment) + + if result.get("filtered"): + result["warnings"].append( + "Some configuration commands were " + "removed, please see the filtered key" + ) + + result["changed"] = True + + if module._diff: + result["diff"] = {"prepared": diff} + + +def main(): + backup_spec = dict(filename=dict(), dir_path=dict(type="path")) + argument_spec = dict( + src=dict(type="path"), + lines=dict(type="list"), + match=dict(default="line", choices=["line", "none"]), + comment=dict(default=DEFAULT_COMMENT), + config=dict(), + backup=dict(type="bool", default=False), + backup_options=dict(type="dict", options=backup_spec), + save=dict(type="bool", default=False), + ) + + argument_spec.update(vyos_argument_spec) + + mutually_exclusive = [("lines", "src")] + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=mutually_exclusive, + supports_check_mode=True, + ) + + warnings = list() + + result = dict(changed=False, warnings=warnings) + + if module.params["backup"]: + result["__backup__"] = get_config(module=module) + + if any((module.params["src"], module.params["lines"])): + run(module, result) + + if module.params["save"]: + diff = run_commands(module, commands=["configure", "compare saved"])[1] + if diff != "[edit]": + run_commands(module, commands=["save"]) + result["changed"] = True + run_commands(module, commands=["exit"]) + + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py new file mode 100644 index 00000000..19fb727f --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_facts.py @@ -0,0 +1,174 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +""" +The module file for vyos_facts +""" + + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": [u"preview"], + "supported_by": "network", +} + + +DOCUMENTATION = """module: vyos_facts +short_description: Get facts about vyos devices. +description: +- Collects facts from network devices running the vyos operating system. This module + places the facts gathered in the fact tree keyed by the respective resource name. The + facts module will always collect a base set of facts from the device and can enable + or disable collection of additional facts. +author: +- Nathaniel Case (@qalthos) +- Nilashish Chakraborty (@Nilashishc) +- Rohit Thakur (@rohitthakur2590) +extends_documentation_fragment: +- vyos.vyos.vyos +notes: +- Tested against VyOS 1.1.8 (helium). +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +options: + gather_subset: + description: + - When supplied, this argument will restrict the facts collected to a given subset. Possible + values for this argument include all, default, config, and neighbors. Can specify + a list of values to include a larger subset. Values can also be used with an + initial C(M(!)) to specify that a specific subset should not be collected. + required: false + default: '!config' + gather_network_resources: + description: + - When supplied, this argument will restrict the facts collected to a given subset. + Possible values for this argument include all and the resources like interfaces. + Can specify a list of values to include a larger subset. Values can also be + used with an initial C(M(!)) to specify that a specific subset should not be + collected. Valid subsets are 'all', 'interfaces', 'l3_interfaces', 'lag_interfaces', + 'lldp_global', 'lldp_interfaces', 'static_routes', 'firewall_rules'. + required: false +""" + +EXAMPLES = """ +# Gather all facts +- vyos_facts: + gather_subset: all + gather_network_resources: all + +# collect only the config and default facts +- vyos_facts: + gather_subset: config + +# collect everything exception the config +- vyos_facts: + gather_subset: "!config" + +# Collect only the interfaces facts +- vyos_facts: + gather_subset: + - '!all' + - '!min' + gather_network_resources: + - interfaces + +# Do not collect interfaces facts +- vyos_facts: + gather_network_resources: + - "!interfaces" + +# Collect interfaces and minimal default facts +- vyos_facts: + gather_subset: min + gather_network_resources: interfaces +""" + +RETURN = """ +ansible_net_config: + description: The running-config from the device + returned: when config is configured + type: str +ansible_net_commits: + description: The set of available configuration revisions + returned: when present + type: list +ansible_net_hostname: + description: The configured system hostname + returned: always + type: str +ansible_net_model: + description: The device model string + returned: always + type: str +ansible_net_serialnum: + description: The serial number of the device + returned: always + type: str +ansible_net_version: + description: The version of the software running + returned: always + type: str +ansible_net_neighbors: + description: The set of LLDP neighbors + returned: when interface is configured + type: list +ansible_net_gather_subset: + description: The list of subsets gathered by the module + returned: always + type: list +ansible_net_api: + description: The name of the transport + returned: always + type: str +ansible_net_python_version: + description: The Python version Ansible controller is using + returned: always + type: str +ansible_net_gather_network_resources: + description: The list of fact resource subsets collected from the device + returned: always + type: list +""" + +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.facts.facts import ( + FactsArgs, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.facts.facts import ( + Facts, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.vyos import ( + vyos_argument_spec, +) + + +def main(): + """ + Main entry point for module execution + + :returns: ansible_facts + """ + argument_spec = FactsArgs.argument_spec + argument_spec.update(vyos_argument_spec) + + module = AnsibleModule( + argument_spec=argument_spec, supports_check_mode=True + ) + + warnings = [] + if module.params["gather_subset"] == "!config": + warnings.append( + "default value for `gather_subset` will be changed to `min` from `!config` v2.11 onwards" + ) + + result = Facts(module).get_facts() + + ansible_facts, additional_warnings = result + warnings.extend(additional_warnings) + + module.exit_json(ansible_facts=ansible_facts, warnings=warnings) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py new file mode 100644 index 00000000..8fe572b0 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/modules/vyos_lldp_interfaces.py @@ -0,0 +1,513 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# Copyright 2019 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +############################################# +# WARNING # +############################################# +# +# This file is auto generated by the resource +# module builder playbook. +# +# Do not edit this file manually. +# +# Changes to this file will be over written +# by the resource module builder. +# +# Changes should be made in the model used to +# generate this file or in the resource module +# builder template. +# +############################################# + +""" +The module file for vyos_lldp_interfaces +""" + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +ANSIBLE_METADATA = { + "metadata_version": "1.1", + "status": ["preview"], + "supported_by": "network", +} + +DOCUMENTATION = """module: vyos_lldp_interfaces +short_description: Manages attributes of lldp interfaces on VyOS devices. +description: This module manages attributes of lldp interfaces on VyOS network devices. +notes: +- Tested against VyOS 1.1.8 (helium). +- This module works with connection C(network_cli). See L(the VyOS OS Platform Options,../network/user_guide/platform_vyos.html). +author: +- Rohit Thakur (@rohitthakur2590) +options: + config: + description: A list of lldp interfaces configurations. + type: list + suboptions: + name: + description: + - Name of the lldp interface. + type: str + required: true + enable: + description: + - to disable lldp on the interface. + type: bool + default: true + location: + description: + - LLDP-MED location data. + type: dict + suboptions: + civic_based: + description: + - Civic-based location data. + type: dict + suboptions: + ca_info: + description: LLDP-MED address info + type: list + suboptions: + ca_type: + description: LLDP-MED Civic Address type. + type: int + required: true + ca_value: + description: LLDP-MED Civic Address value. + type: str + required: true + country_code: + description: Country Code + type: str + required: true + coordinate_based: + description: + - Coordinate-based location. + type: dict + suboptions: + altitude: + description: Altitude in meters. + type: int + datum: + description: Coordinate datum type. + type: str + choices: + - WGS84 + - NAD83 + - MLLW + latitude: + description: Latitude. + type: str + required: true + longitude: + description: Longitude. + type: str + required: true + elin: + description: Emergency Call Service ELIN number (between 10-25 numbers). + type: str + state: + description: + - The state of the configuration after module completion. + type: str + choices: + - merged + - replaced + - overridden + - deleted + default: merged +""" +EXAMPLES = """ +# Using merged +# +# Before state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# +- name: Merge provided configuration with device configuration + vyos_lldp_interfaces: + config: + - name: 'eth1' + location: + civic_based: + country_code: 'US' + ca_info: + - ca_type: 0 + ca_value: 'ENGLISH' + + - name: 'eth2' + location: + coordinate_based: + altitude: 2200 + datum: 'WGS84' + longitude: '222.267255W' + latitude: '33.524449N' + state: merged +# +# +# ------------------------- +# Module Execution Result +# ------------------------- +# +# before": [] +# +# "commands": [ +# "set service lldp interface eth1 location civic-based country-code 'US'", +# "set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH'", +# "set service lldp interface eth1", +# "set service lldp interface eth2 location coordinate-based latitude '33.524449N'", +# "set service lldp interface eth2 location coordinate-based altitude '2200'", +# "set service lldp interface eth2 location coordinate-based datum 'WGS84'", +# "set service lldp interface eth2 location coordinate-based longitude '222.267255W'", +# "set service lldp interface eth2 location coordinate-based latitude '33.524449N'", +# "set service lldp interface eth2 location coordinate-based altitude '2200'", +# "set service lldp interface eth2 location coordinate-based datum 'WGS84'", +# "set service lldp interface eth2 location coordinate-based longitude '222.267255W'", +# "set service lldp interface eth2" +# +# "after": [ +# { +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth2" +# }, +# { +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth1" +# } +# ], +# +# After state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth1 location civic-based country-code 'US' +# set service lldp interface eth2 location coordinate-based altitude '2200' +# set service lldp interface eth2 location coordinate-based datum 'WGS84' +# set service lldp interface eth2 location coordinate-based latitude '33.524449N' +# set service lldp interface eth2 location coordinate-based longitude '222.267255W' + + +# Using replaced +# +# Before state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth1 location civic-based country-code 'US' +# set service lldp interface eth2 location coordinate-based altitude '2200' +# set service lldp interface eth2 location coordinate-based datum 'WGS84' +# set service lldp interface eth2 location coordinate-based latitude '33.524449N' +# set service lldp interface eth2 location coordinate-based longitude '222.267255W' +# +- name: Replace device configurations of listed LLDP interfaces with provided configurations + vyos_lldp_interfaces: + config: + - name: 'eth2' + location: + civic_based: + country_code: 'US' + ca_info: + - ca_type: 0 + ca_value: 'ENGLISH' + + - name: 'eth1' + location: + coordinate_based: + altitude: 2200 + datum: 'WGS84' + longitude: '222.267255W' + latitude: '33.524449N' + state: replaced +# +# +# ------------------------- +# Module Execution Result +# ------------------------- +# +# "before": [ +# { +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth2" +# }, +# { +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth1" +# } +# ] +# +# "commands": [ +# "delete service lldp interface eth2 location", +# "set service lldp interface eth2 'disable'", +# "set service lldp interface eth2 location civic-based country-code 'US'", +# "set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH'", +# "delete service lldp interface eth1 location", +# "set service lldp interface eth1 'disable'", +# "set service lldp interface eth1 location coordinate-based latitude '33.524449N'", +# "set service lldp interface eth1 location coordinate-based altitude '2200'", +# "set service lldp interface eth1 location coordinate-based datum 'WGS84'", +# "set service lldp interface eth1 location coordinate-based longitude '222.267255W'" +# ] +# +# "after": [ +# { +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth2" +# }, +# { +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth1" +# } +# ] +# +# After state: +# ------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 'disable' +# set service lldp interface eth1 location coordinate-based altitude '2200' +# set service lldp interface eth1 location coordinate-based datum 'WGS84' +# set service lldp interface eth1 location coordinate-based latitude '33.524449N' +# set service lldp interface eth1 location coordinate-based longitude '222.267255W' +# set service lldp interface eth2 'disable' +# set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth2 location civic-based country-code 'US' + + +# Using overridden +# +# Before state +# -------------- +# +# vyos@vyos:~$ show configuration commands | grep lldp +# set service lldp interface eth1 'disable' +# set service lldp interface eth1 location coordinate-based altitude '2200' +# set service lldp interface eth1 location coordinate-based datum 'WGS84' +# set service lldp interface eth1 location coordinate-based latitude '33.524449N' +# set service lldp interface eth1 location coordinate-based longitude '222.267255W' +# set service lldp interface eth2 'disable' +# set service lldp interface eth2 location civic-based ca-type 0 ca-value 'ENGLISH' +# set service lldp interface eth2 location civic-based country-code 'US' +# +- name: Overrides all device configuration with provided configuration + vyos_lag_interfaces: + config: + - name: 'eth2' + location: + elin: 0000000911 + + state: overridden +# +# +# ------------------------- +# Module Execution Result +# ------------------------- +# +# "before": [ +# { +# "enable": false, +# "location": { +# "civic_based": { +# "ca_info": [ +# { +# "ca_type": 0, +# "ca_value": "ENGLISH" +# } +# ], +# "country_code": "US" +# } +# }, +# "name": "eth2" +# }, +# { +# "enable": false, +# "location": { +# "coordinate_based": { +# "altitude": 2200, +# "datum": "WGS84", +# "latitude": "33.524449N", +# "longitude": "222.267255W" +# } +# }, +# "name": "eth1" +# } +# ] +# +# "commands": [ +# "delete service lldp interface eth2 location", +# "delete service lldp interface eth2 disable", +# "set service lldp interface eth2 location elin 0000000911" +# +# +# "after": [ +# { +# "location": { +# "elin": 0000000911 +# }, +# "name": "eth2" +# } +# ] +# +# +# After state +# ------------ +# +# vyos@vyos# run show configuration commands | grep lldp +# set service lldp interface eth2 location elin '0000000911' + + +# Using deleted +# +# Before state +# ------------- +# +# vyos@vyos# run show configuration commands | grep lldp +# set service lldp interface eth2 location elin '0000000911' +# +- name: Delete lldp interface attributes of given interfaces. + vyos_lag_interfaces: + config: + - name: 'eth2' + state: deleted +# +# +# ------------------------ +# Module Execution Results +# ------------------------ +# + "before": [ + { + "location": { + "elin": 0000000911 + }, + "name": "eth2" + } + ] +# "commands": [ +# "commands": [ +# "delete service lldp interface eth2" +# ] +# +# "after": [] +# After state +# ------------ +# vyos@vyos# run show configuration commands | grep lldp +# set service 'lldp' + + +""" +RETURN = """ +before: + description: The configuration as structured data prior to module invocation. + returned: always + type: list + sample: > + The configuration returned will always be in the same format + of the parameters above. +after: + description: The configuration as structured data after module completion. + returned: when changed + type: list + sample: > + The configuration returned will always be in the same format + of the parameters above. +commands: + description: The set of commands pushed to the remote device. + returned: always + type: list + sample: + - "set service lldp interface eth2 'disable'" + - "delete service lldp interface eth1 location" +""" + + +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.argspec.lldp_interfaces.lldp_interfaces import ( + Lldp_interfacesArgs, +) +from ansible_collections.vyos.vyos.plugins.module_utils.network.vyos.config.lldp_interfaces.lldp_interfaces import ( + Lldp_interfaces, +) + + +def main(): + """ + Main entry point for module execution + + :returns: the result form module invocation + """ + required_if = [ + ("state", "merged", ("config",)), + ("state", "replaced", ("config",)), + ("state", "overridden", ("config",)), + ] + module = AnsibleModule( + argument_spec=Lldp_interfacesArgs.argument_spec, + required_if=required_if, + supports_check_mode=True, + ) + + result = Lldp_interfaces(module).execute_module() + module.exit_json(**result) + + +if __name__ == "__main__": + main() diff --git a/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py new file mode 100644 index 00000000..fe7712f6 --- /dev/null +++ b/test/support/network-integration/collections/ansible_collections/vyos/vyos/plugins/terminal/vyos.py @@ -0,0 +1,53 @@ +# +# (c) 2016 Red Hat Inc. +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import os +import re + +from ansible.plugins.terminal import TerminalBase +from ansible.errors import AnsibleConnectionFailure + + +class TerminalModule(TerminalBase): + + terminal_stdout_re = [ + re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"), + re.compile(br"\@[\w\-\.]+:\S+?[>#\$] ?$"), + ] + + terminal_stderr_re = [ + re.compile(br"\n\s*Invalid command:"), + re.compile(br"\nCommit failed"), + re.compile(br"\n\s+Set failed"), + ] + + terminal_length = os.getenv("ANSIBLE_VYOS_TERMINAL_LENGTH", 10000) + + def on_open_shell(self): + try: + for cmd in (b"set terminal length 0", b"set terminal width 512"): + self._exec_cli_command(cmd) + self._exec_cli_command( + b"set terminal length %d" % self.terminal_length + ) + except AnsibleConnectionFailure: + raise AnsibleConnectionFailure("unable to set terminal parameters") diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py new file mode 120000 index 00000000..0364d766 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/action/win_copy.py @@ -0,0 +1 @@ +../../../../../../plugins/action/win_copy.py
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1 new file mode 120000 index 00000000..6fc438d6 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/async_status.ps1 @@ -0,0 +1 @@ +../../../../../../plugins/modules/async_status.ps1
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1 new file mode 120000 index 00000000..81d8afa3 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.ps1 @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_acl.ps1
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py new file mode 120000 index 00000000..3a2434cf --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_acl.py @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_acl.py
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1 new file mode 120000 index 00000000..a34fb012 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.ps1 @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_copy.ps1
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py new file mode 120000 index 00000000..2d2c69a2 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_copy.py @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_copy.py
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1 new file mode 120000 index 00000000..8ee5c2b5 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.ps1 @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_file.ps1
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py new file mode 120000 index 00000000..b4bc0583 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_file.py @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_file.py
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1 new file mode 120000 index 00000000..d7b25ed0 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.ps1 @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_ping.ps1
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py new file mode 120000 index 00000000..0b97c87b --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_ping.py @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_ping.py
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1 new file mode 120000 index 00000000..eb07a017 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.ps1 @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_shell.ps1
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py new file mode 120000 index 00000000..3c6f0749 --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_shell.py @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_shell.py
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1 b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1 new file mode 120000 index 00000000..62a7a40a --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.ps1 @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_stat.ps1
\ No newline at end of file diff --git a/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py new file mode 120000 index 00000000..1db4c95e --- /dev/null +++ b/test/support/windows-integration/collections/ansible_collections/ansible/windows/plugins/modules/win_stat.py @@ -0,0 +1 @@ +../../../../../../plugins/modules/win_stat.py
\ No newline at end of file diff --git a/test/support/windows-integration/plugins/action/win_copy.py b/test/support/windows-integration/plugins/action/win_copy.py new file mode 100644 index 00000000..adb918be --- /dev/null +++ b/test/support/windows-integration/plugins/action/win_copy.py @@ -0,0 +1,522 @@ +# This file is part of Ansible + +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import base64 +import json +import os +import os.path +import shutil +import tempfile +import traceback +import zipfile + +from ansible import constants as C +from ansible.errors import AnsibleError, AnsibleFileNotFound +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.parsing.convert_bool import boolean +from ansible.plugins.action import ActionBase +from ansible.utils.hashing import checksum + + +def _walk_dirs(topdir, loader, decrypt=True, base_path=None, local_follow=False, trailing_slash_detector=None, checksum_check=False): + """ + Walk a filesystem tree returning enough information to copy the files. + This is similar to the _walk_dirs function in ``copy.py`` but returns + a dict instead of a tuple for each entry and includes the checksum of + a local file if wanted. + + :arg topdir: The directory that the filesystem tree is rooted at + :arg loader: The self._loader object from ActionBase + :kwarg decrypt: Whether to decrypt a file encrypted with ansible-vault + :kwarg base_path: The initial directory structure to strip off of the + files for the destination directory. If this is None (the default), + the base_path is set to ``top_dir``. + :kwarg local_follow: Whether to follow symlinks on the source. When set + to False, no symlinks are dereferenced. When set to True (the + default), the code will dereference most symlinks. However, symlinks + can still be present if needed to break a circular link. + :kwarg trailing_slash_detector: Function to determine if a path has + a trailing directory separator. Only needed when dealing with paths on + a remote machine (in which case, pass in a function that is aware of the + directory separator conventions on the remote machine). + :kawrg whether to get the checksum of the local file and add to the dict + :returns: dictionary of dictionaries. All of the path elements in the structure are text string. + This separates all the files, directories, and symlinks along with + import information about each:: + + { + 'files'; [{ + src: '/absolute/path/to/copy/from', + dest: 'relative/path/to/copy/to', + checksum: 'b54ba7f5621240d403f06815f7246006ef8c7d43' + }, ...], + 'directories'; [{ + src: '/absolute/path/to/copy/from', + dest: 'relative/path/to/copy/to' + }, ...], + 'symlinks'; [{ + src: '/symlink/target/path', + dest: 'relative/path/to/copy/to' + }, ...], + + } + + The ``symlinks`` field is only populated if ``local_follow`` is set to False + *or* a circular symlink cannot be dereferenced. The ``checksum`` entry is set + to None if checksum_check=False. + + """ + # Convert the path segments into byte strings + + r_files = {'files': [], 'directories': [], 'symlinks': []} + + def _recurse(topdir, rel_offset, parent_dirs, rel_base=u'', checksum_check=False): + """ + This is a closure (function utilizing variables from it's parent + function's scope) so that we only need one copy of all the containers. + Note that this function uses side effects (See the Variables used from + outer scope). + + :arg topdir: The directory we are walking for files + :arg rel_offset: Integer defining how many characters to strip off of + the beginning of a path + :arg parent_dirs: Directories that we're copying that this directory is in. + :kwarg rel_base: String to prepend to the path after ``rel_offset`` is + applied to form the relative path. + + Variables used from the outer scope + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + :r_files: Dictionary of files in the hierarchy. See the return value + for :func:`walk` for the structure of this dictionary. + :local_follow: Read-only inside of :func:`_recurse`. Whether to follow symlinks + """ + for base_path, sub_folders, files in os.walk(topdir): + for filename in files: + filepath = os.path.join(base_path, filename) + dest_filepath = os.path.join(rel_base, filepath[rel_offset:]) + + if os.path.islink(filepath): + # Dereference the symlnk + real_file = loader.get_real_file(os.path.realpath(filepath), decrypt=decrypt) + if local_follow and os.path.isfile(real_file): + # Add the file pointed to by the symlink + r_files['files'].append( + { + "src": real_file, + "dest": dest_filepath, + "checksum": _get_local_checksum(checksum_check, real_file) + } + ) + else: + # Mark this file as a symlink to copy + r_files['symlinks'].append({"src": os.readlink(filepath), "dest": dest_filepath}) + else: + # Just a normal file + real_file = loader.get_real_file(filepath, decrypt=decrypt) + r_files['files'].append( + { + "src": real_file, + "dest": dest_filepath, + "checksum": _get_local_checksum(checksum_check, real_file) + } + ) + + for dirname in sub_folders: + dirpath = os.path.join(base_path, dirname) + dest_dirpath = os.path.join(rel_base, dirpath[rel_offset:]) + real_dir = os.path.realpath(dirpath) + dir_stats = os.stat(real_dir) + + if os.path.islink(dirpath): + if local_follow: + if (dir_stats.st_dev, dir_stats.st_ino) in parent_dirs: + # Just insert the symlink if the target directory + # exists inside of the copy already + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Walk the dirpath to find all parent directories. + new_parents = set() + parent_dir_list = os.path.dirname(dirpath).split(os.path.sep) + for parent in range(len(parent_dir_list), 0, -1): + parent_stat = os.stat(u'/'.join(parent_dir_list[:parent])) + if (parent_stat.st_dev, parent_stat.st_ino) in parent_dirs: + # Reached the point at which the directory + # tree is already known. Don't add any + # more or we might go to an ancestor that + # isn't being copied. + break + new_parents.add((parent_stat.st_dev, parent_stat.st_ino)) + + if (dir_stats.st_dev, dir_stats.st_ino) in new_parents: + # This was a a circular symlink. So add it as + # a symlink + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Walk the directory pointed to by the symlink + r_files['directories'].append({"src": real_dir, "dest": dest_dirpath}) + offset = len(real_dir) + 1 + _recurse(real_dir, offset, parent_dirs.union(new_parents), + rel_base=dest_dirpath, + checksum_check=checksum_check) + else: + # Add the symlink to the destination + r_files['symlinks'].append({"src": os.readlink(dirpath), "dest": dest_dirpath}) + else: + # Just a normal directory + r_files['directories'].append({"src": dirpath, "dest": dest_dirpath}) + + # Check if the source ends with a "/" so that we know which directory + # level to work at (similar to rsync) + source_trailing_slash = False + if trailing_slash_detector: + source_trailing_slash = trailing_slash_detector(topdir) + else: + source_trailing_slash = topdir.endswith(os.path.sep) + + # Calculate the offset needed to strip the base_path to make relative + # paths + if base_path is None: + base_path = topdir + if not source_trailing_slash: + base_path = os.path.dirname(base_path) + if topdir.startswith(base_path): + offset = len(base_path) + + # Make sure we're making the new paths relative + if trailing_slash_detector and not trailing_slash_detector(base_path): + offset += 1 + elif not base_path.endswith(os.path.sep): + offset += 1 + + if os.path.islink(topdir) and not local_follow: + r_files['symlinks'] = {"src": os.readlink(topdir), "dest": os.path.basename(topdir)} + return r_files + + dir_stats = os.stat(topdir) + parents = frozenset(((dir_stats.st_dev, dir_stats.st_ino),)) + # Actually walk the directory hierarchy + _recurse(topdir, offset, parents, checksum_check=checksum_check) + + return r_files + + +def _get_local_checksum(get_checksum, local_path): + if get_checksum: + return checksum(local_path) + else: + return None + + +class ActionModule(ActionBase): + + WIN_PATH_SEPARATOR = "\\" + + def _create_content_tempfile(self, content): + ''' Create a tempfile containing defined content ''' + fd, content_tempfile = tempfile.mkstemp(dir=C.DEFAULT_LOCAL_TMP) + f = os.fdopen(fd, 'wb') + content = to_bytes(content) + try: + f.write(content) + except Exception as err: + os.remove(content_tempfile) + raise Exception(err) + finally: + f.close() + return content_tempfile + + def _create_zip_tempfile(self, files, directories): + tmpdir = tempfile.mkdtemp(dir=C.DEFAULT_LOCAL_TMP) + zip_file_path = os.path.join(tmpdir, "win_copy.zip") + zip_file = zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_STORED, True) + + # encoding the file/dir name with base64 so Windows can unzip a unicode + # filename and get the right name, Windows doesn't handle unicode names + # very well + for directory in directories: + directory_path = to_bytes(directory['src'], errors='surrogate_or_strict') + archive_path = to_bytes(directory['dest'], errors='surrogate_or_strict') + + encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict') + zip_file.write(directory_path, encoded_path, zipfile.ZIP_DEFLATED) + + for file in files: + file_path = to_bytes(file['src'], errors='surrogate_or_strict') + archive_path = to_bytes(file['dest'], errors='surrogate_or_strict') + + encoded_path = to_text(base64.b64encode(archive_path), errors='surrogate_or_strict') + zip_file.write(file_path, encoded_path, zipfile.ZIP_DEFLATED) + + return zip_file_path + + def _remove_tempfile_if_content_defined(self, content, content_tempfile): + if content is not None: + os.remove(content_tempfile) + + def _copy_single_file(self, local_file, dest, source_rel, task_vars, tmp, backup): + if self._play_context.check_mode: + module_return = dict(changed=True) + return module_return + + # copy the file across to the server + tmp_src = self._connection._shell.join_path(tmp, 'source') + self._transfer_file(local_file, tmp_src) + + copy_args = self._task.args.copy() + copy_args.update( + dict( + dest=dest, + src=tmp_src, + _original_basename=source_rel, + _copy_mode="single", + backup=backup, + ) + ) + copy_args.pop('content', None) + + copy_result = self._execute_module(module_name="copy", + module_args=copy_args, + task_vars=task_vars) + + return copy_result + + def _copy_zip_file(self, dest, files, directories, task_vars, tmp, backup): + # create local zip file containing all the files and directories that + # need to be copied to the server + if self._play_context.check_mode: + module_return = dict(changed=True) + return module_return + + try: + zip_file = self._create_zip_tempfile(files, directories) + except Exception as e: + module_return = dict( + changed=False, + failed=True, + msg="failed to create tmp zip file: %s" % to_text(e), + exception=traceback.format_exc() + ) + return module_return + + zip_path = self._loader.get_real_file(zip_file) + + # send zip file to remote, file must end in .zip so + # Com Shell.Application works + tmp_src = self._connection._shell.join_path(tmp, 'source.zip') + self._transfer_file(zip_path, tmp_src) + + # run the explode operation of win_copy on remote + copy_args = self._task.args.copy() + copy_args.update( + dict( + src=tmp_src, + dest=dest, + _copy_mode="explode", + backup=backup, + ) + ) + copy_args.pop('content', None) + module_return = self._execute_module(module_name='copy', + module_args=copy_args, + task_vars=task_vars) + shutil.rmtree(os.path.dirname(zip_path)) + return module_return + + def run(self, tmp=None, task_vars=None): + ''' handler for file transfer operations ''' + if task_vars is None: + task_vars = dict() + + result = super(ActionModule, self).run(tmp, task_vars) + del tmp # tmp no longer has any effect + + source = self._task.args.get('src', None) + content = self._task.args.get('content', None) + dest = self._task.args.get('dest', None) + remote_src = boolean(self._task.args.get('remote_src', False), strict=False) + local_follow = boolean(self._task.args.get('local_follow', False), strict=False) + force = boolean(self._task.args.get('force', True), strict=False) + decrypt = boolean(self._task.args.get('decrypt', True), strict=False) + backup = boolean(self._task.args.get('backup', False), strict=False) + + result['src'] = source + result['dest'] = dest + + result['failed'] = True + if (source is None and content is None) or dest is None: + result['msg'] = "src (or content) and dest are required" + elif source is not None and content is not None: + result['msg'] = "src and content are mutually exclusive" + elif content is not None and dest is not None and ( + dest.endswith(os.path.sep) or dest.endswith(self.WIN_PATH_SEPARATOR)): + result['msg'] = "dest must be a file if content is defined" + else: + del result['failed'] + + if result.get('failed'): + return result + + # If content is defined make a temp file and write the content into it + content_tempfile = None + if content is not None: + try: + # if content comes to us as a dict it should be decoded json. + # We need to encode it back into a string and write it out + if isinstance(content, dict) or isinstance(content, list): + content_tempfile = self._create_content_tempfile(json.dumps(content)) + else: + content_tempfile = self._create_content_tempfile(content) + source = content_tempfile + except Exception as err: + result['failed'] = True + result['msg'] = "could not write content tmp file: %s" % to_native(err) + return result + # all actions should occur on the remote server, run win_copy module + elif remote_src: + new_module_args = self._task.args.copy() + new_module_args.update( + dict( + _copy_mode="remote", + dest=dest, + src=source, + force=force, + backup=backup, + ) + ) + new_module_args.pop('content', None) + result.update(self._execute_module(module_args=new_module_args, task_vars=task_vars)) + return result + # find_needle returns a path that may not have a trailing slash on a + # directory so we need to find that out first and append at the end + else: + trailing_slash = source.endswith(os.path.sep) + try: + # find in expected paths + source = self._find_needle('files', source) + except AnsibleError as e: + result['failed'] = True + result['msg'] = to_text(e) + result['exception'] = traceback.format_exc() + return result + + if trailing_slash != source.endswith(os.path.sep): + if source[-1] == os.path.sep: + source = source[:-1] + else: + source = source + os.path.sep + + # A list of source file tuples (full_path, relative_path) which will try to copy to the destination + source_files = {'files': [], 'directories': [], 'symlinks': []} + + # If source is a directory populate our list else source is a file and translate it to a tuple. + if os.path.isdir(to_bytes(source, errors='surrogate_or_strict')): + result['operation'] = 'folder_copy' + + # Get a list of the files we want to replicate on the remote side + source_files = _walk_dirs(source, self._loader, decrypt=decrypt, local_follow=local_follow, + trailing_slash_detector=self._connection._shell.path_has_trailing_slash, + checksum_check=force) + + # If it's recursive copy, destination is always a dir, + # explicitly mark it so (note - win_copy module relies on this). + if not self._connection._shell.path_has_trailing_slash(dest): + dest = "%s%s" % (dest, self.WIN_PATH_SEPARATOR) + + check_dest = dest + # Source is a file, add details to source_files dict + else: + result['operation'] = 'file_copy' + + # If the local file does not exist, get_real_file() raises AnsibleFileNotFound + try: + source_full = self._loader.get_real_file(source, decrypt=decrypt) + except AnsibleFileNotFound as e: + result['failed'] = True + result['msg'] = "could not find src=%s, %s" % (source_full, to_text(e)) + return result + + original_basename = os.path.basename(source) + result['original_basename'] = original_basename + + # check if dest ends with / or \ and append source filename to dest + if self._connection._shell.path_has_trailing_slash(dest): + check_dest = dest + filename = original_basename + result['dest'] = self._connection._shell.join_path(dest, filename) + else: + # replace \\ with / so we can use os.path to get the filename or dirname + unix_path = dest.replace(self.WIN_PATH_SEPARATOR, os.path.sep) + filename = os.path.basename(unix_path) + check_dest = os.path.dirname(unix_path) + + file_checksum = _get_local_checksum(force, source_full) + source_files['files'].append( + dict( + src=source_full, + dest=filename, + checksum=file_checksum + ) + ) + result['checksum'] = file_checksum + result['size'] = os.path.getsize(to_bytes(source_full, errors='surrogate_or_strict')) + + # find out the files/directories/symlinks that we need to copy to the server + query_args = self._task.args.copy() + query_args.update( + dict( + _copy_mode="query", + dest=check_dest, + force=force, + files=source_files['files'], + directories=source_files['directories'], + symlinks=source_files['symlinks'], + ) + ) + # src is not required for query, will fail path validation is src has unix allowed chars + query_args.pop('src', None) + + query_args.pop('content', None) + query_return = self._execute_module(module_args=query_args, + task_vars=task_vars) + + if query_return.get('failed') is True: + result.update(query_return) + return result + + if len(query_return['files']) > 0 or len(query_return['directories']) > 0 and self._connection._shell.tmpdir is None: + self._connection._shell.tmpdir = self._make_tmp_path() + + if len(query_return['files']) == 1 and len(query_return['directories']) == 0: + # we only need to copy 1 file, don't mess around with zips + file_src = query_return['files'][0]['src'] + file_dest = query_return['files'][0]['dest'] + result.update(self._copy_single_file(file_src, dest, file_dest, + task_vars, self._connection._shell.tmpdir, backup)) + if result.get('failed') is True: + result['msg'] = "failed to copy file %s: %s" % (file_src, result['msg']) + result['changed'] = True + + elif len(query_return['files']) > 0 or len(query_return['directories']) > 0: + # either multiple files or directories need to be copied, compress + # to a zip and 'explode' the zip on the server + # TODO: handle symlinks + result.update(self._copy_zip_file(dest, source_files['files'], + source_files['directories'], + task_vars, self._connection._shell.tmpdir, backup)) + result['changed'] = True + else: + # no operations need to occur + result['failed'] = False + result['changed'] = False + + # remove the content tmp file and remote tmp file if it was created + self._remove_tempfile_if_content_defined(content, content_tempfile) + self._remove_tmp_path(self._connection._shell.tmpdir) + return result diff --git a/test/support/windows-integration/plugins/action/win_reboot.py b/test/support/windows-integration/plugins/action/win_reboot.py new file mode 100644 index 00000000..c408f4f3 --- /dev/null +++ b/test/support/windows-integration/plugins/action/win_reboot.py @@ -0,0 +1,96 @@ +# Copyright: (c) 2018, Matt Davis <mdavis@ansible.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from datetime import datetime + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_native +from ansible.plugins.action import ActionBase +from ansible.plugins.action.reboot import ActionModule as RebootActionModule +from ansible.utils.display import Display + +display = Display() + + +class TimedOutException(Exception): + pass + + +class ActionModule(RebootActionModule, ActionBase): + TRANSFERS_FILES = False + _VALID_ARGS = frozenset(( + 'connect_timeout', 'connect_timeout_sec', 'msg', 'post_reboot_delay', 'post_reboot_delay_sec', 'pre_reboot_delay', 'pre_reboot_delay_sec', + 'reboot_timeout', 'reboot_timeout_sec', 'shutdown_timeout', 'shutdown_timeout_sec', 'test_command', + )) + + DEFAULT_BOOT_TIME_COMMAND = "(Get-WmiObject -ClassName Win32_OperatingSystem).LastBootUpTime" + DEFAULT_CONNECT_TIMEOUT = 5 + DEFAULT_PRE_REBOOT_DELAY = 2 + DEFAULT_SUDOABLE = False + DEFAULT_SHUTDOWN_COMMAND_ARGS = '/r /t {delay_sec} /c "{message}"' + + DEPRECATED_ARGS = { + 'shutdown_timeout': '2.5', + 'shutdown_timeout_sec': '2.5', + } + + def __init__(self, *args, **kwargs): + super(ActionModule, self).__init__(*args, **kwargs) + + def get_distribution(self, task_vars): + return {'name': 'windows', 'version': '', 'family': ''} + + def get_shutdown_command(self, task_vars, distribution): + return self.DEFAULT_SHUTDOWN_COMMAND + + def run_test_command(self, distribution, **kwargs): + # Need to wrap the test_command in our PowerShell encoded wrapper. This is done to align the command input to a + # common shell and to allow the psrp connection plugin to report the correct exit code without manually setting + # $LASTEXITCODE for just that plugin. + test_command = self._task.args.get('test_command', self.DEFAULT_TEST_COMMAND) + kwargs['test_command'] = self._connection._shell._encode_script(test_command) + super(ActionModule, self).run_test_command(distribution, **kwargs) + + def perform_reboot(self, task_vars, distribution): + shutdown_command = self.get_shutdown_command(task_vars, distribution) + shutdown_command_args = self.get_shutdown_command_args(distribution) + reboot_command = self._connection._shell._encode_script('{0} {1}'.format(shutdown_command, shutdown_command_args)) + + display.vvv("{action}: rebooting server...".format(action=self._task.action)) + display.debug("{action}: distribution: {dist}".format(action=self._task.action, dist=distribution)) + display.debug("{action}: rebooting server with command '{command}'".format(action=self._task.action, command=reboot_command)) + + result = {} + reboot_result = self._low_level_execute_command(reboot_command, sudoable=self.DEFAULT_SUDOABLE) + result['start'] = datetime.utcnow() + + # Test for "A system shutdown has already been scheduled. (1190)" and handle it gracefully + stdout = reboot_result['stdout'] + stderr = reboot_result['stderr'] + if reboot_result['rc'] == 1190 or (reboot_result['rc'] != 0 and "(1190)" in reboot_result['stderr']): + display.warning('A scheduled reboot was pre-empted by Ansible.') + + # Try to abort (this may fail if it was already aborted) + result1 = self._low_level_execute_command(self._connection._shell._encode_script('shutdown /a'), + sudoable=self.DEFAULT_SUDOABLE) + + # Initiate reboot again + result2 = self._low_level_execute_command(reboot_command, sudoable=self.DEFAULT_SUDOABLE) + + reboot_result['rc'] = result2['rc'] + stdout += result1['stdout'] + result2['stdout'] + stderr += result1['stderr'] + result2['stderr'] + + if reboot_result['rc'] != 0: + result['failed'] = True + result['rebooted'] = False + result['msg'] = "Reboot command failed, error was: {stdout} {stderr}".format( + stdout=to_native(stdout.strip()), + stderr=to_native(stderr.strip())) + return result + + result['failed'] = False + return result diff --git a/test/support/windows-integration/plugins/action/win_template.py b/test/support/windows-integration/plugins/action/win_template.py new file mode 100644 index 00000000..20494b93 --- /dev/null +++ b/test/support/windows-integration/plugins/action/win_template.py @@ -0,0 +1,29 @@ +# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.plugins.action import ActionBase +from ansible.plugins.action.template import ActionModule as TemplateActionModule + + +# Even though TemplateActionModule inherits from ActionBase, we still need to +# directly inherit from ActionBase to appease the plugin loader. +class ActionModule(TemplateActionModule, ActionBase): + DEFAULT_NEWLINE_SEQUENCE = '\r\n' diff --git a/test/support/windows-integration/plugins/become/runas.py b/test/support/windows-integration/plugins/become/runas.py new file mode 100644 index 00000000..c8ae881c --- /dev/null +++ b/test/support/windows-integration/plugins/become/runas.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = """ + become: runas + short_description: Run As user + description: + - This become plugins allows your remote/login user to execute commands as another user via the windows runas facility. + author: ansible (@core) + version_added: "2.8" + options: + become_user: + description: User you 'become' to execute the task + ini: + - section: privilege_escalation + key: become_user + - section: runas_become_plugin + key: user + vars: + - name: ansible_become_user + - name: ansible_runas_user + env: + - name: ANSIBLE_BECOME_USER + - name: ANSIBLE_RUNAS_USER + required: True + become_flags: + description: Options to pass to runas, a space delimited list of k=v pairs + default: '' + ini: + - section: privilege_escalation + key: become_flags + - section: runas_become_plugin + key: flags + vars: + - name: ansible_become_flags + - name: ansible_runas_flags + env: + - name: ANSIBLE_BECOME_FLAGS + - name: ANSIBLE_RUNAS_FLAGS + become_pass: + description: password + ini: + - section: runas_become_plugin + key: password + vars: + - name: ansible_become_password + - name: ansible_become_pass + - name: ansible_runas_pass + env: + - name: ANSIBLE_BECOME_PASS + - name: ANSIBLE_RUNAS_PASS + notes: + - runas is really implemented in the powershell module handler and as such can only be used with winrm connections. + - This plugin ignores the 'become_exe' setting as it uses an API and not an executable. + - The Secondary Logon service (seclogon) must be running to use runas +""" + +from ansible.plugins.become import BecomeBase + + +class BecomeModule(BecomeBase): + + name = 'runas' + + def build_become_command(self, cmd, shell): + # runas is implemented inside the winrm connection plugin + return cmd diff --git a/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs b/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs new file mode 100644 index 00000000..be0f3db3 --- /dev/null +++ b/test/support/windows-integration/plugins/module_utils/Ansible.Service.cs @@ -0,0 +1,1341 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; +using Ansible.Privilege; + +namespace Ansible.Service +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct ENUM_SERVICE_STATUSW + { + public string lpServiceName; + public string lpDisplayName; + public SERVICE_STATUS ServiceStatus; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct QUERY_SERVICE_CONFIGW + { + public ServiceType dwServiceType; + public ServiceStartType dwStartType; + public ErrorControl dwErrorControl; + [MarshalAs(UnmanagedType.LPWStr)] public string lpBinaryPathName; + [MarshalAs(UnmanagedType.LPWStr)] public string lpLoadOrderGroup; + public Int32 dwTagId; + public IntPtr lpDependencies; // Can't rely on marshaling as dependencies are delimited by \0. + [MarshalAs(UnmanagedType.LPWStr)] public string lpServiceStartName; + [MarshalAs(UnmanagedType.LPWStr)] public string lpDisplayName; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SC_ACTION + { + public FailureAction Type; + public UInt32 Delay; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_DELAYED_AUTO_START_INFO + { + public bool fDelayedAutostart; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct SERVICE_DESCRIPTIONW + { + [MarshalAs(UnmanagedType.LPWStr)] public string lpDescription; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_FAILURE_ACTIONS_FLAG + { + public bool fFailureActionsOnNonCrashFailures; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct SERVICE_FAILURE_ACTIONSW + { + public UInt32 dwResetPeriod; + [MarshalAs(UnmanagedType.LPWStr)] public string lpRebootMsg; + [MarshalAs(UnmanagedType.LPWStr)] public string lpCommand; + public UInt32 cActions; + public IntPtr lpsaActions; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_LAUNCH_PROTECTED_INFO + { + public LaunchProtection dwLaunchProtected; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_PREFERRED_NODE_INFO + { + public UInt16 usPreferredNode; + public bool fDelete; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_PRESHUTDOWN_INFO + { + public UInt32 dwPreshutdownTimeout; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct SERVICE_REQUIRED_PRIVILEGES_INFOW + { + // Can't rely on marshaling as privileges are delimited by \0. + public IntPtr pmszRequiredPrivileges; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_SID_INFO + { + public ServiceSidInfo dwServiceSidType; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_STATUS + { + public ServiceType dwServiceType; + public ServiceStatus dwCurrentState; + public ControlsAccepted dwControlsAccepted; + public UInt32 dwWin32ExitCode; + public UInt32 dwServiceSpecificExitCode; + public UInt32 dwCheckPoint; + public UInt32 dwWaitHint; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_STATUS_PROCESS + { + public ServiceType dwServiceType; + public ServiceStatus dwCurrentState; + public ControlsAccepted dwControlsAccepted; + public UInt32 dwWin32ExitCode; + public UInt32 dwServiceSpecificExitCode; + public UInt32 dwCheckPoint; + public UInt32 dwWaitHint; + public UInt32 dwProcessId; + public ServiceFlags dwServiceFlags; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_TRIGGER + { + public TriggerType dwTriggerType; + public TriggerAction dwAction; + public IntPtr pTriggerSubtype; + public UInt32 cDataItems; + public IntPtr pDataItems; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_TRIGGER_SPECIFIC_DATA_ITEM + { + public TriggerDataType dwDataType; + public UInt32 cbData; + public IntPtr pData; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SERVICE_TRIGGER_INFO + { + public UInt32 cTriggers; + public IntPtr pTriggers; + public IntPtr pReserved; + } + + public enum ConfigInfoLevel : uint + { + SERVICE_CONFIG_DESCRIPTION = 0x00000001, + SERVICE_CONFIG_FAILURE_ACTIONS = 0x00000002, + SERVICE_CONFIG_DELAYED_AUTO_START_INFO = 0x00000003, + SERVICE_CONFIG_FAILURE_ACTIONS_FLAG = 0x00000004, + SERVICE_CONFIG_SERVICE_SID_INFO = 0x00000005, + SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO = 0x00000006, + SERVICE_CONFIG_PRESHUTDOWN_INFO = 0x00000007, + SERVICE_CONFIG_TRIGGER_INFO = 0x00000008, + SERVICE_CONFIG_PREFERRED_NODE = 0x00000009, + SERVICE_CONFIG_LAUNCH_PROTECTED = 0x0000000c, + } + } + + internal class NativeMethods + { + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool ChangeServiceConfigW( + SafeHandle hService, + ServiceType dwServiceType, + ServiceStartType dwStartType, + ErrorControl dwErrorControl, + string lpBinaryPathName, + string lpLoadOrderGroup, + IntPtr lpdwTagId, + string lpDependencies, + string lpServiceStartName, + string lpPassword, + string lpDisplayName); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool ChangeServiceConfig2W( + SafeHandle hService, + NativeHelpers.ConfigInfoLevel dwInfoLevel, + IntPtr lpInfo); + + [DllImport("Advapi32.dll", SetLastError = true)] + public static extern bool CloseServiceHandle( + IntPtr hSCObject); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern SafeServiceHandle CreateServiceW( + SafeHandle hSCManager, + string lpServiceName, + string lpDisplayName, + ServiceRights dwDesiredAccess, + ServiceType dwServiceType, + ServiceStartType dwStartType, + ErrorControl dwErrorControl, + string lpBinaryPathName, + string lpLoadOrderGroup, + IntPtr lpdwTagId, + string lpDependencies, + string lpServiceStartName, + string lpPassword); + + [DllImport("Advapi32.dll", SetLastError = true)] + public static extern bool DeleteService( + SafeHandle hService); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool EnumDependentServicesW( + SafeHandle hService, + UInt32 dwServiceState, + SafeMemoryBuffer lpServices, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded, + out UInt32 lpServicesReturned); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern SafeServiceHandle OpenSCManagerW( + string lpMachineName, + string lpDatabaseNmae, + SCMRights dwDesiredAccess); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern SafeServiceHandle OpenServiceW( + SafeHandle hSCManager, + string lpServiceName, + ServiceRights dwDesiredAccess); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool QueryServiceConfigW( + SafeHandle hService, + IntPtr lpServiceConfig, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded); + + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + public static extern bool QueryServiceConfig2W( + SafeHandle hservice, + NativeHelpers.ConfigInfoLevel dwInfoLevel, + IntPtr lpBuffer, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded); + + [DllImport("Advapi32.dll", SetLastError = true)] + public static extern bool QueryServiceStatusEx( + SafeHandle hService, + UInt32 InfoLevel, + IntPtr lpBuffer, + UInt32 cbBufSize, + out UInt32 pcbBytesNeeded); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public UInt32 BufferLength { get; internal set; } + + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + BufferLength = (UInt32)cb; + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + internal class SafeServiceHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeServiceHandle() : base(true) { } + public SafeServiceHandle(IntPtr handle) : base(true) { this.handle = handle; } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseServiceHandle(handle); + } + } + + [Flags] + public enum ControlsAccepted : uint + { + None = 0x00000000, + Stop = 0x00000001, + PauseContinue = 0x00000002, + Shutdown = 0x00000004, + ParamChange = 0x00000008, + NetbindChange = 0x00000010, + HardwareProfileChange = 0x00000020, + PowerEvent = 0x00000040, + SessionChange = 0x00000080, + PreShutdown = 0x00000100, + } + + public enum ErrorControl : uint + { + Ignore = 0x00000000, + Normal = 0x00000001, + Severe = 0x00000002, + Critical = 0x00000003, + } + + public enum FailureAction : uint + { + None = 0x00000000, + Restart = 0x00000001, + Reboot = 0x00000002, + RunCommand = 0x00000003, + } + + public enum LaunchProtection : uint + { + None = 0, + Windows = 1, + WindowsLight = 2, + AntimalwareLight = 3, + } + + [Flags] + public enum SCMRights : uint + { + Connect = 0x00000001, + CreateService = 0x00000002, + EnumerateService = 0x00000004, + Lock = 0x00000008, + QueryLockStatus = 0x00000010, + ModifyBootConfig = 0x00000020, + AllAccess = 0x000F003F, + } + + [Flags] + public enum ServiceFlags : uint + { + None = 0x0000000, + RunsInSystemProcess = 0x00000001, + } + + [Flags] + public enum ServiceRights : uint + { + QueryConfig = 0x00000001, + ChangeConfig = 0x00000002, + QueryStatus = 0x00000004, + EnumerateDependents = 0x00000008, + Start = 0x00000010, + Stop = 0x00000020, + PauseContinue = 0x00000040, + Interrogate = 0x00000080, + UserDefinedControl = 0x00000100, + Delete = 0x00010000, + ReadControl = 0x00020000, + WriteDac = 0x00040000, + WriteOwner = 0x00080000, + AllAccess = 0x000F01FF, + AccessSystemSecurity = 0x01000000, + } + + public enum ServiceStartType : uint + { + BootStart = 0x00000000, + SystemStart = 0x00000001, + AutoStart = 0x00000002, + DemandStart = 0x00000003, + Disabled = 0x00000004, + + // Not part of ChangeServiceConfig enumeration but built by the Srvice class for the StartType property. + AutoStartDelayed = 0x1000000 + } + + [Flags] + public enum ServiceType : uint + { + KernelDriver = 0x00000001, + FileSystemDriver = 0x00000002, + Adapter = 0x00000004, + RecognizerDriver = 0x00000008, + Driver = KernelDriver | FileSystemDriver | RecognizerDriver, + Win32OwnProcess = 0x00000010, + Win32ShareProcess = 0x00000020, + Win32 = Win32OwnProcess | Win32ShareProcess, + UserProcess = 0x00000040, + UserOwnprocess = Win32OwnProcess | UserProcess, + UserShareProcess = Win32ShareProcess | UserProcess, + UserServiceInstance = 0x00000080, + InteractiveProcess = 0x00000100, + PkgService = 0x00000200, + } + + public enum ServiceSidInfo : uint + { + None, + Unrestricted, + Restricted = 3, + } + + public enum ServiceStatus : uint + { + Stopped = 0x00000001, + StartPending = 0x00000002, + StopPending = 0x00000003, + Running = 0x00000004, + ContinuePending = 0x00000005, + PausePending = 0x00000006, + Paused = 0x00000007, + } + + public enum TriggerAction : uint + { + ServiceStart = 0x00000001, + ServiceStop = 0x000000002, + } + + public enum TriggerDataType : uint + { + Binary = 00000001, + String = 0x00000002, + Level = 0x00000003, + KeywordAny = 0x00000004, + KeywordAll = 0x00000005, + } + + public enum TriggerType : uint + { + DeviceInterfaceArrival = 0x00000001, + IpAddressAvailability = 0x00000002, + DomainJoin = 0x00000003, + FirewallPortEvent = 0x00000004, + GroupPolicy = 0x00000005, + NetworkEndpoint = 0x00000006, + Custom = 0x00000014, + } + + public class ServiceManagerException : System.ComponentModel.Win32Exception + { + private string _msg; + + public ServiceManagerException(string message) : this(Marshal.GetLastWin32Error(), message) { } + public ServiceManagerException(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2} - 0x{2:X8})", message, base.Message, errorCode); + } + + public override string Message { get { return _msg; } } + public static explicit operator ServiceManagerException(string message) + { + return new ServiceManagerException(message); + } + } + + public class Action + { + public FailureAction Type; + public UInt32 Delay; + } + + public class FailureActions + { + public UInt32? ResetPeriod = null; // Get is always populated, can be null on set to preserve existing. + public string RebootMsg = null; + public string Command = null; + public List<Action> Actions = null; + + public FailureActions() { } + + internal FailureActions(NativeHelpers.SERVICE_FAILURE_ACTIONSW actions) + { + ResetPeriod = actions.dwResetPeriod; + RebootMsg = actions.lpRebootMsg; + Command = actions.lpCommand; + Actions = new List<Action>(); + + int actionLength = Marshal.SizeOf(typeof(NativeHelpers.SC_ACTION)); + for (int i = 0; i < actions.cActions; i++) + { + IntPtr actionPtr = IntPtr.Add(actions.lpsaActions, i * actionLength); + + NativeHelpers.SC_ACTION rawAction = (NativeHelpers.SC_ACTION)Marshal.PtrToStructure( + actionPtr, typeof(NativeHelpers.SC_ACTION)); + + Actions.Add(new Action() + { + Type = rawAction.Type, + Delay = rawAction.Delay, + }); + } + } + } + + public class TriggerItem + { + public TriggerDataType Type; + public object Data; // Can be string, List<string>, byte, byte[], or Int64 depending on Type. + + public TriggerItem() { } + + internal TriggerItem(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM dataItem) + { + Type = dataItem.dwDataType; + + byte[] itemBytes = new byte[dataItem.cbData]; + Marshal.Copy(dataItem.pData, itemBytes, 0, itemBytes.Length); + + switch (dataItem.dwDataType) + { + case TriggerDataType.String: + string value = Encoding.Unicode.GetString(itemBytes, 0, itemBytes.Length); + + if (value.EndsWith("\0\0")) + { + // Multistring with a delimiter of \0 and terminated with \0\0. + Data = new List<string>(value.Split(new char[1] { '\0' }, StringSplitOptions.RemoveEmptyEntries)); + } + else + // Just a single string with null character at the end, strip it off. + Data = value.Substring(0, value.Length - 1); + break; + case TriggerDataType.Level: + Data = itemBytes[0]; + break; + case TriggerDataType.KeywordAll: + case TriggerDataType.KeywordAny: + Data = BitConverter.ToUInt64(itemBytes, 0); + break; + default: + Data = itemBytes; + break; + } + } + } + + public class Trigger + { + // https://docs.microsoft.com/en-us/windows/win32/api/winsvc/ns-winsvc-service_trigger + public const string NAMED_PIPE_EVENT_GUID = "1f81d131-3fac-4537-9e0c-7e7b0c2f4b55"; + public const string RPC_INTERFACE_EVENT_GUID = "bc90d167-9470-4139-a9ba-be0bbbf5b74d"; + public const string DOMAIN_JOIN_GUID = "1ce20aba-9851-4421-9430-1ddeb766e809"; + public const string DOMAIN_LEAVE_GUID = "ddaf516e-58c2-4866-9574-c3b615d42ea1"; + public const string FIREWALL_PORT_OPEN_GUID = "b7569e07-8421-4ee0-ad10-86915afdad09"; + public const string FIREWALL_PORT_CLOSE_GUID = "a144ed38-8e12-4de4-9d96-e64740b1a524"; + public const string MACHINE_POLICY_PRESENT_GUID = "659fcae6-5bdb-4da9-b1ff-ca2a178d46e0"; + public const string NETWORK_MANAGER_FIRST_IP_ADDRESS_ARRIVAL_GUID = "4f27f2de-14e2-430b-a549-7cd48cbc8245"; + public const string NETWORK_MANAGER_LAST_IP_ADDRESS_REMOVAL_GUID = "cc4ba62a-162e-4648-847a-b6bdf993e335"; + public const string USER_POLICY_PRESENT_GUID = "54fb46c8-f089-464c-b1fd-59d1b62c3b50"; + + public TriggerType Type; + public TriggerAction Action; + public Guid SubType; + public List<TriggerItem> DataItems = new List<TriggerItem>(); + + public Trigger() { } + + internal Trigger(NativeHelpers.SERVICE_TRIGGER trigger) + { + Type = trigger.dwTriggerType; + Action = trigger.dwAction; + SubType = (Guid)Marshal.PtrToStructure(trigger.pTriggerSubtype, typeof(Guid)); + + int dataItemLength = Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)); + for (int i = 0; i < trigger.cDataItems; i++) + { + IntPtr dataPtr = IntPtr.Add(trigger.pDataItems, i * dataItemLength); + + var dataItem = (NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)Marshal.PtrToStructure( + dataPtr, typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)); + + DataItems.Add(new TriggerItem(dataItem)); + } + } + } + + public class Service : IDisposable + { + private const UInt32 SERVICE_NO_CHANGE = 0xFFFFFFFF; + + private SafeServiceHandle _scmHandle; + private SafeServiceHandle _serviceHandle; + private SafeMemoryBuffer _rawServiceConfig; + private NativeHelpers.SERVICE_STATUS_PROCESS _statusProcess; + + private NativeHelpers.QUERY_SERVICE_CONFIGW _ServiceConfig + { + get + { + return (NativeHelpers.QUERY_SERVICE_CONFIGW)Marshal.PtrToStructure( + _rawServiceConfig.DangerousGetHandle(), typeof(NativeHelpers.QUERY_SERVICE_CONFIGW)); + } + } + + // ServiceConfig + public string ServiceName { get; private set; } + + public ServiceType ServiceType + { + get { return _ServiceConfig.dwServiceType; } + set { ChangeServiceConfig(serviceType: value); } + } + + public ServiceStartType StartType + { + get + { + ServiceStartType startType = _ServiceConfig.dwStartType; + if (startType == ServiceStartType.AutoStart) + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_DELAYED_AUTO_START_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DELAYED_AUTO_START_INFO); + + if (value.fDelayedAutostart) + startType = ServiceStartType.AutoStartDelayed; + } + + return startType; + } + set + { + ServiceStartType newStartType = value; + bool delayedStart = false; + if (value == ServiceStartType.AutoStartDelayed) + { + newStartType = ServiceStartType.AutoStart; + delayedStart = true; + } + + ChangeServiceConfig(startType: newStartType); + + var info = new NativeHelpers.SERVICE_DELAYED_AUTO_START_INFO() + { + fDelayedAutostart = delayedStart, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DELAYED_AUTO_START_INFO, info); + } + } + + public ErrorControl ErrorControl + { + get { return _ServiceConfig.dwErrorControl; } + set { ChangeServiceConfig(errorControl: value); } + } + + public string Path + { + get { return _ServiceConfig.lpBinaryPathName; } + set { ChangeServiceConfig(binaryPath: value); } + } + + public string LoadOrderGroup + { + get { return _ServiceConfig.lpLoadOrderGroup; } + set { ChangeServiceConfig(loadOrderGroup: value); } + } + + public List<string> DependentOn + { + get + { + StringBuilder deps = new StringBuilder(); + IntPtr depPtr = _ServiceConfig.lpDependencies; + + bool wasNull = false; + while (true) + { + // Get the current char at the pointer and add it to the StringBuilder. + byte[] charBytes = new byte[sizeof(char)]; + Marshal.Copy(depPtr, charBytes, 0, charBytes.Length); + depPtr = IntPtr.Add(depPtr, charBytes.Length); + char currentChar = BitConverter.ToChar(charBytes, 0); + deps.Append(currentChar); + + // If the previous and current char is \0 exit the loop. + if (currentChar == '\0' && wasNull) + break; + wasNull = currentChar == '\0'; + } + + return new List<string>(deps.ToString().Split(new char[1] { '\0' }, + StringSplitOptions.RemoveEmptyEntries)); + } + set { ChangeServiceConfig(dependencies: value); } + } + + public IdentityReference Account + { + get + { + if (_ServiceConfig.lpServiceStartName == null) + // User services don't have the start name specified and will be null. + return null; + else if (_ServiceConfig.lpServiceStartName == "LocalSystem") + // Special string used for the SYSTEM account, this is the same even for different localisations. + return (NTAccount)new SecurityIdentifier("S-1-5-18").Translate(typeof(NTAccount)); + else + return new NTAccount(_ServiceConfig.lpServiceStartName); + } + set + { + string startName = null; + string pass = null; + + if (value != null) + { + // Create a SID and convert back from a SID to get the Netlogon form regardless of the input + // specified. + SecurityIdentifier accountSid = (SecurityIdentifier)value.Translate(typeof(SecurityIdentifier)); + NTAccount accountName = (NTAccount)accountSid.Translate(typeof(NTAccount)); + string[] accountSplit = accountName.Value.Split(new char[1] { '\\' }, 2); + + // SYSTEM, Local Service, Network Service + List<string> serviceAccounts = new List<string> { "S-1-5-18", "S-1-5-19", "S-1-5-20" }; + + // Well known service accounts and MSAs should have no password set. Explicitly blank out the + // existing password to ensure older passwords are no longer stored by Windows. + if (serviceAccounts.Contains(accountSid.Value) || accountSplit[1].EndsWith("$")) + pass = ""; + + // The SYSTEM account uses this special string to specify that account otherwise use the original + // NTAccount value in case it is in a custom format (not Netlogon) for a reason. + if (accountSid.Value == serviceAccounts[0]) + startName = "LocalSystem"; + else + startName = value.Translate(typeof(NTAccount)).Value; + } + + ChangeServiceConfig(startName: startName, password: pass); + } + } + + public string Password { set { ChangeServiceConfig(password: value); } } + + public string DisplayName + { + get { return _ServiceConfig.lpDisplayName; } + set { ChangeServiceConfig(displayName: value); } + } + + // ServiceConfig2 + + public string Description + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_DESCRIPTIONW>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DESCRIPTION); + + return value.lpDescription; + } + set + { + var info = new NativeHelpers.SERVICE_DESCRIPTIONW() + { + lpDescription = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_DESCRIPTION, info); + } + } + + public FailureActions FailureActions + { + get + { + using (SafeMemoryBuffer b = QueryServiceConfig2( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS)) + { + NativeHelpers.SERVICE_FAILURE_ACTIONSW value = (NativeHelpers.SERVICE_FAILURE_ACTIONSW) + Marshal.PtrToStructure(b.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_FAILURE_ACTIONSW)); + + return new FailureActions(value); + } + } + set + { + // dwResetPeriod and lpsaActions must be set together, we need to read the existing config if someone + // wants to update 1 or the other but both aren't explicitly defined. + UInt32? resetPeriod = value.ResetPeriod; + List<Action> actions = value.Actions; + if ((resetPeriod != null && actions == null) || (resetPeriod == null && actions != null)) + { + FailureActions existingValue = this.FailureActions; + + if (resetPeriod != null && existingValue.Actions.Count == 0) + throw new ArgumentException( + "Cannot set FailureAction ResetPeriod without explicit Actions and no existing Actions"); + else if (resetPeriod == null) + resetPeriod = (UInt32)existingValue.ResetPeriod; + + if (actions == null) + actions = existingValue.Actions; + } + + var info = new NativeHelpers.SERVICE_FAILURE_ACTIONSW() + { + dwResetPeriod = resetPeriod == null ? 0 : (UInt32)resetPeriod, + lpRebootMsg = value.RebootMsg, + lpCommand = value.Command, + cActions = actions == null ? 0 : (UInt32)actions.Count, + lpsaActions = IntPtr.Zero, + }; + + // null means to keep the existing actions whereas an empty list deletes the actions. + if (actions == null) + { + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS, info); + return; + } + + int actionLength = Marshal.SizeOf(typeof(NativeHelpers.SC_ACTION)); + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(actionLength * actions.Count)) + { + info.lpsaActions = buffer.DangerousGetHandle(); + HashSet<string> privileges = new HashSet<string>(); + + for (int i = 0; i < actions.Count; i++) + { + IntPtr actionPtr = IntPtr.Add(info.lpsaActions, i * actionLength); + NativeHelpers.SC_ACTION action = new NativeHelpers.SC_ACTION() + { + Delay = actions[i].Delay, + Type = actions[i].Type, + }; + Marshal.StructureToPtr(action, actionPtr, false); + + // Need to make sure the SeShutdownPrivilege is enabled when adding a reboot failure action. + if (action.Type == FailureAction.Reboot) + privileges.Add("SeShutdownPrivilege"); + } + + using (new PrivilegeEnabler(true, privileges.ToList().ToArray())) + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS, info); + } + } + } + + public bool FailureActionsOnNonCrashFailures + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_FAILURE_ACTIONS_FLAG>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG); + + return value.fFailureActionsOnNonCrashFailures; + } + set + { + var info = new NativeHelpers.SERVICE_FAILURE_ACTIONS_FLAG() + { + fFailureActionsOnNonCrashFailures = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG, info); + } + } + + public ServiceSidInfo ServiceSidInfo + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_SID_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_SERVICE_SID_INFO); + + return value.dwServiceSidType; + } + set + { + var info = new NativeHelpers.SERVICE_SID_INFO() + { + dwServiceSidType = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_SERVICE_SID_INFO, info); + } + } + + public List<string> RequiredPrivileges + { + get + { + using (SafeMemoryBuffer buffer = QueryServiceConfig2( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO)) + { + var value = (NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW)Marshal.PtrToStructure( + buffer.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW)); + + int structLength = Marshal.SizeOf(value); + int stringLength = ((int)buffer.BufferLength - structLength) / sizeof(char); + + if (stringLength > 0) + { + string privilegesString = Marshal.PtrToStringUni(value.pmszRequiredPrivileges, stringLength); + return new List<string>(privilegesString.Split(new char[1] { '\0' }, + StringSplitOptions.RemoveEmptyEntries)); + } + else + return new List<string>(); + } + } + set + { + string privilegeString = String.Join("\0", value ?? new List<string>()) + "\0\0"; + + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(Marshal.StringToHGlobalUni(privilegeString))) + { + var info = new NativeHelpers.SERVICE_REQUIRED_PRIVILEGES_INFOW() + { + pmszRequiredPrivileges = buffer.DangerousGetHandle(), + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_REQUIRED_PRIVILEGES_INFO, info); + } + } + } + + public UInt32 PreShutdownTimeout + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_PRESHUTDOWN_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PRESHUTDOWN_INFO); + + return value.dwPreshutdownTimeout; + } + set + { + var info = new NativeHelpers.SERVICE_PRESHUTDOWN_INFO() + { + dwPreshutdownTimeout = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PRESHUTDOWN_INFO, info); + } + } + + public List<Trigger> Triggers + { + get + { + List<Trigger> triggers = new List<Trigger>(); + + using (SafeMemoryBuffer b = QueryServiceConfig2( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO)) + { + var value = (NativeHelpers.SERVICE_TRIGGER_INFO)Marshal.PtrToStructure( + b.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_TRIGGER_INFO)); + + int triggerLength = Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER)); + for (int i = 0; i < value.cTriggers; i++) + { + IntPtr triggerPtr = IntPtr.Add(value.pTriggers, i * triggerLength); + var trigger = (NativeHelpers.SERVICE_TRIGGER)Marshal.PtrToStructure(triggerPtr, + typeof(NativeHelpers.SERVICE_TRIGGER)); + + triggers.Add(new Trigger(trigger)); + } + } + + return triggers; + } + set + { + var info = new NativeHelpers.SERVICE_TRIGGER_INFO() + { + cTriggers = value == null ? 0 : (UInt32)value.Count, + pTriggers = IntPtr.Zero, + pReserved = IntPtr.Zero, + }; + + if (info.cTriggers == 0) + { + try + { + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO, info); + } + catch (ServiceManagerException e) + { + // Can fail with ERROR_INVALID_PARAMETER if no triggers were already set on the service, just + // continue as the service is what we want it to be. + if (e.NativeErrorCode != 87) + throw; + } + return; + } + + // Due to the dynamic nature of the trigger structure(s) we need to manually calculate the size of the + // data items on each trigger if present. This also serializes the raw data items to bytes here. + int structDataLength = 0; + int dataLength = 0; + Queue<byte[]> dataItems = new Queue<byte[]>(); + foreach (Trigger trigger in value) + { + if (trigger.DataItems == null || trigger.DataItems.Count == 0) + continue; + + foreach (TriggerItem dataItem in trigger.DataItems) + { + structDataLength += Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM)); + + byte[] dataItemBytes; + Type dataItemType = dataItem.Data.GetType(); + if (dataItemType == typeof(byte)) + dataItemBytes = new byte[1] { (byte)dataItem.Data }; + else if (dataItemType == typeof(byte[])) + dataItemBytes = (byte[])dataItem.Data; + else if (dataItemType == typeof(UInt64)) + dataItemBytes = BitConverter.GetBytes((UInt64)dataItem.Data); + else if (dataItemType == typeof(string)) + dataItemBytes = Encoding.Unicode.GetBytes((string)dataItem.Data + "\0"); + else if (dataItemType == typeof(List<string>)) + dataItemBytes = Encoding.Unicode.GetBytes( + String.Join("\0", (List<string>)dataItem.Data) + "\0"); + else + throw new ArgumentException(String.Format("Trigger data type '{0}' not a value type", + dataItemType.Name)); + + dataLength += dataItemBytes.Length; + dataItems.Enqueue(dataItemBytes); + } + } + + using (SafeMemoryBuffer triggerBuffer = new SafeMemoryBuffer( + value.Count * Marshal.SizeOf(typeof(NativeHelpers.SERVICE_TRIGGER)))) + using (SafeMemoryBuffer triggerGuidBuffer = new SafeMemoryBuffer( + value.Count * Marshal.SizeOf(typeof(Guid)))) + using (SafeMemoryBuffer dataItemBuffer = new SafeMemoryBuffer(structDataLength)) + using (SafeMemoryBuffer dataBuffer = new SafeMemoryBuffer(dataLength)) + { + info.pTriggers = triggerBuffer.DangerousGetHandle(); + + IntPtr triggerPtr = triggerBuffer.DangerousGetHandle(); + IntPtr guidPtr = triggerGuidBuffer.DangerousGetHandle(); + IntPtr dataItemPtr = dataItemBuffer.DangerousGetHandle(); + IntPtr dataPtr = dataBuffer.DangerousGetHandle(); + + foreach (Trigger trigger in value) + { + int dataCount = trigger.DataItems == null ? 0 : trigger.DataItems.Count; + var rawTrigger = new NativeHelpers.SERVICE_TRIGGER() + { + dwTriggerType = trigger.Type, + dwAction = trigger.Action, + pTriggerSubtype = guidPtr, + cDataItems = (UInt32)dataCount, + pDataItems = dataCount == 0 ? IntPtr.Zero : dataItemPtr, + }; + guidPtr = StructureToPtr(trigger.SubType, guidPtr); + + for (int i = 0; i < rawTrigger.cDataItems; i++) + { + byte[] dataItemBytes = dataItems.Dequeue(); + var rawTriggerData = new NativeHelpers.SERVICE_TRIGGER_SPECIFIC_DATA_ITEM() + { + dwDataType = trigger.DataItems[i].Type, + cbData = (UInt32)dataItemBytes.Length, + pData = dataPtr, + }; + Marshal.Copy(dataItemBytes, 0, dataPtr, dataItemBytes.Length); + dataPtr = IntPtr.Add(dataPtr, dataItemBytes.Length); + + dataItemPtr = StructureToPtr(rawTriggerData, dataItemPtr); + } + + triggerPtr = StructureToPtr(rawTrigger, triggerPtr); + } + + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_TRIGGER_INFO, info); + } + } + } + + public UInt16? PreferredNode + { + get + { + try + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_PREFERRED_NODE_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PREFERRED_NODE); + + return value.usPreferredNode; + } + catch (ServiceManagerException e) + { + // If host has no NUMA support this will fail with ERROR_INVALID_PARAMETER + if (e.NativeErrorCode == 0x00000057) // ERROR_INVALID_PARAMETER + return null; + + throw; + } + } + set + { + var info = new NativeHelpers.SERVICE_PREFERRED_NODE_INFO(); + if (value == null) + info.fDelete = true; + else + info.usPreferredNode = (UInt16)value; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_PREFERRED_NODE, info); + } + } + + public LaunchProtection LaunchProtection + { + get + { + var value = QueryServiceConfig2<NativeHelpers.SERVICE_LAUNCH_PROTECTED_INFO>( + NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_LAUNCH_PROTECTED); + + return value.dwLaunchProtected; + } + set + { + var info = new NativeHelpers.SERVICE_LAUNCH_PROTECTED_INFO() + { + dwLaunchProtected = value, + }; + ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel.SERVICE_CONFIG_LAUNCH_PROTECTED, info); + } + } + + // ServiceStatus + public ServiceStatus State { get { return _statusProcess.dwCurrentState; } } + + public ControlsAccepted ControlsAccepted { get { return _statusProcess.dwControlsAccepted; } } + + public UInt32 Win32ExitCode { get { return _statusProcess.dwWin32ExitCode; } } + + public UInt32 ServiceExitCode { get { return _statusProcess.dwServiceSpecificExitCode; } } + + public UInt32 Checkpoint { get { return _statusProcess.dwCheckPoint; } } + + public UInt32 WaitHint { get { return _statusProcess.dwWaitHint; } } + + public UInt32 ProcessId { get { return _statusProcess.dwProcessId; } } + + public ServiceFlags ServiceFlags { get { return _statusProcess.dwServiceFlags; } } + + public Service(string name) : this(name, ServiceRights.AllAccess) { } + + public Service(string name, ServiceRights access) : this(name, access, SCMRights.Connect) { } + + public Service(string name, ServiceRights access, SCMRights scmAccess) + { + ServiceName = name; + _scmHandle = OpenSCManager(scmAccess); + _serviceHandle = NativeMethods.OpenServiceW(_scmHandle, name, access); + if (_serviceHandle.IsInvalid) + throw new ServiceManagerException(String.Format("Failed to open service '{0}'", name)); + + Refresh(); + } + + private Service(SafeServiceHandle scmHandle, SafeServiceHandle serviceHandle, string name) + { + ServiceName = name; + _scmHandle = scmHandle; + _serviceHandle = serviceHandle; + + Refresh(); + } + + // EnumDependentServices + public List<string> DependedBy + { + get + { + UInt32 bytesNeeded = 0; + UInt32 numServices = 0; + NativeMethods.EnumDependentServicesW(_serviceHandle, 3, new SafeMemoryBuffer(IntPtr.Zero), 0, + out bytesNeeded, out numServices); + + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded)) + { + if (!NativeMethods.EnumDependentServicesW(_serviceHandle, 3, buffer, bytesNeeded, out bytesNeeded, + out numServices)) + { + throw new ServiceManagerException("Failed to enumerated dependent services"); + } + + List<string> dependents = new List<string>(); + Type enumType = typeof(NativeHelpers.ENUM_SERVICE_STATUSW); + for (int i = 0; i < numServices; i++) + { + var service = (NativeHelpers.ENUM_SERVICE_STATUSW)Marshal.PtrToStructure( + IntPtr.Add(buffer.DangerousGetHandle(), i * Marshal.SizeOf(enumType)), enumType); + + dependents.Add(service.lpServiceName); + } + + return dependents; + } + } + } + + public static Service Create(string name, string binaryPath, string displayName = null, + ServiceType serviceType = ServiceType.Win32OwnProcess, + ServiceStartType startType = ServiceStartType.DemandStart, ErrorControl errorControl = ErrorControl.Normal, + string loadOrderGroup = null, List<string> dependencies = null, string startName = null, + string password = null) + { + SafeServiceHandle scmHandle = OpenSCManager(SCMRights.CreateService | SCMRights.Connect); + + if (displayName == null) + displayName = name; + + string depString = null; + if (dependencies != null && dependencies.Count > 0) + depString = String.Join("\0", dependencies) + "\0\0"; + + SafeServiceHandle serviceHandle = NativeMethods.CreateServiceW(scmHandle, name, displayName, + ServiceRights.AllAccess, serviceType, startType, errorControl, binaryPath, + loadOrderGroup, IntPtr.Zero, depString, startName, password); + + if (serviceHandle.IsInvalid) + throw new ServiceManagerException(String.Format("Failed to create new service '{0}'", name)); + + return new Service(scmHandle, serviceHandle, name); + } + + public void Delete() + { + if (!NativeMethods.DeleteService(_serviceHandle)) + throw new ServiceManagerException("Failed to delete service"); + Dispose(); + } + + public void Dispose() + { + if (_serviceHandle != null) + _serviceHandle.Dispose(); + + if (_scmHandle != null) + _scmHandle.Dispose(); + GC.SuppressFinalize(this); + } + + public void Refresh() + { + UInt32 bytesNeeded; + NativeMethods.QueryServiceConfigW(_serviceHandle, IntPtr.Zero, 0, out bytesNeeded); + + _rawServiceConfig = new SafeMemoryBuffer((int)bytesNeeded); + if (!NativeMethods.QueryServiceConfigW(_serviceHandle, _rawServiceConfig.DangerousGetHandle(), bytesNeeded, + out bytesNeeded)) + { + throw new ServiceManagerException("Failed to query service config"); + } + + NativeMethods.QueryServiceStatusEx(_serviceHandle, 0, IntPtr.Zero, 0, out bytesNeeded); + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded)) + { + if (!NativeMethods.QueryServiceStatusEx(_serviceHandle, 0, buffer.DangerousGetHandle(), bytesNeeded, + out bytesNeeded)) + { + throw new ServiceManagerException("Failed to query service status"); + } + + _statusProcess = (NativeHelpers.SERVICE_STATUS_PROCESS)Marshal.PtrToStructure( + buffer.DangerousGetHandle(), typeof(NativeHelpers.SERVICE_STATUS_PROCESS)); + } + } + + private void ChangeServiceConfig(ServiceType serviceType = (ServiceType)SERVICE_NO_CHANGE, + ServiceStartType startType = (ServiceStartType)SERVICE_NO_CHANGE, + ErrorControl errorControl = (ErrorControl)SERVICE_NO_CHANGE, string binaryPath = null, + string loadOrderGroup = null, List<string> dependencies = null, string startName = null, + string password = null, string displayName = null) + { + string depString = null; + if (dependencies != null && dependencies.Count > 0) + depString = String.Join("\0", dependencies) + "\0\0"; + + if (!NativeMethods.ChangeServiceConfigW(_serviceHandle, serviceType, startType, errorControl, binaryPath, + loadOrderGroup, IntPtr.Zero, depString, startName, password, displayName)) + { + throw new ServiceManagerException("Failed to change service config"); + } + + Refresh(); + } + + private void ChangeServiceConfig2(NativeHelpers.ConfigInfoLevel infoLevel, object info) + { + using (SafeMemoryBuffer buffer = new SafeMemoryBuffer(Marshal.SizeOf(info))) + { + Marshal.StructureToPtr(info, buffer.DangerousGetHandle(), false); + + if (!NativeMethods.ChangeServiceConfig2W(_serviceHandle, infoLevel, buffer.DangerousGetHandle())) + throw new ServiceManagerException("Failed to change service config"); + } + } + + private static SafeServiceHandle OpenSCManager(SCMRights desiredAccess) + { + SafeServiceHandle handle = NativeMethods.OpenSCManagerW(null, null, desiredAccess); + if (handle.IsInvalid) + throw new ServiceManagerException("Failed to open SCManager"); + + return handle; + } + + private T QueryServiceConfig2<T>(NativeHelpers.ConfigInfoLevel infoLevel) + { + using (SafeMemoryBuffer buffer = QueryServiceConfig2(infoLevel)) + return (T)Marshal.PtrToStructure(buffer.DangerousGetHandle(), typeof(T)); + } + + private SafeMemoryBuffer QueryServiceConfig2(NativeHelpers.ConfigInfoLevel infoLevel) + { + UInt32 bytesNeeded = 0; + NativeMethods.QueryServiceConfig2W(_serviceHandle, infoLevel, IntPtr.Zero, 0, out bytesNeeded); + + SafeMemoryBuffer buffer = new SafeMemoryBuffer((int)bytesNeeded); + if (!NativeMethods.QueryServiceConfig2W(_serviceHandle, infoLevel, buffer.DangerousGetHandle(), bytesNeeded, + out bytesNeeded)) + { + throw new ServiceManagerException(String.Format("QueryServiceConfig2W({0}) failed", + infoLevel.ToString())); + } + + return buffer; + } + + private static IntPtr StructureToPtr(object structure, IntPtr ptr) + { + Marshal.StructureToPtr(structure, ptr, false); + return IntPtr.Add(ptr, Marshal.SizeOf(structure)); + } + + ~Service() { Dispose(); } + } +} diff --git a/test/support/windows-integration/plugins/modules/async_status.ps1 b/test/support/windows-integration/plugins/modules/async_status.ps1 new file mode 100644 index 00000000..1ce3ff40 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/async_status.ps1 @@ -0,0 +1,58 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$results = @{changed=$false} + +$parsed_args = Parse-Args $args +$jid = Get-AnsibleParam $parsed_args "jid" -failifempty $true -resultobj $results +$mode = Get-AnsibleParam $parsed_args "mode" -Default "status" -ValidateSet "status","cleanup" + +# parsed in from the async_status action plugin +$async_dir = Get-AnsibleParam $parsed_args "_async_dir" -type "path" -failifempty $true + +$log_path = [System.IO.Path]::Combine($async_dir, $jid) + +If(-not $(Test-Path $log_path)) +{ + Fail-Json @{ansible_job_id=$jid; started=1; finished=1} "could not find job at '$async_dir'" +} + +If($mode -eq "cleanup") { + Remove-Item $log_path -Recurse + Exit-Json @{ansible_job_id=$jid; erased=$log_path} +} + +# NOT in cleanup mode, assume regular status mode +# no remote kill mode currently exists, but probably should +# consider log_path + ".pid" file and also unlink that above + +$data = $null +Try { + $data_raw = Get-Content $log_path + + # TODO: move this into module_utils/powershell.ps1? + $jss = New-Object System.Web.Script.Serialization.JavaScriptSerializer + $data = $jss.DeserializeObject($data_raw) +} +Catch { + If(-not $data_raw) { + # file not written yet? That means it is running + Exit-Json @{results_file=$log_path; ansible_job_id=$jid; started=1; finished=0} + } + Else { + Fail-Json @{ansible_job_id=$jid; results_file=$log_path; started=1; finished=1} "Could not parse job output: $data" + } +} + +If (-not $data.ContainsKey("started")) { + $data['finished'] = 1 + $data['ansible_job_id'] = $jid +} +ElseIf (-not $data.ContainsKey("finished")) { + $data['finished'] = 0 +} + +Exit-Json $data diff --git a/test/support/windows-integration/plugins/modules/setup.ps1 b/test/support/windows-integration/plugins/modules/setup.ps1 new file mode 100644 index 00000000..50647239 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/setup.ps1 @@ -0,0 +1,516 @@ +#!powershell + +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +Function Get-CustomFacts { + [cmdletBinding()] + param ( + [Parameter(mandatory=$false)] + $factpath = $null + ) + + if (Test-Path -Path $factpath) { + $FactsFiles = Get-ChildItem -Path $factpath | Where-Object -FilterScript {($PSItem.PSIsContainer -eq $false) -and ($PSItem.Extension -eq '.ps1')} + + foreach ($FactsFile in $FactsFiles) { + $out = & $($FactsFile.FullName) + $result.ansible_facts.Add("ansible_$(($FactsFile.Name).Split('.')[0])", $out) + } + } + else + { + Add-Warning $result "Non existing path was set for local facts - $factpath" + } +} + +Function Get-MachineSid { + # The Machine SID is stored in HKLM:\SECURITY\SAM\Domains\Account and is + # only accessible by the Local System account. This method get's the local + # admin account (ends with -500) and lops it off to get the machine sid. + + $machine_sid = $null + + try { + $admins_sid = "S-1-5-32-544" + $admin_group = ([Security.Principal.SecurityIdentifier]$admins_sid).Translate([Security.Principal.NTAccount]).Value + + Add-Type -AssemblyName System.DirectoryServices.AccountManagement + $principal_context = New-Object -TypeName System.DirectoryServices.AccountManagement.PrincipalContext([System.DirectoryServices.AccountManagement.ContextType]::Machine) + $group_principal = New-Object -TypeName System.DirectoryServices.AccountManagement.GroupPrincipal($principal_context, $admin_group) + $searcher = New-Object -TypeName System.DirectoryServices.AccountManagement.PrincipalSearcher($group_principal) + $groups = $searcher.FindOne() + + foreach ($user in $groups.Members) { + $user_sid = $user.Sid + if ($user_sid.Value.EndsWith("-500")) { + $machine_sid = $user_sid.AccountDomainSid.Value + break + } + } + } catch { + #can fail for any number of reasons, if it does just return the original null + Add-Warning -obj $result -message "Error during machine sid retrieval: $($_.Exception.Message)" + } + + return $machine_sid +} + +$cim_instances = @{} + +Function Get-LazyCimInstance([string]$instance_name, [string]$namespace="Root\CIMV2") { + if(-not $cim_instances.ContainsKey($instance_name)) { + $cim_instances[$instance_name] = $(Get-CimInstance -Namespace $namespace -ClassName $instance_name) + } + + return $cim_instances[$instance_name] +} + +$result = @{ + ansible_facts = @{ } + changed = $false +} + +$grouped_subsets = @{ + min=[System.Collections.Generic.List[string]]@('date_time','distribution','dns','env','local','platform','powershell_version','user') + network=[System.Collections.Generic.List[string]]@('all_ipv4_addresses','all_ipv6_addresses','interfaces','windows_domain', 'winrm') + hardware=[System.Collections.Generic.List[string]]@('bios','memory','processor','uptime','virtual') + external=[System.Collections.Generic.List[string]]@('facter') +} + +# build "all" set from everything mentioned in the group- this means every value must be in at least one subset to be considered legal +$all_set = [System.Collections.Generic.HashSet[string]]@() + +foreach($kv in $grouped_subsets.GetEnumerator()) { + [void] $all_set.UnionWith($kv.Value) +} + +# dynamically create an "all" subset now that we know what should be in it +$grouped_subsets['all'] = [System.Collections.Generic.List[string]]$all_set + +# start with all, build up gather and exclude subsets +$gather_subset = [System.Collections.Generic.HashSet[string]]$grouped_subsets.all +$explicit_subset = [System.Collections.Generic.HashSet[string]]@() +$exclude_subset = [System.Collections.Generic.HashSet[string]]@() + +$params = Parse-Args $args -supports_check_mode $true +$factpath = Get-AnsibleParam -obj $params -name "fact_path" -type "path" +$gather_subset_source = Get-AnsibleParam -obj $params -name "gather_subset" -type "list" -default "all" + +foreach($item in $gather_subset_source) { + if(([string]$item).StartsWith("!")) { + $item = ([string]$item).Substring(1) + if($item -eq "all") { + $all_minus_min = [System.Collections.Generic.HashSet[string]]@($all_set) + [void] $all_minus_min.ExceptWith($grouped_subsets.min) + [void] $exclude_subset.UnionWith($all_minus_min) + } + elseif($grouped_subsets.ContainsKey($item)) { + [void] $exclude_subset.UnionWith($grouped_subsets[$item]) + } + elseif($all_set.Contains($item)) { + [void] $exclude_subset.Add($item) + } + # NB: invalid exclude values are ignored, since that's what posix setup does + } + else { + if($grouped_subsets.ContainsKey($item)) { + [void] $explicit_subset.UnionWith($grouped_subsets[$item]) + } + elseif($all_set.Contains($item)) { + [void] $explicit_subset.Add($item) + } + else { + # NB: POSIX setup fails on invalid value; we warn, because we don't implement the same set as POSIX + # and we don't have platform-specific config for this... + Add-Warning $result "invalid value $item specified in gather_subset" + } + } +} + +[void] $gather_subset.ExceptWith($exclude_subset) +[void] $gather_subset.UnionWith($explicit_subset) + +$ansible_facts = @{ + gather_subset=@($gather_subset_source) + module_setup=$true +} + +$osversion = [Environment]::OSVersion + +if ($osversion.Version -lt [version]"6.2") { + # Server 2008, 2008 R2, and Windows 7 are not tested in CI and we want to let customers know about it before + # removing support altogether. + $version_string = "{0}.{1}" -f ($osversion.Version.Major, $osversion.Version.Minor) + $msg = "Windows version '$version_string' will no longer be supported or tested in the next Ansible release" + Add-DeprecationWarning -obj $result -message $msg -version "2.11" +} + +if($gather_subset.Contains('all_ipv4_addresses') -or $gather_subset.Contains('all_ipv6_addresses')) { + $netcfg = Get-LazyCimInstance Win32_NetworkAdapterConfiguration + + # TODO: split v4/v6 properly, return in separate keys + $ips = @() + Foreach ($ip in $netcfg.IPAddress) { + If ($ip) { + $ips += $ip + } + } + + $ansible_facts += @{ + ansible_ip_addresses = $ips + } +} + +if($gather_subset.Contains('bios')) { + $win32_bios = Get-LazyCimInstance Win32_Bios + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $ansible_facts += @{ + ansible_bios_date = $win32_bios.ReleaseDate.ToString("MM/dd/yyyy") + ansible_bios_version = $win32_bios.SMBIOSBIOSVersion + ansible_product_name = $win32_cs.Model.Trim() + ansible_product_serial = $win32_bios.SerialNumber + # ansible_product_version = ([string] $win32_cs.SystemFamily) + } +} + +if($gather_subset.Contains('date_time')) { + $datetime = (Get-Date) + $datetime_utc = $datetime.ToUniversalTime() + $date = @{ + date = $datetime.ToString("yyyy-MM-dd") + day = $datetime.ToString("dd") + epoch = (Get-Date -UFormat "%s") + hour = $datetime.ToString("HH") + iso8601 = $datetime_utc.ToString("yyyy-MM-ddTHH:mm:ssZ") + iso8601_basic = $datetime.ToString("yyyyMMddTHHmmssffffff") + iso8601_basic_short = $datetime.ToString("yyyyMMddTHHmmss") + iso8601_micro = $datetime_utc.ToString("yyyy-MM-ddTHH:mm:ss.ffffffZ") + minute = $datetime.ToString("mm") + month = $datetime.ToString("MM") + second = $datetime.ToString("ss") + time = $datetime.ToString("HH:mm:ss") + tz = ([System.TimeZoneInfo]::Local.Id) + tz_offset = $datetime.ToString("zzzz") + # Ensure that the weekday is in English + weekday = $datetime.ToString("dddd", [System.Globalization.CultureInfo]::InvariantCulture) + weekday_number = (Get-Date -UFormat "%w") + weeknumber = (Get-Date -UFormat "%W") + year = $datetime.ToString("yyyy") + } + + $ansible_facts += @{ + ansible_date_time = $date + } +} + +if($gather_subset.Contains('distribution')) { + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $product_type = switch($win32_os.ProductType) { + 1 { "workstation" } + 2 { "domain_controller" } + 3 { "server" } + default { "unknown" } + } + + $installation_type = $null + $current_version_path = "HKLM:\SOFTWARE\Microsoft\Windows NT\CurrentVersion" + if (Test-Path -LiteralPath $current_version_path) { + $install_type_prop = Get-ItemProperty -LiteralPath $current_version_path -ErrorAction SilentlyContinue + $installation_type = [String]$install_type_prop.InstallationType + } + + $ansible_facts += @{ + ansible_distribution = $win32_os.Caption + ansible_distribution_version = $osversion.Version.ToString() + ansible_distribution_major_version = $osversion.Version.Major.ToString() + ansible_os_family = "Windows" + ansible_os_name = ($win32_os.Name.Split('|')[0]).Trim() + ansible_os_product_type = $product_type + ansible_os_installation_type = $installation_type + } +} + +if($gather_subset.Contains('env')) { + $env_vars = @{ } + foreach ($item in Get-ChildItem Env:) { + $name = $item | Select-Object -ExpandProperty Name + # Powershell ConvertTo-Json fails if string ends with \ + $value = ($item | Select-Object -ExpandProperty Value).TrimEnd("\") + $env_vars.Add($name, $value) + } + + $ansible_facts += @{ + ansible_env = $env_vars + } +} + +if($gather_subset.Contains('facter')) { + # See if Facter is on the System Path + Try { + Get-Command facter -ErrorAction Stop > $null + $facter_installed = $true + } Catch { + $facter_installed = $false + } + + # Get JSON from Facter, and parse it out. + if ($facter_installed) { + &facter -j | Tee-Object -Variable facter_output > $null + $facts = "$facter_output" | ConvertFrom-Json + ForEach($fact in $facts.PSObject.Properties) { + $fact_name = $fact.Name + $ansible_facts.Add("facter_$fact_name", $fact.Value) + } + } +} + +if($gather_subset.Contains('interfaces')) { + $netcfg = Get-LazyCimInstance Win32_NetworkAdapterConfiguration + $ActiveNetcfg = @() + $ActiveNetcfg += $netcfg | Where-Object {$_.ipaddress -ne $null} + + $namespaces = Get-LazyCimInstance __Namespace -namespace root + if ($namespaces | Where-Object { $_.Name -eq "StandardCimv" }) { + $net_adapters = Get-LazyCimInstance MSFT_NetAdapter -namespace Root\StandardCimv2 + $guid_key = "InterfaceGUID" + $name_key = "Name" + } else { + $net_adapters = Get-LazyCimInstance Win32_NetworkAdapter + $guid_key = "GUID" + $name_key = "NetConnectionID" + } + + $formattednetcfg = @() + foreach ($adapter in $ActiveNetcfg) + { + $thisadapter = @{ + default_gateway = $null + connection_name = $null + dns_domain = $adapter.dnsdomain + interface_index = $adapter.InterfaceIndex + interface_name = $adapter.description + macaddress = $adapter.macaddress + } + + if ($adapter.defaultIPGateway) + { + $thisadapter.default_gateway = $adapter.DefaultIPGateway[0].ToString() + } + $net_adapter = $net_adapters | Where-Object { $_.$guid_key -eq $adapter.SettingID } + if ($net_adapter) { + $thisadapter.connection_name = $net_adapter.$name_key + } + + $formattednetcfg += $thisadapter + } + + $ansible_facts += @{ + ansible_interfaces = $formattednetcfg + } +} + +if ($gather_subset.Contains("local") -and $null -ne $factpath) { + # Get any custom facts; results are updated in the + Get-CustomFacts -factpath $factpath +} + +if($gather_subset.Contains('memory')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $ansible_facts += @{ + # Win32_PhysicalMemory is empty on some virtual platforms + ansible_memtotal_mb = ([math]::ceiling($win32_cs.TotalPhysicalMemory / 1024 / 1024)) + ansible_memfree_mb = ([math]::ceiling($win32_os.FreePhysicalMemory / 1024)) + ansible_swaptotal_mb = ([math]::round($win32_os.TotalSwapSpaceSize / 1024)) + ansible_pagefiletotal_mb = ([math]::round($win32_os.SizeStoredInPagingFiles / 1024)) + ansible_pagefilefree_mb = ([math]::round($win32_os.FreeSpaceInPagingFiles / 1024)) + } +} + + +if($gather_subset.Contains('platform')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $domain_suffix = $win32_cs.Domain.Substring($win32_cs.Workgroup.length) + $fqdn = $win32_cs.DNSHostname + + if( $domain_suffix -ne "") + { + $fqdn = $win32_cs.DNSHostname + "." + $domain_suffix + } + + try { + $ansible_reboot_pending = Get-PendingRebootStatus + } catch { + # fails for non-admin users, set to null in this case + $ansible_reboot_pending = $null + } + + $ansible_facts += @{ + ansible_architecture = $win32_os.OSArchitecture + ansible_domain = $domain_suffix + ansible_fqdn = $fqdn + ansible_hostname = $win32_cs.DNSHostname + ansible_netbios_name = $win32_cs.Name + ansible_kernel = $osversion.Version.ToString() + ansible_nodename = $fqdn + ansible_machine_id = Get-MachineSid + ansible_owner_contact = ([string] $win32_cs.PrimaryOwnerContact) + ansible_owner_name = ([string] $win32_cs.PrimaryOwnerName) + # FUTURE: should this live in its own subset? + ansible_reboot_pending = $ansible_reboot_pending + ansible_system = $osversion.Platform.ToString() + ansible_system_description = ([string] $win32_os.Description) + ansible_system_vendor = $win32_cs.Manufacturer + } +} + +if($gather_subset.Contains('powershell_version')) { + $ansible_facts += @{ + ansible_powershell_version = ($PSVersionTable.PSVersion.Major) + } +} + +if($gather_subset.Contains('processor')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $win32_cpu = Get-LazyCimInstance Win32_Processor + if ($win32_cpu -is [array]) { + # multi-socket, pick first + $win32_cpu = $win32_cpu[0] + } + + $cpu_list = @( ) + for ($i=1; $i -le $win32_cs.NumberOfLogicalProcessors; $i++) { + $cpu_list += $win32_cpu.Manufacturer + $cpu_list += $win32_cpu.Name + } + + $ansible_facts += @{ + ansible_processor = $cpu_list + ansible_processor_cores = $win32_cpu.NumberOfCores + ansible_processor_count = $win32_cs.NumberOfProcessors + ansible_processor_threads_per_core = ($win32_cpu.NumberOfLogicalProcessors / $win32_cpu.NumberofCores) + ansible_processor_vcpus = $win32_cs.NumberOfLogicalProcessors + } +} + +if($gather_subset.Contains('uptime')) { + $win32_os = Get-LazyCimInstance Win32_OperatingSystem + $ansible_facts += @{ + ansible_lastboot = $win32_os.lastbootuptime.ToString("u") + ansible_uptime_seconds = $([System.Convert]::ToInt64($(Get-Date).Subtract($win32_os.lastbootuptime).TotalSeconds)) + } +} + +if($gather_subset.Contains('user')) { + $user = [Security.Principal.WindowsIdentity]::GetCurrent() + $ansible_facts += @{ + ansible_user_dir = $env:userprofile + # Win32_UserAccount.FullName is probably the right thing here, but it can be expensive to get on large domains + ansible_user_gecos = "" + ansible_user_id = $env:username + ansible_user_sid = $user.User.Value + } +} + +if($gather_subset.Contains('windows_domain')) { + $win32_cs = Get-LazyCimInstance Win32_ComputerSystem + $domain_roles = @{ + 0 = "Stand-alone workstation" + 1 = "Member workstation" + 2 = "Stand-alone server" + 3 = "Member server" + 4 = "Backup domain controller" + 5 = "Primary domain controller" + } + + $domain_role = $domain_roles.Get_Item([Int32]$win32_cs.DomainRole) + + $ansible_facts += @{ + ansible_windows_domain = $win32_cs.Domain + ansible_windows_domain_member = $win32_cs.PartOfDomain + ansible_windows_domain_role = $domain_role + } +} + +if($gather_subset.Contains('winrm')) { + + $winrm_https_listener_parent_paths = Get-ChildItem -Path WSMan:\localhost\Listener -Recurse -ErrorAction SilentlyContinue | ` + Where-Object {$_.PSChildName -eq "Transport" -and $_.Value -eq "HTTPS"} | Select-Object PSParentPath + if ($winrm_https_listener_parent_paths -isnot [array]) { + $winrm_https_listener_parent_paths = @($winrm_https_listener_parent_paths) + } + + $winrm_https_listener_paths = @() + foreach ($winrm_https_listener_parent_path in $winrm_https_listener_parent_paths) { + $winrm_https_listener_paths += $winrm_https_listener_parent_path.PSParentPath.Substring($winrm_https_listener_parent_path.PSParentPath.LastIndexOf("\")) + } + + $https_listeners = @() + foreach ($winrm_https_listener_path in $winrm_https_listener_paths) { + $https_listeners += Get-ChildItem -Path "WSMan:\localhost\Listener$winrm_https_listener_path" + } + + $winrm_cert_thumbprints = @() + foreach ($https_listener in $https_listeners) { + $winrm_cert_thumbprints += $https_listener | Where-Object {$_.Name -EQ "CertificateThumbprint" } | Select-Object Value + } + + $winrm_cert_expiry = @() + foreach ($winrm_cert_thumbprint in $winrm_cert_thumbprints) { + Try { + $winrm_cert_expiry += Get-ChildItem -Path Cert:\LocalMachine\My | Where-Object Thumbprint -EQ $winrm_cert_thumbprint.Value.ToString().ToUpper() | Select-Object NotAfter + } Catch { + Add-Warning -obj $result -message "Error during certificate expiration retrieval: $($_.Exception.Message)" + } + } + + $winrm_cert_expirations = $winrm_cert_expiry | Sort-Object NotAfter + if ($winrm_cert_expirations) { + # this fact was renamed from ansible_winrm_certificate_expires due to collision with ansible_winrm_X connection var pattern + $ansible_facts.Add("ansible_win_rm_certificate_expires", $winrm_cert_expirations[0].NotAfter.ToString("yyyy-MM-dd HH:mm:ss")) + } +} + +if($gather_subset.Contains('virtual')) { + $machine_info = Get-LazyCimInstance Win32_ComputerSystem + + switch ($machine_info.model) { + "Virtual Machine" { + $machine_type="Hyper-V" + $machine_role="guest" + } + + "VMware Virtual Platform" { + $machine_type="VMware" + $machine_role="guest" + } + + "VirtualBox" { + $machine_type="VirtualBox" + $machine_role="guest" + } + + "HVM domU" { + $machine_type="Xen" + $machine_role="guest" + } + + default { + $machine_type="NA" + $machine_role="NA" + } + } + + $ansible_facts += @{ + ansible_virtualization_role = $machine_role + ansible_virtualization_type = $machine_type + } +} + +$result.ansible_facts += $ansible_facts + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/slurp.ps1 b/test/support/windows-integration/plugins/modules/slurp.ps1 new file mode 100644 index 00000000..eb506c7c --- /dev/null +++ b/test/support/windows-integration/plugins/modules/slurp.ps1 @@ -0,0 +1,28 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$params = Parse-Args $args -supports_check_mode $true; +$src = Get-AnsibleParam -obj $params -name "src" -type "path" -aliases "path" -failifempty $true; + +$result = @{ + changed = $false; +} + +If (Test-Path -LiteralPath $src -PathType Leaf) +{ + $bytes = [System.IO.File]::ReadAllBytes($src); + $result.content = [System.Convert]::ToBase64String($bytes); + $result.encoding = "base64"; + Exit-Json $result; +} +ElseIf (Test-Path -LiteralPath $src -PathType Container) +{ + Fail-Json $result "Path $src is a directory"; +} +Else +{ + Fail-Json $result "Path $src is not found"; +} diff --git a/test/support/windows-integration/plugins/modules/win_acl.ps1 b/test/support/windows-integration/plugins/modules/win_acl.ps1 new file mode 100644 index 00000000..e3c38130 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_acl.ps1 @@ -0,0 +1,225 @@ +#!powershell + +# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com> +# Copyright: (c) 2015, Trond Hindenes +# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.PrivilegeUtil +#Requires -Module Ansible.ModuleUtils.SID + +$ErrorActionPreference = "Stop" + +# win_acl module (File/Resources Permission Additions/Removal) + +#Functions +function Get-UserSID { + param( + [String]$AccountName + ) + + $userSID = $null + $searchAppPools = $false + + if ($AccountName.Split("\").Count -gt 1) { + if ($AccountName.Split("\")[0] -eq "IIS APPPOOL") { + $searchAppPools = $true + $AccountName = $AccountName.Split("\")[1] + } + } + + if ($searchAppPools) { + Import-Module -Name WebAdministration + $testIISPath = Test-Path -LiteralPath "IIS:" + if ($testIISPath) { + $appPoolObj = Get-ItemProperty -LiteralPath "IIS:\AppPools\$AccountName" + $userSID = $appPoolObj.applicationPoolSid + } + } + else { + $userSID = Convert-ToSID -account_name $AccountName + } + + return $userSID +} + +$params = Parse-Args $args + +Function SetPrivilegeTokens() { + # Set privilege tokens only if admin. + # Admins would have these privs or be able to set these privs in the UI Anyway + + $adminRole=[System.Security.Principal.WindowsBuiltInRole]::Administrator + $myWindowsID=[System.Security.Principal.WindowsIdentity]::GetCurrent() + $myWindowsPrincipal=new-object System.Security.Principal.WindowsPrincipal($myWindowsID) + + + if ($myWindowsPrincipal.IsInRole($adminRole)) { + # Need to adjust token privs when executing Set-ACL in certain cases. + # e.g. d:\testdir is owned by group in which current user is not a member and no perms are inherited from d:\ + # This also sets us up for setting the owner as a feature. + # See the following for details of each privilege + # https://msdn.microsoft.com/en-us/library/windows/desktop/bb530716(v=vs.85).aspx + $privileges = @( + "SeRestorePrivilege", # Grants all write access control to any file, regardless of ACL. + "SeBackupPrivilege", # Grants all read access control to any file, regardless of ACL. + "SeTakeOwnershipPrivilege" # Grants ability to take owernship of an object w/out being granted discretionary access + ) + foreach ($privilege in $privileges) { + $state = Get-AnsiblePrivilege -Name $privilege + if ($state -eq $false) { + Set-AnsiblePrivilege -Name $privilege -Value $true + } + } + } +} + + +$result = @{ + changed = $false +} + +$path = Get-AnsibleParam -obj $params -name "path" -type "str" -failifempty $true +$user = Get-AnsibleParam -obj $params -name "user" -type "str" -failifempty $true +$rights = Get-AnsibleParam -obj $params -name "rights" -type "str" -failifempty $true + +$type = Get-AnsibleParam -obj $params -name "type" -type "str" -failifempty $true -validateset "allow","deny" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "absent","present" + +$inherit = Get-AnsibleParam -obj $params -name "inherit" -type "str" +$propagation = Get-AnsibleParam -obj $params -name "propagation" -type "str" -default "None" -validateset "InheritOnly","None","NoPropagateInherit" + +# We mount the HKCR, HKU, and HKCC registry hives so PS can access them. +# Network paths have no qualifiers so we use -EA SilentlyContinue to ignore that +$path_qualifier = Split-Path -Path $path -Qualifier -ErrorAction SilentlyContinue +if ($path_qualifier -eq "HKCR:" -and (-not (Test-Path -LiteralPath HKCR:\))) { + New-PSDrive -Name HKCR -PSProvider Registry -Root HKEY_CLASSES_ROOT > $null +} +if ($path_qualifier -eq "HKU:" -and (-not (Test-Path -LiteralPath HKU:\))) { + New-PSDrive -Name HKU -PSProvider Registry -Root HKEY_USERS > $null +} +if ($path_qualifier -eq "HKCC:" -and (-not (Test-Path -LiteralPath HKCC:\))) { + New-PSDrive -Name HKCC -PSProvider Registry -Root HKEY_CURRENT_CONFIG > $null +} + +If (-Not (Test-Path -LiteralPath $path)) { + Fail-Json -obj $result -message "$path file or directory does not exist on the host" +} + +# Test that the user/group is resolvable on the local machine +$sid = Get-UserSID -AccountName $user +if (!$sid) { + Fail-Json -obj $result -message "$user is not a valid user or group on the host machine or domain" +} + +If (Test-Path -LiteralPath $path -PathType Leaf) { + $inherit = "None" +} +ElseIf ($null -eq $inherit) { + $inherit = "ContainerInherit, ObjectInherit" +} + +# Bug in Set-Acl, Get-Acl where -LiteralPath only works for the Registry provider if the location is in that root +# qualifier. We also don't have a qualifier for a network path so only change if not null +if ($null -ne $path_qualifier) { + Push-Location -LiteralPath $path_qualifier +} + +Try { + SetPrivilegeTokens + $path_item = Get-Item -LiteralPath $path -Force + If ($path_item.PSProvider.Name -eq "Registry") { + $colRights = [System.Security.AccessControl.RegistryRights]$rights + } + Else { + $colRights = [System.Security.AccessControl.FileSystemRights]$rights + } + + $InheritanceFlag = [System.Security.AccessControl.InheritanceFlags]$inherit + $PropagationFlag = [System.Security.AccessControl.PropagationFlags]$propagation + + If ($type -eq "allow") { + $objType =[System.Security.AccessControl.AccessControlType]::Allow + } + Else { + $objType =[System.Security.AccessControl.AccessControlType]::Deny + } + + $objUser = New-Object System.Security.Principal.SecurityIdentifier($sid) + If ($path_item.PSProvider.Name -eq "Registry") { + $objACE = New-Object System.Security.AccessControl.RegistryAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType) + } + Else { + $objACE = New-Object System.Security.AccessControl.FileSystemAccessRule ($objUser, $colRights, $InheritanceFlag, $PropagationFlag, $objType) + } + $objACL = Get-ACL -LiteralPath $path + + # Check if the ACE exists already in the objects ACL list + $match = $false + + ForEach($rule in $objACL.GetAccessRules($true, $true, [System.Security.Principal.SecurityIdentifier])){ + + If ($path_item.PSProvider.Name -eq "Registry") { + If (($rule.RegistryRights -eq $objACE.RegistryRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) { + $match = $true + Break + } + } else { + If (($rule.FileSystemRights -eq $objACE.FileSystemRights) -And ($rule.AccessControlType -eq $objACE.AccessControlType) -And ($rule.IdentityReference -eq $objACE.IdentityReference) -And ($rule.IsInherited -eq $objACE.IsInherited) -And ($rule.InheritanceFlags -eq $objACE.InheritanceFlags) -And ($rule.PropagationFlags -eq $objACE.PropagationFlags)) { + $match = $true + Break + } + } + } + + If ($state -eq "present" -And $match -eq $false) { + Try { + $objACL.AddAccessRule($objACE) + If ($path_item.PSProvider.Name -eq "Registry") { + Set-ACL -LiteralPath $path -AclObject $objACL + } else { + (Get-Item -LiteralPath $path).SetAccessControl($objACL) + } + $result.changed = $true + } + Catch { + Fail-Json -obj $result -message "an exception occurred when adding the specified rule - $($_.Exception.Message)" + } + } + ElseIf ($state -eq "absent" -And $match -eq $true) { + Try { + $objACL.RemoveAccessRule($objACE) + If ($path_item.PSProvider.Name -eq "Registry") { + Set-ACL -LiteralPath $path -AclObject $objACL + } else { + (Get-Item -LiteralPath $path).SetAccessControl($objACL) + } + $result.changed = $true + } + Catch { + Fail-Json -obj $result -message "an exception occurred when removing the specified rule - $($_.Exception.Message)" + } + } + Else { + # A rule was attempting to be added but already exists + If ($match -eq $true) { + Exit-Json -obj $result -message "the specified rule already exists" + } + # A rule didn't exist that was trying to be removed + Else { + Exit-Json -obj $result -message "the specified rule does not exist" + } + } +} +Catch { + Fail-Json -obj $result -message "an error occurred when attempting to $state $rights permission(s) on $path for $user - $($_.Exception.Message)" +} +Finally { + # Make sure we revert the location stack to the original path just for cleanups sake + if ($null -ne $path_qualifier) { + Pop-Location + } +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/plugins/modules/win_acl.py b/test/support/windows-integration/plugins/modules/win_acl.py new file mode 100644 index 00000000..14fbd82f --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_acl.py @@ -0,0 +1,132 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Phil Schwartz <schwartzmx@gmail.com> +# Copyright: (c) 2015, Trond Hindenes +# Copyright: (c) 2015, Hans-Joachim Kliemeck <git@kliemeck.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_acl +version_added: "2.0" +short_description: Set file/directory/registry permissions for a system user or group +description: +- Add or remove rights/permissions for a given user or group for the specified + file, folder, registry key or AppPool identifies. +options: + path: + description: + - The path to the file or directory. + type: str + required: yes + user: + description: + - User or Group to add specified rights to act on src file/folder or + registry key. + type: str + required: yes + state: + description: + - Specify whether to add C(present) or remove C(absent) the specified access rule. + type: str + choices: [ absent, present ] + default: present + type: + description: + - Specify whether to allow or deny the rights specified. + type: str + required: yes + choices: [ allow, deny ] + rights: + description: + - The rights/permissions that are to be allowed/denied for the specified + user or group for the item at C(path). + - If C(path) is a file or directory, rights can be any right under MSDN + FileSystemRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.filesystemrights.aspx). + - If C(path) is a registry key, rights can be any right under MSDN + RegistryRights U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.registryrights.aspx). + type: str + required: yes + inherit: + description: + - Inherit flags on the ACL rules. + - Can be specified as a comma separated list, e.g. C(ContainerInherit), + C(ObjectInherit). + - For more information on the choices see MSDN InheritanceFlags enumeration + at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.inheritanceflags.aspx). + - Defaults to C(ContainerInherit, ObjectInherit) for Directories. + type: str + choices: [ ContainerInherit, ObjectInherit ] + propagation: + description: + - Propagation flag on the ACL rules. + - For more information on the choices see MSDN PropagationFlags enumeration + at U(https://msdn.microsoft.com/en-us/library/system.security.accesscontrol.propagationflags.aspx). + type: str + choices: [ InheritOnly, None, NoPropagateInherit ] + default: "None" +notes: +- If adding ACL's for AppPool identities (available since 2.3), the Windows + Feature "Web-Scripting-Tools" must be enabled. +seealso: +- module: win_acl_inheritance +- module: win_file +- module: win_owner +- module: win_stat +author: +- Phil Schwartz (@schwartzmx) +- Trond Hindenes (@trondhindenes) +- Hans-Joachim Kliemeck (@h0nIg) +''' + +EXAMPLES = r''' +- name: Restrict write and execute access to User Fed-Phil + win_acl: + user: Fed-Phil + path: C:\Important\Executable.exe + type: deny + rights: ExecuteFile,Write + +- name: Add IIS_IUSRS allow rights + win_acl: + path: C:\inetpub\wwwroot\MySite + user: IIS_IUSRS + rights: FullControl + type: allow + state: present + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Set registry key right + win_acl: + path: HKCU:\Bovine\Key + user: BUILTIN\Users + rights: EnumerateSubKeys + type: allow + state: present + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Remove FullControl AccessRule for IIS_IUSRS + win_acl: + path: C:\inetpub\wwwroot\MySite + user: IIS_IUSRS + rights: FullControl + type: allow + state: absent + inherit: ContainerInherit, ObjectInherit + propagation: 'None' + +- name: Deny Intern + win_acl: + path: C:\Administrator\Documents + user: Intern + rights: Read,Write,Modify,FullControl,Delete + type: deny + state: present +''' diff --git a/test/support/windows-integration/plugins/modules/win_certificate_store.ps1 b/test/support/windows-integration/plugins/modules/win_certificate_store.ps1 new file mode 100644 index 00000000..db984130 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_certificate_store.ps1 @@ -0,0 +1,260 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic + +$store_name_values = ([System.Security.Cryptography.X509Certificates.StoreName]).GetEnumValues() | ForEach-Object { $_.ToString() } +$store_location_values = ([System.Security.Cryptography.X509Certificates.StoreLocation]).GetEnumValues() | ForEach-Object { $_.ToString() } + +$spec = @{ + options = @{ + state = @{ type = "str"; default = "present"; choices = "absent", "exported", "present" } + path = @{ type = "path" } + thumbprint = @{ type = "str" } + store_name = @{ type = "str"; default = "My"; choices = $store_name_values } + store_location = @{ type = "str"; default = "LocalMachine"; choices = $store_location_values } + password = @{ type = "str"; no_log = $true } + key_exportable = @{ type = "bool"; default = $true } + key_storage = @{ type = "str"; default = "default"; choices = "default", "machine", "user" } + file_type = @{ type = "str"; default = "der"; choices = "der", "pem", "pkcs12" } + } + required_if = @( + @("state", "absent", @("path", "thumbprint"), $true), + @("state", "exported", @("path", "thumbprint")), + @("state", "present", @("path")) + ) + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +Function Get-CertFile($module, $path, $password, $key_exportable, $key_storage) { + # parses a certificate file and returns X509Certificate2Collection + if (-not (Test-Path -LiteralPath $path -PathType Leaf)) { + $module.FailJson("File at '$path' either does not exist or is not a file") + } + + # must set at least the PersistKeySet flag so that the PrivateKey + # is stored in a permanent container and not deleted once the handle + # is gone. + $store_flags = [System.Security.Cryptography.X509Certificates.X509KeyStorageFlags]::PersistKeySet + + $key_storage = $key_storage.substring(0,1).ToUpper() + $key_storage.substring(1).ToLower() + $store_flags = $store_flags -bor [Enum]::Parse([System.Security.Cryptography.X509Certificates.X509KeyStorageFlags], "$($key_storage)KeySet") + if ($key_exportable) { + $store_flags = $store_flags -bor [System.Security.Cryptography.X509Certificates.X509KeyStorageFlags]::Exportable + } + + # TODO: If I'm feeling adventurours, write code to parse PKCS#12 PEM encoded + # file as .NET does not have an easy way to import this + $certs = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Certificate2Collection + + try { + $certs.Import($path, $password, $store_flags) + } catch { + $module.FailJson("Failed to load cert from file: $($_.Exception.Message)", $_) + } + + return $certs +} + +Function New-CertFile($module, $cert, $path, $type, $password) { + $content_type = switch ($type) { + "pem" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Cert } + "der" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Cert } + "pkcs12" { [System.Security.Cryptography.X509Certificates.X509ContentType]::Pkcs12 } + } + if ($type -eq "pkcs12") { + $missing_key = $false + if ($null -eq $cert.PrivateKey) { + $missing_key = $true + } elseif ($cert.PrivateKey.CspKeyContainerInfo.Exportable -eq $false) { + $missing_key = $true + } + if ($missing_key) { + $module.FailJson("Cannot export cert with key as PKCS12 when the key is not marked as exportable or not accessible by the current user") + } + } + + if (Test-Path -LiteralPath $path) { + Remove-Item -LiteralPath $path -Force + $module.Result.changed = $true + } + try { + $cert_bytes = $cert.Export($content_type, $password) + } catch { + $module.FailJson("Failed to export certificate as bytes: $($_.Exception.Message)", $_) + } + + # Need to manually handle a PEM file + if ($type -eq "pem") { + $cert_content = "-----BEGIN CERTIFICATE-----`r`n" + $base64_string = [System.Convert]::ToBase64String($cert_bytes, [System.Base64FormattingOptions]::InsertLineBreaks) + $cert_content += $base64_string + $cert_content += "`r`n-----END CERTIFICATE-----" + $file_encoding = [System.Text.Encoding]::ASCII + $cert_bytes = $file_encoding.GetBytes($cert_content) + } elseif ($type -eq "pkcs12") { + $module.Result.key_exported = $false + if ($null -ne $cert.PrivateKey) { + $module.Result.key_exportable = $cert.PrivateKey.CspKeyContainerInfo.Exportable + } + } + + if (-not $module.CheckMode) { + try { + [System.IO.File]::WriteAllBytes($path, $cert_bytes) + } catch [System.ArgumentNullException] { + $module.FailJson("Failed to write cert to file, cert was null: $($_.Exception.Message)", $_) + } catch [System.IO.IOException] { + $module.FailJson("Failed to write cert to file due to IO Exception: $($_.Exception.Message)", $_) + } catch [System.UnauthorizedAccessException] { + $module.FailJson("Failed to write cert to file due to permissions: $($_.Exception.Message)", $_) + } catch { + $module.FailJson("Failed to write cert to file: $($_.Exception.Message)", $_) + } + } + $module.Result.changed = $true +} + +Function Get-CertFileType($path, $password) { + $certs = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Certificate2Collection + try { + $certs.Import($path, $password, 0) + } catch [System.Security.Cryptography.CryptographicException] { + # the file is a pkcs12 we just had the wrong password + return "pkcs12" + } catch { + return "unknown" + } + + $file_contents = Get-Content -LiteralPath $path -Raw + if ($file_contents.StartsWith("-----BEGIN CERTIFICATE-----")) { + return "pem" + } elseif ($file_contents.StartsWith("-----BEGIN PKCS7-----")) { + return "pkcs7-ascii" + } elseif ($certs.Count -gt 1) { + # multiple certs must be pkcs7 + return "pkcs7-binary" + } elseif ($certs[0].HasPrivateKey) { + return "pkcs12" + } elseif ($path.EndsWith(".pfx") -or $path.EndsWith(".p12")) { + # no way to differenciate a pfx with a der file so we must rely on the + # extension + return "pkcs12" + } else { + return "der" + } +} + +$state = $module.Params.state +$path = $module.Params.path +$thumbprint = $module.Params.thumbprint +$store_name = [System.Security.Cryptography.X509Certificates.StoreName]"$($module.Params.store_name)" +$store_location = [System.Security.Cryptography.X509Certificates.Storelocation]"$($module.Params.store_location)" +$password = $module.Params.password +$key_exportable = $module.Params.key_exportable +$key_storage = $module.Params.key_storage +$file_type = $module.Params.file_type + +$module.Result.thumbprints = @() + +$store = New-Object -TypeName System.Security.Cryptography.X509Certificates.X509Store -ArgumentList $store_name, $store_location +try { + $store.Open([System.Security.Cryptography.X509Certificates.OpenFlags]::ReadWrite) +} catch [System.Security.Cryptography.CryptographicException] { + $module.FailJson("Unable to open the store as it is not readable: $($_.Exception.Message)", $_) +} catch [System.Security.SecurityException] { + $module.FailJson("Unable to open the store with the current permissions: $($_.Exception.Message)", $_) +} catch { + $module.FailJson("Unable to open the store: $($_.Exception.Message)", $_) +} +$store_certificates = $store.Certificates + +try { + if ($state -eq "absent") { + $cert_thumbprints = @() + + if ($null -ne $path) { + $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage + foreach ($cert in $certs) { + $cert_thumbprints += $cert.Thumbprint + } + } elseif ($null -ne $thumbprint) { + $cert_thumbprints += $thumbprint + } + + foreach ($cert_thumbprint in $cert_thumbprints) { + $module.Result.thumbprints += $cert_thumbprint + $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $cert_thumbprint, $false) + if ($found_certs.Count -gt 0) { + foreach ($found_cert in $found_certs) { + try { + if (-not $module.CheckMode) { + $store.Remove($found_cert) + } + } catch [System.Security.SecurityException] { + $module.FailJson("Unable to remove cert with thumbprint '$cert_thumbprint' with current permissions: $($_.Exception.Message)", $_) + } catch { + $module.FailJson("Unable to remove cert with thumbprint '$cert_thumbprint': $($_.Exception.Message)", $_) + } + $module.Result.changed = $true + } + } + } + } elseif ($state -eq "exported") { + # TODO: Add support for PKCS7 and exporting a cert chain + $module.Result.thumbprints += $thumbprint + $export = $true + if (Test-Path -LiteralPath $path -PathType Container) { + $module.FailJson("Cannot export cert to path '$path' as it is a directory") + } elseif (Test-Path -LiteralPath $path -PathType Leaf) { + $actual_cert_type = Get-CertFileType -path $path -password $password + if ($actual_cert_type -eq $file_type) { + try { + $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage + } catch { + # failed to load the file so we set the thumbprint to something + # that will fail validation + $certs = @{Thumbprint = $null} + } + + if ($certs.Thumbprint -eq $thumbprint) { + $export = $false + } + } + } + + if ($export) { + $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $thumbprint, $false) + if ($found_certs.Count -ne 1) { + $module.FailJson("Found $($found_certs.Count) certs when only expecting 1") + } + + New-CertFile -module $module -cert $found_certs -path $path -type $file_type -password $password + } + } else { + $certs = Get-CertFile -module $module -path $path -password $password -key_exportable $key_exportable -key_storage $key_storage + foreach ($cert in $certs) { + $module.Result.thumbprints += $cert.Thumbprint + $found_certs = $store_certificates.Find([System.Security.Cryptography.X509Certificates.X509FindType]::FindByThumbprint, $cert.Thumbprint, $false) + if ($found_certs.Count -eq 0) { + try { + if (-not $module.CheckMode) { + $store.Add($cert) + } + } catch [System.Security.Cryptography.CryptographicException] { + $module.FailJson("Unable to import certificate with thumbprint '$($cert.Thumbprint)' with the current permissions: $($_.Exception.Message)", $_) + } catch { + $module.FailJson("Unable to import certificate with thumbprint '$($cert.Thumbprint)': $($_.Exception.Message)", $_) + } + $module.Result.changed = $true + } + } + } +} finally { + $store.Close() +} + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_certificate_store.py b/test/support/windows-integration/plugins/modules/win_certificate_store.py new file mode 100644 index 00000000..dc617e33 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_certificate_store.py @@ -0,0 +1,208 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_certificate_store +version_added: '2.5' +short_description: Manages the certificate store +description: +- Used to import/export and remove certificates and keys from the local + certificate store. +- This module is not used to create certificates and will only manage existing + certs as a file or in the store. +- It can be used to import PEM, DER, P7B, PKCS12 (PFX) certificates and export + PEM, DER and PKCS12 certificates. +options: + state: + description: + - If C(present), will ensure that the certificate at I(path) is imported + into the certificate store specified. + - If C(absent), will ensure that the certificate specified by I(thumbprint) + or the thumbprint of the cert at I(path) is removed from the store + specified. + - If C(exported), will ensure the file at I(path) is a certificate + specified by I(thumbprint). + - When exporting a certificate, if I(path) is a directory then the module + will fail, otherwise the file will be replaced if needed. + type: str + choices: [ absent, exported, present ] + default: present + path: + description: + - The path to a certificate file. + - This is required when I(state) is C(present) or C(exported). + - When I(state) is C(absent) and I(thumbprint) is not specified, the + thumbprint is derived from the certificate at this path. + type: path + thumbprint: + description: + - The thumbprint as a hex string to either export or remove. + - See the examples for how to specify the thumbprint. + type: str + store_name: + description: + - The store name to use when importing a certificate or searching for a + certificate. + - "C(AddressBook): The X.509 certificate store for other users" + - "C(AuthRoot): The X.509 certificate store for third-party certificate authorities (CAs)" + - "C(CertificateAuthority): The X.509 certificate store for intermediate certificate authorities (CAs)" + - "C(Disallowed): The X.509 certificate store for revoked certificates" + - "C(My): The X.509 certificate store for personal certificates" + - "C(Root): The X.509 certificate store for trusted root certificate authorities (CAs)" + - "C(TrustedPeople): The X.509 certificate store for directly trusted people and resources" + - "C(TrustedPublisher): The X.509 certificate store for directly trusted publishers" + type: str + choices: + - AddressBook + - AuthRoot + - CertificateAuthority + - Disallowed + - My + - Root + - TrustedPeople + - TrustedPublisher + default: My + store_location: + description: + - The store location to use when importing a certificate or searching for a + certificate. + choices: [ CurrentUser, LocalMachine ] + default: LocalMachine + password: + description: + - The password of the pkcs12 certificate key. + - This is used when reading a pkcs12 certificate file or the password to + set when C(state=exported) and C(file_type=pkcs12). + - If the pkcs12 file has no password set or no password should be set on + the exported file, do not set this option. + type: str + key_exportable: + description: + - Whether to allow the private key to be exported. + - If C(no), then this module and other process will only be able to export + the certificate and the private key cannot be exported. + - Used when C(state=present) only. + type: bool + default: yes + key_storage: + description: + - Specifies where Windows will store the private key when it is imported. + - When set to C(default), the default option as set by Windows is used, typically C(user). + - When set to C(machine), the key is stored in a path accessible by various + users. + - When set to C(user), the key is stored in a path only accessible by the + current user. + - Used when C(state=present) only and cannot be changed once imported. + - See U(https://msdn.microsoft.com/en-us/library/system.security.cryptography.x509certificates.x509keystorageflags.aspx) + for more details. + type: str + choices: [ default, machine, user ] + default: default + file_type: + description: + - The file type to export the certificate as when C(state=exported). + - C(der) is a binary ASN.1 encoded file. + - C(pem) is a base64 encoded file of a der file in the OpenSSL form. + - C(pkcs12) (also known as pfx) is a binary container that contains both + the certificate and private key unlike the other options. + - When C(pkcs12) is set and the private key is not exportable or accessible + by the current user, it will throw an exception. + type: str + choices: [ der, pem, pkcs12 ] + default: der +notes: +- Some actions on PKCS12 certificates and keys may fail with the error + C(the specified network password is not correct), either use CredSSP or + Kerberos with credential delegation, or use C(become) to bypass these + restrictions. +- The certificates must be located on the Windows host to be set with I(path). +- When importing a certificate for usage in IIS, it is generally required + to use the C(machine) key_storage option, as both C(default) and C(user) + will make the private key unreadable to IIS APPPOOL identities and prevent + binding the certificate to the https endpoint. +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Import a certificate + win_certificate_store: + path: C:\Temp\cert.pem + state: present + +- name: Import pfx certificate that is password protected + win_certificate_store: + path: C:\Temp\cert.pfx + state: present + password: VeryStrongPasswordHere! + become: yes + become_method: runas + +- name: Import pfx certificate without password and set private key as un-exportable + win_certificate_store: + path: C:\Temp\cert.pfx + state: present + key_exportable: no + # usually you don't set this here but it is for illustrative purposes + vars: + ansible_winrm_transport: credssp + +- name: Remove a certificate based on file thumbprint + win_certificate_store: + path: C:\Temp\cert.pem + state: absent + +- name: Remove a certificate based on thumbprint + win_certificate_store: + thumbprint: BD7AF104CF1872BDB518D95C9534EA941665FD27 + state: absent + +- name: Remove certificate based on thumbprint is CurrentUser/TrustedPublishers store + win_certificate_store: + thumbprint: BD7AF104CF1872BDB518D95C9534EA941665FD27 + state: absent + store_location: CurrentUser + store_name: TrustedPublisher + +- name: Export certificate as der encoded file + win_certificate_store: + path: C:\Temp\cert.cer + state: exported + file_type: der + +- name: Export certificate and key as pfx encoded file + win_certificate_store: + path: C:\Temp\cert.pfx + state: exported + file_type: pkcs12 + password: AnotherStrongPass! + become: yes + become_method: runas + become_user: SYSTEM + +- name: Import certificate be used by IIS + win_certificate_store: + path: C:\Temp\cert.pfx + file_type: pkcs12 + password: StrongPassword! + store_location: LocalMachine + key_storage: machine + state: present +''' + +RETURN = r''' +thumbprints: + description: A list of certificate thumbprints that were touched by the + module. + returned: success + type: list + sample: ["BC05633694E675449136679A658281F17A191087"] +''' diff --git a/test/support/windows-integration/plugins/modules/win_command.ps1 b/test/support/windows-integration/plugins/modules/win_command.ps1 new file mode 100644 index 00000000..e2a30650 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_command.ps1 @@ -0,0 +1,78 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.CommandUtil +#Requires -Module Ansible.ModuleUtils.FileUtil + +# TODO: add check mode support + +Set-StrictMode -Version 2 +$ErrorActionPreference = 'Stop' + +$params = Parse-Args $args -supports_check_mode $false + +$raw_command_line = Get-AnsibleParam -obj $params -name "_raw_params" -type "str" -failifempty $true +$chdir = Get-AnsibleParam -obj $params -name "chdir" -type "path" +$creates = Get-AnsibleParam -obj $params -name "creates" -type "path" +$removes = Get-AnsibleParam -obj $params -name "removes" -type "path" +$stdin = Get-AnsibleParam -obj $params -name "stdin" -type "str" +$output_encoding_override = Get-AnsibleParam -obj $params -name "output_encoding_override" -type "str" + +$raw_command_line = $raw_command_line.Trim() + +$result = @{ + changed = $true + cmd = $raw_command_line +} + +if ($creates -and $(Test-AnsiblePath -Path $creates)) { + Exit-Json @{msg="skipped, since $creates exists";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +if ($removes -and -not $(Test-AnsiblePath -Path $removes)) { + Exit-Json @{msg="skipped, since $removes does not exist";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +$command_args = @{ + command = $raw_command_line +} +if ($chdir) { + $command_args['working_directory'] = $chdir +} +if ($stdin) { + $command_args['stdin'] = $stdin +} +if ($output_encoding_override) { + $command_args['output_encoding_override'] = $output_encoding_override +} + +$start_datetime = [DateTime]::UtcNow +try { + $command_result = Run-Command @command_args +} catch { + $result.changed = $false + try { + $result.rc = $_.Exception.NativeErrorCode + } catch { + $result.rc = 2 + } + Fail-Json -obj $result -message $_.Exception.Message +} + +$result.stdout = $command_result.stdout +$result.stderr = $command_result.stderr +$result.rc = $command_result.rc + +$end_datetime = [DateTime]::UtcNow +$result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff") + +If ($result.rc -ne 0) { + Fail-Json -obj $result -message "non-zero return code" +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_command.py b/test/support/windows-integration/plugins/modules/win_command.py new file mode 100644 index 00000000..508419b2 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_command.py @@ -0,0 +1,136 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Ansible, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_command +short_description: Executes a command on a remote Windows node +version_added: 2.2 +description: + - The C(win_command) module takes the command name followed by a list of space-delimited arguments. + - The given command will be executed on all selected nodes. It will not be + processed through the shell, so variables like C($env:HOME) and operations + like C("<"), C(">"), C("|"), and C(";") will not work (use the M(win_shell) + module if you need these features). + - For non-Windows targets, use the M(command) module instead. +options: + free_form: + description: + - The C(win_command) module takes a free form command to run. + - There is no parameter actually named 'free form'. See the examples! + type: str + required: yes + creates: + description: + - A path or path filter pattern; when the referenced path exists on the target host, the task will be skipped. + type: path + removes: + description: + - A path or path filter pattern; when the referenced path B(does not) exist on the target host, the task will be skipped. + type: path + chdir: + description: + - Set the specified path as the current working directory before executing a command. + type: path + stdin: + description: + - Set the stdin of the command directly to the specified value. + type: str + version_added: '2.5' + output_encoding_override: + description: + - This option overrides the encoding of stdout/stderr output. + - You can use this option when you need to run a command which ignore the console's codepage. + - You should only need to use this option in very rare circumstances. + - This value can be any valid encoding C(Name) based on the output of C([System.Text.Encoding]::GetEncodings()). + See U(https://docs.microsoft.com/dotnet/api/system.text.encoding.getencodings). + type: str + version_added: '2.10' +notes: + - If you want to run a command through a shell (say you are using C(<), + C(>), C(|), etc), you actually want the M(win_shell) module instead. The + C(win_command) module is much more secure as it's not affected by the user's + environment. + - C(creates), C(removes), and C(chdir) can be specified after the command. For instance, if you only want to run a command if a certain file does not + exist, use this. +seealso: +- module: command +- module: psexec +- module: raw +- module: win_psexec +- module: win_shell +author: + - Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +- name: Save the result of 'whoami' in 'whoami_out' + win_command: whoami + register: whoami_out + +- name: Run command that only runs if folder exists and runs from a specific folder + win_command: wbadmin -backupTarget:C:\backup\ + args: + chdir: C:\somedir\ + creates: C:\backup\ + +- name: Run an executable and send data to the stdin for the executable + win_command: powershell.exe - + args: + stdin: Write-Host test +''' + +RETURN = r''' +msg: + description: changed + returned: always + type: bool + sample: true +start: + description: The command execution start time + returned: always + type: str + sample: '2016-02-25 09:18:26.429568' +end: + description: The command execution end time + returned: always + type: str + sample: '2016-02-25 09:18:26.755339' +delta: + description: The command execution delta time + returned: always + type: str + sample: '0:00:00.325771' +stdout: + description: The command standard output + returned: always + type: str + sample: 'Clustering node rabbit@slave1 with rabbit@master ...' +stderr: + description: The command standard error + returned: always + type: str + sample: 'ls: cannot access foo: No such file or directory' +cmd: + description: The command executed by the task + returned: always + type: str + sample: 'rabbitmqctl join_cluster rabbit@master' +rc: + description: The command return code (0 means success) + returned: always + type: int + sample: 0 +stdout_lines: + description: The command standard output split in lines + returned: always + type: list + sample: [u'Clustering node rabbit@slave1 with rabbit@master ...'] +''' diff --git a/test/support/windows-integration/plugins/modules/win_copy.ps1 b/test/support/windows-integration/plugins/modules/win_copy.ps1 new file mode 100644 index 00000000..6a26ee72 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_copy.ps1 @@ -0,0 +1,403 @@ +#!powershell + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.Backup + +$ErrorActionPreference = 'Stop' + +$params = Parse-Args -arguments $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false + +# there are 4 modes to win_copy which are driven by the action plugins: +# explode: src is a zip file which needs to be extracted to dest, for use with multiple files +# query: win_copy action plugin wants to get the state of remote files to check whether it needs to send them +# remote: all copy action is happening remotely (remote_src=True) +# single: a single file has been copied, also used with template +$copy_mode = Get-AnsibleParam -obj $params -name "_copy_mode" -type "str" -default "single" -validateset "explode","query","remote","single" + +# used in explode, remote and single mode +$src = Get-AnsibleParam -obj $params -name "src" -type "path" -failifempty ($copy_mode -in @("explode","process","single")) +$dest = Get-AnsibleParam -obj $params -name "dest" -type "path" -failifempty $true +$backup = Get-AnsibleParam -obj $params -name "backup" -type "bool" -default $false + +# used in single mode +$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str" + +# used in query and remote mode +$force = Get-AnsibleParam -obj $params -name "force" -type "bool" -default $true + +# used in query mode, contains the local files/directories/symlinks that are to be copied +$files = Get-AnsibleParam -obj $params -name "files" -type "list" +$directories = Get-AnsibleParam -obj $params -name "directories" -type "list" + +$result = @{ + changed = $false +} + +if ($diff_mode) { + $result.diff = @{} +} + +Function Copy-File($source, $dest) { + $diff = "" + $copy_file = $false + $source_checksum = $null + if ($force) { + $source_checksum = Get-FileChecksum -path $source + } + + if (Test-Path -LiteralPath $dest -PathType Container) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': dest is already a folder" + } elseif (Test-Path -LiteralPath $dest -PathType Leaf) { + if ($force) { + $target_checksum = Get-FileChecksum -path $dest + if ($source_checksum -ne $target_checksum) { + $copy_file = $true + } + } + } else { + $copy_file = $true + } + + if ($copy_file) { + $file_dir = [System.IO.Path]::GetDirectoryName($dest) + # validate the parent dir is not a file and that it exists + if (Test-Path -LiteralPath $file_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder" + } elseif (-not (Test-Path -LiteralPath $file_dir)) { + # directory doesn't exist, need to create + New-Item -Path $file_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + $diff += "+$file_dir\`n" + } + + if ($backup) { + $result.backup_file = Backup-File -path $dest -WhatIf:$check_mode + } + + if (Test-Path -LiteralPath $dest -PathType Leaf) { + Remove-Item -LiteralPath $dest -Force -Recurse -WhatIf:$check_mode | Out-Null + $diff += "-$dest`n" + } + + if (-not $check_mode) { + # cannot run with -WhatIf:$check_mode as if the parent dir didn't + # exist and was created above would still not exist in check mode + Copy-Item -LiteralPath $source -Destination $dest -Force | Out-Null + } + $diff += "+$dest`n" + + $result.changed = $true + } + + # ugly but to save us from running the checksum twice, let's return it for + # the main code to add it to $result + return ,@{ diff = $diff; checksum = $source_checksum } +} + +Function Copy-Folder($source, $dest) { + $diff = "" + + if (-not (Test-Path -LiteralPath $dest -PathType Container)) { + $parent_dir = [System.IO.Path]::GetDirectoryName($dest) + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy file from '$source' to '$dest': object at dest parent dir is not a folder" + } + if (Test-Path -LiteralPath $dest -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder from '$source' to '$dest': dest is already a file" + } + + New-Item -Path $dest -ItemType Container -WhatIf:$check_mode | Out-Null + $diff += "+$dest\`n" + $result.changed = $true + } + + $child_items = Get-ChildItem -LiteralPath $source -Force + foreach ($child_item in $child_items) { + $dest_child_path = Join-Path -Path $dest -ChildPath $child_item.Name + if ($child_item.PSIsContainer) { + $diff += (Copy-Folder -source $child_item.Fullname -dest $dest_child_path) + } else { + $diff += (Copy-File -source $child_item.Fullname -dest $dest_child_path).diff + } + } + + return $diff +} + +Function Get-FileSize($path) { + $file = Get-Item -LiteralPath $path -Force + if ($file.PSIsContainer) { + $size = (Get-ChildItem -Literalpath $file.FullName -Recurse -Force | ` + Where-Object { $_.PSObject.Properties.Name -contains 'Length' } | ` + Measure-Object -Property Length -Sum).Sum + if ($null -eq $size) { + $size = 0 + } + } else { + $size = $file.Length + } + + $size +} + +Function Extract-Zip($src, $dest) { + $archive = [System.IO.Compression.ZipFile]::Open($src, [System.IO.Compression.ZipArchiveMode]::Read, [System.Text.Encoding]::UTF8) + foreach ($entry in $archive.Entries) { + $archive_name = $entry.FullName + + # FullName may be appended with / or \, determine if it is padded and remove it + $padding_length = $archive_name.Length % 4 + if ($padding_length -eq 0) { + $is_dir = $false + $base64_name = $archive_name + } elseif ($padding_length -eq 1) { + $is_dir = $true + if ($archive_name.EndsWith("/") -or $archive_name.EndsWith("`\")) { + $base64_name = $archive_name.Substring(0, $archive_name.Length - 1) + } else { + throw "invalid base64 archive name '$archive_name'" + } + } else { + throw "invalid base64 length '$archive_name'" + } + + # to handle unicode character, win_copy action plugin has encoded the filename + $decoded_archive_name = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($base64_name)) + # re-add the / to the entry full name if it was a directory + if ($is_dir) { + $decoded_archive_name = "$decoded_archive_name/" + } + $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_name) + $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path) + + if (-not (Test-Path -LiteralPath $entry_dir)) { + New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + } + + if ($is_dir -eq $false) { + if (-not $check_mode) { + [System.IO.Compression.ZipFileExtensions]::ExtractToFile($entry, $entry_target_path, $true) + } + } + } + $archive.Dispose() # release the handle of the zip file +} + +Function Extract-ZipLegacy($src, $dest) { + if (-not (Test-Path -LiteralPath $dest)) { + New-Item -Path $dest -ItemType Directory -WhatIf:$check_mode | Out-Null + } + $shell = New-Object -ComObject Shell.Application + $zip = $shell.NameSpace($src) + $dest_path = $shell.NameSpace($dest) + + foreach ($entry in $zip.Items()) { + $is_dir = $entry.IsFolder + $encoded_archive_entry = $entry.Name + # to handle unicode character, win_copy action plugin has encoded the filename + $decoded_archive_entry = [System.Text.Encoding]::UTF8.GetString([System.Convert]::FromBase64String($encoded_archive_entry)) + if ($is_dir) { + $decoded_archive_entry = "$decoded_archive_entry/" + } + + $entry_target_path = [System.IO.Path]::Combine($dest, $decoded_archive_entry) + $entry_dir = [System.IO.Path]::GetDirectoryName($entry_target_path) + + if (-not (Test-Path -LiteralPath $entry_dir)) { + New-Item -Path $entry_dir -ItemType Directory -WhatIf:$check_mode | Out-Null + } + + if ($is_dir -eq $false -and (-not $check_mode)) { + # https://msdn.microsoft.com/en-us/library/windows/desktop/bb787866.aspx + # From Folder.CopyHere documentation, 1044 means: + # - 1024: do not display a user interface if an error occurs + # - 16: respond with "yes to all" for any dialog box that is displayed + # - 4: do not display a progress dialog box + $dest_path.CopyHere($entry, 1044) + + # once file is extraced, we need to rename it with non base64 name + $combined_encoded_path = [System.IO.Path]::Combine($dest, $encoded_archive_entry) + Move-Item -LiteralPath $combined_encoded_path -Destination $entry_target_path -Force | Out-Null + } + } +} + +if ($copy_mode -eq "query") { + # we only return a list of files/directories that need to be copied over + # the source of the local file will be the key used + $changed_files = @() + $changed_directories = @() + $changed_symlinks = @() + + foreach ($file in $files) { + $filename = $file.dest + $local_checksum = $file.checksum + + $filepath = Join-Path -Path $dest -ChildPath $filename + if (Test-Path -LiteralPath $filepath -PathType Leaf) { + if ($force) { + $checksum = Get-FileChecksum -path $filepath + if ($checksum -ne $local_checksum) { + $changed_files += $file + } + } + } elseif (Test-Path -LiteralPath $filepath -PathType Container) { + Fail-Json -obj $result -message "cannot copy file to dest '$filepath': object at path is already a directory" + } else { + $changed_files += $file + } + } + + foreach ($directory in $directories) { + $dirname = $directory.dest + + $dirpath = Join-Path -Path $dest -ChildPath $dirname + $parent_dir = [System.IO.Path]::GetDirectoryName($dirpath) + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at parent directory path is already a file" + } + if (Test-Path -LiteralPath $dirpath -PathType Leaf) { + Fail-Json -obj $result -message "cannot copy folder to dest '$dirpath': object at path is already a file" + } elseif (-not (Test-Path -LiteralPath $dirpath -PathType Container)) { + $changed_directories += $directory + } + } + + # TODO: Handle symlinks + + $result.files = $changed_files + $result.directories = $changed_directories + $result.symlinks = $changed_symlinks +} elseif ($copy_mode -eq "explode") { + # a single zip file containing the files and directories needs to be + # expanded this will always result in a change as the calculation is done + # on the win_copy action plugin and is only run if a change needs to occur + if (-not (Test-Path -LiteralPath $src -PathType Leaf)) { + Fail-Json -obj $result -message "Cannot expand src zip file: '$src' as it does not exist" + } + + # Detect if the PS zip assemblies are available or whether to use Shell + $use_legacy = $false + try { + Add-Type -AssemblyName System.IO.Compression.FileSystem | Out-Null + Add-Type -AssemblyName System.IO.Compression | Out-Null + } catch { + $use_legacy = $true + } + if ($use_legacy) { + Extract-ZipLegacy -src $src -dest $dest + } else { + Extract-Zip -src $src -dest $dest + } + + $result.changed = $true +} elseif ($copy_mode -eq "remote") { + # all copy actions are happening on the remote side (windows host), need + # too copy source and dest using PS code + $result.src = $src + $result.dest = $dest + + if (-not (Test-Path -LiteralPath $src)) { + Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist" + } + + if (Test-Path -LiteralPath $src -PathType Container) { + # we are copying a directory or the contents of a directory + $result.operation = 'folder_copy' + if ($src.EndsWith("/") -or $src.EndsWith("`\")) { + # copying the folder's contents to dest + $diff = "" + $child_files = Get-ChildItem -LiteralPath $src -Force + foreach ($child_file in $child_files) { + $dest_child_path = Join-Path -Path $dest -ChildPath $child_file.Name + if ($child_file.PSIsContainer) { + $diff += Copy-Folder -source $child_file.FullName -dest $dest_child_path + } else { + $diff += (Copy-File -source $child_file.FullName -dest $dest_child_path).diff + } + } + } else { + # copying the folder and it's contents to dest + $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name + $result.dest = $dest + $diff = Copy-Folder -source $src -dest $dest + } + } else { + # we are just copying a single file to dest + $result.operation = 'file_copy' + + $source_basename = (Get-Item -LiteralPath $src -Force).Name + $result.original_basename = $source_basename + + if ($dest.EndsWith("/") -or $dest.EndsWith("`\")) { + $dest = Join-Path -Path $dest -ChildPath (Get-Item -LiteralPath $src -Force).Name + $result.dest = $dest + } else { + # check if the parent dir exists, this is only done if src is a + # file and dest if the path to a file (doesn't end with \ or /) + $parent_dir = Split-Path -LiteralPath $dest + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist" + } + } + $copy_result = Copy-File -source $src -dest $dest + $diff = $copy_result.diff + $result.checksum = $copy_result.checksum + } + + # the file might not exist if running in check mode + if (-not $check_mode -or (Test-Path -LiteralPath $dest -PathType Leaf)) { + $result.size = Get-FileSize -path $dest + } else { + $result.size = $null + } + if ($diff_mode) { + $result.diff.prepared = $diff + } +} elseif ($copy_mode -eq "single") { + # a single file is located in src and we need to copy to dest, this will + # always result in a change as the calculation is done on the Ansible side + # before this is run. This should also never run in check mode + if (-not (Test-Path -LiteralPath $src -PathType Leaf)) { + Fail-Json -obj $result -message "Cannot copy src file: '$src' as it does not exist" + } + + # the dest parameter is a directory, we need to append original_basename + if ($dest.EndsWith("/") -or $dest.EndsWith("`\") -or (Test-Path -LiteralPath $dest -PathType Container)) { + $remote_dest = Join-Path -Path $dest -ChildPath $original_basename + $parent_dir = Split-Path -LiteralPath $remote_dest + + # when dest ends with /, we need to create the destination directories + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + New-Item -Path $parent_dir -ItemType Directory | Out-Null + } + } else { + $remote_dest = $dest + $parent_dir = Split-Path -LiteralPath $remote_dest + + # check if the dest parent dirs exist, need to fail if they don't + if (Test-Path -LiteralPath $parent_dir -PathType Leaf) { + Fail-Json -obj $result -message "object at destination parent dir '$parent_dir' is currently a file" + } elseif (-not (Test-Path -LiteralPath $parent_dir -PathType Container)) { + Fail-Json -obj $result -message "Destination directory '$parent_dir' does not exist" + } + } + + if ($backup) { + $result.backup_file = Backup-File -path $remote_dest -WhatIf:$check_mode + } + + Copy-Item -LiteralPath $src -Destination $remote_dest -Force | Out-Null + $result.changed = $true +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/plugins/modules/win_copy.py b/test/support/windows-integration/plugins/modules/win_copy.py new file mode 100644 index 00000000..a55f4c65 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_copy.py @@ -0,0 +1,207 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_copy +version_added: '1.9.2' +short_description: Copies files to remote locations on windows hosts +description: +- The C(win_copy) module copies a file on the local box to remote windows locations. +- For non-Windows targets, use the M(copy) module instead. +options: + content: + description: + - When used instead of C(src), sets the contents of a file directly to the + specified value. + - This is for simple values, for anything complex or with formatting please + switch to the M(template) module. + type: str + version_added: '2.3' + decrypt: + description: + - This option controls the autodecryption of source files using vault. + type: bool + default: yes + version_added: '2.5' + dest: + description: + - Remote absolute path where the file should be copied to. + - If C(src) is a directory, this must be a directory too. + - Use \ for path separators or \\ when in "double quotes". + - If C(dest) ends with \ then source or the contents of source will be + copied to the directory without renaming. + - If C(dest) is a nonexistent path, it will only be created if C(dest) ends + with "/" or "\", or C(src) is a directory. + - If C(src) and C(dest) are files and if the parent directory of C(dest) + doesn't exist, then the task will fail. + type: path + required: yes + backup: + description: + - Determine whether a backup should be created. + - When set to C(yes), create a backup file including the timestamp information + so you can get the original file back if you somehow clobbered it incorrectly. + - No backup is taken when C(remote_src=False) and multiple files are being + copied. + type: bool + default: no + version_added: '2.8' + force: + description: + - If set to C(yes), the file will only be transferred if the content + is different than destination. + - If set to C(no), the file will only be transferred if the + destination does not exist. + - If set to C(no), no checksuming of the content is performed which can + help improve performance on larger files. + type: bool + default: yes + version_added: '2.3' + local_follow: + description: + - This flag indicates that filesystem links in the source tree, if they + exist, should be followed. + type: bool + default: yes + version_added: '2.4' + remote_src: + description: + - If C(no), it will search for src at originating/master machine. + - If C(yes), it will go to the remote/target machine for the src. + type: bool + default: no + version_added: '2.3' + src: + description: + - Local path to a file to copy to the remote server; can be absolute or + relative. + - If path is a directory, it is copied (including the source folder name) + recursively to C(dest). + - If path is a directory and ends with "/", only the inside contents of + that directory are copied to the destination. Otherwise, if it does not + end with "/", the directory itself with all contents is copied. + - If path is a file and dest ends with "\", the file is copied to the + folder with the same filename. + - Required unless using C(content). + type: path +notes: +- Currently win_copy does not support copying symbolic links from both local to + remote and remote to remote. +- It is recommended that backslashes C(\) are used instead of C(/) when dealing + with remote paths. +- Because win_copy runs over WinRM, it is not a very efficient transfer + mechanism. If sending large files consider hosting them on a web service and + using M(win_get_url) instead. +seealso: +- module: assemble +- module: copy +- module: win_get_url +- module: win_robocopy +author: +- Jon Hawkesworth (@jhawkesworth) +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Copy a single file + win_copy: + src: /srv/myfiles/foo.conf + dest: C:\Temp\renamed-foo.conf + +- name: Copy a single file, but keep a backup + win_copy: + src: /srv/myfiles/foo.conf + dest: C:\Temp\renamed-foo.conf + backup: yes + +- name: Copy a single file keeping the filename + win_copy: + src: /src/myfiles/foo.conf + dest: C:\Temp\ + +- name: Copy folder to C:\Temp (results in C:\Temp\temp_files) + win_copy: + src: files/temp_files + dest: C:\Temp + +- name: Copy folder contents recursively + win_copy: + src: files/temp_files/ + dest: C:\Temp + +- name: Copy a single file where the source is on the remote host + win_copy: + src: C:\Temp\foo.txt + dest: C:\ansible\foo.txt + remote_src: yes + +- name: Copy a folder recursively where the source is on the remote host + win_copy: + src: C:\Temp + dest: C:\ansible + remote_src: yes + +- name: Set the contents of a file + win_copy: + content: abc123 + dest: C:\Temp\foo.txt + +- name: Copy a single file as another user + win_copy: + src: NuGet.config + dest: '%AppData%\NuGet\NuGet.config' + vars: + ansible_become_user: user + ansible_become_password: pass + # The tmp dir must be set when using win_copy as another user + # This ensures the become user will have permissions for the operation + # Make sure to specify a folder both the ansible_user and the become_user have access to (i.e not %TEMP% which is user specific and requires Admin) + ansible_remote_tmp: 'c:\tmp' +''' + +RETURN = r''' +backup_file: + description: Name of the backup file that was created. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +dest: + description: Destination file/path. + returned: changed + type: str + sample: C:\Temp\ +src: + description: Source file used for the copy on the target machine. + returned: changed + type: str + sample: /home/httpd/.ansible/tmp/ansible-tmp-1423796390.97-147729857856000/source +checksum: + description: SHA1 checksum of the file after running copy. + returned: success, src is a file + type: str + sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827 +size: + description: Size of the target, after execution. + returned: changed, src is a file + type: int + sample: 1220 +operation: + description: Whether a single file copy took place or a folder copy. + returned: success + type: str + sample: file_copy +original_basename: + description: Basename of the copied file. + returned: changed, src is a file + type: str + sample: foo.txt +''' diff --git a/test/support/windows-integration/plugins/modules/win_data_deduplication.ps1 b/test/support/windows-integration/plugins/modules/win_data_deduplication.ps1 new file mode 100644 index 00000000..593ee763 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_data_deduplication.ps1 @@ -0,0 +1,129 @@ +#!powershell + +# Copyright: 2019, rnsc(@rnsc) <github@rnsc.be> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt + +#AnsibleRequires -CSharpUtil Ansible.Basic +#AnsibleRequires -OSVersion 6.3 + +$spec = @{ + options = @{ + drive_letter = @{ type = "str"; required = $true } + state = @{ type = "str"; choices = "absent", "present"; default = "present"; } + settings = @{ + type = "dict" + required = $false + options = @{ + minimum_file_size = @{ type = "int"; default = 32768 } + minimum_file_age_days = @{ type = "int"; default = 2 } + no_compress = @{ type = "bool"; required = $false; default = $false } + optimize_in_use_files = @{ type = "bool"; required = $false; default = $false } + verify = @{ type = "bool"; required = $false; default = $false } + } + } + } + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$drive_letter = $module.Params.drive_letter +$state = $module.Params.state +$settings = $module.Params.settings + +$module.Result.changed = $false +$module.Result.reboot_required = $false +$module.Result.msg = "" + +function Set-DataDeduplication($volume, $state, $settings, $dedup_job) { + + $current_state = 'absent' + + try { + $dedup_info = Get-DedupVolume -Volume "$($volume.DriveLetter):" + } catch { + $dedup_info = $null + } + + if ($dedup_info.Enabled) { + $current_state = 'present' + } + + if ( $state -ne $current_state ) { + if( -not $module.CheckMode) { + if($state -eq 'present') { + # Enable-DedupVolume -Volume <String> + Enable-DedupVolume -Volume "$($volume.DriveLetter):" + } elseif ($state -eq 'absent') { + Disable-DedupVolume -Volume "$($volume.DriveLetter):" + } + } + $module.Result.changed = $true + } + + if ($state -eq 'present') { + if ($null -ne $settings) { + Set-DataDedupJobSettings -volume $volume -settings $settings + } + } +} + +function Set-DataDedupJobSettings ($volume, $settings) { + + try { + $dedup_info = Get-DedupVolume -Volume "$($volume.DriveLetter):" + } catch { + $dedup_info = $null + } + + ForEach ($key in $settings.keys) { + + # See Microsoft documentation: + # https://docs.microsoft.com/en-us/powershell/module/deduplication/set-dedupvolume?view=win10-ps + + $update_key = $key + $update_value = $settings.$($key) + # Transform Ansible style options to Powershell params + $update_key = $update_key -replace('_', '') + + if ($update_key -eq "MinimumFileSize" -and $update_value -lt 32768) { + $update_value = 32768 + } + + $current_value = ($dedup_info | Select-Object -ExpandProperty $update_key) + + if ($update_value -ne $current_value) { + $command_param = @{ + $($update_key) = $update_value + } + + # Set-DedupVolume -Volume <String>` + # -NoCompress <bool> ` + # -MinimumFileAgeDays <UInt32> ` + # -MinimumFileSize <UInt32> (minimum 32768) + if( -not $module.CheckMode ) { + Set-DedupVolume -Volume "$($volume.DriveLetter):" @command_param + } + + $module.Result.changed = $true + } + } + +} + +# Install required feature +$feature_name = "FS-Data-Deduplication" +if( -not $module.CheckMode) { + $feature = Install-WindowsFeature -Name $feature_name + + if ($feature.RestartNeeded -eq 'Yes') { + $module.Result.reboot_required = $true + $module.FailJson("$feature_name was installed but requires Windows to be rebooted to work.") + } +} + +$volume = Get-Volume -DriveLetter $drive_letter + +Set-DataDeduplication -volume $volume -state $state -settings $settings -dedup_job $dedup_job + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_data_deduplication.py b/test/support/windows-integration/plugins/modules/win_data_deduplication.py new file mode 100644 index 00000000..d320b9f7 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_data_deduplication.py @@ -0,0 +1,87 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: 2019, rnsc(@rnsc) <github@rnsc.be> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_data_deduplication +version_added: "2.10" +short_description: Module to enable Data Deduplication on a volume. +description: +- This module can be used to enable Data Deduplication on a Windows volume. +- The module will install the FS-Data-Deduplication feature (a reboot will be necessary). +options: + drive_letter: + description: + - Windows drive letter on which to enable data deduplication. + required: yes + type: str + state: + description: + - Wether to enable or disable data deduplication on the selected volume. + default: present + type: str + choices: [ present, absent ] + settings: + description: + - Dictionary of settings to pass to the Set-DedupVolume powershell command. + type: dict + suboptions: + minimum_file_size: + description: + - Minimum file size you want to target for deduplication. + - It will default to 32768 if not defined or if the value is less than 32768. + type: int + default: 32768 + minimum_file_age_days: + description: + - Minimum file age you want to target for deduplication. + type: int + default: 2 + no_compress: + description: + - Wether you want to enabled filesystem compression or not. + type: bool + default: no + optimize_in_use_files: + description: + - Indicates that the server attempts to optimize currently open files. + type: bool + default: no + verify: + description: + - Indicates whether the deduplication engine performs a byte-for-byte verification for each duplicate chunk + that optimization creates, rather than relying on a cryptographically strong hash. + - This option is not recommend. + - Setting this parameter to True can degrade optimization performance. + type: bool + default: no +author: +- rnsc (@rnsc) +''' + +EXAMPLES = r''' +- name: Enable Data Deduplication on D + win_data_deduplication: + drive_letter: 'D' + state: present + +- name: Enable Data Deduplication on D + win_data_deduplication: + drive_letter: 'D' + state: present + settings: + no_compress: true + minimum_file_age_days: 1 + minimum_file_size: 0 +''' + +RETURN = r''' +# +''' diff --git a/test/support/windows-integration/plugins/modules/win_dsc.ps1 b/test/support/windows-integration/plugins/modules/win_dsc.ps1 new file mode 100644 index 00000000..690f391a --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_dsc.ps1 @@ -0,0 +1,398 @@ +#!powershell + +# Copyright: (c) 2015, Trond Hindenes <trond@hindenes.com>, and others +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Version 5 + +Function ConvertTo-ArgSpecType { + <# + .SYNOPSIS + Converts the DSC parameter type to the arg spec type required for Ansible. + #> + param( + [Parameter(Mandatory=$true)][String]$CimType + ) + + $arg_type = switch($CimType) { + Boolean { "bool" } + Char16 { [Func[[Object], [Char]]]{ [System.Char]::Parse($args[0].ToString()) } } + DateTime { [Func[[Object], [DateTime]]]{ [System.DateTime]($args[0].ToString()) } } + Instance { "dict" } + Real32 { "float" } + Real64 { [Func[[Object], [Double]]]{ [System.Double]::Parse($args[0].ToString()) } } + Reference { "dict" } + SInt16 { [Func[[Object], [Int16]]]{ [System.Int16]::Parse($args[0].ToString()) } } + SInt32 { "int" } + SInt64 { [Func[[Object], [Int64]]]{ [System.Int64]::Parse($args[0].ToString()) } } + SInt8 { [Func[[Object], [SByte]]]{ [System.SByte]::Parse($args[0].ToString()) } } + String { "str" } + UInt16 { [Func[[Object], [UInt16]]]{ [System.UInt16]::Parse($args[0].ToString()) } } + UInt32 { [Func[[Object], [UInt32]]]{ [System.UInt32]::Parse($args[0].ToString()) } } + UInt64 { [Func[[Object], [UInt64]]]{ [System.UInt64]::Parse($args[0].ToString()) } } + UInt8 { [Func[[Object], [Byte]]]{ [System.Byte]::Parse($args[0].ToString()) } } + Unknown { "raw" } + default { "raw" } + } + return $arg_type +} + +Function Get-DscCimClassProperties { + <# + .SYNOPSIS + Get's a list of CimProperties of a CIM Class. It filters out any magic or + read only properties that we don't need to know about. + #> + param([Parameter(Mandatory=$true)][String]$ClassName) + + $resource = Get-CimClass -ClassName $ClassName -Namespace root\Microsoft\Windows\DesiredStateConfiguration + + # Filter out any magic properties that are used internally on an OMI_BaseResource + # https://github.com/PowerShell/PowerShell/blob/master/src/System.Management.Automation/DscSupport/CimDSCParser.cs#L1203 + $magic_properties = @("ResourceId", "SourceInfo", "ModuleName", "ModuleVersion", "ConfigurationName") + $properties = $resource.CimClassProperties | Where-Object { + + ($resource.CimSuperClassName -ne "OMI_BaseResource" -or $_.Name -notin $magic_properties) -and + -not $_.Flags.HasFlag([Microsoft.Management.Infrastructure.CimFlags]::ReadOnly) + } + + return ,$properties +} + +Function Add-PropertyOption { + <# + .SYNOPSIS + Adds the spec for the property type to the existing module specification. + #> + param( + [Parameter(Mandatory=$true)][Hashtable]$Spec, + [Parameter(Mandatory=$true)] + [Microsoft.Management.Infrastructure.CimPropertyDeclaration]$Property + ) + + $option = @{ + required = $false + } + $property_name = $Property.Name + $property_type = $Property.CimType.ToString() + + if ($Property.Flags.HasFlag([Microsoft.Management.Infrastructure.CimFlags]::Key) -or + $Property.Flags.HasFlag([Microsoft.Management.Infrastructure.CimFlags]::Required)) { + $option.required = $true + } + + if ($null -ne $Property.Qualifiers['Values']) { + $option.choices = [System.Collections.Generic.List`1[Object]]$Property.Qualifiers['Values'].Value + } + + if ($property_name -eq "Name") { + # For backwards compatibility we support specifying the Name DSC property as item_name + $option.aliases = @("item_name") + } elseif ($property_name -ceq "key") { + # There seems to be a bug in the CIM property parsing when the property name is 'Key'. The CIM instance will + # think the name is 'key' when the MOF actually defines it as 'Key'. We set the proper casing so the module arg + # validator won't fire a case sensitive warning + $property_name = "Key" + } + + if ($Property.ReferenceClassName -eq "MSFT_Credential") { + # Special handling for the MSFT_Credential type (PSCredential), we handle this with having 2 options that + # have the suffix _username and _password. + $option_spec_pass = @{ + type = "str" + required = $option.required + no_log = $true + } + $Spec.options."$($property_name)_password" = $option_spec_pass + $Spec.required_together.Add(@("$($property_name)_username", "$($property_name)_password")) > $null + + $property_name = "$($property_name)_username" + $option.type = "str" + } elseif ($Property.ReferenceClassName -eq "MSFT_KeyValuePair") { + $option.type = "dict" + } elseif ($property_type.EndsWith("Array")) { + $option.type = "list" + $option.elements = ConvertTo-ArgSpecType -CimType $property_type.Substring(0, $property_type.Length - 5) + } else { + $option.type = ConvertTo-ArgSpecType -CimType $property_type + } + + if (($option.type -eq "dict" -or ($option.type -eq "list" -and $option.elements -eq "dict")) -and + $Property.ReferenceClassName -ne "MSFT_KeyValuePair") { + # Get the sub spec if the type is a Instance (CimInstance/dict) + $sub_option_spec = Get-OptionSpec -ClassName $Property.ReferenceClassName + $option += $sub_option_spec + } + + $Spec.options.$property_name = $option +} + +Function Get-OptionSpec { + <# + .SYNOPSIS + Generates the specifiec used in AnsibleModule for a CIM MOF resource name. + + .NOTES + This won't be able to retrieve the default values for an option as that is not defined in the MOF for a resource. + Default values are still preserved in the DSC engine if we don't pass in the property at all, we just can't report + on what they are automatically. + #> + param( + [Parameter(Mandatory=$true)][String]$ClassName + ) + + $spec = @{ + options = @{} + required_together = [System.Collections.ArrayList]@() + } + $properties = Get-DscCimClassProperties -ClassName $ClassName + foreach ($property in $properties) { + Add-PropertyOption -Spec $spec -Property $property + } + + return $spec +} + +Function ConvertTo-CimInstance { + <# + .SYNOPSIS + Converts a dict to a CimInstance of the specified Class. Also provides a + better error message if this fails that contains the option name that failed. + #> + param( + [Parameter(Mandatory=$true)][String]$Name, + [Parameter(Mandatory=$true)][String]$ClassName, + [Parameter(Mandatory=$true)][System.Collections.IDictionary]$Value, + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Switch]$Recurse + ) + + $properties = @{} + foreach ($value_info in $Value.GetEnumerator()) { + # Need to remove all null values from existing dict so the conversion works + if ($null -eq $value_info.Value) { + continue + } + $properties.($value_info.Key) = $value_info.Value + } + + if ($Recurse) { + # We want to validate and convert and values to what's required by DSC + $properties = ConvertTo-DscProperty -ClassName $ClassName -Params $properties -Module $Module + } + + try { + return (New-CimInstance -ClassName $ClassName -Property $properties -ClientOnly) + } catch { + # New-CimInstance raises a poor error message, make sure we mention what option it is for + $Module.FailJson("Failed to cast dict value for option '$Name' to a CimInstance: $($_.Exception.Message)", $_) + } +} + +Function ConvertTo-DscProperty { + <# + .SYNOPSIS + Converts the input module parameters that have been validated and casted + into the types expected by the DSC engine. This is mostly done to deal with + types like PSCredential and Dictionaries. + #> + param( + [Parameter(Mandatory=$true)][String]$ClassName, + [Parameter(Mandatory=$true)][System.Collections.IDictionary]$Params, + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module + ) + $properties = Get-DscCimClassProperties -ClassName $ClassName + + $dsc_properties = @{} + foreach ($property in $properties) { + $property_name = $property.Name + $property_type = $property.CimType.ToString() + + if ($property.ReferenceClassName -eq "MSFT_Credential") { + $username = $Params."$($property_name)_username" + $password = $Params."$($property_name)_password" + + # No user set == No option set in playbook, skip this property + if ($null -eq $username) { + continue + } + $sec_password = ConvertTo-SecureString -String $password -AsPlainText -Force + $value = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList $username, $sec_password + } else { + $value = $Params.$property_name + + # The actual value wasn't set, skip adding this property + if ($null -eq $value) { + continue + } + + if ($property.ReferenceClassName -eq "MSFT_KeyValuePair") { + $key_value_pairs = [System.Collections.Generic.List`1[CimInstance]]@() + foreach ($value_info in $value.GetEnumerator()) { + $kvp = @{Key = $value_info.Key; Value = $value_info.Value.ToString()} + $cim_instance = ConvertTo-CimInstance -Name $property_name -ClassName MSFT_KeyValuePair ` + -Value $kvp -Module $Module + $key_value_pairs.Add($cim_instance) > $null + } + $value = $key_value_pairs.ToArray() + } elseif ($null -ne $property.ReferenceClassName) { + # Convert the dict to a CimInstance (or list of CimInstances) + $convert_args = @{ + ClassName = $property.ReferenceClassName + Module = $Module + Name = $property_name + Recurse = $true + } + if ($property_type.EndsWith("Array")) { + $value = [System.Collections.Generic.List`1[CimInstance]]@() + foreach ($raw in $Params.$property_name.GetEnumerator()) { + $cim_instance = ConvertTo-CimInstance -Value $raw @convert_args + $value.Add($cim_instance) > $null + } + $value = $value.ToArray() # Need to make sure we are dealing with an Array not a List + } else { + $value = ConvertTo-CimInstance -Value $value @convert_args + } + } + } + $dsc_properties.$property_name = $value + } + + return $dsc_properties +} + +Function Invoke-DscMethod { + <# + .SYNOPSIS + Invokes the DSC Resource Method specified in another PS pipeline. This is + done so we can retrieve the Verbose stream and return it back to the user + for futher debugging. + #> + param( + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Parameter(Mandatory=$true)][String]$Method, + [Parameter(Mandatory=$true)][Hashtable]$Arguments + ) + + # Invoke the DSC resource in a separate runspace so we can capture the Verbose output + $ps = [PowerShell]::Create() + $ps.AddCommand("Invoke-DscResource").AddParameter("Method", $Method) > $null + $ps.AddParameters($Arguments) > $null + + $result = $ps.Invoke() + + # Pass the warnings through to the AnsibleModule return result + foreach ($warning in $ps.Streams.Warning) { + $Module.Warn($warning.Message) + } + + # If running at a high enough verbosity, add the verbose output to the AnsibleModule return result + if ($Module.Verbosity -ge 3) { + $verbose_logs = [System.Collections.Generic.List`1[String]]@() + foreach ($verbosity in $ps.Streams.Verbose) { + $verbose_logs.Add($verbosity.Message) > $null + } + $Module.Result."verbose_$($Method.ToLower())" = $verbose_logs + } + + if ($ps.HadErrors) { + # Cannot pass in the ErrorRecord as it's a RemotingErrorRecord and doesn't contain the ScriptStackTrace + # or other info that would be useful + $Module.FailJson("Failed to invoke DSC $Method method: $($ps.Streams.Error[0].Exception.Message)") + } + + return $result +} + +# win_dsc is unique in that is builds the arg spec based on DSC Resource input. To get this info +# we need to read the resource_name and module_version value which is done outside of Ansible.Basic +if ($args.Length -gt 0) { + $params = Get-Content -Path $args[0] | ConvertFrom-Json +} else { + $params = $complex_args +} +if (-not $params.ContainsKey("resource_name")) { + $res = @{ + msg = "missing required argument: resource_name" + failed = $true + } + Write-Output -InputObject (ConvertTo-Json -Compress -InputObject $res) + exit 1 +} +$resource_name = $params.resource_name + +if ($params.ContainsKey("module_version")) { + $module_version = $params.module_version +} else { + $module_version = "latest" +} + +$module_versions = (Get-DscResource -Name $resource_name -ErrorAction SilentlyContinue | Sort-Object -Property Version) +$resource = $null +if ($module_version -eq "latest" -and $null -ne $module_versions) { + $resource = $module_versions[-1] +} elseif ($module_version -ne "latest") { + $resource = $module_versions | Where-Object { $_.Version -eq $module_version } +} + +if (-not $resource) { + if ($module_version -eq "latest") { + $msg = "Resource '$resource_name' not found." + } else { + $msg = "Resource '$resource_name' with version '$module_version' not found." + $msg += " Versions installed: '$($module_versions.Version -join "', '")'." + } + + Write-Output -InputObject (ConvertTo-Json -Compress -InputObject @{ failed = $true; msg = $msg }) + exit 1 +} + +# Build the base args for the DSC Invocation based on the resource selected +$dsc_args = @{ + Name = $resource.Name +} + +# Binary resources are not working very well with that approach - need to guesstimate module name/version +$module_version = $null +if ($resource.Module) { + $dsc_args.ModuleName = @{ + ModuleName = $resource.Module.Name + ModuleVersion = $resource.Module.Version + } + $module_version = $resource.Module.Version.ToString() +} else { + $dsc_args.ModuleName = "PSDesiredStateConfiguration" +} + +# To ensure the class registered with CIM is the one based on our version, we want to run the Get method so the DSC +# engine updates the metadata propery. We don't care about any errors here +try { + Invoke-DscResource -Method Get -Property @{Fake="Fake"} @dsc_args > $null +} catch {} + +# Dynamically build the option spec based on the resource_name specified and create the module object +$spec = Get-OptionSpec -ClassName $resource.ResourceType +$spec.supports_check_mode = $true +$spec.options.module_version = @{ type = "str"; default = "latest" } +$spec.options.resource_name = @{ type = "str"; required = $true } + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) +$module.Result.reboot_required = $false +$module.Result.module_version = $module_version + +# Build the DSC invocation arguments and invoke the resource +$dsc_args.Property = ConvertTo-DscProperty -ClassName $resource.ResourceType -Module $module -Params $Module.Params +$dsc_args.Verbose = $true + +$test_result = Invoke-DscMethod -Module $module -Method Test -Arguments $dsc_args +if ($test_result.InDesiredState -ne $true) { + if (-not $module.CheckMode) { + $result = Invoke-DscMethod -Module $module -Method Set -Arguments $dsc_args + $module.Result.reboot_required = $result.RebootRequired + } + $module.Result.changed = $true +} + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_dsc.py b/test/support/windows-integration/plugins/modules/win_dsc.py new file mode 100644 index 00000000..200d025e --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_dsc.py @@ -0,0 +1,183 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Trond Hindenes <trond@hindenes.com>, and others +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_dsc +version_added: "2.4" +short_description: Invokes a PowerShell DSC configuration +description: +- Configures a resource using PowerShell DSC. +- Requires PowerShell version 5.0 or newer. +- Most of the options for this module are dynamic and will vary depending on + the DSC Resource specified in I(resource_name). +- See :doc:`/user_guide/windows_dsc` for more information on how to use this module. +options: + resource_name: + description: + - The name of the DSC Resource to use. + - Must be accessible to PowerShell using any of the default paths. + type: str + required: yes + module_version: + description: + - Can be used to configure the exact version of the DSC resource to be + invoked. + - Useful if the target node has multiple versions installed of the module + containing the DSC resource. + - If not specified, the module will follow standard PowerShell convention + and use the highest version available. + type: str + default: latest + free_form: + description: + - The M(win_dsc) module takes in multiple free form options based on the + DSC resource being invoked by I(resource_name). + - There is no option actually named C(free_form) so see the examples. + - This module will try and convert the option to the correct type required + by the DSC resource and throw a warning if it fails. + - If the type of the DSC resource option is a C(CimInstance) or + C(CimInstance[]), this means the value should be a dictionary or list + of dictionaries based on the values required by that option. + - If the type of the DSC resource option is a C(PSCredential) then there + needs to be 2 options set in the Ansible task definition suffixed with + C(_username) and C(_password). + - If the type of the DSC resource option is an array, then a list should be + provided but a comma separated string also work. Use a list where + possible as no escaping is required and it works with more complex types + list C(CimInstance[]). + - If the type of the DSC resource option is a C(DateTime), you should use + a string in the form of an ISO 8901 string to ensure the exact date is + used. + - Since Ansible 2.8, Ansible will now validate the input fields against the + DSC resource definition automatically. Older versions will silently + ignore invalid fields. + type: str + required: true +notes: +- By default there are a few builtin resources that come with PowerShell 5.0, + see U(https://docs.microsoft.com/en-us/powershell/scripting/dsc/resources/resources) for + more information on these resources. +- Custom DSC resources can be installed with M(win_psmodule) using the I(name) + option. +- The DSC engine run's each task as the SYSTEM account, any resources that need + to be accessed with a different account need to have C(PsDscRunAsCredential) + set. +- To see the valid options for a DSC resource, run the module with C(-vvv) to + show the possible module invocation. Default values are not shown in this + output but are applied within the DSC engine. +author: +- Trond Hindenes (@trondhindenes) +''' + +EXAMPLES = r''' +- name: Extract zip file + win_dsc: + resource_name: Archive + Ensure: Present + Path: C:\Temp\zipfile.zip + Destination: C:\Temp\Temp2 + +- name: Install a Windows feature with the WindowsFeature resource + win_dsc: + resource_name: WindowsFeature + Name: telnet-client + +- name: Edit HKCU reg key under specific user + win_dsc: + resource_name: Registry + Ensure: Present + Key: HKEY_CURRENT_USER\ExampleKey + ValueName: TestValue + ValueData: TestData + PsDscRunAsCredential_username: '{{ansible_user}}' + PsDscRunAsCredential_password: '{{ansible_password}}' + no_log: true + +- name: Create file with multiple attributes + win_dsc: + resource_name: File + DestinationPath: C:\ansible\dsc + Attributes: # can also be a comma separated string, e.g. 'Hidden, System' + - Hidden + - System + Ensure: Present + Type: Directory + +- name: Call DSC resource with DateTime option + win_dsc: + resource_name: DateTimeResource + DateTimeOption: '2019-02-22T13:57:31.2311892+00:00' + +# more complex example using custom DSC resource and dict values +- name: Setup the xWebAdministration module + win_psmodule: + name: xWebAdministration + state: present + +- name: Create IIS Website with Binding and Authentication options + win_dsc: + resource_name: xWebsite + Ensure: Present + Name: DSC Website + State: Started + PhysicalPath: C:\inetpub\wwwroot + BindingInfo: # Example of a CimInstance[] DSC parameter (list of dicts) + - Protocol: https + Port: 1234 + CertificateStoreName: MY + CertificateThumbprint: C676A89018C4D5902353545343634F35E6B3A659 + HostName: DSCTest + IPAddress: '*' + SSLFlags: '1' + - Protocol: http + Port: 4321 + IPAddress: '*' + AuthenticationInfo: # Example of a CimInstance DSC parameter (dict) + Anonymous: no + Basic: true + Digest: false + Windows: yes +''' + +RETURN = r''' +module_version: + description: The version of the dsc resource/module used. + returned: always + type: str + sample: "1.0.1" +reboot_required: + description: Flag returned from the DSC engine indicating whether or not + the machine requires a reboot for the invoked changes to take effect. + returned: always + type: bool + sample: true +verbose_test: + description: The verbose output as a list from executing the DSC test + method. + returned: Ansible verbosity is -vvv or greater + type: list + sample: [ + "Perform operation 'Invoke CimMethod' with the following parameters, ", + "[SERVER]: LCM: [Start Test ] [[File]DirectResourceAccess]", + "Operation 'Invoke CimMethod' complete." + ] +verbose_set: + description: The verbose output as a list from executing the DSC Set + method. + returned: Ansible verbosity is -vvv or greater and a change occurred + type: list + sample: [ + "Perform operation 'Invoke CimMethod' with the following parameters, ", + "[SERVER]: LCM: [Start Set ] [[File]DirectResourceAccess]", + "Operation 'Invoke CimMethod' complete." + ] +''' diff --git a/test/support/windows-integration/plugins/modules/win_feature.ps1 b/test/support/windows-integration/plugins/modules/win_feature.ps1 new file mode 100644 index 00000000..9a7e1c30 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_feature.ps1 @@ -0,0 +1,111 @@ +#!powershell + +# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +Import-Module -Name ServerManager + +$result = @{ + changed = $false +} + +$params = Parse-Args $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false + +$name = Get-AnsibleParam -obj $params -name "name" -type "list" -failifempty $true +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent" + +$include_sub_features = Get-AnsibleParam -obj $params -name "include_sub_features" -type "bool" -default $false +$include_management_tools = Get-AnsibleParam -obj $params -name "include_management_tools" -type "bool" -default $false +$source = Get-AnsibleParam -obj $params -name "source" -type "str" + +$install_cmdlet = $false +if (Get-Command -Name Install-WindowsFeature -ErrorAction SilentlyContinue) { + Set-Alias -Name Install-AnsibleWindowsFeature -Value Install-WindowsFeature + Set-Alias -Name Uninstall-AnsibleWindowsFeature -Value Uninstall-WindowsFeature + $install_cmdlet = $true +} elseif (Get-Command -Name Add-WindowsFeature -ErrorAction SilentlyContinue) { + Set-Alias -Name Install-AnsibleWindowsFeature -Value Add-WindowsFeature + Set-Alias -Name Uninstall-AnsibleWindowsFeature -Value Remove-WindowsFeature +} else { + Fail-Json -obj $result -message "This version of Windows does not support the cmdlets Install-WindowsFeature or Add-WindowsFeature" +} + +if ($state -eq "present") { + $install_args = @{ + Name = $name + IncludeAllSubFeature = $include_sub_features + Restart = $false + WhatIf = $check_mode + ErrorAction = "Stop" + } + + if ($install_cmdlet) { + $install_args.IncludeManagementTools = $include_management_tools + $install_args.Confirm = $false + if ($source) { + if (-not (Test-Path -Path $source)) { + Fail-Json -obj $result -message "Failed to find source path $source for feature install" + } + $install_args.Source = $source + } + } + + try { + $action_results = Install-AnsibleWindowsFeature @install_args + } catch { + Fail-Json -obj $result -message "Failed to install Windows Feature: $($_.Exception.Message)" + } +} else { + $uninstall_args = @{ + Name = $name + Restart = $false + WhatIf = $check_mode + ErrorAction = "Stop" + } + if ($install_cmdlet) { + $uninstall_args.IncludeManagementTools = $include_management_tools + } + + try { + $action_results = Uninstall-AnsibleWindowsFeature @uninstall_args + } catch { + Fail-Json -obj $result -message "Failed to uninstall Windows Feature: $($_.Exception.Message)" + } +} + +# Loop through results and create a hash containing details about +# each role/feature that is installed/removed +# $action_results.FeatureResult is not empty if anything was changed +$feature_results = @() +foreach ($action_result in $action_results.FeatureResult) { + $message = @() + foreach ($msg in $action_result.Message) { + $message += @{ + message_type = $msg.MessageType.ToString() + error_code = $msg.ErrorCode + text = $msg.Text + } + } + + $feature_results += @{ + id = $action_result.Id + display_name = $action_result.DisplayName + message = $message + reboot_required = ConvertTo-Bool -obj $action_result.RestartNeeded + skip_reason = $action_result.SkipReason.ToString() + success = ConvertTo-Bool -obj $action_result.Success + restart_needed = ConvertTo-Bool -obj $action_result.RestartNeeded + } + $result.changed = $true +} +$result.feature_result = $feature_results +$result.success = ConvertTo-Bool -obj $action_results.Success +$result.exitcode = $action_results.ExitCode.ToString() +$result.reboot_required = ConvertTo-Bool -obj $action_results.RestartNeeded +# controls whether Ansible will fail or not +$result.failed = (-not $action_results.Success) + +Exit-Json -obj $result diff --git a/test/support/windows-integration/plugins/modules/win_feature.py b/test/support/windows-integration/plugins/modules/win_feature.py new file mode 100644 index 00000000..62e310b2 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_feature.py @@ -0,0 +1,149 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com> +# Copyright: (c) 2014, Trond Hindenes <trond@hindenes.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_feature +version_added: "1.7" +short_description: Installs and uninstalls Windows Features on Windows Server +description: + - Installs or uninstalls Windows Roles or Features on Windows Server. + - This module uses the Add/Remove-WindowsFeature Cmdlets on Windows 2008 R2 + and Install/Uninstall-WindowsFeature Cmdlets on Windows 2012, which are not available on client os machines. +options: + name: + description: + - Names of roles or features to install as a single feature or a comma-separated list of features. + - To list all available features use the PowerShell command C(Get-WindowsFeature). + type: list + required: yes + state: + description: + - State of the features or roles on the system. + type: str + choices: [ absent, present ] + default: present + include_sub_features: + description: + - Adds all subfeatures of the specified feature. + type: bool + default: no + include_management_tools: + description: + - Adds the corresponding management tools to the specified feature. + - Not supported in Windows 2008 R2 and will be ignored. + type: bool + default: no + source: + description: + - Specify a source to install the feature from. + - Not supported in Windows 2008 R2 and will be ignored. + - Can either be C({driveletter}:\sources\sxs) or C(\\{IP}\share\sources\sxs). + type: str + version_added: "2.1" +seealso: +- module: win_chocolatey +- module: win_package +author: + - Paul Durivage (@angstwad) + - Trond Hindenes (@trondhindenes) +''' + +EXAMPLES = r''' +- name: Install IIS (Web-Server only) + win_feature: + name: Web-Server + state: present + +- name: Install IIS (Web-Server and Web-Common-Http) + win_feature: + name: + - Web-Server + - Web-Common-Http + state: present + +- name: Install NET-Framework-Core from file + win_feature: + name: NET-Framework-Core + source: C:\Temp\iso\sources\sxs + state: present + +- name: Install IIS Web-Server with sub features and management tools + win_feature: + name: Web-Server + state: present + include_sub_features: yes + include_management_tools: yes + register: win_feature + +- name: Reboot if installing Web-Server feature requires it + win_reboot: + when: win_feature.reboot_required +''' + +RETURN = r''' +exitcode: + description: The stringified exit code from the feature installation/removal command. + returned: always + type: str + sample: Success +feature_result: + description: List of features that were installed or removed. + returned: success + type: complex + sample: + contains: + display_name: + description: Feature display name. + returned: always + type: str + sample: "Telnet Client" + id: + description: A list of KB article IDs that apply to the update. + returned: always + type: int + sample: 44 + message: + description: Any messages returned from the feature subsystem that occurred during installation or removal of this feature. + returned: always + type: list + elements: str + sample: [] + reboot_required: + description: True when the target server requires a reboot as a result of installing or removing this feature. + returned: always + type: bool + sample: true + restart_needed: + description: DEPRECATED in Ansible 2.4 (refer to C(reboot_required) instead). True when the target server requires a reboot as a + result of installing or removing this feature. + returned: always + type: bool + sample: true + skip_reason: + description: The reason a feature installation or removal was skipped. + returned: always + type: str + sample: NotSkipped + success: + description: If the feature installation or removal was successful. + returned: always + type: bool + sample: true +reboot_required: + description: True when the target server requires a reboot to complete updates (no further updates can be installed until after a reboot). + returned: success + type: bool + sample: true +''' diff --git a/test/support/windows-integration/plugins/modules/win_file.ps1 b/test/support/windows-integration/plugins/modules/win_file.ps1 new file mode 100644 index 00000000..54427549 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_file.ps1 @@ -0,0 +1,152 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$ErrorActionPreference = "Stop" + +$params = Parse-Args $args -supports_check_mode $true + +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -default $false +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$path = Get-AnsibleParam -obj $params -name "path" -type "path" -failifempty $true -aliases "dest","name" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -validateset "absent","directory","file","touch" + +# used in template/copy when dest is the path to a dir and source is a file +$original_basename = Get-AnsibleParam -obj $params -name "_original_basename" -type "str" +if ((Test-Path -LiteralPath $path -PathType Container) -and ($null -ne $original_basename)) { + $path = Join-Path -Path $path -ChildPath $original_basename +} + +$result = @{ + changed = $false +} + +# Used to delete symlinks as powershell cannot delete broken symlinks +$symlink_util = @" +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Ansible.Command { + public class SymLinkHelper { + [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)] + public static extern bool DeleteFileW(string lpFileName); + + [DllImport("kernel32.dll", CharSet=CharSet.Unicode, SetLastError=true)] + public static extern bool RemoveDirectoryW(string lpPathName); + + public static void DeleteDirectory(string path) { + if (!RemoveDirectoryW(path)) + throw new Exception(String.Format("RemoveDirectoryW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message)); + } + + public static void DeleteFile(string path) { + if (!DeleteFileW(path)) + throw new Exception(String.Format("DeleteFileW({0}) failed: {1}", path, new Win32Exception(Marshal.GetLastWin32Error()).Message)); + } + } +} +"@ +$original_tmp = $env:TMP +$env:TMP = $_remote_tmp +Add-Type -TypeDefinition $symlink_util +$env:TMP = $original_tmp + +# Used to delete directories and files with logic on handling symbolic links +function Remove-File($file, $checkmode) { + try { + if ($file.Attributes -band [System.IO.FileAttributes]::ReparsePoint) { + # Bug with powershell, if you try and delete a symbolic link that is pointing + # to an invalid path it will fail, using Win32 API to do this instead + if ($file.PSIsContainer) { + if (-not $checkmode) { + [Ansible.Command.SymLinkHelper]::DeleteDirectory($file.FullName) + } + } else { + if (-not $checkmode) { + [Ansible.Command.SymlinkHelper]::DeleteFile($file.FullName) + } + } + } elseif ($file.PSIsContainer) { + Remove-Directory -directory $file -checkmode $checkmode + } else { + Remove-Item -LiteralPath $file.FullName -Force -WhatIf:$checkmode + } + } catch [Exception] { + Fail-Json $result "Failed to delete $($file.FullName): $($_.Exception.Message)" + } +} + +function Remove-Directory($directory, $checkmode) { + foreach ($file in Get-ChildItem -LiteralPath $directory.FullName) { + Remove-File -file $file -checkmode $checkmode + } + Remove-Item -LiteralPath $directory.FullName -Force -Recurse -WhatIf:$checkmode +} + + +if ($state -eq "touch") { + if (Test-Path -LiteralPath $path) { + if (-not $check_mode) { + (Get-ChildItem -LiteralPath $path).LastWriteTime = Get-Date + } + $result.changed = $true + } else { + Write-Output $null | Out-File -LiteralPath $path -Encoding ASCII -WhatIf:$check_mode + $result.changed = $true + } +} + +if (Test-Path -LiteralPath $path) { + $fileinfo = Get-Item -LiteralPath $path -Force + if ($state -eq "absent") { + Remove-File -file $fileinfo -checkmode $check_mode + $result.changed = $true + } else { + if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a directory" + } + + if ($state -eq "file" -and $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a file" + } + } + +} else { + + # If state is not supplied, test the $path to see if it looks like + # a file or a folder and set state to file or folder + if ($null -eq $state) { + $basename = Split-Path -Path $path -Leaf + if ($basename.length -gt 0) { + $state = "file" + } else { + $state = "directory" + } + } + + if ($state -eq "directory") { + try { + New-Item -Path $path -ItemType Directory -WhatIf:$check_mode | Out-Null + } catch { + if ($_.CategoryInfo.Category -eq "ResourceExists") { + $fileinfo = Get-Item -LiteralPath $_.CategoryInfo.TargetName + if ($state -eq "directory" -and -not $fileinfo.PsIsContainer) { + Fail-Json $result "path $path is not a directory" + } + } else { + Fail-Json $result $_.Exception.Message + } + } + $result.changed = $true + } elseif ($state -eq "file") { + Fail-Json $result "path $path will not be created" + } + +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_file.py b/test/support/windows-integration/plugins/modules/win_file.py new file mode 100644 index 00000000..28149579 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_file.py @@ -0,0 +1,70 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Jon Hawkesworth (@jhawkesworth) <figs@unity.demon.co.uk> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_file +version_added: "1.9.2" +short_description: Creates, touches or removes files or directories +description: + - Creates (empty) files, updates file modification stamps of existing files, + and can create or remove directories. + - Unlike M(file), does not modify ownership, permissions or manipulate links. + - For non-Windows targets, use the M(file) module instead. +options: + path: + description: + - Path to the file being managed. + required: yes + type: path + aliases: [ dest, name ] + state: + description: + - If C(directory), all immediate subdirectories will be created if they + do not exist. + - If C(file), the file will NOT be created if it does not exist, see the M(copy) + or M(template) module if you want that behavior. + - If C(absent), directories will be recursively deleted, and files will be removed. + - If C(touch), an empty file will be created if the C(path) does not + exist, while an existing file or directory will receive updated file access and + modification times (similar to the way C(touch) works from the command line). + type: str + choices: [ absent, directory, file, touch ] +seealso: +- module: file +- module: win_acl +- module: win_acl_inheritance +- module: win_owner +- module: win_stat +author: +- Jon Hawkesworth (@jhawkesworth) +''' + +EXAMPLES = r''' +- name: Touch a file (creates if not present, updates modification time if present) + win_file: + path: C:\Temp\foo.conf + state: touch + +- name: Remove a file, if present + win_file: + path: C:\Temp\foo.conf + state: absent + +- name: Create directory structure + win_file: + path: C:\Temp\folder\subfolder + state: directory + +- name: Remove directory structure + win_file: + path: C:\Temp + state: absent +''' diff --git a/test/support/windows-integration/plugins/modules/win_find.ps1 b/test/support/windows-integration/plugins/modules/win_find.ps1 new file mode 100644 index 00000000..bc57c5ff --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_find.ps1 @@ -0,0 +1,416 @@ +#!powershell + +# Copyright: (c) 2016, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.LinkUtil + +$spec = @{ + options = @{ + paths = @{ type = "list"; elements = "str"; required = $true } + age = @{ type = "str" } + age_stamp = @{ type = "str"; default = "mtime"; choices = "mtime", "ctime", "atime" } + file_type = @{ type = "str"; default = "file"; choices = "file", "directory" } + follow = @{ type = "bool"; default = $false } + hidden = @{ type = "bool"; default = $false } + patterns = @{ type = "list"; elements = "str"; aliases = "regex", "regexp" } + recurse = @{ type = "bool"; default = $false } + size = @{ type = "str" } + use_regex = @{ type = "bool"; default = $false } + get_checksum = @{ type = "bool"; default = $true } + checksum_algorithm = @{ type = "str"; default = "sha1"; choices = "md5", "sha1", "sha256", "sha384", "sha512" } + } + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$paths = $module.Params.paths +$age = $module.Params.age +$age_stamp = $module.Params.age_stamp +$file_type = $module.Params.file_type +$follow = $module.Params.follow +$hidden = $module.Params.hidden +$patterns = $module.Params.patterns +$recurse = $module.Params.recurse +$size = $module.Params.size +$use_regex = $module.Params.use_regex +$get_checksum = $module.Params.get_checksum +$checksum_algorithm = $module.Params.checksum_algorithm + +$module.Result.examined = 0 +$module.Result.files = @() +$module.Result.matched = 0 + +Load-LinkUtils + +Function Assert-Age { + Param ( + [System.IO.FileSystemInfo]$File, + [System.Int64]$Age, + [System.String]$AgeStamp + ) + + $actual_age = switch ($AgeStamp) { + mtime { $File.LastWriteTime.Ticks } + ctime { $File.CreationTime.Ticks } + atime { $File.LastAccessTime.Ticks } + } + + if ($Age -ge 0) { + return $Age -ge $actual_age + } else { + return ($Age * -1) -le $actual_age + } +} + +Function Assert-FileType { + Param ( + [System.IO.FileSystemInfo]$File, + [System.String]$FileType + ) + + $is_dir = $File.Attributes.HasFlag([System.IO.FileAttributes]::Directory) + return ($FileType -eq 'directory' -and $is_dir) -or ($FileType -eq 'file' -and -not $is_dir) +} + +Function Assert-FileHidden { + Param ( + [System.IO.FileSystemInfo]$File, + [Switch]$IsHidden + ) + + $file_is_hidden = $File.Attributes.HasFlag([System.IO.FileAttributes]::Hidden) + return $IsHidden.IsPresent -eq $file_is_hidden +} + + +Function Assert-FileNamePattern { + Param ( + [System.IO.FileSystemInfo]$File, + [System.String[]]$Patterns, + [Switch]$UseRegex + ) + + $valid_match = $false + foreach ($pattern in $Patterns) { + if ($UseRegex) { + if ($File.Name -match $pattern) { + $valid_match = $true + break + } + } else { + if ($File.Name -like $pattern) { + $valid_match = $true + break + } + } + } + return $valid_match +} + +Function Assert-FileSize { + Param ( + [System.IO.FileSystemInfo]$File, + [System.Int64]$Size + ) + + if ($Size -ge 0) { + return $File.Length -ge $Size + } else { + return $File.Length -le ($Size * -1) + } +} + +Function Get-FileChecksum { + Param ( + [System.String]$Path, + [System.String]$Algorithm + ) + + $sp = switch ($algorithm) { + 'md5' { New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider } + 'sha1' { New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider } + 'sha256' { New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider } + 'sha384' { New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider } + 'sha512' { New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider } + } + + $fp = [System.IO.File]::Open($Path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, [System.IO.FileShare]::ReadWrite) + try { + $hash = [System.BitConverter]::ToString($sp.ComputeHash($fp)).Replace("-", "").ToLower() + } finally { + $fp.Dispose() + } + + return $hash +} + +Function Search-Path { + [CmdletBinding()] + Param ( + [Parameter(Mandatory=$true)] + [System.String] + $Path, + + [Parameter(Mandatory=$true)] + [AllowEmptyCollection()] + [System.Collections.Generic.HashSet`1[System.String]] + $CheckedPaths, + + [Parameter(Mandatory=$true)] + [Object] + $Module, + + [System.Int64] + $Age, + + [System.String] + $AgeStamp, + + [System.String] + $FileType, + + [Switch] + $Follow, + + [Switch] + $GetChecksum, + + [Switch] + $IsHidden, + + [System.String[]] + $Patterns, + + [Switch] + $Recurse, + + [System.Int64] + $Size, + + [Switch] + $UseRegex + ) + + $dir_obj = New-Object -TypeName System.IO.DirectoryInfo -ArgumentList $Path + if ([Int32]$dir_obj.Attributes -eq -1) { + $Module.Warn("Argument path '$Path' does not exist, skipping") + return + } elseif (-not $dir_obj.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) { + $Module.Warn("Argument path '$Path' is a file not a directory, skipping") + return + } + + $dir_files = @() + try { + $dir_files = $dir_obj.EnumerateFileSystemInfos("*", [System.IO.SearchOption]::TopDirectoryOnly) + } catch [System.IO.DirectoryNotFoundException] { # Broken ReparsePoint/Symlink, cannot enumerate + } catch [System.UnauthorizedAccessException] {} # No ListDirectory permissions, Get-ChildItem ignored this + + foreach ($dir_child in $dir_files) { + if ($dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Directory) -and $Recurse) { + if ($Follow -or -not $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::ReparsePoint)) { + $PSBoundParameters.Remove('Path') > $null + Search-Path -Path $dir_child.FullName @PSBoundParameters + } + } + + # Check to see if we've already encountered this path and skip if we have. + if (-not $CheckedPaths.Add($dir_child.FullName.ToLowerInvariant())) { + continue + } + + $Module.Result.examined++ + + if ($PSBoundParameters.ContainsKey('Age')) { + $age_match = Assert-Age -File $dir_child -Age $Age -AgeStamp $AgeStamp + } else { + $age_match = $true + } + + $file_type_match = Assert-FileType -File $dir_child -FileType $FileType + $hidden_match = Assert-FileHidden -File $dir_child -IsHidden:$IsHidden + + if ($PSBoundParameters.ContainsKey('Patterns')) { + $pattern_match = Assert-FileNamePattern -File $dir_child -Patterns $Patterns -UseRegex:$UseRegex.IsPresent + } else { + $pattern_match = $true + } + + if ($PSBoundParameters.ContainsKey('Size')) { + $size_match = Assert-FileSize -File $dir_child -Size $Size + } else { + $size_match = $true + } + + if (-not ($age_match -and $file_type_match -and $hidden_match -and $pattern_match -and $size_match)) { + continue + } + + # It passed all our filters so add it + $module.Result.matched++ + + # TODO: Make this generic so it can be shared with win_find and win_stat. + $epoch = New-Object -Type System.DateTime -ArgumentList 1970, 1, 1, 0, 0, 0, 0 + $file_info = @{ + attributes = $dir_child.Attributes.ToString() + checksum = $null + creationtime = (New-TimeSpan -Start $epoch -End $dir_child.CreationTime).TotalSeconds + exists = $true + extension = $null + filename = $dir_child.Name + isarchive = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Archive) + isdir = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Directory) + ishidden = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Hidden) + isreadonly = $dir_child.Attributes.HasFlag([System.IO.FileAttributes]::ReadOnly) + isreg = $false + isshared = $false + lastaccesstime = (New-TimeSpan -Start $epoch -End $dir_child.LastAccessTime).TotalSeconds + lastwritetime = (New-TimeSpan -Start $epoch -End $dir_child.LastWriteTime).TotalSeconds + owner = $null + path = $dir_child.FullName + sharename = $null + size = $null + } + + try { + $file_info.owner = $dir_child.GetAccessControl().Owner + } catch {} # May not have rights to get the Owner, historical behaviour is to ignore. + + if ($dir_child.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) { + $share_info = Get-CimInstance -ClassName Win32_Share -Filter "Path='$($dir_child.FullName -replace '\\', '\\')'" + if ($null -ne $share_info) { + $file_info.isshared = $true + $file_info.sharename = $share_info.Name + } + } else { + $file_info.extension = $dir_child.Extension + $file_info.isreg = $true + $file_info.size = $dir_child.Length + + if ($GetChecksum) { + try { + $file_info.checksum = Get-FileChecksum -Path $dir_child.FullName -Algorithm $checksum_algorithm + } catch {} # Just keep the checksum as $null in the case of a failure. + } + } + + # Append the link information if the path is a link + $link_info = @{ + isjunction = $false + islnk = $false + nlink = 1 + lnk_source = $null + lnk_target = $null + hlnk_targets = @() + } + $link_stat = Get-Link -link_path $dir_child.FullName + if ($null -ne $link_stat) { + switch ($link_stat.Type) { + "SymbolicLink" { + $link_info.islnk = $true + $link_info.isreg = $false + $link_info.lnk_source = $link_stat.AbsolutePath + $link_info.lnk_target = $link_stat.TargetPath + break + } + "JunctionPoint" { + $link_info.isjunction = $true + $link_info.isreg = $false + $link_info.lnk_source = $link_stat.AbsolutePath + $link_info.lnk_target = $link_stat.TargetPath + break + } + "HardLink" { + $link_info.nlink = $link_stat.HardTargets.Count + + # remove current path from the targets + $hlnk_targets = $link_info.HardTargets | Where-Object { $_ -ne $dir_child.FullName } + $link_info.hlnk_targets = @($hlnk_targets) + break + } + } + } + foreach ($kv in $link_info.GetEnumerator()) { + $file_info.$($kv.Key) = $kv.Value + } + + # Output the file_info object + $file_info + } +} + +$search_params = @{ + CheckedPaths = [System.Collections.Generic.HashSet`1[System.String]]@() + GetChecksum = $get_checksum + Module = $module + FileType = $file_type + Follow = $follow + IsHidden = $hidden + Recurse = $recurse +} + +if ($null -ne $age) { + $seconds_per_unit = @{'s'=1; 'm'=60; 'h'=3600; 'd'=86400; 'w'=604800} + $seconds_pattern = '^(-?\d+)(s|m|h|d|w)?$' + $match = $age -match $seconds_pattern + if ($Match) { + $specified_seconds = [Int64]$Matches[1] + if ($null -eq $Matches[2]) { + $chosen_unit = 's' + } else { + $chosen_unit = $Matches[2] + } + + $total_seconds = $specified_seconds * ($seconds_per_unit.$chosen_unit) + + if ($total_seconds -ge 0) { + $search_params.Age = (Get-Date).AddSeconds($total_seconds * -1).Ticks + } else { + # Make sure we add the positive value of seconds to current time then make it negative for later comparisons. + $age = (Get-Date).AddSeconds($total_seconds).Ticks + $search_params.Age = $age * -1 + } + $search_params.AgeStamp = $age_stamp + } else { + $module.FailJson("Invalid age pattern specified") + } +} + +if ($null -ne $patterns) { + $search_params.Patterns = $patterns + $search_params.UseRegex = $use_regex +} + +if ($null -ne $size) { + $bytes_per_unit = @{'b'=1; 'k'=1KB; 'm'=1MB; 'g'=1GB;'t'=1TB} + $size_pattern = '^(-?\d+)(b|k|m|g|t)?$' + $match = $size -match $size_pattern + if ($Match) { + $specified_size = [Int64]$Matches[1] + if ($null -eq $Matches[2]) { + $chosen_byte = 'b' + } else { + $chosen_byte = $Matches[2] + } + + $search_params.Size = $specified_size * ($bytes_per_unit.$chosen_byte) + } else { + $module.FailJson("Invalid size pattern specified") + } +} + +$matched_files = foreach ($path in $paths) { + # Ensure we pass in an absolute path. We use the ExecutionContext as this is based on the PSProvider path not the + # process location which can be different. + $abs_path = $ExecutionContext.SessionState.Path.GetUnresolvedProviderPathFromPSPath($path) + Search-Path -Path $abs_path @search_params +} + +# Make sure we sort the files in alphabetical order. +$module.Result.files = @() + ($matched_files | Sort-Object -Property {$_.path}) + +$module.ExitJson() + diff --git a/test/support/windows-integration/plugins/modules/win_find.py b/test/support/windows-integration/plugins/modules/win_find.py new file mode 100644 index 00000000..f506f956 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_find.py @@ -0,0 +1,345 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_find +version_added: "2.3" +short_description: Return a list of files based on specific criteria +description: + - Return a list of files based on specified criteria. + - Multiple criteria are AND'd together. + - For non-Windows targets, use the M(find) module instead. +options: + age: + description: + - Select files or folders whose age is equal to or greater than + the specified time. + - Use a negative age to find files equal to or less than + the specified time. + - You can choose seconds, minutes, hours, days or weeks + by specifying the first letter of an of + those words (e.g., "2s", "10d", 1w"). + type: str + age_stamp: + description: + - Choose the file property against which we compare C(age). + - The default attribute we compare with is the last modification time. + type: str + choices: [ atime, ctime, mtime ] + default: mtime + checksum_algorithm: + description: + - Algorithm to determine the checksum of a file. + - Will throw an error if the host is unable to use specified algorithm. + type: str + choices: [ md5, sha1, sha256, sha384, sha512 ] + default: sha1 + file_type: + description: Type of file to search for. + type: str + choices: [ directory, file ] + default: file + follow: + description: + - Set this to C(yes) to follow symlinks in the path. + - This needs to be used in conjunction with C(recurse). + type: bool + default: no + get_checksum: + description: + - Whether to return a checksum of the file in the return info (default sha1), + use C(checksum_algorithm) to change from the default. + type: bool + default: yes + hidden: + description: Set this to include hidden files or folders. + type: bool + default: no + paths: + description: + - List of paths of directories to search for files or folders in. + - This can be supplied as a single path or a list of paths. + type: list + required: yes + patterns: + description: + - One or more (powershell or regex) patterns to compare filenames with. + - The type of pattern matching is controlled by C(use_regex) option. + - The patterns restrict the list of files or folders to be returned based on the filenames. + - For a file to be matched it only has to match with one pattern in a list provided. + type: list + aliases: [ "regex", "regexp" ] + recurse: + description: + - Will recursively descend into the directory looking for files or folders. + type: bool + default: no + size: + description: + - Select files or folders whose size is equal to or greater than the specified size. + - Use a negative value to find files equal to or less than the specified size. + - You can specify the size with a suffix of the byte type i.e. kilo = k, mega = m... + - Size is not evaluated for symbolic links. + type: str + use_regex: + description: + - Will set patterns to run as a regex check if set to C(yes). + type: bool + default: no +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Find files in path + win_find: + paths: D:\Temp + +- name: Find hidden files in path + win_find: + paths: D:\Temp + hidden: yes + +- name: Find files in multiple paths + win_find: + paths: + - C:\Temp + - D:\Temp + +- name: Find files in directory while searching recursively + win_find: + paths: D:\Temp + recurse: yes + +- name: Find files in directory while following symlinks + win_find: + paths: D:\Temp + recurse: yes + follow: yes + +- name: Find files with .log and .out extension using powershell wildcards + win_find: + paths: D:\Temp + patterns: [ '*.log', '*.out' ] + +- name: Find files in path based on regex pattern + win_find: + paths: D:\Temp + patterns: out_\d{8}-\d{6}.log + +- name: Find files older than 1 day + win_find: + paths: D:\Temp + age: 86400 + +- name: Find files older than 1 day based on create time + win_find: + paths: D:\Temp + age: 86400 + age_stamp: ctime + +- name: Find files older than 1 day with unit syntax + win_find: + paths: D:\Temp + age: 1d + +- name: Find files newer than 1 hour + win_find: + paths: D:\Temp + age: -3600 + +- name: Find files newer than 1 hour with unit syntax + win_find: + paths: D:\Temp + age: -1h + +- name: Find files larger than 1MB + win_find: + paths: D:\Temp + size: 1048576 + +- name: Find files larger than 1GB with unit syntax + win_find: + paths: D:\Temp + size: 1g + +- name: Find files smaller than 1MB + win_find: + paths: D:\Temp + size: -1048576 + +- name: Find files smaller than 1GB with unit syntax + win_find: + paths: D:\Temp + size: -1g + +- name: Find folders/symlinks in multiple paths + win_find: + paths: + - C:\Temp + - D:\Temp + file_type: directory + +- name: Find files and return SHA256 checksum of files found + win_find: + paths: C:\Temp + get_checksum: yes + checksum_algorithm: sha256 + +- name: Find files and do not return the checksum + win_find: + paths: C:\Temp + get_checksum: no +''' + +RETURN = r''' +examined: + description: The number of files/folders that was checked. + returned: always + type: int + sample: 10 +matched: + description: The number of files/folders that match the criteria. + returned: always + type: int + sample: 2 +files: + description: Information on the files/folders that match the criteria returned as a list of dictionary elements + for each file matched. The entries are sorted by the path value alphabetically. + returned: success + type: complex + contains: + attributes: + description: attributes of the file at path in raw form. + returned: success, path exists + type: str + sample: "Archive, Hidden" + checksum: + description: The checksum of a file based on checksum_algorithm specified. + returned: success, path exists, path is a file, get_checksum == True + type: str + sample: 09cb79e8fc7453c84a07f644e441fd81623b7f98 + creationtime: + description: The create time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + exists: + description: Whether the file exists, will always be true for M(win_find). + returned: success, path exists + type: bool + sample: true + extension: + description: The extension of the file at path. + returned: success, path exists, path is a file + type: str + sample: ".ps1" + filename: + description: The name of the file. + returned: success, path exists + type: str + sample: temp + hlnk_targets: + description: List of other files pointing to the same file (hard links), excludes the current file. + returned: success, path exists + type: list + sample: + - C:\temp\file.txt + - C:\Windows\update.log + isarchive: + description: If the path is ready for archiving or not. + returned: success, path exists + type: bool + sample: true + isdir: + description: If the path is a directory or not. + returned: success, path exists + type: bool + sample: true + ishidden: + description: If the path is hidden or not. + returned: success, path exists + type: bool + sample: true + isjunction: + description: If the path is a junction point. + returned: success, path exists + type: bool + sample: true + islnk: + description: If the path is a symbolic link. + returned: success, path exists + type: bool + sample: true + isreadonly: + description: If the path is read only or not. + returned: success, path exists + type: bool + sample: true + isreg: + description: If the path is a regular file or not. + returned: success, path exists + type: bool + sample: true + isshared: + description: If the path is shared or not. + returned: success, path exists + type: bool + sample: true + lastaccesstime: + description: The last access time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lastwritetime: + description: The last modification time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lnk_source: + description: The target of the symlink normalized for the remote filesystem. + returned: success, path exists, path is a symbolic link or junction point + type: str + sample: C:\temp + lnk_target: + description: The target of the symlink. Note that relative paths remain relative, will return null if not a link. + returned: success, path exists, path is a symbolic link or junction point + type: str + sample: temp + nlink: + description: Number of links to the file (hard links) + returned: success, path exists + type: int + sample: 1 + owner: + description: The owner of the file. + returned: success, path exists + type: str + sample: BUILTIN\Administrators + path: + description: The full absolute path to the file. + returned: success, path exists + type: str + sample: BUILTIN\Administrators + sharename: + description: The name of share if folder is shared. + returned: success, path exists, path is a directory and isshared == True + type: str + sample: file-share + size: + description: The size in bytes of the file. + returned: success, path exists, path is a file + type: int + sample: 1024 +''' diff --git a/test/support/windows-integration/plugins/modules/win_format.ps1 b/test/support/windows-integration/plugins/modules/win_format.ps1 new file mode 100644 index 00000000..b5fd3ae0 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_format.ps1 @@ -0,0 +1,200 @@ +#!powershell + +# Copyright: (c) 2019, Varun Chopra (@chopraaa) <v@chopraaa.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#AnsibleRequires -OSVersion 6.2 + +Set-StrictMode -Version 2 + +$ErrorActionPreference = "Stop" + +$spec = @{ + options = @{ + drive_letter = @{ type = "str" } + path = @{ type = "str" } + label = @{ type = "str" } + new_label = @{ type = "str" } + file_system = @{ type = "str"; choices = "ntfs", "refs", "exfat", "fat32", "fat" } + allocation_unit_size = @{ type = "int" } + large_frs = @{ type = "bool" } + full = @{ type = "bool"; default = $false } + compress = @{ type = "bool" } + integrity_streams = @{ type = "bool" } + force = @{ type = "bool"; default = $false } + } + mutually_exclusive = @( + ,@('drive_letter', 'path', 'label') + ) + required_one_of = @( + ,@('drive_letter', 'path', 'label') + ) + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$drive_letter = $module.Params.drive_letter +$path = $module.Params.path +$label = $module.Params.label +$new_label = $module.Params.new_label +$file_system = $module.Params.file_system +$allocation_unit_size = $module.Params.allocation_unit_size +$large_frs = $module.Params.large_frs +$full_format = $module.Params.full +$compress_volume = $module.Params.compress +$integrity_streams = $module.Params.integrity_streams +$force_format = $module.Params.force + +# Some pre-checks +if ($null -ne $drive_letter -and $drive_letter -notmatch "^[a-zA-Z]$") { + $module.FailJson("The parameter drive_letter should be a single character A-Z") +} +if ($integrity_streams -eq $true -and $file_system -ne "refs") { + $module.FailJson("Integrity streams can be enabled only on ReFS volumes. You specified: $($file_system)") +} +if ($compress_volume -eq $true) { + if ($file_system -eq "ntfs") { + if ($null -ne $allocation_unit_size -and $allocation_unit_size -gt 4096) { + $module.FailJson("NTFS compression is not supported for allocation unit sizes above 4096") + } + } + else { + $module.FailJson("Compression can be enabled only on NTFS volumes. You specified: $($file_system)") + } +} + +function Get-AnsibleVolume { + param( + $DriveLetter, + $Path, + $Label + ) + + if ($null -ne $DriveLetter) { + try { + $volume = Get-Volume -DriveLetter $DriveLetter + } catch { + $module.FailJson("There was an error retrieving the volume using drive_letter $($DriveLetter): $($_.Exception.Message)", $_) + } + } + elseif ($null -ne $Path) { + try { + $volume = Get-Volume -Path $Path + } catch { + $module.FailJson("There was an error retrieving the volume using path $($Path): $($_.Exception.Message)", $_) + } + } + elseif ($null -ne $Label) { + try { + $volume = Get-Volume -FileSystemLabel $Label + } catch { + $module.FailJson("There was an error retrieving the volume using label $($Label): $($_.Exception.Message)", $_) + } + } + else { + $module.FailJson("Unable to locate volume: drive_letter, path and label were not specified") + } + + return $volume +} + +function Format-AnsibleVolume { + param( + $Path, + $Label, + $FileSystem, + $Full, + $UseLargeFRS, + $Compress, + $SetIntegrityStreams, + $AllocationUnitSize + ) + $parameters = @{ + Path = $Path + Full = $Full + } + if ($null -ne $UseLargeFRS) { + $parameters.Add("UseLargeFRS", $UseLargeFRS) + } + if ($null -ne $SetIntegrityStreams) { + $parameters.Add("SetIntegrityStreams", $SetIntegrityStreams) + } + if ($null -ne $Compress){ + $parameters.Add("Compress", $Compress) + } + if ($null -ne $Label) { + $parameters.Add("NewFileSystemLabel", $Label) + } + if ($null -ne $FileSystem) { + $parameters.Add("FileSystem", $FileSystem) + } + if ($null -ne $AllocationUnitSize) { + $parameters.Add("AllocationUnitSize", $AllocationUnitSize) + } + + Format-Volume @parameters -Confirm:$false | Out-Null + +} + +$ansible_volume = Get-AnsibleVolume -DriveLetter $drive_letter -Path $path -Label $label +$ansible_file_system = $ansible_volume.FileSystem +$ansible_volume_size = $ansible_volume.Size +$ansible_volume_alu = (Get-CimInstance -ClassName Win32_Volume -Filter "DeviceId = '$($ansible_volume.path.replace('\','\\'))'" -Property BlockSize).BlockSize + +$ansible_partition = Get-Partition -Volume $ansible_volume + +if (-not $force_format -and $null -ne $allocation_unit_size -and $ansible_volume_alu -ne 0 -and $null -ne $ansible_volume_alu -and $allocation_unit_size -ne $ansible_volume_alu) { + $module.FailJson("Force format must be specified since target allocation unit size: $($allocation_unit_size) is different from the current allocation unit size of the volume: $($ansible_volume_alu)") +} + +foreach ($access_path in $ansible_partition.AccessPaths) { + if ($access_path -ne $Path) { + if ($null -ne $file_system -and + -not [string]::IsNullOrEmpty($ansible_file_system) -and + $file_system -ne $ansible_file_system) + { + if (-not $force_format) + { + $no_files_in_volume = (Get-ChildItem -LiteralPath $access_path -ErrorAction SilentlyContinue | Measure-Object).Count -eq 0 + if($no_files_in_volume) + { + $module.FailJson("Force format must be specified since target file system: $($file_system) is different from the current file system of the volume: $($ansible_file_system.ToLower())") + } + else + { + $module.FailJson("Force format must be specified to format non-pristine volumes") + } + } + } + else + { + $pristine = -not $force_format + } + } +} + +if ($force_format) { + if (-not $module.CheckMode) { + Format-AnsibleVolume -Path $ansible_volume.Path -Full $full_format -Label $new_label -FileSystem $file_system -SetIntegrityStreams $integrity_streams -UseLargeFRS $large_frs -Compress $compress_volume -AllocationUnitSize $allocation_unit_size + } + $module.Result.changed = $true +} +else { + if ($pristine) { + if ($null -eq $new_label) { + $new_label = $ansible_volume.FileSystemLabel + } + # Conditions for formatting + if ($ansible_volume_size -eq 0 -or + $ansible_volume.FileSystemLabel -ne $new_label) { + if (-not $module.CheckMode) { + Format-AnsibleVolume -Path $ansible_volume.Path -Full $full_format -Label $new_label -FileSystem $file_system -SetIntegrityStreams $integrity_streams -UseLargeFRS $large_frs -Compress $compress_volume -AllocationUnitSize $allocation_unit_size + } + $module.Result.changed = $true + } + } +} + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_format.py b/test/support/windows-integration/plugins/modules/win_format.py new file mode 100644 index 00000000..f8f18ed7 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_format.py @@ -0,0 +1,103 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019, Varun Chopra (@chopraaa) <v@chopraaa.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community' +} + +DOCUMENTATION = r''' +module: win_format +version_added: '2.8' +short_description: Formats an existing volume or a new volume on an existing partition on Windows +description: + - The M(win_format) module formats an existing volume or a new volume on an existing partition on Windows +options: + drive_letter: + description: + - Used to specify the drive letter of the volume to be formatted. + type: str + path: + description: + - Used to specify the path to the volume to be formatted. + type: str + label: + description: + - Used to specify the label of the volume to be formatted. + type: str + new_label: + description: + - Used to specify the new file system label of the formatted volume. + type: str + file_system: + description: + - Used to specify the file system to be used when formatting the target volume. + type: str + choices: [ ntfs, refs, exfat, fat32, fat ] + allocation_unit_size: + description: + - Specifies the cluster size to use when formatting the volume. + - If no cluster size is specified when you format a partition, defaults are selected based on + the size of the partition. + - This value must be a multiple of the physical sector size of the disk. + type: int + large_frs: + description: + - Specifies that large File Record System (FRS) should be used. + type: bool + compress: + description: + - Enable compression on the resulting NTFS volume. + - NTFS compression is not supported where I(allocation_unit_size) is more than 4096. + type: bool + integrity_streams: + description: + - Enable integrity streams on the resulting ReFS volume. + type: bool + full: + description: + - A full format writes to every sector of the disk, takes much longer to perform than the + default (quick) format, and is not recommended on storage that is thinly provisioned. + - Specify C(true) for full format. + type: bool + force: + description: + - Specify if formatting should be forced for volumes that are not created from new partitions + or if the source and target file system are different. + type: bool +notes: + - Microsoft Windows Server 2012 or Microsoft Windows 8 or newer is required to use this module. To check if your system is compatible, see + U(https://docs.microsoft.com/en-us/windows/desktop/sysinfo/operating-system-version). + - One of three parameters (I(drive_letter), I(path) and I(label)) are mandatory to identify the target + volume but more than one cannot be specified at the same time. + - This module is idempotent if I(force) is not specified and file system labels remain preserved. + - For more information, see U(https://docs.microsoft.com/en-us/previous-versions/windows/desktop/stormgmt/format-msft-volume) +seealso: + - module: win_disk_facts + - module: win_partition +author: + - Varun Chopra (@chopraaa) <v@chopraaa.com> +''' + +EXAMPLES = r''' +- name: Create a partition with drive letter D and size 5 GiB + win_partition: + drive_letter: D + partition_size: 5 GiB + disk_number: 1 + +- name: Full format the newly created partition as NTFS and label it + win_format: + drive_letter: D + file_system: NTFS + new_label: Formatted + full: True +''' + +RETURN = r''' +# +''' diff --git a/test/support/windows-integration/plugins/modules/win_get_url.ps1 b/test/support/windows-integration/plugins/modules/win_get_url.ps1 new file mode 100644 index 00000000..1d8dd5a3 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_get_url.ps1 @@ -0,0 +1,274 @@ +#!powershell + +# Copyright: (c) 2015, Paul Durivage <paul.durivage@rackspace.com> +# Copyright: (c) 2015, Tal Auslander <tal@cloudshare.com> +# Copyright: (c) 2017, Dag Wieers <dag@wieers.com> +# Copyright: (c) 2019, Viktor Utkin <viktor_utkin@epam.com> +# Copyright: (c) 2019, Uladzimir Klybik <uladzimir_klybik@epam.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.FileUtil +#Requires -Module Ansible.ModuleUtils.WebRequest + +$spec = @{ + options = @{ + url = @{ type="str"; required=$true } + dest = @{ type='path'; required=$true } + force = @{ type='bool'; default=$true } + checksum = @{ type='str' } + checksum_algorithm = @{ type='str'; default='sha1'; choices = @("md5", "sha1", "sha256", "sha384", "sha512") } + checksum_url = @{ type='str' } + + # Defined for the alias backwards compatibility, remove once aliases are removed + url_username = @{ + aliases = @("user", "username") + deprecated_aliases = @( + @{ name = "user"; version = "2.14" }, + @{ name = "username"; version = "2.14" } + ) + } + url_password = @{ + aliases = @("password") + deprecated_aliases = @( + @{ name = "password"; version = "2.14" } + ) + } + } + mutually_exclusive = @( + ,@('checksum', 'checksum_url') + ) + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec, @(Get-AnsibleWebRequestSpec)) + +$url = $module.Params.url +$dest = $module.Params.dest +$force = $module.Params.force +$checksum = $module.Params.checksum +$checksum_algorithm = $module.Params.checksum_algorithm +$checksum_url = $module.Params.checksum_url + +$module.Result.elapsed = 0 +$module.Result.url = $url + +Function Get-ChecksumFromUri { + param( + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Parameter(Mandatory=$true)][Uri]$Uri, + [Uri]$SourceUri + ) + + $script = { + param($Response, $Stream) + + $read_stream = New-Object -TypeName System.IO.StreamReader -ArgumentList $Stream + $web_checksum = $read_stream.ReadToEnd() + $basename = (Split-Path -Path $SourceUri.LocalPath -Leaf) + $basename = [regex]::Escape($basename) + $web_checksum_str = $web_checksum -split '\r?\n' | Select-String -Pattern $("\s+\.?\/?\\?" + $basename + "\s*$") + if (-not $web_checksum_str) { + $Module.FailJson("Checksum record not found for file name '$basename' in file from url: '$Uri'") + } + + $web_checksum_str_splitted = $web_checksum_str[0].ToString().split(" ", 2) + $hash_from_file = $web_checksum_str_splitted[0].Trim() + # Remove any non-alphanumeric characters + $hash_from_file = $hash_from_file -replace '\W+', '' + + Write-Output -InputObject $hash_from_file + } + $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module + + try { + Invoke-WithWebRequest -Module $Module -Request $web_request -Script $script + } catch { + $Module.FailJson("Error when getting the remote checksum from '$Uri'. $($_.Exception.Message)", $_) + } +} + +Function Compare-ModifiedFile { + <# + .SYNOPSIS + Compares the remote URI resource against the local Dest resource. Will + return true if the LastWriteTime/LastModificationDate of the remote is + newer than the local resource date. + #> + param( + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Parameter(Mandatory=$true)][Uri]$Uri, + [Parameter(Mandatory=$true)][String]$Dest + ) + + $dest_last_mod = (Get-AnsibleItem -Path $Dest).LastWriteTimeUtc + + # If the URI is a file we don't need to go through the whole WebRequest + if ($Uri.IsFile) { + $src_last_mod = (Get-AnsibleItem -Path $Uri.AbsolutePath).LastWriteTimeUtc + } else { + $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module + $web_request.Method = switch ($web_request.GetType().Name) { + FtpWebRequest { [System.Net.WebRequestMethods+Ftp]::GetDateTimestamp } + HttpWebRequest { [System.Net.WebRequestMethods+Http]::Head } + } + $script = { param($Response, $Stream); $Response.LastModified } + + try { + $src_last_mod = Invoke-WithWebRequest -Module $Module -Request $web_request -Script $script + } catch { + $Module.FailJson("Error when requesting 'Last-Modified' date from '$Uri'. $($_.Exception.Message)", $_) + } + } + + # Return $true if the Uri LastModification date is newer than the Dest LastModification date + ((Get-Date -Date $src_last_mod).ToUniversalTime() -gt $dest_last_mod) +} + +Function Get-Checksum { + param( + [Parameter(Mandatory=$true)][String]$Path, + [String]$Algorithm = "sha1" + ) + + switch ($Algorithm) { + 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider } + 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider } + 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider } + 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider } + 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider } + } + + $fs = [System.IO.File]::Open($Path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, + [System.IO.FileShare]::ReadWrite) + try { + $hash = [System.BitConverter]::ToString($sp.ComputeHash($fs)).Replace("-", "").ToLower() + } finally { + $fs.Dispose() + } + return $hash +} + +Function Invoke-DownloadFile { + param( + [Parameter(Mandatory=$true)][Ansible.Basic.AnsibleModule]$Module, + [Parameter(Mandatory=$true)][Uri]$Uri, + [Parameter(Mandatory=$true)][String]$Dest, + [String]$Checksum, + [String]$ChecksumAlgorithm + ) + + # Check $dest parent folder exists before attempting download, which avoids unhelpful generic error message. + $dest_parent = Split-Path -LiteralPath $Dest + if (-not (Test-Path -LiteralPath $dest_parent -PathType Container)) { + $module.FailJson("The path '$dest_parent' does not exist for destination '$Dest', or is not visible to the current user. Ensure download destination folder exists (perhaps using win_file state=directory) before win_get_url runs.") + } + + $download_script = { + param($Response, $Stream) + + # Download the file to a temporary directory so we can compare it + $tmp_dest = Join-Path -Path $Module.Tmpdir -ChildPath ([System.IO.Path]::GetRandomFileName()) + $fs = [System.IO.File]::Create($tmp_dest) + try { + $Stream.CopyTo($fs) + $fs.Flush() + } finally { + $fs.Dispose() + } + $tmp_checksum = Get-Checksum -Path $tmp_dest -Algorithm $ChecksumAlgorithm + $Module.Result.checksum_src = $tmp_checksum + + # If the checksum has been set, verify the checksum of the remote against the input checksum. + if ($Checksum -and $Checksum -ne $tmp_checksum) { + $Module.FailJson(("The checksum for {0} did not match '{1}', it was '{2}'" -f $Uri, $Checksum, $tmp_checksum)) + } + + $download = $true + if (Test-Path -LiteralPath $Dest) { + # Validate the remote checksum against the existing downloaded file + $dest_checksum = Get-Checksum -Path $Dest -Algorithm $ChecksumAlgorithm + + # If we don't need to download anything, save the dest checksum so we don't waste time calculating it + # again at the end of the script + if ($dest_checksum -eq $tmp_checksum) { + $download = $false + $Module.Result.checksum_dest = $dest_checksum + $Module.Result.size = (Get-AnsibleItem -Path $Dest).Length + } + } + + if ($download) { + Copy-Item -LiteralPath $tmp_dest -Destination $Dest -Force -WhatIf:$Module.CheckMode > $null + $Module.Result.changed = $true + } + } + $web_request = Get-AnsibleWebRequest -Uri $Uri -Module $Module + + try { + Invoke-WithWebRequest -Module $Module -Request $web_request -Script $download_script + } catch { + $Module.FailJson("Error downloading '$Uri' to '$Dest': $($_.Exception.Message)", $_) + } +} + +# Use last part of url for dest file name if a directory is supplied for $dest +if (Test-Path -LiteralPath $dest -PathType Container) { + $uri = [System.Uri]$url + $basename = Split-Path -Path $uri.LocalPath -Leaf + if ($uri.LocalPath -and $uri.LocalPath -ne '/' -and $basename) { + $url_basename = Split-Path -Path $uri.LocalPath -Leaf + $dest = Join-Path -Path $dest -ChildPath $url_basename + } else { + $dest = Join-Path -Path $dest -ChildPath $uri.Host + } + + # Ensure we have a string instead of a PS object to avoid serialization issues + $dest = $dest.ToString() +} elseif (([System.IO.Path]::GetFileName($dest)) -eq '') { + # We have a trailing path separator + $module.FailJson("The destination path '$dest' does not exist, or is not visible to the current user. Ensure download destination folder exists (perhaps using win_file state=directory) before win_get_url runs.") +} + +$module.Result.dest = $dest + +if ($checksum) { + $checksum = $checksum.Trim().ToLower() +} +if ($checksum_algorithm) { + $checksum_algorithm = $checksum_algorithm.Trim().ToLower() +} +if ($checksum_url) { + $checksum_url = $checksum_url.Trim() +} + +# Check for case $checksum variable contain url. If yes, get file data from url and replace original value in $checksum +if ($checksum_url) { + $checksum_uri = [System.Uri]$checksum_url + if ($checksum_uri.Scheme -notin @("file", "ftp", "http", "https")) { + $module.FailJson("Unsupported 'checksum_url' value for '$dest': '$checksum_url'") + } + + $checksum = Get-ChecksumFromUri -Module $Module -Uri $checksum_uri -SourceUri $url +} + +if ($force -or -not (Test-Path -LiteralPath $dest)) { + # force=yes or dest does not exist, download the file + # Note: Invoke-DownloadFile will compare the checksums internally if dest exists + Invoke-DownloadFile -Module $module -Uri $url -Dest $dest -Checksum $checksum ` + -ChecksumAlgorithm $checksum_algorithm +} else { + # force=no, we want to check the last modified dates and only download if they don't match + $is_modified = Compare-ModifiedFile -Module $module -Uri $url -Dest $dest + if ($is_modified) { + Invoke-DownloadFile -Module $module -Uri $url -Dest $dest -Checksum $checksum ` + -ChecksumAlgorithm $checksum_algorithm + } +} + +if ((-not $module.Result.ContainsKey("checksum_dest")) -and (Test-Path -LiteralPath $dest)) { + # Calculate the dest file checksum if it hasn't already been done + $module.Result.checksum_dest = Get-Checksum -Path $dest -Algorithm $checksum_algorithm + $module.Result.size = (Get-AnsibleItem -Path $dest).Length +} + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_get_url.py b/test/support/windows-integration/plugins/modules/win_get_url.py new file mode 100644 index 00000000..ef5b5f97 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_get_url.py @@ -0,0 +1,215 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com>, and others +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# This is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_get_url +version_added: "1.7" +short_description: Downloads file from HTTP, HTTPS, or FTP to node +description: +- Downloads files from HTTP, HTTPS, or FTP to the remote server. +- The remote server I(must) have direct access to the remote resource. +- For non-Windows targets, use the M(get_url) module instead. +options: + url: + description: + - The full URL of a file to download. + type: str + required: yes + dest: + description: + - The location to save the file at the URL. + - Be sure to include a filename and extension as appropriate. + type: path + required: yes + force: + description: + - If C(yes), will download the file every time and replace the file if the contents change. If C(no), will only + download the file if it does not exist or the remote file has been + modified more recently than the local file. + - This works by sending an http HEAD request to retrieve last modified + time of the requested resource, so for this to work, the remote web + server must support HEAD requests. + type: bool + default: yes + version_added: "2.0" + checksum: + description: + - If a I(checksum) is passed to this parameter, the digest of the + destination file will be calculated after it is downloaded to ensure + its integrity and verify that the transfer completed successfully. + - This option cannot be set with I(checksum_url). + type: str + version_added: "2.8" + checksum_algorithm: + description: + - Specifies the hashing algorithm used when calculating the checksum of + the remote and destination file. + type: str + choices: + - md5 + - sha1 + - sha256 + - sha384 + - sha512 + default: sha1 + version_added: "2.8" + checksum_url: + description: + - Specifies a URL that contains the checksum values for the resource at + I(url). + - Like C(checksum), this is used to verify the integrity of the remote + transfer. + - This option cannot be set with I(checksum). + type: str + version_added: "2.8" + url_username: + description: + - The username to use for authentication. + - The aliases I(user) and I(username) are deprecated and will be removed in + Ansible 2.14. + aliases: + - user + - username + url_password: + description: + - The password for I(url_username). + - The alias I(password) is deprecated and will be removed in Ansible 2.14. + aliases: + - password + proxy_url: + version_added: "2.0" + proxy_username: + version_added: "2.0" + proxy_password: + version_added: "2.0" + headers: + version_added: "2.4" + use_proxy: + version_added: "2.4" + follow_redirects: + version_added: "2.9" + maximum_redirection: + version_added: "2.9" + client_cert: + version_added: "2.9" + client_cert_password: + version_added: "2.9" + method: + description: + - This option is not for use with C(win_get_url) and should be ignored. + version_added: "2.9" +notes: +- If your URL includes an escaped slash character (%2F) this module will convert it to a real slash. + This is a result of the behaviour of the System.Uri class as described in + L(the documentation,https://docs.microsoft.com/en-us/dotnet/framework/configure-apps/file-schema/network/schemesettings-element-uri-settings#remarks). +- Since Ansible 2.8, the module will skip reporting a change if the remote + checksum is the same as the local local even when C(force=yes). This is to + better align with M(get_url). +extends_documentation_fragment: +- url_windows +seealso: +- module: get_url +- module: uri +- module: win_uri +author: +- Paul Durivage (@angstwad) +- Takeshi Kuramochi (@tksarah) +''' + +EXAMPLES = r''' +- name: Download earthrise.jpg to specified path + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\Users\RandomUser\earthrise.jpg + +- name: Download earthrise.jpg to specified path only if modified + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\Users\RandomUser\earthrise.jpg + force: no + +- name: Download earthrise.jpg to specified path through a proxy server. + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\Users\RandomUser\earthrise.jpg + proxy_url: http://10.0.0.1:8080 + proxy_username: username + proxy_password: password + +- name: Download file from FTP with authentication + win_get_url: + url: ftp://server/file.txt + dest: '%TEMP%\ftp-file.txt' + url_username: ftp-user + url_password: ftp-password + +- name: Download src with sha256 checksum url + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\temp\earthrise.jpg + checksum_url: http://www.example.com/sha256sum.txt + checksum_algorithm: sha256 + force: True + +- name: Download src with sha256 checksum url + win_get_url: + url: http://www.example.com/earthrise.jpg + dest: C:\temp\earthrise.jpg + checksum: a97e6837f60cec6da4491bab387296bbcd72bdba + checksum_algorithm: sha1 + force: True +''' + +RETURN = r''' +dest: + description: destination file/path + returned: always + type: str + sample: C:\Users\RandomUser\earthrise.jpg +checksum_dest: + description: <algorithm> checksum of the file after the download + returned: success and dest has been downloaded + type: str + sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827 +checksum_src: + description: <algorithm> checksum of the remote resource + returned: force=yes or dest did not exist + type: str + sample: 6e642bb8dd5c2e027bf21dd923337cbb4214f827 +elapsed: + description: The elapsed seconds between the start of poll and the end of the module. + returned: always + type: float + sample: 2.1406487 +size: + description: size of the dest file + returned: success + type: int + sample: 1220 +url: + description: requested url + returned: always + type: str + sample: http://www.example.com/earthrise.jpg +msg: + description: Error message, or HTTP status message from web-server + returned: always + type: str + sample: OK +status_code: + description: HTTP status code + returned: always + type: int + sample: 200 +''' diff --git a/test/support/windows-integration/plugins/modules/win_lineinfile.ps1 b/test/support/windows-integration/plugins/modules/win_lineinfile.ps1 new file mode 100644 index 00000000..38dd8b8b --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_lineinfile.ps1 @@ -0,0 +1,450 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.Backup + +function WriteLines($outlines, $path, $linesep, $encodingobj, $validate, $check_mode) { + Try { + $temppath = [System.IO.Path]::GetTempFileName(); + } + Catch { + Fail-Json @{} "Cannot create temporary file! ($($_.Exception.Message))"; + } + $joined = $outlines -join $linesep; + [System.IO.File]::WriteAllText($temppath, $joined, $encodingobj); + + If ($validate) { + + If (-not ($validate -like "*%s*")) { + Fail-Json @{} "validate must contain %s: $validate"; + } + + $validate = $validate.Replace("%s", $temppath); + + $parts = [System.Collections.ArrayList] $validate.Split(" "); + $cmdname = $parts[0]; + + $cmdargs = $validate.Substring($cmdname.Length + 1); + + $process = [Diagnostics.Process]::Start($cmdname, $cmdargs); + $process.WaitForExit(); + + If ($process.ExitCode -ne 0) { + [string] $output = $process.StandardOutput.ReadToEnd(); + [string] $error = $process.StandardError.ReadToEnd(); + Remove-Item $temppath -force; + Fail-Json @{} "failed to validate $cmdname $cmdargs with error: $output $error"; + } + + } + + # Commit changes to the path + $cleanpath = $path.Replace("/", "\"); + Try { + Copy-Item -Path $temppath -Destination $cleanpath -Force -WhatIf:$check_mode; + } + Catch { + Fail-Json @{} "Cannot write to: $cleanpath ($($_.Exception.Message))"; + } + + Try { + Remove-Item -Path $temppath -Force -WhatIf:$check_mode; + } + Catch { + Fail-Json @{} "Cannot remove temporary file: $temppath ($($_.Exception.Message))"; + } + + return $joined; + +} + + +# Implement the functionality for state == 'present' +function Present($path, $regex, $line, $insertafter, $insertbefore, $create, $backup, $backrefs, $validate, $encodingobj, $linesep, $check_mode, $diff_support) { + + # Note that we have to clean up the path because ansible wants to treat / and \ as + # interchangeable in windows pathnames, but .NET framework internals do not support that. + $cleanpath = $path.Replace("/", "\"); + + # Check if path exists. If it does not exist, either create it if create == "yes" + # was specified or fail with a reasonable error message. + If (-not (Test-Path -LiteralPath $path)) { + If (-not $create) { + Fail-Json @{} "Path $path does not exist !"; + } + # Create new empty file, using the specified encoding to write correct BOM + [System.IO.File]::WriteAllLines($cleanpath, "", $encodingobj); + } + + # Initialize result information + $result = @{ + backup = ""; + changed = $false; + msg = ""; + } + + # Read the dest file lines using the indicated encoding into a mutable ArrayList. + $before = [System.IO.File]::ReadAllLines($cleanpath, $encodingobj) + If ($null -eq $before) { + $lines = New-Object System.Collections.ArrayList; + } + Else { + $lines = [System.Collections.ArrayList] $before; + } + + if ($diff_support) { + $result.diff = @{ + before = $before -join $linesep; + } + } + + # Compile the regex specified, if provided + $mre = $null; + If ($regex) { + $mre = New-Object Regex $regex, 'Compiled'; + } + + # Compile the regex for insertafter or insertbefore, if provided + $insre = $null; + If ($insertafter -and $insertafter -ne "BOF" -and $insertafter -ne "EOF") { + $insre = New-Object Regex $insertafter, 'Compiled'; + } + ElseIf ($insertbefore -and $insertbefore -ne "BOF") { + $insre = New-Object Regex $insertbefore, 'Compiled'; + } + + # index[0] is the line num where regex has been found + # index[1] is the line num where insertafter/insertbefore has been found + $index = -1, -1; + $lineno = 0; + + # The latest match object and matched line + $matched_line = ""; + + # Iterate through the lines in the file looking for matches + Foreach ($cur_line in $lines) { + If ($regex) { + $m = $mre.Match($cur_line); + $match_found = $m.Success; + If ($match_found) { + $matched_line = $cur_line; + } + } + Else { + $match_found = $line -ceq $cur_line; + } + If ($match_found) { + $index[0] = $lineno; + } + ElseIf ($insre -and $insre.Match($cur_line).Success) { + If ($insertafter) { + $index[1] = $lineno + 1; + } + If ($insertbefore) { + $index[1] = $lineno; + } + } + $lineno = $lineno + 1; + } + + If ($index[0] -ne -1) { + If ($backrefs) { + $new_line = [regex]::Replace($matched_line, $regex, $line); + } + Else { + $new_line = $line; + } + If ($lines[$index[0]] -cne $new_line) { + $lines[$index[0]] = $new_line; + $result.changed = $true; + $result.msg = "line replaced"; + } + } + ElseIf ($backrefs) { + # No matches - no-op + } + ElseIf ($insertbefore -eq "BOF" -or $insertafter -eq "BOF") { + $lines.Insert(0, $line); + $result.changed = $true; + $result.msg = "line added"; + } + ElseIf ($insertafter -eq "EOF" -or $index[1] -eq -1) { + $lines.Add($line) > $null; + $result.changed = $true; + $result.msg = "line added"; + } + Else { + $lines.Insert($index[1], $line); + $result.changed = $true; + $result.msg = "line added"; + } + + # Write changes to the path if changes were made + If ($result.changed) { + + # Write backup file if backup == "yes" + If ($backup) { + $result.backup_file = Backup-File -path $path -WhatIf:$check_mode + # Ensure backward compatibility (deprecate in future) + $result.backup = $result.backup_file + } + + $writelines_params = @{ + outlines = $lines + path = $path + linesep = $linesep + encodingobj = $encodingobj + validate = $validate + check_mode = $check_mode + } + $after = WriteLines @writelines_params; + + if ($diff_support) { + $result.diff.after = $after; + } + } + + $result.encoding = $encodingobj.WebName; + + Exit-Json $result; +} + + +# Implement the functionality for state == 'absent' +function Absent($path, $regex, $line, $backup, $validate, $encodingobj, $linesep, $check_mode, $diff_support) { + + # Check if path exists. If it does not exist, fail with a reasonable error message. + If (-not (Test-Path -LiteralPath $path)) { + Fail-Json @{} "Path $path does not exist !"; + } + + # Initialize result information + $result = @{ + backup = ""; + changed = $false; + msg = ""; + } + + # Read the dest file lines using the indicated encoding into a mutable ArrayList. Note + # that we have to clean up the path because ansible wants to treat / and \ as + # interchangeable in windows pathnames, but .NET framework internals do not support that. + $cleanpath = $path.Replace("/", "\"); + $before = [System.IO.File]::ReadAllLines($cleanpath, $encodingobj); + If ($null -eq $before) { + $lines = New-Object System.Collections.ArrayList; + } + Else { + $lines = [System.Collections.ArrayList] $before; + } + + if ($diff_support) { + $result.diff = @{ + before = $before -join $linesep; + } + } + + # Compile the regex specified, if provided + $cre = $null; + If ($regex) { + $cre = New-Object Regex $regex, 'Compiled'; + } + + $found = New-Object System.Collections.ArrayList; + $left = New-Object System.Collections.ArrayList; + + Foreach ($cur_line in $lines) { + If ($regex) { + $m = $cre.Match($cur_line); + $match_found = $m.Success; + } + Else { + $match_found = $line -ceq $cur_line; + } + If ($match_found) { + $found.Add($cur_line) > $null; + $result.changed = $true; + } + Else { + $left.Add($cur_line) > $null; + } + } + + # Write changes to the path if changes were made + If ($result.changed) { + + # Write backup file if backup == "yes" + If ($backup) { + $result.backup_file = Backup-File -path $path -WhatIf:$check_mode + # Ensure backward compatibility (deprecate in future) + $result.backup = $result.backup_file + } + + $writelines_params = @{ + outlines = $left + path = $path + linesep = $linesep + encodingobj = $encodingobj + validate = $validate + check_mode = $check_mode + } + $after = WriteLines @writelines_params; + + if ($diff_support) { + $result.diff.after = $after; + } + } + + $result.encoding = $encodingobj.WebName; + $result.found = $found.Count; + $result.msg = "$($found.Count) line(s) removed"; + + Exit-Json $result; +} + + +# Parse the parameters file dropped by the Ansible machinery +$params = Parse-Args $args -supports_check_mode $true; +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false; +$diff_support = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false; + +# Initialize defaults for input parameters. +$path = Get-AnsibleParam -obj $params -name "path" -type "path" -failifempty $true -aliases "dest","destfile","name"; +$regex = Get-AnsibleParam -obj $params -name "regex" -type "str" -aliases "regexp"; +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent"; +$line = Get-AnsibleParam -obj $params -name "line" -type "str"; +$backrefs = Get-AnsibleParam -obj $params -name "backrefs" -type "bool" -default $false; +$insertafter = Get-AnsibleParam -obj $params -name "insertafter" -type "str"; +$insertbefore = Get-AnsibleParam -obj $params -name "insertbefore" -type "str"; +$create = Get-AnsibleParam -obj $params -name "create" -type "bool" -default $false; +$backup = Get-AnsibleParam -obj $params -name "backup" -type "bool" -default $false; +$validate = Get-AnsibleParam -obj $params -name "validate" -type "str"; +$encoding = Get-AnsibleParam -obj $params -name "encoding" -type "str" -default "auto"; +$newline = Get-AnsibleParam -obj $params -name "newline" -type "str" -default "windows" -validateset "unix","windows"; + +# Fail if the path is not a file +If (Test-Path -LiteralPath $path -PathType "container") { + Fail-Json @{} "Path $path is a directory"; +} + +# Default to windows line separator - probably most common +$linesep = "`r`n" +If ($newline -eq "unix") { + $linesep = "`n"; +} + +# Figure out the proper encoding to use for reading / writing the target file. + +# The default encoding is UTF-8 without BOM +$encodingobj = [System.Text.UTF8Encoding] $false; + +# If an explicit encoding is specified, use that instead +If ($encoding -ne "auto") { + $encodingobj = [System.Text.Encoding]::GetEncoding($encoding); +} + +# Otherwise see if we can determine the current encoding of the target file. +# If the file doesn't exist yet (create == 'yes') we use the default or +# explicitly specified encoding set above. +ElseIf (Test-Path -LiteralPath $path) { + + # Get a sorted list of encodings with preambles, longest first + $max_preamble_len = 0; + $sortedlist = New-Object System.Collections.SortedList; + Foreach ($encodinginfo in [System.Text.Encoding]::GetEncodings()) { + $encoding = $encodinginfo.GetEncoding(); + $plen = $encoding.GetPreamble().Length; + If ($plen -gt $max_preamble_len) { + $max_preamble_len = $plen; + } + If ($plen -gt 0) { + $sortedlist.Add(-($plen * 1000000 + $encoding.CodePage), $encoding) > $null; + } + } + + # Get the first N bytes from the file, where N is the max preamble length we saw + [Byte[]]$bom = Get-Content -Encoding Byte -ReadCount $max_preamble_len -TotalCount $max_preamble_len -LiteralPath $path; + + # Iterate through the sorted encodings, looking for a full match. + $found = $false; + Foreach ($encoding in $sortedlist.GetValueList()) { + $preamble = $encoding.GetPreamble(); + If ($preamble -and $bom) { + Foreach ($i in 0..($preamble.Length - 1)) { + If ($i -ge $bom.Length) { + break; + } + If ($preamble[$i] -ne $bom[$i]) { + break; + } + ElseIf ($i + 1 -eq $preamble.Length) { + $encodingobj = $encoding; + $found = $true; + } + } + If ($found) { + break; + } + } + } +} + + +# Main dispatch - based on the value of 'state', perform argument validation and +# call the appropriate handler function. +If ($state -eq "present") { + + If ($backrefs -and -not $regex) { + Fail-Json @{} "regexp= is required with backrefs=true"; + } + + If (-not $line) { + Fail-Json @{} "line= is required with state=present"; + } + + If ($insertbefore -and $insertafter) { + Add-Warning $result "Both insertbefore and insertafter parameters found, ignoring `"insertafter=$insertafter`"" + } + + If (-not $insertbefore -and -not $insertafter) { + $insertafter = "EOF"; + } + + $present_params = @{ + path = $path + regex = $regex + line = $line + insertafter = $insertafter + insertbefore = $insertbefore + create = $create + backup = $backup + backrefs = $backrefs + validate = $validate + encodingobj = $encodingobj + linesep = $linesep + check_mode = $check_mode + diff_support = $diff_support + } + Present @present_params; + +} +ElseIf ($state -eq "absent") { + + If (-not $regex -and -not $line) { + Fail-Json @{} "one of line= or regexp= is required with state=absent"; + } + + $absent_params = @{ + path = $path + regex = $regex + line = $line + backup = $backup + validate = $validate + encodingobj = $encodingobj + linesep = $linesep + check_mode = $check_mode + diff_support = $diff_support + } + Absent @absent_params; +} diff --git a/test/support/windows-integration/plugins/modules/win_lineinfile.py b/test/support/windows-integration/plugins/modules/win_lineinfile.py new file mode 100644 index 00000000..f4fb7f5a --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_lineinfile.py @@ -0,0 +1,180 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_lineinfile +short_description: Ensure a particular line is in a file, or replace an existing line using a back-referenced regular expression +description: + - This module will search a file for a line, and ensure that it is present or absent. + - This is primarily useful when you want to change a single line in a file only. +version_added: "2.0" +options: + path: + description: + - The path of the file to modify. + - Note that the Windows path delimiter C(\) must be escaped as C(\\) when the line is double quoted. + - Before Ansible 2.3 this option was only usable as I(dest), I(destfile) and I(name). + type: path + required: yes + aliases: [ dest, destfile, name ] + backup: + description: + - Determine whether a backup should be created. + - When set to C(yes), create a backup file including the timestamp information + so you can get the original file back if you somehow clobbered it incorrectly. + type: bool + default: no + regex: + description: + - The regular expression to look for in every line of the file. For C(state=present), the pattern to replace if found; only the last line found + will be replaced. For C(state=absent), the pattern of the line to remove. Uses .NET compatible regular expressions; + see U(https://msdn.microsoft.com/en-us/library/hs600312%28v=vs.110%29.aspx). + aliases: [ "regexp" ] + state: + description: + - Whether the line should be there or not. + type: str + choices: [ absent, present ] + default: present + line: + description: + - Required for C(state=present). The line to insert/replace into the file. If C(backrefs) is set, may contain backreferences that will get + expanded with the C(regexp) capture groups if the regexp matches. + - Be aware that the line is processed first on the controller and thus is dependent on yaml quoting rules. Any double quoted line + will have control characters, such as '\r\n', expanded. To print such characters literally, use single or no quotes. + type: str + backrefs: + description: + - Used with C(state=present). If set, line can contain backreferences (both positional and named) that will get populated if the C(regexp) + matches. This flag changes the operation of the module slightly; C(insertbefore) and C(insertafter) will be ignored, and if the C(regexp) + doesn't match anywhere in the file, the file will be left unchanged. + - If the C(regexp) does match, the last matching line will be replaced by the expanded line parameter. + type: bool + default: no + insertafter: + description: + - Used with C(state=present). If specified, the line will be inserted after the last match of specified regular expression. A special value is + available; C(EOF) for inserting the line at the end of the file. + - If specified regular expression has no matches, EOF will be used instead. May not be used with C(backrefs). + type: str + choices: [ EOF, '*regex*' ] + default: EOF + insertbefore: + description: + - Used with C(state=present). If specified, the line will be inserted before the last match of specified regular expression. A value is available; + C(BOF) for inserting the line at the beginning of the file. + - If specified regular expression has no matches, the line will be inserted at the end of the file. May not be used with C(backrefs). + type: str + choices: [ BOF, '*regex*' ] + create: + description: + - Used with C(state=present). If specified, the file will be created if it does not already exist. By default it will fail if the file is missing. + type: bool + default: no + validate: + description: + - Validation to run before copying into place. Use %s in the command to indicate the current file to validate. + - The command is passed securely so shell features like expansion and pipes won't work. + type: str + encoding: + description: + - Specifies the encoding of the source text file to operate on (and thus what the output encoding will be). The default of C(auto) will cause + the module to auto-detect the encoding of the source file and ensure that the modified file is written with the same encoding. + - An explicit encoding can be passed as a string that is a valid value to pass to the .NET framework System.Text.Encoding.GetEncoding() method - + see U(https://msdn.microsoft.com/en-us/library/system.text.encoding%28v=vs.110%29.aspx). + - This is mostly useful with C(create=yes) if you want to create a new file with a specific encoding. If C(create=yes) is specified without a + specific encoding, the default encoding (UTF-8, no BOM) will be used. + type: str + default: auto + newline: + description: + - Specifies the line separator style to use for the modified file. This defaults to the windows line separator (C(\r\n)). Note that the indicated + line separator will be used for file output regardless of the original line separator that appears in the input file. + type: str + choices: [ unix, windows ] + default: windows +notes: + - As of Ansible 2.3, the I(dest) option has been changed to I(path) as default, but I(dest) still works as well. +seealso: +- module: assemble +- module: lineinfile +author: +- Brian Lloyd (@brianlloyd) +''' + +EXAMPLES = r''' +# Before Ansible 2.3, option 'dest', 'destfile' or 'name' was used instead of 'path' +- name: Insert path without converting \r\n + win_lineinfile: + path: c:\file.txt + line: c:\return\new + +- win_lineinfile: + path: C:\Temp\example.conf + regex: '^name=' + line: 'name=JohnDoe' + +- win_lineinfile: + path: C:\Temp\example.conf + regex: '^name=' + state: absent + +- win_lineinfile: + path: C:\Temp\example.conf + regex: '^127\.0\.0\.1' + line: '127.0.0.1 localhost' + +- win_lineinfile: + path: C:\Temp\httpd.conf + regex: '^Listen ' + insertafter: '^#Listen ' + line: Listen 8080 + +- win_lineinfile: + path: C:\Temp\services + regex: '^# port for http' + insertbefore: '^www.*80/tcp' + line: '# port for http by default' + +- name: Create file if it doesn't exist with a specific encoding + win_lineinfile: + path: C:\Temp\utf16.txt + create: yes + encoding: utf-16 + line: This is a utf-16 encoded file + +- name: Add a line to a file and ensure the resulting file uses unix line separators + win_lineinfile: + path: C:\Temp\testfile.txt + line: Line added to file + newline: unix + +- name: Update a line using backrefs + win_lineinfile: + path: C:\Temp\example.conf + backrefs: yes + regex: '(^name=)' + line: '$1JohnDoe' +''' + +RETURN = r''' +backup: + description: + - Name of the backup file that was created. + - This is now deprecated, use C(backup_file) instead. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +backup_file: + description: Name of the backup file that was created. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +''' diff --git a/test/support/windows-integration/plugins/modules/win_path.ps1 b/test/support/windows-integration/plugins/modules/win_path.ps1 new file mode 100644 index 00000000..04eb41a3 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_path.ps1 @@ -0,0 +1,145 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +Set-StrictMode -Version 2 +$ErrorActionPreference = "Stop" + +$system_path = "System\CurrentControlSet\Control\Session Manager\Environment" +$user_path = "Environment" + +# list/arraylist methods don't allow IEqualityComparer override for case/backslash/quote-insensitivity, roll our own search +Function Get-IndexOfPathElement ($list, [string]$value) { + $idx = 0 + $value = $value.Trim('"').Trim('\') + ForEach($el in $list) { + If ([string]$el.Trim('"').Trim('\') -ieq $value) { + return $idx + } + + $idx++ + } + + return -1 +} + +# alters list in place, returns true if at least one element was added +Function Add-Elements ($existing_elements, $elements_to_add) { + $last_idx = -1 + $changed = $false + + ForEach($el in $elements_to_add) { + $idx = Get-IndexOfPathElement $existing_elements $el + + # add missing elements at the end + If ($idx -eq -1) { + $last_idx = $existing_elements.Add($el) + $changed = $true + } + ElseIf ($idx -lt $last_idx) { + $existing_elements.RemoveAt($idx) | Out-Null + $existing_elements.Add($el) | Out-Null + $last_idx = $existing_elements.Count - 1 + $changed = $true + } + Else { + $last_idx = $idx + } + } + + return $changed +} + +# alters list in place, returns true if at least one element was removed +Function Remove-Elements ($existing_elements, $elements_to_remove) { + $count = $existing_elements.Count + + ForEach($el in $elements_to_remove) { + $idx = Get-IndexOfPathElement $existing_elements $el + $result.removed_idx = $idx + If ($idx -gt -1) { + $existing_elements.RemoveAt($idx) + } + } + + return $count -ne $existing_elements.Count +} + +# PS registry provider doesn't allow access to unexpanded REG_EXPAND_SZ; fall back to .NET +Function Get-RawPathVar ($scope) { + If ($scope -eq "user") { + $env_key = [Microsoft.Win32.Registry]::CurrentUser.OpenSubKey($user_path) + } + ElseIf ($scope -eq "machine") { + $env_key = [Microsoft.Win32.Registry]::LocalMachine.OpenSubKey($system_path) + } + + return $env_key.GetValue($var_name, "", [Microsoft.Win32.RegistryValueOptions]::DoNotExpandEnvironmentNames) +} + +Function Set-RawPathVar($path_value, $scope) { + If ($scope -eq "user") { + $var_path = "HKCU:\" + $user_path + } + ElseIf ($scope -eq "machine") { + $var_path = "HKLM:\" + $system_path + } + + Set-ItemProperty $var_path -Name $var_name -Value $path_value -Type ExpandString | Out-Null + + return $path_value +} + +$parsed_args = Parse-Args $args -supports_check_mode $true + +$result = @{changed=$false} + +$var_name = Get-AnsibleParam $parsed_args "name" -Default "PATH" +$elements = Get-AnsibleParam $parsed_args "elements" -FailIfEmpty $result +$state = Get-AnsibleParam $parsed_args "state" -Default "present" -ValidateSet "present","absent" +$scope = Get-AnsibleParam $parsed_args "scope" -Default "machine" -ValidateSet "machine","user" + +$check_mode = Get-AnsibleParam $parsed_args "_ansible_check_mode" -Default $false + +If ($elements -is [string]) { + $elements = @($elements) +} + +If ($elements -isnot [Array]) { + Fail-Json $result "elements must be a string or list of path strings" +} + +$current_value = Get-RawPathVar $scope +$result.path_value = $current_value + +# TODO: test case-canonicalization on wacky unicode values (eg turkish i) +# TODO: detect and warn/fail on unparseable path? (eg, unbalanced quotes, invalid path chars) +# TODO: detect and warn/fail if system path and Powershell isn't on it? + +$existing_elements = New-Object System.Collections.ArrayList + +# split on semicolons, accounting for quoted values with embedded semicolons (which may or may not be wrapped in whitespace) +$pathsplit_re = [regex] '((?<q>\s*"[^"]+"\s*)|(?<q>[^;]+))(;$|$|;)' + +ForEach ($m in $pathsplit_re.Matches($current_value)) { + $existing_elements.Add($m.Groups['q'].Value) | Out-Null +} + +If ($state -eq "absent") { + $result.changed = Remove-Elements $existing_elements $elements +} +ElseIf ($state -eq "present") { + $result.changed = Add-Elements $existing_elements $elements +} + +# calculate the new path value from the existing elements +$path_value = [String]::Join(";", $existing_elements.ToArray()) +$result.path_value = $path_value + +If ($result.changed -and -not $check_mode) { + Set-RawPathVar $path_value $scope | Out-Null +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_path.py b/test/support/windows-integration/plugins/modules/win_path.py new file mode 100644 index 00000000..6404504f --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_path.py @@ -0,0 +1,79 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Red Hat | Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# This is a windows documentation stub. Actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_path +version_added: "2.3" +short_description: Manage Windows path environment variables +description: + - Allows element-based ordering, addition, and removal of Windows path environment variables. +options: + name: + description: + - Target path environment variable name. + type: str + default: PATH + elements: + description: + - A single path element, or a list of path elements (ie, directories) to add or remove. + - When multiple elements are included in the list (and C(state) is C(present)), the elements are guaranteed to appear in the same relative order + in the resultant path value. + - Variable expansions (eg, C(%VARNAME%)) are allowed, and are stored unexpanded in the target path element. + - Any existing path elements not mentioned in C(elements) are always preserved in their current order. + - New path elements are appended to the path, and existing path elements may be moved closer to the end to satisfy the requested ordering. + - Paths are compared in a case-insensitive fashion, and trailing backslashes are ignored for comparison purposes. However, note that trailing + backslashes in YAML require quotes. + type: list + required: yes + state: + description: + - Whether the path elements specified in C(elements) should be present or absent. + type: str + choices: [ absent, present ] + scope: + description: + - The level at which the environment variable specified by C(name) should be managed (either for the current user or global machine scope). + type: str + choices: [ machine, user ] + default: machine +notes: + - This module is for modifying individual elements of path-like + environment variables. For general-purpose management of other + environment vars, use the M(win_environment) module. + - This module does not broadcast change events. + This means that the minority of windows applications which can have + their environment changed without restarting will not be notified and + therefore will need restarting to pick up new environment settings. + - User level environment variables will require an interactive user to + log out and in again before they become available. +seealso: +- module: win_environment +author: +- Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +- name: Ensure that system32 and Powershell are present on the global system path, and in the specified order + win_path: + elements: + - '%SystemRoot%\system32' + - '%SystemRoot%\system32\WindowsPowerShell\v1.0' + +- name: Ensure that C:\Program Files\MyJavaThing is not on the current user's CLASSPATH + win_path: + name: CLASSPATH + elements: C:\Program Files\MyJavaThing + scope: user + state: absent +''' diff --git a/test/support/windows-integration/plugins/modules/win_ping.ps1 b/test/support/windows-integration/plugins/modules/win_ping.ps1 new file mode 100644 index 00000000..c848b912 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_ping.ps1 @@ -0,0 +1,21 @@ +#!powershell + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic + +$spec = @{ + options = @{ + data = @{ type = "str"; default = "pong" } + } + supports_check_mode = $true +} +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) +$data = $module.Params.data + +if ($data -eq "crash") { + throw "boom" +} + +$module.Result.ping = $data +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_ping.py b/test/support/windows-integration/plugins/modules/win_ping.py new file mode 100644 index 00000000..6d35f379 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_ping.py @@ -0,0 +1,55 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2012, Michael DeHaan <michael.dehaan@gmail.com>, and others +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_ping +version_added: "1.7" +short_description: A windows version of the classic ping module +description: + - Checks management connectivity of a windows host. + - This is NOT ICMP ping, this is just a trivial test module. + - For non-Windows targets, use the M(ping) module instead. + - For Network targets, use the M(net_ping) module instead. +options: + data: + description: + - Alternate data to return instead of 'pong'. + - If this parameter is set to C(crash), the module will cause an exception. + type: str + default: pong +seealso: +- module: ping +author: +- Chris Church (@cchurch) +''' + +EXAMPLES = r''' +# Test connectivity to a windows host +# ansible winserver -m win_ping + +- name: Example from an Ansible Playbook + win_ping: + +- name: Induce an exception to see what happens + win_ping: + data: crash +''' + +RETURN = r''' +ping: + description: Value provided with the data parameter. + returned: success + type: str + sample: pong +''' diff --git a/test/support/windows-integration/plugins/modules/win_psexec.ps1 b/test/support/windows-integration/plugins/modules/win_psexec.ps1 new file mode 100644 index 00000000..04a51270 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_psexec.ps1 @@ -0,0 +1,152 @@ +#!powershell + +# Copyright: (c) 2017, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.ArgvParser +#Requires -Module Ansible.ModuleUtils.CommandUtil + +# See also: https://technet.microsoft.com/en-us/sysinternals/pxexec.aspx + +$spec = @{ + options = @{ + command = @{ type='str'; required=$true } + executable = @{ type='path'; default='psexec.exe' } + hostnames = @{ type='list' } + username = @{ type='str' } + password = @{ type='str'; no_log=$true } + chdir = @{ type='path' } + wait = @{ type='bool'; default=$true } + nobanner = @{ type='bool'; default=$false } + noprofile = @{ type='bool'; default=$false } + elevated = @{ type='bool'; default=$false } + limited = @{ type='bool'; default=$false } + system = @{ type='bool'; default=$false } + interactive = @{ type='bool'; default=$false } + session = @{ type='int' } + priority = @{ type='str'; choices=@( 'background', 'low', 'belownormal', 'abovenormal', 'high', 'realtime' ) } + timeout = @{ type='int' } + } +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$command = $module.Params.command +$executable = $module.Params.executable +$hostnames = $module.Params.hostnames +$username = $module.Params.username +$password = $module.Params.password +$chdir = $module.Params.chdir +$wait = $module.Params.wait +$nobanner = $module.Params.nobanner +$noprofile = $module.Params.noprofile +$elevated = $module.Params.elevated +$limited = $module.Params.limited +$system = $module.Params.system +$interactive = $module.Params.interactive +$session = $module.Params.session +$priority = $module.Params.Priority +$timeout = $module.Params.timeout + +$module.Result.changed = $true + +If (-Not (Get-Command $executable -ErrorAction SilentlyContinue)) { + $module.FailJson("Executable '$executable' was not found.") +} + +$arguments = [System.Collections.Generic.List`1[String]]@($executable) + +If ($nobanner -eq $true) { + $arguments.Add("-nobanner") +} + +# Support running on local system if no hostname is specified +If ($hostnames) { + $hostname_argument = ($hostnames | sort -Unique) -join ',' + $arguments.Add("\\$hostname_argument") +} + +# Username is optional +If ($null -ne $username) { + $arguments.Add("-u") + $arguments.Add($username) +} + +# Password is optional +If ($null -ne $password) { + $arguments.Add("-p") + $arguments.Add($password) +} + +If ($null -ne $chdir) { + $arguments.Add("-w") + $arguments.Add($chdir) +} + +If ($wait -eq $false) { + $arguments.Add("-d") +} + +If ($noprofile -eq $true) { + $arguments.Add("-e") +} + +If ($elevated -eq $true) { + $arguments.Add("-h") +} + +If ($system -eq $true) { + $arguments.Add("-s") +} + +If ($interactive -eq $true) { + $arguments.Add("-i") + If ($null -ne $session) { + $arguments.Add($session) + } +} + +If ($limited -eq $true) { + $arguments.Add("-l") +} + +If ($null -ne $priority) { + $arguments.Add("-$priority") +} + +If ($null -ne $timeout) { + $arguments.Add("-n") + $arguments.Add($timeout) +} + +$arguments.Add("-accepteula") + +$argument_string = Argv-ToString -arguments $arguments + +# Add the command at the end of the argument string, we don't want to escape +# that as psexec doesn't expect it to be one arg +$argument_string += " $command" + +$start_datetime = [DateTime]::UtcNow +$module.Result.psexec_command = $argument_string + +$command_result = Run-Command -command $argument_string + +$end_datetime = [DateTime]::UtcNow + +$module.Result.stdout = $command_result.stdout +$module.Result.stderr = $command_result.stderr + +If ($wait -eq $true) { + $module.Result.rc = $command_result.rc +} else { + $module.Result.rc = 0 + $module.Result.pid = $command_result.rc +} + +$module.Result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$module.Result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$module.Result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff") + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_psexec.py b/test/support/windows-integration/plugins/modules/win_psexec.py new file mode 100644 index 00000000..c3fc37e4 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_psexec.py @@ -0,0 +1,172 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: 2017, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_psexec +version_added: '2.3' +short_description: Runs commands (remotely) as another (privileged) user +description: +- Run commands (remotely) through the PsExec service. +- Run commands as another (domain) user (with elevated privileges). +requirements: +- Microsoft PsExec +options: + command: + description: + - The command line to run through PsExec (limited to 260 characters). + type: str + required: yes + executable: + description: + - The location of the PsExec utility (in case it is not located in your PATH). + type: path + default: psexec.exe + hostnames: + description: + - The hostnames to run the command. + - If not provided, the command is run locally. + type: list + username: + description: + - The (remote) user to run the command as. + - If not provided, the current user is used. + type: str + password: + description: + - The password for the (remote) user to run the command as. + - This is mandatory in order authenticate yourself. + type: str + chdir: + description: + - Run the command from this (remote) directory. + type: path + nobanner: + description: + - Do not display the startup banner and copyright message. + - This only works for specific versions of the PsExec binary. + type: bool + default: no + version_added: '2.4' + noprofile: + description: + - Run the command without loading the account's profile. + type: bool + default: no + elevated: + description: + - Run the command with elevated privileges. + type: bool + default: no + interactive: + description: + - Run the program so that it interacts with the desktop on the remote system. + type: bool + default: no + session: + description: + - Specifies the session ID to use. + - This parameter works in conjunction with I(interactive). + - It has no effect when I(interactive) is set to C(no). + type: int + version_added: '2.7' + limited: + description: + - Run the command as limited user (strips the Administrators group and allows only privileges assigned to the Users group). + type: bool + default: no + system: + description: + - Run the remote command in the System account. + type: bool + default: no + priority: + description: + - Used to run the command at a different priority. + choices: [ abovenormal, background, belownormal, high, low, realtime ] + timeout: + description: + - The connection timeout in seconds + type: int + wait: + description: + - Wait for the application to terminate. + - Only use for non-interactive applications. + type: bool + default: yes +notes: +- More information related to Microsoft PsExec is available from + U(https://technet.microsoft.com/en-us/sysinternals/bb897553.aspx) +seealso: +- module: psexec +- module: raw +- module: win_command +- module: win_shell +author: +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r''' +- name: Test the PsExec connection to the local system (target node) with your user + win_psexec: + command: whoami.exe + +- name: Run regedit.exe locally (on target node) as SYSTEM and interactively + win_psexec: + command: regedit.exe + interactive: yes + system: yes + +- name: Run the setup.exe installer on multiple servers using the Domain Administrator + win_psexec: + command: E:\setup.exe /i /IACCEPTEULA + hostnames: + - remote_server1 + - remote_server2 + username: DOMAIN\Administrator + password: some_password + priority: high + +- name: Run PsExec from custom location C:\Program Files\sysinternals\ + win_psexec: + command: netsh advfirewall set allprofiles state off + executable: C:\Program Files\sysinternals\psexec.exe + hostnames: [ remote_server ] + password: some_password + priority: low +''' + +RETURN = r''' +cmd: + description: The complete command line used by the module, including PsExec call and additional options. + returned: always + type: str + sample: psexec.exe -nobanner \\remote_server -u "DOMAIN\Administrator" -p "some_password" -accepteula E:\setup.exe +pid: + description: The PID of the async process created by PsExec. + returned: when C(wait=False) + type: int + sample: 1532 +rc: + description: The return code for the command. + returned: always + type: int + sample: 0 +stdout: + description: The standard output from the command. + returned: always + type: str + sample: Success. +stderr: + description: The error output from the command. + returned: always + type: str + sample: Error 15 running E:\setup.exe +''' diff --git a/test/support/windows-integration/plugins/modules/win_reboot.py b/test/support/windows-integration/plugins/modules/win_reboot.py new file mode 100644 index 00000000..14318041 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_reboot.py @@ -0,0 +1,131 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_reboot +short_description: Reboot a windows machine +description: +- Reboot a Windows machine, wait for it to go down, come back up, and respond to commands. +- For non-Windows targets, use the M(reboot) module instead. +version_added: '2.1' +options: + pre_reboot_delay: + description: + - Seconds to wait before reboot. Passed as a parameter to the reboot command. + type: int + default: 2 + aliases: [ pre_reboot_delay_sec ] + post_reboot_delay: + description: + - Seconds to wait after the reboot command was successful before attempting to validate the system rebooted successfully. + - This is useful if you want wait for something to settle despite your connection already working. + type: int + default: 0 + version_added: '2.4' + aliases: [ post_reboot_delay_sec ] + shutdown_timeout: + description: + - Maximum seconds to wait for shutdown to occur. + - Increase this timeout for very slow hardware, large update applications, etc. + - This option has been removed since Ansible 2.5 as the win_reboot behavior has changed. + type: int + default: 600 + aliases: [ shutdown_timeout_sec ] + reboot_timeout: + description: + - Maximum seconds to wait for machine to re-appear on the network and respond to a test command. + - This timeout is evaluated separately for both reboot verification and test command success so maximum clock time is actually twice this value. + type: int + default: 600 + aliases: [ reboot_timeout_sec ] + connect_timeout: + description: + - Maximum seconds to wait for a single successful TCP connection to the WinRM endpoint before trying again. + type: int + default: 5 + aliases: [ connect_timeout_sec ] + test_command: + description: + - Command to expect success for to determine the machine is ready for management. + type: str + default: whoami + msg: + description: + - Message to display to users. + type: str + default: Reboot initiated by Ansible + boot_time_command: + description: + - Command to run that returns a unique string indicating the last time the system was booted. + - Setting this to a command that has different output each time it is run will cause the task to fail. + type: str + default: '(Get-WmiObject -ClassName Win32_OperatingSystem).LastBootUpTime' + version_added: '2.10' +notes: +- If a shutdown was already scheduled on the system, C(win_reboot) will abort the scheduled shutdown and enforce its own shutdown. +- Beware that when C(win_reboot) returns, the Windows system may not have settled yet and some base services could be in limbo. + This can result in unexpected behavior. Check the examples for ways to mitigate this. +- The connection user must have the C(SeRemoteShutdownPrivilege) privilege enabled, see + U(https://docs.microsoft.com/en-us/windows/security/threat-protection/security-policy-settings/force-shutdown-from-a-remote-system) + for more information. +seealso: +- module: reboot +author: +- Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +- name: Reboot the machine with all defaults + win_reboot: + +- name: Reboot a slow machine that might have lots of updates to apply + win_reboot: + reboot_timeout: 3600 + +# Install a Windows feature and reboot if necessary +- name: Install IIS Web-Server + win_feature: + name: Web-Server + register: iis_install + +- name: Reboot when Web-Server feature requires it + win_reboot: + when: iis_install.reboot_required + +# One way to ensure the system is reliable, is to set WinRM to a delayed startup +- name: Ensure WinRM starts when the system has settled and is ready to work reliably + win_service: + name: WinRM + start_mode: delayed + + +# Additionally, you can add a delay before running the next task +- name: Reboot a machine that takes time to settle after being booted + win_reboot: + post_reboot_delay: 120 + +# Or you can make win_reboot validate exactly what you need to work before running the next task +- name: Validate that the netlogon service has started, before running the next task + win_reboot: + test_command: 'exit (Get-Service -Name Netlogon).Status -ne "Running"' +''' + +RETURN = r''' +rebooted: + description: True if the machine was rebooted. + returned: always + type: bool + sample: true +elapsed: + description: The number of seconds that elapsed waiting for the system to be rebooted. + returned: always + type: float + sample: 23.2 +''' diff --git a/test/support/windows-integration/plugins/modules/win_regedit.ps1 b/test/support/windows-integration/plugins/modules/win_regedit.ps1 new file mode 100644 index 00000000..c56b4833 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_regedit.ps1 @@ -0,0 +1,495 @@ +#!powershell + +# Copyright: (c) 2015, Adam Keech <akeech@chathamfinancial.com> +# Copyright: (c) 2015, Josh Ludwig <jludwig@chathamfinancial.com> +# Copyright: (c) 2017, Jordan Borean <jborean93@gmail.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.PrivilegeUtil + +$params = Parse-Args -arguments $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$path = Get-AnsibleParam -obj $params -name "path" -type "str" -failifempty $true -aliases "key" +$name = Get-AnsibleParam -obj $params -name "name" -type "str" -aliases "entry","value" +$data = Get-AnsibleParam -obj $params -name "data" +$type = Get-AnsibleParam -obj $params -name "type" -type "str" -default "string" -validateset "none","binary","dword","expandstring","multistring","string","qword" -aliases "datatype" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent" +$delete_key = Get-AnsibleParam -obj $params -name "delete_key" -type "bool" -default $true +$hive = Get-AnsibleParam -obj $params -name "hive" -type "path" + +$result = @{ + changed = $false + data_changed = $false + data_type_changed = $false +} + +if ($diff_mode) { + $result.diff = @{ + before = "" + after = "" + } +} + +$registry_util = @' +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; + +namespace Ansible.WinRegedit +{ + internal class NativeMethods + { + [DllImport("advapi32.dll", CharSet = CharSet.Unicode)] + public static extern int RegLoadKeyW( + UInt32 hKey, + string lpSubKey, + string lpFile); + + [DllImport("advapi32.dll", CharSet = CharSet.Unicode)] + public static extern int RegUnLoadKeyW( + UInt32 hKey, + string lpSubKey); + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + public class Hive : IDisposable + { + private const UInt32 SCOPE = 0x80000002; // HKLM + private string hiveKey; + private bool loaded = false; + + public Hive(string hiveKey, string hivePath) + { + this.hiveKey = hiveKey; + int ret = NativeMethods.RegLoadKeyW(SCOPE, hiveKey, hivePath); + if (ret != 0) + throw new Win32Exception(ret, String.Format("Failed to load registry hive at {0}", hivePath)); + loaded = true; + } + + public static void UnloadHive(string hiveKey) + { + int ret = NativeMethods.RegUnLoadKeyW(SCOPE, hiveKey); + if (ret != 0) + throw new Win32Exception(ret, String.Format("Failed to unload registry hive at {0}", hiveKey)); + } + + public void Dispose() + { + if (loaded) + { + // Make sure the garbage collector disposes all unused handles and waits until it is complete + GC.Collect(); + GC.WaitForPendingFinalizers(); + + UnloadHive(hiveKey); + loaded = false; + } + GC.SuppressFinalize(this); + } + ~Hive() { this.Dispose(); } + } +} +'@ + +# fire a warning if the property name isn't specified, the (Default) key ($null) can only be a string +if ($null -eq $name -and $type -ne "string") { + Add-Warning -obj $result -message "the data type when name is not specified can only be 'string', the type has automatically been converted" + $type = "string" +} + +# Check that the registry path is in PSDrive format: HKCC, HKCR, HKCU, HKLM, HKU +if ($path -notmatch "^HK(CC|CR|CU|LM|U):\\") { + Fail-Json $result "path: $path is not a valid powershell path, see module documentation for examples." +} + +# Add a warning if the path does not contains a \ and is not the leaf path +$registry_path = (Split-Path -Path $path -NoQualifier).Substring(1) # removes the hive: and leading \ +$registry_leaf = Split-Path -Path $path -Leaf +if ($registry_path -ne $registry_leaf -and -not $registry_path.Contains('\')) { + $msg = "path is not using '\' as a separator, support for '/' as a separator will be removed in a future Ansible version" + Add-DeprecationWarning -obj $result -message $msg -version 2.12 + $registry_path = $registry_path.Replace('/', '\') +} + +# Simplified version of Convert-HexStringToByteArray from +# https://cyber-defense.sans.org/blog/2010/02/11/powershell-byte-array-hex-convert +# Expects a hex in the format you get when you run reg.exe export, +# and converts to a byte array so powershell can modify binary registry entries +# import format is like 'hex:be,ef,be,ef,be,ef,be,ef,be,ef' +Function Convert-RegExportHexStringToByteArray($string) { + # Remove 'hex:' from the front of the string if present + $string = $string.ToLower() -replace '^hex\:','' + + # Remove whitespace and any other non-hex crud. + $string = $string -replace '[^a-f0-9\\,x\-\:]','' + + # Turn commas into colons + $string = $string -replace ',',':' + + # Maybe there's nothing left over to convert... + if ($string.Length -eq 0) { + return ,@() + } + + # Split string with or without colon delimiters. + if ($string.Length -eq 1) { + return ,@([System.Convert]::ToByte($string,16)) + } elseif (($string.Length % 2 -eq 0) -and ($string.IndexOf(":") -eq -1)) { + return ,@($string -split '([a-f0-9]{2})' | foreach-object { if ($_) {[System.Convert]::ToByte($_,16)}}) + } elseif ($string.IndexOf(":") -ne -1) { + return ,@($string -split ':+' | foreach-object {[System.Convert]::ToByte($_,16)}) + } else { + return ,@() + } +} + +Function Compare-RegistryProperties($existing, $new) { + # Outputs $true if the property values don't match + if ($existing -is [Array]) { + (Compare-Object -ReferenceObject $existing -DifferenceObject $new -SyncWindow 0).Length -ne 0 + } else { + $existing -cne $new + } +} + +Function Get-DiffValue { + param( + [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryValueKind]$Type, + [Parameter(Mandatory=$true)][Object]$Value + ) + + $diff = @{ type = $Type.ToString(); value = $Value } + + $enum = [Microsoft.Win32.RegistryValueKind] + if ($Type -in @($enum::Binary, $enum::None)) { + $diff.value = [System.Collections.Generic.List`1[String]]@() + foreach ($dec_value in $Value) { + $diff.value.Add("0x{0:x2}" -f $dec_value) + } + } elseif ($Type -eq $enum::DWord) { + $diff.value = "0x{0:x8}" -f $Value + } elseif ($Type -eq $enum::QWord) { + $diff.value = "0x{0:x16}" -f $Value + } + + return $diff +} + +Function Set-StateAbsent { + param( + # Used for diffs and exception messages to match up against Ansible input + [Parameter(Mandatory=$true)][String]$PrintPath, + [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryKey]$Hive, + [Parameter(Mandatory=$true)][String]$Path, + [String]$Name, + [Switch]$DeleteKey + ) + + $key = $Hive.OpenSubKey($Path, $true) + if ($null -eq $key) { + # Key does not exist, no need to delete anything + return + } + + try { + if ($DeleteKey -and -not $Name) { + # delete_key=yes is set and name is null/empty, so delete the entire key + $key.Dispose() + $key = $null + if (-not $check_mode) { + try { + $Hive.DeleteSubKeyTree($Path, $false) + } catch { + Fail-Json -obj $result -message "failed to delete registry key at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.before = @{$PrintPath = @{}} + $result.diff.after = @{} + } + } else { + # delete_key=no or name is not null/empty, delete the property not the full key + $property = $key.GetValue($Name) + if ($null -eq $property) { + # property does not exist + return + } + $property_type = $key.GetValueKind($Name) # used for the diff + + if (-not $check_mode) { + try { + $key.DeleteValue($Name) + } catch { + Fail-Json -obj $result -message "failed to delete registry property '$Name' at $($PrintPath): $($_.Exception.Message)" + } + } + + $result.changed = $true + if ($diff_mode) { + $diff_value = Get-DiffValue -Type $property_type -Value $property + $result.diff.before = @{ $PrintPath = @{ $Name = $diff_value } } + $result.diff.after = @{ $PrintPath = @{} } + } + } + } finally { + if ($key) { + $key.Dispose() + } + } +} + +Function Set-StatePresent { + param( + [Parameter(Mandatory=$true)][String]$PrintPath, + [Parameter(Mandatory=$true)][Microsoft.Win32.RegistryKey]$Hive, + [Parameter(Mandatory=$true)][String]$Path, + [String]$Name, + [Object]$Data, + [Microsoft.Win32.RegistryValueKind]$Type + ) + + $key = $Hive.OpenSubKey($Path, $true) + try { + if ($null -eq $key) { + # the key does not exist, create it so the next steps work + if (-not $check_mode) { + try { + $key = $Hive.CreateSubKey($Path) + } catch { + Fail-Json -obj $result -message "failed to create registry key at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.before = @{} + $result.diff.after = @{$PrintPath = @{}} + } + } elseif ($diff_mode) { + # Make sure the diff is in an expected state for the key + $result.diff.before = @{$PrintPath = @{}} + $result.diff.after = @{$PrintPath = @{}} + } + + if ($null -eq $key -or $null -eq $Data) { + # Check mode and key was created above, we cannot do any more work, or $Data is $null which happens when + # we create a new key but haven't explicitly set the data + return + } + + $property = $key.GetValue($Name, $null, [Microsoft.Win32.RegistryValueOptions]::DoNotExpandEnvironmentNames) + if ($null -ne $property) { + # property exists, need to compare the values and type + $existing_type = $key.GetValueKind($name) + $change_value = $false + + if ($Type -ne $existing_type) { + $change_value = $true + $result.data_type_changed = $true + $data_mismatch = Compare-RegistryProperties -existing $property -new $Data + if ($data_mismatch) { + $result.data_changed = $true + } + } else { + $data_mismatch = Compare-RegistryProperties -existing $property -new $Data + if ($data_mismatch) { + $change_value = $true + $result.data_changed = $true + } + } + + if ($change_value) { + if (-not $check_mode) { + try { + $key.SetValue($Name, $Data, $Type) + } catch { + Fail-Json -obj $result -message "failed to change registry property '$Name' at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.before.$PrintPath.$Name = Get-DiffValue -Type $existing_type -Value $property + $result.diff.after.$PrintPath.$Name = Get-DiffValue -Type $Type -Value $Data + } + } elseif ($diff_mode) { + $diff_value = Get-DiffValue -Type $existing_type -Value $property + $result.diff.before.$PrintPath.$Name = $diff_value + $result.diff.after.$PrintPath.$Name = $diff_value + } + } else { + # property doesn't exist just create a new one + if (-not $check_mode) { + try { + $key.SetValue($Name, $Data, $Type) + } catch { + Fail-Json -obj $result -message "failed to create registry property '$Name' at $($PrintPath): $($_.Exception.Message)" + } + } + $result.changed = $true + + if ($diff_mode) { + $result.diff.after.$PrintPath.$Name = Get-DiffValue -Type $Type -Value $Data + } + } + } finally { + if ($key) { + $key.Dispose() + } + } +} + +# convert property names "" to $null as "" refers to (Default) +if ($name -eq "") { + $name = $null +} + +# convert the data to the required format +if ($type -in @("binary", "none")) { + if ($null -eq $data) { + $data = "" + } + + # convert the data from string to byte array if in hex: format + if ($data -is [String]) { + $data = [byte[]](Convert-RegExportHexStringToByteArray -string $data) + } elseif ($data -is [Int]) { + if ($data -gt 255) { + Fail-Json $result "cannot convert binary data '$data' to byte array, please specify this value as a yaml byte array or a comma separated hex value string" + } + $data = [byte[]]@([byte]$data) + } elseif ($data -is [Array]) { + $data = [byte[]]$data + } +} elseif ($type -in @("dword", "qword")) { + # dword's and dword's don't allow null values, set to 0 + if ($null -eq $data) { + $data = 0 + } + + if ($data -is [String]) { + # if the data is a string we need to convert it to an unsigned int64 + # it needs to be unsigned as Ansible passes in an unsigned value while + # powershell uses a signed data type. The value will then be converted + # below + $data = [UInt64]$data + } + + if ($type -eq "dword") { + if ($data -gt [UInt32]::MaxValue) { + Fail-Json $result "data cannot be larger than 0xffffffff when type is dword" + } elseif ($data -gt [Int32]::MaxValue) { + # when dealing with larger int32 (> 2147483647 or 0x7FFFFFFF) powershell + # automatically converts it to a signed int64. We need to convert this to + # signed int32 by parsing the hex string value. + $data = "0x$("{0:x}" -f $data)" + } + $data = [Int32]$data + } else { + if ($data -gt [UInt64]::MaxValue) { + Fail-Json $result "data cannot be larger than 0xffffffffffffffff when type is qword" + } elseif ($data -gt [Int64]::MaxValue) { + $data = "0x$("{0:x}" -f $data)" + } + $data = [Int64]$data + } +} elseif ($type -in @("string", "expandstring") -and $name) { + # a null string or expandstring must be empty quotes + # Only do this if $name has been defined (not the default key) + if ($null -eq $data) { + $data = "" + } +} elseif ($type -eq "multistring") { + # convert the data for a multistring to a String[] array + if ($null -eq $data) { + $data = [String[]]@() + } elseif ($data -isnot [Array]) { + $new_data = New-Object -TypeName String[] -ArgumentList 1 + $new_data[0] = $data.ToString([CultureInfo]::InvariantCulture) + $data = $new_data + } else { + $new_data = New-Object -TypeName String[] -ArgumentList $data.Count + foreach ($entry in $data) { + $new_data[$data.IndexOf($entry)] = $entry.ToString([CultureInfo]::InvariantCulture) + } + $data = $new_data + } +} + +# convert the type string to the .NET class +$type = [System.Enum]::Parse([Microsoft.Win32.RegistryValueKind], $type, $true) + +$registry_hive = switch(Split-Path -Path $path -Qualifier) { + "HKCR:" { [Microsoft.Win32.Registry]::ClassesRoot } + "HKCC:" { [Microsoft.Win32.Registry]::CurrentConfig } + "HKCU:" { [Microsoft.Win32.Registry]::CurrentUser } + "HKLM:" { [Microsoft.Win32.Registry]::LocalMachine } + "HKU:" { [Microsoft.Win32.Registry]::Users } +} +$loaded_hive = $null +try { + if ($hive) { + if (-not (Test-Path -LiteralPath $hive)) { + Fail-Json -obj $result -message "hive at path '$hive' is not valid or accessible, cannot load hive" + } + + $original_tmp = $env:TMP + $env:TMP = $_remote_tmp + Add-Type -TypeDefinition $registry_util + $env:TMP = $original_tmp + + try { + Set-AnsiblePrivilege -Name SeBackupPrivilege -Value $true + Set-AnsiblePrivilege -Name SeRestorePrivilege -Value $true + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "failed to enable SeBackupPrivilege and SeRestorePrivilege for the current process: $($_.Exception.Message)" + } + + if (Test-Path -Path HKLM:\ANSIBLE) { + Add-Warning -obj $result -message "hive already loaded at HKLM:\ANSIBLE, had to unload hive for win_regedit to continue" + try { + [Ansible.WinRegedit.Hive]::UnloadHive("ANSIBLE") + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "failed to unload registry hive HKLM:\ANSIBLE from $($hive): $($_.Exception.Message)" + } + } + + try { + $loaded_hive = New-Object -TypeName Ansible.WinRegedit.Hive -ArgumentList "ANSIBLE", $hive + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "failed to load registry hive from '$hive' to HKLM:\ANSIBLE: $($_.Exception.Message)" + } + } + + if ($state -eq "present") { + Set-StatePresent -PrintPath $path -Hive $registry_hive -Path $registry_path -Name $name -Data $data -Type $type + } else { + Set-StateAbsent -PrintPath $path -Hive $registry_hive -Path $registry_path -Name $name -DeleteKey:$delete_key + } +} finally { + $registry_hive.Dispose() + if ($loaded_hive) { + $loaded_hive.Dispose() + } +} + +Exit-Json $result + diff --git a/test/support/windows-integration/plugins/modules/win_regedit.py b/test/support/windows-integration/plugins/modules/win_regedit.py new file mode 100644 index 00000000..2c0fff71 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_regedit.py @@ -0,0 +1,210 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Adam Keech <akeech@chathamfinancial.com> +# Copyright: (c) 2015, Josh Ludwig <jludwig@chathamfinancial.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + + +DOCUMENTATION = r''' +--- +module: win_regedit +version_added: '2.0' +short_description: Add, change, or remove registry keys and values +description: +- Add, modify or remove registry keys and values. +- More information about the windows registry from Wikipedia + U(https://en.wikipedia.org/wiki/Windows_Registry). +options: + path: + description: + - Name of the registry path. + - 'Should be in one of the following registry hives: HKCC, HKCR, HKCU, + HKLM, HKU.' + type: str + required: yes + aliases: [ key ] + name: + description: + - Name of the registry entry in the above C(path) parameters. + - If not provided, or empty then the '(Default)' property for the key will + be used. + type: str + aliases: [ entry, value ] + data: + description: + - Value of the registry entry C(name) in C(path). + - If not specified then the value for the property will be null for the + corresponding C(type). + - Binary and None data should be expressed in a yaml byte array or as comma + separated hex values. + - An easy way to generate this is to run C(regedit.exe) and use the + I(export) option to save the registry values to a file. + - In the exported file, binary value will look like C(hex:be,ef,be,ef), the + C(hex:) prefix is optional. + - DWORD and QWORD values should either be represented as a decimal number + or a hex value. + - Multistring values should be passed in as a list. + - See the examples for more details on how to format this data. + type: str + type: + description: + - The registry value data type. + type: str + choices: [ binary, dword, expandstring, multistring, string, qword ] + default: string + aliases: [ datatype ] + state: + description: + - The state of the registry entry. + type: str + choices: [ absent, present ] + default: present + delete_key: + description: + - When C(state) is 'absent' then this will delete the entire key. + - If C(no) then it will only clear out the '(Default)' property for + that key. + type: bool + default: yes + version_added: '2.4' + hive: + description: + - A path to a hive key like C:\Users\Default\NTUSER.DAT to load in the + registry. + - This hive is loaded under the HKLM:\ANSIBLE key which can then be used + in I(name) like any other path. + - This can be used to load the default user profile registry hive or any + other hive saved as a file. + - Using this function requires the user to have the C(SeRestorePrivilege) + and C(SeBackupPrivilege) privileges enabled. + type: path + version_added: '2.5' +notes: +- Check-mode C(-C/--check) and diff output C(-D/--diff) are supported, so that you can test every change against the active configuration before + applying changes. +- Beware that some registry hives (C(HKEY_USERS) in particular) do not allow to create new registry paths in the root folder. +- Since ansible 2.4, when checking if a string registry value has changed, a case-sensitive test is used. Previously the test was case-insensitive. +seealso: +- module: win_reg_stat +- module: win_regmerge +author: +- Adam Keech (@smadam813) +- Josh Ludwig (@joshludwig) +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Create registry path MyCompany + win_regedit: + path: HKCU:\Software\MyCompany + +- name: Add or update registry path MyCompany, with entry 'hello', and containing 'world' + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: world + +- name: Add or update registry path MyCompany, with dword entry 'hello', and containing 1337 as the decimal value + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: 1337 + type: dword + +- name: Add or update registry path MyCompany, with dword entry 'hello', and containing 0xff2500ae as the hex value + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: 0xff2500ae + type: dword + +- name: Add or update registry path MyCompany, with binary entry 'hello', and containing binary data in hex-string format + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: hex:be,ef,be,ef,be,ef,be,ef,be,ef + type: binary + +- name: Add or update registry path MyCompany, with binary entry 'hello', and containing binary data in yaml format + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: [0xbe,0xef,0xbe,0xef,0xbe,0xef,0xbe,0xef,0xbe,0xef] + type: binary + +- name: Add or update registry path MyCompany, with expand string entry 'hello' + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: '%appdata%\local' + type: expandstring + +- name: Add or update registry path MyCompany, with multi string entry 'hello' + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + data: ['hello', 'world'] + type: multistring + +- name: Disable keyboard layout hotkey for all users (changes existing) + win_regedit: + path: HKU:\.DEFAULT\Keyboard Layout\Toggle + name: Layout Hotkey + data: 3 + type: dword + +- name: Disable language hotkey for current users (adds new) + win_regedit: + path: HKCU:\Keyboard Layout\Toggle + name: Language Hotkey + data: 3 + type: dword + +- name: Remove registry path MyCompany (including all entries it contains) + win_regedit: + path: HKCU:\Software\MyCompany + state: absent + delete_key: yes + +- name: Clear the existing (Default) entry at path MyCompany + win_regedit: + path: HKCU:\Software\MyCompany + state: absent + delete_key: no + +- name: Remove entry 'hello' from registry path MyCompany + win_regedit: + path: HKCU:\Software\MyCompany + name: hello + state: absent + +- name: Change default mouse trailing settings for new users + win_regedit: + path: HKLM:\ANSIBLE\Control Panel\Mouse + name: MouseTrails + data: 10 + type: str + state: present + hive: C:\Users\Default\NTUSER.dat +''' + +RETURN = r''' +data_changed: + description: Whether this invocation changed the data in the registry value. + returned: success + type: bool + sample: false +data_type_changed: + description: Whether this invocation changed the datatype of the registry value. + returned: success + type: bool + sample: true +''' diff --git a/test/support/windows-integration/plugins/modules/win_security_policy.ps1 b/test/support/windows-integration/plugins/modules/win_security_policy.ps1 new file mode 100644 index 00000000..274204b6 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_security_policy.ps1 @@ -0,0 +1,196 @@ +#!powershell + +# Copyright: (c) 2017, Jordan Borean <jborean93@gmail.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy + +$ErrorActionPreference = 'Stop' + +$params = Parse-Args $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $Params -name "_ansible_diff" -type "bool" -default $false + +$section = Get-AnsibleParam -obj $params -name "section" -type "str" -failifempty $true +$key = Get-AnsibleParam -obj $params -name "key" -type "str" -failifempty $true +$value = Get-AnsibleParam -obj $params -name "value" -failifempty $true + +$result = @{ + changed = $false + section = $section + key = $key + value = $value +} + +if ($diff_mode) { + $result.diff = @{} +} + +Function Run-SecEdit($arguments) { + $stdout = $null + $stderr = $null + $log_path = [IO.Path]::GetTempFileName() + $arguments = $arguments + @("/log", $log_path) + + try { + $stdout = &SecEdit.exe $arguments | Out-String + } catch { + $stderr = $_.Exception.Message + } + $log = Get-Content -Path $log_path + Remove-Item -Path $log_path -Force + + $return = @{ + log = ($log -join "`n").Trim() + stdout = $stdout + stderr = $stderr + rc = $LASTEXITCODE + } + + return $return +} + +Function Export-SecEdit() { + $secedit_ini_path = [IO.Path]::GetTempFileName() + # while this will technically make a change to the system in check mode by + # creating a new file, we need these values to be able to do anything + # substantial in check mode + $export_result = Run-SecEdit -arguments @("/export", "/cfg", $secedit_ini_path, "/quiet") + + # check the return code and if the file has been populated, otherwise error out + if (($export_result.rc -ne 0) -or ((Get-Item -Path $secedit_ini_path).Length -eq 0)) { + Remove-Item -Path $secedit_ini_path -Force + $result.rc = $export_result.rc + $result.stdout = $export_result.stdout + $result.stderr = $export_result.stderr + Fail-Json $result "Failed to export secedit.ini file to $($secedit_ini_path)" + } + $secedit_ini = ConvertFrom-Ini -file_path $secedit_ini_path + + return $secedit_ini +} + +Function Import-SecEdit($ini) { + $secedit_ini_path = [IO.Path]::GetTempFileName() + $secedit_db_path = [IO.Path]::GetTempFileName() + Remove-Item -Path $secedit_db_path -Force # needs to be deleted for SecEdit.exe /import to work + + $ini_contents = ConvertTo-Ini -ini $ini + Set-Content -Path $secedit_ini_path -Value $ini_contents + $result.changed = $true + + $import_result = Run-SecEdit -arguments @("/configure", "/db", $secedit_db_path, "/cfg", $secedit_ini_path, "/quiet") + $result.import_log = $import_result.log + Remove-Item -Path $secedit_ini_path -Force + if ($import_result.rc -ne 0) { + $result.rc = $import_result.rc + $result.stdout = $import_result.stdout + $result.stderr = $import_result.stderr + Fail-Json $result "Failed to import secedit.ini file from $($secedit_ini_path)" + } +} + +Function ConvertTo-Ini($ini) { + $content = @() + foreach ($key in $ini.GetEnumerator()) { + $section = $key.Name + $values = $key.Value + + $content += "[$section]" + foreach ($value in $values.GetEnumerator()) { + $value_key = $value.Name + $value_value = $value.Value + + if ($null -ne $value_value) { + $content += "$value_key = $value_value" + } + } + } + + return $content -join "`r`n" +} + +Function ConvertFrom-Ini($file_path) { + $ini = @{} + switch -Regex -File $file_path { + "^\[(.+)\]" { + $section = $matches[1] + $ini.$section = @{} + } + "(.+?)\s*=(.*)" { + $name = $matches[1].Trim() + $value = $matches[2].Trim() + if ($value -match "^\d+$") { + $value = [int]$value + } elseif ($value.StartsWith('"') -and $value.EndsWith('"')) { + $value = $value.Substring(1, $value.Length - 2) + } + + $ini.$section.$name = $value + } + } + + return $ini +} + +if ($section -eq "Privilege Rights") { + Add-Warning -obj $result -message "Using this module to edit rights and privileges is error-prone, use the win_user_right module instead" +} + +$will_change = $false +$secedit_ini = Export-SecEdit +if (-not ($secedit_ini.ContainsKey($section))) { + Fail-Json $result "The section '$section' does not exist in SecEdit.exe output ini" +} + +if ($secedit_ini.$section.ContainsKey($key)) { + $current_value = $secedit_ini.$section.$key + + if ($current_value -cne $value) { + if ($diff_mode) { + $result.diff.prepared = @" +[$section] +-$key = $current_value ++$key = $value +"@ + } + + $secedit_ini.$section.$key = $value + $will_change = $true + } +} elseif ([string]$value -eq "") { + # Value is requested to be removed, and has already been removed, do nothing +} else { + if ($diff_mode) { + $result.diff.prepared = @" +[$section] ++$key = $value +"@ + } + $secedit_ini.$section.$key = $value + $will_change = $true +} + +if ($will_change -eq $true) { + $result.changed = $true + if (-not $check_mode) { + Import-SecEdit -ini $secedit_ini + + # secedit doesn't error out on improper entries, re-export and verify + # the changes occurred + $verification_ini = Export-SecEdit + $new_section_values = $verification_ini.$section + if ($new_section_values.ContainsKey($key)) { + $new_value = $new_section_values.$key + if ($new_value -cne $value) { + Fail-Json $result "Failed to change the value for key '$key' in section '$section', the value is still $new_value" + } + } elseif ([string]$value -eq "") { + # Value was empty, so OK if no longer in the result + } else { + Fail-Json $result "The key '$key' in section '$section' is not a valid key, cannot set this value" + } + } +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_security_policy.py b/test/support/windows-integration/plugins/modules/win_security_policy.py new file mode 100644 index 00000000..d582a532 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_security_policy.py @@ -0,0 +1,126 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub, actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_security_policy +version_added: '2.4' +short_description: Change local security policy settings +description: +- Allows you to set the local security policies that are configured by + SecEdit.exe. +options: + section: + description: + - The ini section the key exists in. + - If the section does not exist then the module will return an error. + - Example sections to use are 'Account Policies', 'Local Policies', + 'Event Log', 'Restricted Groups', 'System Services', 'Registry' and + 'File System' + - If wanting to edit the C(Privilege Rights) section, use the + M(win_user_right) module instead. + type: str + required: yes + key: + description: + - The ini key of the section or policy name to modify. + - The module will return an error if this key is invalid. + type: str + required: yes + value: + description: + - The value for the ini key or policy name. + - If the key takes in a boolean value then 0 = False and 1 = True. + type: str + required: yes +notes: +- This module uses the SecEdit.exe tool to configure the values, more details + of the areas and keys that can be configured can be found here + U(https://msdn.microsoft.com/en-us/library/bb742512.aspx). +- If you are in a domain environment these policies may be set by a GPO policy, + this module can temporarily change these values but the GPO will override + it if the value differs. +- You can also run C(SecEdit.exe /export /cfg C:\temp\output.ini) to view the + current policies set on your system. +- When assigning user rights, use the M(win_user_right) module instead. +seealso: +- module: win_user_right +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Change the guest account name + win_security_policy: + section: System Access + key: NewGuestName + value: Guest Account + +- name: Set the maximum password age + win_security_policy: + section: System Access + key: MaximumPasswordAge + value: 15 + +- name: Do not store passwords using reversible encryption + win_security_policy: + section: System Access + key: ClearTextPassword + value: 0 + +- name: Enable system events + win_security_policy: + section: Event Audit + key: AuditSystemEvents + value: 1 +''' + +RETURN = r''' +rc: + description: The return code after a failure when running SecEdit.exe. + returned: failure with secedit calls + type: int + sample: -1 +stdout: + description: The output of the STDOUT buffer after a failure when running + SecEdit.exe. + returned: failure with secedit calls + type: str + sample: check log for error details +stderr: + description: The output of the STDERR buffer after a failure when running + SecEdit.exe. + returned: failure with secedit calls + type: str + sample: failed to import security policy +import_log: + description: The log of the SecEdit.exe /configure job that configured the + local policies. This is used for debugging purposes on failures. + returned: secedit.exe /import run and change occurred + type: str + sample: Completed 6 percent (0/15) \tProcess Privilege Rights area. +key: + description: The key in the section passed to the module to modify. + returned: success + type: str + sample: NewGuestName +section: + description: The section passed to the module to modify. + returned: success + type: str + sample: System Access +value: + description: The value passed to the module to modify to. + returned: success + type: str + sample: Guest Account +''' diff --git a/test/support/windows-integration/plugins/modules/win_shell.ps1 b/test/support/windows-integration/plugins/modules/win_shell.ps1 new file mode 100644 index 00000000..54aef8de --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_shell.ps1 @@ -0,0 +1,138 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.CommandUtil +#Requires -Module Ansible.ModuleUtils.FileUtil + +# TODO: add check mode support + +Set-StrictMode -Version 2 +$ErrorActionPreference = "Stop" + +# Cleanse CLIXML from stderr (sift out error stream data, discard others for now) +Function Cleanse-Stderr($raw_stderr) { + Try { + # NB: this regex isn't perfect, but is decent at finding CLIXML amongst other stderr noise + If($raw_stderr -match "(?s)(?<prenoise1>.*)#< CLIXML(?<prenoise2>.*)(?<clixml><Objs.+</Objs>)(?<postnoise>.*)") { + $clixml = [xml]$matches["clixml"] + + $merged_stderr = "{0}{1}{2}{3}" -f @( + $matches["prenoise1"], + $matches["prenoise2"], + # filter out just the Error-tagged strings for now, and zap embedded CRLF chars + ($clixml.Objs.ChildNodes | Where-Object { $_.Name -eq 'S' } | Where-Object { $_.S -eq 'Error' } | ForEach-Object { $_.'#text'.Replace('_x000D__x000A_','') } | Out-String), + $matches["postnoise"]) | Out-String + + return $merged_stderr.Trim() + + # FUTURE: parse/return other streams + } + Else { + $raw_stderr + } + } + Catch { + "***EXCEPTION PARSING CLIXML: $_***" + $raw_stderr + } +} + +$params = Parse-Args $args -supports_check_mode $false + +$raw_command_line = Get-AnsibleParam -obj $params -name "_raw_params" -type "str" -failifempty $true +$chdir = Get-AnsibleParam -obj $params -name "chdir" -type "path" +$executable = Get-AnsibleParam -obj $params -name "executable" -type "path" +$creates = Get-AnsibleParam -obj $params -name "creates" -type "path" +$removes = Get-AnsibleParam -obj $params -name "removes" -type "path" +$stdin = Get-AnsibleParam -obj $params -name "stdin" -type "str" +$no_profile = Get-AnsibleParam -obj $params -name "no_profile" -type "bool" -default $false +$output_encoding_override = Get-AnsibleParam -obj $params -name "output_encoding_override" -type "str" + +$raw_command_line = $raw_command_line.Trim() + +$result = @{ + changed = $true + cmd = $raw_command_line +} + +if ($creates -and $(Test-AnsiblePath -Path $creates)) { + Exit-Json @{msg="skipped, since $creates exists";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +if ($removes -and -not $(Test-AnsiblePath -Path $removes)) { + Exit-Json @{msg="skipped, since $removes does not exist";cmd=$raw_command_line;changed=$false;skipped=$true;rc=0} +} + +$exec_args = $null +If(-not $executable -or $executable -eq "powershell") { + $exec_application = "powershell.exe" + + # force input encoding to preamble-free UTF8 so PS sub-processes (eg, Start-Job) don't blow up + $raw_command_line = "[Console]::InputEncoding = New-Object Text.UTF8Encoding `$false; " + $raw_command_line + + # Base64 encode the command so we don't have to worry about the various levels of escaping + $encoded_command = [Convert]::ToBase64String([System.Text.Encoding]::Unicode.GetBytes($raw_command_line)) + + if ($stdin) { + $exec_args = "-encodedcommand $encoded_command" + } else { + $exec_args = "-noninteractive -encodedcommand $encoded_command" + } + + if ($no_profile) { + $exec_args = "-noprofile $exec_args" + } +} +Else { + # FUTURE: support arg translation from executable (or executable_args?) to process arguments for arbitrary interpreter? + $exec_application = $executable + if (-not ($exec_application.EndsWith(".exe"))) { + $exec_application = "$($exec_application).exe" + } + $exec_args = "/c $raw_command_line" +} + +$command = "`"$exec_application`" $exec_args" +$run_command_arg = @{ + command = $command +} +if ($chdir) { + $run_command_arg['working_directory'] = $chdir +} +if ($stdin) { + $run_command_arg['stdin'] = $stdin +} +if ($output_encoding_override) { + $run_command_arg['output_encoding_override'] = $output_encoding_override +} + +$start_datetime = [DateTime]::UtcNow +try { + $command_result = Run-Command @run_command_arg +} catch { + $result.changed = $false + try { + $result.rc = $_.Exception.NativeErrorCode + } catch { + $result.rc = 2 + } + Fail-Json -obj $result -message $_.Exception.Message +} + +# TODO: decode CLIXML stderr output (and other streams?) +$result.stdout = $command_result.stdout +$result.stderr = Cleanse-Stderr $command_result.stderr +$result.rc = $command_result.rc + +$end_datetime = [DateTime]::UtcNow +$result.start = $start_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.end = $end_datetime.ToString("yyyy-MM-dd hh:mm:ss.ffffff") +$result.delta = $($end_datetime - $start_datetime).ToString("h\:mm\:ss\.ffffff") + +If ($result.rc -ne 0) { + Fail-Json -obj $result -message "non-zero return code" +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_shell.py b/test/support/windows-integration/plugins/modules/win_shell.py new file mode 100644 index 00000000..ee2cd762 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_shell.py @@ -0,0 +1,167 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Ansible, inc +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_shell +short_description: Execute shell commands on target hosts +version_added: 2.2 +description: + - The C(win_shell) module takes the command name followed by a list of space-delimited arguments. + It is similar to the M(win_command) module, but runs + the command via a shell (defaults to PowerShell) on the target host. + - For non-Windows targets, use the M(shell) module instead. +options: + free_form: + description: + - The C(win_shell) module takes a free form command to run. + - There is no parameter actually named 'free form'. See the examples! + type: str + required: yes + creates: + description: + - A path or path filter pattern; when the referenced path exists on the target host, the task will be skipped. + type: path + removes: + description: + - A path or path filter pattern; when the referenced path B(does not) exist on the target host, the task will be skipped. + type: path + chdir: + description: + - Set the specified path as the current working directory before executing a command + type: path + executable: + description: + - Change the shell used to execute the command (eg, C(cmd)). + - The target shell must accept a C(/c) parameter followed by the raw command line to be executed. + type: path + stdin: + description: + - Set the stdin of the command directly to the specified value. + type: str + version_added: '2.5' + no_profile: + description: + - Do not load the user profile before running a command. This is only valid + when using PowerShell as the executable. + type: bool + default: no + version_added: '2.8' + output_encoding_override: + description: + - This option overrides the encoding of stdout/stderr output. + - You can use this option when you need to run a command which ignore the console's codepage. + - You should only need to use this option in very rare circumstances. + - This value can be any valid encoding C(Name) based on the output of C([System.Text.Encoding]::GetEncodings()). + See U(https://docs.microsoft.com/dotnet/api/system.text.encoding.getencodings). + type: str + version_added: '2.10' +notes: + - If you want to run an executable securely and predictably, it may be + better to use the M(win_command) module instead. Best practices when writing + playbooks will follow the trend of using M(win_command) unless C(win_shell) is + explicitly required. When running ad-hoc commands, use your best judgement. + - WinRM will not return from a command execution until all child processes created have exited. + Thus, it is not possible to use C(win_shell) to spawn long-running child or background processes. + Consider creating a Windows service for managing background processes. +seealso: +- module: psexec +- module: raw +- module: script +- module: shell +- module: win_command +- module: win_psexec +author: + - Matt Davis (@nitzmahone) +''' + +EXAMPLES = r''' +# Execute a command in the remote shell; stdout goes to the specified +# file on the remote. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt + +# Change the working directory to somedir/ before executing the command. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt chdir=C:\somedir + +# You can also use the 'args' form to provide the options. This command +# will change the working directory to somedir/ and will only run when +# somedir/somelog.txt doesn't exist. +- win_shell: C:\somescript.ps1 >> C:\somelog.txt + args: + chdir: C:\somedir + creates: C:\somelog.txt + +# Run a command under a non-Powershell interpreter (cmd in this case) +- win_shell: echo %HOMEDIR% + args: + executable: cmd + register: homedir_out + +- name: Run multi-lined shell commands + win_shell: | + $value = Test-Path -Path C:\temp + if ($value) { + Remove-Item -Path C:\temp -Force + } + New-Item -Path C:\temp -ItemType Directory + +- name: Retrieve the input based on stdin + win_shell: '$string = [Console]::In.ReadToEnd(); Write-Output $string.Trim()' + args: + stdin: Input message +''' + +RETURN = r''' +msg: + description: Changed. + returned: always + type: bool + sample: true +start: + description: The command execution start time. + returned: always + type: str + sample: '2016-02-25 09:18:26.429568' +end: + description: The command execution end time. + returned: always + type: str + sample: '2016-02-25 09:18:26.755339' +delta: + description: The command execution delta time. + returned: always + type: str + sample: '0:00:00.325771' +stdout: + description: The command standard output. + returned: always + type: str + sample: 'Clustering node rabbit@slave1 with rabbit@master ...' +stderr: + description: The command standard error. + returned: always + type: str + sample: 'ls: cannot access foo: No such file or directory' +cmd: + description: The command executed by the task. + returned: always + type: str + sample: 'rabbitmqctl join_cluster rabbit@master' +rc: + description: The command return code (0 means success). + returned: always + type: int + sample: 0 +stdout_lines: + description: The command standard output split in lines. + returned: always + type: list + sample: [u'Clustering node rabbit@slave1 with rabbit@master ...'] +''' diff --git a/test/support/windows-integration/plugins/modules/win_stat.ps1 b/test/support/windows-integration/plugins/modules/win_stat.ps1 new file mode 100644 index 00000000..071eb11c --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_stat.ps1 @@ -0,0 +1,186 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic +#Requires -Module Ansible.ModuleUtils.FileUtil +#Requires -Module Ansible.ModuleUtils.LinkUtil + +function ConvertTo-Timestamp($start_date, $end_date) { + if ($start_date -and $end_date) { + return (New-TimeSpan -Start $start_date -End $end_date).TotalSeconds + } +} + +function Get-FileChecksum($path, $algorithm) { + switch ($algorithm) { + 'md5' { $sp = New-Object -TypeName System.Security.Cryptography.MD5CryptoServiceProvider } + 'sha1' { $sp = New-Object -TypeName System.Security.Cryptography.SHA1CryptoServiceProvider } + 'sha256' { $sp = New-Object -TypeName System.Security.Cryptography.SHA256CryptoServiceProvider } + 'sha384' { $sp = New-Object -TypeName System.Security.Cryptography.SHA384CryptoServiceProvider } + 'sha512' { $sp = New-Object -TypeName System.Security.Cryptography.SHA512CryptoServiceProvider } + default { Fail-Json -obj $result -message "Unsupported hash algorithm supplied '$algorithm'" } + } + + $fp = [System.IO.File]::Open($path, [System.IO.Filemode]::Open, [System.IO.FileAccess]::Read, [System.IO.FileShare]::ReadWrite) + try { + $hash = [System.BitConverter]::ToString($sp.ComputeHash($fp)).Replace("-", "").ToLower() + } finally { + $fp.Dispose() + } + + return $hash +} + +function Get-FileInfo { + param([String]$Path, [Switch]$Follow) + + $info = Get-AnsibleItem -Path $Path -ErrorAction SilentlyContinue + $link_info = $null + if ($null -ne $info) { + try { + $link_info = Get-Link -link_path $info.FullName + } catch { + $module.Warn("Failed to check/get link info for file: $($_.Exception.Message)") + } + + # If follow=true we want to follow the link all the way back to root object + if ($Follow -and $null -ne $link_info -and $link_info.Type -in @("SymbolicLink", "JunctionPoint")) { + $info, $link_info = Get-FileInfo -Path $link_info.AbsolutePath -Follow + } + } + + return $info, $link_info +} + +$spec = @{ + options = @{ + path = @{ type='path'; required=$true; aliases=@( 'dest', 'name' ) } + get_checksum = @{ type='bool'; default=$true } + checksum_algorithm = @{ type='str'; default='sha1'; choices=@( 'md5', 'sha1', 'sha256', 'sha384', 'sha512' ) } + follow = @{ type='bool'; default=$false } + } + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$path = $module.Params.path +$get_checksum = $module.Params.get_checksum +$checksum_algorithm = $module.Params.checksum_algorithm +$follow = $module.Params.follow + +$module.Result.stat = @{ exists=$false } + +Load-LinkUtils +$info, $link_info = Get-FileInfo -Path $path -Follow:$follow +If ($null -ne $info) { + $epoch_date = Get-Date -Date "01/01/1970" + $attributes = @() + foreach ($attribute in ($info.Attributes -split ',')) { + $attributes += $attribute.Trim() + } + + # default values that are always set, specific values are set below this + # but are kept commented for easier readability + $stat = @{ + exists = $true + attributes = $info.Attributes.ToString() + isarchive = ($attributes -contains "Archive") + isdir = $false + ishidden = ($attributes -contains "Hidden") + isjunction = $false + islnk = $false + isreadonly = ($attributes -contains "ReadOnly") + isreg = $false + isshared = $false + nlink = 1 # Number of links to the file (hard links), overriden below if islnk + # lnk_target = islnk or isjunction Target of the symlink. Note that relative paths remain relative + # lnk_source = islnk os isjunction Target of the symlink normalized for the remote filesystem + hlnk_targets = @() + creationtime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.CreationTime) + lastaccesstime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastAccessTime) + lastwritetime = (ConvertTo-Timestamp -start_date $epoch_date -end_date $info.LastWriteTime) + # size = a file and directory - calculated below + path = $info.FullName + filename = $info.Name + # extension = a file + # owner = set outsite this dict in case it fails + # sharename = a directory and isshared is True + # checksum = a file and get_checksum: True + } + try { + $stat.owner = $info.GetAccessControl().Owner + } catch { + # may not have rights, historical behaviour was to just set to $null + # due to ErrorActionPreference being set to "Continue" + $stat.owner = $null + } + + # values that are set according to the type of file + if ($info.Attributes.HasFlag([System.IO.FileAttributes]::Directory)) { + $stat.isdir = $true + $share_info = Get-CimInstance -ClassName Win32_Share -Filter "Path='$($stat.path -replace '\\', '\\')'" + if ($null -ne $share_info) { + $stat.isshared = $true + $stat.sharename = $share_info.Name + } + + try { + $size = 0 + foreach ($file in $info.EnumerateFiles("*", [System.IO.SearchOption]::AllDirectories)) { + $size += $file.Length + } + $stat.size = $size + } catch { + $stat.size = 0 + } + } else { + $stat.extension = $info.Extension + $stat.isreg = $true + $stat.size = $info.Length + + if ($get_checksum) { + try { + $stat.checksum = Get-FileChecksum -path $path -algorithm $checksum_algorithm + } catch { + $module.FailJson("Failed to get hash of file, set get_checksum to False to ignore this error: $($_.Exception.Message)", $_) + } + } + } + + # Get symbolic link, junction point, hard link info + if ($null -ne $link_info) { + switch ($link_info.Type) { + "SymbolicLink" { + $stat.islnk = $true + $stat.isreg = $false + $stat.lnk_target = $link_info.TargetPath + $stat.lnk_source = $link_info.AbsolutePath + break + } + "JunctionPoint" { + $stat.isjunction = $true + $stat.isreg = $false + $stat.lnk_target = $link_info.TargetPath + $stat.lnk_source = $link_info.AbsolutePath + break + } + "HardLink" { + $stat.lnk_type = "hard" + $stat.nlink = $link_info.HardTargets.Count + + # remove current path from the targets + $hlnk_targets = $link_info.HardTargets | Where-Object { $_ -ne $stat.path } + $stat.hlnk_targets = @($hlnk_targets) + break + } + } + } + + $module.Result.stat = $stat +} + +$module.ExitJson() + diff --git a/test/support/windows-integration/plugins/modules/win_stat.py b/test/support/windows-integration/plugins/modules/win_stat.py new file mode 100644 index 00000000..0676b5b2 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_stat.py @@ -0,0 +1,236 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_stat +version_added: "1.7" +short_description: Get information about Windows files +description: + - Returns information about a Windows file. + - For non-Windows targets, use the M(stat) module instead. +options: + path: + description: + - The full path of the file/object to get the facts of; both forward and + back slashes are accepted. + type: path + required: yes + aliases: [ dest, name ] + get_checksum: + description: + - Whether to return a checksum of the file (default sha1) + type: bool + default: yes + version_added: "2.1" + checksum_algorithm: + description: + - Algorithm to determine checksum of file. + - Will throw an error if the host is unable to use specified algorithm. + type: str + default: sha1 + choices: [ md5, sha1, sha256, sha384, sha512 ] + version_added: "2.3" + follow: + description: + - Whether to follow symlinks or junction points. + - In the case of C(path) pointing to another link, then that will + be followed until no more links are found. + type: bool + default: no + version_added: "2.8" +seealso: +- module: stat +- module: win_acl +- module: win_file +- module: win_owner +author: +- Chris Church (@cchurch) +''' + +EXAMPLES = r''' +- name: Obtain information about a file + win_stat: + path: C:\foo.ini + register: file_info + +- name: Obtain information about a folder + win_stat: + path: C:\bar + register: folder_info + +- name: Get MD5 checksum of a file + win_stat: + path: C:\foo.ini + get_checksum: yes + checksum_algorithm: md5 + register: md5_checksum + +- debug: + var: md5_checksum.stat.checksum + +- name: Get SHA1 checksum of file + win_stat: + path: C:\foo.ini + get_checksum: yes + register: sha1_checksum + +- debug: + var: sha1_checksum.stat.checksum + +- name: Get SHA256 checksum of file + win_stat: + path: C:\foo.ini + get_checksum: yes + checksum_algorithm: sha256 + register: sha256_checksum + +- debug: + var: sha256_checksum.stat.checksum +''' + +RETURN = r''' +changed: + description: Whether anything was changed + returned: always + type: bool + sample: true +stat: + description: dictionary containing all the stat data + returned: success + type: complex + contains: + attributes: + description: Attributes of the file at path in raw form. + returned: success, path exists + type: str + sample: "Archive, Hidden" + checksum: + description: The checksum of a file based on checksum_algorithm specified. + returned: success, path exist, path is a file, get_checksum == True + checksum_algorithm specified is supported + type: str + sample: 09cb79e8fc7453c84a07f644e441fd81623b7f98 + creationtime: + description: The create time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + exists: + description: If the path exists or not. + returned: success + type: bool + sample: true + extension: + description: The extension of the file at path. + returned: success, path exists, path is a file + type: str + sample: ".ps1" + filename: + description: The name of the file (without path). + returned: success, path exists, path is a file + type: str + sample: foo.ini + hlnk_targets: + description: List of other files pointing to the same file (hard links), excludes the current file. + returned: success, path exists + type: list + sample: + - C:\temp\file.txt + - C:\Windows\update.log + isarchive: + description: If the path is ready for archiving or not. + returned: success, path exists + type: bool + sample: true + isdir: + description: If the path is a directory or not. + returned: success, path exists + type: bool + sample: true + ishidden: + description: If the path is hidden or not. + returned: success, path exists + type: bool + sample: true + isjunction: + description: If the path is a junction point or not. + returned: success, path exists + type: bool + sample: true + islnk: + description: If the path is a symbolic link or not. + returned: success, path exists + type: bool + sample: true + isreadonly: + description: If the path is read only or not. + returned: success, path exists + type: bool + sample: true + isreg: + description: If the path is a regular file. + returned: success, path exists + type: bool + sample: true + isshared: + description: If the path is shared or not. + returned: success, path exists + type: bool + sample: true + lastaccesstime: + description: The last access time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lastwritetime: + description: The last modification time of the file represented in seconds since epoch. + returned: success, path exists + type: float + sample: 1477984205.15 + lnk_source: + description: Target of the symlink normalized for the remote filesystem. + returned: success, path exists and the path is a symbolic link or junction point + type: str + sample: C:\temp\link + lnk_target: + description: Target of the symlink. Note that relative paths remain relative. + returned: success, path exists and the path is a symbolic link or junction point + type: str + sample: ..\link + nlink: + description: Number of links to the file (hard links). + returned: success, path exists + type: int + sample: 1 + owner: + description: The owner of the file. + returned: success, path exists + type: str + sample: BUILTIN\Administrators + path: + description: The full absolute path to the file. + returned: success, path exists, file exists + type: str + sample: C:\foo.ini + sharename: + description: The name of share if folder is shared. + returned: success, path exists, file is a directory and isshared == True + type: str + sample: file-share + size: + description: The size in bytes of a file or folder. + returned: success, path exists, file is not a link + type: int + sample: 1024 +''' diff --git a/test/support/windows-integration/plugins/modules/win_tempfile.ps1 b/test/support/windows-integration/plugins/modules/win_tempfile.ps1 new file mode 100644 index 00000000..9a1a7174 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_tempfile.ps1 @@ -0,0 +1,72 @@ +#!powershell + +# Copyright: (c) 2017, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.Basic + +Function New-TempFile { + Param ([string]$path, [string]$prefix, [string]$suffix, [string]$type, [bool]$checkmode) + $temppath = $null + $curerror = $null + $attempt = 0 + + # Since we don't know if the file already exists, we try 5 times with a random name + do { + $attempt += 1 + $randomname = [System.IO.Path]::GetRandomFileName() + $temppath = (Join-Path -Path $path -ChildPath "$prefix$randomname$suffix") + Try { + $file = New-Item -Path $temppath -ItemType $type -WhatIf:$checkmode + # Makes sure we get the full absolute path of the created temp file and not a relative or DOS 8.3 dir + if (-not $checkmode) { + $temppath = $file.FullName + } else { + # Just rely on GetFulLpath for check mode + $temppath = [System.IO.Path]::GetFullPath($temppath) + } + } Catch { + $temppath = $null + $curerror = $_ + } + } until (($null -ne $temppath) -or ($attempt -ge 5)) + + # If it fails 5 times, something is wrong and we have to report the details + if ($null -eq $temppath) { + $module.FailJson("No random temporary file worked in $attempt attempts. Error: $($curerror.Exception.Message)", $curerror) + } + + return $temppath.ToString() +} + +$spec = @{ + options = @{ + path = @{ type='path'; default='%TEMP%'; aliases=@( 'dest' ) } + state = @{ type='str'; default='file'; choices=@( 'directory', 'file') } + prefix = @{ type='str'; default='ansible.' } + suffix = @{ type='str' } + } + supports_check_mode = $true +} + +$module = [Ansible.Basic.AnsibleModule]::Create($args, $spec) + +$path = $module.Params.path +$state = $module.Params.state +$prefix = $module.Params.prefix +$suffix = $module.Params.suffix + +# Expand environment variables on non-path types +if ($null -ne $prefix) { + $prefix = [System.Environment]::ExpandEnvironmentVariables($prefix) +} +if ($null -ne $suffix) { + $suffix = [System.Environment]::ExpandEnvironmentVariables($suffix) +} + +$module.Result.changed = $true +$module.Result.state = $state + +$module.Result.path = New-TempFile -Path $path -Prefix $prefix -Suffix $suffix -Type $state -CheckMode $module.CheckMode + +$module.ExitJson() diff --git a/test/support/windows-integration/plugins/modules/win_tempfile.py b/test/support/windows-integration/plugins/modules/win_tempfile.py new file mode 100644 index 00000000..58dd6501 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_tempfile.py @@ -0,0 +1,67 @@ +#!/usr/bin/python +# coding: utf-8 -*- + +# Copyright: (c) 2017, Dag Wieers <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_tempfile +version_added: "2.3" +short_description: Creates temporary files and directories +description: + - Creates temporary files and directories. + - For non-Windows targets, please use the M(tempfile) module instead. +options: + state: + description: + - Whether to create file or directory. + type: str + choices: [ directory, file ] + default: file + path: + description: + - Location where temporary file or directory should be created. + - If path is not specified default system temporary directory (%TEMP%) will be used. + type: path + default: '%TEMP%' + aliases: [ dest ] + prefix: + description: + - Prefix of file/directory name created by module. + type: str + default: ansible. + suffix: + description: + - Suffix of file/directory name created by module. + type: str + default: '' +seealso: +- module: tempfile +author: +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r""" +- name: Create temporary build directory + win_tempfile: + state: directory + suffix: build + +- name: Create temporary file + win_tempfile: + state: file + suffix: temp +""" + +RETURN = r''' +path: + description: The absolute path to the created file or directory. + returned: success + type: str + sample: C:\Users\Administrator\AppData\Local\Temp\ansible.bMlvdk +''' diff --git a/test/support/windows-integration/plugins/modules/win_template.py b/test/support/windows-integration/plugins/modules/win_template.py new file mode 100644 index 00000000..bd8b2492 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_template.py @@ -0,0 +1,66 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a virtual module that is entirely implemented server side + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_template +version_added: "1.9.2" +short_description: Template a file out to a remote server +options: + backup: + description: + - Determine whether a backup should be created. + - When set to C(yes), create a backup file including the timestamp information + so you can get the original file back if you somehow clobbered it incorrectly. + type: bool + default: no + version_added: '2.8' + newline_sequence: + default: '\r\n' + force: + version_added: '2.4' +notes: +- Beware fetching files from windows machines when creating templates because certain tools, such as Powershell ISE, + and regedit's export facility add a Byte Order Mark as the first character of the file, which can cause tracebacks. +- You can use the M(win_copy) module with the C(content:) option if you prefer the template inline, as part of the + playbook. +- For Linux you can use M(template) which uses '\\n' as C(newline_sequence) by default. +seealso: +- module: win_copy +- module: copy +- module: template +author: +- Jon Hawkesworth (@jhawkesworth) +extends_documentation_fragment: +- template_common +''' + +EXAMPLES = r''' +- name: Create a file from a Jinja2 template + win_template: + src: /mytemplates/file.conf.j2 + dest: C:\Temp\file.conf + +- name: Create a Unix-style file from a Jinja2 template + win_template: + src: unix/config.conf.j2 + dest: C:\share\unix\config.conf + newline_sequence: '\n' + backup: yes +''' + +RETURN = r''' +backup_file: + description: Name of the backup file that was created. + returned: if backup=yes + type: str + sample: C:\Path\To\File.txt.11540.20150212-220915.bak +''' diff --git a/test/support/windows-integration/plugins/modules/win_user.ps1 b/test/support/windows-integration/plugins/modules/win_user.ps1 new file mode 100644 index 00000000..54905cb2 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user.ps1 @@ -0,0 +1,273 @@ +#!powershell + +# Copyright: (c) 2014, Paul Durivage <paul.durivage@rackspace.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#AnsibleRequires -CSharpUtil Ansible.AccessToken +#Requires -Module Ansible.ModuleUtils.Legacy + +######## +$ADS_UF_PASSWD_CANT_CHANGE = 64 +$ADS_UF_DONT_EXPIRE_PASSWD = 65536 + +$adsi = [ADSI]"WinNT://$env:COMPUTERNAME" + +function Get-User($user) { + $adsi.Children | Where-Object {$_.SchemaClassName -eq 'user' -and $_.Name -eq $user } + return +} + +function Get-UserFlag($user, $flag) { + If ($user.UserFlags[0] -band $flag) { + $true + } + Else { + $false + } +} + +function Set-UserFlag($user, $flag) { + $user.UserFlags = ($user.UserFlags[0] -BOR $flag) +} + +function Clear-UserFlag($user, $flag) { + $user.UserFlags = ($user.UserFlags[0] -BXOR $flag) +} + +function Get-Group($grp) { + $adsi.Children | Where-Object { $_.SchemaClassName -eq 'Group' -and $_.Name -eq $grp } + return +} + +Function Test-LocalCredential { + param([String]$Username, [String]$Password) + + try { + $handle = [Ansible.AccessToken.TokenUtil]::LogonUser($Username, $null, $Password, "Network", "Default") + $handle.Dispose() + $valid_credentials = $true + } catch [Ansible.AccessToken.Win32Exception] { + # following errors indicate the creds are correct but the user was + # unable to log on for other reasons, which we don't care about + $success_codes = @( + 0x0000052F, # ERROR_ACCOUNT_RESTRICTION + 0x00000530, # ERROR_INVALID_LOGON_HOURS + 0x00000531, # ERROR_INVALID_WORKSTATION + 0x00000569 # ERROR_LOGON_TYPE_GRANTED + ) + + if ($_.Exception.NativeErrorCode -eq 0x0000052E) { + # ERROR_LOGON_FAILURE - the user or pass was incorrect + $valid_credentials = $false + } elseif ($_.Exception.NativeErrorCode -in $success_codes) { + $valid_credentials = $true + } else { + # an unknown failure, reraise exception + throw $_ + } + } + return $valid_credentials +} + +######## + +$params = Parse-Args $args; + +$result = @{ + changed = $false +}; + +$username = Get-AnsibleParam -obj $params -name "name" -type "str" -failifempty $true +$fullname = Get-AnsibleParam -obj $params -name "fullname" -type "str" +$description = Get-AnsibleParam -obj $params -name "description" -type "str" +$password = Get-AnsibleParam -obj $params -name "password" -type "str" +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "present" -validateset "present","absent","query" +$update_password = Get-AnsibleParam -obj $params -name "update_password" -type "str" -default "always" -validateset "always","on_create" +$password_expired = Get-AnsibleParam -obj $params -name "password_expired" -type "bool" +$password_never_expires = Get-AnsibleParam -obj $params -name "password_never_expires" -type "bool" +$user_cannot_change_password = Get-AnsibleParam -obj $params -name "user_cannot_change_password" -type "bool" +$account_disabled = Get-AnsibleParam -obj $params -name "account_disabled" -type "bool" +$account_locked = Get-AnsibleParam -obj $params -name "account_locked" -type "bool" +$groups = Get-AnsibleParam -obj $params -name "groups" +$groups_action = Get-AnsibleParam -obj $params -name "groups_action" -type "str" -default "replace" -validateset "add","remove","replace" + +If ($null -ne $account_locked -and $account_locked) { + Fail-Json $result "account_locked must be set to 'no' if provided" +} + +If ($null -ne $groups) { + If ($groups -is [System.String]) { + [string[]]$groups = $groups.Split(",") + } + ElseIf ($groups -isnot [System.Collections.IList]) { + Fail-Json $result "groups must be a string or array" + } + $groups = $groups | ForEach-Object { ([string]$_).Trim() } | Where-Object { $_ } + If ($null -eq $groups) { + $groups = @() + } +} + +$user_obj = Get-User $username + +If ($state -eq 'present') { + # Add or update user + try { + If (-not $user_obj) { + $user_obj = $adsi.Create("User", $username) + If ($null -ne $password) { + $user_obj.SetPassword($password) + } + $user_obj.SetInfo() + $result.changed = $true + } + ElseIf (($null -ne $password) -and ($update_password -eq 'always')) { + # ValidateCredentials will fail if either of these are true- just force update... + If($user_obj.AccountDisabled -or $user_obj.PasswordExpired) { + $password_match = $false + } + Else { + try { + $password_match = Test-LocalCredential -Username $username -Password $password + } catch [System.ComponentModel.Win32Exception] { + Fail-Json -obj $result -message "Failed to validate the user's credentials: $($_.Exception.Message)" + } + } + + If (-not $password_match) { + $user_obj.SetPassword($password) + $result.changed = $true + } + } + If (($null -ne $fullname) -and ($fullname -ne $user_obj.FullName[0])) { + $user_obj.FullName = $fullname + $result.changed = $true + } + If (($null -ne $description) -and ($description -ne $user_obj.Description[0])) { + $user_obj.Description = $description + $result.changed = $true + } + If (($null -ne $password_expired) -and ($password_expired -ne ($user_obj.PasswordExpired | ConvertTo-Bool))) { + $user_obj.PasswordExpired = If ($password_expired) { 1 } Else { 0 } + $result.changed = $true + } + If (($null -ne $password_never_expires) -and ($password_never_expires -ne (Get-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD))) { + If ($password_never_expires) { + Set-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD + } + Else { + Clear-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD + } + $result.changed = $true + } + If (($null -ne $user_cannot_change_password) -and ($user_cannot_change_password -ne (Get-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE))) { + If ($user_cannot_change_password) { + Set-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE + } + Else { + Clear-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE + } + $result.changed = $true + } + If (($null -ne $account_disabled) -and ($account_disabled -ne $user_obj.AccountDisabled)) { + $user_obj.AccountDisabled = $account_disabled + $result.changed = $true + } + If (($null -ne $account_locked) -and ($account_locked -ne $user_obj.IsAccountLocked)) { + $user_obj.IsAccountLocked = $account_locked + $result.changed = $true + } + If ($result.changed) { + $user_obj.SetInfo() + } + If ($null -ne $groups) { + [string[]]$current_groups = $user_obj.Groups() | ForEach-Object { $_.GetType().InvokeMember("Name", "GetProperty", $null, $_, $null) } + If (($groups_action -eq "remove") -or ($groups_action -eq "replace")) { + ForEach ($grp in $current_groups) { + If ((($groups_action -eq "remove") -and ($groups -contains $grp)) -or (($groups_action -eq "replace") -and ($groups -notcontains $grp))) { + $group_obj = Get-Group $grp + If ($group_obj) { + $group_obj.Remove($user_obj.Path) + $result.changed = $true + } + Else { + Fail-Json $result "group '$grp' not found" + } + } + } + } + If (($groups_action -eq "add") -or ($groups_action -eq "replace")) { + ForEach ($grp in $groups) { + If ($current_groups -notcontains $grp) { + $group_obj = Get-Group $grp + If ($group_obj) { + $group_obj.Add($user_obj.Path) + $result.changed = $true + } + Else { + Fail-Json $result "group '$grp' not found" + } + } + } + } + } + } + catch { + Fail-Json $result $_.Exception.Message + } +} +ElseIf ($state -eq 'absent') { + # Remove user + try { + If ($user_obj) { + $username = $user_obj.Name.Value + $adsi.delete("User", $user_obj.Name.Value) + $result.changed = $true + $result.msg = "User '$username' deleted successfully" + $user_obj = $null + } else { + $result.msg = "User '$username' was not found" + } + } + catch { + Fail-Json $result $_.Exception.Message + } +} + +try { + If ($user_obj -and $user_obj -is [System.DirectoryServices.DirectoryEntry]) { + $user_obj.RefreshCache() + $result.name = $user_obj.Name[0] + $result.fullname = $user_obj.FullName[0] + $result.path = $user_obj.Path + $result.description = $user_obj.Description[0] + $result.password_expired = ($user_obj.PasswordExpired | ConvertTo-Bool) + $result.password_never_expires = (Get-UserFlag $user_obj $ADS_UF_DONT_EXPIRE_PASSWD) + $result.user_cannot_change_password = (Get-UserFlag $user_obj $ADS_UF_PASSWD_CANT_CHANGE) + $result.account_disabled = $user_obj.AccountDisabled + $result.account_locked = $user_obj.IsAccountLocked + $result.sid = (New-Object System.Security.Principal.SecurityIdentifier($user_obj.ObjectSid.Value, 0)).Value + $user_groups = @() + ForEach ($grp in $user_obj.Groups()) { + $group_result = @{ + name = $grp.GetType().InvokeMember("Name", "GetProperty", $null, $grp, $null) + path = $grp.GetType().InvokeMember("ADsPath", "GetProperty", $null, $grp, $null) + } + $user_groups += $group_result; + } + $result.groups = $user_groups + $result.state = "present" + } + Else { + $result.name = $username + if ($state -eq 'query') { + $result.msg = "User '$username' was not found" + } + $result.state = "absent" + } +} +catch { + Fail-Json $result $_.Exception.Message +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_user.py b/test/support/windows-integration/plugins/modules/win_user.py new file mode 100644 index 00000000..5fc0633d --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user.py @@ -0,0 +1,194 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2014, Matt Martz <matt@sivel.net>, and others +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: win_user +version_added: "1.7" +short_description: Manages local Windows user accounts +description: + - Manages local Windows user accounts. + - For non-Windows targets, use the M(user) module instead. +options: + name: + description: + - Name of the user to create, remove or modify. + type: str + required: yes + fullname: + description: + - Full name of the user. + type: str + version_added: "1.9" + description: + description: + - Description of the user. + type: str + version_added: "1.9" + password: + description: + - Optionally set the user's password to this (plain text) value. + type: str + update_password: + description: + - C(always) will update passwords if they differ. C(on_create) will + only set the password for newly created users. + type: str + choices: [ always, on_create ] + default: always + version_added: "1.9" + password_expired: + description: + - C(yes) will require the user to change their password at next login. + - C(no) will clear the expired password flag. + type: bool + version_added: "1.9" + password_never_expires: + description: + - C(yes) will set the password to never expire. + - C(no) will allow the password to expire. + type: bool + version_added: "1.9" + user_cannot_change_password: + description: + - C(yes) will prevent the user from changing their password. + - C(no) will allow the user to change their password. + type: bool + version_added: "1.9" + account_disabled: + description: + - C(yes) will disable the user account. + - C(no) will clear the disabled flag. + type: bool + version_added: "1.9" + account_locked: + description: + - C(no) will unlock the user account if locked. + choices: [ 'no' ] + version_added: "1.9" + groups: + description: + - Adds or removes the user from this comma-separated list of groups, + depending on the value of I(groups_action). + - When I(groups_action) is C(replace) and I(groups) is set to the empty + string ('groups='), the user is removed from all groups. + version_added: "1.9" + groups_action: + description: + - If C(add), the user is added to each group in I(groups) where not + already a member. + - If C(replace), the user is added as a member of each group in + I(groups) and removed from any other groups. + - If C(remove), the user is removed from each group in I(groups). + type: str + choices: [ add, replace, remove ] + default: replace + version_added: "1.9" + state: + description: + - When C(absent), removes the user account if it exists. + - When C(present), creates or updates the user account. + - When C(query) (new in 1.9), retrieves the user account details + without making any changes. + type: str + choices: [ absent, present, query ] + default: present +seealso: +- module: user +- module: win_domain_membership +- module: win_domain_user +- module: win_group +- module: win_group_membership +- module: win_user_profile +author: + - Paul Durivage (@angstwad) + - Chris Church (@cchurch) +''' + +EXAMPLES = r''' +- name: Ensure user bob is present + win_user: + name: bob + password: B0bP4ssw0rd + state: present + groups: + - Users + +- name: Ensure user bob is absent + win_user: + name: bob + state: absent +''' + +RETURN = r''' +account_disabled: + description: Whether the user is disabled. + returned: user exists + type: bool + sample: false +account_locked: + description: Whether the user is locked. + returned: user exists + type: bool + sample: false +description: + description: The description set for the user. + returned: user exists + type: str + sample: Username for test +fullname: + description: The full name set for the user. + returned: user exists + type: str + sample: Test Username +groups: + description: A list of groups and their ADSI path the user is a member of. + returned: user exists + type: list + sample: [ + { + "name": "Administrators", + "path": "WinNT://WORKGROUP/USER-PC/Administrators" + } + ] +name: + description: The name of the user + returned: always + type: str + sample: username +password_expired: + description: Whether the password is expired. + returned: user exists + type: bool + sample: false +password_never_expires: + description: Whether the password is set to never expire. + returned: user exists + type: bool + sample: true +path: + description: The ADSI path for the user. + returned: user exists + type: str + sample: "WinNT://WORKGROUP/USER-PC/username" +sid: + description: The SID for the user. + returned: user exists + type: str + sample: S-1-5-21-3322259488-2828151810-3939402796-1001 +user_cannot_change_password: + description: Whether the user can change their own password. + returned: user exists + type: bool + sample: false +''' diff --git a/test/support/windows-integration/plugins/modules/win_user_right.ps1 b/test/support/windows-integration/plugins/modules/win_user_right.ps1 new file mode 100644 index 00000000..3fac52a8 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user_right.ps1 @@ -0,0 +1,349 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.SID + +$ErrorActionPreference = 'Stop' + +$params = Parse-Args $args -supports_check_mode $true +$check_mode = Get-AnsibleParam -obj $params -name "_ansible_check_mode" -type "bool" -default $false +$diff_mode = Get-AnsibleParam -obj $params -name "_ansible_diff" -type "bool" -default $false +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$name = Get-AnsibleParam -obj $params -name "name" -type "str" -failifempty $true +$users = Get-AnsibleParam -obj $params -name "users" -type "list" -failifempty $true +$action = Get-AnsibleParam -obj $params -name "action" -type "str" -default "set" -validateset "add","remove","set" + +$result = @{ + changed = $false + added = @() + removed = @() +} + +if ($diff_mode) { + $result.diff = @{} +} + +$sec_helper_util = @" +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Security.Principal; + +namespace Ansible +{ + public class LsaRightHelper : IDisposable + { + // Code modified from https://gallery.technet.microsoft.com/scriptcenter/Grant-Revoke-Query-user-26e259b0 + + enum Access : int + { + POLICY_READ = 0x20006, + POLICY_ALL_ACCESS = 0x00F0FFF, + POLICY_EXECUTE = 0X20801, + POLICY_WRITE = 0X207F8 + } + + IntPtr lsaHandle; + + const string LSA_DLL = "advapi32.dll"; + const CharSet DEFAULT_CHAR_SET = CharSet.Unicode; + + const uint STATUS_NO_MORE_ENTRIES = 0x8000001a; + const uint STATUS_NO_SUCH_PRIVILEGE = 0xc0000060; + + internal sealed class Sid : IDisposable + { + public IntPtr pSid = IntPtr.Zero; + public SecurityIdentifier sid = null; + + public Sid(string sidString) + { + try + { + sid = new SecurityIdentifier(sidString); + } catch + { + throw new ArgumentException(String.Format("SID string {0} could not be converted to SecurityIdentifier", sidString)); + } + + Byte[] buffer = new Byte[sid.BinaryLength]; + sid.GetBinaryForm(buffer, 0); + + pSid = Marshal.AllocHGlobal(sid.BinaryLength); + Marshal.Copy(buffer, 0, pSid, sid.BinaryLength); + } + + public void Dispose() + { + if (pSid != IntPtr.Zero) + { + Marshal.FreeHGlobal(pSid); + pSid = IntPtr.Zero; + } + GC.SuppressFinalize(this); + } + ~Sid() { Dispose(); } + } + + [StructLayout(LayoutKind.Sequential)] + private struct LSA_OBJECT_ATTRIBUTES + { + public int Length; + public IntPtr RootDirectory; + public IntPtr ObjectName; + public int Attributes; + public IntPtr SecurityDescriptor; + public IntPtr SecurityQualityOfService; + } + + [StructLayout(LayoutKind.Sequential, CharSet = DEFAULT_CHAR_SET)] + private struct LSA_UNICODE_STRING + { + public ushort Length; + public ushort MaximumLength; + [MarshalAs(UnmanagedType.LPWStr)] + public string Buffer; + } + + [StructLayout(LayoutKind.Sequential)] + private struct LSA_ENUMERATION_INFORMATION + { + public IntPtr Sid; + } + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaOpenPolicy( + LSA_UNICODE_STRING[] SystemName, + ref LSA_OBJECT_ATTRIBUTES ObjectAttributes, + int AccessMask, + out IntPtr PolicyHandle + ); + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaAddAccountRights( + IntPtr PolicyHandle, + IntPtr pSID, + LSA_UNICODE_STRING[] UserRights, + int CountOfRights + ); + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaRemoveAccountRights( + IntPtr PolicyHandle, + IntPtr pSID, + bool AllRights, + LSA_UNICODE_STRING[] UserRights, + int CountOfRights + ); + + [DllImport(LSA_DLL, CharSet = DEFAULT_CHAR_SET, SetLastError = true)] + private static extern uint LsaEnumerateAccountsWithUserRight( + IntPtr PolicyHandle, + LSA_UNICODE_STRING[] UserRights, + out IntPtr EnumerationBuffer, + out ulong CountReturned + ); + + [DllImport(LSA_DLL)] + private static extern int LsaNtStatusToWinError(int NTSTATUS); + + [DllImport(LSA_DLL)] + private static extern int LsaClose(IntPtr PolicyHandle); + + [DllImport(LSA_DLL)] + private static extern int LsaFreeMemory(IntPtr Buffer); + + public LsaRightHelper() + { + LSA_OBJECT_ATTRIBUTES lsaAttr; + lsaAttr.RootDirectory = IntPtr.Zero; + lsaAttr.ObjectName = IntPtr.Zero; + lsaAttr.Attributes = 0; + lsaAttr.SecurityDescriptor = IntPtr.Zero; + lsaAttr.SecurityQualityOfService = IntPtr.Zero; + lsaAttr.Length = Marshal.SizeOf(typeof(LSA_OBJECT_ATTRIBUTES)); + + lsaHandle = IntPtr.Zero; + + LSA_UNICODE_STRING[] system = new LSA_UNICODE_STRING[1]; + system[0] = InitLsaString(""); + + uint ret = LsaOpenPolicy(system, ref lsaAttr, (int)Access.POLICY_ALL_ACCESS, out lsaHandle); + if (ret != 0) + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + + public void AddPrivilege(string sidString, string privilege) + { + uint ret = 0; + using (Sid sid = new Sid(sidString)) + { + LSA_UNICODE_STRING[] privileges = new LSA_UNICODE_STRING[1]; + privileges[0] = InitLsaString(privilege); + ret = LsaAddAccountRights(lsaHandle, sid.pSid, privileges, 1); + } + if (ret != 0) + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + + public void RemovePrivilege(string sidString, string privilege) + { + uint ret = 0; + using (Sid sid = new Sid(sidString)) + { + LSA_UNICODE_STRING[] privileges = new LSA_UNICODE_STRING[1]; + privileges[0] = InitLsaString(privilege); + ret = LsaRemoveAccountRights(lsaHandle, sid.pSid, false, privileges, 1); + } + if (ret != 0) + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + + public string[] EnumerateAccountsWithUserRight(string privilege) + { + uint ret = 0; + ulong count = 0; + LSA_UNICODE_STRING[] rights = new LSA_UNICODE_STRING[1]; + rights[0] = InitLsaString(privilege); + IntPtr buffer = IntPtr.Zero; + + ret = LsaEnumerateAccountsWithUserRight(lsaHandle, rights, out buffer, out count); + switch (ret) + { + case 0: + string[] accounts = new string[count]; + for (int i = 0; i < (int)count; i++) + { + LSA_ENUMERATION_INFORMATION LsaInfo = (LSA_ENUMERATION_INFORMATION)Marshal.PtrToStructure( + IntPtr.Add(buffer, i * Marshal.SizeOf(typeof(LSA_ENUMERATION_INFORMATION))), + typeof(LSA_ENUMERATION_INFORMATION)); + + accounts[i] = new SecurityIdentifier(LsaInfo.Sid).ToString(); + } + LsaFreeMemory(buffer); + return accounts; + + case STATUS_NO_MORE_ENTRIES: + return new string[0]; + + case STATUS_NO_SUCH_PRIVILEGE: + throw new ArgumentException(String.Format("Invalid privilege {0} not found in LSA database", privilege)); + + default: + throw new Win32Exception(LsaNtStatusToWinError((int)ret)); + } + } + + static LSA_UNICODE_STRING InitLsaString(string s) + { + // Unicode strings max. 32KB + if (s.Length > 0x7ffe) + throw new ArgumentException("String too long"); + + LSA_UNICODE_STRING lus = new LSA_UNICODE_STRING(); + lus.Buffer = s; + lus.Length = (ushort)(s.Length * sizeof(char)); + lus.MaximumLength = (ushort)(lus.Length + sizeof(char)); + + return lus; + } + + public void Dispose() + { + if (lsaHandle != IntPtr.Zero) + { + LsaClose(lsaHandle); + lsaHandle = IntPtr.Zero; + } + GC.SuppressFinalize(this); + } + ~LsaRightHelper() { Dispose(); } + } +} +"@ + +$original_tmp = $env:TMP +$env:TMP = $_remote_tmp +Add-Type -TypeDefinition $sec_helper_util +$env:TMP = $original_tmp + +Function Compare-UserList($existing_users, $new_users) { + $added_users = [String[]]@() + $removed_users = [String[]]@() + if ($action -eq "add") { + $added_users = [Linq.Enumerable]::Except($new_users, $existing_users) + } elseif ($action -eq "remove") { + $removed_users = [Linq.Enumerable]::Intersect($new_users, $existing_users) + } else { + $added_users = [Linq.Enumerable]::Except($new_users, $existing_users) + $removed_users = [Linq.Enumerable]::Except($existing_users, $new_users) + } + + $change_result = @{ + added = $added_users + removed = $removed_users + } + + return $change_result +} + +# C# class we can use to enumerate/add/remove rights +$lsa_helper = New-Object -TypeName Ansible.LsaRightHelper + +$new_users = [System.Collections.ArrayList]@() +foreach ($user in $users) { + $user_sid = Convert-ToSID -account_name $user + $new_users.Add($user_sid) > $null +} +$new_users = [String[]]$new_users.ToArray() +try { + $existing_users = $lsa_helper.EnumerateAccountsWithUserRight($name) +} catch [ArgumentException] { + Fail-Json -obj $result -message "the specified right $name is not a valid right" +} catch { + Fail-Json -obj $result -message "failed to enumerate existing accounts with right: $($_.Exception.Message)" +} + +$change_result = Compare-UserList -existing_users $existing_users -new_user $new_users +if (($change_result.added.Length -gt 0) -or ($change_result.removed.Length -gt 0)) { + $result.changed = $true + $diff_text = "[$name]`n" + + # used in diff mode calculation + $new_user_list = [System.Collections.ArrayList]$existing_users + foreach ($user in $change_result.removed) { + if (-not $check_mode) { + $lsa_helper.RemovePrivilege($user, $name) + } + $user_name = Convert-FromSID -sid $user + $result.removed += $user_name + $diff_text += "-$user_name`n" + $new_user_list.Remove($user) > $null + } + foreach ($user in $change_result.added) { + if (-not $check_mode) { + $lsa_helper.AddPrivilege($user, $name) + } + $user_name = Convert-FromSID -sid $user + $result.added += $user_name + $diff_text += "+$user_name`n" + $new_user_list.Add($user) > $null + } + + if ($diff_mode) { + if ($new_user_list.Count -eq 0) { + $diff_text = "-$diff_text" + } else { + if ($existing_users.Count -eq 0) { + $diff_text = "+$diff_text" + } + } + $result.diff.prepared = $diff_text + } +} + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_user_right.py b/test/support/windows-integration/plugins/modules/win_user_right.py new file mode 100644 index 00000000..55882083 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_user_right.py @@ -0,0 +1,108 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_user_right +version_added: '2.4' +short_description: Manage Windows User Rights +description: +- Add, remove or set User Rights for a group or users or groups. +- You can set user rights for both local and domain accounts. +options: + name: + description: + - The name of the User Right as shown by the C(Constant Name) value from + U(https://technet.microsoft.com/en-us/library/dd349804.aspx). + - The module will return an error if the right is invalid. + type: str + required: yes + users: + description: + - A list of users or groups to add/remove on the User Right. + - These can be in the form DOMAIN\user-group, user-group@DOMAIN.COM for + domain users/groups. + - For local users/groups it can be in the form user-group, .\user-group, + SERVERNAME\user-group where SERVERNAME is the name of the remote server. + - You can also add special local accounts like SYSTEM and others. + - Can be set to an empty list with I(action=set) to remove all accounts + from the right. + type: list + required: yes + action: + description: + - C(add) will add the users/groups to the existing right. + - C(remove) will remove the users/groups from the existing right. + - C(set) will replace the users/groups of the existing right. + type: str + default: set + choices: [ add, remove, set ] +notes: +- If the server is domain joined this module can change a right but if a GPO + governs this right then the changes won't last. +seealso: +- module: win_group +- module: win_group_membership +- module: win_user +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +--- +- name: Replace the entries of Deny log on locally + win_user_right: + name: SeDenyInteractiveLogonRight + users: + - Guest + - Users + action: set + +- name: Add account to Log on as a service + win_user_right: + name: SeServiceLogonRight + users: + - .\Administrator + - '{{ansible_hostname}}\local-user' + action: add + +- name: Remove accounts who can create Symbolic links + win_user_right: + name: SeCreateSymbolicLinkPrivilege + users: + - SYSTEM + - Administrators + - DOMAIN\User + - group@DOMAIN.COM + action: remove + +- name: Remove all accounts who cannot log on remote interactively + win_user_right: + name: SeDenyRemoteInteractiveLogonRight + users: [] +''' + +RETURN = r''' +added: + description: A list of accounts that were added to the right, this is empty + if no accounts were added. + returned: success + type: list + sample: ["NT AUTHORITY\\SYSTEM", "DOMAIN\\User"] +removed: + description: A list of accounts that were removed from the right, this is + empty if no accounts were removed. + returned: success + type: list + sample: ["SERVERNAME\\Administrator", "BUILTIN\\Administrators"] +''' diff --git a/test/support/windows-integration/plugins/modules/win_wait_for.ps1 b/test/support/windows-integration/plugins/modules/win_wait_for.ps1 new file mode 100644 index 00000000..e0a9a720 --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_wait_for.ps1 @@ -0,0 +1,259 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.FileUtil + +$ErrorActionPreference = "Stop" + +$params = Parse-Args -arguments $args -supports_check_mode $true + +$connect_timeout = Get-AnsibleParam -obj $params -name "connect_timeout" -type "int" -default 5 +$delay = Get-AnsibleParam -obj $params -name "delay" -type "int" +$exclude_hosts = Get-AnsibleParam -obj $params -name "exclude_hosts" -type "list" +$hostname = Get-AnsibleParam -obj $params -name "host" -type "str" -default "127.0.0.1" +$path = Get-AnsibleParam -obj $params -name "path" -type "path" +$port = Get-AnsibleParam -obj $params -name "port" -type "int" +$regex = Get-AnsibleParam -obj $params -name "regex" -type "str" -aliases "search_regex","regexp" +$sleep = Get-AnsibleParam -obj $params -name "sleep" -type "int" -default 1 +$state = Get-AnsibleParam -obj $params -name "state" -type "str" -default "started" -validateset "present","started","stopped","absent","drained" +$timeout = Get-AnsibleParam -obj $params -name "timeout" -type "int" -default 300 + +$result = @{ + changed = $false + elapsed = 0 +} + +# validate the input with the various options +if ($null -ne $port -and $null -ne $path) { + Fail-Json $result "port and path parameter can not both be passed to win_wait_for" +} +if ($null -ne $exclude_hosts -and $state -ne "drained") { + Fail-Json $result "exclude_hosts should only be with state=drained" +} +if ($null -ne $path) { + if ($state -in @("stopped","drained")) { + Fail-Json $result "state=$state should only be used for checking a port in the win_wait_for module" + } + + if ($null -ne $exclude_hosts) { + Fail-Json $result "exclude_hosts should only be used when checking a port and state=drained in the win_wait_for module" + } +} + +if ($null -ne $port) { + if ($null -ne $regex) { + Fail-Json $result "regex should by used when checking a string in a file in the win_wait_for module" + } + + if ($null -ne $exclude_hosts -and $state -ne "drained") { + Fail-Json $result "exclude_hosts should be used when state=drained in the win_wait_for module" + } +} + +Function Test-Port($hostname, $port) { + $timeout = $connect_timeout * 1000 + $socket = New-Object -TypeName System.Net.Sockets.TcpClient + $connect = $socket.BeginConnect($hostname, $port, $null, $null) + $wait = $connect.AsyncWaitHandle.WaitOne($timeout, $false) + + if ($wait) { + try { + $socket.EndConnect($connect) | Out-Null + $valid = $true + } catch { + $valid = $false + } + } else { + $valid = $false + } + + $socket.Close() + $socket.Dispose() + + $valid +} + +Function Get-PortConnections($hostname, $port) { + $connections = @() + + $conn_info = [Net.NetworkInformation.IPGlobalProperties]::GetIPGlobalProperties() + if ($hostname -eq "0.0.0.0") { + $active_connections = $conn_info.GetActiveTcpConnections() | Where-Object { $_.LocalEndPoint.Port -eq $port } + } else { + $active_connections = $conn_info.GetActiveTcpConnections() | Where-Object { $_.LocalEndPoint.Address -eq $hostname -and $_.LocalEndPoint.Port -eq $port } + } + + if ($null -ne $active_connections) { + foreach ($active_connection in $active_connections) { + $connections += $active_connection.RemoteEndPoint.Address + } + } + + $connections +} + +$module_start = Get-Date + +if ($null -ne $delay) { + Start-Sleep -Seconds $delay +} + +$attempts = 0 +if ($null -eq $path -and $null -eq $port -and $state -ne "drained") { + Start-Sleep -Seconds $timeout +} elseif ($null -ne $path) { + if ($state -in @("present", "started")) { + # check if the file exists or string exists in file + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + if (Test-AnsiblePath -Path $path) { + if ($null -eq $regex) { + $complete = $true + break + } else { + $file_contents = Get-Content -Path $path -Raw + if ($file_contents -match $regex) { + $complete = $true + break + } + } + } + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + if ($null -eq $regex) { + Fail-Json $result "timeout while waiting for file $path to be present" + } else { + Fail-Json $result "timeout while waiting for string regex $regex in file $path to match" + } + } + } elseif ($state -in @("absent")) { + # check if the file is deleted or string doesn't exist in file + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + if (Test-AnsiblePath -Path $path) { + if ($null -ne $regex) { + $file_contents = Get-Content -Path $path -Raw + if ($file_contents -notmatch $regex) { + $complete = $true + break + } + } + } else { + $complete = $true + break + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + if ($null -eq $regex) { + Fail-Json $result "timeout while waiting for file $path to be absent" + } else { + Fail-Json $result "timeout while waiting for string regex $regex in file $path to not match" + } + } + } +} elseif ($null -ne $port) { + if ($state -in @("started","present")) { + # check that the port is online and is listening + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + $port_result = Test-Port -hostname $hostname -port $port + if ($port_result -eq $true) { + $complete = $true + break + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + Fail-Json $result "timeout while waiting for $($hostname):$port to start listening" + } + } elseif ($state -in @("stopped","absent")) { + # check that the port is offline and is not listening + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + $port_result = Test-Port -hostname $hostname -port $port + if ($port_result -eq $false) { + $complete = $true + break + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + Fail-Json $result "timeout while waiting for $($hostname):$port to stop listening" + } + } elseif ($state -eq "drained") { + # check that the local port is online but has no active connections + $start_time = Get-Date + $complete = $false + while (((Get-Date) - $start_time).TotalSeconds -lt $timeout) { + $attempts += 1 + $active_connections = Get-PortConnections -hostname $hostname -port $port + if ($null -eq $active_connections) { + $complete = $true + break + } elseif ($active_connections.Count -eq 0) { + # no connections on port + $complete = $true + break + } else { + # there are listeners, check if we should ignore any hosts + if ($null -ne $exclude_hosts) { + $connection_info = $active_connections + foreach ($exclude_host in $exclude_hosts) { + try { + $exclude_ips = [System.Net.Dns]::GetHostAddresses($exclude_host) | ForEach-Object { Write-Output $_.IPAddressToString } + $connection_info = $connection_info | Where-Object { $_ -notin $exclude_ips } + } catch { # ignore invalid hostnames + Add-Warning -obj $result -message "Invalid hostname specified $exclude_host" + } + } + + if ($connection_info.Count -eq 0) { + $complete = $true + break + } + } + } + + Start-Sleep -Seconds $sleep + } + + if ($complete -eq $false) { + $result.elapsed = ((Get-Date) - $module_start).TotalSeconds + $result.wait_attempts = $attempts + Fail-Json $result "timeout while waiting for $($hostname):$port to drain" + } + } +} + +$result.elapsed = ((Get-Date) - $module_start).TotalSeconds +$result.wait_attempts = $attempts + +Exit-Json $result diff --git a/test/support/windows-integration/plugins/modules/win_wait_for.py b/test/support/windows-integration/plugins/modules/win_wait_for.py new file mode 100644 index 00000000..85721e7d --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_wait_for.py @@ -0,0 +1,155 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub, actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_wait_for +version_added: '2.4' +short_description: Waits for a condition before continuing +description: +- You can wait for a set amount of time C(timeout), this is the default if + nothing is specified. +- Waiting for a port to become available is useful for when services are not + immediately available after their init scripts return which is true of + certain Java application servers. +- You can wait for a file to exist or not exist on the filesystem. +- This module can also be used to wait for a regex match string to be present + in a file. +- You can wait for active connections to be closed before continuing on a + local port. +options: + connect_timeout: + description: + - The maximum number of seconds to wait for a connection to happen before + closing and retrying. + type: int + default: 5 + delay: + description: + - The number of seconds to wait before starting to poll. + type: int + exclude_hosts: + description: + - The list of hosts or IPs to ignore when looking for active TCP + connections when C(state=drained). + type: list + host: + description: + - A resolvable hostname or IP address to wait for. + - If C(state=drained) then it will only check for connections on the IP + specified, you can use '0.0.0.0' to use all host IPs. + type: str + default: '127.0.0.1' + path: + description: + - The path to a file on the filesystem to check. + - If C(state) is present or started then it will wait until the file + exists. + - If C(state) is absent then it will wait until the file does not exist. + type: path + port: + description: + - The port number to poll on C(host). + type: int + regex: + description: + - Can be used to match a string in a file. + - If C(state) is present or started then it will wait until the regex + matches. + - If C(state) is absent then it will wait until the regex does not match. + - Defaults to a multiline regex. + type: str + aliases: [ "search_regex", "regexp" ] + sleep: + description: + - Number of seconds to sleep between checks. + type: int + default: 1 + state: + description: + - When checking a port, C(started) will ensure the port is open, C(stopped) + will check that is it closed and C(drained) will check for active + connections. + - When checking for a file or a search string C(present) or C(started) will + ensure that the file or string is present, C(absent) will check that the + file or search string is absent or removed. + type: str + choices: [ absent, drained, present, started, stopped ] + default: started + timeout: + description: + - The maximum number of seconds to wait for. + type: int + default: 300 +seealso: +- module: wait_for +- module: win_wait_for_process +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Wait 300 seconds for port 8000 to become open on the host, don't start checking for 10 seconds + win_wait_for: + port: 8000 + delay: 10 + +- name: Wait 150 seconds for port 8000 of any IP to close active connections + win_wait_for: + host: 0.0.0.0 + port: 8000 + state: drained + timeout: 150 + +- name: Wait for port 8000 of any IP to close active connection, ignoring certain hosts + win_wait_for: + host: 0.0.0.0 + port: 8000 + state: drained + exclude_hosts: ['10.2.1.2', '10.2.1.3'] + +- name: Wait for file C:\temp\log.txt to exist before continuing + win_wait_for: + path: C:\temp\log.txt + +- name: Wait until process complete is in the file before continuing + win_wait_for: + path: C:\temp\log.txt + regex: process complete + +- name: Wait until file is removed + win_wait_for: + path: C:\temp\log.txt + state: absent + +- name: Wait until port 1234 is offline but try every 10 seconds + win_wait_for: + port: 1234 + state: absent + sleep: 10 +''' + +RETURN = r''' +wait_attempts: + description: The number of attempts to poll the file or port before module + finishes. + returned: always + type: int + sample: 1 +elapsed: + description: The elapsed seconds between the start of poll and the end of the + module. This includes the delay if the option is set. + returned: always + type: float + sample: 2.1406487 +''' diff --git a/test/support/windows-integration/plugins/modules/win_whoami.ps1 b/test/support/windows-integration/plugins/modules/win_whoami.ps1 new file mode 100644 index 00000000..6c9965af --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_whoami.ps1 @@ -0,0 +1,837 @@ +#!powershell + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +#Requires -Module Ansible.ModuleUtils.Legacy +#Requires -Module Ansible.ModuleUtils.CamelConversion + +$ErrorActionPreference = "Stop" + +$params = Parse-Args $args -supports_check_mode $true +$_remote_tmp = Get-AnsibleParam $params "_ansible_remote_tmp" -type "path" -default $env:TMP + +$session_util = @' +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; + +namespace Ansible +{ + public class SessionInfo + { + // SECURITY_LOGON_SESSION_DATA + public UInt64 LogonId { get; internal set; } + public Sid Account { get; internal set; } + public string LoginDomain { get; internal set; } + public string AuthenticationPackage { get; internal set; } + public SECURITY_LOGON_TYPE LogonType { get; internal set; } + public string LoginTime { get; internal set; } + public string LogonServer { get; internal set; } + public string DnsDomainName { get; internal set; } + public string Upn { get; internal set; } + public ArrayList UserFlags { get; internal set; } + + // TOKEN_STATISTICS + public SECURITY_IMPERSONATION_LEVEL ImpersonationLevel { get; internal set; } + public TOKEN_TYPE TokenType { get; internal set; } + + // TOKEN_GROUPS + public ArrayList Groups { get; internal set; } + public ArrayList Rights { get; internal set; } + + // TOKEN_MANDATORY_LABEL + public Sid Label { get; internal set; } + + // TOKEN_PRIVILEGES + public Hashtable Privileges { get; internal set; } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct LSA_UNICODE_STRING + { + public UInt16 Length; + public UInt16 MaximumLength; + public IntPtr buffer; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LUID + { + public UInt32 LowPart; + public Int32 HighPart; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SECURITY_LOGON_SESSION_DATA + { + public UInt32 Size; + public LUID LogonId; + public LSA_UNICODE_STRING Username; + public LSA_UNICODE_STRING LoginDomain; + public LSA_UNICODE_STRING AuthenticationPackage; + public SECURITY_LOGON_TYPE LogonType; + public UInt32 Session; + public IntPtr Sid; + public UInt64 LoginTime; + public LSA_UNICODE_STRING LogonServer; + public LSA_UNICODE_STRING DnsDomainName; + public LSA_UNICODE_STRING Upn; + public UInt32 UserFlags; + public LSA_LAST_INTER_LOGON_INFO LastLogonInfo; + public LSA_UNICODE_STRING LogonScript; + public LSA_UNICODE_STRING ProfilePath; + public LSA_UNICODE_STRING HomeDirectory; + public LSA_UNICODE_STRING HomeDirectoryDrive; + public UInt64 LogoffTime; + public UInt64 KickOffTime; + public UInt64 PasswordLastSet; + public UInt64 PasswordCanChange; + public UInt64 PasswordMustChange; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LSA_LAST_INTER_LOGON_INFO + { + public UInt64 LastSuccessfulLogon; + public UInt64 LastFailedLogon; + public UInt32 FailedAttemptCountSinceLastSuccessfulLogon; + } + + public enum TOKEN_TYPE + { + TokenPrimary = 1, + TokenImpersonation + } + + public enum SECURITY_IMPERSONATION_LEVEL + { + SecurityAnonymous, + SecurityIdentification, + SecurityImpersonation, + SecurityDelegation + } + + public enum SECURITY_LOGON_TYPE + { + System = 0, // Used only by the Sytem account + Interactive = 2, + Network, + Batch, + Service, + Proxy, + Unlock, + NetworkCleartext, + NewCredentials, + RemoteInteractive, + CachedInteractive, + CachedRemoteInteractive, + CachedUnlock + } + + [Flags] + public enum TokenGroupAttributes : uint + { + SE_GROUP_ENABLED = 0x00000004, + SE_GROUP_ENABLED_BY_DEFAULT = 0x00000002, + SE_GROUP_INTEGRITY = 0x00000020, + SE_GROUP_INTEGRITY_ENABLED = 0x00000040, + SE_GROUP_LOGON_ID = 0xC0000000, + SE_GROUP_MANDATORY = 0x00000001, + SE_GROUP_OWNER = 0x00000008, + SE_GROUP_RESOURCE = 0x20000000, + SE_GROUP_USE_FOR_DENY_ONLY = 0x00000010, + } + + [Flags] + public enum UserFlags : uint + { + LOGON_OPTIMIZED = 0x4000, + LOGON_WINLOGON = 0x8000, + LOGON_PKINIT = 0x10000, + LOGON_NOT_OPTMIZED = 0x20000, + } + + [StructLayout(LayoutKind.Sequential)] + public struct SID_AND_ATTRIBUTES + { + public IntPtr Sid; + public UInt32 Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LUID_AND_ATTRIBUTES + { + public LUID Luid; + public UInt32 Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_GROUPS + { + public UInt32 GroupCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public SID_AND_ATTRIBUTES[] Groups; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_MANDATORY_LABEL + { + public SID_AND_ATTRIBUTES Label; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_STATISTICS + { + public LUID TokenId; + public LUID AuthenticationId; + public UInt64 ExpirationTime; + public TOKEN_TYPE TokenType; + public SECURITY_IMPERSONATION_LEVEL ImpersonationLevel; + public UInt32 DynamicCharged; + public UInt32 DynamicAvailable; + public UInt32 GroupCount; + public UInt32 PrivilegeCount; + public LUID ModifiedId; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_PRIVILEGES + { + public UInt32 PrivilegeCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public LUID_AND_ATTRIBUTES[] Privileges; + } + + public class AccessToken : IDisposable + { + public enum TOKEN_INFORMATION_CLASS + { + TokenUser = 1, + TokenGroups, + TokenPrivileges, + TokenOwner, + TokenPrimaryGroup, + TokenDefaultDacl, + TokenSource, + TokenType, + TokenImpersonationLevel, + TokenStatistics, + TokenRestrictedSids, + TokenSessionId, + TokenGroupsAndPrivileges, + TokenSessionReference, + TokenSandBoxInert, + TokenAuditPolicy, + TokenOrigin, + TokenElevationType, + TokenLinkedToken, + TokenElevation, + TokenHasRestrictions, + TokenAccessInformation, + TokenVirtualizationAllowed, + TokenVirtualizationEnabled, + TokenIntegrityLevel, + TokenUIAccess, + TokenMandatoryPolicy, + TokenLogonSid, + TokenIsAppContainer, + TokenCapabilities, + TokenAppContainerSid, + TokenAppContainerNumber, + TokenUserClaimAttributes, + TokenDeviceClaimAttributes, + TokenRestrictedUserClaimAttributes, + TokenRestrictedDeviceClaimAttributes, + TokenDeviceGroups, + TokenRestrictedDeviceGroups, + TokenSecurityAttributes, + TokenIsRestricted, + MaxTokenInfoClass + } + + public IntPtr hToken = IntPtr.Zero; + + [DllImport("kernel32.dll")] + private static extern IntPtr GetCurrentProcess(); + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern bool OpenProcessToken( + IntPtr ProcessHandle, + TokenAccessLevels DesiredAccess, + out IntPtr TokenHandle); + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern bool GetTokenInformation( + IntPtr TokenHandle, + TOKEN_INFORMATION_CLASS TokenInformationClass, + IntPtr TokenInformation, + UInt32 TokenInformationLength, + out UInt32 ReturnLength); + + public AccessToken(TokenAccessLevels tokenAccessLevels) + { + IntPtr currentProcess = GetCurrentProcess(); + if (!OpenProcessToken(currentProcess, tokenAccessLevels, out hToken)) + throw new Win32Exception("OpenProcessToken() for current process failed"); + } + + public IntPtr GetTokenInformation<T>(out T tokenInformation, TOKEN_INFORMATION_CLASS tokenClass) + { + UInt32 tokenLength = 0; + GetTokenInformation(hToken, tokenClass, IntPtr.Zero, 0, out tokenLength); + + IntPtr infoPtr = Marshal.AllocHGlobal((int)tokenLength); + + if (!GetTokenInformation(hToken, tokenClass, infoPtr, tokenLength, out tokenLength)) + throw new Win32Exception(String.Format("GetTokenInformation() data for {0} failed", tokenClass.ToString())); + + tokenInformation = (T)Marshal.PtrToStructure(infoPtr, typeof(T)); + return infoPtr; + } + + public void Dispose() + { + GC.SuppressFinalize(this); + } + + ~AccessToken() { Dispose(); } + } + + public class LsaHandle : IDisposable + { + [Flags] + public enum DesiredAccess : uint + { + POLICY_VIEW_LOCAL_INFORMATION = 0x00000001, + POLICY_VIEW_AUDIT_INFORMATION = 0x00000002, + POLICY_GET_PRIVATE_INFORMATION = 0x00000004, + POLICY_TRUST_ADMIN = 0x00000008, + POLICY_CREATE_ACCOUNT = 0x00000010, + POLICY_CREATE_SECRET = 0x00000020, + POLICY_CREATE_PRIVILEGE = 0x00000040, + POLICY_SET_DEFAULT_QUOTA_LIMITS = 0x00000080, + POLICY_SET_AUDIT_REQUIREMENTS = 0x00000100, + POLICY_AUDIT_LOG_ADMIN = 0x00000200, + POLICY_SERVER_ADMIN = 0x00000400, + POLICY_LOOKUP_NAMES = 0x00000800, + POLICY_NOTIFICATION = 0x00001000 + } + + public IntPtr handle = IntPtr.Zero; + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern uint LsaOpenPolicy( + LSA_UNICODE_STRING[] SystemName, + ref LSA_OBJECT_ATTRIBUTES ObjectAttributes, + DesiredAccess AccessMask, + out IntPtr PolicyHandle); + + [DllImport("advapi32.dll", SetLastError = true)] + private static extern uint LsaClose( + IntPtr ObjectHandle); + + [DllImport("advapi32.dll", SetLastError = false)] + private static extern int LsaNtStatusToWinError( + uint Status); + + [StructLayout(LayoutKind.Sequential)] + public struct LSA_OBJECT_ATTRIBUTES + { + public int Length; + public IntPtr RootDirectory; + public IntPtr ObjectName; + public int Attributes; + public IntPtr SecurityDescriptor; + public IntPtr SecurityQualityOfService; + } + + public LsaHandle(DesiredAccess desiredAccess) + { + LSA_OBJECT_ATTRIBUTES lsaAttr; + lsaAttr.RootDirectory = IntPtr.Zero; + lsaAttr.ObjectName = IntPtr.Zero; + lsaAttr.Attributes = 0; + lsaAttr.SecurityDescriptor = IntPtr.Zero; + lsaAttr.SecurityQualityOfService = IntPtr.Zero; + lsaAttr.Length = Marshal.SizeOf(typeof(LSA_OBJECT_ATTRIBUTES)); + LSA_UNICODE_STRING[] system = new LSA_UNICODE_STRING[1]; + system[0].buffer = IntPtr.Zero; + + uint res = LsaOpenPolicy(system, ref lsaAttr, desiredAccess, out handle); + if (res != 0) + throw new Win32Exception(LsaNtStatusToWinError(res), "LsaOpenPolicy() failed"); + } + + public void Dispose() + { + if (handle != IntPtr.Zero) + { + LsaClose(handle); + handle = IntPtr.Zero; + } + GC.SuppressFinalize(this); + } + + ~LsaHandle() { Dispose(); } + } + + public class Sid + { + public string SidString { get; internal set; } + public string DomainName { get; internal set; } + public string AccountName { get; internal set; } + public SID_NAME_USE SidType { get; internal set; } + + public enum SID_NAME_USE + { + SidTypeUser = 1, + SidTypeGroup, + SidTypeDomain, + SidTypeAlias, + SidTypeWellKnownGroup, + SidTypeDeletedAccount, + SidTypeInvalid, + SidTypeUnknown, + SidTypeComputer, + SidTypeLabel, + SidTypeLogon, + } + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + private static extern bool LookupAccountSid( + string lpSystemName, + [MarshalAs(UnmanagedType.LPArray)] + byte[] Sid, + StringBuilder lpName, + ref UInt32 cchName, + StringBuilder ReferencedDomainName, + ref UInt32 cchReferencedDomainName, + out SID_NAME_USE peUse); + + public Sid(IntPtr sidPtr) + { + SecurityIdentifier sid; + try + { + sid = new SecurityIdentifier(sidPtr); + } + catch (Exception e) + { + throw new ArgumentException(String.Format("Failed to cast IntPtr to SecurityIdentifier: {0}", e)); + } + + SetSidInfo(sid); + } + + public Sid(SecurityIdentifier sid) + { + SetSidInfo(sid); + } + + public override string ToString() + { + return SidString; + } + + private void SetSidInfo(SecurityIdentifier sid) + { + byte[] sidBytes = new byte[sid.BinaryLength]; + sid.GetBinaryForm(sidBytes, 0); + + StringBuilder lpName = new StringBuilder(); + UInt32 cchName = 0; + StringBuilder referencedDomainName = new StringBuilder(); + UInt32 cchReferencedDomainName = 0; + SID_NAME_USE peUse; + LookupAccountSid(null, sidBytes, lpName, ref cchName, referencedDomainName, ref cchReferencedDomainName, out peUse); + + lpName.EnsureCapacity((int)cchName); + referencedDomainName.EnsureCapacity((int)cchReferencedDomainName); + + SidString = sid.ToString(); + if (!LookupAccountSid(null, sidBytes, lpName, ref cchName, referencedDomainName, ref cchReferencedDomainName, out peUse)) + { + int lastError = Marshal.GetLastWin32Error(); + + if (lastError != 1332 && lastError != 1789) // Fails to lookup Logon Sid + { + throw new Win32Exception(lastError, String.Format("LookupAccountSid() failed for SID: {0} {1}", sid.ToString(), lastError)); + } + else if (SidString.StartsWith("S-1-5-5-")) + { + AccountName = String.Format("LogonSessionId_{0}", SidString.Substring(8)); + DomainName = "NT AUTHORITY"; + SidType = SID_NAME_USE.SidTypeLogon; + } + else + { + AccountName = null; + DomainName = null; + SidType = SID_NAME_USE.SidTypeUnknown; + } + } + else + { + AccountName = lpName.ToString(); + DomainName = referencedDomainName.ToString(); + SidType = peUse; + } + } + } + + public class SessionUtil + { + [DllImport("secur32.dll", SetLastError = false)] + private static extern uint LsaFreeReturnBuffer( + IntPtr Buffer); + + [DllImport("secur32.dll", SetLastError = false)] + private static extern uint LsaEnumerateLogonSessions( + out UInt64 LogonSessionCount, + out IntPtr LogonSessionList); + + [DllImport("secur32.dll", SetLastError = false)] + private static extern uint LsaGetLogonSessionData( + IntPtr LogonId, + out IntPtr ppLogonSessionData); + + [DllImport("advapi32.dll", SetLastError = false)] + private static extern int LsaNtStatusToWinError( + uint Status); + + [DllImport("advapi32", SetLastError = true)] + private static extern uint LsaEnumerateAccountRights( + IntPtr PolicyHandle, + IntPtr AccountSid, + out IntPtr UserRights, + out UInt64 CountOfRights); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + private static extern bool LookupPrivilegeName( + string lpSystemName, + ref LUID lpLuid, + StringBuilder lpName, + ref UInt32 cchName); + + private const UInt32 SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001; + private const UInt32 SE_PRIVILEGE_ENABLED = 0x00000002; + private const UInt32 STATUS_OBJECT_NAME_NOT_FOUND = 0xC0000034; + private const UInt32 STATUS_ACCESS_DENIED = 0xC0000022; + + public static SessionInfo GetSessionInfo() + { + AccessToken accessToken = new AccessToken(TokenAccessLevels.Query); + + // Get Privileges + Hashtable privilegeInfo = new Hashtable(); + TOKEN_PRIVILEGES privileges; + IntPtr privilegesPtr = accessToken.GetTokenInformation(out privileges, AccessToken.TOKEN_INFORMATION_CLASS.TokenPrivileges); + LUID_AND_ATTRIBUTES[] luidAndAttributes = new LUID_AND_ATTRIBUTES[privileges.PrivilegeCount]; + try + { + PtrToStructureArray(luidAndAttributes, privilegesPtr.ToInt64() + Marshal.SizeOf(privileges.PrivilegeCount)); + } + finally + { + Marshal.FreeHGlobal(privilegesPtr); + } + foreach (LUID_AND_ATTRIBUTES luidAndAttribute in luidAndAttributes) + { + LUID privLuid = luidAndAttribute.Luid; + UInt32 privNameLen = 0; + StringBuilder privName = new StringBuilder(); + LookupPrivilegeName(null, ref privLuid, null, ref privNameLen); + privName.EnsureCapacity((int)(privNameLen + 1)); + if (!LookupPrivilegeName(null, ref privLuid, privName, ref privNameLen)) + throw new Win32Exception("LookupPrivilegeName() failed"); + + string state = "disabled"; + if ((luidAndAttribute.Attributes & SE_PRIVILEGE_ENABLED) == SE_PRIVILEGE_ENABLED) + state = "enabled"; + if ((luidAndAttribute.Attributes & SE_PRIVILEGE_ENABLED_BY_DEFAULT) == SE_PRIVILEGE_ENABLED_BY_DEFAULT) + state = "enabled-by-default"; + privilegeInfo.Add(privName.ToString(), state); + } + + // Get Current Process LogonSID, User Rights and Groups + ArrayList userRights = new ArrayList(); + ArrayList userGroups = new ArrayList(); + TOKEN_GROUPS groups; + IntPtr groupsPtr = accessToken.GetTokenInformation(out groups, AccessToken.TOKEN_INFORMATION_CLASS.TokenGroups); + SID_AND_ATTRIBUTES[] sidAndAttributes = new SID_AND_ATTRIBUTES[groups.GroupCount]; + LsaHandle lsaHandle = null; + // We can only get rights if we are an admin + if (new WindowsPrincipal(WindowsIdentity.GetCurrent()).IsInRole(WindowsBuiltInRole.Administrator)) + lsaHandle = new LsaHandle(LsaHandle.DesiredAccess.POLICY_LOOKUP_NAMES); + try + { + PtrToStructureArray(sidAndAttributes, groupsPtr.ToInt64() + IntPtr.Size); + foreach (SID_AND_ATTRIBUTES sidAndAttribute in sidAndAttributes) + { + TokenGroupAttributes attributes = (TokenGroupAttributes)sidAndAttribute.Attributes; + if (attributes.HasFlag(TokenGroupAttributes.SE_GROUP_ENABLED) && lsaHandle != null) + { + ArrayList rights = GetAccountRights(lsaHandle.handle, sidAndAttribute.Sid); + foreach (string right in rights) + { + // Includes both Privileges and Account Rights, only add the ones with Logon in the name + // https://msdn.microsoft.com/en-us/library/windows/desktop/bb545671(v=vs.85).aspx + if (!userRights.Contains(right) && right.Contains("Logon")) + userRights.Add(right); + } + } + // Do not include the Logon SID in the groups category + if (!attributes.HasFlag(TokenGroupAttributes.SE_GROUP_LOGON_ID)) + { + Hashtable groupInfo = new Hashtable(); + Sid group = new Sid(sidAndAttribute.Sid); + ArrayList groupAttributes = new ArrayList(); + foreach (TokenGroupAttributes attribute in Enum.GetValues(typeof(TokenGroupAttributes))) + { + if (attributes.HasFlag(attribute)) + { + string attributeName = attribute.ToString().Substring(9); + attributeName = attributeName.Replace('_', ' '); + attributeName = attributeName.First().ToString().ToUpper() + attributeName.Substring(1).ToLower(); + groupAttributes.Add(attributeName); + } + } + // Using snake_case here as I can't generically convert all dict keys in PS (see Privileges) + groupInfo.Add("sid", group.SidString); + groupInfo.Add("domain_name", group.DomainName); + groupInfo.Add("account_name", group.AccountName); + groupInfo.Add("type", group.SidType); + groupInfo.Add("attributes", groupAttributes); + userGroups.Add(groupInfo); + } + } + } + finally + { + Marshal.FreeHGlobal(groupsPtr); + if (lsaHandle != null) + lsaHandle.Dispose(); + } + + // Get Integrity Level + Sid integritySid = null; + TOKEN_MANDATORY_LABEL mandatoryLabel; + IntPtr mandatoryLabelPtr = accessToken.GetTokenInformation(out mandatoryLabel, AccessToken.TOKEN_INFORMATION_CLASS.TokenIntegrityLevel); + Marshal.FreeHGlobal(mandatoryLabelPtr); + integritySid = new Sid(mandatoryLabel.Label.Sid); + + // Get Token Statistics + TOKEN_STATISTICS tokenStats; + IntPtr tokenStatsPtr = accessToken.GetTokenInformation(out tokenStats, AccessToken.TOKEN_INFORMATION_CLASS.TokenStatistics); + Marshal.FreeHGlobal(tokenStatsPtr); + + SessionInfo sessionInfo = GetSessionDataForLogonSession(tokenStats.AuthenticationId); + sessionInfo.Groups = userGroups; + sessionInfo.Label = integritySid; + sessionInfo.ImpersonationLevel = tokenStats.ImpersonationLevel; + sessionInfo.TokenType = tokenStats.TokenType; + sessionInfo.Privileges = privilegeInfo; + sessionInfo.Rights = userRights; + return sessionInfo; + } + + private static ArrayList GetAccountRights(IntPtr lsaHandle, IntPtr sid) + { + UInt32 res; + ArrayList rights = new ArrayList(); + IntPtr userRightsPointer = IntPtr.Zero; + UInt64 countOfRights = 0; + + res = LsaEnumerateAccountRights(lsaHandle, sid, out userRightsPointer, out countOfRights); + if (res != 0 && res != STATUS_OBJECT_NAME_NOT_FOUND) + throw new Win32Exception(LsaNtStatusToWinError(res), "LsaEnumerateAccountRights() failed"); + else if (res != STATUS_OBJECT_NAME_NOT_FOUND) + { + LSA_UNICODE_STRING[] userRights = new LSA_UNICODE_STRING[countOfRights]; + PtrToStructureArray(userRights, userRightsPointer.ToInt64()); + rights = new ArrayList(); + foreach (LSA_UNICODE_STRING right in userRights) + rights.Add(Marshal.PtrToStringUni(right.buffer)); + } + + return rights; + } + + private static SessionInfo GetSessionDataForLogonSession(LUID logonSession) + { + uint res; + UInt64 count = 0; + IntPtr luidPtr = IntPtr.Zero; + SessionInfo sessionInfo = null; + UInt64 processDataId = ConvertLuidToUint(logonSession); + + res = LsaEnumerateLogonSessions(out count, out luidPtr); + if (res != 0) + throw new Win32Exception(LsaNtStatusToWinError(res), "LsaEnumerateLogonSessions() failed"); + Int64 luidAddr = luidPtr.ToInt64(); + + try + { + for (UInt64 i = 0; i < count; i++) + { + IntPtr dataPointer = IntPtr.Zero; + res = LsaGetLogonSessionData(luidPtr, out dataPointer); + if (res == STATUS_ACCESS_DENIED) // Non admins won't be able to get info for session's that are not their own + { + luidPtr = new IntPtr(luidPtr.ToInt64() + Marshal.SizeOf(typeof(LUID))); + continue; + } + else if (res != 0) + throw new Win32Exception(LsaNtStatusToWinError(res), String.Format("LsaGetLogonSessionData() failed {0}", res)); + + SECURITY_LOGON_SESSION_DATA sessionData = (SECURITY_LOGON_SESSION_DATA)Marshal.PtrToStructure(dataPointer, typeof(SECURITY_LOGON_SESSION_DATA)); + UInt64 sessionDataid = ConvertLuidToUint(sessionData.LogonId); + + if (sessionDataid == processDataId) + { + ArrayList userFlags = new ArrayList(); + UserFlags flags = (UserFlags)sessionData.UserFlags; + foreach (UserFlags flag in Enum.GetValues(typeof(UserFlags))) + { + if (flags.HasFlag(flag)) + { + string flagName = flag.ToString().Substring(6); + flagName = flagName.Replace('_', ' '); + flagName = flagName.First().ToString().ToUpper() + flagName.Substring(1).ToLower(); + userFlags.Add(flagName); + } + } + + sessionInfo = new SessionInfo() + { + AuthenticationPackage = Marshal.PtrToStringUni(sessionData.AuthenticationPackage.buffer), + DnsDomainName = Marshal.PtrToStringUni(sessionData.DnsDomainName.buffer), + LoginDomain = Marshal.PtrToStringUni(sessionData.LoginDomain.buffer), + LoginTime = ConvertIntegerToDateString(sessionData.LoginTime), + LogonId = ConvertLuidToUint(sessionData.LogonId), + LogonServer = Marshal.PtrToStringUni(sessionData.LogonServer.buffer), + LogonType = sessionData.LogonType, + Upn = Marshal.PtrToStringUni(sessionData.Upn.buffer), + UserFlags = userFlags, + Account = new Sid(sessionData.Sid) + }; + break; + } + luidPtr = new IntPtr(luidPtr.ToInt64() + Marshal.SizeOf(typeof(LUID))); + } + } + finally + { + LsaFreeReturnBuffer(new IntPtr(luidAddr)); + } + + if (sessionInfo == null) + throw new Exception(String.Format("Could not find the data for logon session {0}", processDataId)); + return sessionInfo; + } + + private static string ConvertIntegerToDateString(UInt64 time) + { + if (time == 0) + return null; + if (time > (UInt64)DateTime.MaxValue.ToFileTime()) + return null; + + DateTime dateTime = DateTime.FromFileTime((long)time); + return dateTime.ToString("o"); + } + + private static UInt64 ConvertLuidToUint(LUID luid) + { + UInt32 low = luid.LowPart; + UInt64 high = (UInt64)luid.HighPart; + high = high << 32; + UInt64 uintValue = (high | (UInt64)low); + return uintValue; + } + + private static void PtrToStructureArray<T>(T[] array, Int64 pointerAddress) + { + Int64 pointerOffset = pointerAddress; + for (int i = 0; i < array.Length; i++, pointerOffset += Marshal.SizeOf(typeof(T))) + array[i] = (T)Marshal.PtrToStructure(new IntPtr(pointerOffset), typeof(T)); + } + + public static IEnumerable<T> GetValues<T>() + { + return Enum.GetValues(typeof(T)).Cast<T>(); + } + } +} +'@ + +$original_tmp = $env:TMP +$env:TMP = $_remote_tmp +Add-Type -TypeDefinition $session_util +$env:TMP = $original_tmp + +$session_info = [Ansible.SessionUtil]::GetSessionInfo() + +Function Convert-Value($value) { + $new_value = $value + if ($value -is [System.Collections.ArrayList]) { + $new_value = [System.Collections.ArrayList]@() + foreach ($list_value in $value) { + $new_list_value = Convert-Value -value $list_value + [void]$new_value.Add($new_list_value) + } + } elseif ($value -is [Hashtable]) { + $new_value = @{} + foreach ($entry in $value.GetEnumerator()) { + $entry_value = Convert-Value -value $entry.Value + # manually convert Sid type entry to remove the SidType prefix + if ($entry.Name -eq "type") { + $entry_value = $entry_value.Replace("SidType", "") + } + $new_value[$entry.Name] = $entry_value + } + } elseif ($value -is [Ansible.Sid]) { + $new_value = @{ + sid = $value.SidString + account_name = $value.AccountName + domain_name = $value.DomainName + type = $value.SidType.ToString().Replace("SidType", "") + } + } elseif ($value -is [Enum]) { + $new_value = $value.ToString() + } + + return ,$new_value +} + +$result = @{ + changed = $false +} + +$properties = [type][Ansible.SessionInfo] +foreach ($property in $properties.DeclaredProperties) { + $property_name = $property.Name + $property_value = $session_info.$property_name + $snake_name = Convert-StringToSnakeCase -string $property_name + + $result.$snake_name = Convert-Value -value $property_value +} + +Exit-Json -obj $result diff --git a/test/support/windows-integration/plugins/modules/win_whoami.py b/test/support/windows-integration/plugins/modules/win_whoami.py new file mode 100644 index 00000000..d647374b --- /dev/null +++ b/test/support/windows-integration/plugins/modules/win_whoami.py @@ -0,0 +1,203 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# this is a windows documentation stub. actual code lives in the .ps1 +# file of the same name + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: win_whoami +version_added: "2.5" +short_description: Get information about the current user and process +description: +- Designed to return the same information as the C(whoami /all) command. +- Also includes information missing from C(whoami) such as logon metadata like + logon rights, id, type. +notes: +- If running this module with a non admin user, the logon rights will be an + empty list as Administrator rights are required to query LSA for the + information. +seealso: +- module: win_credential +- module: win_group_membership +- module: win_user_right +author: +- Jordan Borean (@jborean93) +''' + +EXAMPLES = r''' +- name: Get whoami information + win_whoami: +''' + +RETURN = r''' +authentication_package: + description: The name of the authentication package used to authenticate the + user in the session. + returned: success + type: str + sample: Negotiate +user_flags: + description: The user flags for the logon session, see UserFlags in + U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa380128). + returned: success + type: str + sample: Winlogon +upn: + description: The user principal name of the current user. + returned: success + type: str + sample: Administrator@DOMAIN.COM +logon_type: + description: The logon type that identifies the logon method, see + U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa380129.aspx). + returned: success + type: str + sample: Network +privileges: + description: A dictionary of privileges and their state on the logon token. + returned: success + type: dict + sample: { + "SeChangeNotifyPrivileges": "enabled-by-default", + "SeRemoteShutdownPrivilege": "disabled", + "SeDebugPrivilege": "enabled" + } +label: + description: The mandatory label set to the logon session. + returned: success + type: complex + contains: + domain_name: + description: The domain name of the label SID. + returned: success + type: str + sample: Mandatory Label + sid: + description: The SID in string form. + returned: success + type: str + sample: S-1-16-12288 + account_name: + description: The account name of the label SID. + returned: success + type: str + sample: High Mandatory Level + type: + description: The type of SID. + returned: success + type: str + sample: Label +impersonation_level: + description: The impersonation level of the token, only valid if + C(token_type) is C(TokenImpersonation), see + U(https://msdn.microsoft.com/en-us/library/windows/desktop/aa379572.aspx). + returned: success + type: str + sample: SecurityAnonymous +login_time: + description: The logon time in ISO 8601 format + returned: success + type: str + sample: '2017-11-27T06:24:14.3321665+10:00' +groups: + description: A list of groups and attributes that the user is a member of. + returned: success + type: list + sample: [ + { + "account_name": "Domain Users", + "domain_name": "DOMAIN", + "attributes": [ + "Mandatory", + "Enabled by default", + "Enabled" + ], + "sid": "S-1-5-21-1654078763-769949647-2968445802-513", + "type": "Group" + }, + { + "account_name": "Administrators", + "domain_name": "BUILTIN", + "attributes": [ + "Mandatory", + "Enabled by default", + "Enabled", + "Owner" + ], + "sid": "S-1-5-32-544", + "type": "Alias" + } + ] +account: + description: The running account SID details. + returned: success + type: complex + contains: + domain_name: + description: The domain name of the account SID. + returned: success + type: str + sample: DOMAIN + sid: + description: The SID in string form. + returned: success + type: str + sample: S-1-5-21-1654078763-769949647-2968445802-500 + account_name: + description: The account name of the account SID. + returned: success + type: str + sample: Administrator + type: + description: The type of SID. + returned: success + type: str + sample: User +login_domain: + description: The name of the domain used to authenticate the owner of the + session. + returned: success + type: str + sample: DOMAIN +rights: + description: A list of logon rights assigned to the logon. + returned: success and running user is a member of the local Administrators group + type: list + sample: [ + "SeNetworkLogonRight", + "SeInteractiveLogonRight", + "SeBatchLogonRight", + "SeRemoteInteractiveLogonRight" + ] +logon_server: + description: The name of the server used to authenticate the owner of the + logon session. + returned: success + type: str + sample: DC01 +logon_id: + description: The unique identifier of the logon session. + returned: success + type: int + sample: 20470143 +dns_domain_name: + description: The DNS name of the logon session, this is an empty string if + this is not set. + returned: success + type: str + sample: DOMAIN.COM +token_type: + description: The token type to indicate whether it is a primary or + impersonation token. + returned: success + type: str + sample: TokenPrimary +''' |