diff options
Diffstat (limited to 'test/support/integration')
91 files changed, 40328 insertions, 0 deletions
diff --git a/test/support/integration/plugins/cache/jsonfile.py b/test/support/integration/plugins/cache/jsonfile.py new file mode 100644 index 00000000..80b16f55 --- /dev/null +++ b/test/support/integration/plugins/cache/jsonfile.py @@ -0,0 +1,63 @@ +# (c) 2014, Brian Coca, Josh Drake, et al +# (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +# Make coding more python3-ish +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = ''' + cache: jsonfile + short_description: JSON formatted files. + description: + - This cache uses JSON formatted, per host, files saved to the filesystem. + version_added: "1.9" + author: Ansible Core (@ansible-core) + options: + _uri: + required: True + description: + - Path in which the cache plugin will save the JSON files + env: + - name: ANSIBLE_CACHE_PLUGIN_CONNECTION + ini: + - key: fact_caching_connection + section: defaults + _prefix: + description: User defined prefix to use when creating the JSON files + env: + - name: ANSIBLE_CACHE_PLUGIN_PREFIX + ini: + - key: fact_caching_prefix + section: defaults + _timeout: + default: 86400 + description: Expiration timeout in seconds for the cache plugin data. Set to 0 to never expire + env: + - name: ANSIBLE_CACHE_PLUGIN_TIMEOUT + ini: + - key: fact_caching_timeout + section: defaults + type: integer +''' + +import codecs +import json + +from ansible.parsing.ajson import AnsibleJSONEncoder, AnsibleJSONDecoder +from ansible.plugins.cache import BaseFileCacheModule + + +class CacheModule(BaseFileCacheModule): + """ + A caching module backed by json files. + """ + + def _load(self, filepath): + # Valid JSON is always UTF-8 encoded. + with codecs.open(filepath, 'r', encoding='utf-8') as f: + return json.load(f, cls=AnsibleJSONDecoder) + + def _dump(self, value, filepath): + with codecs.open(filepath, 'w', encoding='utf-8') as f: + f.write(json.dumps(value, cls=AnsibleJSONEncoder, sort_keys=True, indent=4)) diff --git a/test/support/integration/plugins/filter/json_query.py b/test/support/integration/plugins/filter/json_query.py new file mode 100644 index 00000000..d1da71b4 --- /dev/null +++ b/test/support/integration/plugins/filter/json_query.py @@ -0,0 +1,53 @@ +# (c) 2015, Filipe Niero Felisbino <filipenf@gmail.com> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +from ansible.errors import AnsibleError, AnsibleFilterError + +try: + import jmespath + HAS_LIB = True +except ImportError: + HAS_LIB = False + + +def json_query(data, expr): + '''Query data using jmespath query language ( http://jmespath.org ). Example: + - debug: msg="{{ instance | json_query(tagged_instances[*].block_device_mapping.*.volume_id') }}" + ''' + if not HAS_LIB: + raise AnsibleError('You need to install "jmespath" prior to running ' + 'json_query filter') + + try: + return jmespath.search(expr, data) + except jmespath.exceptions.JMESPathError as e: + raise AnsibleFilterError('JMESPathError in json_query filter plugin:\n%s' % e) + except Exception as e: + # For older jmespath, we can get ValueError and TypeError without much info. + raise AnsibleFilterError('Error in jmespath.search in json_query filter plugin:\n%s' % e) + + +class FilterModule(object): + ''' Query filter ''' + + def filters(self): + return { + 'json_query': json_query + } diff --git a/test/support/integration/plugins/inventory/aws_ec2.py b/test/support/integration/plugins/inventory/aws_ec2.py new file mode 100644 index 00000000..09c42cf9 --- /dev/null +++ b/test/support/integration/plugins/inventory/aws_ec2.py @@ -0,0 +1,760 @@ +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = ''' + name: aws_ec2 + plugin_type: inventory + short_description: EC2 inventory source + requirements: + - boto3 + - botocore + extends_documentation_fragment: + - inventory_cache + - constructed + description: + - Get inventory hosts from Amazon Web Services EC2. + - Uses a YAML configuration file that ends with C(aws_ec2.(yml|yaml)). + notes: + - If no credentials are provided and the control node has an associated IAM instance profile then the + role will be used for authentication. + author: + - Sloane Hertel (@s-hertel) + options: + aws_profile: + description: The AWS profile + type: str + aliases: [ boto_profile ] + env: + - name: AWS_DEFAULT_PROFILE + - name: AWS_PROFILE + aws_access_key: + description: The AWS access key to use. + type: str + aliases: [ aws_access_key_id ] + env: + - name: EC2_ACCESS_KEY + - name: AWS_ACCESS_KEY + - name: AWS_ACCESS_KEY_ID + aws_secret_key: + description: The AWS secret key that corresponds to the access key. + type: str + aliases: [ aws_secret_access_key ] + env: + - name: EC2_SECRET_KEY + - name: AWS_SECRET_KEY + - name: AWS_SECRET_ACCESS_KEY + aws_security_token: + description: The AWS security token if using temporary access and secret keys. + type: str + env: + - name: EC2_SECURITY_TOKEN + - name: AWS_SESSION_TOKEN + - name: AWS_SECURITY_TOKEN + plugin: + description: Token that ensures this is a source file for the plugin. + required: True + choices: ['aws_ec2'] + iam_role_arn: + description: The ARN of the IAM role to assume to perform the inventory lookup. You should still provide AWS + credentials with enough privilege to perform the AssumeRole action. + version_added: '2.9' + regions: + description: + - A list of regions in which to describe EC2 instances. + - If empty (the default) default this will include all regions, except possibly restricted ones like us-gov-west-1 and cn-north-1. + type: list + default: [] + hostnames: + description: + - A list in order of precedence for hostname variables. + - You can use the options specified in U(http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options). + - To use tags as hostnames use the syntax tag:Name=Value to use the hostname Name_Value, or tag:Name to use the value of the Name tag. + type: list + default: [] + filters: + description: + - A dictionary of filter value pairs. + - Available filters are listed here U(http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options). + type: dict + default: {} + include_extra_api_calls: + description: + - Add two additional API calls for every instance to include 'persistent' and 'events' host variables. + - Spot instances may be persistent and instances may have associated events. + type: bool + default: False + version_added: '2.8' + strict_permissions: + description: + - By default if a 403 (Forbidden) error code is encountered this plugin will fail. + - You can set this option to False in the inventory config file which will allow 403 errors to be gracefully skipped. + type: bool + default: True + use_contrib_script_compatible_sanitization: + description: + - By default this plugin is using a general group name sanitization to create safe and usable group names for use in Ansible. + This option allows you to override that, in efforts to allow migration from the old inventory script and + matches the sanitization of groups when the script's ``replace_dash_in_groups`` option is set to ``False``. + To replicate behavior of ``replace_dash_in_groups = True`` with constructed groups, + you will need to replace hyphens with underscores via the regex_replace filter for those entries. + - For this to work you should also turn off the TRANSFORM_INVALID_GROUP_CHARS setting, + otherwise the core engine will just use the standard sanitization on top. + - This is not the default as such names break certain functionality as not all characters are valid Python identifiers + which group names end up being used as. + type: bool + default: False + version_added: '2.8' +''' + +EXAMPLES = ''' +# Minimal example using environment vars or instance role credentials +# Fetch all hosts in us-east-1, the hostname is the public DNS if it exists, otherwise the private IP address +plugin: aws_ec2 +regions: + - us-east-1 + +# Example using filters, ignoring permission errors, and specifying the hostname precedence +plugin: aws_ec2 +boto_profile: aws_profile +# Populate inventory with instances in these regions +regions: + - us-east-1 + - us-east-2 +filters: + # All instances with their `Environment` tag set to `dev` + tag:Environment: dev + # All dev and QA hosts + tag:Environment: + - dev + - qa + instance.group-id: sg-xxxxxxxx +# Ignores 403 errors rather than failing +strict_permissions: False +# Note: I(hostnames) sets the inventory_hostname. To modify ansible_host without modifying +# inventory_hostname use compose (see example below). +hostnames: + - tag:Name=Tag1,Name=Tag2 # Return specific hosts only + - tag:CustomDNSName + - dns-name + - private-ip-address + +# Example using constructed features to create groups and set ansible_host +plugin: aws_ec2 +regions: + - us-east-1 + - us-west-1 +# keyed_groups may be used to create custom groups +strict: False +keyed_groups: + # Add e.g. x86_64 hosts to an arch_x86_64 group + - prefix: arch + key: 'architecture' + # Add hosts to tag_Name_Value groups for each Name/Value tag pair + - prefix: tag + key: tags + # Add hosts to e.g. instance_type_z3_tiny + - prefix: instance_type + key: instance_type + # Create security_groups_sg_abcd1234 group for each SG + - key: 'security_groups|json_query("[].group_id")' + prefix: 'security_groups' + # Create a group for each value of the Application tag + - key: tags.Application + separator: '' + # Create a group per region e.g. aws_region_us_east_2 + - key: placement.region + prefix: aws_region + # Create a group (or groups) based on the value of a custom tag "Role" and add them to a metagroup called "project" + - key: tags['Role'] + prefix: foo + parent_group: "project" +# Set individual variables with compose +compose: + # Use the private IP address to connect to the host + # (note: this does not modify inventory_hostname, which is set via I(hostnames)) + ansible_host: private_ip_address +''' + +import re + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_native, to_text +from ansible.module_utils.common.dict_transformations import camel_dict_to_snake_dict +from ansible.plugins.inventory import BaseInventoryPlugin, Constructable, Cacheable +from ansible.utils.display import Display +from ansible.module_utils.six import string_types + +try: + import boto3 + import botocore +except ImportError: + raise AnsibleError('The ec2 dynamic inventory plugin requires boto3 and botocore.') + +display = Display() + +# The mappings give an array of keys to get from the filter name to the value +# returned by boto3's EC2 describe_instances method. + +instance_meta_filter_to_boto_attr = { + 'group-id': ('Groups', 'GroupId'), + 'group-name': ('Groups', 'GroupName'), + 'network-interface.attachment.instance-owner-id': ('OwnerId',), + 'owner-id': ('OwnerId',), + 'requester-id': ('RequesterId',), + 'reservation-id': ('ReservationId',), +} + +instance_data_filter_to_boto_attr = { + 'affinity': ('Placement', 'Affinity'), + 'architecture': ('Architecture',), + 'availability-zone': ('Placement', 'AvailabilityZone'), + 'block-device-mapping.attach-time': ('BlockDeviceMappings', 'Ebs', 'AttachTime'), + 'block-device-mapping.delete-on-termination': ('BlockDeviceMappings', 'Ebs', 'DeleteOnTermination'), + 'block-device-mapping.device-name': ('BlockDeviceMappings', 'DeviceName'), + 'block-device-mapping.status': ('BlockDeviceMappings', 'Ebs', 'Status'), + 'block-device-mapping.volume-id': ('BlockDeviceMappings', 'Ebs', 'VolumeId'), + 'client-token': ('ClientToken',), + 'dns-name': ('PublicDnsName',), + 'host-id': ('Placement', 'HostId'), + 'hypervisor': ('Hypervisor',), + 'iam-instance-profile.arn': ('IamInstanceProfile', 'Arn'), + 'image-id': ('ImageId',), + 'instance-id': ('InstanceId',), + 'instance-lifecycle': ('InstanceLifecycle',), + 'instance-state-code': ('State', 'Code'), + 'instance-state-name': ('State', 'Name'), + 'instance-type': ('InstanceType',), + 'instance.group-id': ('SecurityGroups', 'GroupId'), + 'instance.group-name': ('SecurityGroups', 'GroupName'), + 'ip-address': ('PublicIpAddress',), + 'kernel-id': ('KernelId',), + 'key-name': ('KeyName',), + 'launch-index': ('AmiLaunchIndex',), + 'launch-time': ('LaunchTime',), + 'monitoring-state': ('Monitoring', 'State'), + 'network-interface.addresses.private-ip-address': ('NetworkInterfaces', 'PrivateIpAddress'), + 'network-interface.addresses.primary': ('NetworkInterfaces', 'PrivateIpAddresses', 'Primary'), + 'network-interface.addresses.association.public-ip': ('NetworkInterfaces', 'PrivateIpAddresses', 'Association', 'PublicIp'), + 'network-interface.addresses.association.ip-owner-id': ('NetworkInterfaces', 'PrivateIpAddresses', 'Association', 'IpOwnerId'), + 'network-interface.association.public-ip': ('NetworkInterfaces', 'Association', 'PublicIp'), + 'network-interface.association.ip-owner-id': ('NetworkInterfaces', 'Association', 'IpOwnerId'), + 'network-interface.association.allocation-id': ('ElasticGpuAssociations', 'ElasticGpuId'), + 'network-interface.association.association-id': ('ElasticGpuAssociations', 'ElasticGpuAssociationId'), + 'network-interface.attachment.attachment-id': ('NetworkInterfaces', 'Attachment', 'AttachmentId'), + 'network-interface.attachment.instance-id': ('InstanceId',), + 'network-interface.attachment.device-index': ('NetworkInterfaces', 'Attachment', 'DeviceIndex'), + 'network-interface.attachment.status': ('NetworkInterfaces', 'Attachment', 'Status'), + 'network-interface.attachment.attach-time': ('NetworkInterfaces', 'Attachment', 'AttachTime'), + 'network-interface.attachment.delete-on-termination': ('NetworkInterfaces', 'Attachment', 'DeleteOnTermination'), + 'network-interface.availability-zone': ('Placement', 'AvailabilityZone'), + 'network-interface.description': ('NetworkInterfaces', 'Description'), + 'network-interface.group-id': ('NetworkInterfaces', 'Groups', 'GroupId'), + 'network-interface.group-name': ('NetworkInterfaces', 'Groups', 'GroupName'), + 'network-interface.ipv6-addresses.ipv6-address': ('NetworkInterfaces', 'Ipv6Addresses', 'Ipv6Address'), + 'network-interface.mac-address': ('NetworkInterfaces', 'MacAddress'), + 'network-interface.network-interface-id': ('NetworkInterfaces', 'NetworkInterfaceId'), + 'network-interface.owner-id': ('NetworkInterfaces', 'OwnerId'), + 'network-interface.private-dns-name': ('NetworkInterfaces', 'PrivateDnsName'), + # 'network-interface.requester-id': (), + 'network-interface.requester-managed': ('NetworkInterfaces', 'Association', 'IpOwnerId'), + 'network-interface.status': ('NetworkInterfaces', 'Status'), + 'network-interface.source-dest-check': ('NetworkInterfaces', 'SourceDestCheck'), + 'network-interface.subnet-id': ('NetworkInterfaces', 'SubnetId'), + 'network-interface.vpc-id': ('NetworkInterfaces', 'VpcId'), + 'placement-group-name': ('Placement', 'GroupName'), + 'platform': ('Platform',), + 'private-dns-name': ('PrivateDnsName',), + 'private-ip-address': ('PrivateIpAddress',), + 'product-code': ('ProductCodes', 'ProductCodeId'), + 'product-code.type': ('ProductCodes', 'ProductCodeType'), + 'ramdisk-id': ('RamdiskId',), + 'reason': ('StateTransitionReason',), + 'root-device-name': ('RootDeviceName',), + 'root-device-type': ('RootDeviceType',), + 'source-dest-check': ('SourceDestCheck',), + 'spot-instance-request-id': ('SpotInstanceRequestId',), + 'state-reason-code': ('StateReason', 'Code'), + 'state-reason-message': ('StateReason', 'Message'), + 'subnet-id': ('SubnetId',), + 'tag': ('Tags',), + 'tag-key': ('Tags',), + 'tag-value': ('Tags',), + 'tenancy': ('Placement', 'Tenancy'), + 'virtualization-type': ('VirtualizationType',), + 'vpc-id': ('VpcId',), +} + + +class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): + + NAME = 'aws_ec2' + + def __init__(self): + super(InventoryModule, self).__init__() + + self.group_prefix = 'aws_ec2_' + + # credentials + self.boto_profile = None + self.aws_secret_access_key = None + self.aws_access_key_id = None + self.aws_security_token = None + self.iam_role_arn = None + + def _compile_values(self, obj, attr): + ''' + :param obj: A list or dict of instance attributes + :param attr: A key + :return The value(s) found via the attr + ''' + if obj is None: + return + + temp_obj = [] + + if isinstance(obj, list) or isinstance(obj, tuple): + for each in obj: + value = self._compile_values(each, attr) + if value: + temp_obj.append(value) + else: + temp_obj = obj.get(attr) + + has_indexes = any([isinstance(temp_obj, list), isinstance(temp_obj, tuple)]) + if has_indexes and len(temp_obj) == 1: + return temp_obj[0] + + return temp_obj + + def _get_boto_attr_chain(self, filter_name, instance): + ''' + :param filter_name: The filter + :param instance: instance dict returned by boto3 ec2 describe_instances() + ''' + allowed_filters = sorted(list(instance_data_filter_to_boto_attr.keys()) + list(instance_meta_filter_to_boto_attr.keys())) + if filter_name not in allowed_filters: + raise AnsibleError("Invalid filter '%s' provided; filter must be one of %s." % (filter_name, + allowed_filters)) + if filter_name in instance_data_filter_to_boto_attr: + boto_attr_list = instance_data_filter_to_boto_attr[filter_name] + else: + boto_attr_list = instance_meta_filter_to_boto_attr[filter_name] + + instance_value = instance + for attribute in boto_attr_list: + instance_value = self._compile_values(instance_value, attribute) + return instance_value + + def _get_credentials(self): + ''' + :return A dictionary of boto client credentials + ''' + boto_params = {} + for credential in (('aws_access_key_id', self.aws_access_key_id), + ('aws_secret_access_key', self.aws_secret_access_key), + ('aws_session_token', self.aws_security_token)): + if credential[1]: + boto_params[credential[0]] = credential[1] + + return boto_params + + def _get_connection(self, credentials, region='us-east-1'): + try: + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region, **credentials) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + if self.boto_profile: + try: + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + else: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + return connection + + def _boto3_assume_role(self, credentials, region): + """ + Assume an IAM role passed by iam_role_arn parameter + + :return: a dict containing the credentials of the assumed role + """ + + iam_role_arn = self.iam_role_arn + + try: + sts_connection = boto3.session.Session(profile_name=self.boto_profile).client('sts', region, **credentials) + sts_session = sts_connection.assume_role(RoleArn=iam_role_arn, RoleSessionName='ansible_aws_ec2_dynamic_inventory') + return dict( + aws_access_key_id=sts_session['Credentials']['AccessKeyId'], + aws_secret_access_key=sts_session['Credentials']['SecretAccessKey'], + aws_session_token=sts_session['Credentials']['SessionToken'] + ) + except botocore.exceptions.ClientError as e: + raise AnsibleError("Unable to assume IAM role: %s" % to_native(e)) + + def _boto3_conn(self, regions): + ''' + :param regions: A list of regions to create a boto3 client + + Generator that yields a boto3 client and the region + ''' + + credentials = self._get_credentials() + iam_role_arn = self.iam_role_arn + + if not regions: + try: + # as per https://boto3.amazonaws.com/v1/documentation/api/latest/guide/ec2-example-regions-avail-zones.html + client = self._get_connection(credentials) + resp = client.describe_regions() + regions = [x['RegionName'] for x in resp.get('Regions', [])] + except botocore.exceptions.NoRegionError: + # above seems to fail depending on boto3 version, ignore and lets try something else + pass + + # fallback to local list hardcoded in boto3 if still no regions + if not regions: + session = boto3.Session() + regions = session.get_available_regions('ec2') + + # I give up, now you MUST give me regions + if not regions: + raise AnsibleError('Unable to get regions list from available methods, you must specify the "regions" option to continue.') + + for region in regions: + connection = self._get_connection(credentials, region) + try: + if iam_role_arn is not None: + assumed_credentials = self._boto3_assume_role(credentials, region) + else: + assumed_credentials = credentials + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region, **assumed_credentials) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + if self.boto_profile: + try: + connection = boto3.session.Session(profile_name=self.boto_profile).client('ec2', region) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError) as e: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + else: + raise AnsibleError("Insufficient credentials found: %s" % to_native(e)) + yield connection, region + + def _get_instances_by_region(self, regions, filters, strict_permissions): + ''' + :param regions: a list of regions in which to describe instances + :param filters: a list of boto3 filter dictionaries + :param strict_permissions: a boolean determining whether to fail or ignore 403 error codes + :return A list of instance dictionaries + ''' + all_instances = [] + + for connection, region in self._boto3_conn(regions): + try: + # By default find non-terminated/terminating instances + if not any([f['Name'] == 'instance-state-name' for f in filters]): + filters.append({'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopping', 'stopped']}) + paginator = connection.get_paginator('describe_instances') + reservations = paginator.paginate(Filters=filters).build_full_result().get('Reservations') + instances = [] + for r in reservations: + new_instances = r['Instances'] + for instance in new_instances: + instance.update(self._get_reservation_details(r)) + if self.get_option('include_extra_api_calls'): + instance.update(self._get_event_set_and_persistence(connection, instance['InstanceId'], instance.get('SpotInstanceRequestId'))) + instances.extend(new_instances) + except botocore.exceptions.ClientError as e: + if e.response['ResponseMetadata']['HTTPStatusCode'] == 403 and not strict_permissions: + instances = [] + else: + raise AnsibleError("Failed to describe instances: %s" % to_native(e)) + except botocore.exceptions.BotoCoreError as e: + raise AnsibleError("Failed to describe instances: %s" % to_native(e)) + + all_instances.extend(instances) + + return sorted(all_instances, key=lambda x: x['InstanceId']) + + def _get_reservation_details(self, reservation): + return { + 'OwnerId': reservation['OwnerId'], + 'RequesterId': reservation.get('RequesterId', ''), + 'ReservationId': reservation['ReservationId'] + } + + def _get_event_set_and_persistence(self, connection, instance_id, spot_instance): + host_vars = {'Events': '', 'Persistent': False} + try: + kwargs = {'InstanceIds': [instance_id]} + host_vars['Events'] = connection.describe_instance_status(**kwargs)['InstanceStatuses'][0].get('Events', '') + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + if not self.get_option('strict_permissions'): + pass + else: + raise AnsibleError("Failed to describe instance status: %s" % to_native(e)) + if spot_instance: + try: + kwargs = {'SpotInstanceRequestIds': [spot_instance]} + host_vars['Persistent'] = bool( + connection.describe_spot_instance_requests(**kwargs)['SpotInstanceRequests'][0].get('Type') == 'persistent' + ) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + if not self.get_option('strict_permissions'): + pass + else: + raise AnsibleError("Failed to describe spot instance requests: %s" % to_native(e)) + return host_vars + + def _get_tag_hostname(self, preference, instance): + tag_hostnames = preference.split('tag:', 1)[1] + if ',' in tag_hostnames: + tag_hostnames = tag_hostnames.split(',') + else: + tag_hostnames = [tag_hostnames] + tags = boto3_tag_list_to_ansible_dict(instance.get('Tags', [])) + for v in tag_hostnames: + if '=' in v: + tag_name, tag_value = v.split('=') + if tags.get(tag_name) == tag_value: + return to_text(tag_name) + "_" + to_text(tag_value) + else: + tag_value = tags.get(v) + if tag_value: + return to_text(tag_value) + return None + + def _get_hostname(self, instance, hostnames): + ''' + :param instance: an instance dict returned by boto3 ec2 describe_instances() + :param hostnames: a list of hostname destination variables in order of preference + :return the preferred identifer for the host + ''' + if not hostnames: + hostnames = ['dns-name', 'private-dns-name'] + + hostname = None + for preference in hostnames: + if 'tag' in preference: + if not preference.startswith('tag:'): + raise AnsibleError("To name a host by tags name_value, use 'tag:name=value'.") + hostname = self._get_tag_hostname(preference, instance) + else: + hostname = self._get_boto_attr_chain(preference, instance) + if hostname: + break + if hostname: + if ':' in to_text(hostname): + return self._sanitize_group_name((to_text(hostname))) + else: + return to_text(hostname) + + def _query(self, regions, filters, strict_permissions): + ''' + :param regions: a list of regions to query + :param filters: a list of boto3 filter dictionaries + :param hostnames: a list of hostname destination variables in order of preference + :param strict_permissions: a boolean determining whether to fail or ignore 403 error codes + ''' + return {'aws_ec2': self._get_instances_by_region(regions, filters, strict_permissions)} + + def _populate(self, groups, hostnames): + for group in groups: + group = self.inventory.add_group(group) + self._add_hosts(hosts=groups[group], group=group, hostnames=hostnames) + self.inventory.add_child('all', group) + + def _add_hosts(self, hosts, group, hostnames): + ''' + :param hosts: a list of hosts to be added to a group + :param group: the name of the group to which the hosts belong + :param hostnames: a list of hostname destination variables in order of preference + ''' + for host in hosts: + hostname = self._get_hostname(host, hostnames) + + host = camel_dict_to_snake_dict(host, ignore_list=['Tags']) + host['tags'] = boto3_tag_list_to_ansible_dict(host.get('tags', [])) + + # Allow easier grouping by region + host['placement']['region'] = host['placement']['availability_zone'][:-1] + + if not hostname: + continue + self.inventory.add_host(hostname, group=group) + for hostvar, hostval in host.items(): + self.inventory.set_variable(hostname, hostvar, hostval) + + # Use constructed if applicable + + strict = self.get_option('strict') + + # Composed variables + self._set_composite_vars(self.get_option('compose'), host, hostname, strict=strict) + + # Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group + self._add_host_to_composed_groups(self.get_option('groups'), host, hostname, strict=strict) + + # Create groups based on variable values and add the corresponding hosts to it + self._add_host_to_keyed_groups(self.get_option('keyed_groups'), host, hostname, strict=strict) + + def _set_credentials(self): + ''' + :param config_data: contents of the inventory config file + ''' + + self.boto_profile = self.get_option('aws_profile') + self.aws_access_key_id = self.get_option('aws_access_key') + self.aws_secret_access_key = self.get_option('aws_secret_key') + self.aws_security_token = self.get_option('aws_security_token') + self.iam_role_arn = self.get_option('iam_role_arn') + + if not self.boto_profile and not (self.aws_access_key_id and self.aws_secret_access_key): + session = botocore.session.get_session() + try: + credentials = session.get_credentials().get_frozen_credentials() + except AttributeError: + pass + else: + self.aws_access_key_id = credentials.access_key + self.aws_secret_access_key = credentials.secret_key + self.aws_security_token = credentials.token + + if not self.boto_profile and not (self.aws_access_key_id and self.aws_secret_access_key): + raise AnsibleError("Insufficient boto credentials found. Please provide them in your " + "inventory configuration file or set them as environment variables.") + + def verify_file(self, path): + ''' + :param loader: an ansible.parsing.dataloader.DataLoader object + :param path: the path to the inventory config file + :return the contents of the config file + ''' + if super(InventoryModule, self).verify_file(path): + if path.endswith(('aws_ec2.yml', 'aws_ec2.yaml')): + return True + display.debug("aws_ec2 inventory filename must end with 'aws_ec2.yml' or 'aws_ec2.yaml'") + return False + + def parse(self, inventory, loader, path, cache=True): + + super(InventoryModule, self).parse(inventory, loader, path) + + self._read_config_data(path) + + if self.get_option('use_contrib_script_compatible_sanitization'): + self._sanitize_group_name = self._legacy_script_compatible_group_sanitization + + self._set_credentials() + + # get user specifications + regions = self.get_option('regions') + filters = ansible_dict_to_boto3_filter_list(self.get_option('filters')) + hostnames = self.get_option('hostnames') + strict_permissions = self.get_option('strict_permissions') + + cache_key = self.get_cache_key(path) + # false when refresh_cache or --flush-cache is used + if cache: + # get the user-specified directive + cache = self.get_option('cache') + + # Generate inventory + cache_needs_update = False + if cache: + try: + results = self._cache[cache_key] + except KeyError: + # if cache expires or cache file doesn't exist + cache_needs_update = True + + if not cache or cache_needs_update: + results = self._query(regions, filters, strict_permissions) + + self._populate(results, hostnames) + + # If the cache has expired/doesn't exist or if refresh_inventory/flush cache is used + # when the user is using caching, update the cached inventory + if cache_needs_update or (not cache and self.get_option('cache')): + self._cache[cache_key] = results + + @staticmethod + def _legacy_script_compatible_group_sanitization(name): + + # note that while this mirrors what the script used to do, it has many issues with unicode and usability in python + regex = re.compile(r"[^A-Za-z0-9\_\-]") + + return regex.sub('_', name) + + +def ansible_dict_to_boto3_filter_list(filters_dict): + + """ Convert an Ansible dict of filters to list of dicts that boto3 can use + Args: + filters_dict (dict): Dict of AWS filters. + Basic Usage: + >>> filters = {'some-aws-id': 'i-01234567'} + >>> ansible_dict_to_boto3_filter_list(filters) + { + 'some-aws-id': 'i-01234567' + } + Returns: + List: List of AWS filters and their values + [ + { + 'Name': 'some-aws-id', + 'Values': [ + 'i-01234567', + ] + } + ] + """ + + filters_list = [] + for k, v in filters_dict.items(): + filter_dict = {'Name': k} + if isinstance(v, string_types): + filter_dict['Values'] = [v] + else: + filter_dict['Values'] = v + + filters_list.append(filter_dict) + + return filters_list + + +def boto3_tag_list_to_ansible_dict(tags_list, tag_name_key_name=None, tag_value_key_name=None): + + """ Convert a boto3 list of resource tags to a flat dict of key:value pairs + Args: + tags_list (list): List of dicts representing AWS tags. + tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key") + tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value") + Basic Usage: + >>> tags_list = [{'Key': 'MyTagKey', 'Value': 'MyTagValue'}] + >>> boto3_tag_list_to_ansible_dict(tags_list) + [ + { + 'Key': 'MyTagKey', + 'Value': 'MyTagValue' + } + ] + Returns: + Dict: Dict of key:value pairs representing AWS tags + { + 'MyTagKey': 'MyTagValue', + } + """ + + if tag_name_key_name and tag_value_key_name: + tag_candidates = {tag_name_key_name: tag_value_key_name} + else: + tag_candidates = {'key': 'value', 'Key': 'Value'} + + if not tags_list: + return {} + for k, v in tag_candidates.items(): + if k in tags_list[0] and v in tags_list[0]: + return dict((tag[k], tag[v]) for tag in tags_list) + raise ValueError("Couldn't find tag key (candidates %s) in tag list %s" % (str(tag_candidates), str(tags_list))) diff --git a/test/support/integration/plugins/inventory/docker_swarm.py b/test/support/integration/plugins/inventory/docker_swarm.py new file mode 100644 index 00000000..d0a95ca0 --- /dev/null +++ b/test/support/integration/plugins/inventory/docker_swarm.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2018, Stefan Heitmueller <stefan.heitmueller@gmx.com> +# Copyright (c) 2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) + +__metaclass__ = type + +DOCUMENTATION = ''' + name: docker_swarm + plugin_type: inventory + version_added: '2.8' + author: + - Stefan Heitmüller (@morph027) <stefan.heitmueller@gmx.com> + short_description: Ansible dynamic inventory plugin for Docker swarm nodes. + requirements: + - python >= 2.7 + - L(Docker SDK for Python,https://docker-py.readthedocs.io/en/stable/) >= 1.10.0 + extends_documentation_fragment: + - constructed + description: + - Reads inventories from the Docker swarm API. + - Uses a YAML configuration file docker_swarm.[yml|yaml]. + - "The plugin returns following groups of swarm nodes: I(all) - all hosts; I(workers) - all worker nodes; + I(managers) - all manager nodes; I(leader) - the swarm leader node; + I(nonleaders) - all nodes except the swarm leader." + options: + plugin: + description: The name of this plugin, it should always be set to C(docker_swarm) for this plugin to + recognize it as it's own. + type: str + required: true + choices: docker_swarm + docker_host: + description: + - Socket of a Docker swarm manager node (C(tcp), C(unix)). + - "Use C(unix://var/run/docker.sock) to connect via local socket." + type: str + required: true + aliases: [ docker_url ] + verbose_output: + description: Toggle to (not) include all available nodes metadata (e.g. C(Platform), C(Architecture), C(OS), + C(EngineVersion)) + type: bool + default: yes + tls: + description: Connect using TLS without verifying the authenticity of the Docker host server. + type: bool + default: no + validate_certs: + description: Toggle if connecting using TLS with or without verifying the authenticity of the Docker + host server. + type: bool + default: no + aliases: [ tls_verify ] + client_key: + description: Path to the client's TLS key file. + type: path + aliases: [ tls_client_key, key_path ] + ca_cert: + description: Use a CA certificate when performing server verification by providing the path to a CA + certificate file. + type: path + aliases: [ tls_ca_cert, cacert_path ] + client_cert: + description: Path to the client's TLS certificate file. + type: path + aliases: [ tls_client_cert, cert_path ] + tls_hostname: + description: When verifying the authenticity of the Docker host server, provide the expected name of + the server. + type: str + ssl_version: + description: Provide a valid SSL version number. Default value determined by ssl.py module. + type: str + api_version: + description: + - The version of the Docker API running on the Docker Host. + - Defaults to the latest version of the API supported by docker-py. + type: str + aliases: [ docker_api_version ] + timeout: + description: + - The maximum amount of time in seconds to wait on a response from the API. + - If the value is not specified in the task, the value of environment variable C(DOCKER_TIMEOUT) + will be used instead. If the environment variable is not set, the default value will be used. + type: int + default: 60 + aliases: [ time_out ] + include_host_uri: + description: Toggle to return the additional attribute C(ansible_host_uri) which contains the URI of the + swarm leader in format of C(tcp://172.16.0.1:2376). This value may be used without additional + modification as value of option I(docker_host) in Docker Swarm modules when connecting via API. + The port always defaults to C(2376). + type: bool + default: no + include_host_uri_port: + description: Override the detected port number included in I(ansible_host_uri) + type: int +''' + +EXAMPLES = ''' +# Minimal example using local docker +plugin: docker_swarm +docker_host: unix://var/run/docker.sock + +# Minimal example using remote docker +plugin: docker_swarm +docker_host: tcp://my-docker-host:2375 + +# Example using remote docker with unverified TLS +plugin: docker_swarm +docker_host: tcp://my-docker-host:2376 +tls: yes + +# Example using remote docker with verified TLS and client certificate verification +plugin: docker_swarm +docker_host: tcp://my-docker-host:2376 +validate_certs: yes +ca_cert: /somewhere/ca.pem +client_key: /somewhere/key.pem +client_cert: /somewhere/cert.pem + +# Example using constructed features to create groups and set ansible_host +plugin: docker_swarm +docker_host: tcp://my-docker-host:2375 +strict: False +keyed_groups: + # add e.g. x86_64 hosts to an arch_x86_64 group + - prefix: arch + key: 'Description.Platform.Architecture' + # add e.g. linux hosts to an os_linux group + - prefix: os + key: 'Description.Platform.OS' + # create a group per node label + # e.g. a node labeled w/ "production" ends up in group "label_production" + # hint: labels containing special characters will be converted to safe names + - key: 'Spec.Labels' + prefix: label +''' + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_native +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ansible.plugins.inventory import BaseInventoryPlugin, Constructable +from ansible.parsing.utils.addresses import parse_address + +try: + import docker + from docker.errors import TLSParameterError + from docker.tls import TLSConfig + HAS_DOCKER = True +except ImportError: + HAS_DOCKER = False + + +def update_tls_hostname(result): + if result['tls_hostname'] is None: + # get default machine name from the url + parsed_url = urlparse(result['docker_host']) + if ':' in parsed_url.netloc: + result['tls_hostname'] = parsed_url.netloc[:parsed_url.netloc.rindex(':')] + else: + result['tls_hostname'] = parsed_url + + +def _get_tls_config(fail_function, **kwargs): + try: + tls_config = TLSConfig(**kwargs) + return tls_config + except TLSParameterError as exc: + fail_function("TLS config error: %s" % exc) + + +def get_connect_params(auth, fail_function): + if auth['tls'] or auth['tls_verify']: + auth['docker_host'] = auth['docker_host'].replace('tcp://', 'https://') + + if auth['tls_verify'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and host verification + if auth['cacert_path']: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + ca_cert=auth['cacert_path'], + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + else: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify'] and auth['cacert_path']: + # TLS with cacert only + tls_config = _get_tls_config(ca_cert=auth['cacert_path'], + assert_hostname=auth['tls_hostname'], + verify=True, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify']: + # TLS with verify and no certs + tls_config = _get_tls_config(verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and no host verification + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls']: + # TLS with no certs and not host verification + tls_config = _get_tls_config(verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + # No TLS + return dict(base_url=auth['docker_host'], + version=auth['api_version'], + timeout=auth['timeout']) + + +class InventoryModule(BaseInventoryPlugin, Constructable): + ''' Host inventory parser for ansible using Docker swarm as source. ''' + + NAME = 'docker_swarm' + + def _fail(self, msg): + raise AnsibleError(msg) + + def _populate(self): + raw_params = dict( + docker_host=self.get_option('docker_host'), + tls=self.get_option('tls'), + tls_verify=self.get_option('validate_certs'), + key_path=self.get_option('client_key'), + cacert_path=self.get_option('ca_cert'), + cert_path=self.get_option('client_cert'), + tls_hostname=self.get_option('tls_hostname'), + api_version=self.get_option('api_version'), + timeout=self.get_option('timeout'), + ssl_version=self.get_option('ssl_version'), + debug=None, + ) + update_tls_hostname(raw_params) + connect_params = get_connect_params(raw_params, fail_function=self._fail) + self.client = docker.DockerClient(**connect_params) + self.inventory.add_group('all') + self.inventory.add_group('manager') + self.inventory.add_group('worker') + self.inventory.add_group('leader') + self.inventory.add_group('nonleaders') + + if self.get_option('include_host_uri'): + if self.get_option('include_host_uri_port'): + host_uri_port = str(self.get_option('include_host_uri_port')) + elif self.get_option('tls') or self.get_option('validate_certs'): + host_uri_port = '2376' + else: + host_uri_port = '2375' + + try: + self.nodes = self.client.nodes.list() + for self.node in self.nodes: + self.node_attrs = self.client.nodes.get(self.node.id).attrs + self.inventory.add_host(self.node_attrs['ID']) + self.inventory.add_host(self.node_attrs['ID'], group=self.node_attrs['Spec']['Role']) + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host', + self.node_attrs['Status']['Addr']) + if self.get_option('include_host_uri'): + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host_uri', + 'tcp://' + self.node_attrs['Status']['Addr'] + ':' + host_uri_port) + if self.get_option('verbose_output'): + self.inventory.set_variable(self.node_attrs['ID'], 'docker_swarm_node_attributes', self.node_attrs) + if 'ManagerStatus' in self.node_attrs: + if self.node_attrs['ManagerStatus'].get('Leader'): + # This is workaround of bug in Docker when in some cases the Leader IP is 0.0.0.0 + # Check moby/moby#35437 for details + swarm_leader_ip = parse_address(self.node_attrs['ManagerStatus']['Addr'])[0] or \ + self.node_attrs['Status']['Addr'] + if self.get_option('include_host_uri'): + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host_uri', + 'tcp://' + swarm_leader_ip + ':' + host_uri_port) + self.inventory.set_variable(self.node_attrs['ID'], 'ansible_host', swarm_leader_ip) + self.inventory.add_host(self.node_attrs['ID'], group='leader') + else: + self.inventory.add_host(self.node_attrs['ID'], group='nonleaders') + else: + self.inventory.add_host(self.node_attrs['ID'], group='nonleaders') + # Use constructed if applicable + strict = self.get_option('strict') + # Composed variables + self._set_composite_vars(self.get_option('compose'), + self.node_attrs, + self.node_attrs['ID'], + strict=strict) + # Complex groups based on jinja2 conditionals, hosts that meet the conditional are added to group + self._add_host_to_composed_groups(self.get_option('groups'), + self.node_attrs, + self.node_attrs['ID'], + strict=strict) + # Create groups based on variable values and add the corresponding hosts to it + self._add_host_to_keyed_groups(self.get_option('keyed_groups'), + self.node_attrs, + self.node_attrs['ID'], + strict=strict) + except Exception as e: + raise AnsibleError('Unable to fetch hosts from Docker swarm API, this was the original exception: %s' % + to_native(e)) + + def verify_file(self, path): + """Return the possibly of a file being consumable by this plugin.""" + return ( + super(InventoryModule, self).verify_file(path) and + path.endswith((self.NAME + '.yaml', self.NAME + '.yml'))) + + def parse(self, inventory, loader, path, cache=True): + if not HAS_DOCKER: + raise AnsibleError('The Docker swarm dynamic inventory plugin requires the Docker SDK for Python: ' + 'https://github.com/docker/docker-py.') + super(InventoryModule, self).parse(inventory, loader, path, cache) + self._read_config_data(path) + self._populate() diff --git a/test/support/integration/plugins/inventory/foreman.py b/test/support/integration/plugins/inventory/foreman.py new file mode 100644 index 00000000..43073f81 --- /dev/null +++ b/test/support/integration/plugins/inventory/foreman.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2016 Guido Günther <agx@sigxcpu.org>, Daniel Lobato Garcia <dlobatog@redhat.com> +# Copyright (c) 2018 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = ''' + name: foreman + plugin_type: inventory + short_description: foreman inventory source + version_added: "2.6" + requirements: + - requests >= 1.1 + description: + - Get inventory hosts from the foreman service. + - "Uses a configuration file as an inventory source, it must end in ``.foreman.yml`` or ``.foreman.yaml`` and has a ``plugin: foreman`` entry." + extends_documentation_fragment: + - inventory_cache + - constructed + options: + plugin: + description: the name of this plugin, it should always be set to 'foreman' for this plugin to recognize it as it's own. + required: True + choices: ['foreman'] + url: + description: url to foreman + default: 'http://localhost:3000' + env: + - name: FOREMAN_SERVER + version_added: "2.8" + user: + description: foreman authentication user + required: True + env: + - name: FOREMAN_USER + version_added: "2.8" + password: + description: foreman authentication password + required: True + env: + - name: FOREMAN_PASSWORD + version_added: "2.8" + validate_certs: + description: verify SSL certificate if using https + type: boolean + default: False + group_prefix: + description: prefix to apply to foreman groups + default: foreman_ + vars_prefix: + description: prefix to apply to host variables, does not include facts nor params + default: foreman_ + want_facts: + description: Toggle, if True the plugin will retrieve host facts from the server + type: boolean + default: False + want_params: + description: Toggle, if true the inventory will retrieve 'all_parameters' information as host vars + type: boolean + default: False + want_hostcollections: + description: Toggle, if true the plugin will create Ansible groups for host collections + type: boolean + default: False + version_added: '2.10' + want_ansible_ssh_host: + description: Toggle, if true the plugin will populate the ansible_ssh_host variable to explicitly specify the connection target + type: boolean + default: False + version_added: '2.10' + +''' + +EXAMPLES = ''' +# my.foreman.yml +plugin: foreman +url: http://localhost:2222 +user: ansible-tester +password: secure +validate_certs: False +''' + +from distutils.version import LooseVersion + +from ansible.errors import AnsibleError +from ansible.module_utils._text import to_bytes, to_native, to_text +from ansible.module_utils.common._collections_compat import MutableMapping +from ansible.plugins.inventory import BaseInventoryPlugin, Cacheable, to_safe_group_name, Constructable + +# 3rd party imports +try: + import requests + if LooseVersion(requests.__version__) < LooseVersion('1.1.0'): + raise ImportError +except ImportError: + raise AnsibleError('This script requires python-requests 1.1 as a minimum version') + +from requests.auth import HTTPBasicAuth + + +class InventoryModule(BaseInventoryPlugin, Cacheable, Constructable): + ''' Host inventory parser for ansible using foreman as source. ''' + + NAME = 'foreman' + + def __init__(self): + + super(InventoryModule, self).__init__() + + # from config + self.foreman_url = None + + self.session = None + self.cache_key = None + self.use_cache = None + + def verify_file(self, path): + + valid = False + if super(InventoryModule, self).verify_file(path): + if path.endswith(('foreman.yaml', 'foreman.yml')): + valid = True + else: + self.display.vvv('Skipping due to inventory source not ending in "foreman.yaml" nor "foreman.yml"') + return valid + + def _get_session(self): + if not self.session: + self.session = requests.session() + self.session.auth = HTTPBasicAuth(self.get_option('user'), to_bytes(self.get_option('password'))) + self.session.verify = self.get_option('validate_certs') + return self.session + + def _get_json(self, url, ignore_errors=None): + + if not self.use_cache or url not in self._cache.get(self.cache_key, {}): + + if self.cache_key not in self._cache: + self._cache[self.cache_key] = {url: ''} + + results = [] + s = self._get_session() + params = {'page': 1, 'per_page': 250} + while True: + ret = s.get(url, params=params) + if ignore_errors and ret.status_code in ignore_errors: + break + ret.raise_for_status() + json = ret.json() + + # process results + # FIXME: This assumes 'return type' matches a specific query, + # it will break if we expand the queries and they dont have different types + if 'results' not in json: + # /hosts/:id dos not have a 'results' key + results = json + break + elif isinstance(json['results'], MutableMapping): + # /facts are returned as dict in 'results' + results = json['results'] + break + else: + # /hosts 's 'results' is a list of all hosts, returned is paginated + results = results + json['results'] + + # check for end of paging + if len(results) >= json['subtotal']: + break + if len(json['results']) == 0: + self.display.warning("Did not make any progress during loop. expected %d got %d" % (json['subtotal'], len(results))) + break + + # get next page + params['page'] += 1 + + self._cache[self.cache_key][url] = results + + return self._cache[self.cache_key][url] + + def _get_hosts(self): + return self._get_json("%s/api/v2/hosts" % self.foreman_url) + + def _get_all_params_by_id(self, hid): + url = "%s/api/v2/hosts/%s" % (self.foreman_url, hid) + ret = self._get_json(url, [404]) + if not ret or not isinstance(ret, MutableMapping) or not ret.get('all_parameters', False): + return {} + return ret.get('all_parameters') + + def _get_facts_by_id(self, hid): + url = "%s/api/v2/hosts/%s/facts" % (self.foreman_url, hid) + return self._get_json(url) + + def _get_host_data_by_id(self, hid): + url = "%s/api/v2/hosts/%s" % (self.foreman_url, hid) + return self._get_json(url) + + def _get_facts(self, host): + """Fetch all host facts of the host""" + + ret = self._get_facts_by_id(host['id']) + if len(ret.values()) == 0: + facts = {} + elif len(ret.values()) == 1: + facts = list(ret.values())[0] + else: + raise ValueError("More than one set of facts returned for '%s'" % host) + return facts + + def _populate(self): + + for host in self._get_hosts(): + + if host.get('name'): + host_name = self.inventory.add_host(host['name']) + + # create directly mapped groups + group_name = host.get('hostgroup_title', host.get('hostgroup_name')) + if group_name: + group_name = to_safe_group_name('%s%s' % (self.get_option('group_prefix'), group_name.lower().replace(" ", ""))) + group_name = self.inventory.add_group(group_name) + self.inventory.add_child(group_name, host_name) + + # set host vars from host info + try: + for k, v in host.items(): + if k not in ('name', 'hostgroup_title', 'hostgroup_name'): + try: + self.inventory.set_variable(host_name, self.get_option('vars_prefix') + k, v) + except ValueError as e: + self.display.warning("Could not set host info hostvar for %s, skipping %s: %s" % (host, k, to_text(e))) + except ValueError as e: + self.display.warning("Could not get host info for %s, skipping: %s" % (host_name, to_text(e))) + + # set host vars from params + if self.get_option('want_params'): + for p in self._get_all_params_by_id(host['id']): + try: + self.inventory.set_variable(host_name, p['name'], p['value']) + except ValueError as e: + self.display.warning("Could not set hostvar %s to '%s' for the '%s' host, skipping: %s" % + (p['name'], to_native(p['value']), host, to_native(e))) + + # set host vars from facts + if self.get_option('want_facts'): + self.inventory.set_variable(host_name, 'foreman_facts', self._get_facts(host)) + + # create group for host collections + if self.get_option('want_hostcollections'): + host_data = self._get_host_data_by_id(host['id']) + hostcollections = host_data.get('host_collections') + if hostcollections: + # Create Ansible groups for host collections + for hostcollection in hostcollections: + try: + hostcollection_group = to_safe_group_name('%shostcollection_%s' % (self.get_option('group_prefix'), + hostcollection['name'].lower().replace(" ", ""))) + hostcollection_group = self.inventory.add_group(hostcollection_group) + self.inventory.add_child(hostcollection_group, host_name) + except ValueError as e: + self.display.warning("Could not create groups for host collections for %s, skipping: %s" % (host_name, to_text(e))) + + # put ansible_ssh_host as hostvar + if self.get_option('want_ansible_ssh_host'): + for key in ('ip', 'ipv4', 'ipv6'): + if host.get(key): + try: + self.inventory.set_variable(host_name, 'ansible_ssh_host', host[key]) + break + except ValueError as e: + self.display.warning("Could not set hostvar ansible_ssh_host to '%s' for the '%s' host, skipping: %s" % + (host[key], host_name, to_text(e))) + + strict = self.get_option('strict') + + hostvars = self.inventory.get_host(host_name).get_vars() + self._set_composite_vars(self.get_option('compose'), hostvars, host_name, strict) + self._add_host_to_composed_groups(self.get_option('groups'), hostvars, host_name, strict) + self._add_host_to_keyed_groups(self.get_option('keyed_groups'), hostvars, host_name, strict) + + def parse(self, inventory, loader, path, cache=True): + + super(InventoryModule, self).parse(inventory, loader, path) + + # read config from file, this sets 'options' + self._read_config_data(path) + + # get connection host + self.foreman_url = self.get_option('url') + self.cache_key = self.get_cache_key(path) + self.use_cache = cache and self.get_option('cache') + + # actually populate inventory + self._populate() diff --git a/test/support/integration/plugins/lookup/rabbitmq.py b/test/support/integration/plugins/lookup/rabbitmq.py new file mode 100644 index 00000000..7c2745f4 --- /dev/null +++ b/test/support/integration/plugins/lookup/rabbitmq.py @@ -0,0 +1,190 @@ +# (c) 2018, John Imison <john+github@imison.net> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +DOCUMENTATION = """ + lookup: rabbitmq + author: John Imison <@Im0> + version_added: "2.8" + short_description: Retrieve messages from an AMQP/AMQPS RabbitMQ queue. + description: + - This lookup uses a basic get to retrieve all, or a limited number C(count), messages from a RabbitMQ queue. + options: + url: + description: + - An URI connection string to connect to the AMQP/AMQPS RabbitMQ server. + - For more information refer to the URI spec U(https://www.rabbitmq.com/uri-spec.html). + required: True + queue: + description: + - The queue to get messages from. + required: True + count: + description: + - How many messages to collect from the queue. + - If not set, defaults to retrieving all the messages from the queue. + requirements: + - The python pika package U(https://pypi.org/project/pika/). + notes: + - This lookup implements BlockingChannel.basic_get to get messages from a RabbitMQ server. + - After retrieving a message from the server, receipt of the message is acknowledged and the message on the server is deleted. + - Pika is a pure-Python implementation of the AMQP 0-9-1 protocol that tries to stay fairly independent of the underlying network support library. + - More information about pika can be found at U(https://pika.readthedocs.io/en/stable/). + - This plugin is tested against RabbitMQ. Other AMQP 0.9.1 protocol based servers may work but not tested/guaranteed. + - Assigning the return messages to a variable under C(vars) may result in unexpected results as the lookup is evaluated every time the + variable is referenced. + - Currently this plugin only handles text based messages from a queue. Unexpected results may occur when retrieving binary data. +""" + + +EXAMPLES = """ +- name: Get all messages off a queue + debug: + msg: "{{ lookup('rabbitmq', url='amqp://guest:guest@192.168.0.10:5672/%2F', queue='hello') }}" + + +# If you are intending on using the returned messages as a variable in more than +# one task (eg. debug, template), it is recommended to set_fact. + +- name: Get 2 messages off a queue and set a fact for re-use + set_fact: + messages: "{{ lookup('rabbitmq', url='amqp://guest:guest@192.168.0.10:5672/%2F', queue='hello', count=2) }}" + +- name: Dump out contents of the messages + debug: + var: messages + +""" + +RETURN = """ + _list: + description: + - A list of dictionaries with keys and value from the queue. + type: list + contains: + content_type: + description: The content_type on the message in the queue. + type: str + delivery_mode: + description: The delivery_mode on the message in the queue. + type: str + delivery_tag: + description: The delivery_tag on the message in the queue. + type: str + exchange: + description: The exchange the message came from. + type: str + message_count: + description: The message_count for the message on the queue. + type: str + msg: + description: The content of the message. + type: str + redelivered: + description: The redelivered flag. True if the message has been delivered before. + type: bool + routing_key: + description: The routing_key on the message in the queue. + type: str + headers: + description: The headers for the message returned from the queue. + type: dict + json: + description: If application/json is specified in content_type, json will be loaded into variables. + type: dict + +""" + +import json + +from ansible.errors import AnsibleError, AnsibleParserError +from ansible.plugins.lookup import LookupBase +from ansible.module_utils._text import to_native, to_text +from ansible.utils.display import Display + +try: + import pika + from pika import spec + HAS_PIKA = True +except ImportError: + HAS_PIKA = False + +display = Display() + + +class LookupModule(LookupBase): + + def run(self, terms, variables=None, url=None, queue=None, count=None): + if not HAS_PIKA: + raise AnsibleError('pika python package is required for rabbitmq lookup.') + if not url: + raise AnsibleError('URL is required for rabbitmq lookup.') + if not queue: + raise AnsibleError('Queue is required for rabbitmq lookup.') + + display.vvv(u"terms:%s : variables:%s url:%s queue:%s count:%s" % (terms, variables, url, queue, count)) + + try: + parameters = pika.URLParameters(url) + except Exception as e: + raise AnsibleError("URL malformed: %s" % to_native(e)) + + try: + connection = pika.BlockingConnection(parameters) + except Exception as e: + raise AnsibleError("Connection issue: %s" % to_native(e)) + + try: + conn_channel = connection.channel() + except pika.exceptions.AMQPChannelError as e: + try: + connection.close() + except pika.exceptions.AMQPConnectionError as ie: + raise AnsibleError("Channel and connection closing issues: %s / %s" % to_native(e), to_native(ie)) + raise AnsibleError("Channel issue: %s" % to_native(e)) + + ret = [] + idx = 0 + + while True: + method_frame, properties, body = conn_channel.basic_get(queue=queue) + if method_frame: + display.vvv(u"%s, %s, %s " % (method_frame, properties, to_text(body))) + + # TODO: In the future consider checking content_type and handle text/binary data differently. + msg_details = dict({ + 'msg': to_text(body), + 'message_count': method_frame.message_count, + 'routing_key': method_frame.routing_key, + 'delivery_tag': method_frame.delivery_tag, + 'redelivered': method_frame.redelivered, + 'exchange': method_frame.exchange, + 'delivery_mode': properties.delivery_mode, + 'content_type': properties.content_type, + 'headers': properties.headers + }) + if properties.content_type == 'application/json': + try: + msg_details['json'] = json.loads(msg_details['msg']) + except ValueError as e: + raise AnsibleError("Unable to decode JSON for message %s: %s" % (method_frame.delivery_tag, to_native(e))) + + ret.append(msg_details) + conn_channel.basic_ack(method_frame.delivery_tag) + idx += 1 + if method_frame.message_count == 0 or idx == count: + break + # If we didn't get a method_frame, exit. + else: + break + + if connection.is_closed: + return [ret] + else: + try: + connection.close() + except pika.exceptions.AMQPConnectionError: + pass + return [ret] diff --git a/test/support/integration/plugins/module_utils/aws/__init__.py b/test/support/integration/plugins/module_utils/aws/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/__init__.py diff --git a/test/support/integration/plugins/module_utils/aws/core.py b/test/support/integration/plugins/module_utils/aws/core.py new file mode 100644 index 00000000..c4527b6d --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/core.py @@ -0,0 +1,335 @@ +# +# Copyright 2017 Michael De La Rue | Ansible +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +"""This module adds shared support for generic Amazon AWS modules + +**This code is not yet ready for use in user modules. As of 2017** +**and through to 2018, the interface is likely to change** +**aggressively as the exact correct interface for ansible AWS modules** +**is identified. In particular, until this notice goes away or is** +**changed, methods may disappear from the interface. Please don't** +**publish modules using this except directly to the main Ansible** +**development repository.** + +In order to use this module, include it as part of a custom +module as shown below. + + from ansible.module_utils.aws import AnsibleAWSModule + module = AnsibleAWSModule(argument_spec=dictionary, supports_check_mode=boolean + mutually_exclusive=list1, required_together=list2) + +The 'AnsibleAWSModule' module provides similar, but more restricted, +interfaces to the normal Ansible module. It also includes the +additional methods for connecting to AWS using the standard module arguments + + m.resource('lambda') # - get an AWS connection as a boto3 resource. + +or + + m.client('sts') # - get an AWS connection as a boto3 client. + +To make use of AWSRetry easier, it can now be wrapped around any call from a +module-created client. To add retries to a client, create a client: + + m.client('ec2', retry_decorator=AWSRetry.jittered_backoff(retries=10)) + +Any calls from that client can be made to use the decorator passed at call-time +using the `aws_retry` argument. By default, no retries are used. + + ec2 = m.client('ec2', retry_decorator=AWSRetry.jittered_backoff(retries=10)) + ec2.describe_instances(InstanceIds=['i-123456789'], aws_retry=True) + +The call will be retried the specified number of times, so the calling functions +don't need to be wrapped in the backoff decorator. +""" + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import re +import logging +import traceback +from functools import wraps +from distutils.version import LooseVersion + +try: + from cStringIO import StringIO +except ImportError: + # Python 3 + from io import StringIO + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native +from ansible.module_utils.ec2 import HAS_BOTO3, camel_dict_to_snake_dict, ec2_argument_spec, boto3_conn +from ansible.module_utils.ec2 import get_aws_connection_info, get_aws_region + +# We will also export HAS_BOTO3 so end user modules can use it. +__all__ = ('AnsibleAWSModule', 'HAS_BOTO3', 'is_boto3_error_code') + + +class AnsibleAWSModule(object): + """An ansible module class for AWS modules + + AnsibleAWSModule provides an a class for building modules which + connect to Amazon Web Services. The interface is currently more + restricted than the basic module class with the aim that later the + basic module class can be reduced. If you find that any key + feature is missing please contact the author/Ansible AWS team + (available on #ansible-aws on IRC) to request the additional + features needed. + """ + default_settings = { + "default_args": True, + "check_boto3": True, + "auto_retry": True, + "module_class": AnsibleModule + } + + def __init__(self, **kwargs): + local_settings = {} + for key in AnsibleAWSModule.default_settings: + try: + local_settings[key] = kwargs.pop(key) + except KeyError: + local_settings[key] = AnsibleAWSModule.default_settings[key] + self.settings = local_settings + + if local_settings["default_args"]: + # ec2_argument_spec contains the region so we use that; there's a patch coming which + # will add it to aws_argument_spec so if that's accepted then later we should change + # over + argument_spec_full = ec2_argument_spec() + try: + argument_spec_full.update(kwargs["argument_spec"]) + except (TypeError, NameError): + pass + kwargs["argument_spec"] = argument_spec_full + + self._module = AnsibleAWSModule.default_settings["module_class"](**kwargs) + + if local_settings["check_boto3"] and not HAS_BOTO3: + self._module.fail_json( + msg=missing_required_lib('botocore or boto3')) + + self.check_mode = self._module.check_mode + self._diff = self._module._diff + self._name = self._module._name + + self._botocore_endpoint_log_stream = StringIO() + self.logger = None + if self.params.get('debug_botocore_endpoint_logs'): + self.logger = logging.getLogger('botocore.endpoint') + self.logger.setLevel(logging.DEBUG) + self.logger.addHandler(logging.StreamHandler(self._botocore_endpoint_log_stream)) + + @property + def params(self): + return self._module.params + + def _get_resource_action_list(self): + actions = [] + for ln in self._botocore_endpoint_log_stream.getvalue().split('\n'): + ln = ln.strip() + if not ln: + continue + found_operational_request = re.search(r"OperationModel\(name=.*?\)", ln) + if found_operational_request: + operation_request = found_operational_request.group(0)[20:-1] + resource = re.search(r"https://.*?\.", ln).group(0)[8:-1] + actions.append("{0}:{1}".format(resource, operation_request)) + return list(set(actions)) + + def exit_json(self, *args, **kwargs): + if self.params.get('debug_botocore_endpoint_logs'): + kwargs['resource_actions'] = self._get_resource_action_list() + return self._module.exit_json(*args, **kwargs) + + def fail_json(self, *args, **kwargs): + if self.params.get('debug_botocore_endpoint_logs'): + kwargs['resource_actions'] = self._get_resource_action_list() + return self._module.fail_json(*args, **kwargs) + + def debug(self, *args, **kwargs): + return self._module.debug(*args, **kwargs) + + def warn(self, *args, **kwargs): + return self._module.warn(*args, **kwargs) + + def deprecate(self, *args, **kwargs): + return self._module.deprecate(*args, **kwargs) + + def boolean(self, *args, **kwargs): + return self._module.boolean(*args, **kwargs) + + def md5(self, *args, **kwargs): + return self._module.md5(*args, **kwargs) + + def client(self, service, retry_decorator=None): + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(self, boto3=True) + conn = boto3_conn(self, conn_type='client', resource=service, + region=region, endpoint=ec2_url, **aws_connect_kwargs) + return conn if retry_decorator is None else _RetryingBotoClientWrapper(conn, retry_decorator) + + def resource(self, service): + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(self, boto3=True) + return boto3_conn(self, conn_type='resource', resource=service, + region=region, endpoint=ec2_url, **aws_connect_kwargs) + + @property + def region(self, boto3=True): + return get_aws_region(self, boto3) + + def fail_json_aws(self, exception, msg=None): + """call fail_json with processed exception + + function for converting exceptions thrown by AWS SDK modules, + botocore, boto3 and boto, into nice error messages. + """ + last_traceback = traceback.format_exc() + + # to_native is trusted to handle exceptions that str() could + # convert to text. + try: + except_msg = to_native(exception.message) + except AttributeError: + except_msg = to_native(exception) + + if msg is not None: + message = '{0}: {1}'.format(msg, except_msg) + else: + message = except_msg + + try: + response = exception.response + except AttributeError: + response = None + + failure = dict( + msg=message, + exception=last_traceback, + **self._gather_versions() + ) + + if response is not None: + failure.update(**camel_dict_to_snake_dict(response)) + + self.fail_json(**failure) + + def _gather_versions(self): + """Gather AWS SDK (boto3 and botocore) dependency versions + + Returns {'boto3_version': str, 'botocore_version': str} + Returns {} if neither are installed + """ + if not HAS_BOTO3: + return {} + import boto3 + import botocore + return dict(boto3_version=boto3.__version__, + botocore_version=botocore.__version__) + + def boto3_at_least(self, desired): + """Check if the available boto3 version is greater than or equal to a desired version. + + Usage: + if module.params.get('assign_ipv6_address') and not module.boto3_at_least('1.4.4'): + # conditionally fail on old boto3 versions if a specific feature is not supported + module.fail_json(msg="Boto3 can't deal with EC2 IPv6 addresses before version 1.4.4.") + """ + existing = self._gather_versions() + return LooseVersion(existing['boto3_version']) >= LooseVersion(desired) + + def botocore_at_least(self, desired): + """Check if the available botocore version is greater than or equal to a desired version. + + Usage: + if not module.botocore_at_least('1.2.3'): + module.fail_json(msg='The Serverless Elastic Load Compute Service is not in botocore before v1.2.3') + if not module.botocore_at_least('1.5.3'): + module.warn('Botocore did not include waiters for Service X before 1.5.3. ' + 'To wait until Service X resources are fully available, update botocore.') + """ + existing = self._gather_versions() + return LooseVersion(existing['botocore_version']) >= LooseVersion(desired) + + +class _RetryingBotoClientWrapper(object): + __never_wait = ( + 'get_paginator', 'can_paginate', + 'get_waiter', 'generate_presigned_url', + ) + + def __init__(self, client, retry): + self.client = client + self.retry = retry + + def _create_optional_retry_wrapper_function(self, unwrapped): + retrying_wrapper = self.retry(unwrapped) + + @wraps(unwrapped) + def deciding_wrapper(aws_retry=False, *args, **kwargs): + if aws_retry: + return retrying_wrapper(*args, **kwargs) + else: + return unwrapped(*args, **kwargs) + return deciding_wrapper + + def __getattr__(self, name): + unwrapped = getattr(self.client, name) + if name in self.__never_wait: + return unwrapped + elif callable(unwrapped): + wrapped = self._create_optional_retry_wrapper_function(unwrapped) + setattr(self, name, wrapped) + return wrapped + else: + return unwrapped + + +def is_boto3_error_code(code, e=None): + """Check if the botocore exception is raised by a specific error code. + + Returns ClientError if the error code matches, a dummy exception if it does not have an error code or does not match + + Example: + try: + ec2.describe_instances(InstanceIds=['potato']) + except is_boto3_error_code('InvalidInstanceID.Malformed'): + # handle the error for that code case + except botocore.exceptions.ClientError as e: + # handle the generic error case for all other codes + """ + from botocore.exceptions import ClientError + if e is None: + import sys + dummy, e, dummy = sys.exc_info() + if isinstance(e, ClientError) and e.response['Error']['Code'] == code: + return ClientError + return type('NeverEverRaisedException', (Exception,), {}) + + +def get_boto3_client_method_parameters(client, method_name, required=False): + op = client.meta.method_to_api_mapping.get(method_name) + input_shape = client._service_model.operation_model(op).input_shape + if not input_shape: + parameters = [] + elif required: + parameters = list(input_shape.required_members) + else: + parameters = list(input_shape.members.keys()) + return parameters diff --git a/test/support/integration/plugins/module_utils/aws/iam.py b/test/support/integration/plugins/module_utils/aws/iam.py new file mode 100644 index 00000000..f05999aa --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/iam.py @@ -0,0 +1,49 @@ +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import traceback + +try: + from botocore.exceptions import ClientError, NoCredentialsError +except ImportError: + pass # caught by HAS_BOTO3 + +from ansible.module_utils._text import to_native + + +def get_aws_account_id(module): + """ Given AnsibleAWSModule instance, get the active AWS account ID + + get_account_id tries too find out the account that we are working + on. It's not guaranteed that this will be easy so we try in + several different ways. Giving either IAM or STS privilages to + the account should be enough to permit this. + """ + account_id = None + try: + sts_client = module.client('sts') + account_id = sts_client.get_caller_identity().get('Account') + # non-STS sessions may also get NoCredentialsError from this STS call, so + # we must catch that too and try the IAM version + except (ClientError, NoCredentialsError): + try: + iam_client = module.client('iam') + account_id = iam_client.get_user()['User']['Arn'].split(':')[4] + except ClientError as e: + if (e.response['Error']['Code'] == 'AccessDenied'): + except_msg = to_native(e) + # don't match on `arn:aws` because of China region `arn:aws-cn` and similar + account_id = except_msg.search(r"arn:\w+:iam::([0-9]{12,32}):\w+/").group(1) + if account_id is None: + module.fail_json_aws(e, msg="Could not get AWS account information") + except Exception as e: + module.fail_json( + msg="Failed to get AWS account information, Try allowing sts:GetCallerIdentity or iam:GetUser permissions.", + exception=traceback.format_exc() + ) + if not account_id: + module.fail_json(msg="Failed while determining AWS account ID. Try allowing sts:GetCallerIdentity or iam:GetUser permissions.") + return to_native(account_id) diff --git a/test/support/integration/plugins/module_utils/aws/s3.py b/test/support/integration/plugins/module_utils/aws/s3.py new file mode 100644 index 00000000..2185869d --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/s3.py @@ -0,0 +1,50 @@ +# Copyright (c) 2018 Red Hat, Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +try: + from botocore.exceptions import BotoCoreError, ClientError +except ImportError: + pass # Handled by the calling module + +HAS_MD5 = True +try: + from hashlib import md5 +except ImportError: + try: + from md5 import md5 + except ImportError: + HAS_MD5 = False + + +def calculate_etag(module, filename, etag, s3, bucket, obj, version=None): + if not HAS_MD5: + return None + + if '-' in etag: + # Multi-part ETag; a hash of the hashes of each part. + parts = int(etag[1:-1].split('-')[1]) + digests = [] + + s3_kwargs = dict( + Bucket=bucket, + Key=obj, + ) + if version: + s3_kwargs['VersionId'] = version + + with open(filename, 'rb') as f: + for part_num in range(1, parts + 1): + s3_kwargs['PartNumber'] = part_num + try: + head = s3.head_object(**s3_kwargs) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get head object") + digests.append(md5(f.read(int(head['ContentLength'])))) + + digest_squared = md5(b''.join(m.digest() for m in digests)) + return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests)) + else: # Compute the MD5 sum normally + return '"{0}"'.format(module.md5(filename)) diff --git a/test/support/integration/plugins/module_utils/aws/waiters.py b/test/support/integration/plugins/module_utils/aws/waiters.py new file mode 100644 index 00000000..25db598b --- /dev/null +++ b/test/support/integration/plugins/module_utils/aws/waiters.py @@ -0,0 +1,405 @@ +# Copyright: (c) 2018, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +try: + import botocore.waiter as core_waiter +except ImportError: + pass # caught by HAS_BOTO3 + + +ec2_data = { + "version": 2, + "waiters": { + "InternetGatewayExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeInternetGateways", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(InternetGateways) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidInternetGatewayID.NotFound", + "state": "retry" + }, + ] + }, + "RouteTableExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeRouteTables", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(RouteTables[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidRouteTableID.NotFound", + "state": "retry" + }, + ] + }, + "SecurityGroupExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSecurityGroups", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(SecurityGroups[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidGroup.NotFound", + "state": "retry" + }, + ] + }, + "SubnetExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(Subnets[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidSubnetID.NotFound", + "state": "retry" + }, + ] + }, + "SubnetHasMapPublic": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": True, + "argument": "Subnets[].MapPublicIpOnLaunch", + "state": "success" + }, + ] + }, + "SubnetNoMapPublic": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": False, + "argument": "Subnets[].MapPublicIpOnLaunch", + "state": "success" + }, + ] + }, + "SubnetHasAssignIpv6": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": True, + "argument": "Subnets[].AssignIpv6AddressOnCreation", + "state": "success" + }, + ] + }, + "SubnetNoAssignIpv6": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "pathAll", + "expected": False, + "argument": "Subnets[].AssignIpv6AddressOnCreation", + "state": "success" + }, + ] + }, + "SubnetDeleted": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeSubnets", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(Subnets[]) > `0`", + "state": "retry" + }, + { + "matcher": "error", + "expected": "InvalidSubnetID.NotFound", + "state": "success" + }, + ] + }, + "VpnGatewayExists": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeVpnGateways", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "length(VpnGateways[]) > `0`", + "state": "success" + }, + { + "matcher": "error", + "expected": "InvalidVpnGatewayID.NotFound", + "state": "retry" + }, + ] + }, + "VpnGatewayDetached": { + "delay": 5, + "maxAttempts": 40, + "operation": "DescribeVpnGateways", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "VpnGateways[0].State == 'available'", + "state": "success" + }, + ] + }, + } +} + + +waf_data = { + "version": 2, + "waiters": { + "ChangeTokenInSync": { + "delay": 20, + "maxAttempts": 60, + "operation": "GetChangeTokenStatus", + "acceptors": [ + { + "matcher": "path", + "expected": True, + "argument": "ChangeTokenStatus == 'INSYNC'", + "state": "success" + }, + { + "matcher": "error", + "expected": "WAFInternalErrorException", + "state": "retry" + } + ] + } + } +} + +eks_data = { + "version": 2, + "waiters": { + "ClusterActive": { + "delay": 20, + "maxAttempts": 60, + "operation": "DescribeCluster", + "acceptors": [ + { + "state": "success", + "matcher": "path", + "argument": "cluster.status", + "expected": "ACTIVE" + }, + { + "state": "retry", + "matcher": "error", + "expected": "ResourceNotFoundException" + } + ] + }, + "ClusterDeleted": { + "delay": 20, + "maxAttempts": 60, + "operation": "DescribeCluster", + "acceptors": [ + { + "state": "retry", + "matcher": "path", + "argument": "cluster.status != 'DELETED'", + "expected": True + }, + { + "state": "success", + "matcher": "error", + "expected": "ResourceNotFoundException" + } + ] + } + } +} + + +rds_data = { + "version": 2, + "waiters": { + "DBInstanceStopped": { + "delay": 20, + "maxAttempts": 60, + "operation": "DescribeDBInstances", + "acceptors": [ + { + "state": "success", + "matcher": "pathAll", + "argument": "DBInstances[].DBInstanceStatus", + "expected": "stopped" + }, + ] + } + } +} + + +def ec2_model(name): + ec2_models = core_waiter.WaiterModel(waiter_config=ec2_data) + return ec2_models.get_waiter(name) + + +def waf_model(name): + waf_models = core_waiter.WaiterModel(waiter_config=waf_data) + return waf_models.get_waiter(name) + + +def eks_model(name): + eks_models = core_waiter.WaiterModel(waiter_config=eks_data) + return eks_models.get_waiter(name) + + +def rds_model(name): + rds_models = core_waiter.WaiterModel(waiter_config=rds_data) + return rds_models.get_waiter(name) + + +waiters_by_name = { + ('EC2', 'internet_gateway_exists'): lambda ec2: core_waiter.Waiter( + 'internet_gateway_exists', + ec2_model('InternetGatewayExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_internet_gateways + )), + ('EC2', 'route_table_exists'): lambda ec2: core_waiter.Waiter( + 'route_table_exists', + ec2_model('RouteTableExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_route_tables + )), + ('EC2', 'security_group_exists'): lambda ec2: core_waiter.Waiter( + 'security_group_exists', + ec2_model('SecurityGroupExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_security_groups + )), + ('EC2', 'subnet_exists'): lambda ec2: core_waiter.Waiter( + 'subnet_exists', + ec2_model('SubnetExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_has_map_public'): lambda ec2: core_waiter.Waiter( + 'subnet_has_map_public', + ec2_model('SubnetHasMapPublic'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_no_map_public'): lambda ec2: core_waiter.Waiter( + 'subnet_no_map_public', + ec2_model('SubnetNoMapPublic'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_has_assign_ipv6'): lambda ec2: core_waiter.Waiter( + 'subnet_has_assign_ipv6', + ec2_model('SubnetHasAssignIpv6'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_no_assign_ipv6'): lambda ec2: core_waiter.Waiter( + 'subnet_no_assign_ipv6', + ec2_model('SubnetNoAssignIpv6'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'subnet_deleted'): lambda ec2: core_waiter.Waiter( + 'subnet_deleted', + ec2_model('SubnetDeleted'), + core_waiter.NormalizedOperationMethod( + ec2.describe_subnets + )), + ('EC2', 'vpn_gateway_exists'): lambda ec2: core_waiter.Waiter( + 'vpn_gateway_exists', + ec2_model('VpnGatewayExists'), + core_waiter.NormalizedOperationMethod( + ec2.describe_vpn_gateways + )), + ('EC2', 'vpn_gateway_detached'): lambda ec2: core_waiter.Waiter( + 'vpn_gateway_detached', + ec2_model('VpnGatewayDetached'), + core_waiter.NormalizedOperationMethod( + ec2.describe_vpn_gateways + )), + ('WAF', 'change_token_in_sync'): lambda waf: core_waiter.Waiter( + 'change_token_in_sync', + waf_model('ChangeTokenInSync'), + core_waiter.NormalizedOperationMethod( + waf.get_change_token_status + )), + ('WAFRegional', 'change_token_in_sync'): lambda waf: core_waiter.Waiter( + 'change_token_in_sync', + waf_model('ChangeTokenInSync'), + core_waiter.NormalizedOperationMethod( + waf.get_change_token_status + )), + ('EKS', 'cluster_active'): lambda eks: core_waiter.Waiter( + 'cluster_active', + eks_model('ClusterActive'), + core_waiter.NormalizedOperationMethod( + eks.describe_cluster + )), + ('EKS', 'cluster_deleted'): lambda eks: core_waiter.Waiter( + 'cluster_deleted', + eks_model('ClusterDeleted'), + core_waiter.NormalizedOperationMethod( + eks.describe_cluster + )), + ('RDS', 'db_instance_stopped'): lambda rds: core_waiter.Waiter( + 'db_instance_stopped', + rds_model('DBInstanceStopped'), + core_waiter.NormalizedOperationMethod( + rds.describe_db_instances + )), +} + + +def get_waiter(client, waiter_name): + try: + return waiters_by_name[(client.__class__.__name__, waiter_name)](client) + except KeyError: + raise NotImplementedError("Waiter {0} could not be found for client {1}. Available waiters: {2}".format( + waiter_name, type(client), ', '.join(repr(k) for k in waiters_by_name.keys()))) diff --git a/test/support/integration/plugins/module_utils/azure_rm_common.py b/test/support/integration/plugins/module_utils/azure_rm_common.py new file mode 100644 index 00000000..a7b55e97 --- /dev/null +++ b/test/support/integration/plugins/module_utils/azure_rm_common.py @@ -0,0 +1,1473 @@ +# Copyright (c) 2016 Matt Davis, <mdavis@ansible.com> +# Chris Houseknecht, <house@redhat.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import os +import re +import types +import copy +import inspect +import traceback +import json + +from os.path import expanduser + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +try: + from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION +except Exception: + ANSIBLE_VERSION = 'unknown' +from ansible.module_utils.six.moves import configparser +import ansible.module_utils.six.moves.urllib.parse as urlparse + +AZURE_COMMON_ARGS = dict( + auth_source=dict( + type='str', + choices=['auto', 'cli', 'env', 'credential_file', 'msi'] + ), + profile=dict(type='str'), + subscription_id=dict(type='str'), + client_id=dict(type='str', no_log=True), + secret=dict(type='str', no_log=True), + tenant=dict(type='str', no_log=True), + ad_user=dict(type='str', no_log=True), + password=dict(type='str', no_log=True), + cloud_environment=dict(type='str', default='AzureCloud'), + cert_validation_mode=dict(type='str', choices=['validate', 'ignore']), + api_profile=dict(type='str', default='latest'), + adfs_authority_url=dict(type='str', default=None) +) + +AZURE_CREDENTIAL_ENV_MAPPING = dict( + profile='AZURE_PROFILE', + subscription_id='AZURE_SUBSCRIPTION_ID', + client_id='AZURE_CLIENT_ID', + secret='AZURE_SECRET', + tenant='AZURE_TENANT', + ad_user='AZURE_AD_USER', + password='AZURE_PASSWORD', + cloud_environment='AZURE_CLOUD_ENVIRONMENT', + cert_validation_mode='AZURE_CERT_VALIDATION_MODE', + adfs_authority_url='AZURE_ADFS_AUTHORITY_URL' +) + + +class SDKProfile(object): # pylint: disable=too-few-public-methods + + def __init__(self, default_api_version, profile=None): + """Constructor. + + :param str default_api_version: Default API version if not overridden by a profile. Nullable. + :param profile: A dict operation group name to API version. + :type profile: dict[str, str] + """ + self.profile = profile if profile is not None else {} + self.profile[None] = default_api_version + + @property + def default_api_version(self): + return self.profile[None] + + +# FUTURE: this should come from the SDK or an external location. +# For now, we have to copy from azure-cli +AZURE_API_PROFILES = { + 'latest': { + 'ContainerInstanceManagementClient': '2018-02-01-preview', + 'ComputeManagementClient': dict( + default_api_version='2018-10-01', + resource_skus='2018-10-01', + disks='2018-06-01', + snapshots='2018-10-01', + virtual_machine_run_commands='2018-10-01' + ), + 'NetworkManagementClient': '2018-08-01', + 'ResourceManagementClient': '2017-05-10', + 'StorageManagementClient': '2017-10-01', + 'WebSiteManagementClient': '2018-02-01', + 'PostgreSQLManagementClient': '2017-12-01', + 'MySQLManagementClient': '2017-12-01', + 'MariaDBManagementClient': '2019-03-01', + 'ManagementLockClient': '2016-09-01' + }, + '2019-03-01-hybrid': { + 'StorageManagementClient': '2017-10-01', + 'NetworkManagementClient': '2017-10-01', + 'ComputeManagementClient': SDKProfile('2017-12-01', { + 'resource_skus': '2017-09-01', + 'disks': '2017-03-30', + 'snapshots': '2017-03-30' + }), + 'ManagementLinkClient': '2016-09-01', + 'ManagementLockClient': '2016-09-01', + 'PolicyClient': '2016-12-01', + 'ResourceManagementClient': '2018-05-01', + 'SubscriptionClient': '2016-06-01', + 'DnsManagementClient': '2016-04-01', + 'KeyVaultManagementClient': '2016-10-01', + 'AuthorizationManagementClient': SDKProfile('2015-07-01', { + 'classic_administrators': '2015-06-01', + 'policy_assignments': '2016-12-01', + 'policy_definitions': '2016-12-01' + }), + 'KeyVaultClient': '2016-10-01', + 'azure.multiapi.storage': '2017-11-09', + 'azure.multiapi.cosmosdb': '2017-04-17' + }, + '2018-03-01-hybrid': { + 'StorageManagementClient': '2016-01-01', + 'NetworkManagementClient': '2017-10-01', + 'ComputeManagementClient': SDKProfile('2017-03-30'), + 'ManagementLinkClient': '2016-09-01', + 'ManagementLockClient': '2016-09-01', + 'PolicyClient': '2016-12-01', + 'ResourceManagementClient': '2018-02-01', + 'SubscriptionClient': '2016-06-01', + 'DnsManagementClient': '2016-04-01', + 'KeyVaultManagementClient': '2016-10-01', + 'AuthorizationManagementClient': SDKProfile('2015-07-01', { + 'classic_administrators': '2015-06-01' + }), + 'KeyVaultClient': '2016-10-01', + 'azure.multiapi.storage': '2017-04-17', + 'azure.multiapi.cosmosdb': '2017-04-17' + }, + '2017-03-09-profile': { + 'StorageManagementClient': '2016-01-01', + 'NetworkManagementClient': '2015-06-15', + 'ComputeManagementClient': SDKProfile('2016-03-30'), + 'ManagementLinkClient': '2016-09-01', + 'ManagementLockClient': '2015-01-01', + 'PolicyClient': '2015-10-01-preview', + 'ResourceManagementClient': '2016-02-01', + 'SubscriptionClient': '2016-06-01', + 'DnsManagementClient': '2016-04-01', + 'KeyVaultManagementClient': '2016-10-01', + 'AuthorizationManagementClient': SDKProfile('2015-07-01', { + 'classic_administrators': '2015-06-01' + }), + 'KeyVaultClient': '2016-10-01', + 'azure.multiapi.storage': '2015-04-05' + } +} + +AZURE_TAG_ARGS = dict( + tags=dict(type='dict'), + append_tags=dict(type='bool', default=True), +) + +AZURE_COMMON_REQUIRED_IF = [ + ('log_mode', 'file', ['log_path']) +] + +ANSIBLE_USER_AGENT = 'Ansible/{0}'.format(ANSIBLE_VERSION) +CLOUDSHELL_USER_AGENT_KEY = 'AZURE_HTTP_USER_AGENT' +VSCODEEXT_USER_AGENT_KEY = 'VSCODEEXT_USER_AGENT' + +CIDR_PATTERN = re.compile(r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1" + r"[0-9]{2}|2[0-4][0-9]|25[0-5])(/([0-9]|[1-2][0-9]|3[0-2]))") + +AZURE_SUCCESS_STATE = "Succeeded" +AZURE_FAILED_STATE = "Failed" + +HAS_AZURE = True +HAS_AZURE_EXC = None +HAS_AZURE_CLI_CORE = True +HAS_AZURE_CLI_CORE_EXC = None + +HAS_MSRESTAZURE = True +HAS_MSRESTAZURE_EXC = None + +try: + import importlib +except ImportError: + # This passes the sanity import test, but does not provide a user friendly error message. + # Doing so would require catching Exception for all imports of Azure dependencies in modules and module_utils. + importlib = None + +try: + from packaging.version import Version + HAS_PACKAGING_VERSION = True + HAS_PACKAGING_VERSION_EXC = None +except ImportError: + Version = None + HAS_PACKAGING_VERSION = False + HAS_PACKAGING_VERSION_EXC = traceback.format_exc() + +# NB: packaging issue sometimes cause msrestazure not to be installed, check it separately +try: + from msrest.serialization import Serializer +except ImportError: + HAS_MSRESTAZURE_EXC = traceback.format_exc() + HAS_MSRESTAZURE = False + +try: + from enum import Enum + from msrestazure.azure_active_directory import AADTokenCredentials + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_active_directory import MSIAuthentication + from msrestazure.tools import parse_resource_id, resource_id, is_valid_resource_id + from msrestazure import azure_cloud + from azure.common.credentials import ServicePrincipalCredentials, UserPassCredentials + from azure.mgmt.monitor.version import VERSION as monitor_client_version + from azure.mgmt.network.version import VERSION as network_client_version + from azure.mgmt.storage.version import VERSION as storage_client_version + from azure.mgmt.compute.version import VERSION as compute_client_version + from azure.mgmt.resource.version import VERSION as resource_client_version + from azure.mgmt.dns.version import VERSION as dns_client_version + from azure.mgmt.web.version import VERSION as web_client_version + from azure.mgmt.network import NetworkManagementClient + from azure.mgmt.resource.resources import ResourceManagementClient + from azure.mgmt.resource.subscriptions import SubscriptionClient + from azure.mgmt.storage import StorageManagementClient + from azure.mgmt.compute import ComputeManagementClient + from azure.mgmt.dns import DnsManagementClient + from azure.mgmt.monitor import MonitorManagementClient + from azure.mgmt.web import WebSiteManagementClient + from azure.mgmt.containerservice import ContainerServiceClient + from azure.mgmt.marketplaceordering import MarketplaceOrderingAgreements + from azure.mgmt.trafficmanager import TrafficManagerManagementClient + from azure.storage.cloudstorageaccount import CloudStorageAccount + from azure.storage.blob import PageBlobService, BlockBlobService + from adal.authentication_context import AuthenticationContext + from azure.mgmt.sql import SqlManagementClient + from azure.mgmt.servicebus import ServiceBusManagementClient + import azure.mgmt.servicebus.models as ServicebusModel + from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient + from azure.mgmt.rdbms.mysql import MySQLManagementClient + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from azure.mgmt.containerregistry import ContainerRegistryManagementClient + from azure.mgmt.containerinstance import ContainerInstanceManagementClient + from azure.mgmt.loganalytics import LogAnalyticsManagementClient + import azure.mgmt.loganalytics.models as LogAnalyticsModels + from azure.mgmt.automation import AutomationClient + import azure.mgmt.automation.models as AutomationModel + from azure.mgmt.iothub import IotHubClient + from azure.mgmt.iothub import models as IoTHubModels + from msrest.service_client import ServiceClient + from msrestazure import AzureConfiguration + from msrest.authentication import Authentication + from azure.mgmt.resource.locks import ManagementLockClient +except ImportError as exc: + Authentication = object + HAS_AZURE_EXC = traceback.format_exc() + HAS_AZURE = False + +from base64 import b64encode, b64decode +from hashlib import sha256 +from hmac import HMAC +from time import time + +try: + from urllib import (urlencode, quote_plus) +except ImportError: + from urllib.parse import (urlencode, quote_plus) + +try: + from azure.cli.core.util import CLIError + from azure.common.credentials import get_azure_cli_credentials, get_cli_profile + from azure.common.cloud import get_cli_active_cloud +except ImportError: + HAS_AZURE_CLI_CORE = False + HAS_AZURE_CLI_CORE_EXC = None + CLIError = Exception + + +def azure_id_to_dict(id): + pieces = re.sub(r'^\/', '', id).split('/') + result = {} + index = 0 + while index < len(pieces) - 1: + result[pieces[index]] = pieces[index + 1] + index += 1 + return result + + +def format_resource_id(val, subscription_id, namespace, types, resource_group): + return resource_id(name=val, + resource_group=resource_group, + namespace=namespace, + type=types, + subscription=subscription_id) if not is_valid_resource_id(val) else val + + +def normalize_location_name(name): + return name.replace(' ', '').lower() + + +# FUTURE: either get this from the requirements file (if we can be sure it's always available at runtime) +# or generate the requirements files from this so we only have one source of truth to maintain... +AZURE_PKG_VERSIONS = { + 'StorageManagementClient': { + 'package_name': 'storage', + 'expected_version': '3.1.0' + }, + 'ComputeManagementClient': { + 'package_name': 'compute', + 'expected_version': '4.4.0' + }, + 'ContainerInstanceManagementClient': { + 'package_name': 'containerinstance', + 'expected_version': '0.4.0' + }, + 'NetworkManagementClient': { + 'package_name': 'network', + 'expected_version': '2.3.0' + }, + 'ResourceManagementClient': { + 'package_name': 'resource', + 'expected_version': '2.1.0' + }, + 'DnsManagementClient': { + 'package_name': 'dns', + 'expected_version': '2.1.0' + }, + 'WebSiteManagementClient': { + 'package_name': 'web', + 'expected_version': '0.41.0' + }, + 'TrafficManagerManagementClient': { + 'package_name': 'trafficmanager', + 'expected_version': '0.50.0' + }, +} if HAS_AZURE else {} + + +AZURE_MIN_RELEASE = '2.0.0' + + +class AzureRMModuleBase(object): + def __init__(self, derived_arg_spec, bypass_checks=False, no_log=False, + mutually_exclusive=None, required_together=None, + required_one_of=None, add_file_common_args=False, supports_check_mode=False, + required_if=None, supports_tags=True, facts_module=False, skip_exec=False): + + merged_arg_spec = dict() + merged_arg_spec.update(AZURE_COMMON_ARGS) + if supports_tags: + merged_arg_spec.update(AZURE_TAG_ARGS) + + if derived_arg_spec: + merged_arg_spec.update(derived_arg_spec) + + merged_required_if = list(AZURE_COMMON_REQUIRED_IF) + if required_if: + merged_required_if += required_if + + self.module = AnsibleModule(argument_spec=merged_arg_spec, + bypass_checks=bypass_checks, + no_log=no_log, + mutually_exclusive=mutually_exclusive, + required_together=required_together, + required_one_of=required_one_of, + add_file_common_args=add_file_common_args, + supports_check_mode=supports_check_mode, + required_if=merged_required_if) + + if not HAS_PACKAGING_VERSION: + self.fail(msg=missing_required_lib('packaging'), + exception=HAS_PACKAGING_VERSION_EXC) + + if not HAS_MSRESTAZURE: + self.fail(msg=missing_required_lib('msrestazure'), + exception=HAS_MSRESTAZURE_EXC) + + if not HAS_AZURE: + self.fail(msg=missing_required_lib('ansible[azure] (azure >= {0})'.format(AZURE_MIN_RELEASE)), + exception=HAS_AZURE_EXC) + + self._network_client = None + self._storage_client = None + self._resource_client = None + self._compute_client = None + self._dns_client = None + self._web_client = None + self._marketplace_client = None + self._sql_client = None + self._mysql_client = None + self._mariadb_client = None + self._postgresql_client = None + self._containerregistry_client = None + self._containerinstance_client = None + self._containerservice_client = None + self._managedcluster_client = None + self._traffic_manager_management_client = None + self._monitor_client = None + self._resource = None + self._log_analytics_client = None + self._servicebus_client = None + self._automation_client = None + self._IoThub_client = None + self._lock_client = None + + self.check_mode = self.module.check_mode + self.api_profile = self.module.params.get('api_profile') + self.facts_module = facts_module + # self.debug = self.module.params.get('debug') + + # delegate auth to AzureRMAuth class (shared with all plugin types) + self.azure_auth = AzureRMAuth(fail_impl=self.fail, **self.module.params) + + # common parameter validation + if self.module.params.get('tags'): + self.validate_tags(self.module.params['tags']) + + if not skip_exec: + res = self.exec_module(**self.module.params) + self.module.exit_json(**res) + + def check_client_version(self, client_type): + # Ensure Azure modules are at least 2.0.0rc5. + package_version = AZURE_PKG_VERSIONS.get(client_type.__name__, None) + if package_version is not None: + client_name = package_version.get('package_name') + try: + client_module = importlib.import_module(client_type.__module__) + client_version = client_module.VERSION + except (RuntimeError, AttributeError): + # can't get at the module version for some reason, just fail silently... + return + expected_version = package_version.get('expected_version') + if Version(client_version) < Version(expected_version): + self.fail("Installed azure-mgmt-{0} client version is {1}. The minimum supported version is {2}. Try " + "`pip install ansible[azure]`".format(client_name, client_version, expected_version)) + if Version(client_version) != Version(expected_version): + self.module.warn("Installed azure-mgmt-{0} client version is {1}. The expected version is {2}. Try " + "`pip install ansible[azure]`".format(client_name, client_version, expected_version)) + + def exec_module(self, **kwargs): + self.fail("Error: {0} failed to implement exec_module method.".format(self.__class__.__name__)) + + def fail(self, msg, **kwargs): + ''' + Shortcut for calling module.fail() + + :param msg: Error message text. + :param kwargs: Any key=value pairs + :return: None + ''' + self.module.fail_json(msg=msg, **kwargs) + + def deprecate(self, msg, version=None, collection_name=None): + self.module.deprecate(msg, version, collection_name=collection_name) + + def log(self, msg, pretty_print=False): + if pretty_print: + self.module.debug(json.dumps(msg, indent=4, sort_keys=True)) + else: + self.module.debug(msg) + + def validate_tags(self, tags): + ''' + Check if tags dictionary contains string:string pairs. + + :param tags: dictionary of string:string pairs + :return: None + ''' + if not self.facts_module: + if not isinstance(tags, dict): + self.fail("Tags must be a dictionary of string:string values.") + for key, value in tags.items(): + if not isinstance(value, str): + self.fail("Tags values must be strings. Found {0}:{1}".format(str(key), str(value))) + + def update_tags(self, tags): + ''' + Call from the module to update metadata tags. Returns tuple + with bool indicating if there was a change and dict of new + tags to assign to the object. + + :param tags: metadata tags from the object + :return: bool, dict + ''' + tags = tags or dict() + new_tags = copy.copy(tags) if isinstance(tags, dict) else dict() + param_tags = self.module.params.get('tags') if isinstance(self.module.params.get('tags'), dict) else dict() + append_tags = self.module.params.get('append_tags') if self.module.params.get('append_tags') is not None else True + changed = False + # check add or update + for key, value in param_tags.items(): + if not new_tags.get(key) or new_tags[key] != value: + changed = True + new_tags[key] = value + # check remove + if not append_tags: + for key, value in tags.items(): + if not param_tags.get(key): + new_tags.pop(key) + changed = True + return changed, new_tags + + def has_tags(self, obj_tags, tag_list): + ''' + Used in fact modules to compare object tags to list of parameter tags. Return true if list of parameter tags + exists in object tags. + + :param obj_tags: dictionary of tags from an Azure object. + :param tag_list: list of tag keys or tag key:value pairs + :return: bool + ''' + + if not obj_tags and tag_list: + return False + + if not tag_list: + return True + + matches = 0 + result = False + for tag in tag_list: + tag_key = tag + tag_value = None + if ':' in tag: + tag_key, tag_value = tag.split(':') + if tag_value and obj_tags.get(tag_key) == tag_value: + matches += 1 + elif not tag_value and obj_tags.get(tag_key): + matches += 1 + if matches == len(tag_list): + result = True + return result + + def get_resource_group(self, resource_group): + ''' + Fetch a resource group. + + :param resource_group: name of a resource group + :return: resource group object + ''' + try: + return self.rm_client.resource_groups.get(resource_group) + except CloudError as cloud_error: + self.fail("Error retrieving resource group {0} - {1}".format(resource_group, cloud_error.message)) + except Exception as exc: + self.fail("Error retrieving resource group {0} - {1}".format(resource_group, str(exc))) + + def parse_resource_to_dict(self, resource): + ''' + Return a dict of the give resource, which contains name and resource group. + + :param resource: It can be a resource name, id or a dict contains name and resource group. + ''' + resource_dict = parse_resource_id(resource) if not isinstance(resource, dict) else resource + resource_dict['resource_group'] = resource_dict.get('resource_group', self.resource_group) + resource_dict['subscription_id'] = resource_dict.get('subscription_id', self.subscription_id) + return resource_dict + + def serialize_obj(self, obj, class_name, enum_modules=None): + ''' + Return a JSON representation of an Azure object. + + :param obj: Azure object + :param class_name: Name of the object's class + :param enum_modules: List of module names to build enum dependencies from. + :return: serialized result + ''' + enum_modules = [] if enum_modules is None else enum_modules + + dependencies = dict() + if enum_modules: + for module_name in enum_modules: + mod = importlib.import_module(module_name) + for mod_class_name, mod_class_obj in inspect.getmembers(mod, predicate=inspect.isclass): + dependencies[mod_class_name] = mod_class_obj + self.log("dependencies: ") + self.log(str(dependencies)) + serializer = Serializer(classes=dependencies) + return serializer.body(obj, class_name, keep_readonly=True) + + def get_poller_result(self, poller, wait=5): + ''' + Consistent method of waiting on and retrieving results from Azure's long poller + + :param poller Azure poller object + :return object resulting from the original request + ''' + try: + delay = wait + while not poller.done(): + self.log("Waiting for {0} sec".format(delay)) + poller.wait(timeout=delay) + return poller.result() + except Exception as exc: + self.log(str(exc)) + raise + + def check_provisioning_state(self, azure_object, requested_state='present'): + ''' + Check an Azure object's provisioning state. If something did not complete the provisioning + process, then we cannot operate on it. + + :param azure_object An object such as a subnet, storageaccount, etc. Must have provisioning_state + and name attributes. + :return None + ''' + + if hasattr(azure_object, 'properties') and hasattr(azure_object.properties, 'provisioning_state') and \ + hasattr(azure_object, 'name'): + # resource group object fits this model + if isinstance(azure_object.properties.provisioning_state, Enum): + if azure_object.properties.provisioning_state.value != AZURE_SUCCESS_STATE and \ + requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.properties.provisioning_state, AZURE_SUCCESS_STATE)) + return + if azure_object.properties.provisioning_state != AZURE_SUCCESS_STATE and \ + requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.properties.provisioning_state, AZURE_SUCCESS_STATE)) + return + + if hasattr(azure_object, 'provisioning_state') or not hasattr(azure_object, 'name'): + if isinstance(azure_object.provisioning_state, Enum): + if azure_object.provisioning_state.value != AZURE_SUCCESS_STATE and requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.provisioning_state, AZURE_SUCCESS_STATE)) + return + if azure_object.provisioning_state != AZURE_SUCCESS_STATE and requested_state != 'absent': + self.fail("Error {0} has a provisioning state of {1}. Expecting state to be {2}.".format( + azure_object.name, azure_object.provisioning_state, AZURE_SUCCESS_STATE)) + + def get_blob_client(self, resource_group_name, storage_account_name, storage_blob_type='block'): + keys = dict() + try: + # Get keys from the storage account + self.log('Getting keys') + account_keys = self.storage_client.storage_accounts.list_keys(resource_group_name, storage_account_name) + except Exception as exc: + self.fail("Error getting keys for account {0} - {1}".format(storage_account_name, str(exc))) + + try: + self.log('Create blob service') + if storage_blob_type == 'page': + return PageBlobService(endpoint_suffix=self._cloud_environment.suffixes.storage_endpoint, + account_name=storage_account_name, + account_key=account_keys.keys[0].value) + elif storage_blob_type == 'block': + return BlockBlobService(endpoint_suffix=self._cloud_environment.suffixes.storage_endpoint, + account_name=storage_account_name, + account_key=account_keys.keys[0].value) + else: + raise Exception("Invalid storage blob type defined.") + except Exception as exc: + self.fail("Error creating blob service client for storage account {0} - {1}".format(storage_account_name, + str(exc))) + + def create_default_pip(self, resource_group, location, public_ip_name, allocation_method='Dynamic', sku=None): + ''' + Create a default public IP address <public_ip_name> to associate with a network interface. + If a PIP address matching <public_ip_name> exists, return it. Otherwise, create one. + + :param resource_group: name of an existing resource group + :param location: a valid azure location + :param public_ip_name: base name to assign the public IP address + :param allocation_method: one of 'Static' or 'Dynamic' + :param sku: sku + :return: PIP object + ''' + pip = None + + self.log("Starting create_default_pip {0}".format(public_ip_name)) + self.log("Check to see if public IP {0} exists".format(public_ip_name)) + try: + pip = self.network_client.public_ip_addresses.get(resource_group, public_ip_name) + except CloudError: + pass + + if pip: + self.log("Public ip {0} found.".format(public_ip_name)) + self.check_provisioning_state(pip) + return pip + + params = self.network_models.PublicIPAddress( + location=location, + public_ip_allocation_method=allocation_method, + sku=sku + ) + self.log('Creating default public IP {0}'.format(public_ip_name)) + try: + poller = self.network_client.public_ip_addresses.create_or_update(resource_group, public_ip_name, params) + except Exception as exc: + self.fail("Error creating {0} - {1}".format(public_ip_name, str(exc))) + + return self.get_poller_result(poller) + + def create_default_securitygroup(self, resource_group, location, security_group_name, os_type, open_ports): + ''' + Create a default security group <security_group_name> to associate with a network interface. If a security group matching + <security_group_name> exists, return it. Otherwise, create one. + + :param resource_group: Resource group name + :param location: azure location name + :param security_group_name: base name to use for the security group + :param os_type: one of 'Windows' or 'Linux'. Determins any default rules added to the security group. + :param ssh_port: for os_type 'Linux' port used in rule allowing SSH access. + :param rdp_port: for os_type 'Windows' port used in rule allowing RDP access. + :return: security_group object + ''' + group = None + + self.log("Create security group {0}".format(security_group_name)) + self.log("Check to see if security group {0} exists".format(security_group_name)) + try: + group = self.network_client.network_security_groups.get(resource_group, security_group_name) + except CloudError: + pass + + if group: + self.log("Security group {0} found.".format(security_group_name)) + self.check_provisioning_state(group) + return group + + parameters = self.network_models.NetworkSecurityGroup() + parameters.location = location + + if not open_ports: + # Open default ports based on OS type + if os_type == 'Linux': + # add an inbound SSH rule + parameters.security_rules = [ + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + description='Allow SSH Access', + source_port_range='*', + destination_port_range='22', + priority=100, + name='SSH') + ] + parameters.location = location + else: + # for windows add inbound RDP and WinRM rules + parameters.security_rules = [ + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + description='Allow RDP port 3389', + source_port_range='*', + destination_port_range='3389', + priority=100, + name='RDP01'), + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + description='Allow WinRM HTTPS port 5986', + source_port_range='*', + destination_port_range='5986', + priority=101, + name='WinRM01'), + ] + else: + # Open custom ports + parameters.security_rules = [] + priority = 100 + for port in open_ports: + priority += 1 + rule_name = "Rule_{0}".format(priority) + parameters.security_rules.append( + self.network_models.SecurityRule(protocol='Tcp', + source_address_prefix='*', + destination_address_prefix='*', + access='Allow', + direction='Inbound', + source_port_range='*', + destination_port_range=str(port), + priority=priority, + name=rule_name) + ) + + self.log('Creating default security group {0}'.format(security_group_name)) + try: + poller = self.network_client.network_security_groups.create_or_update(resource_group, + security_group_name, + parameters) + except Exception as exc: + self.fail("Error creating default security rule {0} - {1}".format(security_group_name, str(exc))) + + return self.get_poller_result(poller) + + @staticmethod + def _validation_ignore_callback(session, global_config, local_config, **kwargs): + session.verify = False + + def get_api_profile(self, client_type_name, api_profile_name): + profile_all_clients = AZURE_API_PROFILES.get(api_profile_name) + + if not profile_all_clients: + raise KeyError("unknown Azure API profile: {0}".format(api_profile_name)) + + profile_raw = profile_all_clients.get(client_type_name, None) + + if not profile_raw: + self.module.warn("Azure API profile {0} does not define an entry for {1}".format(api_profile_name, client_type_name)) + + if isinstance(profile_raw, dict): + if not profile_raw.get('default_api_version'): + raise KeyError("Azure API profile {0} does not define 'default_api_version'".format(api_profile_name)) + return profile_raw + + # wrap basic strings in a dict that just defines the default + return dict(default_api_version=profile_raw) + + def get_mgmt_svc_client(self, client_type, base_url=None, api_version=None): + self.log('Getting management service client {0}'.format(client_type.__name__)) + self.check_client_version(client_type) + + client_argspec = inspect.getargspec(client_type.__init__) + + if not base_url: + # most things are resource_manager, don't make everyone specify + base_url = self.azure_auth._cloud_environment.endpoints.resource_manager + + client_kwargs = dict(credentials=self.azure_auth.azure_credentials, subscription_id=self.azure_auth.subscription_id, base_url=base_url) + + api_profile_dict = {} + + if self.api_profile: + api_profile_dict = self.get_api_profile(client_type.__name__, self.api_profile) + + # unversioned clients won't accept profile; only send it if necessary + # clients without a version specified in the profile will use the default + if api_profile_dict and 'profile' in client_argspec.args: + client_kwargs['profile'] = api_profile_dict + + # If the client doesn't accept api_version, it's unversioned. + # If it does, favor explicitly-specified api_version, fall back to api_profile + if 'api_version' in client_argspec.args: + profile_default_version = api_profile_dict.get('default_api_version', None) + if api_version or profile_default_version: + client_kwargs['api_version'] = api_version or profile_default_version + if 'profile' in client_kwargs: + # remove profile; only pass API version if specified + client_kwargs.pop('profile') + + client = client_type(**client_kwargs) + + # FUTURE: remove this once everything exposes models directly (eg, containerinstance) + try: + getattr(client, "models") + except AttributeError: + def _ansible_get_models(self, *arg, **kwarg): + return self._ansible_models + + setattr(client, '_ansible_models', importlib.import_module(client_type.__module__).models) + client.models = types.MethodType(_ansible_get_models, client) + + client.config = self.add_user_agent(client.config) + + if self.azure_auth._cert_validation_mode == 'ignore': + client.config.session_configuration_callback = self._validation_ignore_callback + + return client + + def add_user_agent(self, config): + # Add user agent for Ansible + config.add_user_agent(ANSIBLE_USER_AGENT) + # Add user agent when running from Cloud Shell + if CLOUDSHELL_USER_AGENT_KEY in os.environ: + config.add_user_agent(os.environ[CLOUDSHELL_USER_AGENT_KEY]) + # Add user agent when running from VSCode extension + if VSCODEEXT_USER_AGENT_KEY in os.environ: + config.add_user_agent(os.environ[VSCODEEXT_USER_AGENT_KEY]) + return config + + def generate_sas_token(self, **kwags): + base_url = kwags.get('base_url', None) + expiry = kwags.get('expiry', time() + 3600) + key = kwags.get('key', None) + policy = kwags.get('policy', None) + url = quote_plus(base_url) + ttl = int(expiry) + sign_key = '{0}\n{1}'.format(url, ttl) + signature = b64encode(HMAC(b64decode(key), sign_key.encode('utf-8'), sha256).digest()) + result = { + 'sr': url, + 'sig': signature, + 'se': str(ttl), + } + if policy: + result['skn'] = policy + return 'SharedAccessSignature ' + urlencode(result) + + def get_data_svc_client(self, **kwags): + url = kwags.get('base_url', None) + config = AzureConfiguration(base_url='https://{0}'.format(url)) + config.credentials = AzureSASAuthentication(token=self.generate_sas_token(**kwags)) + config = self.add_user_agent(config) + return ServiceClient(creds=config.credentials, config=config) + + # passthru methods to AzureAuth instance for backcompat + @property + def credentials(self): + return self.azure_auth.credentials + + @property + def _cloud_environment(self): + return self.azure_auth._cloud_environment + + @property + def subscription_id(self): + return self.azure_auth.subscription_id + + @property + def storage_client(self): + self.log('Getting storage client...') + if not self._storage_client: + self._storage_client = self.get_mgmt_svc_client(StorageManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-07-01') + return self._storage_client + + @property + def storage_models(self): + return StorageManagementClient.models("2018-07-01") + + @property + def network_client(self): + self.log('Getting network client') + if not self._network_client: + self._network_client = self.get_mgmt_svc_client(NetworkManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2019-06-01') + return self._network_client + + @property + def network_models(self): + self.log("Getting network models...") + return NetworkManagementClient.models("2018-08-01") + + @property + def rm_client(self): + self.log('Getting resource manager client') + if not self._resource_client: + self._resource_client = self.get_mgmt_svc_client(ResourceManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2017-05-10') + return self._resource_client + + @property + def rm_models(self): + self.log("Getting resource manager models") + return ResourceManagementClient.models("2017-05-10") + + @property + def compute_client(self): + self.log('Getting compute client') + if not self._compute_client: + self._compute_client = self.get_mgmt_svc_client(ComputeManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2019-07-01') + return self._compute_client + + @property + def compute_models(self): + self.log("Getting compute models") + return ComputeManagementClient.models("2019-07-01") + + @property + def dns_client(self): + self.log('Getting dns client') + if not self._dns_client: + self._dns_client = self.get_mgmt_svc_client(DnsManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-05-01') + return self._dns_client + + @property + def dns_models(self): + self.log("Getting dns models...") + return DnsManagementClient.models('2018-05-01') + + @property + def web_client(self): + self.log('Getting web client') + if not self._web_client: + self._web_client = self.get_mgmt_svc_client(WebSiteManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-02-01') + return self._web_client + + @property + def containerservice_client(self): + self.log('Getting container service client') + if not self._containerservice_client: + self._containerservice_client = self.get_mgmt_svc_client(ContainerServiceClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2017-07-01') + return self._containerservice_client + + @property + def managedcluster_models(self): + self.log("Getting container service models") + return ContainerServiceClient.models('2018-03-31') + + @property + def managedcluster_client(self): + self.log('Getting container service client') + if not self._managedcluster_client: + self._managedcluster_client = self.get_mgmt_svc_client(ContainerServiceClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-03-31') + return self._managedcluster_client + + @property + def sql_client(self): + self.log('Getting SQL client') + if not self._sql_client: + self._sql_client = self.get_mgmt_svc_client(SqlManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._sql_client + + @property + def postgresql_client(self): + self.log('Getting PostgreSQL client') + if not self._postgresql_client: + self._postgresql_client = self.get_mgmt_svc_client(PostgreSQLManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._postgresql_client + + @property + def mysql_client(self): + self.log('Getting MySQL client') + if not self._mysql_client: + self._mysql_client = self.get_mgmt_svc_client(MySQLManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._mysql_client + + @property + def mariadb_client(self): + self.log('Getting MariaDB client') + if not self._mariadb_client: + self._mariadb_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._mariadb_client + + @property + def sql_client(self): + self.log('Getting SQL client') + if not self._sql_client: + self._sql_client = self.get_mgmt_svc_client(SqlManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._sql_client + + @property + def containerregistry_client(self): + self.log('Getting container registry mgmt client') + if not self._containerregistry_client: + self._containerregistry_client = self.get_mgmt_svc_client(ContainerRegistryManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2017-10-01') + + return self._containerregistry_client + + @property + def containerinstance_client(self): + self.log('Getting container instance mgmt client') + if not self._containerinstance_client: + self._containerinstance_client = self.get_mgmt_svc_client(ContainerInstanceManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2018-06-01') + + return self._containerinstance_client + + @property + def marketplace_client(self): + self.log('Getting marketplace agreement client') + if not self._marketplace_client: + self._marketplace_client = self.get_mgmt_svc_client(MarketplaceOrderingAgreements, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._marketplace_client + + @property + def traffic_manager_management_client(self): + self.log('Getting traffic manager client') + if not self._traffic_manager_management_client: + self._traffic_manager_management_client = self.get_mgmt_svc_client(TrafficManagerManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._traffic_manager_management_client + + @property + def monitor_client(self): + self.log('Getting monitor client') + if not self._monitor_client: + self._monitor_client = self.get_mgmt_svc_client(MonitorManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._monitor_client + + @property + def log_analytics_client(self): + self.log('Getting log analytics client') + if not self._log_analytics_client: + self._log_analytics_client = self.get_mgmt_svc_client(LogAnalyticsManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._log_analytics_client + + @property + def log_analytics_models(self): + self.log('Getting log analytics models') + return LogAnalyticsModels + + @property + def servicebus_client(self): + self.log('Getting servicebus client') + if not self._servicebus_client: + self._servicebus_client = self.get_mgmt_svc_client(ServiceBusManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._servicebus_client + + @property + def servicebus_models(self): + return ServicebusModel + + @property + def automation_client(self): + self.log('Getting automation client') + if not self._automation_client: + self._automation_client = self.get_mgmt_svc_client(AutomationClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._automation_client + + @property + def automation_models(self): + return AutomationModel + + @property + def IoThub_client(self): + self.log('Getting iothub client') + if not self._IoThub_client: + self._IoThub_client = self.get_mgmt_svc_client(IotHubClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._IoThub_client + + @property + def IoThub_models(self): + return IoTHubModels + + @property + def automation_client(self): + self.log('Getting automation client') + if not self._automation_client: + self._automation_client = self.get_mgmt_svc_client(AutomationClient, + base_url=self._cloud_environment.endpoints.resource_manager) + return self._automation_client + + @property + def automation_models(self): + return AutomationModel + + @property + def lock_client(self): + self.log('Getting lock client') + if not self._lock_client: + self._lock_client = self.get_mgmt_svc_client(ManagementLockClient, + base_url=self._cloud_environment.endpoints.resource_manager, + api_version='2016-09-01') + return self._lock_client + + @property + def lock_models(self): + self.log("Getting lock models") + return ManagementLockClient.models('2016-09-01') + + +class AzureSASAuthentication(Authentication): + """Simple SAS Authentication. + An implementation of Authentication in + https://github.com/Azure/msrest-for-python/blob/0732bc90bdb290e5f58c675ffdd7dbfa9acefc93/msrest/authentication.py + + :param str token: SAS token + """ + def __init__(self, token): + self.token = token + + def signed_session(self): + session = super(AzureSASAuthentication, self).signed_session() + session.headers['Authorization'] = self.token + return session + + +class AzureRMAuthException(Exception): + pass + + +class AzureRMAuth(object): + def __init__(self, auth_source='auto', profile=None, subscription_id=None, client_id=None, secret=None, + tenant=None, ad_user=None, password=None, cloud_environment='AzureCloud', cert_validation_mode='validate', + api_profile='latest', adfs_authority_url=None, fail_impl=None, **kwargs): + + if fail_impl: + self._fail_impl = fail_impl + else: + self._fail_impl = self._default_fail_impl + + self._cloud_environment = None + self._adfs_authority_url = None + + # authenticate + self.credentials = self._get_credentials( + dict(auth_source=auth_source, profile=profile, subscription_id=subscription_id, client_id=client_id, secret=secret, + tenant=tenant, ad_user=ad_user, password=password, cloud_environment=cloud_environment, + cert_validation_mode=cert_validation_mode, api_profile=api_profile, adfs_authority_url=adfs_authority_url)) + + if not self.credentials: + if HAS_AZURE_CLI_CORE: + self.fail("Failed to get credentials. Either pass as parameters, set environment variables, " + "define a profile in ~/.azure/credentials, or log in with Azure CLI (`az login`).") + else: + self.fail("Failed to get credentials. Either pass as parameters, set environment variables, " + "define a profile in ~/.azure/credentials, or install Azure CLI and log in (`az login`).") + + # cert validation mode precedence: module-arg, credential profile, env, "validate" + self._cert_validation_mode = cert_validation_mode or self.credentials.get('cert_validation_mode') or \ + os.environ.get('AZURE_CERT_VALIDATION_MODE') or 'validate' + + if self._cert_validation_mode not in ['validate', 'ignore']: + self.fail('invalid cert_validation_mode: {0}'.format(self._cert_validation_mode)) + + # if cloud_environment specified, look up/build Cloud object + raw_cloud_env = self.credentials.get('cloud_environment') + if self.credentials.get('credentials') is not None and raw_cloud_env is not None: + self._cloud_environment = raw_cloud_env + elif not raw_cloud_env: + self._cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD # SDK default + else: + # try to look up "well-known" values via the name attribute on azure_cloud members + all_clouds = [x[1] for x in inspect.getmembers(azure_cloud) if isinstance(x[1], azure_cloud.Cloud)] + matched_clouds = [x for x in all_clouds if x.name == raw_cloud_env] + if len(matched_clouds) == 1: + self._cloud_environment = matched_clouds[0] + elif len(matched_clouds) > 1: + self.fail("Azure SDK failure: more than one cloud matched for cloud_environment name '{0}'".format(raw_cloud_env)) + else: + if not urlparse.urlparse(raw_cloud_env).scheme: + self.fail("cloud_environment must be an endpoint discovery URL or one of {0}".format([x.name for x in all_clouds])) + try: + self._cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(raw_cloud_env) + except Exception as e: + self.fail("cloud_environment {0} could not be resolved: {1}".format(raw_cloud_env, e.message), exception=traceback.format_exc()) + + if self.credentials.get('subscription_id', None) is None and self.credentials.get('credentials') is None: + self.fail("Credentials did not include a subscription_id value.") + self.log("setting subscription_id") + self.subscription_id = self.credentials['subscription_id'] + + # get authentication authority + # for adfs, user could pass in authority or not. + # for others, use default authority from cloud environment + if self.credentials.get('adfs_authority_url') is None: + self._adfs_authority_url = self._cloud_environment.endpoints.active_directory + else: + self._adfs_authority_url = self.credentials.get('adfs_authority_url') + + # get resource from cloud environment + self._resource = self._cloud_environment.endpoints.active_directory_resource_id + + if self.credentials.get('credentials') is not None: + # AzureCLI credentials + self.azure_credentials = self.credentials['credentials'] + elif self.credentials.get('client_id') is not None and \ + self.credentials.get('secret') is not None and \ + self.credentials.get('tenant') is not None: + self.azure_credentials = ServicePrincipalCredentials(client_id=self.credentials['client_id'], + secret=self.credentials['secret'], + tenant=self.credentials['tenant'], + cloud_environment=self._cloud_environment, + verify=self._cert_validation_mode == 'validate') + + elif self.credentials.get('ad_user') is not None and \ + self.credentials.get('password') is not None and \ + self.credentials.get('client_id') is not None and \ + self.credentials.get('tenant') is not None: + + self.azure_credentials = self.acquire_token_with_username_password( + self._adfs_authority_url, + self._resource, + self.credentials['ad_user'], + self.credentials['password'], + self.credentials['client_id'], + self.credentials['tenant']) + + elif self.credentials.get('ad_user') is not None and self.credentials.get('password') is not None: + tenant = self.credentials.get('tenant') + if not tenant: + tenant = 'common' # SDK default + + self.azure_credentials = UserPassCredentials(self.credentials['ad_user'], + self.credentials['password'], + tenant=tenant, + cloud_environment=self._cloud_environment, + verify=self._cert_validation_mode == 'validate') + else: + self.fail("Failed to authenticate with provided credentials. Some attributes were missing. " + "Credentials must include client_id, secret and tenant or ad_user and password, or " + "ad_user, password, client_id, tenant and adfs_authority_url(optional) for ADFS authentication, or " + "be logged in using AzureCLI.") + + def fail(self, msg, exception=None, **kwargs): + self._fail_impl(msg) + + def _default_fail_impl(self, msg, exception=None, **kwargs): + raise AzureRMAuthException(msg) + + def _get_profile(self, profile="default"): + path = expanduser("~/.azure/credentials") + try: + config = configparser.ConfigParser() + config.read(path) + except Exception as exc: + self.fail("Failed to access {0}. Check that the file exists and you have read " + "access. {1}".format(path, str(exc))) + credentials = dict() + for key in AZURE_CREDENTIAL_ENV_MAPPING: + try: + credentials[key] = config.get(profile, key, raw=True) + except Exception: + pass + + if credentials.get('subscription_id'): + return credentials + + return None + + def _get_msi_credentials(self, subscription_id_param=None, **kwargs): + client_id = kwargs.get('client_id', None) + credentials = MSIAuthentication(client_id=client_id) + subscription_id = subscription_id_param or os.environ.get(AZURE_CREDENTIAL_ENV_MAPPING['subscription_id'], None) + if not subscription_id: + try: + # use the first subscription of the MSI + subscription_client = SubscriptionClient(credentials) + subscription = next(subscription_client.subscriptions.list()) + subscription_id = str(subscription.subscription_id) + except Exception as exc: + self.fail("Failed to get MSI token: {0}. " + "Please check whether your machine enabled MSI or grant access to any subscription.".format(str(exc))) + return { + 'credentials': credentials, + 'subscription_id': subscription_id + } + + def _get_azure_cli_credentials(self): + credentials, subscription_id = get_azure_cli_credentials() + cloud_environment = get_cli_active_cloud() + + cli_credentials = { + 'credentials': credentials, + 'subscription_id': subscription_id, + 'cloud_environment': cloud_environment + } + return cli_credentials + + def _get_env_credentials(self): + env_credentials = dict() + for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items(): + env_credentials[attribute] = os.environ.get(env_variable, None) + + if env_credentials['profile']: + credentials = self._get_profile(env_credentials['profile']) + return credentials + + if env_credentials.get('subscription_id') is not None: + return env_credentials + + return None + + # TODO: use explicit kwargs instead of intermediate dict + def _get_credentials(self, params): + # Get authentication credentials. + self.log('Getting credentials') + + arg_credentials = dict() + for attribute, env_variable in AZURE_CREDENTIAL_ENV_MAPPING.items(): + arg_credentials[attribute] = params.get(attribute, None) + + auth_source = params.get('auth_source', None) + if not auth_source: + auth_source = os.environ.get('ANSIBLE_AZURE_AUTH_SOURCE', 'auto') + + if auth_source == 'msi': + self.log('Retrieving credenitals from MSI') + return self._get_msi_credentials(arg_credentials['subscription_id'], client_id=params.get('client_id', None)) + + if auth_source == 'cli': + if not HAS_AZURE_CLI_CORE: + self.fail(msg=missing_required_lib('azure-cli', reason='for `cli` auth_source'), + exception=HAS_AZURE_CLI_CORE_EXC) + try: + self.log('Retrieving credentials from Azure CLI profile') + cli_credentials = self._get_azure_cli_credentials() + return cli_credentials + except CLIError as err: + self.fail("Azure CLI profile cannot be loaded - {0}".format(err)) + + if auth_source == 'env': + self.log('Retrieving credentials from environment') + env_credentials = self._get_env_credentials() + return env_credentials + + if auth_source == 'credential_file': + self.log("Retrieving credentials from credential file") + profile = params.get('profile') or 'default' + default_credentials = self._get_profile(profile) + return default_credentials + + # auto, precedence: module parameters -> environment variables -> default profile in ~/.azure/credentials + # try module params + if arg_credentials['profile'] is not None: + self.log('Retrieving credentials with profile parameter.') + credentials = self._get_profile(arg_credentials['profile']) + return credentials + + if arg_credentials['subscription_id']: + self.log('Received credentials from parameters.') + return arg_credentials + + # try environment + env_credentials = self._get_env_credentials() + if env_credentials: + self.log('Received credentials from env.') + return env_credentials + + # try default profile from ~./azure/credentials + default_credentials = self._get_profile() + if default_credentials: + self.log('Retrieved default profile credentials from ~/.azure/credentials.') + return default_credentials + + try: + if HAS_AZURE_CLI_CORE: + self.log('Retrieving credentials from AzureCLI profile') + cli_credentials = self._get_azure_cli_credentials() + return cli_credentials + except CLIError as ce: + self.log('Error getting AzureCLI profile credentials - {0}'.format(ce)) + + return None + + def acquire_token_with_username_password(self, authority, resource, username, password, client_id, tenant): + authority_uri = authority + + if tenant is not None: + authority_uri = authority + '/' + tenant + + context = AuthenticationContext(authority_uri) + token_response = context.acquire_token_with_username_password(resource, username, password, client_id) + + return AADTokenCredentials(token_response) + + def log(self, msg, pretty_print=False): + pass + # Use only during module development + # if self.debug: + # log_file = open('azure_rm.log', 'a') + # if pretty_print: + # log_file.write(json.dumps(msg, indent=4, sort_keys=True)) + # else: + # log_file.write(msg + u'\n') diff --git a/test/support/integration/plugins/module_utils/azure_rm_common_rest.py b/test/support/integration/plugins/module_utils/azure_rm_common_rest.py new file mode 100644 index 00000000..4fd7eaa3 --- /dev/null +++ b/test/support/integration/plugins/module_utils/azure_rm_common_rest.py @@ -0,0 +1,97 @@ +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from ansible.module_utils.ansible_release import __version__ as ANSIBLE_VERSION + +try: + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_configuration import AzureConfiguration + from msrest.service_client import ServiceClient + from msrest.pipeline import ClientRawResponse + from msrest.polling import LROPoller + from msrestazure.polling.arm_polling import ARMPolling + import uuid + import json +except ImportError: + # This is handled in azure_rm_common + AzureConfiguration = object + +ANSIBLE_USER_AGENT = 'Ansible/{0}'.format(ANSIBLE_VERSION) + + +class GenericRestClientConfiguration(AzureConfiguration): + + def __init__(self, credentials, subscription_id, base_url=None): + + if credentials is None: + raise ValueError("Parameter 'credentials' must not be None.") + if subscription_id is None: + raise ValueError("Parameter 'subscription_id' must not be None.") + if not base_url: + base_url = 'https://management.azure.com' + + super(GenericRestClientConfiguration, self).__init__(base_url) + + self.add_user_agent(ANSIBLE_USER_AGENT) + + self.credentials = credentials + self.subscription_id = subscription_id + + +class GenericRestClient(object): + + def __init__(self, credentials, subscription_id, base_url=None): + self.config = GenericRestClientConfiguration(credentials, subscription_id, base_url) + self._client = ServiceClient(self.config.credentials, self.config) + self.models = None + + def query(self, url, method, query_parameters, header_parameters, body, expected_status_codes, polling_timeout, polling_interval): + # Construct and send request + operation_config = {} + + request = None + + if header_parameters is None: + header_parameters = {} + + header_parameters['x-ms-client-request-id'] = str(uuid.uuid1()) + + if method == 'GET': + request = self._client.get(url, query_parameters) + elif method == 'PUT': + request = self._client.put(url, query_parameters) + elif method == 'POST': + request = self._client.post(url, query_parameters) + elif method == 'HEAD': + request = self._client.head(url, query_parameters) + elif method == 'PATCH': + request = self._client.patch(url, query_parameters) + elif method == 'DELETE': + request = self._client.delete(url, query_parameters) + elif method == 'MERGE': + request = self._client.merge(url, query_parameters) + + response = self._client.send(request, header_parameters, body, **operation_config) + + if response.status_code not in expected_status_codes: + exp = CloudError(response) + exp.request_id = response.headers.get('x-ms-request-id') + raise exp + elif response.status_code == 202 and polling_timeout > 0: + def get_long_running_output(response): + return response + poller = LROPoller(self._client, + ClientRawResponse(None, response), + get_long_running_output, + ARMPolling(polling_interval, **operation_config)) + response = self.get_poller_result(poller, polling_timeout) + + return response + + def get_poller_result(self, poller, timeout): + try: + poller.wait(timeout=timeout) + return poller.result() + except Exception as exc: + raise diff --git a/test/support/integration/plugins/module_utils/cloud.py b/test/support/integration/plugins/module_utils/cloud.py new file mode 100644 index 00000000..0d29071f --- /dev/null +++ b/test/support/integration/plugins/module_utils/cloud.py @@ -0,0 +1,217 @@ +# +# (c) 2016 Allen Sanabria, <asanabria@linuxdynasty.org> +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +""" +This module adds shared support for generic cloud modules + +In order to use this module, include it as part of a custom +module as shown below. + +from ansible.module_utils.cloud import CloudRetry + +The 'cloud' module provides the following common classes: + + * CloudRetry + - The base class to be used by other cloud providers, in order to + provide a backoff/retry decorator based on status codes. + + - Example using the AWSRetry class which inherits from CloudRetry. + + @AWSRetry.exponential_backoff(retries=10, delay=3) + get_ec2_security_group_ids_from_names() + + @AWSRetry.jittered_backoff() + get_ec2_security_group_ids_from_names() + +""" +import random +from functools import wraps +import syslog +import time + + +def _exponential_backoff(retries=10, delay=2, backoff=2, max_delay=60): + """ Customizable exponential backoff strategy. + Args: + retries (int): Maximum number of times to retry a request. + delay (float): Initial (base) delay. + backoff (float): base of the exponent to use for exponential + backoff. + max_delay (int): Optional. If provided each delay generated is capped + at this amount. Defaults to 60 seconds. + Returns: + Callable that returns a generator. This generator yields durations in + seconds to be used as delays for an exponential backoff strategy. + Usage: + >>> backoff = _exponential_backoff() + >>> backoff + <function backoff_backoff at 0x7f0d939facf8> + >>> list(backoff()) + [2, 4, 8, 16, 32, 60, 60, 60, 60, 60] + """ + def backoff_gen(): + for retry in range(0, retries): + sleep = delay * backoff ** retry + yield sleep if max_delay is None else min(sleep, max_delay) + return backoff_gen + + +def _full_jitter_backoff(retries=10, delay=3, max_delay=60, _random=random): + """ Implements the "Full Jitter" backoff strategy described here + https://www.awsarchitectureblog.com/2015/03/backoff.html + Args: + retries (int): Maximum number of times to retry a request. + delay (float): Approximate number of seconds to sleep for the first + retry. + max_delay (int): The maximum number of seconds to sleep for any retry. + _random (random.Random or None): Makes this generator testable by + allowing developers to explicitly pass in the a seeded Random. + Returns: + Callable that returns a generator. This generator yields durations in + seconds to be used as delays for a full jitter backoff strategy. + Usage: + >>> backoff = _full_jitter_backoff(retries=5) + >>> backoff + <function backoff_backoff at 0x7f0d939facf8> + >>> list(backoff()) + [3, 6, 5, 23, 38] + >>> list(backoff()) + [2, 1, 6, 6, 31] + """ + def backoff_gen(): + for retry in range(0, retries): + yield _random.randint(0, min(max_delay, delay * 2 ** retry)) + return backoff_gen + + +class CloudRetry(object): + """ CloudRetry can be used by any cloud provider, in order to implement a + backoff algorithm/retry effect based on Status Code from Exceptions. + """ + # This is the base class of the exception. + # AWS Example botocore.exceptions.ClientError + base_class = None + + @staticmethod + def status_code_from_exception(error): + """ Return the status code from the exception object + Args: + error (object): The exception itself. + """ + pass + + @staticmethod + def found(response_code, catch_extra_error_codes=None): + """ Return True if the Response Code to retry on was found. + Args: + response_code (str): This is the Response Code that is being matched against. + """ + pass + + @classmethod + def _backoff(cls, backoff_strategy, catch_extra_error_codes=None): + """ Retry calling the Cloud decorated function using the provided + backoff strategy. + Args: + backoff_strategy (callable): Callable that returns a generator. The + generator should yield sleep times for each retry of the decorated + function. + """ + def deco(f): + @wraps(f) + def retry_func(*args, **kwargs): + for delay in backoff_strategy(): + try: + return f(*args, **kwargs) + except Exception as e: + if isinstance(e, cls.base_class): + response_code = cls.status_code_from_exception(e) + if cls.found(response_code, catch_extra_error_codes): + msg = "{0}: Retrying in {1} seconds...".format(str(e), delay) + syslog.syslog(syslog.LOG_INFO, msg) + time.sleep(delay) + else: + # Return original exception if exception is not a ClientError + raise e + else: + # Return original exception if exception is not a ClientError + raise e + return f(*args, **kwargs) + + return retry_func # true decorator + + return deco + + @classmethod + def exponential_backoff(cls, retries=10, delay=3, backoff=2, max_delay=60, catch_extra_error_codes=None): + """ + Retry calling the Cloud decorated function using an exponential backoff. + + Kwargs: + retries (int): Number of times to retry a failed request before giving up + default=10 + delay (int or float): Initial delay between retries in seconds + default=3 + backoff (int or float): backoff multiplier e.g. value of 2 will + double the delay each retry + default=1.1 + max_delay (int or None): maximum amount of time to wait between retries. + default=60 + """ + return cls._backoff(_exponential_backoff( + retries=retries, delay=delay, backoff=backoff, max_delay=max_delay), catch_extra_error_codes) + + @classmethod + def jittered_backoff(cls, retries=10, delay=3, max_delay=60, catch_extra_error_codes=None): + """ + Retry calling the Cloud decorated function using a jittered backoff + strategy. More on this strategy here: + + https://www.awsarchitectureblog.com/2015/03/backoff.html + + Kwargs: + retries (int): Number of times to retry a failed request before giving up + default=10 + delay (int): Initial delay between retries in seconds + default=3 + max_delay (int): maximum amount of time to wait between retries. + default=60 + """ + return cls._backoff(_full_jitter_backoff( + retries=retries, delay=delay, max_delay=max_delay), catch_extra_error_codes) + + @classmethod + def backoff(cls, tries=10, delay=3, backoff=1.1, catch_extra_error_codes=None): + """ + Retry calling the Cloud decorated function using an exponential backoff. + + Compatibility for the original implementation of CloudRetry.backoff that + did not provide configurable backoff strategies. Developers should use + CloudRetry.exponential_backoff instead. + + Kwargs: + tries (int): Number of times to try (not retry) before giving up + default=10 + delay (int or float): Initial delay between retries in seconds + default=3 + backoff (int or float): backoff multiplier e.g. value of 2 will + double the delay each retry + default=1.1 + """ + return cls.exponential_backoff( + retries=tries - 1, delay=delay, backoff=backoff, max_delay=None, catch_extra_error_codes=catch_extra_error_codes) diff --git a/test/support/integration/plugins/module_utils/compat/__init__.py b/test/support/integration/plugins/module_utils/compat/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/compat/__init__.py diff --git a/test/support/integration/plugins/module_utils/compat/ipaddress.py b/test/support/integration/plugins/module_utils/compat/ipaddress.py new file mode 100644 index 00000000..c46ad72a --- /dev/null +++ b/test/support/integration/plugins/module_utils/compat/ipaddress.py @@ -0,0 +1,2476 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file, and this file only, is based on +# Lib/ipaddress.py of cpython +# It is licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 +# +# 1. This LICENSE AGREEMENT is between the Python Software Foundation +# ("PSF"), and the Individual or Organization ("Licensee") accessing and +# otherwise using this software ("Python") in source or binary form and +# its associated documentation. +# +# 2. Subject to the terms and conditions of this License Agreement, PSF hereby +# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +# analyze, test, perform and/or display publicly, prepare derivative works, +# distribute, and otherwise use Python alone or in any derivative version, +# provided, however, that PSF's License Agreement and PSF's notice of copyright, +# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, +# 2011, 2012, 2013, 2014, 2015 Python Software Foundation; All Rights Reserved" +# are retained in Python alone or in any derivative version prepared by Licensee. +# +# 3. In the event Licensee prepares a derivative work that is based on +# or incorporates Python or any part thereof, and wants to make +# the derivative work available to others as provided herein, then +# Licensee hereby agrees to include in any such work a brief summary of +# the changes made to Python. +# +# 4. PSF is making Python available to Licensee on an "AS IS" +# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +# INFRINGE ANY THIRD PARTY RIGHTS. +# +# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. +# +# 6. This License Agreement will automatically terminate upon a material +# breach of its terms and conditions. +# +# 7. Nothing in this License Agreement shall be deemed to create any +# relationship of agency, partnership, or joint venture between PSF and +# Licensee. This License Agreement does not grant permission to use PSF +# trademarks or trade name in a trademark sense to endorse or promote +# products or services of Licensee, or any third party. +# +# 8. By copying, installing or otherwise using Python, Licensee +# agrees to be bound by the terms and conditions of this License +# Agreement. + +# Copyright 2007 Google Inc. +# Licensed to PSF under a Contributor Agreement. + +"""A fast, lightweight IPv4/IPv6 manipulation library in Python. + +This library is used to create/poke/manipulate IPv4 and IPv6 addresses +and networks. + +""" + +from __future__ import unicode_literals + + +import itertools +import struct + + +# The following makes it easier for us to script updates of the bundled code and is not part of +# upstream +_BUNDLED_METADATA = {"pypi_name": "ipaddress", "version": "1.0.22"} + +__version__ = '1.0.22' + +# Compatibility functions +_compat_int_types = (int,) +try: + _compat_int_types = (int, long) +except NameError: + pass +try: + _compat_str = unicode +except NameError: + _compat_str = str + assert bytes != str +if b'\0'[0] == 0: # Python 3 semantics + def _compat_bytes_to_byte_vals(byt): + return byt +else: + def _compat_bytes_to_byte_vals(byt): + return [struct.unpack(b'!B', b)[0] for b in byt] +try: + _compat_int_from_byte_vals = int.from_bytes +except AttributeError: + def _compat_int_from_byte_vals(bytvals, endianess): + assert endianess == 'big' + res = 0 + for bv in bytvals: + assert isinstance(bv, _compat_int_types) + res = (res << 8) + bv + return res + + +def _compat_to_bytes(intval, length, endianess): + assert isinstance(intval, _compat_int_types) + assert endianess == 'big' + if length == 4: + if intval < 0 or intval >= 2 ** 32: + raise struct.error("integer out of range for 'I' format code") + return struct.pack(b'!I', intval) + elif length == 16: + if intval < 0 or intval >= 2 ** 128: + raise struct.error("integer out of range for 'QQ' format code") + return struct.pack(b'!QQ', intval >> 64, intval & 0xffffffffffffffff) + else: + raise NotImplementedError() + + +if hasattr(int, 'bit_length'): + # Not int.bit_length , since that won't work in 2.7 where long exists + def _compat_bit_length(i): + return i.bit_length() +else: + def _compat_bit_length(i): + for res in itertools.count(): + if i >> res == 0: + return res + + +def _compat_range(start, end, step=1): + assert step > 0 + i = start + while i < end: + yield i + i += step + + +class _TotalOrderingMixin(object): + __slots__ = () + + # Helper that derives the other comparison operations from + # __lt__ and __eq__ + # We avoid functools.total_ordering because it doesn't handle + # NotImplemented correctly yet (http://bugs.python.org/issue10042) + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not equal + + def __lt__(self, other): + raise NotImplementedError + + def __le__(self, other): + less = self.__lt__(other) + if less is NotImplemented or not less: + return self.__eq__(other) + return less + + def __gt__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + equal = self.__eq__(other) + if equal is NotImplemented: + return NotImplemented + return not (less or equal) + + def __ge__(self, other): + less = self.__lt__(other) + if less is NotImplemented: + return NotImplemented + return not less + + +IPV4LENGTH = 32 +IPV6LENGTH = 128 + + +class AddressValueError(ValueError): + """A Value Error related to the address.""" + + +class NetmaskValueError(ValueError): + """A Value Error related to the netmask.""" + + +def ip_address(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Address or IPv6Address object. + + Raises: + ValueError: if the *address* passed isn't either a v4 or a v6 + address + + """ + try: + return IPv4Address(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Address(address) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + '%r does not appear to be an IPv4 or IPv6 address. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?' % address) + + raise ValueError('%r does not appear to be an IPv4 or IPv6 address' % + address) + + +def ip_network(address, strict=True): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP network. Either IPv4 or + IPv6 networks may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Network or IPv6Network object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. Or if the network has host bits set. + + """ + try: + return IPv4Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Network(address, strict) + except (AddressValueError, NetmaskValueError): + pass + + if isinstance(address, bytes): + raise AddressValueError( + '%r does not appear to be an IPv4 or IPv6 network. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?' % address) + + raise ValueError('%r does not appear to be an IPv4 or IPv6 network' % + address) + + +def ip_interface(address): + """Take an IP string/int and return an object of the correct type. + + Args: + address: A string or integer, the IP address. Either IPv4 or + IPv6 addresses may be supplied; integers less than 2**32 will + be considered to be IPv4 by default. + + Returns: + An IPv4Interface or IPv6Interface object. + + Raises: + ValueError: if the string passed isn't either a v4 or a v6 + address. + + Notes: + The IPv?Interface classes describe an Address on a particular + Network, so they're basically a combination of both the Address + and Network classes. + + """ + try: + return IPv4Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + try: + return IPv6Interface(address) + except (AddressValueError, NetmaskValueError): + pass + + raise ValueError('%r does not appear to be an IPv4 or IPv6 interface' % + address) + + +def v4_int_to_packed(address): + """Represent an address as 4 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv4 IP address. + + Returns: + The integer address packed as 4 bytes in network (big-endian) order. + + Raises: + ValueError: If the integer is negative or too large to be an + IPv4 IP address. + + """ + try: + return _compat_to_bytes(address, 4, 'big') + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv4") + + +def v6_int_to_packed(address): + """Represent an address as 16 packed bytes in network (big-endian) order. + + Args: + address: An integer representation of an IPv6 IP address. + + Returns: + The integer address packed as 16 bytes in network (big-endian) order. + + """ + try: + return _compat_to_bytes(address, 16, 'big') + except (struct.error, OverflowError): + raise ValueError("Address negative or too large for IPv6") + + +def _split_optional_netmask(address): + """Helper to split the netmask and raise AddressValueError if needed""" + addr = _compat_str(address).split('/') + if len(addr) > 2: + raise AddressValueError("Only one '/' permitted in %r" % address) + return addr + + +def _find_address_range(addresses): + """Find a sequence of sorted deduplicated IPv#Address. + + Args: + addresses: a list of IPv#Address objects. + + Yields: + A tuple containing the first and last IP addresses in the sequence. + + """ + it = iter(addresses) + first = last = next(it) # pylint: disable=stop-iteration-return + for ip in it: + if ip._ip != last._ip + 1: + yield first, last + first = ip + last = ip + yield first, last + + +def _count_righthand_zero_bits(number, bits): + """Count the number of zero bits on the right hand side. + + Args: + number: an integer. + bits: maximum number of bits to count. + + Returns: + The number of zero bits on the right hand side of the number. + + """ + if number == 0: + return bits + return min(bits, _compat_bit_length(~number & (number - 1))) + + +def summarize_address_range(first, last): + """Summarize a network range given the first and last IP addresses. + + Example: + >>> list(summarize_address_range(IPv4Address('192.0.2.0'), + ... IPv4Address('192.0.2.130'))) + ... #doctest: +NORMALIZE_WHITESPACE + [IPv4Network('192.0.2.0/25'), IPv4Network('192.0.2.128/31'), + IPv4Network('192.0.2.130/32')] + + Args: + first: the first IPv4Address or IPv6Address in the range. + last: the last IPv4Address or IPv6Address in the range. + + Returns: + An iterator of the summarized IPv(4|6) network objects. + + Raise: + TypeError: + If the first and last objects are not IP addresses. + If the first and last objects are not the same version. + ValueError: + If the last object is not greater than the first. + If the version of the first address is not 4 or 6. + + """ + if (not (isinstance(first, _BaseAddress) and + isinstance(last, _BaseAddress))): + raise TypeError('first and last must be IP addresses, not networks') + if first.version != last.version: + raise TypeError("%s and %s are not of the same version" % ( + first, last)) + if first > last: + raise ValueError('last IP address must be greater than first') + + if first.version == 4: + ip = IPv4Network + elif first.version == 6: + ip = IPv6Network + else: + raise ValueError('unknown IP version') + + ip_bits = first._max_prefixlen + first_int = first._ip + last_int = last._ip + while first_int <= last_int: + nbits = min(_count_righthand_zero_bits(first_int, ip_bits), + _compat_bit_length(last_int - first_int + 1) - 1) + net = ip((first_int, ip_bits - nbits)) + yield net + first_int += 1 << nbits + if first_int - 1 == ip._ALL_ONES: + break + + +def _collapse_addresses_internal(addresses): + """Loops through the addresses, collapsing concurrent netblocks. + + Example: + + ip1 = IPv4Network('192.0.2.0/26') + ip2 = IPv4Network('192.0.2.64/26') + ip3 = IPv4Network('192.0.2.128/26') + ip4 = IPv4Network('192.0.2.192/26') + + _collapse_addresses_internal([ip1, ip2, ip3, ip4]) -> + [IPv4Network('192.0.2.0/24')] + + This shouldn't be called directly; it is called via + collapse_addresses([]). + + Args: + addresses: A list of IPv4Network's or IPv6Network's + + Returns: + A list of IPv4Network's or IPv6Network's depending on what we were + passed. + + """ + # First merge + to_merge = list(addresses) + subnets = {} + while to_merge: + net = to_merge.pop() + supernet = net.supernet() + existing = subnets.get(supernet) + if existing is None: + subnets[supernet] = net + elif existing != net: + # Merge consecutive subnets + del subnets[supernet] + to_merge.append(supernet) + # Then iterate over resulting networks, skipping subsumed subnets + last = None + for net in sorted(subnets.values()): + if last is not None: + # Since they are sorted, + # last.network_address <= net.network_address is a given. + if last.broadcast_address >= net.broadcast_address: + continue + yield net + last = net + + +def collapse_addresses(addresses): + """Collapse a list of IP objects. + + Example: + collapse_addresses([IPv4Network('192.0.2.0/25'), + IPv4Network('192.0.2.128/25')]) -> + [IPv4Network('192.0.2.0/24')] + + Args: + addresses: An iterator of IPv4Network or IPv6Network objects. + + Returns: + An iterator of the collapsed IPv(4|6)Network objects. + + Raises: + TypeError: If passed a list of mixed version objects. + + """ + addrs = [] + ips = [] + nets = [] + + # split IP addresses and networks + for ip in addresses: + if isinstance(ip, _BaseAddress): + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + ips.append(ip) + elif ip._prefixlen == ip._max_prefixlen: + if ips and ips[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, ips[-1])) + try: + ips.append(ip.ip) + except AttributeError: + ips.append(ip.network_address) + else: + if nets and nets[-1]._version != ip._version: + raise TypeError("%s and %s are not of the same version" % ( + ip, nets[-1])) + nets.append(ip) + + # sort and dedup + ips = sorted(set(ips)) + + # find consecutive address ranges in the sorted sequence and summarize them + if ips: + for first, last in _find_address_range(ips): + addrs.extend(summarize_address_range(first, last)) + + return _collapse_addresses_internal(addrs + nets) + + +def get_mixed_type_key(obj): + """Return a key suitable for sorting between networks and addresses. + + Address and Network objects are not sortable by default; they're + fundamentally different so the expression + + IPv4Address('192.0.2.0') <= IPv4Network('192.0.2.0/24') + + doesn't make any sense. There are some times however, where you may wish + to have ipaddress sort these for you anyway. If you need to do this, you + can use this function as the key= argument to sorted(). + + Args: + obj: either a Network or Address object. + Returns: + appropriate key. + + """ + if isinstance(obj, _BaseNetwork): + return obj._get_networks_key() + elif isinstance(obj, _BaseAddress): + return obj._get_address_key() + return NotImplemented + + +class _IPAddressBase(_TotalOrderingMixin): + + """The mother class.""" + + __slots__ = () + + @property + def exploded(self): + """Return the longhand version of the IP address as a string.""" + return self._explode_shorthand_ip_string() + + @property + def compressed(self): + """Return the shorthand version of the IP address as a string.""" + return _compat_str(self) + + @property + def reverse_pointer(self): + """The name of the reverse DNS pointer for the IP address, e.g.: + >>> ipaddress.ip_address("127.0.0.1").reverse_pointer + '1.0.0.127.in-addr.arpa' + >>> ipaddress.ip_address("2001:db8::1").reverse_pointer + '1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa' + + """ + return self._reverse_pointer() + + @property + def version(self): + msg = '%200s has no version specified' % (type(self),) + raise NotImplementedError(msg) + + def _check_int_address(self, address): + if address < 0: + msg = "%d (< 0) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._version)) + if address > self._ALL_ONES: + msg = "%d (>= 2**%d) is not permitted as an IPv%d address" + raise AddressValueError(msg % (address, self._max_prefixlen, + self._version)) + + def _check_packed_address(self, address, expected_len): + address_len = len(address) + if address_len != expected_len: + msg = ( + '%r (len %d != %d) is not permitted as an IPv%d address. ' + 'Did you pass in a bytes (str in Python 2) instead of' + ' a unicode object?') + raise AddressValueError(msg % (address, address_len, + expected_len, self._version)) + + @classmethod + def _ip_int_from_prefix(cls, prefixlen): + """Turn the prefix length into a bitwise netmask + + Args: + prefixlen: An integer, the prefix length. + + Returns: + An integer. + + """ + return cls._ALL_ONES ^ (cls._ALL_ONES >> prefixlen) + + @classmethod + def _prefix_from_ip_int(cls, ip_int): + """Return prefix length from the bitwise netmask. + + Args: + ip_int: An integer, the netmask in expanded bitwise format + + Returns: + An integer, the prefix length. + + Raises: + ValueError: If the input intermingles zeroes & ones + """ + trailing_zeroes = _count_righthand_zero_bits(ip_int, + cls._max_prefixlen) + prefixlen = cls._max_prefixlen - trailing_zeroes + leading_ones = ip_int >> trailing_zeroes + all_ones = (1 << prefixlen) - 1 + if leading_ones != all_ones: + byteslen = cls._max_prefixlen // 8 + details = _compat_to_bytes(ip_int, byteslen, 'big') + msg = 'Netmask pattern %r mixes zeroes & ones' + raise ValueError(msg % details) + return prefixlen + + @classmethod + def _report_invalid_netmask(cls, netmask_str): + msg = '%r is not a valid netmask' % netmask_str + raise NetmaskValueError(msg) + + @classmethod + def _prefix_from_prefix_string(cls, prefixlen_str): + """Return prefix length from a numeric string + + Args: + prefixlen_str: The string to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask + """ + # int allows a leading +/- as well as surrounding whitespace, + # so we ensure that isn't the case + if not _BaseV4._DECIMAL_DIGITS.issuperset(prefixlen_str): + cls._report_invalid_netmask(prefixlen_str) + try: + prefixlen = int(prefixlen_str) + except ValueError: + cls._report_invalid_netmask(prefixlen_str) + if not (0 <= prefixlen <= cls._max_prefixlen): + cls._report_invalid_netmask(prefixlen_str) + return prefixlen + + @classmethod + def _prefix_from_ip_string(cls, ip_str): + """Turn a netmask/hostmask string into a prefix length + + Args: + ip_str: The netmask/hostmask to be converted + + Returns: + An integer, the prefix length. + + Raises: + NetmaskValueError: If the input is not a valid netmask/hostmask + """ + # Parse the netmask/hostmask like an IP address. + try: + ip_int = cls._ip_int_from_string(ip_str) + except AddressValueError: + cls._report_invalid_netmask(ip_str) + + # Try matching a netmask (this would be /1*0*/ as a bitwise regexp). + # Note that the two ambiguous cases (all-ones and all-zeroes) are + # treated as netmasks. + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + pass + + # Invert the bits, and try matching a /0+1+/ hostmask instead. + ip_int ^= cls._ALL_ONES + try: + return cls._prefix_from_ip_int(ip_int) + except ValueError: + cls._report_invalid_netmask(ip_str) + + def __reduce__(self): + return self.__class__, (_compat_str(self),) + + +class _BaseAddress(_IPAddressBase): + + """A generic IP object. + + This IP class contains the version independent methods which are + used by single IP addresses. + """ + + __slots__ = () + + def __int__(self): + return self._ip + + def __eq__(self, other): + try: + return (self._ip == other._ip and + self._version == other._version) + except AttributeError: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseAddress): + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self._ip != other._ip: + return self._ip < other._ip + return False + + # Shorthand for Integer addition and subtraction. This is not + # meant to ever support addition/subtraction of addresses. + def __add__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) + other) + + def __sub__(self, other): + if not isinstance(other, _compat_int_types): + return NotImplemented + return self.__class__(int(self) - other) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return _compat_str(self._string_from_ip_int(self._ip)) + + def __hash__(self): + return hash(hex(int(self._ip))) + + def _get_address_key(self): + return (self._version, self) + + def __reduce__(self): + return self.__class__, (self._ip,) + + +class _BaseNetwork(_IPAddressBase): + + """A generic IP network object. + + This IP class contains the version independent methods which are + used by networks. + + """ + def __init__(self, address): + self._cache = {} + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, _compat_str(self)) + + def __str__(self): + return '%s/%d' % (self.network_address, self.prefixlen) + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the network + or broadcast addresses. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast): + yield self._address_class(x) + + def __iter__(self): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network, broadcast + 1): + yield self._address_class(x) + + def __getitem__(self, n): + network = int(self.network_address) + broadcast = int(self.broadcast_address) + if n >= 0: + if network + n > broadcast: + raise IndexError('address out of range') + return self._address_class(network + n) + else: + n += 1 + if broadcast + n < network: + raise IndexError('address out of range') + return self._address_class(broadcast + n) + + def __lt__(self, other): + if not isinstance(other, _IPAddressBase): + return NotImplemented + if not isinstance(other, _BaseNetwork): + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + if self._version != other._version: + raise TypeError('%s and %s are not of the same version' % ( + self, other)) + if self.network_address != other.network_address: + return self.network_address < other.network_address + if self.netmask != other.netmask: + return self.netmask < other.netmask + return False + + def __eq__(self, other): + try: + return (self._version == other._version and + self.network_address == other.network_address and + int(self.netmask) == int(other.netmask)) + except AttributeError: + return NotImplemented + + def __hash__(self): + return hash(int(self.network_address) ^ int(self.netmask)) + + def __contains__(self, other): + # always false if one is v4 and the other is v6. + if self._version != other._version: + return False + # dealing with another network. + if isinstance(other, _BaseNetwork): + return False + # dealing with another address + else: + # address + return (int(self.network_address) <= int(other._ip) <= + int(self.broadcast_address)) + + def overlaps(self, other): + """Tell if self is partly contained in other.""" + return self.network_address in other or ( + self.broadcast_address in other or ( + other.network_address in self or ( + other.broadcast_address in self))) + + @property + def broadcast_address(self): + x = self._cache.get('broadcast_address') + if x is None: + x = self._address_class(int(self.network_address) | + int(self.hostmask)) + self._cache['broadcast_address'] = x + return x + + @property + def hostmask(self): + x = self._cache.get('hostmask') + if x is None: + x = self._address_class(int(self.netmask) ^ self._ALL_ONES) + self._cache['hostmask'] = x + return x + + @property + def with_prefixlen(self): + return '%s/%d' % (self.network_address, self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self.network_address, self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self.network_address, self.hostmask) + + @property + def num_addresses(self): + """Number of hosts in the current subnet.""" + return int(self.broadcast_address) - int(self.network_address) + 1 + + @property + def _address_class(self): + # Returning bare address objects (rather than interfaces) allows for + # more consistent behaviour across the network address, broadcast + # address and individual host addresses. + msg = '%200s has no associated address class' % (type(self),) + raise NotImplementedError(msg) + + @property + def prefixlen(self): + return self._prefixlen + + def address_exclude(self, other): + """Remove an address from a larger block. + + For example: + + addr1 = ip_network('192.0.2.0/28') + addr2 = ip_network('192.0.2.1/32') + list(addr1.address_exclude(addr2)) = + [IPv4Network('192.0.2.0/32'), IPv4Network('192.0.2.2/31'), + IPv4Network('192.0.2.4/30'), IPv4Network('192.0.2.8/29')] + + or IPv6: + + addr1 = ip_network('2001:db8::1/32') + addr2 = ip_network('2001:db8::1/128') + list(addr1.address_exclude(addr2)) = + [ip_network('2001:db8::1/128'), + ip_network('2001:db8::2/127'), + ip_network('2001:db8::4/126'), + ip_network('2001:db8::8/125'), + ... + ip_network('2001:db8:8000::/33')] + + Args: + other: An IPv4Network or IPv6Network object of the same type. + + Returns: + An iterator of the IPv(4|6)Network objects which is self + minus other. + + Raises: + TypeError: If self and other are of differing address + versions, or if other is not a network object. + ValueError: If other is not completely contained by self. + + """ + if not self._version == other._version: + raise TypeError("%s and %s are not of the same version" % ( + self, other)) + + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + + if not other.subnet_of(self): + raise ValueError('%s not contained in %s' % (other, self)) + if other == self: + return + + # Make sure we're comparing the network of other. + other = other.__class__('%s/%s' % (other.network_address, + other.prefixlen)) + + s1, s2 = self.subnets() + while s1 != other and s2 != other: + if other.subnet_of(s1): + yield s2 + s1, s2 = s1.subnets() + elif other.subnet_of(s2): + yield s1 + s1, s2 = s2.subnets() + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + if s1 == other: + yield s2 + elif s2 == other: + yield s1 + else: + # If we got here, there's a bug somewhere. + raise AssertionError('Error performing exclusion: ' + 's1: %s s2: %s other: %s' % + (s1, s2, other)) + + def compare_networks(self, other): + """Compare two IP objects. + + This is only concerned about the comparison of the integer + representation of the network addresses. This means that the + host bits aren't considered at all in this method. If you want + to compare host bits, you can easily enough do a + 'HostA._ip < HostB._ip' + + Args: + other: An IP object. + + Returns: + If the IP versions of self and other are the same, returns: + + -1 if self < other: + eg: IPv4Network('192.0.2.0/25') < IPv4Network('192.0.2.128/25') + IPv6Network('2001:db8::1000/124') < + IPv6Network('2001:db8::2000/124') + 0 if self == other + eg: IPv4Network('192.0.2.0/24') == IPv4Network('192.0.2.0/24') + IPv6Network('2001:db8::1000/124') == + IPv6Network('2001:db8::1000/124') + 1 if self > other + eg: IPv4Network('192.0.2.128/25') > IPv4Network('192.0.2.0/25') + IPv6Network('2001:db8::2000/124') > + IPv6Network('2001:db8::1000/124') + + Raises: + TypeError if the IP versions are different. + + """ + # does this need to raise a ValueError? + if self._version != other._version: + raise TypeError('%s and %s are not of the same type' % ( + self, other)) + # self._version == other._version below here: + if self.network_address < other.network_address: + return -1 + if self.network_address > other.network_address: + return 1 + # self.network_address == other.network_address below here: + if self.netmask < other.netmask: + return -1 + if self.netmask > other.netmask: + return 1 + return 0 + + def _get_networks_key(self): + """Network-only key function. + + Returns an object that identifies this address' network and + netmask. This function is a suitable "key" argument for sorted() + and list.sort(). + + """ + return (self._version, self.network_address, self.netmask) + + def subnets(self, prefixlen_diff=1, new_prefix=None): + """The subnets which join to make the current subnet. + + In the case that self contains only one IP + (self._prefixlen == 32 for IPv4 or self._prefixlen == 128 + for IPv6), yield an iterator with just ourself. + + Args: + prefixlen_diff: An integer, the amount the prefix length + should be increased by. This should not be set if + new_prefix is also set. + new_prefix: The desired new prefix length. This must be a + larger number (smaller prefix) than the existing prefix. + This should not be set if prefixlen_diff is also set. + + Returns: + An iterator of IPv(4|6) objects. + + Raises: + ValueError: The prefixlen_diff is too small or too large. + OR + prefixlen_diff and new_prefix are both set or new_prefix + is a smaller number than the current prefix (smaller + number means a larger network) + + """ + if self._prefixlen == self._max_prefixlen: + yield self + return + + if new_prefix is not None: + if new_prefix < self._prefixlen: + raise ValueError('new prefix must be longer') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = new_prefix - self._prefixlen + + if prefixlen_diff < 0: + raise ValueError('prefix length diff must be > 0') + new_prefixlen = self._prefixlen + prefixlen_diff + + if new_prefixlen > self._max_prefixlen: + raise ValueError( + 'prefix length diff %d is invalid for netblock %s' % ( + new_prefixlen, self)) + + start = int(self.network_address) + end = int(self.broadcast_address) + 1 + step = (int(self.hostmask) + 1) >> prefixlen_diff + for new_addr in _compat_range(start, end, step): + current = self.__class__((new_addr, new_prefixlen)) + yield current + + def supernet(self, prefixlen_diff=1, new_prefix=None): + """The supernet containing the current network. + + Args: + prefixlen_diff: An integer, the amount the prefix length of + the network should be decreased by. For example, given a + /24 network and a prefixlen_diff of 3, a supernet with a + /21 netmask is returned. + + Returns: + An IPv4 network object. + + Raises: + ValueError: If self.prefixlen - prefixlen_diff < 0. I.e., you have + a negative prefix length. + OR + If prefixlen_diff and new_prefix are both set or new_prefix is a + larger number than the current prefix (larger number means a + smaller network) + + """ + if self._prefixlen == 0: + return self + + if new_prefix is not None: + if new_prefix > self._prefixlen: + raise ValueError('new prefix must be shorter') + if prefixlen_diff != 1: + raise ValueError('cannot set prefixlen_diff and new_prefix') + prefixlen_diff = self._prefixlen - new_prefix + + new_prefixlen = self.prefixlen - prefixlen_diff + if new_prefixlen < 0: + raise ValueError( + 'current prefixlen is %d, cannot have a prefixlen_diff of %d' % + (self.prefixlen, prefixlen_diff)) + return self.__class__(( + int(self.network_address) & (int(self.netmask) << prefixlen_diff), + new_prefixlen)) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return (self.network_address.is_multicast and + self.broadcast_address.is_multicast) + + @staticmethod + def _is_subnet_of(a, b): + try: + # Always false if one is v4 and the other is v6. + if a._version != b._version: + raise TypeError("%s and %s are not of the same version" % (a, b)) + return (b.network_address <= a.network_address and + b.broadcast_address >= a.broadcast_address) + except AttributeError: + raise TypeError("Unable to test subnet containment " + "between %s and %s" % (a, b)) + + def subnet_of(self, other): + """Return True if this network is a subnet of other.""" + return self._is_subnet_of(self, other) + + def supernet_of(self, other): + """Return True if this network is a supernet of other.""" + return self._is_subnet_of(other, self) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return (self.network_address.is_reserved and + self.broadcast_address.is_reserved) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return (self.network_address.is_link_local and + self.broadcast_address.is_link_local) + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return (self.network_address.is_private and + self.broadcast_address.is_private) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry or iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return (self.network_address.is_unspecified and + self.broadcast_address.is_unspecified) + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return (self.network_address.is_loopback and + self.broadcast_address.is_loopback) + + +class _BaseV4(object): + + """Base IPv4 object. + + The following methods are used by IPv4 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 4 + # Equivalent to 255.255.255.255 or 32 bits of 1's. + _ALL_ONES = (2 ** IPV4LENGTH) - 1 + _DECIMAL_DIGITS = frozenset('0123456789') + + # the valid octets for host and netmasks. only useful for IPv4. + _valid_mask_octets = frozenset([255, 254, 252, 248, 240, 224, 192, 128, 0]) + + _max_prefixlen = IPV4LENGTH + # There are only a handful of valid v4 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + def _explode_shorthand_ip_string(self): + return _compat_str(self) + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + try: + # Check for a netmask in prefix length form + prefixlen = cls._prefix_from_prefix_string(arg) + except NetmaskValueError: + # Check for a netmask or hostmask in dotted-quad form. + # This may raise NetmaskValueError. + prefixlen = cls._prefix_from_ip_string(arg) + netmask = IPv4Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn the given IP string into an integer for comparison. + + Args: + ip_str: A string, the IP ip_str. + + Returns: + The IP ip_str as an integer. + + Raises: + AddressValueError: if ip_str isn't a valid IPv4 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + octets = ip_str.split('.') + if len(octets) != 4: + raise AddressValueError("Expected 4 octets in %r" % ip_str) + + try: + return _compat_int_from_byte_vals( + map(cls._parse_octet, octets), 'big') + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_octet(cls, octet_str): + """Convert a decimal octet into an integer. + + Args: + octet_str: A string, the number to parse. + + Returns: + The octet as an integer. + + Raises: + ValueError: if the octet isn't strictly a decimal from [0..255]. + + """ + if not octet_str: + raise ValueError("Empty octet not permitted") + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._DECIMAL_DIGITS.issuperset(octet_str): + msg = "Only decimal digits permitted in %r" + raise ValueError(msg % octet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(octet_str) > 3: + msg = "At most 3 characters permitted in %r" + raise ValueError(msg % octet_str) + # Convert to integer (we know digits are legal) + octet_int = int(octet_str, 10) + # Any octets that look like they *might* be written in octal, + # and which don't look exactly the same in both octal and + # decimal are rejected as ambiguous + if octet_int > 7 and octet_str[0] == '0': + msg = "Ambiguous (octal/decimal) value in %r not permitted" + raise ValueError(msg % octet_str) + if octet_int > 255: + raise ValueError("Octet %d (> 255) not permitted" % octet_int) + return octet_int + + @classmethod + def _string_from_ip_int(cls, ip_int): + """Turns a 32-bit integer into dotted decimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + The IP address as a string in dotted decimal notation. + + """ + return '.'.join(_compat_str(struct.unpack(b'!B', b)[0] + if isinstance(b, bytes) + else b) + for b in _compat_to_bytes(ip_int, 4, 'big')) + + def _is_hostmask(self, ip_str): + """Test if the IP string is a hostmask (rather than a netmask). + + Args: + ip_str: A string, the potential hostmask. + + Returns: + A boolean, True if the IP string is a hostmask. + + """ + bits = ip_str.split('.') + try: + parts = [x for x in map(int, bits) if x in self._valid_mask_octets] + except ValueError: + return False + if len(parts) != len(bits): + return False + if parts[0] < parts[-1]: + return True + return False + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv4 address. + + This implements the method described in RFC1035 3.5. + + """ + reverse_octets = _compat_str(self).split('.')[::-1] + return '.'.join(reverse_octets) + '.in-addr.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv4Address(_BaseV4, _BaseAddress): + + """Represent and manipulate single IPv4 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + + """ + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv4Address('192.0.2.1') == IPv4Address(3221225985). + or, more generally + IPv4Address(int(IPv4Address('192.0.2.1'))) == + IPv4Address('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 4) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v4_int_to_packed(self._ip) + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within the + reserved IPv4 Network range. + + """ + return self in self._constants._reserved_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv4-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + return ( + self not in self._constants._public_network and + not self.is_private) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is multicast. + See RFC 3171 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 5735 3. + + """ + return self == self._constants._unspecified_address + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback per RFC 3330. + + """ + return self in self._constants._loopback_network + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is link-local per RFC 3927. + + """ + return self in self._constants._linklocal_network + + +class IPv4Interface(IPv4Address): + + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv4Address.__init__(self, address) + self.network = IPv4Network(self._ip) + self._prefixlen = self._max_prefixlen + return + + if isinstance(address, tuple): + IPv4Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + + self.network = IPv4Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv4Address.__init__(self, addr[0]) + + self.network = IPv4Network(address, strict=False) + self._prefixlen = self.network._prefixlen + + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv4Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv4Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv4Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + +class IPv4Network(_BaseV4, _BaseNetwork): + + """This class represents and manipulates 32-bit IPv4 network + addresses.. + + Attributes: [examples for IPv4Network('192.0.2.0/27')] + .network_address: IPv4Address('192.0.2.0') + .hostmask: IPv4Address('0.0.0.31') + .broadcast_address: IPv4Address('192.0.2.32') + .netmask: IPv4Address('255.255.255.224') + .prefixlen: 27 + + """ + # Class to use when creating address objects + _address_class = IPv4Address + + def __init__(self, address, strict=True): + + """Instantiate a new IPv4 network object. + + Args: + address: A string or integer representing the IP [& network]. + '192.0.2.0/24' + '192.0.2.0/255.255.255.0' + '192.0.0.2/0.0.0.255' + are all functionally the same in IPv4. Similarly, + '192.0.2.1' + '192.0.2.1/255.255.255.255' + '192.0.2.1/32' + are also functionally equivalent. That is to say, failing to + provide a subnetmask will create an object with a mask of /32. + + If the mask (portion after the / in the argument) is given in + dotted quad form, it is treated as a netmask if it starts with a + non-zero field (e.g. /255.0.0.0 == /8) and as a hostmask if it + starts with a zero field (e.g. 0.255.255.255 == /8), with the + single exception of an all-zero mask which is treated as a + netmask == /0. If no mask is given, a default of /32 is used. + + Additionally, an integer can be passed, so + IPv4Network('192.0.2.1') == IPv4Network(3221225985) + or, more generally + IPv4Interface(int(IPv4Interface('192.0.2.1'))) == + IPv4Interface('192.0.2.1') + + Raises: + AddressValueError: If ipaddress isn't a valid IPv4 address. + NetmaskValueError: If the netmask isn't valid for + an IPv4 address. + ValueError: If strict is True and a network address is not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Constructing from a packed address or integer + if isinstance(address, (_compat_int_types, bytes)): + self.network_address = IPv4Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen) + # fixme: address/network test here. + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + # We weren't given an address[1] + arg = self._max_prefixlen + self.network_address = IPv4Address(address[0]) + self.netmask, self._prefixlen = self._make_netmask(arg) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv4Address(packed & + int(self.netmask)) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + self.network_address = IPv4Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if (IPv4Address(int(self.network_address) & int(self.netmask)) != + self.network_address): + raise ValueError('%s has host bits set' % self) + self.network_address = IPv4Address(int(self.network_address) & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, True if the address is not reserved per + iana-ipv4-special-registry. + + """ + return (not (self.network_address in IPv4Network('100.64.0.0/10') and + self.broadcast_address in IPv4Network('100.64.0.0/10')) and + not self.is_private) + + +class _IPv4Constants(object): + + _linklocal_network = IPv4Network('169.254.0.0/16') + + _loopback_network = IPv4Network('127.0.0.0/8') + + _multicast_network = IPv4Network('224.0.0.0/4') + + _public_network = IPv4Network('100.64.0.0/10') + + _private_networks = [ + IPv4Network('0.0.0.0/8'), + IPv4Network('10.0.0.0/8'), + IPv4Network('127.0.0.0/8'), + IPv4Network('169.254.0.0/16'), + IPv4Network('172.16.0.0/12'), + IPv4Network('192.0.0.0/29'), + IPv4Network('192.0.0.170/31'), + IPv4Network('192.0.2.0/24'), + IPv4Network('192.168.0.0/16'), + IPv4Network('198.18.0.0/15'), + IPv4Network('198.51.100.0/24'), + IPv4Network('203.0.113.0/24'), + IPv4Network('240.0.0.0/4'), + IPv4Network('255.255.255.255/32'), + ] + + _reserved_network = IPv4Network('240.0.0.0/4') + + _unspecified_address = IPv4Address('0.0.0.0') + + +IPv4Address._constants = _IPv4Constants + + +class _BaseV6(object): + + """Base IPv6 object. + + The following methods are used by IPv6 objects in both single IP + addresses and networks. + + """ + + __slots__ = () + _version = 6 + _ALL_ONES = (2 ** IPV6LENGTH) - 1 + _HEXTET_COUNT = 8 + _HEX_DIGITS = frozenset('0123456789ABCDEFabcdef') + _max_prefixlen = IPV6LENGTH + + # There are only a bunch of valid v6 netmasks, so we cache them all + # when constructed (see _make_netmask()). + _netmask_cache = {} + + @classmethod + def _make_netmask(cls, arg): + """Make a (netmask, prefix_len) tuple from the given argument. + + Argument can be: + - an integer (the prefix length) + - a string representing the prefix length (e.g. "24") + - a string representing the prefix netmask (e.g. "255.255.255.0") + """ + if arg not in cls._netmask_cache: + if isinstance(arg, _compat_int_types): + prefixlen = arg + else: + prefixlen = cls._prefix_from_prefix_string(arg) + netmask = IPv6Address(cls._ip_int_from_prefix(prefixlen)) + cls._netmask_cache[arg] = netmask, prefixlen + return cls._netmask_cache[arg] + + @classmethod + def _ip_int_from_string(cls, ip_str): + """Turn an IPv6 ip_str into an integer. + + Args: + ip_str: A string, the IPv6 ip_str. + + Returns: + An int, the IPv6 address + + Raises: + AddressValueError: if ip_str isn't a valid IPv6 Address. + + """ + if not ip_str: + raise AddressValueError('Address cannot be empty') + + parts = ip_str.split(':') + + # An IPv6 address needs at least 2 colons (3 parts). + _min_parts = 3 + if len(parts) < _min_parts: + msg = "At least %d parts expected in %r" % (_min_parts, ip_str) + raise AddressValueError(msg) + + # If the address has an IPv4-style suffix, convert it to hexadecimal. + if '.' in parts[-1]: + try: + ipv4_int = IPv4Address(parts.pop())._ip + except AddressValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + parts.append('%x' % ((ipv4_int >> 16) & 0xFFFF)) + parts.append('%x' % (ipv4_int & 0xFFFF)) + + # An IPv6 address can't have more than 8 colons (9 parts). + # The extra colon comes from using the "::" notation for a single + # leading or trailing zero part. + _max_parts = cls._HEXTET_COUNT + 1 + if len(parts) > _max_parts: + msg = "At most %d colons permitted in %r" % ( + _max_parts - 1, ip_str) + raise AddressValueError(msg) + + # Disregarding the endpoints, find '::' with nothing in between. + # This indicates that a run of zeroes has been skipped. + skip_index = None + for i in _compat_range(1, len(parts) - 1): + if not parts[i]: + if skip_index is not None: + # Can't have more than one '::' + msg = "At most one '::' permitted in %r" % ip_str + raise AddressValueError(msg) + skip_index = i + + # parts_hi is the number of parts to copy from above/before the '::' + # parts_lo is the number of parts to copy from below/after the '::' + if skip_index is not None: + # If we found a '::', then check if it also covers the endpoints. + parts_hi = skip_index + parts_lo = len(parts) - skip_index - 1 + if not parts[0]: + parts_hi -= 1 + if parts_hi: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + parts_lo -= 1 + if parts_lo: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_skipped = cls._HEXTET_COUNT - (parts_hi + parts_lo) + if parts_skipped < 1: + msg = "Expected at most %d other parts with '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT - 1, ip_str)) + else: + # Otherwise, allocate the entire address to parts_hi. The + # endpoints could still be empty, but _parse_hextet() will check + # for that. + if len(parts) != cls._HEXTET_COUNT: + msg = "Exactly %d parts expected without '::' in %r" + raise AddressValueError(msg % (cls._HEXTET_COUNT, ip_str)) + if not parts[0]: + msg = "Leading ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # ^: requires ^:: + if not parts[-1]: + msg = "Trailing ':' only permitted as part of '::' in %r" + raise AddressValueError(msg % ip_str) # :$ requires ::$ + parts_hi = len(parts) + parts_lo = 0 + parts_skipped = 0 + + try: + # Now, parse the hextets into a 128-bit integer. + ip_int = 0 + for i in range(parts_hi): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + ip_int <<= 16 * parts_skipped + for i in range(-parts_lo, 0): + ip_int <<= 16 + ip_int |= cls._parse_hextet(parts[i]) + return ip_int + except ValueError as exc: + raise AddressValueError("%s in %r" % (exc, ip_str)) + + @classmethod + def _parse_hextet(cls, hextet_str): + """Convert an IPv6 hextet string into an integer. + + Args: + hextet_str: A string, the number to parse. + + Returns: + The hextet as an integer. + + Raises: + ValueError: if the input isn't strictly a hex number from + [0..FFFF]. + + """ + # Whitelist the characters, since int() allows a lot of bizarre stuff. + if not cls._HEX_DIGITS.issuperset(hextet_str): + raise ValueError("Only hex digits permitted in %r" % hextet_str) + # We do the length check second, since the invalid character error + # is likely to be more informative for the user + if len(hextet_str) > 4: + msg = "At most 4 characters permitted in %r" + raise ValueError(msg % hextet_str) + # Length check means we can skip checking the integer value + return int(hextet_str, 16) + + @classmethod + def _compress_hextets(cls, hextets): + """Compresses a list of hextets. + + Compresses a list of strings, replacing the longest continuous + sequence of "0" in the list with "" and adding empty strings at + the beginning or at the end of the string such that subsequently + calling ":".join(hextets) will produce the compressed version of + the IPv6 address. + + Args: + hextets: A list of strings, the hextets to compress. + + Returns: + A list of strings. + + """ + best_doublecolon_start = -1 + best_doublecolon_len = 0 + doublecolon_start = -1 + doublecolon_len = 0 + for index, hextet in enumerate(hextets): + if hextet == '0': + doublecolon_len += 1 + if doublecolon_start == -1: + # Start of a sequence of zeros. + doublecolon_start = index + if doublecolon_len > best_doublecolon_len: + # This is the longest sequence of zeros so far. + best_doublecolon_len = doublecolon_len + best_doublecolon_start = doublecolon_start + else: + doublecolon_len = 0 + doublecolon_start = -1 + + if best_doublecolon_len > 1: + best_doublecolon_end = (best_doublecolon_start + + best_doublecolon_len) + # For zeros at the end of the address. + if best_doublecolon_end == len(hextets): + hextets += [''] + hextets[best_doublecolon_start:best_doublecolon_end] = [''] + # For zeros at the beginning of the address. + if best_doublecolon_start == 0: + hextets = [''] + hextets + + return hextets + + @classmethod + def _string_from_ip_int(cls, ip_int=None): + """Turns a 128-bit integer into hexadecimal notation. + + Args: + ip_int: An integer, the IP address. + + Returns: + A string, the hexadecimal representation of the address. + + Raises: + ValueError: The address is bigger than 128 bits of all ones. + + """ + if ip_int is None: + ip_int = int(cls._ip) + + if ip_int > cls._ALL_ONES: + raise ValueError('IPv6 address is too large') + + hex_str = '%032x' % ip_int + hextets = ['%x' % int(hex_str[x:x + 4], 16) for x in range(0, 32, 4)] + + hextets = cls._compress_hextets(hextets) + return ':'.join(hextets) + + def _explode_shorthand_ip_string(self): + """Expand a shortened IPv6 address. + + Args: + ip_str: A string, the IPv6 address. + + Returns: + A string, the expanded IPv6 address. + + """ + if isinstance(self, IPv6Network): + ip_str = _compat_str(self.network_address) + elif isinstance(self, IPv6Interface): + ip_str = _compat_str(self.ip) + else: + ip_str = _compat_str(self) + + ip_int = self._ip_int_from_string(ip_str) + hex_str = '%032x' % ip_int + parts = [hex_str[x:x + 4] for x in range(0, 32, 4)] + if isinstance(self, (_BaseNetwork, IPv6Interface)): + return '%s/%d' % (':'.join(parts), self._prefixlen) + return ':'.join(parts) + + def _reverse_pointer(self): + """Return the reverse DNS pointer name for the IPv6 address. + + This implements the method described in RFC3596 2.5. + + """ + reverse_chars = self.exploded[::-1].replace(':', '') + return '.'.join(reverse_chars) + '.ip6.arpa' + + @property + def max_prefixlen(self): + return self._max_prefixlen + + @property + def version(self): + return self._version + + +class IPv6Address(_BaseV6, _BaseAddress): + + """Represent and manipulate single IPv6 Addresses.""" + + __slots__ = ('_ip', '__weakref__') + + def __init__(self, address): + """Instantiate a new IPv6 address object. + + Args: + address: A string or integer representing the IP + + Additionally, an integer can be passed, so + IPv6Address('2001:db8::') == + IPv6Address(42540766411282592856903984951653826560) + or, more generally + IPv6Address(int(IPv6Address('2001:db8::'))) == + IPv6Address('2001:db8::') + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + + """ + # Efficient constructor from integer. + if isinstance(address, _compat_int_types): + self._check_int_address(address) + self._ip = address + return + + # Constructing from a packed address + if isinstance(address, bytes): + self._check_packed_address(address, 16) + bvs = _compat_bytes_to_byte_vals(address) + self._ip = _compat_int_from_byte_vals(bvs, 'big') + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP string. + addr_str = _compat_str(address) + if '/' in addr_str: + raise AddressValueError("Unexpected '/' in %r" % address) + self._ip = self._ip_int_from_string(addr_str) + + @property + def packed(self): + """The binary representation of this address.""" + return v6_int_to_packed(self._ip) + + @property + def is_multicast(self): + """Test if the address is reserved for multicast use. + + Returns: + A boolean, True if the address is a multicast address. + See RFC 2373 2.7 for details. + + """ + return self in self._constants._multicast_network + + @property + def is_reserved(self): + """Test if the address is otherwise IETF reserved. + + Returns: + A boolean, True if the address is within one of the + reserved IPv6 Network ranges. + + """ + return any(self in x for x in self._constants._reserved_networks) + + @property + def is_link_local(self): + """Test if the address is reserved for link-local. + + Returns: + A boolean, True if the address is reserved per RFC 4291. + + """ + return self in self._constants._linklocal_network + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return self in self._constants._sitelocal_network + + @property + def is_private(self): + """Test if this address is allocated for private networks. + + Returns: + A boolean, True if the address is reserved per + iana-ipv6-special-registry. + + """ + return any(self in net for net in self._constants._private_networks) + + @property + def is_global(self): + """Test if this address is allocated for public networks. + + Returns: + A boolean, true if the address is not reserved per + iana-ipv6-special-registry. + + """ + return not self.is_private + + @property + def is_unspecified(self): + """Test if the address is unspecified. + + Returns: + A boolean, True if this is the unspecified address as defined in + RFC 2373 2.5.2. + + """ + return self._ip == 0 + + @property + def is_loopback(self): + """Test if the address is a loopback address. + + Returns: + A boolean, True if the address is a loopback address as defined in + RFC 2373 2.5.3. + + """ + return self._ip == 1 + + @property + def ipv4_mapped(self): + """Return the IPv4 mapped address. + + Returns: + If the IPv6 address is a v4 mapped address, return the + IPv4 mapped address. Return None otherwise. + + """ + if (self._ip >> 32) != 0xFFFF: + return None + return IPv4Address(self._ip & 0xFFFFFFFF) + + @property + def teredo(self): + """Tuple of embedded teredo IPs. + + Returns: + Tuple of the (server, client) IPs or None if the address + doesn't appear to be a teredo address (doesn't start with + 2001::/32) + + """ + if (self._ip >> 96) != 0x20010000: + return None + return (IPv4Address((self._ip >> 64) & 0xFFFFFFFF), + IPv4Address(~self._ip & 0xFFFFFFFF)) + + @property + def sixtofour(self): + """Return the IPv4 6to4 embedded address. + + Returns: + The IPv4 6to4-embedded address if present or None if the + address doesn't appear to contain a 6to4 embedded address. + + """ + if (self._ip >> 112) != 0x2002: + return None + return IPv4Address((self._ip >> 80) & 0xFFFFFFFF) + + +class IPv6Interface(IPv6Address): + + def __init__(self, address): + if isinstance(address, (bytes, _compat_int_types)): + IPv6Address.__init__(self, address) + self.network = IPv6Network(self._ip) + self._prefixlen = self._max_prefixlen + return + if isinstance(address, tuple): + IPv6Address.__init__(self, address[0]) + if len(address) > 1: + self._prefixlen = int(address[1]) + else: + self._prefixlen = self._max_prefixlen + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self.hostmask = self.network.hostmask + return + + addr = _split_optional_netmask(address) + IPv6Address.__init__(self, addr[0]) + self.network = IPv6Network(address, strict=False) + self.netmask = self.network.netmask + self._prefixlen = self.network._prefixlen + self.hostmask = self.network.hostmask + + def __str__(self): + return '%s/%d' % (self._string_from_ip_int(self._ip), + self.network.prefixlen) + + def __eq__(self, other): + address_equal = IPv6Address.__eq__(self, other) + if not address_equal or address_equal is NotImplemented: + return address_equal + try: + return self.network == other.network + except AttributeError: + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. + return False + + def __lt__(self, other): + address_less = IPv6Address.__lt__(self, other) + if address_less is NotImplemented: + return NotImplemented + try: + return (self.network < other.network or + self.network == other.network and address_less) + except AttributeError: + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. + return False + + def __hash__(self): + return self._ip ^ self._prefixlen ^ int(self.network.network_address) + + __reduce__ = _IPAddressBase.__reduce__ + + @property + def ip(self): + return IPv6Address(self._ip) + + @property + def with_prefixlen(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self._prefixlen) + + @property + def with_netmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.netmask) + + @property + def with_hostmask(self): + return '%s/%s' % (self._string_from_ip_int(self._ip), + self.hostmask) + + @property + def is_unspecified(self): + return self._ip == 0 and self.network.is_unspecified + + @property + def is_loopback(self): + return self._ip == 1 and self.network.is_loopback + + +class IPv6Network(_BaseV6, _BaseNetwork): + + """This class represents and manipulates 128-bit IPv6 networks. + + Attributes: [examples for IPv6('2001:db8::1000/124')] + .network_address: IPv6Address('2001:db8::1000') + .hostmask: IPv6Address('::f') + .broadcast_address: IPv6Address('2001:db8::100f') + .netmask: IPv6Address('ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0') + .prefixlen: 124 + + """ + + # Class to use when creating address objects + _address_class = IPv6Address + + def __init__(self, address, strict=True): + """Instantiate a new IPv6 Network object. + + Args: + address: A string or integer representing the IPv6 network or the + IP and prefix/netmask. + '2001:db8::/128' + '2001:db8:0000:0000:0000:0000:0000:0000/128' + '2001:db8::' + are all functionally the same in IPv6. That is to say, + failing to provide a subnetmask will create an object with + a mask of /128. + + Additionally, an integer can be passed, so + IPv6Network('2001:db8::') == + IPv6Network(42540766411282592856903984951653826560) + or, more generally + IPv6Network(int(IPv6Network('2001:db8::'))) == + IPv6Network('2001:db8::') + + strict: A boolean. If true, ensure that we have been passed + A true network address, eg, 2001:db8::1000/124 and not an + IP address on a network, eg, 2001:db8::1/124. + + Raises: + AddressValueError: If address isn't a valid IPv6 address. + NetmaskValueError: If the netmask isn't valid for + an IPv6 address. + ValueError: If strict was True and a network address was not + supplied. + + """ + _BaseNetwork.__init__(self, address) + + # Efficient constructor from integer or packed address + if isinstance(address, (bytes, _compat_int_types)): + self.network_address = IPv6Address(address) + self.netmask, self._prefixlen = self._make_netmask( + self._max_prefixlen) + return + + if isinstance(address, tuple): + if len(address) > 1: + arg = address[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + self.network_address = IPv6Address(address[0]) + packed = int(self.network_address) + if packed & int(self.netmask) != packed: + if strict: + raise ValueError('%s has host bits set' % self) + else: + self.network_address = IPv6Address(packed & + int(self.netmask)) + return + + # Assume input argument to be string or any object representation + # which converts into a formatted IP prefix string. + addr = _split_optional_netmask(address) + + self.network_address = IPv6Address(self._ip_int_from_string(addr[0])) + + if len(addr) == 2: + arg = addr[1] + else: + arg = self._max_prefixlen + self.netmask, self._prefixlen = self._make_netmask(arg) + + if strict: + if (IPv6Address(int(self.network_address) & int(self.netmask)) != + self.network_address): + raise ValueError('%s has host bits set' % self) + self.network_address = IPv6Address(int(self.network_address) & + int(self.netmask)) + + if self._prefixlen == (self._max_prefixlen - 1): + self.hosts = self.__iter__ + + def hosts(self): + """Generate Iterator over usable hosts in a network. + + This is like __iter__ except it doesn't return the + Subnet-Router anycast address. + + """ + network = int(self.network_address) + broadcast = int(self.broadcast_address) + for x in _compat_range(network + 1, broadcast + 1): + yield self._address_class(x) + + @property + def is_site_local(self): + """Test if the address is reserved for site-local. + + Note that the site-local address space has been deprecated by RFC 3879. + Use is_private to test if this address is in the space of unique local + addresses as defined by RFC 4193. + + Returns: + A boolean, True if the address is reserved per RFC 3513 2.5.6. + + """ + return (self.network_address.is_site_local and + self.broadcast_address.is_site_local) + + +class _IPv6Constants(object): + + _linklocal_network = IPv6Network('fe80::/10') + + _multicast_network = IPv6Network('ff00::/8') + + _private_networks = [ + IPv6Network('::1/128'), + IPv6Network('::/128'), + IPv6Network('::ffff:0:0/96'), + IPv6Network('100::/64'), + IPv6Network('2001::/23'), + IPv6Network('2001:2::/48'), + IPv6Network('2001:db8::/32'), + IPv6Network('2001:10::/28'), + IPv6Network('fc00::/7'), + IPv6Network('fe80::/10'), + ] + + _reserved_networks = [ + IPv6Network('::/8'), IPv6Network('100::/8'), + IPv6Network('200::/7'), IPv6Network('400::/6'), + IPv6Network('800::/5'), IPv6Network('1000::/4'), + IPv6Network('4000::/3'), IPv6Network('6000::/3'), + IPv6Network('8000::/3'), IPv6Network('A000::/3'), + IPv6Network('C000::/3'), IPv6Network('E000::/4'), + IPv6Network('F000::/5'), IPv6Network('F800::/6'), + IPv6Network('FE00::/9'), + ] + + _sitelocal_network = IPv6Network('fec0::/10') + + +IPv6Address._constants = _IPv6Constants diff --git a/test/support/integration/plugins/module_utils/crypto.py b/test/support/integration/plugins/module_utils/crypto.py new file mode 100644 index 00000000..e67eeff1 --- /dev/null +++ b/test/support/integration/plugins/module_utils/crypto.py @@ -0,0 +1,2125 @@ +# -*- coding: utf-8 -*- +# +# (c) 2016, Yanis Guenane <yanis+ansible@guenane.org> +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. +# +# ---------------------------------------------------------------------- +# A clearly marked portion of this file is licensed under the BSD license +# Copyright (c) 2015, 2016 Paul Kehrer (@reaperhulk) +# Copyright (c) 2017 Fraser Tweedale (@frasertweedale) +# For more details, search for the function _obj2txt(). +# --------------------------------------------------------------------- +# A clearly marked portion of this file is extracted from a project that +# is licensed under the Apache License 2.0 +# Copyright (c) the OpenSSL contributors +# For more details, search for the function _OID_MAP. + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +import sys +from distutils.version import LooseVersion + +try: + import OpenSSL + from OpenSSL import crypto +except ImportError: + # An error will be raised in the calling class to let the end + # user know that OpenSSL couldn't be found. + pass + +try: + import cryptography + from cryptography import x509 + from cryptography.hazmat.backends import default_backend as cryptography_backend + from cryptography.hazmat.primitives.serialization import load_pem_private_key + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives import serialization + import ipaddress + + # Older versions of cryptography (< 2.1) do not have __hash__ functions for + # general name objects (DNSName, IPAddress, ...), while providing overloaded + # equality and string representation operations. This makes it impossible to + # use them in hash-based data structures such as set or dict. Since we are + # actually doing that in openssl_certificate, and potentially in other code, + # we need to monkey-patch __hash__ for these classes to make sure our code + # works fine. + if LooseVersion(cryptography.__version__) < LooseVersion('2.1'): + # A very simply hash function which relies on the representation + # of an object to be implemented. This is the case since at least + # cryptography 1.0, see + # https://github.com/pyca/cryptography/commit/7a9abce4bff36c05d26d8d2680303a6f64a0e84f + def simple_hash(self): + return hash(repr(self)) + + # The hash functions for the following types were added for cryptography 2.1: + # https://github.com/pyca/cryptography/commit/fbfc36da2a4769045f2373b004ddf0aff906cf38 + x509.DNSName.__hash__ = simple_hash + x509.DirectoryName.__hash__ = simple_hash + x509.GeneralName.__hash__ = simple_hash + x509.IPAddress.__hash__ = simple_hash + x509.OtherName.__hash__ = simple_hash + x509.RegisteredID.__hash__ = simple_hash + + if LooseVersion(cryptography.__version__) < LooseVersion('1.2'): + # The hash functions for the following types were added for cryptography 1.2: + # https://github.com/pyca/cryptography/commit/b642deed88a8696e5f01ce6855ccf89985fc35d0 + # https://github.com/pyca/cryptography/commit/d1b5681f6db2bde7a14625538bd7907b08dfb486 + x509.RFC822Name.__hash__ = simple_hash + x509.UniformResourceIdentifier.__hash__ = simple_hash + + # Test whether we have support for X25519, X448, Ed25519 and/or Ed448 + try: + import cryptography.hazmat.primitives.asymmetric.x25519 + CRYPTOGRAPHY_HAS_X25519 = True + try: + cryptography.hazmat.primitives.asymmetric.x25519.X25519PrivateKey.private_bytes + CRYPTOGRAPHY_HAS_X25519_FULL = True + except AttributeError: + CRYPTOGRAPHY_HAS_X25519_FULL = False + except ImportError: + CRYPTOGRAPHY_HAS_X25519 = False + CRYPTOGRAPHY_HAS_X25519_FULL = False + try: + import cryptography.hazmat.primitives.asymmetric.x448 + CRYPTOGRAPHY_HAS_X448 = True + except ImportError: + CRYPTOGRAPHY_HAS_X448 = False + try: + import cryptography.hazmat.primitives.asymmetric.ed25519 + CRYPTOGRAPHY_HAS_ED25519 = True + except ImportError: + CRYPTOGRAPHY_HAS_ED25519 = False + try: + import cryptography.hazmat.primitives.asymmetric.ed448 + CRYPTOGRAPHY_HAS_ED448 = True + except ImportError: + CRYPTOGRAPHY_HAS_ED448 = False + + HAS_CRYPTOGRAPHY = True +except ImportError: + # Error handled in the calling module. + CRYPTOGRAPHY_HAS_X25519 = False + CRYPTOGRAPHY_HAS_X25519_FULL = False + CRYPTOGRAPHY_HAS_X448 = False + CRYPTOGRAPHY_HAS_ED25519 = False + CRYPTOGRAPHY_HAS_ED448 = False + HAS_CRYPTOGRAPHY = False + + +import abc +import base64 +import binascii +import datetime +import errno +import hashlib +import os +import re +import tempfile + +from ansible.module_utils import six +from ansible.module_utils._text import to_native, to_bytes, to_text + + +class OpenSSLObjectError(Exception): + pass + + +class OpenSSLBadPassphraseError(OpenSSLObjectError): + pass + + +def get_fingerprint_of_bytes(source): + """Generate the fingerprint of the given bytes.""" + + fingerprint = {} + + try: + algorithms = hashlib.algorithms + except AttributeError: + try: + algorithms = hashlib.algorithms_guaranteed + except AttributeError: + return None + + for algo in algorithms: + f = getattr(hashlib, algo) + try: + h = f(source) + except ValueError: + # This can happen for hash algorithms not supported in FIPS mode + # (https://github.com/ansible/ansible/issues/67213) + continue + try: + # Certain hash functions have a hexdigest() which expects a length parameter + pubkey_digest = h.hexdigest() + except TypeError: + pubkey_digest = h.hexdigest(32) + fingerprint[algo] = ':'.join(pubkey_digest[i:i + 2] for i in range(0, len(pubkey_digest), 2)) + + return fingerprint + + +def get_fingerprint(path, passphrase=None, content=None, backend='pyopenssl'): + """Generate the fingerprint of the public key. """ + + privatekey = load_privatekey(path, passphrase=passphrase, content=content, check_passphrase=False, backend=backend) + + if backend == 'pyopenssl': + try: + publickey = crypto.dump_publickey(crypto.FILETYPE_ASN1, privatekey) + except AttributeError: + # If PyOpenSSL < 16.0 crypto.dump_publickey() will fail. + try: + bio = crypto._new_mem_buf() + rc = crypto._lib.i2d_PUBKEY_bio(bio, privatekey._pkey) + if rc != 1: + crypto._raise_current_error() + publickey = crypto._bio_to_string(bio) + except AttributeError: + # By doing this we prevent the code from raising an error + # yet we return no value in the fingerprint hash. + return None + elif backend == 'cryptography': + publickey = privatekey.public_key().public_bytes( + serialization.Encoding.DER, + serialization.PublicFormat.SubjectPublicKeyInfo + ) + + return get_fingerprint_of_bytes(publickey) + + +def load_file_if_exists(path, module=None, ignore_errors=False): + try: + with open(path, 'rb') as f: + return f.read() + except EnvironmentError as exc: + if exc.errno == errno.ENOENT: + return None + if ignore_errors: + return None + if module is None: + raise + module.fail_json('Error while loading {0} - {1}'.format(path, str(exc))) + except Exception as exc: + if ignore_errors: + return None + if module is None: + raise + module.fail_json('Error while loading {0} - {1}'.format(path, str(exc))) + + +def load_privatekey(path, passphrase=None, check_passphrase=True, content=None, backend='pyopenssl'): + """Load the specified OpenSSL private key. + + The content can also be specified via content; in that case, + this function will not load the key from disk. + """ + + try: + if content is None: + with open(path, 'rb') as b_priv_key_fh: + priv_key_detail = b_priv_key_fh.read() + else: + priv_key_detail = content + + if backend == 'pyopenssl': + + # First try: try to load with real passphrase (resp. empty string) + # Will work if this is the correct passphrase, or the key is not + # password-protected. + try: + result = crypto.load_privatekey(crypto.FILETYPE_PEM, + priv_key_detail, + to_bytes(passphrase or '')) + except crypto.Error as e: + if len(e.args) > 0 and len(e.args[0]) > 0: + if e.args[0][0][2] in ('bad decrypt', 'bad password read'): + # This happens in case we have the wrong passphrase. + if passphrase is not None: + raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key!') + else: + raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!') + raise OpenSSLObjectError('Error while deserializing key: {0}'.format(e)) + if check_passphrase: + # Next we want to make sure that the key is actually protected by + # a passphrase (in case we did try the empty string before, make + # sure that the key is not protected by the empty string) + try: + crypto.load_privatekey(crypto.FILETYPE_PEM, + priv_key_detail, + to_bytes('y' if passphrase == 'x' else 'x')) + if passphrase is not None: + # Since we can load the key without an exception, the + # key isn't password-protected + raise OpenSSLBadPassphraseError('Passphrase provided, but private key is not password-protected!') + except crypto.Error as e: + if passphrase is None and len(e.args) > 0 and len(e.args[0]) > 0: + if e.args[0][0][2] in ('bad decrypt', 'bad password read'): + # The key is obviously protected by the empty string. + # Don't do this at home (if it's possible at all)... + raise OpenSSLBadPassphraseError('No passphrase provided, but private key is password-protected!') + elif backend == 'cryptography': + try: + result = load_pem_private_key(priv_key_detail, + None if passphrase is None else to_bytes(passphrase), + cryptography_backend()) + except TypeError as dummy: + raise OpenSSLBadPassphraseError('Wrong or empty passphrase provided for private key') + except ValueError as dummy: + raise OpenSSLBadPassphraseError('Wrong passphrase provided for private key') + + return result + except (IOError, OSError) as exc: + raise OpenSSLObjectError(exc) + + +def load_certificate(path, content=None, backend='pyopenssl'): + """Load the specified certificate.""" + + try: + if content is None: + with open(path, 'rb') as cert_fh: + cert_content = cert_fh.read() + else: + cert_content = content + if backend == 'pyopenssl': + return crypto.load_certificate(crypto.FILETYPE_PEM, cert_content) + elif backend == 'cryptography': + return x509.load_pem_x509_certificate(cert_content, cryptography_backend()) + except (IOError, OSError) as exc: + raise OpenSSLObjectError(exc) + + +def load_certificate_request(path, content=None, backend='pyopenssl'): + """Load the specified certificate signing request.""" + try: + if content is None: + with open(path, 'rb') as csr_fh: + csr_content = csr_fh.read() + else: + csr_content = content + except (IOError, OSError) as exc: + raise OpenSSLObjectError(exc) + if backend == 'pyopenssl': + return crypto.load_certificate_request(crypto.FILETYPE_PEM, csr_content) + elif backend == 'cryptography': + return x509.load_pem_x509_csr(csr_content, cryptography_backend()) + + +def parse_name_field(input_dict): + """Take a dict with key: value or key: list_of_values mappings and return a list of tuples""" + + result = [] + for key in input_dict: + if isinstance(input_dict[key], list): + for entry in input_dict[key]: + result.append((key, entry)) + else: + result.append((key, input_dict[key])) + return result + + +def convert_relative_to_datetime(relative_time_string): + """Get a datetime.datetime or None from a string in the time format described in sshd_config(5)""" + + parsed_result = re.match( + r"^(?P<prefix>[+-])((?P<weeks>\d+)[wW])?((?P<days>\d+)[dD])?((?P<hours>\d+)[hH])?((?P<minutes>\d+)[mM])?((?P<seconds>\d+)[sS]?)?$", + relative_time_string) + + if parsed_result is None or len(relative_time_string) == 1: + # not matched or only a single "+" or "-" + return None + + offset = datetime.timedelta(0) + if parsed_result.group("weeks") is not None: + offset += datetime.timedelta(weeks=int(parsed_result.group("weeks"))) + if parsed_result.group("days") is not None: + offset += datetime.timedelta(days=int(parsed_result.group("days"))) + if parsed_result.group("hours") is not None: + offset += datetime.timedelta(hours=int(parsed_result.group("hours"))) + if parsed_result.group("minutes") is not None: + offset += datetime.timedelta( + minutes=int(parsed_result.group("minutes"))) + if parsed_result.group("seconds") is not None: + offset += datetime.timedelta( + seconds=int(parsed_result.group("seconds"))) + + if parsed_result.group("prefix") == "+": + return datetime.datetime.utcnow() + offset + else: + return datetime.datetime.utcnow() - offset + + +def get_relative_time_option(input_string, input_name, backend='cryptography'): + """Return an absolute timespec if a relative timespec or an ASN1 formatted + string is provided. + + The return value will be a datetime object for the cryptography backend, + and a ASN1 formatted string for the pyopenssl backend.""" + result = to_native(input_string) + if result is None: + raise OpenSSLObjectError( + 'The timespec "%s" for %s is not valid' % + input_string, input_name) + # Relative time + if result.startswith("+") or result.startswith("-"): + result_datetime = convert_relative_to_datetime(result) + if backend == 'pyopenssl': + return result_datetime.strftime("%Y%m%d%H%M%SZ") + elif backend == 'cryptography': + return result_datetime + # Absolute time + if backend == 'pyopenssl': + return input_string + elif backend == 'cryptography': + for date_fmt in ['%Y%m%d%H%M%SZ', '%Y%m%d%H%MZ', '%Y%m%d%H%M%S%z', '%Y%m%d%H%M%z']: + try: + return datetime.datetime.strptime(result, date_fmt) + except ValueError: + pass + + raise OpenSSLObjectError( + 'The time spec "%s" for %s is invalid' % + (input_string, input_name) + ) + + +def select_message_digest(digest_string): + digest = None + if digest_string == 'sha256': + digest = hashes.SHA256() + elif digest_string == 'sha384': + digest = hashes.SHA384() + elif digest_string == 'sha512': + digest = hashes.SHA512() + elif digest_string == 'sha1': + digest = hashes.SHA1() + elif digest_string == 'md5': + digest = hashes.MD5() + return digest + + +def write_file(module, content, default_mode=None, path=None): + ''' + Writes content into destination file as securely as possible. + Uses file arguments from module. + ''' + # Find out parameters for file + file_args = module.load_file_common_arguments(module.params, path=path) + if file_args['mode'] is None: + file_args['mode'] = default_mode + # Create tempfile name + tmp_fd, tmp_name = tempfile.mkstemp(prefix=b'.ansible_tmp') + try: + os.close(tmp_fd) + except Exception as dummy: + pass + module.add_cleanup_file(tmp_name) # if we fail, let Ansible try to remove the file + try: + try: + # Create tempfile + file = os.open(tmp_name, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + os.write(file, content) + os.close(file) + except Exception as e: + try: + os.remove(tmp_name) + except Exception as dummy: + pass + module.fail_json(msg='Error while writing result into temporary file: {0}'.format(e)) + # Update destination to wanted permissions + if os.path.exists(file_args['path']): + module.set_fs_attributes_if_different(file_args, False) + # Move tempfile to final destination + module.atomic_move(tmp_name, file_args['path']) + # Try to update permissions again + module.set_fs_attributes_if_different(file_args, False) + except Exception as e: + try: + os.remove(tmp_name) + except Exception as dummy: + pass + module.fail_json(msg='Error while writing result: {0}'.format(e)) + + +@six.add_metaclass(abc.ABCMeta) +class OpenSSLObject(object): + + def __init__(self, path, state, force, check_mode): + self.path = path + self.state = state + self.force = force + self.name = os.path.basename(path) + self.changed = False + self.check_mode = check_mode + + def check(self, module, perms_required=True): + """Ensure the resource is in its desired state.""" + + def _check_state(): + return os.path.exists(self.path) + + def _check_perms(module): + file_args = module.load_file_common_arguments(module.params) + return not module.set_fs_attributes_if_different(file_args, False) + + if not perms_required: + return _check_state() + + return _check_state() and _check_perms(module) + + @abc.abstractmethod + def dump(self): + """Serialize the object into a dictionary.""" + + pass + + @abc.abstractmethod + def generate(self): + """Generate the resource.""" + + pass + + def remove(self, module): + """Remove the resource from the filesystem.""" + + try: + os.remove(self.path) + self.changed = True + except OSError as exc: + if exc.errno != errno.ENOENT: + raise OpenSSLObjectError(exc) + else: + pass + + +# ##################################################################################### +# ##################################################################################### +# This has been extracted from the OpenSSL project's objects.txt: +# https://github.com/openssl/openssl/blob/9537fe5757bb07761fa275d779bbd40bcf5530e4/crypto/objects/objects.txt +# Extracted with https://gist.github.com/felixfontein/376748017ad65ead093d56a45a5bf376 +# +# In case the following data structure has any copyrightable content, note that it is licensed as follows: +# Copyright (c) the OpenSSL contributors +# Licensed under the Apache License 2.0 +# https://github.com/openssl/openssl/blob/master/LICENSE +_OID_MAP = { + '0': ('itu-t', 'ITU-T', 'ccitt'), + '0.3.4401.5': ('ntt-ds', ), + '0.3.4401.5.3.1.9': ('camellia', ), + '0.3.4401.5.3.1.9.1': ('camellia-128-ecb', 'CAMELLIA-128-ECB'), + '0.3.4401.5.3.1.9.3': ('camellia-128-ofb', 'CAMELLIA-128-OFB'), + '0.3.4401.5.3.1.9.4': ('camellia-128-cfb', 'CAMELLIA-128-CFB'), + '0.3.4401.5.3.1.9.6': ('camellia-128-gcm', 'CAMELLIA-128-GCM'), + '0.3.4401.5.3.1.9.7': ('camellia-128-ccm', 'CAMELLIA-128-CCM'), + '0.3.4401.5.3.1.9.9': ('camellia-128-ctr', 'CAMELLIA-128-CTR'), + '0.3.4401.5.3.1.9.10': ('camellia-128-cmac', 'CAMELLIA-128-CMAC'), + '0.3.4401.5.3.1.9.21': ('camellia-192-ecb', 'CAMELLIA-192-ECB'), + '0.3.4401.5.3.1.9.23': ('camellia-192-ofb', 'CAMELLIA-192-OFB'), + '0.3.4401.5.3.1.9.24': ('camellia-192-cfb', 'CAMELLIA-192-CFB'), + '0.3.4401.5.3.1.9.26': ('camellia-192-gcm', 'CAMELLIA-192-GCM'), + '0.3.4401.5.3.1.9.27': ('camellia-192-ccm', 'CAMELLIA-192-CCM'), + '0.3.4401.5.3.1.9.29': ('camellia-192-ctr', 'CAMELLIA-192-CTR'), + '0.3.4401.5.3.1.9.30': ('camellia-192-cmac', 'CAMELLIA-192-CMAC'), + '0.3.4401.5.3.1.9.41': ('camellia-256-ecb', 'CAMELLIA-256-ECB'), + '0.3.4401.5.3.1.9.43': ('camellia-256-ofb', 'CAMELLIA-256-OFB'), + '0.3.4401.5.3.1.9.44': ('camellia-256-cfb', 'CAMELLIA-256-CFB'), + '0.3.4401.5.3.1.9.46': ('camellia-256-gcm', 'CAMELLIA-256-GCM'), + '0.3.4401.5.3.1.9.47': ('camellia-256-ccm', 'CAMELLIA-256-CCM'), + '0.3.4401.5.3.1.9.49': ('camellia-256-ctr', 'CAMELLIA-256-CTR'), + '0.3.4401.5.3.1.9.50': ('camellia-256-cmac', 'CAMELLIA-256-CMAC'), + '0.9': ('data', ), + '0.9.2342': ('pss', ), + '0.9.2342.19200300': ('ucl', ), + '0.9.2342.19200300.100': ('pilot', ), + '0.9.2342.19200300.100.1': ('pilotAttributeType', ), + '0.9.2342.19200300.100.1.1': ('userId', 'UID'), + '0.9.2342.19200300.100.1.2': ('textEncodedORAddress', ), + '0.9.2342.19200300.100.1.3': ('rfc822Mailbox', 'mail'), + '0.9.2342.19200300.100.1.4': ('info', ), + '0.9.2342.19200300.100.1.5': ('favouriteDrink', ), + '0.9.2342.19200300.100.1.6': ('roomNumber', ), + '0.9.2342.19200300.100.1.7': ('photo', ), + '0.9.2342.19200300.100.1.8': ('userClass', ), + '0.9.2342.19200300.100.1.9': ('host', ), + '0.9.2342.19200300.100.1.10': ('manager', ), + '0.9.2342.19200300.100.1.11': ('documentIdentifier', ), + '0.9.2342.19200300.100.1.12': ('documentTitle', ), + '0.9.2342.19200300.100.1.13': ('documentVersion', ), + '0.9.2342.19200300.100.1.14': ('documentAuthor', ), + '0.9.2342.19200300.100.1.15': ('documentLocation', ), + '0.9.2342.19200300.100.1.20': ('homeTelephoneNumber', ), + '0.9.2342.19200300.100.1.21': ('secretary', ), + '0.9.2342.19200300.100.1.22': ('otherMailbox', ), + '0.9.2342.19200300.100.1.23': ('lastModifiedTime', ), + '0.9.2342.19200300.100.1.24': ('lastModifiedBy', ), + '0.9.2342.19200300.100.1.25': ('domainComponent', 'DC'), + '0.9.2342.19200300.100.1.26': ('aRecord', ), + '0.9.2342.19200300.100.1.27': ('pilotAttributeType27', ), + '0.9.2342.19200300.100.1.28': ('mXRecord', ), + '0.9.2342.19200300.100.1.29': ('nSRecord', ), + '0.9.2342.19200300.100.1.30': ('sOARecord', ), + '0.9.2342.19200300.100.1.31': ('cNAMERecord', ), + '0.9.2342.19200300.100.1.37': ('associatedDomain', ), + '0.9.2342.19200300.100.1.38': ('associatedName', ), + '0.9.2342.19200300.100.1.39': ('homePostalAddress', ), + '0.9.2342.19200300.100.1.40': ('personalTitle', ), + '0.9.2342.19200300.100.1.41': ('mobileTelephoneNumber', ), + '0.9.2342.19200300.100.1.42': ('pagerTelephoneNumber', ), + '0.9.2342.19200300.100.1.43': ('friendlyCountryName', ), + '0.9.2342.19200300.100.1.44': ('uniqueIdentifier', 'uid'), + '0.9.2342.19200300.100.1.45': ('organizationalStatus', ), + '0.9.2342.19200300.100.1.46': ('janetMailbox', ), + '0.9.2342.19200300.100.1.47': ('mailPreferenceOption', ), + '0.9.2342.19200300.100.1.48': ('buildingName', ), + '0.9.2342.19200300.100.1.49': ('dSAQuality', ), + '0.9.2342.19200300.100.1.50': ('singleLevelQuality', ), + '0.9.2342.19200300.100.1.51': ('subtreeMinimumQuality', ), + '0.9.2342.19200300.100.1.52': ('subtreeMaximumQuality', ), + '0.9.2342.19200300.100.1.53': ('personalSignature', ), + '0.9.2342.19200300.100.1.54': ('dITRedirect', ), + '0.9.2342.19200300.100.1.55': ('audio', ), + '0.9.2342.19200300.100.1.56': ('documentPublisher', ), + '0.9.2342.19200300.100.3': ('pilotAttributeSyntax', ), + '0.9.2342.19200300.100.3.4': ('iA5StringSyntax', ), + '0.9.2342.19200300.100.3.5': ('caseIgnoreIA5StringSyntax', ), + '0.9.2342.19200300.100.4': ('pilotObjectClass', ), + '0.9.2342.19200300.100.4.3': ('pilotObject', ), + '0.9.2342.19200300.100.4.4': ('pilotPerson', ), + '0.9.2342.19200300.100.4.5': ('account', ), + '0.9.2342.19200300.100.4.6': ('document', ), + '0.9.2342.19200300.100.4.7': ('room', ), + '0.9.2342.19200300.100.4.9': ('documentSeries', ), + '0.9.2342.19200300.100.4.13': ('Domain', 'domain'), + '0.9.2342.19200300.100.4.14': ('rFC822localPart', ), + '0.9.2342.19200300.100.4.15': ('dNSDomain', ), + '0.9.2342.19200300.100.4.17': ('domainRelatedObject', ), + '0.9.2342.19200300.100.4.18': ('friendlyCountry', ), + '0.9.2342.19200300.100.4.19': ('simpleSecurityObject', ), + '0.9.2342.19200300.100.4.20': ('pilotOrganization', ), + '0.9.2342.19200300.100.4.21': ('pilotDSA', ), + '0.9.2342.19200300.100.4.22': ('qualityLabelledData', ), + '0.9.2342.19200300.100.10': ('pilotGroups', ), + '1': ('iso', 'ISO'), + '1.0.9797.3.4': ('gmac', 'GMAC'), + '1.0.10118.3.0.55': ('whirlpool', ), + '1.2': ('ISO Member Body', 'member-body'), + '1.2.156': ('ISO CN Member Body', 'ISO-CN'), + '1.2.156.10197': ('oscca', ), + '1.2.156.10197.1': ('sm-scheme', ), + '1.2.156.10197.1.104.1': ('sm4-ecb', 'SM4-ECB'), + '1.2.156.10197.1.104.2': ('sm4-cbc', 'SM4-CBC'), + '1.2.156.10197.1.104.3': ('sm4-ofb', 'SM4-OFB'), + '1.2.156.10197.1.104.4': ('sm4-cfb', 'SM4-CFB'), + '1.2.156.10197.1.104.5': ('sm4-cfb1', 'SM4-CFB1'), + '1.2.156.10197.1.104.6': ('sm4-cfb8', 'SM4-CFB8'), + '1.2.156.10197.1.104.7': ('sm4-ctr', 'SM4-CTR'), + '1.2.156.10197.1.301': ('sm2', 'SM2'), + '1.2.156.10197.1.401': ('sm3', 'SM3'), + '1.2.156.10197.1.501': ('SM2-with-SM3', 'SM2-SM3'), + '1.2.156.10197.1.504': ('sm3WithRSAEncryption', 'RSA-SM3'), + '1.2.392.200011.61.1.1.1.2': ('camellia-128-cbc', 'CAMELLIA-128-CBC'), + '1.2.392.200011.61.1.1.1.3': ('camellia-192-cbc', 'CAMELLIA-192-CBC'), + '1.2.392.200011.61.1.1.1.4': ('camellia-256-cbc', 'CAMELLIA-256-CBC'), + '1.2.392.200011.61.1.1.3.2': ('id-camellia128-wrap', ), + '1.2.392.200011.61.1.1.3.3': ('id-camellia192-wrap', ), + '1.2.392.200011.61.1.1.3.4': ('id-camellia256-wrap', ), + '1.2.410.200004': ('kisa', 'KISA'), + '1.2.410.200004.1.3': ('seed-ecb', 'SEED-ECB'), + '1.2.410.200004.1.4': ('seed-cbc', 'SEED-CBC'), + '1.2.410.200004.1.5': ('seed-cfb', 'SEED-CFB'), + '1.2.410.200004.1.6': ('seed-ofb', 'SEED-OFB'), + '1.2.410.200046.1.1': ('aria', ), + '1.2.410.200046.1.1.1': ('aria-128-ecb', 'ARIA-128-ECB'), + '1.2.410.200046.1.1.2': ('aria-128-cbc', 'ARIA-128-CBC'), + '1.2.410.200046.1.1.3': ('aria-128-cfb', 'ARIA-128-CFB'), + '1.2.410.200046.1.1.4': ('aria-128-ofb', 'ARIA-128-OFB'), + '1.2.410.200046.1.1.5': ('aria-128-ctr', 'ARIA-128-CTR'), + '1.2.410.200046.1.1.6': ('aria-192-ecb', 'ARIA-192-ECB'), + '1.2.410.200046.1.1.7': ('aria-192-cbc', 'ARIA-192-CBC'), + '1.2.410.200046.1.1.8': ('aria-192-cfb', 'ARIA-192-CFB'), + '1.2.410.200046.1.1.9': ('aria-192-ofb', 'ARIA-192-OFB'), + '1.2.410.200046.1.1.10': ('aria-192-ctr', 'ARIA-192-CTR'), + '1.2.410.200046.1.1.11': ('aria-256-ecb', 'ARIA-256-ECB'), + '1.2.410.200046.1.1.12': ('aria-256-cbc', 'ARIA-256-CBC'), + '1.2.410.200046.1.1.13': ('aria-256-cfb', 'ARIA-256-CFB'), + '1.2.410.200046.1.1.14': ('aria-256-ofb', 'ARIA-256-OFB'), + '1.2.410.200046.1.1.15': ('aria-256-ctr', 'ARIA-256-CTR'), + '1.2.410.200046.1.1.34': ('aria-128-gcm', 'ARIA-128-GCM'), + '1.2.410.200046.1.1.35': ('aria-192-gcm', 'ARIA-192-GCM'), + '1.2.410.200046.1.1.36': ('aria-256-gcm', 'ARIA-256-GCM'), + '1.2.410.200046.1.1.37': ('aria-128-ccm', 'ARIA-128-CCM'), + '1.2.410.200046.1.1.38': ('aria-192-ccm', 'ARIA-192-CCM'), + '1.2.410.200046.1.1.39': ('aria-256-ccm', 'ARIA-256-CCM'), + '1.2.643.2.2': ('cryptopro', ), + '1.2.643.2.2.3': ('GOST R 34.11-94 with GOST R 34.10-2001', 'id-GostR3411-94-with-GostR3410-2001'), + '1.2.643.2.2.4': ('GOST R 34.11-94 with GOST R 34.10-94', 'id-GostR3411-94-with-GostR3410-94'), + '1.2.643.2.2.9': ('GOST R 34.11-94', 'md_gost94'), + '1.2.643.2.2.10': ('HMAC GOST 34.11-94', 'id-HMACGostR3411-94'), + '1.2.643.2.2.14.0': ('id-Gost28147-89-None-KeyMeshing', ), + '1.2.643.2.2.14.1': ('id-Gost28147-89-CryptoPro-KeyMeshing', ), + '1.2.643.2.2.19': ('GOST R 34.10-2001', 'gost2001'), + '1.2.643.2.2.20': ('GOST R 34.10-94', 'gost94'), + '1.2.643.2.2.20.1': ('id-GostR3410-94-a', ), + '1.2.643.2.2.20.2': ('id-GostR3410-94-aBis', ), + '1.2.643.2.2.20.3': ('id-GostR3410-94-b', ), + '1.2.643.2.2.20.4': ('id-GostR3410-94-bBis', ), + '1.2.643.2.2.21': ('GOST 28147-89', 'gost89'), + '1.2.643.2.2.22': ('GOST 28147-89 MAC', 'gost-mac'), + '1.2.643.2.2.23': ('GOST R 34.11-94 PRF', 'prf-gostr3411-94'), + '1.2.643.2.2.30.0': ('id-GostR3411-94-TestParamSet', ), + '1.2.643.2.2.30.1': ('id-GostR3411-94-CryptoProParamSet', ), + '1.2.643.2.2.31.0': ('id-Gost28147-89-TestParamSet', ), + '1.2.643.2.2.31.1': ('id-Gost28147-89-CryptoPro-A-ParamSet', ), + '1.2.643.2.2.31.2': ('id-Gost28147-89-CryptoPro-B-ParamSet', ), + '1.2.643.2.2.31.3': ('id-Gost28147-89-CryptoPro-C-ParamSet', ), + '1.2.643.2.2.31.4': ('id-Gost28147-89-CryptoPro-D-ParamSet', ), + '1.2.643.2.2.31.5': ('id-Gost28147-89-CryptoPro-Oscar-1-1-ParamSet', ), + '1.2.643.2.2.31.6': ('id-Gost28147-89-CryptoPro-Oscar-1-0-ParamSet', ), + '1.2.643.2.2.31.7': ('id-Gost28147-89-CryptoPro-RIC-1-ParamSet', ), + '1.2.643.2.2.32.0': ('id-GostR3410-94-TestParamSet', ), + '1.2.643.2.2.32.2': ('id-GostR3410-94-CryptoPro-A-ParamSet', ), + '1.2.643.2.2.32.3': ('id-GostR3410-94-CryptoPro-B-ParamSet', ), + '1.2.643.2.2.32.4': ('id-GostR3410-94-CryptoPro-C-ParamSet', ), + '1.2.643.2.2.32.5': ('id-GostR3410-94-CryptoPro-D-ParamSet', ), + '1.2.643.2.2.33.1': ('id-GostR3410-94-CryptoPro-XchA-ParamSet', ), + '1.2.643.2.2.33.2': ('id-GostR3410-94-CryptoPro-XchB-ParamSet', ), + '1.2.643.2.2.33.3': ('id-GostR3410-94-CryptoPro-XchC-ParamSet', ), + '1.2.643.2.2.35.0': ('id-GostR3410-2001-TestParamSet', ), + '1.2.643.2.2.35.1': ('id-GostR3410-2001-CryptoPro-A-ParamSet', ), + '1.2.643.2.2.35.2': ('id-GostR3410-2001-CryptoPro-B-ParamSet', ), + '1.2.643.2.2.35.3': ('id-GostR3410-2001-CryptoPro-C-ParamSet', ), + '1.2.643.2.2.36.0': ('id-GostR3410-2001-CryptoPro-XchA-ParamSet', ), + '1.2.643.2.2.36.1': ('id-GostR3410-2001-CryptoPro-XchB-ParamSet', ), + '1.2.643.2.2.98': ('GOST R 34.10-2001 DH', 'id-GostR3410-2001DH'), + '1.2.643.2.2.99': ('GOST R 34.10-94 DH', 'id-GostR3410-94DH'), + '1.2.643.2.9': ('cryptocom', ), + '1.2.643.2.9.1.3.3': ('GOST R 34.11-94 with GOST R 34.10-94 Cryptocom', 'id-GostR3411-94-with-GostR3410-94-cc'), + '1.2.643.2.9.1.3.4': ('GOST R 34.11-94 with GOST R 34.10-2001 Cryptocom', 'id-GostR3411-94-with-GostR3410-2001-cc'), + '1.2.643.2.9.1.5.3': ('GOST 34.10-94 Cryptocom', 'gost94cc'), + '1.2.643.2.9.1.5.4': ('GOST 34.10-2001 Cryptocom', 'gost2001cc'), + '1.2.643.2.9.1.6.1': ('GOST 28147-89 Cryptocom ParamSet', 'id-Gost28147-89-cc'), + '1.2.643.2.9.1.8.1': ('GOST R 3410-2001 Parameter Set Cryptocom', 'id-GostR3410-2001-ParamSet-cc'), + '1.2.643.3.131.1.1': ('INN', 'INN'), + '1.2.643.7.1': ('id-tc26', ), + '1.2.643.7.1.1': ('id-tc26-algorithms', ), + '1.2.643.7.1.1.1': ('id-tc26-sign', ), + '1.2.643.7.1.1.1.1': ('GOST R 34.10-2012 with 256 bit modulus', 'gost2012_256'), + '1.2.643.7.1.1.1.2': ('GOST R 34.10-2012 with 512 bit modulus', 'gost2012_512'), + '1.2.643.7.1.1.2': ('id-tc26-digest', ), + '1.2.643.7.1.1.2.2': ('GOST R 34.11-2012 with 256 bit hash', 'md_gost12_256'), + '1.2.643.7.1.1.2.3': ('GOST R 34.11-2012 with 512 bit hash', 'md_gost12_512'), + '1.2.643.7.1.1.3': ('id-tc26-signwithdigest', ), + '1.2.643.7.1.1.3.2': ('GOST R 34.10-2012 with GOST R 34.11-2012 (256 bit)', 'id-tc26-signwithdigest-gost3410-2012-256'), + '1.2.643.7.1.1.3.3': ('GOST R 34.10-2012 with GOST R 34.11-2012 (512 bit)', 'id-tc26-signwithdigest-gost3410-2012-512'), + '1.2.643.7.1.1.4': ('id-tc26-mac', ), + '1.2.643.7.1.1.4.1': ('HMAC GOST 34.11-2012 256 bit', 'id-tc26-hmac-gost-3411-2012-256'), + '1.2.643.7.1.1.4.2': ('HMAC GOST 34.11-2012 512 bit', 'id-tc26-hmac-gost-3411-2012-512'), + '1.2.643.7.1.1.5': ('id-tc26-cipher', ), + '1.2.643.7.1.1.5.1': ('id-tc26-cipher-gostr3412-2015-magma', ), + '1.2.643.7.1.1.5.1.1': ('id-tc26-cipher-gostr3412-2015-magma-ctracpkm', ), + '1.2.643.7.1.1.5.1.2': ('id-tc26-cipher-gostr3412-2015-magma-ctracpkm-omac', ), + '1.2.643.7.1.1.5.2': ('id-tc26-cipher-gostr3412-2015-kuznyechik', ), + '1.2.643.7.1.1.5.2.1': ('id-tc26-cipher-gostr3412-2015-kuznyechik-ctracpkm', ), + '1.2.643.7.1.1.5.2.2': ('id-tc26-cipher-gostr3412-2015-kuznyechik-ctracpkm-omac', ), + '1.2.643.7.1.1.6': ('id-tc26-agreement', ), + '1.2.643.7.1.1.6.1': ('id-tc26-agreement-gost-3410-2012-256', ), + '1.2.643.7.1.1.6.2': ('id-tc26-agreement-gost-3410-2012-512', ), + '1.2.643.7.1.1.7': ('id-tc26-wrap', ), + '1.2.643.7.1.1.7.1': ('id-tc26-wrap-gostr3412-2015-magma', ), + '1.2.643.7.1.1.7.1.1': ('id-tc26-wrap-gostr3412-2015-magma-kexp15', 'id-tc26-wrap-gostr3412-2015-kuznyechik-kexp15'), + '1.2.643.7.1.1.7.2': ('id-tc26-wrap-gostr3412-2015-kuznyechik', ), + '1.2.643.7.1.2': ('id-tc26-constants', ), + '1.2.643.7.1.2.1': ('id-tc26-sign-constants', ), + '1.2.643.7.1.2.1.1': ('id-tc26-gost-3410-2012-256-constants', ), + '1.2.643.7.1.2.1.1.1': ('GOST R 34.10-2012 (256 bit) ParamSet A', 'id-tc26-gost-3410-2012-256-paramSetA'), + '1.2.643.7.1.2.1.1.2': ('GOST R 34.10-2012 (256 bit) ParamSet B', 'id-tc26-gost-3410-2012-256-paramSetB'), + '1.2.643.7.1.2.1.1.3': ('GOST R 34.10-2012 (256 bit) ParamSet C', 'id-tc26-gost-3410-2012-256-paramSetC'), + '1.2.643.7.1.2.1.1.4': ('GOST R 34.10-2012 (256 bit) ParamSet D', 'id-tc26-gost-3410-2012-256-paramSetD'), + '1.2.643.7.1.2.1.2': ('id-tc26-gost-3410-2012-512-constants', ), + '1.2.643.7.1.2.1.2.0': ('GOST R 34.10-2012 (512 bit) testing parameter set', 'id-tc26-gost-3410-2012-512-paramSetTest'), + '1.2.643.7.1.2.1.2.1': ('GOST R 34.10-2012 (512 bit) ParamSet A', 'id-tc26-gost-3410-2012-512-paramSetA'), + '1.2.643.7.1.2.1.2.2': ('GOST R 34.10-2012 (512 bit) ParamSet B', 'id-tc26-gost-3410-2012-512-paramSetB'), + '1.2.643.7.1.2.1.2.3': ('GOST R 34.10-2012 (512 bit) ParamSet C', 'id-tc26-gost-3410-2012-512-paramSetC'), + '1.2.643.7.1.2.2': ('id-tc26-digest-constants', ), + '1.2.643.7.1.2.5': ('id-tc26-cipher-constants', ), + '1.2.643.7.1.2.5.1': ('id-tc26-gost-28147-constants', ), + '1.2.643.7.1.2.5.1.1': ('GOST 28147-89 TC26 parameter set', 'id-tc26-gost-28147-param-Z'), + '1.2.643.100.1': ('OGRN', 'OGRN'), + '1.2.643.100.3': ('SNILS', 'SNILS'), + '1.2.643.100.111': ('Signing Tool of Subject', 'subjectSignTool'), + '1.2.643.100.112': ('Signing Tool of Issuer', 'issuerSignTool'), + '1.2.804': ('ISO-UA', ), + '1.2.804.2.1.1.1': ('ua-pki', ), + '1.2.804.2.1.1.1.1.1.1': ('DSTU Gost 28147-2009', 'dstu28147'), + '1.2.804.2.1.1.1.1.1.1.2': ('DSTU Gost 28147-2009 OFB mode', 'dstu28147-ofb'), + '1.2.804.2.1.1.1.1.1.1.3': ('DSTU Gost 28147-2009 CFB mode', 'dstu28147-cfb'), + '1.2.804.2.1.1.1.1.1.1.5': ('DSTU Gost 28147-2009 key wrap', 'dstu28147-wrap'), + '1.2.804.2.1.1.1.1.1.2': ('HMAC DSTU Gost 34311-95', 'hmacWithDstu34311'), + '1.2.804.2.1.1.1.1.2.1': ('DSTU Gost 34311-95', 'dstu34311'), + '1.2.804.2.1.1.1.1.3.1.1': ('DSTU 4145-2002 little endian', 'dstu4145le'), + '1.2.804.2.1.1.1.1.3.1.1.1.1': ('DSTU 4145-2002 big endian', 'dstu4145be'), + '1.2.804.2.1.1.1.1.3.1.1.2.0': ('DSTU curve 0', 'uacurve0'), + '1.2.804.2.1.1.1.1.3.1.1.2.1': ('DSTU curve 1', 'uacurve1'), + '1.2.804.2.1.1.1.1.3.1.1.2.2': ('DSTU curve 2', 'uacurve2'), + '1.2.804.2.1.1.1.1.3.1.1.2.3': ('DSTU curve 3', 'uacurve3'), + '1.2.804.2.1.1.1.1.3.1.1.2.4': ('DSTU curve 4', 'uacurve4'), + '1.2.804.2.1.1.1.1.3.1.1.2.5': ('DSTU curve 5', 'uacurve5'), + '1.2.804.2.1.1.1.1.3.1.1.2.6': ('DSTU curve 6', 'uacurve6'), + '1.2.804.2.1.1.1.1.3.1.1.2.7': ('DSTU curve 7', 'uacurve7'), + '1.2.804.2.1.1.1.1.3.1.1.2.8': ('DSTU curve 8', 'uacurve8'), + '1.2.804.2.1.1.1.1.3.1.1.2.9': ('DSTU curve 9', 'uacurve9'), + '1.2.840': ('ISO US Member Body', 'ISO-US'), + '1.2.840.10040': ('X9.57', 'X9-57'), + '1.2.840.10040.2': ('holdInstruction', ), + '1.2.840.10040.2.1': ('Hold Instruction None', 'holdInstructionNone'), + '1.2.840.10040.2.2': ('Hold Instruction Call Issuer', 'holdInstructionCallIssuer'), + '1.2.840.10040.2.3': ('Hold Instruction Reject', 'holdInstructionReject'), + '1.2.840.10040.4': ('X9.57 CM ?', 'X9cm'), + '1.2.840.10040.4.1': ('dsaEncryption', 'DSA'), + '1.2.840.10040.4.3': ('dsaWithSHA1', 'DSA-SHA1'), + '1.2.840.10045': ('ANSI X9.62', 'ansi-X9-62'), + '1.2.840.10045.1': ('id-fieldType', ), + '1.2.840.10045.1.1': ('prime-field', ), + '1.2.840.10045.1.2': ('characteristic-two-field', ), + '1.2.840.10045.1.2.3': ('id-characteristic-two-basis', ), + '1.2.840.10045.1.2.3.1': ('onBasis', ), + '1.2.840.10045.1.2.3.2': ('tpBasis', ), + '1.2.840.10045.1.2.3.3': ('ppBasis', ), + '1.2.840.10045.2': ('id-publicKeyType', ), + '1.2.840.10045.2.1': ('id-ecPublicKey', ), + '1.2.840.10045.3': ('ellipticCurve', ), + '1.2.840.10045.3.0': ('c-TwoCurve', ), + '1.2.840.10045.3.0.1': ('c2pnb163v1', ), + '1.2.840.10045.3.0.2': ('c2pnb163v2', ), + '1.2.840.10045.3.0.3': ('c2pnb163v3', ), + '1.2.840.10045.3.0.4': ('c2pnb176v1', ), + '1.2.840.10045.3.0.5': ('c2tnb191v1', ), + '1.2.840.10045.3.0.6': ('c2tnb191v2', ), + '1.2.840.10045.3.0.7': ('c2tnb191v3', ), + '1.2.840.10045.3.0.8': ('c2onb191v4', ), + '1.2.840.10045.3.0.9': ('c2onb191v5', ), + '1.2.840.10045.3.0.10': ('c2pnb208w1', ), + '1.2.840.10045.3.0.11': ('c2tnb239v1', ), + '1.2.840.10045.3.0.12': ('c2tnb239v2', ), + '1.2.840.10045.3.0.13': ('c2tnb239v3', ), + '1.2.840.10045.3.0.14': ('c2onb239v4', ), + '1.2.840.10045.3.0.15': ('c2onb239v5', ), + '1.2.840.10045.3.0.16': ('c2pnb272w1', ), + '1.2.840.10045.3.0.17': ('c2pnb304w1', ), + '1.2.840.10045.3.0.18': ('c2tnb359v1', ), + '1.2.840.10045.3.0.19': ('c2pnb368w1', ), + '1.2.840.10045.3.0.20': ('c2tnb431r1', ), + '1.2.840.10045.3.1': ('primeCurve', ), + '1.2.840.10045.3.1.1': ('prime192v1', ), + '1.2.840.10045.3.1.2': ('prime192v2', ), + '1.2.840.10045.3.1.3': ('prime192v3', ), + '1.2.840.10045.3.1.4': ('prime239v1', ), + '1.2.840.10045.3.1.5': ('prime239v2', ), + '1.2.840.10045.3.1.6': ('prime239v3', ), + '1.2.840.10045.3.1.7': ('prime256v1', ), + '1.2.840.10045.4': ('id-ecSigType', ), + '1.2.840.10045.4.1': ('ecdsa-with-SHA1', ), + '1.2.840.10045.4.2': ('ecdsa-with-Recommended', ), + '1.2.840.10045.4.3': ('ecdsa-with-Specified', ), + '1.2.840.10045.4.3.1': ('ecdsa-with-SHA224', ), + '1.2.840.10045.4.3.2': ('ecdsa-with-SHA256', ), + '1.2.840.10045.4.3.3': ('ecdsa-with-SHA384', ), + '1.2.840.10045.4.3.4': ('ecdsa-with-SHA512', ), + '1.2.840.10046.2.1': ('X9.42 DH', 'dhpublicnumber'), + '1.2.840.113533.7.66.10': ('cast5-cbc', 'CAST5-CBC'), + '1.2.840.113533.7.66.12': ('pbeWithMD5AndCast5CBC', ), + '1.2.840.113533.7.66.13': ('password based MAC', 'id-PasswordBasedMAC'), + '1.2.840.113533.7.66.30': ('Diffie-Hellman based MAC', 'id-DHBasedMac'), + '1.2.840.113549': ('RSA Data Security, Inc.', 'rsadsi'), + '1.2.840.113549.1': ('RSA Data Security, Inc. PKCS', 'pkcs'), + '1.2.840.113549.1.1': ('pkcs1', ), + '1.2.840.113549.1.1.1': ('rsaEncryption', ), + '1.2.840.113549.1.1.2': ('md2WithRSAEncryption', 'RSA-MD2'), + '1.2.840.113549.1.1.3': ('md4WithRSAEncryption', 'RSA-MD4'), + '1.2.840.113549.1.1.4': ('md5WithRSAEncryption', 'RSA-MD5'), + '1.2.840.113549.1.1.5': ('sha1WithRSAEncryption', 'RSA-SHA1'), + '1.2.840.113549.1.1.6': ('rsaOAEPEncryptionSET', ), + '1.2.840.113549.1.1.7': ('rsaesOaep', 'RSAES-OAEP'), + '1.2.840.113549.1.1.8': ('mgf1', 'MGF1'), + '1.2.840.113549.1.1.9': ('pSpecified', 'PSPECIFIED'), + '1.2.840.113549.1.1.10': ('rsassaPss', 'RSASSA-PSS'), + '1.2.840.113549.1.1.11': ('sha256WithRSAEncryption', 'RSA-SHA256'), + '1.2.840.113549.1.1.12': ('sha384WithRSAEncryption', 'RSA-SHA384'), + '1.2.840.113549.1.1.13': ('sha512WithRSAEncryption', 'RSA-SHA512'), + '1.2.840.113549.1.1.14': ('sha224WithRSAEncryption', 'RSA-SHA224'), + '1.2.840.113549.1.1.15': ('sha512-224WithRSAEncryption', 'RSA-SHA512/224'), + '1.2.840.113549.1.1.16': ('sha512-256WithRSAEncryption', 'RSA-SHA512/256'), + '1.2.840.113549.1.3': ('pkcs3', ), + '1.2.840.113549.1.3.1': ('dhKeyAgreement', ), + '1.2.840.113549.1.5': ('pkcs5', ), + '1.2.840.113549.1.5.1': ('pbeWithMD2AndDES-CBC', 'PBE-MD2-DES'), + '1.2.840.113549.1.5.3': ('pbeWithMD5AndDES-CBC', 'PBE-MD5-DES'), + '1.2.840.113549.1.5.4': ('pbeWithMD2AndRC2-CBC', 'PBE-MD2-RC2-64'), + '1.2.840.113549.1.5.6': ('pbeWithMD5AndRC2-CBC', 'PBE-MD5-RC2-64'), + '1.2.840.113549.1.5.10': ('pbeWithSHA1AndDES-CBC', 'PBE-SHA1-DES'), + '1.2.840.113549.1.5.11': ('pbeWithSHA1AndRC2-CBC', 'PBE-SHA1-RC2-64'), + '1.2.840.113549.1.5.12': ('PBKDF2', ), + '1.2.840.113549.1.5.13': ('PBES2', ), + '1.2.840.113549.1.5.14': ('PBMAC1', ), + '1.2.840.113549.1.7': ('pkcs7', ), + '1.2.840.113549.1.7.1': ('pkcs7-data', ), + '1.2.840.113549.1.7.2': ('pkcs7-signedData', ), + '1.2.840.113549.1.7.3': ('pkcs7-envelopedData', ), + '1.2.840.113549.1.7.4': ('pkcs7-signedAndEnvelopedData', ), + '1.2.840.113549.1.7.5': ('pkcs7-digestData', ), + '1.2.840.113549.1.7.6': ('pkcs7-encryptedData', ), + '1.2.840.113549.1.9': ('pkcs9', ), + '1.2.840.113549.1.9.1': ('emailAddress', ), + '1.2.840.113549.1.9.2': ('unstructuredName', ), + '1.2.840.113549.1.9.3': ('contentType', ), + '1.2.840.113549.1.9.4': ('messageDigest', ), + '1.2.840.113549.1.9.5': ('signingTime', ), + '1.2.840.113549.1.9.6': ('countersignature', ), + '1.2.840.113549.1.9.7': ('challengePassword', ), + '1.2.840.113549.1.9.8': ('unstructuredAddress', ), + '1.2.840.113549.1.9.9': ('extendedCertificateAttributes', ), + '1.2.840.113549.1.9.14': ('Extension Request', 'extReq'), + '1.2.840.113549.1.9.15': ('S/MIME Capabilities', 'SMIME-CAPS'), + '1.2.840.113549.1.9.16': ('S/MIME', 'SMIME'), + '1.2.840.113549.1.9.16.0': ('id-smime-mod', ), + '1.2.840.113549.1.9.16.0.1': ('id-smime-mod-cms', ), + '1.2.840.113549.1.9.16.0.2': ('id-smime-mod-ess', ), + '1.2.840.113549.1.9.16.0.3': ('id-smime-mod-oid', ), + '1.2.840.113549.1.9.16.0.4': ('id-smime-mod-msg-v3', ), + '1.2.840.113549.1.9.16.0.5': ('id-smime-mod-ets-eSignature-88', ), + '1.2.840.113549.1.9.16.0.6': ('id-smime-mod-ets-eSignature-97', ), + '1.2.840.113549.1.9.16.0.7': ('id-smime-mod-ets-eSigPolicy-88', ), + '1.2.840.113549.1.9.16.0.8': ('id-smime-mod-ets-eSigPolicy-97', ), + '1.2.840.113549.1.9.16.1': ('id-smime-ct', ), + '1.2.840.113549.1.9.16.1.1': ('id-smime-ct-receipt', ), + '1.2.840.113549.1.9.16.1.2': ('id-smime-ct-authData', ), + '1.2.840.113549.1.9.16.1.3': ('id-smime-ct-publishCert', ), + '1.2.840.113549.1.9.16.1.4': ('id-smime-ct-TSTInfo', ), + '1.2.840.113549.1.9.16.1.5': ('id-smime-ct-TDTInfo', ), + '1.2.840.113549.1.9.16.1.6': ('id-smime-ct-contentInfo', ), + '1.2.840.113549.1.9.16.1.7': ('id-smime-ct-DVCSRequestData', ), + '1.2.840.113549.1.9.16.1.8': ('id-smime-ct-DVCSResponseData', ), + '1.2.840.113549.1.9.16.1.9': ('id-smime-ct-compressedData', ), + '1.2.840.113549.1.9.16.1.19': ('id-smime-ct-contentCollection', ), + '1.2.840.113549.1.9.16.1.23': ('id-smime-ct-authEnvelopedData', ), + '1.2.840.113549.1.9.16.1.27': ('id-ct-asciiTextWithCRLF', ), + '1.2.840.113549.1.9.16.1.28': ('id-ct-xml', ), + '1.2.840.113549.1.9.16.2': ('id-smime-aa', ), + '1.2.840.113549.1.9.16.2.1': ('id-smime-aa-receiptRequest', ), + '1.2.840.113549.1.9.16.2.2': ('id-smime-aa-securityLabel', ), + '1.2.840.113549.1.9.16.2.3': ('id-smime-aa-mlExpandHistory', ), + '1.2.840.113549.1.9.16.2.4': ('id-smime-aa-contentHint', ), + '1.2.840.113549.1.9.16.2.5': ('id-smime-aa-msgSigDigest', ), + '1.2.840.113549.1.9.16.2.6': ('id-smime-aa-encapContentType', ), + '1.2.840.113549.1.9.16.2.7': ('id-smime-aa-contentIdentifier', ), + '1.2.840.113549.1.9.16.2.8': ('id-smime-aa-macValue', ), + '1.2.840.113549.1.9.16.2.9': ('id-smime-aa-equivalentLabels', ), + '1.2.840.113549.1.9.16.2.10': ('id-smime-aa-contentReference', ), + '1.2.840.113549.1.9.16.2.11': ('id-smime-aa-encrypKeyPref', ), + '1.2.840.113549.1.9.16.2.12': ('id-smime-aa-signingCertificate', ), + '1.2.840.113549.1.9.16.2.13': ('id-smime-aa-smimeEncryptCerts', ), + '1.2.840.113549.1.9.16.2.14': ('id-smime-aa-timeStampToken', ), + '1.2.840.113549.1.9.16.2.15': ('id-smime-aa-ets-sigPolicyId', ), + '1.2.840.113549.1.9.16.2.16': ('id-smime-aa-ets-commitmentType', ), + '1.2.840.113549.1.9.16.2.17': ('id-smime-aa-ets-signerLocation', ), + '1.2.840.113549.1.9.16.2.18': ('id-smime-aa-ets-signerAttr', ), + '1.2.840.113549.1.9.16.2.19': ('id-smime-aa-ets-otherSigCert', ), + '1.2.840.113549.1.9.16.2.20': ('id-smime-aa-ets-contentTimestamp', ), + '1.2.840.113549.1.9.16.2.21': ('id-smime-aa-ets-CertificateRefs', ), + '1.2.840.113549.1.9.16.2.22': ('id-smime-aa-ets-RevocationRefs', ), + '1.2.840.113549.1.9.16.2.23': ('id-smime-aa-ets-certValues', ), + '1.2.840.113549.1.9.16.2.24': ('id-smime-aa-ets-revocationValues', ), + '1.2.840.113549.1.9.16.2.25': ('id-smime-aa-ets-escTimeStamp', ), + '1.2.840.113549.1.9.16.2.26': ('id-smime-aa-ets-certCRLTimestamp', ), + '1.2.840.113549.1.9.16.2.27': ('id-smime-aa-ets-archiveTimeStamp', ), + '1.2.840.113549.1.9.16.2.28': ('id-smime-aa-signatureType', ), + '1.2.840.113549.1.9.16.2.29': ('id-smime-aa-dvcs-dvc', ), + '1.2.840.113549.1.9.16.2.47': ('id-smime-aa-signingCertificateV2', ), + '1.2.840.113549.1.9.16.3': ('id-smime-alg', ), + '1.2.840.113549.1.9.16.3.1': ('id-smime-alg-ESDHwith3DES', ), + '1.2.840.113549.1.9.16.3.2': ('id-smime-alg-ESDHwithRC2', ), + '1.2.840.113549.1.9.16.3.3': ('id-smime-alg-3DESwrap', ), + '1.2.840.113549.1.9.16.3.4': ('id-smime-alg-RC2wrap', ), + '1.2.840.113549.1.9.16.3.5': ('id-smime-alg-ESDH', ), + '1.2.840.113549.1.9.16.3.6': ('id-smime-alg-CMS3DESwrap', ), + '1.2.840.113549.1.9.16.3.7': ('id-smime-alg-CMSRC2wrap', ), + '1.2.840.113549.1.9.16.3.8': ('zlib compression', 'ZLIB'), + '1.2.840.113549.1.9.16.3.9': ('id-alg-PWRI-KEK', ), + '1.2.840.113549.1.9.16.4': ('id-smime-cd', ), + '1.2.840.113549.1.9.16.4.1': ('id-smime-cd-ldap', ), + '1.2.840.113549.1.9.16.5': ('id-smime-spq', ), + '1.2.840.113549.1.9.16.5.1': ('id-smime-spq-ets-sqt-uri', ), + '1.2.840.113549.1.9.16.5.2': ('id-smime-spq-ets-sqt-unotice', ), + '1.2.840.113549.1.9.16.6': ('id-smime-cti', ), + '1.2.840.113549.1.9.16.6.1': ('id-smime-cti-ets-proofOfOrigin', ), + '1.2.840.113549.1.9.16.6.2': ('id-smime-cti-ets-proofOfReceipt', ), + '1.2.840.113549.1.9.16.6.3': ('id-smime-cti-ets-proofOfDelivery', ), + '1.2.840.113549.1.9.16.6.4': ('id-smime-cti-ets-proofOfSender', ), + '1.2.840.113549.1.9.16.6.5': ('id-smime-cti-ets-proofOfApproval', ), + '1.2.840.113549.1.9.16.6.6': ('id-smime-cti-ets-proofOfCreation', ), + '1.2.840.113549.1.9.20': ('friendlyName', ), + '1.2.840.113549.1.9.21': ('localKeyID', ), + '1.2.840.113549.1.9.22': ('certTypes', ), + '1.2.840.113549.1.9.22.1': ('x509Certificate', ), + '1.2.840.113549.1.9.22.2': ('sdsiCertificate', ), + '1.2.840.113549.1.9.23': ('crlTypes', ), + '1.2.840.113549.1.9.23.1': ('x509Crl', ), + '1.2.840.113549.1.12': ('pkcs12', ), + '1.2.840.113549.1.12.1': ('pkcs12-pbeids', ), + '1.2.840.113549.1.12.1.1': ('pbeWithSHA1And128BitRC4', 'PBE-SHA1-RC4-128'), + '1.2.840.113549.1.12.1.2': ('pbeWithSHA1And40BitRC4', 'PBE-SHA1-RC4-40'), + '1.2.840.113549.1.12.1.3': ('pbeWithSHA1And3-KeyTripleDES-CBC', 'PBE-SHA1-3DES'), + '1.2.840.113549.1.12.1.4': ('pbeWithSHA1And2-KeyTripleDES-CBC', 'PBE-SHA1-2DES'), + '1.2.840.113549.1.12.1.5': ('pbeWithSHA1And128BitRC2-CBC', 'PBE-SHA1-RC2-128'), + '1.2.840.113549.1.12.1.6': ('pbeWithSHA1And40BitRC2-CBC', 'PBE-SHA1-RC2-40'), + '1.2.840.113549.1.12.10': ('pkcs12-Version1', ), + '1.2.840.113549.1.12.10.1': ('pkcs12-BagIds', ), + '1.2.840.113549.1.12.10.1.1': ('keyBag', ), + '1.2.840.113549.1.12.10.1.2': ('pkcs8ShroudedKeyBag', ), + '1.2.840.113549.1.12.10.1.3': ('certBag', ), + '1.2.840.113549.1.12.10.1.4': ('crlBag', ), + '1.2.840.113549.1.12.10.1.5': ('secretBag', ), + '1.2.840.113549.1.12.10.1.6': ('safeContentsBag', ), + '1.2.840.113549.2.2': ('md2', 'MD2'), + '1.2.840.113549.2.4': ('md4', 'MD4'), + '1.2.840.113549.2.5': ('md5', 'MD5'), + '1.2.840.113549.2.6': ('hmacWithMD5', ), + '1.2.840.113549.2.7': ('hmacWithSHA1', ), + '1.2.840.113549.2.8': ('hmacWithSHA224', ), + '1.2.840.113549.2.9': ('hmacWithSHA256', ), + '1.2.840.113549.2.10': ('hmacWithSHA384', ), + '1.2.840.113549.2.11': ('hmacWithSHA512', ), + '1.2.840.113549.2.12': ('hmacWithSHA512-224', ), + '1.2.840.113549.2.13': ('hmacWithSHA512-256', ), + '1.2.840.113549.3.2': ('rc2-cbc', 'RC2-CBC'), + '1.2.840.113549.3.4': ('rc4', 'RC4'), + '1.2.840.113549.3.7': ('des-ede3-cbc', 'DES-EDE3-CBC'), + '1.2.840.113549.3.8': ('rc5-cbc', 'RC5-CBC'), + '1.2.840.113549.3.10': ('des-cdmf', 'DES-CDMF'), + '1.3': ('identified-organization', 'org', 'ORG'), + '1.3.6': ('dod', 'DOD'), + '1.3.6.1': ('iana', 'IANA', 'internet'), + '1.3.6.1.1': ('Directory', 'directory'), + '1.3.6.1.2': ('Management', 'mgmt'), + '1.3.6.1.3': ('Experimental', 'experimental'), + '1.3.6.1.4': ('Private', 'private'), + '1.3.6.1.4.1': ('Enterprises', 'enterprises'), + '1.3.6.1.4.1.188.7.1.1.2': ('idea-cbc', 'IDEA-CBC'), + '1.3.6.1.4.1.311.2.1.14': ('Microsoft Extension Request', 'msExtReq'), + '1.3.6.1.4.1.311.2.1.21': ('Microsoft Individual Code Signing', 'msCodeInd'), + '1.3.6.1.4.1.311.2.1.22': ('Microsoft Commercial Code Signing', 'msCodeCom'), + '1.3.6.1.4.1.311.10.3.1': ('Microsoft Trust List Signing', 'msCTLSign'), + '1.3.6.1.4.1.311.10.3.3': ('Microsoft Server Gated Crypto', 'msSGC'), + '1.3.6.1.4.1.311.10.3.4': ('Microsoft Encrypted File System', 'msEFS'), + '1.3.6.1.4.1.311.17.1': ('Microsoft CSP Name', 'CSPName'), + '1.3.6.1.4.1.311.17.2': ('Microsoft Local Key set', 'LocalKeySet'), + '1.3.6.1.4.1.311.20.2.2': ('Microsoft Smartcardlogin', 'msSmartcardLogin'), + '1.3.6.1.4.1.311.20.2.3': ('Microsoft Universal Principal Name', 'msUPN'), + '1.3.6.1.4.1.311.60.2.1.1': ('jurisdictionLocalityName', 'jurisdictionL'), + '1.3.6.1.4.1.311.60.2.1.2': ('jurisdictionStateOrProvinceName', 'jurisdictionST'), + '1.3.6.1.4.1.311.60.2.1.3': ('jurisdictionCountryName', 'jurisdictionC'), + '1.3.6.1.4.1.1466.344': ('dcObject', 'dcobject'), + '1.3.6.1.4.1.1722.12.2.1.16': ('blake2b512', 'BLAKE2b512'), + '1.3.6.1.4.1.1722.12.2.2.8': ('blake2s256', 'BLAKE2s256'), + '1.3.6.1.4.1.3029.1.2': ('bf-cbc', 'BF-CBC'), + '1.3.6.1.4.1.11129.2.4.2': ('CT Precertificate SCTs', 'ct_precert_scts'), + '1.3.6.1.4.1.11129.2.4.3': ('CT Precertificate Poison', 'ct_precert_poison'), + '1.3.6.1.4.1.11129.2.4.4': ('CT Precertificate Signer', 'ct_precert_signer'), + '1.3.6.1.4.1.11129.2.4.5': ('CT Certificate SCTs', 'ct_cert_scts'), + '1.3.6.1.4.1.11591.4.11': ('scrypt', 'id-scrypt'), + '1.3.6.1.5': ('Security', 'security'), + '1.3.6.1.5.2.3': ('id-pkinit', ), + '1.3.6.1.5.2.3.4': ('PKINIT Client Auth', 'pkInitClientAuth'), + '1.3.6.1.5.2.3.5': ('Signing KDC Response', 'pkInitKDC'), + '1.3.6.1.5.5.7': ('PKIX', ), + '1.3.6.1.5.5.7.0': ('id-pkix-mod', ), + '1.3.6.1.5.5.7.0.1': ('id-pkix1-explicit-88', ), + '1.3.6.1.5.5.7.0.2': ('id-pkix1-implicit-88', ), + '1.3.6.1.5.5.7.0.3': ('id-pkix1-explicit-93', ), + '1.3.6.1.5.5.7.0.4': ('id-pkix1-implicit-93', ), + '1.3.6.1.5.5.7.0.5': ('id-mod-crmf', ), + '1.3.6.1.5.5.7.0.6': ('id-mod-cmc', ), + '1.3.6.1.5.5.7.0.7': ('id-mod-kea-profile-88', ), + '1.3.6.1.5.5.7.0.8': ('id-mod-kea-profile-93', ), + '1.3.6.1.5.5.7.0.9': ('id-mod-cmp', ), + '1.3.6.1.5.5.7.0.10': ('id-mod-qualified-cert-88', ), + '1.3.6.1.5.5.7.0.11': ('id-mod-qualified-cert-93', ), + '1.3.6.1.5.5.7.0.12': ('id-mod-attribute-cert', ), + '1.3.6.1.5.5.7.0.13': ('id-mod-timestamp-protocol', ), + '1.3.6.1.5.5.7.0.14': ('id-mod-ocsp', ), + '1.3.6.1.5.5.7.0.15': ('id-mod-dvcs', ), + '1.3.6.1.5.5.7.0.16': ('id-mod-cmp2000', ), + '1.3.6.1.5.5.7.1': ('id-pe', ), + '1.3.6.1.5.5.7.1.1': ('Authority Information Access', 'authorityInfoAccess'), + '1.3.6.1.5.5.7.1.2': ('Biometric Info', 'biometricInfo'), + '1.3.6.1.5.5.7.1.3': ('qcStatements', ), + '1.3.6.1.5.5.7.1.4': ('ac-auditEntity', ), + '1.3.6.1.5.5.7.1.5': ('ac-targeting', ), + '1.3.6.1.5.5.7.1.6': ('aaControls', ), + '1.3.6.1.5.5.7.1.7': ('sbgp-ipAddrBlock', ), + '1.3.6.1.5.5.7.1.8': ('sbgp-autonomousSysNum', ), + '1.3.6.1.5.5.7.1.9': ('sbgp-routerIdentifier', ), + '1.3.6.1.5.5.7.1.10': ('ac-proxying', ), + '1.3.6.1.5.5.7.1.11': ('Subject Information Access', 'subjectInfoAccess'), + '1.3.6.1.5.5.7.1.14': ('Proxy Certificate Information', 'proxyCertInfo'), + '1.3.6.1.5.5.7.1.24': ('TLS Feature', 'tlsfeature'), + '1.3.6.1.5.5.7.2': ('id-qt', ), + '1.3.6.1.5.5.7.2.1': ('Policy Qualifier CPS', 'id-qt-cps'), + '1.3.6.1.5.5.7.2.2': ('Policy Qualifier User Notice', 'id-qt-unotice'), + '1.3.6.1.5.5.7.2.3': ('textNotice', ), + '1.3.6.1.5.5.7.3': ('id-kp', ), + '1.3.6.1.5.5.7.3.1': ('TLS Web Server Authentication', 'serverAuth'), + '1.3.6.1.5.5.7.3.2': ('TLS Web Client Authentication', 'clientAuth'), + '1.3.6.1.5.5.7.3.3': ('Code Signing', 'codeSigning'), + '1.3.6.1.5.5.7.3.4': ('E-mail Protection', 'emailProtection'), + '1.3.6.1.5.5.7.3.5': ('IPSec End System', 'ipsecEndSystem'), + '1.3.6.1.5.5.7.3.6': ('IPSec Tunnel', 'ipsecTunnel'), + '1.3.6.1.5.5.7.3.7': ('IPSec User', 'ipsecUser'), + '1.3.6.1.5.5.7.3.8': ('Time Stamping', 'timeStamping'), + '1.3.6.1.5.5.7.3.9': ('OCSP Signing', 'OCSPSigning'), + '1.3.6.1.5.5.7.3.10': ('dvcs', 'DVCS'), + '1.3.6.1.5.5.7.3.17': ('ipsec Internet Key Exchange', 'ipsecIKE'), + '1.3.6.1.5.5.7.3.18': ('Ctrl/provision WAP Access', 'capwapAC'), + '1.3.6.1.5.5.7.3.19': ('Ctrl/Provision WAP Termination', 'capwapWTP'), + '1.3.6.1.5.5.7.3.21': ('SSH Client', 'secureShellClient'), + '1.3.6.1.5.5.7.3.22': ('SSH Server', 'secureShellServer'), + '1.3.6.1.5.5.7.3.23': ('Send Router', 'sendRouter'), + '1.3.6.1.5.5.7.3.24': ('Send Proxied Router', 'sendProxiedRouter'), + '1.3.6.1.5.5.7.3.25': ('Send Owner', 'sendOwner'), + '1.3.6.1.5.5.7.3.26': ('Send Proxied Owner', 'sendProxiedOwner'), + '1.3.6.1.5.5.7.3.27': ('CMC Certificate Authority', 'cmcCA'), + '1.3.6.1.5.5.7.3.28': ('CMC Registration Authority', 'cmcRA'), + '1.3.6.1.5.5.7.4': ('id-it', ), + '1.3.6.1.5.5.7.4.1': ('id-it-caProtEncCert', ), + '1.3.6.1.5.5.7.4.2': ('id-it-signKeyPairTypes', ), + '1.3.6.1.5.5.7.4.3': ('id-it-encKeyPairTypes', ), + '1.3.6.1.5.5.7.4.4': ('id-it-preferredSymmAlg', ), + '1.3.6.1.5.5.7.4.5': ('id-it-caKeyUpdateInfo', ), + '1.3.6.1.5.5.7.4.6': ('id-it-currentCRL', ), + '1.3.6.1.5.5.7.4.7': ('id-it-unsupportedOIDs', ), + '1.3.6.1.5.5.7.4.8': ('id-it-subscriptionRequest', ), + '1.3.6.1.5.5.7.4.9': ('id-it-subscriptionResponse', ), + '1.3.6.1.5.5.7.4.10': ('id-it-keyPairParamReq', ), + '1.3.6.1.5.5.7.4.11': ('id-it-keyPairParamRep', ), + '1.3.6.1.5.5.7.4.12': ('id-it-revPassphrase', ), + '1.3.6.1.5.5.7.4.13': ('id-it-implicitConfirm', ), + '1.3.6.1.5.5.7.4.14': ('id-it-confirmWaitTime', ), + '1.3.6.1.5.5.7.4.15': ('id-it-origPKIMessage', ), + '1.3.6.1.5.5.7.4.16': ('id-it-suppLangTags', ), + '1.3.6.1.5.5.7.5': ('id-pkip', ), + '1.3.6.1.5.5.7.5.1': ('id-regCtrl', ), + '1.3.6.1.5.5.7.5.1.1': ('id-regCtrl-regToken', ), + '1.3.6.1.5.5.7.5.1.2': ('id-regCtrl-authenticator', ), + '1.3.6.1.5.5.7.5.1.3': ('id-regCtrl-pkiPublicationInfo', ), + '1.3.6.1.5.5.7.5.1.4': ('id-regCtrl-pkiArchiveOptions', ), + '1.3.6.1.5.5.7.5.1.5': ('id-regCtrl-oldCertID', ), + '1.3.6.1.5.5.7.5.1.6': ('id-regCtrl-protocolEncrKey', ), + '1.3.6.1.5.5.7.5.2': ('id-regInfo', ), + '1.3.6.1.5.5.7.5.2.1': ('id-regInfo-utf8Pairs', ), + '1.3.6.1.5.5.7.5.2.2': ('id-regInfo-certReq', ), + '1.3.6.1.5.5.7.6': ('id-alg', ), + '1.3.6.1.5.5.7.6.1': ('id-alg-des40', ), + '1.3.6.1.5.5.7.6.2': ('id-alg-noSignature', ), + '1.3.6.1.5.5.7.6.3': ('id-alg-dh-sig-hmac-sha1', ), + '1.3.6.1.5.5.7.6.4': ('id-alg-dh-pop', ), + '1.3.6.1.5.5.7.7': ('id-cmc', ), + '1.3.6.1.5.5.7.7.1': ('id-cmc-statusInfo', ), + '1.3.6.1.5.5.7.7.2': ('id-cmc-identification', ), + '1.3.6.1.5.5.7.7.3': ('id-cmc-identityProof', ), + '1.3.6.1.5.5.7.7.4': ('id-cmc-dataReturn', ), + '1.3.6.1.5.5.7.7.5': ('id-cmc-transactionId', ), + '1.3.6.1.5.5.7.7.6': ('id-cmc-senderNonce', ), + '1.3.6.1.5.5.7.7.7': ('id-cmc-recipientNonce', ), + '1.3.6.1.5.5.7.7.8': ('id-cmc-addExtensions', ), + '1.3.6.1.5.5.7.7.9': ('id-cmc-encryptedPOP', ), + '1.3.6.1.5.5.7.7.10': ('id-cmc-decryptedPOP', ), + '1.3.6.1.5.5.7.7.11': ('id-cmc-lraPOPWitness', ), + '1.3.6.1.5.5.7.7.15': ('id-cmc-getCert', ), + '1.3.6.1.5.5.7.7.16': ('id-cmc-getCRL', ), + '1.3.6.1.5.5.7.7.17': ('id-cmc-revokeRequest', ), + '1.3.6.1.5.5.7.7.18': ('id-cmc-regInfo', ), + '1.3.6.1.5.5.7.7.19': ('id-cmc-responseInfo', ), + '1.3.6.1.5.5.7.7.21': ('id-cmc-queryPending', ), + '1.3.6.1.5.5.7.7.22': ('id-cmc-popLinkRandom', ), + '1.3.6.1.5.5.7.7.23': ('id-cmc-popLinkWitness', ), + '1.3.6.1.5.5.7.7.24': ('id-cmc-confirmCertAcceptance', ), + '1.3.6.1.5.5.7.8': ('id-on', ), + '1.3.6.1.5.5.7.8.1': ('id-on-personalData', ), + '1.3.6.1.5.5.7.8.3': ('Permanent Identifier', 'id-on-permanentIdentifier'), + '1.3.6.1.5.5.7.9': ('id-pda', ), + '1.3.6.1.5.5.7.9.1': ('id-pda-dateOfBirth', ), + '1.3.6.1.5.5.7.9.2': ('id-pda-placeOfBirth', ), + '1.3.6.1.5.5.7.9.3': ('id-pda-gender', ), + '1.3.6.1.5.5.7.9.4': ('id-pda-countryOfCitizenship', ), + '1.3.6.1.5.5.7.9.5': ('id-pda-countryOfResidence', ), + '1.3.6.1.5.5.7.10': ('id-aca', ), + '1.3.6.1.5.5.7.10.1': ('id-aca-authenticationInfo', ), + '1.3.6.1.5.5.7.10.2': ('id-aca-accessIdentity', ), + '1.3.6.1.5.5.7.10.3': ('id-aca-chargingIdentity', ), + '1.3.6.1.5.5.7.10.4': ('id-aca-group', ), + '1.3.6.1.5.5.7.10.5': ('id-aca-role', ), + '1.3.6.1.5.5.7.10.6': ('id-aca-encAttrs', ), + '1.3.6.1.5.5.7.11': ('id-qcs', ), + '1.3.6.1.5.5.7.11.1': ('id-qcs-pkixQCSyntax-v1', ), + '1.3.6.1.5.5.7.12': ('id-cct', ), + '1.3.6.1.5.5.7.12.1': ('id-cct-crs', ), + '1.3.6.1.5.5.7.12.2': ('id-cct-PKIData', ), + '1.3.6.1.5.5.7.12.3': ('id-cct-PKIResponse', ), + '1.3.6.1.5.5.7.21': ('id-ppl', ), + '1.3.6.1.5.5.7.21.0': ('Any language', 'id-ppl-anyLanguage'), + '1.3.6.1.5.5.7.21.1': ('Inherit all', 'id-ppl-inheritAll'), + '1.3.6.1.5.5.7.21.2': ('Independent', 'id-ppl-independent'), + '1.3.6.1.5.5.7.48': ('id-ad', ), + '1.3.6.1.5.5.7.48.1': ('OCSP', 'OCSP', 'id-pkix-OCSP'), + '1.3.6.1.5.5.7.48.1.1': ('Basic OCSP Response', 'basicOCSPResponse'), + '1.3.6.1.5.5.7.48.1.2': ('OCSP Nonce', 'Nonce'), + '1.3.6.1.5.5.7.48.1.3': ('OCSP CRL ID', 'CrlID'), + '1.3.6.1.5.5.7.48.1.4': ('Acceptable OCSP Responses', 'acceptableResponses'), + '1.3.6.1.5.5.7.48.1.5': ('OCSP No Check', 'noCheck'), + '1.3.6.1.5.5.7.48.1.6': ('OCSP Archive Cutoff', 'archiveCutoff'), + '1.3.6.1.5.5.7.48.1.7': ('OCSP Service Locator', 'serviceLocator'), + '1.3.6.1.5.5.7.48.1.8': ('Extended OCSP Status', 'extendedStatus'), + '1.3.6.1.5.5.7.48.1.9': ('valid', ), + '1.3.6.1.5.5.7.48.1.10': ('path', ), + '1.3.6.1.5.5.7.48.1.11': ('Trust Root', 'trustRoot'), + '1.3.6.1.5.5.7.48.2': ('CA Issuers', 'caIssuers'), + '1.3.6.1.5.5.7.48.3': ('AD Time Stamping', 'ad_timestamping'), + '1.3.6.1.5.5.7.48.4': ('ad dvcs', 'AD_DVCS'), + '1.3.6.1.5.5.7.48.5': ('CA Repository', 'caRepository'), + '1.3.6.1.5.5.8.1.1': ('hmac-md5', 'HMAC-MD5'), + '1.3.6.1.5.5.8.1.2': ('hmac-sha1', 'HMAC-SHA1'), + '1.3.6.1.6': ('SNMPv2', 'snmpv2'), + '1.3.6.1.7': ('Mail', ), + '1.3.6.1.7.1': ('MIME MHS', 'mime-mhs'), + '1.3.6.1.7.1.1': ('mime-mhs-headings', 'mime-mhs-headings'), + '1.3.6.1.7.1.1.1': ('id-hex-partial-message', 'id-hex-partial-message'), + '1.3.6.1.7.1.1.2': ('id-hex-multipart-message', 'id-hex-multipart-message'), + '1.3.6.1.7.1.2': ('mime-mhs-bodies', 'mime-mhs-bodies'), + '1.3.14.3.2': ('algorithm', 'algorithm'), + '1.3.14.3.2.3': ('md5WithRSA', 'RSA-NP-MD5'), + '1.3.14.3.2.6': ('des-ecb', 'DES-ECB'), + '1.3.14.3.2.7': ('des-cbc', 'DES-CBC'), + '1.3.14.3.2.8': ('des-ofb', 'DES-OFB'), + '1.3.14.3.2.9': ('des-cfb', 'DES-CFB'), + '1.3.14.3.2.11': ('rsaSignature', ), + '1.3.14.3.2.12': ('dsaEncryption-old', 'DSA-old'), + '1.3.14.3.2.13': ('dsaWithSHA', 'DSA-SHA'), + '1.3.14.3.2.15': ('shaWithRSAEncryption', 'RSA-SHA'), + '1.3.14.3.2.17': ('des-ede', 'DES-EDE'), + '1.3.14.3.2.18': ('sha', 'SHA'), + '1.3.14.3.2.26': ('sha1', 'SHA1'), + '1.3.14.3.2.27': ('dsaWithSHA1-old', 'DSA-SHA1-old'), + '1.3.14.3.2.29': ('sha1WithRSA', 'RSA-SHA1-2'), + '1.3.36.3.2.1': ('ripemd160', 'RIPEMD160'), + '1.3.36.3.3.1.2': ('ripemd160WithRSA', 'RSA-RIPEMD160'), + '1.3.36.3.3.2.8.1.1.1': ('brainpoolP160r1', ), + '1.3.36.3.3.2.8.1.1.2': ('brainpoolP160t1', ), + '1.3.36.3.3.2.8.1.1.3': ('brainpoolP192r1', ), + '1.3.36.3.3.2.8.1.1.4': ('brainpoolP192t1', ), + '1.3.36.3.3.2.8.1.1.5': ('brainpoolP224r1', ), + '1.3.36.3.3.2.8.1.1.6': ('brainpoolP224t1', ), + '1.3.36.3.3.2.8.1.1.7': ('brainpoolP256r1', ), + '1.3.36.3.3.2.8.1.1.8': ('brainpoolP256t1', ), + '1.3.36.3.3.2.8.1.1.9': ('brainpoolP320r1', ), + '1.3.36.3.3.2.8.1.1.10': ('brainpoolP320t1', ), + '1.3.36.3.3.2.8.1.1.11': ('brainpoolP384r1', ), + '1.3.36.3.3.2.8.1.1.12': ('brainpoolP384t1', ), + '1.3.36.3.3.2.8.1.1.13': ('brainpoolP512r1', ), + '1.3.36.3.3.2.8.1.1.14': ('brainpoolP512t1', ), + '1.3.36.8.3.3': ('Professional Information or basis for Admission', 'x509ExtAdmission'), + '1.3.101.1.4.1': ('Strong Extranet ID', 'SXNetID'), + '1.3.101.110': ('X25519', ), + '1.3.101.111': ('X448', ), + '1.3.101.112': ('ED25519', ), + '1.3.101.113': ('ED448', ), + '1.3.111': ('ieee', ), + '1.3.111.2.1619': ('IEEE Security in Storage Working Group', 'ieee-siswg'), + '1.3.111.2.1619.0.1.1': ('aes-128-xts', 'AES-128-XTS'), + '1.3.111.2.1619.0.1.2': ('aes-256-xts', 'AES-256-XTS'), + '1.3.132': ('certicom-arc', ), + '1.3.132.0': ('secg_ellipticCurve', ), + '1.3.132.0.1': ('sect163k1', ), + '1.3.132.0.2': ('sect163r1', ), + '1.3.132.0.3': ('sect239k1', ), + '1.3.132.0.4': ('sect113r1', ), + '1.3.132.0.5': ('sect113r2', ), + '1.3.132.0.6': ('secp112r1', ), + '1.3.132.0.7': ('secp112r2', ), + '1.3.132.0.8': ('secp160r1', ), + '1.3.132.0.9': ('secp160k1', ), + '1.3.132.0.10': ('secp256k1', ), + '1.3.132.0.15': ('sect163r2', ), + '1.3.132.0.16': ('sect283k1', ), + '1.3.132.0.17': ('sect283r1', ), + '1.3.132.0.22': ('sect131r1', ), + '1.3.132.0.23': ('sect131r2', ), + '1.3.132.0.24': ('sect193r1', ), + '1.3.132.0.25': ('sect193r2', ), + '1.3.132.0.26': ('sect233k1', ), + '1.3.132.0.27': ('sect233r1', ), + '1.3.132.0.28': ('secp128r1', ), + '1.3.132.0.29': ('secp128r2', ), + '1.3.132.0.30': ('secp160r2', ), + '1.3.132.0.31': ('secp192k1', ), + '1.3.132.0.32': ('secp224k1', ), + '1.3.132.0.33': ('secp224r1', ), + '1.3.132.0.34': ('secp384r1', ), + '1.3.132.0.35': ('secp521r1', ), + '1.3.132.0.36': ('sect409k1', ), + '1.3.132.0.37': ('sect409r1', ), + '1.3.132.0.38': ('sect571k1', ), + '1.3.132.0.39': ('sect571r1', ), + '1.3.132.1': ('secg-scheme', ), + '1.3.132.1.11.0': ('dhSinglePass-stdDH-sha224kdf-scheme', ), + '1.3.132.1.11.1': ('dhSinglePass-stdDH-sha256kdf-scheme', ), + '1.3.132.1.11.2': ('dhSinglePass-stdDH-sha384kdf-scheme', ), + '1.3.132.1.11.3': ('dhSinglePass-stdDH-sha512kdf-scheme', ), + '1.3.132.1.14.0': ('dhSinglePass-cofactorDH-sha224kdf-scheme', ), + '1.3.132.1.14.1': ('dhSinglePass-cofactorDH-sha256kdf-scheme', ), + '1.3.132.1.14.2': ('dhSinglePass-cofactorDH-sha384kdf-scheme', ), + '1.3.132.1.14.3': ('dhSinglePass-cofactorDH-sha512kdf-scheme', ), + '1.3.133.16.840.63.0': ('x9-63-scheme', ), + '1.3.133.16.840.63.0.2': ('dhSinglePass-stdDH-sha1kdf-scheme', ), + '1.3.133.16.840.63.0.3': ('dhSinglePass-cofactorDH-sha1kdf-scheme', ), + '2': ('joint-iso-itu-t', 'JOINT-ISO-ITU-T', 'joint-iso-ccitt'), + '2.5': ('directory services (X.500)', 'X500'), + '2.5.1.5': ('Selected Attribute Types', 'selected-attribute-types'), + '2.5.1.5.55': ('clearance', ), + '2.5.4': ('X509', ), + '2.5.4.3': ('commonName', 'CN'), + '2.5.4.4': ('surname', 'SN'), + '2.5.4.5': ('serialNumber', ), + '2.5.4.6': ('countryName', 'C'), + '2.5.4.7': ('localityName', 'L'), + '2.5.4.8': ('stateOrProvinceName', 'ST'), + '2.5.4.9': ('streetAddress', 'street'), + '2.5.4.10': ('organizationName', 'O'), + '2.5.4.11': ('organizationalUnitName', 'OU'), + '2.5.4.12': ('title', 'title'), + '2.5.4.13': ('description', ), + '2.5.4.14': ('searchGuide', ), + '2.5.4.15': ('businessCategory', ), + '2.5.4.16': ('postalAddress', ), + '2.5.4.17': ('postalCode', ), + '2.5.4.18': ('postOfficeBox', ), + '2.5.4.19': ('physicalDeliveryOfficeName', ), + '2.5.4.20': ('telephoneNumber', ), + '2.5.4.21': ('telexNumber', ), + '2.5.4.22': ('teletexTerminalIdentifier', ), + '2.5.4.23': ('facsimileTelephoneNumber', ), + '2.5.4.24': ('x121Address', ), + '2.5.4.25': ('internationaliSDNNumber', ), + '2.5.4.26': ('registeredAddress', ), + '2.5.4.27': ('destinationIndicator', ), + '2.5.4.28': ('preferredDeliveryMethod', ), + '2.5.4.29': ('presentationAddress', ), + '2.5.4.30': ('supportedApplicationContext', ), + '2.5.4.31': ('member', ), + '2.5.4.32': ('owner', ), + '2.5.4.33': ('roleOccupant', ), + '2.5.4.34': ('seeAlso', ), + '2.5.4.35': ('userPassword', ), + '2.5.4.36': ('userCertificate', ), + '2.5.4.37': ('cACertificate', ), + '2.5.4.38': ('authorityRevocationList', ), + '2.5.4.39': ('certificateRevocationList', ), + '2.5.4.40': ('crossCertificatePair', ), + '2.5.4.41': ('name', 'name'), + '2.5.4.42': ('givenName', 'GN'), + '2.5.4.43': ('initials', 'initials'), + '2.5.4.44': ('generationQualifier', ), + '2.5.4.45': ('x500UniqueIdentifier', ), + '2.5.4.46': ('dnQualifier', 'dnQualifier'), + '2.5.4.47': ('enhancedSearchGuide', ), + '2.5.4.48': ('protocolInformation', ), + '2.5.4.49': ('distinguishedName', ), + '2.5.4.50': ('uniqueMember', ), + '2.5.4.51': ('houseIdentifier', ), + '2.5.4.52': ('supportedAlgorithms', ), + '2.5.4.53': ('deltaRevocationList', ), + '2.5.4.54': ('dmdName', ), + '2.5.4.65': ('pseudonym', ), + '2.5.4.72': ('role', 'role'), + '2.5.4.97': ('organizationIdentifier', ), + '2.5.4.98': ('countryCode3c', 'c3'), + '2.5.4.99': ('countryCode3n', 'n3'), + '2.5.4.100': ('dnsName', ), + '2.5.8': ('directory services - algorithms', 'X500algorithms'), + '2.5.8.1.1': ('rsa', 'RSA'), + '2.5.8.3.100': ('mdc2WithRSA', 'RSA-MDC2'), + '2.5.8.3.101': ('mdc2', 'MDC2'), + '2.5.29': ('id-ce', ), + '2.5.29.9': ('X509v3 Subject Directory Attributes', 'subjectDirectoryAttributes'), + '2.5.29.14': ('X509v3 Subject Key Identifier', 'subjectKeyIdentifier'), + '2.5.29.15': ('X509v3 Key Usage', 'keyUsage'), + '2.5.29.16': ('X509v3 Private Key Usage Period', 'privateKeyUsagePeriod'), + '2.5.29.17': ('X509v3 Subject Alternative Name', 'subjectAltName'), + '2.5.29.18': ('X509v3 Issuer Alternative Name', 'issuerAltName'), + '2.5.29.19': ('X509v3 Basic Constraints', 'basicConstraints'), + '2.5.29.20': ('X509v3 CRL Number', 'crlNumber'), + '2.5.29.21': ('X509v3 CRL Reason Code', 'CRLReason'), + '2.5.29.23': ('Hold Instruction Code', 'holdInstructionCode'), + '2.5.29.24': ('Invalidity Date', 'invalidityDate'), + '2.5.29.27': ('X509v3 Delta CRL Indicator', 'deltaCRL'), + '2.5.29.28': ('X509v3 Issuing Distribution Point', 'issuingDistributionPoint'), + '2.5.29.29': ('X509v3 Certificate Issuer', 'certificateIssuer'), + '2.5.29.30': ('X509v3 Name Constraints', 'nameConstraints'), + '2.5.29.31': ('X509v3 CRL Distribution Points', 'crlDistributionPoints'), + '2.5.29.32': ('X509v3 Certificate Policies', 'certificatePolicies'), + '2.5.29.32.0': ('X509v3 Any Policy', 'anyPolicy'), + '2.5.29.33': ('X509v3 Policy Mappings', 'policyMappings'), + '2.5.29.35': ('X509v3 Authority Key Identifier', 'authorityKeyIdentifier'), + '2.5.29.36': ('X509v3 Policy Constraints', 'policyConstraints'), + '2.5.29.37': ('X509v3 Extended Key Usage', 'extendedKeyUsage'), + '2.5.29.37.0': ('Any Extended Key Usage', 'anyExtendedKeyUsage'), + '2.5.29.46': ('X509v3 Freshest CRL', 'freshestCRL'), + '2.5.29.54': ('X509v3 Inhibit Any Policy', 'inhibitAnyPolicy'), + '2.5.29.55': ('X509v3 AC Targeting', 'targetInformation'), + '2.5.29.56': ('X509v3 No Revocation Available', 'noRevAvail'), + '2.16.840.1.101.3': ('csor', ), + '2.16.840.1.101.3.4': ('nistAlgorithms', ), + '2.16.840.1.101.3.4.1': ('aes', ), + '2.16.840.1.101.3.4.1.1': ('aes-128-ecb', 'AES-128-ECB'), + '2.16.840.1.101.3.4.1.2': ('aes-128-cbc', 'AES-128-CBC'), + '2.16.840.1.101.3.4.1.3': ('aes-128-ofb', 'AES-128-OFB'), + '2.16.840.1.101.3.4.1.4': ('aes-128-cfb', 'AES-128-CFB'), + '2.16.840.1.101.3.4.1.5': ('id-aes128-wrap', ), + '2.16.840.1.101.3.4.1.6': ('aes-128-gcm', 'id-aes128-GCM'), + '2.16.840.1.101.3.4.1.7': ('aes-128-ccm', 'id-aes128-CCM'), + '2.16.840.1.101.3.4.1.8': ('id-aes128-wrap-pad', ), + '2.16.840.1.101.3.4.1.21': ('aes-192-ecb', 'AES-192-ECB'), + '2.16.840.1.101.3.4.1.22': ('aes-192-cbc', 'AES-192-CBC'), + '2.16.840.1.101.3.4.1.23': ('aes-192-ofb', 'AES-192-OFB'), + '2.16.840.1.101.3.4.1.24': ('aes-192-cfb', 'AES-192-CFB'), + '2.16.840.1.101.3.4.1.25': ('id-aes192-wrap', ), + '2.16.840.1.101.3.4.1.26': ('aes-192-gcm', 'id-aes192-GCM'), + '2.16.840.1.101.3.4.1.27': ('aes-192-ccm', 'id-aes192-CCM'), + '2.16.840.1.101.3.4.1.28': ('id-aes192-wrap-pad', ), + '2.16.840.1.101.3.4.1.41': ('aes-256-ecb', 'AES-256-ECB'), + '2.16.840.1.101.3.4.1.42': ('aes-256-cbc', 'AES-256-CBC'), + '2.16.840.1.101.3.4.1.43': ('aes-256-ofb', 'AES-256-OFB'), + '2.16.840.1.101.3.4.1.44': ('aes-256-cfb', 'AES-256-CFB'), + '2.16.840.1.101.3.4.1.45': ('id-aes256-wrap', ), + '2.16.840.1.101.3.4.1.46': ('aes-256-gcm', 'id-aes256-GCM'), + '2.16.840.1.101.3.4.1.47': ('aes-256-ccm', 'id-aes256-CCM'), + '2.16.840.1.101.3.4.1.48': ('id-aes256-wrap-pad', ), + '2.16.840.1.101.3.4.2': ('nist_hashalgs', ), + '2.16.840.1.101.3.4.2.1': ('sha256', 'SHA256'), + '2.16.840.1.101.3.4.2.2': ('sha384', 'SHA384'), + '2.16.840.1.101.3.4.2.3': ('sha512', 'SHA512'), + '2.16.840.1.101.3.4.2.4': ('sha224', 'SHA224'), + '2.16.840.1.101.3.4.2.5': ('sha512-224', 'SHA512-224'), + '2.16.840.1.101.3.4.2.6': ('sha512-256', 'SHA512-256'), + '2.16.840.1.101.3.4.2.7': ('sha3-224', 'SHA3-224'), + '2.16.840.1.101.3.4.2.8': ('sha3-256', 'SHA3-256'), + '2.16.840.1.101.3.4.2.9': ('sha3-384', 'SHA3-384'), + '2.16.840.1.101.3.4.2.10': ('sha3-512', 'SHA3-512'), + '2.16.840.1.101.3.4.2.11': ('shake128', 'SHAKE128'), + '2.16.840.1.101.3.4.2.12': ('shake256', 'SHAKE256'), + '2.16.840.1.101.3.4.2.13': ('hmac-sha3-224', 'id-hmacWithSHA3-224'), + '2.16.840.1.101.3.4.2.14': ('hmac-sha3-256', 'id-hmacWithSHA3-256'), + '2.16.840.1.101.3.4.2.15': ('hmac-sha3-384', 'id-hmacWithSHA3-384'), + '2.16.840.1.101.3.4.2.16': ('hmac-sha3-512', 'id-hmacWithSHA3-512'), + '2.16.840.1.101.3.4.3': ('dsa_with_sha2', 'sigAlgs'), + '2.16.840.1.101.3.4.3.1': ('dsa_with_SHA224', ), + '2.16.840.1.101.3.4.3.2': ('dsa_with_SHA256', ), + '2.16.840.1.101.3.4.3.3': ('dsa_with_SHA384', 'id-dsa-with-sha384'), + '2.16.840.1.101.3.4.3.4': ('dsa_with_SHA512', 'id-dsa-with-sha512'), + '2.16.840.1.101.3.4.3.5': ('dsa_with_SHA3-224', 'id-dsa-with-sha3-224'), + '2.16.840.1.101.3.4.3.6': ('dsa_with_SHA3-256', 'id-dsa-with-sha3-256'), + '2.16.840.1.101.3.4.3.7': ('dsa_with_SHA3-384', 'id-dsa-with-sha3-384'), + '2.16.840.1.101.3.4.3.8': ('dsa_with_SHA3-512', 'id-dsa-with-sha3-512'), + '2.16.840.1.101.3.4.3.9': ('ecdsa_with_SHA3-224', 'id-ecdsa-with-sha3-224'), + '2.16.840.1.101.3.4.3.10': ('ecdsa_with_SHA3-256', 'id-ecdsa-with-sha3-256'), + '2.16.840.1.101.3.4.3.11': ('ecdsa_with_SHA3-384', 'id-ecdsa-with-sha3-384'), + '2.16.840.1.101.3.4.3.12': ('ecdsa_with_SHA3-512', 'id-ecdsa-with-sha3-512'), + '2.16.840.1.101.3.4.3.13': ('RSA-SHA3-224', 'id-rsassa-pkcs1-v1_5-with-sha3-224'), + '2.16.840.1.101.3.4.3.14': ('RSA-SHA3-256', 'id-rsassa-pkcs1-v1_5-with-sha3-256'), + '2.16.840.1.101.3.4.3.15': ('RSA-SHA3-384', 'id-rsassa-pkcs1-v1_5-with-sha3-384'), + '2.16.840.1.101.3.4.3.16': ('RSA-SHA3-512', 'id-rsassa-pkcs1-v1_5-with-sha3-512'), + '2.16.840.1.113730': ('Netscape Communications Corp.', 'Netscape'), + '2.16.840.1.113730.1': ('Netscape Certificate Extension', 'nsCertExt'), + '2.16.840.1.113730.1.1': ('Netscape Cert Type', 'nsCertType'), + '2.16.840.1.113730.1.2': ('Netscape Base Url', 'nsBaseUrl'), + '2.16.840.1.113730.1.3': ('Netscape Revocation Url', 'nsRevocationUrl'), + '2.16.840.1.113730.1.4': ('Netscape CA Revocation Url', 'nsCaRevocationUrl'), + '2.16.840.1.113730.1.7': ('Netscape Renewal Url', 'nsRenewalUrl'), + '2.16.840.1.113730.1.8': ('Netscape CA Policy Url', 'nsCaPolicyUrl'), + '2.16.840.1.113730.1.12': ('Netscape SSL Server Name', 'nsSslServerName'), + '2.16.840.1.113730.1.13': ('Netscape Comment', 'nsComment'), + '2.16.840.1.113730.2': ('Netscape Data Type', 'nsDataType'), + '2.16.840.1.113730.2.5': ('Netscape Certificate Sequence', 'nsCertSequence'), + '2.16.840.1.113730.4.1': ('Netscape Server Gated Crypto', 'nsSGC'), + '2.23': ('International Organizations', 'international-organizations'), + '2.23.42': ('Secure Electronic Transactions', 'id-set'), + '2.23.42.0': ('content types', 'set-ctype'), + '2.23.42.0.0': ('setct-PANData', ), + '2.23.42.0.1': ('setct-PANToken', ), + '2.23.42.0.2': ('setct-PANOnly', ), + '2.23.42.0.3': ('setct-OIData', ), + '2.23.42.0.4': ('setct-PI', ), + '2.23.42.0.5': ('setct-PIData', ), + '2.23.42.0.6': ('setct-PIDataUnsigned', ), + '2.23.42.0.7': ('setct-HODInput', ), + '2.23.42.0.8': ('setct-AuthResBaggage', ), + '2.23.42.0.9': ('setct-AuthRevReqBaggage', ), + '2.23.42.0.10': ('setct-AuthRevResBaggage', ), + '2.23.42.0.11': ('setct-CapTokenSeq', ), + '2.23.42.0.12': ('setct-PInitResData', ), + '2.23.42.0.13': ('setct-PI-TBS', ), + '2.23.42.0.14': ('setct-PResData', ), + '2.23.42.0.16': ('setct-AuthReqTBS', ), + '2.23.42.0.17': ('setct-AuthResTBS', ), + '2.23.42.0.18': ('setct-AuthResTBSX', ), + '2.23.42.0.19': ('setct-AuthTokenTBS', ), + '2.23.42.0.20': ('setct-CapTokenData', ), + '2.23.42.0.21': ('setct-CapTokenTBS', ), + '2.23.42.0.22': ('setct-AcqCardCodeMsg', ), + '2.23.42.0.23': ('setct-AuthRevReqTBS', ), + '2.23.42.0.24': ('setct-AuthRevResData', ), + '2.23.42.0.25': ('setct-AuthRevResTBS', ), + '2.23.42.0.26': ('setct-CapReqTBS', ), + '2.23.42.0.27': ('setct-CapReqTBSX', ), + '2.23.42.0.28': ('setct-CapResData', ), + '2.23.42.0.29': ('setct-CapRevReqTBS', ), + '2.23.42.0.30': ('setct-CapRevReqTBSX', ), + '2.23.42.0.31': ('setct-CapRevResData', ), + '2.23.42.0.32': ('setct-CredReqTBS', ), + '2.23.42.0.33': ('setct-CredReqTBSX', ), + '2.23.42.0.34': ('setct-CredResData', ), + '2.23.42.0.35': ('setct-CredRevReqTBS', ), + '2.23.42.0.36': ('setct-CredRevReqTBSX', ), + '2.23.42.0.37': ('setct-CredRevResData', ), + '2.23.42.0.38': ('setct-PCertReqData', ), + '2.23.42.0.39': ('setct-PCertResTBS', ), + '2.23.42.0.40': ('setct-BatchAdminReqData', ), + '2.23.42.0.41': ('setct-BatchAdminResData', ), + '2.23.42.0.42': ('setct-CardCInitResTBS', ), + '2.23.42.0.43': ('setct-MeAqCInitResTBS', ), + '2.23.42.0.44': ('setct-RegFormResTBS', ), + '2.23.42.0.45': ('setct-CertReqData', ), + '2.23.42.0.46': ('setct-CertReqTBS', ), + '2.23.42.0.47': ('setct-CertResData', ), + '2.23.42.0.48': ('setct-CertInqReqTBS', ), + '2.23.42.0.49': ('setct-ErrorTBS', ), + '2.23.42.0.50': ('setct-PIDualSignedTBE', ), + '2.23.42.0.51': ('setct-PIUnsignedTBE', ), + '2.23.42.0.52': ('setct-AuthReqTBE', ), + '2.23.42.0.53': ('setct-AuthResTBE', ), + '2.23.42.0.54': ('setct-AuthResTBEX', ), + '2.23.42.0.55': ('setct-AuthTokenTBE', ), + '2.23.42.0.56': ('setct-CapTokenTBE', ), + '2.23.42.0.57': ('setct-CapTokenTBEX', ), + '2.23.42.0.58': ('setct-AcqCardCodeMsgTBE', ), + '2.23.42.0.59': ('setct-AuthRevReqTBE', ), + '2.23.42.0.60': ('setct-AuthRevResTBE', ), + '2.23.42.0.61': ('setct-AuthRevResTBEB', ), + '2.23.42.0.62': ('setct-CapReqTBE', ), + '2.23.42.0.63': ('setct-CapReqTBEX', ), + '2.23.42.0.64': ('setct-CapResTBE', ), + '2.23.42.0.65': ('setct-CapRevReqTBE', ), + '2.23.42.0.66': ('setct-CapRevReqTBEX', ), + '2.23.42.0.67': ('setct-CapRevResTBE', ), + '2.23.42.0.68': ('setct-CredReqTBE', ), + '2.23.42.0.69': ('setct-CredReqTBEX', ), + '2.23.42.0.70': ('setct-CredResTBE', ), + '2.23.42.0.71': ('setct-CredRevReqTBE', ), + '2.23.42.0.72': ('setct-CredRevReqTBEX', ), + '2.23.42.0.73': ('setct-CredRevResTBE', ), + '2.23.42.0.74': ('setct-BatchAdminReqTBE', ), + '2.23.42.0.75': ('setct-BatchAdminResTBE', ), + '2.23.42.0.76': ('setct-RegFormReqTBE', ), + '2.23.42.0.77': ('setct-CertReqTBE', ), + '2.23.42.0.78': ('setct-CertReqTBEX', ), + '2.23.42.0.79': ('setct-CertResTBE', ), + '2.23.42.0.80': ('setct-CRLNotificationTBS', ), + '2.23.42.0.81': ('setct-CRLNotificationResTBS', ), + '2.23.42.0.82': ('setct-BCIDistributionTBS', ), + '2.23.42.1': ('message extensions', 'set-msgExt'), + '2.23.42.1.1': ('generic cryptogram', 'setext-genCrypt'), + '2.23.42.1.3': ('merchant initiated auth', 'setext-miAuth'), + '2.23.42.1.4': ('setext-pinSecure', ), + '2.23.42.1.5': ('setext-pinAny', ), + '2.23.42.1.7': ('setext-track2', ), + '2.23.42.1.8': ('additional verification', 'setext-cv'), + '2.23.42.3': ('set-attr', ), + '2.23.42.3.0': ('setAttr-Cert', ), + '2.23.42.3.0.0': ('set-rootKeyThumb', ), + '2.23.42.3.0.1': ('set-addPolicy', ), + '2.23.42.3.1': ('payment gateway capabilities', 'setAttr-PGWYcap'), + '2.23.42.3.2': ('setAttr-TokenType', ), + '2.23.42.3.2.1': ('setAttr-Token-EMV', ), + '2.23.42.3.2.2': ('setAttr-Token-B0Prime', ), + '2.23.42.3.3': ('issuer capabilities', 'setAttr-IssCap'), + '2.23.42.3.3.3': ('setAttr-IssCap-CVM', ), + '2.23.42.3.3.3.1': ('generate cryptogram', 'setAttr-GenCryptgrm'), + '2.23.42.3.3.4': ('setAttr-IssCap-T2', ), + '2.23.42.3.3.4.1': ('encrypted track 2', 'setAttr-T2Enc'), + '2.23.42.3.3.4.2': ('cleartext track 2', 'setAttr-T2cleartxt'), + '2.23.42.3.3.5': ('setAttr-IssCap-Sig', ), + '2.23.42.3.3.5.1': ('ICC or token signature', 'setAttr-TokICCsig'), + '2.23.42.3.3.5.2': ('secure device signature', 'setAttr-SecDevSig'), + '2.23.42.5': ('set-policy', ), + '2.23.42.5.0': ('set-policy-root', ), + '2.23.42.7': ('certificate extensions', 'set-certExt'), + '2.23.42.7.0': ('setCext-hashedRoot', ), + '2.23.42.7.1': ('setCext-certType', ), + '2.23.42.7.2': ('setCext-merchData', ), + '2.23.42.7.3': ('setCext-cCertRequired', ), + '2.23.42.7.4': ('setCext-tunneling', ), + '2.23.42.7.5': ('setCext-setExt', ), + '2.23.42.7.6': ('setCext-setQualf', ), + '2.23.42.7.7': ('setCext-PGWYcapabilities', ), + '2.23.42.7.8': ('setCext-TokenIdentifier', ), + '2.23.42.7.9': ('setCext-Track2Data', ), + '2.23.42.7.10': ('setCext-TokenType', ), + '2.23.42.7.11': ('setCext-IssuerCapabilities', ), + '2.23.42.8': ('set-brand', ), + '2.23.42.8.1': ('set-brand-IATA-ATA', ), + '2.23.42.8.4': ('set-brand-Visa', ), + '2.23.42.8.5': ('set-brand-MasterCard', ), + '2.23.42.8.30': ('set-brand-Diners', ), + '2.23.42.8.34': ('set-brand-AmericanExpress', ), + '2.23.42.8.35': ('set-brand-JCB', ), + '2.23.42.8.6011': ('set-brand-Novus', ), + '2.23.43': ('wap', ), + '2.23.43.1': ('wap-wsg', ), + '2.23.43.1.4': ('wap-wsg-idm-ecid', ), + '2.23.43.1.4.1': ('wap-wsg-idm-ecid-wtls1', ), + '2.23.43.1.4.3': ('wap-wsg-idm-ecid-wtls3', ), + '2.23.43.1.4.4': ('wap-wsg-idm-ecid-wtls4', ), + '2.23.43.1.4.5': ('wap-wsg-idm-ecid-wtls5', ), + '2.23.43.1.4.6': ('wap-wsg-idm-ecid-wtls6', ), + '2.23.43.1.4.7': ('wap-wsg-idm-ecid-wtls7', ), + '2.23.43.1.4.8': ('wap-wsg-idm-ecid-wtls8', ), + '2.23.43.1.4.9': ('wap-wsg-idm-ecid-wtls9', ), + '2.23.43.1.4.10': ('wap-wsg-idm-ecid-wtls10', ), + '2.23.43.1.4.11': ('wap-wsg-idm-ecid-wtls11', ), + '2.23.43.1.4.12': ('wap-wsg-idm-ecid-wtls12', ), +} +# ##################################################################################### +# ##################################################################################### + +_OID_LOOKUP = dict() +_NORMALIZE_NAMES = dict() +_NORMALIZE_NAMES_SHORT = dict() + +for dotted, names in _OID_MAP.items(): + for name in names: + if name in _NORMALIZE_NAMES and _OID_LOOKUP[name] != dotted: + raise AssertionError( + 'Name collision during setup: "{0}" for OIDs {1} and {2}' + .format(name, dotted, _OID_LOOKUP[name]) + ) + _NORMALIZE_NAMES[name] = names[0] + _NORMALIZE_NAMES_SHORT[name] = names[-1] + _OID_LOOKUP[name] = dotted +for alias, original in [('userID', 'userId')]: + if alias in _NORMALIZE_NAMES: + raise AssertionError( + 'Name collision during adding aliases: "{0}" (alias for "{1}") is already mapped to OID {2}' + .format(alias, original, _OID_LOOKUP[alias]) + ) + _NORMALIZE_NAMES[alias] = original + _NORMALIZE_NAMES_SHORT[alias] = _NORMALIZE_NAMES_SHORT[original] + _OID_LOOKUP[alias] = _OID_LOOKUP[original] + + +def pyopenssl_normalize_name(name, short=False): + nid = OpenSSL._util.lib.OBJ_txt2nid(to_bytes(name)) + if nid != 0: + b_name = OpenSSL._util.lib.OBJ_nid2ln(nid) + name = to_text(OpenSSL._util.ffi.string(b_name)) + if short: + return _NORMALIZE_NAMES_SHORT.get(name, name) + else: + return _NORMALIZE_NAMES.get(name, name) + + +# ##################################################################################### +# ##################################################################################### +# # This excerpt is dual licensed under the terms of the Apache License, Version +# # 2.0, and the BSD License. See the LICENSE file at +# # https://github.com/pyca/cryptography/blob/master/LICENSE for complete details. +# # +# # Adapted from cryptography's hazmat/backends/openssl/decode_asn1.py +# # +# # Copyright (c) 2015, 2016 Paul Kehrer (@reaperhulk) +# # Copyright (c) 2017 Fraser Tweedale (@frasertweedale) +# # +# # Relevant commits from cryptography project (https://github.com/pyca/cryptography): +# # pyca/cryptography@719d536dd691e84e208534798f2eb4f82aaa2e07 +# # pyca/cryptography@5ab6d6a5c05572bd1c75f05baf264a2d0001894a +# # pyca/cryptography@2e776e20eb60378e0af9b7439000d0e80da7c7e3 +# # pyca/cryptography@fb309ed24647d1be9e319b61b1f2aa8ebb87b90b +# # pyca/cryptography@2917e460993c475c72d7146c50dc3bbc2414280d +# # pyca/cryptography@3057f91ea9a05fb593825006d87a391286a4d828 +# # pyca/cryptography@d607dd7e5bc5c08854ec0c9baff70ba4a35be36f +def _obj2txt(openssl_lib, openssl_ffi, obj): + # Set to 80 on the recommendation of + # https://www.openssl.org/docs/crypto/OBJ_nid2ln.html#return_values + # + # But OIDs longer than this occur in real life (e.g. Active + # Directory makes some very long OIDs). So we need to detect + # and properly handle the case where the default buffer is not + # big enough. + # + buf_len = 80 + buf = openssl_ffi.new("char[]", buf_len) + + # 'res' is the number of bytes that *would* be written if the + # buffer is large enough. If 'res' > buf_len - 1, we need to + # alloc a big-enough buffer and go again. + res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1) + if res > buf_len - 1: # account for terminating null byte + buf_len = res + 1 + buf = openssl_ffi.new("char[]", buf_len) + res = openssl_lib.OBJ_obj2txt(buf, buf_len, obj, 1) + return openssl_ffi.buffer(buf, res)[:].decode() +# ##################################################################################### +# ##################################################################################### + + +def cryptography_get_extensions_from_cert(cert): + # Since cryptography won't give us the DER value for an extension + # (that is only stored for unrecognized extensions), we have to re-do + # the extension parsing outselves. + result = dict() + backend = cert._backend + x509_obj = cert._x509 + + for i in range(backend._lib.X509_get_ext_count(x509_obj)): + ext = backend._lib.X509_get_ext(x509_obj, i) + if ext == backend._ffi.NULL: + continue + crit = backend._lib.X509_EXTENSION_get_critical(ext) + data = backend._lib.X509_EXTENSION_get_data(ext) + backend.openssl_assert(data != backend._ffi.NULL) + der = backend._ffi.buffer(data.data, data.length)[:] + entry = dict( + critical=(crit == 1), + value=base64.b64encode(der), + ) + oid = _obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext)) + result[oid] = entry + return result + + +def cryptography_get_extensions_from_csr(csr): + # Since cryptography won't give us the DER value for an extension + # (that is only stored for unrecognized extensions), we have to re-do + # the extension parsing outselves. + result = dict() + backend = csr._backend + + extensions = backend._lib.X509_REQ_get_extensions(csr._x509_req) + extensions = backend._ffi.gc( + extensions, + lambda ext: backend._lib.sk_X509_EXTENSION_pop_free( + ext, + backend._ffi.addressof(backend._lib._original_lib, "X509_EXTENSION_free") + ) + ) + + for i in range(backend._lib.sk_X509_EXTENSION_num(extensions)): + ext = backend._lib.sk_X509_EXTENSION_value(extensions, i) + if ext == backend._ffi.NULL: + continue + crit = backend._lib.X509_EXTENSION_get_critical(ext) + data = backend._lib.X509_EXTENSION_get_data(ext) + backend.openssl_assert(data != backend._ffi.NULL) + der = backend._ffi.buffer(data.data, data.length)[:] + entry = dict( + critical=(crit == 1), + value=base64.b64encode(der), + ) + oid = _obj2txt(backend._lib, backend._ffi, backend._lib.X509_EXTENSION_get_object(ext)) + result[oid] = entry + return result + + +def pyopenssl_get_extensions_from_cert(cert): + # While pyOpenSSL allows us to get an extension's DER value, it won't + # give us the dotted string for an OID. So we have to do some magic to + # get hold of it. + result = dict() + ext_count = cert.get_extension_count() + for i in range(0, ext_count): + ext = cert.get_extension(i) + entry = dict( + critical=bool(ext.get_critical()), + value=base64.b64encode(ext.get_data()), + ) + oid = _obj2txt( + OpenSSL._util.lib, + OpenSSL._util.ffi, + OpenSSL._util.lib.X509_EXTENSION_get_object(ext._extension) + ) + # This could also be done a bit simpler: + # + # oid = _obj2txt(OpenSSL._util.lib, OpenSSL._util.ffi, OpenSSL._util.lib.OBJ_nid2obj(ext._nid)) + # + # Unfortunately this gives the wrong result in case the linked OpenSSL + # doesn't know the OID. That's why we have to get the OID dotted string + # similarly to how cryptography does it. + result[oid] = entry + return result + + +def pyopenssl_get_extensions_from_csr(csr): + # While pyOpenSSL allows us to get an extension's DER value, it won't + # give us the dotted string for an OID. So we have to do some magic to + # get hold of it. + result = dict() + for ext in csr.get_extensions(): + entry = dict( + critical=bool(ext.get_critical()), + value=base64.b64encode(ext.get_data()), + ) + oid = _obj2txt( + OpenSSL._util.lib, + OpenSSL._util.ffi, + OpenSSL._util.lib.X509_EXTENSION_get_object(ext._extension) + ) + # This could also be done a bit simpler: + # + # oid = _obj2txt(OpenSSL._util.lib, OpenSSL._util.ffi, OpenSSL._util.lib.OBJ_nid2obj(ext._nid)) + # + # Unfortunately this gives the wrong result in case the linked OpenSSL + # doesn't know the OID. That's why we have to get the OID dotted string + # similarly to how cryptography does it. + result[oid] = entry + return result + + +def cryptography_name_to_oid(name): + dotted = _OID_LOOKUP.get(name) + if dotted is None: + raise OpenSSLObjectError('Cannot find OID for "{0}"'.format(name)) + return x509.oid.ObjectIdentifier(dotted) + + +def cryptography_oid_to_name(oid, short=False): + dotted_string = oid.dotted_string + names = _OID_MAP.get(dotted_string) + name = names[0] if names else oid._name + if short: + return _NORMALIZE_NAMES_SHORT.get(name, name) + else: + return _NORMALIZE_NAMES.get(name, name) + + +def cryptography_get_name(name): + ''' + Given a name string, returns a cryptography x509.Name object. + Raises an OpenSSLObjectError if the name is unknown or cannot be parsed. + ''' + try: + if name.startswith('DNS:'): + return x509.DNSName(to_text(name[4:])) + if name.startswith('IP:'): + return x509.IPAddress(ipaddress.ip_address(to_text(name[3:]))) + if name.startswith('email:'): + return x509.RFC822Name(to_text(name[6:])) + if name.startswith('URI:'): + return x509.UniformResourceIdentifier(to_text(name[4:])) + except Exception as e: + raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}": {1}'.format(name, e)) + if ':' not in name: + raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}" (forgot "DNS:" prefix?)'.format(name)) + raise OpenSSLObjectError('Cannot parse Subject Alternative Name "{0}" (potentially unsupported by cryptography backend)'.format(name)) + + +def _get_hex(bytesstr): + if bytesstr is None: + return bytesstr + data = binascii.hexlify(bytesstr) + data = to_text(b':'.join(data[i:i + 2] for i in range(0, len(data), 2))) + return data + + +def cryptography_decode_name(name): + ''' + Given a cryptography x509.Name object, returns a string. + Raises an OpenSSLObjectError if the name is not supported. + ''' + if isinstance(name, x509.DNSName): + return 'DNS:{0}'.format(name.value) + if isinstance(name, x509.IPAddress): + return 'IP:{0}'.format(name.value.compressed) + if isinstance(name, x509.RFC822Name): + return 'email:{0}'.format(name.value) + if isinstance(name, x509.UniformResourceIdentifier): + return 'URI:{0}'.format(name.value) + if isinstance(name, x509.DirectoryName): + # FIXME: test + return 'DirName:' + ''.join(['/{0}:{1}'.format(attribute.oid._name, attribute.value) for attribute in name.value]) + if isinstance(name, x509.RegisteredID): + # FIXME: test + return 'RegisteredID:{0}'.format(name.value) + if isinstance(name, x509.OtherName): + # FIXME: test + return '{0}:{1}'.format(name.type_id.dotted_string, _get_hex(name.value)) + raise OpenSSLObjectError('Cannot decode name "{0}"'.format(name)) + + +def _cryptography_get_keyusage(usage): + ''' + Given a key usage identifier string, returns the parameter name used by cryptography's x509.KeyUsage(). + Raises an OpenSSLObjectError if the identifier is unknown. + ''' + if usage in ('Digital Signature', 'digitalSignature'): + return 'digital_signature' + if usage in ('Non Repudiation', 'nonRepudiation'): + return 'content_commitment' + if usage in ('Key Encipherment', 'keyEncipherment'): + return 'key_encipherment' + if usage in ('Data Encipherment', 'dataEncipherment'): + return 'data_encipherment' + if usage in ('Key Agreement', 'keyAgreement'): + return 'key_agreement' + if usage in ('Certificate Sign', 'keyCertSign'): + return 'key_cert_sign' + if usage in ('CRL Sign', 'cRLSign'): + return 'crl_sign' + if usage in ('Encipher Only', 'encipherOnly'): + return 'encipher_only' + if usage in ('Decipher Only', 'decipherOnly'): + return 'decipher_only' + raise OpenSSLObjectError('Unknown key usage "{0}"'.format(usage)) + + +def cryptography_parse_key_usage_params(usages): + ''' + Given a list of key usage identifier strings, returns the parameters for cryptography's x509.KeyUsage(). + Raises an OpenSSLObjectError if an identifier is unknown. + ''' + params = dict( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) + for usage in usages: + params[_cryptography_get_keyusage(usage)] = True + return params + + +def cryptography_get_basic_constraints(constraints): + ''' + Given a list of constraints, returns a tuple (ca, path_length). + Raises an OpenSSLObjectError if a constraint is unknown or cannot be parsed. + ''' + ca = False + path_length = None + if constraints: + for constraint in constraints: + if constraint.startswith('CA:'): + if constraint == 'CA:TRUE': + ca = True + elif constraint == 'CA:FALSE': + ca = False + else: + raise OpenSSLObjectError('Unknown basic constraint value "{0}" for CA'.format(constraint[3:])) + elif constraint.startswith('pathlen:'): + v = constraint[len('pathlen:'):] + try: + path_length = int(v) + except Exception as e: + raise OpenSSLObjectError('Cannot parse path length constraint "{0}" ({1})'.format(v, e)) + else: + raise OpenSSLObjectError('Unknown basic constraint "{0}"'.format(constraint)) + return ca, path_length + + +def binary_exp_mod(f, e, m): + '''Computes f^e mod m in O(log e) multiplications modulo m.''' + # Compute len_e = floor(log_2(e)) + len_e = -1 + x = e + while x > 0: + x >>= 1 + len_e += 1 + # Compute f**e mod m + result = 1 + for k in range(len_e, -1, -1): + result = (result * result) % m + if ((e >> k) & 1) != 0: + result = (result * f) % m + return result + + +def simple_gcd(a, b): + '''Compute GCD of its two inputs.''' + while b != 0: + a, b = b, a % b + return a + + +def quick_is_not_prime(n): + '''Does some quick checks to see if we can poke a hole into the primality of n. + + A result of `False` does **not** mean that the number is prime; it just means + that we couldn't detect quickly whether it is not prime. + ''' + if n <= 2: + return True + # The constant in the next line is the product of all primes < 200 + if simple_gcd(n, 7799922041683461553249199106329813876687996789903550945093032474868511536164700810) > 1: + return True + # TODO: maybe do some iterations of Miller-Rabin to increase confidence + # (https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test) + return False + + +python_version = (sys.version_info[0], sys.version_info[1]) +if python_version >= (2, 7) or python_version >= (3, 1): + # Ansible still supports Python 2.6 on remote nodes + def count_bits(no): + no = abs(no) + if no == 0: + return 0 + return no.bit_length() +else: + # Slow, but works + def count_bits(no): + no = abs(no) + count = 0 + while no > 0: + no >>= 1 + count += 1 + return count + + +PEM_START = '-----BEGIN ' +PEM_END = '-----' +PKCS8_PRIVATEKEY_NAMES = ('PRIVATE KEY', 'ENCRYPTED PRIVATE KEY') +PKCS1_PRIVATEKEY_SUFFIX = ' PRIVATE KEY' + + +def identify_private_key_format(content): + '''Given the contents of a private key file, identifies its format.''' + # See https://github.com/openssl/openssl/blob/master/crypto/pem/pem_pkey.c#L40-L85 + # (PEM_read_bio_PrivateKey) + # and https://github.com/openssl/openssl/blob/master/include/openssl/pem.h#L46-L47 + # (PEM_STRING_PKCS8, PEM_STRING_PKCS8INF) + try: + lines = content.decode('utf-8').splitlines(False) + if lines[0].startswith(PEM_START) and lines[0].endswith(PEM_END) and len(lines[0]) > len(PEM_START) + len(PEM_END): + name = lines[0][len(PEM_START):-len(PEM_END)] + if name in PKCS8_PRIVATEKEY_NAMES: + return 'pkcs8' + if len(name) > len(PKCS1_PRIVATEKEY_SUFFIX) and name.endswith(PKCS1_PRIVATEKEY_SUFFIX): + return 'pkcs1' + return 'unknown-pem' + except UnicodeDecodeError: + pass + return 'raw' + + +def cryptography_key_needs_digest_for_signing(key): + '''Tests whether the given private key requires a digest algorithm for signing. + + Ed25519 and Ed448 keys do not; they need None to be passed as the digest algorithm. + ''' + if CRYPTOGRAPHY_HAS_ED25519 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PrivateKey): + return False + if CRYPTOGRAPHY_HAS_ED448 and isinstance(key, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PrivateKey): + return False + return True + + +def cryptography_compare_public_keys(key1, key2): + '''Tests whether two public keys are the same. + + Needs special logic for Ed25519 and Ed448 keys, since they do not have public_numbers(). + ''' + if CRYPTOGRAPHY_HAS_ED25519: + a = isinstance(key1, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey) + b = isinstance(key2, cryptography.hazmat.primitives.asymmetric.ed25519.Ed25519PublicKey) + if a or b: + if not a or not b: + return False + a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + return a == b + if CRYPTOGRAPHY_HAS_ED448: + a = isinstance(key1, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey) + b = isinstance(key2, cryptography.hazmat.primitives.asymmetric.ed448.Ed448PublicKey) + if a or b: + if not a or not b: + return False + a = key1.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + b = key2.public_bytes(serialization.Encoding.Raw, serialization.PublicFormat.Raw) + return a == b + return key1.public_numbers() == key2.public_numbers() + + +if HAS_CRYPTOGRAPHY: + REVOCATION_REASON_MAP = { + 'unspecified': x509.ReasonFlags.unspecified, + 'key_compromise': x509.ReasonFlags.key_compromise, + 'ca_compromise': x509.ReasonFlags.ca_compromise, + 'affiliation_changed': x509.ReasonFlags.affiliation_changed, + 'superseded': x509.ReasonFlags.superseded, + 'cessation_of_operation': x509.ReasonFlags.cessation_of_operation, + 'certificate_hold': x509.ReasonFlags.certificate_hold, + 'privilege_withdrawn': x509.ReasonFlags.privilege_withdrawn, + 'aa_compromise': x509.ReasonFlags.aa_compromise, + 'remove_from_crl': x509.ReasonFlags.remove_from_crl, + } + REVOCATION_REASON_MAP_INVERSE = dict() + for k, v in REVOCATION_REASON_MAP.items(): + REVOCATION_REASON_MAP_INVERSE[v] = k + + +def cryptography_decode_revoked_certificate(cert): + result = { + 'serial_number': cert.serial_number, + 'revocation_date': cert.revocation_date, + 'issuer': None, + 'issuer_critical': False, + 'reason': None, + 'reason_critical': False, + 'invalidity_date': None, + 'invalidity_date_critical': False, + } + try: + ext = cert.extensions.get_extension_for_class(x509.CertificateIssuer) + result['issuer'] = list(ext.value) + result['issuer_critical'] = ext.critical + except x509.ExtensionNotFound: + pass + try: + ext = cert.extensions.get_extension_for_class(x509.CRLReason) + result['reason'] = ext.value.reason + result['reason_critical'] = ext.critical + except x509.ExtensionNotFound: + pass + try: + ext = cert.extensions.get_extension_for_class(x509.InvalidityDate) + result['invalidity_date'] = ext.value.invalidity_date + result['invalidity_date_critical'] = ext.critical + except x509.ExtensionNotFound: + pass + return result diff --git a/test/support/integration/plugins/module_utils/database.py b/test/support/integration/plugins/module_utils/database.py new file mode 100644 index 00000000..014939a2 --- /dev/null +++ b/test/support/integration/plugins/module_utils/database.py @@ -0,0 +1,142 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com> +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +class SQLParseError(Exception): + pass + + +class UnclosedQuoteError(SQLParseError): + pass + + +# maps a type of identifier to the maximum number of dot levels that are +# allowed to specify that identifier. For example, a database column can be +# specified by up to 4 levels: database.schema.table.column +_PG_IDENTIFIER_TO_DOT_LEVEL = dict( + database=1, + schema=2, + table=3, + column=4, + role=1, + tablespace=1, + sequence=3, + publication=1, +) +_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1) + + +def _find_end_quote(identifier, quote_char): + accumulate = 0 + while True: + try: + quote = identifier.index(quote_char) + except ValueError: + raise UnclosedQuoteError + accumulate = accumulate + quote + try: + next_char = identifier[quote + 1] + except IndexError: + return accumulate + if next_char == quote_char: + try: + identifier = identifier[quote + 2:] + accumulate = accumulate + 2 + except IndexError: + raise UnclosedQuoteError + else: + return accumulate + + +def _identifier_parse(identifier, quote_char): + if not identifier: + raise SQLParseError('Identifier name unspecified or unquoted trailing dot') + + already_quoted = False + if identifier.startswith(quote_char): + already_quoted = True + try: + end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1 + except UnclosedQuoteError: + already_quoted = False + else: + if end_quote < len(identifier) - 1: + if identifier[end_quote + 1] == '.': + dot = end_quote + 1 + first_identifier = identifier[:dot] + next_identifier = identifier[dot + 1:] + further_identifiers = _identifier_parse(next_identifier, quote_char) + further_identifiers.insert(0, first_identifier) + else: + raise SQLParseError('User escaped identifiers must escape extra quotes') + else: + further_identifiers = [identifier] + + if not already_quoted: + try: + dot = identifier.index('.') + except ValueError: + identifier = identifier.replace(quote_char, quote_char * 2) + identifier = ''.join((quote_char, identifier, quote_char)) + further_identifiers = [identifier] + else: + if dot == 0 or dot >= len(identifier) - 1: + identifier = identifier.replace(quote_char, quote_char * 2) + identifier = ''.join((quote_char, identifier, quote_char)) + further_identifiers = [identifier] + else: + first_identifier = identifier[:dot] + next_identifier = identifier[dot + 1:] + further_identifiers = _identifier_parse(next_identifier, quote_char) + first_identifier = first_identifier.replace(quote_char, quote_char * 2) + first_identifier = ''.join((quote_char, first_identifier, quote_char)) + further_identifiers.insert(0, first_identifier) + + return further_identifiers + + +def pg_quote_identifier(identifier, id_type): + identifier_fragments = _identifier_parse(identifier, quote_char='"') + if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]: + raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type])) + return '.'.join(identifier_fragments) + + +def mysql_quote_identifier(identifier, id_type): + identifier_fragments = _identifier_parse(identifier, quote_char='`') + if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]: + raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type])) + + special_cased_fragments = [] + for fragment in identifier_fragments: + if fragment == '`*`': + special_cased_fragments.append('*') + else: + special_cased_fragments.append(fragment) + + return '.'.join(special_cased_fragments) diff --git a/test/support/integration/plugins/module_utils/docker/__init__.py b/test/support/integration/plugins/module_utils/docker/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/docker/__init__.py diff --git a/test/support/integration/plugins/module_utils/docker/common.py b/test/support/integration/plugins/module_utils/docker/common.py new file mode 100644 index 00000000..03307250 --- /dev/null +++ b/test/support/integration/plugins/module_utils/docker/common.py @@ -0,0 +1,1022 @@ +# +# Copyright 2016 Red Hat | Ansible +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +import os +import platform +import re +import sys +from datetime import timedelta +from distutils.version import LooseVersion + + +from ansible.module_utils.basic import AnsibleModule, env_fallback, missing_required_lib +from ansible.module_utils.common._collections_compat import Mapping, Sequence +from ansible.module_utils.six import string_types +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ansible.module_utils.parsing.convert_bool import BOOLEANS_TRUE, BOOLEANS_FALSE + +HAS_DOCKER_PY = True +HAS_DOCKER_PY_2 = False +HAS_DOCKER_PY_3 = False +HAS_DOCKER_ERROR = None + +try: + from requests.exceptions import SSLError + from docker import __version__ as docker_version + from docker.errors import APIError, NotFound, TLSParameterError + from docker.tls import TLSConfig + from docker import auth + + if LooseVersion(docker_version) >= LooseVersion('3.0.0'): + HAS_DOCKER_PY_3 = True + from docker import APIClient as Client + elif LooseVersion(docker_version) >= LooseVersion('2.0.0'): + HAS_DOCKER_PY_2 = True + from docker import APIClient as Client + else: + from docker import Client + +except ImportError as exc: + HAS_DOCKER_ERROR = str(exc) + HAS_DOCKER_PY = False + + +# The next 2 imports ``docker.models`` and ``docker.ssladapter`` are used +# to ensure the user does not have both ``docker`` and ``docker-py`` modules +# installed, as they utilize the same namespace are are incompatible +try: + # docker (Docker SDK for Python >= 2.0.0) + import docker.models # noqa: F401 + HAS_DOCKER_MODELS = True +except ImportError: + HAS_DOCKER_MODELS = False + +try: + # docker-py (Docker SDK for Python < 2.0.0) + import docker.ssladapter # noqa: F401 + HAS_DOCKER_SSLADAPTER = True +except ImportError: + HAS_DOCKER_SSLADAPTER = False + + +try: + from requests.exceptions import RequestException +except ImportError: + # Either docker-py is no longer using requests, or docker-py isn't around either, + # or docker-py's dependency requests is missing. In any case, define an exception + # class RequestException so that our code doesn't break. + class RequestException(Exception): + pass + + +DEFAULT_DOCKER_HOST = 'unix://var/run/docker.sock' +DEFAULT_TLS = False +DEFAULT_TLS_VERIFY = False +DEFAULT_TLS_HOSTNAME = 'localhost' +MIN_DOCKER_VERSION = "1.8.0" +DEFAULT_TIMEOUT_SECONDS = 60 + +DOCKER_COMMON_ARGS = dict( + docker_host=dict(type='str', default=DEFAULT_DOCKER_HOST, fallback=(env_fallback, ['DOCKER_HOST']), aliases=['docker_url']), + tls_hostname=dict(type='str', default=DEFAULT_TLS_HOSTNAME, fallback=(env_fallback, ['DOCKER_TLS_HOSTNAME'])), + api_version=dict(type='str', default='auto', fallback=(env_fallback, ['DOCKER_API_VERSION']), aliases=['docker_api_version']), + timeout=dict(type='int', default=DEFAULT_TIMEOUT_SECONDS, fallback=(env_fallback, ['DOCKER_TIMEOUT'])), + ca_cert=dict(type='path', aliases=['tls_ca_cert', 'cacert_path']), + client_cert=dict(type='path', aliases=['tls_client_cert', 'cert_path']), + client_key=dict(type='path', aliases=['tls_client_key', 'key_path']), + ssl_version=dict(type='str', fallback=(env_fallback, ['DOCKER_SSL_VERSION'])), + tls=dict(type='bool', default=DEFAULT_TLS, fallback=(env_fallback, ['DOCKER_TLS'])), + validate_certs=dict(type='bool', default=DEFAULT_TLS_VERIFY, fallback=(env_fallback, ['DOCKER_TLS_VERIFY']), aliases=['tls_verify']), + debug=dict(type='bool', default=False) +) + +DOCKER_MUTUALLY_EXCLUSIVE = [] + +DOCKER_REQUIRED_TOGETHER = [ + ['client_cert', 'client_key'] +] + +DEFAULT_DOCKER_REGISTRY = 'https://index.docker.io/v1/' +EMAIL_REGEX = r'[^@]+@[^@]+\.[^@]+' +BYTE_SUFFIXES = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + + +if not HAS_DOCKER_PY: + docker_version = None + + # No Docker SDK for Python. Create a place holder client to allow + # instantiation of AnsibleModule and proper error handing + class Client(object): # noqa: F811 + def __init__(self, **kwargs): + pass + + class APIError(Exception): # noqa: F811 + pass + + class NotFound(Exception): # noqa: F811 + pass + + +def is_image_name_id(name): + """Check whether the given image name is in fact an image ID (hash).""" + if re.match('^sha256:[0-9a-fA-F]{64}$', name): + return True + return False + + +def is_valid_tag(tag, allow_empty=False): + """Check whether the given string is a valid docker tag name.""" + if not tag: + return allow_empty + # See here ("Extended description") for a definition what tags can be: + # https://docs.docker.com/engine/reference/commandline/tag/ + return bool(re.match('^[a-zA-Z0-9_][a-zA-Z0-9_.-]{0,127}$', tag)) + + +def sanitize_result(data): + """Sanitize data object for return to Ansible. + + When the data object contains types such as docker.types.containers.HostConfig, + Ansible will fail when these are returned via exit_json or fail_json. + HostConfig is derived from dict, but its constructor requires additional + arguments. This function sanitizes data structures by recursively converting + everything derived from dict to dict and everything derived from list (and tuple) + to a list. + """ + if isinstance(data, dict): + return dict((k, sanitize_result(v)) for k, v in data.items()) + elif isinstance(data, (list, tuple)): + return [sanitize_result(v) for v in data] + else: + return data + + +class DockerBaseClass(object): + + def __init__(self): + self.debug = False + + def log(self, msg, pretty_print=False): + pass + # if self.debug: + # log_file = open('docker.log', 'a') + # if pretty_print: + # log_file.write(json.dumps(msg, sort_keys=True, indent=4, separators=(',', ': '))) + # log_file.write(u'\n') + # else: + # log_file.write(msg + u'\n') + + +def update_tls_hostname(result): + if result['tls_hostname'] is None: + # get default machine name from the url + parsed_url = urlparse(result['docker_host']) + if ':' in parsed_url.netloc: + result['tls_hostname'] = parsed_url.netloc[:parsed_url.netloc.rindex(':')] + else: + result['tls_hostname'] = parsed_url + + +def _get_tls_config(fail_function, **kwargs): + try: + tls_config = TLSConfig(**kwargs) + return tls_config + except TLSParameterError as exc: + fail_function("TLS config error: %s" % exc) + + +def get_connect_params(auth, fail_function): + if auth['tls'] or auth['tls_verify']: + auth['docker_host'] = auth['docker_host'].replace('tcp://', 'https://') + + if auth['tls_verify'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and host verification + if auth['cacert_path']: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + ca_cert=auth['cacert_path'], + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + else: + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify'] and auth['cacert_path']: + # TLS with cacert only + tls_config = _get_tls_config(ca_cert=auth['cacert_path'], + assert_hostname=auth['tls_hostname'], + verify=True, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls_verify']: + # TLS with verify and no certs + tls_config = _get_tls_config(verify=True, + assert_hostname=auth['tls_hostname'], + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls'] and auth['cert_path'] and auth['key_path']: + # TLS with certs and no host verification + tls_config = _get_tls_config(client_cert=(auth['cert_path'], auth['key_path']), + verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + if auth['tls']: + # TLS with no certs and not host verification + tls_config = _get_tls_config(verify=False, + ssl_version=auth['ssl_version'], + fail_function=fail_function) + return dict(base_url=auth['docker_host'], + tls=tls_config, + version=auth['api_version'], + timeout=auth['timeout']) + + # No TLS + return dict(base_url=auth['docker_host'], + version=auth['api_version'], + timeout=auth['timeout']) + + +DOCKERPYUPGRADE_SWITCH_TO_DOCKER = "Try `pip uninstall docker-py` followed by `pip install docker`." +DOCKERPYUPGRADE_UPGRADE_DOCKER = "Use `pip install --upgrade docker` to upgrade." +DOCKERPYUPGRADE_RECOMMEND_DOCKER = ("Use `pip install --upgrade docker-py` to upgrade. " + "Hint: if you do not need Python 2.6 support, try " + "`pip uninstall docker-py` instead, followed by `pip install docker`.") + + +class AnsibleDockerClient(Client): + + def __init__(self, argument_spec=None, supports_check_mode=False, mutually_exclusive=None, + required_together=None, required_if=None, min_docker_version=MIN_DOCKER_VERSION, + min_docker_api_version=None, option_minimal_versions=None, + option_minimal_versions_ignore_params=None, fail_results=None): + + # Modules can put information in here which will always be returned + # in case client.fail() is called. + self.fail_results = fail_results or {} + + merged_arg_spec = dict() + merged_arg_spec.update(DOCKER_COMMON_ARGS) + if argument_spec: + merged_arg_spec.update(argument_spec) + self.arg_spec = merged_arg_spec + + mutually_exclusive_params = [] + mutually_exclusive_params += DOCKER_MUTUALLY_EXCLUSIVE + if mutually_exclusive: + mutually_exclusive_params += mutually_exclusive + + required_together_params = [] + required_together_params += DOCKER_REQUIRED_TOGETHER + if required_together: + required_together_params += required_together + + self.module = AnsibleModule( + argument_spec=merged_arg_spec, + supports_check_mode=supports_check_mode, + mutually_exclusive=mutually_exclusive_params, + required_together=required_together_params, + required_if=required_if) + + NEEDS_DOCKER_PY2 = (LooseVersion(min_docker_version) >= LooseVersion('2.0.0')) + + self.docker_py_version = LooseVersion(docker_version) + + if HAS_DOCKER_MODELS and HAS_DOCKER_SSLADAPTER: + self.fail("Cannot have both the docker-py and docker python modules (old and new version of Docker " + "SDK for Python) installed together as they use the same namespace and cause a corrupt " + "installation. Please uninstall both packages, and re-install only the docker-py or docker " + "python module (for %s's Python %s). It is recommended to install the docker module if no " + "support for Python 2.6 is required. Please note that simply uninstalling one of the modules " + "can leave the other module in a broken state." % (platform.node(), sys.executable)) + + if not HAS_DOCKER_PY: + if NEEDS_DOCKER_PY2: + msg = missing_required_lib("Docker SDK for Python: docker") + msg = msg + ", for example via `pip install docker`. The error was: %s" + else: + msg = missing_required_lib("Docker SDK for Python: docker (Python >= 2.7) or docker-py (Python 2.6)") + msg = msg + ", for example via `pip install docker` or `pip install docker-py` (Python 2.6). The error was: %s" + self.fail(msg % HAS_DOCKER_ERROR) + + if self.docker_py_version < LooseVersion(min_docker_version): + msg = "Error: Docker SDK for Python version is %s (%s's Python %s). Minimum version required is %s." + if not NEEDS_DOCKER_PY2: + # The minimal required version is < 2.0 (and the current version as well). + # Advertise docker (instead of docker-py) for non-Python-2.6 users. + msg += DOCKERPYUPGRADE_RECOMMEND_DOCKER + elif docker_version < LooseVersion('2.0'): + msg += DOCKERPYUPGRADE_SWITCH_TO_DOCKER + else: + msg += DOCKERPYUPGRADE_UPGRADE_DOCKER + self.fail(msg % (docker_version, platform.node(), sys.executable, min_docker_version)) + + self.debug = self.module.params.get('debug') + self.check_mode = self.module.check_mode + self._connect_params = get_connect_params(self.auth_params, fail_function=self.fail) + + try: + super(AnsibleDockerClient, self).__init__(**self._connect_params) + self.docker_api_version_str = self.version()['ApiVersion'] + except APIError as exc: + self.fail("Docker API error: %s" % exc) + except Exception as exc: + self.fail("Error connecting: %s" % exc) + + self.docker_api_version = LooseVersion(self.docker_api_version_str) + if min_docker_api_version is not None: + if self.docker_api_version < LooseVersion(min_docker_api_version): + self.fail('Docker API version is %s. Minimum version required is %s.' % (self.docker_api_version_str, min_docker_api_version)) + + if option_minimal_versions is not None: + self._get_minimal_versions(option_minimal_versions, option_minimal_versions_ignore_params) + + def log(self, msg, pretty_print=False): + pass + # if self.debug: + # log_file = open('docker.log', 'a') + # if pretty_print: + # log_file.write(json.dumps(msg, sort_keys=True, indent=4, separators=(',', ': '))) + # log_file.write(u'\n') + # else: + # log_file.write(msg + u'\n') + + def fail(self, msg, **kwargs): + self.fail_results.update(kwargs) + self.module.fail_json(msg=msg, **sanitize_result(self.fail_results)) + + @staticmethod + def _get_value(param_name, param_value, env_variable, default_value): + if param_value is not None: + # take module parameter value + if param_value in BOOLEANS_TRUE: + return True + if param_value in BOOLEANS_FALSE: + return False + return param_value + + if env_variable is not None: + env_value = os.environ.get(env_variable) + if env_value is not None: + # take the env variable value + if param_name == 'cert_path': + return os.path.join(env_value, 'cert.pem') + if param_name == 'cacert_path': + return os.path.join(env_value, 'ca.pem') + if param_name == 'key_path': + return os.path.join(env_value, 'key.pem') + if env_value in BOOLEANS_TRUE: + return True + if env_value in BOOLEANS_FALSE: + return False + return env_value + + # take the default + return default_value + + @property + def auth_params(self): + # Get authentication credentials. + # Precedence: module parameters-> environment variables-> defaults. + + self.log('Getting credentials') + + params = dict() + for key in DOCKER_COMMON_ARGS: + params[key] = self.module.params.get(key) + + if self.module.params.get('use_tls'): + # support use_tls option in docker_image.py. This will be deprecated. + use_tls = self.module.params.get('use_tls') + if use_tls == 'encrypt': + params['tls'] = True + if use_tls == 'verify': + params['validate_certs'] = True + + result = dict( + docker_host=self._get_value('docker_host', params['docker_host'], 'DOCKER_HOST', + DEFAULT_DOCKER_HOST), + tls_hostname=self._get_value('tls_hostname', params['tls_hostname'], + 'DOCKER_TLS_HOSTNAME', DEFAULT_TLS_HOSTNAME), + api_version=self._get_value('api_version', params['api_version'], 'DOCKER_API_VERSION', + 'auto'), + cacert_path=self._get_value('cacert_path', params['ca_cert'], 'DOCKER_CERT_PATH', None), + cert_path=self._get_value('cert_path', params['client_cert'], 'DOCKER_CERT_PATH', None), + key_path=self._get_value('key_path', params['client_key'], 'DOCKER_CERT_PATH', None), + ssl_version=self._get_value('ssl_version', params['ssl_version'], 'DOCKER_SSL_VERSION', None), + tls=self._get_value('tls', params['tls'], 'DOCKER_TLS', DEFAULT_TLS), + tls_verify=self._get_value('tls_verfy', params['validate_certs'], 'DOCKER_TLS_VERIFY', + DEFAULT_TLS_VERIFY), + timeout=self._get_value('timeout', params['timeout'], 'DOCKER_TIMEOUT', + DEFAULT_TIMEOUT_SECONDS), + ) + + update_tls_hostname(result) + + return result + + def _handle_ssl_error(self, error): + match = re.match(r"hostname.*doesn\'t match (\'.*\')", str(error)) + if match: + self.fail("You asked for verification that Docker daemons certificate's hostname matches %s. " + "The actual certificate's hostname is %s. Most likely you need to set DOCKER_TLS_HOSTNAME " + "or pass `tls_hostname` with a value of %s. You may also use TLS without verification by " + "setting the `tls` parameter to true." + % (self.auth_params['tls_hostname'], match.group(1), match.group(1))) + self.fail("SSL Exception: %s" % (error)) + + def _get_minimal_versions(self, option_minimal_versions, ignore_params=None): + self.option_minimal_versions = dict() + for option in self.module.argument_spec: + if ignore_params is not None: + if option in ignore_params: + continue + self.option_minimal_versions[option] = dict() + self.option_minimal_versions.update(option_minimal_versions) + + for option, data in self.option_minimal_versions.items(): + # Test whether option is supported, and store result + support_docker_py = True + support_docker_api = True + if 'docker_py_version' in data: + support_docker_py = self.docker_py_version >= LooseVersion(data['docker_py_version']) + if 'docker_api_version' in data: + support_docker_api = self.docker_api_version >= LooseVersion(data['docker_api_version']) + data['supported'] = support_docker_py and support_docker_api + # Fail if option is not supported but used + if not data['supported']: + # Test whether option is specified + if 'detect_usage' in data: + used = data['detect_usage'](self) + else: + used = self.module.params.get(option) is not None + if used and 'default' in self.module.argument_spec[option]: + used = self.module.params[option] != self.module.argument_spec[option]['default'] + if used: + # If the option is used, compose error message. + if 'usage_msg' in data: + usg = data['usage_msg'] + else: + usg = 'set %s option' % (option, ) + if not support_docker_api: + msg = 'Docker API version is %s. Minimum version required is %s to %s.' + msg = msg % (self.docker_api_version_str, data['docker_api_version'], usg) + elif not support_docker_py: + msg = "Docker SDK for Python version is %s (%s's Python %s). Minimum version required is %s to %s. " + if LooseVersion(data['docker_py_version']) < LooseVersion('2.0.0'): + msg += DOCKERPYUPGRADE_RECOMMEND_DOCKER + elif self.docker_py_version < LooseVersion('2.0.0'): + msg += DOCKERPYUPGRADE_SWITCH_TO_DOCKER + else: + msg += DOCKERPYUPGRADE_UPGRADE_DOCKER + msg = msg % (docker_version, platform.node(), sys.executable, data['docker_py_version'], usg) + else: + # should not happen + msg = 'Cannot %s with your configuration.' % (usg, ) + self.fail(msg) + + def get_container_by_id(self, container_id): + try: + self.log("Inspecting container Id %s" % container_id) + result = self.inspect_container(container=container_id) + self.log("Completed container inspection") + return result + except NotFound as dummy: + return None + except Exception as exc: + self.fail("Error inspecting container: %s" % exc) + + def get_container(self, name=None): + ''' + Lookup a container and return the inspection results. + ''' + if name is None: + return None + + search_name = name + if not name.startswith('/'): + search_name = '/' + name + + result = None + try: + for container in self.containers(all=True): + self.log("testing container: %s" % (container['Names'])) + if isinstance(container['Names'], list) and search_name in container['Names']: + result = container + break + if container['Id'].startswith(name): + result = container + break + if container['Id'] == name: + result = container + break + except SSLError as exc: + self._handle_ssl_error(exc) + except Exception as exc: + self.fail("Error retrieving container list: %s" % exc) + + if result is None: + return None + + return self.get_container_by_id(result['Id']) + + def get_network(self, name=None, network_id=None): + ''' + Lookup a network and return the inspection results. + ''' + if name is None and network_id is None: + return None + + result = None + + if network_id is None: + try: + for network in self.networks(): + self.log("testing network: %s" % (network['Name'])) + if name == network['Name']: + result = network + break + if network['Id'].startswith(name): + result = network + break + except SSLError as exc: + self._handle_ssl_error(exc) + except Exception as exc: + self.fail("Error retrieving network list: %s" % exc) + + if result is not None: + network_id = result['Id'] + + if network_id is not None: + try: + self.log("Inspecting network Id %s" % network_id) + result = self.inspect_network(network_id) + self.log("Completed network inspection") + except NotFound as dummy: + return None + except Exception as exc: + self.fail("Error inspecting network: %s" % exc) + + return result + + def find_image(self, name, tag): + ''' + Lookup an image (by name and tag) and return the inspection results. + ''' + if not name: + return None + + self.log("Find image %s:%s" % (name, tag)) + images = self._image_lookup(name, tag) + if not images: + # In API <= 1.20 seeing 'docker.io/<name>' as the name of images pulled from docker hub + registry, repo_name = auth.resolve_repository_name(name) + if registry == 'docker.io': + # If docker.io is explicitly there in name, the image + # isn't found in some cases (#41509) + self.log("Check for docker.io image: %s" % repo_name) + images = self._image_lookup(repo_name, tag) + if not images and repo_name.startswith('library/'): + # Sometimes library/xxx images are not found + lookup = repo_name[len('library/'):] + self.log("Check for docker.io image: %s" % lookup) + images = self._image_lookup(lookup, tag) + if not images: + # Last case: if docker.io wasn't there, it can be that + # the image wasn't found either (#15586) + lookup = "%s/%s" % (registry, repo_name) + self.log("Check for docker.io image: %s" % lookup) + images = self._image_lookup(lookup, tag) + + if len(images) > 1: + self.fail("Registry returned more than one result for %s:%s" % (name, tag)) + + if len(images) == 1: + try: + inspection = self.inspect_image(images[0]['Id']) + except Exception as exc: + self.fail("Error inspecting image %s:%s - %s" % (name, tag, str(exc))) + return inspection + + self.log("Image %s:%s not found." % (name, tag)) + return None + + def find_image_by_id(self, image_id): + ''' + Lookup an image (by ID) and return the inspection results. + ''' + if not image_id: + return None + + self.log("Find image %s (by ID)" % image_id) + try: + inspection = self.inspect_image(image_id) + except Exception as exc: + self.fail("Error inspecting image ID %s - %s" % (image_id, str(exc))) + return inspection + + def _image_lookup(self, name, tag): + ''' + Including a tag in the name parameter sent to the Docker SDK for Python images method + does not work consistently. Instead, get the result set for name and manually check + if the tag exists. + ''' + try: + response = self.images(name=name) + except Exception as exc: + self.fail("Error searching for image %s - %s" % (name, str(exc))) + images = response + if tag: + lookup = "%s:%s" % (name, tag) + lookup_digest = "%s@%s" % (name, tag) + images = [] + for image in response: + tags = image.get('RepoTags') + digests = image.get('RepoDigests') + if (tags and lookup in tags) or (digests and lookup_digest in digests): + images = [image] + break + return images + + def pull_image(self, name, tag="latest"): + ''' + Pull an image + ''' + self.log("Pulling image %s:%s" % (name, tag)) + old_tag = self.find_image(name, tag) + try: + for line in self.pull(name, tag=tag, stream=True, decode=True): + self.log(line, pretty_print=True) + if line.get('error'): + if line.get('errorDetail'): + error_detail = line.get('errorDetail') + self.fail("Error pulling %s - code: %s message: %s" % (name, + error_detail.get('code'), + error_detail.get('message'))) + else: + self.fail("Error pulling %s - %s" % (name, line.get('error'))) + except Exception as exc: + self.fail("Error pulling image %s:%s - %s" % (name, tag, str(exc))) + + new_tag = self.find_image(name, tag) + + return new_tag, old_tag == new_tag + + def report_warnings(self, result, warnings_key=None): + ''' + Checks result of client operation for warnings, and if present, outputs them. + + warnings_key should be a list of keys used to crawl the result dictionary. + For example, if warnings_key == ['a', 'b'], the function will consider + result['a']['b'] if these keys exist. If the result is a non-empty string, it + will be reported as a warning. If the result is a list, every entry will be + reported as a warning. + + In most cases (if warnings are returned at all), warnings_key should be + ['Warnings'] or ['Warning']. The default value (if not specified) is ['Warnings']. + ''' + if warnings_key is None: + warnings_key = ['Warnings'] + for key in warnings_key: + if not isinstance(result, Mapping): + return + result = result.get(key) + if isinstance(result, Sequence): + for warning in result: + self.module.warn('Docker warning: {0}'.format(warning)) + elif isinstance(result, string_types) and result: + self.module.warn('Docker warning: {0}'.format(result)) + + def inspect_distribution(self, image, **kwargs): + ''' + Get image digest by directly calling the Docker API when running Docker SDK < 4.0.0 + since prior versions did not support accessing private repositories. + ''' + if self.docker_py_version < LooseVersion('4.0.0'): + registry = auth.resolve_repository_name(image)[0] + header = auth.get_config_header(self, registry) + if header: + return self._result(self._get( + self._url('/distribution/{0}/json', image), + headers={'X-Registry-Auth': header} + ), json=True) + return super(AnsibleDockerClient, self).inspect_distribution(image, **kwargs) + + +def compare_dict_allow_more_present(av, bv): + ''' + Compare two dictionaries for whether every entry of the first is in the second. + ''' + for key, value in av.items(): + if key not in bv: + return False + if bv[key] != value: + return False + return True + + +def compare_generic(a, b, method, datatype): + ''' + Compare values a and b as described by method and datatype. + + Returns ``True`` if the values compare equal, and ``False`` if not. + + ``a`` is usually the module's parameter, while ``b`` is a property + of the current object. ``a`` must not be ``None`` (except for + ``datatype == 'value'``). + + Valid values for ``method`` are: + - ``ignore`` (always compare as equal); + - ``strict`` (only compare if really equal) + - ``allow_more_present`` (allow b to have elements which a does not have). + + Valid values for ``datatype`` are: + - ``value``: for simple values (strings, numbers, ...); + - ``list``: for ``list``s or ``tuple``s where order matters; + - ``set``: for ``list``s, ``tuple``s or ``set``s where order does not + matter; + - ``set(dict)``: for ``list``s, ``tuple``s or ``sets`` where order does + not matter and which contain ``dict``s; ``allow_more_present`` is used + for the ``dict``s, and these are assumed to be dictionaries of values; + - ``dict``: for dictionaries of values. + ''' + if method == 'ignore': + return True + # If a or b is None: + if a is None or b is None: + # If both are None: equality + if a == b: + return True + # Otherwise, not equal for values, and equal + # if the other is empty for set/list/dict + if datatype == 'value': + return False + # For allow_more_present, allow a to be None + if method == 'allow_more_present' and a is None: + return True + # Otherwise, the iterable object which is not None must have length 0 + return len(b if a is None else a) == 0 + # Do proper comparison (both objects not None) + if datatype == 'value': + return a == b + elif datatype == 'list': + if method == 'strict': + return a == b + else: + i = 0 + for v in a: + while i < len(b) and b[i] != v: + i += 1 + if i == len(b): + return False + i += 1 + return True + elif datatype == 'dict': + if method == 'strict': + return a == b + else: + return compare_dict_allow_more_present(a, b) + elif datatype == 'set': + set_a = set(a) + set_b = set(b) + if method == 'strict': + return set_a == set_b + else: + return set_b >= set_a + elif datatype == 'set(dict)': + for av in a: + found = False + for bv in b: + if compare_dict_allow_more_present(av, bv): + found = True + break + if not found: + return False + if method == 'strict': + # If we would know that both a and b do not contain duplicates, + # we could simply compare len(a) to len(b) to finish this test. + # We can assume that b has no duplicates (as it is returned by + # docker), but we don't know for a. + for bv in b: + found = False + for av in a: + if compare_dict_allow_more_present(av, bv): + found = True + break + if not found: + return False + return True + + +class DifferenceTracker(object): + def __init__(self): + self._diff = [] + + def add(self, name, parameter=None, active=None): + self._diff.append(dict( + name=name, + parameter=parameter, + active=active, + )) + + def merge(self, other_tracker): + self._diff.extend(other_tracker._diff) + + @property + def empty(self): + return len(self._diff) == 0 + + def get_before_after(self): + ''' + Return texts ``before`` and ``after``. + ''' + before = dict() + after = dict() + for item in self._diff: + before[item['name']] = item['active'] + after[item['name']] = item['parameter'] + return before, after + + def has_difference_for(self, name): + ''' + Returns a boolean if a difference exists for name + ''' + return any(diff for diff in self._diff if diff['name'] == name) + + def get_legacy_docker_container_diffs(self): + ''' + Return differences in the docker_container legacy format. + ''' + result = [] + for entry in self._diff: + item = dict() + item[entry['name']] = dict( + parameter=entry['parameter'], + container=entry['active'], + ) + result.append(item) + return result + + def get_legacy_docker_diffs(self): + ''' + Return differences in the docker_container legacy format. + ''' + result = [entry['name'] for entry in self._diff] + return result + + +def clean_dict_booleans_for_docker_api(data): + ''' + Go doesn't like Python booleans 'True' or 'False', while Ansible is just + fine with them in YAML. As such, they need to be converted in cases where + we pass dictionaries to the Docker API (e.g. docker_network's + driver_options and docker_prune's filters). + ''' + result = dict() + if data is not None: + for k, v in data.items(): + if v is True: + v = 'true' + elif v is False: + v = 'false' + else: + v = str(v) + result[str(k)] = v + return result + + +def convert_duration_to_nanosecond(time_str): + """ + Return time duration in nanosecond. + """ + if not isinstance(time_str, str): + raise ValueError('Missing unit in duration - %s' % time_str) + + regex = re.compile( + r'^(((?P<hours>\d+)h)?' + r'((?P<minutes>\d+)m(?!s))?' + r'((?P<seconds>\d+)s)?' + r'((?P<milliseconds>\d+)ms)?' + r'((?P<microseconds>\d+)us)?)$' + ) + parts = regex.match(time_str) + + if not parts: + raise ValueError('Invalid time duration - %s' % time_str) + + parts = parts.groupdict() + time_params = {} + for (name, value) in parts.items(): + if value: + time_params[name] = int(value) + + delta = timedelta(**time_params) + time_in_nanoseconds = ( + delta.microseconds + (delta.seconds + delta.days * 24 * 3600) * 10 ** 6 + ) * 10 ** 3 + + return time_in_nanoseconds + + +def parse_healthcheck(healthcheck): + """ + Return dictionary of healthcheck parameters and boolean if + healthcheck defined in image was requested to be disabled. + """ + if (not healthcheck) or (not healthcheck.get('test')): + return None, None + + result = dict() + + # All supported healthcheck parameters + options = dict( + test='test', + interval='interval', + timeout='timeout', + start_period='start_period', + retries='retries' + ) + + duration_options = ['interval', 'timeout', 'start_period'] + + for (key, value) in options.items(): + if value in healthcheck: + if healthcheck.get(value) is None: + # due to recursive argument_spec, all keys are always present + # (but have default value None if not specified) + continue + if value in duration_options: + time = convert_duration_to_nanosecond(healthcheck.get(value)) + if time: + result[key] = time + elif healthcheck.get(value): + result[key] = healthcheck.get(value) + if key == 'test': + if isinstance(result[key], (tuple, list)): + result[key] = [str(e) for e in result[key]] + else: + result[key] = ['CMD-SHELL', str(result[key])] + elif key == 'retries': + try: + result[key] = int(result[key]) + except ValueError: + raise ValueError( + 'Cannot parse number of retries for healthcheck. ' + 'Expected an integer, got "{0}".'.format(result[key]) + ) + + if result['test'] == ['NONE']: + # If the user explicitly disables the healthcheck, return None + # as the healthcheck object, and set disable_healthcheck to True + return None, True + + return result, False + + +def omit_none_from_dict(d): + """ + Return a copy of the dictionary with all keys with value None omitted. + """ + return dict((k, v) for (k, v) in d.items() if v is not None) diff --git a/test/support/integration/plugins/module_utils/docker/swarm.py b/test/support/integration/plugins/module_utils/docker/swarm.py new file mode 100644 index 00000000..55d94db0 --- /dev/null +++ b/test/support/integration/plugins/module_utils/docker/swarm.py @@ -0,0 +1,280 @@ +# (c) 2019 Piotr Wojciechowski (@wojciechowskipiotr) <piotr@it-playground.pl> +# (c) Thierry Bouvet (@tbouvet) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + + +import json +from time import sleep + +try: + from docker.errors import APIError, NotFound +except ImportError: + # missing Docker SDK for Python handled in ansible.module_utils.docker.common + pass + +from ansible.module_utils._text import to_native +from ansible.module_utils.docker.common import ( + AnsibleDockerClient, + LooseVersion, +) + + +class AnsibleDockerSwarmClient(AnsibleDockerClient): + + def __init__(self, **kwargs): + super(AnsibleDockerSwarmClient, self).__init__(**kwargs) + + def get_swarm_node_id(self): + """ + Get the 'NodeID' of the Swarm node or 'None' if host is not in Swarm. It returns the NodeID + of Docker host the module is executed on + :return: + NodeID of host or 'None' if not part of Swarm + """ + + try: + info = self.info() + except APIError as exc: + self.fail("Failed to get node information for %s" % to_native(exc)) + + if info: + json_str = json.dumps(info, ensure_ascii=False) + swarm_info = json.loads(json_str) + if swarm_info['Swarm']['NodeID']: + return swarm_info['Swarm']['NodeID'] + return None + + def check_if_swarm_node(self, node_id=None): + """ + Checking if host is part of Docker Swarm. If 'node_id' is not provided it reads the Docker host + system information looking if specific key in output exists. If 'node_id' is provided then it tries to + read node information assuming it is run on Swarm manager. The get_node_inspect() method handles exception if + it is not executed on Swarm manager + + :param node_id: Node identifier + :return: + bool: True if node is part of Swarm, False otherwise + """ + + if node_id is None: + try: + info = self.info() + except APIError: + self.fail("Failed to get host information.") + + if info: + json_str = json.dumps(info, ensure_ascii=False) + swarm_info = json.loads(json_str) + if swarm_info['Swarm']['NodeID']: + return True + if swarm_info['Swarm']['LocalNodeState'] in ('active', 'pending', 'locked'): + return True + return False + else: + try: + node_info = self.get_node_inspect(node_id=node_id) + except APIError: + return + + if node_info['ID'] is not None: + return True + return False + + def check_if_swarm_manager(self): + """ + Checks if node role is set as Manager in Swarm. The node is the docker host on which module action + is performed. The inspect_swarm() will fail if node is not a manager + + :return: True if node is Swarm Manager, False otherwise + """ + + try: + self.inspect_swarm() + return True + except APIError: + return False + + def fail_task_if_not_swarm_manager(self): + """ + If host is not a swarm manager then Ansible task on this host should end with 'failed' state + """ + if not self.check_if_swarm_manager(): + self.fail("Error running docker swarm module: must run on swarm manager node") + + def check_if_swarm_worker(self): + """ + Checks if node role is set as Worker in Swarm. The node is the docker host on which module action + is performed. Will fail if run on host that is not part of Swarm via check_if_swarm_node() + + :return: True if node is Swarm Worker, False otherwise + """ + + if self.check_if_swarm_node() and not self.check_if_swarm_manager(): + return True + return False + + def check_if_swarm_node_is_down(self, node_id=None, repeat_check=1): + """ + Checks if node status on Swarm manager is 'down'. If node_id is provided it query manager about + node specified in parameter, otherwise it query manager itself. If run on Swarm Worker node or + host that is not part of Swarm it will fail the playbook + + :param repeat_check: number of check attempts with 5 seconds delay between them, by default check only once + :param node_id: node ID or name, if None then method will try to get node_id of host module run on + :return: + True if node is part of swarm but its state is down, False otherwise + """ + + if repeat_check < 1: + repeat_check = 1 + + if node_id is None: + node_id = self.get_swarm_node_id() + + for retry in range(0, repeat_check): + if retry > 0: + sleep(5) + node_info = self.get_node_inspect(node_id=node_id) + if node_info['Status']['State'] == 'down': + return True + return False + + def get_node_inspect(self, node_id=None, skip_missing=False): + """ + Returns Swarm node info as in 'docker node inspect' command about single node + + :param skip_missing: if True then function will return None instead of failing the task + :param node_id: node ID or name, if None then method will try to get node_id of host module run on + :return: + Single node information structure + """ + + if node_id is None: + node_id = self.get_swarm_node_id() + + if node_id is None: + self.fail("Failed to get node information.") + + try: + node_info = self.inspect_node(node_id=node_id) + except APIError as exc: + if exc.status_code == 503: + self.fail("Cannot inspect node: To inspect node execute module on Swarm Manager") + if exc.status_code == 404: + if skip_missing: + return None + self.fail("Error while reading from Swarm manager: %s" % to_native(exc)) + except Exception as exc: + self.fail("Error inspecting swarm node: %s" % exc) + + json_str = json.dumps(node_info, ensure_ascii=False) + node_info = json.loads(json_str) + + if 'ManagerStatus' in node_info: + if node_info['ManagerStatus'].get('Leader'): + # This is workaround of bug in Docker when in some cases the Leader IP is 0.0.0.0 + # Check moby/moby#35437 for details + count_colons = node_info['ManagerStatus']['Addr'].count(":") + if count_colons == 1: + swarm_leader_ip = node_info['ManagerStatus']['Addr'].split(":", 1)[0] or node_info['Status']['Addr'] + else: + swarm_leader_ip = node_info['Status']['Addr'] + node_info['Status']['Addr'] = swarm_leader_ip + return node_info + + def get_all_nodes_inspect(self): + """ + Returns Swarm node info as in 'docker node inspect' command about all registered nodes + + :return: + Structure with information about all nodes + """ + try: + node_info = self.nodes() + except APIError as exc: + if exc.status_code == 503: + self.fail("Cannot inspect node: To inspect node execute module on Swarm Manager") + self.fail("Error while reading from Swarm manager: %s" % to_native(exc)) + except Exception as exc: + self.fail("Error inspecting swarm node: %s" % exc) + + json_str = json.dumps(node_info, ensure_ascii=False) + node_info = json.loads(json_str) + return node_info + + def get_all_nodes_list(self, output='short'): + """ + Returns list of nodes registered in Swarm + + :param output: Defines format of returned data + :return: + If 'output' is 'short' then return data is list of nodes hostnames registered in Swarm, + if 'output' is 'long' then returns data is list of dict containing the attributes as in + output of command 'docker node ls' + """ + nodes_list = [] + + nodes_inspect = self.get_all_nodes_inspect() + if nodes_inspect is None: + return None + + if output == 'short': + for node in nodes_inspect: + nodes_list.append(node['Description']['Hostname']) + elif output == 'long': + for node in nodes_inspect: + node_property = {} + + node_property.update({'ID': node['ID']}) + node_property.update({'Hostname': node['Description']['Hostname']}) + node_property.update({'Status': node['Status']['State']}) + node_property.update({'Availability': node['Spec']['Availability']}) + if 'ManagerStatus' in node: + if node['ManagerStatus']['Leader'] is True: + node_property.update({'Leader': True}) + node_property.update({'ManagerStatus': node['ManagerStatus']['Reachability']}) + node_property.update({'EngineVersion': node['Description']['Engine']['EngineVersion']}) + + nodes_list.append(node_property) + else: + return None + + return nodes_list + + def get_node_name_by_id(self, nodeid): + return self.get_node_inspect(nodeid)['Description']['Hostname'] + + def get_unlock_key(self): + if self.docker_py_version < LooseVersion('2.7.0'): + return None + return super(AnsibleDockerSwarmClient, self).get_unlock_key() + + def get_service_inspect(self, service_id, skip_missing=False): + """ + Returns Swarm service info as in 'docker service inspect' command about single service + + :param service_id: service ID or name + :param skip_missing: if True then function will return None instead of failing the task + :return: + Single service information structure + """ + try: + service_info = self.inspect_service(service_id) + except NotFound as exc: + if skip_missing is False: + self.fail("Error while reading from Swarm manager: %s" % to_native(exc)) + else: + return None + except APIError as exc: + if exc.status_code == 503: + self.fail("Cannot inspect service: To inspect service execute module on Swarm Manager") + self.fail("Error inspecting swarm service: %s" % exc) + except Exception as exc: + self.fail("Error inspecting swarm service: %s" % exc) + + json_str = json.dumps(service_info, ensure_ascii=False) + service_info = json.loads(json_str) + return service_info diff --git a/test/support/integration/plugins/module_utils/ec2.py b/test/support/integration/plugins/module_utils/ec2.py new file mode 100644 index 00000000..0d28108d --- /dev/null +++ b/test/support/integration/plugins/module_utils/ec2.py @@ -0,0 +1,758 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Michael DeHaan <michael.dehaan@gmail.com>, 2012-2013 +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +import os +import re +import sys +import traceback + +from ansible.module_utils.ansible_release import __version__ +from ansible.module_utils.basic import missing_required_lib, env_fallback +from ansible.module_utils._text import to_native, to_text +from ansible.module_utils.cloud import CloudRetry +from ansible.module_utils.six import string_types, binary_type, text_type +from ansible.module_utils.common.dict_transformations import ( + camel_dict_to_snake_dict, snake_dict_to_camel_dict, + _camel_to_snake, _snake_to_camel, +) + +BOTO_IMP_ERR = None +try: + import boto + import boto.ec2 # boto does weird import stuff + HAS_BOTO = True +except ImportError: + BOTO_IMP_ERR = traceback.format_exc() + HAS_BOTO = False + +BOTO3_IMP_ERR = None +try: + import boto3 + import botocore + HAS_BOTO3 = True +except Exception: + BOTO3_IMP_ERR = traceback.format_exc() + HAS_BOTO3 = False + +try: + # Although this is to allow Python 3 the ability to use the custom comparison as a key, Python 2.7 also + # uses this (and it works as expected). Python 2.6 will trigger the ImportError. + from functools import cmp_to_key + PY3_COMPARISON = True +except ImportError: + PY3_COMPARISON = False + + +class AnsibleAWSError(Exception): + pass + + +def _botocore_exception_maybe(): + """ + Allow for boto3 not being installed when using these utils by wrapping + botocore.exceptions instead of assigning from it directly. + """ + if HAS_BOTO3: + return botocore.exceptions.ClientError + return type(None) + + +class AWSRetry(CloudRetry): + base_class = _botocore_exception_maybe() + + @staticmethod + def status_code_from_exception(error): + return error.response['Error']['Code'] + + @staticmethod + def found(response_code, catch_extra_error_codes=None): + # This list of failures is based on this API Reference + # http://docs.aws.amazon.com/AWSEC2/latest/APIReference/errors-overview.html + # + # TooManyRequestsException comes from inside botocore when it + # does retrys, unfortunately however it does not try long + # enough to allow some services such as API Gateway to + # complete configuration. At the moment of writing there is a + # botocore/boto3 bug open to fix this. + # + # https://github.com/boto/boto3/issues/876 (and linked PRs etc) + retry_on = [ + 'RequestLimitExceeded', 'Unavailable', 'ServiceUnavailable', + 'InternalFailure', 'InternalError', 'TooManyRequestsException', + 'Throttling' + ] + if catch_extra_error_codes: + retry_on.extend(catch_extra_error_codes) + + return response_code in retry_on + + +def boto3_conn(module, conn_type=None, resource=None, region=None, endpoint=None, **params): + try: + return _boto3_conn(conn_type=conn_type, resource=resource, region=region, endpoint=endpoint, **params) + except ValueError as e: + module.fail_json(msg="Couldn't connect to AWS: %s" % to_native(e)) + except (botocore.exceptions.ProfileNotFound, botocore.exceptions.PartialCredentialsError, + botocore.exceptions.NoCredentialsError, botocore.exceptions.ConfigParseError) as e: + module.fail_json(msg=to_native(e)) + except botocore.exceptions.NoRegionError as e: + module.fail_json(msg="The %s module requires a region and none was found in configuration, " + "environment variables or module parameters" % module._name) + + +def _boto3_conn(conn_type=None, resource=None, region=None, endpoint=None, **params): + profile = params.pop('profile_name', None) + + if conn_type not in ['both', 'resource', 'client']: + raise ValueError('There is an issue in the calling code. You ' + 'must specify either both, resource, or client to ' + 'the conn_type parameter in the boto3_conn function ' + 'call') + + config = botocore.config.Config( + user_agent_extra='Ansible/{0}'.format(__version__), + ) + + if params.get('config') is not None: + config = config.merge(params.pop('config')) + if params.get('aws_config') is not None: + config = config.merge(params.pop('aws_config')) + + session = boto3.session.Session( + profile_name=profile, + ) + + if conn_type == 'resource': + return session.resource(resource, config=config, region_name=region, endpoint_url=endpoint, **params) + elif conn_type == 'client': + return session.client(resource, config=config, region_name=region, endpoint_url=endpoint, **params) + else: + client = session.client(resource, region_name=region, endpoint_url=endpoint, **params) + resource = session.resource(resource, region_name=region, endpoint_url=endpoint, **params) + return client, resource + + +boto3_inventory_conn = _boto3_conn + + +def boto_exception(err): + """ + Extracts the error message from a boto exception. + + :param err: Exception from boto + :return: Error message + """ + if hasattr(err, 'error_message'): + error = err.error_message + elif hasattr(err, 'message'): + error = str(err.message) + ' ' + str(err) + ' - ' + str(type(err)) + else: + error = '%s: %s' % (Exception, err) + + return error + + +def aws_common_argument_spec(): + return dict( + debug_botocore_endpoint_logs=dict(fallback=(env_fallback, ['ANSIBLE_DEBUG_BOTOCORE_LOGS']), default=False, type='bool'), + ec2_url=dict(), + aws_secret_key=dict(aliases=['ec2_secret_key', 'secret_key'], no_log=True), + aws_access_key=dict(aliases=['ec2_access_key', 'access_key']), + validate_certs=dict(default=True, type='bool'), + security_token=dict(aliases=['access_token'], no_log=True), + profile=dict(), + aws_config=dict(type='dict'), + ) + + +def ec2_argument_spec(): + spec = aws_common_argument_spec() + spec.update( + dict( + region=dict(aliases=['aws_region', 'ec2_region']), + ) + ) + return spec + + +def get_aws_region(module, boto3=False): + region = module.params.get('region') + + if region: + return region + + if 'AWS_REGION' in os.environ: + return os.environ['AWS_REGION'] + if 'AWS_DEFAULT_REGION' in os.environ: + return os.environ['AWS_DEFAULT_REGION'] + if 'EC2_REGION' in os.environ: + return os.environ['EC2_REGION'] + + if not boto3: + if not HAS_BOTO: + module.fail_json(msg=missing_required_lib('boto'), exception=BOTO_IMP_ERR) + # boto.config.get returns None if config not found + region = boto.config.get('Boto', 'aws_region') + if region: + return region + return boto.config.get('Boto', 'ec2_region') + + if not HAS_BOTO3: + module.fail_json(msg=missing_required_lib('boto3'), exception=BOTO3_IMP_ERR) + + # here we don't need to make an additional call, will default to 'us-east-1' if the below evaluates to None. + try: + profile_name = module.params.get('profile') + return botocore.session.Session(profile=profile_name).get_config_variable('region') + except botocore.exceptions.ProfileNotFound as e: + return None + + +def get_aws_connection_info(module, boto3=False): + + # Check module args for credentials, then check environment vars + # access_key + + ec2_url = module.params.get('ec2_url') + access_key = module.params.get('aws_access_key') + secret_key = module.params.get('aws_secret_key') + security_token = module.params.get('security_token') + region = get_aws_region(module, boto3) + profile_name = module.params.get('profile') + validate_certs = module.params.get('validate_certs') + config = module.params.get('aws_config') + + if not ec2_url: + if 'AWS_URL' in os.environ: + ec2_url = os.environ['AWS_URL'] + elif 'EC2_URL' in os.environ: + ec2_url = os.environ['EC2_URL'] + + if not access_key: + if os.environ.get('AWS_ACCESS_KEY_ID'): + access_key = os.environ['AWS_ACCESS_KEY_ID'] + elif os.environ.get('AWS_ACCESS_KEY'): + access_key = os.environ['AWS_ACCESS_KEY'] + elif os.environ.get('EC2_ACCESS_KEY'): + access_key = os.environ['EC2_ACCESS_KEY'] + elif HAS_BOTO and boto.config.get('Credentials', 'aws_access_key_id'): + access_key = boto.config.get('Credentials', 'aws_access_key_id') + elif HAS_BOTO and boto.config.get('default', 'aws_access_key_id'): + access_key = boto.config.get('default', 'aws_access_key_id') + else: + # in case access_key came in as empty string + access_key = None + + if not secret_key: + if os.environ.get('AWS_SECRET_ACCESS_KEY'): + secret_key = os.environ['AWS_SECRET_ACCESS_KEY'] + elif os.environ.get('AWS_SECRET_KEY'): + secret_key = os.environ['AWS_SECRET_KEY'] + elif os.environ.get('EC2_SECRET_KEY'): + secret_key = os.environ['EC2_SECRET_KEY'] + elif HAS_BOTO and boto.config.get('Credentials', 'aws_secret_access_key'): + secret_key = boto.config.get('Credentials', 'aws_secret_access_key') + elif HAS_BOTO and boto.config.get('default', 'aws_secret_access_key'): + secret_key = boto.config.get('default', 'aws_secret_access_key') + else: + # in case secret_key came in as empty string + secret_key = None + + if not security_token: + if os.environ.get('AWS_SECURITY_TOKEN'): + security_token = os.environ['AWS_SECURITY_TOKEN'] + elif os.environ.get('AWS_SESSION_TOKEN'): + security_token = os.environ['AWS_SESSION_TOKEN'] + elif os.environ.get('EC2_SECURITY_TOKEN'): + security_token = os.environ['EC2_SECURITY_TOKEN'] + elif HAS_BOTO and boto.config.get('Credentials', 'aws_security_token'): + security_token = boto.config.get('Credentials', 'aws_security_token') + elif HAS_BOTO and boto.config.get('default', 'aws_security_token'): + security_token = boto.config.get('default', 'aws_security_token') + else: + # in case secret_token came in as empty string + security_token = None + + if HAS_BOTO3 and boto3: + boto_params = dict(aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + aws_session_token=security_token) + boto_params['verify'] = validate_certs + + if profile_name: + boto_params = dict(aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) + boto_params['profile_name'] = profile_name + + else: + boto_params = dict(aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + security_token=security_token) + + # only set profile_name if passed as an argument + if profile_name: + boto_params['profile_name'] = profile_name + + boto_params['validate_certs'] = validate_certs + + if config is not None: + if HAS_BOTO3 and boto3: + boto_params['aws_config'] = botocore.config.Config(**config) + elif HAS_BOTO and not boto3: + if 'user_agent' in config: + sys.modules["boto.connection"].UserAgent = config['user_agent'] + + for param, value in boto_params.items(): + if isinstance(value, binary_type): + boto_params[param] = text_type(value, 'utf-8', 'strict') + + return region, ec2_url, boto_params + + +def get_ec2_creds(module): + ''' for compatibility mode with old modules that don't/can't yet + use ec2_connect method ''' + region, ec2_url, boto_params = get_aws_connection_info(module) + return ec2_url, boto_params['aws_access_key_id'], boto_params['aws_secret_access_key'], region + + +def boto_fix_security_token_in_profile(conn, profile_name): + ''' monkey patch for boto issue boto/boto#2100 ''' + profile = 'profile ' + profile_name + if boto.config.has_option(profile, 'aws_security_token'): + conn.provider.set_security_token(boto.config.get(profile, 'aws_security_token')) + return conn + + +def connect_to_aws(aws_module, region, **params): + try: + conn = aws_module.connect_to_region(region, **params) + except(boto.provider.ProfileNotFoundError): + raise AnsibleAWSError("Profile given for AWS was not found. Please fix and retry.") + if not conn: + if region not in [aws_module_region.name for aws_module_region in aws_module.regions()]: + raise AnsibleAWSError("Region %s does not seem to be available for aws module %s. If the region definitely exists, you may need to upgrade " + "boto or extend with endpoints_path" % (region, aws_module.__name__)) + else: + raise AnsibleAWSError("Unknown problem connecting to region %s for aws module %s." % (region, aws_module.__name__)) + if params.get('profile_name'): + conn = boto_fix_security_token_in_profile(conn, params['profile_name']) + return conn + + +def ec2_connect(module): + + """ Return an ec2 connection""" + + region, ec2_url, boto_params = get_aws_connection_info(module) + + # If we have a region specified, connect to its endpoint. + if region: + try: + ec2 = connect_to_aws(boto.ec2, region, **boto_params) + except (boto.exception.NoAuthHandlerFound, AnsibleAWSError, boto.provider.ProfileNotFoundError) as e: + module.fail_json(msg=str(e)) + # Otherwise, no region so we fallback to the old connection method + elif ec2_url: + try: + ec2 = boto.connect_ec2_endpoint(ec2_url, **boto_params) + except (boto.exception.NoAuthHandlerFound, AnsibleAWSError, boto.provider.ProfileNotFoundError) as e: + module.fail_json(msg=str(e)) + else: + module.fail_json(msg="Either region or ec2_url must be specified") + + return ec2 + + +def ansible_dict_to_boto3_filter_list(filters_dict): + + """ Convert an Ansible dict of filters to list of dicts that boto3 can use + Args: + filters_dict (dict): Dict of AWS filters. + Basic Usage: + >>> filters = {'some-aws-id': 'i-01234567'} + >>> ansible_dict_to_boto3_filter_list(filters) + { + 'some-aws-id': 'i-01234567' + } + Returns: + List: List of AWS filters and their values + [ + { + 'Name': 'some-aws-id', + 'Values': [ + 'i-01234567', + ] + } + ] + """ + + filters_list = [] + for k, v in filters_dict.items(): + filter_dict = {'Name': k} + if isinstance(v, string_types): + filter_dict['Values'] = [v] + else: + filter_dict['Values'] = v + + filters_list.append(filter_dict) + + return filters_list + + +def boto3_tag_list_to_ansible_dict(tags_list, tag_name_key_name=None, tag_value_key_name=None): + + """ Convert a boto3 list of resource tags to a flat dict of key:value pairs + Args: + tags_list (list): List of dicts representing AWS tags. + tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key") + tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value") + Basic Usage: + >>> tags_list = [{'Key': 'MyTagKey', 'Value': 'MyTagValue'}] + >>> boto3_tag_list_to_ansible_dict(tags_list) + [ + { + 'Key': 'MyTagKey', + 'Value': 'MyTagValue' + } + ] + Returns: + Dict: Dict of key:value pairs representing AWS tags + { + 'MyTagKey': 'MyTagValue', + } + """ + + if tag_name_key_name and tag_value_key_name: + tag_candidates = {tag_name_key_name: tag_value_key_name} + else: + tag_candidates = {'key': 'value', 'Key': 'Value'} + + if not tags_list: + return {} + for k, v in tag_candidates.items(): + if k in tags_list[0] and v in tags_list[0]: + return dict((tag[k], tag[v]) for tag in tags_list) + raise ValueError("Couldn't find tag key (candidates %s) in tag list %s" % (str(tag_candidates), str(tags_list))) + + +def ansible_dict_to_boto3_tag_list(tags_dict, tag_name_key_name='Key', tag_value_key_name='Value'): + + """ Convert a flat dict of key:value pairs representing AWS resource tags to a boto3 list of dicts + Args: + tags_dict (dict): Dict representing AWS resource tags. + tag_name_key_name (str): Value to use as the key for all tag keys (useful because boto3 doesn't always use "Key") + tag_value_key_name (str): Value to use as the key for all tag values (useful because boto3 doesn't always use "Value") + Basic Usage: + >>> tags_dict = {'MyTagKey': 'MyTagValue'} + >>> ansible_dict_to_boto3_tag_list(tags_dict) + { + 'MyTagKey': 'MyTagValue' + } + Returns: + List: List of dicts containing tag keys and values + [ + { + 'Key': 'MyTagKey', + 'Value': 'MyTagValue' + } + ] + """ + + tags_list = [] + for k, v in tags_dict.items(): + tags_list.append({tag_name_key_name: k, tag_value_key_name: to_native(v)}) + + return tags_list + + +def get_ec2_security_group_ids_from_names(sec_group_list, ec2_connection, vpc_id=None, boto3=True): + + """ Return list of security group IDs from security group names. Note that security group names are not unique + across VPCs. If a name exists across multiple VPCs and no VPC ID is supplied, all matching IDs will be returned. This + will probably lead to a boto exception if you attempt to assign both IDs to a resource so ensure you wrap the call in + a try block + """ + + def get_sg_name(sg, boto3): + + if boto3: + return sg['GroupName'] + else: + return sg.name + + def get_sg_id(sg, boto3): + + if boto3: + return sg['GroupId'] + else: + return sg.id + + sec_group_id_list = [] + + if isinstance(sec_group_list, string_types): + sec_group_list = [sec_group_list] + + # Get all security groups + if boto3: + if vpc_id: + filters = [ + { + 'Name': 'vpc-id', + 'Values': [ + vpc_id, + ] + } + ] + all_sec_groups = ec2_connection.describe_security_groups(Filters=filters)['SecurityGroups'] + else: + all_sec_groups = ec2_connection.describe_security_groups()['SecurityGroups'] + else: + if vpc_id: + filters = {'vpc-id': vpc_id} + all_sec_groups = ec2_connection.get_all_security_groups(filters=filters) + else: + all_sec_groups = ec2_connection.get_all_security_groups() + + unmatched = set(sec_group_list).difference(str(get_sg_name(all_sg, boto3)) for all_sg in all_sec_groups) + sec_group_name_list = list(set(sec_group_list) - set(unmatched)) + + if len(unmatched) > 0: + # If we have unmatched names that look like an ID, assume they are + import re + sec_group_id_list = [sg for sg in unmatched if re.match('sg-[a-fA-F0-9]+$', sg)] + still_unmatched = [sg for sg in unmatched if not re.match('sg-[a-fA-F0-9]+$', sg)] + if len(still_unmatched) > 0: + raise ValueError("The following group names are not valid: %s" % ', '.join(still_unmatched)) + + sec_group_id_list += [str(get_sg_id(all_sg, boto3)) for all_sg in all_sec_groups if str(get_sg_name(all_sg, boto3)) in sec_group_name_list] + + return sec_group_id_list + + +def _hashable_policy(policy, policy_list): + """ + Takes a policy and returns a list, the contents of which are all hashable and sorted. + Example input policy: + {'Version': '2012-10-17', + 'Statement': [{'Action': 's3:PutObjectAcl', + 'Sid': 'AddCannedAcl2', + 'Resource': 'arn:aws:s3:::test_policy/*', + 'Effect': 'Allow', + 'Principal': {'AWS': ['arn:aws:iam::XXXXXXXXXXXX:user/username1', 'arn:aws:iam::XXXXXXXXXXXX:user/username2']} + }]} + Returned value: + [('Statement', ((('Action', (u's3:PutObjectAcl',)), + ('Effect', (u'Allow',)), + ('Principal', ('AWS', ((u'arn:aws:iam::XXXXXXXXXXXX:user/username1',), (u'arn:aws:iam::XXXXXXXXXXXX:user/username2',)))), + ('Resource', (u'arn:aws:s3:::test_policy/*',)), ('Sid', (u'AddCannedAcl2',)))), + ('Version', (u'2012-10-17',)))] + + """ + # Amazon will automatically convert bool and int to strings for us + if isinstance(policy, bool): + return tuple([str(policy).lower()]) + elif isinstance(policy, int): + return tuple([str(policy)]) + + if isinstance(policy, list): + for each in policy: + tupleified = _hashable_policy(each, []) + if isinstance(tupleified, list): + tupleified = tuple(tupleified) + policy_list.append(tupleified) + elif isinstance(policy, string_types) or isinstance(policy, binary_type): + policy = to_text(policy) + # convert root account ARNs to just account IDs + if policy.startswith('arn:aws:iam::') and policy.endswith(':root'): + policy = policy.split(':')[4] + return [policy] + elif isinstance(policy, dict): + sorted_keys = list(policy.keys()) + sorted_keys.sort() + for key in sorted_keys: + tupleified = _hashable_policy(policy[key], []) + if isinstance(tupleified, list): + tupleified = tuple(tupleified) + policy_list.append((key, tupleified)) + + # ensure we aren't returning deeply nested structures of length 1 + if len(policy_list) == 1 and isinstance(policy_list[0], tuple): + policy_list = policy_list[0] + if isinstance(policy_list, list): + if PY3_COMPARISON: + policy_list.sort(key=cmp_to_key(py3cmp)) + else: + policy_list.sort() + return policy_list + + +def py3cmp(a, b): + """ Python 2 can sort lists of mixed types. Strings < tuples. Without this function this fails on Python 3.""" + try: + if a > b: + return 1 + elif a < b: + return -1 + else: + return 0 + except TypeError as e: + # check to see if they're tuple-string + # always say strings are less than tuples (to maintain compatibility with python2) + str_ind = to_text(e).find('str') + tup_ind = to_text(e).find('tuple') + if -1 not in (str_ind, tup_ind): + if str_ind < tup_ind: + return -1 + elif tup_ind < str_ind: + return 1 + raise + + +def compare_policies(current_policy, new_policy): + """ Compares the existing policy and the updated policy + Returns True if there is a difference between policies. + """ + return set(_hashable_policy(new_policy, [])) != set(_hashable_policy(current_policy, [])) + + +def sort_json_policy_dict(policy_dict): + + """ Sort any lists in an IAM JSON policy so that comparison of two policies with identical values but + different orders will return true + Args: + policy_dict (dict): Dict representing IAM JSON policy. + Basic Usage: + >>> my_iam_policy = {'Principle': {'AWS':["31","7","14","101"]} + >>> sort_json_policy_dict(my_iam_policy) + Returns: + Dict: Will return a copy of the policy as a Dict but any List will be sorted + { + 'Principle': { + 'AWS': [ '7', '14', '31', '101' ] + } + } + """ + + def value_is_list(my_list): + + checked_list = [] + for item in my_list: + if isinstance(item, dict): + checked_list.append(sort_json_policy_dict(item)) + elif isinstance(item, list): + checked_list.append(value_is_list(item)) + else: + checked_list.append(item) + + # Sort list. If it's a list of dictionaries, sort by tuple of key-value + # pairs, since Python 3 doesn't allow comparisons such as `<` between dictionaries. + checked_list.sort(key=lambda x: sorted(x.items()) if isinstance(x, dict) else x) + return checked_list + + ordered_policy_dict = {} + for key, value in policy_dict.items(): + if isinstance(value, dict): + ordered_policy_dict[key] = sort_json_policy_dict(value) + elif isinstance(value, list): + ordered_policy_dict[key] = value_is_list(value) + else: + ordered_policy_dict[key] = value + + return ordered_policy_dict + + +def map_complex_type(complex_type, type_map): + """ + Allows to cast elements within a dictionary to a specific type + Example of usage: + + DEPLOYMENT_CONFIGURATION_TYPE_MAP = { + 'maximum_percent': 'int', + 'minimum_healthy_percent': 'int' + } + + deployment_configuration = map_complex_type(module.params['deployment_configuration'], + DEPLOYMENT_CONFIGURATION_TYPE_MAP) + + This ensures all keys within the root element are casted and valid integers + """ + + if complex_type is None: + return + new_type = type(complex_type)() + if isinstance(complex_type, dict): + for key in complex_type: + if key in type_map: + if isinstance(type_map[key], list): + new_type[key] = map_complex_type( + complex_type[key], + type_map[key][0]) + else: + new_type[key] = map_complex_type( + complex_type[key], + type_map[key]) + else: + return complex_type + elif isinstance(complex_type, list): + for i in range(len(complex_type)): + new_type.append(map_complex_type( + complex_type[i], + type_map)) + elif type_map: + return globals()['__builtins__'][type_map](complex_type) + return new_type + + +def compare_aws_tags(current_tags_dict, new_tags_dict, purge_tags=True): + """ + Compare two dicts of AWS tags. Dicts are expected to of been created using 'boto3_tag_list_to_ansible_dict' helper function. + Two dicts are returned - the first is tags to be set, the second is any tags to remove. Since the AWS APIs differ + these may not be able to be used out of the box. + + :param current_tags_dict: + :param new_tags_dict: + :param purge_tags: + :return: tag_key_value_pairs_to_set: a dict of key value pairs that need to be set in AWS. If all tags are identical this dict will be empty + :return: tag_keys_to_unset: a list of key names (type str) that need to be unset in AWS. If no tags need to be unset this list will be empty + """ + + tag_key_value_pairs_to_set = {} + tag_keys_to_unset = [] + + for key in current_tags_dict.keys(): + if key not in new_tags_dict and purge_tags: + tag_keys_to_unset.append(key) + + for key in set(new_tags_dict.keys()) - set(tag_keys_to_unset): + if to_text(new_tags_dict[key]) != current_tags_dict.get(key): + tag_key_value_pairs_to_set[key] = new_tags_dict[key] + + return tag_key_value_pairs_to_set, tag_keys_to_unset diff --git a/test/support/integration/plugins/module_utils/ecs/__init__.py b/test/support/integration/plugins/module_utils/ecs/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/ecs/__init__.py diff --git a/test/support/integration/plugins/module_utils/ecs/api.py b/test/support/integration/plugins/module_utils/ecs/api.py new file mode 100644 index 00000000..d89b0333 --- /dev/null +++ b/test/support/integration/plugins/module_utils/ecs/api.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- + +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is licensed under the +# Modified BSD License. Modules you write using this snippet, which is embedded +# dynamically by Ansible, still belong to the author of the module, and may assign +# their own license to the complete work. +# +# Copyright (c), Entrust Datacard Corporation, 2019 +# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) + +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import json +import os +import re +import time +import traceback + +from ansible.module_utils._text import to_text, to_native +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.six.moves.urllib.parse import urlencode +from ansible.module_utils.six.moves.urllib.error import HTTPError +from ansible.module_utils.urls import Request + +YAML_IMP_ERR = None +try: + import yaml +except ImportError: + YAML_FOUND = False + YAML_IMP_ERR = traceback.format_exc() +else: + YAML_FOUND = True + +valid_file_format = re.compile(r".*(\.)(yml|yaml|json)$") + + +def ecs_client_argument_spec(): + return dict( + entrust_api_user=dict(type='str', required=True), + entrust_api_key=dict(type='str', required=True, no_log=True), + entrust_api_client_cert_path=dict(type='path', required=True), + entrust_api_client_cert_key_path=dict(type='path', required=True, no_log=True), + entrust_api_specification_path=dict(type='path', default='https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml'), + ) + + +class SessionConfigurationException(Exception): + """ Raised if we cannot configure a session with the API """ + + pass + + +class RestOperationException(Exception): + """ Encapsulate a REST API error """ + + def __init__(self, error): + self.status = to_native(error.get("status", None)) + self.errors = [to_native(err.get("message")) for err in error.get("errors", {})] + self.message = to_native(" ".join(self.errors)) + + +def generate_docstring(operation_spec): + """Generate a docstring for an operation defined in operation_spec (swagger)""" + # Description of the operation + docs = operation_spec.get("description", "No Description") + docs += "\n\n" + + # Parameters of the operation + parameters = operation_spec.get("parameters", []) + if len(parameters) != 0: + docs += "\tArguments:\n\n" + for parameter in parameters: + docs += "{0} ({1}:{2}): {3}\n".format( + parameter.get("name"), + parameter.get("type", "No Type"), + "Required" if parameter.get("required", False) else "Not Required", + parameter.get("description"), + ) + + return docs + + +def bind(instance, method, operation_spec): + def binding_scope_fn(*args, **kwargs): + return method(instance, *args, **kwargs) + + # Make sure we don't confuse users; add the proper name and documentation to the function. + # Users can use !help(<function>) to get help on the function from interactive python or pdb + operation_name = operation_spec.get("operationId").split("Using")[0] + binding_scope_fn.__name__ = str(operation_name) + binding_scope_fn.__doc__ = generate_docstring(operation_spec) + + return binding_scope_fn + + +class RestOperation(object): + def __init__(self, session, uri, method, parameters=None): + self.session = session + self.method = method + if parameters is None: + self.parameters = {} + else: + self.parameters = parameters + self.url = "{scheme}://{host}{base_path}{uri}".format(scheme="https", host=session._spec.get("host"), base_path=session._spec.get("basePath"), uri=uri) + + def restmethod(self, *args, **kwargs): + """Do the hard work of making the request here""" + + # gather named path parameters and do substitution on the URL + if self.parameters: + path_parameters = {} + body_parameters = {} + query_parameters = {} + for x in self.parameters: + expected_location = x.get("in") + key_name = x.get("name", None) + key_value = kwargs.get(key_name, None) + if expected_location == "path" and key_name and key_value: + path_parameters.update({key_name: key_value}) + elif expected_location == "body" and key_name and key_value: + body_parameters.update({key_name: key_value}) + elif expected_location == "query" and key_name and key_value: + query_parameters.update({key_name: key_value}) + + if len(body_parameters.keys()) >= 1: + body_parameters = body_parameters.get(list(body_parameters.keys())[0]) + else: + body_parameters = None + else: + path_parameters = {} + query_parameters = {} + body_parameters = None + + # This will fail if we have not set path parameters with a KeyError + url = self.url.format(**path_parameters) + if query_parameters: + # modify the URL to add path parameters + url = url + "?" + urlencode(query_parameters) + + try: + if body_parameters: + body_parameters_json = json.dumps(body_parameters) + response = self.session.request.open(method=self.method, url=url, data=body_parameters_json) + else: + response = self.session.request.open(method=self.method, url=url) + request_error = False + except HTTPError as e: + # An HTTPError has the same methods available as a valid response from request.open + response = e + request_error = True + + # Return the result if JSON and success ({} for empty responses) + # Raise an exception if there was a failure. + try: + result_code = response.getcode() + result = json.loads(response.read()) + except ValueError: + result = {} + + if result or result == {}: + if result_code and result_code < 400: + return result + else: + raise RestOperationException(result) + + # Raise a generic RestOperationException if this fails + raise RestOperationException({"status": result_code, "errors": [{"message": "REST Operation Failed"}]}) + + +class Resource(object): + """ Implement basic CRUD operations against a path. """ + + def __init__(self, session): + self.session = session + self.parameters = {} + + for url in session._spec.get("paths").keys(): + methods = session._spec.get("paths").get(url) + for method in methods.keys(): + operation_spec = methods.get(method) + operation_name = operation_spec.get("operationId", None) + parameters = operation_spec.get("parameters") + + if not operation_name: + if method.lower() == "post": + operation_name = "Create" + elif method.lower() == "get": + operation_name = "Get" + elif method.lower() == "put": + operation_name = "Update" + elif method.lower() == "delete": + operation_name = "Delete" + elif method.lower() == "patch": + operation_name = "Patch" + else: + raise SessionConfigurationException(to_native("Invalid REST method type {0}".format(method))) + + # Get the non-parameter parts of the URL and append to the operation name + # e.g /application/version -> GetApplicationVersion + # e.g. /application/{id} -> GetApplication + # This may lead to duplicates, which we must prevent. + operation_name += re.sub(r"{(.*)}", "", url).replace("/", " ").title().replace(" ", "") + operation_spec["operationId"] = operation_name + + op = RestOperation(session, url, method, parameters) + setattr(self, operation_name, bind(self, op.restmethod, operation_spec)) + + +# Session to encapsulate the connection parameters of the module_utils Request object, the api spec, etc +class ECSSession(object): + def __init__(self, name, **kwargs): + """ + Initialize our session + """ + + self._set_config(name, **kwargs) + + def client(self): + resource = Resource(self) + return resource + + def _set_config(self, name, **kwargs): + headers = { + "Content-Type": "application/json", + "Connection": "keep-alive", + } + self.request = Request(headers=headers, timeout=60) + + configurators = [self._read_config_vars] + for configurator in configurators: + self._config = configurator(name, **kwargs) + if self._config: + break + if self._config is None: + raise SessionConfigurationException(to_native("No Configuration Found.")) + + # set up auth if passed + entrust_api_user = self.get_config("entrust_api_user") + entrust_api_key = self.get_config("entrust_api_key") + if entrust_api_user and entrust_api_key: + self.request.url_username = entrust_api_user + self.request.url_password = entrust_api_key + else: + raise SessionConfigurationException(to_native("User and key must be provided.")) + + # set up client certificate if passed (support all-in one or cert + key) + entrust_api_cert = self.get_config("entrust_api_cert") + entrust_api_cert_key = self.get_config("entrust_api_cert_key") + if entrust_api_cert: + self.request.client_cert = entrust_api_cert + if entrust_api_cert_key: + self.request.client_key = entrust_api_cert_key + else: + raise SessionConfigurationException(to_native("Client certificate for authentication to the API must be provided.")) + + # set up the spec + entrust_api_specification_path = self.get_config("entrust_api_specification_path") + + if not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path): + raise SessionConfigurationException(to_native("OpenAPI specification was not found at location {0}.".format(entrust_api_specification_path))) + if not valid_file_format.match(entrust_api_specification_path): + raise SessionConfigurationException(to_native("OpenAPI specification filename must end in .json, .yml or .yaml")) + + self.verify = True + + if entrust_api_specification_path.startswith("http"): + try: + http_response = Request().open(method="GET", url=entrust_api_specification_path) + http_response_contents = http_response.read() + if entrust_api_specification_path.endswith(".json"): + self._spec = json.load(http_response_contents) + elif entrust_api_specification_path.endswith(".yml") or entrust_api_specification_path.endswith(".yaml"): + self._spec = yaml.safe_load(http_response_contents) + except HTTPError as e: + raise SessionConfigurationException(to_native("Error downloading specification from address '{0}', received error code '{1}'".format( + entrust_api_specification_path, e.getcode()))) + else: + with open(entrust_api_specification_path) as f: + if ".json" in entrust_api_specification_path: + self._spec = json.load(f) + elif ".yml" in entrust_api_specification_path or ".yaml" in entrust_api_specification_path: + self._spec = yaml.safe_load(f) + + def get_config(self, item): + return self._config.get(item, None) + + def _read_config_vars(self, name, **kwargs): + """ Read configuration from variables passed to the module. """ + config = {} + + entrust_api_specification_path = kwargs.get("entrust_api_specification_path") + if not entrust_api_specification_path or (not entrust_api_specification_path.startswith("http") and not os.path.isfile(entrust_api_specification_path)): + raise SessionConfigurationException( + to_native( + "Parameter provided for entrust_api_specification_path of value '{0}' was not a valid file path or HTTPS address.".format( + entrust_api_specification_path + ) + ) + ) + + for required_file in ["entrust_api_cert", "entrust_api_cert_key"]: + file_path = kwargs.get(required_file) + if not file_path or not os.path.isfile(file_path): + raise SessionConfigurationException( + to_native("Parameter provided for {0} of value '{1}' was not a valid file path.".format(required_file, file_path)) + ) + + for required_var in ["entrust_api_user", "entrust_api_key"]: + if not kwargs.get(required_var): + raise SessionConfigurationException(to_native("Parameter provided for {0} was missing.".format(required_var))) + + config["entrust_api_cert"] = kwargs.get("entrust_api_cert") + config["entrust_api_cert_key"] = kwargs.get("entrust_api_cert_key") + config["entrust_api_specification_path"] = kwargs.get("entrust_api_specification_path") + config["entrust_api_user"] = kwargs.get("entrust_api_user") + config["entrust_api_key"] = kwargs.get("entrust_api_key") + + return config + + +def ECSClient(entrust_api_user=None, entrust_api_key=None, entrust_api_cert=None, entrust_api_cert_key=None, entrust_api_specification_path=None): + """Create an ECS client""" + + if not YAML_FOUND: + raise SessionConfigurationException(missing_required_lib("PyYAML"), exception=YAML_IMP_ERR) + + if entrust_api_specification_path is None: + entrust_api_specification_path = "https://cloud.entrust.net/EntrustCloud/documentation/cms-api-2.1.0.yaml" + + # Not functionally necessary with current uses of this module_util, but better to be explicit for future use cases + entrust_api_user = to_text(entrust_api_user) + entrust_api_key = to_text(entrust_api_key) + entrust_api_cert_key = to_text(entrust_api_cert_key) + entrust_api_specification_path = to_text(entrust_api_specification_path) + + return ECSSession( + "ecs", + entrust_api_user=entrust_api_user, + entrust_api_key=entrust_api_key, + entrust_api_cert=entrust_api_cert, + entrust_api_cert_key=entrust_api_cert_key, + entrust_api_specification_path=entrust_api_specification_path, + ).client() diff --git a/test/support/integration/plugins/module_utils/mysql.py b/test/support/integration/plugins/module_utils/mysql.py new file mode 100644 index 00000000..46198f36 --- /dev/null +++ b/test/support/integration/plugins/module_utils/mysql.py @@ -0,0 +1,106 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Jonathan Mainguy <jon@soh.re>, 2015 +# Most of this was originally added by Sven Schliesing @muffl0n in the mysql_user.py module +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os + +try: + import pymysql as mysql_driver + _mysql_cursor_param = 'cursor' +except ImportError: + try: + import MySQLdb as mysql_driver + import MySQLdb.cursors + _mysql_cursor_param = 'cursorclass' + except ImportError: + mysql_driver = None + +mysql_driver_fail_msg = 'The PyMySQL (Python 2.7 and Python 3.X) or MySQL-python (Python 2.X) module is required.' + + +def mysql_connect(module, login_user=None, login_password=None, config_file='', ssl_cert=None, ssl_key=None, ssl_ca=None, db=None, cursor_class=None, + connect_timeout=30, autocommit=False): + config = {} + + if ssl_ca is not None or ssl_key is not None or ssl_cert is not None: + config['ssl'] = {} + + if module.params['login_unix_socket']: + config['unix_socket'] = module.params['login_unix_socket'] + else: + config['host'] = module.params['login_host'] + config['port'] = module.params['login_port'] + + if os.path.exists(config_file): + config['read_default_file'] = config_file + + # If login_user or login_password are given, they should override the + # config file + if login_user is not None: + config['user'] = login_user + if login_password is not None: + config['passwd'] = login_password + if ssl_cert is not None: + config['ssl']['cert'] = ssl_cert + if ssl_key is not None: + config['ssl']['key'] = ssl_key + if ssl_ca is not None: + config['ssl']['ca'] = ssl_ca + if db is not None: + config['db'] = db + if connect_timeout is not None: + config['connect_timeout'] = connect_timeout + + if _mysql_cursor_param == 'cursor': + # In case of PyMySQL driver: + db_connection = mysql_driver.connect(autocommit=autocommit, **config) + else: + # In case of MySQLdb driver + db_connection = mysql_driver.connect(**config) + if autocommit: + db_connection.autocommit(True) + + if cursor_class == 'DictCursor': + return db_connection.cursor(**{_mysql_cursor_param: mysql_driver.cursors.DictCursor}), db_connection + else: + return db_connection.cursor(), db_connection + + +def mysql_common_argument_spec(): + return dict( + login_user=dict(type='str', default=None), + login_password=dict(type='str', no_log=True), + login_host=dict(type='str', default='localhost'), + login_port=dict(type='int', default=3306), + login_unix_socket=dict(type='str'), + config_file=dict(type='path', default='~/.my.cnf'), + connect_timeout=dict(type='int', default=30), + client_cert=dict(type='path', aliases=['ssl_cert']), + client_key=dict(type='path', aliases=['ssl_key']), + ca_cert=dict(type='path', aliases=['ssl_ca']), + ) diff --git a/test/support/integration/plugins/module_utils/net_tools/__init__.py b/test/support/integration/plugins/module_utils/net_tools/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/net_tools/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/__init__.py b/test/support/integration/plugins/module_utils/network/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/common/__init__.py b/test/support/integration/plugins/module_utils/network/common/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/common/__init__.py diff --git a/test/support/integration/plugins/module_utils/network/common/utils.py b/test/support/integration/plugins/module_utils/network/common/utils.py new file mode 100644 index 00000000..80317387 --- /dev/null +++ b/test/support/integration/plugins/module_utils/network/common/utils.py @@ -0,0 +1,643 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# (c) 2016 Red Hat Inc. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Networking tools for network modules only + +import re +import ast +import operator +import socket +import json + +from itertools import chain + +from ansible.module_utils._text import to_text, to_bytes +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils import basic +from ansible.module_utils.parsing.convert_bool import boolean + +# Backwards compatibility for 3rd party modules +# TODO(pabelanger): With move to ansible.netcommon, we should clean this code +# up and have modules import directly themself. +from ansible.module_utils.common.network import ( # noqa: F401 + to_bits, is_netmask, is_masklen, to_netmask, to_masklen, to_subnet, to_ipv6_network, VALID_MASKS +) + +try: + from jinja2 import Environment, StrictUndefined + from jinja2.exceptions import UndefinedError + HAS_JINJA2 = True +except ImportError: + HAS_JINJA2 = False + + +OPERATORS = frozenset(['ge', 'gt', 'eq', 'neq', 'lt', 'le']) +ALIASES = frozenset([('min', 'ge'), ('max', 'le'), ('exactly', 'eq'), ('neq', 'ne')]) + + +def to_list(val): + if isinstance(val, (list, tuple, set)): + return list(val) + elif val is not None: + return [val] + else: + return list() + + +def to_lines(stdout): + for item in stdout: + if isinstance(item, string_types): + item = to_text(item).split('\n') + yield item + + +def transform_commands(module): + transform = ComplexList(dict( + command=dict(key=True), + output=dict(), + prompt=dict(type='list'), + answer=dict(type='list'), + newline=dict(type='bool', default=True), + sendonly=dict(type='bool', default=False), + check_all=dict(type='bool', default=False), + ), module) + + return transform(module.params['commands']) + + +def sort_list(val): + if isinstance(val, list): + return sorted(val) + return val + + +class Entity(object): + """Transforms a dict to with an argument spec + + This class will take a dict and apply an Ansible argument spec to the + values. The resulting dict will contain all of the keys in the param + with appropriate values set. + + Example:: + + argument_spec = dict( + command=dict(key=True), + display=dict(default='text', choices=['text', 'json']), + validate=dict(type='bool') + ) + transform = Entity(module, argument_spec) + value = dict(command='foo') + result = transform(value) + print result + {'command': 'foo', 'display': 'text', 'validate': None} + + Supported argument spec: + * key - specifies how to map a single value to a dict + * read_from - read and apply the argument_spec from the module + * required - a value is required + * type - type of value (uses AnsibleModule type checker) + * fallback - implements fallback function + * choices - set of valid options + * default - default value + """ + + def __init__(self, module, attrs=None, args=None, keys=None, from_argspec=False): + args = [] if args is None else args + + self._attributes = attrs or {} + self._module = module + + for arg in args: + self._attributes[arg] = dict() + if from_argspec: + self._attributes[arg]['read_from'] = arg + if keys and arg in keys: + self._attributes[arg]['key'] = True + + self.attr_names = frozenset(self._attributes.keys()) + + _has_key = False + + for name, attr in iteritems(self._attributes): + if attr.get('read_from'): + if attr['read_from'] not in self._module.argument_spec: + module.fail_json(msg='argument %s does not exist' % attr['read_from']) + spec = self._module.argument_spec.get(attr['read_from']) + for key, value in iteritems(spec): + if key not in attr: + attr[key] = value + + if attr.get('key'): + if _has_key: + module.fail_json(msg='only one key value can be specified') + _has_key = True + attr['required'] = True + + def serialize(self): + return self._attributes + + def to_dict(self, value): + obj = {} + for name, attr in iteritems(self._attributes): + if attr.get('key'): + obj[name] = value + else: + obj[name] = attr.get('default') + return obj + + def __call__(self, value, strict=True): + if not isinstance(value, dict): + value = self.to_dict(value) + + if strict: + unknown = set(value).difference(self.attr_names) + if unknown: + self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown)) + + for name, attr in iteritems(self._attributes): + if value.get(name) is None: + value[name] = attr.get('default') + + if attr.get('fallback') and not value.get(name): + fallback = attr.get('fallback', (None,)) + fallback_strategy = fallback[0] + fallback_args = [] + fallback_kwargs = {} + if fallback_strategy is not None: + for item in fallback[1:]: + if isinstance(item, dict): + fallback_kwargs = item + else: + fallback_args = item + try: + value[name] = fallback_strategy(*fallback_args, **fallback_kwargs) + except basic.AnsibleFallbackNotFound: + continue + + if attr.get('required') and value.get(name) is None: + self._module.fail_json(msg='missing required attribute %s' % name) + + if 'choices' in attr: + if value[name] not in attr['choices']: + self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name])) + + if value[name] is not None: + value_type = attr.get('type', 'str') + type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type] + type_checker(value[name]) + elif value.get(name): + value[name] = self._module.params[name] + + return value + + +class EntityCollection(Entity): + """Extends ```Entity``` to handle a list of dicts """ + + def __call__(self, iterable, strict=True): + if iterable is None: + iterable = [super(EntityCollection, self).__call__(self._module.params, strict)] + + if not isinstance(iterable, (list, tuple)): + self._module.fail_json(msg='value must be an iterable') + + return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable] + + +# these two are for backwards compatibility and can be removed once all of the +# modules that use them are updated +class ComplexDict(Entity): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexDict, self).__init__(module, attrs, *args, **kwargs) + + +class ComplexList(EntityCollection): + def __init__(self, attrs, module, *args, **kwargs): + super(ComplexList, self).__init__(module, attrs, *args, **kwargs) + + +def dict_diff(base, comparable): + """ Generate a dict object of differences + + This function will compare two dict objects and return the difference + between them as a dict object. For scalar values, the key will reflect + the updated value. If the key does not exist in `comparable`, then then no + key will be returned. For lists, the value in comparable will wholly replace + the value in base for the key. For dicts, the returned value will only + return keys that are different. + + :param base: dict object to base the diff on + :param comparable: dict object to compare against base + + :returns: new dict object with differences + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(comparable, dict): + if comparable is None: + comparable = dict() + else: + raise AssertionError("`comparable` must be of type <dict>") + + updates = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + item = comparable.get(key) + if item is not None: + sub_diff = dict_diff(value, comparable[key]) + if sub_diff: + updates[key] = sub_diff + else: + comparable_value = comparable.get(key) + if comparable_value is not None: + if sort_list(base[key]) != sort_list(comparable_value): + updates[key] = comparable_value + + for key in set(comparable.keys()).difference(base.keys()): + updates[key] = comparable.get(key) + + return updates + + +def dict_merge(base, other): + """ Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. If the value is a list object, the two lists will + be combined and duplicate entries removed. + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type <dict>") + if not isinstance(other, dict): + raise AssertionError("`other` must be of type <dict>") + + combined = dict() + + for key, value in iteritems(base): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + if isinstance(other[key], Mapping): + combined[key] = dict_merge(value, other[key]) + else: + combined[key] = other[key] + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if item is not None: + try: + combined[key] = list(set(chain(value, item))) + except TypeError: + value.extend([i for i in item if i not in value]) + combined[key] = value + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) + + return combined + + +def param_list_to_dict(param_list, unique_key="name", remove_key=True): + """Rotates a list of dictionaries to be a dictionary of dictionaries. + + :param param_list: The aforementioned list of dictionaries + :param unique_key: The name of a key which is present and unique in all of param_list's dictionaries. The value + behind this key will be the key each dictionary can be found at in the new root dictionary + :param remove_key: If True, remove unique_key from the individual dictionaries before returning. + """ + param_dict = {} + for params in param_list: + params = params.copy() + if remove_key: + name = params.pop(unique_key) + else: + name = params.get(unique_key) + param_dict[name] = params + + return param_dict + + +def conditional(expr, val, cast=None): + match = re.match(r'^(.+)\((.+)\)$', str(expr), re.I) + if match: + op, arg = match.groups() + else: + op = 'eq' + if ' ' in str(expr): + raise AssertionError('invalid expression: cannot contain spaces') + arg = expr + + if cast is None and val is not None: + arg = type(val)(arg) + elif callable(cast): + arg = cast(arg) + val = cast(val) + + op = next((oper for alias, oper in ALIASES if op == alias), op) + + if not hasattr(operator, op) and op not in OPERATORS: + raise ValueError('unknown operator: %s' % op) + + func = getattr(operator, op) + return func(val, arg) + + +def ternary(value, true_val, false_val): + ''' value ? true_val : false_val ''' + if value: + return true_val + else: + return false_val + + +def remove_default_spec(spec): + for item in spec: + if 'default' in spec[item]: + del spec[item]['default'] + + +def validate_ip_address(address): + try: + socket.inet_aton(address) + except socket.error: + return False + return address.count('.') == 3 + + +def validate_ip_v6_address(address): + try: + socket.inet_pton(socket.AF_INET6, address) + except socket.error: + return False + return True + + +def validate_prefix(prefix): + if prefix and not 0 <= int(prefix) <= 32: + return False + return True + + +def load_provider(spec, args): + provider = args.get('provider') or {} + for key, value in iteritems(spec): + if key not in provider: + if 'fallback' in value: + provider[key] = _fallback(value['fallback']) + elif 'default' in value: + provider[key] = value['default'] + else: + provider[key] = None + if 'authorize' in provider: + # Coerce authorize to provider if a string has somehow snuck in. + provider['authorize'] = boolean(provider['authorize'] or False) + args['provider'] = provider + return provider + + +def _fallback(fallback): + strategy = fallback[0] + args = [] + kwargs = {} + + for item in fallback[1:]: + if isinstance(item, dict): + kwargs = item + else: + args = item + try: + return strategy(*args, **kwargs) + except basic.AnsibleFallbackNotFound: + pass + + +def generate_dict(spec): + """ + Generate dictionary which is in sync with argspec + + :param spec: A dictionary that is the argspec of the module + :rtype: A dictionary + :returns: A dictionary in sync with argspec with default value + """ + obj = {} + if not spec: + return obj + + for key, val in iteritems(spec): + if 'default' in val: + dct = {key: val['default']} + elif 'type' in val and val['type'] == 'dict': + dct = {key: generate_dict(val['options'])} + else: + dct = {key: None} + obj.update(dct) + return obj + + +def parse_conf_arg(cfg, arg): + """ + Parse config based on argument + + :param cfg: A text string which is a line of configuration. + :param arg: A text string which is to be matched. + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r'%s (.+)(\n|$)' % arg, cfg, re.M) + if match: + result = match.group(1).strip() + else: + result = None + return result + + +def parse_conf_cmd_arg(cfg, cmd, res1, res2=None, delete_str='no'): + """ + Parse config based on command + + :param cfg: A text string which is a line of configuration. + :param cmd: A text string which is the command to be matched + :param res1: A text string to be returned if the command is present + :param res2: A text string to be returned if the negate command + is present + :param delete_str: A text string to identify the start of the + negate command + :rtype: A text string + :returns: A text string if match is found + """ + match = re.search(r'\n\s+%s(\n|$)' % cmd, cfg) + if match: + return res1 + if res2 is not None: + match = re.search(r'\n\s+%s %s(\n|$)' % (delete_str, cmd), cfg) + if match: + return res2 + return None + + +def get_xml_conf_arg(cfg, path, data='text'): + """ + :param cfg: The top level configuration lxml Element tree object + :param path: The relative xpath w.r.t to top level element (cfg) + to be searched in the xml hierarchy + :param data: The type of data to be returned for the matched xml node. + Valid values are text, tag, attrib, with default as text. + :return: Returns the required type for the matched xml node or else None + """ + match = cfg.xpath(path) + if len(match): + if data == 'tag': + result = getattr(match[0], 'tag') + elif data == 'attrib': + result = getattr(match[0], 'attrib') + else: + result = getattr(match[0], 'text') + else: + result = None + return result + + +def remove_empties(cfg_dict): + """ + Generate final config dictionary + + :param cfg_dict: A dictionary parsed in the facts system + :rtype: A dictionary + :returns: A dictionary by eliminating keys that have null values + """ + final_cfg = {} + if not cfg_dict: + return final_cfg + + for key, val in iteritems(cfg_dict): + dct = None + if isinstance(val, dict): + child_val = remove_empties(val) + if child_val: + dct = {key: child_val} + elif (isinstance(val, list) and val + and all([isinstance(x, dict) for x in val])): + child_val = [remove_empties(x) for x in val] + if child_val: + dct = {key: child_val} + elif val not in [None, [], {}, (), '']: + dct = {key: val} + if dct: + final_cfg.update(dct) + return final_cfg + + +def validate_config(spec, data): + """ + Validate if the input data against the AnsibleModule spec format + :param spec: Ansible argument spec + :param data: Data to be validated + :return: + """ + params = basic._ANSIBLE_ARGS + basic._ANSIBLE_ARGS = to_bytes(json.dumps({'ANSIBLE_MODULE_ARGS': data})) + validated_data = basic.AnsibleModule(spec).params + basic._ANSIBLE_ARGS = params + return validated_data + + +def search_obj_in_list(name, lst, key='name'): + if not lst: + return None + else: + for item in lst: + if item.get(key) == name: + return item + + +class Template: + + def __init__(self): + if not HAS_JINJA2: + raise ImportError("jinja2 is required but does not appear to be installed. " + "It can be installed using `pip install jinja2`") + + self.env = Environment(undefined=StrictUndefined) + self.env.filters.update({'ternary': ternary}) + + def __call__(self, value, variables=None, fail_on_undefined=True): + variables = variables or {} + + if not self.contains_vars(value): + return value + + try: + value = self.env.from_string(value).render(variables) + except UndefinedError: + if not fail_on_undefined: + return None + raise + + if value: + try: + return ast.literal_eval(value) + except Exception: + return str(value) + else: + return None + + def contains_vars(self, data): + if isinstance(data, string_types): + for marker in (self.env.block_start_string, self.env.variable_start_string, self.env.comment_start_string): + if marker in data: + return True + return False diff --git a/test/support/integration/plugins/module_utils/postgres.py b/test/support/integration/plugins/module_utils/postgres.py new file mode 100644 index 00000000..63811c30 --- /dev/null +++ b/test/support/integration/plugins/module_utils/postgres.py @@ -0,0 +1,330 @@ +# This code is part of Ansible, but is an independent component. +# This particular file snippet, and this file snippet only, is BSD licensed. +# Modules you write using this snippet, which is embedded dynamically by Ansible +# still belong to the author of the module, and may assign their own license +# to the complete work. +# +# Copyright (c), Ted Timmons <ted@timmons.me>, 2017. +# Most of this was originally added by other creators in the postgresql_user module. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +psycopg2 = None # This line needs for unit tests +try: + import psycopg2 + HAS_PSYCOPG2 = True +except ImportError: + HAS_PSYCOPG2 = False + +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils._text import to_native +from ansible.module_utils.six import iteritems +from distutils.version import LooseVersion + + +def postgres_common_argument_spec(): + """ + Return a dictionary with connection options. + + The options are commonly used by most of PostgreSQL modules. + """ + return dict( + login_user=dict(default='postgres'), + login_password=dict(default='', no_log=True), + login_host=dict(default=''), + login_unix_socket=dict(default=''), + port=dict(type='int', default=5432, aliases=['login_port']), + ssl_mode=dict(default='prefer', choices=['allow', 'disable', 'prefer', 'require', 'verify-ca', 'verify-full']), + ca_cert=dict(aliases=['ssl_rootcert']), + ) + + +def ensure_required_libs(module): + """Check required libraries.""" + if not HAS_PSYCOPG2: + module.fail_json(msg=missing_required_lib('psycopg2')) + + if module.params.get('ca_cert') and LooseVersion(psycopg2.__version__) < LooseVersion('2.4.3'): + module.fail_json(msg='psycopg2 must be at least 2.4.3 in order to use the ca_cert parameter') + + +def connect_to_db(module, conn_params, autocommit=False, fail_on_conn=True): + """Connect to a PostgreSQL database. + + Return psycopg2 connection object. + + Args: + module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class + conn_params (dict) -- dictionary with connection parameters + + Kwargs: + autocommit (bool) -- commit automatically (default False) + fail_on_conn (bool) -- fail if connection failed or just warn and return None (default True) + """ + ensure_required_libs(module) + + db_connection = None + try: + db_connection = psycopg2.connect(**conn_params) + if autocommit: + if LooseVersion(psycopg2.__version__) >= LooseVersion('2.4.2'): + db_connection.set_session(autocommit=True) + else: + db_connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + + # Switch role, if specified: + if module.params.get('session_role'): + cursor = db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) + + try: + cursor.execute('SET ROLE "%s"' % module.params['session_role']) + except Exception as e: + module.fail_json(msg="Could not switch role: %s" % to_native(e)) + finally: + cursor.close() + + except TypeError as e: + if 'sslrootcert' in e.args[0]: + module.fail_json(msg='Postgresql server must be at least ' + 'version 8.4 to support sslrootcert') + + if fail_on_conn: + module.fail_json(msg="unable to connect to database: %s" % to_native(e)) + else: + module.warn("PostgreSQL server is unavailable: %s" % to_native(e)) + db_connection = None + + except Exception as e: + if fail_on_conn: + module.fail_json(msg="unable to connect to database: %s" % to_native(e)) + else: + module.warn("PostgreSQL server is unavailable: %s" % to_native(e)) + db_connection = None + + return db_connection + + +def exec_sql(obj, query, query_params=None, ddl=False, add_to_executed=True, dont_exec=False): + """Execute SQL. + + Auxiliary function for PostgreSQL user classes. + + Returns a query result if possible or True/False if ddl=True arg was passed. + It necessary for statements that don't return any result (like DDL queries). + + Args: + obj (obj) -- must be an object of a user class. + The object must have module (AnsibleModule class object) and + cursor (psycopg cursor object) attributes + query (str) -- SQL query to execute + + Kwargs: + query_params (dict or tuple) -- Query parameters to prevent SQL injections, + could be a dict or tuple + ddl (bool) -- must return True or False instead of rows (typical for DDL queries) + (default False) + add_to_executed (bool) -- append the query to obj.executed_queries attribute + dont_exec (bool) -- used with add_to_executed=True to generate a query, add it + to obj.executed_queries list and return True (default False) + """ + + if dont_exec: + # This is usually needed to return queries in check_mode + # without execution + query = obj.cursor.mogrify(query, query_params) + if add_to_executed: + obj.executed_queries.append(query) + + return True + + try: + if query_params is not None: + obj.cursor.execute(query, query_params) + else: + obj.cursor.execute(query) + + if add_to_executed: + if query_params is not None: + obj.executed_queries.append(obj.cursor.mogrify(query, query_params)) + else: + obj.executed_queries.append(query) + + if not ddl: + res = obj.cursor.fetchall() + return res + return True + except Exception as e: + obj.module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e))) + return False + + +def get_conn_params(module, params_dict, warn_db_default=True): + """Get connection parameters from the passed dictionary. + + Return a dictionary with parameters to connect to PostgreSQL server. + + Args: + module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class + params_dict (dict) -- dictionary with variables + + Kwargs: + warn_db_default (bool) -- warn that the default DB is used (default True) + """ + # To use defaults values, keyword arguments must be absent, so + # check which values are empty and don't include in the return dictionary + params_map = { + "login_host": "host", + "login_user": "user", + "login_password": "password", + "port": "port", + "ssl_mode": "sslmode", + "ca_cert": "sslrootcert" + } + + # Might be different in the modules: + if params_dict.get('db'): + params_map['db'] = 'database' + elif params_dict.get('database'): + params_map['database'] = 'database' + elif params_dict.get('login_db'): + params_map['login_db'] = 'database' + else: + if warn_db_default: + module.warn('Database name has not been passed, ' + 'used default database to connect to.') + + kw = dict((params_map[k], v) for (k, v) in iteritems(params_dict) + if k in params_map and v != '' and v is not None) + + # If a login_unix_socket is specified, incorporate it here. + is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost" + if is_localhost and params_dict["login_unix_socket"] != "": + kw["host"] = params_dict["login_unix_socket"] + + return kw + + +class PgMembership(object): + def __init__(self, module, cursor, groups, target_roles, fail_on_role=True): + self.module = module + self.cursor = cursor + self.target_roles = [r.strip() for r in target_roles] + self.groups = [r.strip() for r in groups] + self.executed_queries = [] + self.granted = {} + self.revoked = {} + self.fail_on_role = fail_on_role + self.non_existent_roles = [] + self.changed = False + self.__check_roles_exist() + + def grant(self): + for group in self.groups: + self.granted[group] = [] + + for role in self.target_roles: + # If role is in a group now, pass: + if self.__check_membership(group, role): + continue + + query = 'GRANT "%s" TO "%s"' % (group, role) + self.changed = exec_sql(self, query, ddl=True) + + if self.changed: + self.granted[group].append(role) + + return self.changed + + def revoke(self): + for group in self.groups: + self.revoked[group] = [] + + for role in self.target_roles: + # If role is not in a group now, pass: + if not self.__check_membership(group, role): + continue + + query = 'REVOKE "%s" FROM "%s"' % (group, role) + self.changed = exec_sql(self, query, ddl=True) + + if self.changed: + self.revoked[group].append(role) + + return self.changed + + def __check_membership(self, src_role, dst_role): + query = ("SELECT ARRAY(SELECT b.rolname FROM " + "pg_catalog.pg_auth_members m " + "JOIN pg_catalog.pg_roles b ON (m.roleid = b.oid) " + "WHERE m.member = r.oid) " + "FROM pg_catalog.pg_roles r " + "WHERE r.rolname = %(dst_role)s") + + res = exec_sql(self, query, query_params={'dst_role': dst_role}, add_to_executed=False) + membership = [] + if res: + membership = res[0][0] + + if not membership: + return False + + if src_role in membership: + return True + + return False + + def __check_roles_exist(self): + existent_groups = self.__roles_exist(self.groups) + existent_roles = self.__roles_exist(self.target_roles) + + for group in self.groups: + if group not in existent_groups: + if self.fail_on_role: + self.module.fail_json(msg="Role %s does not exist" % group) + else: + self.module.warn("Role %s does not exist, pass" % group) + self.non_existent_roles.append(group) + + for role in self.target_roles: + if role not in existent_roles: + if self.fail_on_role: + self.module.fail_json(msg="Role %s does not exist" % role) + else: + self.module.warn("Role %s does not exist, pass" % role) + + if role not in self.groups: + self.non_existent_roles.append(role) + + else: + if self.fail_on_role: + self.module.exit_json(msg="Role role '%s' is a member of role '%s'" % (role, role)) + else: + self.module.warn("Role role '%s' is a member of role '%s', pass" % (role, role)) + + # Update role lists, excluding non existent roles: + self.groups = [g for g in self.groups if g not in self.non_existent_roles] + + self.target_roles = [r for r in self.target_roles if r not in self.non_existent_roles] + + def __roles_exist(self, roles): + tmp = ["'" + x + "'" for x in roles] + query = "SELECT rolname FROM pg_roles WHERE rolname IN (%s)" % ','.join(tmp) + return [x[0] for x in exec_sql(self, query, add_to_executed=False)] diff --git a/test/support/integration/plugins/module_utils/rabbitmq.py b/test/support/integration/plugins/module_utils/rabbitmq.py new file mode 100644 index 00000000..cf764006 --- /dev/null +++ b/test/support/integration/plugins/module_utils/rabbitmq.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# +# Copyright: (c) 2016, Jorge Rodriguez <jorge.rodriguez@tiriel.eu> +# Copyright: (c) 2018, John Imison <john+github@imison.net> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from ansible.module_utils._text import to_native +from ansible.module_utils.basic import missing_required_lib +from ansible.module_utils.six.moves.urllib import parse as urllib_parse +from mimetypes import MimeTypes + +import os +import json +import traceback + +PIKA_IMP_ERR = None +try: + import pika + import pika.exceptions + from pika import spec + HAS_PIKA = True +except ImportError: + PIKA_IMP_ERR = traceback.format_exc() + HAS_PIKA = False + + +def rabbitmq_argument_spec(): + return dict( + login_user=dict(type='str', default='guest'), + login_password=dict(type='str', default='guest', no_log=True), + login_host=dict(type='str', default='localhost'), + login_port=dict(type='str', default='15672'), + login_protocol=dict(type='str', default='http', choices=['http', 'https']), + ca_cert=dict(type='path', aliases=['cacert']), + client_cert=dict(type='path', aliases=['cert']), + client_key=dict(type='path', aliases=['key']), + vhost=dict(type='str', default='/'), + ) + + +# notification/rabbitmq_basic_publish.py +class RabbitClient(): + def __init__(self, module): + self.module = module + self.params = module.params + self.check_required_library() + self.check_host_params() + self.url = self.params['url'] + self.proto = self.params['proto'] + self.username = self.params['username'] + self.password = self.params['password'] + self.host = self.params['host'] + self.port = self.params['port'] + self.vhost = self.params['vhost'] + self.queue = self.params['queue'] + self.headers = self.params['headers'] + self.cafile = self.params['cafile'] + self.certfile = self.params['certfile'] + self.keyfile = self.params['keyfile'] + + if self.host is not None: + self.build_url() + + if self.cafile is not None: + self.append_ssl_certs() + + self.connect_to_rabbitmq() + + def check_required_library(self): + if not HAS_PIKA: + self.module.fail_json(msg=missing_required_lib("pika"), exception=PIKA_IMP_ERR) + + def check_host_params(self): + # Fail if url is specified and other conflicting parameters have been specified + if self.params['url'] is not None and any(self.params[k] is not None for k in ['proto', 'host', 'port', 'password', 'username', 'vhost']): + self.module.fail_json(msg="url and proto, host, port, vhost, username or password cannot be specified at the same time.") + + # Fail if url not specified and there is a missing parameter to build the url + if self.params['url'] is None and any(self.params[k] is None for k in ['proto', 'host', 'port', 'password', 'username', 'vhost']): + self.module.fail_json(msg="Connection parameters must be passed via url, or, proto, host, port, vhost, username or password.") + + def append_ssl_certs(self): + ssl_options = {} + if self.cafile: + ssl_options['cafile'] = self.cafile + if self.certfile: + ssl_options['certfile'] = self.certfile + if self.keyfile: + ssl_options['keyfile'] = self.keyfile + + self.url = self.url + '?ssl_options=' + urllib_parse.quote(json.dumps(ssl_options)) + + @staticmethod + def rabbitmq_argument_spec(): + return dict( + url=dict(type='str'), + proto=dict(type='str', choices=['amqp', 'amqps']), + host=dict(type='str'), + port=dict(type='int'), + username=dict(type='str'), + password=dict(type='str', no_log=True), + vhost=dict(type='str'), + queue=dict(type='str') + ) + + ''' Consider some file size limits here ''' + def _read_file(self, path): + try: + with open(path, "rb") as file_handle: + return file_handle.read() + except IOError as e: + self.module.fail_json(msg="Unable to open file %s: %s" % (path, to_native(e))) + + @staticmethod + def _check_file_mime_type(path): + mime = MimeTypes() + return mime.guess_type(path) + + def build_url(self): + self.url = '{0}://{1}:{2}@{3}:{4}/{5}'.format(self.proto, + self.username, + self.password, + self.host, + self.port, + self.vhost) + + def connect_to_rabbitmq(self): + """ + Function to connect to rabbitmq using username and password + """ + try: + parameters = pika.URLParameters(self.url) + except Exception as e: + self.module.fail_json(msg="URL malformed: %s" % to_native(e)) + + try: + self.connection = pika.BlockingConnection(parameters) + except Exception as e: + self.module.fail_json(msg="Connection issue: %s" % to_native(e)) + + try: + self.conn_channel = self.connection.channel() + except pika.exceptions.AMQPChannelError as e: + self.close_connection() + self.module.fail_json(msg="Channel issue: %s" % to_native(e)) + + def close_connection(self): + try: + self.connection.close() + except pika.exceptions.AMQPConnectionError: + pass + + def basic_publish(self): + self.content_type = self.params.get("content_type") + + if self.params.get("body") is not None: + args = dict( + body=self.params.get("body"), + exchange=self.params.get("exchange"), + routing_key=self.params.get("routing_key"), + properties=pika.BasicProperties(content_type=self.content_type, delivery_mode=1, headers=self.headers)) + + # If src (file) is defined and content_type is left as default, do a mime lookup on the file + if self.params.get("src") is not None and self.content_type == 'text/plain': + self.content_type = RabbitClient._check_file_mime_type(self.params.get("src"))[0] + self.headers.update( + filename=os.path.basename(self.params.get("src")) + ) + + args = dict( + body=self._read_file(self.params.get("src")), + exchange=self.params.get("exchange"), + routing_key=self.params.get("routing_key"), + properties=pika.BasicProperties(content_type=self.content_type, + delivery_mode=1, + headers=self.headers + )) + elif self.params.get("src") is not None: + args = dict( + body=self._read_file(self.params.get("src")), + exchange=self.params.get("exchange"), + routing_key=self.params.get("routing_key"), + properties=pika.BasicProperties(content_type=self.content_type, + delivery_mode=1, + headers=self.headers + )) + + try: + # If queue is not defined, RabbitMQ will return the queue name of the automatically generated queue. + if self.queue is None: + result = self.conn_channel.queue_declare(durable=self.params.get("durable"), + exclusive=self.params.get("exclusive"), + auto_delete=self.params.get("auto_delete")) + self.conn_channel.confirm_delivery() + self.queue = result.method.queue + else: + self.conn_channel.queue_declare(queue=self.queue, + durable=self.params.get("durable"), + exclusive=self.params.get("exclusive"), + auto_delete=self.params.get("auto_delete")) + self.conn_channel.confirm_delivery() + except Exception as e: + self.module.fail_json(msg="Queue declare issue: %s" % to_native(e)) + + # https://github.com/ansible/ansible/blob/devel/lib/ansible/module_utils/cloudstack.py#L150 + if args['routing_key'] is None: + args['routing_key'] = self.queue + + if args['exchange'] is None: + args['exchange'] = '' + + try: + self.conn_channel.basic_publish(**args) + return True + except pika.exceptions.UnroutableError: + return False diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py new file mode 120000 index 00000000..f9993bfb --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbconfiguration_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbconfiguration_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py new file mode 120000 index 00000000..b8293e64 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbdatabase_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbdatabase_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py new file mode 120000 index 00000000..4311a0c1 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbfirewallrule_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbfirewallrule_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py b/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py new file mode 120000 index 00000000..5f76e0e9 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_mariadbserver_facts.py @@ -0,0 +1 @@ +azure_rm_mariadbserver_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_resource_facts.py b/test/support/integration/plugins/modules/_azure_rm_resource_facts.py new file mode 120000 index 00000000..710fda10 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_resource_facts.py @@ -0,0 +1 @@ +azure_rm_resource_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py b/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py new file mode 120000 index 00000000..ead87c85 --- /dev/null +++ b/test/support/integration/plugins/modules/_azure_rm_webapp_facts.py @@ -0,0 +1 @@ +azure_rm_webapp_info.py
\ No newline at end of file diff --git a/test/support/integration/plugins/modules/aws_az_info.py b/test/support/integration/plugins/modules/aws_az_info.py new file mode 100644 index 00000000..c1efed6f --- /dev/null +++ b/test/support/integration/plugins/modules/aws_az_info.py @@ -0,0 +1,111 @@ +#!/usr/bin/python +# Copyright (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'supported_by': 'community', + 'status': ['preview'] +} + +DOCUMENTATION = ''' +module: aws_az_info +short_description: Gather information about availability zones in AWS. +description: + - Gather information about availability zones in AWS. + - This module was called C(aws_az_facts) before Ansible 2.9. The usage did not change. +version_added: '2.5' +author: 'Henrique Rodrigues (@Sodki)' +options: + filters: + description: + - A dict of filters to apply. Each dict item consists of a filter key and a filter value. See + U(https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeAvailabilityZones.html) for + possible filters. Filter names and values are case sensitive. You can also use underscores + instead of dashes (-) in the filter keys, which will take precedence in case of conflict. + required: false + default: {} + type: dict +extends_documentation_fragment: + - aws + - ec2 +requirements: [botocore, boto3] +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Gather information about all availability zones +- aws_az_info: + +# Gather information about a single availability zone +- aws_az_info: + filters: + zone-name: eu-west-1a +''' + +RETURN = ''' +availability_zones: + returned: on success + description: > + Availability zones that match the provided filters. Each element consists of a dict with all the information + related to that available zone. + type: list + sample: "[ + { + 'messages': [], + 'region_name': 'us-west-1', + 'state': 'available', + 'zone_name': 'us-west-1b' + }, + { + 'messages': [], + 'region_name': 'us-west-1', + 'state': 'available', + 'zone_name': 'us-west-1c' + } + ]" +''' + +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import AWSRetry, ansible_dict_to_boto3_filter_list, camel_dict_to_snake_dict + +try: + from botocore.exceptions import ClientError, BotoCoreError +except ImportError: + pass # Handled by AnsibleAWSModule + + +def main(): + argument_spec = dict( + filters=dict(default={}, type='dict') + ) + + module = AnsibleAWSModule(argument_spec=argument_spec) + if module._name == 'aws_az_facts': + module.deprecate("The 'aws_az_facts' module has been renamed to 'aws_az_info'", + version='2.14', collection_name='ansible.builtin') + + connection = module.client('ec2', retry_decorator=AWSRetry.jittered_backoff()) + + # Replace filter key underscores with dashes, for compatibility + sanitized_filters = dict((k.replace('_', '-'), v) for k, v in module.params.get('filters').items()) + + try: + availability_zones = connection.describe_availability_zones( + Filters=ansible_dict_to_boto3_filter_list(sanitized_filters) + ) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to describe availability zones.") + + # Turn the boto3 result into ansible_friendly_snaked_names + snaked_availability_zones = [camel_dict_to_snake_dict(az) for az in availability_zones['AvailabilityZones']] + + module.exit_json(availability_zones=snaked_availability_zones) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/aws_s3.py b/test/support/integration/plugins/modules/aws_s3.py new file mode 100644 index 00000000..54874f05 --- /dev/null +++ b/test/support/integration/plugins/modules/aws_s3.py @@ -0,0 +1,925 @@ +#!/usr/bin/python +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: aws_s3 +short_description: manage objects in S3. +description: + - This module allows the user to manage S3 buckets and the objects within them. Includes support for creating and + deleting both objects and buckets, retrieving objects as files or strings and generating download links. + This module has a dependency on boto3 and botocore. +notes: + - In 2.4, this module has been renamed from C(s3) into M(aws_s3). +version_added: "1.1" +options: + bucket: + description: + - Bucket name. + required: true + type: str + dest: + description: + - The destination file path when downloading an object/key with a GET operation. + version_added: "1.3" + type: path + encrypt: + description: + - When set for PUT mode, asks for server-side encryption. + default: true + version_added: "2.0" + type: bool + encryption_mode: + description: + - What encryption mode to use if I(encrypt=true). + default: AES256 + choices: + - AES256 + - aws:kms + version_added: "2.7" + type: str + expiry: + description: + - Time limit (in seconds) for the URL generated and returned by S3/Walrus when performing a I(mode=put) or I(mode=geturl) operation. + default: 600 + aliases: ['expiration'] + type: int + headers: + description: + - Custom headers for PUT operation, as a dictionary of 'key=value' and 'key=value,key=value'. + version_added: "2.0" + type: dict + marker: + description: + - Specifies the key to start with when using list mode. Object keys are returned in alphabetical order, starting with key after the marker in order. + version_added: "2.0" + type: str + max_keys: + description: + - Max number of results to return in list mode, set this if you want to retrieve fewer than the default 1000 keys. + default: 1000 + version_added: "2.0" + type: int + metadata: + description: + - Metadata for PUT operation, as a dictionary of 'key=value' and 'key=value,key=value'. + version_added: "1.6" + type: dict + mode: + description: + - Switches the module behaviour between put (upload), get (download), geturl (return download url, Ansible 1.3+), + getstr (download object as string (1.3+)), list (list keys, Ansible 2.0+), create (bucket), delete (bucket), + and delobj (delete object, Ansible 2.0+). + required: true + choices: ['get', 'put', 'delete', 'create', 'geturl', 'getstr', 'delobj', 'list'] + type: str + object: + description: + - Keyname of the object inside the bucket. Can be used to create "virtual directories", see examples. + type: str + permission: + description: + - This option lets the user set the canned permissions on the object/bucket that are created. + The permissions that can be set are C(private), C(public-read), C(public-read-write), C(authenticated-read) for a bucket or + C(private), C(public-read), C(public-read-write), C(aws-exec-read), C(authenticated-read), C(bucket-owner-read), + C(bucket-owner-full-control) for an object. Multiple permissions can be specified as a list. + default: ['private'] + version_added: "2.0" + type: list + elements: str + prefix: + description: + - Limits the response to keys that begin with the specified prefix for list mode. + default: "" + version_added: "2.0" + type: str + version: + description: + - Version ID of the object inside the bucket. Can be used to get a specific version of a file if versioning is enabled in the target bucket. + version_added: "2.0" + type: str + overwrite: + description: + - Force overwrite either locally on the filesystem or remotely with the object/key. Used with PUT and GET operations. + Boolean or one of [always, never, different], true is equal to 'always' and false is equal to 'never', new in 2.0. + When this is set to 'different', the md5 sum of the local file is compared with the 'ETag' of the object/key in S3. + The ETag may or may not be an MD5 digest of the object data. See the ETag response header here + U(https://docs.aws.amazon.com/AmazonS3/latest/API/RESTCommonResponseHeaders.html) + default: 'always' + aliases: ['force'] + version_added: "1.2" + type: str + retries: + description: + - On recoverable failure, how many times to retry before actually failing. + default: 0 + version_added: "2.0" + type: int + aliases: ['retry'] + s3_url: + description: + - S3 URL endpoint for usage with Ceph, Eucalyptus and fakes3 etc. Otherwise assumes AWS. + aliases: [ S3_URL ] + type: str + dualstack: + description: + - Enables Amazon S3 Dual-Stack Endpoints, allowing S3 communications using both IPv4 and IPv6. + - Requires at least botocore version 1.4.45. + type: bool + default: false + version_added: "2.7" + rgw: + description: + - Enable Ceph RGW S3 support. This option requires an explicit url via I(s3_url). + default: false + version_added: "2.2" + type: bool + src: + description: + - The source file path when performing a PUT operation. + version_added: "1.3" + type: str + ignore_nonexistent_bucket: + description: + - "Overrides initial bucket lookups in case bucket or iam policies are restrictive. Example: a user may have the + GetObject permission but no other permissions. In this case using the option mode: get will fail without specifying + I(ignore_nonexistent_bucket=true)." + version_added: "2.3" + type: bool + encryption_kms_key_id: + description: + - KMS key id to use when encrypting objects using I(encrypting=aws:kms). Ignored if I(encryption) is not C(aws:kms) + version_added: "2.7" + type: str +requirements: [ "boto3", "botocore" ] +author: + - "Lester Wade (@lwade)" + - "Sloane Hertel (@s-hertel)" +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +- name: Simple PUT operation + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + +- name: Simple PUT operation in Ceph RGW S3 + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + rgw: true + s3_url: "http://localhost:8000" + +- name: Simple GET operation + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + dest: /usr/local/myfile.txt + mode: get + +- name: Get a specific version of an object. + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + version: 48c9ee5131af7a716edc22df9772aa6f + dest: /usr/local/myfile.txt + mode: get + +- name: PUT/upload with metadata + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + metadata: 'Content-Encoding=gzip,Cache-Control=no-cache' + +- name: PUT/upload with custom headers + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + src: /usr/local/myfile.txt + mode: put + headers: 'x-amz-grant-full-control=emailAddress=owner@example.com' + +- name: List keys simple + aws_s3: + bucket: mybucket + mode: list + +- name: List keys all options + aws_s3: + bucket: mybucket + mode: list + prefix: /my/desired/ + marker: /my/desired/0023.txt + max_keys: 472 + +- name: Create an empty bucket + aws_s3: + bucket: mybucket + mode: create + permission: public-read + +- name: Create a bucket with key as directory, in the EU region + aws_s3: + bucket: mybucket + object: /my/directory/path + mode: create + region: eu-west-1 + +- name: Delete a bucket and all contents + aws_s3: + bucket: mybucket + mode: delete + +- name: GET an object but don't download if the file checksums match. New in 2.0 + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + dest: /usr/local/myfile.txt + mode: get + overwrite: different + +- name: Delete an object from a bucket + aws_s3: + bucket: mybucket + object: /my/desired/key.txt + mode: delobj +''' + +RETURN = ''' +msg: + description: Message indicating the status of the operation. + returned: always + type: str + sample: PUT operation complete +url: + description: URL of the object. + returned: (for put and geturl operations) + type: str + sample: https://my-bucket.s3.amazonaws.com/my-key.txt?AWSAccessKeyId=<access-key>&Expires=1506888865&Signature=<signature> +expiry: + description: Number of seconds the presigned url is valid for. + returned: (for geturl operation) + type: int + sample: 600 +contents: + description: Contents of the object as string. + returned: (for getstr operation) + type: str + sample: "Hello, world!" +s3_keys: + description: List of object keys. + returned: (for list operation) + type: list + elements: str + sample: + - prefix1/ + - prefix1/key1 + - prefix1/key2 +''' + +import mimetypes +import os +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ssl import SSLError +from ansible.module_utils.basic import to_text, to_native +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.aws.s3 import calculate_etag, HAS_MD5 +from ansible.module_utils.ec2 import get_aws_connection_info, boto3_conn + +try: + import botocore +except ImportError: + pass # will be detected by imported AnsibleAWSModule + +IGNORE_S3_DROP_IN_EXCEPTIONS = ['XNotImplemented', 'NotImplemented'] + + +class Sigv4Required(Exception): + pass + + +def key_check(module, s3, bucket, obj, version=None, validate=True): + exists = True + try: + if version: + s3.head_object(Bucket=bucket, Key=obj, VersionId=version) + else: + s3.head_object(Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + # if a client error is thrown, check if it's a 404 error + # if it's a 404 error, then the object does not exist + error_code = int(e.response['Error']['Code']) + if error_code == 404: + exists = False + elif error_code == 403 and validate is False: + pass + else: + module.fail_json_aws(e, msg="Failed while looking up object (during key check) %s." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while looking up object (during key check) %s." % obj) + return exists + + +def etag_compare(module, local_file, s3, bucket, obj, version=None): + s3_etag = get_etag(s3, bucket, obj, version=version) + local_etag = calculate_etag(module, local_file, s3_etag, s3, bucket, obj, version) + + return s3_etag == local_etag + + +def get_etag(s3, bucket, obj, version=None): + if version: + key_check = s3.head_object(Bucket=bucket, Key=obj, VersionId=version) + else: + key_check = s3.head_object(Bucket=bucket, Key=obj) + if not key_check: + return None + return key_check['ETag'] + + +def bucket_check(module, s3, bucket, validate=True): + exists = True + try: + s3.head_bucket(Bucket=bucket) + except botocore.exceptions.ClientError as e: + # If a client error is thrown, then check that it was a 404 error. + # If it was a 404 error, then the bucket does not exist. + error_code = int(e.response['Error']['Code']) + if error_code == 404: + exists = False + elif error_code == 403 and validate is False: + pass + else: + module.fail_json_aws(e, msg="Failed while looking up bucket (during bucket_check) %s." % bucket) + except botocore.exceptions.EndpointConnectionError as e: + module.fail_json_aws(e, msg="Invalid endpoint provided") + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while looking up bucket (during bucket_check) %s." % bucket) + return exists + + +def create_bucket(module, s3, bucket, location=None): + if module.check_mode: + module.exit_json(msg="CREATE operation skipped - running in check mode", changed=True) + configuration = {} + if location not in ('us-east-1', None): + configuration['LocationConstraint'] = location + try: + if len(configuration) > 0: + s3.create_bucket(Bucket=bucket, CreateBucketConfiguration=configuration) + else: + s3.create_bucket(Bucket=bucket) + if module.params.get('permission'): + # Wait for the bucket to exist before setting ACLs + s3.get_waiter('bucket_exists').wait(Bucket=bucket) + for acl in module.params.get('permission'): + s3.put_bucket_acl(ACL=acl, Bucket=bucket) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS: + module.warn("PutBucketAcl is not implemented by your storage provider. Set the permission parameters to the empty list to avoid this warning") + else: + module.fail_json_aws(e, msg="Failed while creating bucket or setting acl (check that you have CreateBucket and PutBucketAcl permission).") + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while creating bucket or setting acl (check that you have CreateBucket and PutBucketAcl permission).") + + if bucket: + return True + + +def paginated_list(s3, **pagination_params): + pg = s3.get_paginator('list_objects_v2') + for page in pg.paginate(**pagination_params): + yield [data['Key'] for data in page.get('Contents', [])] + + +def paginated_versioned_list_with_fallback(s3, **pagination_params): + try: + versioned_pg = s3.get_paginator('list_object_versions') + for page in versioned_pg.paginate(**pagination_params): + delete_markers = [{'Key': data['Key'], 'VersionId': data['VersionId']} for data in page.get('DeleteMarkers', [])] + current_objects = [{'Key': data['Key'], 'VersionId': data['VersionId']} for data in page.get('Versions', [])] + yield delete_markers + current_objects + except botocore.exceptions.ClientError as e: + if to_text(e.response['Error']['Code']) in IGNORE_S3_DROP_IN_EXCEPTIONS + ['AccessDenied']: + for page in paginated_list(s3, **pagination_params): + yield [{'Key': data['Key']} for data in page] + + +def list_keys(module, s3, bucket, prefix, marker, max_keys): + pagination_params = {'Bucket': bucket} + for param_name, param_value in (('Prefix', prefix), ('StartAfter', marker), ('MaxKeys', max_keys)): + pagination_params[param_name] = param_value + try: + keys = sum(paginated_list(s3, **pagination_params), []) + module.exit_json(msg="LIST operation complete", s3_keys=keys) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while listing the keys in the bucket {0}".format(bucket)) + + +def delete_bucket(module, s3, bucket): + if module.check_mode: + module.exit_json(msg="DELETE operation skipped - running in check mode", changed=True) + try: + exists = bucket_check(module, s3, bucket) + if exists is False: + return False + # if there are contents then we need to delete them before we can delete the bucket + for keys in paginated_versioned_list_with_fallback(s3, Bucket=bucket): + if keys: + s3.delete_objects(Bucket=bucket, Delete={'Objects': keys}) + s3.delete_bucket(Bucket=bucket) + return True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while deleting bucket %s." % bucket) + + +def delete_key(module, s3, bucket, obj): + if module.check_mode: + module.exit_json(msg="DELETE operation skipped - running in check mode", changed=True) + try: + s3.delete_object(Bucket=bucket, Key=obj) + module.exit_json(msg="Object deleted from bucket %s." % (bucket), changed=True) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while trying to delete %s." % obj) + + +def create_dirkey(module, s3, bucket, obj, encrypt): + if module.check_mode: + module.exit_json(msg="PUT operation skipped - running in check mode", changed=True) + try: + params = {'Bucket': bucket, 'Key': obj, 'Body': b''} + if encrypt: + params['ServerSideEncryption'] = module.params['encryption_mode'] + if module.params['encryption_kms_key_id'] and module.params['encryption_mode'] == 'aws:kms': + params['SSEKMSKeyId'] = module.params['encryption_kms_key_id'] + + s3.put_object(**params) + for acl in module.params.get('permission'): + s3.put_object_acl(ACL=acl, Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS: + module.warn("PutObjectAcl is not implemented by your storage provider. Set the permissions parameters to the empty list to avoid this warning") + else: + module.fail_json_aws(e, msg="Failed while creating object %s." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while creating object %s." % obj) + module.exit_json(msg="Virtual directory %s created in bucket %s" % (obj, bucket), changed=True) + + +def path_check(path): + if os.path.exists(path): + return True + else: + return False + + +def option_in_extra_args(option): + temp_option = option.replace('-', '').lower() + + allowed_extra_args = {'acl': 'ACL', 'cachecontrol': 'CacheControl', 'contentdisposition': 'ContentDisposition', + 'contentencoding': 'ContentEncoding', 'contentlanguage': 'ContentLanguage', + 'contenttype': 'ContentType', 'expires': 'Expires', 'grantfullcontrol': 'GrantFullControl', + 'grantread': 'GrantRead', 'grantreadacp': 'GrantReadACP', 'grantwriteacp': 'GrantWriteACP', + 'metadata': 'Metadata', 'requestpayer': 'RequestPayer', 'serversideencryption': 'ServerSideEncryption', + 'storageclass': 'StorageClass', 'ssecustomeralgorithm': 'SSECustomerAlgorithm', 'ssecustomerkey': 'SSECustomerKey', + 'ssecustomerkeymd5': 'SSECustomerKeyMD5', 'ssekmskeyid': 'SSEKMSKeyId', 'websiteredirectlocation': 'WebsiteRedirectLocation'} + + if temp_option in allowed_extra_args: + return allowed_extra_args[temp_option] + + +def upload_s3file(module, s3, bucket, obj, src, expiry, metadata, encrypt, headers): + if module.check_mode: + module.exit_json(msg="PUT operation skipped - running in check mode", changed=True) + try: + extra = {} + if encrypt: + extra['ServerSideEncryption'] = module.params['encryption_mode'] + if module.params['encryption_kms_key_id'] and module.params['encryption_mode'] == 'aws:kms': + extra['SSEKMSKeyId'] = module.params['encryption_kms_key_id'] + if metadata: + extra['Metadata'] = {} + + # determine object metadata and extra arguments + for option in metadata: + extra_args_option = option_in_extra_args(option) + if extra_args_option is not None: + extra[extra_args_option] = metadata[option] + else: + extra['Metadata'][option] = metadata[option] + + if 'ContentType' not in extra: + content_type = mimetypes.guess_type(src)[0] + if content_type is None: + # s3 default content type + content_type = 'binary/octet-stream' + extra['ContentType'] = content_type + + s3.upload_file(Filename=src, Bucket=bucket, Key=obj, ExtraArgs=extra) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to complete PUT operation.") + try: + for acl in module.params.get('permission'): + s3.put_object_acl(ACL=acl, Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] in IGNORE_S3_DROP_IN_EXCEPTIONS: + module.warn("PutObjectAcl is not implemented by your storage provider. Set the permission parameters to the empty list to avoid this warning") + else: + module.fail_json_aws(e, msg="Unable to set object ACL") + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Unable to set object ACL") + try: + url = s3.generate_presigned_url(ClientMethod='put_object', + Params={'Bucket': bucket, 'Key': obj}, + ExpiresIn=expiry) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to generate presigned URL") + module.exit_json(msg="PUT operation complete", url=url, changed=True) + + +def download_s3file(module, s3, bucket, obj, dest, retries, version=None): + if module.check_mode: + module.exit_json(msg="GET operation skipped - running in check mode", changed=True) + # retries is the number of loops; range/xrange needs to be one + # more to get that count of loops. + try: + if version: + key = s3.get_object(Bucket=bucket, Key=obj, VersionId=version) + else: + key = s3.get_object(Bucket=bucket, Key=obj) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] == 'InvalidArgument' and 'require AWS Signature Version 4' in to_text(e): + raise Sigv4Required() + elif e.response['Error']['Code'] not in ("403", "404"): + # AccessDenied errors may be triggered if 1) file does not exist or 2) file exists but + # user does not have the s3:GetObject permission. 404 errors are handled by download_file(). + module.fail_json_aws(e, msg="Could not find the key %s." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Could not find the key %s." % obj) + + optional_kwargs = {'ExtraArgs': {'VersionId': version}} if version else {} + for x in range(0, retries + 1): + try: + s3.download_file(bucket, obj, dest, **optional_kwargs) + module.exit_json(msg="GET operation complete", changed=True) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + # actually fail on last pass through the loop. + if x >= retries: + module.fail_json_aws(e, msg="Failed while downloading %s." % obj) + # otherwise, try again, this may be a transient timeout. + except SSLError as e: # will ClientError catch SSLError? + # actually fail on last pass through the loop. + if x >= retries: + module.fail_json_aws(e, msg="s3 download failed") + # otherwise, try again, this may be a transient timeout. + + +def download_s3str(module, s3, bucket, obj, version=None, validate=True): + if module.check_mode: + module.exit_json(msg="GET operation skipped - running in check mode", changed=True) + try: + if version: + contents = to_native(s3.get_object(Bucket=bucket, Key=obj, VersionId=version)["Body"].read()) + else: + contents = to_native(s3.get_object(Bucket=bucket, Key=obj)["Body"].read()) + module.exit_json(msg="GET operation complete", contents=contents, changed=True) + except botocore.exceptions.ClientError as e: + if e.response['Error']['Code'] == 'InvalidArgument' and 'require AWS Signature Version 4' in to_text(e): + raise Sigv4Required() + else: + module.fail_json_aws(e, msg="Failed while getting contents of object %s as a string." % obj) + except botocore.exceptions.BotoCoreError as e: + module.fail_json_aws(e, msg="Failed while getting contents of object %s as a string." % obj) + + +def get_download_url(module, s3, bucket, obj, expiry, changed=True): + try: + url = s3.generate_presigned_url(ClientMethod='get_object', + Params={'Bucket': bucket, 'Key': obj}, + ExpiresIn=expiry) + module.exit_json(msg="Download url:", url=url, expiry=expiry, changed=changed) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed while getting download url.") + + +def is_fakes3(s3_url): + """ Return True if s3_url has scheme fakes3:// """ + if s3_url is not None: + return urlparse(s3_url).scheme in ('fakes3', 'fakes3s') + else: + return False + + +def get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=False): + if s3_url and rgw: # TODO - test this + rgw = urlparse(s3_url) + params = dict(module=module, conn_type='client', resource='s3', use_ssl=rgw.scheme == 'https', region=location, endpoint=s3_url, **aws_connect_kwargs) + elif is_fakes3(s3_url): + fakes3 = urlparse(s3_url) + port = fakes3.port + if fakes3.scheme == 'fakes3s': + protocol = "https" + if port is None: + port = 443 + else: + protocol = "http" + if port is None: + port = 80 + params = dict(module=module, conn_type='client', resource='s3', region=location, + endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)), + use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs) + else: + params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=s3_url, **aws_connect_kwargs) + if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms': + params['config'] = botocore.client.Config(signature_version='s3v4') + elif module.params['mode'] in ('get', 'getstr') and sig_4: + params['config'] = botocore.client.Config(signature_version='s3v4') + if module.params['dualstack']: + dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True}) + if 'config' in params: + params['config'] = params['config'].merge(dualconf) + else: + params['config'] = dualconf + return boto3_conn(**params) + + +def main(): + argument_spec = dict( + bucket=dict(required=True), + dest=dict(default=None, type='path'), + encrypt=dict(default=True, type='bool'), + encryption_mode=dict(choices=['AES256', 'aws:kms'], default='AES256'), + expiry=dict(default=600, type='int', aliases=['expiration']), + headers=dict(type='dict'), + marker=dict(default=""), + max_keys=dict(default=1000, type='int'), + metadata=dict(type='dict'), + mode=dict(choices=['get', 'put', 'delete', 'create', 'geturl', 'getstr', 'delobj', 'list'], required=True), + object=dict(), + permission=dict(type='list', default=['private']), + version=dict(default=None), + overwrite=dict(aliases=['force'], default='always'), + prefix=dict(default=""), + retries=dict(aliases=['retry'], type='int', default=0), + s3_url=dict(aliases=['S3_URL']), + dualstack=dict(default='no', type='bool'), + rgw=dict(default='no', type='bool'), + src=dict(), + ignore_nonexistent_bucket=dict(default=False, type='bool'), + encryption_kms_key_id=dict() + ) + module = AnsibleAWSModule( + argument_spec=argument_spec, + supports_check_mode=True, + required_if=[['mode', 'put', ['src', 'object']], + ['mode', 'get', ['dest', 'object']], + ['mode', 'getstr', ['object']], + ['mode', 'geturl', ['object']]], + ) + + bucket = module.params.get('bucket') + encrypt = module.params.get('encrypt') + expiry = module.params.get('expiry') + dest = module.params.get('dest', '') + headers = module.params.get('headers') + marker = module.params.get('marker') + max_keys = module.params.get('max_keys') + metadata = module.params.get('metadata') + mode = module.params.get('mode') + obj = module.params.get('object') + version = module.params.get('version') + overwrite = module.params.get('overwrite') + prefix = module.params.get('prefix') + retries = module.params.get('retries') + s3_url = module.params.get('s3_url') + dualstack = module.params.get('dualstack') + rgw = module.params.get('rgw') + src = module.params.get('src') + ignore_nonexistent_bucket = module.params.get('ignore_nonexistent_bucket') + + object_canned_acl = ["private", "public-read", "public-read-write", "aws-exec-read", "authenticated-read", "bucket-owner-read", "bucket-owner-full-control"] + bucket_canned_acl = ["private", "public-read", "public-read-write", "authenticated-read"] + + if overwrite not in ['always', 'never', 'different']: + if module.boolean(overwrite): + overwrite = 'always' + else: + overwrite = 'never' + + if overwrite == 'different' and not HAS_MD5: + module.fail_json(msg='overwrite=different is unavailable: ETag calculation requires MD5 support') + + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True) + + if region in ('us-east-1', '', None): + # default to US Standard region + location = 'us-east-1' + else: + # Boto uses symbolic names for locations but region strings will + # actually work fine for everything except us-east-1 (US Standard) + location = region + + if module.params.get('object'): + obj = module.params['object'] + # If there is a top level object, do nothing - if the object starts with / + # remove the leading character to maintain compatibility with Ansible versions < 2.4 + if obj.startswith('/'): + obj = obj[1:] + + # Bucket deletion does not require obj. Prevents ambiguity with delobj. + if obj and mode == "delete": + module.fail_json(msg='Parameter obj cannot be used with mode=delete') + + # allow eucarc environment variables to be used if ansible vars aren't set + if not s3_url and 'S3_URL' in os.environ: + s3_url = os.environ['S3_URL'] + + if dualstack and s3_url is not None and 'amazonaws.com' not in s3_url: + module.fail_json(msg='dualstack only applies to AWS S3') + + if dualstack and not module.botocore_at_least('1.4.45'): + module.fail_json(msg='dualstack requires botocore >= 1.4.45') + + # rgw requires an explicit url + if rgw and not s3_url: + module.fail_json(msg='rgw flavour requires s3_url') + + # Look at s3_url and tweak connection settings + # if connecting to RGW, Walrus or fakes3 + if s3_url: + for key in ['validate_certs', 'security_token', 'profile_name']: + aws_connect_kwargs.pop(key, None) + s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url) + + validate = not ignore_nonexistent_bucket + + # separate types of ACLs + bucket_acl = [acl for acl in module.params.get('permission') if acl in bucket_canned_acl] + object_acl = [acl for acl in module.params.get('permission') if acl in object_canned_acl] + error_acl = [acl for acl in module.params.get('permission') if acl not in bucket_canned_acl and acl not in object_canned_acl] + if error_acl: + module.fail_json(msg='Unknown permission specified: %s' % error_acl) + + # First, we check to see if the bucket exists, we get "bucket" returned. + bucketrtn = bucket_check(module, s3, bucket, validate=validate) + + if validate and mode not in ('create', 'put', 'delete') and not bucketrtn: + module.fail_json(msg="Source bucket cannot be found.") + + if mode == 'get': + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + if keyrtn is False: + if version: + module.fail_json(msg="Key %s with version id %s does not exist." % (obj, version)) + else: + module.fail_json(msg="Key %s does not exist." % obj) + + if path_check(dest) and overwrite != 'always': + if overwrite == 'never': + module.exit_json(msg="Local object already exists and overwrite is disabled.", changed=False) + if etag_compare(module, dest, s3, bucket, obj, version=version): + module.exit_json(msg="Local and remote object are identical, ignoring. Use overwrite=always parameter to force.", changed=False) + + try: + download_s3file(module, s3, bucket, obj, dest, retries, version=version) + except Sigv4Required: + s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=True) + download_s3file(module, s3, bucket, obj, dest, retries, version=version) + + if mode == 'put': + + # if putting an object in a bucket yet to be created, acls for the bucket and/or the object may be specified + # these were separated into the variables bucket_acl and object_acl above + + if not path_check(src): + module.fail_json(msg="Local object for PUT does not exist") + + if bucketrtn: + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + else: + # If the bucket doesn't exist we should create it. + # only use valid bucket acls for create_bucket function + module.params['permission'] = bucket_acl + create_bucket(module, s3, bucket, location) + + if keyrtn and overwrite != 'always': + if overwrite == 'never' or etag_compare(module, src, s3, bucket, obj): + # Return the download URL for the existing object + get_download_url(module, s3, bucket, obj, expiry, changed=False) + + # only use valid object acls for the upload_s3file function + module.params['permission'] = object_acl + upload_s3file(module, s3, bucket, obj, src, expiry, metadata, encrypt, headers) + + # Delete an object from a bucket, not the entire bucket + if mode == 'delobj': + if obj is None: + module.fail_json(msg="object parameter is required") + if bucket: + deletertn = delete_key(module, s3, bucket, obj) + if deletertn is True: + module.exit_json(msg="Object deleted from bucket %s." % bucket, changed=True) + else: + module.fail_json(msg="Bucket parameter is required.") + + # Delete an entire bucket, including all objects in the bucket + if mode == 'delete': + if bucket: + deletertn = delete_bucket(module, s3, bucket) + if deletertn is True: + module.exit_json(msg="Bucket %s and all keys have been deleted." % bucket, changed=True) + else: + module.fail_json(msg="Bucket parameter is required.") + + # Support for listing a set of keys + if mode == 'list': + exists = bucket_check(module, s3, bucket) + + # If the bucket does not exist then bail out + if not exists: + module.fail_json(msg="Target bucket (%s) cannot be found" % bucket) + + list_keys(module, s3, bucket, prefix, marker, max_keys) + + # Need to research how to create directories without "populating" a key, so this should just do bucket creation for now. + # WE SHOULD ENABLE SOME WAY OF CREATING AN EMPTY KEY TO CREATE "DIRECTORY" STRUCTURE, AWS CONSOLE DOES THIS. + if mode == 'create': + + # if both creating a bucket and putting an object in it, acls for the bucket and/or the object may be specified + # these were separated above into the variables bucket_acl and object_acl + + if bucket and not obj: + if bucketrtn: + module.exit_json(msg="Bucket already exists.", changed=False) + else: + # only use valid bucket acls when creating the bucket + module.params['permission'] = bucket_acl + module.exit_json(msg="Bucket created successfully", changed=create_bucket(module, s3, bucket, location)) + if bucket and obj: + if obj.endswith('/'): + dirobj = obj + else: + dirobj = obj + "/" + if bucketrtn: + if key_check(module, s3, bucket, dirobj): + module.exit_json(msg="Bucket %s and key %s already exists." % (bucket, obj), changed=False) + else: + # setting valid object acls for the create_dirkey function + module.params['permission'] = object_acl + create_dirkey(module, s3, bucket, dirobj, encrypt) + else: + # only use valid bucket acls for the create_bucket function + module.params['permission'] = bucket_acl + created = create_bucket(module, s3, bucket, location) + # only use valid object acls for the create_dirkey function + module.params['permission'] = object_acl + create_dirkey(module, s3, bucket, dirobj, encrypt) + + # Support for grabbing the time-expired URL for an object in S3/Walrus. + if mode == 'geturl': + if not bucket and not obj: + module.fail_json(msg="Bucket and Object parameters must be set") + + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + if keyrtn: + get_download_url(module, s3, bucket, obj, expiry) + else: + module.fail_json(msg="Key %s does not exist." % obj) + + if mode == 'getstr': + if bucket and obj: + keyrtn = key_check(module, s3, bucket, obj, version=version, validate=validate) + if keyrtn: + try: + download_s3str(module, s3, bucket, obj, version=version) + except Sigv4Required: + s3 = get_s3_connection(module, aws_connect_kwargs, location, rgw, s3_url, sig_4=True) + download_s3str(module, s3, bucket, obj, version=version) + elif version is not None: + module.fail_json(msg="Key %s with version id %s does not exist." % (obj, version)) + else: + module.fail_json(msg="Key %s does not exist." % obj) + + module.exit_json(failed=False) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_appserviceplan.py b/test/support/integration/plugins/modules/azure_rm_appserviceplan.py new file mode 100644 index 00000000..ee871c35 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_appserviceplan.py @@ -0,0 +1,379 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_appserviceplan +version_added: "2.7" +short_description: Manage App Service Plan +description: + - Create, update and delete instance of App Service Plan. + +options: + resource_group: + description: + - Name of the resource group to which the resource belongs. + required: True + + name: + description: + - Unique name of the app service plan to create or update. + required: True + + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + + sku: + description: + - The pricing tiers, e.g., C(F1), C(D1), C(B1), C(B2), C(B3), C(S1), C(P1), C(P1V2) etc. + - Please see U(https://azure.microsoft.com/en-us/pricing/details/app-service/plans/) for more detail. + - For Linux app service plan, please see U(https://azure.microsoft.com/en-us/pricing/details/app-service/linux/) for more detail. + is_linux: + description: + - Describe whether to host webapp on Linux worker. + type: bool + default: false + + number_of_workers: + description: + - Describe number of workers to be allocated. + + state: + description: + - Assert the state of the app service plan. + - Use C(present) to create or update an app service plan and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Yunge Zhu (@yungezz) + +''' + +EXAMPLES = ''' + - name: Create a windows app service plan + azure_rm_appserviceplan: + resource_group: myResourceGroup + name: myAppPlan + location: eastus + sku: S1 + + - name: Create a linux app service plan + azure_rm_appserviceplan: + resource_group: myResourceGroup + name: myAppPlan + location: eastus + sku: S1 + is_linux: true + number_of_workers: 1 + + - name: update sku of existing windows app service plan + azure_rm_appserviceplan: + resource_group: myResourceGroup + name: myAppPlan + location: eastus + sku: S2 +''' + +RETURN = ''' +azure_appserviceplan: + description: Facts about the current state of the app service plan. + returned: always + type: dict + sample: { + "id": "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/myAppPlan" + } +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrestazure.azure_operation import AzureOperationPoller + from msrest.serialization import Model + from azure.mgmt.web.models import ( + app_service_plan, AppServicePlan, SkuDescription + ) +except ImportError: + # This is handled in azure_rm_common + pass + + +def _normalize_sku(sku): + if sku is None: + return sku + + sku = sku.upper() + if sku == 'FREE': + return 'F1' + elif sku == 'SHARED': + return 'D1' + return sku + + +def get_sku_name(tier): + tier = tier.upper() + if tier == 'F1' or tier == "FREE": + return 'FREE' + elif tier == 'D1' or tier == "SHARED": + return 'SHARED' + elif tier in ['B1', 'B2', 'B3', 'BASIC']: + return 'BASIC' + elif tier in ['S1', 'S2', 'S3']: + return 'STANDARD' + elif tier in ['P1', 'P2', 'P3']: + return 'PREMIUM' + elif tier in ['P1V2', 'P2V2', 'P3V2']: + return 'PREMIUMV2' + else: + return None + + +def appserviceplan_to_dict(plan): + return dict( + id=plan.id, + name=plan.name, + kind=plan.kind, + location=plan.location, + reserved=plan.reserved, + is_linux=plan.reserved, + provisioning_state=plan.provisioning_state, + status=plan.status, + target_worker_count=plan.target_worker_count, + sku=dict( + name=plan.sku.name, + size=plan.sku.size, + tier=plan.sku.tier, + family=plan.sku.family, + capacity=plan.sku.capacity + ), + resource_group=plan.resource_group, + number_of_sites=plan.number_of_sites, + tags=plan.tags if plan.tags else None + ) + + +class AzureRMAppServicePlans(AzureRMModuleBase): + """Configuration class for an Azure RM App Service Plan resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + location=dict( + type='str' + ), + sku=dict( + type='str' + ), + is_linux=dict( + type='bool', + default=False + ), + number_of_workers=dict( + type='str' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.name = None + self.location = None + + self.sku = None + self.is_linux = None + self.number_of_workers = 1 + + self.tags = None + + self.results = dict( + changed=False, + ansible_facts=dict(azure_appserviceplan=None) + ) + self.state = None + + super(AzureRMAppServicePlans, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if kwargs[key]: + setattr(self, key, kwargs[key]) + + old_response = None + response = None + to_be_updated = False + + # set location + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + self.location = resource_group.location + + # get app service plan + old_response = self.get_plan() + + # if not existing + if not old_response: + self.log("App Service plan doesn't exist") + + if self.state == "present": + to_be_updated = True + + if not self.sku: + self.fail('Please specify sku in plan when creation') + + else: + # existing app service plan, do update + self.log("App Service Plan already exists") + + if self.state == 'present': + self.log('Result: {0}'.format(old_response)) + + update_tags, newtags = self.update_tags(old_response.get('tags', dict())) + + if update_tags: + to_be_updated = True + self.tags = newtags + + # check if sku changed + if self.sku and _normalize_sku(self.sku) != old_response['sku']['size']: + to_be_updated = True + + # check if number_of_workers changed + if self.number_of_workers and int(self.number_of_workers) != old_response['sku']['capacity']: + to_be_updated = True + + if self.is_linux and self.is_linux != old_response['reserved']: + self.fail("Operation not allowed: cannot update reserved of app service plan.") + + if old_response: + self.results['id'] = old_response['id'] + + if to_be_updated: + self.log('Need to Create/Update app service plan') + self.results['changed'] = True + + if self.check_mode: + return self.results + + response = self.create_or_update_plan() + self.results['id'] = response['id'] + + if self.state == 'absent' and old_response: + self.log("Delete app service plan") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_plan() + + self.log('App service plan instance deleted') + + return self.results + + def get_plan(self): + ''' + Gets app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Get App Service Plan {0}".format(self.name)) + + try: + response = self.web_client.app_service_plans.get(self.resource_group, self.name) + if response: + self.log("Response : {0}".format(response)) + self.log("App Service Plan : {0} found".format(response.name)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + self.log("Didn't find app service plan {0} in resource group {1}".format(self.name, self.resource_group)) + + return False + + def create_or_update_plan(self): + ''' + Creates app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Create App Service Plan {0}".format(self.name)) + + try: + # normalize sku + sku = _normalize_sku(self.sku) + + sku_def = SkuDescription(tier=get_sku_name( + sku), name=sku, capacity=self.number_of_workers) + plan_def = AppServicePlan( + location=self.location, app_service_plan_name=self.name, sku=sku_def, reserved=self.is_linux, tags=self.tags if self.tags else None) + + response = self.web_client.app_service_plans.create_or_update(self.resource_group, self.name, plan_def) + + if isinstance(response, LROPoller) or isinstance(response, AzureOperationPoller): + response = self.get_poller_result(response) + + self.log("Response : {0}".format(response)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + self.fail("Failed to create app service plan {0} in resource group {1}: {2}".format(self.name, self.resource_group, str(ex))) + + def delete_plan(self): + ''' + Deletes specified App service plan in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the App service plan {0}".format(self.name)) + try: + response = self.web_client.app_service_plans.delete(resource_group_name=self.resource_group, + name=self.name) + except CloudError as e: + self.log('Error attempting to delete App service plan.') + self.fail( + "Error deleting the App service plan : {0}".format(str(e))) + + return True + + +def main(): + """Main execution""" + AzureRMAppServicePlans() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_functionapp.py b/test/support/integration/plugins/modules/azure_rm_functionapp.py new file mode 100644 index 00000000..0c372a88 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_functionapp.py @@ -0,0 +1,421 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Thomas Stringer <tomstr@microsoft.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_functionapp +version_added: "2.4" +short_description: Manage Azure Function Apps +description: + - Create, update or delete an Azure Function App. +options: + resource_group: + description: + - Name of resource group. + required: true + aliases: + - resource_group_name + name: + description: + - Name of the Azure Function App. + required: true + location: + description: + - Valid Azure location. Defaults to location of the resource group. + plan: + description: + - App service plan. + - It can be name of existing app service plan in same resource group as function app. + - It can be resource id of existing app service plan. + - Resource id. For example /subscriptions/<subs_id>/resourceGroups/<resource_group>/providers/Microsoft.Web/serverFarms/<plan_name>. + - It can be a dict which contains C(name), C(resource_group). + - C(name). Name of app service plan. + - C(resource_group). Resource group name of app service plan. + version_added: "2.8" + container_settings: + description: Web app container settings. + suboptions: + name: + description: + - Name of container. For example "imagename:tag". + registry_server_url: + description: + - Container registry server url. For example C(mydockerregistry.io). + registry_server_user: + description: + - The container registry server user name. + registry_server_password: + description: + - The container registry server password. + version_added: "2.8" + storage_account: + description: + - Name of the storage account to use. + required: true + aliases: + - storage + - storage_account_name + app_settings: + description: + - Dictionary containing application settings. + state: + description: + - Assert the state of the Function App. Use C(present) to create or update a Function App and C(absent) to delete. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Thomas Stringer (@trstringer) +''' + +EXAMPLES = ''' +- name: Create a function app + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + storage_account: myStorageAccount + +- name: Create a function app with app settings + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + storage_account: myStorageAccount + app_settings: + setting1: value1 + setting2: value2 + +- name: Create container based function app + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + storage_account: myStorageAccount + plan: + resource_group: myResourceGroup + name: myAppPlan + container_settings: + name: httpd + registry_server_url: index.docker.io + +- name: Delete a function app + azure_rm_functionapp: + resource_group: myResourceGroup + name: myFunctionApp + state: absent +''' + +RETURN = ''' +state: + description: + - Current state of the Azure Function App. + returned: success + type: dict + example: + id: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myFunctionApp + name: myfunctionapp + kind: functionapp + location: East US + type: Microsoft.Web/sites + state: Running + host_names: + - myfunctionapp.azurewebsites.net + repository_site_name: myfunctionapp + usage_state: Normal + enabled: true + enabled_host_names: + - myfunctionapp.azurewebsites.net + - myfunctionapp.scm.azurewebsites.net + availability_state: Normal + host_name_ssl_states: + - name: myfunctionapp.azurewebsites.net + ssl_state: Disabled + host_type: Standard + - name: myfunctionapp.scm.azurewebsites.net + ssl_state: Disabled + host_type: Repository + server_farm_id: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/EastUSPlan + reserved: false + last_modified_time_utc: 2017-08-22T18:54:01.190Z + scm_site_also_stopped: false + client_affinity_enabled: true + client_cert_enabled: false + host_names_disabled: false + outbound_ip_addresses: ............ + container_size: 1536 + daily_memory_time_quota: 0 + resource_group: myResourceGroup + default_host_name: myfunctionapp.azurewebsites.net +''' # NOQA + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from azure.mgmt.web.models import ( + site_config, app_service_plan, Site, SiteConfig, NameValuePair, SiteSourceControl, + AppServicePlan, SkuDescription + ) + from azure.mgmt.resource.resources import ResourceManagementClient + from msrest.polling import LROPoller +except ImportError: + # This is handled in azure_rm_common + pass + +container_settings_spec = dict( + name=dict(type='str', required=True), + registry_server_url=dict(type='str'), + registry_server_user=dict(type='str'), + registry_server_password=dict(type='str', no_log=True) +) + + +class AzureRMFunctionApp(AzureRMModuleBase): + + def __init__(self): + + self.module_arg_spec = dict( + resource_group=dict(type='str', required=True, aliases=['resource_group_name']), + name=dict(type='str', required=True), + state=dict(type='str', default='present', choices=['present', 'absent']), + location=dict(type='str'), + storage_account=dict( + type='str', + aliases=['storage', 'storage_account_name'] + ), + app_settings=dict(type='dict'), + plan=dict( + type='raw' + ), + container_settings=dict( + type='dict', + options=container_settings_spec + ) + ) + + self.results = dict( + changed=False, + state=dict() + ) + + self.resource_group = None + self.name = None + self.state = None + self.location = None + self.storage_account = None + self.app_settings = None + self.plan = None + self.container_settings = None + + required_if = [('state', 'present', ['storage_account'])] + + super(AzureRMFunctionApp, self).__init__( + self.module_arg_spec, + supports_check_mode=True, + required_if=required_if + ) + + def exec_module(self, **kwargs): + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + if self.app_settings is None: + self.app_settings = dict() + + try: + resource_group = self.rm_client.resource_groups.get(self.resource_group) + except CloudError: + self.fail('Unable to retrieve resource group') + + self.location = self.location or resource_group.location + + try: + function_app = self.web_client.web_apps.get( + resource_group_name=self.resource_group, + name=self.name + ) + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + exists = function_app is not None + except CloudError as exc: + exists = False + + if self.state == 'absent': + if exists: + if self.check_mode: + self.results['changed'] = True + return self.results + try: + self.web_client.web_apps.delete( + resource_group_name=self.resource_group, + name=self.name + ) + self.results['changed'] = True + except CloudError as exc: + self.fail('Failure while deleting web app: {0}'.format(exc)) + else: + self.results['changed'] = False + else: + kind = 'functionapp' + linux_fx_version = None + if self.container_settings and self.container_settings.get('name'): + kind = 'functionapp,linux,container' + linux_fx_version = 'DOCKER|' + if self.container_settings.get('registry_server_url'): + self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url'] + linux_fx_version += self.container_settings['registry_server_url'] + '/' + linux_fx_version += self.container_settings['name'] + if self.container_settings.get('registry_server_user'): + self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings.get('registry_server_user') + + if self.container_settings.get('registry_server_password'): + self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings.get('registry_server_password') + + if not self.plan and function_app: + self.plan = function_app.server_farm_id + + if not exists: + function_app = Site( + location=self.location, + kind=kind, + site_config=SiteConfig( + app_settings=self.aggregated_app_settings(), + scm_type='LocalGit' + ) + ) + self.results['changed'] = True + else: + self.results['changed'], function_app = self.update(function_app) + + # get app service plan + if self.plan: + if isinstance(self.plan, dict): + self.plan = "/subscriptions/{0}/resourceGroups/{1}/providers/Microsoft.Web/serverfarms/{2}".format( + self.subscription_id, + self.plan.get('resource_group', self.resource_group), + self.plan.get('name') + ) + function_app.server_farm_id = self.plan + + # set linux fx version + if linux_fx_version: + function_app.site_config.linux_fx_version = linux_fx_version + + if self.check_mode: + self.results['state'] = function_app.as_dict() + elif self.results['changed']: + try: + new_function_app = self.web_client.web_apps.create_or_update( + resource_group_name=self.resource_group, + name=self.name, + site_envelope=function_app + ).result() + self.results['state'] = new_function_app.as_dict() + except CloudError as exc: + self.fail('Error creating or updating web app: {0}'.format(exc)) + + return self.results + + def update(self, source_function_app): + """Update the Site object if there are any changes""" + + source_app_settings = self.web_client.web_apps.list_application_settings( + resource_group_name=self.resource_group, + name=self.name + ) + + changed, target_app_settings = self.update_app_settings(source_app_settings.properties) + + source_function_app.site_config = SiteConfig( + app_settings=target_app_settings, + scm_type='LocalGit' + ) + + return changed, source_function_app + + def update_app_settings(self, source_app_settings): + """Update app settings""" + + target_app_settings = self.aggregated_app_settings() + target_app_settings_dict = dict([(i.name, i.value) for i in target_app_settings]) + return target_app_settings_dict != source_app_settings, target_app_settings + + def necessary_functionapp_settings(self): + """Construct the necessary app settings required for an Azure Function App""" + + function_app_settings = [] + + if self.container_settings is None: + for key in ['AzureWebJobsStorage', 'WEBSITE_CONTENTAZUREFILECONNECTIONSTRING', 'AzureWebJobsDashboard']: + function_app_settings.append(NameValuePair(name=key, value=self.storage_connection_string)) + function_app_settings.append(NameValuePair(name='FUNCTIONS_EXTENSION_VERSION', value='~1')) + function_app_settings.append(NameValuePair(name='WEBSITE_NODE_DEFAULT_VERSION', value='6.5.0')) + function_app_settings.append(NameValuePair(name='WEBSITE_CONTENTSHARE', value=self.name)) + else: + function_app_settings.append(NameValuePair(name='FUNCTIONS_EXTENSION_VERSION', value='~2')) + function_app_settings.append(NameValuePair(name='WEBSITES_ENABLE_APP_SERVICE_STORAGE', value=False)) + function_app_settings.append(NameValuePair(name='AzureWebJobsStorage', value=self.storage_connection_string)) + + return function_app_settings + + def aggregated_app_settings(self): + """Combine both system and user app settings""" + + function_app_settings = self.necessary_functionapp_settings() + for app_setting_key in self.app_settings: + found_setting = None + for s in function_app_settings: + if s.name == app_setting_key: + found_setting = s + break + if found_setting: + found_setting.value = self.app_settings[app_setting_key] + else: + function_app_settings.append(NameValuePair( + name=app_setting_key, + value=self.app_settings[app_setting_key] + )) + return function_app_settings + + @property + def storage_connection_string(self): + """Construct the storage account connection string""" + + return 'DefaultEndpointsProtocol=https;AccountName={0};AccountKey={1}'.format( + self.storage_account, + self.storage_key + ) + + @property + def storage_key(self): + """Retrieve the storage account key""" + + return self.storage_client.storage_accounts.list_keys( + resource_group_name=self.resource_group, + account_name=self.storage_account + ).keys[0].value + + +def main(): + """Main function execution""" + + AzureRMFunctionApp() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_functionapp_info.py b/test/support/integration/plugins/modules/azure_rm_functionapp_info.py new file mode 100644 index 00000000..40672f95 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_functionapp_info.py @@ -0,0 +1,207 @@ +#!/usr/bin/python +# +# Copyright (c) 2016 Thomas Stringer, <tomstr@microsoft.com> + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_functionapp_info +version_added: "2.9" +short_description: Get Azure Function App facts +description: + - Get facts for one Azure Function App or all Function Apps within a resource group. +options: + name: + description: + - Only show results for a specific Function App. + resource_group: + description: + - Limit results to a resource group. Required when filtering by name. + aliases: + - resource_group_name + tags: + description: + - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'. + +extends_documentation_fragment: + - azure + +author: + - Thomas Stringer (@trstringer) +''' + +EXAMPLES = ''' + - name: Get facts for one Function App + azure_rm_functionapp_info: + resource_group: myResourceGroup + name: myfunctionapp + + - name: Get facts for all Function Apps in a resource group + azure_rm_functionapp_info: + resource_group: myResourceGroup + + - name: Get facts for all Function Apps by tags + azure_rm_functionapp_info: + tags: + - testing +''' + +RETURN = ''' +azure_functionapps: + description: + - List of Azure Function Apps dicts. + returned: always + type: list + example: + id: /subscriptions/.../resourceGroups/ansible-rg/providers/Microsoft.Web/sites/myfunctionapp + name: myfunctionapp + kind: functionapp + location: East US + type: Microsoft.Web/sites + state: Running + host_names: + - myfunctionapp.azurewebsites.net + repository_site_name: myfunctionapp + usage_state: Normal + enabled: true + enabled_host_names: + - myfunctionapp.azurewebsites.net + - myfunctionapp.scm.azurewebsites.net + availability_state: Normal + host_name_ssl_states: + - name: myfunctionapp.azurewebsites.net + ssl_state: Disabled + host_type: Standard + - name: myfunctionapp.scm.azurewebsites.net + ssl_state: Disabled + host_type: Repository + server_farm_id: /subscriptions/.../resourceGroups/ansible-rg/providers/Microsoft.Web/serverfarms/EastUSPlan + reserved: false + last_modified_time_utc: 2017-08-22T18:54:01.190Z + scm_site_also_stopped: false + client_affinity_enabled: true + client_cert_enabled: false + host_names_disabled: false + outbound_ip_addresses: ............ + container_size: 1536 + daily_memory_time_quota: 0 + resource_group: myResourceGroup + default_host_name: myfunctionapp.azurewebsites.net +''' + +try: + from msrestazure.azure_exceptions import CloudError +except Exception: + # This is handled in azure_rm_common + pass + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + + +class AzureRMFunctionAppInfo(AzureRMModuleBase): + def __init__(self): + + self.module_arg_spec = dict( + name=dict(type='str'), + resource_group=dict(type='str', aliases=['resource_group_name']), + tags=dict(type='list'), + ) + + self.results = dict( + changed=False, + ansible_info=dict(azure_functionapps=[]) + ) + + self.name = None + self.resource_group = None + self.tags = None + + super(AzureRMFunctionAppInfo, self).__init__( + self.module_arg_spec, + supports_tags=False, + facts_module=True + ) + + def exec_module(self, **kwargs): + + is_old_facts = self.module._name == 'azure_rm_functionapp_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_functionapp_facts' module has been renamed to 'azure_rm_functionapp_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if self.name and not self.resource_group: + self.fail("Parameter error: resource group required when filtering by name.") + + if self.name: + self.results['ansible_info']['azure_functionapps'] = self.get_functionapp() + elif self.resource_group: + self.results['ansible_info']['azure_functionapps'] = self.list_resource_group() + else: + self.results['ansible_info']['azure_functionapps'] = self.list_all() + + return self.results + + def get_functionapp(self): + self.log('Get properties for Function App {0}'.format(self.name)) + function_app = None + result = [] + + try: + function_app = self.web_client.web_apps.get( + self.resource_group, + self.name + ) + except CloudError: + pass + + if function_app and self.has_tags(function_app.tags, self.tags): + result = function_app.as_dict() + + return [result] + + def list_resource_group(self): + self.log('List items') + try: + response = self.web_client.web_apps.list_by_resource_group(self.resource_group) + except Exception as exc: + self.fail("Error listing for resource group {0} - {1}".format(self.resource_group, str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + results.append(item.as_dict()) + return results + + def list_all(self): + self.log('List all items') + try: + response = self.web_client.web_apps.list_by_resource_group(self.resource_group) + except Exception as exc: + self.fail("Error listing all items - {0}".format(str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + results.append(item.as_dict()) + return results + + +def main(): + AzureRMFunctionAppInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py new file mode 100644 index 00000000..212cf795 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration.py @@ -0,0 +1,241 @@ +#!/usr/bin/python +# +# Copyright (c) 2019 Zim Kalinowski, (@zikalino) +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbconfiguration +version_added: "2.8" +short_description: Manage Configuration instance +description: + - Create, update and delete instance of Configuration. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. + required: True + server_name: + description: + - The name of the server. + required: True + name: + description: + - The name of the server configuration. + required: True + value: + description: + - Value of the configuration. + state: + description: + - Assert the state of the MariaDB configuration. Use C(present) to update setting, or C(absent) to reset to default value. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) +''' + +EXAMPLES = ''' + - name: Update SQL Server setting + azure_rm_mariadbconfiguration: + resource_group: myResourceGroup + server_name: myServer + name: event_scheduler + value: "ON" +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/myServer/confi + gurations/event_scheduler" +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from azure.mgmt.rdbms.mysql import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbConfiguration(AzureRMModuleBase): + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + value=dict( + type='str' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.server_name = None + self.name = None + self.value = None + + self.results = dict(changed=False) + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbConfiguration, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=False) + + def exec_module(self, **kwargs): + + for key in list(self.module_arg_spec.keys()): + if hasattr(self, key): + setattr(self, key, kwargs[key]) + + old_response = None + response = None + + old_response = self.get_configuration() + + if not old_response: + self.log("Configuration instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("Configuration instance already exists") + if self.state == 'absent' and old_response['source'] == 'user-override': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if Configuration instance has to be deleted or may be updated") + if self.value != old_response.get('value'): + self.to_do = Actions.Update + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the Configuration instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_configuration() + + self.results['changed'] = True + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("Configuration instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_configuration() + else: + self.log("Configuration instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + + return self.results + + def create_update_configuration(self): + self.log("Creating / Updating the Configuration instance {0}".format(self.name)) + + try: + response = self.mariadb_client.configurations.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name, + value=self.value, + source='user-override') + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the Configuration instance.') + self.fail("Error creating the Configuration instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_configuration(self): + self.log("Deleting the Configuration instance {0}".format(self.name)) + try: + response = self.mariadb_client.configurations.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name, + source='system-default') + except CloudError as e: + self.log('Error attempting to delete the Configuration instance.') + self.fail("Error deleting the Configuration instance: {0}".format(str(e))) + + return True + + def get_configuration(self): + self.log("Checking if the Configuration instance {0} is present".format(self.name)) + found = False + try: + response = self.mariadb_client.configurations.get(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("Configuration instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the Configuration instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbConfiguration() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py new file mode 100644 index 00000000..3faac5eb --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbconfiguration_info.py @@ -0,0 +1,217 @@ +#!/usr/bin/python +# +# Copyright (c) 2019 Zim Kalinowski, (@zikalino) +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbconfiguration_info +version_added: "2.9" +short_description: Get Azure MariaDB Configuration facts +description: + - Get facts of Azure MariaDB Configuration. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + type: str + server_name: + description: + - The name of the server. + required: True + type: str + name: + description: + - Setting name. + type: str + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get specific setting of MariaDB Server + azure_rm_mariadbconfiguration_info: + resource_group: myResourceGroup + server_name: testserver + name: deadlock_timeout + + - name: Get all settings of MariaDB Server + azure_rm_mariadbconfiguration_info: + resource_group: myResourceGroup + server_name: server_name +''' + +RETURN = ''' +settings: + description: + - A list of dictionaries containing MariaDB Server settings. + returned: always + type: complex + contains: + id: + description: + - Setting resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver + /configurations/deadlock_timeout" + name: + description: + - Setting name. + returned: always + type: str + sample: deadlock_timeout + value: + description: + - Setting value. + returned: always + type: raw + sample: 1000 + description: + description: + - Description of the configuration. + returned: always + type: str + sample: Deadlock timeout. + source: + description: + - Source of the configuration. + returned: always + type: str + sample: system-default +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_operation import AzureOperationPoller + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbConfigurationInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict(changed=False) + self.mgmt_client = None + self.resource_group = None + self.server_name = None + self.name = None + super(AzureRMMariaDbConfigurationInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbconfiguration_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbconfiguration_facts' module has been renamed to 'azure_rm_mariadbconfiguration_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if self.name is not None: + self.results['settings'] = self.get() + else: + self.results['settings'] = self.list_by_server() + return self.results + + def get(self): + ''' + Gets facts of the specified MariaDB Configuration. + + :return: deserialized MariaDB Configurationinstance state dictionary + ''' + response = None + results = [] + try: + response = self.mgmt_client.configurations.get(resource_group_name=self.resource_group, + server_name=self.server_name, + configuration_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for Configurations.') + + if response is not None: + results.append(self.format_item(response)) + + return results + + def list_by_server(self): + ''' + Gets facts of the specified MariaDB Configuration. + + :return: deserialized MariaDB Configurationinstance state dictionary + ''' + response = None + results = [] + try: + response = self.mgmt_client.configurations.list_by_server(resource_group_name=self.resource_group, + server_name=self.server_name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for Configurations.') + + if response is not None: + for item in response: + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'resource_group': self.resource_group, + 'server_name': self.server_name, + 'id': d['id'], + 'name': d['name'], + 'value': d['value'], + 'description': d['description'], + 'source': d['source'] + } + return d + + +def main(): + AzureRMMariaDbConfigurationInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py new file mode 100644 index 00000000..8492b968 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase.py @@ -0,0 +1,304 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbdatabase +version_added: "2.8" +short_description: Manage MariaDB Database instance +description: + - Create, update and delete instance of MariaDB Database. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + server_name: + description: + - The name of the server. + required: True + name: + description: + - The name of the database. + required: True + charset: + description: + - The charset of the database. Check MariaDB documentation for possible values. + - This is only set on creation, use I(force_update) to recreate a database if the values don't match. + collation: + description: + - The collation of the database. Check MariaDB documentation for possible values. + - This is only set on creation, use I(force_update) to recreate a database if the values don't match. + force_update: + description: + - When set to C(true), will delete and recreate the existing MariaDB database if any of the properties don't match what is set. + - When set to C(false), no change will occur to the database even if any of the properties do not match. + type: bool + default: 'no' + state: + description: + - Assert the state of the MariaDB Database. Use C(present) to create or update a database and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Create (or update) MariaDB Database + azure_rm_mariadbdatabase: + resource_group: myResourceGroup + server_name: testserver + name: db1 +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver/databases/db1 +name: + description: + - Resource name. + returned: always + type: str + sample: db1 +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbDatabase(AzureRMModuleBase): + """Configuration class for an Azure RM MariaDB Database resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + charset=dict( + type='str' + ), + collation=dict( + type='str' + ), + force_update=dict( + type='bool', + default=False + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.server_name = None + self.name = None + self.force_update = None + self.parameters = dict() + + self.results = dict(changed=False) + self.mgmt_client = None + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbDatabase, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=False) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()): + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "charset": + self.parameters["charset"] = kwargs[key] + elif key == "collation": + self.parameters["collation"] = kwargs[key] + + old_response = None + response = None + + self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + resource_group = self.get_resource_group(self.resource_group) + + old_response = self.get_mariadbdatabase() + + if not old_response: + self.log("MariaDB Database instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("MariaDB Database instance already exists") + if self.state == 'absent': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if MariaDB Database instance has to be deleted or may be updated") + if ('collation' in self.parameters) and (self.parameters['collation'] != old_response['collation']): + self.to_do = Actions.Update + if ('charset' in self.parameters) and (self.parameters['charset'] != old_response['charset']): + self.to_do = Actions.Update + if self.to_do == Actions.Update: + if self.force_update: + if not self.check_mode: + self.delete_mariadbdatabase() + else: + self.fail("Database properties cannot be updated without setting 'force_update' option") + self.to_do = Actions.NoAction + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the MariaDB Database instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_mariadbdatabase() + self.results['changed'] = True + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("MariaDB Database instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_mariadbdatabase() + # make sure instance is actually deleted, for some Azure resources, instance is hanging around + # for some time after deletion -- this should be really fixed in Azure + while self.get_mariadbdatabase(): + time.sleep(20) + else: + self.log("MariaDB Database instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + self.results["name"] = response["name"] + + return self.results + + def create_update_mariadbdatabase(self): + ''' + Creates or updates MariaDB Database with the specified configuration. + + :return: deserialized MariaDB Database instance state dictionary + ''' + self.log("Creating / Updating the MariaDB Database instance {0}".format(self.name)) + + try: + response = self.mgmt_client.databases.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name, + parameters=self.parameters) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the MariaDB Database instance.') + self.fail("Error creating the MariaDB Database instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_mariadbdatabase(self): + ''' + Deletes specified MariaDB Database instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the MariaDB Database instance {0}".format(self.name)) + try: + response = self.mgmt_client.databases.delete(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name) + except CloudError as e: + self.log('Error attempting to delete the MariaDB Database instance.') + self.fail("Error deleting the MariaDB Database instance: {0}".format(str(e))) + + return True + + def get_mariadbdatabase(self): + ''' + Gets the properties of the specified MariaDB Database. + + :return: deserialized MariaDB Database instance state dictionary + ''' + self.log("Checking if the MariaDB Database instance {0} is present".format(self.name)) + found = False + try: + response = self.mgmt_client.databases.get(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("MariaDB Database instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the MariaDB Database instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbDatabase() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py new file mode 100644 index 00000000..e9c99c14 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbdatabase_info.py @@ -0,0 +1,212 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbdatabase_info +version_added: "2.9" +short_description: Get Azure MariaDB Database facts +description: + - Get facts of MariaDB Database. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + type: str + server_name: + description: + - The name of the server. + required: True + type: str + name: + description: + - The name of the database. + type: str + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get instance of MariaDB Database + azure_rm_mariadbdatabase_info: + resource_group: myResourceGroup + server_name: server_name + name: database_name + + - name: List instances of MariaDB Database + azure_rm_mariadbdatabase_info: + resource_group: myResourceGroup + server_name: server_name +''' + +RETURN = ''' +databases: + description: + - A list of dictionaries containing facts for MariaDB Databases. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testser + ver/databases/db1" + resource_group: + description: + - Resource group name. + returned: always + type: str + sample: testrg + server_name: + description: + - Server name. + returned: always + type: str + sample: testserver + name: + description: + - Resource name. + returned: always + type: str + sample: db1 + charset: + description: + - The charset of the database. + returned: always + type: str + sample: UTF8 + collation: + description: + - The collation of the database. + returned: always + type: str + sample: English_United States.1252 +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbDatabaseInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict( + changed=False + ) + self.resource_group = None + self.server_name = None + self.name = None + super(AzureRMMariaDbDatabaseInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbdatabase_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbdatabase_facts' module has been renamed to 'azure_rm_mariadbdatabase_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if (self.resource_group is not None and + self.server_name is not None and + self.name is not None): + self.results['databases'] = self.get() + elif (self.resource_group is not None and + self.server_name is not None): + self.results['databases'] = self.list_by_server() + return self.results + + def get(self): + response = None + results = [] + try: + response = self.mariadb_client.databases.get(resource_group_name=self.resource_group, + server_name=self.server_name, + database_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for Databases.') + + if response is not None: + results.append(self.format_item(response)) + + return results + + def list_by_server(self): + response = None + results = [] + try: + response = self.mariadb_client.databases.list_by_server(resource_group_name=self.resource_group, + server_name=self.server_name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.fail("Error listing for server {0} - {1}".format(self.server_name, str(e))) + + if response is not None: + for item in response: + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'resource_group': self.resource_group, + 'server_name': self.server_name, + 'name': d['name'], + 'charset': d['charset'], + 'collation': d['collation'] + } + return d + + +def main(): + AzureRMMariaDbDatabaseInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py new file mode 100644 index 00000000..1fc8c5e7 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule.py @@ -0,0 +1,277 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbfirewallrule +version_added: "2.8" +short_description: Manage MariaDB firewall rule instance +description: + - Create, update and delete instance of MariaDB firewall rule. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + server_name: + description: + - The name of the server. + required: True + name: + description: + - The name of the MariaDB firewall rule. + required: True + start_ip_address: + description: + - The start IP address of the MariaDB firewall rule. Must be IPv4 format. + end_ip_address: + description: + - The end IP address of the MariaDB firewall rule. Must be IPv4 format. + state: + description: + - Assert the state of the MariaDB firewall rule. Use C(present) to create or update a rule and C(absent) to ensure it is not present. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Create (or update) MariaDB firewall rule + azure_rm_mariadbfirewallrule: + resource_group: myResourceGroup + server_name: testserver + name: rule1 + start_ip_address: 10.0.0.17 + end_ip_address: 10.0.0.20 +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/testserver/fire + wallRules/rule1" +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbFirewallRule(AzureRMModuleBase): + """Configuration class for an Azure RM MariaDB firewall rule resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + start_ip_address=dict( + type='str' + ), + end_ip_address=dict( + type='str' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.server_name = None + self.name = None + self.start_ip_address = None + self.end_ip_address = None + + self.results = dict(changed=False) + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbFirewallRule, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=False) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()): + if hasattr(self, key): + setattr(self, key, kwargs[key]) + + old_response = None + response = None + + resource_group = self.get_resource_group(self.resource_group) + + old_response = self.get_firewallrule() + + if not old_response: + self.log("MariaDB firewall rule instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("MariaDB firewall rule instance already exists") + if self.state == 'absent': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if MariaDB firewall rule instance has to be deleted or may be updated") + if (self.start_ip_address is not None) and (self.start_ip_address != old_response['start_ip_address']): + self.to_do = Actions.Update + if (self.end_ip_address is not None) and (self.end_ip_address != old_response['end_ip_address']): + self.to_do = Actions.Update + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the MariaDB firewall rule instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_firewallrule() + + if not old_response: + self.results['changed'] = True + else: + self.results['changed'] = old_response.__ne__(response) + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("MariaDB firewall rule instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_firewallrule() + # make sure instance is actually deleted, for some Azure resources, instance is hanging around + # for some time after deletion -- this should be really fixed in Azure + while self.get_firewallrule(): + time.sleep(20) + else: + self.log("MariaDB firewall rule instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + + return self.results + + def create_update_firewallrule(self): + ''' + Creates or updates MariaDB firewall rule with the specified configuration. + + :return: deserialized MariaDB firewall rule instance state dictionary + ''' + self.log("Creating / Updating the MariaDB firewall rule instance {0}".format(self.name)) + + try: + response = self.mariadb_client.firewall_rules.create_or_update(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name, + start_ip_address=self.start_ip_address, + end_ip_address=self.end_ip_address) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the MariaDB firewall rule instance.') + self.fail("Error creating the MariaDB firewall rule instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_firewallrule(self): + ''' + Deletes specified MariaDB firewall rule instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the MariaDB firewall rule instance {0}".format(self.name)) + try: + response = self.mariadb_client.firewall_rules.delete(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name) + except CloudError as e: + self.log('Error attempting to delete the MariaDB firewall rule instance.') + self.fail("Error deleting the MariaDB firewall rule instance: {0}".format(str(e))) + + return True + + def get_firewallrule(self): + ''' + Gets the properties of the specified MariaDB firewall rule. + + :return: deserialized MariaDB firewall rule instance state dictionary + ''' + self.log("Checking if the MariaDB firewall rule instance {0} is present".format(self.name)) + found = False + try: + response = self.mariadb_client.firewall_rules.get(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("MariaDB firewall rule instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the MariaDB firewall rule instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbFirewallRule() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py new file mode 100644 index 00000000..ef71be8d --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbfirewallrule_info.py @@ -0,0 +1,208 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbfirewallrule_info +version_added: "2.9" +short_description: Get Azure MariaDB Firewall Rule facts +description: + - Get facts of Azure MariaDB Firewall Rule. + +options: + resource_group: + description: + - The name of the resource group. + required: True + type: str + server_name: + description: + - The name of the server. + required: True + type: str + name: + description: + - The name of the server firewall rule. + type: str + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get instance of MariaDB Firewall Rule + azure_rm_mariadbfirewallrule_info: + resource_group: myResourceGroup + server_name: server_name + name: firewall_rule_name + + - name: List instances of MariaDB Firewall Rule + azure_rm_mariadbfirewallrule_info: + resource_group: myResourceGroup + server_name: server_name +''' + +RETURN = ''' +rules: + description: + - A list of dictionaries containing facts for MariaDB Firewall Rule. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/TestGroup/providers/Microsoft.DBforMariaDB/servers/testserver/fire + wallRules/rule1" + server_name: + description: + - The name of the server. + returned: always + type: str + sample: testserver + name: + description: + - Resource name. + returned: always + type: str + sample: rule1 + start_ip_address: + description: + - The start IP address of the MariaDB firewall rule. + returned: always + type: str + sample: 10.0.0.16 + end_ip_address: + description: + - The end IP address of the MariaDB firewall rule. + returned: always + type: str + sample: 10.0.0.18 +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrestazure.azure_operation import AzureOperationPoller + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbFirewallRuleInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + server_name=dict( + type='str', + required=True + ), + name=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict( + changed=False + ) + self.mgmt_client = None + self.resource_group = None + self.server_name = None + self.name = None + super(AzureRMMariaDbFirewallRuleInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbfirewallrule_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbfirewallrule_facts' module has been renamed to 'azure_rm_mariadbfirewallrule_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(MariaDBManagementClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if (self.name is not None): + self.results['rules'] = self.get() + else: + self.results['rules'] = self.list_by_server() + return self.results + + def get(self): + response = None + results = [] + try: + response = self.mgmt_client.firewall_rules.get(resource_group_name=self.resource_group, + server_name=self.server_name, + firewall_rule_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for FirewallRules.') + + if response is not None: + results.append(self.format_item(response)) + + return results + + def list_by_server(self): + response = None + results = [] + try: + response = self.mgmt_client.firewall_rules.list_by_server(resource_group_name=self.resource_group, + server_name=self.server_name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for FirewallRules.') + + if response is not None: + for item in response: + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'resource_group': self.resource_group, + 'id': d['id'], + 'server_name': self.server_name, + 'name': d['name'], + 'start_ip_address': d['start_ip_address'], + 'end_ip_address': d['end_ip_address'] + } + return d + + +def main(): + AzureRMMariaDbFirewallRuleInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbserver.py b/test/support/integration/plugins/modules/azure_rm_mariadbserver.py new file mode 100644 index 00000000..30a29988 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbserver.py @@ -0,0 +1,388 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbserver +version_added: "2.8" +short_description: Manage MariaDB Server instance +description: + - Create, update and delete instance of MariaDB Server. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + name: + description: + - The name of the server. + required: True + sku: + description: + - The SKU (pricing tier) of the server. + suboptions: + name: + description: + - The name of the SKU, typically, tier + family + cores, for example C(B_Gen4_1), C(GP_Gen5_8). + tier: + description: + - The tier of the particular SKU, for example C(Basic). + choices: + - basic + - standard + capacity: + description: + - The scale up/out capacity, representing server's compute units. + type: int + size: + description: + - The size code, to be interpreted by resource as appropriate. + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + storage_mb: + description: + - The maximum storage allowed for a server. + type: int + version: + description: + - Server version. + choices: + - 10.2 + enforce_ssl: + description: + - Enable SSL enforcement. + type: bool + default: False + admin_username: + description: + - The administrator's login name of a server. Can only be specified when the server is being created (and is required for creation). + admin_password: + description: + - The password of the administrator login. + create_mode: + description: + - Create mode of SQL Server. + default: Default + state: + description: + - Assert the state of the MariaDB Server. Use C(present) to create or update a server and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Create (or update) MariaDB Server + azure_rm_mariadbserver: + resource_group: myResourceGroup + name: testserver + sku: + name: B_Gen5_1 + tier: Basic + location: eastus + storage_mb: 1024 + enforce_ssl: True + version: 10.2 + admin_username: cloudsa + admin_password: password +''' + +RETURN = ''' +id: + description: + - Resource ID. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/mariadbsrv1b6dd89593 +version: + description: + - Server version. Possible values include C(10.2). + returned: always + type: str + sample: 10.2 +state: + description: + - A state of a server that is visible to user. Possible values include C(Ready), C(Dropping), C(Disabled). + returned: always + type: str + sample: Ready +fully_qualified_domain_name: + description: + - The fully qualified domain name of a server. + returned: always + type: str + sample: mariadbsrv1b6dd89593.mariadb.database.azure.com +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class Actions: + NoAction, Create, Update, Delete = range(4) + + +class AzureRMMariaDbServers(AzureRMModuleBase): + """Configuration class for an Azure RM MariaDB Server resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + sku=dict( + type='dict' + ), + location=dict( + type='str' + ), + storage_mb=dict( + type='int' + ), + version=dict( + type='str', + choices=['10.2'] + ), + enforce_ssl=dict( + type='bool', + default=False + ), + create_mode=dict( + type='str', + default='Default' + ), + admin_username=dict( + type='str' + ), + admin_password=dict( + type='str', + no_log=True + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + self.resource_group = None + self.name = None + self.parameters = dict() + self.tags = None + + self.results = dict(changed=False) + self.state = None + self.to_do = Actions.NoAction + + super(AzureRMMariaDbServers, self).__init__(derived_arg_spec=self.module_arg_spec, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "sku": + ev = kwargs[key] + if 'tier' in ev: + if ev['tier'] == 'basic': + ev['tier'] = 'Basic' + elif ev['tier'] == 'standard': + ev['tier'] = 'Standard' + self.parameters["sku"] = ev + elif key == "location": + self.parameters["location"] = kwargs[key] + elif key == "storage_mb": + self.parameters.setdefault("properties", {}).setdefault("storage_profile", {})["storage_mb"] = kwargs[key] + elif key == "version": + self.parameters.setdefault("properties", {})["version"] = kwargs[key] + elif key == "enforce_ssl": + self.parameters.setdefault("properties", {})["ssl_enforcement"] = 'Enabled' if kwargs[key] else 'Disabled' + elif key == "create_mode": + self.parameters.setdefault("properties", {})["create_mode"] = kwargs[key] + elif key == "admin_username": + self.parameters.setdefault("properties", {})["administrator_login"] = kwargs[key] + elif key == "admin_password": + self.parameters.setdefault("properties", {})["administrator_login_password"] = kwargs[key] + + old_response = None + response = None + + resource_group = self.get_resource_group(self.resource_group) + + if "location" not in self.parameters: + self.parameters["location"] = resource_group.location + + old_response = self.get_mariadbserver() + + if not old_response: + self.log("MariaDB Server instance doesn't exist") + if self.state == 'absent': + self.log("Old instance didn't exist") + else: + self.to_do = Actions.Create + else: + self.log("MariaDB Server instance already exists") + if self.state == 'absent': + self.to_do = Actions.Delete + elif self.state == 'present': + self.log("Need to check if MariaDB Server instance has to be deleted or may be updated") + update_tags, newtags = self.update_tags(old_response.get('tags', {})) + if update_tags: + self.tags = newtags + self.to_do = Actions.Update + + if (self.to_do == Actions.Create) or (self.to_do == Actions.Update): + self.log("Need to Create / Update the MariaDB Server instance") + + if self.check_mode: + self.results['changed'] = True + return self.results + + response = self.create_update_mariadbserver() + + if not old_response: + self.results['changed'] = True + else: + self.results['changed'] = old_response.__ne__(response) + self.log("Creation / Update done") + elif self.to_do == Actions.Delete: + self.log("MariaDB Server instance deleted") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_mariadbserver() + # make sure instance is actually deleted, for some Azure resources, instance is hanging around + # for some time after deletion -- this should be really fixed in Azure + while self.get_mariadbserver(): + time.sleep(20) + else: + self.log("MariaDB Server instance unchanged") + self.results['changed'] = False + response = old_response + + if response: + self.results["id"] = response["id"] + self.results["version"] = response["version"] + self.results["state"] = response["user_visible_state"] + self.results["fully_qualified_domain_name"] = response["fully_qualified_domain_name"] + + return self.results + + def create_update_mariadbserver(self): + ''' + Creates or updates MariaDB Server with the specified configuration. + + :return: deserialized MariaDB Server instance state dictionary + ''' + self.log("Creating / Updating the MariaDB Server instance {0}".format(self.name)) + + try: + self.parameters['tags'] = self.tags + if self.to_do == Actions.Create: + response = self.mariadb_client.servers.create(resource_group_name=self.resource_group, + server_name=self.name, + parameters=self.parameters) + else: + # structure of parameters for update must be changed + self.parameters.update(self.parameters.pop("properties", {})) + response = self.mariadb_client.servers.update(resource_group_name=self.resource_group, + server_name=self.name, + parameters=self.parameters) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the MariaDB Server instance.') + self.fail("Error creating the MariaDB Server instance: {0}".format(str(exc))) + return response.as_dict() + + def delete_mariadbserver(self): + ''' + Deletes specified MariaDB Server instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the MariaDB Server instance {0}".format(self.name)) + try: + response = self.mariadb_client.servers.delete(resource_group_name=self.resource_group, + server_name=self.name) + except CloudError as e: + self.log('Error attempting to delete the MariaDB Server instance.') + self.fail("Error deleting the MariaDB Server instance: {0}".format(str(e))) + + return True + + def get_mariadbserver(self): + ''' + Gets the properties of the specified MariaDB Server. + + :return: deserialized MariaDB Server instance state dictionary + ''' + self.log("Checking if the MariaDB Server instance {0} is present".format(self.name)) + found = False + try: + response = self.mariadb_client.servers.get(resource_group_name=self.resource_group, + server_name=self.name) + found = True + self.log("Response : {0}".format(response)) + self.log("MariaDB Server instance : {0} found".format(response.name)) + except CloudError as e: + self.log('Did not find the MariaDB Server instance.') + if found is True: + return response.as_dict() + + return False + + +def main(): + """Main execution""" + AzureRMMariaDbServers() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py b/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py new file mode 100644 index 00000000..464aa4d8 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_mariadbserver_info.py @@ -0,0 +1,265 @@ +#!/usr/bin/python +# +# Copyright (c) 2017 Zim Kalinowski, <zikalino@microsoft.com> +# Copyright (c) 2019 Matti Ranta, (@techknowlogick) +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_mariadbserver_info +version_added: "2.9" +short_description: Get Azure MariaDB Server facts +description: + - Get facts of MariaDB Server. + +options: + resource_group: + description: + - The name of the resource group that contains the resource. You can obtain this value from the Azure Resource Manager API or the portal. + required: True + type: str + name: + description: + - The name of the server. + type: str + tags: + description: + - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'. + type: list + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + - Matti Ranta (@techknowlogick) + +''' + +EXAMPLES = ''' + - name: Get instance of MariaDB Server + azure_rm_mariadbserver_info: + resource_group: myResourceGroup + name: server_name + + - name: List instances of MariaDB Server + azure_rm_mariadbserver_info: + resource_group: myResourceGroup +''' + +RETURN = ''' +servers: + description: + - A list of dictionaries containing facts for MariaDB servers. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.DBforMariaDB/servers/myabdud1223 + resource_group: + description: + - Resource group name. + returned: always + type: str + sample: myResourceGroup + name: + description: + - Resource name. + returned: always + type: str + sample: myabdud1223 + location: + description: + - The location the resource resides in. + returned: always + type: str + sample: eastus + sku: + description: + - The SKU of the server. + returned: always + type: complex + contains: + name: + description: + - The name of the SKU. + returned: always + type: str + sample: GP_Gen4_2 + tier: + description: + - The tier of the particular SKU. + returned: always + type: str + sample: GeneralPurpose + capacity: + description: + - The scale capacity. + returned: always + type: int + sample: 2 + storage_mb: + description: + - The maximum storage allowed for a server. + returned: always + type: int + sample: 128000 + enforce_ssl: + description: + - Enable SSL enforcement. + returned: always + type: bool + sample: False + admin_username: + description: + - The administrator's login name of a server. + returned: always + type: str + sample: serveradmin + version: + description: + - Server version. + returned: always + type: str + sample: "9.6" + user_visible_state: + description: + - A state of a server that is visible to user. + returned: always + type: str + sample: Ready + fully_qualified_domain_name: + description: + - The fully qualified domain name of a server. + returned: always + type: str + sample: myabdud1223.mys.database.azure.com + tags: + description: + - Tags assigned to the resource. Dictionary of string:string pairs. + type: dict + sample: { tag1: abc } +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from azure.mgmt.rdbms.mariadb import MariaDBManagementClient + from msrest.serialization import Model +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMMariaDbServerInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str' + ), + tags=dict( + type='list' + ) + ) + # store the results of the module operation + self.results = dict( + changed=False + ) + self.resource_group = None + self.name = None + self.tags = None + super(AzureRMMariaDbServerInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_mariadbserver_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_mariadbserver_facts' module has been renamed to 'azure_rm_mariadbserver_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if (self.resource_group is not None and + self.name is not None): + self.results['servers'] = self.get() + elif (self.resource_group is not None): + self.results['servers'] = self.list_by_resource_group() + return self.results + + def get(self): + response = None + results = [] + try: + response = self.mariadb_client.servers.get(resource_group_name=self.resource_group, + server_name=self.name) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for MariaDB Server.') + + if response and self.has_tags(response.tags, self.tags): + results.append(self.format_item(response)) + + return results + + def list_by_resource_group(self): + response = None + results = [] + try: + response = self.mariadb_client.servers.list_by_resource_group(resource_group_name=self.resource_group) + self.log("Response : {0}".format(response)) + except CloudError as e: + self.log('Could not get facts for MariaDB Servers.') + + if response is not None: + for item in response: + if self.has_tags(item.tags, self.tags): + results.append(self.format_item(item)) + + return results + + def format_item(self, item): + d = item.as_dict() + d = { + 'id': d['id'], + 'resource_group': self.resource_group, + 'name': d['name'], + 'sku': d['sku'], + 'location': d['location'], + 'storage_mb': d['storage_profile']['storage_mb'], + 'version': d['version'], + 'enforce_ssl': (d['ssl_enforcement'] == 'Enabled'), + 'admin_username': d['administrator_login'], + 'user_visible_state': d['user_visible_state'], + 'fully_qualified_domain_name': d['fully_qualified_domain_name'], + 'tags': d.get('tags') + } + + return d + + +def main(): + AzureRMMariaDbServerInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_resource.py b/test/support/integration/plugins/modules/azure_rm_resource.py new file mode 100644 index 00000000..6ea3e3bb --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_resource.py @@ -0,0 +1,427 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_resource +version_added: "2.6" +short_description: Create any Azure resource +description: + - Create, update or delete any Azure resource using Azure REST API. + - This module gives access to resources that are not supported via Ansible modules. + - Refer to U(https://docs.microsoft.com/en-us/rest/api/) regarding details related to specific resource REST API. + +options: + url: + description: + - Azure RM Resource URL. + api_version: + description: + - Specific API version to be used. + provider: + description: + - Provider type. + - Required if URL is not specified. + resource_group: + description: + - Resource group to be used. + - Required if URL is not specified. + resource_type: + description: + - Resource type. + - Required if URL is not specified. + resource_name: + description: + - Resource name. + - Required if URL Is not specified. + subresource: + description: + - List of subresources. + suboptions: + namespace: + description: + - Subresource namespace. + type: + description: + - Subresource type. + name: + description: + - Subresource name. + body: + description: + - The body of the HTTP request/response to the web service. + method: + description: + - The HTTP method of the request or response. It must be uppercase. + choices: + - GET + - PUT + - POST + - HEAD + - PATCH + - DELETE + - MERGE + default: "PUT" + status_code: + description: + - A valid, numeric, HTTP status code that signifies success of the request. Can also be comma separated list of status codes. + type: list + default: [ 200, 201, 202 ] + idempotency: + description: + - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body). + default: no + type: bool + polling_timeout: + description: + - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body). + default: 0 + type: int + version_added: "2.8" + polling_interval: + description: + - If enabled, idempotency check will be done by using I(method=GET) first and then comparing with I(body). + default: 60 + type: int + version_added: "2.8" + state: + description: + - Assert the state of the resource. Use C(present) to create or update resource or C(absent) to delete resource. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + +''' + +EXAMPLES = ''' + - name: Update scaleset info using azure_rm_resource + azure_rm_resource: + resource_group: myResourceGroup + provider: compute + resource_type: virtualmachinescalesets + resource_name: myVmss + api_version: "2017-12-01" + body: { body } +''' + +RETURN = ''' +response: + description: + - Response specific to resource type. + returned: always + type: complex + contains: + id: + description: + - Resource ID. + type: str + returned: always + sample: "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Storage/storageAccounts/staccb57dc95183" + kind: + description: + - The kind of storage. + type: str + returned: always + sample: Storage + location: + description: + - The resource location, defaults to location of the resource group. + type: str + returned: always + sample: eastus + name: + description: + The storage account name. + type: str + returned: always + sample: staccb57dc95183 + properties: + description: + - The storage account's related properties. + type: dict + returned: always + sample: { + "creationTime": "2019-06-13T06:34:33.0996676Z", + "encryption": { + "keySource": "Microsoft.Storage", + "services": { + "blob": { + "enabled": true, + "lastEnabledTime": "2019-06-13T06:34:33.1934074Z" + }, + "file": { + "enabled": true, + "lastEnabledTime": "2019-06-13T06:34:33.1934074Z" + } + } + }, + "networkAcls": { + "bypass": "AzureServices", + "defaultAction": "Allow", + "ipRules": [], + "virtualNetworkRules": [] + }, + "primaryEndpoints": { + "blob": "https://staccb57dc95183.blob.core.windows.net/", + "file": "https://staccb57dc95183.file.core.windows.net/", + "queue": "https://staccb57dc95183.queue.core.windows.net/", + "table": "https://staccb57dc95183.table.core.windows.net/" + }, + "primaryLocation": "eastus", + "provisioningState": "Succeeded", + "secondaryLocation": "westus", + "statusOfPrimary": "available", + "statusOfSecondary": "available", + "supportsHttpsTrafficOnly": false + } + sku: + description: + - The storage account SKU. + type: dict + returned: always + sample: { + "name": "Standard_GRS", + "tier": "Standard" + } + tags: + description: + - Resource tags. + type: dict + returned: always + sample: { 'key1': 'value1' } + type: + description: + - The resource type. + type: str + returned: always + sample: "Microsoft.Storage/storageAccounts" + +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase +from ansible.module_utils.azure_rm_common_rest import GenericRestClient +from ansible.module_utils.common.dict_transformations import dict_merge + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.service_client import ServiceClient + from msrestazure.tools import resource_id, is_valid_resource_id + import json + +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMResource(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + url=dict( + type='str' + ), + provider=dict( + type='str', + ), + resource_group=dict( + type='str', + ), + resource_type=dict( + type='str', + ), + resource_name=dict( + type='str', + ), + subresource=dict( + type='list', + default=[] + ), + api_version=dict( + type='str' + ), + method=dict( + type='str', + default='PUT', + choices=["GET", "PUT", "POST", "HEAD", "PATCH", "DELETE", "MERGE"] + ), + body=dict( + type='raw' + ), + status_code=dict( + type='list', + default=[200, 201, 202] + ), + idempotency=dict( + type='bool', + default=False + ), + polling_timeout=dict( + type='int', + default=0 + ), + polling_interval=dict( + type='int', + default=60 + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + # store the results of the module operation + self.results = dict( + changed=False, + response=None + ) + self.mgmt_client = None + self.url = None + self.api_version = None + self.provider = None + self.resource_group = None + self.resource_type = None + self.resource_name = None + self.subresource_type = None + self.subresource_name = None + self.subresource = [] + self.method = None + self.status_code = [] + self.idempotency = False + self.polling_timeout = None + self.polling_interval = None + self.state = None + self.body = None + super(AzureRMResource, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(GenericRestClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if self.state == 'absent': + self.method = 'DELETE' + self.status_code.append(204) + + if self.url is None: + orphan = None + rargs = dict() + rargs['subscription'] = self.subscription_id + rargs['resource_group'] = self.resource_group + if not (self.provider is None or self.provider.lower().startswith('.microsoft')): + rargs['namespace'] = "Microsoft." + self.provider + else: + rargs['namespace'] = self.provider + + if self.resource_type is not None and self.resource_name is not None: + rargs['type'] = self.resource_type + rargs['name'] = self.resource_name + for i in range(len(self.subresource)): + resource_ns = self.subresource[i].get('namespace', None) + resource_type = self.subresource[i].get('type', None) + resource_name = self.subresource[i].get('name', None) + if resource_type is not None and resource_name is not None: + rargs['child_namespace_' + str(i + 1)] = resource_ns + rargs['child_type_' + str(i + 1)] = resource_type + rargs['child_name_' + str(i + 1)] = resource_name + else: + orphan = resource_type + else: + orphan = self.resource_type + + self.url = resource_id(**rargs) + + if orphan is not None: + self.url += '/' + orphan + + # if api_version was not specified, get latest one + if not self.api_version: + try: + # extract provider and resource type + if "/providers/" in self.url: + provider = self.url.split("/providers/")[1].split("/")[0] + resourceType = self.url.split(provider + "/")[1].split("/")[0] + url = "/subscriptions/" + self.subscription_id + "/providers/" + provider + api_versions = json.loads(self.mgmt_client.query(url, "GET", {'api-version': '2015-01-01'}, None, None, [200], 0, 0).text) + for rt in api_versions['resourceTypes']: + if rt['resourceType'].lower() == resourceType.lower(): + self.api_version = rt['apiVersions'][0] + break + else: + # if there's no provider in API version, assume Microsoft.Resources + self.api_version = '2018-05-01' + if not self.api_version: + self.fail("Couldn't find api version for {0}/{1}".format(provider, resourceType)) + except Exception as exc: + self.fail("Failed to obtain API version: {0}".format(str(exc))) + + query_parameters = {} + query_parameters['api-version'] = self.api_version + + header_parameters = {} + header_parameters['Content-Type'] = 'application/json; charset=utf-8' + + needs_update = True + response = None + + if self.idempotency: + original = self.mgmt_client.query(self.url, "GET", query_parameters, None, None, [200, 404], 0, 0) + + if original.status_code == 404: + if self.state == 'absent': + needs_update = False + else: + try: + response = json.loads(original.text) + needs_update = (dict_merge(response, self.body) != response) + except Exception: + pass + + if needs_update: + response = self.mgmt_client.query(self.url, + self.method, + query_parameters, + header_parameters, + self.body, + self.status_code, + self.polling_timeout, + self.polling_interval) + if self.state == 'present': + try: + response = json.loads(response.text) + except Exception: + response = response.text + else: + response = None + + self.results['response'] = response + self.results['changed'] = needs_update + + return self.results + + +def main(): + AzureRMResource() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_resource_info.py b/test/support/integration/plugins/modules/azure_rm_resource_info.py new file mode 100644 index 00000000..f797f662 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_resource_info.py @@ -0,0 +1,432 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Zim Kalinowski, <zikalino@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_resource_info +version_added: "2.9" +short_description: Generic facts of Azure resources +description: + - Obtain facts of any resource using Azure REST API. + - This module gives access to resources that are not supported via Ansible modules. + - Refer to U(https://docs.microsoft.com/en-us/rest/api/) regarding details related to specific resource REST API. + +options: + url: + description: + - Azure RM Resource URL. + api_version: + description: + - Specific API version to be used. + provider: + description: + - Provider type, should be specified in no URL is given. + resource_group: + description: + - Resource group to be used. + - Required if URL is not specified. + resource_type: + description: + - Resource type. + resource_name: + description: + - Resource name. + subresource: + description: + - List of subresources. + suboptions: + namespace: + description: + - Subresource namespace. + type: + description: + - Subresource type. + name: + description: + - Subresource name. + +extends_documentation_fragment: + - azure + +author: + - Zim Kalinowski (@zikalino) + +''' + +EXAMPLES = ''' + - name: Get scaleset info + azure_rm_resource_info: + resource_group: myResourceGroup + provider: compute + resource_type: virtualmachinescalesets + resource_name: myVmss + api_version: "2017-12-01" + + - name: Query all the resources in the resource group + azure_rm_resource_info: + resource_group: "{{ resource_group }}" + resource_type: resources +''' + +RETURN = ''' +response: + description: + - Response specific to resource type. + returned: always + type: complex + contains: + id: + description: + - Id of the Azure resource. + type: str + returned: always + sample: "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Compute/virtualMachines/myVM" + location: + description: + - Resource location. + type: str + returned: always + sample: eastus + name: + description: + - Resource name. + type: str + returned: always + sample: myVM + properties: + description: + - Specifies the virtual machine's property. + type: complex + returned: always + contains: + diagnosticsProfile: + description: + - Specifies the boot diagnostic settings state. + type: complex + returned: always + contains: + bootDiagnostics: + description: + - A debugging feature, which to view Console Output and Screenshot to diagnose VM status. + type: dict + returned: always + sample: { + "enabled": true, + "storageUri": "https://vxisurgdiag.blob.core.windows.net/" + } + hardwareProfile: + description: + - Specifies the hardware settings for the virtual machine. + type: dict + returned: always + sample: { + "vmSize": "Standard_D2s_v3" + } + networkProfile: + description: + - Specifies the network interfaces of the virtual machine. + type: complex + returned: always + contains: + networkInterfaces: + description: + - Describes a network interface reference. + type: list + returned: always + sample: + - { + "id": "/subscriptions/xxxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Network/networkInterfaces/myvm441" + } + osProfile: + description: + - Specifies the operating system settings for the virtual machine. + type: complex + returned: always + contains: + adminUsername: + description: + - Specifies the name of the administrator account. + type: str + returned: always + sample: azureuser + allowExtensionOperations: + description: + - Specifies whether extension operations should be allowed on the virtual machine. + - This may only be set to False when no extensions are present on the virtual machine. + type: bool + returned: always + sample: true + computerName: + description: + - Specifies the host OS name of the virtual machine. + type: str + returned: always + sample: myVM + requireGuestProvisionSignale: + description: + - Specifies the host require guest provision signal or not. + type: bool + returned: always + sample: true + secrets: + description: + - Specifies set of certificates that should be installed onto the virtual machine. + type: list + returned: always + sample: [] + linuxConfiguration: + description: + - Specifies the Linux operating system settings on the virtual machine. + type: dict + returned: when OS type is Linux + sample: { + "disablePasswordAuthentication": false, + "provisionVMAgent": true + } + provisioningState: + description: + - The provisioning state. + type: str + returned: always + sample: Succeeded + vmID: + description: + - Specifies the VM unique ID which is a 128-bits identifier that is encoded and stored in all Azure laaS VMs SMBIOS. + - It can be read using platform BIOS commands. + type: str + returned: always + sample: "eb86d9bb-6725-4787-a487-2e497d5b340c" + storageProfile: + description: + - Specifies the storage account type for the managed disk. + type: complex + returned: always + contains: + dataDisks: + description: + - Specifies the parameters that are used to add a data disk to virtual machine. + type: list + returned: always + sample: + - { + "caching": "None", + "createOption": "Attach", + "diskSizeGB": 1023, + "lun": 2, + "managedDisk": { + "id": "/subscriptions/xxxx....xxxx/resourceGroups/V-XISURG/providers/Microsoft.Compute/disks/testdisk2", + "storageAccountType": "StandardSSD_LRS" + }, + "name": "testdisk2" + } + - { + "caching": "None", + "createOption": "Attach", + "diskSizeGB": 1023, + "lun": 1, + "managedDisk": { + "id": "/subscriptions/xxxx...xxxx/resourceGroups/V-XISURG/providers/Microsoft.Compute/disks/testdisk3", + "storageAccountType": "StandardSSD_LRS" + }, + "name": "testdisk3" + } + + imageReference: + description: + - Specifies information about the image to use. + type: dict + returned: always + sample: { + "offer": "UbuntuServer", + "publisher": "Canonical", + "sku": "18.04-LTS", + "version": "latest" + } + osDisk: + description: + - Specifies information about the operating system disk used by the virtual machine. + type: dict + returned: always + sample: { + "caching": "ReadWrite", + "createOption": "FromImage", + "diskSizeGB": 30, + "managedDisk": { + "id": "/subscriptions/xxx...xxxx/resourceGroups/v-xisuRG/providers/Microsoft.Compute/disks/myVM_disk1_xxx", + "storageAccountType": "Premium_LRS" + }, + "name": "myVM_disk1_xxx", + "osType": "Linux" + } + type: + description: + - The type of identity used for the virtual machine. + type: str + returned: always + sample: "Microsoft.Compute/virtualMachines" +''' + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase +from ansible.module_utils.azure_rm_common_rest import GenericRestClient + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.service_client import ServiceClient + from msrestazure.tools import resource_id, is_valid_resource_id + import json + +except ImportError: + # This is handled in azure_rm_common + pass + + +class AzureRMResourceInfo(AzureRMModuleBase): + def __init__(self): + # define user inputs into argument + self.module_arg_spec = dict( + url=dict( + type='str' + ), + provider=dict( + type='str' + ), + resource_group=dict( + type='str' + ), + resource_type=dict( + type='str' + ), + resource_name=dict( + type='str' + ), + subresource=dict( + type='list', + default=[] + ), + api_version=dict( + type='str' + ) + ) + # store the results of the module operation + self.results = dict( + response=[] + ) + self.mgmt_client = None + self.url = None + self.api_version = None + self.provider = None + self.resource_group = None + self.resource_type = None + self.resource_name = None + self.subresource = [] + super(AzureRMResourceInfo, self).__init__(self.module_arg_spec, supports_tags=False) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_resource_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_resource_facts' module has been renamed to 'azure_rm_resource_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + self.mgmt_client = self.get_mgmt_svc_client(GenericRestClient, + base_url=self._cloud_environment.endpoints.resource_manager) + + if self.url is None: + orphan = None + rargs = dict() + rargs['subscription'] = self.subscription_id + rargs['resource_group'] = self.resource_group + if not (self.provider is None or self.provider.lower().startswith('.microsoft')): + rargs['namespace'] = "Microsoft." + self.provider + else: + rargs['namespace'] = self.provider + + if self.resource_type is not None and self.resource_name is not None: + rargs['type'] = self.resource_type + rargs['name'] = self.resource_name + for i in range(len(self.subresource)): + resource_ns = self.subresource[i].get('namespace', None) + resource_type = self.subresource[i].get('type', None) + resource_name = self.subresource[i].get('name', None) + if resource_type is not None and resource_name is not None: + rargs['child_namespace_' + str(i + 1)] = resource_ns + rargs['child_type_' + str(i + 1)] = resource_type + rargs['child_name_' + str(i + 1)] = resource_name + else: + orphan = resource_type + else: + orphan = self.resource_type + + self.url = resource_id(**rargs) + + if orphan is not None: + self.url += '/' + orphan + + # if api_version was not specified, get latest one + if not self.api_version: + try: + # extract provider and resource type + if "/providers/" in self.url: + provider = self.url.split("/providers/")[1].split("/")[0] + resourceType = self.url.split(provider + "/")[1].split("/")[0] + url = "/subscriptions/" + self.subscription_id + "/providers/" + provider + api_versions = json.loads(self.mgmt_client.query(url, "GET", {'api-version': '2015-01-01'}, None, None, [200], 0, 0).text) + for rt in api_versions['resourceTypes']: + if rt['resourceType'].lower() == resourceType.lower(): + self.api_version = rt['apiVersions'][0] + break + else: + # if there's no provider in API version, assume Microsoft.Resources + self.api_version = '2018-05-01' + if not self.api_version: + self.fail("Couldn't find api version for {0}/{1}".format(provider, resourceType)) + except Exception as exc: + self.fail("Failed to obtain API version: {0}".format(str(exc))) + + self.results['url'] = self.url + + query_parameters = {} + query_parameters['api-version'] = self.api_version + + header_parameters = {} + header_parameters['Content-Type'] = 'application/json; charset=utf-8' + skiptoken = None + + while True: + if skiptoken: + query_parameters['skiptoken'] = skiptoken + response = self.mgmt_client.query(self.url, "GET", query_parameters, header_parameters, None, [200, 404], 0, 0) + try: + response = json.loads(response.text) + if isinstance(response, dict): + if response.get('value'): + self.results['response'] = self.results['response'] + response['value'] + skiptoken = response.get('nextLink') + else: + self.results['response'] = self.results['response'] + [response] + except Exception as e: + self.fail('Failed to parse response: ' + str(e)) + if not skiptoken: + break + return self.results + + +def main(): + AzureRMResourceInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_storageaccount.py b/test/support/integration/plugins/modules/azure_rm_storageaccount.py new file mode 100644 index 00000000..d4158bbd --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_storageaccount.py @@ -0,0 +1,684 @@ +#!/usr/bin/python +# +# Copyright (c) 2016 Matt Davis, <mdavis@ansible.com> +# Chris Houseknecht, <house@redhat.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_storageaccount +version_added: "2.1" +short_description: Manage Azure storage accounts +description: + - Create, update or delete a storage account. +options: + resource_group: + description: + - Name of the resource group to use. + required: true + aliases: + - resource_group_name + name: + description: + - Name of the storage account to update or create. + state: + description: + - State of the storage account. Use C(present) to create or update a storage account and use C(absent) to delete an account. + default: present + choices: + - absent + - present + location: + description: + - Valid Azure location. Defaults to location of the resource group. + account_type: + description: + - Type of storage account. Required when creating a storage account. + - C(Standard_ZRS) and C(Premium_LRS) accounts cannot be changed to other account types. + - Other account types cannot be changed to C(Standard_ZRS) or C(Premium_LRS). + choices: + - Premium_LRS + - Standard_GRS + - Standard_LRS + - StandardSSD_LRS + - Standard_RAGRS + - Standard_ZRS + - Premium_ZRS + aliases: + - type + custom_domain: + description: + - User domain assigned to the storage account. + - Must be a dictionary with I(name) and I(use_sub_domain) keys where I(name) is the CNAME source. + - Only one custom domain is supported per storage account at this time. + - To clear the existing custom domain, use an empty string for the custom domain name property. + - Can be added to an existing storage account. Will be ignored during storage account creation. + aliases: + - custom_dns_domain_suffix + kind: + description: + - The kind of storage. + default: 'Storage' + choices: + - Storage + - StorageV2 + - BlobStorage + version_added: "2.2" + access_tier: + description: + - The access tier for this storage account. Required when I(kind=BlobStorage). + choices: + - Hot + - Cool + version_added: "2.4" + force_delete_nonempty: + description: + - Attempt deletion if resource already exists and cannot be updated. + type: bool + aliases: + - force + https_only: + description: + - Allows https traffic only to storage service when set to C(true). + type: bool + version_added: "2.8" + blob_cors: + description: + - Specifies CORS rules for the Blob service. + - You can include up to five CorsRule elements in the request. + - If no blob_cors elements are included in the argument list, nothing about CORS will be changed. + - If you want to delete all CORS rules and disable CORS for the Blob service, explicitly set I(blob_cors=[]). + type: list + version_added: "2.8" + suboptions: + allowed_origins: + description: + - A list of origin domains that will be allowed via CORS, or "*" to allow all domains. + type: list + required: true + allowed_methods: + description: + - A list of HTTP methods that are allowed to be executed by the origin. + type: list + required: true + max_age_in_seconds: + description: + - The number of seconds that the client/browser should cache a preflight response. + type: int + required: true + exposed_headers: + description: + - A list of response headers to expose to CORS clients. + type: list + required: true + allowed_headers: + description: + - A list of headers allowed to be part of the cross-origin request. + type: list + required: true + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Chris Houseknecht (@chouseknecht) + - Matt Davis (@nitzmahone) +''' + +EXAMPLES = ''' + - name: remove account, if it exists + azure_rm_storageaccount: + resource_group: myResourceGroup + name: clh0002 + state: absent + + - name: create an account + azure_rm_storageaccount: + resource_group: myResourceGroup + name: clh0002 + type: Standard_RAGRS + tags: + testing: testing + delete: on-exit + + - name: create an account with blob CORS + azure_rm_storageaccount: + resource_group: myResourceGroup + name: clh002 + type: Standard_RAGRS + blob_cors: + - allowed_origins: + - http://www.example.com/ + allowed_methods: + - GET + - POST + allowed_headers: + - x-ms-meta-data* + - x-ms-meta-target* + - x-ms-meta-abc + exposed_headers: + - x-ms-meta-* + max_age_in_seconds: 200 +''' + + +RETURN = ''' +state: + description: + - Current state of the storage account. + returned: always + type: complex + contains: + account_type: + description: + - Type of storage account. + returned: always + type: str + sample: Standard_RAGRS + custom_domain: + description: + - User domain assigned to the storage account. + returned: always + type: complex + contains: + name: + description: + - CNAME source. + returned: always + type: str + sample: testaccount + use_sub_domain: + description: + - Whether to use sub domain. + returned: always + type: bool + sample: true + id: + description: + - Resource ID. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Storage/storageAccounts/clh0003" + location: + description: + - Valid Azure location. Defaults to location of the resource group. + returned: always + type: str + sample: eastus2 + name: + description: + - Name of the storage account to update or create. + returned: always + type: str + sample: clh0003 + primary_endpoints: + description: + - The URLs to retrieve the public I(blob), I(queue), or I(table) object from the primary location. + returned: always + type: dict + sample: { + "blob": "https://clh0003.blob.core.windows.net/", + "queue": "https://clh0003.queue.core.windows.net/", + "table": "https://clh0003.table.core.windows.net/" + } + primary_location: + description: + - The location of the primary data center for the storage account. + returned: always + type: str + sample: eastus2 + provisioning_state: + description: + - The status of the storage account. + - Possible values include C(Creating), C(ResolvingDNS), C(Succeeded). + returned: always + type: str + sample: Succeeded + resource_group: + description: + - The resource group's name. + returned: always + type: str + sample: Testing + secondary_endpoints: + description: + - The URLs to retrieve the public I(blob), I(queue), or I(table) object from the secondary location. + returned: always + type: dict + sample: { + "blob": "https://clh0003-secondary.blob.core.windows.net/", + "queue": "https://clh0003-secondary.queue.core.windows.net/", + "table": "https://clh0003-secondary.table.core.windows.net/" + } + secondary_location: + description: + - The location of the geo-replicated secondary for the storage account. + returned: always + type: str + sample: centralus + status_of_primary: + description: + - The status of the primary location of the storage account; either C(available) or C(unavailable). + returned: always + type: str + sample: available + status_of_secondary: + description: + - The status of the secondary location of the storage account; either C(available) or C(unavailable). + returned: always + type: str + sample: available + tags: + description: + - Resource tags. + returned: always + type: dict + sample: { 'tags1': 'value1' } + type: + description: + - The storage account type. + returned: always + type: str + sample: "Microsoft.Storage/storageAccounts" +''' + +try: + from msrestazure.azure_exceptions import CloudError + from azure.storage.cloudstorageaccount import CloudStorageAccount + from azure.common import AzureMissingResourceHttpError +except ImportError: + # This is handled in azure_rm_common + pass + +import copy +from ansible.module_utils.azure_rm_common import AZURE_SUCCESS_STATE, AzureRMModuleBase +from ansible.module_utils._text import to_native + +cors_rule_spec = dict( + allowed_origins=dict(type='list', elements='str', required=True), + allowed_methods=dict(type='list', elements='str', required=True), + max_age_in_seconds=dict(type='int', required=True), + exposed_headers=dict(type='list', elements='str', required=True), + allowed_headers=dict(type='list', elements='str', required=True), +) + + +def compare_cors(cors1, cors2): + if len(cors1) != len(cors2): + return False + copy2 = copy.copy(cors2) + for rule1 in cors1: + matched = False + for rule2 in copy2: + if (rule1['max_age_in_seconds'] == rule2['max_age_in_seconds'] + and set(rule1['allowed_methods']) == set(rule2['allowed_methods']) + and set(rule1['allowed_origins']) == set(rule2['allowed_origins']) + and set(rule1['allowed_headers']) == set(rule2['allowed_headers']) + and set(rule1['exposed_headers']) == set(rule2['exposed_headers'])): + matched = True + copy2.remove(rule2) + if not matched: + return False + return True + + +class AzureRMStorageAccount(AzureRMModuleBase): + + def __init__(self): + + self.module_arg_spec = dict( + account_type=dict(type='str', + choices=['Premium_LRS', 'Standard_GRS', 'Standard_LRS', 'StandardSSD_LRS', 'Standard_RAGRS', 'Standard_ZRS', 'Premium_ZRS'], + aliases=['type']), + custom_domain=dict(type='dict', aliases=['custom_dns_domain_suffix']), + location=dict(type='str'), + name=dict(type='str', required=True), + resource_group=dict(required=True, type='str', aliases=['resource_group_name']), + state=dict(default='present', choices=['present', 'absent']), + force_delete_nonempty=dict(type='bool', default=False, aliases=['force']), + tags=dict(type='dict'), + kind=dict(type='str', default='Storage', choices=['Storage', 'StorageV2', 'BlobStorage']), + access_tier=dict(type='str', choices=['Hot', 'Cool']), + https_only=dict(type='bool', default=False), + blob_cors=dict(type='list', options=cors_rule_spec, elements='dict') + ) + + self.results = dict( + changed=False, + state=dict() + ) + + self.account_dict = None + self.resource_group = None + self.name = None + self.state = None + self.location = None + self.account_type = None + self.custom_domain = None + self.tags = None + self.force_delete_nonempty = None + self.kind = None + self.access_tier = None + self.https_only = None + self.blob_cors = None + + super(AzureRMStorageAccount, self).__init__(self.module_arg_spec, + supports_check_mode=True) + + def exec_module(self, **kwargs): + + for key in list(self.module_arg_spec.keys()) + ['tags']: + setattr(self, key, kwargs[key]) + + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + # Set default location + self.location = resource_group.location + + if len(self.name) < 3 or len(self.name) > 24: + self.fail("Parameter error: name length must be between 3 and 24 characters.") + + if self.custom_domain: + if self.custom_domain.get('name', None) is None: + self.fail("Parameter error: expecting custom_domain to have a name attribute of type string.") + if self.custom_domain.get('use_sub_domain', None) is None: + self.fail("Parameter error: expecting custom_domain to have a use_sub_domain " + "attribute of type boolean.") + + self.account_dict = self.get_account() + + if self.state == 'present' and self.account_dict and \ + self.account_dict['provisioning_state'] != AZURE_SUCCESS_STATE: + self.fail("Error: storage account {0} has not completed provisioning. State is {1}. Expecting state " + "to be {2}.".format(self.name, self.account_dict['provisioning_state'], AZURE_SUCCESS_STATE)) + + if self.account_dict is not None: + self.results['state'] = self.account_dict + else: + self.results['state'] = dict() + + if self.state == 'present': + if not self.account_dict: + self.results['state'] = self.create_account() + else: + self.update_account() + elif self.state == 'absent' and self.account_dict: + self.delete_account() + self.results['state'] = dict(Status='Deleted') + + return self.results + + def check_name_availability(self): + self.log('Checking name availability for {0}'.format(self.name)) + try: + response = self.storage_client.storage_accounts.check_name_availability(self.name) + except CloudError as e: + self.log('Error attempting to validate name.') + self.fail("Error checking name availability: {0}".format(str(e))) + if not response.name_available: + self.log('Error name not available.') + self.fail("{0} - {1}".format(response.message, response.reason)) + + def get_account(self): + self.log('Get properties for account {0}'.format(self.name)) + account_obj = None + blob_service_props = None + account_dict = None + + try: + account_obj = self.storage_client.storage_accounts.get_properties(self.resource_group, self.name) + blob_service_props = self.storage_client.blob_services.get_service_properties(self.resource_group, self.name) + except CloudError: + pass + + if account_obj: + account_dict = self.account_obj_to_dict(account_obj, blob_service_props) + + return account_dict + + def account_obj_to_dict(self, account_obj, blob_service_props=None): + account_dict = dict( + id=account_obj.id, + name=account_obj.name, + location=account_obj.location, + resource_group=self.resource_group, + type=account_obj.type, + access_tier=(account_obj.access_tier.value + if account_obj.access_tier is not None else None), + sku_tier=account_obj.sku.tier.value, + sku_name=account_obj.sku.name.value, + provisioning_state=account_obj.provisioning_state.value, + secondary_location=account_obj.secondary_location, + status_of_primary=(account_obj.status_of_primary.value + if account_obj.status_of_primary is not None else None), + status_of_secondary=(account_obj.status_of_secondary.value + if account_obj.status_of_secondary is not None else None), + primary_location=account_obj.primary_location, + https_only=account_obj.enable_https_traffic_only + ) + account_dict['custom_domain'] = None + if account_obj.custom_domain: + account_dict['custom_domain'] = dict( + name=account_obj.custom_domain.name, + use_sub_domain=account_obj.custom_domain.use_sub_domain + ) + + account_dict['primary_endpoints'] = None + if account_obj.primary_endpoints: + account_dict['primary_endpoints'] = dict( + blob=account_obj.primary_endpoints.blob, + queue=account_obj.primary_endpoints.queue, + table=account_obj.primary_endpoints.table + ) + account_dict['secondary_endpoints'] = None + if account_obj.secondary_endpoints: + account_dict['secondary_endpoints'] = dict( + blob=account_obj.secondary_endpoints.blob, + queue=account_obj.secondary_endpoints.queue, + table=account_obj.secondary_endpoints.table + ) + account_dict['tags'] = None + if account_obj.tags: + account_dict['tags'] = account_obj.tags + if blob_service_props and blob_service_props.cors and blob_service_props.cors.cors_rules: + account_dict['blob_cors'] = [dict( + allowed_origins=[to_native(y) for y in x.allowed_origins], + allowed_methods=[to_native(y) for y in x.allowed_methods], + max_age_in_seconds=x.max_age_in_seconds, + exposed_headers=[to_native(y) for y in x.exposed_headers], + allowed_headers=[to_native(y) for y in x.allowed_headers] + ) for x in blob_service_props.cors.cors_rules] + return account_dict + + def update_account(self): + self.log('Update storage account {0}'.format(self.name)) + if bool(self.https_only) != bool(self.account_dict.get('https_only')): + self.results['changed'] = True + self.account_dict['https_only'] = self.https_only + if not self.check_mode: + try: + parameters = self.storage_models.StorageAccountUpdateParameters(enable_https_traffic_only=self.https_only) + self.storage_client.storage_accounts.update(self.resource_group, + self.name, + parameters) + except Exception as exc: + self.fail("Failed to update account type: {0}".format(str(exc))) + + if self.account_type: + if self.account_type != self.account_dict['sku_name']: + # change the account type + SkuName = self.storage_models.SkuName + if self.account_dict['sku_name'] in [SkuName.premium_lrs, SkuName.standard_zrs]: + self.fail("Storage accounts of type {0} and {1} cannot be changed.".format( + SkuName.premium_lrs, SkuName.standard_zrs)) + if self.account_type in [SkuName.premium_lrs, SkuName.standard_zrs]: + self.fail("Storage account of type {0} cannot be changed to a type of {1} or {2}.".format( + self.account_dict['sku_name'], SkuName.premium_lrs, SkuName.standard_zrs)) + + self.results['changed'] = True + self.account_dict['sku_name'] = self.account_type + + if self.results['changed'] and not self.check_mode: + # Perform the update. The API only allows changing one attribute per call. + try: + self.log("sku_name: %s" % self.account_dict['sku_name']) + self.log("sku_tier: %s" % self.account_dict['sku_tier']) + sku = self.storage_models.Sku(name=SkuName(self.account_dict['sku_name'])) + sku.tier = self.storage_models.SkuTier(self.account_dict['sku_tier']) + parameters = self.storage_models.StorageAccountUpdateParameters(sku=sku) + self.storage_client.storage_accounts.update(self.resource_group, + self.name, + parameters) + except Exception as exc: + self.fail("Failed to update account type: {0}".format(str(exc))) + + if self.custom_domain: + if not self.account_dict['custom_domain'] or self.account_dict['custom_domain'] != self.custom_domain: + self.results['changed'] = True + self.account_dict['custom_domain'] = self.custom_domain + + if self.results['changed'] and not self.check_mode: + new_domain = self.storage_models.CustomDomain(name=self.custom_domain['name'], + use_sub_domain=self.custom_domain['use_sub_domain']) + parameters = self.storage_models.StorageAccountUpdateParameters(custom_domain=new_domain) + try: + self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters) + except Exception as exc: + self.fail("Failed to update custom domain: {0}".format(str(exc))) + + if self.access_tier: + if not self.account_dict['access_tier'] or self.account_dict['access_tier'] != self.access_tier: + self.results['changed'] = True + self.account_dict['access_tier'] = self.access_tier + + if self.results['changed'] and not self.check_mode: + parameters = self.storage_models.StorageAccountUpdateParameters(access_tier=self.access_tier) + try: + self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters) + except Exception as exc: + self.fail("Failed to update access tier: {0}".format(str(exc))) + + update_tags, self.account_dict['tags'] = self.update_tags(self.account_dict['tags']) + if update_tags: + self.results['changed'] = True + if not self.check_mode: + parameters = self.storage_models.StorageAccountUpdateParameters(tags=self.account_dict['tags']) + try: + self.storage_client.storage_accounts.update(self.resource_group, self.name, parameters) + except Exception as exc: + self.fail("Failed to update tags: {0}".format(str(exc))) + + if self.blob_cors and not compare_cors(self.account_dict.get('blob_cors', []), self.blob_cors): + self.results['changed'] = True + if not self.check_mode: + self.set_blob_cors() + + def create_account(self): + self.log("Creating account {0}".format(self.name)) + + if not self.location: + self.fail('Parameter error: location required when creating a storage account.') + + if not self.account_type: + self.fail('Parameter error: account_type required when creating a storage account.') + + if not self.access_tier and self.kind == 'BlobStorage': + self.fail('Parameter error: access_tier required when creating a storage account of type BlobStorage.') + + self.check_name_availability() + self.results['changed'] = True + + if self.check_mode: + account_dict = dict( + location=self.location, + account_type=self.account_type, + name=self.name, + resource_group=self.resource_group, + enable_https_traffic_only=self.https_only, + tags=dict() + ) + if self.tags: + account_dict['tags'] = self.tags + if self.blob_cors: + account_dict['blob_cors'] = self.blob_cors + return account_dict + sku = self.storage_models.Sku(name=self.storage_models.SkuName(self.account_type)) + sku.tier = self.storage_models.SkuTier.standard if 'Standard' in self.account_type else \ + self.storage_models.SkuTier.premium + parameters = self.storage_models.StorageAccountCreateParameters(sku=sku, + kind=self.kind, + location=self.location, + tags=self.tags, + access_tier=self.access_tier) + self.log(str(parameters)) + try: + poller = self.storage_client.storage_accounts.create(self.resource_group, self.name, parameters) + self.get_poller_result(poller) + except CloudError as e: + self.log('Error creating storage account.') + self.fail("Failed to create account: {0}".format(str(e))) + if self.blob_cors: + self.set_blob_cors() + # the poller doesn't actually return anything + return self.get_account() + + def delete_account(self): + if self.account_dict['provisioning_state'] == self.storage_models.ProvisioningState.succeeded.value and \ + not self.force_delete_nonempty and self.account_has_blob_containers(): + self.fail("Account contains blob containers. Is it in use? Use the force_delete_nonempty option to attempt deletion.") + + self.log('Delete storage account {0}'.format(self.name)) + self.results['changed'] = True + if not self.check_mode: + try: + status = self.storage_client.storage_accounts.delete(self.resource_group, self.name) + self.log("delete status: ") + self.log(str(status)) + except CloudError as e: + self.fail("Failed to delete the account: {0}".format(str(e))) + return True + + def account_has_blob_containers(self): + ''' + If there are blob containers, then there are likely VMs depending on this account and it should + not be deleted. + ''' + self.log('Checking for existing blob containers') + blob_service = self.get_blob_client(self.resource_group, self.name) + try: + response = blob_service.list_containers() + except AzureMissingResourceHttpError: + # No blob storage available? + return False + + if len(response.items) > 0: + return True + return False + + def set_blob_cors(self): + try: + cors_rules = self.storage_models.CorsRules(cors_rules=[self.storage_models.CorsRule(**x) for x in self.blob_cors]) + self.storage_client.blob_services.set_service_properties(self.resource_group, + self.name, + self.storage_models.BlobServiceProperties(cors=cors_rules)) + except Exception as exc: + self.fail("Failed to set CORS rules: {0}".format(str(exc))) + + +def main(): + AzureRMStorageAccount() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_webapp.py b/test/support/integration/plugins/modules/azure_rm_webapp.py new file mode 100644 index 00000000..4f185f45 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_webapp.py @@ -0,0 +1,1070 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_webapp +version_added: "2.7" +short_description: Manage Web App instances +description: + - Create, update and delete instance of Web App. + +options: + resource_group: + description: + - Name of the resource group to which the resource belongs. + required: True + name: + description: + - Unique name of the app to create or update. To create or update a deployment slot, use the {slot} parameter. + required: True + + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + + plan: + description: + - App service plan. Required for creation. + - Can be name of existing app service plan in same resource group as web app. + - Can be the resource ID of an existing app service plan. For example + /subscriptions/<subs_id>/resourceGroups/<resource_group>/providers/Microsoft.Web/serverFarms/<plan_name>. + - Can be a dict containing five parameters, defined below. + - C(name), name of app service plan. + - C(resource_group), resource group of the app service plan. + - C(sku), SKU of app service plan, allowed values listed on U(https://azure.microsoft.com/en-us/pricing/details/app-service/linux/). + - C(is_linux), whether or not the app service plan is Linux. defaults to C(False). + - C(number_of_workers), number of workers for app service plan. + + frameworks: + description: + - Set of run time framework settings. Each setting is a dictionary. + - See U(https://docs.microsoft.com/en-us/azure/app-service/app-service-web-overview) for more info. + suboptions: + name: + description: + - Name of the framework. + - Supported framework list for Windows web app and Linux web app is different. + - Windows web apps support C(java), C(net_framework), C(php), C(python), and C(node) from June 2018. + - Windows web apps support multiple framework at the same time. + - Linux web apps support C(java), C(ruby), C(php), C(dotnetcore), and C(node) from June 2018. + - Linux web apps support only one framework. + - Java framework is mutually exclusive with others. + choices: + - java + - net_framework + - php + - python + - ruby + - dotnetcore + - node + version: + description: + - Version of the framework. For Linux web app supported value, see U(https://aka.ms/linux-stacks) for more info. + - C(net_framework) supported value sample, C(v4.0) for .NET 4.6 and C(v3.0) for .NET 3.5. + - C(php) supported value sample, C(5.5), C(5.6), C(7.0). + - C(python) supported value sample, C(5.5), C(5.6), C(7.0). + - C(node) supported value sample, C(6.6), C(6.9). + - C(dotnetcore) supported value sample, C(1.0), C(1.1), C(1.2). + - C(ruby) supported value sample, C(2.3). + - C(java) supported value sample, C(1.9) for Windows web app. C(1.8) for Linux web app. + settings: + description: + - List of settings of the framework. + suboptions: + java_container: + description: + - Name of Java container. + - Supported only when I(frameworks=java). Sample values C(Tomcat), C(Jetty). + java_container_version: + description: + - Version of Java container. + - Supported only when I(frameworks=java). + - Sample values for C(Tomcat), C(8.0), C(8.5), C(9.0). For C(Jetty,), C(9.1), C(9.3). + + container_settings: + description: + - Web app container settings. + suboptions: + name: + description: + - Name of container, for example C(imagename:tag). + registry_server_url: + description: + - Container registry server URL, for example C(mydockerregistry.io). + registry_server_user: + description: + - The container registry server user name. + registry_server_password: + description: + - The container registry server password. + + scm_type: + description: + - Repository type of deployment source, for example C(LocalGit), C(GitHub). + - List of supported values maintained at U(https://docs.microsoft.com/en-us/rest/api/appservice/webapps/createorupdate#scmtype). + + deployment_source: + description: + - Deployment source for git. + suboptions: + url: + description: + - Repository url of deployment source. + + branch: + description: + - The branch name of the repository. + startup_file: + description: + - The web's startup file. + - Used only for Linux web apps. + + client_affinity_enabled: + description: + - Whether or not to send session affinity cookies, which route client requests in the same session to the same instance. + type: bool + default: True + + https_only: + description: + - Configures web site to accept only https requests. + type: bool + + dns_registration: + description: + - Whether or not the web app hostname is registered with DNS on creation. Set to C(false) to register. + type: bool + + skip_custom_domain_verification: + description: + - Whether or not to skip verification of custom (non *.azurewebsites.net) domains associated with web app. Set to C(true) to skip. + type: bool + + ttl_in_seconds: + description: + - Time to live in seconds for web app default domain name. + + app_settings: + description: + - Configure web app application settings. Suboptions are in key value pair format. + + purge_app_settings: + description: + - Purge any existing application settings. Replace web app application settings with app_settings. + type: bool + + app_state: + description: + - Start/Stop/Restart the web app. + type: str + choices: + - started + - stopped + - restarted + default: started + + state: + description: + - State of the Web App. + - Use C(present) to create or update a Web App and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Yunge Zhu (@yungezz) + +''' + +EXAMPLES = ''' + - name: Create a windows web app with non-exist app service plan + azure_rm_webapp: + resource_group: myResourceGroup + name: myWinWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + is_linux: false + sku: S1 + + - name: Create a docker web app with some app settings, with docker image + azure_rm_webapp: + resource_group: myResourceGroup + name: myDockerWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + is_linux: true + sku: S1 + number_of_workers: 2 + app_settings: + testkey: testvalue + testkey2: testvalue2 + container_settings: + name: ansible/ansible:ubuntu1404 + + - name: Create a docker web app with private acr registry + azure_rm_webapp: + resource_group: myResourceGroup + name: myDockerWebapp + plan: myAppServicePlan + app_settings: + testkey: testvalue + container_settings: + name: ansible/ubuntu1404 + registry_server_url: myregistry.io + registry_server_user: user + registry_server_password: pass + + - name: Create a linux web app with Node 6.6 framework + azure_rm_webapp: + resource_group: myResourceGroup + name: myLinuxWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey: testvalue + frameworks: + - name: "node" + version: "6.6" + + - name: Create a windows web app with node, php + azure_rm_webapp: + resource_group: myResourceGroup + name: myWinWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey: testvalue + frameworks: + - name: "node" + version: 6.6 + - name: "php" + version: "7.0" + + - name: Create a stage deployment slot for an existing web app + azure_rm_webapp: + resource_group: myResourceGroup + name: myWebapp/slots/stage + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey:testvalue + + - name: Create a linux web app with java framework + azure_rm_webapp: + resource_group: myResourceGroup + name: myLinuxWebapp + plan: + resource_group: myAppServicePlan_rg + name: myAppServicePlan + app_settings: + testkey: testvalue + frameworks: + - name: "java" + version: "8" + settings: + java_container: "Tomcat" + java_container_version: "8.5" +''' + +RETURN = ''' +azure_webapp: + description: + - ID of current web app. + returned: always + type: str + sample: "/subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myWebApp" +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model + from azure.mgmt.web.models import ( + site_config, app_service_plan, Site, + AppServicePlan, SkuDescription, NameValuePair + ) +except ImportError: + # This is handled in azure_rm_common + pass + +container_settings_spec = dict( + name=dict(type='str', required=True), + registry_server_url=dict(type='str'), + registry_server_user=dict(type='str'), + registry_server_password=dict(type='str', no_log=True) +) + +deployment_source_spec = dict( + url=dict(type='str'), + branch=dict(type='str') +) + + +framework_settings_spec = dict( + java_container=dict(type='str', required=True), + java_container_version=dict(type='str', required=True) +) + + +framework_spec = dict( + name=dict( + type='str', + required=True, + choices=['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby']), + version=dict(type='str', required=True), + settings=dict(type='dict', options=framework_settings_spec) +) + + +def _normalize_sku(sku): + if sku is None: + return sku + + sku = sku.upper() + if sku == 'FREE': + return 'F1' + elif sku == 'SHARED': + return 'D1' + return sku + + +def get_sku_name(tier): + tier = tier.upper() + if tier == 'F1' or tier == "FREE": + return 'FREE' + elif tier == 'D1' or tier == "SHARED": + return 'SHARED' + elif tier in ['B1', 'B2', 'B3', 'BASIC']: + return 'BASIC' + elif tier in ['S1', 'S2', 'S3']: + return 'STANDARD' + elif tier in ['P1', 'P2', 'P3']: + return 'PREMIUM' + elif tier in ['P1V2', 'P2V2', 'P3V2']: + return 'PREMIUMV2' + else: + return None + + +def appserviceplan_to_dict(plan): + return dict( + id=plan.id, + name=plan.name, + kind=plan.kind, + location=plan.location, + reserved=plan.reserved, + is_linux=plan.reserved, + provisioning_state=plan.provisioning_state, + tags=plan.tags if plan.tags else None + ) + + +def webapp_to_dict(webapp): + return dict( + id=webapp.id, + name=webapp.name, + location=webapp.location, + client_cert_enabled=webapp.client_cert_enabled, + enabled=webapp.enabled, + reserved=webapp.reserved, + client_affinity_enabled=webapp.client_affinity_enabled, + server_farm_id=webapp.server_farm_id, + host_names_disabled=webapp.host_names_disabled, + https_only=webapp.https_only if hasattr(webapp, 'https_only') else None, + skip_custom_domain_verification=webapp.skip_custom_domain_verification if hasattr(webapp, 'skip_custom_domain_verification') else None, + ttl_in_seconds=webapp.ttl_in_seconds if hasattr(webapp, 'ttl_in_seconds') else None, + state=webapp.state, + tags=webapp.tags if webapp.tags else None + ) + + +class Actions: + CreateOrUpdate, UpdateAppSettings, Delete = range(3) + + +class AzureRMWebApps(AzureRMModuleBase): + """Configuration class for an Azure RM Web App resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + location=dict( + type='str' + ), + plan=dict( + type='raw' + ), + frameworks=dict( + type='list', + elements='dict', + options=framework_spec + ), + container_settings=dict( + type='dict', + options=container_settings_spec + ), + scm_type=dict( + type='str', + ), + deployment_source=dict( + type='dict', + options=deployment_source_spec + ), + startup_file=dict( + type='str' + ), + client_affinity_enabled=dict( + type='bool', + default=True + ), + dns_registration=dict( + type='bool' + ), + https_only=dict( + type='bool' + ), + skip_custom_domain_verification=dict( + type='bool' + ), + ttl_in_seconds=dict( + type='int' + ), + app_settings=dict( + type='dict' + ), + purge_app_settings=dict( + type='bool', + default=False + ), + app_state=dict( + type='str', + choices=['started', 'stopped', 'restarted'], + default='started' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + mutually_exclusive = [['container_settings', 'frameworks']] + + self.resource_group = None + self.name = None + self.location = None + + # update in create_or_update as parameters + self.client_affinity_enabled = True + self.dns_registration = None + self.skip_custom_domain_verification = None + self.ttl_in_seconds = None + self.https_only = None + + self.tags = None + + # site config, e.g app settings, ssl + self.site_config = dict() + self.app_settings = dict() + self.app_settings_strDic = None + + # app service plan + self.plan = None + + # siteSourceControl + self.deployment_source = dict() + + # site, used at level creation, or update. e.g windows/linux, client_affinity etc first level args + self.site = None + + # property for internal usage, not used for sdk + self.container_settings = None + + self.purge_app_settings = False + self.app_state = 'started' + + self.results = dict( + changed=False, + id=None, + ) + self.state = None + self.to_do = [] + + self.frameworks = None + + # set site_config value from kwargs + self.site_config_updatable_properties = ["net_framework_version", + "java_version", + "php_version", + "python_version", + "scm_type"] + + # updatable_properties + self.updatable_properties = ["client_affinity_enabled", + "force_dns_registration", + "https_only", + "skip_custom_domain_verification", + "ttl_in_seconds"] + + self.supported_linux_frameworks = ['ruby', 'php', 'dotnetcore', 'node', 'java'] + self.supported_windows_frameworks = ['net_framework', 'php', 'python', 'node', 'java'] + + super(AzureRMWebApps, self).__init__(derived_arg_spec=self.module_arg_spec, + mutually_exclusive=mutually_exclusive, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "scm_type": + self.site_config[key] = kwargs[key] + + old_response = None + response = None + to_be_updated = False + + # set location + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + self.location = resource_group.location + + # get existing web app + old_response = self.get_webapp() + + if old_response: + self.results['id'] = old_response['id'] + + if self.state == 'present': + if not self.plan and not old_response: + self.fail("Please specify plan for newly created web app.") + + if not self.plan: + self.plan = old_response['server_farm_id'] + + self.plan = self.parse_resource_to_dict(self.plan) + + # get app service plan + is_linux = False + old_plan = self.get_app_service_plan() + if old_plan: + is_linux = old_plan['reserved'] + else: + is_linux = self.plan['is_linux'] if 'is_linux' in self.plan else False + + if self.frameworks: + # java is mutually exclusive with other frameworks + if len(self.frameworks) > 1 and any(f['name'] == 'java' for f in self.frameworks): + self.fail('Java is mutually exclusive with other frameworks.') + + if is_linux: + if len(self.frameworks) != 1: + self.fail('Can specify one framework only for Linux web app.') + + if self.frameworks[0]['name'] not in self.supported_linux_frameworks: + self.fail('Unsupported framework {0} for Linux web app.'.format(self.frameworks[0]['name'])) + + self.site_config['linux_fx_version'] = (self.frameworks[0]['name'] + '|' + self.frameworks[0]['version']).upper() + + if self.frameworks[0]['name'] == 'java': + if self.frameworks[0]['version'] != '8': + self.fail("Linux web app only supports java 8.") + if self.frameworks[0]['settings'] and self.frameworks[0]['settings']['java_container'].lower() != 'tomcat': + self.fail("Linux web app only supports tomcat container.") + + if self.frameworks[0]['settings'] and self.frameworks[0]['settings']['java_container'].lower() == 'tomcat': + self.site_config['linux_fx_version'] = 'TOMCAT|' + self.frameworks[0]['settings']['java_container_version'] + '-jre8' + else: + self.site_config['linux_fx_version'] = 'JAVA|8-jre8' + else: + for fx in self.frameworks: + if fx.get('name') not in self.supported_windows_frameworks: + self.fail('Unsupported framework {0} for Windows web app.'.format(fx.get('name'))) + else: + self.site_config[fx.get('name') + '_version'] = fx.get('version') + + if 'settings' in fx and fx['settings'] is not None: + for key, value in fx['settings'].items(): + self.site_config[key] = value + + if not self.app_settings: + self.app_settings = dict() + + if self.container_settings: + linux_fx_version = 'DOCKER|' + + if self.container_settings.get('registry_server_url'): + self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url'] + + linux_fx_version += self.container_settings['registry_server_url'] + '/' + + linux_fx_version += self.container_settings['name'] + + self.site_config['linux_fx_version'] = linux_fx_version + + if self.container_settings.get('registry_server_user'): + self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings['registry_server_user'] + + if self.container_settings.get('registry_server_password'): + self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings['registry_server_password'] + + # init site + self.site = Site(location=self.location, site_config=self.site_config) + + if self.https_only is not None: + self.site.https_only = self.https_only + + if self.client_affinity_enabled: + self.site.client_affinity_enabled = self.client_affinity_enabled + + # check if the web app already present in the resource group + if not old_response: + self.log("Web App instance doesn't exist") + + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + self.site.tags = self.tags + + # service plan is required for creation + if not self.plan: + self.fail("Please specify app service plan in plan parameter.") + + if not old_plan: + # no existing service plan, create one + if (not self.plan.get('name') or not self.plan.get('sku')): + self.fail('Please specify name, is_linux, sku in plan') + + if 'location' not in self.plan: + plan_resource_group = self.get_resource_group(self.plan['resource_group']) + self.plan['location'] = plan_resource_group.location + + old_plan = self.create_app_service_plan() + + self.site.server_farm_id = old_plan['id'] + + # if linux, setup startup_file + if old_plan['is_linux']: + if hasattr(self, 'startup_file'): + self.site_config['app_command_line'] = self.startup_file + + # set app setting + if self.app_settings: + app_settings = [] + for key in self.app_settings.keys(): + app_settings.append(NameValuePair(name=key, value=self.app_settings[key])) + + self.site_config['app_settings'] = app_settings + else: + # existing web app, do update + self.log("Web App instance already exists") + + self.log('Result: {0}'.format(old_response)) + + update_tags, self.site.tags = self.update_tags(old_response.get('tags', None)) + + if update_tags: + to_be_updated = True + + # check if root level property changed + if self.is_updatable_property_changed(old_response): + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + + # check if site_config changed + old_config = self.get_webapp_configuration() + + if self.is_site_config_changed(old_config): + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + + # check if linux_fx_version changed + if old_config.linux_fx_version != self.site_config.get('linux_fx_version', ''): + to_be_updated = True + self.to_do.append(Actions.CreateOrUpdate) + + self.app_settings_strDic = self.list_app_settings() + + # purge existing app_settings: + if self.purge_app_settings: + to_be_updated = True + self.app_settings_strDic = dict() + self.to_do.append(Actions.UpdateAppSettings) + + # check if app settings changed + if self.purge_app_settings or self.is_app_settings_changed(): + to_be_updated = True + self.to_do.append(Actions.UpdateAppSettings) + + if self.app_settings: + for key in self.app_settings.keys(): + self.app_settings_strDic[key] = self.app_settings[key] + + elif self.state == 'absent': + if old_response: + self.log("Delete Web App instance") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_webapp() + + self.log('Web App instance deleted') + + else: + self.fail("Web app {0} not exists.".format(self.name)) + + if to_be_updated: + self.log('Need to Create/Update web app') + self.results['changed'] = True + + if self.check_mode: + return self.results + + if Actions.CreateOrUpdate in self.to_do: + response = self.create_update_webapp() + + self.results['id'] = response['id'] + + if Actions.UpdateAppSettings in self.to_do: + update_response = self.update_app_settings() + self.results['id'] = update_response.id + + webapp = None + if old_response: + webapp = old_response + if response: + webapp = response + + if webapp: + if (webapp['state'] != 'Stopped' and self.app_state == 'stopped') or \ + (webapp['state'] != 'Running' and self.app_state == 'started') or \ + self.app_state == 'restarted': + + self.results['changed'] = True + if self.check_mode: + return self.results + + self.set_webapp_state(self.app_state) + + return self.results + + # compare existing web app with input, determine weather it's update operation + def is_updatable_property_changed(self, existing_webapp): + for property_name in self.updatable_properties: + if hasattr(self, property_name) and getattr(self, property_name) is not None and \ + getattr(self, property_name) != existing_webapp.get(property_name, None): + return True + + return False + + # compare xxx_version + def is_site_config_changed(self, existing_config): + for fx_version in self.site_config_updatable_properties: + if self.site_config.get(fx_version): + if not getattr(existing_config, fx_version) or \ + getattr(existing_config, fx_version).upper() != self.site_config.get(fx_version).upper(): + return True + + return False + + # comparing existing app setting with input, determine whether it's changed + def is_app_settings_changed(self): + if self.app_settings: + if self.app_settings_strDic: + for key in self.app_settings.keys(): + if self.app_settings[key] != self.app_settings_strDic.get(key, None): + return True + else: + return True + return False + + # comparing deployment source with input, determine wheather it's changed + def is_deployment_source_changed(self, existing_webapp): + if self.deployment_source: + if self.deployment_source.get('url') \ + and self.deployment_source['url'] != existing_webapp.get('site_source_control')['url']: + return True + + if self.deployment_source.get('branch') \ + and self.deployment_source['branch'] != existing_webapp.get('site_source_control')['branch']: + return True + + return False + + def create_update_webapp(self): + ''' + Creates or updates Web App with the specified configuration. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Creating / Updating the Web App instance {0}".format(self.name)) + + try: + skip_dns_registration = self.dns_registration + force_dns_registration = None if self.dns_registration is None else not self.dns_registration + + response = self.web_client.web_apps.create_or_update(resource_group_name=self.resource_group, + name=self.name, + site_envelope=self.site, + skip_dns_registration=skip_dns_registration, + skip_custom_domain_verification=self.skip_custom_domain_verification, + force_dns_registration=force_dns_registration, + ttl_in_seconds=self.ttl_in_seconds) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the Web App instance.') + self.fail( + "Error creating the Web App instance: {0}".format(str(exc))) + return webapp_to_dict(response) + + def delete_webapp(self): + ''' + Deletes specified Web App instance in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the Web App instance {0}".format(self.name)) + try: + response = self.web_client.web_apps.delete(resource_group_name=self.resource_group, + name=self.name) + except CloudError as e: + self.log('Error attempting to delete the Web App instance.') + self.fail( + "Error deleting the Web App instance: {0}".format(str(e))) + + return True + + def get_webapp(self): + ''' + Gets the properties of the specified Web App. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Checking if the Web App instance {0} is present".format(self.name)) + + response = None + + try: + response = self.web_client.web_apps.get(resource_group_name=self.resource_group, + name=self.name) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("Web App instance : {0} found".format(response.name)) + return webapp_to_dict(response) + + except CloudError as ex: + pass + + self.log("Didn't find web app {0} in resource group {1}".format( + self.name, self.resource_group)) + + return False + + def get_app_service_plan(self): + ''' + Gets app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Get App Service Plan {0}".format(self.plan['name'])) + + try: + response = self.web_client.app_service_plans.get( + resource_group_name=self.plan['resource_group'], + name=self.plan['name']) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("App Service Plan : {0} found".format(response.name)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + pass + + self.log("Didn't find app service plan {0} in resource group {1}".format( + self.plan['name'], self.plan['resource_group'])) + + return False + + def create_app_service_plan(self): + ''' + Creates app service plan + :return: deserialized app service plan dictionary + ''' + self.log("Create App Service Plan {0}".format(self.plan['name'])) + + try: + # normalize sku + sku = _normalize_sku(self.plan['sku']) + + sku_def = SkuDescription(tier=get_sku_name( + sku), name=sku, capacity=(self.plan.get('number_of_workers', None))) + plan_def = AppServicePlan( + location=self.plan['location'], app_service_plan_name=self.plan['name'], sku=sku_def, reserved=(self.plan.get('is_linux', None))) + + poller = self.web_client.app_service_plans.create_or_update( + self.plan['resource_group'], self.plan['name'], plan_def) + + if isinstance(poller, LROPoller): + response = self.get_poller_result(poller) + + self.log("Response : {0}".format(response)) + + return appserviceplan_to_dict(response) + except CloudError as ex: + self.fail("Failed to create app service plan {0} in resource group {1}: {2}".format( + self.plan['name'], self.plan['resource_group'], str(ex))) + + def list_app_settings(self): + ''' + List application settings + :return: deserialized list response + ''' + self.log("List application setting") + + try: + + response = self.web_client.web_apps.list_application_settings( + resource_group_name=self.resource_group, name=self.name) + self.log("Response : {0}".format(response)) + + return response.properties + except CloudError as ex: + self.fail("Failed to list application settings for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def update_app_settings(self): + ''' + Update application settings + :return: deserialized updating response + ''' + self.log("Update application setting") + + try: + response = self.web_client.web_apps.update_application_settings( + resource_group_name=self.resource_group, name=self.name, properties=self.app_settings_strDic) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to update application settings for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def create_or_update_source_control(self): + ''' + Update site source control + :return: deserialized updating response + ''' + self.log("Update site source control") + + if self.deployment_source is None: + return False + + self.deployment_source['is_manual_integration'] = False + self.deployment_source['is_mercurial'] = False + + try: + response = self.web_client.web_client.create_or_update_source_control( + self.resource_group, self.name, self.deployment_source) + self.log("Response : {0}".format(response)) + + return response.as_dict() + except CloudError as ex: + self.fail("Failed to update site source control for web app {0} in resource group {1}".format( + self.name, self.resource_group)) + + def get_webapp_configuration(self): + ''' + Get web app configuration + :return: deserialized web app configuration response + ''' + self.log("Get web app configuration") + + try: + + response = self.web_client.web_apps.get_configuration( + resource_group_name=self.resource_group, name=self.name) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.log("Failed to get configuration for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + return False + + def set_webapp_state(self, appstate): + ''' + Start/stop/restart web app + :return: deserialized updating response + ''' + try: + if appstate == 'started': + response = self.web_client.web_apps.start(resource_group_name=self.resource_group, name=self.name) + elif appstate == 'stopped': + response = self.web_client.web_apps.stop(resource_group_name=self.resource_group, name=self.name) + elif appstate == 'restarted': + response = self.web_client.web_apps.restart(resource_group_name=self.resource_group, name=self.name) + else: + self.fail("Invalid web app state {0}".format(appstate)) + + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.log("Failed to {0} web app {1} in resource group {2}, request_id {3} - {4}".format( + appstate, self.name, self.resource_group, request_id, str(ex))) + + +def main(): + """Main execution""" + AzureRMWebApps() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_webapp_info.py b/test/support/integration/plugins/modules/azure_rm_webapp_info.py new file mode 100644 index 00000000..22286803 --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_webapp_info.py @@ -0,0 +1,489 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_webapp_info + +version_added: "2.9" + +short_description: Get Azure web app facts + +description: + - Get facts for a specific web app or all web app in a resource group, or all web app in current subscription. + +options: + name: + description: + - Only show results for a specific web app. + resource_group: + description: + - Limit results by resource group. + return_publish_profile: + description: + - Indicate whether to return publishing profile of the web app. + default: False + type: bool + tags: + description: + - Limit results by providing a list of tags. Format tags as 'key' or 'key:value'. + +extends_documentation_fragment: + - azure + +author: + - Yunge Zhu (@yungezz) +''' + +EXAMPLES = ''' + - name: Get facts for web app by name + azure_rm_webapp_info: + resource_group: myResourceGroup + name: winwebapp1 + + - name: Get facts for web apps in resource group + azure_rm_webapp_info: + resource_group: myResourceGroup + + - name: Get facts for web apps with tags + azure_rm_webapp_info: + tags: + - testtag + - foo:bar +''' + +RETURN = ''' +webapps: + description: + - List of web apps. + returned: always + type: complex + contains: + id: + description: + - ID of the web app. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/myWebApp + name: + description: + - Name of the web app. + returned: always + type: str + sample: winwebapp1 + resource_group: + description: + - Resource group of the web app. + returned: always + type: str + sample: myResourceGroup + location: + description: + - Location of the web app. + returned: always + type: str + sample: eastus + plan: + description: + - ID of app service plan used by the web app. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/serverfarms/myAppServicePlan + app_settings: + description: + - App settings of the application. Only returned when web app has app settings. + returned: always + type: dict + sample: { + "testkey": "testvalue", + "testkey2": "testvalue2" + } + frameworks: + description: + - Frameworks of the application. Only returned when web app has frameworks. + returned: always + type: list + sample: [ + { + "name": "net_framework", + "version": "v4.0" + }, + { + "name": "java", + "settings": { + "java_container": "tomcat", + "java_container_version": "8.5" + }, + "version": "1.7" + }, + { + "name": "php", + "version": "5.6" + } + ] + availability_state: + description: + - Availability of this web app. + returned: always + type: str + sample: Normal + default_host_name: + description: + - Host name of the web app. + returned: always + type: str + sample: vxxisurg397winapp4.azurewebsites.net + enabled: + description: + - Indicates the web app enabled or not. + returned: always + type: bool + sample: true + enabled_host_names: + description: + - Enabled host names of the web app. + returned: always + type: list + sample: [ + "vxxisurg397winapp4.azurewebsites.net", + "vxxisurg397winapp4.scm.azurewebsites.net" + ] + host_name_ssl_states: + description: + - SSL state per host names of the web app. + returned: always + type: list + sample: [ + { + "hostType": "Standard", + "name": "vxxisurg397winapp4.azurewebsites.net", + "sslState": "Disabled" + }, + { + "hostType": "Repository", + "name": "vxxisurg397winapp4.scm.azurewebsites.net", + "sslState": "Disabled" + } + ] + host_names: + description: + - Host names of the web app. + returned: always + type: list + sample: [ + "vxxisurg397winapp4.azurewebsites.net" + ] + outbound_ip_addresses: + description: + - Outbound IP address of the web app. + returned: always + type: str + sample: "40.71.11.131,40.85.166.200,168.62.166.67,137.135.126.248,137.135.121.45" + ftp_publish_url: + description: + - Publishing URL of the web app when deployment type is FTP. + returned: always + type: str + sample: ftp://xxxx.ftp.azurewebsites.windows.net + state: + description: + - State of the web app. + returned: always + type: str + sample: running + publishing_username: + description: + - Publishing profile user name. + returned: only when I(return_publish_profile=True). + type: str + sample: "$vxxisuRG397winapp4" + publishing_password: + description: + - Publishing profile password. + returned: only when I(return_publish_profile=True). + type: str + sample: "uvANsPQpGjWJmrFfm4Ssd5rpBSqGhjMk11pMSgW2vCsQtNx9tcgZ0xN26s9A" + tags: + description: + - Tags assigned to the resource. Dictionary of string:string pairs. + returned: always + type: dict + sample: { tag1: abc } +''' +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from azure.common import AzureMissingResourceHttpError, AzureHttpError +except Exception: + # This is handled in azure_rm_common + pass + +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +AZURE_OBJECT_CLASS = 'WebApp' + + +class AzureRMWebAppInfo(AzureRMModuleBase): + + def __init__(self): + + self.module_arg_spec = dict( + name=dict(type='str'), + resource_group=dict(type='str'), + tags=dict(type='list'), + return_publish_profile=dict(type='bool', default=False), + ) + + self.results = dict( + changed=False, + webapps=[], + ) + + self.name = None + self.resource_group = None + self.tags = None + self.return_publish_profile = False + + self.framework_names = ['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby'] + + super(AzureRMWebAppInfo, self).__init__(self.module_arg_spec, + supports_tags=False, + facts_module=True) + + def exec_module(self, **kwargs): + is_old_facts = self.module._name == 'azure_rm_webapp_facts' + if is_old_facts: + self.module.deprecate("The 'azure_rm_webapp_facts' module has been renamed to 'azure_rm_webapp_info'", + version='2.13', collection_name='ansible.builtin') + + for key in self.module_arg_spec: + setattr(self, key, kwargs[key]) + + if self.name: + self.results['webapps'] = self.list_by_name() + elif self.resource_group: + self.results['webapps'] = self.list_by_resource_group() + else: + self.results['webapps'] = self.list_all() + + return self.results + + def list_by_name(self): + self.log('Get web app {0}'.format(self.name)) + item = None + result = [] + + try: + item = self.web_client.web_apps.get(self.resource_group, self.name) + except CloudError: + pass + + if item and self.has_tags(item.tags, self.tags): + curated_result = self.get_curated_webapp(self.resource_group, self.name, item) + result = [curated_result] + + return result + + def list_by_resource_group(self): + self.log('List web apps in resource groups {0}'.format(self.resource_group)) + try: + response = list(self.web_client.web_apps.list_by_resource_group(self.resource_group)) + except CloudError as exc: + request_id = exc.request_id if exc.request_id else '' + self.fail("Error listing web apps in resource groups {0}, request id: {1} - {2}".format(self.resource_group, request_id, str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + curated_output = self.get_curated_webapp(self.resource_group, item.name, item) + results.append(curated_output) + return results + + def list_all(self): + self.log('List web apps in current subscription') + try: + response = list(self.web_client.web_apps.list()) + except CloudError as exc: + request_id = exc.request_id if exc.request_id else '' + self.fail("Error listing web apps, request id {0} - {1}".format(request_id, str(exc))) + + results = [] + for item in response: + if self.has_tags(item.tags, self.tags): + curated_output = self.get_curated_webapp(item.resource_group, item.name, item) + results.append(curated_output) + return results + + def list_webapp_configuration(self, resource_group, name): + self.log('Get web app {0} configuration'.format(name)) + + response = [] + + try: + response = self.web_client.web_apps.get_configuration(resource_group_name=resource_group, name=name) + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail('Error getting web app {0} configuration, request id {1} - {2}'.format(name, request_id, str(ex))) + + return response.as_dict() + + def list_webapp_appsettings(self, resource_group, name): + self.log('Get web app {0} app settings'.format(name)) + + response = [] + + try: + response = self.web_client.web_apps.list_application_settings(resource_group_name=resource_group, name=name) + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail('Error getting web app {0} app settings, request id {1} - {2}'.format(name, request_id, str(ex))) + + return response.as_dict() + + def get_publish_credentials(self, resource_group, name): + self.log('Get web app {0} publish credentials'.format(name)) + try: + poller = self.web_client.web_apps.list_publishing_credentials(resource_group, name) + if isinstance(poller, LROPoller): + response = self.get_poller_result(poller) + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail('Error getting web app {0} publishing credentials - {1}'.format(request_id, str(ex))) + return response + + def get_webapp_ftp_publish_url(self, resource_group, name): + import xmltodict + + self.log('Get web app {0} app publish profile'.format(name)) + + url = None + try: + content = self.web_client.web_apps.list_publishing_profile_xml_with_secrets(resource_group_name=resource_group, name=name) + if not content: + return url + + full_xml = '' + for f in content: + full_xml += f.decode() + profiles = xmltodict.parse(full_xml, xml_attribs=True)['publishData']['publishProfile'] + + if not profiles: + return url + + for profile in profiles: + if profile['@publishMethod'] == 'FTP': + url = profile['@publishUrl'] + + except CloudError as ex: + self.fail('Error getting web app {0} app settings'.format(name)) + + return url + + def get_curated_webapp(self, resource_group, name, webapp): + pip = self.serialize_obj(webapp, AZURE_OBJECT_CLASS) + + try: + site_config = self.list_webapp_configuration(resource_group, name) + app_settings = self.list_webapp_appsettings(resource_group, name) + publish_cred = self.get_publish_credentials(resource_group, name) + ftp_publish_url = self.get_webapp_ftp_publish_url(resource_group, name) + except CloudError as ex: + pass + return self.construct_curated_webapp(webapp=pip, + configuration=site_config, + app_settings=app_settings, + deployment_slot=None, + ftp_publish_url=ftp_publish_url, + publish_credentials=publish_cred) + + def construct_curated_webapp(self, + webapp, + configuration=None, + app_settings=None, + deployment_slot=None, + ftp_publish_url=None, + publish_credentials=None): + curated_output = dict() + curated_output['id'] = webapp['id'] + curated_output['name'] = webapp['name'] + curated_output['resource_group'] = webapp['properties']['resourceGroup'] + curated_output['location'] = webapp['location'] + curated_output['plan'] = webapp['properties']['serverFarmId'] + curated_output['tags'] = webapp.get('tags', None) + + # important properties from output. not match input arguments. + curated_output['app_state'] = webapp['properties']['state'] + curated_output['availability_state'] = webapp['properties']['availabilityState'] + curated_output['default_host_name'] = webapp['properties']['defaultHostName'] + curated_output['host_names'] = webapp['properties']['hostNames'] + curated_output['enabled'] = webapp['properties']['enabled'] + curated_output['enabled_host_names'] = webapp['properties']['enabledHostNames'] + curated_output['host_name_ssl_states'] = webapp['properties']['hostNameSslStates'] + curated_output['outbound_ip_addresses'] = webapp['properties']['outboundIpAddresses'] + + # curated site_config + if configuration: + curated_output['frameworks'] = [] + for fx_name in self.framework_names: + fx_version = configuration.get(fx_name + '_version', None) + if fx_version: + fx = { + 'name': fx_name, + 'version': fx_version + } + # java container setting + if fx_name == 'java': + if configuration['java_container'] and configuration['java_container_version']: + settings = { + 'java_container': configuration['java_container'].lower(), + 'java_container_version': configuration['java_container_version'] + } + fx['settings'] = settings + + curated_output['frameworks'].append(fx) + + # linux_fx_version + if configuration.get('linux_fx_version', None): + tmp = configuration.get('linux_fx_version').split("|") + if len(tmp) == 2: + curated_output['frameworks'].append({'name': tmp[0].lower(), 'version': tmp[1]}) + + # curated app_settings + if app_settings and app_settings.get('properties', None): + curated_output['app_settings'] = dict() + for item in app_settings['properties']: + curated_output['app_settings'][item] = app_settings['properties'][item] + + # curated deploymenet_slot + if deployment_slot: + curated_output['deployment_slot'] = deployment_slot + + # ftp_publish_url + if ftp_publish_url: + curated_output['ftp_publish_url'] = ftp_publish_url + + # curated publish credentials + if publish_credentials and self.return_publish_profile: + curated_output['publishing_username'] = publish_credentials.publishing_user_name + curated_output['publishing_password'] = publish_credentials.publishing_password + return curated_output + + +def main(): + AzureRMWebAppInfo() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/azure_rm_webappslot.py b/test/support/integration/plugins/modules/azure_rm_webappslot.py new file mode 100644 index 00000000..ddba710b --- /dev/null +++ b/test/support/integration/plugins/modules/azure_rm_webappslot.py @@ -0,0 +1,1058 @@ +#!/usr/bin/python +# +# Copyright (c) 2018 Yunge Zhu, <yungez@microsoft.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: azure_rm_webappslot +version_added: "2.8" +short_description: Manage Azure Web App slot +description: + - Create, update and delete Azure Web App slot. + +options: + resource_group: + description: + - Name of the resource group to which the resource belongs. + required: True + name: + description: + - Unique name of the deployment slot to create or update. + required: True + webapp_name: + description: + - Web app name which this deployment slot belongs to. + required: True + location: + description: + - Resource location. If not set, location from the resource group will be used as default. + configuration_source: + description: + - Source slot to clone configurations from when creating slot. Use webapp's name to refer to the production slot. + auto_swap_slot_name: + description: + - Used to configure target slot name to auto swap, or disable auto swap. + - Set it target slot name to auto swap. + - Set it to False to disable auto slot swap. + swap: + description: + - Swap deployment slots of a web app. + suboptions: + action: + description: + - Swap types. + - C(preview) is to apply target slot settings on source slot first. + - C(swap) is to complete swapping. + - C(reset) is to reset the swap. + choices: + - preview + - swap + - reset + default: preview + target_slot: + description: + - Name of target slot to swap. If set to None, then swap with production slot. + preserve_vnet: + description: + - C(True) to preserve virtual network to the slot during swap. Otherwise C(False). + type: bool + default: True + frameworks: + description: + - Set of run time framework settings. Each setting is a dictionary. + - See U(https://docs.microsoft.com/en-us/azure/app-service/app-service-web-overview) for more info. + suboptions: + name: + description: + - Name of the framework. + - Supported framework list for Windows web app and Linux web app is different. + - Windows web apps support C(java), C(net_framework), C(php), C(python), and C(node) from June 2018. + - Windows web apps support multiple framework at same time. + - Linux web apps support C(java), C(ruby), C(php), C(dotnetcore), and C(node) from June 2018. + - Linux web apps support only one framework. + - Java framework is mutually exclusive with others. + choices: + - java + - net_framework + - php + - python + - ruby + - dotnetcore + - node + version: + description: + - Version of the framework. For Linux web app supported value, see U(https://aka.ms/linux-stacks) for more info. + - C(net_framework) supported value sample, C(v4.0) for .NET 4.6 and C(v3.0) for .NET 3.5. + - C(php) supported value sample, C(5.5), C(5.6), C(7.0). + - C(python) supported value sample, C(5.5), C(5.6), C(7.0). + - C(node) supported value sample, C(6.6), C(6.9). + - C(dotnetcore) supported value sample, C(1.0), C(1.1), C(1.2). + - C(ruby) supported value sample, 2.3. + - C(java) supported value sample, C(1.9) for Windows web app. C(1.8) for Linux web app. + settings: + description: + - List of settings of the framework. + suboptions: + java_container: + description: + - Name of Java container. This is supported by specific framework C(java) onlys, for example C(Tomcat), C(Jetty). + java_container_version: + description: + - Version of Java container. This is supported by specific framework C(java) only. + - For C(Tomcat), for example C(8.0), C(8.5), C(9.0). For C(Jetty), for example C(9.1), C(9.3). + container_settings: + description: + - Web app slot container settings. + suboptions: + name: + description: + - Name of container, for example C(imagename:tag). + registry_server_url: + description: + - Container registry server URL, for example C(mydockerregistry.io). + registry_server_user: + description: + - The container registry server user name. + registry_server_password: + description: + - The container registry server password. + startup_file: + description: + - The slot startup file. + - This only applies for Linux web app slot. + app_settings: + description: + - Configure web app slot application settings. Suboptions are in key value pair format. + purge_app_settings: + description: + - Purge any existing application settings. Replace slot application settings with app_settings. + type: bool + deployment_source: + description: + - Deployment source for git. + suboptions: + url: + description: + - Repository URL of deployment source. + branch: + description: + - The branch name of the repository. + app_state: + description: + - Start/Stop/Restart the slot. + type: str + choices: + - started + - stopped + - restarted + default: started + state: + description: + - State of the Web App deployment slot. + - Use C(present) to create or update a slot and C(absent) to delete it. + default: present + choices: + - absent + - present + +extends_documentation_fragment: + - azure + - azure_tags + +author: + - Yunge Zhu(@yungezz) + +''' + +EXAMPLES = ''' + - name: Create a webapp slot + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + configuration_source: myJavaWebApp + app_settings: + testkey: testvalue + + - name: swap the slot with production slot + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + swap: + action: swap + + - name: stop the slot + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + app_state: stopped + + - name: udpate a webapp slot app settings + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + app_settings: + testkey: testvalue2 + + - name: udpate a webapp slot frameworks + azure_rm_webappslot: + resource_group: myResourceGroup + webapp_name: myJavaWebApp + name: stage + frameworks: + - name: "node" + version: "10.1" +''' + +RETURN = ''' +id: + description: + - ID of current slot. + returned: always + type: str + sample: /subscriptions/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx/resourceGroups/myResourceGroup/providers/Microsoft.Web/sites/testapp/slots/stage1 +''' + +import time +from ansible.module_utils.azure_rm_common import AzureRMModuleBase + +try: + from msrestazure.azure_exceptions import CloudError + from msrest.polling import LROPoller + from msrest.serialization import Model + from azure.mgmt.web.models import ( + site_config, app_service_plan, Site, + AppServicePlan, SkuDescription, NameValuePair + ) +except ImportError: + # This is handled in azure_rm_common + pass + +swap_spec = dict( + action=dict( + type='str', + choices=[ + 'preview', + 'swap', + 'reset' + ], + default='preview' + ), + target_slot=dict( + type='str' + ), + preserve_vnet=dict( + type='bool', + default=True + ) +) + +container_settings_spec = dict( + name=dict(type='str', required=True), + registry_server_url=dict(type='str'), + registry_server_user=dict(type='str'), + registry_server_password=dict(type='str', no_log=True) +) + +deployment_source_spec = dict( + url=dict(type='str'), + branch=dict(type='str') +) + + +framework_settings_spec = dict( + java_container=dict(type='str', required=True), + java_container_version=dict(type='str', required=True) +) + + +framework_spec = dict( + name=dict( + type='str', + required=True, + choices=['net_framework', 'java', 'php', 'node', 'python', 'dotnetcore', 'ruby']), + version=dict(type='str', required=True), + settings=dict(type='dict', options=framework_settings_spec) +) + + +def webapp_to_dict(webapp): + return dict( + id=webapp.id, + name=webapp.name, + location=webapp.location, + client_cert_enabled=webapp.client_cert_enabled, + enabled=webapp.enabled, + reserved=webapp.reserved, + client_affinity_enabled=webapp.client_affinity_enabled, + server_farm_id=webapp.server_farm_id, + host_names_disabled=webapp.host_names_disabled, + https_only=webapp.https_only if hasattr(webapp, 'https_only') else None, + skip_custom_domain_verification=webapp.skip_custom_domain_verification if hasattr(webapp, 'skip_custom_domain_verification') else None, + ttl_in_seconds=webapp.ttl_in_seconds if hasattr(webapp, 'ttl_in_seconds') else None, + state=webapp.state, + tags=webapp.tags if webapp.tags else None + ) + + +def slot_to_dict(slot): + return dict( + id=slot.id, + resource_group=slot.resource_group, + server_farm_id=slot.server_farm_id, + target_swap_slot=slot.target_swap_slot, + enabled_host_names=slot.enabled_host_names, + slot_swap_status=slot.slot_swap_status, + name=slot.name, + location=slot.location, + enabled=slot.enabled, + reserved=slot.reserved, + host_names_disabled=slot.host_names_disabled, + state=slot.state, + repository_site_name=slot.repository_site_name, + default_host_name=slot.default_host_name, + kind=slot.kind, + site_config=slot.site_config, + tags=slot.tags if slot.tags else None + ) + + +class Actions: + NoAction, CreateOrUpdate, UpdateAppSettings, Delete = range(4) + + +class AzureRMWebAppSlots(AzureRMModuleBase): + """Configuration class for an Azure RM Web App slot resource""" + + def __init__(self): + self.module_arg_spec = dict( + resource_group=dict( + type='str', + required=True + ), + name=dict( + type='str', + required=True + ), + webapp_name=dict( + type='str', + required=True + ), + location=dict( + type='str' + ), + configuration_source=dict( + type='str' + ), + auto_swap_slot_name=dict( + type='raw' + ), + swap=dict( + type='dict', + options=swap_spec + ), + frameworks=dict( + type='list', + elements='dict', + options=framework_spec + ), + container_settings=dict( + type='dict', + options=container_settings_spec + ), + deployment_source=dict( + type='dict', + options=deployment_source_spec + ), + startup_file=dict( + type='str' + ), + app_settings=dict( + type='dict' + ), + purge_app_settings=dict( + type='bool', + default=False + ), + app_state=dict( + type='str', + choices=['started', 'stopped', 'restarted'], + default='started' + ), + state=dict( + type='str', + default='present', + choices=['present', 'absent'] + ) + ) + + mutually_exclusive = [['container_settings', 'frameworks']] + + self.resource_group = None + self.name = None + self.webapp_name = None + self.location = None + + self.auto_swap_slot_name = None + self.swap = None + self.tags = None + self.startup_file = None + self.configuration_source = None + self.clone = False + + # site config, e.g app settings, ssl + self.site_config = dict() + self.app_settings = dict() + self.app_settings_strDic = None + + # siteSourceControl + self.deployment_source = dict() + + # site, used at level creation, or update. + self.site = None + + # property for internal usage, not used for sdk + self.container_settings = None + + self.purge_app_settings = False + self.app_state = 'started' + + self.results = dict( + changed=False, + id=None, + ) + self.state = None + self.to_do = Actions.NoAction + + self.frameworks = None + + # set site_config value from kwargs + self.site_config_updatable_frameworks = ["net_framework_version", + "java_version", + "php_version", + "python_version", + "linux_fx_version"] + + self.supported_linux_frameworks = ['ruby', 'php', 'dotnetcore', 'node', 'java'] + self.supported_windows_frameworks = ['net_framework', 'php', 'python', 'node', 'java'] + + super(AzureRMWebAppSlots, self).__init__(derived_arg_spec=self.module_arg_spec, + mutually_exclusive=mutually_exclusive, + supports_check_mode=True, + supports_tags=True) + + def exec_module(self, **kwargs): + """Main module execution method""" + + for key in list(self.module_arg_spec.keys()) + ['tags']: + if hasattr(self, key): + setattr(self, key, kwargs[key]) + elif kwargs[key] is not None: + if key == "scm_type": + self.site_config[key] = kwargs[key] + + old_response = None + response = None + to_be_updated = False + + # set location + resource_group = self.get_resource_group(self.resource_group) + if not self.location: + self.location = resource_group.location + + # get web app + webapp_response = self.get_webapp() + + if not webapp_response: + self.fail("Web app {0} does not exist in resource group {1}.".format(self.webapp_name, self.resource_group)) + + # get slot + old_response = self.get_slot() + + # set is_linux + is_linux = True if webapp_response['reserved'] else False + + if self.state == 'present': + if self.frameworks: + # java is mutually exclusive with other frameworks + if len(self.frameworks) > 1 and any(f['name'] == 'java' for f in self.frameworks): + self.fail('Java is mutually exclusive with other frameworks.') + + if is_linux: + if len(self.frameworks) != 1: + self.fail('Can specify one framework only for Linux web app.') + + if self.frameworks[0]['name'] not in self.supported_linux_frameworks: + self.fail('Unsupported framework {0} for Linux web app.'.format(self.frameworks[0]['name'])) + + self.site_config['linux_fx_version'] = (self.frameworks[0]['name'] + '|' + self.frameworks[0]['version']).upper() + + if self.frameworks[0]['name'] == 'java': + if self.frameworks[0]['version'] != '8': + self.fail("Linux web app only supports java 8.") + + if self.frameworks[0].get('settings', {}) and self.frameworks[0]['settings'].get('java_container', None) and \ + self.frameworks[0]['settings']['java_container'].lower() != 'tomcat': + self.fail("Linux web app only supports tomcat container.") + + if self.frameworks[0].get('settings', {}) and self.frameworks[0]['settings'].get('java_container', None) and \ + self.frameworks[0]['settings']['java_container'].lower() == 'tomcat': + self.site_config['linux_fx_version'] = 'TOMCAT|' + self.frameworks[0]['settings']['java_container_version'] + '-jre8' + else: + self.site_config['linux_fx_version'] = 'JAVA|8-jre8' + else: + for fx in self.frameworks: + if fx.get('name') not in self.supported_windows_frameworks: + self.fail('Unsupported framework {0} for Windows web app.'.format(fx.get('name'))) + else: + self.site_config[fx.get('name') + '_version'] = fx.get('version') + + if 'settings' in fx and fx['settings'] is not None: + for key, value in fx['settings'].items(): + self.site_config[key] = value + + if not self.app_settings: + self.app_settings = dict() + + if self.container_settings: + linux_fx_version = 'DOCKER|' + + if self.container_settings.get('registry_server_url'): + self.app_settings['DOCKER_REGISTRY_SERVER_URL'] = 'https://' + self.container_settings['registry_server_url'] + + linux_fx_version += self.container_settings['registry_server_url'] + '/' + + linux_fx_version += self.container_settings['name'] + + self.site_config['linux_fx_version'] = linux_fx_version + + if self.container_settings.get('registry_server_user'): + self.app_settings['DOCKER_REGISTRY_SERVER_USERNAME'] = self.container_settings['registry_server_user'] + + if self.container_settings.get('registry_server_password'): + self.app_settings['DOCKER_REGISTRY_SERVER_PASSWORD'] = self.container_settings['registry_server_password'] + + # set auto_swap_slot_name + if self.auto_swap_slot_name and isinstance(self.auto_swap_slot_name, str): + self.site_config['auto_swap_slot_name'] = self.auto_swap_slot_name + if self.auto_swap_slot_name is False: + self.site_config['auto_swap_slot_name'] = None + + # init site + self.site = Site(location=self.location, site_config=self.site_config) + + # check if the slot already present in the webapp + if not old_response: + self.log("Web App slot doesn't exist") + + to_be_updated = True + self.to_do = Actions.CreateOrUpdate + self.site.tags = self.tags + + # if linux, setup startup_file + if self.startup_file: + self.site_config['app_command_line'] = self.startup_file + + # set app setting + if self.app_settings: + app_settings = [] + for key in self.app_settings.keys(): + app_settings.append(NameValuePair(name=key, value=self.app_settings[key])) + + self.site_config['app_settings'] = app_settings + + # clone slot + if self.configuration_source: + self.clone = True + + else: + # existing slot, do update + self.log("Web App slot already exists") + + self.log('Result: {0}'.format(old_response)) + + update_tags, self.site.tags = self.update_tags(old_response.get('tags', None)) + + if update_tags: + to_be_updated = True + + # check if site_config changed + old_config = self.get_configuration_slot(self.name) + + if self.is_site_config_changed(old_config): + to_be_updated = True + self.to_do = Actions.CreateOrUpdate + + self.app_settings_strDic = self.list_app_settings_slot(self.name) + + # purge existing app_settings: + if self.purge_app_settings: + to_be_updated = True + self.to_do = Actions.UpdateAppSettings + self.app_settings_strDic = dict() + + # check if app settings changed + if self.purge_app_settings or self.is_app_settings_changed(): + to_be_updated = True + self.to_do = Actions.UpdateAppSettings + + if self.app_settings: + for key in self.app_settings.keys(): + self.app_settings_strDic[key] = self.app_settings[key] + + elif self.state == 'absent': + if old_response: + self.log("Delete Web App slot") + self.results['changed'] = True + + if self.check_mode: + return self.results + + self.delete_slot() + + self.log('Web App slot deleted') + + else: + self.log("Web app slot {0} not exists.".format(self.name)) + + if to_be_updated: + self.log('Need to Create/Update web app') + self.results['changed'] = True + + if self.check_mode: + return self.results + + if self.to_do == Actions.CreateOrUpdate: + response = self.create_update_slot() + + self.results['id'] = response['id'] + + if self.clone: + self.clone_slot() + + if self.to_do == Actions.UpdateAppSettings: + self.update_app_settings_slot() + + slot = None + if response: + slot = response + if old_response: + slot = old_response + + if slot: + if (slot['state'] != 'Stopped' and self.app_state == 'stopped') or \ + (slot['state'] != 'Running' and self.app_state == 'started') or \ + self.app_state == 'restarted': + + self.results['changed'] = True + if self.check_mode: + return self.results + + self.set_state_slot(self.app_state) + + if self.swap: + self.results['changed'] = True + if self.check_mode: + return self.results + + self.swap_slot() + + return self.results + + # compare site config + def is_site_config_changed(self, existing_config): + for fx_version in self.site_config_updatable_frameworks: + if self.site_config.get(fx_version): + if not getattr(existing_config, fx_version) or \ + getattr(existing_config, fx_version).upper() != self.site_config.get(fx_version).upper(): + return True + + if self.auto_swap_slot_name is False and existing_config.auto_swap_slot_name is not None: + return True + elif self.auto_swap_slot_name and self.auto_swap_slot_name != getattr(existing_config, 'auto_swap_slot_name', None): + return True + return False + + # comparing existing app setting with input, determine whether it's changed + def is_app_settings_changed(self): + if self.app_settings: + if len(self.app_settings_strDic) != len(self.app_settings): + return True + + if self.app_settings_strDic != self.app_settings: + return True + return False + + # comparing deployment source with input, determine whether it's changed + def is_deployment_source_changed(self, existing_webapp): + if self.deployment_source: + if self.deployment_source.get('url') \ + and self.deployment_source['url'] != existing_webapp.get('site_source_control')['url']: + return True + + if self.deployment_source.get('branch') \ + and self.deployment_source['branch'] != existing_webapp.get('site_source_control')['branch']: + return True + + return False + + def create_update_slot(self): + ''' + Creates or updates Web App slot with the specified configuration. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Creating / Updating the Web App slot {0}".format(self.name)) + + try: + response = self.web_client.web_apps.create_or_update_slot(resource_group_name=self.resource_group, + slot=self.name, + name=self.webapp_name, + site_envelope=self.site) + if isinstance(response, LROPoller): + response = self.get_poller_result(response) + + except CloudError as exc: + self.log('Error attempting to create the Web App slot instance.') + self.fail("Error creating the Web App slot: {0}".format(str(exc))) + return slot_to_dict(response) + + def delete_slot(self): + ''' + Deletes specified Web App slot in the specified subscription and resource group. + + :return: True + ''' + self.log("Deleting the Web App slot {0}".format(self.name)) + try: + response = self.web_client.web_apps.delete_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name) + except CloudError as e: + self.log('Error attempting to delete the Web App slot.') + self.fail( + "Error deleting the Web App slots: {0}".format(str(e))) + + return True + + def get_webapp(self): + ''' + Gets the properties of the specified Web App. + + :return: deserialized Web App instance state dictionary + ''' + self.log( + "Checking if the Web App instance {0} is present".format(self.webapp_name)) + + response = None + + try: + response = self.web_client.web_apps.get(resource_group_name=self.resource_group, + name=self.webapp_name) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("Web App instance : {0} found".format(response.name)) + return webapp_to_dict(response) + + except CloudError as ex: + pass + + self.log("Didn't find web app {0} in resource group {1}".format( + self.webapp_name, self.resource_group)) + + return False + + def get_slot(self): + ''' + Gets the properties of the specified Web App slot. + + :return: deserialized Web App slot state dictionary + ''' + self.log( + "Checking if the Web App slot {0} is present".format(self.name)) + + response = None + + try: + response = self.web_client.web_apps.get_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name) + + # Newer SDK versions (0.40.0+) seem to return None if it doesn't exist instead of raising CloudError + if response is not None: + self.log("Response : {0}".format(response)) + self.log("Web App slot: {0} found".format(response.name)) + return slot_to_dict(response) + + except CloudError as ex: + pass + + self.log("Does not find web app slot {0} in resource group {1}".format(self.name, self.resource_group)) + + return False + + def list_app_settings(self): + ''' + List webapp application settings + :return: deserialized list response + ''' + self.log("List webapp application setting") + + try: + + response = self.web_client.web_apps.list_application_settings( + resource_group_name=self.resource_group, name=self.webapp_name) + self.log("Response : {0}".format(response)) + + return response.properties + except CloudError as ex: + self.fail("Failed to list application settings for web app {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def list_app_settings_slot(self, slot_name): + ''' + List application settings + :return: deserialized list response + ''' + self.log("List application setting") + + try: + + response = self.web_client.web_apps.list_application_settings_slot( + resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name) + self.log("Response : {0}".format(response)) + + return response.properties + except CloudError as ex: + self.fail("Failed to list application settings for web app slot {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def update_app_settings_slot(self, slot_name=None, app_settings=None): + ''' + Update application settings + :return: deserialized updating response + ''' + self.log("Update application setting") + + if slot_name is None: + slot_name = self.name + if app_settings is None: + app_settings = self.app_settings_strDic + try: + response = self.web_client.web_apps.update_application_settings_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=slot_name, + kind=None, + properties=app_settings) + self.log("Response : {0}".format(response)) + + return response.as_dict() + except CloudError as ex: + self.fail("Failed to update application settings for web app slot {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + return response + + def create_or_update_source_control_slot(self): + ''' + Update site source control + :return: deserialized updating response + ''' + self.log("Update site source control") + + if self.deployment_source is None: + return False + + self.deployment_source['is_manual_integration'] = False + self.deployment_source['is_mercurial'] = False + + try: + response = self.web_client.web_client.create_or_update_source_control_slot( + resource_group_name=self.resource_group, + name=self.webapp_name, + site_source_control=self.deployment_source, + slot=self.name) + self.log("Response : {0}".format(response)) + + return response.as_dict() + except CloudError as ex: + self.fail("Failed to update site source control for web app slot {0} in resource group {1}: {2}".format( + self.name, self.resource_group, str(ex))) + + def get_configuration(self): + ''' + Get web app configuration + :return: deserialized web app configuration response + ''' + self.log("Get web app configuration") + + try: + + response = self.web_client.web_apps.get_configuration( + resource_group_name=self.resource_group, name=self.webapp_name) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to get configuration for web app {0} in resource group {1}: {2}".format( + self.webapp_name, self.resource_group, str(ex))) + + def get_configuration_slot(self, slot_name): + ''' + Get slot configuration + :return: deserialized slot configuration response + ''' + self.log("Get web app slot configuration") + + try: + + response = self.web_client.web_apps.get_configuration_slot( + resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to get configuration for web app slot {0} in resource group {1}: {2}".format( + slot_name, self.resource_group, str(ex))) + + def update_configuration_slot(self, slot_name=None, site_config=None): + ''' + Update slot configuration + :return: deserialized slot configuration response + ''' + self.log("Update web app slot configuration") + + if slot_name is None: + slot_name = self.name + if site_config is None: + site_config = self.site_config + try: + + response = self.web_client.web_apps.update_configuration_slot( + resource_group_name=self.resource_group, name=self.webapp_name, slot=slot_name, site_config=site_config) + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to update configuration for web app slot {0} in resource group {1}: {2}".format( + slot_name, self.resource_group, str(ex))) + + def set_state_slot(self, appstate): + ''' + Start/stop/restart web app slot + :return: deserialized updating response + ''' + try: + if appstate == 'started': + response = self.web_client.web_apps.start_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name) + elif appstate == 'stopped': + response = self.web_client.web_apps.stop_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name) + elif appstate == 'restarted': + response = self.web_client.web_apps.restart_slot(resource_group_name=self.resource_group, name=self.webapp_name, slot=self.name) + else: + self.fail("Invalid web app slot state {0}".format(appstate)) + + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + request_id = ex.request_id if ex.request_id else '' + self.fail("Failed to {0} web app slot {1} in resource group {2}, request_id {3} - {4}".format( + appstate, self.name, self.resource_group, request_id, str(ex))) + + def swap_slot(self): + ''' + Swap slot + :return: deserialized response + ''' + self.log("Swap slot") + + try: + if self.swap['action'] == 'swap': + if self.swap['target_slot'] is None: + response = self.web_client.web_apps.swap_slot_with_production(resource_group_name=self.resource_group, + name=self.webapp_name, + target_slot=self.name, + preserve_vnet=self.swap['preserve_vnet']) + else: + response = self.web_client.web_apps.swap_slot_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name, + target_slot=self.swap['target_slot'], + preserve_vnet=self.swap['preserve_vnet']) + elif self.swap['action'] == 'preview': + if self.swap['target_slot'] is None: + response = self.web_client.web_apps.apply_slot_config_to_production(resource_group_name=self.resource_group, + name=self.webapp_name, + target_slot=self.name, + preserve_vnet=self.swap['preserve_vnet']) + else: + response = self.web_client.web_apps.apply_slot_configuration_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name, + target_slot=self.swap['target_slot'], + preserve_vnet=self.swap['preserve_vnet']) + elif self.swap['action'] == 'reset': + if self.swap['target_slot'] is None: + response = self.web_client.web_apps.reset_production_slot_config(resource_group_name=self.resource_group, + name=self.webapp_name) + else: + response = self.web_client.web_apps.reset_slot_configuration_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.swap['target_slot']) + response = self.web_client.web_apps.reset_slot_configuration_slot(resource_group_name=self.resource_group, + name=self.webapp_name, + slot=self.name) + + self.log("Response : {0}".format(response)) + + return response + except CloudError as ex: + self.fail("Failed to swap web app slot {0} in resource group {1}: {2}".format(self.name, self.resource_group, str(ex))) + + def clone_slot(self): + if self.configuration_source: + src_slot = None if self.configuration_source.lower() == self.webapp_name.lower() else self.configuration_source + + if src_slot is None: + site_config_clone_from = self.get_configuration() + else: + site_config_clone_from = self.get_configuration_slot(slot_name=src_slot) + + self.update_configuration_slot(site_config=site_config_clone_from) + + if src_slot is None: + app_setting_clone_from = self.list_app_settings() + else: + app_setting_clone_from = self.list_app_settings_slot(src_slot) + + if self.app_settings: + app_setting_clone_from.update(self.app_settings) + + self.update_app_settings_slot(app_settings=app_setting_clone_from) + + +def main(): + """Main execution""" + AzureRMWebAppSlots() + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/cloud_init_data_facts.py b/test/support/integration/plugins/modules/cloud_init_data_facts.py new file mode 100644 index 00000000..4f871b99 --- /dev/null +++ b/test/support/integration/plugins/modules/cloud_init_data_facts.py @@ -0,0 +1,134 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# (c) 2018, René Moser <mail@renemoser.net> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: cloud_init_data_facts +short_description: Retrieve facts of cloud-init. +description: + - Gathers facts by reading the status.json and result.json of cloud-init. +version_added: 2.6 +author: René Moser (@resmo) +options: + filter: + description: + - Filter facts + choices: [ status, result ] +notes: + - See http://cloudinit.readthedocs.io/ for more information about cloud-init. +''' + +EXAMPLES = ''' +- name: Gather all facts of cloud init + cloud_init_data_facts: + register: result + +- debug: + var: result + +- name: Wait for cloud init to finish + cloud_init_data_facts: + filter: status + register: res + until: "res.cloud_init_data_facts.status.v1.stage is defined and not res.cloud_init_data_facts.status.v1.stage" + retries: 50 + delay: 5 +''' + +RETURN = ''' +--- +cloud_init_data_facts: + description: Facts of result and status. + returned: success + type: dict + sample: '{ + "status": { + "v1": { + "datasource": "DataSourceCloudStack", + "errors": [] + }, + "result": { + "v1": { + "datasource": "DataSourceCloudStack", + "init": { + "errors": [], + "finished": 1522066377.0185432, + "start": 1522066375.2648022 + }, + "init-local": { + "errors": [], + "finished": 1522066373.70919, + "start": 1522066373.4726632 + }, + "modules-config": { + "errors": [], + "finished": 1522066380.9097016, + "start": 1522066379.0011985 + }, + "modules-final": { + "errors": [], + "finished": 1522066383.56594, + "start": 1522066382.3449218 + }, + "stage": null + } + }' +''' + +import os + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_text + + +CLOUD_INIT_PATH = "/var/lib/cloud/data/" + + +def gather_cloud_init_data_facts(module): + res = { + 'cloud_init_data_facts': dict() + } + + for i in ['result', 'status']: + filter = module.params.get('filter') + if filter is None or filter == i: + res['cloud_init_data_facts'][i] = dict() + json_file = CLOUD_INIT_PATH + i + '.json' + + if os.path.exists(json_file): + f = open(json_file, 'rb') + contents = to_text(f.read(), errors='surrogate_or_strict') + f.close() + + if contents: + res['cloud_init_data_facts'][i] = module.from_json(contents) + return res + + +def main(): + module = AnsibleModule( + argument_spec=dict( + filter=dict(choices=['result', 'status']), + ), + supports_check_mode=True, + ) + + facts = gather_cloud_init_data_facts(module) + result = dict(changed=False, ansible_facts=facts, **facts) + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/cloudformation.py b/test/support/integration/plugins/modules/cloudformation.py new file mode 100644 index 00000000..cd031465 --- /dev/null +++ b/test/support/integration/plugins/modules/cloudformation.py @@ -0,0 +1,837 @@ +#!/usr/bin/python + +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: cloudformation +short_description: Create or delete an AWS CloudFormation stack +description: + - Launches or updates an AWS CloudFormation stack and waits for it complete. +notes: + - CloudFormation features change often, and this module tries to keep up. That means your botocore version should be fresh. + The version listed in the requirements is the oldest version that works with the module as a whole. + Some features may require recent versions, and we do not pinpoint a minimum version for each feature. + Instead of relying on the minimum version, keep botocore up to date. AWS is always releasing features and fixing bugs. +version_added: "1.1" +options: + stack_name: + description: + - Name of the CloudFormation stack. + required: true + type: str + disable_rollback: + description: + - If a stacks fails to form, rollback will remove the stack. + default: false + type: bool + on_create_failure: + description: + - Action to take upon failure of stack creation. Incompatible with the I(disable_rollback) option. + choices: + - DO_NOTHING + - ROLLBACK + - DELETE + version_added: "2.8" + type: str + create_timeout: + description: + - The amount of time (in minutes) that can pass before the stack status becomes CREATE_FAILED + version_added: "2.6" + type: int + template_parameters: + description: + - A list of hashes of all the template variables for the stack. The value can be a string or a dict. + - Dict can be used to set additional template parameter attributes like UsePreviousValue (see example). + default: {} + type: dict + state: + description: + - If I(state=present), stack will be created. + - If I(state=present) and if stack exists and template has changed, it will be updated. + - If I(state=absent), stack will be removed. + default: present + choices: [ present, absent ] + type: str + template: + description: + - The local path of the CloudFormation template. + - This must be the full path to the file, relative to the working directory. If using roles this may look + like C(roles/cloudformation/files/cloudformation-example.json). + - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url) + must be specified (but only one of them). + - If I(state=present), the stack does exist, and neither I(template), + I(template_body) nor I(template_url) are specified, the previous template will be reused. + type: path + notification_arns: + description: + - A comma separated list of Simple Notification Service (SNS) topic ARNs to publish stack related events. + version_added: "2.0" + type: str + stack_policy: + description: + - The path of the CloudFormation stack policy. A policy cannot be removed once placed, but it can be modified. + for instance, allow all updates U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/protect-stack-resources.html#d0e9051) + version_added: "1.9" + type: str + tags: + description: + - Dictionary of tags to associate with stack and its resources during stack creation. + - Can be updated later, updating tags removes previous entries. + version_added: "1.4" + type: dict + template_url: + description: + - Location of file containing the template body. The URL must point to a template (max size 307,200 bytes) located in an + S3 bucket in the same region as the stack. + - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url) + must be specified (but only one of them). + - If I(state=present), the stack does exist, and neither I(template), I(template_body) nor I(template_url) are specified, + the previous template will be reused. + version_added: "2.0" + type: str + create_changeset: + description: + - "If stack already exists create a changeset instead of directly applying changes. See the AWS Change Sets docs + U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-changesets.html)." + - "WARNING: if the stack does not exist, it will be created without changeset. If I(state=absent), the stack will be + deleted immediately with no changeset." + type: bool + default: false + version_added: "2.4" + changeset_name: + description: + - Name given to the changeset when creating a changeset. + - Only used when I(create_changeset=true). + - By default a name prefixed with Ansible-STACKNAME is generated based on input parameters. + See the AWS Change Sets docs for more information + U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-updating-stacks-changesets.html) + version_added: "2.4" + type: str + template_format: + description: + - This parameter is ignored since Ansible 2.3 and will be removed in Ansible 2.14. + - Templates are now passed raw to CloudFormation regardless of format. + version_added: "2.0" + type: str + role_arn: + description: + - The role that AWS CloudFormation assumes to create the stack. See the AWS CloudFormation Service Role + docs U(https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-iam-servicerole.html) + version_added: "2.3" + type: str + termination_protection: + description: + - Enable or disable termination protection on the stack. Only works with botocore >= 1.7.18. + type: bool + version_added: "2.5" + template_body: + description: + - Template body. Use this to pass in the actual body of the CloudFormation template. + - If I(state=present) and the stack does not exist yet, either I(template), I(template_body) or I(template_url) + must be specified (but only one of them). + - If I(state=present), the stack does exist, and neither I(template), I(template_body) nor I(template_url) + are specified, the previous template will be reused. + version_added: "2.5" + type: str + events_limit: + description: + - Maximum number of CloudFormation events to fetch from a stack when creating or updating it. + default: 200 + version_added: "2.7" + type: int + backoff_delay: + description: + - Number of seconds to wait for the next retry. + default: 3 + version_added: "2.8" + type: int + required: False + backoff_max_delay: + description: + - Maximum amount of time to wait between retries. + default: 30 + version_added: "2.8" + type: int + required: False + backoff_retries: + description: + - Number of times to retry operation. + - AWS API throttling mechanism fails CloudFormation module so we have to retry a couple of times. + default: 10 + version_added: "2.8" + type: int + required: False + capabilities: + description: + - Specify capabilities that stack template contains. + - Valid values are C(CAPABILITY_IAM), C(CAPABILITY_NAMED_IAM) and C(CAPABILITY_AUTO_EXPAND). + type: list + elements: str + version_added: "2.8" + default: [ CAPABILITY_IAM, CAPABILITY_NAMED_IAM ] + +author: "James S. Martin (@jsmartin)" +extends_documentation_fragment: +- aws +- ec2 +requirements: [ boto3, botocore>=1.5.45 ] +''' + +EXAMPLES = ''' +- name: create a cloudformation stack + cloudformation: + stack_name: "ansible-cloudformation" + state: "present" + region: "us-east-1" + disable_rollback: true + template: "files/cloudformation-example.json" + template_parameters: + KeyName: "jmartin" + DiskType: "ephemeral" + InstanceType: "m1.small" + ClusterSize: 3 + tags: + Stack: "ansible-cloudformation" + +# Basic role example +- name: create a stack, specify role that cloudformation assumes + cloudformation: + stack_name: "ansible-cloudformation" + state: "present" + region: "us-east-1" + disable_rollback: true + template: "roles/cloudformation/files/cloudformation-example.json" + role_arn: 'arn:aws:iam::123456789012:role/cloudformation-iam-role' + +- name: delete a stack + cloudformation: + stack_name: "ansible-cloudformation-old" + state: "absent" + +# Create a stack, pass in template from a URL, disable rollback if stack creation fails, +# pass in some parameters to the template, provide tags for resources created +- name: create a stack, pass in the template via an URL + cloudformation: + stack_name: "ansible-cloudformation" + state: present + region: us-east-1 + disable_rollback: true + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + template_parameters: + KeyName: jmartin + DiskType: ephemeral + InstanceType: m1.small + ClusterSize: 3 + tags: + Stack: ansible-cloudformation + +# Create a stack, passing in template body using lookup of Jinja2 template, disable rollback if stack creation fails, +# pass in some parameters to the template, provide tags for resources created +- name: create a stack, pass in the template body via lookup template + cloudformation: + stack_name: "ansible-cloudformation" + state: present + region: us-east-1 + disable_rollback: true + template_body: "{{ lookup('template', 'cloudformation.j2') }}" + template_parameters: + KeyName: jmartin + DiskType: ephemeral + InstanceType: m1.small + ClusterSize: 3 + tags: + Stack: ansible-cloudformation + +# Pass a template parameter which uses CloudFormation's UsePreviousValue attribute +# When use_previous_value is set to True, the given value will be ignored and +# CloudFormation will use the value from a previously submitted template. +# If use_previous_value is set to False (default) the given value is used. +- cloudformation: + stack_name: "ansible-cloudformation" + state: "present" + region: "us-east-1" + template: "files/cloudformation-example.json" + template_parameters: + DBSnapshotIdentifier: + use_previous_value: True + value: arn:aws:rds:es-east-1:000000000000:snapshot:rds:my-db-snapshot + DBName: + use_previous_value: True + tags: + Stack: "ansible-cloudformation" + +# Enable termination protection on a stack. +# If the stack already exists, this will update its termination protection +- name: enable termination protection during stack creation + cloudformation: + stack_name: my_stack + state: present + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + termination_protection: yes + +# Configure TimeoutInMinutes before the stack status becomes CREATE_FAILED +# In this case, if disable_rollback is not set or is set to false, the stack will be rolled back. +- name: enable termination protection during stack creation + cloudformation: + stack_name: my_stack + state: present + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + create_timeout: 5 + +# Configure rollback behaviour on the unsuccessful creation of a stack allowing +# CloudFormation to clean up, or do nothing in the event of an unsuccessful +# deployment +# In this case, if on_create_failure is set to "DELETE", it will clean up the stack if +# it fails to create +- name: create stack which will delete on creation failure + cloudformation: + stack_name: my_stack + state: present + template_url: https://s3.amazonaws.com/my-bucket/cloudformation.template + on_create_failure: DELETE +''' + +RETURN = ''' +events: + type: list + description: Most recent events in CloudFormation's event log. This may be from a previous run in some cases. + returned: always + sample: ["StackEvent AWS::CloudFormation::Stack stackname UPDATE_COMPLETE", "StackEvent AWS::CloudFormation::Stack stackname UPDATE_COMPLETE_CLEANUP_IN_PROGRESS"] +log: + description: Debugging logs. Useful when modifying or finding an error. + returned: always + type: list + sample: ["updating stack"] +change_set_id: + description: The ID of the stack change set if one was created + returned: I(state=present) and I(create_changeset=true) + type: str + sample: "arn:aws:cloudformation:us-east-1:012345678901:changeSet/Ansible-StackName-f4496805bd1b2be824d1e315c6884247ede41eb0" +stack_resources: + description: AWS stack resources and their status. List of dictionaries, one dict per resource. + returned: state == present + type: list + sample: [ + { + "last_updated_time": "2016-10-11T19:40:14.979000+00:00", + "logical_resource_id": "CFTestSg", + "physical_resource_id": "cloudformation2-CFTestSg-16UQ4CYQ57O9F", + "resource_type": "AWS::EC2::SecurityGroup", + "status": "UPDATE_COMPLETE", + "status_reason": null + } + ] +stack_outputs: + type: dict + description: A key:value dictionary of all the stack outputs currently defined. If there are no stack outputs, it is an empty dictionary. + returned: state == present + sample: {"MySg": "AnsibleModuleTestYAML-CFTestSg-C8UVS567B6NS"} +''' # NOQA + +import json +import time +import uuid +import traceback +from hashlib import sha1 + +try: + import boto3 + import botocore + HAS_BOTO3 = True +except ImportError: + HAS_BOTO3 = False + +from ansible.module_utils.ec2 import ansible_dict_to_boto3_tag_list, AWSRetry, boto3_conn, boto_exception, ec2_argument_spec, get_aws_connection_info +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_bytes, to_native + + +def get_stack_events(cfn, stack_name, events_limit, token_filter=None): + '''This event data was never correct, it worked as a side effect. So the v2.3 format is different.''' + ret = {'events': [], 'log': []} + + try: + pg = cfn.get_paginator( + 'describe_stack_events' + ).paginate( + StackName=stack_name, + PaginationConfig={'MaxItems': events_limit} + ) + if token_filter is not None: + events = list(pg.search( + "StackEvents[?ClientRequestToken == '{0}']".format(token_filter) + )) + else: + events = list(pg.search("StackEvents[*]")) + except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err: + error_msg = boto_exception(err) + if 'does not exist' in error_msg: + # missing stack, don't bail. + ret['log'].append('Stack does not exist.') + return ret + ret['log'].append('Unknown error: ' + str(error_msg)) + return ret + + for e in events: + eventline = 'StackEvent {ResourceType} {LogicalResourceId} {ResourceStatus}'.format(**e) + ret['events'].append(eventline) + + if e['ResourceStatus'].endswith('FAILED'): + failline = '{ResourceType} {LogicalResourceId} {ResourceStatus}: {ResourceStatusReason}'.format(**e) + ret['log'].append(failline) + + return ret + + +def create_stack(module, stack_params, cfn, events_limit): + if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params: + module.fail_json(msg="Either 'template', 'template_body' or 'template_url' is required when the stack does not exist.") + + # 'DisableRollback', 'TimeoutInMinutes', 'EnableTerminationProtection' and + # 'OnFailure' only apply on creation, not update. + if module.params.get('on_create_failure') is not None: + stack_params['OnFailure'] = module.params['on_create_failure'] + else: + stack_params['DisableRollback'] = module.params['disable_rollback'] + + if module.params.get('create_timeout') is not None: + stack_params['TimeoutInMinutes'] = module.params['create_timeout'] + if module.params.get('termination_protection') is not None: + if boto_supports_termination_protection(cfn): + stack_params['EnableTerminationProtection'] = bool(module.params.get('termination_protection')) + else: + module.fail_json(msg="termination_protection parameter requires botocore >= 1.7.18") + + try: + response = cfn.create_stack(**stack_params) + # Use stack ID to follow stack state in case of on_create_failure = DELETE + result = stack_operation(cfn, response['StackId'], 'CREATE', events_limit, stack_params.get('ClientRequestToken', None)) + except Exception as err: + error_msg = boto_exception(err) + module.fail_json(msg="Failed to create stack {0}: {1}.".format(stack_params.get('StackName'), error_msg), exception=traceback.format_exc()) + if not result: + module.fail_json(msg="empty result") + return result + + +def list_changesets(cfn, stack_name): + res = cfn.list_change_sets(StackName=stack_name) + return [cs['ChangeSetName'] for cs in res['Summaries']] + + +def create_changeset(module, stack_params, cfn, events_limit): + if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params: + module.fail_json(msg="Either 'template' or 'template_url' is required.") + if module.params['changeset_name'] is not None: + stack_params['ChangeSetName'] = module.params['changeset_name'] + + # changesets don't accept ClientRequestToken parameters + stack_params.pop('ClientRequestToken', None) + + try: + changeset_name = build_changeset_name(stack_params) + stack_params['ChangeSetName'] = changeset_name + + # Determine if this changeset already exists + pending_changesets = list_changesets(cfn, stack_params['StackName']) + if changeset_name in pending_changesets: + warning = 'WARNING: %d pending changeset(s) exist(s) for this stack!' % len(pending_changesets) + result = dict(changed=False, output='ChangeSet %s already exists.' % changeset_name, warnings=[warning]) + else: + cs = cfn.create_change_set(**stack_params) + # Make sure we don't enter an infinite loop + time_end = time.time() + 600 + while time.time() < time_end: + try: + newcs = cfn.describe_change_set(ChangeSetName=cs['Id']) + except botocore.exceptions.BotoCoreError as err: + error_msg = boto_exception(err) + module.fail_json(msg=error_msg) + if newcs['Status'] == 'CREATE_PENDING' or newcs['Status'] == 'CREATE_IN_PROGRESS': + time.sleep(1) + elif newcs['Status'] == 'FAILED' and "The submitted information didn't contain changes" in newcs['StatusReason']: + cfn.delete_change_set(ChangeSetName=cs['Id']) + result = dict(changed=False, + output='The created Change Set did not contain any changes to this stack and was deleted.') + # a failed change set does not trigger any stack events so we just want to + # skip any further processing of result and just return it directly + return result + else: + break + # Lets not hog the cpu/spam the AWS API + time.sleep(1) + result = stack_operation(cfn, stack_params['StackName'], 'CREATE_CHANGESET', events_limit) + result['change_set_id'] = cs['Id'] + result['warnings'] = ['Created changeset named %s for stack %s' % (changeset_name, stack_params['StackName']), + 'You can execute it using: aws cloudformation execute-change-set --change-set-name %s' % cs['Id'], + 'NOTE that dependencies on this stack might fail due to pending changes!'] + except Exception as err: + error_msg = boto_exception(err) + if 'No updates are to be performed.' in error_msg: + result = dict(changed=False, output='Stack is already up-to-date.') + else: + module.fail_json(msg="Failed to create change set: {0}".format(error_msg), exception=traceback.format_exc()) + + if not result: + module.fail_json(msg="empty result") + return result + + +def update_stack(module, stack_params, cfn, events_limit): + if 'TemplateBody' not in stack_params and 'TemplateURL' not in stack_params: + stack_params['UsePreviousTemplate'] = True + + # if the state is present and the stack already exists, we try to update it. + # AWS will tell us if the stack template and parameters are the same and + # don't need to be updated. + try: + cfn.update_stack(**stack_params) + result = stack_operation(cfn, stack_params['StackName'], 'UPDATE', events_limit, stack_params.get('ClientRequestToken', None)) + except Exception as err: + error_msg = boto_exception(err) + if 'No updates are to be performed.' in error_msg: + result = dict(changed=False, output='Stack is already up-to-date.') + else: + module.fail_json(msg="Failed to update stack {0}: {1}".format(stack_params.get('StackName'), error_msg), exception=traceback.format_exc()) + if not result: + module.fail_json(msg="empty result") + return result + + +def update_termination_protection(module, cfn, stack_name, desired_termination_protection_state): + '''updates termination protection of a stack''' + if not boto_supports_termination_protection(cfn): + module.fail_json(msg="termination_protection parameter requires botocore >= 1.7.18") + stack = get_stack_facts(cfn, stack_name) + if stack: + if stack['EnableTerminationProtection'] is not desired_termination_protection_state: + try: + cfn.update_termination_protection( + EnableTerminationProtection=desired_termination_protection_state, + StackName=stack_name) + except botocore.exceptions.ClientError as e: + module.fail_json(msg=boto_exception(e), exception=traceback.format_exc()) + + +def boto_supports_termination_protection(cfn): + '''termination protection was added in botocore 1.7.18''' + return hasattr(cfn, "update_termination_protection") + + +def stack_operation(cfn, stack_name, operation, events_limit, op_token=None): + '''gets the status of a stack while it is created/updated/deleted''' + existed = [] + while True: + try: + stack = get_stack_facts(cfn, stack_name) + existed.append('yes') + except Exception: + # If the stack previously existed, and now can't be found then it's + # been deleted successfully. + if 'yes' in existed or operation == 'DELETE': # stacks may delete fast, look in a few ways. + ret = get_stack_events(cfn, stack_name, events_limit, op_token) + ret.update({'changed': True, 'output': 'Stack Deleted'}) + return ret + else: + return {'changed': True, 'failed': True, 'output': 'Stack Not Found', 'exception': traceback.format_exc()} + ret = get_stack_events(cfn, stack_name, events_limit, op_token) + if not stack: + if 'yes' in existed or operation == 'DELETE': # stacks may delete fast, look in a few ways. + ret = get_stack_events(cfn, stack_name, events_limit, op_token) + ret.update({'changed': True, 'output': 'Stack Deleted'}) + return ret + else: + ret.update({'changed': False, 'failed': True, 'output': 'Stack not found.'}) + return ret + # it covers ROLLBACK_COMPLETE and UPDATE_ROLLBACK_COMPLETE + # Possible states: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-describing-stacks.html#w1ab2c15c17c21c13 + elif stack['StackStatus'].endswith('ROLLBACK_COMPLETE') and operation != 'CREATE_CHANGESET': + ret.update({'changed': True, 'failed': True, 'output': 'Problem with %s. Rollback complete' % operation}) + return ret + elif stack['StackStatus'] == 'DELETE_COMPLETE' and operation == 'CREATE': + ret.update({'changed': True, 'failed': True, 'output': 'Stack create failed. Delete complete.'}) + return ret + # note the ordering of ROLLBACK_COMPLETE, DELETE_COMPLETE, and COMPLETE, because otherwise COMPLETE will match all cases. + elif stack['StackStatus'].endswith('_COMPLETE'): + ret.update({'changed': True, 'output': 'Stack %s complete' % operation}) + return ret + elif stack['StackStatus'].endswith('_ROLLBACK_FAILED'): + ret.update({'changed': True, 'failed': True, 'output': 'Stack %s rollback failed' % operation}) + return ret + # note the ordering of ROLLBACK_FAILED and FAILED, because otherwise FAILED will match both cases. + elif stack['StackStatus'].endswith('_FAILED'): + ret.update({'changed': True, 'failed': True, 'output': 'Stack %s failed' % operation}) + return ret + else: + # this can loop forever :/ + time.sleep(5) + return {'failed': True, 'output': 'Failed for unknown reasons.'} + + +def build_changeset_name(stack_params): + if 'ChangeSetName' in stack_params: + return stack_params['ChangeSetName'] + + json_params = json.dumps(stack_params, sort_keys=True) + + return 'Ansible-{0}-{1}'.format( + stack_params['StackName'], + sha1(to_bytes(json_params, errors='surrogate_or_strict')).hexdigest() + ) + + +def check_mode_changeset(module, stack_params, cfn): + """Create a change set, describe it and delete it before returning check mode outputs.""" + stack_params['ChangeSetName'] = build_changeset_name(stack_params) + # changesets don't accept ClientRequestToken parameters + stack_params.pop('ClientRequestToken', None) + + try: + change_set = cfn.create_change_set(**stack_params) + for i in range(60): # total time 5 min + description = cfn.describe_change_set(ChangeSetName=change_set['Id']) + if description['Status'] in ('CREATE_COMPLETE', 'FAILED'): + break + time.sleep(5) + else: + # if the changeset doesn't finish in 5 mins, this `else` will trigger and fail + module.fail_json(msg="Failed to create change set %s" % stack_params['ChangeSetName']) + + cfn.delete_change_set(ChangeSetName=change_set['Id']) + + reason = description.get('StatusReason') + + if description['Status'] == 'FAILED' and "didn't contain changes" in description['StatusReason']: + return {'changed': False, 'msg': reason, 'meta': description['StatusReason']} + return {'changed': True, 'msg': reason, 'meta': description['Changes']} + + except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err: + error_msg = boto_exception(err) + module.fail_json(msg=error_msg, exception=traceback.format_exc()) + + +def get_stack_facts(cfn, stack_name): + try: + stack_response = cfn.describe_stacks(StackName=stack_name) + stack_info = stack_response['Stacks'][0] + except (botocore.exceptions.ValidationError, botocore.exceptions.ClientError) as err: + error_msg = boto_exception(err) + if 'does not exist' in error_msg: + # missing stack, don't bail. + return None + + # other error, bail. + raise err + + if stack_response and stack_response.get('Stacks', None): + stacks = stack_response['Stacks'] + if len(stacks): + stack_info = stacks[0] + + return stack_info + + +def main(): + argument_spec = ec2_argument_spec() + argument_spec.update(dict( + stack_name=dict(required=True), + template_parameters=dict(required=False, type='dict', default={}), + state=dict(default='present', choices=['present', 'absent']), + template=dict(default=None, required=False, type='path'), + notification_arns=dict(default=None, required=False), + stack_policy=dict(default=None, required=False), + disable_rollback=dict(default=False, type='bool'), + on_create_failure=dict(default=None, required=False, choices=['DO_NOTHING', 'ROLLBACK', 'DELETE']), + create_timeout=dict(default=None, type='int'), + template_url=dict(default=None, required=False), + template_body=dict(default=None, required=False), + template_format=dict(removed_in_version='2.14'), + create_changeset=dict(default=False, type='bool'), + changeset_name=dict(default=None, required=False), + role_arn=dict(default=None, required=False), + tags=dict(default=None, type='dict'), + termination_protection=dict(default=None, type='bool'), + events_limit=dict(default=200, type='int'), + backoff_retries=dict(type='int', default=10, required=False), + backoff_delay=dict(type='int', default=3, required=False), + backoff_max_delay=dict(type='int', default=30, required=False), + capabilities=dict(type='list', default=['CAPABILITY_IAM', 'CAPABILITY_NAMED_IAM']) + ) + ) + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=[['template_url', 'template', 'template_body'], + ['disable_rollback', 'on_create_failure']], + supports_check_mode=True + ) + if not HAS_BOTO3: + module.fail_json(msg='boto3 and botocore are required for this module') + + invalid_capabilities = [] + user_capabilities = module.params.get('capabilities') + for user_cap in user_capabilities: + if user_cap not in ['CAPABILITY_IAM', 'CAPABILITY_NAMED_IAM', 'CAPABILITY_AUTO_EXPAND']: + invalid_capabilities.append(user_cap) + + if invalid_capabilities: + module.fail_json(msg="Specified capabilities are invalid : %r," + " please check documentation for valid capabilities" % invalid_capabilities) + + # collect the parameters that are passed to boto3. Keeps us from having so many scalars floating around. + stack_params = { + 'Capabilities': user_capabilities, + 'ClientRequestToken': to_native(uuid.uuid4()), + } + state = module.params['state'] + stack_params['StackName'] = module.params['stack_name'] + + if module.params['template'] is not None: + with open(module.params['template'], 'r') as template_fh: + stack_params['TemplateBody'] = template_fh.read() + elif module.params['template_body'] is not None: + stack_params['TemplateBody'] = module.params['template_body'] + elif module.params['template_url'] is not None: + stack_params['TemplateURL'] = module.params['template_url'] + + if module.params.get('notification_arns'): + stack_params['NotificationARNs'] = module.params['notification_arns'].split(',') + else: + stack_params['NotificationARNs'] = [] + + # can't check the policy when verifying. + if module.params['stack_policy'] is not None and not module.check_mode and not module.params['create_changeset']: + with open(module.params['stack_policy'], 'r') as stack_policy_fh: + stack_params['StackPolicyBody'] = stack_policy_fh.read() + + template_parameters = module.params['template_parameters'] + + stack_params['Parameters'] = [] + for k, v in template_parameters.items(): + if isinstance(v, dict): + # set parameter based on a dict to allow additional CFN Parameter Attributes + param = dict(ParameterKey=k) + + if 'value' in v: + param['ParameterValue'] = str(v['value']) + + if 'use_previous_value' in v and bool(v['use_previous_value']): + param['UsePreviousValue'] = True + param.pop('ParameterValue', None) + + stack_params['Parameters'].append(param) + else: + # allow default k/v configuration to set a template parameter + stack_params['Parameters'].append({'ParameterKey': k, 'ParameterValue': str(v)}) + + if isinstance(module.params.get('tags'), dict): + stack_params['Tags'] = ansible_dict_to_boto3_tag_list(module.params['tags']) + + if module.params.get('role_arn'): + stack_params['RoleARN'] = module.params['role_arn'] + + result = {} + + try: + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True) + cfn = boto3_conn(module, conn_type='client', resource='cloudformation', region=region, endpoint=ec2_url, **aws_connect_kwargs) + except botocore.exceptions.NoCredentialsError as e: + module.fail_json(msg=boto_exception(e)) + + # Wrap the cloudformation client methods that this module uses with + # automatic backoff / retry for throttling error codes + backoff_wrapper = AWSRetry.jittered_backoff( + retries=module.params.get('backoff_retries'), + delay=module.params.get('backoff_delay'), + max_delay=module.params.get('backoff_max_delay') + ) + cfn.describe_stack_events = backoff_wrapper(cfn.describe_stack_events) + cfn.create_stack = backoff_wrapper(cfn.create_stack) + cfn.list_change_sets = backoff_wrapper(cfn.list_change_sets) + cfn.create_change_set = backoff_wrapper(cfn.create_change_set) + cfn.update_stack = backoff_wrapper(cfn.update_stack) + cfn.describe_stacks = backoff_wrapper(cfn.describe_stacks) + cfn.list_stack_resources = backoff_wrapper(cfn.list_stack_resources) + cfn.delete_stack = backoff_wrapper(cfn.delete_stack) + if boto_supports_termination_protection(cfn): + cfn.update_termination_protection = backoff_wrapper(cfn.update_termination_protection) + + stack_info = get_stack_facts(cfn, stack_params['StackName']) + + if module.check_mode: + if state == 'absent' and stack_info: + module.exit_json(changed=True, msg='Stack would be deleted', meta=[]) + elif state == 'absent' and not stack_info: + module.exit_json(changed=False, msg='Stack doesn\'t exist', meta=[]) + elif state == 'present' and not stack_info: + module.exit_json(changed=True, msg='New stack would be created', meta=[]) + else: + module.exit_json(**check_mode_changeset(module, stack_params, cfn)) + + if state == 'present': + if not stack_info: + result = create_stack(module, stack_params, cfn, module.params.get('events_limit')) + elif module.params.get('create_changeset'): + result = create_changeset(module, stack_params, cfn, module.params.get('events_limit')) + else: + if module.params.get('termination_protection') is not None: + update_termination_protection(module, cfn, stack_params['StackName'], + bool(module.params.get('termination_protection'))) + result = update_stack(module, stack_params, cfn, module.params.get('events_limit')) + + # format the stack output + + stack = get_stack_facts(cfn, stack_params['StackName']) + if stack is not None: + if result.get('stack_outputs') is None: + # always define stack_outputs, but it may be empty + result['stack_outputs'] = {} + for output in stack.get('Outputs', []): + result['stack_outputs'][output['OutputKey']] = output['OutputValue'] + stack_resources = [] + reslist = cfn.list_stack_resources(StackName=stack_params['StackName']) + for res in reslist.get('StackResourceSummaries', []): + stack_resources.append({ + "logical_resource_id": res['LogicalResourceId'], + "physical_resource_id": res.get('PhysicalResourceId', ''), + "resource_type": res['ResourceType'], + "last_updated_time": res['LastUpdatedTimestamp'], + "status": res['ResourceStatus'], + "status_reason": res.get('ResourceStatusReason') # can be blank, apparently + }) + result['stack_resources'] = stack_resources + + elif state == 'absent': + # absent state is different because of the way delete_stack works. + # problem is it it doesn't give an error if stack isn't found + # so must describe the stack first + + try: + stack = get_stack_facts(cfn, stack_params['StackName']) + if not stack: + result = {'changed': False, 'output': 'Stack not found.'} + else: + if stack_params.get('RoleARN') is None: + cfn.delete_stack(StackName=stack_params['StackName']) + else: + cfn.delete_stack(StackName=stack_params['StackName'], RoleARN=stack_params['RoleARN']) + result = stack_operation(cfn, stack_params['StackName'], 'DELETE', module.params.get('events_limit'), + stack_params.get('ClientRequestToken', None)) + except Exception as err: + module.fail_json(msg=boto_exception(err), exception=traceback.format_exc()) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/cloudformation_info.py b/test/support/integration/plugins/modules/cloudformation_info.py new file mode 100644 index 00000000..ee2e5c17 --- /dev/null +++ b/test/support/integration/plugins/modules/cloudformation_info.py @@ -0,0 +1,355 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: cloudformation_info +short_description: Obtain information about an AWS CloudFormation stack +description: + - Gets information about an AWS CloudFormation stack. + - This module was called C(cloudformation_facts) before Ansible 2.9, returning C(ansible_facts). + Note that the M(cloudformation_info) module no longer returns C(ansible_facts)! +requirements: + - boto3 >= 1.0.0 + - python >= 2.6 +version_added: "2.2" +author: + - Justin Menga (@jmenga) + - Kevin Coming (@waffie1) +options: + stack_name: + description: + - The name or id of the CloudFormation stack. Gathers information on all stacks by default. + type: str + all_facts: + description: + - Get all stack information for the stack. + type: bool + default: false + stack_events: + description: + - Get stack events for the stack. + type: bool + default: false + stack_template: + description: + - Get stack template body for the stack. + type: bool + default: false + stack_resources: + description: + - Get stack resources for the stack. + type: bool + default: false + stack_policy: + description: + - Get stack policy for the stack. + type: bool + default: false + stack_change_sets: + description: + - Get stack change sets for the stack + type: bool + default: false + version_added: '2.10' +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Get summary information about a stack +- cloudformation_info: + stack_name: my-cloudformation-stack + register: output + +- debug: + msg: "{{ output['cloudformation']['my-cloudformation-stack'] }}" + +# When the module is called as cloudformation_facts, return values are published +# in ansible_facts['cloudformation'][<stack_name>] and can be used as follows. +# Note that this is deprecated and will stop working in Ansible 2.13. + +- cloudformation_facts: + stack_name: my-cloudformation-stack + +- debug: + msg: "{{ ansible_facts['cloudformation']['my-cloudformation-stack'] }}" + +# Get stack outputs, when you have the stack name available as a fact +- set_fact: + stack_name: my-awesome-stack + +- cloudformation_info: + stack_name: "{{ stack_name }}" + register: my_stack + +- debug: + msg: "{{ my_stack.cloudformation[stack_name].stack_outputs }}" + +# Get all stack information about a stack +- cloudformation_info: + stack_name: my-cloudformation-stack + all_facts: true + +# Get stack resource and stack policy information about a stack +- cloudformation_info: + stack_name: my-cloudformation-stack + stack_resources: true + stack_policy: true + +# Fail if the stack doesn't exist +- name: try to get facts about a stack but fail if it doesn't exist + cloudformation_info: + stack_name: nonexistent-stack + all_facts: yes + failed_when: cloudformation['nonexistent-stack'] is undefined +''' + +RETURN = ''' +stack_description: + description: Summary facts about the stack + returned: if the stack exists + type: dict +stack_outputs: + description: Dictionary of stack outputs keyed by the value of each output 'OutputKey' parameter and corresponding value of each + output 'OutputValue' parameter + returned: if the stack exists + type: dict + sample: + ApplicationDatabaseName: dazvlpr01xj55a.ap-southeast-2.rds.amazonaws.com +stack_parameters: + description: Dictionary of stack parameters keyed by the value of each parameter 'ParameterKey' parameter and corresponding value of + each parameter 'ParameterValue' parameter + returned: if the stack exists + type: dict + sample: + DatabaseEngine: mysql + DatabasePassword: "***" +stack_events: + description: All stack events for the stack + returned: only if all_facts or stack_events is true and the stack exists + type: list +stack_policy: + description: Describes the stack policy for the stack + returned: only if all_facts or stack_policy is true and the stack exists + type: dict +stack_template: + description: Describes the stack template for the stack + returned: only if all_facts or stack_template is true and the stack exists + type: dict +stack_resource_list: + description: Describes stack resources for the stack + returned: only if all_facts or stack_resourses is true and the stack exists + type: list +stack_resources: + description: Dictionary of stack resources keyed by the value of each resource 'LogicalResourceId' parameter and corresponding value of each + resource 'PhysicalResourceId' parameter + returned: only if all_facts or stack_resourses is true and the stack exists + type: dict + sample: + AutoScalingGroup: "dev-someapp-AutoscalingGroup-1SKEXXBCAN0S7" + AutoScalingSecurityGroup: "sg-abcd1234" + ApplicationDatabase: "dazvlpr01xj55a" +stack_change_sets: + description: A list of stack change sets. Each item in the list represents the details of a specific changeset + + returned: only if all_facts or stack_change_sets is true and the stack exists + type: list +''' + +import json +import traceback + +from functools import partial +from ansible.module_utils._text import to_native +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import (camel_dict_to_snake_dict, AWSRetry, boto3_tag_list_to_ansible_dict) + +try: + import botocore +except ImportError: + pass # handled by AnsibleAWSModule + + +class CloudFormationServiceManager: + """Handles CloudFormation Services""" + + def __init__(self, module): + self.module = module + self.client = module.client('cloudformation') + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def describe_stacks_with_backoff(self, **kwargs): + paginator = self.client.get_paginator('describe_stacks') + return paginator.paginate(**kwargs).build_full_result()['Stacks'] + + def describe_stacks(self, stack_name=None): + try: + kwargs = {'StackName': stack_name} if stack_name else {} + response = self.describe_stacks_with_backoff(**kwargs) + if response is not None: + return response + self.module.fail_json(msg="Error describing stack(s) - an empty response was returned") + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + if 'does not exist' in e.response['Error']['Message']: + # missing stack, don't bail. + return {} + self.module.fail_json_aws(e, msg="Error describing stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def list_stack_resources_with_backoff(self, stack_name): + paginator = self.client.get_paginator('list_stack_resources') + return paginator.paginate(StackName=stack_name).build_full_result()['StackResourceSummaries'] + + def list_stack_resources(self, stack_name): + try: + return self.list_stack_resources_with_backoff(stack_name) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error listing stack resources for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def describe_stack_events_with_backoff(self, stack_name): + paginator = self.client.get_paginator('describe_stack_events') + return paginator.paginate(StackName=stack_name).build_full_result()['StackEvents'] + + def describe_stack_events(self, stack_name): + try: + return self.describe_stack_events_with_backoff(stack_name) + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error listing stack events for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def list_stack_change_sets_with_backoff(self, stack_name): + paginator = self.client.get_paginator('list_change_sets') + return paginator.paginate(StackName=stack_name).build_full_result()['Summaries'] + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def describe_stack_change_set_with_backoff(self, **kwargs): + paginator = self.client.get_paginator('describe_change_set') + return paginator.paginate(**kwargs).build_full_result() + + def describe_stack_change_sets(self, stack_name): + changes = [] + try: + change_sets = self.list_stack_change_sets_with_backoff(stack_name) + for item in change_sets: + changes.append(self.describe_stack_change_set_with_backoff( + StackName=stack_name, + ChangeSetName=item['ChangeSetName'])) + return changes + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error describing stack change sets for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def get_stack_policy_with_backoff(self, stack_name): + return self.client.get_stack_policy(StackName=stack_name) + + def get_stack_policy(self, stack_name): + try: + response = self.get_stack_policy_with_backoff(stack_name) + stack_policy = response.get('StackPolicyBody') + if stack_policy: + return json.loads(stack_policy) + return dict() + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error getting stack policy for stack " + stack_name) + + @AWSRetry.exponential_backoff(retries=5, delay=5) + def get_template_with_backoff(self, stack_name): + return self.client.get_template(StackName=stack_name) + + def get_template(self, stack_name): + try: + response = self.get_template_with_backoff(stack_name) + return response.get('TemplateBody') + except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError) as e: + self.module.fail_json_aws(e, msg="Error getting stack template for stack " + stack_name) + + +def to_dict(items, key, value): + ''' Transforms a list of items to a Key/Value dictionary ''' + if items: + return dict(zip([i.get(key) for i in items], [i.get(value) for i in items])) + else: + return dict() + + +def main(): + argument_spec = dict( + stack_name=dict(), + all_facts=dict(required=False, default=False, type='bool'), + stack_policy=dict(required=False, default=False, type='bool'), + stack_events=dict(required=False, default=False, type='bool'), + stack_resources=dict(required=False, default=False, type='bool'), + stack_template=dict(required=False, default=False, type='bool'), + stack_change_sets=dict(required=False, default=False, type='bool'), + ) + module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True) + + is_old_facts = module._name == 'cloudformation_facts' + if is_old_facts: + module.deprecate("The 'cloudformation_facts' module has been renamed to 'cloudformation_info', " + "and the renamed one no longer returns ansible_facts", + version='2.13', collection_name='ansible.builtin') + + service_mgr = CloudFormationServiceManager(module) + + if is_old_facts: + result = {'ansible_facts': {'cloudformation': {}}} + else: + result = {'cloudformation': {}} + + for stack_description in service_mgr.describe_stacks(module.params.get('stack_name')): + facts = {'stack_description': stack_description} + stack_name = stack_description.get('StackName') + + # Create stack output and stack parameter dictionaries + if facts['stack_description']: + facts['stack_outputs'] = to_dict(facts['stack_description'].get('Outputs'), 'OutputKey', 'OutputValue') + facts['stack_parameters'] = to_dict(facts['stack_description'].get('Parameters'), + 'ParameterKey', 'ParameterValue') + facts['stack_tags'] = boto3_tag_list_to_ansible_dict(facts['stack_description'].get('Tags')) + + # Create optional stack outputs + all_facts = module.params.get('all_facts') + if all_facts or module.params.get('stack_resources'): + facts['stack_resource_list'] = service_mgr.list_stack_resources(stack_name) + facts['stack_resources'] = to_dict(facts.get('stack_resource_list'), + 'LogicalResourceId', 'PhysicalResourceId') + if all_facts or module.params.get('stack_template'): + facts['stack_template'] = service_mgr.get_template(stack_name) + if all_facts or module.params.get('stack_policy'): + facts['stack_policy'] = service_mgr.get_stack_policy(stack_name) + if all_facts or module.params.get('stack_events'): + facts['stack_events'] = service_mgr.describe_stack_events(stack_name) + if all_facts or module.params.get('stack_change_sets'): + facts['stack_change_sets'] = service_mgr.describe_stack_change_sets(stack_name) + + if is_old_facts: + result['ansible_facts']['cloudformation'][stack_name] = facts + else: + result['cloudformation'][stack_name] = camel_dict_to_snake_dict(facts, ignore_list=('stack_outputs', + 'stack_parameters', + 'stack_policy', + 'stack_resources', + 'stack_tags', + 'stack_template')) + + module.exit_json(changed=False, **result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/deploy_helper.py b/test/support/integration/plugins/modules/deploy_helper.py new file mode 100644 index 00000000..38594dde --- /dev/null +++ b/test/support/integration/plugins/modules/deploy_helper.py @@ -0,0 +1,521 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2014, Jasper N. Brouwer <jasper@nerdsweide.nl> +# (c) 2014, Ramon de la Fuente <ramon@delafuente.nl> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: deploy_helper +version_added: "2.0" +author: "Ramon de la Fuente (@ramondelafuente)" +short_description: Manages some of the steps common in deploying projects. +description: + - The Deploy Helper manages some of the steps common in deploying software. + It creates a folder structure, manages a symlink for the current release + and cleans up old releases. + - "Running it with the C(state=query) or C(state=present) will return the C(deploy_helper) fact. + C(project_path), whatever you set in the path parameter, + C(current_path), the path to the symlink that points to the active release, + C(releases_path), the path to the folder to keep releases in, + C(shared_path), the path to the folder to keep shared resources in, + C(unfinished_filename), the file to check for to recognize unfinished builds, + C(previous_release), the release the 'current' symlink is pointing to, + C(previous_release_path), the full path to the 'current' symlink target, + C(new_release), either the 'release' parameter or a generated timestamp, + C(new_release_path), the path to the new release folder (not created by the module)." + +options: + path: + required: True + aliases: ['dest'] + description: + - the root path of the project. Alias I(dest). + Returned in the C(deploy_helper.project_path) fact. + + state: + description: + - the state of the project. + C(query) will only gather facts, + C(present) will create the project I(root) folder, and in it the I(releases) and I(shared) folders, + C(finalize) will remove the unfinished_filename file, create a symlink to the newly + deployed release and optionally clean old releases, + C(clean) will remove failed & old releases, + C(absent) will remove the project folder (synonymous to the M(file) module with C(state=absent)) + choices: [ present, finalize, absent, clean, query ] + default: present + + release: + description: + - the release version that is being deployed. Defaults to a timestamp format %Y%m%d%H%M%S (i.e. '20141119223359'). + This parameter is optional during C(state=present), but needs to be set explicitly for C(state=finalize). + You can use the generated fact C(release={{ deploy_helper.new_release }}). + + releases_path: + description: + - the name of the folder that will hold the releases. This can be relative to C(path) or absolute. + Returned in the C(deploy_helper.releases_path) fact. + default: releases + + shared_path: + description: + - the name of the folder that will hold the shared resources. This can be relative to C(path) or absolute. + If this is set to an empty string, no shared folder will be created. + Returned in the C(deploy_helper.shared_path) fact. + default: shared + + current_path: + description: + - the name of the symlink that is created when the deploy is finalized. Used in C(finalize) and C(clean). + Returned in the C(deploy_helper.current_path) fact. + default: current + + unfinished_filename: + description: + - the name of the file that indicates a deploy has not finished. All folders in the releases_path that + contain this file will be deleted on C(state=finalize) with clean=True, or C(state=clean). This file is + automatically deleted from the I(new_release_path) during C(state=finalize). + default: DEPLOY_UNFINISHED + + clean: + description: + - Whether to run the clean procedure in case of C(state=finalize). + type: bool + default: 'yes' + + keep_releases: + description: + - the number of old releases to keep when cleaning. Used in C(finalize) and C(clean). Any unfinished builds + will be deleted first, so only correct releases will count. The current version will not count. + default: 5 + +notes: + - Facts are only returned for C(state=query) and C(state=present). If you use both, you should pass any overridden + parameters to both calls, otherwise the second call will overwrite the facts of the first one. + - When using C(state=clean), the releases are ordered by I(creation date). You should be able to switch to a + new naming strategy without problems. + - Because of the default behaviour of generating the I(new_release) fact, this module will not be idempotent + unless you pass your own release name with C(release). Due to the nature of deploying software, this should not + be much of a problem. +''' + +EXAMPLES = ''' + +# General explanation, starting with an example folder structure for a project: + +# root: +# releases: +# - 20140415234508 +# - 20140415235146 +# - 20140416082818 +# +# shared: +# - sessions +# - uploads +# +# current: releases/20140416082818 + + +# The 'releases' folder holds all the available releases. A release is a complete build of the application being +# deployed. This can be a clone of a repository for example, or a sync of a local folder on your filesystem. +# Having timestamped folders is one way of having distinct releases, but you could choose your own strategy like +# git tags or commit hashes. +# +# During a deploy, a new folder should be created in the releases folder and any build steps required should be +# performed. Once the new build is ready, the deploy procedure is 'finalized' by replacing the 'current' symlink +# with a link to this build. +# +# The 'shared' folder holds any resource that is shared between releases. Examples of this are web-server +# session files, or files uploaded by users of your application. It's quite common to have symlinks from a release +# folder pointing to a shared/subfolder, and creating these links would be automated as part of the build steps. +# +# The 'current' symlink points to one of the releases. Probably the latest one, unless a deploy is in progress. +# The web-server's root for the project will go through this symlink, so the 'downtime' when switching to a new +# release is reduced to the time it takes to switch the link. +# +# To distinguish between successful builds and unfinished ones, a file can be placed in the folder of the release +# that is currently in progress. The existence of this file will mark it as unfinished, and allow an automated +# procedure to remove it during cleanup. + + +# Typical usage +- name: Initialize the deploy root and gather facts + deploy_helper: + path: /path/to/root +- name: Clone the project to the new release folder + git: + repo: git://foosball.example.org/path/to/repo.git + dest: '{{ deploy_helper.new_release_path }}' + version: v1.1.1 +- name: Add an unfinished file, to allow cleanup on successful finalize + file: + path: '{{ deploy_helper.new_release_path }}/{{ deploy_helper.unfinished_filename }}' + state: touch +- name: Perform some build steps, like running your dependency manager for example + composer: + command: install + working_dir: '{{ deploy_helper.new_release_path }}' +- name: Create some folders in the shared folder + file: + path: '{{ deploy_helper.shared_path }}/{{ item }}' + state: directory + with_items: + - sessions + - uploads +- name: Add symlinks from the new release to the shared folder + file: + path: '{{ deploy_helper.new_release_path }}/{{ item.path }}' + src: '{{ deploy_helper.shared_path }}/{{ item.src }}' + state: link + with_items: + - path: app/sessions + src: sessions + - path: web/uploads + src: uploads +- name: Finalize the deploy, removing the unfinished file and switching the symlink + deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + +# Retrieving facts before running a deploy +- name: Run 'state=query' to gather facts without changing anything + deploy_helper: + path: /path/to/root + state: query +# Remember to set the 'release' parameter when you actually call 'state=present' later +- name: Initialize the deploy root + deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: present + +# all paths can be absolute or relative (to the 'path' parameter) +- deploy_helper: + path: /path/to/root + releases_path: /var/www/project/releases + shared_path: /var/www/shared + current_path: /var/www/active + +# Using your own naming strategy for releases (a version tag in this case): +- deploy_helper: + path: /path/to/root + release: v1.1.1 + state: present +- deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + +# Using a different unfinished_filename: +- deploy_helper: + path: /path/to/root + unfinished_filename: README.md + release: '{{ deploy_helper.new_release }}' + state: finalize + +# Postponing the cleanup of older builds: +- deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + clean: False +- deploy_helper: + path: /path/to/root + state: clean +# Or running the cleanup ahead of the new deploy +- deploy_helper: + path: /path/to/root + state: clean +- deploy_helper: + path: /path/to/root + state: present + +# Keeping more old releases: +- deploy_helper: + path: /path/to/root + release: '{{ deploy_helper.new_release }}' + state: finalize + keep_releases: 10 +# Or, if you use 'clean=false' on finalize: +- deploy_helper: + path: /path/to/root + state: clean + keep_releases: 10 + +# Removing the entire project root folder +- deploy_helper: + path: /path/to/root + state: absent + +# Debugging the facts returned by the module +- deploy_helper: + path: /path/to/root +- debug: + var: deploy_helper +''' +import os +import shutil +import time +import traceback + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_native + + +class DeployHelper(object): + + def __init__(self, module): + self.module = module + self.file_args = module.load_file_common_arguments(module.params) + + self.clean = module.params['clean'] + self.current_path = module.params['current_path'] + self.keep_releases = module.params['keep_releases'] + self.path = module.params['path'] + self.release = module.params['release'] + self.releases_path = module.params['releases_path'] + self.shared_path = module.params['shared_path'] + self.state = module.params['state'] + self.unfinished_filename = module.params['unfinished_filename'] + + def gather_facts(self): + current_path = os.path.join(self.path, self.current_path) + releases_path = os.path.join(self.path, self.releases_path) + if self.shared_path: + shared_path = os.path.join(self.path, self.shared_path) + else: + shared_path = None + + previous_release, previous_release_path = self._get_last_release(current_path) + + if not self.release and (self.state == 'query' or self.state == 'present'): + self.release = time.strftime("%Y%m%d%H%M%S") + + if self.release: + new_release_path = os.path.join(releases_path, self.release) + else: + new_release_path = None + + return { + 'project_path': self.path, + 'current_path': current_path, + 'releases_path': releases_path, + 'shared_path': shared_path, + 'previous_release': previous_release, + 'previous_release_path': previous_release_path, + 'new_release': self.release, + 'new_release_path': new_release_path, + 'unfinished_filename': self.unfinished_filename + } + + def delete_path(self, path): + if not os.path.lexists(path): + return False + + if not os.path.isdir(path): + self.module.fail_json(msg="%s exists but is not a directory" % path) + + if not self.module.check_mode: + try: + shutil.rmtree(path, ignore_errors=False) + except Exception as e: + self.module.fail_json(msg="rmtree failed: %s" % to_native(e), exception=traceback.format_exc()) + + return True + + def create_path(self, path): + changed = False + + if not os.path.lexists(path): + changed = True + if not self.module.check_mode: + os.makedirs(path) + + elif not os.path.isdir(path): + self.module.fail_json(msg="%s exists but is not a directory" % path) + + changed += self.module.set_directory_attributes_if_different(self._get_file_args(path), changed) + + return changed + + def check_link(self, path): + if os.path.lexists(path): + if not os.path.islink(path): + self.module.fail_json(msg="%s exists but is not a symbolic link" % path) + + def create_link(self, source, link_name): + changed = False + + if os.path.islink(link_name): + norm_link = os.path.normpath(os.path.realpath(link_name)) + norm_source = os.path.normpath(os.path.realpath(source)) + if norm_link == norm_source: + changed = False + else: + changed = True + if not self.module.check_mode: + if not os.path.lexists(source): + self.module.fail_json(msg="the symlink target %s doesn't exists" % source) + tmp_link_name = link_name + '.' + self.unfinished_filename + if os.path.islink(tmp_link_name): + os.unlink(tmp_link_name) + os.symlink(source, tmp_link_name) + os.rename(tmp_link_name, link_name) + else: + changed = True + if not self.module.check_mode: + os.symlink(source, link_name) + + return changed + + def remove_unfinished_file(self, new_release_path): + changed = False + unfinished_file_path = os.path.join(new_release_path, self.unfinished_filename) + if os.path.lexists(unfinished_file_path): + changed = True + if not self.module.check_mode: + os.remove(unfinished_file_path) + + return changed + + def remove_unfinished_builds(self, releases_path): + changes = 0 + + for release in os.listdir(releases_path): + if os.path.isfile(os.path.join(releases_path, release, self.unfinished_filename)): + if self.module.check_mode: + changes += 1 + else: + changes += self.delete_path(os.path.join(releases_path, release)) + + return changes + + def remove_unfinished_link(self, path): + changed = False + + tmp_link_name = os.path.join(path, self.release + '.' + self.unfinished_filename) + if not self.module.check_mode and os.path.exists(tmp_link_name): + changed = True + os.remove(tmp_link_name) + + return changed + + def cleanup(self, releases_path, reserve_version): + changes = 0 + + if os.path.lexists(releases_path): + releases = [f for f in os.listdir(releases_path) if os.path.isdir(os.path.join(releases_path, f))] + try: + releases.remove(reserve_version) + except ValueError: + pass + + if not self.module.check_mode: + releases.sort(key=lambda x: os.path.getctime(os.path.join(releases_path, x)), reverse=True) + for release in releases[self.keep_releases:]: + changes += self.delete_path(os.path.join(releases_path, release)) + elif len(releases) > self.keep_releases: + changes += (len(releases) - self.keep_releases) + + return changes + + def _get_file_args(self, path): + file_args = self.file_args.copy() + file_args['path'] = path + return file_args + + def _get_last_release(self, current_path): + previous_release = None + previous_release_path = None + + if os.path.lexists(current_path): + previous_release_path = os.path.realpath(current_path) + previous_release = os.path.basename(previous_release_path) + + return previous_release, previous_release_path + + +def main(): + + module = AnsibleModule( + argument_spec=dict( + path=dict(aliases=['dest'], required=True, type='path'), + release=dict(required=False, type='str', default=None), + releases_path=dict(required=False, type='str', default='releases'), + shared_path=dict(required=False, type='path', default='shared'), + current_path=dict(required=False, type='path', default='current'), + keep_releases=dict(required=False, type='int', default=5), + clean=dict(required=False, type='bool', default=True), + unfinished_filename=dict(required=False, type='str', default='DEPLOY_UNFINISHED'), + state=dict(required=False, choices=['present', 'absent', 'clean', 'finalize', 'query'], default='present') + ), + add_file_common_args=True, + supports_check_mode=True + ) + + deploy_helper = DeployHelper(module) + facts = deploy_helper.gather_facts() + + result = { + 'state': deploy_helper.state + } + + changes = 0 + + if deploy_helper.state == 'query': + result['ansible_facts'] = {'deploy_helper': facts} + + elif deploy_helper.state == 'present': + deploy_helper.check_link(facts['current_path']) + changes += deploy_helper.create_path(facts['project_path']) + changes += deploy_helper.create_path(facts['releases_path']) + if deploy_helper.shared_path: + changes += deploy_helper.create_path(facts['shared_path']) + + result['ansible_facts'] = {'deploy_helper': facts} + + elif deploy_helper.state == 'finalize': + if not deploy_helper.release: + module.fail_json(msg="'release' is a required parameter for state=finalize (try the 'deploy_helper.new_release' fact)") + if deploy_helper.keep_releases <= 0: + module.fail_json(msg="'keep_releases' should be at least 1") + + changes += deploy_helper.remove_unfinished_file(facts['new_release_path']) + changes += deploy_helper.create_link(facts['new_release_path'], facts['current_path']) + if deploy_helper.clean: + changes += deploy_helper.remove_unfinished_link(facts['project_path']) + changes += deploy_helper.remove_unfinished_builds(facts['releases_path']) + changes += deploy_helper.cleanup(facts['releases_path'], facts['new_release']) + + elif deploy_helper.state == 'clean': + changes += deploy_helper.remove_unfinished_link(facts['project_path']) + changes += deploy_helper.remove_unfinished_builds(facts['releases_path']) + changes += deploy_helper.cleanup(facts['releases_path'], facts['new_release']) + + elif deploy_helper.state == 'absent': + # destroy the facts + result['ansible_facts'] = {'deploy_helper': []} + changes += deploy_helper.delete_path(facts['project_path']) + + if changes > 0: + result['changed'] = True + else: + result['changed'] = False + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/docker_swarm.py b/test/support/integration/plugins/modules/docker_swarm.py new file mode 100644 index 00000000..a2c076c5 --- /dev/null +++ b/test/support/integration/plugins/modules/docker_swarm.py @@ -0,0 +1,681 @@ +#!/usr/bin/python + +# Copyright 2016 Red Hat | Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +--- +module: docker_swarm +short_description: Manage Swarm cluster +version_added: "2.7" +description: + - Create a new Swarm cluster. + - Add/Remove nodes or managers to an existing cluster. +options: + advertise_addr: + description: + - Externally reachable address advertised to other nodes. + - This can either be an address/port combination + in the form C(192.168.1.1:4567), or an interface followed by a + port number, like C(eth0:4567). + - If the port number is omitted, + the port number from the listen address is used. + - If I(advertise_addr) is not specified, it will be automatically + detected when possible. + - Only used when swarm is initialised or joined. Because of this it's not + considered for idempotency checking. + type: str + default_addr_pool: + description: + - Default address pool in CIDR format. + - Only used when swarm is initialised. Because of this it's not considered + for idempotency checking. + - Requires API version >= 1.39. + type: list + elements: str + version_added: "2.8" + subnet_size: + description: + - Default address pool subnet mask length. + - Only used when swarm is initialised. Because of this it's not considered + for idempotency checking. + - Requires API version >= 1.39. + type: int + version_added: "2.8" + listen_addr: + description: + - Listen address used for inter-manager communication. + - This can either be an address/port combination in the form + C(192.168.1.1:4567), or an interface followed by a port number, + like C(eth0:4567). + - If the port number is omitted, the default swarm listening port + is used. + - Only used when swarm is initialised or joined. Because of this it's not + considered for idempotency checking. + type: str + default: 0.0.0.0:2377 + force: + description: + - Use with state C(present) to force creating a new Swarm, even if already part of one. + - Use with state C(absent) to Leave the swarm even if this node is a manager. + type: bool + default: no + state: + description: + - Set to C(present), to create/update a new cluster. + - Set to C(join), to join an existing cluster. + - Set to C(absent), to leave an existing cluster. + - Set to C(remove), to remove an absent node from the cluster. + Note that removing requires Docker SDK for Python >= 2.4.0. + - Set to C(inspect) to display swarm informations. + type: str + default: present + choices: + - present + - join + - absent + - remove + - inspect + node_id: + description: + - Swarm id of the node to remove. + - Used with I(state=remove). + type: str + join_token: + description: + - Swarm token used to join a swarm cluster. + - Used with I(state=join). + type: str + remote_addrs: + description: + - Remote address of one or more manager nodes of an existing Swarm to connect to. + - Used with I(state=join). + type: list + elements: str + task_history_retention_limit: + description: + - Maximum number of tasks history stored. + - Docker default value is C(5). + type: int + snapshot_interval: + description: + - Number of logs entries between snapshot. + - Docker default value is C(10000). + type: int + keep_old_snapshots: + description: + - Number of snapshots to keep beyond the current snapshot. + - Docker default value is C(0). + type: int + log_entries_for_slow_followers: + description: + - Number of log entries to keep around to sync up slow followers after a snapshot is created. + type: int + heartbeat_tick: + description: + - Amount of ticks (in seconds) between each heartbeat. + - Docker default value is C(1s). + type: int + election_tick: + description: + - Amount of ticks (in seconds) needed without a leader to trigger a new election. + - Docker default value is C(10s). + type: int + dispatcher_heartbeat_period: + description: + - The delay for an agent to send a heartbeat to the dispatcher. + - Docker default value is C(5s). + type: int + node_cert_expiry: + description: + - Automatic expiry for nodes certificates. + - Docker default value is C(3months). + type: int + name: + description: + - The name of the swarm. + type: str + labels: + description: + - User-defined key/value metadata. + - Label operations in this module apply to the docker swarm cluster. + Use M(docker_node) module to add/modify/remove swarm node labels. + - Requires API version >= 1.32. + type: dict + signing_ca_cert: + description: + - The desired signing CA certificate for all swarm node TLS leaf certificates, in PEM format. + - This must not be a path to a certificate, but the contents of the certificate. + - Requires API version >= 1.30. + type: str + signing_ca_key: + description: + - The desired signing CA key for all swarm node TLS leaf certificates, in PEM format. + - This must not be a path to a key, but the contents of the key. + - Requires API version >= 1.30. + type: str + ca_force_rotate: + description: + - An integer whose purpose is to force swarm to generate a new signing CA certificate and key, + if none have been specified. + - Docker default value is C(0). + - Requires API version >= 1.30. + type: int + autolock_managers: + description: + - If set, generate a key and use it to lock data stored on the managers. + - Docker default value is C(no). + - M(docker_swarm_info) can be used to retrieve the unlock key. + type: bool + rotate_worker_token: + description: Rotate the worker join token. + type: bool + default: no + rotate_manager_token: + description: Rotate the manager join token. + type: bool + default: no +extends_documentation_fragment: + - docker + - docker.docker_py_1_documentation +requirements: + - "L(Docker SDK for Python,https://docker-py.readthedocs.io/en/stable/) >= 1.10.0 (use L(docker-py,https://pypi.org/project/docker-py/) for Python 2.6)" + - Docker API >= 1.25 +author: + - Thierry Bouvet (@tbouvet) + - Piotr Wojciechowski (@WojciechowskiPiotr) +''' + +EXAMPLES = ''' + +- name: Init a new swarm with default parameters + docker_swarm: + state: present + +- name: Update swarm configuration + docker_swarm: + state: present + election_tick: 5 + +- name: Add nodes + docker_swarm: + state: join + advertise_addr: 192.168.1.2 + join_token: SWMTKN-1--xxxxx + remote_addrs: [ '192.168.1.1:2377' ] + +- name: Leave swarm for a node + docker_swarm: + state: absent + +- name: Remove a swarm manager + docker_swarm: + state: absent + force: true + +- name: Remove node from swarm + docker_swarm: + state: remove + node_id: mynode + +- name: Inspect swarm + docker_swarm: + state: inspect + register: swarm_info +''' + +RETURN = ''' +swarm_facts: + description: Informations about swarm. + returned: success + type: dict + contains: + JoinTokens: + description: Tokens to connect to the Swarm. + returned: success + type: dict + contains: + Worker: + description: Token to create a new *worker* node + returned: success + type: str + example: SWMTKN-1--xxxxx + Manager: + description: Token to create a new *manager* node + returned: success + type: str + example: SWMTKN-1--xxxxx + UnlockKey: + description: The swarm unlock-key if I(autolock_managers) is C(true). + returned: on success if I(autolock_managers) is C(true) + and swarm is initialised, or if I(autolock_managers) has changed. + type: str + example: SWMKEY-1-xxx + +actions: + description: Provides the actions done on the swarm. + returned: when action failed. + type: list + elements: str + example: "['This cluster is already a swarm cluster']" + +''' + +import json +import traceback + +try: + from docker.errors import DockerException, APIError +except ImportError: + # missing Docker SDK for Python handled in ansible.module_utils.docker.common + pass + +from ansible.module_utils.docker.common import ( + DockerBaseClass, + DifferenceTracker, + RequestException, +) + +from ansible.module_utils.docker.swarm import AnsibleDockerSwarmClient + +from ansible.module_utils._text import to_native + + +class TaskParameters(DockerBaseClass): + def __init__(self): + super(TaskParameters, self).__init__() + + self.advertise_addr = None + self.listen_addr = None + self.remote_addrs = None + self.join_token = None + + # Spec + self.snapshot_interval = None + self.task_history_retention_limit = None + self.keep_old_snapshots = None + self.log_entries_for_slow_followers = None + self.heartbeat_tick = None + self.election_tick = None + self.dispatcher_heartbeat_period = None + self.node_cert_expiry = None + self.name = None + self.labels = None + self.log_driver = None + self.signing_ca_cert = None + self.signing_ca_key = None + self.ca_force_rotate = None + self.autolock_managers = None + self.rotate_worker_token = None + self.rotate_manager_token = None + self.default_addr_pool = None + self.subnet_size = None + + @staticmethod + def from_ansible_params(client): + result = TaskParameters() + for key, value in client.module.params.items(): + if key in result.__dict__: + setattr(result, key, value) + + result.update_parameters(client) + return result + + def update_from_swarm_info(self, swarm_info): + spec = swarm_info['Spec'] + + ca_config = spec.get('CAConfig') or dict() + if self.node_cert_expiry is None: + self.node_cert_expiry = ca_config.get('NodeCertExpiry') + if self.ca_force_rotate is None: + self.ca_force_rotate = ca_config.get('ForceRotate') + + dispatcher = spec.get('Dispatcher') or dict() + if self.dispatcher_heartbeat_period is None: + self.dispatcher_heartbeat_period = dispatcher.get('HeartbeatPeriod') + + raft = spec.get('Raft') or dict() + if self.snapshot_interval is None: + self.snapshot_interval = raft.get('SnapshotInterval') + if self.keep_old_snapshots is None: + self.keep_old_snapshots = raft.get('KeepOldSnapshots') + if self.heartbeat_tick is None: + self.heartbeat_tick = raft.get('HeartbeatTick') + if self.log_entries_for_slow_followers is None: + self.log_entries_for_slow_followers = raft.get('LogEntriesForSlowFollowers') + if self.election_tick is None: + self.election_tick = raft.get('ElectionTick') + + orchestration = spec.get('Orchestration') or dict() + if self.task_history_retention_limit is None: + self.task_history_retention_limit = orchestration.get('TaskHistoryRetentionLimit') + + encryption_config = spec.get('EncryptionConfig') or dict() + if self.autolock_managers is None: + self.autolock_managers = encryption_config.get('AutoLockManagers') + + if self.name is None: + self.name = spec['Name'] + + if self.labels is None: + self.labels = spec.get('Labels') or {} + + if 'LogDriver' in spec['TaskDefaults']: + self.log_driver = spec['TaskDefaults']['LogDriver'] + + def update_parameters(self, client): + assign = dict( + snapshot_interval='snapshot_interval', + task_history_retention_limit='task_history_retention_limit', + keep_old_snapshots='keep_old_snapshots', + log_entries_for_slow_followers='log_entries_for_slow_followers', + heartbeat_tick='heartbeat_tick', + election_tick='election_tick', + dispatcher_heartbeat_period='dispatcher_heartbeat_period', + node_cert_expiry='node_cert_expiry', + name='name', + labels='labels', + signing_ca_cert='signing_ca_cert', + signing_ca_key='signing_ca_key', + ca_force_rotate='ca_force_rotate', + autolock_managers='autolock_managers', + log_driver='log_driver', + ) + params = dict() + for dest, source in assign.items(): + if not client.option_minimal_versions[source]['supported']: + continue + value = getattr(self, source) + if value is not None: + params[dest] = value + self.spec = client.create_swarm_spec(**params) + + def compare_to_active(self, other, client, differences): + for k in self.__dict__: + if k in ('advertise_addr', 'listen_addr', 'remote_addrs', 'join_token', + 'rotate_worker_token', 'rotate_manager_token', 'spec', + 'default_addr_pool', 'subnet_size'): + continue + if not client.option_minimal_versions[k]['supported']: + continue + value = getattr(self, k) + if value is None: + continue + other_value = getattr(other, k) + if value != other_value: + differences.add(k, parameter=value, active=other_value) + if self.rotate_worker_token: + differences.add('rotate_worker_token', parameter=True, active=False) + if self.rotate_manager_token: + differences.add('rotate_manager_token', parameter=True, active=False) + return differences + + +class SwarmManager(DockerBaseClass): + + def __init__(self, client, results): + + super(SwarmManager, self).__init__() + + self.client = client + self.results = results + self.check_mode = self.client.check_mode + self.swarm_info = {} + + self.state = client.module.params['state'] + self.force = client.module.params['force'] + self.node_id = client.module.params['node_id'] + + self.differences = DifferenceTracker() + self.parameters = TaskParameters.from_ansible_params(client) + + self.created = False + + def __call__(self): + choice_map = { + "present": self.init_swarm, + "join": self.join, + "absent": self.leave, + "remove": self.remove, + "inspect": self.inspect_swarm + } + + if self.state == 'inspect': + self.client.module.deprecate( + "The 'inspect' state is deprecated, please use 'docker_swarm_info' to inspect swarm cluster", + version='2.12', collection_name='ansible.builtin') + + choice_map.get(self.state)() + + if self.client.module._diff or self.parameters.debug: + diff = dict() + diff['before'], diff['after'] = self.differences.get_before_after() + self.results['diff'] = diff + + def inspect_swarm(self): + try: + data = self.client.inspect_swarm() + json_str = json.dumps(data, ensure_ascii=False) + self.swarm_info = json.loads(json_str) + + self.results['changed'] = False + self.results['swarm_facts'] = self.swarm_info + + unlock_key = self.get_unlock_key() + self.swarm_info.update(unlock_key) + except APIError: + return + + def get_unlock_key(self): + default = {'UnlockKey': None} + if not self.has_swarm_lock_changed(): + return default + try: + return self.client.get_unlock_key() or default + except APIError: + return default + + def has_swarm_lock_changed(self): + return self.parameters.autolock_managers and ( + self.created or self.differences.has_difference_for('autolock_managers') + ) + + def init_swarm(self): + if not self.force and self.client.check_if_swarm_manager(): + self.__update_swarm() + return + + if not self.check_mode: + init_arguments = { + 'advertise_addr': self.parameters.advertise_addr, + 'listen_addr': self.parameters.listen_addr, + 'force_new_cluster': self.force, + 'swarm_spec': self.parameters.spec, + } + if self.parameters.default_addr_pool is not None: + init_arguments['default_addr_pool'] = self.parameters.default_addr_pool + if self.parameters.subnet_size is not None: + init_arguments['subnet_size'] = self.parameters.subnet_size + try: + self.client.init_swarm(**init_arguments) + except APIError as exc: + self.client.fail("Can not create a new Swarm Cluster: %s" % to_native(exc)) + + if not self.client.check_if_swarm_manager(): + if not self.check_mode: + self.client.fail("Swarm not created or other error!") + + self.created = True + self.inspect_swarm() + self.results['actions'].append("New Swarm cluster created: %s" % (self.swarm_info.get('ID'))) + self.differences.add('state', parameter='present', active='absent') + self.results['changed'] = True + self.results['swarm_facts'] = { + 'JoinTokens': self.swarm_info.get('JoinTokens'), + 'UnlockKey': self.swarm_info.get('UnlockKey') + } + + def __update_swarm(self): + try: + self.inspect_swarm() + version = self.swarm_info['Version']['Index'] + self.parameters.update_from_swarm_info(self.swarm_info) + old_parameters = TaskParameters() + old_parameters.update_from_swarm_info(self.swarm_info) + self.parameters.compare_to_active(old_parameters, self.client, self.differences) + if self.differences.empty: + self.results['actions'].append("No modification") + self.results['changed'] = False + return + update_parameters = TaskParameters.from_ansible_params(self.client) + update_parameters.update_parameters(self.client) + if not self.check_mode: + self.client.update_swarm( + version=version, swarm_spec=update_parameters.spec, + rotate_worker_token=self.parameters.rotate_worker_token, + rotate_manager_token=self.parameters.rotate_manager_token) + except APIError as exc: + self.client.fail("Can not update a Swarm Cluster: %s" % to_native(exc)) + return + + self.inspect_swarm() + self.results['actions'].append("Swarm cluster updated") + self.results['changed'] = True + + def join(self): + if self.client.check_if_swarm_node(): + self.results['actions'].append("This node is already part of a swarm.") + return + if not self.check_mode: + try: + self.client.join_swarm( + remote_addrs=self.parameters.remote_addrs, join_token=self.parameters.join_token, + listen_addr=self.parameters.listen_addr, advertise_addr=self.parameters.advertise_addr) + except APIError as exc: + self.client.fail("Can not join the Swarm Cluster: %s" % to_native(exc)) + self.results['actions'].append("New node is added to swarm cluster") + self.differences.add('joined', parameter=True, active=False) + self.results['changed'] = True + + def leave(self): + if not self.client.check_if_swarm_node(): + self.results['actions'].append("This node is not part of a swarm.") + return + if not self.check_mode: + try: + self.client.leave_swarm(force=self.force) + except APIError as exc: + self.client.fail("This node can not leave the Swarm Cluster: %s" % to_native(exc)) + self.results['actions'].append("Node has left the swarm cluster") + self.differences.add('joined', parameter='absent', active='present') + self.results['changed'] = True + + def remove(self): + if not self.client.check_if_swarm_manager(): + self.client.fail("This node is not a manager.") + + try: + status_down = self.client.check_if_swarm_node_is_down(node_id=self.node_id, repeat_check=5) + except APIError: + return + + if not status_down: + self.client.fail("Can not remove the node. The status node is ready and not down.") + + if not self.check_mode: + try: + self.client.remove_node(node_id=self.node_id, force=self.force) + except APIError as exc: + self.client.fail("Can not remove the node from the Swarm Cluster: %s" % to_native(exc)) + self.results['actions'].append("Node is removed from swarm cluster.") + self.differences.add('joined', parameter=False, active=True) + self.results['changed'] = True + + +def _detect_remove_operation(client): + return client.module.params['state'] == 'remove' + + +def main(): + argument_spec = dict( + advertise_addr=dict(type='str'), + state=dict(type='str', default='present', choices=['present', 'join', 'absent', 'remove', 'inspect']), + force=dict(type='bool', default=False), + listen_addr=dict(type='str', default='0.0.0.0:2377'), + remote_addrs=dict(type='list', elements='str'), + join_token=dict(type='str'), + snapshot_interval=dict(type='int'), + task_history_retention_limit=dict(type='int'), + keep_old_snapshots=dict(type='int'), + log_entries_for_slow_followers=dict(type='int'), + heartbeat_tick=dict(type='int'), + election_tick=dict(type='int'), + dispatcher_heartbeat_period=dict(type='int'), + node_cert_expiry=dict(type='int'), + name=dict(type='str'), + labels=dict(type='dict'), + signing_ca_cert=dict(type='str'), + signing_ca_key=dict(type='str'), + ca_force_rotate=dict(type='int'), + autolock_managers=dict(type='bool'), + node_id=dict(type='str'), + rotate_worker_token=dict(type='bool', default=False), + rotate_manager_token=dict(type='bool', default=False), + default_addr_pool=dict(type='list', elements='str'), + subnet_size=dict(type='int'), + ) + + required_if = [ + ('state', 'join', ['advertise_addr', 'remote_addrs', 'join_token']), + ('state', 'remove', ['node_id']) + ] + + option_minimal_versions = dict( + labels=dict(docker_py_version='2.6.0', docker_api_version='1.32'), + signing_ca_cert=dict(docker_py_version='2.6.0', docker_api_version='1.30'), + signing_ca_key=dict(docker_py_version='2.6.0', docker_api_version='1.30'), + ca_force_rotate=dict(docker_py_version='2.6.0', docker_api_version='1.30'), + autolock_managers=dict(docker_py_version='2.6.0'), + log_driver=dict(docker_py_version='2.6.0'), + remove_operation=dict( + docker_py_version='2.4.0', + detect_usage=_detect_remove_operation, + usage_msg='remove swarm nodes' + ), + default_addr_pool=dict(docker_py_version='4.0.0', docker_api_version='1.39'), + subnet_size=dict(docker_py_version='4.0.0', docker_api_version='1.39'), + ) + + client = AnsibleDockerSwarmClient( + argument_spec=argument_spec, + supports_check_mode=True, + required_if=required_if, + min_docker_version='1.10.0', + min_docker_api_version='1.25', + option_minimal_versions=option_minimal_versions, + ) + + try: + results = dict( + changed=False, + result='', + actions=[] + ) + + SwarmManager(client, results)() + client.module.exit_json(**results) + except DockerException as e: + client.fail('An unexpected docker error occurred: {0}'.format(e), exception=traceback.format_exc()) + except RequestException as e: + client.fail('An unexpected requests error occurred when docker-py tried to talk to the docker daemon: {0}'.format(e), exception=traceback.format_exc()) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2.py b/test/support/integration/plugins/modules/ec2.py new file mode 100644 index 00000000..952aa5a1 --- /dev/null +++ b/test/support/integration/plugins/modules/ec2.py @@ -0,0 +1,1766 @@ +#!/usr/bin/python +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: ec2 +short_description: create, terminate, start or stop an instance in ec2 +description: + - Creates or terminates ec2 instances. + - > + Note: This module uses the older boto Python module to interact with the EC2 API. + M(ec2) will still receive bug fixes, but no new features. + Consider using the M(ec2_instance) module instead. + If M(ec2_instance) does not support a feature you need that is available in M(ec2), please + file a feature request. +version_added: "0.9" +options: + key_name: + description: + - Key pair to use on the instance. + - The SSH key must already exist in AWS in order to use this argument. + - Keys can be created / deleted using the M(ec2_key) module. + aliases: ['keypair'] + type: str + id: + version_added: "1.1" + description: + - Identifier for this instance or set of instances, so that the module will be idempotent with respect to EC2 instances. + - This identifier is valid for at least 24 hours after the termination of the instance, and should not be reused for another call later on. + - For details, see the description of client token at U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Run_Instance_Idempotency.html). + type: str + group: + description: + - Security group (or list of groups) to use with the instance. + aliases: [ 'groups' ] + type: list + elements: str + group_id: + version_added: "1.1" + description: + - Security group id (or list of ids) to use with the instance. + type: list + elements: str + zone: + version_added: "1.2" + description: + - AWS availability zone in which to launch the instance. + aliases: [ 'aws_zone', 'ec2_zone' ] + type: str + instance_type: + description: + - Instance type to use for the instance, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-types.html). + - Required when creating a new instance. + type: str + aliases: ['type'] + tenancy: + version_added: "1.9" + description: + - An instance with a tenancy of C(dedicated) runs on single-tenant hardware and can only be launched into a VPC. + - Note that to use dedicated tenancy you MUST specify a I(vpc_subnet_id) as well. + - Dedicated tenancy is not available for EC2 "micro" instances. + default: default + choices: [ "default", "dedicated" ] + type: str + spot_price: + version_added: "1.5" + description: + - Maximum spot price to bid. If not set, a regular on-demand instance is requested. + - A spot request is made with this maximum bid. When it is filled, the instance is started. + type: str + spot_type: + version_added: "2.0" + description: + - The type of spot request. + - After being interrupted a C(persistent) spot instance will be started once there is capacity to fill the request again. + default: "one-time" + choices: [ "one-time", "persistent" ] + type: str + image: + description: + - I(ami) ID to use for the instance. + - Required when I(state=present). + type: str + kernel: + description: + - Kernel eki to use for the instance. + type: str + ramdisk: + description: + - Ramdisk eri to use for the instance. + type: str + wait: + description: + - Wait for the instance to reach its desired state before returning. + - Does not wait for SSH, see the 'wait_for_connection' example for details. + type: bool + default: false + wait_timeout: + description: + - How long before wait gives up, in seconds. + default: 300 + type: int + spot_wait_timeout: + version_added: "1.5" + description: + - How long to wait for the spot instance request to be fulfilled. Affects 'Request valid until' for setting spot request lifespan. + default: 600 + type: int + count: + description: + - Number of instances to launch. + default: 1 + type: int + monitoring: + version_added: "1.1" + description: + - Enable detailed monitoring (CloudWatch) for instance. + type: bool + default: false + user_data: + version_added: "0.9" + description: + - Opaque blob of data which is made available to the EC2 instance. + type: str + instance_tags: + version_added: "1.0" + description: + - A hash/dictionary of tags to add to the new instance or for starting/stopping instance by tag; '{"key":"value"}' and '{"key":"value","key":"value"}'. + type: dict + placement_group: + version_added: "1.3" + description: + - Placement group for the instance when using EC2 Clustered Compute. + type: str + vpc_subnet_id: + version_added: "1.1" + description: + - the subnet ID in which to launch the instance (VPC). + type: str + assign_public_ip: + version_added: "1.5" + description: + - When provisioning within vpc, assign a public IP address. Boto library must be 2.13.0+. + type: bool + private_ip: + version_added: "1.2" + description: + - The private ip address to assign the instance (from the vpc subnet). + type: str + instance_profile_name: + version_added: "1.3" + description: + - Name of the IAM instance profile (i.e. what the EC2 console refers to as an "IAM Role") to use. Boto library must be 2.5.0+. + type: str + instance_ids: + version_added: "1.3" + description: + - "list of instance ids, currently used for states: absent, running, stopped" + aliases: ['instance_id'] + type: list + elements: str + source_dest_check: + version_added: "1.6" + description: + - Enable or Disable the Source/Destination checks (for NAT instances and Virtual Routers). + When initially creating an instance the EC2 API defaults this to C(True). + type: bool + termination_protection: + version_added: "2.0" + description: + - Enable or Disable the Termination Protection. + type: bool + default: false + instance_initiated_shutdown_behavior: + version_added: "2.2" + description: + - Set whether AWS will Stop or Terminate an instance on shutdown. This parameter is ignored when using instance-store. + images (which require termination on shutdown). + default: 'stop' + choices: [ "stop", "terminate" ] + type: str + state: + version_added: "1.3" + description: + - Create, terminate, start, stop or restart instances. The state 'restarted' was added in Ansible 2.2. + - When I(state=absent), I(instance_ids) is required. + - When I(state=running), I(state=stopped) or I(state=restarted) then either I(instance_ids) or I(instance_tags) is required. + default: 'present' + choices: ['absent', 'present', 'restarted', 'running', 'stopped'] + type: str + volumes: + version_added: "1.5" + description: + - A list of hash/dictionaries of volumes to add to the new instance. + type: list + elements: dict + suboptions: + device_name: + type: str + required: true + description: + - A name for the device (For example C(/dev/sda)). + delete_on_termination: + type: bool + default: false + description: + - Whether the volume should be automatically deleted when the instance is terminated. + ephemeral: + type: str + description: + - Whether the volume should be ephemeral. + - Data on ephemeral volumes is lost when the instance is stopped. + - Mutually exclusive with the I(snapshot) parameter. + encrypted: + type: bool + default: false + description: + - Whether the volume should be encrypted using the 'aws/ebs' KMS CMK. + snapshot: + type: str + description: + - The ID of an EBS snapshot to copy when creating the volume. + - Mutually exclusive with the I(ephemeral) parameter. + volume_type: + type: str + description: + - The type of volume to create. + - See U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html) for more information on the available volume types. + volume_size: + type: int + description: + - The size of the volume (in GiB). + iops: + type: int + description: + - The number of IOPS per second to provision for the volume. + - Required when I(volume_type=io1). + ebs_optimized: + version_added: "1.6" + description: + - Whether instance is using optimized EBS volumes, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSOptimized.html). + default: false + type: bool + exact_count: + version_added: "1.5" + description: + - An integer value which indicates how many instances that match the 'count_tag' parameter should be running. + Instances are either created or terminated based on this value. + type: int + count_tag: + version_added: "1.5" + description: + - Used with I(exact_count) to determine how many nodes based on a specific tag criteria should be running. + This can be expressed in multiple ways and is shown in the EXAMPLES section. For instance, one can request 25 servers + that are tagged with "class=webserver". The specified tag must already exist or be passed in as the I(instance_tags) option. + type: raw + network_interfaces: + version_added: "2.0" + description: + - A list of existing network interfaces to attach to the instance at launch. When specifying existing network interfaces, + none of the I(assign_public_ip), I(private_ip), I(vpc_subnet_id), I(group), or I(group_id) parameters may be used. (Those parameters are + for creating a new network interface at launch.) + aliases: ['network_interface'] + type: list + elements: str + spot_launch_group: + version_added: "2.1" + description: + - Launch group for spot requests, see U(https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html#spot-launch-group). + type: str +author: + - "Tim Gerla (@tgerla)" + - "Lester Wade (@lwade)" + - "Seth Vidal (@skvidal)" +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Basic provisioning example +- ec2: + key_name: mykey + instance_type: t2.micro + image: ami-123456 + wait: yes + group: webserver + count: 3 + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Advanced example with tagging and CloudWatch +- ec2: + key_name: mykey + group: databases + instance_type: t2.micro + image: ami-123456 + wait: yes + wait_timeout: 500 + count: 5 + instance_tags: + db: postgres + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Single instance with additional IOPS volume from snapshot and volume delete on termination +- ec2: + key_name: mykey + group: webserver + instance_type: c3.medium + image: ami-123456 + wait: yes + wait_timeout: 500 + volumes: + - device_name: /dev/sdb + snapshot: snap-abcdef12 + volume_type: io1 + iops: 1000 + volume_size: 100 + delete_on_termination: true + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Single instance with ssd gp2 root volume +- ec2: + key_name: mykey + group: webserver + instance_type: c3.medium + image: ami-123456 + wait: yes + wait_timeout: 500 + volumes: + - device_name: /dev/xvda + volume_type: gp2 + volume_size: 8 + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + count_tag: + Name: dbserver + exact_count: 1 + +# Multiple groups example +- ec2: + key_name: mykey + group: ['databases', 'internal-services', 'sshable', 'and-so-forth'] + instance_type: m1.large + image: ami-6e649707 + wait: yes + wait_timeout: 500 + count: 5 + instance_tags: + db: postgres + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Multiple instances with additional volume from snapshot +- ec2: + key_name: mykey + group: webserver + instance_type: m1.large + image: ami-6e649707 + wait: yes + wait_timeout: 500 + count: 5 + volumes: + - device_name: /dev/sdb + snapshot: snap-abcdef12 + volume_size: 10 + monitoring: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# Dedicated tenancy example +- local_action: + module: ec2 + assign_public_ip: yes + group_id: sg-1dc53f72 + key_name: mykey + image: ami-6e649707 + instance_type: m1.small + tenancy: dedicated + vpc_subnet_id: subnet-29e63245 + wait: yes + +# Spot instance example +- ec2: + spot_price: 0.24 + spot_wait_timeout: 600 + keypair: mykey + group_id: sg-1dc53f72 + instance_type: m1.small + image: ami-6e649707 + wait: yes + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + spot_launch_group: report_generators + instance_initiated_shutdown_behavior: terminate + +# Examples using pre-existing network interfaces +- ec2: + key_name: mykey + instance_type: t2.small + image: ami-f005ba11 + network_interface: eni-deadbeef + +- ec2: + key_name: mykey + instance_type: t2.small + image: ami-f005ba11 + network_interfaces: ['eni-deadbeef', 'eni-5ca1ab1e'] + +# Launch instances, runs some tasks +# and then terminate them + +- name: Create a sandbox instance + hosts: localhost + gather_facts: False + vars: + keypair: my_keypair + instance_type: m1.small + security_group: my_securitygroup + image: my_ami_id + region: us-east-1 + tasks: + - name: Launch instance + ec2: + key_name: "{{ keypair }}" + group: "{{ security_group }}" + instance_type: "{{ instance_type }}" + image: "{{ image }}" + wait: true + region: "{{ region }}" + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + register: ec2 + + - name: Add new instance to host group + add_host: + hostname: "{{ item.public_ip }}" + groupname: launched + loop: "{{ ec2.instances }}" + + - name: Wait for SSH to come up + delegate_to: "{{ item.public_dns_name }}" + wait_for_connection: + delay: 60 + timeout: 320 + loop: "{{ ec2.instances }}" + +- name: Configure instance(s) + hosts: launched + become: True + gather_facts: True + roles: + - my_awesome_role + - my_awesome_test + +- name: Terminate instances + hosts: localhost + tasks: + - name: Terminate instances that were previously launched + ec2: + state: 'absent' + instance_ids: '{{ ec2.instance_ids }}' + +# Start a few existing instances, run some tasks +# and stop the instances + +- name: Start sandbox instances + hosts: localhost + gather_facts: false + vars: + instance_ids: + - 'i-xxxxxx' + - 'i-xxxxxx' + - 'i-xxxxxx' + region: us-east-1 + tasks: + - name: Start the sandbox instances + ec2: + instance_ids: '{{ instance_ids }}' + region: '{{ region }}' + state: running + wait: True + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + roles: + - do_neat_stuff + - do_more_neat_stuff + +- name: Stop sandbox instances + hosts: localhost + gather_facts: false + vars: + instance_ids: + - 'i-xxxxxx' + - 'i-xxxxxx' + - 'i-xxxxxx' + region: us-east-1 + tasks: + - name: Stop the sandbox instances + ec2: + instance_ids: '{{ instance_ids }}' + region: '{{ region }}' + state: stopped + wait: True + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# +# Start stopped instances specified by tag +# +- local_action: + module: ec2 + instance_tags: + Name: ExtraPower + state: running + +# +# Restart instances specified by tag +# +- local_action: + module: ec2 + instance_tags: + Name: ExtraPower + state: restarted + +# +# Enforce that 5 instances with a tag "foo" are running +# (Highly recommended!) +# + +- ec2: + key_name: mykey + instance_type: c1.medium + image: ami-40603AD1 + wait: yes + group: webserver + instance_tags: + foo: bar + exact_count: 5 + count_tag: foo + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# +# Enforce that 5 running instances named "database" with a "dbtype" of "postgres" +# + +- ec2: + key_name: mykey + instance_type: c1.medium + image: ami-40603AD1 + wait: yes + group: webserver + instance_tags: + Name: database + dbtype: postgres + exact_count: 5 + count_tag: + Name: database + dbtype: postgres + vpc_subnet_id: subnet-29e63245 + assign_public_ip: yes + +# +# count_tag complex argument examples +# + + # instances with tag foo +- ec2: + count_tag: + foo: + + # instances with tag foo=bar +- ec2: + count_tag: + foo: bar + + # instances with tags foo=bar & baz +- ec2: + count_tag: + foo: bar + baz: + + # instances with tags foo & bar & baz=bang +- ec2: + count_tag: + - foo + - bar + - baz: bang + +''' + +import time +import datetime +import traceback +from ast import literal_eval +from distutils.version import LooseVersion + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.ec2 import get_aws_connection_info, ec2_argument_spec, ec2_connect +from ansible.module_utils.six import get_function_code, string_types +from ansible.module_utils._text import to_bytes, to_text + +try: + import boto.ec2 + from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping + from boto.exception import EC2ResponseError + from boto import connect_ec2_endpoint + from boto import connect_vpc + HAS_BOTO = True +except ImportError: + HAS_BOTO = False + + +def find_running_instances_by_count_tag(module, ec2, vpc, count_tag, zone=None): + + # get reservations for instances that match tag(s) and are in the desired state + state = module.params.get('state') + if state not in ['running', 'stopped']: + state = None + reservations = get_reservations(module, ec2, vpc, tags=count_tag, state=state, zone=zone) + + instances = [] + for res in reservations: + if hasattr(res, 'instances'): + for inst in res.instances: + if inst.state == 'terminated' or inst.state == 'shutting-down': + continue + instances.append(inst) + + return reservations, instances + + +def _set_none_to_blank(dictionary): + result = dictionary + for k in result: + if isinstance(result[k], dict): + result[k] = _set_none_to_blank(result[k]) + elif not result[k]: + result[k] = "" + return result + + +def get_reservations(module, ec2, vpc, tags=None, state=None, zone=None): + # TODO: filters do not work with tags that have underscores + filters = dict() + + vpc_subnet_id = module.params.get('vpc_subnet_id') + vpc_id = None + if vpc_subnet_id: + filters.update({"subnet-id": vpc_subnet_id}) + if vpc: + vpc_id = vpc.get_all_subnets(subnet_ids=[vpc_subnet_id])[0].vpc_id + + if vpc_id: + filters.update({"vpc-id": vpc_id}) + + if tags is not None: + + if isinstance(tags, str): + try: + tags = literal_eval(tags) + except Exception: + pass + + # if not a string type, convert and make sure it's a text string + if isinstance(tags, int): + tags = to_text(tags) + + # if string, we only care that a tag of that name exists + if isinstance(tags, str): + filters.update({"tag-key": tags}) + + # if list, append each item to filters + if isinstance(tags, list): + for x in tags: + if isinstance(x, dict): + x = _set_none_to_blank(x) + filters.update(dict(("tag:" + tn, tv) for (tn, tv) in x.items())) + else: + filters.update({"tag-key": x}) + + # if dict, add the key and value to the filter + if isinstance(tags, dict): + tags = _set_none_to_blank(tags) + filters.update(dict(("tag:" + tn, tv) for (tn, tv) in tags.items())) + + # lets check to see if the filters dict is empty, if so then stop + if not filters: + module.fail_json(msg="Filters based on tag is empty => tags: %s" % (tags)) + + if state: + # http://stackoverflow.com/questions/437511/what-are-the-valid-instancestates-for-the-amazon-ec2-api + filters.update({'instance-state-name': state}) + + if zone: + filters.update({'availability-zone': zone}) + + if module.params.get('id'): + filters['client-token'] = module.params['id'] + + results = ec2.get_all_instances(filters=filters) + + return results + + +def get_instance_info(inst): + """ + Retrieves instance information from an instance + ID and returns it as a dictionary + """ + instance_info = {'id': inst.id, + 'ami_launch_index': inst.ami_launch_index, + 'private_ip': inst.private_ip_address, + 'private_dns_name': inst.private_dns_name, + 'public_ip': inst.ip_address, + 'dns_name': inst.dns_name, + 'public_dns_name': inst.public_dns_name, + 'state_code': inst.state_code, + 'architecture': inst.architecture, + 'image_id': inst.image_id, + 'key_name': inst.key_name, + 'placement': inst.placement, + 'region': inst.placement[:-1], + 'kernel': inst.kernel, + 'ramdisk': inst.ramdisk, + 'launch_time': inst.launch_time, + 'instance_type': inst.instance_type, + 'root_device_type': inst.root_device_type, + 'root_device_name': inst.root_device_name, + 'state': inst.state, + 'hypervisor': inst.hypervisor, + 'tags': inst.tags, + 'groups': dict((group.id, group.name) for group in inst.groups), + } + try: + instance_info['virtualization_type'] = getattr(inst, 'virtualization_type') + except AttributeError: + instance_info['virtualization_type'] = None + + try: + instance_info['ebs_optimized'] = getattr(inst, 'ebs_optimized') + except AttributeError: + instance_info['ebs_optimized'] = False + + try: + bdm_dict = {} + bdm = getattr(inst, 'block_device_mapping') + for device_name in bdm.keys(): + bdm_dict[device_name] = { + 'status': bdm[device_name].status, + 'volume_id': bdm[device_name].volume_id, + 'delete_on_termination': bdm[device_name].delete_on_termination + } + instance_info['block_device_mapping'] = bdm_dict + except AttributeError: + instance_info['block_device_mapping'] = False + + try: + instance_info['tenancy'] = getattr(inst, 'placement_tenancy') + except AttributeError: + instance_info['tenancy'] = 'default' + + return instance_info + + +def boto_supports_associate_public_ip_address(ec2): + """ + Check if Boto library has associate_public_ip_address in the NetworkInterfaceSpecification + class. Added in Boto 2.13.0 + + ec2: authenticated ec2 connection object + + Returns: + True if Boto library accepts associate_public_ip_address argument, else false + """ + + try: + network_interface = boto.ec2.networkinterface.NetworkInterfaceSpecification() + getattr(network_interface, "associate_public_ip_address") + return True + except AttributeError: + return False + + +def boto_supports_profile_name_arg(ec2): + """ + Check if Boto library has instance_profile_name argument. instance_profile_name has been added in Boto 2.5.0 + + ec2: authenticated ec2 connection object + + Returns: + True if Boto library accept instance_profile_name argument, else false + """ + run_instances_method = getattr(ec2, 'run_instances') + return 'instance_profile_name' in get_function_code(run_instances_method).co_varnames + + +def boto_supports_volume_encryption(): + """ + Check if Boto library supports encryption of EBS volumes (added in 2.29.0) + + Returns: + True if boto library has the named param as an argument on the request_spot_instances method, else False + """ + return hasattr(boto, 'Version') and LooseVersion(boto.Version) >= LooseVersion('2.29.0') + + +def create_block_device(module, ec2, volume): + # Not aware of a way to determine this programatically + # http://aws.amazon.com/about-aws/whats-new/2013/10/09/ebs-provisioned-iops-maximum-iops-gb-ratio-increased-to-30-1/ + MAX_IOPS_TO_SIZE_RATIO = 30 + + volume_type = volume.get('volume_type') + + if 'snapshot' not in volume and 'ephemeral' not in volume: + if 'volume_size' not in volume: + module.fail_json(msg='Size must be specified when creating a new volume or modifying the root volume') + if 'snapshot' in volume: + if volume_type == 'io1' and 'iops' not in volume: + module.fail_json(msg='io1 volumes must have an iops value set') + if 'iops' in volume: + snapshot = ec2.get_all_snapshots(snapshot_ids=[volume['snapshot']])[0] + size = volume.get('volume_size', snapshot.volume_size) + if int(volume['iops']) > MAX_IOPS_TO_SIZE_RATIO * size: + module.fail_json(msg='IOPS must be at most %d times greater than size' % MAX_IOPS_TO_SIZE_RATIO) + if 'ephemeral' in volume: + if 'snapshot' in volume: + module.fail_json(msg='Cannot set both ephemeral and snapshot') + if boto_supports_volume_encryption(): + return BlockDeviceType(snapshot_id=volume.get('snapshot'), + ephemeral_name=volume.get('ephemeral'), + size=volume.get('volume_size'), + volume_type=volume_type, + delete_on_termination=volume.get('delete_on_termination', False), + iops=volume.get('iops'), + encrypted=volume.get('encrypted', None)) + else: + return BlockDeviceType(snapshot_id=volume.get('snapshot'), + ephemeral_name=volume.get('ephemeral'), + size=volume.get('volume_size'), + volume_type=volume_type, + delete_on_termination=volume.get('delete_on_termination', False), + iops=volume.get('iops')) + + +def boto_supports_param_in_spot_request(ec2, param): + """ + Check if Boto library has a <param> in its request_spot_instances() method. For example, the placement_group parameter wasn't added until 2.3.0. + + ec2: authenticated ec2 connection object + + Returns: + True if boto library has the named param as an argument on the request_spot_instances method, else False + """ + method = getattr(ec2, 'request_spot_instances') + return param in get_function_code(method).co_varnames + + +def await_spot_requests(module, ec2, spot_requests, count): + """ + Wait for a group of spot requests to be fulfilled, or fail. + + module: Ansible module object + ec2: authenticated ec2 connection object + spot_requests: boto.ec2.spotinstancerequest.SpotInstanceRequest object returned by ec2.request_spot_instances + count: Total number of instances to be created by the spot requests + + Returns: + list of instance ID's created by the spot request(s) + """ + spot_wait_timeout = int(module.params.get('spot_wait_timeout')) + wait_complete = time.time() + spot_wait_timeout + + spot_req_inst_ids = dict() + while time.time() < wait_complete: + reqs = ec2.get_all_spot_instance_requests() + for sirb in spot_requests: + if sirb.id in spot_req_inst_ids: + continue + for sir in reqs: + if sir.id != sirb.id: + continue # this is not our spot instance + if sir.instance_id is not None: + spot_req_inst_ids[sirb.id] = sir.instance_id + elif sir.state == 'open': + continue # still waiting, nothing to do here + elif sir.state == 'active': + continue # Instance is created already, nothing to do here + elif sir.state == 'failed': + module.fail_json(msg="Spot instance request %s failed with status %s and fault %s:%s" % ( + sir.id, sir.status.code, sir.fault.code, sir.fault.message)) + elif sir.state == 'cancelled': + module.fail_json(msg="Spot instance request %s was cancelled before it could be fulfilled." % sir.id) + elif sir.state == 'closed': + # instance is terminating or marked for termination + # this may be intentional on the part of the operator, + # or it may have been terminated by AWS due to capacity, + # price, or group constraints in this case, we'll fail + # the module if the reason for the state is anything + # other than termination by user. Codes are documented at + # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-bid-status.html + if sir.status.code == 'instance-terminated-by-user': + # do nothing, since the user likely did this on purpose + pass + else: + spot_msg = "Spot instance request %s was closed by AWS with the status %s and fault %s:%s" + module.fail_json(msg=spot_msg % (sir.id, sir.status.code, sir.fault.code, sir.fault.message)) + + if len(spot_req_inst_ids) < count: + time.sleep(5) + else: + return list(spot_req_inst_ids.values()) + module.fail_json(msg="wait for spot requests timeout on %s" % time.asctime()) + + +def enforce_count(module, ec2, vpc): + + exact_count = module.params.get('exact_count') + count_tag = module.params.get('count_tag') + zone = module.params.get('zone') + + # fail here if the exact count was specified without filtering + # on a tag, as this may lead to a undesired removal of instances + if exact_count and count_tag is None: + module.fail_json(msg="you must use the 'count_tag' option with exact_count") + + reservations, instances = find_running_instances_by_count_tag(module, ec2, vpc, count_tag, zone) + + changed = None + checkmode = False + instance_dict_array = [] + changed_instance_ids = None + + if len(instances) == exact_count: + changed = False + elif len(instances) < exact_count: + changed = True + to_create = exact_count - len(instances) + if not checkmode: + (instance_dict_array, changed_instance_ids, changed) \ + = create_instances(module, ec2, vpc, override_count=to_create) + + for inst in instance_dict_array: + instances.append(inst) + elif len(instances) > exact_count: + changed = True + to_remove = len(instances) - exact_count + if not checkmode: + all_instance_ids = sorted([x.id for x in instances]) + remove_ids = all_instance_ids[0:to_remove] + + instances = [x for x in instances if x.id not in remove_ids] + + (changed, instance_dict_array, changed_instance_ids) \ + = terminate_instances(module, ec2, remove_ids) + terminated_list = [] + for inst in instance_dict_array: + inst['state'] = "terminated" + terminated_list.append(inst) + instance_dict_array = terminated_list + + # ensure all instances are dictionaries + all_instances = [] + for inst in instances: + + if not isinstance(inst, dict): + warn_if_public_ip_assignment_changed(module, inst) + inst = get_instance_info(inst) + all_instances.append(inst) + + return (all_instances, instance_dict_array, changed_instance_ids, changed) + + +def create_instances(module, ec2, vpc, override_count=None): + """ + Creates new instances + + module : AnsibleModule object + ec2: authenticated ec2 connection object + + Returns: + A list of dictionaries with instance information + about the instances that were launched + """ + + key_name = module.params.get('key_name') + id = module.params.get('id') + group_name = module.params.get('group') + group_id = module.params.get('group_id') + zone = module.params.get('zone') + instance_type = module.params.get('instance_type') + tenancy = module.params.get('tenancy') + spot_price = module.params.get('spot_price') + spot_type = module.params.get('spot_type') + image = module.params.get('image') + if override_count: + count = override_count + else: + count = module.params.get('count') + monitoring = module.params.get('monitoring') + kernel = module.params.get('kernel') + ramdisk = module.params.get('ramdisk') + wait = module.params.get('wait') + wait_timeout = int(module.params.get('wait_timeout')) + spot_wait_timeout = int(module.params.get('spot_wait_timeout')) + placement_group = module.params.get('placement_group') + user_data = module.params.get('user_data') + instance_tags = module.params.get('instance_tags') + vpc_subnet_id = module.params.get('vpc_subnet_id') + assign_public_ip = module.boolean(module.params.get('assign_public_ip')) + private_ip = module.params.get('private_ip') + instance_profile_name = module.params.get('instance_profile_name') + volumes = module.params.get('volumes') + ebs_optimized = module.params.get('ebs_optimized') + exact_count = module.params.get('exact_count') + count_tag = module.params.get('count_tag') + source_dest_check = module.boolean(module.params.get('source_dest_check')) + termination_protection = module.boolean(module.params.get('termination_protection')) + network_interfaces = module.params.get('network_interfaces') + spot_launch_group = module.params.get('spot_launch_group') + instance_initiated_shutdown_behavior = module.params.get('instance_initiated_shutdown_behavior') + + vpc_id = None + if vpc_subnet_id: + if not vpc: + module.fail_json(msg="region must be specified") + else: + vpc_id = vpc.get_all_subnets(subnet_ids=[vpc_subnet_id])[0].vpc_id + else: + vpc_id = None + + try: + # Here we try to lookup the group id from the security group name - if group is set. + if group_name: + if vpc_id: + grp_details = ec2.get_all_security_groups(filters={'vpc_id': vpc_id}) + else: + grp_details = ec2.get_all_security_groups() + if isinstance(group_name, string_types): + group_name = [group_name] + unmatched = set(group_name).difference(str(grp.name) for grp in grp_details) + if len(unmatched) > 0: + module.fail_json(msg="The following group names are not valid: %s" % ', '.join(unmatched)) + group_id = [str(grp.id) for grp in grp_details if str(grp.name) in group_name] + # Now we try to lookup the group id testing if group exists. + elif group_id: + # wrap the group_id in a list if it's not one already + if isinstance(group_id, string_types): + group_id = [group_id] + grp_details = ec2.get_all_security_groups(group_ids=group_id) + group_name = [grp_item.name for grp_item in grp_details] + except boto.exception.NoAuthHandlerFound as e: + module.fail_json(msg=str(e)) + + # Lookup any instances that much our run id. + + running_instances = [] + count_remaining = int(count) + + if id is not None: + filter_dict = {'client-token': id, 'instance-state-name': 'running'} + previous_reservations = ec2.get_all_instances(None, filter_dict) + for res in previous_reservations: + for prev_instance in res.instances: + running_instances.append(prev_instance) + count_remaining = count_remaining - len(running_instances) + + # Both min_count and max_count equal count parameter. This means the launch request is explicit (we want count, or fail) in how many instances we want. + + if count_remaining == 0: + changed = False + else: + changed = True + try: + params = {'image_id': image, + 'key_name': key_name, + 'monitoring_enabled': monitoring, + 'placement': zone, + 'instance_type': instance_type, + 'kernel_id': kernel, + 'ramdisk_id': ramdisk} + if user_data is not None: + params['user_data'] = to_bytes(user_data, errors='surrogate_or_strict') + + if ebs_optimized: + params['ebs_optimized'] = ebs_optimized + + # 'tenancy' always has a default value, but it is not a valid parameter for spot instance request + if not spot_price: + params['tenancy'] = tenancy + + if boto_supports_profile_name_arg(ec2): + params['instance_profile_name'] = instance_profile_name + else: + if instance_profile_name is not None: + module.fail_json( + msg="instance_profile_name parameter requires Boto version 2.5.0 or higher") + + if assign_public_ip is not None: + if not boto_supports_associate_public_ip_address(ec2): + module.fail_json( + msg="assign_public_ip parameter requires Boto version 2.13.0 or higher.") + elif not vpc_subnet_id: + module.fail_json( + msg="assign_public_ip only available with vpc_subnet_id") + + else: + if private_ip: + interface = boto.ec2.networkinterface.NetworkInterfaceSpecification( + subnet_id=vpc_subnet_id, + private_ip_address=private_ip, + groups=group_id, + associate_public_ip_address=assign_public_ip) + else: + interface = boto.ec2.networkinterface.NetworkInterfaceSpecification( + subnet_id=vpc_subnet_id, + groups=group_id, + associate_public_ip_address=assign_public_ip) + interfaces = boto.ec2.networkinterface.NetworkInterfaceCollection(interface) + params['network_interfaces'] = interfaces + else: + if network_interfaces: + if isinstance(network_interfaces, string_types): + network_interfaces = [network_interfaces] + interfaces = [] + for i, network_interface_id in enumerate(network_interfaces): + interface = boto.ec2.networkinterface.NetworkInterfaceSpecification( + network_interface_id=network_interface_id, + device_index=i) + interfaces.append(interface) + params['network_interfaces'] = \ + boto.ec2.networkinterface.NetworkInterfaceCollection(*interfaces) + else: + params['subnet_id'] = vpc_subnet_id + if vpc_subnet_id: + params['security_group_ids'] = group_id + else: + params['security_groups'] = group_name + + if volumes: + bdm = BlockDeviceMapping() + for volume in volumes: + if 'device_name' not in volume: + module.fail_json(msg='Device name must be set for volume') + # Minimum volume size is 1GiB. We'll use volume size explicitly set to 0 + # to be a signal not to create this volume + if 'volume_size' not in volume or int(volume['volume_size']) > 0: + bdm[volume['device_name']] = create_block_device(module, ec2, volume) + + params['block_device_map'] = bdm + + # check to see if we're using spot pricing first before starting instances + if not spot_price: + if assign_public_ip is not None and private_ip: + params.update( + dict( + min_count=count_remaining, + max_count=count_remaining, + client_token=id, + placement_group=placement_group, + ) + ) + else: + params.update( + dict( + min_count=count_remaining, + max_count=count_remaining, + client_token=id, + placement_group=placement_group, + private_ip_address=private_ip, + ) + ) + + # For ordinary (not spot) instances, we can select 'stop' + # (the default) or 'terminate' here. + params['instance_initiated_shutdown_behavior'] = instance_initiated_shutdown_behavior or 'stop' + + try: + res = ec2.run_instances(**params) + except boto.exception.EC2ResponseError as e: + if (params['instance_initiated_shutdown_behavior'] != 'terminate' and + "InvalidParameterCombination" == e.error_code): + params['instance_initiated_shutdown_behavior'] = 'terminate' + res = ec2.run_instances(**params) + else: + raise + + instids = [i.id for i in res.instances] + while True: + try: + ec2.get_all_instances(instids) + break + except boto.exception.EC2ResponseError as e: + if "<Code>InvalidInstanceID.NotFound</Code>" in str(e): + # there's a race between start and get an instance + continue + else: + module.fail_json(msg=str(e)) + + # The instances returned through ec2.run_instances above can be in + # terminated state due to idempotency. See commit 7f11c3d for a complete + # explanation. + terminated_instances = [ + str(instance.id) for instance in res.instances if instance.state == 'terminated' + ] + if terminated_instances: + module.fail_json(msg="Instances with id(s) %s " % terminated_instances + + "were created previously but have since been terminated - " + + "use a (possibly different) 'instanceid' parameter") + + else: + if private_ip: + module.fail_json( + msg='private_ip only available with on-demand (non-spot) instances') + if boto_supports_param_in_spot_request(ec2, 'placement_group'): + params['placement_group'] = placement_group + elif placement_group: + module.fail_json( + msg="placement_group parameter requires Boto version 2.3.0 or higher.") + + # You can't tell spot instances to 'stop'; they will always be + # 'terminate'd. For convenience, we'll ignore the latter value. + if instance_initiated_shutdown_behavior and instance_initiated_shutdown_behavior != 'terminate': + module.fail_json( + msg="instance_initiated_shutdown_behavior=stop is not supported for spot instances.") + + if spot_launch_group and isinstance(spot_launch_group, string_types): + params['launch_group'] = spot_launch_group + + params.update(dict( + count=count_remaining, + type=spot_type, + )) + + # Set spot ValidUntil + # ValidUntil -> (timestamp). The end date of the request, in + # UTC format (for example, YYYY -MM -DD T*HH* :MM :SS Z). + utc_valid_until = ( + datetime.datetime.utcnow() + + datetime.timedelta(seconds=spot_wait_timeout)) + params['valid_until'] = utc_valid_until.strftime('%Y-%m-%dT%H:%M:%S.000Z') + + res = ec2.request_spot_instances(spot_price, **params) + + # Now we have to do the intermediate waiting + if wait: + instids = await_spot_requests(module, ec2, res, count) + else: + instids = [] + except boto.exception.BotoServerError as e: + module.fail_json(msg="Instance creation failed => %s: %s" % (e.error_code, e.error_message)) + + # wait here until the instances are up + num_running = 0 + wait_timeout = time.time() + wait_timeout + res_list = () + while wait_timeout > time.time() and num_running < len(instids): + try: + res_list = ec2.get_all_instances(instids) + except boto.exception.BotoServerError as e: + if e.error_code == 'InvalidInstanceID.NotFound': + time.sleep(1) + continue + else: + raise + + num_running = 0 + for res in res_list: + num_running += len([i for i in res.instances if i.state == 'running']) + if len(res_list) <= 0: + # got a bad response of some sort, possibly due to + # stale/cached data. Wait a second and then try again + time.sleep(1) + continue + if wait and num_running < len(instids): + time.sleep(5) + else: + break + + if wait and wait_timeout <= time.time(): + # waiting took too long + module.fail_json(msg="wait for instances running timeout on %s" % time.asctime()) + + # We do this after the loop ends so that we end up with one list + for res in res_list: + running_instances.extend(res.instances) + + # Enabled by default by AWS + if source_dest_check is False: + for inst in res.instances: + inst.modify_attribute('sourceDestCheck', False) + + # Disabled by default by AWS + if termination_protection is True: + for inst in res.instances: + inst.modify_attribute('disableApiTermination', True) + + # Leave this as late as possible to try and avoid InvalidInstanceID.NotFound + if instance_tags and instids: + try: + ec2.create_tags(instids, instance_tags) + except boto.exception.EC2ResponseError as e: + module.fail_json(msg="Instance tagging failed => %s: %s" % (e.error_code, e.error_message)) + + instance_dict_array = [] + created_instance_ids = [] + for inst in running_instances: + inst.update() + d = get_instance_info(inst) + created_instance_ids.append(inst.id) + instance_dict_array.append(d) + + return (instance_dict_array, created_instance_ids, changed) + + +def terminate_instances(module, ec2, instance_ids): + """ + Terminates a list of instances + + module: Ansible module object + ec2: authenticated ec2 connection object + termination_list: a list of instances to terminate in the form of + [ {id: <inst-id>}, ..] + + Returns a dictionary of instance information + about the instances terminated. + + If the instance to be terminated is running + "changed" will be set to False. + + """ + + # Whether to wait for termination to complete before returning + wait = module.params.get('wait') + wait_timeout = int(module.params.get('wait_timeout')) + + changed = False + instance_dict_array = [] + + if not isinstance(instance_ids, list) or len(instance_ids) < 1: + module.fail_json(msg='instance_ids should be a list of instances, aborting') + + terminated_instance_ids = [] + for res in ec2.get_all_instances(instance_ids): + for inst in res.instances: + if inst.state == 'running' or inst.state == 'stopped': + terminated_instance_ids.append(inst.id) + instance_dict_array.append(get_instance_info(inst)) + try: + ec2.terminate_instances([inst.id]) + except EC2ResponseError as e: + module.fail_json(msg='Unable to terminate instance {0}, error: {1}'.format(inst.id, e)) + changed = True + + # wait here until the instances are 'terminated' + if wait: + num_terminated = 0 + wait_timeout = time.time() + wait_timeout + while wait_timeout > time.time() and num_terminated < len(terminated_instance_ids): + response = ec2.get_all_instances(instance_ids=terminated_instance_ids, + filters={'instance-state-name': 'terminated'}) + try: + num_terminated = sum([len(res.instances) for res in response]) + except Exception as e: + # got a bad response of some sort, possibly due to + # stale/cached data. Wait a second and then try again + time.sleep(1) + continue + + if num_terminated < len(terminated_instance_ids): + time.sleep(5) + + # waiting took too long + if wait_timeout < time.time() and num_terminated < len(terminated_instance_ids): + module.fail_json(msg="wait for instance termination timeout on %s" % time.asctime()) + # Lets get the current state of the instances after terminating - issue600 + instance_dict_array = [] + for res in ec2.get_all_instances(instance_ids=terminated_instance_ids, filters={'instance-state-name': 'terminated'}): + for inst in res.instances: + instance_dict_array.append(get_instance_info(inst)) + + return (changed, instance_dict_array, terminated_instance_ids) + + +def startstop_instances(module, ec2, instance_ids, state, instance_tags): + """ + Starts or stops a list of existing instances + + module: Ansible module object + ec2: authenticated ec2 connection object + instance_ids: The list of instances to start in the form of + [ {id: <inst-id>}, ..] + instance_tags: A dict of tag keys and values in the form of + {key: value, ... } + state: Intended state ("running" or "stopped") + + Returns a dictionary of instance information + about the instances started/stopped. + + If the instance was not able to change state, + "changed" will be set to False. + + Note that if instance_ids and instance_tags are both non-empty, + this method will process the intersection of the two + """ + + wait = module.params.get('wait') + wait_timeout = int(module.params.get('wait_timeout')) + group_id = module.params.get('group_id') + group_name = module.params.get('group') + changed = False + instance_dict_array = [] + + if not isinstance(instance_ids, list) or len(instance_ids) < 1: + # Fail unless the user defined instance tags + if not instance_tags: + module.fail_json(msg='instance_ids should be a list of instances, aborting') + + # To make an EC2 tag filter, we need to prepend 'tag:' to each key. + # An empty filter does no filtering, so it's safe to pass it to the + # get_all_instances method even if the user did not specify instance_tags + filters = {} + if instance_tags: + for key, value in instance_tags.items(): + filters["tag:" + key] = value + + if module.params.get('id'): + filters['client-token'] = module.params['id'] + # Check that our instances are not in the state we want to take + + # Check (and eventually change) instances attributes and instances state + existing_instances_array = [] + for res in ec2.get_all_instances(instance_ids, filters=filters): + for inst in res.instances: + + warn_if_public_ip_assignment_changed(module, inst) + + changed = (check_source_dest_attr(module, inst, ec2) or + check_termination_protection(module, inst) or changed) + + # Check security groups and if we're using ec2-vpc; ec2-classic security groups may not be modified + if inst.vpc_id and group_name: + grp_details = ec2.get_all_security_groups(filters={'vpc_id': inst.vpc_id}) + if isinstance(group_name, string_types): + group_name = [group_name] + unmatched = set(group_name) - set(to_text(grp.name) for grp in grp_details) + if unmatched: + module.fail_json(msg="The following group names are not valid: %s" % ', '.join(unmatched)) + group_ids = [to_text(grp.id) for grp in grp_details if to_text(grp.name) in group_name] + elif inst.vpc_id and group_id: + if isinstance(group_id, string_types): + group_id = [group_id] + grp_details = ec2.get_all_security_groups(group_ids=group_id) + group_ids = [grp_item.id for grp_item in grp_details] + if inst.vpc_id and (group_name or group_id): + if set(sg.id for sg in inst.groups) != set(group_ids): + changed = inst.modify_attribute('groupSet', group_ids) + + # Check instance state + if inst.state != state: + instance_dict_array.append(get_instance_info(inst)) + try: + if state == 'running': + inst.start() + else: + inst.stop() + except EC2ResponseError as e: + module.fail_json(msg='Unable to change state for instance {0}, error: {1}'.format(inst.id, e)) + changed = True + existing_instances_array.append(inst.id) + + instance_ids = list(set(existing_instances_array + (instance_ids or []))) + # Wait for all the instances to finish starting or stopping + wait_timeout = time.time() + wait_timeout + while wait and wait_timeout > time.time(): + instance_dict_array = [] + matched_instances = [] + for res in ec2.get_all_instances(instance_ids): + for i in res.instances: + if i.state == state: + instance_dict_array.append(get_instance_info(i)) + matched_instances.append(i) + if len(matched_instances) < len(instance_ids): + time.sleep(5) + else: + break + + if wait and wait_timeout <= time.time(): + # waiting took too long + module.fail_json(msg="wait for instances running timeout on %s" % time.asctime()) + + return (changed, instance_dict_array, instance_ids) + + +def restart_instances(module, ec2, instance_ids, state, instance_tags): + """ + Restarts a list of existing instances + + module: Ansible module object + ec2: authenticated ec2 connection object + instance_ids: The list of instances to start in the form of + [ {id: <inst-id>}, ..] + instance_tags: A dict of tag keys and values in the form of + {key: value, ... } + state: Intended state ("restarted") + + Returns a dictionary of instance information + about the instances. + + If the instance was not able to change state, + "changed" will be set to False. + + Wait will not apply here as this is a OS level operation. + + Note that if instance_ids and instance_tags are both non-empty, + this method will process the intersection of the two. + """ + + changed = False + instance_dict_array = [] + + if not isinstance(instance_ids, list) or len(instance_ids) < 1: + # Fail unless the user defined instance tags + if not instance_tags: + module.fail_json(msg='instance_ids should be a list of instances, aborting') + + # To make an EC2 tag filter, we need to prepend 'tag:' to each key. + # An empty filter does no filtering, so it's safe to pass it to the + # get_all_instances method even if the user did not specify instance_tags + filters = {} + if instance_tags: + for key, value in instance_tags.items(): + filters["tag:" + key] = value + if module.params.get('id'): + filters['client-token'] = module.params['id'] + + # Check that our instances are not in the state we want to take + + # Check (and eventually change) instances attributes and instances state + for res in ec2.get_all_instances(instance_ids, filters=filters): + for inst in res.instances: + + warn_if_public_ip_assignment_changed(module, inst) + + changed = (check_source_dest_attr(module, inst, ec2) or + check_termination_protection(module, inst) or changed) + + # Check instance state + if inst.state != state: + instance_dict_array.append(get_instance_info(inst)) + try: + inst.reboot() + except EC2ResponseError as e: + module.fail_json(msg='Unable to change state for instance {0}, error: {1}'.format(inst.id, e)) + changed = True + + return (changed, instance_dict_array, instance_ids) + + +def check_termination_protection(module, inst): + """ + Check the instance disableApiTermination attribute. + + module: Ansible module object + inst: EC2 instance object + + returns: True if state changed None otherwise + """ + + termination_protection = module.params.get('termination_protection') + + if (inst.get_attribute('disableApiTermination')['disableApiTermination'] != termination_protection and termination_protection is not None): + inst.modify_attribute('disableApiTermination', termination_protection) + return True + + +def check_source_dest_attr(module, inst, ec2): + """ + Check the instance sourceDestCheck attribute. + + module: Ansible module object + inst: EC2 instance object + + returns: True if state changed None otherwise + """ + + source_dest_check = module.params.get('source_dest_check') + + if source_dest_check is not None: + try: + if inst.vpc_id is not None and inst.get_attribute('sourceDestCheck')['sourceDestCheck'] != source_dest_check: + inst.modify_attribute('sourceDestCheck', source_dest_check) + return True + except boto.exception.EC2ResponseError as exc: + # instances with more than one Elastic Network Interface will + # fail, because they have the sourceDestCheck attribute defined + # per-interface + if exc.code == 'InvalidInstanceID': + for interface in inst.interfaces: + if interface.source_dest_check != source_dest_check: + ec2.modify_network_interface_attribute(interface.id, "sourceDestCheck", source_dest_check) + return True + else: + module.fail_json(msg='Failed to handle source_dest_check state for instance {0}, error: {1}'.format(inst.id, exc), + exception=traceback.format_exc()) + + +def warn_if_public_ip_assignment_changed(module, instance): + # This is a non-modifiable attribute. + assign_public_ip = module.params.get('assign_public_ip') + + # Check that public ip assignment is the same and warn if not + public_dns_name = getattr(instance, 'public_dns_name', None) + if (assign_public_ip or public_dns_name) and (not public_dns_name or assign_public_ip is False): + module.warn("Unable to modify public ip assignment to {0} for instance {1}. " + "Whether or not to assign a public IP is determined during instance creation.".format(assign_public_ip, instance.id)) + + +def main(): + argument_spec = ec2_argument_spec() + argument_spec.update( + dict( + key_name=dict(aliases=['keypair']), + id=dict(), + group=dict(type='list', aliases=['groups']), + group_id=dict(type='list'), + zone=dict(aliases=['aws_zone', 'ec2_zone']), + instance_type=dict(aliases=['type']), + spot_price=dict(), + spot_type=dict(default='one-time', choices=["one-time", "persistent"]), + spot_launch_group=dict(), + image=dict(), + kernel=dict(), + count=dict(type='int', default='1'), + monitoring=dict(type='bool', default=False), + ramdisk=dict(), + wait=dict(type='bool', default=False), + wait_timeout=dict(type='int', default=300), + spot_wait_timeout=dict(type='int', default=600), + placement_group=dict(), + user_data=dict(), + instance_tags=dict(type='dict'), + vpc_subnet_id=dict(), + assign_public_ip=dict(type='bool'), + private_ip=dict(), + instance_profile_name=dict(), + instance_ids=dict(type='list', aliases=['instance_id']), + source_dest_check=dict(type='bool', default=None), + termination_protection=dict(type='bool', default=None), + state=dict(default='present', choices=['present', 'absent', 'running', 'restarted', 'stopped']), + instance_initiated_shutdown_behavior=dict(default='stop', choices=['stop', 'terminate']), + exact_count=dict(type='int', default=None), + count_tag=dict(type='raw'), + volumes=dict(type='list'), + ebs_optimized=dict(type='bool', default=False), + tenancy=dict(default='default', choices=['default', 'dedicated']), + network_interfaces=dict(type='list', aliases=['network_interface']) + ) + ) + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=[ + # Can be uncommented when we finish the deprecation cycle. + # ['group', 'group_id'], + ['exact_count', 'count'], + ['exact_count', 'state'], + ['exact_count', 'instance_ids'], + ['network_interfaces', 'assign_public_ip'], + ['network_interfaces', 'group'], + ['network_interfaces', 'group_id'], + ['network_interfaces', 'private_ip'], + ['network_interfaces', 'vpc_subnet_id'], + ], + ) + + if module.params.get('group') and module.params.get('group_id'): + module.deprecate( + msg='Support for passing both group and group_id has been deprecated. ' + 'Currently group_id is ignored, in future passing both will result in an error', + version='2.14', collection_name='ansible.builtin') + + if not HAS_BOTO: + module.fail_json(msg='boto required for this module') + + try: + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module) + if module.params.get('region') or not module.params.get('ec2_url'): + ec2 = ec2_connect(module) + elif module.params.get('ec2_url'): + ec2 = connect_ec2_endpoint(ec2_url, **aws_connect_kwargs) + + if 'region' not in aws_connect_kwargs: + aws_connect_kwargs['region'] = ec2.region + + vpc = connect_vpc(**aws_connect_kwargs) + except boto.exception.NoAuthHandlerFound as e: + module.fail_json(msg="Failed to get connection: %s" % e.message, exception=traceback.format_exc()) + + tagged_instances = [] + + state = module.params['state'] + + if state == 'absent': + instance_ids = module.params['instance_ids'] + if not instance_ids: + module.fail_json(msg='instance_ids list is required for absent state') + + (changed, instance_dict_array, new_instance_ids) = terminate_instances(module, ec2, instance_ids) + + elif state in ('running', 'stopped'): + instance_ids = module.params.get('instance_ids') + instance_tags = module.params.get('instance_tags') + if not (isinstance(instance_ids, list) or isinstance(instance_tags, dict)): + module.fail_json(msg='running list needs to be a list of instances or set of tags to run: %s' % instance_ids) + + (changed, instance_dict_array, new_instance_ids) = startstop_instances(module, ec2, instance_ids, state, instance_tags) + + elif state in ('restarted'): + instance_ids = module.params.get('instance_ids') + instance_tags = module.params.get('instance_tags') + if not (isinstance(instance_ids, list) or isinstance(instance_tags, dict)): + module.fail_json(msg='running list needs to be a list of instances or set of tags to run: %s' % instance_ids) + + (changed, instance_dict_array, new_instance_ids) = restart_instances(module, ec2, instance_ids, state, instance_tags) + + elif state == 'present': + # Changed is always set to true when provisioning new instances + if not module.params.get('image'): + module.fail_json(msg='image parameter is required for new instance') + + if module.params.get('exact_count') is None: + (instance_dict_array, new_instance_ids, changed) = create_instances(module, ec2, vpc) + else: + (tagged_instances, instance_dict_array, new_instance_ids, changed) = enforce_count(module, ec2, vpc) + + # Always return instances in the same order + if new_instance_ids: + new_instance_ids.sort() + if instance_dict_array: + instance_dict_array.sort(key=lambda x: x['id']) + if tagged_instances: + tagged_instances.sort(key=lambda x: x['id']) + + module.exit_json(changed=changed, instance_ids=new_instance_ids, instances=instance_dict_array, tagged_instances=tagged_instances) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_ami_info.py b/test/support/integration/plugins/modules/ec2_ami_info.py new file mode 100644 index 00000000..53c2374d --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_ami_info.py @@ -0,0 +1,282 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: ec2_ami_info +version_added: '2.5' +short_description: Gather information about ec2 AMIs +description: + - Gather information about ec2 AMIs + - This module was called C(ec2_ami_facts) before Ansible 2.9. The usage did not change. +author: + - Prasad Katti (@prasadkatti) +requirements: [ boto3 ] +options: + image_ids: + description: One or more image IDs. + aliases: [image_id] + type: list + elements: str + filters: + description: + - A dict of filters to apply. Each dict item consists of a filter key and a filter value. + - See U(https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeImages.html) for possible filters. + - Filter names and values are case sensitive. + type: dict + owners: + description: + - Filter the images by the owner. Valid options are an AWS account ID, self, + or an AWS owner alias ( amazon | aws-marketplace | microsoft ). + aliases: [owner] + type: list + elements: str + executable_users: + description: + - Filter images by users with explicit launch permissions. Valid options are an AWS account ID, self, or all (public AMIs). + aliases: [executable_user] + type: list + elements: str + describe_image_attributes: + description: + - Describe attributes (like launchPermission) of the images found. + default: no + type: bool + +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +- name: gather information about an AMI using ami-id + ec2_ami_info: + image_ids: ami-5b488823 + +- name: gather information about all AMIs with tag key Name and value webapp + ec2_ami_info: + filters: + "tag:Name": webapp + +- name: gather information about an AMI with 'AMI Name' equal to foobar + ec2_ami_info: + filters: + name: foobar + +- name: gather information about Ubuntu 17.04 AMIs published by Canonical (099720109477) + ec2_ami_info: + owners: 099720109477 + filters: + name: "ubuntu/images/ubuntu-zesty-17.04-*" +''' + +RETURN = ''' +images: + description: A list of images. + returned: always + type: list + elements: dict + contains: + architecture: + description: The architecture of the image. + returned: always + type: str + sample: x86_64 + block_device_mappings: + description: Any block device mapping entries. + returned: always + type: list + elements: dict + contains: + device_name: + description: The device name exposed to the instance. + returned: always + type: str + sample: /dev/sda1 + ebs: + description: EBS volumes + returned: always + type: complex + creation_date: + description: The date and time the image was created. + returned: always + type: str + sample: '2017-10-16T19:22:13.000Z' + description: + description: The description of the AMI. + returned: always + type: str + sample: '' + ena_support: + description: Whether enhanced networking with ENA is enabled. + returned: always + type: bool + sample: true + hypervisor: + description: The hypervisor type of the image. + returned: always + type: str + sample: xen + image_id: + description: The ID of the AMI. + returned: always + type: str + sample: ami-5b466623 + image_location: + description: The location of the AMI. + returned: always + type: str + sample: 408466080000/Webapp + image_type: + description: The type of image. + returned: always + type: str + sample: machine + launch_permissions: + description: A List of AWS accounts may launch the AMI. + returned: When image is owned by calling account and I(describe_image_attributes) is yes. + type: list + elements: dict + contains: + group: + description: A value of 'all' means the AMI is public. + type: str + user_id: + description: An AWS account ID with permissions to launch the AMI. + type: str + sample: [{"group": "all"}, {"user_id": "408466080000"}] + name: + description: The name of the AMI that was provided during image creation. + returned: always + type: str + sample: Webapp + owner_id: + description: The AWS account ID of the image owner. + returned: always + type: str + sample: '408466080000' + public: + description: Whether the image has public launch permissions. + returned: always + type: bool + sample: true + root_device_name: + description: The device name of the root device. + returned: always + type: str + sample: /dev/sda1 + root_device_type: + description: The type of root device used by the AMI. + returned: always + type: str + sample: ebs + sriov_net_support: + description: Whether enhanced networking is enabled. + returned: always + type: str + sample: simple + state: + description: The current state of the AMI. + returned: always + type: str + sample: available + tags: + description: Any tags assigned to the image. + returned: always + type: dict + virtualization_type: + description: The type of virtualization of the AMI. + returned: always + type: str + sample: hvm +''' + +try: + from botocore.exceptions import ClientError, BotoCoreError +except ImportError: + pass # caught by AnsibleAWSModule + +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import ansible_dict_to_boto3_filter_list, camel_dict_to_snake_dict, boto3_tag_list_to_ansible_dict + + +def list_ec2_images(ec2_client, module): + + image_ids = module.params.get("image_ids") + owners = module.params.get("owners") + executable_users = module.params.get("executable_users") + filters = module.params.get("filters") + owner_param = [] + + # describe_images is *very* slow if you pass the `Owners` + # param (unless it's self), for some reason. + # Converting the owners to filters and removing from the + # owners param greatly speeds things up. + # Implementation based on aioue's suggestion in #24886 + for owner in owners: + if owner.isdigit(): + if 'owner-id' not in filters: + filters['owner-id'] = list() + filters['owner-id'].append(owner) + elif owner == 'self': + # self not a valid owner-alias filter (https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeImages.html) + owner_param.append(owner) + else: + if 'owner-alias' not in filters: + filters['owner-alias'] = list() + filters['owner-alias'].append(owner) + + filters = ansible_dict_to_boto3_filter_list(filters) + + try: + images = ec2_client.describe_images(ImageIds=image_ids, Filters=filters, Owners=owner_param, ExecutableUsers=executable_users) + images = [camel_dict_to_snake_dict(image) for image in images["Images"]] + except (ClientError, BotoCoreError) as err: + module.fail_json_aws(err, msg="error describing images") + for image in images: + try: + image['tags'] = boto3_tag_list_to_ansible_dict(image.get('tags', [])) + if module.params.get("describe_image_attributes"): + launch_permissions = ec2_client.describe_image_attribute(Attribute='launchPermission', ImageId=image['image_id'])['LaunchPermissions'] + image['launch_permissions'] = [camel_dict_to_snake_dict(perm) for perm in launch_permissions] + except (ClientError, BotoCoreError) as err: + # describing launch permissions of images owned by others is not permitted, but shouldn't cause failures + pass + + images.sort(key=lambda e: e.get('creation_date', '')) # it may be possible that creation_date does not always exist + module.exit_json(images=images) + + +def main(): + + argument_spec = dict( + image_ids=dict(default=[], type='list', aliases=['image_id']), + filters=dict(default={}, type='dict'), + owners=dict(default=[], type='list', aliases=['owner']), + executable_users=dict(default=[], type='list', aliases=['executable_user']), + describe_image_attributes=dict(default=False, type='bool') + ) + + module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True) + if module._module._name == 'ec2_ami_facts': + module._module.deprecate("The 'ec2_ami_facts' module has been renamed to 'ec2_ami_info'", + version='2.13', collection_name='ansible.builtin') + + ec2_client = module.client('ec2') + + list_ec2_images(ec2_client, module) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_group.py b/test/support/integration/plugins/modules/ec2_group.py new file mode 100644 index 00000000..bc416f66 --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_group.py @@ -0,0 +1,1345 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + +DOCUMENTATION = ''' +--- +module: ec2_group +author: "Andrew de Quincey (@adq)" +version_added: "1.3" +requirements: [ boto3 ] +short_description: maintain an ec2 VPC security group. +description: + - Maintains ec2 security groups. This module has a dependency on python-boto >= 2.5. +options: + name: + description: + - Name of the security group. + - One of and only one of I(name) or I(group_id) is required. + - Required if I(state=present). + required: false + type: str + group_id: + description: + - Id of group to delete (works only with absent). + - One of and only one of I(name) or I(group_id) is required. + required: false + version_added: "2.4" + type: str + description: + description: + - Description of the security group. Required when C(state) is C(present). + required: false + type: str + vpc_id: + description: + - ID of the VPC to create the group in. + required: false + type: str + rules: + description: + - List of firewall inbound rules to enforce in this group (see example). If none are supplied, + no inbound rules will be enabled. Rules list may include its own name in `group_name`. + This allows idempotent loopback additions (e.g. allow group to access itself). + Rule sources list support was added in version 2.4. This allows to define multiple sources per + source type as well as multiple source types per rule. Prior to 2.4 an individual source is allowed. + In version 2.5 support for rule descriptions was added. + required: false + type: list + elements: dict + suboptions: + cidr_ip: + type: str + description: + - The IPv4 CIDR range traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + cidr_ipv6: + type: str + description: + - The IPv6 CIDR range traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + ip_prefix: + type: str + description: + - The IP Prefix U(https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-prefix-lists.html) + that traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_id: + type: str + description: + - The ID of the Security Group that traffic is coming from. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_name: + type: str + description: + - Name of the Security Group that traffic is coming from. + - If the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_desc: + type: str + description: + - If the I(group_name) is set and the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + proto: + type: str + description: + - The IP protocol name (C(tcp), C(udp), C(icmp), C(icmpv6)) or number (U(https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers)) + from_port: + type: int + description: The start of the range of ports that traffic is coming from. A value of C(-1) indicates all ports. + to_port: + type: int + description: The end of the range of ports that traffic is coming from. A value of C(-1) indicates all ports. + rule_desc: + type: str + description: A description for the rule. + rules_egress: + description: + - List of firewall outbound rules to enforce in this group (see example). If none are supplied, + a default all-out rule is assumed. If an empty list is supplied, no outbound rules will be enabled. + Rule Egress sources list support was added in version 2.4. In version 2.5 support for rule descriptions + was added. + required: false + version_added: "1.6" + type: list + elements: dict + suboptions: + cidr_ip: + type: str + description: + - The IPv4 CIDR range traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + cidr_ipv6: + type: str + description: + - The IPv6 CIDR range traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + ip_prefix: + type: str + description: + - The IP Prefix U(https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-prefix-lists.html) + that traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_id: + type: str + description: + - The ID of the Security Group that traffic is going to. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_name: + type: str + description: + - Name of the Security Group that traffic is going to. + - If the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + - You can specify only one of I(cidr_ip), I(cidr_ipv6), I(ip_prefix), I(group_id) + and I(group_name). + group_desc: + type: str + description: + - If the I(group_name) is set and the Security Group doesn't exist a new Security Group will be + created with I(group_desc) as the description. + proto: + type: str + description: + - The IP protocol name (C(tcp), C(udp), C(icmp), C(icmpv6)) or number (U(https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers)) + from_port: + type: int + description: The start of the range of ports that traffic is going to. A value of C(-1) indicates all ports. + to_port: + type: int + description: The end of the range of ports that traffic is going to. A value of C(-1) indicates all ports. + rule_desc: + type: str + description: A description for the rule. + state: + version_added: "1.4" + description: + - Create or delete a security group. + required: false + default: 'present' + choices: [ "present", "absent" ] + aliases: [] + type: str + purge_rules: + version_added: "1.8" + description: + - Purge existing rules on security group that are not found in rules. + required: false + default: 'true' + aliases: [] + type: bool + purge_rules_egress: + version_added: "1.8" + description: + - Purge existing rules_egress on security group that are not found in rules_egress. + required: false + default: 'true' + aliases: [] + type: bool + tags: + version_added: "2.4" + description: + - A dictionary of one or more tags to assign to the security group. + required: false + type: dict + aliases: ['resource_tags'] + purge_tags: + version_added: "2.4" + description: + - If yes, existing tags will be purged from the resource to match exactly what is defined by I(tags) parameter. If the I(tags) parameter is not set then + tags will not be modified. + required: false + default: yes + type: bool + +extends_documentation_fragment: + - aws + - ec2 + +notes: + - If a rule declares a group_name and that group doesn't exist, it will be + automatically created. In that case, group_desc should be provided as well. + The module will refuse to create a depended-on group without a description. + - Preview diff mode support is added in version 2.7. +''' + +EXAMPLES = ''' +- name: example using security group rule descriptions + ec2_group: + name: "{{ name }}" + description: sg with rule descriptions + vpc_id: vpc-xxxxxxxx + profile: "{{ aws_profile }}" + region: us-east-1 + rules: + - proto: tcp + ports: + - 80 + cidr_ip: 0.0.0.0/0 + rule_desc: allow all on port 80 + +- name: example ec2 group + ec2_group: + name: example + description: an example EC2 group + vpc_id: 12345 + region: eu-west-1 + aws_secret_key: SECRET + aws_access_key: ACCESS + rules: + - proto: tcp + from_port: 80 + to_port: 80 + cidr_ip: 0.0.0.0/0 + - proto: tcp + from_port: 22 + to_port: 22 + cidr_ip: 10.0.0.0/8 + - proto: tcp + from_port: 443 + to_port: 443 + # this should only be needed for EC2 Classic security group rules + # because in a VPC an ELB will use a user-account security group + group_id: amazon-elb/sg-87654321/amazon-elb-sg + - proto: tcp + from_port: 3306 + to_port: 3306 + group_id: 123412341234/sg-87654321/exact-name-of-sg + - proto: udp + from_port: 10050 + to_port: 10050 + cidr_ip: 10.0.0.0/8 + - proto: udp + from_port: 10051 + to_port: 10051 + group_id: sg-12345678 + - proto: icmp + from_port: 8 # icmp type, -1 = any type + to_port: -1 # icmp subtype, -1 = any subtype + cidr_ip: 10.0.0.0/8 + - proto: all + # the containing group name may be specified here + group_name: example + - proto: all + # in the 'proto' attribute, if you specify -1, all, or a protocol number other than tcp, udp, icmp, or 58 (ICMPv6), + # traffic on all ports is allowed, regardless of any ports you specify + from_port: 10050 # this value is ignored + to_port: 10050 # this value is ignored + cidr_ip: 10.0.0.0/8 + + rules_egress: + - proto: tcp + from_port: 80 + to_port: 80 + cidr_ip: 0.0.0.0/0 + cidr_ipv6: 64:ff9b::/96 + group_name: example-other + # description to use if example-other needs to be created + group_desc: other example EC2 group + +- name: example2 ec2 group + ec2_group: + name: example2 + description: an example2 EC2 group + vpc_id: 12345 + region: eu-west-1 + rules: + # 'ports' rule keyword was introduced in version 2.4. It accepts a single port value or a list of values including ranges (from_port-to_port). + - proto: tcp + ports: 22 + group_name: example-vpn + - proto: tcp + ports: + - 80 + - 443 + - 8080-8099 + cidr_ip: 0.0.0.0/0 + # Rule sources list support was added in version 2.4. This allows to define multiple sources per source type as well as multiple source types per rule. + - proto: tcp + ports: + - 6379 + - 26379 + group_name: + - example-vpn + - example-redis + - proto: tcp + ports: 5665 + group_name: example-vpn + cidr_ip: + - 172.16.1.0/24 + - 172.16.17.0/24 + cidr_ipv6: + - 2607:F8B0::/32 + - 64:ff9b::/96 + group_id: + - sg-edcd9784 + diff: True + +- name: "Delete group by its id" + ec2_group: + region: eu-west-1 + group_id: sg-33b4ee5b + state: absent +''' + +RETURN = ''' +group_name: + description: Security group name + sample: My Security Group + type: str + returned: on create/update +group_id: + description: Security group id + sample: sg-abcd1234 + type: str + returned: on create/update +description: + description: Description of security group + sample: My Security Group + type: str + returned: on create/update +tags: + description: Tags associated with the security group + sample: + Name: My Security Group + Purpose: protecting stuff + type: dict + returned: on create/update +vpc_id: + description: ID of VPC to which the security group belongs + sample: vpc-abcd1234 + type: str + returned: on create/update +ip_permissions: + description: Inbound rules associated with the security group. + sample: + - from_port: 8182 + ip_protocol: tcp + ip_ranges: + - cidr_ip: "1.1.1.1/32" + ipv6_ranges: [] + prefix_list_ids: [] + to_port: 8182 + user_id_group_pairs: [] + type: list + returned: on create/update +ip_permissions_egress: + description: Outbound rules associated with the security group. + sample: + - ip_protocol: -1 + ip_ranges: + - cidr_ip: "0.0.0.0/0" + ipv6_ranges: [] + prefix_list_ids: [] + user_id_group_pairs: [] + type: list + returned: on create/update +owner_id: + description: AWS Account ID of the security group + sample: 123456789012 + type: int + returned: on create/update +''' + +import json +import re +import itertools +from copy import deepcopy +from time import sleep +from collections import namedtuple +from ansible.module_utils.aws.core import AnsibleAWSModule, is_boto3_error_code +from ansible.module_utils.aws.iam import get_aws_account_id +from ansible.module_utils.aws.waiters import get_waiter +from ansible.module_utils.ec2 import AWSRetry, camel_dict_to_snake_dict, compare_aws_tags +from ansible.module_utils.ec2 import ansible_dict_to_boto3_filter_list, boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list +from ansible.module_utils.common.network import to_ipv6_subnet, to_subnet +from ansible.module_utils.compat.ipaddress import ip_network, IPv6Network +from ansible.module_utils._text import to_text +from ansible.module_utils.six import string_types + +try: + from botocore.exceptions import BotoCoreError, ClientError +except ImportError: + pass # caught by AnsibleAWSModule + + +Rule = namedtuple('Rule', ['port_range', 'protocol', 'target', 'target_type', 'description']) +valid_targets = set(['ipv4', 'ipv6', 'group', 'ip_prefix']) +current_account_id = None + + +def rule_cmp(a, b): + """Compare rules without descriptions""" + for prop in ['port_range', 'protocol', 'target', 'target_type']: + if prop == 'port_range' and to_text(a.protocol) == to_text(b.protocol): + # equal protocols can interchange `(-1, -1)` and `(None, None)` + if a.port_range in ((None, None), (-1, -1)) and b.port_range in ((None, None), (-1, -1)): + continue + elif getattr(a, prop) != getattr(b, prop): + return False + elif getattr(a, prop) != getattr(b, prop): + return False + return True + + +def rules_to_permissions(rules): + return [to_permission(rule) for rule in rules] + + +def to_permission(rule): + # take a Rule, output the serialized grant + perm = { + 'IpProtocol': rule.protocol, + } + perm['FromPort'], perm['ToPort'] = rule.port_range + if rule.target_type == 'ipv4': + perm['IpRanges'] = [{ + 'CidrIp': rule.target, + }] + if rule.description: + perm['IpRanges'][0]['Description'] = rule.description + elif rule.target_type == 'ipv6': + perm['Ipv6Ranges'] = [{ + 'CidrIpv6': rule.target, + }] + if rule.description: + perm['Ipv6Ranges'][0]['Description'] = rule.description + elif rule.target_type == 'group': + if isinstance(rule.target, tuple): + pair = {} + if rule.target[0]: + pair['UserId'] = rule.target[0] + # group_id/group_name are mutually exclusive - give group_id more precedence as it is more specific + if rule.target[1]: + pair['GroupId'] = rule.target[1] + elif rule.target[2]: + pair['GroupName'] = rule.target[2] + perm['UserIdGroupPairs'] = [pair] + else: + perm['UserIdGroupPairs'] = [{ + 'GroupId': rule.target + }] + if rule.description: + perm['UserIdGroupPairs'][0]['Description'] = rule.description + elif rule.target_type == 'ip_prefix': + perm['PrefixListIds'] = [{ + 'PrefixListId': rule.target, + }] + if rule.description: + perm['PrefixListIds'][0]['Description'] = rule.description + elif rule.target_type not in valid_targets: + raise ValueError('Invalid target type for rule {0}'.format(rule)) + return fix_port_and_protocol(perm) + + +def rule_from_group_permission(perm): + def ports_from_permission(p): + if 'FromPort' not in p and 'ToPort' not in p: + return (None, None) + return (int(perm['FromPort']), int(perm['ToPort'])) + + # outputs a rule tuple + for target_key, target_subkey, target_type in [ + ('IpRanges', 'CidrIp', 'ipv4'), + ('Ipv6Ranges', 'CidrIpv6', 'ipv6'), + ('PrefixListIds', 'PrefixListId', 'ip_prefix'), + ]: + if target_key not in perm: + continue + for r in perm[target_key]: + # there may be several IP ranges here, which is ok + yield Rule( + ports_from_permission(perm), + to_text(perm['IpProtocol']), + r[target_subkey], + target_type, + r.get('Description') + ) + if 'UserIdGroupPairs' in perm and perm['UserIdGroupPairs']: + for pair in perm['UserIdGroupPairs']: + target = ( + pair.get('UserId', None), + pair.get('GroupId', None), + pair.get('GroupName', None), + ) + if pair.get('UserId', '').startswith('amazon-'): + # amazon-elb and amazon-prefix rules don't need + # group-id specified, so remove it when querying + # from permission + target = ( + target[0], + None, + target[2], + ) + elif 'VpcPeeringConnectionId' in pair or pair['UserId'] != current_account_id: + target = ( + pair.get('UserId', None), + pair.get('GroupId', None), + pair.get('GroupName', None), + ) + + yield Rule( + ports_from_permission(perm), + to_text(perm['IpProtocol']), + target, + 'group', + pair.get('Description') + ) + + +@AWSRetry.backoff(tries=5, delay=5, backoff=2.0, catch_extra_error_codes=['InvalidGroup.NotFound']) +def get_security_groups_with_backoff(connection, **kwargs): + return connection.describe_security_groups(**kwargs) + + +@AWSRetry.backoff(tries=5, delay=5, backoff=2.0) +def sg_exists_with_backoff(connection, **kwargs): + try: + return connection.describe_security_groups(**kwargs) + except is_boto3_error_code('InvalidGroup.NotFound'): + return {'SecurityGroups': []} + + +def deduplicate_rules_args(rules): + """Returns unique rules""" + if rules is None: + return None + return list(dict(zip((json.dumps(r, sort_keys=True) for r in rules), rules)).values()) + + +def validate_rule(module, rule): + VALID_PARAMS = ('cidr_ip', 'cidr_ipv6', 'ip_prefix', + 'group_id', 'group_name', 'group_desc', + 'proto', 'from_port', 'to_port', 'rule_desc') + if not isinstance(rule, dict): + module.fail_json(msg='Invalid rule parameter type [%s].' % type(rule)) + for k in rule: + if k not in VALID_PARAMS: + module.fail_json(msg='Invalid rule parameter \'{0}\' for rule: {1}'.format(k, rule)) + + if 'group_id' in rule and 'cidr_ip' in rule: + module.fail_json(msg='Specify group_id OR cidr_ip, not both') + elif 'group_name' in rule and 'cidr_ip' in rule: + module.fail_json(msg='Specify group_name OR cidr_ip, not both') + elif 'group_id' in rule and 'cidr_ipv6' in rule: + module.fail_json(msg="Specify group_id OR cidr_ipv6, not both") + elif 'group_name' in rule and 'cidr_ipv6' in rule: + module.fail_json(msg="Specify group_name OR cidr_ipv6, not both") + elif 'cidr_ip' in rule and 'cidr_ipv6' in rule: + module.fail_json(msg="Specify cidr_ip OR cidr_ipv6, not both") + elif 'group_id' in rule and 'group_name' in rule: + module.fail_json(msg='Specify group_id OR group_name, not both') + + +def get_target_from_rule(module, client, rule, name, group, groups, vpc_id): + """ + Returns tuple of (target_type, target, group_created) after validating rule params. + + rule: Dict describing a rule. + name: Name of the security group being managed. + groups: Dict of all available security groups. + + AWS accepts an ip range or a security group as target of a rule. This + function validate the rule specification and return either a non-None + group_id or a non-None ip range. + """ + FOREIGN_SECURITY_GROUP_REGEX = r'^([^/]+)/?(sg-\S+)?/(\S+)' + group_id = None + group_name = None + target_group_created = False + + validate_rule(module, rule) + if rule.get('group_id') and re.match(FOREIGN_SECURITY_GROUP_REGEX, rule['group_id']): + # this is a foreign Security Group. Since you can't fetch it you must create an instance of it + owner_id, group_id, group_name = re.match(FOREIGN_SECURITY_GROUP_REGEX, rule['group_id']).groups() + group_instance = dict(UserId=owner_id, GroupId=group_id, GroupName=group_name) + groups[group_id] = group_instance + groups[group_name] = group_instance + # group_id/group_name are mutually exclusive - give group_id more precedence as it is more specific + if group_id and group_name: + group_name = None + return 'group', (owner_id, group_id, group_name), False + elif 'group_id' in rule: + return 'group', rule['group_id'], False + elif 'group_name' in rule: + group_name = rule['group_name'] + if group_name == name: + group_id = group['GroupId'] + groups[group_id] = group + groups[group_name] = group + elif group_name in groups and group.get('VpcId') and groups[group_name].get('VpcId'): + # both are VPC groups, this is ok + group_id = groups[group_name]['GroupId'] + elif group_name in groups and not (group.get('VpcId') or groups[group_name].get('VpcId')): + # both are EC2 classic, this is ok + group_id = groups[group_name]['GroupId'] + else: + auto_group = None + filters = {'group-name': group_name} + if vpc_id: + filters['vpc-id'] = vpc_id + # if we got here, either the target group does not exist, or there + # is a mix of EC2 classic + VPC groups. Mixing of EC2 classic + VPC + # is bad, so we have to create a new SG because no compatible group + # exists + if not rule.get('group_desc', '').strip(): + # retry describing the group once + try: + auto_group = get_security_groups_with_backoff(client, Filters=ansible_dict_to_boto3_filter_list(filters)).get('SecurityGroups', [])[0] + except (is_boto3_error_code('InvalidGroup.NotFound'), IndexError): + module.fail_json(msg="group %s will be automatically created by rule %s but " + "no description was provided" % (group_name, rule)) + except ClientError as e: # pylint: disable=duplicate-except + module.fail_json_aws(e) + elif not module.check_mode: + params = dict(GroupName=group_name, Description=rule['group_desc']) + if vpc_id: + params['VpcId'] = vpc_id + try: + auto_group = client.create_security_group(**params) + get_waiter( + client, 'security_group_exists', + ).wait( + GroupIds=[auto_group['GroupId']], + ) + except is_boto3_error_code('InvalidGroup.Duplicate'): + # The group exists, but didn't show up in any of our describe-security-groups calls + # Try searching on a filter for the name, and allow a retry window for AWS to update + # the model on their end. + try: + auto_group = get_security_groups_with_backoff(client, Filters=ansible_dict_to_boto3_filter_list(filters)).get('SecurityGroups', [])[0] + except IndexError as e: + module.fail_json(msg="Could not create or use existing group '{0}' in rule. Make sure the group exists".format(group_name)) + except ClientError as e: + module.fail_json_aws( + e, + msg="Could not create or use existing group '{0}' in rule. Make sure the group exists".format(group_name)) + if auto_group is not None: + group_id = auto_group['GroupId'] + groups[group_id] = auto_group + groups[group_name] = auto_group + target_group_created = True + return 'group', group_id, target_group_created + elif 'cidr_ip' in rule: + return 'ipv4', validate_ip(module, rule['cidr_ip']), False + elif 'cidr_ipv6' in rule: + return 'ipv6', validate_ip(module, rule['cidr_ipv6']), False + elif 'ip_prefix' in rule: + return 'ip_prefix', rule['ip_prefix'], False + + module.fail_json(msg="Could not match target for rule {0}".format(rule), failed_rule=rule) + + +def ports_expand(ports): + # takes a list of ports and returns a list of (port_from, port_to) + ports_expanded = [] + for port in ports: + if not isinstance(port, string_types): + ports_expanded.append((port,) * 2) + elif '-' in port: + ports_expanded.append(tuple(int(p.strip()) for p in port.split('-', 1))) + else: + ports_expanded.append((int(port.strip()),) * 2) + + return ports_expanded + + +def rule_expand_ports(rule): + # takes a rule dict and returns a list of expanded rule dicts + if 'ports' not in rule: + if isinstance(rule.get('from_port'), string_types): + rule['from_port'] = int(rule.get('from_port')) + if isinstance(rule.get('to_port'), string_types): + rule['to_port'] = int(rule.get('to_port')) + return [rule] + + ports = rule['ports'] if isinstance(rule['ports'], list) else [rule['ports']] + + rule_expanded = [] + for from_to in ports_expand(ports): + temp_rule = rule.copy() + del temp_rule['ports'] + temp_rule['from_port'], temp_rule['to_port'] = sorted(from_to) + rule_expanded.append(temp_rule) + + return rule_expanded + + +def rules_expand_ports(rules): + # takes a list of rules and expands it based on 'ports' + if not rules: + return rules + + return [rule for rule_complex in rules + for rule in rule_expand_ports(rule_complex)] + + +def rule_expand_source(rule, source_type): + # takes a rule dict and returns a list of expanded rule dicts for specified source_type + sources = rule[source_type] if isinstance(rule[source_type], list) else [rule[source_type]] + source_types_all = ('cidr_ip', 'cidr_ipv6', 'group_id', 'group_name', 'ip_prefix') + + rule_expanded = [] + for source in sources: + temp_rule = rule.copy() + for s in source_types_all: + temp_rule.pop(s, None) + temp_rule[source_type] = source + rule_expanded.append(temp_rule) + + return rule_expanded + + +def rule_expand_sources(rule): + # takes a rule dict and returns a list of expanded rule discts + source_types = (stype for stype in ('cidr_ip', 'cidr_ipv6', 'group_id', 'group_name', 'ip_prefix') if stype in rule) + + return [r for stype in source_types + for r in rule_expand_source(rule, stype)] + + +def rules_expand_sources(rules): + # takes a list of rules and expands it based on 'cidr_ip', 'group_id', 'group_name' + if not rules: + return rules + + return [rule for rule_complex in rules + for rule in rule_expand_sources(rule_complex)] + + +def update_rules_description(module, client, rule_type, group_id, ip_permissions): + if module.check_mode: + return + try: + if rule_type == "in": + client.update_security_group_rule_descriptions_ingress(GroupId=group_id, IpPermissions=ip_permissions) + if rule_type == "out": + client.update_security_group_rule_descriptions_egress(GroupId=group_id, IpPermissions=ip_permissions) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to update rule description for group %s" % group_id) + + +def fix_port_and_protocol(permission): + for key in ('FromPort', 'ToPort'): + if key in permission: + if permission[key] is None: + del permission[key] + else: + permission[key] = int(permission[key]) + + permission['IpProtocol'] = to_text(permission['IpProtocol']) + + return permission + + +def remove_old_permissions(client, module, revoke_ingress, revoke_egress, group_id): + if revoke_ingress: + revoke(client, module, revoke_ingress, group_id, 'in') + if revoke_egress: + revoke(client, module, revoke_egress, group_id, 'out') + return bool(revoke_ingress or revoke_egress) + + +def revoke(client, module, ip_permissions, group_id, rule_type): + if not module.check_mode: + try: + if rule_type == 'in': + client.revoke_security_group_ingress(GroupId=group_id, IpPermissions=ip_permissions) + elif rule_type == 'out': + client.revoke_security_group_egress(GroupId=group_id, IpPermissions=ip_permissions) + except (BotoCoreError, ClientError) as e: + rules = 'ingress rules' if rule_type == 'in' else 'egress rules' + module.fail_json_aws(e, "Unable to revoke {0}: {1}".format(rules, ip_permissions)) + + +def add_new_permissions(client, module, new_ingress, new_egress, group_id): + if new_ingress: + authorize(client, module, new_ingress, group_id, 'in') + if new_egress: + authorize(client, module, new_egress, group_id, 'out') + return bool(new_ingress or new_egress) + + +def authorize(client, module, ip_permissions, group_id, rule_type): + if not module.check_mode: + try: + if rule_type == 'in': + client.authorize_security_group_ingress(GroupId=group_id, IpPermissions=ip_permissions) + elif rule_type == 'out': + client.authorize_security_group_egress(GroupId=group_id, IpPermissions=ip_permissions) + except (BotoCoreError, ClientError) as e: + rules = 'ingress rules' if rule_type == 'in' else 'egress rules' + module.fail_json_aws(e, "Unable to authorize {0}: {1}".format(rules, ip_permissions)) + + +def validate_ip(module, cidr_ip): + split_addr = cidr_ip.split('/') + if len(split_addr) == 2: + # this_ip is a IPv4 or IPv6 CIDR that may or may not have host bits set + # Get the network bits if IPv4, and validate if IPv6. + try: + ip = to_subnet(split_addr[0], split_addr[1]) + if ip != cidr_ip: + module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, " + "check the network mask and make sure that only network bits are set: {1}.".format( + cidr_ip, ip)) + except ValueError: + # to_subnet throws a ValueError on IPv6 networks, so we should be working with v6 if we get here + try: + isinstance(ip_network(to_text(cidr_ip)), IPv6Network) + ip = cidr_ip + except ValueError: + # If a host bit is set on something other than a /128, IPv6Network will throw a ValueError + # The ipv6_cidr in this case probably looks like "2001:DB8:A0B:12F0::1/64" and we just want the network bits + ip6 = to_ipv6_subnet(split_addr[0]) + "/" + split_addr[1] + if ip6 != cidr_ip: + module.warn("One of your IPv6 CIDR addresses ({0}) has host bits set. To get rid of this warning, " + "check the network mask and make sure that only network bits are set: {1}.".format(cidr_ip, ip6)) + return ip6 + return ip + return cidr_ip + + +def update_tags(client, module, group_id, current_tags, tags, purge_tags): + tags_need_modify, tags_to_delete = compare_aws_tags(current_tags, tags, purge_tags) + + if not module.check_mode: + if tags_to_delete: + try: + client.delete_tags(Resources=[group_id], Tags=[{'Key': tag} for tag in tags_to_delete]) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to delete tags {0}".format(tags_to_delete)) + + # Add/update tags + if tags_need_modify: + try: + client.create_tags(Resources=[group_id], Tags=ansible_dict_to_boto3_tag_list(tags_need_modify)) + except (BotoCoreError, ClientError) as e: + module.fail_json(e, msg="Unable to add tags {0}".format(tags_need_modify)) + + return bool(tags_need_modify or tags_to_delete) + + +def update_rule_descriptions(module, group_id, present_ingress, named_tuple_ingress_list, present_egress, named_tuple_egress_list): + changed = False + client = module.client('ec2') + ingress_needs_desc_update = [] + egress_needs_desc_update = [] + + for present_rule in present_egress: + needs_update = [r for r in named_tuple_egress_list if rule_cmp(r, present_rule) and r.description != present_rule.description] + for r in needs_update: + named_tuple_egress_list.remove(r) + egress_needs_desc_update.extend(needs_update) + for present_rule in present_ingress: + needs_update = [r for r in named_tuple_ingress_list if rule_cmp(r, present_rule) and r.description != present_rule.description] + for r in needs_update: + named_tuple_ingress_list.remove(r) + ingress_needs_desc_update.extend(needs_update) + + if ingress_needs_desc_update: + update_rules_description(module, client, 'in', group_id, rules_to_permissions(ingress_needs_desc_update)) + changed |= True + if egress_needs_desc_update: + update_rules_description(module, client, 'out', group_id, rules_to_permissions(egress_needs_desc_update)) + changed |= True + return changed + + +def create_security_group(client, module, name, description, vpc_id): + if not module.check_mode: + params = dict(GroupName=name, Description=description) + if vpc_id: + params['VpcId'] = vpc_id + try: + group = client.create_security_group(**params) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to create security group") + # When a group is created, an egress_rule ALLOW ALL + # to 0.0.0.0/0 is added automatically but it's not + # reflected in the object returned by the AWS API + # call. We re-read the group for getting an updated object + # amazon sometimes takes a couple seconds to update the security group so wait till it exists + while True: + sleep(3) + group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0] + if group.get('VpcId') and not group.get('IpPermissionsEgress'): + pass + else: + break + return group + return None + + +def wait_for_rule_propagation(module, group, desired_ingress, desired_egress, purge_ingress, purge_egress): + group_id = group['GroupId'] + tries = 6 + + def await_rules(group, desired_rules, purge, rule_key): + for i in range(tries): + current_rules = set(sum([list(rule_from_group_permission(p)) for p in group[rule_key]], [])) + if purge and len(current_rules ^ set(desired_rules)) == 0: + return group + elif purge: + conflicts = current_rules ^ set(desired_rules) + # For cases where set comparison is equivalent, but invalid port/proto exist + for a, b in itertools.combinations(conflicts, 2): + if rule_cmp(a, b): + conflicts.discard(a) + conflicts.discard(b) + if not len(conflicts): + return group + elif current_rules.issuperset(desired_rules) and not purge: + return group + sleep(10) + group = get_security_groups_with_backoff(module.client('ec2'), GroupIds=[group_id])['SecurityGroups'][0] + module.warn("Ran out of time waiting for {0} {1}. Current: {2}, Desired: {3}".format(group_id, rule_key, current_rules, desired_rules)) + return group + + group = get_security_groups_with_backoff(module.client('ec2'), GroupIds=[group_id])['SecurityGroups'][0] + if 'VpcId' in group and module.params.get('rules_egress') is not None: + group = await_rules(group, desired_egress, purge_egress, 'IpPermissionsEgress') + return await_rules(group, desired_ingress, purge_ingress, 'IpPermissions') + + +def group_exists(client, module, vpc_id, group_id, name): + params = {'Filters': []} + if group_id: + params['GroupIds'] = [group_id] + if name: + # Add name to filters rather than params['GroupNames'] + # because params['GroupNames'] only checks the default vpc if no vpc is provided + params['Filters'].append({'Name': 'group-name', 'Values': [name]}) + if vpc_id: + params['Filters'].append({'Name': 'vpc-id', 'Values': [vpc_id]}) + # Don't filter by description to maintain backwards compatibility + + try: + security_groups = sg_exists_with_backoff(client, **params).get('SecurityGroups', []) + all_groups = get_security_groups_with_backoff(client).get('SecurityGroups', []) + except (BotoCoreError, ClientError) as e: # pylint: disable=duplicate-except + module.fail_json_aws(e, msg="Error in describe_security_groups") + + if security_groups: + groups = dict((group['GroupId'], group) for group in all_groups) + groups.update(dict((group['GroupName'], group) for group in all_groups)) + if vpc_id: + vpc_wins = dict((group['GroupName'], group) for group in all_groups if group.get('VpcId') and group['VpcId'] == vpc_id) + groups.update(vpc_wins) + # maintain backwards compatibility by using the last matching group + return security_groups[-1], groups + return None, {} + + +def verify_rules_with_descriptions_permitted(client, module, rules, rules_egress): + if not hasattr(client, "update_security_group_rule_descriptions_egress"): + all_rules = rules if rules else [] + rules_egress if rules_egress else [] + if any('rule_desc' in rule for rule in all_rules): + module.fail_json(msg="Using rule descriptions requires botocore version >= 1.7.2.") + + +def get_diff_final_resource(client, module, security_group): + def get_account_id(security_group, module): + try: + owner_id = security_group.get('owner_id', module.client('sts').get_caller_identity()['Account']) + except (BotoCoreError, ClientError) as e: + owner_id = "Unable to determine owner_id: {0}".format(to_text(e)) + return owner_id + + def get_final_tags(security_group_tags, specified_tags, purge_tags): + if specified_tags is None: + return security_group_tags + tags_need_modify, tags_to_delete = compare_aws_tags(security_group_tags, specified_tags, purge_tags) + end_result_tags = dict((k, v) for k, v in specified_tags.items() if k not in tags_to_delete) + end_result_tags.update(dict((k, v) for k, v in security_group_tags.items() if k not in tags_to_delete)) + end_result_tags.update(tags_need_modify) + return end_result_tags + + def get_final_rules(client, module, security_group_rules, specified_rules, purge_rules): + if specified_rules is None: + return security_group_rules + if purge_rules: + final_rules = [] + else: + final_rules = list(security_group_rules) + specified_rules = flatten_nested_targets(module, deepcopy(specified_rules)) + for rule in specified_rules: + format_rule = { + 'from_port': None, 'to_port': None, 'ip_protocol': rule.get('proto', 'tcp'), + 'ip_ranges': [], 'ipv6_ranges': [], 'prefix_list_ids': [], 'user_id_group_pairs': [] + } + if rule.get('proto', 'tcp') in ('all', '-1', -1): + format_rule['ip_protocol'] = '-1' + format_rule.pop('from_port') + format_rule.pop('to_port') + elif rule.get('ports'): + if rule.get('ports') and (isinstance(rule['ports'], string_types) or isinstance(rule['ports'], int)): + rule['ports'] = [rule['ports']] + for port in rule.get('ports'): + if isinstance(port, string_types) and '-' in port: + format_rule['from_port'], format_rule['to_port'] = port.split('-') + else: + format_rule['from_port'] = format_rule['to_port'] = port + elif rule.get('from_port') or rule.get('to_port'): + format_rule['from_port'] = rule.get('from_port', rule.get('to_port')) + format_rule['to_port'] = rule.get('to_port', rule.get('from_port')) + for source_type in ('cidr_ip', 'cidr_ipv6', 'prefix_list_id'): + if rule.get(source_type): + rule_key = {'cidr_ip': 'ip_ranges', 'cidr_ipv6': 'ipv6_ranges', 'prefix_list_id': 'prefix_list_ids'}.get(source_type) + if rule.get('rule_desc'): + format_rule[rule_key] = [{source_type: rule[source_type], 'description': rule['rule_desc']}] + else: + if not isinstance(rule[source_type], list): + rule[source_type] = [rule[source_type]] + format_rule[rule_key] = [{source_type: target} for target in rule[source_type]] + if rule.get('group_id') or rule.get('group_name'): + rule_sg = camel_dict_to_snake_dict(group_exists(client, module, module.params['vpc_id'], rule.get('group_id'), rule.get('group_name'))[0]) + format_rule['user_id_group_pairs'] = [{ + 'description': rule_sg.get('description', rule_sg.get('group_desc')), + 'group_id': rule_sg.get('group_id', rule.get('group_id')), + 'group_name': rule_sg.get('group_name', rule.get('group_name')), + 'peering_status': rule_sg.get('peering_status'), + 'user_id': rule_sg.get('user_id', get_account_id(security_group, module)), + 'vpc_id': rule_sg.get('vpc_id', module.params['vpc_id']), + 'vpc_peering_connection_id': rule_sg.get('vpc_peering_connection_id') + }] + for k, v in list(format_rule['user_id_group_pairs'][0].items()): + if v is None: + format_rule['user_id_group_pairs'][0].pop(k) + final_rules.append(format_rule) + # Order final rules consistently + final_rules.sort(key=get_ip_permissions_sort_key) + return final_rules + security_group_ingress = security_group.get('ip_permissions', []) + specified_ingress = module.params['rules'] + purge_ingress = module.params['purge_rules'] + security_group_egress = security_group.get('ip_permissions_egress', []) + specified_egress = module.params['rules_egress'] + purge_egress = module.params['purge_rules_egress'] + return { + 'description': module.params['description'], + 'group_id': security_group.get('group_id', 'sg-xxxxxxxx'), + 'group_name': security_group.get('group_name', module.params['name']), + 'ip_permissions': get_final_rules(client, module, security_group_ingress, specified_ingress, purge_ingress), + 'ip_permissions_egress': get_final_rules(client, module, security_group_egress, specified_egress, purge_egress), + 'owner_id': get_account_id(security_group, module), + 'tags': get_final_tags(security_group.get('tags', {}), module.params['tags'], module.params['purge_tags']), + 'vpc_id': security_group.get('vpc_id', module.params['vpc_id'])} + + +def flatten_nested_targets(module, rules): + def _flatten(targets): + for target in targets: + if isinstance(target, list): + for t in _flatten(target): + yield t + elif isinstance(target, string_types): + yield target + + if rules is not None: + for rule in rules: + target_list_type = None + if isinstance(rule.get('cidr_ip'), list): + target_list_type = 'cidr_ip' + elif isinstance(rule.get('cidr_ipv6'), list): + target_list_type = 'cidr_ipv6' + if target_list_type is not None: + rule[target_list_type] = list(_flatten(rule[target_list_type])) + return rules + + +def get_rule_sort_key(dicts): + if dicts.get('cidr_ip'): + return dicts.get('cidr_ip') + elif dicts.get('cidr_ipv6'): + return dicts.get('cidr_ipv6') + elif dicts.get('prefix_list_id'): + return dicts.get('prefix_list_id') + elif dicts.get('group_id'): + return dicts.get('group_id') + return None + + +def get_ip_permissions_sort_key(rule): + if rule.get('ip_ranges'): + rule.get('ip_ranges').sort(key=get_rule_sort_key) + return rule.get('ip_ranges')[0]['cidr_ip'] + elif rule.get('ipv6_ranges'): + rule.get('ipv6_ranges').sort(key=get_rule_sort_key) + return rule.get('ipv6_ranges')[0]['cidr_ipv6'] + elif rule.get('prefix_list_ids'): + rule.get('prefix_list_ids').sort(key=get_rule_sort_key) + return rule.get('prefix_list_ids')[0]['prefix_list_id'] + elif rule.get('user_id_group_pairs'): + rule.get('user_id_group_pairs').sort(key=get_rule_sort_key) + return rule.get('user_id_group_pairs')[0]['group_id'] + return None + + +def main(): + argument_spec = dict( + name=dict(), + group_id=dict(), + description=dict(), + vpc_id=dict(), + rules=dict(type='list'), + rules_egress=dict(type='list'), + state=dict(default='present', type='str', choices=['present', 'absent']), + purge_rules=dict(default=True, required=False, type='bool'), + purge_rules_egress=dict(default=True, required=False, type='bool'), + tags=dict(required=False, type='dict', aliases=['resource_tags']), + purge_tags=dict(default=True, required=False, type='bool') + ) + module = AnsibleAWSModule( + argument_spec=argument_spec, + supports_check_mode=True, + required_one_of=[['name', 'group_id']], + required_if=[['state', 'present', ['name']]], + ) + + name = module.params['name'] + group_id = module.params['group_id'] + description = module.params['description'] + vpc_id = module.params['vpc_id'] + rules = flatten_nested_targets(module, deepcopy(module.params['rules'])) + rules_egress = flatten_nested_targets(module, deepcopy(module.params['rules_egress'])) + rules = deduplicate_rules_args(rules_expand_sources(rules_expand_ports(rules))) + rules_egress = deduplicate_rules_args(rules_expand_sources(rules_expand_ports(rules_egress))) + state = module.params.get('state') + purge_rules = module.params['purge_rules'] + purge_rules_egress = module.params['purge_rules_egress'] + tags = module.params['tags'] + purge_tags = module.params['purge_tags'] + + if state == 'present' and not description: + module.fail_json(msg='Must provide description when state is present.') + + changed = False + client = module.client('ec2') + + verify_rules_with_descriptions_permitted(client, module, rules, rules_egress) + group, groups = group_exists(client, module, vpc_id, group_id, name) + group_created_new = not bool(group) + + global current_account_id + current_account_id = get_aws_account_id(module) + + before = {} + after = {} + + # Ensure requested group is absent + if state == 'absent': + if group: + # found a match, delete it + before = camel_dict_to_snake_dict(group, ignore_list=['Tags']) + before['tags'] = boto3_tag_list_to_ansible_dict(before.get('tags', [])) + try: + if not module.check_mode: + client.delete_security_group(GroupId=group['GroupId']) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Unable to delete security group '%s'" % group) + else: + group = None + changed = True + else: + # no match found, no changes required + pass + + # Ensure requested group is present + elif state == 'present': + if group: + # existing group + before = camel_dict_to_snake_dict(group, ignore_list=['Tags']) + before['tags'] = boto3_tag_list_to_ansible_dict(before.get('tags', [])) + if group['Description'] != description: + module.warn("Group description does not match existing group. Descriptions cannot be changed without deleting " + "and re-creating the security group. Try using state=absent to delete, then rerunning this task.") + else: + # no match found, create it + group = create_security_group(client, module, name, description, vpc_id) + changed = True + + if tags is not None and group is not None: + current_tags = boto3_tag_list_to_ansible_dict(group.get('Tags', [])) + changed |= update_tags(client, module, group['GroupId'], current_tags, tags, purge_tags) + + if group: + named_tuple_ingress_list = [] + named_tuple_egress_list = [] + current_ingress = sum([list(rule_from_group_permission(p)) for p in group['IpPermissions']], []) + current_egress = sum([list(rule_from_group_permission(p)) for p in group['IpPermissionsEgress']], []) + + for new_rules, rule_type, named_tuple_rule_list in [(rules, 'in', named_tuple_ingress_list), + (rules_egress, 'out', named_tuple_egress_list)]: + if new_rules is None: + continue + for rule in new_rules: + target_type, target, target_group_created = get_target_from_rule( + module, client, rule, name, group, groups, vpc_id) + changed |= target_group_created + + if rule.get('proto', 'tcp') in ('all', '-1', -1): + rule['proto'] = '-1' + rule['from_port'] = None + rule['to_port'] = None + try: + int(rule.get('proto', 'tcp')) + rule['proto'] = to_text(rule.get('proto', 'tcp')) + rule['from_port'] = None + rule['to_port'] = None + except ValueError: + # rule does not use numeric protocol spec + pass + + named_tuple_rule_list.append( + Rule( + port_range=(rule['from_port'], rule['to_port']), + protocol=to_text(rule.get('proto', 'tcp')), + target=target, target_type=target_type, + description=rule.get('rule_desc'), + ) + ) + + # List comprehensions for rules to add, rules to modify, and rule ids to determine purging + new_ingress_permissions = [to_permission(r) for r in (set(named_tuple_ingress_list) - set(current_ingress))] + new_egress_permissions = [to_permission(r) for r in (set(named_tuple_egress_list) - set(current_egress))] + + if module.params.get('rules_egress') is None and 'VpcId' in group: + # when no egress rules are specified and we're in a VPC, + # we add in a default allow all out rule, which was the + # default behavior before egress rules were added + rule = Rule((None, None), '-1', '0.0.0.0/0', 'ipv4', None) + if rule in current_egress: + named_tuple_egress_list.append(rule) + if rule not in current_egress: + current_egress.append(rule) + + # List comprehensions for rules to add, rules to modify, and rule ids to determine purging + present_ingress = list(set(named_tuple_ingress_list).union(set(current_ingress))) + present_egress = list(set(named_tuple_egress_list).union(set(current_egress))) + + if purge_rules: + revoke_ingress = [] + for p in present_ingress: + if not any([rule_cmp(p, b) for b in named_tuple_ingress_list]): + revoke_ingress.append(to_permission(p)) + else: + revoke_ingress = [] + if purge_rules_egress and module.params.get('rules_egress') is not None: + if module.params.get('rules_egress') is []: + revoke_egress = [ + to_permission(r) for r in set(present_egress) - set(named_tuple_egress_list) + if r != Rule((None, None), '-1', '0.0.0.0/0', 'ipv4', None) + ] + else: + revoke_egress = [] + for p in present_egress: + if not any([rule_cmp(p, b) for b in named_tuple_egress_list]): + revoke_egress.append(to_permission(p)) + else: + revoke_egress = [] + + # named_tuple_ingress_list and named_tuple_egress_list got updated by + # method update_rule_descriptions, deep copy these two lists to new + # variables for the record of the 'desired' ingress and egress sg permissions + desired_ingress = deepcopy(named_tuple_ingress_list) + desired_egress = deepcopy(named_tuple_egress_list) + + changed |= update_rule_descriptions(module, group['GroupId'], present_ingress, named_tuple_ingress_list, present_egress, named_tuple_egress_list) + + # Revoke old rules + changed |= remove_old_permissions(client, module, revoke_ingress, revoke_egress, group['GroupId']) + rule_msg = 'Revoking {0}, and egress {1}'.format(revoke_ingress, revoke_egress) + + new_ingress_permissions = [to_permission(r) for r in (set(named_tuple_ingress_list) - set(current_ingress))] + new_ingress_permissions = rules_to_permissions(set(named_tuple_ingress_list) - set(current_ingress)) + new_egress_permissions = rules_to_permissions(set(named_tuple_egress_list) - set(current_egress)) + # Authorize new rules + changed |= add_new_permissions(client, module, new_ingress_permissions, new_egress_permissions, group['GroupId']) + + if group_created_new and module.params.get('rules') is None and module.params.get('rules_egress') is None: + # A new group with no rules provided is already being awaited. + # When it is created we wait for the default egress rule to be added by AWS + security_group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0] + elif changed and not module.check_mode: + # keep pulling until current security group rules match the desired ingress and egress rules + security_group = wait_for_rule_propagation(module, group, desired_ingress, desired_egress, purge_rules, purge_rules_egress) + else: + security_group = get_security_groups_with_backoff(client, GroupIds=[group['GroupId']])['SecurityGroups'][0] + security_group = camel_dict_to_snake_dict(security_group, ignore_list=['Tags']) + security_group['tags'] = boto3_tag_list_to_ansible_dict(security_group.get('tags', [])) + + else: + security_group = {'group_id': None} + + if module._diff: + if module.params['state'] == 'present': + after = get_diff_final_resource(client, module, security_group) + if before.get('ip_permissions'): + before['ip_permissions'].sort(key=get_ip_permissions_sort_key) + + security_group['diff'] = [{'before': before, 'after': after}] + + module.exit_json(changed=changed, **security_group) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_vpc_net.py b/test/support/integration/plugins/modules/ec2_vpc_net.py new file mode 100644 index 00000000..30e4b1e9 --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_vpc_net.py @@ -0,0 +1,524 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: ec2_vpc_net +short_description: Configure AWS virtual private clouds +description: + - Create, modify, and terminate AWS virtual private clouds. +version_added: "2.0" +author: + - Jonathan Davila (@defionscode) + - Sloane Hertel (@s-hertel) +options: + name: + description: + - The name to give your VPC. This is used in combination with C(cidr_block) to determine if a VPC already exists. + required: yes + type: str + cidr_block: + description: + - The primary CIDR of the VPC. After 2.5 a list of CIDRs can be provided. The first in the list will be used as the primary CIDR + and is used in conjunction with the C(name) to ensure idempotence. + required: yes + type: list + elements: str + ipv6_cidr: + description: + - Request an Amazon-provided IPv6 CIDR block with /56 prefix length. You cannot specify the range of IPv6 addresses, + or the size of the CIDR block. + default: False + type: bool + version_added: '2.10' + purge_cidrs: + description: + - Remove CIDRs that are associated with the VPC and are not specified in C(cidr_block). + default: no + type: bool + version_added: '2.5' + tenancy: + description: + - Whether to be default or dedicated tenancy. This cannot be changed after the VPC has been created. + default: default + choices: [ 'default', 'dedicated' ] + type: str + dns_support: + description: + - Whether to enable AWS DNS support. + default: yes + type: bool + dns_hostnames: + description: + - Whether to enable AWS hostname support. + default: yes + type: bool + dhcp_opts_id: + description: + - The id of the DHCP options to use for this VPC. + type: str + tags: + description: + - The tags you want attached to the VPC. This is independent of the name value, note if you pass a 'Name' key it would override the Name of + the VPC if it's different. + aliases: [ 'resource_tags' ] + type: dict + state: + description: + - The state of the VPC. Either absent or present. + default: present + choices: [ 'present', 'absent' ] + type: str + multi_ok: + description: + - By default the module will not create another VPC if there is another VPC with the same name and CIDR block. Specify this as true if you want + duplicate VPCs created. + type: bool + default: false +requirements: + - boto3 + - botocore +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +- name: create a VPC with dedicated tenancy and a couple of tags + ec2_vpc_net: + name: Module_dev2 + cidr_block: 10.10.0.0/16 + region: us-east-1 + tags: + module: ec2_vpc_net + this: works + tenancy: dedicated + +- name: create a VPC with dedicated tenancy and request an IPv6 CIDR + ec2_vpc_net: + name: Module_dev2 + cidr_block: 10.10.0.0/16 + ipv6_cidr: True + region: us-east-1 + tenancy: dedicated +''' + +RETURN = ''' +vpc: + description: info about the VPC that was created or deleted + returned: always + type: complex + contains: + cidr_block: + description: The CIDR of the VPC + returned: always + type: str + sample: 10.0.0.0/16 + cidr_block_association_set: + description: IPv4 CIDR blocks associated with the VPC + returned: success + type: list + sample: + "cidr_block_association_set": [ + { + "association_id": "vpc-cidr-assoc-97aeeefd", + "cidr_block": "20.0.0.0/24", + "cidr_block_state": { + "state": "associated" + } + } + ] + classic_link_enabled: + description: indicates whether ClassicLink is enabled + returned: always + type: bool + sample: false + dhcp_options_id: + description: the id of the DHCP options associated with this VPC + returned: always + type: str + sample: dopt-0fb8bd6b + id: + description: VPC resource id + returned: always + type: str + sample: vpc-c2e00da5 + instance_tenancy: + description: indicates whether VPC uses default or dedicated tenancy + returned: always + type: str + sample: default + ipv6_cidr_block_association_set: + description: IPv6 CIDR blocks associated with the VPC + returned: success + type: list + sample: + "ipv6_cidr_block_association_set": [ + { + "association_id": "vpc-cidr-assoc-97aeeefd", + "ipv6_cidr_block": "2001:db8::/56", + "ipv6_cidr_block_state": { + "state": "associated" + } + } + ] + is_default: + description: indicates whether this is the default VPC + returned: always + type: bool + sample: false + state: + description: state of the VPC + returned: always + type: str + sample: available + tags: + description: tags attached to the VPC, includes name + returned: always + type: complex + contains: + Name: + description: name tag for the VPC + returned: always + type: str + sample: pk_vpc4 +''' + +try: + import botocore +except ImportError: + pass # Handled by AnsibleAWSModule + +from time import sleep, time +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.ec2 import (AWSRetry, camel_dict_to_snake_dict, compare_aws_tags, + ansible_dict_to_boto3_tag_list, boto3_tag_list_to_ansible_dict) +from ansible.module_utils.six import string_types +from ansible.module_utils._text import to_native +from ansible.module_utils.network.common.utils import to_subnet + + +def vpc_exists(module, vpc, name, cidr_block, multi): + """Returns None or a vpc object depending on the existence of a VPC. When supplied + with a CIDR, it will check for matching tags to determine if it is a match + otherwise it will assume the VPC does not exist and thus return None. + """ + try: + matching_vpcs = vpc.describe_vpcs(Filters=[{'Name': 'tag:Name', 'Values': [name]}, {'Name': 'cidr-block', 'Values': cidr_block}])['Vpcs'] + # If an exact matching using a list of CIDRs isn't found, check for a match with the first CIDR as is documented for C(cidr_block) + if not matching_vpcs: + matching_vpcs = vpc.describe_vpcs(Filters=[{'Name': 'tag:Name', 'Values': [name]}, {'Name': 'cidr-block', 'Values': [cidr_block[0]]}])['Vpcs'] + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to describe VPCs") + + if multi: + return None + elif len(matching_vpcs) == 1: + return matching_vpcs[0]['VpcId'] + elif len(matching_vpcs) > 1: + module.fail_json(msg='Currently there are %d VPCs that have the same name and ' + 'CIDR block you specified. If you would like to create ' + 'the VPC anyway please pass True to the multi_ok param.' % len(matching_vpcs)) + return None + + +@AWSRetry.backoff(delay=3, tries=8, catch_extra_error_codes=['InvalidVpcID.NotFound']) +def get_classic_link_with_backoff(connection, vpc_id): + try: + return connection.describe_vpc_classic_link(VpcIds=[vpc_id])['Vpcs'][0].get('ClassicLinkEnabled') + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Message"] == "The functionality you requested is not available in this region.": + return False + else: + raise + + +def get_vpc(module, connection, vpc_id): + # wait for vpc to be available + try: + connection.get_waiter('vpc_available').wait(VpcIds=[vpc_id]) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to wait for VPC {0} to be available.".format(vpc_id)) + + try: + vpc_obj = connection.describe_vpcs(VpcIds=[vpc_id], aws_retry=True)['Vpcs'][0] + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to describe VPCs") + try: + vpc_obj['ClassicLinkEnabled'] = get_classic_link_with_backoff(connection, vpc_id) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to describe VPCs") + + return vpc_obj + + +def update_vpc_tags(connection, module, vpc_id, tags, name): + if tags is None: + tags = dict() + + tags.update({'Name': name}) + tags = dict((k, to_native(v)) for k, v in tags.items()) + try: + current_tags = dict((t['Key'], t['Value']) for t in connection.describe_tags(Filters=[{'Name': 'resource-id', 'Values': [vpc_id]}])['Tags']) + tags_to_update, dummy = compare_aws_tags(current_tags, tags, False) + if tags_to_update: + if not module.check_mode: + tags = ansible_dict_to_boto3_tag_list(tags_to_update) + vpc_obj = connection.create_tags(Resources=[vpc_id], Tags=tags, aws_retry=True) + + # Wait for tags to be updated + expected_tags = boto3_tag_list_to_ansible_dict(tags) + filters = [{'Name': 'tag:{0}'.format(key), 'Values': [value]} for key, value in expected_tags.items()] + connection.get_waiter('vpc_available').wait(VpcIds=[vpc_id], Filters=filters) + + return True + else: + return False + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to update tags") + + +def update_dhcp_opts(connection, module, vpc_obj, dhcp_id): + if vpc_obj['DhcpOptionsId'] != dhcp_id: + if not module.check_mode: + try: + connection.associate_dhcp_options(DhcpOptionsId=dhcp_id, VpcId=vpc_obj['VpcId']) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to associate DhcpOptionsId {0}".format(dhcp_id)) + + try: + # Wait for DhcpOptionsId to be updated + filters = [{'Name': 'dhcp-options-id', 'Values': [dhcp_id]}] + connection.get_waiter('vpc_available').wait(VpcIds=[vpc_obj['VpcId']], Filters=filters) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json(msg="Failed to wait for DhcpOptionsId to be updated") + + return True + else: + return False + + +def create_vpc(connection, module, cidr_block, tenancy): + try: + if not module.check_mode: + vpc_obj = connection.create_vpc(CidrBlock=cidr_block, InstanceTenancy=tenancy) + else: + module.exit_json(changed=True) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to create the VPC") + + # wait for vpc to exist + try: + connection.get_waiter('vpc_exists').wait(VpcIds=[vpc_obj['Vpc']['VpcId']]) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Unable to wait for VPC {0} to be created.".format(vpc_obj['Vpc']['VpcId'])) + + return vpc_obj['Vpc']['VpcId'] + + +def wait_for_vpc_attribute(connection, module, vpc_id, attribute, expected_value): + start_time = time() + updated = False + while time() < start_time + 300: + current_value = connection.describe_vpc_attribute( + Attribute=attribute, + VpcId=vpc_id + )['{0}{1}'.format(attribute[0].upper(), attribute[1:])]['Value'] + if current_value != expected_value: + sleep(3) + else: + updated = True + break + if not updated: + module.fail_json(msg="Failed to wait for {0} to be updated".format(attribute)) + + +def get_cidr_network_bits(module, cidr_block): + fixed_cidrs = [] + for cidr in cidr_block: + split_addr = cidr.split('/') + if len(split_addr) == 2: + # this_ip is a IPv4 CIDR that may or may not have host bits set + # Get the network bits. + valid_cidr = to_subnet(split_addr[0], split_addr[1]) + if cidr != valid_cidr: + module.warn("One of your CIDR addresses ({0}) has host bits set. To get rid of this warning, " + "check the network mask and make sure that only network bits are set: {1}.".format(cidr, valid_cidr)) + fixed_cidrs.append(valid_cidr) + else: + # let AWS handle invalid CIDRs + fixed_cidrs.append(cidr) + return fixed_cidrs + + +def main(): + argument_spec = dict( + name=dict(required=True), + cidr_block=dict(type='list', required=True), + ipv6_cidr=dict(type='bool', default=False), + tenancy=dict(choices=['default', 'dedicated'], default='default'), + dns_support=dict(type='bool', default=True), + dns_hostnames=dict(type='bool', default=True), + dhcp_opts_id=dict(), + tags=dict(type='dict', aliases=['resource_tags']), + state=dict(choices=['present', 'absent'], default='present'), + multi_ok=dict(type='bool', default=False), + purge_cidrs=dict(type='bool', default=False), + ) + + module = AnsibleAWSModule( + argument_spec=argument_spec, + supports_check_mode=True + ) + + name = module.params.get('name') + cidr_block = get_cidr_network_bits(module, module.params.get('cidr_block')) + ipv6_cidr = module.params.get('ipv6_cidr') + purge_cidrs = module.params.get('purge_cidrs') + tenancy = module.params.get('tenancy') + dns_support = module.params.get('dns_support') + dns_hostnames = module.params.get('dns_hostnames') + dhcp_id = module.params.get('dhcp_opts_id') + tags = module.params.get('tags') + state = module.params.get('state') + multi = module.params.get('multi_ok') + + changed = False + + connection = module.client( + 'ec2', + retry_decorator=AWSRetry.jittered_backoff( + retries=8, delay=3, catch_extra_error_codes=['InvalidVpcID.NotFound'] + ) + ) + + if dns_hostnames and not dns_support: + module.fail_json(msg='In order to enable DNS Hostnames you must also enable DNS support') + + if state == 'present': + + # Check if VPC exists + vpc_id = vpc_exists(module, connection, name, cidr_block, multi) + + if vpc_id is None: + vpc_id = create_vpc(connection, module, cidr_block[0], tenancy) + changed = True + + vpc_obj = get_vpc(module, connection, vpc_id) + + associated_cidrs = dict((cidr['CidrBlock'], cidr['AssociationId']) for cidr in vpc_obj.get('CidrBlockAssociationSet', []) + if cidr['CidrBlockState']['State'] != 'disassociated') + to_add = [cidr for cidr in cidr_block if cidr not in associated_cidrs] + to_remove = [associated_cidrs[cidr] for cidr in associated_cidrs if cidr not in cidr_block] + expected_cidrs = [cidr for cidr in associated_cidrs if associated_cidrs[cidr] not in to_remove] + to_add + + if len(cidr_block) > 1: + for cidr in to_add: + changed = True + try: + connection.associate_vpc_cidr_block(CidrBlock=cidr, VpcId=vpc_id) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Unable to associate CIDR {0}.".format(ipv6_cidr)) + if ipv6_cidr: + if 'Ipv6CidrBlockAssociationSet' in vpc_obj.keys(): + module.warn("Only one IPv6 CIDR is permitted per VPC, {0} already has CIDR {1}".format( + vpc_id, + vpc_obj['Ipv6CidrBlockAssociationSet'][0]['Ipv6CidrBlock'])) + else: + try: + connection.associate_vpc_cidr_block(AmazonProvidedIpv6CidrBlock=ipv6_cidr, VpcId=vpc_id) + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Unable to associate CIDR {0}.".format(ipv6_cidr)) + + if purge_cidrs: + for association_id in to_remove: + changed = True + try: + connection.disassociate_vpc_cidr_block(AssociationId=association_id) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Unable to disassociate {0}. You must detach or delete all gateways and resources that " + "are associated with the CIDR block before you can disassociate it.".format(association_id)) + + if dhcp_id is not None: + try: + if update_dhcp_opts(connection, module, vpc_obj, dhcp_id): + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to update DHCP options") + + if tags is not None or name is not None: + try: + if update_vpc_tags(connection, module, vpc_id, tags, name): + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to update tags") + + current_dns_enabled = connection.describe_vpc_attribute(Attribute='enableDnsSupport', VpcId=vpc_id, aws_retry=True)['EnableDnsSupport']['Value'] + current_dns_hostnames = connection.describe_vpc_attribute(Attribute='enableDnsHostnames', VpcId=vpc_id, aws_retry=True)['EnableDnsHostnames']['Value'] + if current_dns_enabled != dns_support: + changed = True + if not module.check_mode: + try: + connection.modify_vpc_attribute(VpcId=vpc_id, EnableDnsSupport={'Value': dns_support}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to update enabled dns support attribute") + if current_dns_hostnames != dns_hostnames: + changed = True + if not module.check_mode: + try: + connection.modify_vpc_attribute(VpcId=vpc_id, EnableDnsHostnames={'Value': dns_hostnames}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to update enabled dns hostnames attribute") + + # wait for associated cidrs to match + if to_add or to_remove: + try: + connection.get_waiter('vpc_available').wait( + VpcIds=[vpc_id], + Filters=[{'Name': 'cidr-block-association.cidr-block', 'Values': expected_cidrs}] + ) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Failed to wait for CIDRs to update") + + # try to wait for enableDnsSupport and enableDnsHostnames to match + wait_for_vpc_attribute(connection, module, vpc_id, 'enableDnsSupport', dns_support) + wait_for_vpc_attribute(connection, module, vpc_id, 'enableDnsHostnames', dns_hostnames) + + final_state = camel_dict_to_snake_dict(get_vpc(module, connection, vpc_id)) + final_state['tags'] = boto3_tag_list_to_ansible_dict(final_state.get('tags', [])) + final_state['id'] = final_state.pop('vpc_id') + + module.exit_json(changed=changed, vpc=final_state) + + elif state == 'absent': + + # Check if VPC exists + vpc_id = vpc_exists(module, connection, name, cidr_block, multi) + + if vpc_id is not None: + try: + if not module.check_mode: + connection.delete_vpc(VpcId=vpc_id) + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to delete VPC {0} You may want to use the ec2_vpc_subnet, ec2_vpc_igw, " + "and/or ec2_vpc_route_table modules to ensure the other components are absent.".format(vpc_id)) + + module.exit_json(changed=changed, vpc={}) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/ec2_vpc_subnet.py b/test/support/integration/plugins/modules/ec2_vpc_subnet.py new file mode 100644 index 00000000..5085e99b --- /dev/null +++ b/test/support/integration/plugins/modules/ec2_vpc_subnet.py @@ -0,0 +1,604 @@ +#!/usr/bin/python +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: ec2_vpc_subnet +short_description: Manage subnets in AWS virtual private clouds +description: + - Manage subnets in AWS virtual private clouds. +version_added: "2.0" +author: +- Robert Estelle (@erydo) +- Brad Davidson (@brandond) +requirements: [ boto3 ] +options: + az: + description: + - "The availability zone for the subnet." + type: str + cidr: + description: + - "The CIDR block for the subnet. E.g. 192.0.2.0/24." + type: str + required: true + ipv6_cidr: + description: + - "The IPv6 CIDR block for the subnet. The VPC must have a /56 block assigned and this value must be a valid IPv6 /64 that falls in the VPC range." + - "Required if I(assign_instances_ipv6=true)" + version_added: "2.5" + type: str + tags: + description: + - "A dict of tags to apply to the subnet. Any tags currently applied to the subnet and not present here will be removed." + aliases: [ 'resource_tags' ] + type: dict + state: + description: + - "Create or remove the subnet." + default: present + choices: [ 'present', 'absent' ] + type: str + vpc_id: + description: + - "VPC ID of the VPC in which to create or delete the subnet." + required: true + type: str + map_public: + description: + - "Specify C(yes) to indicate that instances launched into the subnet should be assigned public IP address by default." + type: bool + default: 'no' + version_added: "2.4" + assign_instances_ipv6: + description: + - "Specify C(yes) to indicate that instances launched into the subnet should be automatically assigned an IPv6 address." + type: bool + default: false + version_added: "2.5" + wait: + description: + - "When I(wait=true) and I(state=present), module will wait for subnet to be in available state before continuing." + type: bool + default: true + version_added: "2.5" + wait_timeout: + description: + - "Number of seconds to wait for subnet to become available I(wait=True)." + default: 300 + version_added: "2.5" + type: int + purge_tags: + description: + - Whether or not to remove tags that do not appear in the I(tags) list. + type: bool + default: true + version_added: "2.5" +extends_documentation_fragment: + - aws + - ec2 +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +- name: Create subnet for database servers + ec2_vpc_subnet: + state: present + vpc_id: vpc-123456 + cidr: 10.0.1.16/28 + tags: + Name: Database Subnet + register: database_subnet + +- name: Remove subnet for database servers + ec2_vpc_subnet: + state: absent + vpc_id: vpc-123456 + cidr: 10.0.1.16/28 + +- name: Create subnet with IPv6 block assigned + ec2_vpc_subnet: + state: present + vpc_id: vpc-123456 + cidr: 10.1.100.0/24 + ipv6_cidr: 2001:db8:0:102::/64 + +- name: Remove IPv6 block assigned to subnet + ec2_vpc_subnet: + state: present + vpc_id: vpc-123456 + cidr: 10.1.100.0/24 + ipv6_cidr: '' +''' + +RETURN = ''' +subnet: + description: Dictionary of subnet values + returned: I(state=present) + type: complex + contains: + id: + description: Subnet resource id + returned: I(state=present) + type: str + sample: subnet-b883b2c4 + cidr_block: + description: The IPv4 CIDR of the Subnet + returned: I(state=present) + type: str + sample: "10.0.0.0/16" + ipv6_cidr_block: + description: The IPv6 CIDR block actively associated with the Subnet + returned: I(state=present) + type: str + sample: "2001:db8:0:102::/64" + availability_zone: + description: Availability zone of the Subnet + returned: I(state=present) + type: str + sample: us-east-1a + state: + description: state of the Subnet + returned: I(state=present) + type: str + sample: available + tags: + description: tags attached to the Subnet, includes name + returned: I(state=present) + type: dict + sample: {"Name": "My Subnet", "env": "staging"} + map_public_ip_on_launch: + description: whether public IP is auto-assigned to new instances + returned: I(state=present) + type: bool + sample: false + assign_ipv6_address_on_creation: + description: whether IPv6 address is auto-assigned to new instances + returned: I(state=present) + type: bool + sample: false + vpc_id: + description: the id of the VPC where this Subnet exists + returned: I(state=present) + type: str + sample: vpc-67236184 + available_ip_address_count: + description: number of available IPv4 addresses + returned: I(state=present) + type: str + sample: 251 + default_for_az: + description: indicates whether this is the default Subnet for this Availability Zone + returned: I(state=present) + type: bool + sample: false + ipv6_association_id: + description: The IPv6 association ID for the currently associated CIDR + returned: I(state=present) + type: str + sample: subnet-cidr-assoc-b85c74d2 + ipv6_cidr_block_association_set: + description: An array of IPv6 cidr block association set information. + returned: I(state=present) + type: complex + contains: + association_id: + description: The association ID + returned: always + type: str + ipv6_cidr_block: + description: The IPv6 CIDR block that is associated with the subnet. + returned: always + type: str + ipv6_cidr_block_state: + description: A hash/dict that contains a single item. The state of the cidr block association. + returned: always + type: dict + contains: + state: + description: The CIDR block association state. + returned: always + type: str +''' + + +import time + +try: + import botocore +except ImportError: + pass # caught by AnsibleAWSModule + +from ansible.module_utils._text import to_text +from ansible.module_utils.aws.core import AnsibleAWSModule +from ansible.module_utils.aws.waiters import get_waiter +from ansible.module_utils.ec2 import (ansible_dict_to_boto3_filter_list, ansible_dict_to_boto3_tag_list, + camel_dict_to_snake_dict, boto3_tag_list_to_ansible_dict, compare_aws_tags, AWSRetry) + + +def get_subnet_info(subnet): + if 'Subnets' in subnet: + return [get_subnet_info(s) for s in subnet['Subnets']] + elif 'Subnet' in subnet: + subnet = camel_dict_to_snake_dict(subnet['Subnet']) + else: + subnet = camel_dict_to_snake_dict(subnet) + + if 'tags' in subnet: + subnet['tags'] = boto3_tag_list_to_ansible_dict(subnet['tags']) + else: + subnet['tags'] = dict() + + if 'subnet_id' in subnet: + subnet['id'] = subnet['subnet_id'] + del subnet['subnet_id'] + + subnet['ipv6_cidr_block'] = '' + subnet['ipv6_association_id'] = '' + ipv6set = subnet.get('ipv6_cidr_block_association_set') + if ipv6set: + for item in ipv6set: + if item.get('ipv6_cidr_block_state', {}).get('state') in ('associated', 'associating'): + subnet['ipv6_cidr_block'] = item['ipv6_cidr_block'] + subnet['ipv6_association_id'] = item['association_id'] + + return subnet + + +@AWSRetry.exponential_backoff() +def describe_subnets_with_backoff(client, **params): + return client.describe_subnets(**params) + + +def waiter_params(module, params, start_time): + if not module.botocore_at_least("1.7.0"): + remaining_wait_timeout = int(module.params['wait_timeout'] + start_time - time.time()) + params['WaiterConfig'] = {'Delay': 5, 'MaxAttempts': remaining_wait_timeout // 5} + return params + + +def handle_waiter(conn, module, waiter_name, params, start_time): + try: + get_waiter(conn, waiter_name).wait( + **waiter_params(module, params, start_time) + ) + except botocore.exceptions.WaiterError as e: + module.fail_json_aws(e, "Failed to wait for updates to complete") + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "An exception happened while trying to wait for updates") + + +def create_subnet(conn, module, vpc_id, cidr, ipv6_cidr=None, az=None, start_time=None): + wait = module.params['wait'] + wait_timeout = module.params['wait_timeout'] + + params = dict(VpcId=vpc_id, + CidrBlock=cidr) + + if ipv6_cidr: + params['Ipv6CidrBlock'] = ipv6_cidr + + if az: + params['AvailabilityZone'] = az + + try: + subnet = get_subnet_info(conn.create_subnet(**params)) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't create subnet") + + # Sometimes AWS takes its time to create a subnet and so using + # new subnets's id to do things like create tags results in + # exception. + if wait and subnet.get('state') != 'available': + handle_waiter(conn, module, 'subnet_exists', {'SubnetIds': [subnet['id']]}, start_time) + try: + conn.get_waiter('subnet_available').wait( + **waiter_params(module, {'SubnetIds': [subnet['id']]}, start_time) + ) + subnet['state'] = 'available' + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, "Create subnet action timed out waiting for subnet to become available") + + return subnet + + +def ensure_tags(conn, module, subnet, tags, purge_tags, start_time): + changed = False + + filters = ansible_dict_to_boto3_filter_list({'resource-id': subnet['id'], 'resource-type': 'subnet'}) + try: + cur_tags = conn.describe_tags(Filters=filters) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't describe tags") + + to_update, to_delete = compare_aws_tags(boto3_tag_list_to_ansible_dict(cur_tags.get('Tags')), tags, purge_tags) + + if to_update: + try: + if not module.check_mode: + AWSRetry.exponential_backoff( + catch_extra_error_codes=['InvalidSubnetID.NotFound'] + )(conn.create_tags)( + Resources=[subnet['id']], + Tags=ansible_dict_to_boto3_tag_list(to_update) + ) + + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't create tags") + + if to_delete: + try: + if not module.check_mode: + tags_list = [] + for key in to_delete: + tags_list.append({'Key': key}) + + AWSRetry.exponential_backoff( + catch_extra_error_codes=['InvalidSubnetID.NotFound'] + )(conn.delete_tags)(Resources=[subnet['id']], Tags=tags_list) + + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't delete tags") + + if module.params['wait'] and not module.check_mode: + # Wait for tags to be updated + filters = [{'Name': 'tag:{0}'.format(k), 'Values': [v]} for k, v in tags.items()] + handle_waiter(conn, module, 'subnet_exists', + {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time) + + return changed + + +def ensure_map_public(conn, module, subnet, map_public, check_mode, start_time): + if check_mode: + return + try: + conn.modify_subnet_attribute(SubnetId=subnet['id'], MapPublicIpOnLaunch={'Value': map_public}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't modify subnet attribute") + + +def ensure_assign_ipv6_on_create(conn, module, subnet, assign_instances_ipv6, check_mode, start_time): + if check_mode: + return + try: + conn.modify_subnet_attribute(SubnetId=subnet['id'], AssignIpv6AddressOnCreation={'Value': assign_instances_ipv6}) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't modify subnet attribute") + + +def disassociate_ipv6_cidr(conn, module, subnet, start_time): + if subnet.get('assign_ipv6_address_on_creation'): + ensure_assign_ipv6_on_create(conn, module, subnet, False, False, start_time) + + try: + conn.disassociate_subnet_cidr_block(AssociationId=subnet['ipv6_association_id']) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't disassociate ipv6 cidr block id {0} from subnet {1}" + .format(subnet['ipv6_association_id'], subnet['id'])) + + # Wait for cidr block to be disassociated + if module.params['wait']: + filters = ansible_dict_to_boto3_filter_list( + {'ipv6-cidr-block-association.state': ['disassociated'], + 'vpc-id': subnet['vpc_id']} + ) + handle_waiter(conn, module, 'subnet_exists', + {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time) + + +def ensure_ipv6_cidr_block(conn, module, subnet, ipv6_cidr, check_mode, start_time): + wait = module.params['wait'] + changed = False + + if subnet['ipv6_association_id'] and not ipv6_cidr: + if not check_mode: + disassociate_ipv6_cidr(conn, module, subnet, start_time) + changed = True + + if ipv6_cidr: + filters = ansible_dict_to_boto3_filter_list({'ipv6-cidr-block-association.ipv6-cidr-block': ipv6_cidr, + 'vpc-id': subnet['vpc_id']}) + + try: + check_subnets = get_subnet_info(describe_subnets_with_backoff(conn, Filters=filters)) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't get subnet info") + + if check_subnets and check_subnets[0]['ipv6_cidr_block']: + module.fail_json(msg="The IPv6 CIDR '{0}' conflicts with another subnet".format(ipv6_cidr)) + + if subnet['ipv6_association_id']: + if not check_mode: + disassociate_ipv6_cidr(conn, module, subnet, start_time) + changed = True + + try: + if not check_mode: + associate_resp = conn.associate_subnet_cidr_block(SubnetId=subnet['id'], Ipv6CidrBlock=ipv6_cidr) + changed = True + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't associate ipv6 cidr {0} to {1}".format(ipv6_cidr, subnet['id'])) + else: + if not check_mode and wait: + filters = ansible_dict_to_boto3_filter_list( + {'ipv6-cidr-block-association.state': ['associated'], + 'vpc-id': subnet['vpc_id']} + ) + handle_waiter(conn, module, 'subnet_exists', + {'SubnetIds': [subnet['id']], 'Filters': filters}, start_time) + + if associate_resp.get('Ipv6CidrBlockAssociation', {}).get('AssociationId'): + subnet['ipv6_association_id'] = associate_resp['Ipv6CidrBlockAssociation']['AssociationId'] + subnet['ipv6_cidr_block'] = associate_resp['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'] + if subnet['ipv6_cidr_block_association_set']: + subnet['ipv6_cidr_block_association_set'][0] = camel_dict_to_snake_dict(associate_resp['Ipv6CidrBlockAssociation']) + else: + subnet['ipv6_cidr_block_association_set'].append(camel_dict_to_snake_dict(associate_resp['Ipv6CidrBlockAssociation'])) + + return changed + + +def get_matching_subnet(conn, module, vpc_id, cidr): + filters = ansible_dict_to_boto3_filter_list({'vpc-id': vpc_id, 'cidr-block': cidr}) + try: + subnets = get_subnet_info(describe_subnets_with_backoff(conn, Filters=filters)) + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't get matching subnet") + + if subnets: + return subnets[0] + + return None + + +def ensure_subnet_present(conn, module): + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + changed = False + + # Initialize start so max time does not exceed the specified wait_timeout for multiple operations + start_time = time.time() + + if subnet is None: + if not module.check_mode: + subnet = create_subnet(conn, module, module.params['vpc_id'], module.params['cidr'], + ipv6_cidr=module.params['ipv6_cidr'], az=module.params['az'], start_time=start_time) + changed = True + # Subnet will be None when check_mode is true + if subnet is None: + return { + 'changed': changed, + 'subnet': {} + } + if module.params['wait']: + handle_waiter(conn, module, 'subnet_exists', {'SubnetIds': [subnet['id']]}, start_time) + + if module.params['ipv6_cidr'] != subnet.get('ipv6_cidr_block'): + if ensure_ipv6_cidr_block(conn, module, subnet, module.params['ipv6_cidr'], module.check_mode, start_time): + changed = True + + if module.params['map_public'] != subnet['map_public_ip_on_launch']: + ensure_map_public(conn, module, subnet, module.params['map_public'], module.check_mode, start_time) + changed = True + + if module.params['assign_instances_ipv6'] != subnet.get('assign_ipv6_address_on_creation'): + ensure_assign_ipv6_on_create(conn, module, subnet, module.params['assign_instances_ipv6'], module.check_mode, start_time) + changed = True + + if module.params['tags'] != subnet['tags']: + stringified_tags_dict = dict((to_text(k), to_text(v)) for k, v in module.params['tags'].items()) + if ensure_tags(conn, module, subnet, stringified_tags_dict, module.params['purge_tags'], start_time): + changed = True + + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + if not module.check_mode and module.params['wait']: + # GET calls are not monotonic for map_public_ip_on_launch and assign_ipv6_address_on_creation + # so we only wait for those if necessary just before returning the subnet + subnet = ensure_final_subnet(conn, module, subnet, start_time) + + return { + 'changed': changed, + 'subnet': subnet + } + + +def ensure_final_subnet(conn, module, subnet, start_time): + for rewait in range(0, 30): + map_public_correct = False + assign_ipv6_correct = False + + if module.params['map_public'] == subnet['map_public_ip_on_launch']: + map_public_correct = True + else: + if module.params['map_public']: + handle_waiter(conn, module, 'subnet_has_map_public', {'SubnetIds': [subnet['id']]}, start_time) + else: + handle_waiter(conn, module, 'subnet_no_map_public', {'SubnetIds': [subnet['id']]}, start_time) + + if module.params['assign_instances_ipv6'] == subnet.get('assign_ipv6_address_on_creation'): + assign_ipv6_correct = True + else: + if module.params['assign_instances_ipv6']: + handle_waiter(conn, module, 'subnet_has_assign_ipv6', {'SubnetIds': [subnet['id']]}, start_time) + else: + handle_waiter(conn, module, 'subnet_no_assign_ipv6', {'SubnetIds': [subnet['id']]}, start_time) + + if map_public_correct and assign_ipv6_correct: + break + + time.sleep(5) + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + + return subnet + + +def ensure_subnet_absent(conn, module): + subnet = get_matching_subnet(conn, module, module.params['vpc_id'], module.params['cidr']) + if subnet is None: + return {'changed': False} + + try: + if not module.check_mode: + conn.delete_subnet(SubnetId=subnet['id']) + if module.params['wait']: + handle_waiter(conn, module, 'subnet_deleted', {'SubnetIds': [subnet['id']]}, time.time()) + return {'changed': True} + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + module.fail_json_aws(e, msg="Couldn't delete subnet") + + +def main(): + argument_spec = dict( + az=dict(default=None, required=False), + cidr=dict(required=True), + ipv6_cidr=dict(default='', required=False), + state=dict(default='present', choices=['present', 'absent']), + tags=dict(default={}, required=False, type='dict', aliases=['resource_tags']), + vpc_id=dict(required=True), + map_public=dict(default=False, required=False, type='bool'), + assign_instances_ipv6=dict(default=False, required=False, type='bool'), + wait=dict(type='bool', default=True), + wait_timeout=dict(type='int', default=300, required=False), + purge_tags=dict(default=True, type='bool') + ) + + required_if = [('assign_instances_ipv6', True, ['ipv6_cidr'])] + + module = AnsibleAWSModule(argument_spec=argument_spec, supports_check_mode=True, required_if=required_if) + + if module.params.get('assign_instances_ipv6') and not module.params.get('ipv6_cidr'): + module.fail_json(msg="assign_instances_ipv6 is True but ipv6_cidr is None or an empty string") + + if not module.botocore_at_least("1.7.0"): + module.warn("botocore >= 1.7.0 is required to use wait_timeout for custom wait times") + + connection = module.client('ec2') + + state = module.params.get('state') + + try: + if state == 'present': + result = ensure_subnet_present(connection, module) + elif state == 'absent': + result = ensure_subnet_absent(connection, module) + except botocore.exceptions.ClientError as e: + module.fail_json_aws(e) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/flatpak_remote.py b/test/support/integration/plugins/modules/flatpak_remote.py new file mode 100644 index 00000000..db208f1b --- /dev/null +++ b/test/support/integration/plugins/modules/flatpak_remote.py @@ -0,0 +1,243 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017 John Kwiatkoski (@JayKayy) <jkwiat40@gmail.com> +# Copyright: (c) 2018 Alexander Bethke (@oolongbrothers) <oolongbrothers@gmx.net> +# Copyright: (c) 2017 Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +# ATTENTION CONTRIBUTORS! +# +# TL;DR: Run this module's integration tests manually before opening a pull request +# +# Long explanation: +# The integration tests for this module are currently NOT run on the Ansible project's continuous +# delivery pipeline. So please: When you make changes to this module, make sure that you run the +# included integration tests manually for both Python 2 and Python 3: +# +# Python 2: +# ansible-test integration -v --docker fedora28 --docker-privileged --allow-unsupported --python 2.7 flatpak_remote +# Python 3: +# ansible-test integration -v --docker fedora28 --docker-privileged --allow-unsupported --python 3.6 flatpak_remote +# +# Because of external dependencies, the current integration tests are somewhat too slow and brittle +# to be included right now. I have plans to rewrite the integration tests based on a local flatpak +# repository so that they can be included into the normal CI pipeline. +# //oolongbrothers + + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: flatpak_remote +version_added: '2.6' +short_description: Manage flatpak repository remotes +description: +- Allows users to add or remove flatpak remotes. +- The flatpak remotes concept is comparable to what is called repositories in other packaging + formats. +- Currently, remote addition is only supported via I(flatpakrepo) file URLs. +- Existing remotes will not be updated. +- See the M(flatpak) module for managing flatpaks. +author: +- John Kwiatkoski (@JayKayy) +- Alexander Bethke (@oolongbrothers) +requirements: +- flatpak +options: + executable: + description: + - The path to the C(flatpak) executable to use. + - By default, this module looks for the C(flatpak) executable on the path. + default: flatpak + flatpakrepo_url: + description: + - The URL to the I(flatpakrepo) file representing the repository remote to add. + - When used with I(state=present), the flatpak remote specified under the I(flatpakrepo_url) + is added using the specified installation C(method). + - When used with I(state=absent), this is not required. + - Required when I(state=present). + method: + description: + - The installation method to use. + - Defines if the I(flatpak) is supposed to be installed globally for the whole C(system) + or only for the current C(user). + choices: [ system, user ] + default: system + name: + description: + - The desired name for the flatpak remote to be registered under on the managed host. + - When used with I(state=present), the remote will be added to the managed host under + the specified I(name). + - When used with I(state=absent) the remote with that name will be removed. + required: true + state: + description: + - Indicates the desired package state. + choices: [ absent, present ] + default: present +''' + +EXAMPLES = r''' +- name: Add the Gnome flatpak remote to the system installation + flatpak_remote: + name: gnome + state: present + flatpakrepo_url: https://sdk.gnome.org/gnome-apps.flatpakrepo + +- name: Add the flathub flatpak repository remote to the user installation + flatpak_remote: + name: flathub + state: present + flatpakrepo_url: https://dl.flathub.org/repo/flathub.flatpakrepo + method: user + +- name: Remove the Gnome flatpak remote from the user installation + flatpak_remote: + name: gnome + state: absent + method: user + +- name: Remove the flathub remote from the system installation + flatpak_remote: + name: flathub + state: absent +''' + +RETURN = r''' +command: + description: The exact flatpak command that was executed + returned: When a flatpak command has been executed + type: str + sample: "/usr/bin/flatpak remote-add --system flatpak-test https://dl.flathub.org/repo/flathub.flatpakrepo" +msg: + description: Module error message + returned: failure + type: str + sample: "Executable '/usr/local/bin/flatpak' was not found on the system." +rc: + description: Return code from flatpak binary + returned: When a flatpak command has been executed + type: int + sample: 0 +stderr: + description: Error output from flatpak binary + returned: When a flatpak command has been executed + type: str + sample: "error: GPG verification enabled, but no summary found (check that the configured URL in remote config is correct)\n" +stdout: + description: Output from flatpak binary + returned: When a flatpak command has been executed + type: str + sample: "flathub\tFlathub\thttps://dl.flathub.org/repo/\t1\t\n" +''' + +import subprocess +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_bytes, to_native + + +def add_remote(module, binary, name, flatpakrepo_url, method): + """Add a new remote.""" + global result + command = "{0} remote-add --{1} {2} {3}".format( + binary, method, name, flatpakrepo_url) + _flatpak_command(module, module.check_mode, command) + result['changed'] = True + + +def remove_remote(module, binary, name, method): + """Remove an existing remote.""" + global result + command = "{0} remote-delete --{1} --force {2} ".format( + binary, method, name) + _flatpak_command(module, module.check_mode, command) + result['changed'] = True + + +def remote_exists(module, binary, name, method): + """Check if the remote exists.""" + command = "{0} remote-list -d --{1}".format(binary, method) + # The query operation for the remote needs to be run even in check mode + output = _flatpak_command(module, False, command) + for line in output.splitlines(): + listed_remote = line.split() + if len(listed_remote) == 0: + continue + if listed_remote[0] == to_native(name): + return True + return False + + +def _flatpak_command(module, noop, command): + global result + if noop: + result['rc'] = 0 + result['command'] = command + return "" + + process = subprocess.Popen( + command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout_data, stderr_data = process.communicate() + result['rc'] = process.returncode + result['command'] = command + result['stdout'] = stdout_data + result['stderr'] = stderr_data + if result['rc'] != 0: + module.fail_json(msg="Failed to execute flatpak command", **result) + return to_native(stdout_data) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(type='str', required=True), + flatpakrepo_url=dict(type='str'), + method=dict(type='str', default='system', + choices=['user', 'system']), + state=dict(type='str', default="present", + choices=['absent', 'present']), + executable=dict(type='str', default="flatpak") + ), + # This module supports check mode + supports_check_mode=True, + ) + + name = module.params['name'] + flatpakrepo_url = module.params['flatpakrepo_url'] + method = module.params['method'] + state = module.params['state'] + executable = module.params['executable'] + binary = module.get_bin_path(executable, None) + + if flatpakrepo_url is None: + flatpakrepo_url = '' + + global result + result = dict( + changed=False + ) + + # If the binary was not found, fail the operation + if not binary: + module.fail_json(msg="Executable '%s' was not found on the system." % executable, **result) + + remote_already_exists = remote_exists(module, binary, to_bytes(name), method) + + if state == 'present' and not remote_already_exists: + add_remote(module, binary, name, flatpakrepo_url, method) + elif state == 'absent' and remote_already_exists: + remove_remote(module, binary, name, method) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/htpasswd.py b/test/support/integration/plugins/modules/htpasswd.py new file mode 100644 index 00000000..ad12b0c0 --- /dev/null +++ b/test/support/integration/plugins/modules/htpasswd.py @@ -0,0 +1,275 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, Nimbis Services, Inc. +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = """ +module: htpasswd +version_added: "1.3" +short_description: manage user files for basic authentication +description: + - Add and remove username/password entries in a password file using htpasswd. + - This is used by web servers such as Apache and Nginx for basic authentication. +options: + path: + required: true + aliases: [ dest, destfile ] + description: + - Path to the file that contains the usernames and passwords + name: + required: true + aliases: [ username ] + description: + - User name to add or remove + password: + required: false + description: + - Password associated with user. + - Must be specified if user does not exist yet. + crypt_scheme: + required: false + choices: ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"] + default: "apr_md5_crypt" + description: + - Encryption scheme to be used. As well as the four choices listed + here, you can also use any other hash supported by passlib, such as + md5_crypt and sha256_crypt, which are linux passwd hashes. If you + do so the password file will not be compatible with Apache or Nginx + state: + required: false + choices: [ present, absent ] + default: "present" + description: + - Whether the user entry should be present or not + create: + required: false + type: bool + default: "yes" + description: + - Used with C(state=present). If specified, the file will be created + if it does not already exist. If set to "no", will fail if the + file does not exist +notes: + - "This module depends on the I(passlib) Python library, which needs to be installed on all target systems." + - "On Debian, Ubuntu, or Fedora: install I(python-passlib)." + - "On RHEL or CentOS: Enable EPEL, then install I(python-passlib)." +requirements: [ passlib>=1.6 ] +author: "Ansible Core Team" +extends_documentation_fragment: files +""" + +EXAMPLES = """ +# Add a user to a password file and ensure permissions are set +- htpasswd: + path: /etc/nginx/passwdfile + name: janedoe + password: '9s36?;fyNp' + owner: root + group: www-data + mode: 0640 + +# Remove a user from a password file +- htpasswd: + path: /etc/apache2/passwdfile + name: foobar + state: absent + +# Add a user to a password file suitable for use by libpam-pwdfile +- htpasswd: + path: /etc/mail/passwords + name: alex + password: oedu2eGh + crypt_scheme: md5_crypt +""" + + +import os +import tempfile +import traceback +from distutils.version import LooseVersion +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native + +PASSLIB_IMP_ERR = None +try: + from passlib.apache import HtpasswdFile, htpasswd_context + from passlib.context import CryptContext + import passlib +except ImportError: + PASSLIB_IMP_ERR = traceback.format_exc() + passlib_installed = False +else: + passlib_installed = True + +apache_hashes = ["apr_md5_crypt", "des_crypt", "ldap_sha1", "plaintext"] + + +def create_missing_directories(dest): + destpath = os.path.dirname(dest) + if not os.path.exists(destpath): + os.makedirs(destpath) + + +def present(dest, username, password, crypt_scheme, create, check_mode): + """ Ensures user is present + + Returns (msg, changed) """ + if crypt_scheme in apache_hashes: + context = htpasswd_context + else: + context = CryptContext(schemes=[crypt_scheme] + apache_hashes) + if not os.path.exists(dest): + if not create: + raise ValueError('Destination %s does not exist' % dest) + if check_mode: + return ("Create %s" % dest, True) + create_missing_directories(dest) + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=True, default_scheme=crypt_scheme, context=context) + else: + ht = HtpasswdFile(dest, autoload=False, default=crypt_scheme, context=context) + if getattr(ht, 'set_password', None): + ht.set_password(username, password) + else: + ht.update(username, password) + ht.save() + return ("Created %s and added %s" % (dest, username), True) + else: + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=False, default_scheme=crypt_scheme, context=context) + else: + ht = HtpasswdFile(dest, default=crypt_scheme, context=context) + + found = None + if getattr(ht, 'check_password', None): + found = ht.check_password(username, password) + else: + found = ht.verify(username, password) + + if found: + return ("%s already present" % username, False) + else: + if not check_mode: + if getattr(ht, 'set_password', None): + ht.set_password(username, password) + else: + ht.update(username, password) + ht.save() + return ("Add/update %s" % username, True) + + +def absent(dest, username, check_mode): + """ Ensures user is absent + + Returns (msg, changed) """ + if LooseVersion(passlib.__version__) >= LooseVersion('1.6'): + ht = HtpasswdFile(dest, new=False) + else: + ht = HtpasswdFile(dest) + + if username not in ht.users(): + return ("%s not present" % username, False) + else: + if not check_mode: + ht.delete(username) + ht.save() + return ("Remove %s" % username, True) + + +def check_file_attrs(module, changed, message): + + file_args = module.load_file_common_arguments(module.params) + if module.set_fs_attributes_if_different(file_args, False): + + if changed: + message += " and " + changed = True + message += "ownership, perms or SE linux context changed" + + return message, changed + + +def main(): + arg_spec = dict( + path=dict(required=True, aliases=["dest", "destfile"]), + name=dict(required=True, aliases=["username"]), + password=dict(required=False, default=None, no_log=True), + crypt_scheme=dict(required=False, default="apr_md5_crypt"), + state=dict(required=False, default="present"), + create=dict(type='bool', default='yes'), + + ) + module = AnsibleModule(argument_spec=arg_spec, + add_file_common_args=True, + supports_check_mode=True) + + path = module.params['path'] + username = module.params['name'] + password = module.params['password'] + crypt_scheme = module.params['crypt_scheme'] + state = module.params['state'] + create = module.params['create'] + check_mode = module.check_mode + + if not passlib_installed: + module.fail_json(msg=missing_required_lib("passlib"), exception=PASSLIB_IMP_ERR) + + # Check file for blank lines in effort to avoid "need more than 1 value to unpack" error. + try: + f = open(path, "r") + except IOError: + # No preexisting file to remove blank lines from + f = None + else: + try: + lines = f.readlines() + finally: + f.close() + + # If the file gets edited, it returns true, so only edit the file if it has blank lines + strip = False + for line in lines: + if not line.strip(): + strip = True + break + + if strip: + # If check mode, create a temporary file + if check_mode: + temp = tempfile.NamedTemporaryFile() + path = temp.name + f = open(path, "w") + try: + [f.write(line) for line in lines if line.strip()] + finally: + f.close() + + try: + if state == 'present': + (msg, changed) = present(path, username, password, crypt_scheme, create, check_mode) + elif state == 'absent': + if not os.path.exists(path): + module.exit_json(msg="%s not present" % username, + warnings="%s does not exist" % path, changed=False) + (msg, changed) = absent(path, username, check_mode) + else: + module.fail_json(msg="Invalid state: %s" % state) + + check_file_attrs(module, changed, msg) + module.exit_json(msg=msg, changed=changed) + except Exception as e: + module.fail_json(msg=to_native(e)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/locale_gen.py b/test/support/integration/plugins/modules/locale_gen.py new file mode 100644 index 00000000..4968b834 --- /dev/null +++ b/test/support/integration/plugins/modules/locale_gen.py @@ -0,0 +1,237 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +--- +module: locale_gen +short_description: Creates or removes locales +description: + - Manages locales by editing /etc/locale.gen and invoking locale-gen. +version_added: "1.6" +author: +- Augustus Kling (@AugustusKling) +options: + name: + description: + - Name and encoding of the locale, such as "en_GB.UTF-8". + required: true + state: + description: + - Whether the locale shall be present. + choices: [ absent, present ] + default: present +''' + +EXAMPLES = ''' +- name: Ensure a locale exists + locale_gen: + name: de_CH.UTF-8 + state: present +''' + +import os +import re +from subprocess import Popen, PIPE, call + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_native + +LOCALE_NORMALIZATION = { + ".utf8": ".UTF-8", + ".eucjp": ".EUC-JP", + ".iso885915": ".ISO-8859-15", + ".cp1251": ".CP1251", + ".koi8r": ".KOI8-R", + ".armscii8": ".ARMSCII-8", + ".euckr": ".EUC-KR", + ".gbk": ".GBK", + ".gb18030": ".GB18030", + ".euctw": ".EUC-TW", +} + + +# =========================================== +# location module specific support methods. +# + +def is_available(name, ubuntuMode): + """Check if the given locale is available on the system. This is done by + checking either : + * if the locale is present in /etc/locales.gen + * or if the locale is present in /usr/share/i18n/SUPPORTED""" + if ubuntuMode: + __regexp = r'^(?P<locale>\S+_\S+) (?P<charset>\S+)\s*$' + __locales_available = '/usr/share/i18n/SUPPORTED' + else: + __regexp = r'^#{0,1}\s*(?P<locale>\S+_\S+) (?P<charset>\S+)\s*$' + __locales_available = '/etc/locale.gen' + + re_compiled = re.compile(__regexp) + fd = open(__locales_available, 'r') + for line in fd: + result = re_compiled.match(line) + if result and result.group('locale') == name: + return True + fd.close() + return False + + +def is_present(name): + """Checks if the given locale is currently installed.""" + output = Popen(["locale", "-a"], stdout=PIPE).communicate()[0] + output = to_native(output) + return any(fix_case(name) == fix_case(line) for line in output.splitlines()) + + +def fix_case(name): + """locale -a might return the encoding in either lower or upper case. + Passing through this function makes them uniform for comparisons.""" + for s, r in LOCALE_NORMALIZATION.items(): + name = name.replace(s, r) + return name + + +def replace_line(existing_line, new_line): + """Replaces lines in /etc/locale.gen""" + try: + f = open("/etc/locale.gen", "r") + lines = [line.replace(existing_line, new_line) for line in f] + finally: + f.close() + try: + f = open("/etc/locale.gen", "w") + f.write("".join(lines)) + finally: + f.close() + + +def set_locale(name, enabled=True): + """ Sets the state of the locale. Defaults to enabled. """ + search_string = r'#{0,1}\s*%s (?P<charset>.+)' % name + if enabled: + new_string = r'%s \g<charset>' % (name) + else: + new_string = r'# %s \g<charset>' % (name) + try: + f = open("/etc/locale.gen", "r") + lines = [re.sub(search_string, new_string, line) for line in f] + finally: + f.close() + try: + f = open("/etc/locale.gen", "w") + f.write("".join(lines)) + finally: + f.close() + + +def apply_change(targetState, name): + """Create or remove locale. + + Keyword arguments: + targetState -- Desired state, either present or absent. + name -- Name including encoding such as de_CH.UTF-8. + """ + if targetState == "present": + # Create locale. + set_locale(name, enabled=True) + else: + # Delete locale. + set_locale(name, enabled=False) + + localeGenExitValue = call("locale-gen") + if localeGenExitValue != 0: + raise EnvironmentError(localeGenExitValue, "locale.gen failed to execute, it returned " + str(localeGenExitValue)) + + +def apply_change_ubuntu(targetState, name): + """Create or remove locale. + + Keyword arguments: + targetState -- Desired state, either present or absent. + name -- Name including encoding such as de_CH.UTF-8. + """ + if targetState == "present": + # Create locale. + # Ubuntu's patched locale-gen automatically adds the new locale to /var/lib/locales/supported.d/local + localeGenExitValue = call(["locale-gen", name]) + else: + # Delete locale involves discarding the locale from /var/lib/locales/supported.d/local and regenerating all locales. + try: + f = open("/var/lib/locales/supported.d/local", "r") + content = f.readlines() + finally: + f.close() + try: + f = open("/var/lib/locales/supported.d/local", "w") + for line in content: + locale, charset = line.split(' ') + if locale != name: + f.write(line) + finally: + f.close() + # Purge locales and regenerate. + # Please provide a patch if you know how to avoid regenerating the locales to keep! + localeGenExitValue = call(["locale-gen", "--purge"]) + + if localeGenExitValue != 0: + raise EnvironmentError(localeGenExitValue, "locale.gen failed to execute, it returned " + str(localeGenExitValue)) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(type='str', required=True), + state=dict(type='str', default='present', choices=['absent', 'present']), + ), + supports_check_mode=True, + ) + + name = module.params['name'] + state = module.params['state'] + + if not os.path.exists("/etc/locale.gen"): + if os.path.exists("/var/lib/locales/supported.d/"): + # Ubuntu created its own system to manage locales. + ubuntuMode = True + else: + module.fail_json(msg="/etc/locale.gen and /var/lib/locales/supported.d/local are missing. Is the package \"locales\" installed?") + else: + # We found the common way to manage locales. + ubuntuMode = False + + if not is_available(name, ubuntuMode): + module.fail_json(msg="The locale you've entered is not available " + "on your system.") + + if is_present(name): + prev_state = "present" + else: + prev_state = "absent" + changed = (prev_state != state) + + if module.check_mode: + module.exit_json(changed=changed) + else: + if changed: + try: + if ubuntuMode is False: + apply_change(state, name) + else: + apply_change_ubuntu(state, name) + except EnvironmentError as e: + module.fail_json(msg=to_native(e), exitValue=e.errno) + + module.exit_json(name=name, changed=changed, msg="OK") + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/lvg.py b/test/support/integration/plugins/modules/lvg.py new file mode 100644 index 00000000..e2035f68 --- /dev/null +++ b/test/support/integration/plugins/modules/lvg.py @@ -0,0 +1,295 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2013, Alexander Bulimov <lazywolf0@gmail.com> +# Based on lvol module by Jeroen Hoekx <jeroen.hoekx@dsquare.be> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +author: +- Alexander Bulimov (@abulimov) +module: lvg +short_description: Configure LVM volume groups +description: + - This module creates, removes or resizes volume groups. +version_added: "1.1" +options: + vg: + description: + - The name of the volume group. + type: str + required: true + pvs: + description: + - List of comma-separated devices to use as physical devices in this volume group. + - Required when creating or resizing volume group. + - The module will take care of running pvcreate if needed. + type: list + pesize: + description: + - "The size of the physical extent. I(pesize) must be a power of 2 of at least 1 sector + (where the sector size is the largest sector size of the PVs currently used in the VG), + or at least 128KiB." + - Since Ansible 2.6, pesize can be optionally suffixed by a UNIT (k/K/m/M/g/G), default unit is megabyte. + type: str + default: "4" + pv_options: + description: + - Additional options to pass to C(pvcreate) when creating the volume group. + type: str + version_added: "2.4" + vg_options: + description: + - Additional options to pass to C(vgcreate) when creating the volume group. + type: str + version_added: "1.6" + state: + description: + - Control if the volume group exists. + type: str + choices: [ absent, present ] + default: present + force: + description: + - If C(yes), allows to remove volume group with logical volumes. + type: bool + default: no +seealso: +- module: filesystem +- module: lvol +- module: parted +notes: + - This module does not modify PE size for already present volume group. +''' + +EXAMPLES = r''' +- name: Create a volume group on top of /dev/sda1 with physical extent size = 32MB + lvg: + vg: vg.services + pvs: /dev/sda1 + pesize: 32 + +- name: Create a volume group on top of /dev/sdb with physical extent size = 128KiB + lvg: + vg: vg.services + pvs: /dev/sdb + pesize: 128K + +# If, for example, we already have VG vg.services on top of /dev/sdb1, +# this VG will be extended by /dev/sdc5. Or if vg.services was created on +# top of /dev/sda5, we first extend it with /dev/sdb1 and /dev/sdc5, +# and then reduce by /dev/sda5. +- name: Create or resize a volume group on top of /dev/sdb1 and /dev/sdc5. + lvg: + vg: vg.services + pvs: /dev/sdb1,/dev/sdc5 + +- name: Remove a volume group with name vg.services + lvg: + vg: vg.services + state: absent +''' + +import itertools +import os + +from ansible.module_utils.basic import AnsibleModule + + +def parse_vgs(data): + vgs = [] + for line in data.splitlines(): + parts = line.strip().split(';') + vgs.append({ + 'name': parts[0], + 'pv_count': int(parts[1]), + 'lv_count': int(parts[2]), + }) + return vgs + + +def find_mapper_device_name(module, dm_device): + dmsetup_cmd = module.get_bin_path('dmsetup', True) + mapper_prefix = '/dev/mapper/' + rc, dm_name, err = module.run_command("%s info -C --noheadings -o name %s" % (dmsetup_cmd, dm_device)) + if rc != 0: + module.fail_json(msg="Failed executing dmsetup command.", rc=rc, err=err) + mapper_device = mapper_prefix + dm_name.rstrip() + return mapper_device + + +def parse_pvs(module, data): + pvs = [] + dm_prefix = '/dev/dm-' + for line in data.splitlines(): + parts = line.strip().split(';') + if parts[0].startswith(dm_prefix): + parts[0] = find_mapper_device_name(module, parts[0]) + pvs.append({ + 'name': parts[0], + 'vg_name': parts[1], + }) + return pvs + + +def main(): + module = AnsibleModule( + argument_spec=dict( + vg=dict(type='str', required=True), + pvs=dict(type='list'), + pesize=dict(type='str', default='4'), + pv_options=dict(type='str', default=''), + vg_options=dict(type='str', default=''), + state=dict(type='str', default='present', choices=['absent', 'present']), + force=dict(type='bool', default=False), + ), + supports_check_mode=True, + ) + + vg = module.params['vg'] + state = module.params['state'] + force = module.boolean(module.params['force']) + pesize = module.params['pesize'] + pvoptions = module.params['pv_options'].split() + vgoptions = module.params['vg_options'].split() + + dev_list = [] + if module.params['pvs']: + dev_list = list(module.params['pvs']) + elif state == 'present': + module.fail_json(msg="No physical volumes given.") + + # LVM always uses real paths not symlinks so replace symlinks with actual path + for idx, dev in enumerate(dev_list): + dev_list[idx] = os.path.realpath(dev) + + if state == 'present': + # check given devices + for test_dev in dev_list: + if not os.path.exists(test_dev): + module.fail_json(msg="Device %s not found." % test_dev) + + # get pv list + pvs_cmd = module.get_bin_path('pvs', True) + if dev_list: + pvs_filter_pv_name = ' || '.join( + 'pv_name = {0}'.format(x) + for x in itertools.chain(dev_list, module.params['pvs']) + ) + pvs_filter_vg_name = 'vg_name = {0}'.format(vg) + pvs_filter = "--select '{0} || {1}' ".format(pvs_filter_pv_name, pvs_filter_vg_name) + else: + pvs_filter = '' + rc, current_pvs, err = module.run_command("%s --noheadings -o pv_name,vg_name --separator ';' %s" % (pvs_cmd, pvs_filter)) + if rc != 0: + module.fail_json(msg="Failed executing pvs command.", rc=rc, err=err) + + # check pv for devices + pvs = parse_pvs(module, current_pvs) + used_pvs = [pv for pv in pvs if pv['name'] in dev_list and pv['vg_name'] and pv['vg_name'] != vg] + if used_pvs: + module.fail_json(msg="Device %s is already in %s volume group." % (used_pvs[0]['name'], used_pvs[0]['vg_name'])) + + vgs_cmd = module.get_bin_path('vgs', True) + rc, current_vgs, err = module.run_command("%s --noheadings -o vg_name,pv_count,lv_count --separator ';'" % vgs_cmd) + + if rc != 0: + module.fail_json(msg="Failed executing vgs command.", rc=rc, err=err) + + changed = False + + vgs = parse_vgs(current_vgs) + + for test_vg in vgs: + if test_vg['name'] == vg: + this_vg = test_vg + break + else: + this_vg = None + + if this_vg is None: + if state == 'present': + # create VG + if module.check_mode: + changed = True + else: + # create PV + pvcreate_cmd = module.get_bin_path('pvcreate', True) + for current_dev in dev_list: + rc, _, err = module.run_command([pvcreate_cmd] + pvoptions + ['-f', str(current_dev)]) + if rc == 0: + changed = True + else: + module.fail_json(msg="Creating physical volume '%s' failed" % current_dev, rc=rc, err=err) + vgcreate_cmd = module.get_bin_path('vgcreate') + rc, _, err = module.run_command([vgcreate_cmd] + vgoptions + ['-s', pesize, vg] + dev_list) + if rc == 0: + changed = True + else: + module.fail_json(msg="Creating volume group '%s' failed" % vg, rc=rc, err=err) + else: + if state == 'absent': + if module.check_mode: + module.exit_json(changed=True) + else: + if this_vg['lv_count'] == 0 or force: + # remove VG + vgremove_cmd = module.get_bin_path('vgremove', True) + rc, _, err = module.run_command("%s --force %s" % (vgremove_cmd, vg)) + if rc == 0: + module.exit_json(changed=True) + else: + module.fail_json(msg="Failed to remove volume group %s" % (vg), rc=rc, err=err) + else: + module.fail_json(msg="Refuse to remove non-empty volume group %s without force=yes" % (vg)) + + # resize VG + current_devs = [os.path.realpath(pv['name']) for pv in pvs if pv['vg_name'] == vg] + devs_to_remove = list(set(current_devs) - set(dev_list)) + devs_to_add = list(set(dev_list) - set(current_devs)) + + if devs_to_add or devs_to_remove: + if module.check_mode: + changed = True + else: + if devs_to_add: + devs_to_add_string = ' '.join(devs_to_add) + # create PV + pvcreate_cmd = module.get_bin_path('pvcreate', True) + for current_dev in devs_to_add: + rc, _, err = module.run_command([pvcreate_cmd] + pvoptions + ['-f', str(current_dev)]) + if rc == 0: + changed = True + else: + module.fail_json(msg="Creating physical volume '%s' failed" % current_dev, rc=rc, err=err) + # add PV to our VG + vgextend_cmd = module.get_bin_path('vgextend', True) + rc, _, err = module.run_command("%s %s %s" % (vgextend_cmd, vg, devs_to_add_string)) + if rc == 0: + changed = True + else: + module.fail_json(msg="Unable to extend %s by %s." % (vg, devs_to_add_string), rc=rc, err=err) + + # remove some PV from our VG + if devs_to_remove: + devs_to_remove_string = ' '.join(devs_to_remove) + vgreduce_cmd = module.get_bin_path('vgreduce', True) + rc, _, err = module.run_command("%s --force %s %s" % (vgreduce_cmd, vg, devs_to_remove_string)) + if rc == 0: + changed = True + else: + module.fail_json(msg="Unable to reduce %s by %s." % (vg, devs_to_remove_string), rc=rc, err=err) + + module.exit_json(changed=changed) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/mongodb_parameter.py b/test/support/integration/plugins/modules/mongodb_parameter.py new file mode 100644 index 00000000..05de42b2 --- /dev/null +++ b/test/support/integration/plugins/modules/mongodb_parameter.py @@ -0,0 +1,223 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2016, Loic Blot <loic.blot@unix-experience.fr> +# Sponsored by Infopro Digital. http://www.infopro-digital.com/ +# Sponsored by E.T.A.I. http://www.etai.fr/ +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = r''' +--- +module: mongodb_parameter +short_description: Change an administrative parameter on a MongoDB server +description: + - Change an administrative parameter on a MongoDB server. +version_added: "2.1" +options: + login_user: + description: + - The MongoDB username used to authenticate with. + type: str + login_password: + description: + - The login user's password used to authenticate with. + type: str + login_host: + description: + - The host running the database. + type: str + default: localhost + login_port: + description: + - The MongoDB port to connect to. + default: 27017 + type: int + login_database: + description: + - The database where login credentials are stored. + type: str + replica_set: + description: + - Replica set to connect to (automatically connects to primary for writes). + type: str + ssl: + description: + - Whether to use an SSL connection when connecting to the database. + type: bool + default: no + param: + description: + - MongoDB administrative parameter to modify. + type: str + required: true + value: + description: + - MongoDB administrative parameter value to set. + type: str + required: true + param_type: + description: + - Define the type of parameter value. + default: str + type: str + choices: [int, str] + +notes: + - Requires the pymongo Python package on the remote host, version 2.4.2+. + - This can be installed using pip or the OS package manager. + - See also U(http://api.mongodb.org/python/current/installation.html) +requirements: [ "pymongo" ] +author: "Loic Blot (@nerzhul)" +''' + +EXAMPLES = r''' +- name: Set MongoDB syncdelay to 60 (this is an int) + mongodb_parameter: + param: syncdelay + value: 60 + param_type: int +''' + +RETURN = r''' +before: + description: value before modification + returned: success + type: str +after: + description: value after modification + returned: success + type: str +''' + +import os +import traceback + +try: + from pymongo.errors import ConnectionFailure + from pymongo.errors import OperationFailure + from pymongo import version as PyMongoVersion + from pymongo import MongoClient +except ImportError: + try: # for older PyMongo 2.2 + from pymongo import Connection as MongoClient + except ImportError: + pymongo_found = False + else: + pymongo_found = True +else: + pymongo_found = True + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.six.moves import configparser +from ansible.module_utils._text import to_native + + +# ========================================= +# MongoDB module specific support methods. +# + +def load_mongocnf(): + config = configparser.RawConfigParser() + mongocnf = os.path.expanduser('~/.mongodb.cnf') + + try: + config.readfp(open(mongocnf)) + creds = dict( + user=config.get('client', 'user'), + password=config.get('client', 'pass') + ) + except (configparser.NoOptionError, IOError): + return False + + return creds + + +# ========================================= +# Module execution. +# + +def main(): + module = AnsibleModule( + argument_spec=dict( + login_user=dict(default=None), + login_password=dict(default=None, no_log=True), + login_host=dict(default='localhost'), + login_port=dict(default=27017, type='int'), + login_database=dict(default=None), + replica_set=dict(default=None), + param=dict(required=True), + value=dict(required=True), + param_type=dict(default="str", choices=['str', 'int']), + ssl=dict(default=False, type='bool'), + ) + ) + + if not pymongo_found: + module.fail_json(msg=missing_required_lib('pymongo')) + + login_user = module.params['login_user'] + login_password = module.params['login_password'] + login_host = module.params['login_host'] + login_port = module.params['login_port'] + login_database = module.params['login_database'] + + replica_set = module.params['replica_set'] + ssl = module.params['ssl'] + + param = module.params['param'] + param_type = module.params['param_type'] + value = module.params['value'] + + # Verify parameter is coherent with specified type + try: + if param_type == 'int': + value = int(value) + except ValueError: + module.fail_json(msg="value '%s' is not %s" % (value, param_type)) + + try: + if replica_set: + client = MongoClient(login_host, int(login_port), replicaset=replica_set, ssl=ssl) + else: + client = MongoClient(login_host, int(login_port), ssl=ssl) + + if login_user is None and login_password is None: + mongocnf_creds = load_mongocnf() + if mongocnf_creds is not False: + login_user = mongocnf_creds['user'] + login_password = mongocnf_creds['password'] + elif login_password is None or login_user is None: + module.fail_json(msg='when supplying login arguments, both login_user and login_password must be provided') + + if login_user is not None and login_password is not None: + client.admin.authenticate(login_user, login_password, source=login_database) + + except ConnectionFailure as e: + module.fail_json(msg='unable to connect to database: %s' % to_native(e), exception=traceback.format_exc()) + + db = client.admin + + try: + after_value = db.command("setParameter", **{param: value}) + except OperationFailure as e: + module.fail_json(msg="unable to change parameter: %s" % to_native(e), exception=traceback.format_exc()) + + if "was" not in after_value: + module.exit_json(changed=True, msg="Unable to determine old value, assume it changed.") + else: + module.exit_json(changed=(value != after_value["was"]), before=after_value["was"], + after=value) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/mongodb_user.py b/test/support/integration/plugins/modules/mongodb_user.py new file mode 100644 index 00000000..362b3aa4 --- /dev/null +++ b/test/support/integration/plugins/modules/mongodb_user.py @@ -0,0 +1,474 @@ +#!/usr/bin/python + +# (c) 2012, Elliott Foster <elliott@fourkitchens.com> +# Sponsored by Four Kitchens http://fourkitchens.com. +# (c) 2014, Epic Games, Inc. +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: mongodb_user +short_description: Adds or removes a user from a MongoDB database +description: + - Adds or removes a user from a MongoDB database. +version_added: "1.1" +options: + login_user: + description: + - The MongoDB username used to authenticate with. + type: str + login_password: + description: + - The login user's password used to authenticate with. + type: str + login_host: + description: + - The host running the database. + default: localhost + type: str + login_port: + description: + - The MongoDB port to connect to. + default: '27017' + type: str + login_database: + version_added: "2.0" + description: + - The database where login credentials are stored. + type: str + replica_set: + version_added: "1.6" + description: + - Replica set to connect to (automatically connects to primary for writes). + type: str + database: + description: + - The name of the database to add/remove the user from. + required: true + type: str + aliases: [db] + name: + description: + - The name of the user to add or remove. + required: true + aliases: [user] + type: str + password: + description: + - The password to use for the user. + type: str + aliases: [pass] + ssl: + version_added: "1.8" + description: + - Whether to use an SSL connection when connecting to the database. + type: bool + ssl_cert_reqs: + version_added: "2.2" + description: + - Specifies whether a certificate is required from the other side of the connection, + and whether it will be validated if provided. + default: CERT_REQUIRED + choices: [CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED] + type: str + roles: + version_added: "1.3" + type: list + elements: raw + description: + - > + The database user roles valid values could either be one or more of the following strings: + 'read', 'readWrite', 'dbAdmin', 'userAdmin', 'clusterAdmin', 'readAnyDatabase', 'readWriteAnyDatabase', 'userAdminAnyDatabase', + 'dbAdminAnyDatabase' + - "Or the following dictionary '{ db: DATABASE_NAME, role: ROLE_NAME }'." + - "This param requires pymongo 2.5+. If it is a string, mongodb 2.4+ is also required. If it is a dictionary, mongo 2.6+ is required." + state: + description: + - The database user state. + default: present + choices: [absent, present] + type: str + update_password: + default: always + choices: [always, on_create] + version_added: "2.1" + description: + - C(always) will update passwords if they differ. + - C(on_create) will only set the password for newly created users. + type: str + +notes: + - Requires the pymongo Python package on the remote host, version 2.4.2+. This + can be installed using pip or the OS package manager. @see http://api.mongodb.org/python/current/installation.html +requirements: [ "pymongo" ] +author: + - "Elliott Foster (@elliotttf)" + - "Julien Thebault (@Lujeni)" +''' + +EXAMPLES = ''' +- name: Create 'burgers' database user with name 'bob' and password '12345'. + mongodb_user: + database: burgers + name: bob + password: 12345 + state: present + +- name: Create a database user via SSL (MongoDB must be compiled with the SSL option and configured properly) + mongodb_user: + database: burgers + name: bob + password: 12345 + state: present + ssl: True + +- name: Delete 'burgers' database user with name 'bob'. + mongodb_user: + database: burgers + name: bob + state: absent + +- name: Define more users with various specific roles (if not defined, no roles is assigned, and the user will be added via pre mongo 2.2 style) + mongodb_user: + database: burgers + name: ben + password: 12345 + roles: read + state: present + +- name: Define roles + mongodb_user: + database: burgers + name: jim + password: 12345 + roles: readWrite,dbAdmin,userAdmin + state: present + +- name: Define roles + mongodb_user: + database: burgers + name: joe + password: 12345 + roles: readWriteAnyDatabase + state: present + +- name: Add a user to database in a replica set, the primary server is automatically discovered and written to + mongodb_user: + database: burgers + name: bob + replica_set: belcher + password: 12345 + roles: readWriteAnyDatabase + state: present + +# add a user 'oplog_reader' with read only access to the 'local' database on the replica_set 'belcher'. This is useful for oplog access (MONGO_OPLOG_URL). +# please notice the credentials must be added to the 'admin' database because the 'local' database is not synchronized and can't receive user credentials +# To login with such user, the connection string should be MONGO_OPLOG_URL="mongodb://oplog_reader:oplog_reader_password@server1,server2/local?authSource=admin" +# This syntax requires mongodb 2.6+ and pymongo 2.5+ +- name: Roles as a dictionary + mongodb_user: + login_user: root + login_password: root_password + database: admin + user: oplog_reader + password: oplog_reader_password + state: present + replica_set: belcher + roles: + - db: local + role: read + +''' + +RETURN = ''' +user: + description: The name of the user to add or remove. + returned: success + type: str +''' + +import os +import ssl as ssl_lib +import traceback +from distutils.version import LooseVersion +from operator import itemgetter + +try: + from pymongo.errors import ConnectionFailure + from pymongo.errors import OperationFailure + from pymongo import version as PyMongoVersion + from pymongo import MongoClient +except ImportError: + try: # for older PyMongo 2.2 + from pymongo import Connection as MongoClient + except ImportError: + pymongo_found = False + else: + pymongo_found = True +else: + pymongo_found = True + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.six import binary_type, text_type +from ansible.module_utils.six.moves import configparser +from ansible.module_utils._text import to_native + + +# ========================================= +# MongoDB module specific support methods. +# + +def check_compatibility(module, client): + """Check the compatibility between the driver and the database. + + See: https://docs.mongodb.com/ecosystem/drivers/driver-compatibility-reference/#python-driver-compatibility + + Args: + module: Ansible module. + client (cursor): Mongodb cursor on admin database. + """ + loose_srv_version = LooseVersion(client.server_info()['version']) + loose_driver_version = LooseVersion(PyMongoVersion) + + if loose_srv_version >= LooseVersion('3.2') and loose_driver_version < LooseVersion('3.2'): + module.fail_json(msg=' (Note: you must use pymongo 3.2+ with MongoDB >= 3.2)') + + elif loose_srv_version >= LooseVersion('3.0') and loose_driver_version <= LooseVersion('2.8'): + module.fail_json(msg=' (Note: you must use pymongo 2.8+ with MongoDB 3.0)') + + elif loose_srv_version >= LooseVersion('2.6') and loose_driver_version <= LooseVersion('2.7'): + module.fail_json(msg=' (Note: you must use pymongo 2.7+ with MongoDB 2.6)') + + elif LooseVersion(PyMongoVersion) <= LooseVersion('2.5'): + module.fail_json(msg=' (Note: you must be on mongodb 2.4+ and pymongo 2.5+ to use the roles param)') + + +def user_find(client, user, db_name): + """Check if the user exists. + + Args: + client (cursor): Mongodb cursor on admin database. + user (str): User to check. + db_name (str): User's database. + + Returns: + dict: when user exists, False otherwise. + """ + for mongo_user in client["admin"].system.users.find(): + if mongo_user['user'] == user: + # NOTE: there is no 'db' field in mongo 2.4. + if 'db' not in mongo_user: + return mongo_user + + if mongo_user["db"] == db_name: + return mongo_user + return False + + +def user_add(module, client, db_name, user, password, roles): + # pymongo's user_add is a _create_or_update_user so we won't know if it was changed or updated + # without reproducing a lot of the logic in database.py of pymongo + db = client[db_name] + + if roles is None: + db.add_user(user, password, False) + else: + db.add_user(user, password, None, roles=roles) + + +def user_remove(module, client, db_name, user): + exists = user_find(client, user, db_name) + if exists: + if module.check_mode: + module.exit_json(changed=True, user=user) + db = client[db_name] + db.remove_user(user) + else: + module.exit_json(changed=False, user=user) + + +def load_mongocnf(): + config = configparser.RawConfigParser() + mongocnf = os.path.expanduser('~/.mongodb.cnf') + + try: + config.readfp(open(mongocnf)) + creds = dict( + user=config.get('client', 'user'), + password=config.get('client', 'pass') + ) + except (configparser.NoOptionError, IOError): + return False + + return creds + + +def check_if_roles_changed(uinfo, roles, db_name): + # We must be aware of users which can read the oplog on a replicaset + # Such users must have access to the local DB, but since this DB does not store users credentials + # and is not synchronized among replica sets, the user must be stored on the admin db + # Therefore their structure is the following : + # { + # "_id" : "admin.oplog_reader", + # "user" : "oplog_reader", + # "db" : "admin", # <-- admin DB + # "roles" : [ + # { + # "role" : "read", + # "db" : "local" # <-- local DB + # } + # ] + # } + + def make_sure_roles_are_a_list_of_dict(roles, db_name): + output = list() + for role in roles: + if isinstance(role, (binary_type, text_type)): + new_role = {"role": role, "db": db_name} + output.append(new_role) + else: + output.append(role) + return output + + roles_as_list_of_dict = make_sure_roles_are_a_list_of_dict(roles, db_name) + uinfo_roles = uinfo.get('roles', []) + + if sorted(roles_as_list_of_dict, key=itemgetter('db')) == sorted(uinfo_roles, key=itemgetter('db')): + return False + return True + + +# ========================================= +# Module execution. +# + +def main(): + module = AnsibleModule( + argument_spec=dict( + login_user=dict(default=None), + login_password=dict(default=None, no_log=True), + login_host=dict(default='localhost'), + login_port=dict(default='27017'), + login_database=dict(default=None), + replica_set=dict(default=None), + database=dict(required=True, aliases=['db']), + name=dict(required=True, aliases=['user']), + password=dict(aliases=['pass'], no_log=True), + ssl=dict(default=False, type='bool'), + roles=dict(default=None, type='list', elements='raw'), + state=dict(default='present', choices=['absent', 'present']), + update_password=dict(default="always", choices=["always", "on_create"]), + ssl_cert_reqs=dict(default='CERT_REQUIRED', choices=['CERT_NONE', 'CERT_OPTIONAL', 'CERT_REQUIRED']), + ), + supports_check_mode=True + ) + + if not pymongo_found: + module.fail_json(msg=missing_required_lib('pymongo')) + + login_user = module.params['login_user'] + login_password = module.params['login_password'] + login_host = module.params['login_host'] + login_port = module.params['login_port'] + login_database = module.params['login_database'] + + replica_set = module.params['replica_set'] + db_name = module.params['database'] + user = module.params['name'] + password = module.params['password'] + ssl = module.params['ssl'] + roles = module.params['roles'] or [] + state = module.params['state'] + update_password = module.params['update_password'] + + try: + connection_params = { + "host": login_host, + "port": int(login_port), + } + + if replica_set: + connection_params["replicaset"] = replica_set + + if ssl: + connection_params["ssl"] = ssl + connection_params["ssl_cert_reqs"] = getattr(ssl_lib, module.params['ssl_cert_reqs']) + + client = MongoClient(**connection_params) + + # NOTE: this check must be done ASAP. + # We doesn't need to be authenticated (this ability has lost in PyMongo 3.6) + if LooseVersion(PyMongoVersion) <= LooseVersion('3.5'): + check_compatibility(module, client) + + if login_user is None and login_password is None: + mongocnf_creds = load_mongocnf() + if mongocnf_creds is not False: + login_user = mongocnf_creds['user'] + login_password = mongocnf_creds['password'] + elif login_password is None or login_user is None: + module.fail_json(msg='when supplying login arguments, both login_user and login_password must be provided') + + if login_user is not None and login_password is not None: + client.admin.authenticate(login_user, login_password, source=login_database) + elif LooseVersion(PyMongoVersion) >= LooseVersion('3.0'): + if db_name != "admin": + module.fail_json(msg='The localhost login exception only allows the first admin account to be created') + # else: this has to be the first admin user added + + except Exception as e: + module.fail_json(msg='unable to connect to database: %s' % to_native(e), exception=traceback.format_exc()) + + if state == 'present': + if password is None and update_password == 'always': + module.fail_json(msg='password parameter required when adding a user unless update_password is set to on_create') + + try: + if update_password != 'always': + uinfo = user_find(client, user, db_name) + if uinfo: + password = None + if not check_if_roles_changed(uinfo, roles, db_name): + module.exit_json(changed=False, user=user) + + if module.check_mode: + module.exit_json(changed=True, user=user) + + user_add(module, client, db_name, user, password, roles) + except Exception as e: + module.fail_json(msg='Unable to add or update user: %s' % to_native(e), exception=traceback.format_exc()) + finally: + try: + client.close() + except Exception: + pass + # Here we can check password change if mongo provide a query for that : https://jira.mongodb.org/browse/SERVER-22848 + # newuinfo = user_find(client, user, db_name) + # if uinfo['role'] == newuinfo['role'] and CheckPasswordHere: + # module.exit_json(changed=False, user=user) + + elif state == 'absent': + try: + user_remove(module, client, db_name, user) + except Exception as e: + module.fail_json(msg='Unable to remove user: %s' % to_native(e), exception=traceback.format_exc()) + finally: + try: + client.close() + except Exception: + pass + module.exit_json(changed=True, user=user) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/pids.py b/test/support/integration/plugins/modules/pids.py new file mode 100644 index 00000000..4cbf45a9 --- /dev/null +++ b/test/support/integration/plugins/modules/pids.py @@ -0,0 +1,89 @@ +#!/usr/bin/python +# Copyright: (c) 2019, Saranya Sridharan +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +module: pids +version_added: 2.8 +description: "Retrieves a list of PIDs of given process name in Ansible controller/controlled machines.Returns an empty list if no process in that name exists." +short_description: "Retrieves process IDs list if the process is running otherwise return empty list" +author: + - Saranya Sridharan (@saranyasridharan) +requirements: + - psutil(python module) +options: + name: + description: the name of the process you want to get PID for. + required: true + type: str +''' + +EXAMPLES = ''' +# Pass the process name +- name: Getting process IDs of the process + pids: + name: python + register: pids_of_python + +- name: Printing the process IDs obtained + debug: + msg: "PIDS of python:{{pids_of_python.pids|join(',')}}" +''' + +RETURN = ''' +pids: + description: Process IDs of the given process + returned: list of none, one, or more process IDs + type: list + sample: [100,200] +''' + +from ansible.module_utils.basic import AnsibleModule +try: + import psutil + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + + +def compare_lower(a, b): + if a is None or b is None: + # this could just be "return False" but would lead to surprising behavior if both a and b are None + return a == b + + return a.lower() == b.lower() + + +def get_pid(name): + pids = [] + + for proc in psutil.process_iter(attrs=['name', 'cmdline']): + if compare_lower(proc.info['name'], name) or \ + proc.info['cmdline'] and compare_lower(proc.info['cmdline'][0], name): + pids.append(proc.pid) + + return pids + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(required=True, type="str"), + ), + supports_check_mode=True, + ) + if not HAS_PSUTIL: + module.fail_json(msg="Missing required 'psutil' python module. Try installing it with: pip install psutil") + name = module.params["name"] + response = dict(pids=get_pid(name)) + module.exit_json(**response) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/pkgng.py b/test/support/integration/plugins/modules/pkgng.py new file mode 100644 index 00000000..11363479 --- /dev/null +++ b/test/support/integration/plugins/modules/pkgng.py @@ -0,0 +1,406 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, bleader +# Written by bleader <bleader@ratonland.org> +# Based on pkgin module written by Shaun Zinck <shaun.zinck at gmail.com> +# that was based on pacman module written by Afterburn <https://github.com/afterburn> +# that was based on apt module written by Matthew Williams <matthew@flowroute.com> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: pkgng +short_description: Package manager for FreeBSD >= 9.0 +description: + - Manage binary packages for FreeBSD using 'pkgng' which is available in versions after 9.0. +version_added: "1.2" +options: + name: + description: + - Name or list of names of packages to install/remove. + required: true + state: + description: + - State of the package. + - 'Note: "latest" added in 2.7' + choices: [ 'present', 'latest', 'absent' ] + required: false + default: present + cached: + description: + - Use local package base instead of fetching an updated one. + type: bool + required: false + default: no + annotation: + description: + - A comma-separated list of keyvalue-pairs of the form + C(<+/-/:><key>[=<value>]). A C(+) denotes adding an annotation, a + C(-) denotes removing an annotation, and C(:) denotes modifying an + annotation. + If setting or modifying annotations, a value must be provided. + required: false + version_added: "1.6" + pkgsite: + description: + - For pkgng versions before 1.1.4, specify packagesite to use + for downloading packages. If not specified, use settings from + C(/usr/local/etc/pkg.conf). + - For newer pkgng versions, specify a the name of a repository + configured in C(/usr/local/etc/pkg/repos). + required: false + rootdir: + description: + - For pkgng versions 1.5 and later, pkg will install all packages + within the specified root directory. + - Can not be used together with I(chroot) or I(jail) options. + required: false + chroot: + version_added: "2.1" + description: + - Pkg will chroot in the specified environment. + - Can not be used together with I(rootdir) or I(jail) options. + required: false + jail: + version_added: "2.4" + description: + - Pkg will execute in the given jail name or id. + - Can not be used together with I(chroot) or I(rootdir) options. + autoremove: + version_added: "2.2" + description: + - Remove automatically installed packages which are no longer needed. + required: false + type: bool + default: no +author: "bleader (@bleader)" +notes: + - When using pkgsite, be careful that already in cache packages won't be downloaded again. + - When used with a `loop:` each package will be processed individually, + it is much more efficient to pass the list directly to the `name` option. +''' + +EXAMPLES = ''' +- name: Install package foo + pkgng: + name: foo + state: present + +- name: Annotate package foo and bar + pkgng: + name: foo,bar + annotation: '+test1=baz,-test2,:test3=foobar' + +- name: Remove packages foo and bar + pkgng: + name: foo,bar + state: absent + +# "latest" support added in 2.7 +- name: Upgrade package baz + pkgng: + name: baz + state: latest +''' + + +import re +from ansible.module_utils.basic import AnsibleModule + + +def query_package(module, pkgng_path, name, dir_arg): + + rc, out, err = module.run_command("%s %s info -g -e %s" % (pkgng_path, dir_arg, name)) + + if rc == 0: + return True + + return False + + +def query_update(module, pkgng_path, name, dir_arg, old_pkgng, pkgsite): + + # Check to see if a package upgrade is available. + # rc = 0, no updates available or package not installed + # rc = 1, updates available + if old_pkgng: + rc, out, err = module.run_command("%s %s upgrade -g -n %s" % (pkgsite, pkgng_path, name)) + else: + rc, out, err = module.run_command("%s %s upgrade %s -g -n %s" % (pkgng_path, dir_arg, pkgsite, name)) + + if rc == 1: + return True + + return False + + +def pkgng_older_than(module, pkgng_path, compare_version): + + rc, out, err = module.run_command("%s -v" % pkgng_path) + version = [int(x) for x in re.split(r'[\._]', out)] + + i = 0 + new_pkgng = True + while compare_version[i] == version[i]: + i += 1 + if i == min(len(compare_version), len(version)): + break + else: + if compare_version[i] > version[i]: + new_pkgng = False + return not new_pkgng + + +def remove_packages(module, pkgng_path, packages, dir_arg): + + remove_c = 0 + # Using a for loop in case of error, we can report the package that failed + for package in packages: + # Query the package first, to see if we even need to remove + if not query_package(module, pkgng_path, package, dir_arg): + continue + + if not module.check_mode: + rc, out, err = module.run_command("%s %s delete -y %s" % (pkgng_path, dir_arg, package)) + + if not module.check_mode and query_package(module, pkgng_path, package, dir_arg): + module.fail_json(msg="failed to remove %s: %s" % (package, out)) + + remove_c += 1 + + if remove_c > 0: + + return (True, "removed %s package(s)" % remove_c) + + return (False, "package(s) already absent") + + +def install_packages(module, pkgng_path, packages, cached, pkgsite, dir_arg, state): + + install_c = 0 + + # as of pkg-1.1.4, PACKAGESITE is deprecated in favor of repository definitions + # in /usr/local/etc/pkg/repos + old_pkgng = pkgng_older_than(module, pkgng_path, [1, 1, 4]) + if pkgsite != "": + if old_pkgng: + pkgsite = "PACKAGESITE=%s" % (pkgsite) + else: + pkgsite = "-r %s" % (pkgsite) + + # This environment variable skips mid-install prompts, + # setting them to their default values. + batch_var = 'env BATCH=yes' + + if not module.check_mode and not cached: + if old_pkgng: + rc, out, err = module.run_command("%s %s update" % (pkgsite, pkgng_path)) + else: + rc, out, err = module.run_command("%s %s update" % (pkgng_path, dir_arg)) + if rc != 0: + module.fail_json(msg="Could not update catalogue [%d]: %s %s" % (rc, out, err)) + + for package in packages: + already_installed = query_package(module, pkgng_path, package, dir_arg) + if already_installed and state == "present": + continue + + update_available = query_update(module, pkgng_path, package, dir_arg, old_pkgng, pkgsite) + if not update_available and already_installed and state == "latest": + continue + + if not module.check_mode: + if already_installed: + action = "upgrade" + else: + action = "install" + if old_pkgng: + rc, out, err = module.run_command("%s %s %s %s -g -U -y %s" % (batch_var, pkgsite, pkgng_path, action, package)) + else: + rc, out, err = module.run_command("%s %s %s %s %s -g -U -y %s" % (batch_var, pkgng_path, dir_arg, action, pkgsite, package)) + + if not module.check_mode and not query_package(module, pkgng_path, package, dir_arg): + module.fail_json(msg="failed to %s %s: %s" % (action, package, out), stderr=err) + + install_c += 1 + + if install_c > 0: + return (True, "added %s package(s)" % (install_c)) + + return (False, "package(s) already %s" % (state)) + + +def annotation_query(module, pkgng_path, package, tag, dir_arg): + rc, out, err = module.run_command("%s %s info -g -A %s" % (pkgng_path, dir_arg, package)) + match = re.search(r'^\s*(?P<tag>%s)\s*:\s*(?P<value>\w+)' % tag, out, flags=re.MULTILINE) + if match: + return match.group('value') + return False + + +def annotation_add(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if not _value: + # Annotation does not exist, add it. + rc, out, err = module.run_command('%s %s annotate -y -A %s %s "%s"' + % (pkgng_path, dir_arg, package, tag, value)) + if rc != 0: + module.fail_json(msg="could not annotate %s: %s" + % (package, out), stderr=err) + return True + elif _value != value: + # Annotation exists, but value differs + module.fail_json( + mgs="failed to annotate %s, because %s is already set to %s, but should be set to %s" + % (package, tag, _value, value)) + return False + else: + # Annotation exists, nothing to do + return False + + +def annotation_delete(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if _value: + rc, out, err = module.run_command('%s %s annotate -y -D %s %s' + % (pkgng_path, dir_arg, package, tag)) + if rc != 0: + module.fail_json(msg="could not delete annotation to %s: %s" + % (package, out), stderr=err) + return True + return False + + +def annotation_modify(module, pkgng_path, package, tag, value, dir_arg): + _value = annotation_query(module, pkgng_path, package, tag, dir_arg) + if not value: + # No such tag + module.fail_json(msg="could not change annotation to %s: tag %s does not exist" + % (package, tag)) + elif _value == value: + # No change in value + return False + else: + rc, out, err = module.run_command('%s %s annotate -y -M %s %s "%s"' + % (pkgng_path, dir_arg, package, tag, value)) + if rc != 0: + module.fail_json(msg="could not change annotation annotation to %s: %s" + % (package, out), stderr=err) + return True + + +def annotate_packages(module, pkgng_path, packages, annotation, dir_arg): + annotate_c = 0 + annotations = map(lambda _annotation: + re.match(r'(?P<operation>[\+-:])(?P<tag>\w+)(=(?P<value>\w+))?', + _annotation).groupdict(), + re.split(r',', annotation)) + + operation = { + '+': annotation_add, + '-': annotation_delete, + ':': annotation_modify + } + + for package in packages: + for _annotation in annotations: + if operation[_annotation['operation']](module, pkgng_path, package, _annotation['tag'], _annotation['value']): + annotate_c += 1 + + if annotate_c > 0: + return (True, "added %s annotations." % annotate_c) + return (False, "changed no annotations") + + +def autoremove_packages(module, pkgng_path, dir_arg): + rc, out, err = module.run_command("%s %s autoremove -n" % (pkgng_path, dir_arg)) + + autoremove_c = 0 + + match = re.search('^Deinstallation has been requested for the following ([0-9]+) packages', out, re.MULTILINE) + if match: + autoremove_c = int(match.group(1)) + + if autoremove_c == 0: + return False, "no package(s) to autoremove" + + if not module.check_mode: + rc, out, err = module.run_command("%s %s autoremove -y" % (pkgng_path, dir_arg)) + + return True, "autoremoved %d package(s)" % (autoremove_c) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + state=dict(default="present", choices=["present", "latest", "absent"], required=False), + name=dict(aliases=["pkg"], required=True, type='list'), + cached=dict(default=False, type='bool'), + annotation=dict(default="", required=False), + pkgsite=dict(default="", required=False), + rootdir=dict(default="", required=False, type='path'), + chroot=dict(default="", required=False, type='path'), + jail=dict(default="", required=False, type='str'), + autoremove=dict(default=False, type='bool')), + supports_check_mode=True, + mutually_exclusive=[["rootdir", "chroot", "jail"]]) + + pkgng_path = module.get_bin_path('pkg', True) + + p = module.params + + pkgs = p["name"] + + changed = False + msgs = [] + dir_arg = "" + + if p["rootdir"] != "": + old_pkgng = pkgng_older_than(module, pkgng_path, [1, 5, 0]) + if old_pkgng: + module.fail_json(msg="To use option 'rootdir' pkg version must be 1.5 or greater") + else: + dir_arg = "--rootdir %s" % (p["rootdir"]) + + if p["chroot"] != "": + dir_arg = '--chroot %s' % (p["chroot"]) + + if p["jail"] != "": + dir_arg = '--jail %s' % (p["jail"]) + + if p["state"] in ("present", "latest"): + _changed, _msg = install_packages(module, pkgng_path, pkgs, p["cached"], p["pkgsite"], dir_arg, p["state"]) + changed = changed or _changed + msgs.append(_msg) + + elif p["state"] == "absent": + _changed, _msg = remove_packages(module, pkgng_path, pkgs, dir_arg) + changed = changed or _changed + msgs.append(_msg) + + if p["autoremove"]: + _changed, _msg = autoremove_packages(module, pkgng_path, dir_arg) + changed = changed or _changed + msgs.append(_msg) + + if p["annotation"]: + _changed, _msg = annotate_packages(module, pkgng_path, pkgs, p["annotation"], dir_arg) + changed = changed or _changed + msgs.append(_msg) + + module.exit_json(changed=changed, msg=", ".join(msgs)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_db.py b/test/support/integration/plugins/modules/postgresql_db.py new file mode 100644 index 00000000..40858d99 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_db.py @@ -0,0 +1,657 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: postgresql_db +short_description: Add or remove PostgreSQL databases from a remote host. +description: + - Add or remove PostgreSQL databases from a remote host. +version_added: '0.6' +options: + name: + description: + - Name of the database to add or remove + type: str + required: true + aliases: [ db ] + port: + description: + - Database port to connect (if needed) + type: int + default: 5432 + aliases: + - login_port + owner: + description: + - Name of the role to set as owner of the database + type: str + template: + description: + - Template used to create the database + type: str + encoding: + description: + - Encoding of the database + type: str + lc_collate: + description: + - Collation order (LC_COLLATE) to use in the database. Must match collation order of template database unless C(template0) is used as template. + type: str + lc_ctype: + description: + - Character classification (LC_CTYPE) to use in the database (e.g. lower, upper, ...) Must match LC_CTYPE of template database unless C(template0) + is used as template. + type: str + session_role: + description: + - Switch to session_role after connecting. The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally. + type: str + version_added: '2.8' + state: + description: + - The database state. + - C(present) implies that the database should be created if necessary. + - C(absent) implies that the database should be removed if present. + - C(dump) requires a target definition to which the database will be backed up. (Added in Ansible 2.4) + Note that in some PostgreSQL versions of pg_dump, which is an embedded PostgreSQL utility and is used by the module, + returns rc 0 even when errors occurred (e.g. the connection is forbidden by pg_hba.conf, etc.), + so the module returns changed=True but the dump has not actually been done. Please, be sure that your version of + pg_dump returns rc 1 in this case. + - C(restore) also requires a target definition from which the database will be restored. (Added in Ansible 2.4) + - The format of the backup will be detected based on the target name. + - Supported compression formats for dump and restore include C(.pgc), C(.bz2), C(.gz) and C(.xz) + - Supported formats for dump and restore include C(.sql) and C(.tar) + type: str + choices: [ absent, dump, present, restore ] + default: present + target: + description: + - File to back up or restore from. + - Used when I(state) is C(dump) or C(restore). + type: path + version_added: '2.4' + target_opts: + description: + - Further arguments for pg_dump or pg_restore. + - Used when I(state) is C(dump) or C(restore). + type: str + version_added: '2.4' + maintenance_db: + description: + - The value specifies the initial database (which is also called as maintenance DB) that Ansible connects to. + type: str + default: postgres + version_added: '2.5' + conn_limit: + description: + - Specifies the database connection limit. + type: str + version_added: '2.8' + tablespace: + description: + - The tablespace to set for the database + U(https://www.postgresql.org/docs/current/sql-alterdatabase.html). + - If you want to move the database back to the default tablespace, + explicitly set this to pg_default. + type: path + version_added: '2.9' + dump_extra_args: + description: + - Provides additional arguments when I(state) is C(dump). + - Cannot be used with dump-file-format-related arguments like ``--format=d``. + type: str + version_added: '2.10' +seealso: +- name: CREATE DATABASE reference + description: Complete reference of the CREATE DATABASE command documentation. + link: https://www.postgresql.org/docs/current/sql-createdatabase.html +- name: DROP DATABASE reference + description: Complete reference of the DROP DATABASE command documentation. + link: https://www.postgresql.org/docs/current/sql-dropdatabase.html +- name: pg_dump reference + description: Complete reference of pg_dump documentation. + link: https://www.postgresql.org/docs/current/app-pgdump.html +- name: pg_restore reference + description: Complete reference of pg_restore documentation. + link: https://www.postgresql.org/docs/current/app-pgrestore.html +- module: postgresql_tablespace +- module: postgresql_info +- module: postgresql_ping +notes: +- State C(dump) and C(restore) don't require I(psycopg2) since version 2.8. +author: "Ansible Core Team" +extends_documentation_fragment: +- postgres +''' + +EXAMPLES = r''' +- name: Create a new database with name "acme" + postgresql_db: + name: acme + +# Note: If a template different from "template0" is specified, encoding and locale settings must match those of the template. +- name: Create a new database with name "acme" and specific encoding and locale # settings. + postgresql_db: + name: acme + encoding: UTF-8 + lc_collate: de_DE.UTF-8 + lc_ctype: de_DE.UTF-8 + template: template0 + +# Note: Default limit for the number of concurrent connections to a specific database is "-1", which means "unlimited" +- name: Create a new database with name "acme" which has a limit of 100 concurrent connections + postgresql_db: + name: acme + conn_limit: "100" + +- name: Dump an existing database to a file + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql + +- name: Dump an existing database to a file excluding the test table + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql + dump_extra_args: --exclude-table=test + +- name: Dump an existing database to a file (with compression) + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql.gz + +- name: Dump a single schema for an existing database + postgresql_db: + name: acme + state: dump + target: /tmp/acme.sql + target_opts: "-n public" + +# Note: In the example below, if database foo exists and has another tablespace +# the tablespace will be changed to foo. Access to the database will be locked +# until the copying of database files is finished. +- name: Create a new database called foo in tablespace bar + postgresql_db: + name: foo + tablespace: bar +''' + +RETURN = r''' +executed_commands: + description: List of commands which tried to run. + returned: always + type: list + sample: ["CREATE DATABASE acme"] + version_added: '2.10' +''' + + +import os +import subprocess +import traceback + +try: + import psycopg2 + import psycopg2.extras +except ImportError: + HAS_PSYCOPG2 = False +else: + HAS_PSYCOPG2 = True + +import ansible.module_utils.postgres as pgutils +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.database import SQLParseError, pg_quote_identifier +from ansible.module_utils.six import iteritems +from ansible.module_utils.six.moves import shlex_quote +from ansible.module_utils._text import to_native + +executed_commands = [] + + +class NotSupportedError(Exception): + pass + +# =========================================== +# PostgreSQL module specific support methods. +# + + +def set_owner(cursor, db, owner): + query = 'ALTER DATABASE %s OWNER TO "%s"' % ( + pg_quote_identifier(db, 'database'), + owner) + executed_commands.append(query) + cursor.execute(query) + return True + + +def set_conn_limit(cursor, db, conn_limit): + query = "ALTER DATABASE %s CONNECTION LIMIT %s" % ( + pg_quote_identifier(db, 'database'), + conn_limit) + executed_commands.append(query) + cursor.execute(query) + return True + + +def get_encoding_id(cursor, encoding): + query = "SELECT pg_char_to_encoding(%(encoding)s) AS encoding_id;" + cursor.execute(query, {'encoding': encoding}) + return cursor.fetchone()['encoding_id'] + + +def get_db_info(cursor, db): + query = """ + SELECT rolname AS owner, + pg_encoding_to_char(encoding) AS encoding, encoding AS encoding_id, + datcollate AS lc_collate, datctype AS lc_ctype, pg_database.datconnlimit AS conn_limit, + spcname AS tablespace + FROM pg_database + JOIN pg_roles ON pg_roles.oid = pg_database.datdba + JOIN pg_tablespace ON pg_tablespace.oid = pg_database.dattablespace + WHERE datname = %(db)s + """ + cursor.execute(query, {'db': db}) + return cursor.fetchone() + + +def db_exists(cursor, db): + query = "SELECT * FROM pg_database WHERE datname=%(db)s" + cursor.execute(query, {'db': db}) + return cursor.rowcount == 1 + + +def db_delete(cursor, db): + if db_exists(cursor, db): + query = "DROP DATABASE %s" % pg_quote_identifier(db, 'database') + executed_commands.append(query) + cursor.execute(query) + return True + else: + return False + + +def db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace): + params = dict(enc=encoding, collate=lc_collate, ctype=lc_ctype, conn_limit=conn_limit, tablespace=tablespace) + if not db_exists(cursor, db): + query_fragments = ['CREATE DATABASE %s' % pg_quote_identifier(db, 'database')] + if owner: + query_fragments.append('OWNER "%s"' % owner) + if template: + query_fragments.append('TEMPLATE %s' % pg_quote_identifier(template, 'database')) + if encoding: + query_fragments.append('ENCODING %(enc)s') + if lc_collate: + query_fragments.append('LC_COLLATE %(collate)s') + if lc_ctype: + query_fragments.append('LC_CTYPE %(ctype)s') + if tablespace: + query_fragments.append('TABLESPACE %s' % pg_quote_identifier(tablespace, 'tablespace')) + if conn_limit: + query_fragments.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit}) + query = ' '.join(query_fragments) + executed_commands.append(cursor.mogrify(query, params)) + cursor.execute(query, params) + return True + else: + db_info = get_db_info(cursor, db) + if (encoding and get_encoding_id(cursor, encoding) != db_info['encoding_id']): + raise NotSupportedError( + 'Changing database encoding is not supported. ' + 'Current encoding: %s' % db_info['encoding'] + ) + elif lc_collate and lc_collate != db_info['lc_collate']: + raise NotSupportedError( + 'Changing LC_COLLATE is not supported. ' + 'Current LC_COLLATE: %s' % db_info['lc_collate'] + ) + elif lc_ctype and lc_ctype != db_info['lc_ctype']: + raise NotSupportedError( + 'Changing LC_CTYPE is not supported.' + 'Current LC_CTYPE: %s' % db_info['lc_ctype'] + ) + else: + changed = False + + if owner and owner != db_info['owner']: + changed = set_owner(cursor, db, owner) + + if conn_limit and conn_limit != str(db_info['conn_limit']): + changed = set_conn_limit(cursor, db, conn_limit) + + if tablespace and tablespace != db_info['tablespace']: + changed = set_tablespace(cursor, db, tablespace) + + return changed + + +def db_matches(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace): + if not db_exists(cursor, db): + return False + else: + db_info = get_db_info(cursor, db) + if (encoding and get_encoding_id(cursor, encoding) != db_info['encoding_id']): + return False + elif lc_collate and lc_collate != db_info['lc_collate']: + return False + elif lc_ctype and lc_ctype != db_info['lc_ctype']: + return False + elif owner and owner != db_info['owner']: + return False + elif conn_limit and conn_limit != str(db_info['conn_limit']): + return False + elif tablespace and tablespace != db_info['tablespace']: + return False + else: + return True + + +def db_dump(module, target, target_opts="", + db=None, + dump_extra_args=None, + user=None, + password=None, + host=None, + port=None, + **kw): + + flags = login_flags(db, host, port, user, db_prefix=False) + cmd = module.get_bin_path('pg_dump', True) + comp_prog_path = None + + if os.path.splitext(target)[-1] == '.tar': + flags.append(' --format=t') + elif os.path.splitext(target)[-1] == '.pgc': + flags.append(' --format=c') + if os.path.splitext(target)[-1] == '.gz': + if module.get_bin_path('pigz'): + comp_prog_path = module.get_bin_path('pigz', True) + else: + comp_prog_path = module.get_bin_path('gzip', True) + elif os.path.splitext(target)[-1] == '.bz2': + comp_prog_path = module.get_bin_path('bzip2', True) + elif os.path.splitext(target)[-1] == '.xz': + comp_prog_path = module.get_bin_path('xz', True) + + cmd += "".join(flags) + + if dump_extra_args: + cmd += " {0} ".format(dump_extra_args) + + if target_opts: + cmd += " {0} ".format(target_opts) + + if comp_prog_path: + # Use a fifo to be notified of an error in pg_dump + # Using shell pipe has no way to return the code of the first command + # in a portable way. + fifo = os.path.join(module.tmpdir, 'pg_fifo') + os.mkfifo(fifo) + cmd = '{1} <{3} > {2} & {0} >{3}'.format(cmd, comp_prog_path, shlex_quote(target), fifo) + else: + cmd = '{0} > {1}'.format(cmd, shlex_quote(target)) + + return do_with_password(module, cmd, password) + + +def db_restore(module, target, target_opts="", + db=None, + user=None, + password=None, + host=None, + port=None, + **kw): + + flags = login_flags(db, host, port, user) + comp_prog_path = None + cmd = module.get_bin_path('psql', True) + + if os.path.splitext(target)[-1] == '.sql': + flags.append(' --file={0}'.format(target)) + + elif os.path.splitext(target)[-1] == '.tar': + flags.append(' --format=Tar') + cmd = module.get_bin_path('pg_restore', True) + + elif os.path.splitext(target)[-1] == '.pgc': + flags.append(' --format=Custom') + cmd = module.get_bin_path('pg_restore', True) + + elif os.path.splitext(target)[-1] == '.gz': + comp_prog_path = module.get_bin_path('zcat', True) + + elif os.path.splitext(target)[-1] == '.bz2': + comp_prog_path = module.get_bin_path('bzcat', True) + + elif os.path.splitext(target)[-1] == '.xz': + comp_prog_path = module.get_bin_path('xzcat', True) + + cmd += "".join(flags) + if target_opts: + cmd += " {0} ".format(target_opts) + + if comp_prog_path: + env = os.environ.copy() + if password: + env = {"PGPASSWORD": password} + p1 = subprocess.Popen([comp_prog_path, target], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p2 = subprocess.Popen(cmd, stdin=p1.stdout, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=env) + (stdout2, stderr2) = p2.communicate() + p1.stdout.close() + p1.wait() + if p1.returncode != 0: + stderr1 = p1.stderr.read() + return p1.returncode, '', stderr1, 'cmd: ****' + else: + return p2.returncode, '', stderr2, 'cmd: ****' + else: + cmd = '{0} < {1}'.format(cmd, shlex_quote(target)) + + return do_with_password(module, cmd, password) + + +def login_flags(db, host, port, user, db_prefix=True): + """ + returns a list of connection argument strings each prefixed + with a space and quoted where necessary to later be combined + in a single shell string with `"".join(rv)` + + db_prefix determines if "--dbname" is prefixed to the db argument, + since the argument was introduced in 9.3. + """ + flags = [] + if db: + if db_prefix: + flags.append(' --dbname={0}'.format(shlex_quote(db))) + else: + flags.append(' {0}'.format(shlex_quote(db))) + if host: + flags.append(' --host={0}'.format(host)) + if port: + flags.append(' --port={0}'.format(port)) + if user: + flags.append(' --username={0}'.format(user)) + return flags + + +def do_with_password(module, cmd, password): + env = {} + if password: + env = {"PGPASSWORD": password} + executed_commands.append(cmd) + rc, stderr, stdout = module.run_command(cmd, use_unsafe_shell=True, environ_update=env) + return rc, stderr, stdout, cmd + + +def set_tablespace(cursor, db, tablespace): + query = "ALTER DATABASE %s SET TABLESPACE %s" % ( + pg_quote_identifier(db, 'database'), + pg_quote_identifier(tablespace, 'tablespace')) + executed_commands.append(query) + cursor.execute(query) + return True + +# =========================================== +# Module execution. +# + + +def main(): + argument_spec = pgutils.postgres_common_argument_spec() + argument_spec.update( + db=dict(type='str', required=True, aliases=['name']), + owner=dict(type='str', default=''), + template=dict(type='str', default=''), + encoding=dict(type='str', default=''), + lc_collate=dict(type='str', default=''), + lc_ctype=dict(type='str', default=''), + state=dict(type='str', default='present', choices=['absent', 'dump', 'present', 'restore']), + target=dict(type='path', default=''), + target_opts=dict(type='str', default=''), + maintenance_db=dict(type='str', default="postgres"), + session_role=dict(type='str'), + conn_limit=dict(type='str', default=''), + tablespace=dict(type='path', default=''), + dump_extra_args=dict(type='str', default=None), + ) + + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True + ) + + db = module.params["db"] + owner = module.params["owner"] + template = module.params["template"] + encoding = module.params["encoding"] + lc_collate = module.params["lc_collate"] + lc_ctype = module.params["lc_ctype"] + target = module.params["target"] + target_opts = module.params["target_opts"] + state = module.params["state"] + changed = False + maintenance_db = module.params['maintenance_db'] + session_role = module.params["session_role"] + conn_limit = module.params['conn_limit'] + tablespace = module.params['tablespace'] + dump_extra_args = module.params['dump_extra_args'] + + raw_connection = state in ("dump", "restore") + + if not raw_connection: + pgutils.ensure_required_libs(module) + + # To use defaults values, keyword arguments must be absent, so + # check which values are empty and don't include in the **kw + # dictionary + params_map = { + "login_host": "host", + "login_user": "user", + "login_password": "password", + "port": "port", + "ssl_mode": "sslmode", + "ca_cert": "sslrootcert" + } + kw = dict((params_map[k], v) for (k, v) in iteritems(module.params) + if k in params_map and v != '' and v is not None) + + # If a login_unix_socket is specified, incorporate it here. + is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost" + + if is_localhost and module.params["login_unix_socket"] != "": + kw["host"] = module.params["login_unix_socket"] + + if target == "": + target = "{0}/{1}.sql".format(os.getcwd(), db) + target = os.path.expanduser(target) + + if not raw_connection: + try: + db_connection = psycopg2.connect(database=maintenance_db, **kw) + + # Enable autocommit so we can create databases + if psycopg2.__version__ >= '2.4.2': + db_connection.autocommit = True + else: + db_connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + cursor = db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) + + except TypeError as e: + if 'sslrootcert' in e.args[0]: + module.fail_json(msg='Postgresql server must be at least version 8.4 to support sslrootcert. Exception: {0}'.format(to_native(e)), + exception=traceback.format_exc()) + module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc()) + + except Exception as e: + module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc()) + + if session_role: + try: + cursor.execute('SET ROLE "%s"' % session_role) + except Exception as e: + module.fail_json(msg="Could not switch role: %s" % to_native(e), exception=traceback.format_exc()) + + try: + if module.check_mode: + if state == "absent": + changed = db_exists(cursor, db) + elif state == "present": + changed = not db_matches(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace) + module.exit_json(changed=changed, db=db, executed_commands=executed_commands) + + if state == "absent": + try: + changed = db_delete(cursor, db) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + elif state == "present": + try: + changed = db_create(cursor, db, owner, template, encoding, lc_collate, lc_ctype, conn_limit, tablespace) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + elif state in ("dump", "restore"): + method = state == "dump" and db_dump or db_restore + try: + if state == 'dump': + rc, stdout, stderr, cmd = method(module, target, target_opts, db, dump_extra_args, **kw) + else: + rc, stdout, stderr, cmd = method(module, target, target_opts, db, **kw) + + if rc != 0: + module.fail_json(msg=stderr, stdout=stdout, rc=rc, cmd=cmd) + else: + module.exit_json(changed=True, msg=stdout, stderr=stderr, rc=rc, cmd=cmd, + executed_commands=executed_commands) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + except NotSupportedError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + except SystemExit: + # Avoid catching this on Python 2.4 + raise + except Exception as e: + module.fail_json(msg="Database query failed: %s" % to_native(e), exception=traceback.format_exc()) + + module.exit_json(changed=changed, db=db, executed_commands=executed_commands) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_privs.py b/test/support/integration/plugins/modules/postgresql_privs.py new file mode 100644 index 00000000..ba8324dd --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_privs.py @@ -0,0 +1,1097 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# Copyright: (c) 2019, Tobias Birkefeld (@tcraxs) <t@craxs.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: postgresql_privs +version_added: '1.2' +short_description: Grant or revoke privileges on PostgreSQL database objects +description: +- Grant or revoke privileges on PostgreSQL database objects. +- This module is basically a wrapper around most of the functionality of + PostgreSQL's GRANT and REVOKE statements with detection of changes + (GRANT/REVOKE I(privs) ON I(type) I(objs) TO/FROM I(roles)). +options: + database: + description: + - Name of database to connect to. + required: yes + type: str + aliases: + - db + - login_db + state: + description: + - If C(present), the specified privileges are granted, if C(absent) they are revoked. + type: str + default: present + choices: [ absent, present ] + privs: + description: + - Comma separated list of privileges to grant/revoke. + type: str + aliases: + - priv + type: + description: + - Type of database object to set privileges on. + - The C(default_privs) choice is available starting at version 2.7. + - The C(foreign_data_wrapper) and C(foreign_server) object types are available from Ansible version '2.8'. + - The C(type) choice is available from Ansible version '2.10'. + type: str + default: table + choices: [ database, default_privs, foreign_data_wrapper, foreign_server, function, + group, language, table, tablespace, schema, sequence, type ] + objs: + description: + - Comma separated list of database objects to set privileges on. + - If I(type) is C(table), C(partition table), C(sequence) or C(function), + the special valueC(ALL_IN_SCHEMA) can be provided instead to specify all + database objects of type I(type) in the schema specified via I(schema). + (This also works with PostgreSQL < 9.0.) (C(ALL_IN_SCHEMA) is available + for C(function) and C(partition table) from version 2.8) + - If I(type) is C(database), this parameter can be omitted, in which case + privileges are set for the database specified via I(database). + - 'If I(type) is I(function), colons (":") in object names will be + replaced with commas (needed to specify function signatures, see examples)' + type: str + aliases: + - obj + schema: + description: + - Schema that contains the database objects specified via I(objs). + - May only be provided if I(type) is C(table), C(sequence), C(function), C(type), + or C(default_privs). Defaults to C(public) in these cases. + - Pay attention, for embedded types when I(type=type) + I(schema) can be C(pg_catalog) or C(information_schema) respectively. + type: str + roles: + description: + - Comma separated list of role (user/group) names to set permissions for. + - The special value C(PUBLIC) can be provided instead to set permissions + for the implicitly defined PUBLIC group. + type: str + required: yes + aliases: + - role + fail_on_role: + version_added: '2.8' + description: + - If C(yes), fail when target role (for whom privs need to be granted) does not exist. + Otherwise just warn and continue. + default: yes + type: bool + session_role: + version_added: '2.8' + description: + - Switch to session_role after connecting. + - The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally. + type: str + target_roles: + description: + - A list of existing role (user/group) names to set as the + default permissions for database objects subsequently created by them. + - Parameter I(target_roles) is only available with C(type=default_privs). + type: str + version_added: '2.8' + grant_option: + description: + - Whether C(role) may grant/revoke the specified privileges/group memberships to others. + - Set to C(no) to revoke GRANT OPTION, leave unspecified to make no changes. + - I(grant_option) only has an effect if I(state) is C(present). + type: bool + aliases: + - admin_option + host: + description: + - Database host address. If unspecified, connect via Unix socket. + type: str + aliases: + - login_host + port: + description: + - Database port to connect to. + type: int + default: 5432 + aliases: + - login_port + unix_socket: + description: + - Path to a Unix domain socket for local connections. + type: str + aliases: + - login_unix_socket + login: + description: + - The username to authenticate with. + type: str + default: postgres + aliases: + - login_user + password: + description: + - The password to authenticate with. + type: str + aliases: + - login_password + ssl_mode: + description: + - Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated with the server. + - See https://www.postgresql.org/docs/current/static/libpq-ssl.html for more information on the modes. + - Default of C(prefer) matches libpq default. + type: str + default: prefer + choices: [ allow, disable, prefer, require, verify-ca, verify-full ] + version_added: '2.3' + ca_cert: + description: + - Specifies the name of a file containing SSL certificate authority (CA) certificate(s). + - If the file exists, the server's certificate will be verified to be signed by one of these authorities. + version_added: '2.3' + type: str + aliases: + - ssl_rootcert + +notes: +- Parameters that accept comma separated lists (I(privs), I(objs), I(roles)) + have singular alias names (I(priv), I(obj), I(role)). +- To revoke only C(GRANT OPTION) for a specific object, set I(state) to + C(present) and I(grant_option) to C(no) (see examples). +- Note that when revoking privileges from a role R, this role may still have + access via privileges granted to any role R is a member of including C(PUBLIC). +- Note that when revoking privileges from a role R, you do so as the user + specified via I(login). If R has been granted the same privileges by + another user also, R can still access database objects via these privileges. +- When revoking privileges, C(RESTRICT) is assumed (see PostgreSQL docs). + +seealso: +- module: postgresql_user +- module: postgresql_owner +- module: postgresql_membership +- name: PostgreSQL privileges + description: General information about PostgreSQL privileges. + link: https://www.postgresql.org/docs/current/ddl-priv.html +- name: PostgreSQL GRANT command reference + description: Complete reference of the PostgreSQL GRANT command documentation. + link: https://www.postgresql.org/docs/current/sql-grant.html +- name: PostgreSQL REVOKE command reference + description: Complete reference of the PostgreSQL REVOKE command documentation. + link: https://www.postgresql.org/docs/current/sql-revoke.html + +extends_documentation_fragment: +- postgres + +author: +- Bernhard Weitzhofer (@b6d) +- Tobias Birkefeld (@tcraxs) +''' + +EXAMPLES = r''' +# On database "library": +# GRANT SELECT, INSERT, UPDATE ON TABLE public.books, public.authors +# TO librarian, reader WITH GRANT OPTION +- name: Grant privs to librarian and reader on database library + postgresql_privs: + database: library + state: present + privs: SELECT,INSERT,UPDATE + type: table + objs: books,authors + schema: public + roles: librarian,reader + grant_option: yes + +- name: Same as above leveraging default values + postgresql_privs: + db: library + privs: SELECT,INSERT,UPDATE + objs: books,authors + roles: librarian,reader + grant_option: yes + +# REVOKE GRANT OPTION FOR INSERT ON TABLE books FROM reader +# Note that role "reader" will be *granted* INSERT privilege itself if this +# isn't already the case (since state: present). +- name: Revoke privs from reader + postgresql_privs: + db: library + state: present + priv: INSERT + obj: books + role: reader + grant_option: no + +# "public" is the default schema. This also works for PostgreSQL 8.x. +- name: REVOKE INSERT, UPDATE ON ALL TABLES IN SCHEMA public FROM reader + postgresql_privs: + db: library + state: absent + privs: INSERT,UPDATE + objs: ALL_IN_SCHEMA + role: reader + +- name: GRANT ALL PRIVILEGES ON SCHEMA public, math TO librarian + postgresql_privs: + db: library + privs: ALL + type: schema + objs: public,math + role: librarian + +# Note the separation of arguments with colons. +- name: GRANT ALL PRIVILEGES ON FUNCTION math.add(int, int) TO librarian, reader + postgresql_privs: + db: library + privs: ALL + type: function + obj: add(int:int) + schema: math + roles: librarian,reader + +# Note that group role memberships apply cluster-wide and therefore are not +# restricted to database "library" here. +- name: GRANT librarian, reader TO alice, bob WITH ADMIN OPTION + postgresql_privs: + db: library + type: group + objs: librarian,reader + roles: alice,bob + admin_option: yes + +# Note that here "db: postgres" specifies the database to connect to, not the +# database to grant privileges on (which is specified via the "objs" param) +- name: GRANT ALL PRIVILEGES ON DATABASE library TO librarian + postgresql_privs: + db: postgres + privs: ALL + type: database + obj: library + role: librarian + +# If objs is omitted for type "database", it defaults to the database +# to which the connection is established +- name: GRANT ALL PRIVILEGES ON DATABASE library TO librarian + postgresql_privs: + db: library + privs: ALL + type: database + role: librarian + +# Available since version 2.7 +# Objs must be set, ALL_DEFAULT to TABLES/SEQUENCES/TYPES/FUNCTIONS +# ALL_DEFAULT works only with privs=ALL +# For specific +- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO librarian + postgresql_privs: + db: library + objs: ALL_DEFAULT + privs: ALL + type: default_privs + role: librarian + grant_option: yes + +# Available since version 2.7 +# Objs must be set, ALL_DEFAULT to TABLES/SEQUENCES/TYPES/FUNCTIONS +# ALL_DEFAULT works only with privs=ALL +# For specific +- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO reader, step 1 + postgresql_privs: + db: library + objs: TABLES,SEQUENCES + privs: SELECT + type: default_privs + role: reader + +- name: ALTER DEFAULT PRIVILEGES ON DATABASE library TO reader, step 2 + postgresql_privs: + db: library + objs: TYPES + privs: USAGE + type: default_privs + role: reader + +# Available since version 2.8 +- name: GRANT ALL PRIVILEGES ON FOREIGN DATA WRAPPER fdw TO reader + postgresql_privs: + db: test + objs: fdw + privs: ALL + type: foreign_data_wrapper + role: reader + +# Available since version 2.10 +- name: GRANT ALL PRIVILEGES ON TYPE customtype TO reader + postgresql_privs: + db: test + objs: customtype + privs: ALL + type: type + role: reader + +# Available since version 2.8 +- name: GRANT ALL PRIVILEGES ON FOREIGN SERVER fdw_server TO reader + postgresql_privs: + db: test + objs: fdw_server + privs: ALL + type: foreign_server + role: reader + +# Available since version 2.8 +# Grant 'execute' permissions on all functions in schema 'common' to role 'caller' +- name: GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA common TO caller + postgresql_privs: + type: function + state: present + privs: EXECUTE + roles: caller + objs: ALL_IN_SCHEMA + schema: common + +# Available since version 2.8 +# ALTER DEFAULT PRIVILEGES FOR ROLE librarian IN SCHEMA library GRANT SELECT ON TABLES TO reader +# GRANT SELECT privileges for new TABLES objects created by librarian as +# default to the role reader. +# For specific +- name: ALTER privs + postgresql_privs: + db: library + schema: library + objs: TABLES + privs: SELECT + type: default_privs + role: reader + target_roles: librarian + +# Available since version 2.8 +# ALTER DEFAULT PRIVILEGES FOR ROLE librarian IN SCHEMA library REVOKE SELECT ON TABLES FROM reader +# REVOKE SELECT privileges for new TABLES objects created by librarian as +# default from the role reader. +# For specific +- name: ALTER privs + postgresql_privs: + db: library + state: absent + schema: library + objs: TABLES + privs: SELECT + type: default_privs + role: reader + target_roles: librarian + +# Available since version 2.10 +- name: Grant type privileges for pg_catalog.numeric type to alice + postgresql_privs: + type: type + roles: alice + privs: ALL + objs: numeric + schema: pg_catalog + db: acme +''' + +RETURN = r''' +queries: + description: List of executed queries. + returned: always + type: list + sample: ['REVOKE GRANT OPTION FOR INSERT ON TABLE "books" FROM "reader";'] + version_added: '2.8' +''' + +import traceback + +PSYCOPG2_IMP_ERR = None +try: + import psycopg2 + import psycopg2.extensions +except ImportError: + PSYCOPG2_IMP_ERR = traceback.format_exc() + psycopg2 = None + +# import module snippets +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.database import pg_quote_identifier +from ansible.module_utils.postgres import postgres_common_argument_spec +from ansible.module_utils._text import to_native + +VALID_PRIVS = frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', + 'REFERENCES', 'TRIGGER', 'CREATE', 'CONNECT', + 'TEMPORARY', 'TEMP', 'EXECUTE', 'USAGE', 'ALL', 'USAGE')) +VALID_DEFAULT_OBJS = {'TABLES': ('ALL', 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER'), + 'SEQUENCES': ('ALL', 'SELECT', 'UPDATE', 'USAGE'), + 'FUNCTIONS': ('ALL', 'EXECUTE'), + 'TYPES': ('ALL', 'USAGE')} + +executed_queries = [] + + +class Error(Exception): + pass + + +def role_exists(module, cursor, rolname): + """Check user exists or not""" + query = "SELECT 1 FROM pg_roles WHERE rolname = '%s'" % rolname + try: + cursor.execute(query) + return cursor.rowcount > 0 + + except Exception as e: + module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e))) + + return False + + +# We don't have functools.partial in Python < 2.5 +def partial(f, *args, **kwargs): + """Partial function application""" + + def g(*g_args, **g_kwargs): + new_kwargs = kwargs.copy() + new_kwargs.update(g_kwargs) + return f(*(args + g_args), **g_kwargs) + + g.f = f + g.args = args + g.kwargs = kwargs + return g + + +class Connection(object): + """Wrapper around a psycopg2 connection with some convenience methods""" + + def __init__(self, params, module): + self.database = params.database + self.module = module + # To use defaults values, keyword arguments must be absent, so + # check which values are empty and don't include in the **kw + # dictionary + params_map = { + "host": "host", + "login": "user", + "password": "password", + "port": "port", + "database": "database", + "ssl_mode": "sslmode", + "ca_cert": "sslrootcert" + } + + kw = dict((params_map[k], getattr(params, k)) for k in params_map + if getattr(params, k) != '' and getattr(params, k) is not None) + + # If a unix_socket is specified, incorporate it here. + is_localhost = "host" not in kw or kw["host"] == "" or kw["host"] == "localhost" + if is_localhost and params.unix_socket != "": + kw["host"] = params.unix_socket + + sslrootcert = params.ca_cert + if psycopg2.__version__ < '2.4.3' and sslrootcert is not None: + raise ValueError('psycopg2 must be at least 2.4.3 in order to user the ca_cert parameter') + + self.connection = psycopg2.connect(**kw) + self.cursor = self.connection.cursor() + + def commit(self): + self.connection.commit() + + def rollback(self): + self.connection.rollback() + + @property + def encoding(self): + """Connection encoding in Python-compatible form""" + return psycopg2.extensions.encodings[self.connection.encoding] + + # Methods for querying database objects + + # PostgreSQL < 9.0 doesn't support "ALL TABLES IN SCHEMA schema"-like + # phrases in GRANT or REVOKE statements, therefore alternative methods are + # provided here. + + def schema_exists(self, schema): + query = """SELECT count(*) + FROM pg_catalog.pg_namespace WHERE nspname = %s""" + self.cursor.execute(query, (schema,)) + return self.cursor.fetchone()[0] > 0 + + def get_all_tables_in_schema(self, schema): + if not self.schema_exists(schema): + raise Error('Schema "%s" does not exist.' % schema) + query = """SELECT relname + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind in ('r', 'v', 'm', 'p')""" + self.cursor.execute(query, (schema,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_all_sequences_in_schema(self, schema): + if not self.schema_exists(schema): + raise Error('Schema "%s" does not exist.' % schema) + query = """SELECT relname + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind = 'S'""" + self.cursor.execute(query, (schema,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_all_functions_in_schema(self, schema): + if not self.schema_exists(schema): + raise Error('Schema "%s" does not exist.' % schema) + query = """SELECT p.proname, oidvectortypes(p.proargtypes) + FROM pg_catalog.pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE nspname = %s""" + self.cursor.execute(query, (schema,)) + return ["%s(%s)" % (t[0], t[1]) for t in self.cursor.fetchall()] + + # Methods for getting access control lists and group membership info + + # To determine whether anything has changed after granting/revoking + # privileges, we compare the access control lists of the specified database + # objects before and afterwards. Python's list/string comparison should + # suffice for change detection, we should not actually have to parse ACLs. + # The same should apply to group membership information. + + def get_table_acls(self, schema, tables): + query = """SELECT relacl + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind in ('r','p','v','m') AND relname = ANY (%s) + ORDER BY relname""" + self.cursor.execute(query, (schema, tables)) + return [t[0] for t in self.cursor.fetchall()] + + def get_sequence_acls(self, schema, sequences): + query = """SELECT relacl + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE nspname = %s AND relkind = 'S' AND relname = ANY (%s) + ORDER BY relname""" + self.cursor.execute(query, (schema, sequences)) + return [t[0] for t in self.cursor.fetchall()] + + def get_function_acls(self, schema, function_signatures): + funcnames = [f.split('(', 1)[0] for f in function_signatures] + query = """SELECT proacl + FROM pg_catalog.pg_proc p + JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + WHERE nspname = %s AND proname = ANY (%s) + ORDER BY proname, proargtypes""" + self.cursor.execute(query, (schema, funcnames)) + return [t[0] for t in self.cursor.fetchall()] + + def get_schema_acls(self, schemas): + query = """SELECT nspacl FROM pg_catalog.pg_namespace + WHERE nspname = ANY (%s) ORDER BY nspname""" + self.cursor.execute(query, (schemas,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_language_acls(self, languages): + query = """SELECT lanacl FROM pg_catalog.pg_language + WHERE lanname = ANY (%s) ORDER BY lanname""" + self.cursor.execute(query, (languages,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_tablespace_acls(self, tablespaces): + query = """SELECT spcacl FROM pg_catalog.pg_tablespace + WHERE spcname = ANY (%s) ORDER BY spcname""" + self.cursor.execute(query, (tablespaces,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_database_acls(self, databases): + query = """SELECT datacl FROM pg_catalog.pg_database + WHERE datname = ANY (%s) ORDER BY datname""" + self.cursor.execute(query, (databases,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_group_memberships(self, groups): + query = """SELECT roleid, grantor, member, admin_option + FROM pg_catalog.pg_auth_members am + JOIN pg_catalog.pg_roles r ON r.oid = am.roleid + WHERE r.rolname = ANY(%s) + ORDER BY roleid, grantor, member""" + self.cursor.execute(query, (groups,)) + return self.cursor.fetchall() + + def get_default_privs(self, schema, *args): + query = """SELECT defaclacl + FROM pg_default_acl a + JOIN pg_namespace b ON a.defaclnamespace=b.oid + WHERE b.nspname = %s;""" + self.cursor.execute(query, (schema,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_foreign_data_wrapper_acls(self, fdws): + query = """SELECT fdwacl FROM pg_catalog.pg_foreign_data_wrapper + WHERE fdwname = ANY (%s) ORDER BY fdwname""" + self.cursor.execute(query, (fdws,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_foreign_server_acls(self, fs): + query = """SELECT srvacl FROM pg_catalog.pg_foreign_server + WHERE srvname = ANY (%s) ORDER BY srvname""" + self.cursor.execute(query, (fs,)) + return [t[0] for t in self.cursor.fetchall()] + + def get_type_acls(self, schema, types): + query = """SELECT t.typacl FROM pg_catalog.pg_type t + JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE n.nspname = %s AND t.typname = ANY (%s) ORDER BY typname""" + self.cursor.execute(query, (schema, types)) + return [t[0] for t in self.cursor.fetchall()] + + # Manipulating privileges + + def manipulate_privs(self, obj_type, privs, objs, roles, target_roles, + state, grant_option, schema_qualifier=None, fail_on_role=True): + """Manipulate database object privileges. + + :param obj_type: Type of database object to grant/revoke + privileges for. + :param privs: Either a list of privileges to grant/revoke + or None if type is "group". + :param objs: List of database objects to grant/revoke + privileges for. + :param roles: Either a list of role names or "PUBLIC" + for the implicitly defined "PUBLIC" group + :param target_roles: List of role names to grant/revoke + default privileges as. + :param state: "present" to grant privileges, "absent" to revoke. + :param grant_option: Only for state "present": If True, set + grant/admin option. If False, revoke it. + If None, don't change grant option. + :param schema_qualifier: Some object types ("TABLE", "SEQUENCE", + "FUNCTION") must be qualified by schema. + Ignored for other Types. + """ + # get_status: function to get current status + if obj_type == 'table': + get_status = partial(self.get_table_acls, schema_qualifier) + elif obj_type == 'sequence': + get_status = partial(self.get_sequence_acls, schema_qualifier) + elif obj_type == 'function': + get_status = partial(self.get_function_acls, schema_qualifier) + elif obj_type == 'schema': + get_status = self.get_schema_acls + elif obj_type == 'language': + get_status = self.get_language_acls + elif obj_type == 'tablespace': + get_status = self.get_tablespace_acls + elif obj_type == 'database': + get_status = self.get_database_acls + elif obj_type == 'group': + get_status = self.get_group_memberships + elif obj_type == 'default_privs': + get_status = partial(self.get_default_privs, schema_qualifier) + elif obj_type == 'foreign_data_wrapper': + get_status = self.get_foreign_data_wrapper_acls + elif obj_type == 'foreign_server': + get_status = self.get_foreign_server_acls + elif obj_type == 'type': + get_status = partial(self.get_type_acls, schema_qualifier) + else: + raise Error('Unsupported database object type "%s".' % obj_type) + + # Return False (nothing has changed) if there are no objs to work on. + if not objs: + return False + + # obj_ids: quoted db object identifiers (sometimes schema-qualified) + if obj_type == 'function': + obj_ids = [] + for obj in objs: + try: + f, args = obj.split('(', 1) + except Exception: + raise Error('Illegal function signature: "%s".' % obj) + obj_ids.append('"%s"."%s"(%s' % (schema_qualifier, f, args)) + elif obj_type in ['table', 'sequence', 'type']: + obj_ids = ['"%s"."%s"' % (schema_qualifier, o) for o in objs] + else: + obj_ids = ['"%s"' % o for o in objs] + + # set_what: SQL-fragment specifying what to set for the target roles: + # Either group membership or privileges on objects of a certain type + if obj_type == 'group': + set_what = ','.join('"%s"' % i for i in obj_ids) + elif obj_type == 'default_privs': + # We don't want privs to be quoted here + set_what = ','.join(privs) + else: + # function types are already quoted above + if obj_type != 'function': + obj_ids = [pg_quote_identifier(i, 'table') for i in obj_ids] + # Note: obj_type has been checked against a set of string literals + # and privs was escaped when it was parsed + # Note: Underscores are replaced with spaces to support multi-word obj_type + set_what = '%s ON %s %s' % (','.join(privs), obj_type.replace('_', ' '), + ','.join(obj_ids)) + + # for_whom: SQL-fragment specifying for whom to set the above + if roles == 'PUBLIC': + for_whom = 'PUBLIC' + else: + for_whom = [] + for r in roles: + if not role_exists(self.module, self.cursor, r): + if fail_on_role: + self.module.fail_json(msg="Role '%s' does not exist" % r.strip()) + + else: + self.module.warn("Role '%s' does not exist, pass it" % r.strip()) + else: + for_whom.append('"%s"' % r) + + if not for_whom: + return False + + for_whom = ','.join(for_whom) + + # as_who: + as_who = None + if target_roles: + as_who = ','.join('"%s"' % r for r in target_roles) + + status_before = get_status(objs) + + query = QueryBuilder(state) \ + .for_objtype(obj_type) \ + .with_grant_option(grant_option) \ + .for_whom(for_whom) \ + .as_who(as_who) \ + .for_schema(schema_qualifier) \ + .set_what(set_what) \ + .for_objs(objs) \ + .build() + + executed_queries.append(query) + self.cursor.execute(query) + status_after = get_status(objs) + + def nonesorted(e): + # For python 3+ that can fail trying + # to compare NoneType elements by sort method. + if e is None: + return '' + return e + + status_before.sort(key=nonesorted) + status_after.sort(key=nonesorted) + return status_before != status_after + + +class QueryBuilder(object): + def __init__(self, state): + self._grant_option = None + self._for_whom = None + self._as_who = None + self._set_what = None + self._obj_type = None + self._state = state + self._schema = None + self._objs = None + self.query = [] + + def for_objs(self, objs): + self._objs = objs + return self + + def for_schema(self, schema): + self._schema = schema + return self + + def with_grant_option(self, option): + self._grant_option = option + return self + + def for_whom(self, who): + self._for_whom = who + return self + + def as_who(self, target_roles): + self._as_who = target_roles + return self + + def set_what(self, what): + self._set_what = what + return self + + def for_objtype(self, objtype): + self._obj_type = objtype + return self + + def build(self): + if self._state == 'present': + self.build_present() + elif self._state == 'absent': + self.build_absent() + else: + self.build_absent() + return '\n'.join(self.query) + + def add_default_revoke(self): + for obj in self._objs: + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} REVOKE ALL ON {2} FROM {3};'.format(self._as_who, + self._schema, obj, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} REVOKE ALL ON {1} FROM {2};'.format(self._schema, obj, + self._for_whom)) + + def add_grant_option(self): + if self._grant_option: + if self._obj_type == 'group': + self.query[-1] += ' WITH ADMIN OPTION;' + else: + self.query[-1] += ' WITH GRANT OPTION;' + else: + self.query[-1] += ';' + if self._obj_type == 'group': + self.query.append('REVOKE ADMIN OPTION FOR {0} FROM {1};'.format(self._set_what, self._for_whom)) + elif not self._obj_type == 'default_privs': + self.query.append('REVOKE GRANT OPTION FOR {0} FROM {1};'.format(self._set_what, self._for_whom)) + + def add_default_priv(self): + for obj in self._objs: + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} GRANT {2} ON {3} TO {4}'.format(self._as_who, + self._schema, + self._set_what, + obj, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} GRANT {1} ON {2} TO {3}'.format(self._schema, + self._set_what, + obj, + self._for_whom)) + self.add_grant_option() + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} GRANT USAGE ON TYPES TO {2}'.format(self._as_who, + self._schema, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} GRANT USAGE ON TYPES TO {1}'.format(self._schema, self._for_whom)) + self.add_grant_option() + + def build_present(self): + if self._obj_type == 'default_privs': + self.add_default_revoke() + self.add_default_priv() + else: + self.query.append('GRANT {0} TO {1}'.format(self._set_what, self._for_whom)) + self.add_grant_option() + + def build_absent(self): + if self._obj_type == 'default_privs': + self.query = [] + for obj in ['TABLES', 'SEQUENCES', 'TYPES']: + if self._as_who: + self.query.append( + 'ALTER DEFAULT PRIVILEGES FOR ROLE {0} IN SCHEMA {1} REVOKE ALL ON {2} FROM {3};'.format(self._as_who, + self._schema, obj, + self._for_whom)) + else: + self.query.append( + 'ALTER DEFAULT PRIVILEGES IN SCHEMA {0} REVOKE ALL ON {1} FROM {2};'.format(self._schema, obj, + self._for_whom)) + else: + self.query.append('REVOKE {0} FROM {1};'.format(self._set_what, self._for_whom)) + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + database=dict(required=True, aliases=['db', 'login_db']), + state=dict(default='present', choices=['present', 'absent']), + privs=dict(required=False, aliases=['priv']), + type=dict(default='table', + choices=['table', + 'sequence', + 'function', + 'database', + 'schema', + 'language', + 'tablespace', + 'group', + 'default_privs', + 'foreign_data_wrapper', + 'foreign_server', + 'type', ]), + objs=dict(required=False, aliases=['obj']), + schema=dict(required=False), + roles=dict(required=True, aliases=['role']), + session_role=dict(required=False), + target_roles=dict(required=False), + grant_option=dict(required=False, type='bool', + aliases=['admin_option']), + host=dict(default='', aliases=['login_host']), + unix_socket=dict(default='', aliases=['login_unix_socket']), + login=dict(default='postgres', aliases=['login_user']), + password=dict(default='', aliases=['login_password'], no_log=True), + fail_on_role=dict(type='bool', default=True), + ) + + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True, + ) + + fail_on_role = module.params['fail_on_role'] + + # Create type object as namespace for module params + p = type('Params', (), module.params) + # param "schema": default, allowed depends on param "type" + if p.type in ['table', 'sequence', 'function', 'type', 'default_privs']: + p.schema = p.schema or 'public' + elif p.schema: + module.fail_json(msg='Argument "schema" is not allowed ' + 'for type "%s".' % p.type) + + # param "objs": default, required depends on param "type" + if p.type == 'database': + p.objs = p.objs or p.database + elif not p.objs: + module.fail_json(msg='Argument "objs" is required ' + 'for type "%s".' % p.type) + + # param "privs": allowed, required depends on param "type" + if p.type == 'group': + if p.privs: + module.fail_json(msg='Argument "privs" is not allowed ' + 'for type "group".') + elif not p.privs: + module.fail_json(msg='Argument "privs" is required ' + 'for type "%s".' % p.type) + + # Connect to Database + if not psycopg2: + module.fail_json(msg=missing_required_lib('psycopg2'), exception=PSYCOPG2_IMP_ERR) + try: + conn = Connection(p, module) + except psycopg2.Error as e: + module.fail_json(msg='Could not connect to database: %s' % to_native(e), exception=traceback.format_exc()) + except TypeError as e: + if 'sslrootcert' in e.args[0]: + module.fail_json(msg='Postgresql server must be at least version 8.4 to support sslrootcert') + module.fail_json(msg="unable to connect to database: %s" % to_native(e), exception=traceback.format_exc()) + except ValueError as e: + # We raise this when the psycopg library is too old + module.fail_json(msg=to_native(e)) + + if p.session_role: + try: + conn.cursor.execute('SET ROLE "%s"' % p.session_role) + except Exception as e: + module.fail_json(msg="Could not switch to role %s: %s" % (p.session_role, to_native(e)), exception=traceback.format_exc()) + + try: + # privs + if p.privs: + privs = frozenset(pr.upper() for pr in p.privs.split(',')) + if not privs.issubset(VALID_PRIVS): + module.fail_json(msg='Invalid privileges specified: %s' % privs.difference(VALID_PRIVS)) + else: + privs = None + # objs: + if p.type == 'table' and p.objs == 'ALL_IN_SCHEMA': + objs = conn.get_all_tables_in_schema(p.schema) + elif p.type == 'sequence' and p.objs == 'ALL_IN_SCHEMA': + objs = conn.get_all_sequences_in_schema(p.schema) + elif p.type == 'function' and p.objs == 'ALL_IN_SCHEMA': + objs = conn.get_all_functions_in_schema(p.schema) + elif p.type == 'default_privs': + if p.objs == 'ALL_DEFAULT': + objs = frozenset(VALID_DEFAULT_OBJS.keys()) + else: + objs = frozenset(obj.upper() for obj in p.objs.split(',')) + if not objs.issubset(VALID_DEFAULT_OBJS): + module.fail_json( + msg='Invalid Object set specified: %s' % objs.difference(VALID_DEFAULT_OBJS.keys())) + # Again, do we have valid privs specified for object type: + valid_objects_for_priv = frozenset(obj for obj in objs if privs.issubset(VALID_DEFAULT_OBJS[obj])) + if not valid_objects_for_priv == objs: + module.fail_json( + msg='Invalid priv specified. Valid object for priv: {0}. Objects: {1}'.format( + valid_objects_for_priv, objs)) + else: + objs = p.objs.split(',') + + # function signatures are encoded using ':' to separate args + if p.type == 'function': + objs = [obj.replace(':', ',') for obj in objs] + + # roles + if p.roles == 'PUBLIC': + roles = 'PUBLIC' + else: + roles = p.roles.split(',') + + if len(roles) == 1 and not role_exists(module, conn.cursor, roles[0]): + module.exit_json(changed=False) + + if fail_on_role: + module.fail_json(msg="Role '%s' does not exist" % roles[0].strip()) + + else: + module.warn("Role '%s' does not exist, nothing to do" % roles[0].strip()) + + # check if target_roles is set with type: default_privs + if p.target_roles and not p.type == 'default_privs': + module.warn('"target_roles" will be ignored ' + 'Argument "type: default_privs" is required for usage of "target_roles".') + + # target roles + if p.target_roles: + target_roles = p.target_roles.split(',') + else: + target_roles = None + + changed = conn.manipulate_privs( + obj_type=p.type, + privs=privs, + objs=objs, + roles=roles, + target_roles=target_roles, + state=p.state, + grant_option=p.grant_option, + schema_qualifier=p.schema, + fail_on_role=fail_on_role, + ) + + except Error as e: + conn.rollback() + module.fail_json(msg=e.message, exception=traceback.format_exc()) + + except psycopg2.Error as e: + conn.rollback() + module.fail_json(msg=to_native(e.message)) + + if module.check_mode: + conn.rollback() + else: + conn.commit() + module.exit_json(changed=changed, queries=executed_queries) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_query.py b/test/support/integration/plugins/modules/postgresql_query.py new file mode 100644 index 00000000..18d63e33 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_query.py @@ -0,0 +1,364 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2017, Felix Archambault +# Copyright: (c) 2019, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'supported_by': 'community', + 'status': ['preview'] +} + +DOCUMENTATION = r''' +--- +module: postgresql_query +short_description: Run PostgreSQL queries +description: +- Runs arbitrary PostgreSQL queries. +- Can run queries from SQL script files. +- Does not run against backup files. Use M(postgresql_db) with I(state=restore) + to run queries on files made by pg_dump/pg_dumpall utilities. +version_added: '2.8' +options: + query: + description: + - SQL query to run. Variables can be escaped with psycopg2 syntax + U(http://initd.org/psycopg/docs/usage.html). + type: str + positional_args: + description: + - List of values to be passed as positional arguments to the query. + When the value is a list, it will be converted to PostgreSQL array. + - Mutually exclusive with I(named_args). + type: list + elements: raw + named_args: + description: + - Dictionary of key-value arguments to pass to the query. + When the value is a list, it will be converted to PostgreSQL array. + - Mutually exclusive with I(positional_args). + type: dict + path_to_script: + description: + - Path to SQL script on the remote host. + - Returns result of the last query in the script. + - Mutually exclusive with I(query). + type: path + session_role: + description: + - Switch to session_role after connecting. The specified session_role must + be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though + the session_role were the one that had logged in originally. + type: str + db: + description: + - Name of database to connect to and run queries against. + type: str + aliases: + - login_db + autocommit: + description: + - Execute in autocommit mode when the query can't be run inside a transaction block + (e.g., VACUUM). + - Mutually exclusive with I(check_mode). + type: bool + default: no + version_added: '2.9' + encoding: + description: + - Set the client encoding for the current session (e.g. C(UTF-8)). + - The default is the encoding defined by the database. + type: str + version_added: '2.10' +seealso: +- module: postgresql_db +author: +- Felix Archambault (@archf) +- Andrew Klychkov (@Andersson007) +- Will Rouesnel (@wrouesnel) +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Simple select query to acme db + postgresql_query: + db: acme + query: SELECT version() + +- name: Select query to db acme with positional arguments and non-default credentials + postgresql_query: + db: acme + login_user: django + login_password: mysecretpass + query: SELECT * FROM acme WHERE id = %s AND story = %s + positional_args: + - 1 + - test + +- name: Select query to test_db with named_args + postgresql_query: + db: test_db + query: SELECT * FROM test WHERE id = %(id_val)s AND story = %(story_val)s + named_args: + id_val: 1 + story_val: test + +- name: Insert query to test_table in db test_db + postgresql_query: + db: test_db + query: INSERT INTO test_table (id, story) VALUES (2, 'my_long_story') + +- name: Run queries from SQL script using UTF-8 client encoding for session + postgresql_query: + db: test_db + path_to_script: /var/lib/pgsql/test.sql + positional_args: + - 1 + encoding: UTF-8 + +- name: Example of using autocommit parameter + postgresql_query: + db: test_db + query: VACUUM + autocommit: yes + +- name: > + Insert data to the column of array type using positional_args. + Note that we use quotes here, the same as for passing JSON, etc. + postgresql_query: + query: INSERT INTO test_table (array_column) VALUES (%s) + positional_args: + - '{1,2,3}' + +# Pass list and string vars as positional_args +- name: Set vars + set_fact: + my_list: + - 1 + - 2 + - 3 + my_arr: '{1, 2, 3}' + +- name: Select from test table by passing positional_args as arrays + postgresql_query: + query: SELECT * FROM test_array_table WHERE arr_col1 = %s AND arr_col2 = %s + positional_args: + - '{{ my_list }}' + - '{{ my_arr|string }}' +''' + +RETURN = r''' +query: + description: Query that was tried to be executed. + returned: always + type: str + sample: 'SELECT * FROM bar' +statusmessage: + description: Attribute containing the message returned by the command. + returned: always + type: str + sample: 'INSERT 0 1' +query_result: + description: + - List of dictionaries in column:value form representing returned rows. + returned: changed + type: list + sample: [{"Column": "Value1"},{"Column": "Value2"}] +rowcount: + description: Number of affected rows. + returned: changed + type: int + sample: 5 +''' + +try: + from psycopg2 import ProgrammingError as Psycopg2ProgrammingError + from psycopg2.extras import DictCursor +except ImportError: + # it is needed for checking 'no result to fetch' in main(), + # psycopg2 availability will be checked by connect_to_db() into + # ansible.module_utils.postgres + pass + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) +from ansible.module_utils._text import to_native +from ansible.module_utils.six import iteritems + + +# =========================================== +# Module execution. +# + +def list_to_pg_array(elem): + """Convert the passed list to PostgreSQL array + represented as a string. + + Args: + elem (list): List that needs to be converted. + + Returns: + elem (str): String representation of PostgreSQL array. + """ + elem = str(elem).strip('[]') + elem = '{' + elem + '}' + return elem + + +def convert_elements_to_pg_arrays(obj): + """Convert list elements of the passed object + to PostgreSQL arrays represented as strings. + + Args: + obj (dict or list): Object whose elements need to be converted. + + Returns: + obj (dict or list): Object with converted elements. + """ + if isinstance(obj, dict): + for (key, elem) in iteritems(obj): + if isinstance(elem, list): + obj[key] = list_to_pg_array(elem) + + elif isinstance(obj, list): + for i, elem in enumerate(obj): + if isinstance(elem, list): + obj[i] = list_to_pg_array(elem) + + return obj + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + query=dict(type='str'), + db=dict(type='str', aliases=['login_db']), + positional_args=dict(type='list', elements='raw'), + named_args=dict(type='dict'), + session_role=dict(type='str'), + path_to_script=dict(type='path'), + autocommit=dict(type='bool', default=False), + encoding=dict(type='str'), + ) + + module = AnsibleModule( + argument_spec=argument_spec, + mutually_exclusive=(('positional_args', 'named_args'),), + supports_check_mode=True, + ) + + query = module.params["query"] + positional_args = module.params["positional_args"] + named_args = module.params["named_args"] + path_to_script = module.params["path_to_script"] + autocommit = module.params["autocommit"] + encoding = module.params["encoding"] + + if autocommit and module.check_mode: + module.fail_json(msg="Using autocommit is mutually exclusive with check_mode") + + if path_to_script and query: + module.fail_json(msg="path_to_script is mutually exclusive with query") + + if positional_args: + positional_args = convert_elements_to_pg_arrays(positional_args) + + elif named_args: + named_args = convert_elements_to_pg_arrays(named_args) + + if path_to_script: + try: + with open(path_to_script, 'rb') as f: + query = to_native(f.read()) + except Exception as e: + module.fail_json(msg="Cannot read file '%s' : %s" % (path_to_script, to_native(e))) + + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=autocommit) + if encoding is not None: + db_connection.set_client_encoding(encoding) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + # Prepare args: + if module.params.get("positional_args"): + arguments = module.params["positional_args"] + elif module.params.get("named_args"): + arguments = module.params["named_args"] + else: + arguments = None + + # Set defaults: + changed = False + + # Execute query: + try: + cursor.execute(query, arguments) + except Exception as e: + if not autocommit: + db_connection.rollback() + + cursor.close() + db_connection.close() + module.fail_json(msg="Cannot execute SQL '%s' %s: %s" % (query, arguments, to_native(e))) + + statusmessage = cursor.statusmessage + rowcount = cursor.rowcount + + try: + query_result = [dict(row) for row in cursor.fetchall()] + except Psycopg2ProgrammingError as e: + if to_native(e) == 'no results to fetch': + query_result = {} + + except Exception as e: + module.fail_json(msg="Cannot fetch rows from cursor: %s" % to_native(e)) + + if 'SELECT' not in statusmessage: + if 'UPDATE' in statusmessage or 'INSERT' in statusmessage or 'DELETE' in statusmessage: + s = statusmessage.split() + if len(s) == 3: + if statusmessage.split()[2] != '0': + changed = True + + elif len(s) == 2: + if statusmessage.split()[1] != '0': + changed = True + + else: + changed = True + + else: + changed = True + + if module.check_mode: + db_connection.rollback() + else: + if not autocommit: + db_connection.commit() + + kw = dict( + changed=changed, + query=cursor.query, + statusmessage=statusmessage, + query_result=query_result, + rowcount=rowcount if rowcount >= 0 else 0, + ) + + cursor.close() + db_connection.close() + + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_set.py b/test/support/integration/plugins/modules/postgresql_set.py new file mode 100644 index 00000000..cfbdae64 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_set.py @@ -0,0 +1,434 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2018, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community' +} + +DOCUMENTATION = r''' +--- +module: postgresql_set +short_description: Change a PostgreSQL server configuration parameter +description: + - Allows to change a PostgreSQL server configuration parameter. + - The module uses ALTER SYSTEM command and applies changes by reload server configuration. + - ALTER SYSTEM is used for changing server configuration parameters across the entire database cluster. + - It can be more convenient and safe than the traditional method of manually editing the postgresql.conf file. + - ALTER SYSTEM writes the given parameter setting to the $PGDATA/postgresql.auto.conf file, + which is read in addition to postgresql.conf. + - The module allows to reset parameter to boot_val (cluster initial value) by I(reset=yes) or remove parameter + string from postgresql.auto.conf and reload I(value=default) (for settings with postmaster context restart is required). + - After change you can see in the ansible output the previous and + the new parameter value and other information using returned values and M(debug) module. +version_added: '2.8' +options: + name: + description: + - Name of PostgreSQL server parameter. + type: str + required: true + value: + description: + - Parameter value to set. + - To remove parameter string from postgresql.auto.conf and + reload the server configuration you must pass I(value=default). + With I(value=default) the playbook always returns changed is true. + type: str + reset: + description: + - Restore parameter to initial state (boot_val). Mutually exclusive with I(value). + type: bool + default: false + session_role: + description: + - Switch to session_role after connecting. The specified session_role must + be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though + the session_role were the one that had logged in originally. + type: str + db: + description: + - Name of database to connect. + type: str + aliases: + - login_db +notes: +- Supported version of PostgreSQL is 9.4 and later. +- Pay attention, change setting with 'postmaster' context can return changed is true + when actually nothing changes because the same value may be presented in + several different form, for example, 1024MB, 1GB, etc. However in pg_settings + system view it can be defined like 131072 number of 8kB pages. + The final check of the parameter value cannot compare it because the server was + not restarted and the value in pg_settings is not updated yet. +- For some parameters restart of PostgreSQL server is required. + See official documentation U(https://www.postgresql.org/docs/current/view-pg-settings.html). +seealso: +- module: postgresql_info +- name: PostgreSQL server configuration + description: General information about PostgreSQL server configuration. + link: https://www.postgresql.org/docs/current/runtime-config.html +- name: PostgreSQL view pg_settings reference + description: Complete reference of the pg_settings view documentation. + link: https://www.postgresql.org/docs/current/view-pg-settings.html +- name: PostgreSQL ALTER SYSTEM command reference + description: Complete reference of the ALTER SYSTEM command documentation. + link: https://www.postgresql.org/docs/current/sql-altersystem.html +author: +- Andrew Klychkov (@Andersson007) +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Restore wal_keep_segments parameter to initial state + postgresql_set: + name: wal_keep_segments + reset: yes + +# Set work_mem parameter to 32MB and show what's been changed and restart is required or not +# (output example: "msg": "work_mem 4MB >> 64MB restart_req: False") +- name: Set work mem parameter + postgresql_set: + name: work_mem + value: 32mb + register: set + +- debug: + msg: "{{ set.name }} {{ set.prev_val_pretty }} >> {{ set.value_pretty }} restart_req: {{ set.restart_required }}" + when: set.changed +# Ensure that the restart of PostgreSQL server must be required for some parameters. +# In this situation you see the same parameter in prev_val and value_prettyue, but 'changed=True' +# (If you passed the value that was different from the current server setting). + +- name: Set log_min_duration_statement parameter to 1 second + postgresql_set: + name: log_min_duration_statement + value: 1s + +- name: Set wal_log_hints parameter to default value (remove parameter from postgresql.auto.conf) + postgresql_set: + name: wal_log_hints + value: default +''' + +RETURN = r''' +name: + description: Name of PostgreSQL server parameter. + returned: always + type: str + sample: 'shared_buffers' +restart_required: + description: Information about parameter current state. + returned: always + type: bool + sample: true +prev_val_pretty: + description: Information about previous state of the parameter. + returned: always + type: str + sample: '4MB' +value_pretty: + description: Information about current state of the parameter. + returned: always + type: str + sample: '64MB' +value: + description: + - Dictionary that contains the current parameter value (at the time of playbook finish). + - Pay attention that for real change some parameters restart of PostgreSQL server is required. + - Returns the current value in the check mode. + returned: always + type: dict + sample: { "value": 67108864, "unit": "b" } +context: + description: + - PostgreSQL setting context. + returned: always + type: str + sample: user +''' + +try: + from psycopg2.extras import DictCursor +except Exception: + # psycopg2 is checked by connect_to_db() + # from ansible.module_utils.postgres + pass + +from copy import deepcopy + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + postgres_common_argument_spec, +) +from ansible.module_utils._text import to_native + +PG_REQ_VER = 90400 + +# To allow to set value like 1mb instead of 1MB, etc: +POSSIBLE_SIZE_UNITS = ("mb", "gb", "tb") + +# =========================================== +# PostgreSQL module specific support methods. +# + + +def param_get(cursor, module, name): + query = ("SELECT name, setting, unit, context, boot_val " + "FROM pg_settings WHERE name = %(name)s") + try: + cursor.execute(query, {'name': name}) + info = cursor.fetchall() + cursor.execute("SHOW %s" % name) + val = cursor.fetchone() + + except Exception as e: + module.fail_json(msg="Unable to get %s value due to : %s" % (name, to_native(e))) + + raw_val = info[0][1] + unit = info[0][2] + context = info[0][3] + boot_val = info[0][4] + + if val[0] == 'True': + val[0] = 'on' + elif val[0] == 'False': + val[0] = 'off' + + if unit == 'kB': + if int(raw_val) > 0: + raw_val = int(raw_val) * 1024 + if int(boot_val) > 0: + boot_val = int(boot_val) * 1024 + + unit = 'b' + + elif unit == 'MB': + if int(raw_val) > 0: + raw_val = int(raw_val) * 1024 * 1024 + if int(boot_val) > 0: + boot_val = int(boot_val) * 1024 * 1024 + + unit = 'b' + + return (val[0], raw_val, unit, boot_val, context) + + +def pretty_to_bytes(pretty_val): + # The function returns a value in bytes + # if the value contains 'B', 'kB', 'MB', 'GB', 'TB'. + # Otherwise it returns the passed argument. + + val_in_bytes = None + + if 'kB' in pretty_val: + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 + + elif 'MB' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 * 1024 + + elif 'GB' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 * 1024 * 1024 + + elif 'TB' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part * 1024 * 1024 * 1024 * 1024 + + elif 'B' in pretty_val.upper(): + num_part = int(''.join(d for d in pretty_val if d.isdigit())) + val_in_bytes = num_part + + else: + return pretty_val + + return val_in_bytes + + +def param_set(cursor, module, name, value, context): + try: + if str(value).lower() == 'default': + query = "ALTER SYSTEM SET %s = DEFAULT" % name + else: + query = "ALTER SYSTEM SET %s = '%s'" % (name, value) + cursor.execute(query) + + if context != 'postmaster': + cursor.execute("SELECT pg_reload_conf()") + + except Exception as e: + module.fail_json(msg="Unable to get %s value due to : %s" % (name, to_native(e))) + + return True + + +# =========================================== +# Module execution. +# + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + name=dict(type='str', required=True), + db=dict(type='str', aliases=['login_db']), + value=dict(type='str'), + reset=dict(type='bool'), + session_role=dict(type='str'), + ) + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True, + ) + + name = module.params["name"] + value = module.params["value"] + reset = module.params["reset"] + + # Allow to pass values like 1mb instead of 1MB, etc: + if value: + for unit in POSSIBLE_SIZE_UNITS: + if value[:-2].isdigit() and unit in value[-2:]: + value = value.upper() + + if value and reset: + module.fail_json(msg="%s: value and reset params are mutually exclusive" % name) + + if not value and not reset: + module.fail_json(msg="%s: at least one of value or reset param must be specified" % name) + + conn_params = get_conn_params(module, module.params, warn_db_default=False) + db_connection = connect_to_db(module, conn_params, autocommit=True) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + kw = {} + # Check server version (needs 9.4 or later): + ver = db_connection.server_version + if ver < PG_REQ_VER: + module.warn("PostgreSQL is %s version but %s or later is required" % (ver, PG_REQ_VER)) + kw = dict( + changed=False, + restart_required=False, + value_pretty="", + prev_val_pretty="", + value={"value": "", "unit": ""}, + ) + kw['name'] = name + db_connection.close() + module.exit_json(**kw) + + # Set default returned values: + restart_required = False + changed = False + kw['name'] = name + kw['restart_required'] = False + + # Get info about param state: + res = param_get(cursor, module, name) + current_value = res[0] + raw_val = res[1] + unit = res[2] + boot_val = res[3] + context = res[4] + + if value == 'True': + value = 'on' + elif value == 'False': + value = 'off' + + kw['prev_val_pretty'] = current_value + kw['value_pretty'] = deepcopy(kw['prev_val_pretty']) + kw['context'] = context + + # Do job + if context == "internal": + module.fail_json(msg="%s: cannot be changed (internal context). See " + "https://www.postgresql.org/docs/current/runtime-config-preset.html" % name) + + if context == "postmaster": + restart_required = True + + # If check_mode, just compare and exit: + if module.check_mode: + if pretty_to_bytes(value) == pretty_to_bytes(current_value): + kw['changed'] = False + + else: + kw['value_pretty'] = value + kw['changed'] = True + + # Anyway returns current raw value in the check_mode: + kw['value'] = dict( + value=raw_val, + unit=unit, + ) + kw['restart_required'] = restart_required + module.exit_json(**kw) + + # Set param: + if value and value != current_value: + changed = param_set(cursor, module, name, value, context) + + kw['value_pretty'] = value + + # Reset param: + elif reset: + if raw_val == boot_val: + # nothing to change, exit: + kw['value'] = dict( + value=raw_val, + unit=unit, + ) + module.exit_json(**kw) + + changed = param_set(cursor, module, name, boot_val, context) + + if restart_required: + module.warn("Restart of PostgreSQL is required for setting %s" % name) + + cursor.close() + db_connection.close() + + # Reconnect and recheck current value: + if context in ('sighup', 'superuser-backend', 'backend', 'superuser', 'user'): + db_connection = connect_to_db(module, conn_params, autocommit=True) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + res = param_get(cursor, module, name) + # f_ means 'final' + f_value = res[0] + f_raw_val = res[1] + + if raw_val == f_raw_val: + changed = False + + else: + changed = True + + kw['value_pretty'] = f_value + kw['value'] = dict( + value=f_raw_val, + unit=unit, + ) + + cursor.close() + db_connection.close() + + kw['changed'] = changed + kw['restart_required'] = restart_required + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_table.py b/test/support/integration/plugins/modules/postgresql_table.py new file mode 100644 index 00000000..3bef03b0 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_table.py @@ -0,0 +1,601 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019, Andrew Klychkov (@Andersson007) <aaklychkov@mail.ru> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community' +} + +DOCUMENTATION = r''' +--- +module: postgresql_table +short_description: Create, drop, or modify a PostgreSQL table +description: +- Allows to create, drop, rename, truncate a table, or change some table attributes. +version_added: '2.8' +options: + table: + description: + - Table name. + required: true + aliases: + - name + type: str + state: + description: + - The table state. I(state=absent) is mutually exclusive with I(tablespace), I(owner), I(unlogged), + I(like), I(including), I(columns), I(truncate), I(storage_params) and, I(rename). + type: str + default: present + choices: [ absent, present ] + tablespace: + description: + - Set a tablespace for the table. + required: false + type: str + owner: + description: + - Set a table owner. + type: str + unlogged: + description: + - Create an unlogged table. + type: bool + default: no + like: + description: + - Create a table like another table (with similar DDL). + Mutually exclusive with I(columns), I(rename), and I(truncate). + type: str + including: + description: + - Keywords that are used with like parameter, may be DEFAULTS, CONSTRAINTS, INDEXES, STORAGE, COMMENTS or ALL. + Needs I(like) specified. Mutually exclusive with I(columns), I(rename), and I(truncate). + type: str + columns: + description: + - Columns that are needed. + type: list + elements: str + rename: + description: + - New table name. Mutually exclusive with I(tablespace), I(owner), + I(unlogged), I(like), I(including), I(columns), I(truncate), and I(storage_params). + type: str + truncate: + description: + - Truncate a table. Mutually exclusive with I(tablespace), I(owner), I(unlogged), + I(like), I(including), I(columns), I(rename), and I(storage_params). + type: bool + default: no + storage_params: + description: + - Storage parameters like fillfactor, autovacuum_vacuum_treshold, etc. + Mutually exclusive with I(rename) and I(truncate). + type: list + elements: str + db: + description: + - Name of database to connect and where the table will be created. + type: str + aliases: + - login_db + session_role: + description: + - Switch to session_role after connecting. + The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though + the session_role were the one that had logged in originally. + type: str + cascade: + description: + - Automatically drop objects that depend on the table (such as views). + Used with I(state=absent) only. + type: bool + default: no + version_added: '2.9' +notes: +- If you do not pass db parameter, tables will be created in the database + named postgres. +- PostgreSQL allows to create columnless table, so columns param is optional. +- Unlogged tables are available from PostgreSQL server version 9.1. +seealso: +- module: postgresql_sequence +- module: postgresql_idx +- module: postgresql_info +- module: postgresql_tablespace +- module: postgresql_owner +- module: postgresql_privs +- module: postgresql_copy +- name: CREATE TABLE reference + description: Complete reference of the CREATE TABLE command documentation. + link: https://www.postgresql.org/docs/current/sql-createtable.html +- name: ALTER TABLE reference + description: Complete reference of the ALTER TABLE command documentation. + link: https://www.postgresql.org/docs/current/sql-altertable.html +- name: DROP TABLE reference + description: Complete reference of the DROP TABLE command documentation. + link: https://www.postgresql.org/docs/current/sql-droptable.html +- name: PostgreSQL data types + description: Complete reference of the PostgreSQL data types documentation. + link: https://www.postgresql.org/docs/current/datatype.html +author: +- Andrei Klychkov (@Andersson007) +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Create tbl2 in the acme database with the DDL like tbl1 with testuser as an owner + postgresql_table: + db: acme + name: tbl2 + like: tbl1 + owner: testuser + +- name: Create tbl2 in the acme database and tablespace ssd with the DDL like tbl1 including comments and indexes + postgresql_table: + db: acme + table: tbl2 + like: tbl1 + including: comments, indexes + tablespace: ssd + +- name: Create test_table with several columns in ssd tablespace with fillfactor=10 and autovacuum_analyze_threshold=1 + postgresql_table: + name: test_table + columns: + - id bigserial primary key + - num bigint + - stories text + tablespace: ssd + storage_params: + - fillfactor=10 + - autovacuum_analyze_threshold=1 + +- name: Create an unlogged table in schema acme + postgresql_table: + name: acme.useless_data + columns: waste_id int + unlogged: true + +- name: Rename table foo to bar + postgresql_table: + table: foo + rename: bar + +- name: Rename table foo from schema acme to bar + postgresql_table: + name: acme.foo + rename: bar + +- name: Set owner to someuser + postgresql_table: + name: foo + owner: someuser + +- name: Change tablespace of foo table to new_tablespace and set owner to new_user + postgresql_table: + name: foo + tablespace: new_tablespace + owner: new_user + +- name: Truncate table foo + postgresql_table: + name: foo + truncate: yes + +- name: Drop table foo from schema acme + postgresql_table: + name: acme.foo + state: absent + +- name: Drop table bar cascade + postgresql_table: + name: bar + state: absent + cascade: yes +''' + +RETURN = r''' +table: + description: Name of a table. + returned: always + type: str + sample: 'foo' +state: + description: Table state. + returned: always + type: str + sample: 'present' +owner: + description: Table owner. + returned: always + type: str + sample: 'postgres' +tablespace: + description: Tablespace. + returned: always + type: str + sample: 'ssd_tablespace' +queries: + description: List of executed queries. + returned: always + type: str + sample: [ 'CREATE TABLE "test_table" (id bigint)' ] +storage_params: + description: Storage parameters. + returned: always + type: list + sample: [ "fillfactor=100", "autovacuum_analyze_threshold=1" ] +''' + +try: + from psycopg2.extras import DictCursor +except ImportError: + # psycopg2 is checked by connect_to_db() + # from ansible.module_utils.postgres + pass + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.database import pg_quote_identifier +from ansible.module_utils.postgres import ( + connect_to_db, + exec_sql, + get_conn_params, + postgres_common_argument_spec, +) + + +# =========================================== +# PostgreSQL module specific support methods. +# + +class Table(object): + def __init__(self, name, module, cursor): + self.name = name + self.module = module + self.cursor = cursor + self.info = { + 'owner': '', + 'tblspace': '', + 'storage_params': [], + } + self.exists = False + self.__exists_in_db() + self.executed_queries = [] + + def get_info(self): + """Getter to refresh and get table info""" + self.__exists_in_db() + + def __exists_in_db(self): + """Check table exists and refresh info""" + if "." in self.name: + schema = self.name.split('.')[-2] + tblname = self.name.split('.')[-1] + else: + schema = 'public' + tblname = self.name + + query = ("SELECT t.tableowner, t.tablespace, c.reloptions " + "FROM pg_tables AS t " + "INNER JOIN pg_class AS c ON c.relname = t.tablename " + "INNER JOIN pg_namespace AS n ON c.relnamespace = n.oid " + "WHERE t.tablename = %(tblname)s " + "AND n.nspname = %(schema)s") + res = exec_sql(self, query, query_params={'tblname': tblname, 'schema': schema}, + add_to_executed=False) + if res: + self.exists = True + self.info = dict( + owner=res[0][0], + tblspace=res[0][1] if res[0][1] else '', + storage_params=res[0][2] if res[0][2] else [], + ) + + return True + else: + self.exists = False + return False + + def create(self, columns='', params='', tblspace='', + unlogged=False, owner=''): + """ + Create table. + If table exists, check passed args (params, tblspace, owner) and, + if they're different from current, change them. + Arguments: + params - storage params (passed by "WITH (...)" in SQL), + comma separated. + tblspace - tablespace. + owner - table owner. + unlogged - create unlogged table. + columns - column string (comma separated). + """ + name = pg_quote_identifier(self.name, 'table') + + changed = False + + if self.exists: + if tblspace == 'pg_default' and self.info['tblspace'] is None: + pass # Because they have the same meaning + elif tblspace and self.info['tblspace'] != tblspace: + self.set_tblspace(tblspace) + changed = True + + if owner and self.info['owner'] != owner: + self.set_owner(owner) + changed = True + + if params: + param_list = [p.strip(' ') for p in params.split(',')] + + new_param = False + for p in param_list: + if p not in self.info['storage_params']: + new_param = True + + if new_param: + self.set_stor_params(params) + changed = True + + if changed: + return True + return False + + query = "CREATE" + if unlogged: + query += " UNLOGGED TABLE %s" % name + else: + query += " TABLE %s" % name + + if columns: + query += " (%s)" % columns + else: + query += " ()" + + if params: + query += " WITH (%s)" % params + + if tblspace: + query += " TABLESPACE %s" % pg_quote_identifier(tblspace, 'database') + + if exec_sql(self, query, ddl=True): + changed = True + + if owner: + changed = self.set_owner(owner) + + return changed + + def create_like(self, src_table, including='', tblspace='', + unlogged=False, params='', owner=''): + """ + Create table like another table (with similar DDL). + Arguments: + src_table - source table. + including - corresponds to optional INCLUDING expression + in CREATE TABLE ... LIKE statement. + params - storage params (passed by "WITH (...)" in SQL), + comma separated. + tblspace - tablespace. + owner - table owner. + unlogged - create unlogged table. + """ + changed = False + + name = pg_quote_identifier(self.name, 'table') + + query = "CREATE" + if unlogged: + query += " UNLOGGED TABLE %s" % name + else: + query += " TABLE %s" % name + + query += " (LIKE %s" % pg_quote_identifier(src_table, 'table') + + if including: + including = including.split(',') + for i in including: + query += " INCLUDING %s" % i + + query += ')' + + if params: + query += " WITH (%s)" % params + + if tblspace: + query += " TABLESPACE %s" % pg_quote_identifier(tblspace, 'database') + + if exec_sql(self, query, ddl=True): + changed = True + + if owner: + changed = self.set_owner(owner) + + return changed + + def truncate(self): + query = "TRUNCATE TABLE %s" % pg_quote_identifier(self.name, 'table') + return exec_sql(self, query, ddl=True) + + def rename(self, newname): + query = "ALTER TABLE %s RENAME TO %s" % (pg_quote_identifier(self.name, 'table'), + pg_quote_identifier(newname, 'table')) + return exec_sql(self, query, ddl=True) + + def set_owner(self, username): + query = "ALTER TABLE %s OWNER TO %s" % (pg_quote_identifier(self.name, 'table'), + pg_quote_identifier(username, 'role')) + return exec_sql(self, query, ddl=True) + + def drop(self, cascade=False): + if not self.exists: + return False + + query = "DROP TABLE %s" % pg_quote_identifier(self.name, 'table') + if cascade: + query += " CASCADE" + return exec_sql(self, query, ddl=True) + + def set_tblspace(self, tblspace): + query = "ALTER TABLE %s SET TABLESPACE %s" % (pg_quote_identifier(self.name, 'table'), + pg_quote_identifier(tblspace, 'database')) + return exec_sql(self, query, ddl=True) + + def set_stor_params(self, params): + query = "ALTER TABLE %s SET (%s)" % (pg_quote_identifier(self.name, 'table'), params) + return exec_sql(self, query, ddl=True) + + +# =========================================== +# Module execution. +# + + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + table=dict(type='str', required=True, aliases=['name']), + state=dict(type='str', default="present", choices=["absent", "present"]), + db=dict(type='str', default='', aliases=['login_db']), + tablespace=dict(type='str'), + owner=dict(type='str'), + unlogged=dict(type='bool', default=False), + like=dict(type='str'), + including=dict(type='str'), + rename=dict(type='str'), + truncate=dict(type='bool', default=False), + columns=dict(type='list', elements='str'), + storage_params=dict(type='list', elements='str'), + session_role=dict(type='str'), + cascade=dict(type='bool', default=False), + ) + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True, + ) + + table = module.params["table"] + state = module.params["state"] + tablespace = module.params["tablespace"] + owner = module.params["owner"] + unlogged = module.params["unlogged"] + like = module.params["like"] + including = module.params["including"] + newname = module.params["rename"] + storage_params = module.params["storage_params"] + truncate = module.params["truncate"] + columns = module.params["columns"] + cascade = module.params["cascade"] + + if state == 'present' and cascade: + module.warn("cascade=true is ignored when state=present") + + # Check mutual exclusive parameters: + if state == 'absent' and (truncate or newname or columns or tablespace or like or storage_params or unlogged or owner or including): + module.fail_json(msg="%s: state=absent is mutually exclusive with: " + "truncate, rename, columns, tablespace, " + "including, like, storage_params, unlogged, owner" % table) + + if truncate and (newname or columns or like or unlogged or storage_params or owner or tablespace or including): + module.fail_json(msg="%s: truncate is mutually exclusive with: " + "rename, columns, like, unlogged, including, " + "storage_params, owner, tablespace" % table) + + if newname and (columns or like or unlogged or storage_params or owner or tablespace or including): + module.fail_json(msg="%s: rename is mutually exclusive with: " + "columns, like, unlogged, including, " + "storage_params, owner, tablespace" % table) + + if like and columns: + module.fail_json(msg="%s: like and columns params are mutually exclusive" % table) + if including and not like: + module.fail_json(msg="%s: including param needs like param specified" % table) + + conn_params = get_conn_params(module, module.params) + db_connection = connect_to_db(module, conn_params, autocommit=False) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + if storage_params: + storage_params = ','.join(storage_params) + + if columns: + columns = ','.join(columns) + + ############## + # Do main job: + table_obj = Table(table, module, cursor) + + # Set default returned values: + changed = False + kw = {} + kw['table'] = table + kw['state'] = '' + if table_obj.exists: + kw = dict( + table=table, + state='present', + owner=table_obj.info['owner'], + tablespace=table_obj.info['tblspace'], + storage_params=table_obj.info['storage_params'], + ) + + if state == 'absent': + changed = table_obj.drop(cascade=cascade) + + elif truncate: + changed = table_obj.truncate() + + elif newname: + changed = table_obj.rename(newname) + q = table_obj.executed_queries + table_obj = Table(newname, module, cursor) + table_obj.executed_queries = q + + elif state == 'present' and not like: + changed = table_obj.create(columns, storage_params, + tablespace, unlogged, owner) + + elif state == 'present' and like: + changed = table_obj.create_like(like, including, tablespace, + unlogged, storage_params) + + if changed: + if module.check_mode: + db_connection.rollback() + else: + db_connection.commit() + + # Refresh table info for RETURN. + # Note, if table has been renamed, it gets info by newname: + table_obj.get_info() + db_connection.commit() + if table_obj.exists: + kw = dict( + table=table, + state='present', + owner=table_obj.info['owner'], + tablespace=table_obj.info['tblspace'], + storage_params=table_obj.info['storage_params'], + ) + else: + # We just change the table state here + # to keep other information about the dropped table: + kw['state'] = 'absent' + + kw['queries'] = table_obj.executed_queries + kw['changed'] = changed + db_connection.close() + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/postgresql_user.py b/test/support/integration/plugins/modules/postgresql_user.py new file mode 100644 index 00000000..10afd0a0 --- /dev/null +++ b/test/support/integration/plugins/modules/postgresql_user.py @@ -0,0 +1,927 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'community' +} + +DOCUMENTATION = r''' +--- +module: postgresql_user +short_description: Add or remove a user (role) from a PostgreSQL server instance +description: +- Adds or removes a user (role) from a PostgreSQL server instance + ("cluster" in PostgreSQL terminology) and, optionally, + grants the user access to an existing database or tables. +- A user is a role with login privilege. +- The fundamental function of the module is to create, or delete, users from + a PostgreSQL instances. Privilege assignment, or removal, is an optional + step, which works on one database at a time. This allows for the module to + be called several times in the same module to modify the permissions on + different databases, or to grant permissions to already existing users. +- A user cannot be removed until all the privileges have been stripped from + the user. In such situation, if the module tries to remove the user it + will fail. To avoid this from happening the fail_on_user option signals + the module to try to remove the user, but if not possible keep going; the + module will report if changes happened and separately if the user was + removed or not. +version_added: '0.6' +options: + name: + description: + - Name of the user (role) to add or remove. + type: str + required: true + aliases: + - user + password: + description: + - Set the user's password, before 1.4 this was required. + - Password can be passed unhashed or hashed (MD5-hashed). + - Unhashed password will automatically be hashed when saved into the + database if C(encrypted) parameter is set, otherwise it will be save in + plain text format. + - When passing a hashed password it must be generated with the format + C('str["md5"] + md5[ password + username ]'), resulting in a total of + 35 characters. An easy way to do this is C(echo "md5$(echo -n + 'verysecretpasswordJOE' | md5sum | awk '{print $1}')"). + - Note that if the provided password string is already in MD5-hashed + format, then it is used as-is, regardless of C(encrypted) parameter. + type: str + db: + description: + - Name of database to connect to and where user's permissions will be granted. + type: str + aliases: + - login_db + fail_on_user: + description: + - If C(yes), fail when user (role) can't be removed. Otherwise just log and continue. + default: 'yes' + type: bool + aliases: + - fail_on_role + priv: + description: + - "Slash-separated PostgreSQL privileges string: C(priv1/priv2), where + privileges can be defined for database ( allowed options - 'CREATE', + 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL'. For example C(CONNECT) ) or + for table ( allowed options - 'SELECT', 'INSERT', 'UPDATE', 'DELETE', + 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL'. For example + C(table:SELECT) ). Mixed example of this string: + C(CONNECT/CREATE/table1:SELECT/table2:INSERT)." + type: str + role_attr_flags: + description: + - "PostgreSQL user attributes string in the format: CREATEDB,CREATEROLE,SUPERUSER." + - Note that '[NO]CREATEUSER' is deprecated. + - To create a simple role for using it like a group, use C(NOLOGIN) flag. + type: str + choices: [ '[NO]SUPERUSER', '[NO]CREATEROLE', '[NO]CREATEDB', + '[NO]INHERIT', '[NO]LOGIN', '[NO]REPLICATION', '[NO]BYPASSRLS' ] + session_role: + version_added: '2.8' + description: + - Switch to session_role after connecting. + - The specified session_role must be a role that the current login_user is a member of. + - Permissions checking for SQL commands is carried out as though the session_role were the one that had logged in originally. + type: str + state: + description: + - The user (role) state. + type: str + default: present + choices: [ absent, present ] + encrypted: + description: + - Whether the password is stored hashed in the database. + - Passwords can be passed already hashed or unhashed, and postgresql + ensures the stored password is hashed when C(encrypted) is set. + - "Note: Postgresql 10 and newer doesn't support unhashed passwords." + - Previous to Ansible 2.6, this was C(no) by default. + default: 'yes' + type: bool + version_added: '1.4' + expires: + description: + - The date at which the user's password is to expire. + - If set to C('infinity'), user's password never expire. + - Note that this value should be a valid SQL date and time type. + type: str + version_added: '1.4' + no_password_changes: + description: + - If C(yes), don't inspect database for password changes. Effective when + C(pg_authid) is not accessible (such as AWS RDS). Otherwise, make + password changes as necessary. + default: 'no' + type: bool + version_added: '2.0' + conn_limit: + description: + - Specifies the user (role) connection limit. + type: int + version_added: '2.4' + ssl_mode: + description: + - Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated with the server. + - See https://www.postgresql.org/docs/current/static/libpq-ssl.html for more information on the modes. + - Default of C(prefer) matches libpq default. + type: str + default: prefer + choices: [ allow, disable, prefer, require, verify-ca, verify-full ] + version_added: '2.3' + ca_cert: + description: + - Specifies the name of a file containing SSL certificate authority (CA) certificate(s). + - If the file exists, the server's certificate will be verified to be signed by one of these authorities. + type: str + aliases: [ ssl_rootcert ] + version_added: '2.3' + groups: + description: + - The list of groups (roles) that need to be granted to the user. + type: list + elements: str + version_added: '2.9' + comment: + description: + - Add a comment on the user (equal to the COMMENT ON ROLE statement result). + type: str + version_added: '2.10' +notes: +- The module creates a user (role) with login privilege by default. + Use NOLOGIN role_attr_flags to change this behaviour. +- If you specify PUBLIC as the user (role), then the privilege changes will apply to all users (roles). + You may not specify password or role_attr_flags when the PUBLIC user is specified. +seealso: +- module: postgresql_privs +- module: postgresql_membership +- module: postgresql_owner +- name: PostgreSQL database roles + description: Complete reference of the PostgreSQL database roles documentation. + link: https://www.postgresql.org/docs/current/user-manag.html +author: +- Ansible Core Team +extends_documentation_fragment: postgres +''' + +EXAMPLES = r''' +- name: Connect to acme database, create django user, and grant access to database and products table + postgresql_user: + db: acme + name: django + password: ceec4eif7ya + priv: "CONNECT/products:ALL" + expires: "Jan 31 2020" + +- name: Add a comment on django user + postgresql_user: + db: acme + name: django + comment: This is a test user + +# Connect to default database, create rails user, set its password (MD5-hashed), +# and grant privilege to create other databases and demote rails from super user status if user exists +- name: Create rails user, set MD5-hashed password, grant privs + postgresql_user: + name: rails + password: md59543f1d82624df2b31672ec0f7050460 + role_attr_flags: CREATEDB,NOSUPERUSER + +- name: Connect to acme database and remove test user privileges from there + postgresql_user: + db: acme + name: test + priv: "ALL/products:ALL" + state: absent + fail_on_user: no + +- name: Connect to test database, remove test user from cluster + postgresql_user: + db: test + name: test + priv: ALL + state: absent + +- name: Connect to acme database and set user's password with no expire date + postgresql_user: + db: acme + name: django + password: mysupersecretword + priv: "CONNECT/products:ALL" + expires: infinity + +# Example privileges string format +# INSERT,UPDATE/table:SELECT/anothertable:ALL + +- name: Connect to test database and remove an existing user's password + postgresql_user: + db: test + user: test + password: "" + +- name: Create user test and grant group user_ro and user_rw to it + postgresql_user: + name: test + groups: + - user_ro + - user_rw +''' + +RETURN = r''' +queries: + description: List of executed queries. + returned: always + type: list + sample: ['CREATE USER "alice"', 'GRANT CONNECT ON DATABASE "acme" TO "alice"'] + version_added: '2.8' +''' + +import itertools +import re +import traceback +from hashlib import md5 + +try: + import psycopg2 + from psycopg2.extras import DictCursor +except ImportError: + # psycopg2 is checked by connect_to_db() + # from ansible.module_utils.postgres + pass + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils.database import pg_quote_identifier, SQLParseError +from ansible.module_utils.postgres import ( + connect_to_db, + get_conn_params, + PgMembership, + postgres_common_argument_spec, +) +from ansible.module_utils._text import to_bytes, to_native +from ansible.module_utils.six import iteritems + + +FLAGS = ('SUPERUSER', 'CREATEROLE', 'CREATEDB', 'INHERIT', 'LOGIN', 'REPLICATION') +FLAGS_BY_VERSION = {'BYPASSRLS': 90500} + +VALID_PRIVS = dict(table=frozenset(('SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER', 'ALL')), + database=frozenset( + ('CREATE', 'CONNECT', 'TEMPORARY', 'TEMP', 'ALL')), + ) + +# map to cope with idiosyncracies of SUPERUSER and LOGIN +PRIV_TO_AUTHID_COLUMN = dict(SUPERUSER='rolsuper', CREATEROLE='rolcreaterole', + CREATEDB='rolcreatedb', INHERIT='rolinherit', LOGIN='rolcanlogin', + REPLICATION='rolreplication', BYPASSRLS='rolbypassrls') + +executed_queries = [] + + +class InvalidFlagsError(Exception): + pass + + +class InvalidPrivsError(Exception): + pass + +# =========================================== +# PostgreSQL module specific support methods. +# + + +def user_exists(cursor, user): + # The PUBLIC user is a special case that is always there + if user == 'PUBLIC': + return True + query = "SELECT rolname FROM pg_roles WHERE rolname=%(user)s" + cursor.execute(query, {'user': user}) + return cursor.rowcount > 0 + + +def user_add(cursor, user, password, role_attr_flags, encrypted, expires, conn_limit): + """Create a new database user (role).""" + # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a + # literal + query_password_data = dict(password=password, expires=expires) + query = ['CREATE USER "%(user)s"' % + {"user": user}] + if password is not None and password != '': + query.append("WITH %(crypt)s" % {"crypt": encrypted}) + query.append("PASSWORD %(password)s") + if expires is not None: + query.append("VALID UNTIL %(expires)s") + if conn_limit is not None: + query.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit}) + query.append(role_attr_flags) + query = ' '.join(query) + executed_queries.append(query) + cursor.execute(query, query_password_data) + return True + + +def user_should_we_change_password(current_role_attrs, user, password, encrypted): + """Check if we should change the user's password. + + Compare the proposed password with the existing one, comparing + hashes if encrypted. If we can't access it assume yes. + """ + + if current_role_attrs is None: + # on some databases, E.g. AWS RDS instances, there is no access to + # the pg_authid relation to check the pre-existing password, so we + # just assume password is different + return True + + # Do we actually need to do anything? + pwchanging = False + if password is not None: + # Empty password means that the role shouldn't have a password, which + # means we need to check if the current password is None. + if password == '': + if current_role_attrs['rolpassword'] is not None: + pwchanging = True + # 32: MD5 hashes are represented as a sequence of 32 hexadecimal digits + # 3: The size of the 'md5' prefix + # When the provided password looks like a MD5-hash, value of + # 'encrypted' is ignored. + elif (password.startswith('md5') and len(password) == 32 + 3) or encrypted == 'UNENCRYPTED': + if password != current_role_attrs['rolpassword']: + pwchanging = True + elif encrypted == 'ENCRYPTED': + hashed_password = 'md5{0}'.format(md5(to_bytes(password) + to_bytes(user)).hexdigest()) + if hashed_password != current_role_attrs['rolpassword']: + pwchanging = True + + return pwchanging + + +def user_alter(db_connection, module, user, password, role_attr_flags, encrypted, expires, no_password_changes, conn_limit): + """Change user password and/or attributes. Return True if changed, False otherwise.""" + changed = False + + cursor = db_connection.cursor(cursor_factory=DictCursor) + # Note: role_attr_flags escaped by parse_role_attrs and encrypted is a + # literal + if user == 'PUBLIC': + if password is not None: + module.fail_json(msg="cannot change the password for PUBLIC user") + elif role_attr_flags != '': + module.fail_json(msg="cannot change the role_attr_flags for PUBLIC user") + else: + return False + + # Handle passwords. + if not no_password_changes and (password is not None or role_attr_flags != '' or expires is not None or conn_limit is not None): + # Select password and all flag-like columns in order to verify changes. + try: + select = "SELECT * FROM pg_authid where rolname=%(user)s" + cursor.execute(select, {"user": user}) + # Grab current role attributes. + current_role_attrs = cursor.fetchone() + except psycopg2.ProgrammingError: + current_role_attrs = None + db_connection.rollback() + + pwchanging = user_should_we_change_password(current_role_attrs, user, password, encrypted) + + if current_role_attrs is None: + try: + # AWS RDS instances does not allow user to access pg_authid + # so try to get current_role_attrs from pg_roles tables + select = "SELECT * FROM pg_roles where rolname=%(user)s" + cursor.execute(select, {"user": user}) + # Grab current role attributes from pg_roles + current_role_attrs = cursor.fetchone() + except psycopg2.ProgrammingError as e: + db_connection.rollback() + module.fail_json(msg="Failed to get role details for current user %s: %s" % (user, e)) + + role_attr_flags_changing = False + if role_attr_flags: + role_attr_flags_dict = {} + for r in role_attr_flags.split(' '): + if r.startswith('NO'): + role_attr_flags_dict[r.replace('NO', '', 1)] = False + else: + role_attr_flags_dict[r] = True + + for role_attr_name, role_attr_value in role_attr_flags_dict.items(): + if current_role_attrs[PRIV_TO_AUTHID_COLUMN[role_attr_name]] != role_attr_value: + role_attr_flags_changing = True + + if expires is not None: + cursor.execute("SELECT %s::timestamptz;", (expires,)) + expires_with_tz = cursor.fetchone()[0] + expires_changing = expires_with_tz != current_role_attrs.get('rolvaliduntil') + else: + expires_changing = False + + conn_limit_changing = (conn_limit is not None and conn_limit != current_role_attrs['rolconnlimit']) + + if not pwchanging and not role_attr_flags_changing and not expires_changing and not conn_limit_changing: + return False + + alter = ['ALTER USER "%(user)s"' % {"user": user}] + if pwchanging: + if password != '': + alter.append("WITH %(crypt)s" % {"crypt": encrypted}) + alter.append("PASSWORD %(password)s") + else: + alter.append("WITH PASSWORD NULL") + alter.append(role_attr_flags) + elif role_attr_flags: + alter.append('WITH %s' % role_attr_flags) + if expires is not None: + alter.append("VALID UNTIL %(expires)s") + if conn_limit is not None: + alter.append("CONNECTION LIMIT %(conn_limit)s" % {"conn_limit": conn_limit}) + + query_password_data = dict(password=password, expires=expires) + try: + cursor.execute(' '.join(alter), query_password_data) + changed = True + except psycopg2.InternalError as e: + if e.pgcode == '25006': + # Handle errors due to read-only transactions indicated by pgcode 25006 + # ERROR: cannot execute ALTER ROLE in a read-only transaction + changed = False + module.fail_json(msg=e.pgerror, exception=traceback.format_exc()) + return changed + else: + raise psycopg2.InternalError(e) + except psycopg2.NotSupportedError as e: + module.fail_json(msg=e.pgerror, exception=traceback.format_exc()) + + elif no_password_changes and role_attr_flags != '': + # Grab role information from pg_roles instead of pg_authid + select = "SELECT * FROM pg_roles where rolname=%(user)s" + cursor.execute(select, {"user": user}) + # Grab current role attributes. + current_role_attrs = cursor.fetchone() + + role_attr_flags_changing = False + + if role_attr_flags: + role_attr_flags_dict = {} + for r in role_attr_flags.split(' '): + if r.startswith('NO'): + role_attr_flags_dict[r.replace('NO', '', 1)] = False + else: + role_attr_flags_dict[r] = True + + for role_attr_name, role_attr_value in role_attr_flags_dict.items(): + if current_role_attrs[PRIV_TO_AUTHID_COLUMN[role_attr_name]] != role_attr_value: + role_attr_flags_changing = True + + if not role_attr_flags_changing: + return False + + alter = ['ALTER USER "%(user)s"' % + {"user": user}] + if role_attr_flags: + alter.append('WITH %s' % role_attr_flags) + + try: + cursor.execute(' '.join(alter)) + except psycopg2.InternalError as e: + if e.pgcode == '25006': + # Handle errors due to read-only transactions indicated by pgcode 25006 + # ERROR: cannot execute ALTER ROLE in a read-only transaction + changed = False + module.fail_json(msg=e.pgerror, exception=traceback.format_exc()) + return changed + else: + raise psycopg2.InternalError(e) + + # Grab new role attributes. + cursor.execute(select, {"user": user}) + new_role_attrs = cursor.fetchone() + + # Detect any differences between current_ and new_role_attrs. + changed = current_role_attrs != new_role_attrs + + return changed + + +def user_delete(cursor, user): + """Try to remove a user. Returns True if successful otherwise False""" + cursor.execute("SAVEPOINT ansible_pgsql_user_delete") + try: + query = 'DROP USER "%s"' % user + executed_queries.append(query) + cursor.execute(query) + except Exception: + cursor.execute("ROLLBACK TO SAVEPOINT ansible_pgsql_user_delete") + cursor.execute("RELEASE SAVEPOINT ansible_pgsql_user_delete") + return False + + cursor.execute("RELEASE SAVEPOINT ansible_pgsql_user_delete") + return True + + +def has_table_privileges(cursor, user, table, privs): + """ + Return the difference between the privileges that a user already has and + the privileges that they desire to have. + + :returns: tuple of: + * privileges that they have and were requested + * privileges they currently hold but were not requested + * privileges requested that they do not hold + """ + cur_privs = get_table_privileges(cursor, user, table) + have_currently = cur_privs.intersection(privs) + other_current = cur_privs.difference(privs) + desired = privs.difference(cur_privs) + return (have_currently, other_current, desired) + + +def get_table_privileges(cursor, user, table): + if '.' in table: + schema, table = table.split('.', 1) + else: + schema = 'public' + query = ("SELECT privilege_type FROM information_schema.role_table_grants " + "WHERE grantee=%(user)s AND table_name=%(table)s AND table_schema=%(schema)s") + cursor.execute(query, {'user': user, 'table': table, 'schema': schema}) + return frozenset([x[0] for x in cursor.fetchall()]) + + +def grant_table_privileges(cursor, user, table, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + query = 'GRANT %s ON TABLE %s TO "%s"' % ( + privs, pg_quote_identifier(table, 'table'), user) + executed_queries.append(query) + cursor.execute(query) + + +def revoke_table_privileges(cursor, user, table, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + query = 'REVOKE %s ON TABLE %s FROM "%s"' % ( + privs, pg_quote_identifier(table, 'table'), user) + executed_queries.append(query) + cursor.execute(query) + + +def get_database_privileges(cursor, user, db): + priv_map = { + 'C': 'CREATE', + 'T': 'TEMPORARY', + 'c': 'CONNECT', + } + query = 'SELECT datacl FROM pg_database WHERE datname = %s' + cursor.execute(query, (db,)) + datacl = cursor.fetchone()[0] + if datacl is None: + return set() + r = re.search(r'%s\\?"?=(C?T?c?)/[^,]+,?' % user, datacl) + if r is None: + return set() + o = set() + for v in r.group(1): + o.add(priv_map[v]) + return normalize_privileges(o, 'database') + + +def has_database_privileges(cursor, user, db, privs): + """ + Return the difference between the privileges that a user already has and + the privileges that they desire to have. + + :returns: tuple of: + * privileges that they have and were requested + * privileges they currently hold but were not requested + * privileges requested that they do not hold + """ + cur_privs = get_database_privileges(cursor, user, db) + have_currently = cur_privs.intersection(privs) + other_current = cur_privs.difference(privs) + desired = privs.difference(cur_privs) + return (have_currently, other_current, desired) + + +def grant_database_privileges(cursor, user, db, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + if user == "PUBLIC": + query = 'GRANT %s ON DATABASE %s TO PUBLIC' % ( + privs, pg_quote_identifier(db, 'database')) + else: + query = 'GRANT %s ON DATABASE %s TO "%s"' % ( + privs, pg_quote_identifier(db, 'database'), user) + + executed_queries.append(query) + cursor.execute(query) + + +def revoke_database_privileges(cursor, user, db, privs): + # Note: priv escaped by parse_privs + privs = ', '.join(privs) + if user == "PUBLIC": + query = 'REVOKE %s ON DATABASE %s FROM PUBLIC' % ( + privs, pg_quote_identifier(db, 'database')) + else: + query = 'REVOKE %s ON DATABASE %s FROM "%s"' % ( + privs, pg_quote_identifier(db, 'database'), user) + + executed_queries.append(query) + cursor.execute(query) + + +def revoke_privileges(cursor, user, privs): + if privs is None: + return False + + revoke_funcs = dict(table=revoke_table_privileges, + database=revoke_database_privileges) + check_funcs = dict(table=has_table_privileges, + database=has_database_privileges) + + changed = False + for type_ in privs: + for name, privileges in iteritems(privs[type_]): + # Check that any of the privileges requested to be removed are + # currently granted to the user + differences = check_funcs[type_](cursor, user, name, privileges) + if differences[0]: + revoke_funcs[type_](cursor, user, name, privileges) + changed = True + return changed + + +def grant_privileges(cursor, user, privs): + if privs is None: + return False + + grant_funcs = dict(table=grant_table_privileges, + database=grant_database_privileges) + check_funcs = dict(table=has_table_privileges, + database=has_database_privileges) + + changed = False + for type_ in privs: + for name, privileges in iteritems(privs[type_]): + # Check that any of the privileges requested for the user are + # currently missing + differences = check_funcs[type_](cursor, user, name, privileges) + if differences[2]: + grant_funcs[type_](cursor, user, name, privileges) + changed = True + return changed + + +def parse_role_attrs(cursor, role_attr_flags): + """ + Parse role attributes string for user creation. + Format: + + attributes[,attributes,...] + + Where: + + attributes := CREATEDB,CREATEROLE,NOSUPERUSER,... + [ "[NO]SUPERUSER","[NO]CREATEROLE", "[NO]CREATEDB", + "[NO]INHERIT", "[NO]LOGIN", "[NO]REPLICATION", + "[NO]BYPASSRLS" ] + + Note: "[NO]BYPASSRLS" role attribute introduced in 9.5 + Note: "[NO]CREATEUSER" role attribute is deprecated. + + """ + flags = frozenset(role.upper() for role in role_attr_flags.split(',') if role) + + valid_flags = frozenset(itertools.chain(FLAGS, get_valid_flags_by_version(cursor))) + valid_flags = frozenset(itertools.chain(valid_flags, ('NO%s' % flag for flag in valid_flags))) + + if not flags.issubset(valid_flags): + raise InvalidFlagsError('Invalid role_attr_flags specified: %s' % + ' '.join(flags.difference(valid_flags))) + + return ' '.join(flags) + + +def normalize_privileges(privs, type_): + new_privs = set(privs) + if 'ALL' in new_privs: + new_privs.update(VALID_PRIVS[type_]) + new_privs.remove('ALL') + if 'TEMP' in new_privs: + new_privs.add('TEMPORARY') + new_privs.remove('TEMP') + + return new_privs + + +def parse_privs(privs, db): + """ + Parse privilege string to determine permissions for database db. + Format: + + privileges[/privileges/...] + + Where: + + privileges := DATABASE_PRIVILEGES[,DATABASE_PRIVILEGES,...] | + TABLE_NAME:TABLE_PRIVILEGES[,TABLE_PRIVILEGES,...] + """ + if privs is None: + return privs + + o_privs = { + 'database': {}, + 'table': {} + } + for token in privs.split('/'): + if ':' not in token: + type_ = 'database' + name = db + priv_set = frozenset(x.strip().upper() + for x in token.split(',') if x.strip()) + else: + type_ = 'table' + name, privileges = token.split(':', 1) + priv_set = frozenset(x.strip().upper() + for x in privileges.split(',') if x.strip()) + + if not priv_set.issubset(VALID_PRIVS[type_]): + raise InvalidPrivsError('Invalid privs specified for %s: %s' % + (type_, ' '.join(priv_set.difference(VALID_PRIVS[type_])))) + + priv_set = normalize_privileges(priv_set, type_) + o_privs[type_][name] = priv_set + + return o_privs + + +def get_valid_flags_by_version(cursor): + """ + Some role attributes were introduced after certain versions. We want to + compile a list of valid flags against the current Postgres version. + """ + current_version = cursor.connection.server_version + + return [ + flag + for flag, version_introduced in FLAGS_BY_VERSION.items() + if current_version >= version_introduced + ] + + +def get_comment(cursor, user): + """Get user's comment.""" + query = ("SELECT pg_catalog.shobj_description(r.oid, 'pg_authid') " + "FROM pg_catalog.pg_roles r " + "WHERE r.rolname = %(user)s") + cursor.execute(query, {'user': user}) + return cursor.fetchone()[0] + + +def add_comment(cursor, user, comment): + """Add comment on user.""" + if comment != get_comment(cursor, user): + query = 'COMMENT ON ROLE "%s" IS ' % user + cursor.execute(query + '%(comment)s', {'comment': comment}) + executed_queries.append(cursor.mogrify(query + '%(comment)s', {'comment': comment})) + return True + else: + return False + + +# =========================================== +# Module execution. +# + +def main(): + argument_spec = postgres_common_argument_spec() + argument_spec.update( + user=dict(type='str', required=True, aliases=['name']), + password=dict(type='str', default=None, no_log=True), + state=dict(type='str', default='present', choices=['absent', 'present']), + priv=dict(type='str', default=None), + db=dict(type='str', default='', aliases=['login_db']), + fail_on_user=dict(type='bool', default='yes', aliases=['fail_on_role']), + role_attr_flags=dict(type='str', default=''), + encrypted=dict(type='bool', default='yes'), + no_password_changes=dict(type='bool', default='no'), + expires=dict(type='str', default=None), + conn_limit=dict(type='int', default=None), + session_role=dict(type='str'), + groups=dict(type='list', elements='str'), + comment=dict(type='str', default=None), + ) + module = AnsibleModule( + argument_spec=argument_spec, + supports_check_mode=True + ) + + user = module.params["user"] + password = module.params["password"] + state = module.params["state"] + fail_on_user = module.params["fail_on_user"] + if module.params['db'] == '' and module.params["priv"] is not None: + module.fail_json(msg="privileges require a database to be specified") + privs = parse_privs(module.params["priv"], module.params["db"]) + no_password_changes = module.params["no_password_changes"] + if module.params["encrypted"]: + encrypted = "ENCRYPTED" + else: + encrypted = "UNENCRYPTED" + expires = module.params["expires"] + conn_limit = module.params["conn_limit"] + role_attr_flags = module.params["role_attr_flags"] + groups = module.params["groups"] + if groups: + groups = [e.strip() for e in groups] + comment = module.params["comment"] + + conn_params = get_conn_params(module, module.params, warn_db_default=False) + db_connection = connect_to_db(module, conn_params) + cursor = db_connection.cursor(cursor_factory=DictCursor) + + try: + role_attr_flags = parse_role_attrs(cursor, role_attr_flags) + except InvalidFlagsError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + kw = dict(user=user) + changed = False + user_removed = False + + if state == "present": + if user_exists(cursor, user): + try: + changed = user_alter(db_connection, module, user, password, + role_attr_flags, encrypted, expires, no_password_changes, conn_limit) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + else: + try: + changed = user_add(cursor, user, password, + role_attr_flags, encrypted, expires, conn_limit) + except psycopg2.ProgrammingError as e: + module.fail_json(msg="Unable to add user with given requirement " + "due to : %s" % to_native(e), + exception=traceback.format_exc()) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + try: + changed = grant_privileges(cursor, user, privs) or changed + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + + if groups: + target_roles = [] + target_roles.append(user) + pg_membership = PgMembership(module, cursor, groups, target_roles) + changed = pg_membership.grant() or changed + executed_queries.extend(pg_membership.executed_queries) + + if comment is not None: + try: + changed = add_comment(cursor, user, comment) or changed + except Exception as e: + module.fail_json(msg='Unable to add comment on role: %s' % to_native(e), + exception=traceback.format_exc()) + + else: + if user_exists(cursor, user): + if module.check_mode: + changed = True + kw['user_removed'] = True + else: + try: + changed = revoke_privileges(cursor, user, privs) + user_removed = user_delete(cursor, user) + except SQLParseError as e: + module.fail_json(msg=to_native(e), exception=traceback.format_exc()) + changed = changed or user_removed + if fail_on_user and not user_removed: + msg = "Unable to remove user" + module.fail_json(msg=msg) + kw['user_removed'] = user_removed + + if changed: + if module.check_mode: + db_connection.rollback() + else: + db_connection.commit() + + kw['changed'] = changed + kw['queries'] = executed_queries + module.exit_json(**kw) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/rabbitmq_plugin.py b/test/support/integration/plugins/modules/rabbitmq_plugin.py new file mode 100644 index 00000000..301bbfe2 --- /dev/null +++ b/test/support/integration/plugins/modules/rabbitmq_plugin.py @@ -0,0 +1,180 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2013, Chatham Financial <oss@chathamfinancial.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = { + 'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community' +} + + +DOCUMENTATION = ''' +--- +module: rabbitmq_plugin +short_description: Manage RabbitMQ plugins +description: + - This module can be used to enable or disable RabbitMQ plugins. +version_added: "1.1" +author: + - Chris Hoffman (@chrishoffman) +options: + names: + description: + - Comma-separated list of plugin names. Also, accepts plugin name. + required: true + aliases: [name] + new_only: + description: + - Only enable missing plugins. + - Does not disable plugins that are not in the names list. + type: bool + default: "no" + state: + description: + - Specify if plugins are to be enabled or disabled. + default: enabled + choices: [enabled, disabled] + prefix: + description: + - Specify a custom install prefix to a Rabbit. + version_added: "1.3" +''' + +EXAMPLES = ''' +- name: Enables the rabbitmq_management plugin + rabbitmq_plugin: + names: rabbitmq_management + state: enabled + +- name: Enable multiple rabbitmq plugins + rabbitmq_plugin: + names: rabbitmq_management,rabbitmq_management_visualiser + state: enabled + +- name: Disable plugin + rabbitmq_plugin: + names: rabbitmq_management + state: disabled + +- name: Enable every plugin in list with existing plugins + rabbitmq_plugin: + names: rabbitmq_management,rabbitmq_management_visualiser,rabbitmq_shovel,rabbitmq_shovel_management + state: enabled + new_only: 'yes' +''' + +RETURN = ''' +enabled: + description: list of plugins enabled during task run + returned: always + type: list + sample: ["rabbitmq_management"] +disabled: + description: list of plugins disabled during task run + returned: always + type: list + sample: ["rabbitmq_management"] +''' + +import os +from ansible.module_utils.basic import AnsibleModule + + +class RabbitMqPlugins(object): + + def __init__(self, module): + self.module = module + bin_path = '' + if module.params['prefix']: + if os.path.isdir(os.path.join(module.params['prefix'], 'bin')): + bin_path = os.path.join(module.params['prefix'], 'bin') + elif os.path.isdir(os.path.join(module.params['prefix'], 'sbin')): + bin_path = os.path.join(module.params['prefix'], 'sbin') + else: + # No such path exists. + module.fail_json(msg="No binary folder in prefix %s" % module.params['prefix']) + + self._rabbitmq_plugins = os.path.join(bin_path, "rabbitmq-plugins") + else: + self._rabbitmq_plugins = module.get_bin_path('rabbitmq-plugins', True) + + def _exec(self, args, run_in_check_mode=False): + if not self.module.check_mode or (self.module.check_mode and run_in_check_mode): + cmd = [self._rabbitmq_plugins] + rc, out, err = self.module.run_command(cmd + args, check_rc=True) + return out.splitlines() + return list() + + def get_all(self): + list_output = self._exec(['list', '-E', '-m'], True) + plugins = [] + for plugin in list_output: + if not plugin: + break + plugins.append(plugin) + + return plugins + + def enable(self, name): + self._exec(['enable', name]) + + def disable(self, name): + self._exec(['disable', name]) + + +def main(): + arg_spec = dict( + names=dict(required=True, aliases=['name']), + new_only=dict(default='no', type='bool'), + state=dict(default='enabled', choices=['enabled', 'disabled']), + prefix=dict(required=False, default=None) + ) + module = AnsibleModule( + argument_spec=arg_spec, + supports_check_mode=True + ) + + result = dict() + names = module.params['names'].split(',') + new_only = module.params['new_only'] + state = module.params['state'] + + rabbitmq_plugins = RabbitMqPlugins(module) + enabled_plugins = rabbitmq_plugins.get_all() + + enabled = [] + disabled = [] + if state == 'enabled': + if not new_only: + for plugin in enabled_plugins: + if " " in plugin: + continue + if plugin not in names: + rabbitmq_plugins.disable(plugin) + disabled.append(plugin) + + for name in names: + if name not in enabled_plugins: + rabbitmq_plugins.enable(name) + enabled.append(name) + else: + for plugin in enabled_plugins: + if plugin in names: + rabbitmq_plugins.disable(plugin) + disabled.append(plugin) + + result['changed'] = len(enabled) > 0 or len(disabled) > 0 + result['enabled'] = enabled + result['disabled'] = disabled + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/rabbitmq_queue.py b/test/support/integration/plugins/modules/rabbitmq_queue.py new file mode 100644 index 00000000..567ec813 --- /dev/null +++ b/test/support/integration/plugins/modules/rabbitmq_queue.py @@ -0,0 +1,257 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2015, Manuel Sousa <manuel.sousa@gmail.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: rabbitmq_queue +author: Manuel Sousa (@manuel-sousa) +version_added: "2.0" + +short_description: Manage rabbitMQ queues +description: + - This module uses rabbitMQ Rest API to create/delete queues +requirements: [ "requests >= 1.0.0" ] +options: + name: + description: + - Name of the queue + required: true + state: + description: + - Whether the queue should be present or absent + choices: [ "present", "absent" ] + default: present + durable: + description: + - whether queue is durable or not + type: bool + default: 'yes' + auto_delete: + description: + - if the queue should delete itself after all queues/queues unbound from it + type: bool + default: 'no' + message_ttl: + description: + - How long a message can live in queue before it is discarded (milliseconds) + default: forever + auto_expires: + description: + - How long a queue can be unused before it is automatically deleted (milliseconds) + default: forever + max_length: + description: + - How many messages can the queue contain before it starts rejecting + default: no limit + dead_letter_exchange: + description: + - Optional name of an exchange to which messages will be republished if they + - are rejected or expire + dead_letter_routing_key: + description: + - Optional replacement routing key to use when a message is dead-lettered. + - Original routing key will be used if unset + max_priority: + description: + - Maximum number of priority levels for the queue to support. + - If not set, the queue will not support message priorities. + - Larger numbers indicate higher priority. + version_added: "2.4" + arguments: + description: + - extra arguments for queue. If defined this argument is a key/value dictionary + default: {} +extends_documentation_fragment: + - rabbitmq +''' + +EXAMPLES = ''' +# Create a queue +- rabbitmq_queue: + name: myQueue + +# Create a queue on remote host +- rabbitmq_queue: + name: myRemoteQueue + login_user: user + login_password: secret + login_host: remote.example.org +''' + +import json +import traceback + +REQUESTS_IMP_ERR = None +try: + import requests + HAS_REQUESTS = True +except ImportError: + REQUESTS_IMP_ERR = traceback.format_exc() + HAS_REQUESTS = False + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils.six.moves.urllib import parse as urllib_parse +from ansible.module_utils.rabbitmq import rabbitmq_argument_spec + + +def main(): + + argument_spec = rabbitmq_argument_spec() + argument_spec.update( + dict( + state=dict(default='present', choices=['present', 'absent'], type='str'), + name=dict(required=True, type='str'), + durable=dict(default=True, type='bool'), + auto_delete=dict(default=False, type='bool'), + message_ttl=dict(default=None, type='int'), + auto_expires=dict(default=None, type='int'), + max_length=dict(default=None, type='int'), + dead_letter_exchange=dict(default=None, type='str'), + dead_letter_routing_key=dict(default=None, type='str'), + arguments=dict(default=dict(), type='dict'), + max_priority=dict(default=None, type='int') + ) + ) + module = AnsibleModule(argument_spec=argument_spec, supports_check_mode=True) + + url = "%s://%s:%s/api/queues/%s/%s" % ( + module.params['login_protocol'], + module.params['login_host'], + module.params['login_port'], + urllib_parse.quote(module.params['vhost'], ''), + module.params['name'] + ) + + if not HAS_REQUESTS: + module.fail_json(msg=missing_required_lib("requests"), exception=REQUESTS_IMP_ERR) + + result = dict(changed=False, name=module.params['name']) + + # Check if queue already exists + r = requests.get(url, auth=(module.params['login_user'], module.params['login_password']), + verify=module.params['ca_cert'], cert=(module.params['client_cert'], module.params['client_key'])) + + if r.status_code == 200: + queue_exists = True + response = r.json() + elif r.status_code == 404: + queue_exists = False + response = r.text + else: + module.fail_json( + msg="Invalid response from RESTAPI when trying to check if queue exists", + details=r.text + ) + + if module.params['state'] == 'present': + change_required = not queue_exists + else: + change_required = queue_exists + + # Check if attributes change on existing queue + if not change_required and r.status_code == 200 and module.params['state'] == 'present': + if not ( + response['durable'] == module.params['durable'] and + response['auto_delete'] == module.params['auto_delete'] and + ( + ('x-message-ttl' in response['arguments'] and response['arguments']['x-message-ttl'] == module.params['message_ttl']) or + ('x-message-ttl' not in response['arguments'] and module.params['message_ttl'] is None) + ) and + ( + ('x-expires' in response['arguments'] and response['arguments']['x-expires'] == module.params['auto_expires']) or + ('x-expires' not in response['arguments'] and module.params['auto_expires'] is None) + ) and + ( + ('x-max-length' in response['arguments'] and response['arguments']['x-max-length'] == module.params['max_length']) or + ('x-max-length' not in response['arguments'] and module.params['max_length'] is None) + ) and + ( + ('x-dead-letter-exchange' in response['arguments'] and + response['arguments']['x-dead-letter-exchange'] == module.params['dead_letter_exchange']) or + ('x-dead-letter-exchange' not in response['arguments'] and module.params['dead_letter_exchange'] is None) + ) and + ( + ('x-dead-letter-routing-key' in response['arguments'] and + response['arguments']['x-dead-letter-routing-key'] == module.params['dead_letter_routing_key']) or + ('x-dead-letter-routing-key' not in response['arguments'] and module.params['dead_letter_routing_key'] is None) + ) and + ( + ('x-max-priority' in response['arguments'] and + response['arguments']['x-max-priority'] == module.params['max_priority']) or + ('x-max-priority' not in response['arguments'] and module.params['max_priority'] is None) + ) + ): + module.fail_json( + msg="RabbitMQ RESTAPI doesn't support attribute changes for existing queues", + ) + + # Copy parameters to arguments as used by RabbitMQ + for k, v in { + 'message_ttl': 'x-message-ttl', + 'auto_expires': 'x-expires', + 'max_length': 'x-max-length', + 'dead_letter_exchange': 'x-dead-letter-exchange', + 'dead_letter_routing_key': 'x-dead-letter-routing-key', + 'max_priority': 'x-max-priority' + }.items(): + if module.params[k] is not None: + module.params['arguments'][v] = module.params[k] + + # Exit if check_mode + if module.check_mode: + result['changed'] = change_required + result['details'] = response + result['arguments'] = module.params['arguments'] + module.exit_json(**result) + + # Do changes + if change_required: + if module.params['state'] == 'present': + r = requests.put( + url, + auth=(module.params['login_user'], module.params['login_password']), + headers={"content-type": "application/json"}, + data=json.dumps({ + "durable": module.params['durable'], + "auto_delete": module.params['auto_delete'], + "arguments": module.params['arguments'] + }), + verify=module.params['ca_cert'], + cert=(module.params['client_cert'], module.params['client_key']) + ) + elif module.params['state'] == 'absent': + r = requests.delete(url, auth=(module.params['login_user'], module.params['login_password']), + verify=module.params['ca_cert'], cert=(module.params['client_cert'], module.params['client_key'])) + + # RabbitMQ 3.6.7 changed this response code from 204 to 201 + if r.status_code == 204 or r.status_code == 201: + result['changed'] = True + module.exit_json(**result) + else: + module.fail_json( + msg="Error creating queue", + status=r.status_code, + details=r.text + ) + + else: + module.exit_json( + changed=False, + name=module.params['name'] + ) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/s3_bucket.py b/test/support/integration/plugins/modules/s3_bucket.py new file mode 100644 index 00000000..f35cf53b --- /dev/null +++ b/test/support/integration/plugins/modules/s3_bucket.py @@ -0,0 +1,740 @@ +#!/usr/bin/python +# +# This is a free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This Ansible library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this library. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['stableinterface'], + 'supported_by': 'core'} + + +DOCUMENTATION = ''' +--- +module: s3_bucket +short_description: Manage S3 buckets in AWS, DigitalOcean, Ceph, Walrus, FakeS3 and StorageGRID +description: + - Manage S3 buckets in AWS, DigitalOcean, Ceph, Walrus, FakeS3 and StorageGRID +version_added: "2.0" +requirements: [ boto3 ] +author: "Rob White (@wimnat)" +options: + force: + description: + - When trying to delete a bucket, delete all keys (including versions and delete markers) + in the bucket first (an s3 bucket must be empty for a successful deletion) + type: bool + default: 'no' + name: + description: + - Name of the s3 bucket + required: true + type: str + policy: + description: + - The JSON policy as a string. + type: json + s3_url: + description: + - S3 URL endpoint for usage with DigitalOcean, Ceph, Eucalyptus and fakes3 etc. + - Assumes AWS if not specified. + - For Walrus, use FQDN of the endpoint without scheme nor path. + aliases: [ S3_URL ] + type: str + ceph: + description: + - Enable API compatibility with Ceph. It takes into account the S3 API subset working + with Ceph in order to provide the same module behaviour where possible. + type: bool + version_added: "2.2" + requester_pays: + description: + - With Requester Pays buckets, the requester instead of the bucket owner pays the cost + of the request and the data download from the bucket. + type: bool + default: False + state: + description: + - Create or remove the s3 bucket + required: false + default: present + choices: [ 'present', 'absent' ] + type: str + tags: + description: + - tags dict to apply to bucket + type: dict + purge_tags: + description: + - whether to remove tags that aren't present in the C(tags) parameter + type: bool + default: True + version_added: "2.9" + versioning: + description: + - Whether versioning is enabled or disabled (note that once versioning is enabled, it can only be suspended) + type: bool + encryption: + description: + - Describes the default server-side encryption to apply to new objects in the bucket. + In order to remove the server-side encryption, the encryption needs to be set to 'none' explicitly. + choices: [ 'none', 'AES256', 'aws:kms' ] + version_added: "2.9" + type: str + encryption_key_id: + description: KMS master key ID to use for the default encryption. This parameter is allowed if encryption is aws:kms. If + not specified then it will default to the AWS provided KMS key. + version_added: "2.9" + type: str +extends_documentation_fragment: + - aws + - ec2 +notes: + - If C(requestPayment), C(policy), C(tagging) or C(versioning) + operations/API aren't implemented by the endpoint, module doesn't fail + if each parameter satisfies the following condition. + I(requester_pays) is C(False), I(policy), I(tags), and I(versioning) are C(None). +''' + +EXAMPLES = ''' +# Note: These examples do not set authentication details, see the AWS Guide for details. + +# Create a simple s3 bucket +- s3_bucket: + name: mys3bucket + state: present + +# Create a simple s3 bucket on Ceph Rados Gateway +- s3_bucket: + name: mys3bucket + s3_url: http://your-ceph-rados-gateway-server.xxx + ceph: true + +# Remove an s3 bucket and any keys it contains +- s3_bucket: + name: mys3bucket + state: absent + force: yes + +# Create a bucket, add a policy from a file, enable requester pays, enable versioning and tag +- s3_bucket: + name: mys3bucket + policy: "{{ lookup('file','policy.json') }}" + requester_pays: yes + versioning: yes + tags: + example: tag1 + another: tag2 + +# Create a simple DigitalOcean Spaces bucket using their provided regional endpoint +- s3_bucket: + name: mydobucket + s3_url: 'https://nyc3.digitaloceanspaces.com' + +# Create a bucket with AES256 encryption +- s3_bucket: + name: mys3bucket + state: present + encryption: "AES256" + +# Create a bucket with aws:kms encryption, KMS key +- s3_bucket: + name: mys3bucket + state: present + encryption: "aws:kms" + encryption_key_id: "arn:aws:kms:us-east-1:1234/5678example" + +# Create a bucket with aws:kms encryption, default key +- s3_bucket: + name: mys3bucket + state: present + encryption: "aws:kms" +''' + +import json +import os +import time + +from ansible.module_utils.six.moves.urllib.parse import urlparse +from ansible.module_utils.six import string_types +from ansible.module_utils.basic import to_text +from ansible.module_utils.aws.core import AnsibleAWSModule, is_boto3_error_code +from ansible.module_utils.ec2 import compare_policies, ec2_argument_spec, boto3_tag_list_to_ansible_dict, ansible_dict_to_boto3_tag_list +from ansible.module_utils.ec2 import get_aws_connection_info, boto3_conn, AWSRetry + +try: + from botocore.exceptions import BotoCoreError, ClientError, EndpointConnectionError, WaiterError +except ImportError: + pass # handled by AnsibleAWSModule + + +def create_or_update_bucket(s3_client, module, location): + + policy = module.params.get("policy") + name = module.params.get("name") + requester_pays = module.params.get("requester_pays") + tags = module.params.get("tags") + purge_tags = module.params.get("purge_tags") + versioning = module.params.get("versioning") + encryption = module.params.get("encryption") + encryption_key_id = module.params.get("encryption_key_id") + changed = False + result = {} + + try: + bucket_is_present = bucket_exists(s3_client, name) + except EndpointConnectionError as e: + module.fail_json_aws(e, msg="Invalid endpoint provided: %s" % to_text(e)) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to check bucket presence") + + if not bucket_is_present: + try: + bucket_changed = create_bucket(s3_client, name, location) + s3_client.get_waiter('bucket_exists').wait(Bucket=name) + changed = changed or bucket_changed + except WaiterError as e: + module.fail_json_aws(e, msg='An error occurred waiting for the bucket to become available') + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed while creating bucket") + + # Versioning + try: + versioning_status = get_bucket_versioning(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket versioning") + except ClientError as exp: + if exp.response['Error']['Code'] != 'NotImplemented' or versioning is not None: + module.fail_json_aws(exp, msg="Failed to get bucket versioning") + else: + if versioning is not None: + required_versioning = None + if versioning and versioning_status.get('Status') != "Enabled": + required_versioning = 'Enabled' + elif not versioning and versioning_status.get('Status') == "Enabled": + required_versioning = 'Suspended' + + if required_versioning: + try: + put_bucket_versioning(s3_client, name, required_versioning) + changed = True + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to update bucket versioning") + + versioning_status = wait_versioning_is_applied(module, s3_client, name, required_versioning) + + # This output format is there to ensure compatibility with previous versions of the module + result['versioning'] = { + 'Versioning': versioning_status.get('Status', 'Disabled'), + 'MfaDelete': versioning_status.get('MFADelete', 'Disabled'), + } + + # Requester pays + try: + requester_pays_status = get_bucket_request_payment(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket request payment") + except ClientError as exp: + if exp.response['Error']['Code'] not in ('NotImplemented', 'XNotImplemented') or requester_pays: + module.fail_json_aws(exp, msg="Failed to get bucket request payment") + else: + if requester_pays: + payer = 'Requester' if requester_pays else 'BucketOwner' + if requester_pays_status != payer: + put_bucket_request_payment(s3_client, name, payer) + requester_pays_status = wait_payer_is_applied(module, s3_client, name, payer, should_fail=False) + if requester_pays_status is None: + # We have seen that it happens quite a lot of times that the put request was not taken into + # account, so we retry one more time + put_bucket_request_payment(s3_client, name, payer) + requester_pays_status = wait_payer_is_applied(module, s3_client, name, payer, should_fail=True) + changed = True + + result['requester_pays'] = requester_pays + + # Policy + try: + current_policy = get_bucket_policy(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket policy") + except ClientError as exp: + if exp.response['Error']['Code'] != 'NotImplemented' or policy is not None: + module.fail_json_aws(exp, msg="Failed to get bucket policy") + else: + if policy is not None: + if isinstance(policy, string_types): + policy = json.loads(policy) + + if not policy and current_policy: + try: + delete_bucket_policy(s3_client, name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket policy") + current_policy = wait_policy_is_applied(module, s3_client, name, policy) + changed = True + elif compare_policies(current_policy, policy): + try: + put_bucket_policy(s3_client, name, policy) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to update bucket policy") + current_policy = wait_policy_is_applied(module, s3_client, name, policy, should_fail=False) + if current_policy is None: + # As for request payement, it happens quite a lot of times that the put request was not taken into + # account, so we retry one more time + put_bucket_policy(s3_client, name, policy) + current_policy = wait_policy_is_applied(module, s3_client, name, policy, should_fail=True) + changed = True + + result['policy'] = current_policy + + # Tags + try: + current_tags_dict = get_current_bucket_tags_dict(s3_client, name) + except BotoCoreError as exp: + module.fail_json_aws(exp, msg="Failed to get bucket tags") + except ClientError as exp: + if exp.response['Error']['Code'] not in ('NotImplemented', 'XNotImplemented') or tags is not None: + module.fail_json_aws(exp, msg="Failed to get bucket tags") + else: + if tags is not None: + # Tags are always returned as text + tags = dict((to_text(k), to_text(v)) for k, v in tags.items()) + if not purge_tags: + # Ensure existing tags that aren't updated by desired tags remain + current_copy = current_tags_dict.copy() + current_copy.update(tags) + tags = current_copy + if current_tags_dict != tags: + if tags: + try: + put_bucket_tagging(s3_client, name, tags) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to update bucket tags") + else: + if purge_tags: + try: + delete_bucket_tagging(s3_client, name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket tags") + current_tags_dict = wait_tags_are_applied(module, s3_client, name, tags) + changed = True + + result['tags'] = current_tags_dict + + # Encryption + if hasattr(s3_client, "get_bucket_encryption"): + try: + current_encryption = get_bucket_encryption(s3_client, name) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to get bucket encryption") + elif encryption is not None: + module.fail_json(msg="Using bucket encryption requires botocore version >= 1.7.41") + + if encryption is not None: + current_encryption_algorithm = current_encryption.get('SSEAlgorithm') if current_encryption else None + current_encryption_key = current_encryption.get('KMSMasterKeyID') if current_encryption else None + if encryption == 'none' and current_encryption_algorithm is not None: + try: + delete_bucket_encryption(s3_client, name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket encryption") + current_encryption = wait_encryption_is_applied(module, s3_client, name, None) + changed = True + elif encryption != 'none' and (encryption != current_encryption_algorithm) or (encryption == 'aws:kms' and current_encryption_key != encryption_key_id): + expected_encryption = {'SSEAlgorithm': encryption} + if encryption == 'aws:kms' and encryption_key_id is not None: + expected_encryption.update({'KMSMasterKeyID': encryption_key_id}) + try: + put_bucket_encryption(s3_client, name, expected_encryption) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to set bucket encryption") + current_encryption = wait_encryption_is_applied(module, s3_client, name, expected_encryption) + changed = True + + result['encryption'] = current_encryption + + module.exit_json(changed=changed, name=name, **result) + + +def bucket_exists(s3_client, bucket_name): + # head_bucket appeared to be really inconsistent, so we use list_buckets instead, + # and loop over all the buckets, even if we know it's less performant :( + all_buckets = s3_client.list_buckets(Bucket=bucket_name)['Buckets'] + return any(bucket['Name'] == bucket_name for bucket in all_buckets) + + +@AWSRetry.exponential_backoff(max_delay=120) +def create_bucket(s3_client, bucket_name, location): + try: + configuration = {} + if location not in ('us-east-1', None): + configuration['LocationConstraint'] = location + if len(configuration) > 0: + s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration=configuration) + else: + s3_client.create_bucket(Bucket=bucket_name) + return True + except ClientError as e: + error_code = e.response['Error']['Code'] + if error_code == 'BucketAlreadyOwnedByYou': + # We should never get there since we check the bucket presence before calling the create_or_update_bucket + # method. However, the AWS Api sometimes fails to report bucket presence, so we catch this exception + return False + else: + raise e + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_tagging(s3_client, bucket_name, tags): + s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={'TagSet': ansible_dict_to_boto3_tag_list(tags)}) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_policy(s3_client, bucket_name, policy): + s3_client.put_bucket_policy(Bucket=bucket_name, Policy=json.dumps(policy)) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def delete_bucket_policy(s3_client, bucket_name): + s3_client.delete_bucket_policy(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_policy(s3_client, bucket_name): + try: + current_policy = json.loads(s3_client.get_bucket_policy(Bucket=bucket_name).get('Policy')) + except ClientError as e: + if e.response['Error']['Code'] == 'NoSuchBucketPolicy': + current_policy = None + else: + raise e + return current_policy + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_request_payment(s3_client, bucket_name, payer): + s3_client.put_bucket_request_payment(Bucket=bucket_name, RequestPaymentConfiguration={'Payer': payer}) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_request_payment(s3_client, bucket_name): + return s3_client.get_bucket_request_payment(Bucket=bucket_name).get('Payer') + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_versioning(s3_client, bucket_name): + return s3_client.get_bucket_versioning(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_versioning(s3_client, bucket_name, required_versioning): + s3_client.put_bucket_versioning(Bucket=bucket_name, VersioningConfiguration={'Status': required_versioning}) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def get_bucket_encryption(s3_client, bucket_name): + try: + result = s3_client.get_bucket_encryption(Bucket=bucket_name) + return result.get('ServerSideEncryptionConfiguration', {}).get('Rules', [])[0].get('ApplyServerSideEncryptionByDefault') + except ClientError as e: + if e.response['Error']['Code'] == 'ServerSideEncryptionConfigurationNotFoundError': + return None + else: + raise e + except (IndexError, KeyError): + return None + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def put_bucket_encryption(s3_client, bucket_name, encryption): + server_side_encryption_configuration = {'Rules': [{'ApplyServerSideEncryptionByDefault': encryption}]} + s3_client.put_bucket_encryption(Bucket=bucket_name, ServerSideEncryptionConfiguration=server_side_encryption_configuration) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def delete_bucket_tagging(s3_client, bucket_name): + s3_client.delete_bucket_tagging(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120, catch_extra_error_codes=['NoSuchBucket']) +def delete_bucket_encryption(s3_client, bucket_name): + s3_client.delete_bucket_encryption(Bucket=bucket_name) + + +@AWSRetry.exponential_backoff(max_delay=120) +def delete_bucket(s3_client, bucket_name): + try: + s3_client.delete_bucket(Bucket=bucket_name) + except ClientError as e: + if e.response['Error']['Code'] == 'NoSuchBucket': + # This means bucket should have been in a deleting state when we checked it existence + # We just ignore the error + pass + else: + raise e + + +def wait_policy_is_applied(module, s3_client, bucket_name, expected_policy, should_fail=True): + for dummy in range(0, 12): + try: + current_policy = get_bucket_policy(s3_client, bucket_name) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to get bucket policy") + + if compare_policies(current_policy, expected_policy): + time.sleep(5) + else: + return current_policy + if should_fail: + module.fail_json(msg="Bucket policy failed to apply in the expected time") + else: + return None + + +def wait_payer_is_applied(module, s3_client, bucket_name, expected_payer, should_fail=True): + for dummy in range(0, 12): + try: + requester_pays_status = get_bucket_request_payment(s3_client, bucket_name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get bucket request payment") + if requester_pays_status != expected_payer: + time.sleep(5) + else: + return requester_pays_status + if should_fail: + module.fail_json(msg="Bucket request payment failed to apply in the expected time") + else: + return None + + +def wait_encryption_is_applied(module, s3_client, bucket_name, expected_encryption): + for dummy in range(0, 12): + try: + encryption = get_bucket_encryption(s3_client, bucket_name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get updated encryption for bucket") + if encryption != expected_encryption: + time.sleep(5) + else: + return encryption + module.fail_json(msg="Bucket encryption failed to apply in the expected time") + + +def wait_versioning_is_applied(module, s3_client, bucket_name, required_versioning): + for dummy in range(0, 24): + try: + versioning_status = get_bucket_versioning(s3_client, bucket_name) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get updated versioning for bucket") + if versioning_status.get('Status') != required_versioning: + time.sleep(8) + else: + return versioning_status + module.fail_json(msg="Bucket versioning failed to apply in the expected time") + + +def wait_tags_are_applied(module, s3_client, bucket_name, expected_tags_dict): + for dummy in range(0, 12): + try: + current_tags_dict = get_current_bucket_tags_dict(s3_client, bucket_name) + except (ClientError, BotoCoreError) as e: + module.fail_json_aws(e, msg="Failed to get bucket policy") + if current_tags_dict != expected_tags_dict: + time.sleep(5) + else: + return current_tags_dict + module.fail_json(msg="Bucket tags failed to apply in the expected time") + + +def get_current_bucket_tags_dict(s3_client, bucket_name): + try: + current_tags = s3_client.get_bucket_tagging(Bucket=bucket_name).get('TagSet') + except ClientError as e: + if e.response['Error']['Code'] == 'NoSuchTagSet': + return {} + raise e + + return boto3_tag_list_to_ansible_dict(current_tags) + + +def paginated_list(s3_client, **pagination_params): + pg = s3_client.get_paginator('list_objects_v2') + for page in pg.paginate(**pagination_params): + yield [data['Key'] for data in page.get('Contents', [])] + + +def paginated_versions_list(s3_client, **pagination_params): + try: + pg = s3_client.get_paginator('list_object_versions') + for page in pg.paginate(**pagination_params): + # We have to merge the Versions and DeleteMarker lists here, as DeleteMarkers can still prevent a bucket deletion + yield [(data['Key'], data['VersionId']) for data in (page.get('Versions', []) + page.get('DeleteMarkers', []))] + except is_boto3_error_code('NoSuchBucket'): + yield [] + + +def destroy_bucket(s3_client, module): + + force = module.params.get("force") + name = module.params.get("name") + try: + bucket_is_present = bucket_exists(s3_client, name) + except EndpointConnectionError as e: + module.fail_json_aws(e, msg="Invalid endpoint provided: %s" % to_text(e)) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to check bucket presence") + + if not bucket_is_present: + module.exit_json(changed=False) + + if force: + # if there are contents then we need to delete them (including versions) before we can delete the bucket + try: + for key_version_pairs in paginated_versions_list(s3_client, Bucket=name): + formatted_keys = [{'Key': key, 'VersionId': version} for key, version in key_version_pairs] + for fk in formatted_keys: + # remove VersionId from cases where they are `None` so that + # unversioned objects are deleted using `DeleteObject` + # rather than `DeleteObjectVersion`, improving backwards + # compatibility with older IAM policies. + if not fk.get('VersionId'): + fk.pop('VersionId') + + if formatted_keys: + resp = s3_client.delete_objects(Bucket=name, Delete={'Objects': formatted_keys}) + if resp.get('Errors'): + module.fail_json( + msg='Could not empty bucket before deleting. Could not delete objects: {0}'.format( + ', '.join([k['Key'] for k in resp['Errors']]) + ), + errors=resp['Errors'], response=resp + ) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed while deleting bucket") + + try: + delete_bucket(s3_client, name) + s3_client.get_waiter('bucket_not_exists').wait(Bucket=name, WaiterConfig=dict(Delay=5, MaxAttempts=60)) + except WaiterError as e: + module.fail_json_aws(e, msg='An error occurred waiting for the bucket to be deleted.') + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to delete bucket") + + module.exit_json(changed=True) + + +def is_fakes3(s3_url): + """ Return True if s3_url has scheme fakes3:// """ + if s3_url is not None: + return urlparse(s3_url).scheme in ('fakes3', 'fakes3s') + else: + return False + + +def get_s3_client(module, aws_connect_kwargs, location, ceph, s3_url): + if s3_url and ceph: # TODO - test this + ceph = urlparse(s3_url) + params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https', region=location, endpoint=s3_url, **aws_connect_kwargs) + elif is_fakes3(s3_url): + fakes3 = urlparse(s3_url) + port = fakes3.port + if fakes3.scheme == 'fakes3s': + protocol = "https" + if port is None: + port = 443 + else: + protocol = "http" + if port is None: + port = 80 + params = dict(module=module, conn_type='client', resource='s3', region=location, + endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)), + use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs) + else: + params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=s3_url, **aws_connect_kwargs) + return boto3_conn(**params) + + +def main(): + + argument_spec = ec2_argument_spec() + argument_spec.update( + dict( + force=dict(default=False, type='bool'), + policy=dict(type='json'), + name=dict(required=True), + requester_pays=dict(default=False, type='bool'), + s3_url=dict(aliases=['S3_URL']), + state=dict(default='present', choices=['present', 'absent']), + tags=dict(type='dict'), + purge_tags=dict(type='bool', default=True), + versioning=dict(type='bool'), + ceph=dict(default=False, type='bool'), + encryption=dict(choices=['none', 'AES256', 'aws:kms']), + encryption_key_id=dict() + ) + ) + + module = AnsibleAWSModule( + argument_spec=argument_spec, + ) + + region, ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True) + + if region in ('us-east-1', '', None): + # default to US Standard region + location = 'us-east-1' + else: + # Boto uses symbolic names for locations but region strings will + # actually work fine for everything except us-east-1 (US Standard) + location = region + + s3_url = module.params.get('s3_url') + ceph = module.params.get('ceph') + + # allow eucarc environment variables to be used if ansible vars aren't set + if not s3_url and 'S3_URL' in os.environ: + s3_url = os.environ['S3_URL'] + + if ceph and not s3_url: + module.fail_json(msg='ceph flavour requires s3_url') + + # Look at s3_url and tweak connection settings + # if connecting to Ceph RGW, Walrus or fakes3 + if s3_url: + for key in ['validate_certs', 'security_token', 'profile_name']: + aws_connect_kwargs.pop(key, None) + s3_client = get_s3_client(module, aws_connect_kwargs, location, ceph, s3_url) + + if s3_client is None: # this should never happen + module.fail_json(msg='Unknown error, failed to create s3 connection, no information from boto.') + + state = module.params.get("state") + encryption = module.params.get("encryption") + encryption_key_id = module.params.get("encryption_key_id") + + # Parameter validation + if encryption_key_id is not None and encryption is None: + module.fail_json(msg="You must specify encryption parameter along with encryption_key_id.") + elif encryption_key_id is not None and encryption != 'aws:kms': + module.fail_json(msg="Only 'aws:kms' is a valid option for encryption parameter when you specify encryption_key_id.") + + if state == 'present': + create_or_update_bucket(s3_client, module, location) + elif state == 'absent': + destroy_bucket(s3_client, module) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/sefcontext.py b/test/support/integration/plugins/modules/sefcontext.py new file mode 100644 index 00000000..33e3fd2e --- /dev/null +++ b/test/support/integration/plugins/modules/sefcontext.py @@ -0,0 +1,298 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Dag Wieers (@dagwieers) <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: sefcontext +short_description: Manages SELinux file context mapping definitions +description: +- Manages SELinux file context mapping definitions. +- Similar to the C(semanage fcontext) command. +version_added: '2.2' +options: + target: + description: + - Target path (expression). + type: str + required: yes + aliases: [ path ] + ftype: + description: + - The file type that should have SELinux contexts applied. + - "The following file type options are available:" + - C(a) for all files, + - C(b) for block devices, + - C(c) for character devices, + - C(d) for directories, + - C(f) for regular files, + - C(l) for symbolic links, + - C(p) for named pipes, + - C(s) for socket files. + type: str + choices: [ a, b, c, d, f, l, p, s ] + default: a + setype: + description: + - SELinux type for the specified target. + type: str + required: yes + seuser: + description: + - SELinux user for the specified target. + type: str + selevel: + description: + - SELinux range for the specified target. + type: str + aliases: [ serange ] + state: + description: + - Whether the SELinux file context must be C(absent) or C(present). + type: str + choices: [ absent, present ] + default: present + reload: + description: + - Reload SELinux policy after commit. + - Note that this does not apply SELinux file contexts to existing files. + type: bool + default: yes + ignore_selinux_state: + description: + - Useful for scenarios (chrooted environment) that you can't get the real SELinux state. + type: bool + default: no + version_added: '2.8' +notes: +- The changes are persistent across reboots. +- The M(sefcontext) module does not modify existing files to the new + SELinux context(s), so it is advisable to first create the SELinux + file contexts before creating files, or run C(restorecon) manually + for the existing files that require the new SELinux file contexts. +- Not applying SELinux fcontexts to existing files is a deliberate + decision as it would be unclear what reported changes would entail + to, and there's no guarantee that applying SELinux fcontext does + not pick up other unrelated prior changes. +requirements: +- libselinux-python +- policycoreutils-python +author: +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r''' +- name: Allow apache to modify files in /srv/git_repos + sefcontext: + target: '/srv/git_repos(/.*)?' + setype: httpd_git_rw_content_t + state: present + +- name: Apply new SELinux file context to filesystem + command: restorecon -irv /srv/git_repos +''' + +RETURN = r''' +# Default return values +''' + +import traceback + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native + +SELINUX_IMP_ERR = None +try: + import selinux + HAVE_SELINUX = True +except ImportError: + SELINUX_IMP_ERR = traceback.format_exc() + HAVE_SELINUX = False + +SEOBJECT_IMP_ERR = None +try: + import seobject + HAVE_SEOBJECT = True +except ImportError: + SEOBJECT_IMP_ERR = traceback.format_exc() + HAVE_SEOBJECT = False + +# Add missing entries (backward compatible) +if HAVE_SEOBJECT: + seobject.file_types.update( + a=seobject.SEMANAGE_FCONTEXT_ALL, + b=seobject.SEMANAGE_FCONTEXT_BLOCK, + c=seobject.SEMANAGE_FCONTEXT_CHAR, + d=seobject.SEMANAGE_FCONTEXT_DIR, + f=seobject.SEMANAGE_FCONTEXT_REG, + l=seobject.SEMANAGE_FCONTEXT_LINK, + p=seobject.SEMANAGE_FCONTEXT_PIPE, + s=seobject.SEMANAGE_FCONTEXT_SOCK, + ) + +# Make backward compatible +option_to_file_type_str = dict( + a='all files', + b='block device', + c='character device', + d='directory', + f='regular file', + l='symbolic link', + p='named pipe', + s='socket', +) + + +def get_runtime_status(ignore_selinux_state=False): + return True if ignore_selinux_state is True else selinux.is_selinux_enabled() + + +def semanage_fcontext_exists(sefcontext, target, ftype): + ''' Get the SELinux file context mapping definition from policy. Return None if it does not exist. ''' + + # Beware that records comprise of a string representation of the file_type + record = (target, option_to_file_type_str[ftype]) + records = sefcontext.get_all() + try: + return records[record] + except KeyError: + return None + + +def semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser, sestore=''): + ''' Add or modify SELinux file context mapping definition to the policy. ''' + + changed = False + prepared_diff = '' + + try: + sefcontext = seobject.fcontextRecords(sestore) + sefcontext.set_reload(do_reload) + exists = semanage_fcontext_exists(sefcontext, target, ftype) + if exists: + # Modify existing entry + orig_seuser, orig_serole, orig_setype, orig_serange = exists + + if seuser is None: + seuser = orig_seuser + if serange is None: + serange = orig_serange + + if setype != orig_setype or seuser != orig_seuser or serange != orig_serange: + if not module.check_mode: + sefcontext.modify(target, setype, ftype, serange, seuser) + changed = True + + if module._diff: + prepared_diff += '# Change to semanage file context mappings\n' + prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, orig_seuser, orig_serole, orig_setype, orig_serange) + prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, orig_serole, setype, serange) + else: + # Add missing entry + if seuser is None: + seuser = 'system_u' + if serange is None: + serange = 's0' + + if not module.check_mode: + sefcontext.add(target, setype, ftype, serange, seuser) + changed = True + + if module._diff: + prepared_diff += '# Addition to semanage file context mappings\n' + prepared_diff += '+%s %s %s:%s:%s:%s\n' % (target, ftype, seuser, 'object_r', setype, serange) + + except Exception as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e))) + + if module._diff and prepared_diff: + result['diff'] = dict(prepared=prepared_diff) + + module.exit_json(changed=changed, seuser=seuser, serange=serange, **result) + + +def semanage_fcontext_delete(module, result, target, ftype, do_reload, sestore=''): + ''' Delete SELinux file context mapping definition from the policy. ''' + + changed = False + prepared_diff = '' + + try: + sefcontext = seobject.fcontextRecords(sestore) + sefcontext.set_reload(do_reload) + exists = semanage_fcontext_exists(sefcontext, target, ftype) + if exists: + # Remove existing entry + orig_seuser, orig_serole, orig_setype, orig_serange = exists + + if not module.check_mode: + sefcontext.delete(target, ftype) + changed = True + + if module._diff: + prepared_diff += '# Deletion to semanage file context mappings\n' + prepared_diff += '-%s %s %s:%s:%s:%s\n' % (target, ftype, exists[0], exists[1], exists[2], exists[3]) + + except Exception as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e))) + + if module._diff and prepared_diff: + result['diff'] = dict(prepared=prepared_diff) + + module.exit_json(changed=changed, **result) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + ignore_selinux_state=dict(type='bool', default=False), + target=dict(type='str', required=True, aliases=['path']), + ftype=dict(type='str', default='a', choices=option_to_file_type_str.keys()), + setype=dict(type='str', required=True), + seuser=dict(type='str'), + selevel=dict(type='str', aliases=['serange']), + state=dict(type='str', default='present', choices=['absent', 'present']), + reload=dict(type='bool', default=True), + ), + supports_check_mode=True, + ) + if not HAVE_SELINUX: + module.fail_json(msg=missing_required_lib("libselinux-python"), exception=SELINUX_IMP_ERR) + + if not HAVE_SEOBJECT: + module.fail_json(msg=missing_required_lib("policycoreutils-python"), exception=SEOBJECT_IMP_ERR) + + ignore_selinux_state = module.params['ignore_selinux_state'] + + if not get_runtime_status(ignore_selinux_state): + module.fail_json(msg="SELinux is disabled on this host.") + + target = module.params['target'] + ftype = module.params['ftype'] + setype = module.params['setype'] + seuser = module.params['seuser'] + serange = module.params['selevel'] + state = module.params['state'] + do_reload = module.params['reload'] + + result = dict(target=target, ftype=ftype, setype=setype, state=state) + + if state == 'present': + semanage_fcontext_modify(module, result, target, ftype, setype, do_reload, serange, seuser) + elif state == 'absent': + semanage_fcontext_delete(module, result, target, ftype, do_reload) + else: + module.fail_json(msg='Invalid value of argument "state": {0}'.format(state)) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/selogin.py b/test/support/integration/plugins/modules/selogin.py new file mode 100644 index 00000000..6429ef36 --- /dev/null +++ b/test/support/integration/plugins/modules/selogin.py @@ -0,0 +1,260 @@ +#!/usr/bin/python + +# (c) 2017, Petr Lautrbach <plautrba@redhat.com> +# Based on seport.py module (c) 2014, Dan Keder <dan.keder@gmail.com> + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = ''' +--- +module: selogin +short_description: Manages linux user to SELinux user mapping +description: + - Manages linux user to SELinux user mapping +version_added: "2.8" +options: + login: + description: + - a Linux user + required: true + seuser: + description: + - SELinux user name + required: true + selevel: + aliases: [ serange ] + description: + - MLS/MCS Security Range (MLS/MCS Systems only) SELinux Range for SELinux login mapping defaults to the SELinux user record range. + default: s0 + state: + description: + - Desired mapping value. + required: true + default: present + choices: [ 'present', 'absent' ] + reload: + description: + - Reload SELinux policy after commit. + default: yes + ignore_selinux_state: + description: + - Run independent of selinux runtime state + type: bool + default: false +notes: + - The changes are persistent across reboots + - Not tested on any debian based system +requirements: [ 'libselinux', 'policycoreutils' ] +author: +- Dan Keder (@dankeder) +- Petr Lautrbach (@bachradsusi) +- James Cassell (@jamescassell) +''' + +EXAMPLES = ''' +# Modify the default user on the system to the guest_u user +- selogin: + login: __default__ + seuser: guest_u + state: present + +# Assign gijoe user on an MLS machine a range and to the staff_u user +- selogin: + login: gijoe + seuser: staff_u + serange: SystemLow-Secret + state: present + +# Assign all users in the engineering group to the staff_u user +- selogin: + login: '%engineering' + seuser: staff_u + state: present +''' + +RETURN = r''' +# Default return values +''' + + +import traceback + +SELINUX_IMP_ERR = None +try: + import selinux + HAVE_SELINUX = True +except ImportError: + SELINUX_IMP_ERR = traceback.format_exc() + HAVE_SELINUX = False + +SEOBJECT_IMP_ERR = None +try: + import seobject + HAVE_SEOBJECT = True +except ImportError: + SEOBJECT_IMP_ERR = traceback.format_exc() + HAVE_SEOBJECT = False + + +from ansible.module_utils.basic import AnsibleModule, missing_required_lib +from ansible.module_utils._text import to_native + + +def semanage_login_add(module, login, seuser, do_reload, serange='s0', sestore=''): + """ Add linux user to SELinux user mapping + + :type module: AnsibleModule + :param module: Ansible module + + :type login: str + :param login: a Linux User or a Linux group if it begins with % + + :type seuser: str + :param proto: An SELinux user ('__default__', 'unconfined_u', 'staff_u', ...), see 'semanage login -l' + + :type serange: str + :param serange: SELinux MLS/MCS range (defaults to 's0') + + :type do_reload: bool + :param do_reload: Whether to reload SELinux policy after commit + + :type sestore: str + :param sestore: SELinux store + + :rtype: bool + :return: True if the policy was changed, otherwise False + """ + try: + selogin = seobject.loginRecords(sestore) + selogin.set_reload(do_reload) + change = False + all_logins = selogin.get_all() + # module.fail_json(msg="%s: %s %s" % (all_logins, login, sestore)) + # for local_login in all_logins: + if login not in all_logins.keys(): + change = True + if not module.check_mode: + selogin.add(login, seuser, serange) + else: + if all_logins[login][0] != seuser or all_logins[login][1] != serange: + change = True + if not module.check_mode: + selogin.modify(login, seuser, serange) + + except (ValueError, KeyError, OSError, RuntimeError) as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)), exception=traceback.format_exc()) + + return change + + +def semanage_login_del(module, login, seuser, do_reload, sestore=''): + """ Delete linux user to SELinux user mapping + + :type module: AnsibleModule + :param module: Ansible module + + :type login: str + :param login: a Linux User or a Linux group if it begins with % + + :type seuser: str + :param proto: An SELinux user ('__default__', 'unconfined_u', 'staff_u', ...), see 'semanage login -l' + + :type do_reload: bool + :param do_reload: Whether to reload SELinux policy after commit + + :type sestore: str + :param sestore: SELinux store + + :rtype: bool + :return: True if the policy was changed, otherwise False + """ + try: + selogin = seobject.loginRecords(sestore) + selogin.set_reload(do_reload) + change = False + all_logins = selogin.get_all() + # module.fail_json(msg="%s: %s %s" % (all_logins, login, sestore)) + if login in all_logins.keys(): + change = True + if not module.check_mode: + selogin.delete(login) + + except (ValueError, KeyError, OSError, RuntimeError) as e: + module.fail_json(msg="%s: %s\n" % (e.__class__.__name__, to_native(e)), exception=traceback.format_exc()) + + return change + + +def get_runtime_status(ignore_selinux_state=False): + return True if ignore_selinux_state is True else selinux.is_selinux_enabled() + + +def main(): + module = AnsibleModule( + argument_spec=dict( + ignore_selinux_state=dict(type='bool', default=False), + login=dict(type='str', required=True), + seuser=dict(type='str'), + selevel=dict(type='str', aliases=['serange'], default='s0'), + state=dict(type='str', default='present', choices=['absent', 'present']), + reload=dict(type='bool', default=True), + ), + required_if=[ + ["state", "present", ["seuser"]] + ], + supports_check_mode=True + ) + if not HAVE_SELINUX: + module.fail_json(msg=missing_required_lib("libselinux"), exception=SELINUX_IMP_ERR) + + if not HAVE_SEOBJECT: + module.fail_json(msg=missing_required_lib("seobject from policycoreutils"), exception=SEOBJECT_IMP_ERR) + + ignore_selinux_state = module.params['ignore_selinux_state'] + + if not get_runtime_status(ignore_selinux_state): + module.fail_json(msg="SELinux is disabled on this host.") + + login = module.params['login'] + seuser = module.params['seuser'] + serange = module.params['selevel'] + state = module.params['state'] + do_reload = module.params['reload'] + + result = { + 'login': login, + 'seuser': seuser, + 'serange': serange, + 'state': state, + } + + if state == 'present': + result['changed'] = semanage_login_add(module, login, seuser, do_reload, serange) + elif state == 'absent': + result['changed'] = semanage_login_del(module, login, seuser, do_reload) + else: + module.fail_json(msg='Invalid value of argument "state": {0}'.format(state)) + + module.exit_json(**result) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/synchronize.py b/test/support/integration/plugins/modules/synchronize.py new file mode 100644 index 00000000..e4c520b7 --- /dev/null +++ b/test/support/integration/plugins/modules/synchronize.py @@ -0,0 +1,618 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2012-2013, Timothy Appnel <tim@appnel.com> +# Copyright: (c) 2017, Ansible Project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'core'} + +DOCUMENTATION = r''' +--- +module: synchronize +version_added: "1.4" +short_description: A wrapper around rsync to make common tasks in your playbooks quick and easy +description: + - C(synchronize) is a wrapper around rsync to make common tasks in your playbooks quick and easy. + - It is run and originates on the local host where Ansible is being run. + - Of course, you could just use the C(command) action to call rsync yourself, but you also have to add a fair number of + boilerplate options and host facts. + - This module is not intended to provide access to the full power of rsync, but does make the most common + invocations easier to implement. You `still` may need to call rsync directly via C(command) or C(shell) depending on your use case. +options: + src: + description: + - Path on the source host that will be synchronized to the destination. + - The path can be absolute or relative. + type: str + required: true + dest: + description: + - Path on the destination host that will be synchronized from the source. + - The path can be absolute or relative. + type: str + required: true + dest_port: + description: + - Port number for ssh on the destination host. + - Prior to Ansible 2.0, the ansible_ssh_port inventory var took precedence over this value. + - This parameter defaults to the value of C(ansible_ssh_port) or C(ansible_port), + the C(remote_port) config setting or the value from ssh client configuration + if none of the former have been set. + type: int + version_added: "1.5" + mode: + description: + - Specify the direction of the synchronization. + - In push mode the localhost or delegate is the source. + - In pull mode the remote host in context is the source. + type: str + choices: [ pull, push ] + default: push + archive: + description: + - Mirrors the rsync archive flag, enables recursive, links, perms, times, owner, group flags and -D. + type: bool + default: yes + checksum: + description: + - Skip based on checksum, rather than mod-time & size; Note that that "archive" option is still enabled by default - the "checksum" option will + not disable it. + type: bool + default: no + version_added: "1.6" + compress: + description: + - Compress file data during the transfer. + - In most cases, leave this enabled unless it causes problems. + type: bool + default: yes + version_added: "1.7" + existing_only: + description: + - Skip creating new files on receiver. + type: bool + default: no + version_added: "1.5" + delete: + description: + - Delete files in C(dest) that don't exist (after transfer, not before) in the C(src) path. + - This option requires C(recursive=yes). + - This option ignores excluded files and behaves like the rsync opt --delete-excluded. + type: bool + default: no + dirs: + description: + - Transfer directories without recursing. + type: bool + default: no + recursive: + description: + - Recurse into directories. + - This parameter defaults to the value of the archive option. + type: bool + links: + description: + - Copy symlinks as symlinks. + - This parameter defaults to the value of the archive option. + type: bool + copy_links: + description: + - Copy symlinks as the item that they point to (the referent) is copied, rather than the symlink. + type: bool + default: no + perms: + description: + - Preserve permissions. + - This parameter defaults to the value of the archive option. + type: bool + times: + description: + - Preserve modification times. + - This parameter defaults to the value of the archive option. + type: bool + owner: + description: + - Preserve owner (super user only). + - This parameter defaults to the value of the archive option. + type: bool + group: + description: + - Preserve group. + - This parameter defaults to the value of the archive option. + type: bool + rsync_path: + description: + - Specify the rsync command to run on the remote host. See C(--rsync-path) on the rsync man page. + - To specify the rsync command to run on the local host, you need to set this your task var C(ansible_rsync_path). + type: str + rsync_timeout: + description: + - Specify a C(--timeout) for the rsync command in seconds. + type: int + default: 0 + set_remote_user: + description: + - Put user@ for the remote paths. + - If you have a custom ssh config to define the remote user for a host + that does not match the inventory user, you should set this parameter to C(no). + type: bool + default: yes + use_ssh_args: + description: + - Use the ssh_args specified in ansible.cfg. + type: bool + default: no + version_added: "2.0" + rsync_opts: + description: + - Specify additional rsync options by passing in an array. + - Note that an empty string in C(rsync_opts) will end up transfer the current working directory. + type: list + default: + version_added: "1.6" + partial: + description: + - Tells rsync to keep the partial file which should make a subsequent transfer of the rest of the file much faster. + type: bool + default: no + version_added: "2.0" + verify_host: + description: + - Verify destination host key. + type: bool + default: no + version_added: "2.0" + private_key: + description: + - Specify the private key to use for SSH-based rsync connections (e.g. C(~/.ssh/id_rsa)). + type: path + version_added: "1.6" + link_dest: + description: + - Add a destination to hard link against during the rsync. + type: list + default: + version_added: "2.5" +notes: + - rsync must be installed on both the local and remote host. + - For the C(synchronize) module, the "local host" is the host `the synchronize task originates on`, and the "destination host" is the host + `synchronize is connecting to`. + - The "local host" can be changed to a different host by using `delegate_to`. This enables copying between two remote hosts or entirely on one + remote machine. + - > + The user and permissions for the synchronize `src` are those of the user running the Ansible task on the local host (or the remote_user for a + delegate_to host when delegate_to is used). + - The user and permissions for the synchronize `dest` are those of the `remote_user` on the destination host or the `become_user` if `become=yes` is active. + - In Ansible 2.0 a bug in the synchronize module made become occur on the "local host". This was fixed in Ansible 2.0.1. + - Currently, synchronize is limited to elevating permissions via passwordless sudo. This is because rsync itself is connecting to the remote machine + and rsync doesn't give us a way to pass sudo credentials in. + - Currently there are only a few connection types which support synchronize (ssh, paramiko, local, and docker) because a sync strategy has been + determined for those connection types. Note that the connection for these must not need a password as rsync itself is making the connection and + rsync does not provide us a way to pass a password to the connection. + - Expect that dest=~/x will be ~<remote_user>/x even if using sudo. + - Inspect the verbose output to validate the destination user/host/path are what was expected. + - To exclude files and directories from being synchronized, you may add C(.rsync-filter) files to the source directory. + - rsync daemon must be up and running with correct permission when using rsync protocol in source or destination path. + - The C(synchronize) module forces `--delay-updates` to avoid leaving a destination in a broken in-between state if the underlying rsync process + encounters an error. Those synchronizing large numbers of files that are willing to trade safety for performance should call rsync directly. + - link_destination is subject to the same limitations as the underlying rsync daemon. Hard links are only preserved if the relative subtrees + of the source and destination are the same. Attempts to hardlink into a directory that is a subdirectory of the source will be prevented. +seealso: +- module: copy +- module: win_robocopy +author: +- Timothy Appnel (@tima) +''' + +EXAMPLES = ''' +- name: Synchronization of src on the control machine to dest on the remote hosts + synchronize: + src: some/relative/path + dest: /some/absolute/path + +- name: Synchronization using rsync protocol (push) + synchronize: + src: some/relative/path/ + dest: rsync://somehost.com/path/ + +- name: Synchronization using rsync protocol (pull) + synchronize: + mode: pull + src: rsync://somehost.com/path/ + dest: /some/absolute/path/ + +- name: Synchronization using rsync protocol on delegate host (push) + synchronize: + src: /some/absolute/path/ + dest: rsync://somehost.com/path/ + delegate_to: delegate.host + +- name: Synchronization using rsync protocol on delegate host (pull) + synchronize: + mode: pull + src: rsync://somehost.com/path/ + dest: /some/absolute/path/ + delegate_to: delegate.host + +- name: Synchronization without any --archive options enabled + synchronize: + src: some/relative/path + dest: /some/absolute/path + archive: no + +- name: Synchronization with --archive options enabled except for --recursive + synchronize: + src: some/relative/path + dest: /some/absolute/path + recursive: no + +- name: Synchronization with --archive options enabled except for --times, with --checksum option enabled + synchronize: + src: some/relative/path + dest: /some/absolute/path + checksum: yes + times: no + +- name: Synchronization without --archive options enabled except use --links + synchronize: + src: some/relative/path + dest: /some/absolute/path + archive: no + links: yes + +- name: Synchronization of two paths both on the control machine + synchronize: + src: some/relative/path + dest: /some/absolute/path + delegate_to: localhost + +- name: Synchronization of src on the inventory host to the dest on the localhost in pull mode + synchronize: + mode: pull + src: some/relative/path + dest: /some/absolute/path + +- name: Synchronization of src on delegate host to dest on the current inventory host. + synchronize: + src: /first/absolute/path + dest: /second/absolute/path + delegate_to: delegate.host + +- name: Synchronize two directories on one remote host. + synchronize: + src: /first/absolute/path + dest: /second/absolute/path + delegate_to: "{{ inventory_hostname }}" + +- name: Synchronize and delete files in dest on the remote host that are not found in src of localhost. + synchronize: + src: some/relative/path + dest: /some/absolute/path + delete: yes + recursive: yes + +# This specific command is granted su privileges on the destination +- name: Synchronize using an alternate rsync command + synchronize: + src: some/relative/path + dest: /some/absolute/path + rsync_path: su -c rsync + +# Example .rsync-filter file in the source directory +# - var # exclude any path whose last part is 'var' +# - /var # exclude any path starting with 'var' starting at the source directory +# + /var/conf # include /var/conf even though it was previously excluded + +- name: Synchronize passing in extra rsync options + synchronize: + src: /tmp/helloworld + dest: /var/www/helloworld + rsync_opts: + - "--no-motd" + - "--exclude=.git" + +# Hardlink files if they didn't change +- name: Use hardlinks when synchronizing filesystems + synchronize: + src: /tmp/path_a/foo.txt + dest: /tmp/path_b/foo.txt + link_dest: /tmp/path_a/ + +# Specify the rsync binary to use on remote host and on local host +- hosts: groupofhosts + vars: + ansible_rsync_path: /usr/gnu/bin/rsync + + tasks: + - name: copy /tmp/localpath/ to remote location /tmp/remotepath + synchronize: + src: /tmp/localpath/ + dest: /tmp/remotepath + rsync_path: /usr/gnu/bin/rsync +''' + + +import os +import errno + +from ansible.module_utils.basic import AnsibleModule +from ansible.module_utils._text import to_bytes, to_native +from ansible.module_utils.six.moves import shlex_quote + + +client_addr = None + + +def substitute_controller(path): + global client_addr + if not client_addr: + ssh_env_string = os.environ.get('SSH_CLIENT', None) + try: + client_addr, _ = ssh_env_string.split(None, 1) + except AttributeError: + ssh_env_string = os.environ.get('SSH_CONNECTION', None) + try: + client_addr, _ = ssh_env_string.split(None, 1) + except AttributeError: + pass + if not client_addr: + raise ValueError + + if path.startswith('localhost:'): + path = path.replace('localhost', client_addr, 1) + return path + + +def is_rsh_needed(source, dest): + if source.startswith('rsync://') or dest.startswith('rsync://'): + return False + if ':' in source or ':' in dest: + return True + return False + + +def main(): + module = AnsibleModule( + argument_spec=dict( + src=dict(type='str', required=True), + dest=dict(type='str', required=True), + dest_port=dict(type='int'), + delete=dict(type='bool', default=False), + private_key=dict(type='path'), + rsync_path=dict(type='str'), + _local_rsync_path=dict(type='path', default='rsync'), + _local_rsync_password=dict(type='str', no_log=True), + _substitute_controller=dict(type='bool', default=False), + archive=dict(type='bool', default=True), + checksum=dict(type='bool', default=False), + compress=dict(type='bool', default=True), + existing_only=dict(type='bool', default=False), + dirs=dict(type='bool', default=False), + recursive=dict(type='bool'), + links=dict(type='bool'), + copy_links=dict(type='bool', default=False), + perms=dict(type='bool'), + times=dict(type='bool'), + owner=dict(type='bool'), + group=dict(type='bool'), + set_remote_user=dict(type='bool', default=True), + rsync_timeout=dict(type='int', default=0), + rsync_opts=dict(type='list', default=[]), + ssh_args=dict(type='str'), + partial=dict(type='bool', default=False), + verify_host=dict(type='bool', default=False), + mode=dict(type='str', default='push', choices=['pull', 'push']), + link_dest=dict(type='list') + ), + supports_check_mode=True, + ) + + if module.params['_substitute_controller']: + try: + source = substitute_controller(module.params['src']) + dest = substitute_controller(module.params['dest']) + except ValueError: + module.fail_json(msg='Could not determine controller hostname for rsync to send to') + else: + source = module.params['src'] + dest = module.params['dest'] + dest_port = module.params['dest_port'] + delete = module.params['delete'] + private_key = module.params['private_key'] + rsync_path = module.params['rsync_path'] + rsync = module.params.get('_local_rsync_path', 'rsync') + rsync_password = module.params.get('_local_rsync_password') + rsync_timeout = module.params.get('rsync_timeout', 'rsync_timeout') + archive = module.params['archive'] + checksum = module.params['checksum'] + compress = module.params['compress'] + existing_only = module.params['existing_only'] + dirs = module.params['dirs'] + partial = module.params['partial'] + # the default of these params depends on the value of archive + recursive = module.params['recursive'] + links = module.params['links'] + copy_links = module.params['copy_links'] + perms = module.params['perms'] + times = module.params['times'] + owner = module.params['owner'] + group = module.params['group'] + rsync_opts = module.params['rsync_opts'] + ssh_args = module.params['ssh_args'] + verify_host = module.params['verify_host'] + link_dest = module.params['link_dest'] + + if '/' not in rsync: + rsync = module.get_bin_path(rsync, required=True) + + cmd = [rsync, '--delay-updates', '-F'] + _sshpass_pipe = None + if rsync_password: + try: + module.run_command(["sshpass"]) + except OSError: + module.fail_json( + msg="to use rsync connection with passwords, you must install the sshpass program" + ) + _sshpass_pipe = os.pipe() + cmd = ['sshpass', '-d' + to_native(_sshpass_pipe[0], errors='surrogate_or_strict')] + cmd + if compress: + cmd.append('--compress') + if rsync_timeout: + cmd.append('--timeout=%s' % rsync_timeout) + if module.check_mode: + cmd.append('--dry-run') + if delete: + cmd.append('--delete-after') + if existing_only: + cmd.append('--existing') + if checksum: + cmd.append('--checksum') + if copy_links: + cmd.append('--copy-links') + if archive: + cmd.append('--archive') + if recursive is False: + cmd.append('--no-recursive') + if links is False: + cmd.append('--no-links') + if perms is False: + cmd.append('--no-perms') + if times is False: + cmd.append('--no-times') + if owner is False: + cmd.append('--no-owner') + if group is False: + cmd.append('--no-group') + else: + if recursive is True: + cmd.append('--recursive') + if links is True: + cmd.append('--links') + if perms is True: + cmd.append('--perms') + if times is True: + cmd.append('--times') + if owner is True: + cmd.append('--owner') + if group is True: + cmd.append('--group') + if dirs: + cmd.append('--dirs') + + if source.startswith('rsync://') and dest.startswith('rsync://'): + module.fail_json(msg='either src or dest must be a localhost', rc=1) + + if is_rsh_needed(source, dest): + + # https://github.com/ansible/ansible/issues/15907 + has_rsh = False + for rsync_opt in rsync_opts: + if '--rsh' in rsync_opt: + has_rsh = True + break + + # if the user has not supplied an --rsh option go ahead and add ours + if not has_rsh: + ssh_cmd = [module.get_bin_path('ssh', required=True), '-S', 'none'] + if private_key is not None: + ssh_cmd.extend(['-i', private_key]) + # If the user specified a port value + # Note: The action plugin takes care of setting this to a port from + # inventory if the user didn't specify an explicit dest_port + if dest_port is not None: + ssh_cmd.extend(['-o', 'Port=%s' % dest_port]) + if not verify_host: + ssh_cmd.extend(['-o', 'StrictHostKeyChecking=no', '-o', 'UserKnownHostsFile=/dev/null']) + ssh_cmd_str = ' '.join(shlex_quote(arg) for arg in ssh_cmd) + if ssh_args: + ssh_cmd_str += ' %s' % ssh_args + cmd.append('--rsh=%s' % ssh_cmd_str) + + if rsync_path: + cmd.append('--rsync-path=%s' % rsync_path) + + if rsync_opts: + if '' in rsync_opts: + module.warn('The empty string is present in rsync_opts which will cause rsync to' + ' transfer the current working directory. If this is intended, use "."' + ' instead to get rid of this warning. If this is unintended, check for' + ' problems in your playbook leading to empty string in rsync_opts.') + cmd.extend(rsync_opts) + + if partial: + cmd.append('--partial') + + if link_dest: + cmd.append('-H') + # verbose required because rsync does not believe that adding a + # hardlink is actually a change + cmd.append('-vv') + for x in link_dest: + link_path = os.path.abspath(os.path.expanduser(x)) + destination_path = os.path.abspath(os.path.dirname(dest)) + if destination_path.find(link_path) == 0: + module.fail_json(msg='Hardlinking into a subdirectory of the source would cause recursion. %s and %s' % (destination_path, dest)) + cmd.append('--link-dest=%s' % link_path) + + changed_marker = '<<CHANGED>>' + cmd.append('--out-format=' + changed_marker + '%i %n%L') + + # expand the paths + if '@' not in source: + source = os.path.expanduser(source) + if '@' not in dest: + dest = os.path.expanduser(dest) + + cmd.append(source) + cmd.append(dest) + cmdstr = ' '.join(cmd) + + # If we are using password authentication, write the password into the pipe + if rsync_password: + def _write_password_to_pipe(proc): + os.close(_sshpass_pipe[0]) + try: + os.write(_sshpass_pipe[1], to_bytes(rsync_password) + b'\n') + except OSError as exc: + # Ignore broken pipe errors if the sshpass process has exited. + if exc.errno != errno.EPIPE or proc.poll() is None: + raise + + (rc, out, err) = module.run_command( + cmd, pass_fds=_sshpass_pipe, + before_communicate_callback=_write_password_to_pipe) + else: + (rc, out, err) = module.run_command(cmd) + + if rc: + return module.fail_json(msg=err, rc=rc, cmd=cmdstr) + + if link_dest: + # a leading period indicates no change + changed = (changed_marker + '.') not in out + else: + changed = changed_marker in out + + out_clean = out.replace(changed_marker, '') + out_lines = out_clean.split('\n') + while '' in out_lines: + out_lines.remove('') + if module._diff: + diff = {'prepared': out_clean} + return module.exit_json(changed=changed, msg=out_clean, + rc=rc, cmd=cmdstr, stdout_lines=out_lines, + diff=diff) + + return module.exit_json(changed=changed, msg=out_clean, + rc=rc, cmd=cmdstr, stdout_lines=out_lines) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/timezone.py b/test/support/integration/plugins/modules/timezone.py new file mode 100644 index 00000000..b7439a12 --- /dev/null +++ b/test/support/integration/plugins/modules/timezone.py @@ -0,0 +1,909 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2016, Shinichi TAMURA (@tmshn) +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: timezone +short_description: Configure timezone setting +description: + - This module configures the timezone setting, both of the system clock and of the hardware clock. If you want to set up the NTP, use M(service) module. + - It is recommended to restart C(crond) after changing the timezone, otherwise the jobs may run at the wrong time. + - Several different tools are used depending on the OS/Distribution involved. + For Linux it can use C(timedatectl) or edit C(/etc/sysconfig/clock) or C(/etc/timezone) and C(hwclock). + On SmartOS, C(sm-set-timezone), for macOS, C(systemsetup), for BSD, C(/etc/localtime) is modified. + On AIX, C(chtz) is used. + - As of Ansible 2.3 support was added for SmartOS and BSDs. + - As of Ansible 2.4 support was added for macOS. + - As of Ansible 2.9 support was added for AIX 6.1+ + - Windows and HPUX are not supported, please let us know if you find any other OS/distro in which this fails. +version_added: "2.2" +options: + name: + description: + - Name of the timezone for the system clock. + - Default is to keep current setting. + - B(At least one of name and hwclock are required.) + type: str + hwclock: + description: + - Whether the hardware clock is in UTC or in local timezone. + - Default is to keep current setting. + - Note that this option is recommended not to change and may fail + to configure, especially on virtual environments such as AWS. + - B(At least one of name and hwclock are required.) + - I(Only used on Linux.) + type: str + aliases: [ rtc ] + choices: [ local, UTC ] +notes: + - On SmartOS the C(sm-set-timezone) utility (part of the smtools package) is required to set the zone timezone + - On AIX only Olson/tz database timezones are useable (POSIX is not supported). + - An OS reboot is also required on AIX for the new timezone setting to take effect. +author: + - Shinichi TAMURA (@tmshn) + - Jasper Lievisse Adriaanse (@jasperla) + - Indrajit Raychaudhuri (@indrajitr) +''' + +RETURN = r''' +diff: + description: The differences about the given arguments. + returned: success + type: complex + contains: + before: + description: The values before change + type: dict + after: + description: The values after change + type: dict +''' + +EXAMPLES = r''' +- name: Set timezone to Asia/Tokyo + timezone: + name: Asia/Tokyo +''' + +import errno +import os +import platform +import random +import re +import string +import filecmp + +from ansible.module_utils.basic import AnsibleModule, get_distribution +from ansible.module_utils.six import iteritems + + +class Timezone(object): + """This is a generic Timezone manipulation class that is subclassed based on platform. + + A subclass may wish to override the following action methods: + - get(key, phase) ... get the value from the system at `phase` + - set(key, value) ... set the value to the current system + """ + + def __new__(cls, module): + """Return the platform-specific subclass. + + It does not use load_platform_subclass() because it needs to judge based + on whether the `timedatectl` command exists and is available. + + Args: + module: The AnsibleModule. + """ + if platform.system() == 'Linux': + timedatectl = module.get_bin_path('timedatectl') + if timedatectl is not None: + rc, stdout, stderr = module.run_command(timedatectl) + if rc == 0: + return super(Timezone, SystemdTimezone).__new__(SystemdTimezone) + else: + module.warn('timedatectl command was found but not usable: %s. using other method.' % stderr) + return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone) + else: + return super(Timezone, NosystemdTimezone).__new__(NosystemdTimezone) + elif re.match('^joyent_.*Z', platform.version()): + # platform.system() returns SunOS, which is too broad. So look at the + # platform version instead. However we have to ensure that we're not + # running in the global zone where changing the timezone has no effect. + zonename_cmd = module.get_bin_path('zonename') + if zonename_cmd is not None: + (rc, stdout, _) = module.run_command(zonename_cmd) + if rc == 0 and stdout.strip() == 'global': + module.fail_json(msg='Adjusting timezone is not supported in Global Zone') + + return super(Timezone, SmartOSTimezone).__new__(SmartOSTimezone) + elif platform.system() == 'Darwin': + return super(Timezone, DarwinTimezone).__new__(DarwinTimezone) + elif re.match('^(Free|Net|Open)BSD', platform.platform()): + return super(Timezone, BSDTimezone).__new__(BSDTimezone) + elif platform.system() == 'AIX': + AIXoslevel = int(platform.version() + platform.release()) + if AIXoslevel >= 61: + return super(Timezone, AIXTimezone).__new__(AIXTimezone) + else: + module.fail_json(msg='AIX os level must be >= 61 for timezone module (Target: %s).' % AIXoslevel) + else: + # Not supported yet + return super(Timezone, Timezone).__new__(Timezone) + + def __init__(self, module): + """Initialize of the class. + + Args: + module: The AnsibleModule. + """ + super(Timezone, self).__init__() + self.msg = [] + # `self.value` holds the values for each params on each phases. + # Initially there's only info of "planned" phase, but the + # `self.check()` function will fill out it. + self.value = dict() + for key in module.argument_spec: + value = module.params[key] + if value is not None: + self.value[key] = dict(planned=value) + self.module = module + + def abort(self, msg): + """Abort the process with error message. + + This is just the wrapper of module.fail_json(). + + Args: + msg: The error message. + """ + error_msg = ['Error message:', msg] + if len(self.msg) > 0: + error_msg.append('Other message(s):') + error_msg.extend(self.msg) + self.module.fail_json(msg='\n'.join(error_msg)) + + def execute(self, *commands, **kwargs): + """Execute the shell command. + + This is just the wrapper of module.run_command(). + + Args: + *commands: The command to execute. + It will be concatenated with single space. + **kwargs: Only 'log' key is checked. + If kwargs['log'] is true, record the command to self.msg. + + Returns: + stdout: Standard output of the command. + """ + command = ' '.join(commands) + (rc, stdout, stderr) = self.module.run_command(command, check_rc=True) + if kwargs.get('log', False): + self.msg.append('executed `%s`' % command) + return stdout + + def diff(self, phase1='before', phase2='after'): + """Calculate the difference between given 2 phases. + + Args: + phase1, phase2: The names of phase to compare. + + Returns: + diff: The difference of value between phase1 and phase2. + This is in the format which can be used with the + `--diff` option of ansible-playbook. + """ + diff = {phase1: {}, phase2: {}} + for key, value in iteritems(self.value): + diff[phase1][key] = value[phase1] + diff[phase2][key] = value[phase2] + return diff + + def check(self, phase): + """Check the state in given phase and set it to `self.value`. + + Args: + phase: The name of the phase to check. + + Returns: + NO RETURN VALUE + """ + if phase == 'planned': + return + for key, value in iteritems(self.value): + value[phase] = self.get(key, phase) + + def change(self): + """Make the changes effect based on `self.value`.""" + for key, value in iteritems(self.value): + if value['before'] != value['planned']: + self.set(key, value['planned']) + + # =========================================== + # Platform specific methods (must be replaced by subclass). + + def get(self, key, phase): + """Get the value for the key at the given phase. + + Called from self.check(). + + Args: + key: The key to get the value + phase: The phase to get the value + + Return: + value: The value for the key at the given phase. + """ + self.abort('get(key, phase) is not implemented on target platform') + + def set(self, key, value): + """Set the value for the key (of course, for the phase 'after'). + + Called from self.change(). + + Args: + key: Key to set the value + value: Value to set + """ + self.abort('set(key, value) is not implemented on target platform') + + def _verify_timezone(self): + tz = self.value['name']['planned'] + tzfile = '/usr/share/zoneinfo/%s' % tz + if not os.path.isfile(tzfile): + self.abort('given timezone "%s" is not available' % tz) + return tzfile + + +class SystemdTimezone(Timezone): + """This is a Timezone manipulation class for systemd-powered Linux. + + It uses the `timedatectl` command to check/set all arguments. + """ + + regexps = dict( + hwclock=re.compile(r'^\s*RTC in local TZ\s*:\s*([^\s]+)', re.MULTILINE), + name=re.compile(r'^\s*Time ?zone\s*:\s*([^\s]+)', re.MULTILINE) + ) + + subcmds = dict( + hwclock='set-local-rtc', + name='set-timezone' + ) + + def __init__(self, module): + super(SystemdTimezone, self).__init__(module) + self.timedatectl = module.get_bin_path('timedatectl', required=True) + self.status = dict() + # Validate given timezone + if 'name' in self.value: + self._verify_timezone() + + def _get_status(self, phase): + if phase not in self.status: + self.status[phase] = self.execute(self.timedatectl, 'status') + return self.status[phase] + + def get(self, key, phase): + status = self._get_status(phase) + value = self.regexps[key].search(status).group(1) + if key == 'hwclock': + # For key='hwclock'; convert yes/no -> local/UTC + if self.module.boolean(value): + value = 'local' + else: + value = 'UTC' + return value + + def set(self, key, value): + # For key='hwclock'; convert UTC/local -> yes/no + if key == 'hwclock': + if value == 'local': + value = 'yes' + else: + value = 'no' + self.execute(self.timedatectl, self.subcmds[key], value, log=True) + + +class NosystemdTimezone(Timezone): + """This is a Timezone manipulation class for non systemd-powered Linux. + + For timezone setting, it edits the following file and reflect changes: + - /etc/sysconfig/clock ... RHEL/CentOS + - /etc/timezone ... Debian/Ubuntu + For hwclock setting, it executes `hwclock --systohc` command with the + '--utc' or '--localtime' option. + """ + + conf_files = dict( + name=None, # To be set in __init__ + hwclock=None, # To be set in __init__ + adjtime='/etc/adjtime' + ) + + # It's fine if all tree config files don't exist + allow_no_file = dict( + name=True, + hwclock=True, + adjtime=True + ) + + regexps = dict( + name=None, # To be set in __init__ + hwclock=re.compile(r'^UTC\s*=\s*([^\s]+)', re.MULTILINE), + adjtime=re.compile(r'^(UTC|LOCAL)$', re.MULTILINE) + ) + + dist_regexps = dict( + SuSE=re.compile(r'^TIMEZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE), + redhat=re.compile(r'^ZONE\s*=\s*"?([^"\s]+)"?', re.MULTILINE) + ) + + dist_tzline_format = dict( + SuSE='TIMEZONE="%s"\n', + redhat='ZONE="%s"\n' + ) + + def __init__(self, module): + super(NosystemdTimezone, self).__init__(module) + # Validate given timezone + if 'name' in self.value: + tzfile = self._verify_timezone() + # `--remove-destination` is needed if /etc/localtime is a symlink so + # that it overwrites it instead of following it. + self.update_timezone = ['%s --remove-destination %s /etc/localtime' % (self.module.get_bin_path('cp', required=True), tzfile)] + self.update_hwclock = self.module.get_bin_path('hwclock', required=True) + # Distribution-specific configurations + if self.module.get_bin_path('dpkg-reconfigure') is not None: + # Debian/Ubuntu + if 'name' in self.value: + self.update_timezone = ['%s -sf %s /etc/localtime' % (self.module.get_bin_path('ln', required=True), tzfile), + '%s --frontend noninteractive tzdata' % self.module.get_bin_path('dpkg-reconfigure', required=True)] + self.conf_files['name'] = '/etc/timezone' + self.conf_files['hwclock'] = '/etc/default/rcS' + self.regexps['name'] = re.compile(r'^([^\s]+)', re.MULTILINE) + self.tzline_format = '%s\n' + else: + # RHEL/CentOS/SUSE + if self.module.get_bin_path('tzdata-update') is not None: + # tzdata-update cannot update the timezone if /etc/localtime is + # a symlink so we have to use cp to update the time zone which + # was set above. + if not os.path.islink('/etc/localtime'): + self.update_timezone = [self.module.get_bin_path('tzdata-update', required=True)] + # else: + # self.update_timezone = 'cp --remove-destination ...' <- configured above + self.conf_files['name'] = '/etc/sysconfig/clock' + self.conf_files['hwclock'] = '/etc/sysconfig/clock' + try: + f = open(self.conf_files['name'], 'r') + except IOError as err: + if self._allow_ioerror(err, 'name'): + # If the config file doesn't exist detect the distribution and set regexps. + distribution = get_distribution() + if distribution == 'SuSE': + # For SUSE + self.regexps['name'] = self.dist_regexps['SuSE'] + self.tzline_format = self.dist_tzline_format['SuSE'] + else: + # For RHEL/CentOS + self.regexps['name'] = self.dist_regexps['redhat'] + self.tzline_format = self.dist_tzline_format['redhat'] + else: + self.abort('could not read configuration file "%s"' % self.conf_files['name']) + else: + # The key for timezone might be `ZONE` or `TIMEZONE` + # (the former is used in RHEL/CentOS and the latter is used in SUSE linux). + # So check the content of /etc/sysconfig/clock and decide which key to use. + sysconfig_clock = f.read() + f.close() + if re.search(r'^TIMEZONE\s*=', sysconfig_clock, re.MULTILINE): + # For SUSE + self.regexps['name'] = self.dist_regexps['SuSE'] + self.tzline_format = self.dist_tzline_format['SuSE'] + else: + # For RHEL/CentOS + self.regexps['name'] = self.dist_regexps['redhat'] + self.tzline_format = self.dist_tzline_format['redhat'] + + def _allow_ioerror(self, err, key): + # In some cases, even if the target file does not exist, + # simply creating it may solve the problem. + # In such cases, we should continue the configuration rather than aborting. + if err.errno != errno.ENOENT: + # If the error is not ENOENT ("No such file or directory"), + # (e.g., permission error, etc), we should abort. + return False + return self.allow_no_file.get(key, False) + + def _edit_file(self, filename, regexp, value, key): + """Replace the first matched line with given `value`. + + If `regexp` matched more than once, other than the first line will be deleted. + + Args: + filename: The name of the file to edit. + regexp: The regular expression to search with. + value: The line which will be inserted. + key: For what key the file is being editted. + """ + # Read the file + try: + file = open(filename, 'r') + except IOError as err: + if self._allow_ioerror(err, key): + lines = [] + else: + self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename)) + else: + lines = file.readlines() + file.close() + # Find the all matched lines + matched_indices = [] + for i, line in enumerate(lines): + if regexp.search(line): + matched_indices.append(i) + if len(matched_indices) > 0: + insert_line = matched_indices[0] + else: + insert_line = 0 + # Remove all matched lines + for i in matched_indices[::-1]: + del lines[i] + # ...and insert the value + lines.insert(insert_line, value) + # Write the changes + try: + file = open(filename, 'w') + except IOError: + self.abort('tried to configure %s using a file "%s", but could not write to it' % (key, filename)) + else: + file.writelines(lines) + file.close() + self.msg.append('Added 1 line and deleted %s line(s) on %s' % (len(matched_indices), filename)) + + def _get_value_from_config(self, key, phase): + filename = self.conf_files[key] + try: + file = open(filename, mode='r') + except IOError as err: + if self._allow_ioerror(err, key): + if key == 'hwclock': + return 'n/a' + elif key == 'adjtime': + return 'UTC' + elif key == 'name': + return 'n/a' + else: + self.abort('tried to configure %s using a file "%s", but could not read it' % (key, filename)) + else: + status = file.read() + file.close() + try: + value = self.regexps[key].search(status).group(1) + except AttributeError: + if key == 'hwclock': + # If we cannot find UTC in the config that's fine. + return 'n/a' + elif key == 'adjtime': + # If we cannot find UTC/LOCAL in /etc/cannot that means UTC + # will be used by default. + return 'UTC' + elif key == 'name': + if phase == 'before': + # In 'before' phase UTC/LOCAL doesn't need to be set in + # the timezone config file, so we ignore this error. + return 'n/a' + else: + self.abort('tried to configure %s using a file "%s", but could not find a valid value in it' % (key, filename)) + else: + if key == 'hwclock': + # convert yes/no -> UTC/local + if self.module.boolean(value): + value = 'UTC' + else: + value = 'local' + elif key == 'adjtime': + # convert LOCAL -> local + if value != 'UTC': + value = value.lower() + return value + + def get(self, key, phase): + planned = self.value[key]['planned'] + if key == 'hwclock': + value = self._get_value_from_config(key, phase) + if value == planned: + # If the value in the config file is the same as the 'planned' + # value, we need to check /etc/adjtime. + value = self._get_value_from_config('adjtime', phase) + elif key == 'name': + value = self._get_value_from_config(key, phase) + if value == planned: + # If the planned values is the same as the one in the config file + # we need to check if /etc/localtime is also set to the 'planned' zone. + if os.path.islink('/etc/localtime'): + # If /etc/localtime is a symlink and is not set to the TZ we 'planned' + # to set, we need to return the TZ which the symlink points to. + if os.path.exists('/etc/localtime'): + # We use readlink() because on some distros zone files are symlinks + # to other zone files, so it's hard to get which TZ is actually set + # if we follow the symlink. + path = os.readlink('/etc/localtime') + linktz = re.search(r'/usr/share/zoneinfo/(.*)', path, re.MULTILINE) + if linktz: + valuelink = linktz.group(1) + if valuelink != planned: + value = valuelink + else: + # Set current TZ to 'n/a' if the symlink points to a path + # which isn't a zone file. + value = 'n/a' + else: + # Set current TZ to 'n/a' if the symlink to the zone file is broken. + value = 'n/a' + else: + # If /etc/localtime is not a symlink best we can do is compare it with + # the 'planned' zone info file and return 'n/a' if they are different. + try: + if not filecmp.cmp('/etc/localtime', '/usr/share/zoneinfo/' + planned): + return 'n/a' + except Exception: + return 'n/a' + else: + self.abort('unknown parameter "%s"' % key) + return value + + def set_timezone(self, value): + self._edit_file(filename=self.conf_files['name'], + regexp=self.regexps['name'], + value=self.tzline_format % value, + key='name') + for cmd in self.update_timezone: + self.execute(cmd) + + def set_hwclock(self, value): + if value == 'local': + option = '--localtime' + utc = 'no' + else: + option = '--utc' + utc = 'yes' + if self.conf_files['hwclock'] is not None: + self._edit_file(filename=self.conf_files['hwclock'], + regexp=self.regexps['hwclock'], + value='UTC=%s\n' % utc, + key='hwclock') + self.execute(self.update_hwclock, '--systohc', option, log=True) + + def set(self, key, value): + if key == 'name': + self.set_timezone(value) + elif key == 'hwclock': + self.set_hwclock(value) + else: + self.abort('unknown parameter "%s"' % key) + + +class SmartOSTimezone(Timezone): + """This is a Timezone manipulation class for SmartOS instances. + + It uses the C(sm-set-timezone) utility to set the timezone, and + inspects C(/etc/default/init) to determine the current timezone. + + NB: A zone needs to be rebooted in order for the change to be + activated. + """ + + def __init__(self, module): + super(SmartOSTimezone, self).__init__(module) + self.settimezone = self.module.get_bin_path('sm-set-timezone', required=False) + if not self.settimezone: + module.fail_json(msg='sm-set-timezone not found. Make sure the smtools package is installed.') + + def get(self, key, phase): + """Lookup the current timezone name in `/etc/default/init`. If anything else + is requested, or if the TZ field is not set we fail. + """ + if key == 'name': + try: + f = open('/etc/default/init', 'r') + for line in f: + m = re.match('^TZ=(.*)$', line.strip()) + if m: + return m.groups()[0] + except Exception: + self.module.fail_json(msg='Failed to read /etc/default/init') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + """Set the requested timezone through sm-set-timezone, an invalid timezone name + will be rejected and we have no further input validation to perform. + """ + if key == 'name': + cmd = 'sm-set-timezone %s' % value + + (rc, stdout, stderr) = self.module.run_command(cmd) + + if rc != 0: + self.module.fail_json(msg=stderr) + + # sm-set-timezone knows no state and will always set the timezone. + # XXX: https://github.com/joyent/smtools/pull/2 + m = re.match(r'^\* Changed (to)? timezone (to)? (%s).*' % value, stdout.splitlines()[1]) + if not (m and m.groups()[-1] == value): + self.module.fail_json(msg='Failed to set timezone') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class DarwinTimezone(Timezone): + """This is the timezone implementation for Darwin which, unlike other *BSD + implementations, uses the `systemsetup` command on Darwin to check/set + the timezone. + """ + + regexps = dict( + name=re.compile(r'^\s*Time ?Zone\s*:\s*([^\s]+)', re.MULTILINE) + ) + + def __init__(self, module): + super(DarwinTimezone, self).__init__(module) + self.systemsetup = module.get_bin_path('systemsetup', required=True) + self.status = dict() + # Validate given timezone + if 'name' in self.value: + self._verify_timezone() + + def _get_current_timezone(self, phase): + """Lookup the current timezone via `systemsetup -gettimezone`.""" + if phase not in self.status: + self.status[phase] = self.execute(self.systemsetup, '-gettimezone') + return self.status[phase] + + def _verify_timezone(self): + tz = self.value['name']['planned'] + # Lookup the list of supported timezones via `systemsetup -listtimezones`. + # Note: Skip the first line that contains the label 'Time Zones:' + out = self.execute(self.systemsetup, '-listtimezones').splitlines()[1:] + tz_list = list(map(lambda x: x.strip(), out)) + if tz not in tz_list: + self.abort('given timezone "%s" is not available' % tz) + return tz + + def get(self, key, phase): + if key == 'name': + status = self._get_current_timezone(phase) + value = self.regexps[key].search(status).group(1) + return value + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + if key == 'name': + self.execute(self.systemsetup, '-settimezone', value, log=True) + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class BSDTimezone(Timezone): + """This is the timezone implementation for *BSD which works simply through + updating the `/etc/localtime` symlink to point to a valid timezone name under + `/usr/share/zoneinfo`. + """ + + def __init__(self, module): + super(BSDTimezone, self).__init__(module) + + def __get_timezone(self): + zoneinfo_dir = '/usr/share/zoneinfo/' + localtime_file = '/etc/localtime' + + # Strategy 1: + # If /etc/localtime does not exist, assum the timezone is UTC. + if not os.path.exists(localtime_file): + self.module.warn('Could not read /etc/localtime. Assuming UTC.') + return 'UTC' + + # Strategy 2: + # Follow symlink of /etc/localtime + zoneinfo_file = localtime_file + while not zoneinfo_file.startswith(zoneinfo_dir): + try: + zoneinfo_file = os.readlink(localtime_file) + except OSError: + # OSError means "end of symlink chain" or broken link. + break + else: + return zoneinfo_file.replace(zoneinfo_dir, '') + + # Strategy 3: + # (If /etc/localtime is not symlinked) + # Check all files in /usr/share/zoneinfo and return first non-link match. + for dname, _, fnames in sorted(os.walk(zoneinfo_dir)): + for fname in sorted(fnames): + zoneinfo_file = os.path.join(dname, fname) + if not os.path.islink(zoneinfo_file) and filecmp.cmp(zoneinfo_file, localtime_file): + return zoneinfo_file.replace(zoneinfo_dir, '') + + # Strategy 4: + # As a fall-back, return 'UTC' as default assumption. + self.module.warn('Could not identify timezone name from /etc/localtime. Assuming UTC.') + return 'UTC' + + def get(self, key, phase): + """Lookup the current timezone by resolving `/etc/localtime`.""" + if key == 'name': + return self.__get_timezone() + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + if key == 'name': + # First determine if the requested timezone is valid by looking in + # the zoneinfo directory. + zonefile = '/usr/share/zoneinfo/' + value + try: + if not os.path.isfile(zonefile): + self.module.fail_json(msg='%s is not a recognized timezone' % value) + except Exception: + self.module.fail_json(msg='Failed to stat %s' % zonefile) + + # Now (somewhat) atomically update the symlink by creating a new + # symlink and move it into place. Otherwise we have to remove the + # original symlink and create the new symlink, however that would + # create a race condition in case another process tries to read + # /etc/localtime between removal and creation. + suffix = "".join([random.choice(string.ascii_letters + string.digits) for x in range(0, 10)]) + new_localtime = '/etc/localtime.' + suffix + + try: + os.symlink(zonefile, new_localtime) + os.rename(new_localtime, '/etc/localtime') + except Exception: + os.remove(new_localtime) + self.module.fail_json(msg='Could not update /etc/localtime') + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +class AIXTimezone(Timezone): + """This is a Timezone manipulation class for AIX instances. + + It uses the C(chtz) utility to set the timezone, and + inspects C(/etc/environment) to determine the current timezone. + + While AIX time zones can be set using two formats (POSIX and + Olson) the prefered method is Olson. + See the following article for more information: + https://developer.ibm.com/articles/au-aix-posix/ + + NB: AIX needs to be rebooted in order for the change to be + activated. + """ + + def __init__(self, module): + super(AIXTimezone, self).__init__(module) + self.settimezone = self.module.get_bin_path('chtz', required=True) + + def __get_timezone(self): + """ Return the current value of TZ= in /etc/environment """ + try: + f = open('/etc/environment', 'r') + etcenvironment = f.read() + f.close() + except Exception: + self.module.fail_json(msg='Issue reading contents of /etc/environment') + + match = re.search(r'^TZ=(.*)$', etcenvironment, re.MULTILINE) + if match: + return match.group(1) + else: + return None + + def get(self, key, phase): + """Lookup the current timezone name in `/etc/environment`. If anything else + is requested, or if the TZ field is not set we fail. + """ + if key == 'name': + return self.__get_timezone() + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + def set(self, key, value): + """Set the requested timezone through chtz, an invalid timezone name + will be rejected and we have no further input validation to perform. + """ + if key == 'name': + # chtz seems to always return 0 on AIX 7.2, even for invalid timezone values. + # It will only return non-zero if the chtz command itself fails, it does not check for + # valid timezones. We need to perform a basic check to confirm that the timezone + # definition exists in /usr/share/lib/zoneinfo + # This does mean that we can only support Olson for now. The below commented out regex + # detects Olson date formats, so in the future we could detect Posix or Olson and + # act accordingly. + + # regex_olson = re.compile('^([a-z0-9_\-\+]+\/?)+$', re.IGNORECASE) + # if not regex_olson.match(value): + # msg = 'Supplied timezone (%s) does not appear to a be valid Olson string' % value + # self.module.fail_json(msg=msg) + + # First determine if the requested timezone is valid by looking in the zoneinfo + # directory. + zonefile = '/usr/share/lib/zoneinfo/' + value + try: + if not os.path.isfile(zonefile): + self.module.fail_json(msg='%s is not a recognized timezone.' % value) + except Exception: + self.module.fail_json(msg='Failed to check %s.' % zonefile) + + # Now set the TZ using chtz + cmd = 'chtz %s' % value + (rc, stdout, stderr) = self.module.run_command(cmd) + + if rc != 0: + self.module.fail_json(msg=stderr) + + # The best condition check we can do is to check the value of TZ after making the + # change. + TZ = self.__get_timezone() + if TZ != value: + msg = 'TZ value does not match post-change (Actual: %s, Expected: %s).' % (TZ, value) + self.module.fail_json(msg=msg) + + else: + self.module.fail_json(msg='%s is not a supported option on target platform' % key) + + +def main(): + # Construct 'module' and 'tz' + module = AnsibleModule( + argument_spec=dict( + hwclock=dict(type='str', choices=['local', 'UTC'], aliases=['rtc']), + name=dict(type='str'), + ), + required_one_of=[ + ['hwclock', 'name'] + ], + supports_check_mode=True, + ) + tz = Timezone(module) + + # Check the current state + tz.check(phase='before') + if module.check_mode: + diff = tz.diff('before', 'planned') + # In check mode, 'planned' state is treated as 'after' state + diff['after'] = diff.pop('planned') + else: + # Make change + tz.change() + # Check the current state + tz.check(phase='after') + # Examine if the current state matches planned state + (after, planned) = tz.diff('after', 'planned').values() + if after != planned: + tz.abort('still not desired state, though changes have made - ' + 'planned: %s, after: %s' % (str(planned), str(after))) + diff = tz.diff('before', 'after') + + changed = (diff['before'] != diff['after']) + if len(tz.msg) > 0: + module.exit_json(changed=changed, diff=diff, msg='\n'.join(tz.msg)) + else: + module.exit_json(changed=changed, diff=diff) + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/x509_crl.py b/test/support/integration/plugins/modules/x509_crl.py new file mode 100644 index 00000000..ef601eda --- /dev/null +++ b/test/support/integration/plugins/modules/x509_crl.py @@ -0,0 +1,783 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2019, Felix Fontein <felix@fontein.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: x509_crl +version_added: "2.10" +short_description: Generate Certificate Revocation Lists (CRLs) +description: + - This module allows one to (re)generate or update Certificate Revocation Lists (CRLs). + - Certificates on the revocation list can be either specified via serial number and (optionally) their issuer, + or as a path to a certificate file in PEM format. +requirements: + - cryptography >= 1.2 +author: + - Felix Fontein (@felixfontein) +options: + state: + description: + - Whether the CRL file should exist or not, taking action if the state is different from what is stated. + type: str + default: present + choices: [ absent, present ] + + mode: + description: + - Defines how to process entries of existing CRLs. + - If set to C(generate), makes sure that the CRL has the exact set of revoked certificates + as specified in I(revoked_certificates). + - If set to C(update), makes sure that the CRL contains the revoked certificates from + I(revoked_certificates), but can also contain other revoked certificates. If the CRL file + already exists, all entries from the existing CRL will also be included in the new CRL. + When using C(update), you might be interested in setting I(ignore_timestamps) to C(yes). + type: str + default: generate + choices: [ generate, update ] + + force: + description: + - Should the CRL be forced to be regenerated. + type: bool + default: no + + backup: + description: + - Create a backup file including a timestamp so you can get the original + CRL back if you overwrote it with a new one by accident. + type: bool + default: no + + path: + description: + - Remote absolute path where the generated CRL file should be created or is already located. + type: path + required: yes + + privatekey_path: + description: + - Path to the CA's private key to use when signing the CRL. + - Either I(privatekey_path) or I(privatekey_content) must be specified if I(state) is C(present), but not both. + type: path + + privatekey_content: + description: + - The content of the CA's private key to use when signing the CRL. + - Either I(privatekey_path) or I(privatekey_content) must be specified if I(state) is C(present), but not both. + type: str + + privatekey_passphrase: + description: + - The passphrase for the I(privatekey_path). + - This is required if the private key is password protected. + type: str + + issuer: + description: + - Key/value pairs that will be present in the issuer name field of the CRL. + - If you need to specify more than one value with the same key, use a list as value. + - Required if I(state) is C(present). + type: dict + + last_update: + description: + - The point in time from which this CRL can be trusted. + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent, except when + I(ignore_timestamps) is set to C(yes). + type: str + default: "+0s" + + next_update: + description: + - "The absolute latest point in time by which this I(issuer) is expected to have issued + another CRL. Many clients will treat a CRL as expired once I(next_update) occurs." + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent, except when + I(ignore_timestamps) is set to C(yes). + - Required if I(state) is C(present). + type: str + + digest: + description: + - Digest algorithm to be used when signing the CRL. + type: str + default: sha256 + + revoked_certificates: + description: + - List of certificates to be revoked. + - Required if I(state) is C(present). + type: list + elements: dict + suboptions: + path: + description: + - Path to a certificate in PEM format. + - The serial number and issuer will be extracted from the certificate. + - Mutually exclusive with I(content) and I(serial_number). One of these three options + must be specified. + type: path + content: + description: + - Content of a certificate in PEM format. + - The serial number and issuer will be extracted from the certificate. + - Mutually exclusive with I(path) and I(serial_number). One of these three options + must be specified. + type: str + serial_number: + description: + - Serial number of the certificate. + - Mutually exclusive with I(path) and I(content). One of these three options must + be specified. + type: int + revocation_date: + description: + - The point in time the certificate was revoked. + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent, except when + I(ignore_timestamps) is set to C(yes). + type: str + default: "+0s" + issuer: + description: + - The certificate's issuer. + - "Example: C(DNS:ca.example.org)" + type: list + elements: str + issuer_critical: + description: + - Whether the certificate issuer extension should be critical. + type: bool + default: no + reason: + description: + - The value for the revocation reason extension. + type: str + choices: + - unspecified + - key_compromise + - ca_compromise + - affiliation_changed + - superseded + - cessation_of_operation + - certificate_hold + - privilege_withdrawn + - aa_compromise + - remove_from_crl + reason_critical: + description: + - Whether the revocation reason extension should be critical. + type: bool + default: no + invalidity_date: + description: + - The point in time it was known/suspected that the private key was compromised + or that the certificate otherwise became invalid. + - Time can be specified either as relative time or as absolute timestamp. + - Time will always be interpreted as UTC. + - Valid format is C([+-]timespec | ASN.1 TIME) where timespec can be an integer + + C([w | d | h | m | s]) (e.g. C(+32w1d2h). + - Note that if using relative time this module is NOT idempotent. This will NOT + change when I(ignore_timestamps) is set to C(yes). + type: str + invalidity_date_critical: + description: + - Whether the invalidity date extension should be critical. + type: bool + default: no + + ignore_timestamps: + description: + - Whether the timestamps I(last_update), I(next_update) and I(revocation_date) (in + I(revoked_certificates)) should be ignored for idempotency checks. The timestamp + I(invalidity_date) in I(revoked_certificates) will never be ignored. + - Use this in combination with relative timestamps for these values to get idempotency. + type: bool + default: no + + return_content: + description: + - If set to C(yes), will return the (current or generated) CRL's content as I(crl). + type: bool + default: no + +extends_documentation_fragment: + - files + +notes: + - All ASN.1 TIME values should be specified following the YYYYMMDDHHMMSSZ pattern. + - Date specified should be UTC. Minutes and seconds are mandatory. +''' + +EXAMPLES = r''' +- name: Generate a CRL + x509_crl: + path: /etc/ssl/my-ca.crl + privatekey_path: /etc/ssl/private/my-ca.pem + issuer: + CN: My CA + last_update: "+0s" + next_update: "+7d" + revoked_certificates: + - serial_number: 1234 + revocation_date: 20190331202428Z + issuer: + CN: My CA + - serial_number: 2345 + revocation_date: 20191013152910Z + reason: affiliation_changed + invalidity_date: 20191001000000Z + - path: /etc/ssl/crt/revoked-cert.pem + revocation_date: 20191010010203Z +''' + +RETURN = r''' +filename: + description: Path to the generated CRL + returned: changed or success + type: str + sample: /path/to/my-ca.crl +backup_file: + description: Name of backup file created. + returned: changed and if I(backup) is C(yes) + type: str + sample: /path/to/my-ca.crl.2019-03-09@11:22~ +privatekey: + description: Path to the private CA key + returned: changed or success + type: str + sample: /path/to/my-ca.pem +issuer: + description: + - The CRL's issuer. + - Note that for repeated values, only the last one will be returned. + returned: success + type: dict + sample: '{"organizationName": "Ansible", "commonName": "ca.example.com"}' +issuer_ordered: + description: The CRL's issuer as an ordered list of tuples. + returned: success + type: list + elements: list + sample: '[["organizationName", "Ansible"], ["commonName": "ca.example.com"]]' +last_update: + description: The point in time from which this CRL can be trusted as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +next_update: + description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +digest: + description: The signature algorithm used to sign the CRL. + returned: success + type: str + sample: sha256WithRSAEncryption +revoked_certificates: + description: List of certificates to be revoked. + returned: success + type: list + elements: dict + contains: + serial_number: + description: Serial number of the certificate. + type: int + sample: 1234 + revocation_date: + description: The point in time the certificate was revoked as ASN.1 TIME. + type: str + sample: 20190413202428Z + issuer: + description: The certificate's issuer. + type: list + elements: str + sample: '["DNS:ca.example.org"]' + issuer_critical: + description: Whether the certificate issuer extension is critical. + type: bool + sample: no + reason: + description: + - The value for the revocation reason extension. + - One of C(unspecified), C(key_compromise), C(ca_compromise), C(affiliation_changed), C(superseded), + C(cessation_of_operation), C(certificate_hold), C(privilege_withdrawn), C(aa_compromise), and + C(remove_from_crl). + type: str + sample: key_compromise + reason_critical: + description: Whether the revocation reason extension is critical. + type: bool + sample: no + invalidity_date: + description: | + The point in time it was known/suspected that the private key was compromised + or that the certificate otherwise became invalid as ASN.1 TIME. + type: str + sample: 20190413202428Z + invalidity_date_critical: + description: Whether the invalidity date extension is critical. + type: bool + sample: no +crl: + description: The (current or generated) CRL's content. + returned: if I(state) is C(present) and I(return_content) is C(yes) + type: str +''' + + +import os +import traceback +from distutils.version import LooseVersion + +from ansible.module_utils import crypto as crypto_utils +from ansible.module_utils._text import to_native, to_text +from ansible.module_utils.basic import AnsibleModule, missing_required_lib + +MINIMAL_CRYPTOGRAPHY_VERSION = '1.2' + +CRYPTOGRAPHY_IMP_ERR = None +try: + import cryptography + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.serialization import Encoding + from cryptography.x509 import ( + CertificateRevocationListBuilder, + RevokedCertificateBuilder, + NameAttribute, + Name, + ) + CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__) +except ImportError: + CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() + CRYPTOGRAPHY_FOUND = False +else: + CRYPTOGRAPHY_FOUND = True + + +TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ" + + +class CRLError(crypto_utils.OpenSSLObjectError): + pass + + +class CRL(crypto_utils.OpenSSLObject): + + def __init__(self, module): + super(CRL, self).__init__( + module.params['path'], + module.params['state'], + module.params['force'], + module.check_mode + ) + + self.update = module.params['mode'] == 'update' + self.ignore_timestamps = module.params['ignore_timestamps'] + self.return_content = module.params['return_content'] + self.crl_content = None + + self.privatekey_path = module.params['privatekey_path'] + self.privatekey_content = module.params['privatekey_content'] + if self.privatekey_content is not None: + self.privatekey_content = self.privatekey_content.encode('utf-8') + self.privatekey_passphrase = module.params['privatekey_passphrase'] + + self.issuer = crypto_utils.parse_name_field(module.params['issuer']) + self.issuer = [(entry[0], entry[1]) for entry in self.issuer if entry[1]] + + self.last_update = crypto_utils.get_relative_time_option(module.params['last_update'], 'last_update') + self.next_update = crypto_utils.get_relative_time_option(module.params['next_update'], 'next_update') + + self.digest = crypto_utils.select_message_digest(module.params['digest']) + if self.digest is None: + raise CRLError('The digest "{0}" is not supported'.format(module.params['digest'])) + + self.revoked_certificates = [] + for i, rc in enumerate(module.params['revoked_certificates']): + result = { + 'serial_number': None, + 'revocation_date': None, + 'issuer': None, + 'issuer_critical': False, + 'reason': None, + 'reason_critical': False, + 'invalidity_date': None, + 'invalidity_date_critical': False, + } + path_prefix = 'revoked_certificates[{0}].'.format(i) + if rc['path'] is not None or rc['content'] is not None: + # Load certificate from file or content + try: + if rc['content'] is not None: + rc['content'] = rc['content'].encode('utf-8') + cert = crypto_utils.load_certificate(rc['path'], content=rc['content'], backend='cryptography') + try: + result['serial_number'] = cert.serial_number + except AttributeError: + # The property was called "serial" before cryptography 1.4 + result['serial_number'] = cert.serial + except crypto_utils.OpenSSLObjectError as e: + if rc['content'] is not None: + module.fail_json( + msg='Cannot parse certificate from {0}content: {1}'.format(path_prefix, to_native(e)) + ) + else: + module.fail_json( + msg='Cannot read certificate "{1}" from {0}path: {2}'.format(path_prefix, rc['path'], to_native(e)) + ) + else: + # Specify serial_number (and potentially issuer) directly + result['serial_number'] = rc['serial_number'] + # All other options + if rc['issuer']: + result['issuer'] = [crypto_utils.cryptography_get_name(issuer) for issuer in rc['issuer']] + result['issuer_critical'] = rc['issuer_critical'] + result['revocation_date'] = crypto_utils.get_relative_time_option( + rc['revocation_date'], + path_prefix + 'revocation_date' + ) + if rc['reason']: + result['reason'] = crypto_utils.REVOCATION_REASON_MAP[rc['reason']] + result['reason_critical'] = rc['reason_critical'] + if rc['invalidity_date']: + result['invalidity_date'] = crypto_utils.get_relative_time_option( + rc['invalidity_date'], + path_prefix + 'invalidity_date' + ) + result['invalidity_date_critical'] = rc['invalidity_date_critical'] + self.revoked_certificates.append(result) + + self.module = module + + self.backup = module.params['backup'] + self.backup_file = None + + try: + self.privatekey = crypto_utils.load_privatekey( + path=self.privatekey_path, + content=self.privatekey_content, + passphrase=self.privatekey_passphrase, + backend='cryptography' + ) + except crypto_utils.OpenSSLBadPassphraseError as exc: + raise CRLError(exc) + + self.crl = None + try: + with open(self.path, 'rb') as f: + data = f.read() + self.crl = x509.load_pem_x509_crl(data, default_backend()) + if self.return_content: + self.crl_content = data + except Exception as dummy: + self.crl_content = None + + def remove(self): + if self.backup: + self.backup_file = self.module.backup_local(self.path) + super(CRL, self).remove(self.module) + + def _compress_entry(self, entry): + if self.ignore_timestamps: + # Throw out revocation_date + return ( + entry['serial_number'], + tuple(entry['issuer']) if entry['issuer'] is not None else None, + entry['issuer_critical'], + entry['reason'], + entry['reason_critical'], + entry['invalidity_date'], + entry['invalidity_date_critical'], + ) + else: + return ( + entry['serial_number'], + entry['revocation_date'], + tuple(entry['issuer']) if entry['issuer'] is not None else None, + entry['issuer_critical'], + entry['reason'], + entry['reason_critical'], + entry['invalidity_date'], + entry['invalidity_date_critical'], + ) + + def check(self, perms_required=True): + """Ensure the resource is in its desired state.""" + + state_and_perms = super(CRL, self).check(self.module, perms_required) + + if not state_and_perms: + return False + + if self.crl is None: + return False + + if self.last_update != self.crl.last_update and not self.ignore_timestamps: + return False + if self.next_update != self.crl.next_update and not self.ignore_timestamps: + return False + if self.digest.name != self.crl.signature_hash_algorithm.name: + return False + + want_issuer = [(crypto_utils.cryptography_name_to_oid(entry[0]), entry[1]) for entry in self.issuer] + if want_issuer != [(sub.oid, sub.value) for sub in self.crl.issuer]: + return False + + old_entries = [self._compress_entry(crypto_utils.cryptography_decode_revoked_certificate(cert)) for cert in self.crl] + new_entries = [self._compress_entry(cert) for cert in self.revoked_certificates] + if self.update: + # We don't simply use a set so that duplicate entries are treated correctly + for entry in new_entries: + try: + old_entries.remove(entry) + except ValueError: + return False + else: + if old_entries != new_entries: + return False + + return True + + def _generate_crl(self): + backend = default_backend() + crl = CertificateRevocationListBuilder() + + try: + crl = crl.issuer_name(Name([ + NameAttribute(crypto_utils.cryptography_name_to_oid(entry[0]), to_text(entry[1])) + for entry in self.issuer + ])) + except ValueError as e: + raise CRLError(e) + + crl = crl.last_update(self.last_update) + crl = crl.next_update(self.next_update) + + if self.update and self.crl: + new_entries = set([self._compress_entry(entry) for entry in self.revoked_certificates]) + for entry in self.crl: + decoded_entry = self._compress_entry(crypto_utils.cryptography_decode_revoked_certificate(entry)) + if decoded_entry not in new_entries: + crl = crl.add_revoked_certificate(entry) + for entry in self.revoked_certificates: + revoked_cert = RevokedCertificateBuilder() + revoked_cert = revoked_cert.serial_number(entry['serial_number']) + revoked_cert = revoked_cert.revocation_date(entry['revocation_date']) + if entry['issuer'] is not None: + revoked_cert = revoked_cert.add_extension( + x509.CertificateIssuer([ + crypto_utils.cryptography_get_name(name) for name in self.entry['issuer'] + ]), + entry['issuer_critical'] + ) + if entry['reason'] is not None: + revoked_cert = revoked_cert.add_extension( + x509.CRLReason(entry['reason']), + entry['reason_critical'] + ) + if entry['invalidity_date'] is not None: + revoked_cert = revoked_cert.add_extension( + x509.InvalidityDate(entry['invalidity_date']), + entry['invalidity_date_critical'] + ) + crl = crl.add_revoked_certificate(revoked_cert.build(backend)) + + self.crl = crl.sign(self.privatekey, self.digest, backend=backend) + return self.crl.public_bytes(Encoding.PEM) + + def generate(self): + if not self.check(perms_required=False) or self.force: + result = self._generate_crl() + if self.return_content: + self.crl_content = result + if self.backup: + self.backup_file = self.module.backup_local(self.path) + crypto_utils.write_file(self.module, result) + self.changed = True + + file_args = self.module.load_file_common_arguments(self.module.params) + if self.module.set_fs_attributes_if_different(file_args, False): + self.changed = True + + def _dump_revoked(self, entry): + return { + 'serial_number': entry['serial_number'], + 'revocation_date': entry['revocation_date'].strftime(TIMESTAMP_FORMAT), + 'issuer': + [crypto_utils.cryptography_decode_name(issuer) for issuer in entry['issuer']] + if entry['issuer'] is not None else None, + 'issuer_critical': entry['issuer_critical'], + 'reason': crypto_utils.REVOCATION_REASON_MAP_INVERSE.get(entry['reason']) if entry['reason'] is not None else None, + 'reason_critical': entry['reason_critical'], + 'invalidity_date': + entry['invalidity_date'].strftime(TIMESTAMP_FORMAT) + if entry['invalidity_date'] is not None else None, + 'invalidity_date_critical': entry['invalidity_date_critical'], + } + + def dump(self, check_mode=False): + result = { + 'changed': self.changed, + 'filename': self.path, + 'privatekey': self.privatekey_path, + 'last_update': None, + 'next_update': None, + 'digest': None, + 'issuer_ordered': None, + 'issuer': None, + 'revoked_certificates': [], + } + if self.backup_file: + result['backup_file'] = self.backup_file + + if check_mode: + result['last_update'] = self.last_update.strftime(TIMESTAMP_FORMAT) + result['next_update'] = self.next_update.strftime(TIMESTAMP_FORMAT) + # result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid) + result['digest'] = self.module.params['digest'] + result['issuer_ordered'] = self.issuer + result['issuer'] = {} + for k, v in self.issuer: + result['issuer'][k] = v + result['revoked_certificates'] = [] + for entry in self.revoked_certificates: + result['revoked_certificates'].append(self._dump_revoked(entry)) + elif self.crl: + result['last_update'] = self.crl.last_update.strftime(TIMESTAMP_FORMAT) + result['next_update'] = self.crl.next_update.strftime(TIMESTAMP_FORMAT) + try: + result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid) + except AttributeError: + # Older cryptography versions don't have signature_algorithm_oid yet + dotted = crypto_utils._obj2txt( + self.crl._backend._lib, + self.crl._backend._ffi, + self.crl._x509_crl.sig_alg.algorithm + ) + oid = x509.oid.ObjectIdentifier(dotted) + result['digest'] = crypto_utils.cryptography_oid_to_name(oid) + issuer = [] + for attribute in self.crl.issuer: + issuer.append([crypto_utils.cryptography_oid_to_name(attribute.oid), attribute.value]) + result['issuer_ordered'] = issuer + result['issuer'] = {} + for k, v in issuer: + result['issuer'][k] = v + result['revoked_certificates'] = [] + for cert in self.crl: + entry = crypto_utils.cryptography_decode_revoked_certificate(cert) + result['revoked_certificates'].append(self._dump_revoked(entry)) + + if self.return_content: + result['crl'] = self.crl_content + + return result + + +def main(): + module = AnsibleModule( + argument_spec=dict( + state=dict(type='str', default='present', choices=['present', 'absent']), + mode=dict(type='str', default='generate', choices=['generate', 'update']), + force=dict(type='bool', default=False), + backup=dict(type='bool', default=False), + path=dict(type='path', required=True), + privatekey_path=dict(type='path'), + privatekey_content=dict(type='str'), + privatekey_passphrase=dict(type='str', no_log=True), + issuer=dict(type='dict'), + last_update=dict(type='str', default='+0s'), + next_update=dict(type='str'), + digest=dict(type='str', default='sha256'), + ignore_timestamps=dict(type='bool', default=False), + return_content=dict(type='bool', default=False), + revoked_certificates=dict( + type='list', + elements='dict', + options=dict( + path=dict(type='path'), + content=dict(type='str'), + serial_number=dict(type='int'), + revocation_date=dict(type='str', default='+0s'), + issuer=dict(type='list', elements='str'), + issuer_critical=dict(type='bool', default=False), + reason=dict( + type='str', + choices=[ + 'unspecified', 'key_compromise', 'ca_compromise', 'affiliation_changed', + 'superseded', 'cessation_of_operation', 'certificate_hold', + 'privilege_withdrawn', 'aa_compromise', 'remove_from_crl' + ] + ), + reason_critical=dict(type='bool', default=False), + invalidity_date=dict(type='str'), + invalidity_date_critical=dict(type='bool', default=False), + ), + required_one_of=[['path', 'content', 'serial_number']], + mutually_exclusive=[['path', 'content', 'serial_number']], + ), + ), + required_if=[ + ('state', 'present', ['privatekey_path', 'privatekey_content'], True), + ('state', 'present', ['issuer', 'next_update', 'revoked_certificates'], False), + ], + mutually_exclusive=( + ['privatekey_path', 'privatekey_content'], + ), + supports_check_mode=True, + add_file_common_args=True, + ) + + if not CRYPTOGRAPHY_FOUND: + module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)), + exception=CRYPTOGRAPHY_IMP_ERR) + + try: + crl = CRL(module) + + if module.params['state'] == 'present': + if module.check_mode: + result = crl.dump(check_mode=True) + result['changed'] = module.params['force'] or not crl.check() + module.exit_json(**result) + + crl.generate() + else: + if module.check_mode: + result = crl.dump(check_mode=True) + result['changed'] = os.path.exists(module.params['path']) + module.exit_json(**result) + + crl.remove() + + result = crl.dump() + module.exit_json(**result) + except crypto_utils.OpenSSLObjectError as exc: + module.fail_json(msg=to_native(exc)) + + +if __name__ == "__main__": + main() diff --git a/test/support/integration/plugins/modules/x509_crl_info.py b/test/support/integration/plugins/modules/x509_crl_info.py new file mode 100644 index 00000000..b61db26f --- /dev/null +++ b/test/support/integration/plugins/modules/x509_crl_info.py @@ -0,0 +1,281 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2020, Felix Fontein <felix@fontein.de> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: x509_crl_info +version_added: "2.10" +short_description: Retrieve information on Certificate Revocation Lists (CRLs) +description: + - This module allows one to retrieve information on Certificate Revocation Lists (CRLs). +requirements: + - cryptography >= 1.2 +author: + - Felix Fontein (@felixfontein) +options: + path: + description: + - Remote absolute path where the generated CRL file should be created or is already located. + - Either I(path) or I(content) must be specified, but not both. + type: path + content: + description: + - Content of the X.509 certificate in PEM format. + - Either I(path) or I(content) must be specified, but not both. + type: str + +notes: + - All timestamp values are provided in ASN.1 TIME format, i.e. following the C(YYYYMMDDHHMMSSZ) pattern. + They are all in UTC. +seealso: + - module: x509_crl +''' + +EXAMPLES = r''' +- name: Get information on CRL + x509_crl_info: + path: /etc/ssl/my-ca.crl + register: result + +- debug: + msg: "{{ result }}" +''' + +RETURN = r''' +issuer: + description: + - The CRL's issuer. + - Note that for repeated values, only the last one will be returned. + returned: success + type: dict + sample: '{"organizationName": "Ansible", "commonName": "ca.example.com"}' +issuer_ordered: + description: The CRL's issuer as an ordered list of tuples. + returned: success + type: list + elements: list + sample: '[["organizationName", "Ansible"], ["commonName": "ca.example.com"]]' +last_update: + description: The point in time from which this CRL can be trusted as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +next_update: + description: The point in time from which a new CRL will be issued and the client has to check for it as ASN.1 TIME. + returned: success + type: str + sample: 20190413202428Z +digest: + description: The signature algorithm used to sign the CRL. + returned: success + type: str + sample: sha256WithRSAEncryption +revoked_certificates: + description: List of certificates to be revoked. + returned: success + type: list + elements: dict + contains: + serial_number: + description: Serial number of the certificate. + type: int + sample: 1234 + revocation_date: + description: The point in time the certificate was revoked as ASN.1 TIME. + type: str + sample: 20190413202428Z + issuer: + description: The certificate's issuer. + type: list + elements: str + sample: '["DNS:ca.example.org"]' + issuer_critical: + description: Whether the certificate issuer extension is critical. + type: bool + sample: no + reason: + description: + - The value for the revocation reason extension. + - One of C(unspecified), C(key_compromise), C(ca_compromise), C(affiliation_changed), C(superseded), + C(cessation_of_operation), C(certificate_hold), C(privilege_withdrawn), C(aa_compromise), and + C(remove_from_crl). + type: str + sample: key_compromise + reason_critical: + description: Whether the revocation reason extension is critical. + type: bool + sample: no + invalidity_date: + description: | + The point in time it was known/suspected that the private key was compromised + or that the certificate otherwise became invalid as ASN.1 TIME. + type: str + sample: 20190413202428Z + invalidity_date_critical: + description: Whether the invalidity date extension is critical. + type: bool + sample: no +''' + + +import traceback +from distutils.version import LooseVersion + +from ansible.module_utils import crypto as crypto_utils +from ansible.module_utils._text import to_native +from ansible.module_utils.basic import AnsibleModule, missing_required_lib + +MINIMAL_CRYPTOGRAPHY_VERSION = '1.2' + +CRYPTOGRAPHY_IMP_ERR = None +try: + import cryptography + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + CRYPTOGRAPHY_VERSION = LooseVersion(cryptography.__version__) +except ImportError: + CRYPTOGRAPHY_IMP_ERR = traceback.format_exc() + CRYPTOGRAPHY_FOUND = False +else: + CRYPTOGRAPHY_FOUND = True + + +TIMESTAMP_FORMAT = "%Y%m%d%H%M%SZ" + + +class CRLError(crypto_utils.OpenSSLObjectError): + pass + + +class CRLInfo(crypto_utils.OpenSSLObject): + """The main module implementation.""" + + def __init__(self, module): + super(CRLInfo, self).__init__( + module.params['path'] or '', + 'present', + False, + module.check_mode + ) + + self.content = module.params['content'] + + self.module = module + + self.crl = None + if self.content is None: + try: + with open(self.path, 'rb') as f: + data = f.read() + except Exception as e: + self.module.fail_json(msg='Error while reading CRL file from disk: {0}'.format(e)) + else: + data = self.content.encode('utf-8') + + try: + self.crl = x509.load_pem_x509_crl(data, default_backend()) + except Exception as e: + self.module.fail_json(msg='Error while decoding CRL: {0}'.format(e)) + + def _dump_revoked(self, entry): + return { + 'serial_number': entry['serial_number'], + 'revocation_date': entry['revocation_date'].strftime(TIMESTAMP_FORMAT), + 'issuer': + [crypto_utils.cryptography_decode_name(issuer) for issuer in entry['issuer']] + if entry['issuer'] is not None else None, + 'issuer_critical': entry['issuer_critical'], + 'reason': crypto_utils.REVOCATION_REASON_MAP_INVERSE.get(entry['reason']) if entry['reason'] is not None else None, + 'reason_critical': entry['reason_critical'], + 'invalidity_date': + entry['invalidity_date'].strftime(TIMESTAMP_FORMAT) + if entry['invalidity_date'] is not None else None, + 'invalidity_date_critical': entry['invalidity_date_critical'], + } + + def get_info(self): + result = { + 'changed': False, + 'last_update': None, + 'next_update': None, + 'digest': None, + 'issuer_ordered': None, + 'issuer': None, + 'revoked_certificates': [], + } + + result['last_update'] = self.crl.last_update.strftime(TIMESTAMP_FORMAT) + result['next_update'] = self.crl.next_update.strftime(TIMESTAMP_FORMAT) + try: + result['digest'] = crypto_utils.cryptography_oid_to_name(self.crl.signature_algorithm_oid) + except AttributeError: + # Older cryptography versions don't have signature_algorithm_oid yet + dotted = crypto_utils._obj2txt( + self.crl._backend._lib, + self.crl._backend._ffi, + self.crl._x509_crl.sig_alg.algorithm + ) + oid = x509.oid.ObjectIdentifier(dotted) + result['digest'] = crypto_utils.cryptography_oid_to_name(oid) + issuer = [] + for attribute in self.crl.issuer: + issuer.append([crypto_utils.cryptography_oid_to_name(attribute.oid), attribute.value]) + result['issuer_ordered'] = issuer + result['issuer'] = {} + for k, v in issuer: + result['issuer'][k] = v + result['revoked_certificates'] = [] + for cert in self.crl: + entry = crypto_utils.cryptography_decode_revoked_certificate(cert) + result['revoked_certificates'].append(self._dump_revoked(entry)) + + return result + + def generate(self): + # Empty method because crypto_utils.OpenSSLObject wants this + pass + + def dump(self): + # Empty method because crypto_utils.OpenSSLObject wants this + pass + + +def main(): + module = AnsibleModule( + argument_spec=dict( + path=dict(type='path'), + content=dict(type='str'), + ), + required_one_of=( + ['path', 'content'], + ), + mutually_exclusive=( + ['path', 'content'], + ), + supports_check_mode=True, + ) + + if not CRYPTOGRAPHY_FOUND: + module.fail_json(msg=missing_required_lib('cryptography >= {0}'.format(MINIMAL_CRYPTOGRAPHY_VERSION)), + exception=CRYPTOGRAPHY_IMP_ERR) + + try: + crl = CRLInfo(module) + result = crl.get_info() + module.exit_json(**result) + except crypto_utils.OpenSSLObjectError as e: + module.fail_json(msg=to_native(e)) + + +if __name__ == "__main__": + main() diff --git a/test/support/integration/plugins/modules/xml.py b/test/support/integration/plugins/modules/xml.py new file mode 100644 index 00000000..b5b35a38 --- /dev/null +++ b/test/support/integration/plugins/modules/xml.py @@ -0,0 +1,966 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright: (c) 2014, Red Hat, Inc. +# Copyright: (c) 2014, Tim Bielawa <tbielawa@redhat.com> +# Copyright: (c) 2014, Magnus Hedemark <mhedemar@redhat.com> +# Copyright: (c) 2017, Dag Wieers <dag@wieers.com> +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import (absolute_import, division, print_function) +__metaclass__ = type + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + +DOCUMENTATION = r''' +--- +module: xml +short_description: Manage bits and pieces of XML files or strings +description: +- A CRUD-like interface to managing bits of XML files. +version_added: '2.4' +options: + path: + description: + - Path to the file to operate on. + - This file must exist ahead of time. + - This parameter is required, unless C(xmlstring) is given. + type: path + required: yes + aliases: [ dest, file ] + xmlstring: + description: + - A string containing XML on which to operate. + - This parameter is required, unless C(path) is given. + type: str + required: yes + xpath: + description: + - A valid XPath expression describing the item(s) you want to manipulate. + - Operates on the document root, C(/), by default. + type: str + namespaces: + description: + - The namespace C(prefix:uri) mapping for the XPath expression. + - Needs to be a C(dict), not a C(list) of items. + type: dict + state: + description: + - Set or remove an xpath selection (node(s), attribute(s)). + type: str + choices: [ absent, present ] + default: present + aliases: [ ensure ] + attribute: + description: + - The attribute to select when using parameter C(value). + - This is a string, not prepended with C(@). + type: raw + value: + description: + - Desired state of the selected attribute. + - Either a string, or to unset a value, the Python C(None) keyword (YAML Equivalent, C(null)). + - Elements default to no value (but present). + - Attributes default to an empty string. + type: raw + add_children: + description: + - Add additional child-element(s) to a selected element for a given C(xpath). + - Child elements must be given in a list and each item may be either a string + (eg. C(children=ansible) to add an empty C(<ansible/>) child element), + or a hash where the key is an element name and the value is the element value. + - This parameter requires C(xpath) to be set. + type: list + set_children: + description: + - Set the child-element(s) of a selected element for a given C(xpath). + - Removes any existing children. + - Child elements must be specified as in C(add_children). + - This parameter requires C(xpath) to be set. + type: list + count: + description: + - Search for a given C(xpath) and provide the count of any matches. + - This parameter requires C(xpath) to be set. + type: bool + default: no + print_match: + description: + - Search for a given C(xpath) and print out any matches. + - This parameter requires C(xpath) to be set. + type: bool + default: no + pretty_print: + description: + - Pretty print XML output. + type: bool + default: no + content: + description: + - Search for a given C(xpath) and get content. + - This parameter requires C(xpath) to be set. + type: str + choices: [ attribute, text ] + input_type: + description: + - Type of input for C(add_children) and C(set_children). + type: str + choices: [ xml, yaml ] + default: yaml + backup: + description: + - Create a backup file including the timestamp information so you can get + the original file back if you somehow clobbered it incorrectly. + type: bool + default: no + strip_cdata_tags: + description: + - Remove CDATA tags surrounding text values. + - Note that this might break your XML file if text values contain characters that could be interpreted as XML. + type: bool + default: no + version_added: '2.7' + insertbefore: + description: + - Add additional child-element(s) before the first selected element for a given C(xpath). + - Child elements must be given in a list and each item may be either a string + (eg. C(children=ansible) to add an empty C(<ansible/>) child element), + or a hash where the key is an element name and the value is the element value. + - This parameter requires C(xpath) to be set. + type: bool + default: no + version_added: '2.8' + insertafter: + description: + - Add additional child-element(s) after the last selected element for a given C(xpath). + - Child elements must be given in a list and each item may be either a string + (eg. C(children=ansible) to add an empty C(<ansible/>) child element), + or a hash where the key is an element name and the value is the element value. + - This parameter requires C(xpath) to be set. + type: bool + default: no + version_added: '2.8' +requirements: +- lxml >= 2.3.0 +notes: +- Use the C(--check) and C(--diff) options when testing your expressions. +- The diff output is automatically pretty-printed, so may not reflect the actual file content, only the file structure. +- This module does not handle complicated xpath expressions, so limit xpath selectors to simple expressions. +- Beware that in case your XML elements are namespaced, you need to use the C(namespaces) parameter, see the examples. +- Namespaces prefix should be used for all children of an element where namespace is defined, unless another namespace is defined for them. +seealso: +- name: Xml module development community wiki + description: More information related to the development of this xml module. + link: https://github.com/ansible/community/wiki/Module:-xml +- name: Introduction to XPath + description: A brief tutorial on XPath (w3schools.com). + link: https://www.w3schools.com/xml/xpath_intro.asp +- name: XPath Reference document + description: The reference documentation on XSLT/XPath (developer.mozilla.org). + link: https://developer.mozilla.org/en-US/docs/Web/XPath +author: +- Tim Bielawa (@tbielawa) +- Magnus Hedemark (@magnus919) +- Dag Wieers (@dagwieers) +''' + +EXAMPLES = r''' +# Consider the following XML file: +# +# <business type="bar"> +# <name>Tasty Beverage Co.</name> +# <beers> +# <beer>Rochefort 10</beer> +# <beer>St. Bernardus Abbot 12</beer> +# <beer>Schlitz</beer> +# </beers> +# <rating subjective="true">10</rating> +# <website> +# <mobilefriendly/> +# <address>http://tastybeverageco.com</address> +# </website> +# </business> + +- name: Remove the 'subjective' attribute of the 'rating' element + xml: + path: /foo/bar.xml + xpath: /business/rating/@subjective + state: absent + +- name: Set the rating to '11' + xml: + path: /foo/bar.xml + xpath: /business/rating + value: 11 + +# Retrieve and display the number of nodes +- name: Get count of 'beers' nodes + xml: + path: /foo/bar.xml + xpath: /business/beers/beer + count: yes + register: hits + +- debug: + var: hits.count + +# Example where parent XML nodes are created automatically +- name: Add a 'phonenumber' element to the 'business' element + xml: + path: /foo/bar.xml + xpath: /business/phonenumber + value: 555-555-1234 + +- name: Add several more beers to the 'beers' element + xml: + path: /foo/bar.xml + xpath: /business/beers + add_children: + - beer: Old Rasputin + - beer: Old Motor Oil + - beer: Old Curmudgeon + +- name: Add several more beers to the 'beers' element and add them before the 'Rochefort 10' element + xml: + path: /foo/bar.xml + xpath: '/business/beers/beer[text()="Rochefort 10"]' + insertbefore: yes + add_children: + - beer: Old Rasputin + - beer: Old Motor Oil + - beer: Old Curmudgeon + +# NOTE: The 'state' defaults to 'present' and 'value' defaults to 'null' for elements +- name: Add a 'validxhtml' element to the 'website' element + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml + +- name: Add an empty 'validatedon' attribute to the 'validxhtml' element + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml/@validatedon + +- name: Add or modify an attribute, add element if needed + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml + attribute: validatedon + value: 1976-08-05 + +# How to read an attribute value and access it in Ansible +- name: Read an element's attribute values + xml: + path: /foo/bar.xml + xpath: /business/website/validxhtml + content: attribute + register: xmlresp + +- name: Show an attribute value + debug: + var: xmlresp.matches[0].validxhtml.validatedon + +- name: Remove all children from the 'website' element (option 1) + xml: + path: /foo/bar.xml + xpath: /business/website/* + state: absent + +- name: Remove all children from the 'website' element (option 2) + xml: + path: /foo/bar.xml + xpath: /business/website + children: [] + +# In case of namespaces, like in below XML, they have to be explicitly stated. +# +# <foo xmlns="http://x.test" xmlns:attr="http://z.test"> +# <bar> +# <baz xmlns="http://y.test" attr:my_namespaced_attribute="true" /> +# </bar> +# </foo> + +# NOTE: There is the prefix 'x' in front of the 'bar' element, too. +- name: Set namespaced '/x:foo/x:bar/y:baz/@z:my_namespaced_attribute' to 'false' + xml: + path: foo.xml + xpath: /x:foo/x:bar/y:baz + namespaces: + x: http://x.test + y: http://y.test + z: http://z.test + attribute: z:my_namespaced_attribute + value: 'false' +''' + +RETURN = r''' +actions: + description: A dictionary with the original xpath, namespaces and state. + type: dict + returned: success + sample: {xpath: xpath, namespaces: [namespace1, namespace2], state=present} +backup_file: + description: The name of the backup file that was created + type: str + returned: when backup=yes + sample: /path/to/file.xml.1942.2017-08-24@14:16:01~ +count: + description: The count of xpath matches. + type: int + returned: when parameter 'count' is set + sample: 2 +matches: + description: The xpath matches found. + type: list + returned: when parameter 'print_match' is set +msg: + description: A message related to the performed action(s). + type: str + returned: always +xmlstring: + description: An XML string of the resulting output. + type: str + returned: when parameter 'xmlstring' is set +''' + +import copy +import json +import os +import re +import traceback + +from distutils.version import LooseVersion +from io import BytesIO + +LXML_IMP_ERR = None +try: + from lxml import etree, objectify + HAS_LXML = True +except ImportError: + LXML_IMP_ERR = traceback.format_exc() + HAS_LXML = False + +from ansible.module_utils.basic import AnsibleModule, json_dict_bytes_to_unicode, missing_required_lib +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils._text import to_bytes, to_native +from ansible.module_utils.common._collections_compat import MutableMapping + +_IDENT = r"[a-zA-Z-][a-zA-Z0-9_\-\.]*" +_NSIDENT = _IDENT + "|" + _IDENT + ":" + _IDENT +# Note: we can't reasonably support the 'if you need to put both ' and " in a string, concatenate +# strings wrapped by the other delimiter' XPath trick, especially as simple XPath. +_XPSTR = "('(?:.*)'|\"(?:.*)\")" + +_RE_SPLITSIMPLELAST = re.compile("^(.*)/(" + _NSIDENT + ")$") +_RE_SPLITSIMPLELASTEQVALUE = re.compile("^(.*)/(" + _NSIDENT + ")/text\\(\\)=" + _XPSTR + "$") +_RE_SPLITSIMPLEATTRLAST = re.compile("^(.*)/(@(?:" + _NSIDENT + "))$") +_RE_SPLITSIMPLEATTRLASTEQVALUE = re.compile("^(.*)/(@(?:" + _NSIDENT + "))=" + _XPSTR + "$") +_RE_SPLITSUBLAST = re.compile("^(.*)/(" + _NSIDENT + ")\\[(.*)\\]$") +_RE_SPLITONLYEQVALUE = re.compile("^(.*)/text\\(\\)=" + _XPSTR + "$") + + +def has_changed(doc): + orig_obj = etree.tostring(objectify.fromstring(etree.tostring(orig_doc))) + obj = etree.tostring(objectify.fromstring(etree.tostring(doc))) + return (orig_obj != obj) + + +def do_print_match(module, tree, xpath, namespaces): + match = tree.xpath(xpath, namespaces=namespaces) + match_xpaths = [] + for m in match: + match_xpaths.append(tree.getpath(m)) + match_str = json.dumps(match_xpaths) + msg = "selector '%s' match: %s" % (xpath, match_str) + finish(module, tree, xpath, namespaces, changed=False, msg=msg) + + +def count_nodes(module, tree, xpath, namespaces): + """ Return the count of nodes matching the xpath """ + hits = tree.xpath("count(/%s)" % xpath, namespaces=namespaces) + msg = "found %d nodes" % hits + finish(module, tree, xpath, namespaces, changed=False, msg=msg, hitcount=int(hits)) + + +def is_node(tree, xpath, namespaces): + """ Test if a given xpath matches anything and if that match is a node. + + For now we just assume you're only searching for one specific thing.""" + if xpath_matches(tree, xpath, namespaces): + # OK, it found something + match = tree.xpath(xpath, namespaces=namespaces) + if isinstance(match[0], etree._Element): + return True + + return False + + +def is_attribute(tree, xpath, namespaces): + """ Test if a given xpath matches and that match is an attribute + + An xpath attribute search will only match one item""" + if xpath_matches(tree, xpath, namespaces): + match = tree.xpath(xpath, namespaces=namespaces) + if isinstance(match[0], etree._ElementStringResult): + return True + elif isinstance(match[0], etree._ElementUnicodeResult): + return True + return False + + +def xpath_matches(tree, xpath, namespaces): + """ Test if a node exists """ + if tree.xpath(xpath, namespaces=namespaces): + return True + return False + + +def delete_xpath_target(module, tree, xpath, namespaces): + """ Delete an attribute or element from a tree """ + try: + for result in tree.xpath(xpath, namespaces=namespaces): + # Get the xpath for this result + if is_attribute(tree, xpath, namespaces): + # Delete an attribute + parent = result.getparent() + # Pop this attribute match out of the parent + # node's 'attrib' dict by using this match's + # 'attrname' attribute for the key + parent.attrib.pop(result.attrname) + elif is_node(tree, xpath, namespaces): + # Delete an element + result.getparent().remove(result) + else: + raise Exception("Impossible error") + except Exception as e: + module.fail_json(msg="Couldn't delete xpath target: %s (%s)" % (xpath, e)) + else: + finish(module, tree, xpath, namespaces, changed=True) + + +def replace_children_of(children, match): + for element in list(match): + match.remove(element) + match.extend(children) + + +def set_target_children_inner(module, tree, xpath, namespaces, children, in_type): + matches = tree.xpath(xpath, namespaces=namespaces) + + # Create a list of our new children + children = children_to_nodes(module, children, in_type) + children_as_string = [etree.tostring(c) for c in children] + + changed = False + + # xpaths always return matches as a list, so.... + for match in matches: + # Check if elements differ + if len(list(match)) == len(children): + for idx, element in enumerate(list(match)): + if etree.tostring(element) != children_as_string[idx]: + replace_children_of(children, match) + changed = True + break + else: + replace_children_of(children, match) + changed = True + + return changed + + +def set_target_children(module, tree, xpath, namespaces, children, in_type): + changed = set_target_children_inner(module, tree, xpath, namespaces, children, in_type) + # Write it out + finish(module, tree, xpath, namespaces, changed=changed) + + +def add_target_children(module, tree, xpath, namespaces, children, in_type, insertbefore, insertafter): + if is_node(tree, xpath, namespaces): + new_kids = children_to_nodes(module, children, in_type) + if insertbefore or insertafter: + insert_target_children(tree, xpath, namespaces, new_kids, insertbefore, insertafter) + else: + for node in tree.xpath(xpath, namespaces=namespaces): + node.extend(new_kids) + finish(module, tree, xpath, namespaces, changed=True) + else: + finish(module, tree, xpath, namespaces) + + +def insert_target_children(tree, xpath, namespaces, children, insertbefore, insertafter): + """ + Insert the given children before or after the given xpath. If insertbefore is True, it is inserted before the + first xpath hit, with insertafter, it is inserted after the last xpath hit. + """ + insert_target = tree.xpath(xpath, namespaces=namespaces) + loc_index = 0 if insertbefore else -1 + index_in_parent = insert_target[loc_index].getparent().index(insert_target[loc_index]) + parent = insert_target[0].getparent() + if insertafter: + index_in_parent += 1 + for child in children: + parent.insert(index_in_parent, child) + index_in_parent += 1 + + +def _extract_xpstr(g): + return g[1:-1] + + +def split_xpath_last(xpath): + """split an XPath of the form /foo/bar/baz into /foo/bar and baz""" + xpath = xpath.strip() + m = _RE_SPLITSIMPLELAST.match(xpath) + if m: + # requesting an element to exist + return (m.group(1), [(m.group(2), None)]) + m = _RE_SPLITSIMPLELASTEQVALUE.match(xpath) + if m: + # requesting an element to exist with an inner text + return (m.group(1), [(m.group(2), _extract_xpstr(m.group(3)))]) + + m = _RE_SPLITSIMPLEATTRLAST.match(xpath) + if m: + # requesting an attribute to exist + return (m.group(1), [(m.group(2), None)]) + m = _RE_SPLITSIMPLEATTRLASTEQVALUE.match(xpath) + if m: + # requesting an attribute to exist with a value + return (m.group(1), [(m.group(2), _extract_xpstr(m.group(3)))]) + + m = _RE_SPLITSUBLAST.match(xpath) + if m: + content = [x.strip() for x in m.group(3).split(" and ")] + return (m.group(1), [('/' + m.group(2), content)]) + + m = _RE_SPLITONLYEQVALUE.match(xpath) + if m: + # requesting a change of inner text + return (m.group(1), [("", _extract_xpstr(m.group(2)))]) + return (xpath, []) + + +def nsnameToClark(name, namespaces): + if ":" in name: + (nsname, rawname) = name.split(":") + # return "{{%s}}%s" % (namespaces[nsname], rawname) + return "{{{0}}}{1}".format(namespaces[nsname], rawname) + + # no namespace name here + return name + + +def check_or_make_target(module, tree, xpath, namespaces): + (inner_xpath, changes) = split_xpath_last(xpath) + if (inner_xpath == xpath) or (changes is None): + module.fail_json(msg="Can't process Xpath %s in order to spawn nodes! tree is %s" % + (xpath, etree.tostring(tree, pretty_print=True))) + return False + + changed = False + + if not is_node(tree, inner_xpath, namespaces): + changed = check_or_make_target(module, tree, inner_xpath, namespaces) + + # we test again after calling check_or_make_target + if is_node(tree, inner_xpath, namespaces) and changes: + for (eoa, eoa_value) in changes: + if eoa and eoa[0] != '@' and eoa[0] != '/': + # implicitly creating an element + new_kids = children_to_nodes(module, [nsnameToClark(eoa, namespaces)], "yaml") + if eoa_value: + for nk in new_kids: + nk.text = eoa_value + + for node in tree.xpath(inner_xpath, namespaces=namespaces): + node.extend(new_kids) + changed = True + # module.fail_json(msg="now tree=%s" % etree.tostring(tree, pretty_print=True)) + elif eoa and eoa[0] == '/': + element = eoa[1:] + new_kids = children_to_nodes(module, [nsnameToClark(element, namespaces)], "yaml") + for node in tree.xpath(inner_xpath, namespaces=namespaces): + node.extend(new_kids) + for nk in new_kids: + for subexpr in eoa_value: + # module.fail_json(msg="element=%s subexpr=%s node=%s now tree=%s" % + # (element, subexpr, etree.tostring(node, pretty_print=True), etree.tostring(tree, pretty_print=True)) + check_or_make_target(module, nk, "./" + subexpr, namespaces) + changed = True + + # module.fail_json(msg="now tree=%s" % etree.tostring(tree, pretty_print=True)) + elif eoa == "": + for node in tree.xpath(inner_xpath, namespaces=namespaces): + if (node.text != eoa_value): + node.text = eoa_value + changed = True + + elif eoa and eoa[0] == '@': + attribute = nsnameToClark(eoa[1:], namespaces) + + for element in tree.xpath(inner_xpath, namespaces=namespaces): + changing = (attribute not in element.attrib or element.attrib[attribute] != eoa_value) + + if changing: + changed = changed or changing + if eoa_value is None: + value = "" + else: + value = eoa_value + element.attrib[attribute] = value + + # module.fail_json(msg="arf %s changing=%s as curval=%s changed tree=%s" % + # (xpath, changing, etree.tostring(tree, changing, element[attribute], pretty_print=True))) + + else: + module.fail_json(msg="unknown tree transformation=%s" % etree.tostring(tree, pretty_print=True)) + + return changed + + +def ensure_xpath_exists(module, tree, xpath, namespaces): + changed = False + + if not is_node(tree, xpath, namespaces): + changed = check_or_make_target(module, tree, xpath, namespaces) + + finish(module, tree, xpath, namespaces, changed) + + +def set_target_inner(module, tree, xpath, namespaces, attribute, value): + changed = False + + try: + if not is_node(tree, xpath, namespaces): + changed = check_or_make_target(module, tree, xpath, namespaces) + except Exception as e: + missing_namespace = "" + # NOTE: This checks only the namespaces defined in root element! + # TODO: Implement a more robust check to check for child namespaces' existence + if tree.getroot().nsmap and ":" not in xpath: + missing_namespace = "XML document has namespace(s) defined, but no namespace prefix(es) used in xpath!\n" + module.fail_json(msg="%sXpath %s causes a failure: %s\n -- tree is %s" % + (missing_namespace, xpath, e, etree.tostring(tree, pretty_print=True)), exception=traceback.format_exc()) + + if not is_node(tree, xpath, namespaces): + module.fail_json(msg="Xpath %s does not reference a node! tree is %s" % + (xpath, etree.tostring(tree, pretty_print=True))) + + for element in tree.xpath(xpath, namespaces=namespaces): + if not attribute: + changed = changed or (element.text != value) + if element.text != value: + element.text = value + else: + changed = changed or (element.get(attribute) != value) + if ":" in attribute: + attr_ns, attr_name = attribute.split(":") + # attribute = "{{%s}}%s" % (namespaces[attr_ns], attr_name) + attribute = "{{{0}}}{1}".format(namespaces[attr_ns], attr_name) + if element.get(attribute) != value: + element.set(attribute, value) + + return changed + + +def set_target(module, tree, xpath, namespaces, attribute, value): + changed = set_target_inner(module, tree, xpath, namespaces, attribute, value) + finish(module, tree, xpath, namespaces, changed) + + +def get_element_text(module, tree, xpath, namespaces): + if not is_node(tree, xpath, namespaces): + module.fail_json(msg="Xpath %s does not reference a node!" % xpath) + + elements = [] + for element in tree.xpath(xpath, namespaces=namespaces): + elements.append({element.tag: element.text}) + + finish(module, tree, xpath, namespaces, changed=False, msg=len(elements), hitcount=len(elements), matches=elements) + + +def get_element_attr(module, tree, xpath, namespaces): + if not is_node(tree, xpath, namespaces): + module.fail_json(msg="Xpath %s does not reference a node!" % xpath) + + elements = [] + for element in tree.xpath(xpath, namespaces=namespaces): + child = {} + for key in element.keys(): + value = element.get(key) + child.update({key: value}) + elements.append({element.tag: child}) + + finish(module, tree, xpath, namespaces, changed=False, msg=len(elements), hitcount=len(elements), matches=elements) + + +def child_to_element(module, child, in_type): + if in_type == 'xml': + infile = BytesIO(to_bytes(child, errors='surrogate_or_strict')) + + try: + parser = etree.XMLParser() + node = etree.parse(infile, parser) + return node.getroot() + except etree.XMLSyntaxError as e: + module.fail_json(msg="Error while parsing child element: %s" % e) + elif in_type == 'yaml': + if isinstance(child, string_types): + return etree.Element(child) + elif isinstance(child, MutableMapping): + if len(child) > 1: + module.fail_json(msg="Can only create children from hashes with one key") + + (key, value) = next(iteritems(child)) + if isinstance(value, MutableMapping): + children = value.pop('_', None) + + node = etree.Element(key, value) + + if children is not None: + if not isinstance(children, list): + module.fail_json(msg="Invalid children type: %s, must be list." % type(children)) + + subnodes = children_to_nodes(module, children) + node.extend(subnodes) + else: + node = etree.Element(key) + node.text = value + return node + else: + module.fail_json(msg="Invalid child type: %s. Children must be either strings or hashes." % type(child)) + else: + module.fail_json(msg="Invalid child input type: %s. Type must be either xml or yaml." % in_type) + + +def children_to_nodes(module=None, children=None, type='yaml'): + """turn a str/hash/list of str&hash into a list of elements""" + children = [] if children is None else children + + return [child_to_element(module, child, type) for child in children] + + +def make_pretty(module, tree): + xml_string = etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + result = dict( + changed=False, + ) + + if module.params['path']: + xml_file = module.params['path'] + with open(xml_file, 'rb') as xml_content: + if xml_string != xml_content.read(): + result['changed'] = True + if not module.check_mode: + if module.params['backup']: + result['backup_file'] = module.backup_local(module.params['path']) + tree.write(xml_file, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + elif module.params['xmlstring']: + result['xmlstring'] = xml_string + # NOTE: Modifying a string is not considered a change ! + if xml_string != module.params['xmlstring']: + result['changed'] = True + + module.exit_json(**result) + + +def finish(module, tree, xpath, namespaces, changed=False, msg='', hitcount=0, matches=tuple()): + + result = dict( + actions=dict( + xpath=xpath, + namespaces=namespaces, + state=module.params['state'] + ), + changed=has_changed(tree), + ) + + if module.params['count'] or hitcount: + result['count'] = hitcount + + if module.params['print_match'] or matches: + result['matches'] = matches + + if msg: + result['msg'] = msg + + if result['changed']: + if module._diff: + result['diff'] = dict( + before=etree.tostring(orig_doc, xml_declaration=True, encoding='UTF-8', pretty_print=True), + after=etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=True), + ) + + if module.params['path'] and not module.check_mode: + if module.params['backup']: + result['backup_file'] = module.backup_local(module.params['path']) + tree.write(module.params['path'], xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + if module.params['xmlstring']: + result['xmlstring'] = etree.tostring(tree, xml_declaration=True, encoding='UTF-8', pretty_print=module.params['pretty_print']) + + module.exit_json(**result) + + +def main(): + module = AnsibleModule( + argument_spec=dict( + path=dict(type='path', aliases=['dest', 'file']), + xmlstring=dict(type='str'), + xpath=dict(type='str'), + namespaces=dict(type='dict', default={}), + state=dict(type='str', default='present', choices=['absent', 'present'], aliases=['ensure']), + value=dict(type='raw'), + attribute=dict(type='raw'), + add_children=dict(type='list'), + set_children=dict(type='list'), + count=dict(type='bool', default=False), + print_match=dict(type='bool', default=False), + pretty_print=dict(type='bool', default=False), + content=dict(type='str', choices=['attribute', 'text']), + input_type=dict(type='str', default='yaml', choices=['xml', 'yaml']), + backup=dict(type='bool', default=False), + strip_cdata_tags=dict(type='bool', default=False), + insertbefore=dict(type='bool', default=False), + insertafter=dict(type='bool', default=False), + ), + supports_check_mode=True, + required_by=dict( + add_children=['xpath'], + # TODO: Reinstate this in Ansible v2.12 when we have deprecated the incorrect use below + # attribute=['value'], + content=['xpath'], + set_children=['xpath'], + value=['xpath'], + ), + required_if=[ + ['count', True, ['xpath']], + ['print_match', True, ['xpath']], + ['insertbefore', True, ['xpath']], + ['insertafter', True, ['xpath']], + ], + required_one_of=[ + ['path', 'xmlstring'], + ['add_children', 'content', 'count', 'pretty_print', 'print_match', 'set_children', 'value'], + ], + mutually_exclusive=[ + ['add_children', 'content', 'count', 'print_match', 'set_children', 'value'], + ['path', 'xmlstring'], + ['insertbefore', 'insertafter'], + ], + ) + + xml_file = module.params['path'] + xml_string = module.params['xmlstring'] + xpath = module.params['xpath'] + namespaces = module.params['namespaces'] + state = module.params['state'] + value = json_dict_bytes_to_unicode(module.params['value']) + attribute = module.params['attribute'] + set_children = json_dict_bytes_to_unicode(module.params['set_children']) + add_children = json_dict_bytes_to_unicode(module.params['add_children']) + pretty_print = module.params['pretty_print'] + content = module.params['content'] + input_type = module.params['input_type'] + print_match = module.params['print_match'] + count = module.params['count'] + backup = module.params['backup'] + strip_cdata_tags = module.params['strip_cdata_tags'] + insertbefore = module.params['insertbefore'] + insertafter = module.params['insertafter'] + + # Check if we have lxml 2.3.0 or newer installed + if not HAS_LXML: + module.fail_json(msg=missing_required_lib("lxml"), exception=LXML_IMP_ERR) + elif LooseVersion('.'.join(to_native(f) for f in etree.LXML_VERSION)) < LooseVersion('2.3.0'): + module.fail_json(msg='The xml ansible module requires lxml 2.3.0 or newer installed on the managed machine') + elif LooseVersion('.'.join(to_native(f) for f in etree.LXML_VERSION)) < LooseVersion('3.0.0'): + module.warn('Using lxml version lower than 3.0.0 does not guarantee predictable element attribute order.') + + # Report wrongly used attribute parameter when using content=attribute + # TODO: Remove this in Ansible v2.12 (and reinstate strict parameter test above) and remove the integration test example + if content == 'attribute' and attribute is not None: + module.deprecate("Parameter 'attribute=%s' is ignored when using 'content=attribute' only 'xpath' is used. Please remove entry." % attribute, + '2.12', collection_name='ansible.builtin') + + # Check if the file exists + if xml_string: + infile = BytesIO(to_bytes(xml_string, errors='surrogate_or_strict')) + elif os.path.isfile(xml_file): + infile = open(xml_file, 'rb') + else: + module.fail_json(msg="The target XML source '%s' does not exist." % xml_file) + + # Parse and evaluate xpath expression + if xpath is not None: + try: + etree.XPath(xpath) + except etree.XPathSyntaxError as e: + module.fail_json(msg="Syntax error in xpath expression: %s (%s)" % (xpath, e)) + except etree.XPathEvalError as e: + module.fail_json(msg="Evaluation error in xpath expression: %s (%s)" % (xpath, e)) + + # Try to parse in the target XML file + try: + parser = etree.XMLParser(remove_blank_text=pretty_print, strip_cdata=strip_cdata_tags) + doc = etree.parse(infile, parser) + except etree.XMLSyntaxError as e: + module.fail_json(msg="Error while parsing document: %s (%s)" % (xml_file or 'xml_string', e)) + + # Ensure we have the original copy to compare + global orig_doc + orig_doc = copy.deepcopy(doc) + + if print_match: + do_print_match(module, doc, xpath, namespaces) + + if count: + count_nodes(module, doc, xpath, namespaces) + + if content == 'attribute': + get_element_attr(module, doc, xpath, namespaces) + elif content == 'text': + get_element_text(module, doc, xpath, namespaces) + + # File exists: + if state == 'absent': + # - absent: delete xpath target + delete_xpath_target(module, doc, xpath, namespaces) + + # - present: carry on + + # children && value both set?: should have already aborted by now + # add_children && set_children both set?: should have already aborted by now + + # set_children set? + if set_children: + set_target_children(module, doc, xpath, namespaces, set_children, input_type) + + # add_children set? + if add_children: + add_target_children(module, doc, xpath, namespaces, add_children, input_type, insertbefore, insertafter) + + # No?: Carry on + + # Is the xpath target an attribute selector? + if value is not None: + set_target(module, doc, xpath, namespaces, attribute, value) + + # If an xpath was provided, we need to do something with the data + if xpath is not None: + ensure_xpath_exists(module, doc, xpath, namespaces) + + # Otherwise only reformat the xml data? + if pretty_print: + make_pretty(module, doc) + + module.fail_json(msg="Don't know what to do") + + +if __name__ == '__main__': + main() diff --git a/test/support/integration/plugins/modules/zypper.py b/test/support/integration/plugins/modules/zypper.py new file mode 100644 index 00000000..bfb31819 --- /dev/null +++ b/test/support/integration/plugins/modules/zypper.py @@ -0,0 +1,540 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +# (c) 2013, Patrick Callahan <pmc@patrickcallahan.com> +# based on +# openbsd_pkg +# (c) 2013 +# Patrik Lundin <patrik.lundin.swe@gmail.com> +# +# yum +# (c) 2012, Red Hat, Inc +# Written by Seth Vidal <skvidal at fedoraproject.org> +# +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +ANSIBLE_METADATA = {'metadata_version': '1.1', + 'status': ['preview'], + 'supported_by': 'community'} + + +DOCUMENTATION = ''' +--- +module: zypper +author: + - "Patrick Callahan (@dirtyharrycallahan)" + - "Alexander Gubin (@alxgu)" + - "Thomas O'Donnell (@andytom)" + - "Robin Roth (@robinro)" + - "Andrii Radyk (@AnderEnder)" +version_added: "1.2" +short_description: Manage packages on SUSE and openSUSE +description: + - Manage packages on SUSE and openSUSE using the zypper and rpm tools. +options: + name: + description: + - Package name C(name) or package specifier or a list of either. + - Can include a version like C(name=1.0), C(name>3.4) or C(name<=2.7). If a version is given, C(oldpackage) is implied and zypper is allowed to + update the package within the version range given. + - You can also pass a url or a local path to a rpm file. + - When using state=latest, this can be '*', which updates all installed packages. + required: true + aliases: [ 'pkg' ] + state: + description: + - C(present) will make sure the package is installed. + C(latest) will make sure the latest version of the package is installed. + C(absent) will make sure the specified package is not installed. + C(dist-upgrade) will make sure the latest version of all installed packages from all enabled repositories is installed. + - When using C(dist-upgrade), I(name) should be C('*'). + required: false + choices: [ present, latest, absent, dist-upgrade ] + default: "present" + type: + description: + - The type of package to be operated on. + required: false + choices: [ package, patch, pattern, product, srcpackage, application ] + default: "package" + version_added: "2.0" + extra_args_precommand: + version_added: "2.6" + required: false + description: + - Add additional global target options to C(zypper). + - Options should be supplied in a single line as if given in the command line. + disable_gpg_check: + description: + - Whether to disable to GPG signature checking of the package + signature being installed. Has an effect only if state is + I(present) or I(latest). + required: false + default: "no" + type: bool + disable_recommends: + version_added: "1.8" + description: + - Corresponds to the C(--no-recommends) option for I(zypper). Default behavior (C(yes)) modifies zypper's default behavior; C(no) does + install recommended packages. + required: false + default: "yes" + type: bool + force: + version_added: "2.2" + description: + - Adds C(--force) option to I(zypper). Allows to downgrade packages and change vendor or architecture. + required: false + default: "no" + type: bool + force_resolution: + version_added: "2.10" + description: + - Adds C(--force-resolution) option to I(zypper). Allows to (un)install packages with conflicting requirements (resolver will choose a solution). + required: false + default: "no" + type: bool + update_cache: + version_added: "2.2" + description: + - Run the equivalent of C(zypper refresh) before the operation. Disabled in check mode. + required: false + default: "no" + type: bool + aliases: [ "refresh" ] + oldpackage: + version_added: "2.2" + description: + - Adds C(--oldpackage) option to I(zypper). Allows to downgrade packages with less side-effects than force. This is implied as soon as a + version is specified as part of the package name. + required: false + default: "no" + type: bool + extra_args: + version_added: "2.4" + required: false + description: + - Add additional options to C(zypper) command. + - Options should be supplied in a single line as if given in the command line. +notes: + - When used with a `loop:` each package will be processed individually, + it is much more efficient to pass the list directly to the `name` option. +# informational: requirements for nodes +requirements: + - "zypper >= 1.0 # included in openSUSE >= 11.1 or SUSE Linux Enterprise Server/Desktop >= 11.0" + - python-xml + - rpm +''' + +EXAMPLES = ''' +# Install "nmap" +- zypper: + name: nmap + state: present + +# Install apache2 with recommended packages +- zypper: + name: apache2 + state: present + disable_recommends: no + +# Apply a given patch +- zypper: + name: openSUSE-2016-128 + state: present + type: patch + +# Remove the "nmap" package +- zypper: + name: nmap + state: absent + +# Install the nginx rpm from a remote repo +- zypper: + name: 'http://nginx.org/packages/sles/12/x86_64/RPMS/nginx-1.8.0-1.sles12.ngx.x86_64.rpm' + state: present + +# Install local rpm file +- zypper: + name: /tmp/fancy-software.rpm + state: present + +# Update all packages +- zypper: + name: '*' + state: latest + +# Apply all available patches +- zypper: + name: '*' + state: latest + type: patch + +# Perform a dist-upgrade with additional arguments +- zypper: + name: '*' + state: dist-upgrade + extra_args: '--no-allow-vendor-change --allow-arch-change' + +# Refresh repositories and update package "openssl" +- zypper: + name: openssl + state: present + update_cache: yes + +# Install specific version (possible comparisons: <, >, <=, >=, =) +- zypper: + name: 'docker>=1.10' + state: present + +# Wait 20 seconds to acquire the lock before failing +- zypper: + name: mosh + state: present + environment: + ZYPP_LOCK_TIMEOUT: 20 +''' + +import xml +import re +from xml.dom.minidom import parseString as parseXML +from ansible.module_utils.six import iteritems +from ansible.module_utils._text import to_native + +# import module snippets +from ansible.module_utils.basic import AnsibleModule + + +class Package: + def __init__(self, name, prefix, version): + self.name = name + self.prefix = prefix + self.version = version + self.shouldinstall = (prefix == '+') + + def __str__(self): + return self.prefix + self.name + self.version + + +def split_name_version(name): + """splits of the package name and desired version + + example formats: + - docker>=1.10 + - apache=2.4 + + Allowed version specifiers: <, >, <=, >=, = + Allowed version format: [0-9.-]* + + Also allows a prefix indicating remove "-", "~" or install "+" + """ + + prefix = '' + if name[0] in ['-', '~', '+']: + prefix = name[0] + name = name[1:] + if prefix == '~': + prefix = '-' + + version_check = re.compile('^(.*?)((?:<|>|<=|>=|=)[0-9.-]*)?$') + try: + reres = version_check.match(name) + name, version = reres.groups() + if version is None: + version = '' + return prefix, name, version + except Exception: + return prefix, name, '' + + +def get_want_state(names, remove=False): + packages = [] + urls = [] + for name in names: + if '://' in name or name.endswith('.rpm'): + urls.append(name) + else: + prefix, pname, version = split_name_version(name) + if prefix not in ['-', '+']: + if remove: + prefix = '-' + else: + prefix = '+' + packages.append(Package(pname, prefix, version)) + return packages, urls + + +def get_installed_state(m, packages): + "get installed state of packages" + + cmd = get_cmd(m, 'search') + cmd.extend(['--match-exact', '--details', '--installed-only']) + cmd.extend([p.name for p in packages]) + return parse_zypper_xml(m, cmd, fail_not_found=False)[0] + + +def parse_zypper_xml(m, cmd, fail_not_found=True, packages=None): + rc, stdout, stderr = m.run_command(cmd, check_rc=False) + + try: + dom = parseXML(stdout) + except xml.parsers.expat.ExpatError as exc: + m.fail_json(msg="Failed to parse zypper xml output: %s" % to_native(exc), + rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + + if rc == 104: + # exit code 104 is ZYPPER_EXIT_INF_CAP_NOT_FOUND (no packages found) + if fail_not_found: + errmsg = dom.getElementsByTagName('message')[-1].childNodes[0].data + m.fail_json(msg=errmsg, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + else: + return {}, rc, stdout, stderr + elif rc in [0, 106, 103]: + # zypper exit codes + # 0: success + # 106: signature verification failed + # 103: zypper was upgraded, run same command again + if packages is None: + firstrun = True + packages = {} + solvable_list = dom.getElementsByTagName('solvable') + for solvable in solvable_list: + name = solvable.getAttribute('name') + packages[name] = {} + packages[name]['version'] = solvable.getAttribute('edition') + packages[name]['oldversion'] = solvable.getAttribute('edition-old') + status = solvable.getAttribute('status') + packages[name]['installed'] = status == "installed" + packages[name]['group'] = solvable.parentNode.nodeName + if rc == 103 and firstrun: + # if this was the first run and it failed with 103 + # run zypper again with the same command to complete update + return parse_zypper_xml(m, cmd, fail_not_found=fail_not_found, packages=packages) + + return packages, rc, stdout, stderr + m.fail_json(msg='Zypper run command failed with return code %s.' % rc, rc=rc, stdout=stdout, stderr=stderr, cmd=cmd) + + +def get_cmd(m, subcommand): + "puts together the basic zypper command arguments with those passed to the module" + is_install = subcommand in ['install', 'update', 'patch', 'dist-upgrade'] + is_refresh = subcommand == 'refresh' + cmd = ['/usr/bin/zypper', '--quiet', '--non-interactive', '--xmlout'] + if m.params['extra_args_precommand']: + args_list = m.params['extra_args_precommand'].split() + cmd.extend(args_list) + # add global options before zypper command + if (is_install or is_refresh) and m.params['disable_gpg_check']: + cmd.append('--no-gpg-checks') + + if subcommand == 'search': + cmd.append('--disable-repositories') + + cmd.append(subcommand) + if subcommand not in ['patch', 'dist-upgrade'] and not is_refresh: + cmd.extend(['--type', m.params['type']]) + if m.check_mode and subcommand != 'search': + cmd.append('--dry-run') + if is_install: + cmd.append('--auto-agree-with-licenses') + if m.params['disable_recommends']: + cmd.append('--no-recommends') + if m.params['force']: + cmd.append('--force') + if m.params['force_resolution']: + cmd.append('--force-resolution') + if m.params['oldpackage']: + cmd.append('--oldpackage') + if m.params['extra_args']: + args_list = m.params['extra_args'].split(' ') + cmd.extend(args_list) + + return cmd + + +def set_diff(m, retvals, result): + # TODO: if there is only one package, set before/after to version numbers + packages = {'installed': [], 'removed': [], 'upgraded': []} + if result: + for p in result: + group = result[p]['group'] + if group == 'to-upgrade': + versions = ' (' + result[p]['oldversion'] + ' => ' + result[p]['version'] + ')' + packages['upgraded'].append(p + versions) + elif group == 'to-install': + packages['installed'].append(p) + elif group == 'to-remove': + packages['removed'].append(p) + + output = '' + for state in packages: + if packages[state]: + output += state + ': ' + ', '.join(packages[state]) + '\n' + if 'diff' not in retvals: + retvals['diff'] = {} + if 'prepared' not in retvals['diff']: + retvals['diff']['prepared'] = output + else: + retvals['diff']['prepared'] += '\n' + output + + +def package_present(m, name, want_latest): + "install and update (if want_latest) the packages in name_install, while removing the packages in name_remove" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + packages, urls = get_want_state(name) + + # add oldpackage flag when a version is given to allow downgrades + if any(p.version for p in packages): + m.params['oldpackage'] = True + + if not want_latest: + # for state=present: filter out already installed packages + # if a version is given leave the package in to let zypper handle the version + # resolution + packageswithoutversion = [p for p in packages if not p.version] + prerun_state = get_installed_state(m, packageswithoutversion) + # generate lists of packages to install or remove + packages = [p for p in packages if p.shouldinstall != (p.name in prerun_state)] + + if not packages and not urls: + # nothing to install/remove and nothing to update + return None, retvals + + # zypper install also updates packages + cmd = get_cmd(m, 'install') + cmd.append('--') + cmd.extend(urls) + # pass packages to zypper + # allow for + or - prefixes in install/remove lists + # also add version specifier if given + # do this in one zypper run to allow for dependency-resolution + # for example "-exim postfix" runs without removing packages depending on mailserver + cmd.extend([str(p) for p in packages]) + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + + return result, retvals + + +def package_update_all(m): + "run update or patch on all available packages" + + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + if m.params['type'] == 'patch': + cmdname = 'patch' + elif m.params['state'] == 'dist-upgrade': + cmdname = 'dist-upgrade' + else: + cmdname = 'update' + + cmd = get_cmd(m, cmdname) + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + return result, retvals + + +def package_absent(m, name): + "remove the packages in name" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + # Get package state + packages, urls = get_want_state(name, remove=True) + if any(p.prefix == '+' for p in packages): + m.fail_json(msg="Can not combine '+' prefix with state=remove/absent.") + if urls: + m.fail_json(msg="Can not remove via URL.") + if m.params['type'] == 'patch': + m.fail_json(msg="Can not remove patches.") + prerun_state = get_installed_state(m, packages) + packages = [p for p in packages if p.name in prerun_state] + + if not packages: + return None, retvals + + cmd = get_cmd(m, 'remove') + cmd.extend([p.name + p.version for p in packages]) + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + return result, retvals + + +def repo_refresh(m): + "update the repositories" + retvals = {'rc': 0, 'stdout': '', 'stderr': ''} + + cmd = get_cmd(m, 'refresh') + + retvals['cmd'] = cmd + result, retvals['rc'], retvals['stdout'], retvals['stderr'] = parse_zypper_xml(m, cmd) + + return retvals + +# =========================================== +# Main control flow + + +def main(): + module = AnsibleModule( + argument_spec=dict( + name=dict(required=True, aliases=['pkg'], type='list'), + state=dict(required=False, default='present', choices=['absent', 'installed', 'latest', 'present', 'removed', 'dist-upgrade']), + type=dict(required=False, default='package', choices=['package', 'patch', 'pattern', 'product', 'srcpackage', 'application']), + extra_args_precommand=dict(required=False, default=None), + disable_gpg_check=dict(required=False, default='no', type='bool'), + disable_recommends=dict(required=False, default='yes', type='bool'), + force=dict(required=False, default='no', type='bool'), + force_resolution=dict(required=False, default='no', type='bool'), + update_cache=dict(required=False, aliases=['refresh'], default='no', type='bool'), + oldpackage=dict(required=False, default='no', type='bool'), + extra_args=dict(required=False, default=None), + ), + supports_check_mode=True + ) + + name = module.params['name'] + state = module.params['state'] + update_cache = module.params['update_cache'] + + # remove empty strings from package list + name = list(filter(None, name)) + + # Refresh repositories + if update_cache and not module.check_mode: + retvals = repo_refresh(module) + + if retvals['rc'] != 0: + module.fail_json(msg="Zypper refresh run failed.", **retvals) + + # Perform requested action + if name == ['*'] and state in ['latest', 'dist-upgrade']: + packages_changed, retvals = package_update_all(module) + elif name != ['*'] and state == 'dist-upgrade': + module.fail_json(msg="Can not dist-upgrade specific packages.") + else: + if state in ['absent', 'removed']: + packages_changed, retvals = package_absent(module, name) + elif state in ['installed', 'present', 'latest']: + packages_changed, retvals = package_present(module, name, state == 'latest') + + retvals['changed'] = retvals['rc'] == 0 and bool(packages_changed) + + if module._diff: + set_diff(module, retvals, packages_changed) + + if retvals['rc'] != 0: + module.fail_json(msg="Zypper run failed.", **retvals) + + if not retvals['changed']: + del retvals['stdout'] + del retvals['stderr'] + + module.exit_json(name=name, state=state, update_cache=update_cache, **retvals) + + +if __name__ == "__main__": + main() |